CONTROLLER-1641: Integrate DataTreeCohortActor with pipe-lining 61/66861/5
authorTom Pantelis <tompantelis@gmail.com>
Thu, 4 Jan 2018 04:52:12 +0000 (23:52 -0500)
committerRobert Varga <nite@hq.sk>
Mon, 29 Jan 2018 19:03:19 +0000 (19:03 +0000)
The DataTreeCohortActor was originally written assuming that only
one 3-phase commit was in progress at any time and thus maintained
a single state. However with transaction pipe-lining there can be
multiple simultaneous 3-phase commits so DataTreeCohortActor was
modified to maintain/track state per transaction.

In addition, it now also handles the DOMDataTreeCommitCohort
Futures async.

Change-Id: Ib7588ea2e32b297a2db0b532726549f9ec61a1a4
Signed-off-by: Tom Pantelis <tompantelis@gmail.com>
opendaylight/md-sal/sal-distributed-datastore/src/main/java/org/opendaylight/controller/cluster/datastore/DataTreeCohortActor.java
opendaylight/md-sal/sal-distributed-datastore/src/test/java/org/opendaylight/controller/cluster/datastore/DataTreeCohortActorTest.java [new file with mode: 0644]

index da1117764f8b8146ef7f8f49e22bcc7f2ef64d8b..e8db09a7cab570587615d1b1850ce84b6bbb98a0 100644 (file)
@@ -11,8 +11,17 @@ package org.opendaylight.controller.cluster.datastore;
 import akka.actor.ActorRef;
 import akka.actor.Props;
 import akka.actor.Status;
 import akka.actor.ActorRef;
 import akka.actor.Props;
 import akka.actor.Status;
-import com.google.common.base.Preconditions;
+import com.google.common.util.concurrent.FutureCallback;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.MoreExecutors;
 import java.util.Collection;
 import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+import java.util.concurrent.Executor;
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
 import org.opendaylight.controller.cluster.access.concepts.TransactionIdentifier;
 import org.opendaylight.controller.cluster.common.actor.AbstractUntypedActor;
 import org.opendaylight.mdsal.common.api.PostCanCommitStep;
 import org.opendaylight.controller.cluster.access.concepts.TransactionIdentifier;
 import org.opendaylight.controller.cluster.common.actor.AbstractUntypedActor;
 import org.opendaylight.mdsal.common.api.PostCanCommitStep;
@@ -28,25 +37,32 @@ import org.opendaylight.yangtools.yang.model.api.SchemaContext;
  * decapsulating DataTreeChanged messages and dispatching their context to the user.
  */
 final class DataTreeCohortActor extends AbstractUntypedActor {
  * decapsulating DataTreeChanged messages and dispatching their context to the user.
  */
 final class DataTreeCohortActor extends AbstractUntypedActor {
-    private final CohortBehaviour<?> idleState = new Idle();
+    private final Idle idleState = new Idle();
     private final DOMDataTreeCommitCohort cohort;
     private final YangInstanceIdentifier registeredPath;
     private final DOMDataTreeCommitCohort cohort;
     private final YangInstanceIdentifier registeredPath;
-    private CohortBehaviour<?> currentState = idleState;
+    private final Map<TransactionIdentifier, CohortBehaviour<?, ?>> currentStateMap = new HashMap<>();
 
     private DataTreeCohortActor(final DOMDataTreeCommitCohort cohort, final YangInstanceIdentifier registeredPath) {
 
     private DataTreeCohortActor(final DOMDataTreeCommitCohort cohort, final YangInstanceIdentifier registeredPath) {
-        this.cohort = Preconditions.checkNotNull(cohort);
-        this.registeredPath = Preconditions.checkNotNull(registeredPath);
+        this.cohort = Objects.requireNonNull(cohort);
+        this.registeredPath = Objects.requireNonNull(registeredPath);
     }
 
     @Override
     protected void handleReceive(final Object message) {
     }
 
     @Override
     protected void handleReceive(final Object message) {
+        if (!(message instanceof CommitProtocolCommand)) {
+            unknownMessage(message);
+            return;
+        }
+
+        CommitProtocolCommand<?> command = (CommitProtocolCommand<?>)message;
+        CohortBehaviour<?, ?> currentState = currentStateMap.computeIfAbsent(command.getTxId(), key -> idleState);
+
         LOG.debug("handleReceive for cohort {} - currentState: {}, message: {}", cohort.getClass().getName(),
                 currentState, message);
 
         LOG.debug("handleReceive for cohort {} - currentState: {}, message: {}", cohort.getClass().getName(),
                 currentState, message);
 
-        currentState = currentState.handle(message);
+        currentState.handle(command);
     }
 
     }
 
-
     /**
      * Abstract message base for messages handled by {@link DataTreeCohortActor}.
      *
     /**
      * Abstract message base for messages handled by {@link DataTreeCohortActor}.
      *
@@ -61,7 +77,7 @@ final class DataTreeCohortActor extends AbstractUntypedActor {
         }
 
         protected CommitProtocolCommand(TransactionIdentifier txId) {
         }
 
         protected CommitProtocolCommand(TransactionIdentifier txId) {
-            this.txId = Preconditions.checkNotNull(txId);
+            this.txId = Objects.requireNonNull(txId);
         }
 
         @Override
         }
 
         @Override
@@ -79,9 +95,9 @@ final class DataTreeCohortActor extends AbstractUntypedActor {
         CanCommit(TransactionIdentifier txId, Collection<DOMDataTreeCandidate> candidates, SchemaContext schema,
                 ActorRef cohort) {
             super(txId);
         CanCommit(TransactionIdentifier txId, Collection<DOMDataTreeCandidate> candidates, SchemaContext schema,
                 ActorRef cohort) {
             super(txId);
-            this.cohort = Preconditions.checkNotNull(cohort);
-            this.candidates = Preconditions.checkNotNull(candidates);
-            this.schema = Preconditions.checkNotNull(schema);
+            this.cohort = Objects.requireNonNull(cohort);
+            this.candidates = Objects.requireNonNull(candidates);
+            this.schema = Objects.requireNonNull(schema);
         }
 
         Collection<DOMDataTreeCandidate> getCandidates() {
         }
 
         Collection<DOMDataTreeCandidate> getCandidates() {
@@ -108,8 +124,8 @@ final class DataTreeCohortActor extends AbstractUntypedActor {
         private final TransactionIdentifier txId;
 
         protected CommitReply(ActorRef cohortRef, TransactionIdentifier txId) {
         private final TransactionIdentifier txId;
 
         protected CommitReply(ActorRef cohortRef, TransactionIdentifier txId) {
-            this.cohortRef = Preconditions.checkNotNull(cohortRef);
-            this.txId = Preconditions.checkNotNull(txId);
+            this.cohortRef = Objects.requireNonNull(cohortRef);
+            this.txId = Objects.requireNonNull(txId);
         }
 
         ActorRef getCohort() {
         }
 
         ActorRef getCohort() {
@@ -154,23 +170,78 @@ final class DataTreeCohortActor extends AbstractUntypedActor {
         }
     }
 
         }
     }
 
-    private abstract static class CohortBehaviour<E> {
+    private abstract class CohortBehaviour<M extends CommitProtocolCommand<?>, S extends ThreePhaseCommitStep> {
+        private final Class<M> handledMessageType;
 
 
-        abstract Class<E> getHandledMessageType();
+        CohortBehaviour(Class<M> handledMessageType) {
+            this.handledMessageType = Objects.requireNonNull(handledMessageType);
+        }
 
 
-        CohortBehaviour<?> handle(Object message) {
-            if (getHandledMessageType().isInstance(message)) {
-                return process(getHandledMessageType().cast(message));
-            } else if (message instanceof Abort) {
-                return abort();
+        void handle(CommitProtocolCommand<?> command) {
+            if (handledMessageType.isInstance(command)) {
+                onMessage(command);
+            } else if (command instanceof Abort) {
+                onAbort(((Abort)command).getTxId());
+            } else {
+                getSender().tell(new Status.Failure(new IllegalArgumentException(String.format(
+                        "Unexpected message %s in cohort behavior %s", command.getClass(),
+                        getClass().getSimpleName()))), getSelf());
             }
             }
-            throw new UnsupportedOperationException(String.format("Unexpected message %s in cohort behavior %s",
-                    message.getClass(), getClass().getSimpleName()));
         }
 
         }
 
-        abstract CohortBehaviour<?> abort();
+        private void onMessage(CommitProtocolCommand<?> message) {
+            final ActorRef sender = getSender();
+            TransactionIdentifier txId = message.getTxId();
+            ListenableFuture<S> future = process(handledMessageType.cast(message));
+            Executor callbackExecutor = future.isDone() ? MoreExecutors.directExecutor()
+                    : runnable -> executeInSelf(runnable);
+            Futures.addCallback(future, new FutureCallback<S>() {
+                @Override
+                public void onSuccess(S nextStep) {
+                    success(txId, sender, nextStep);
+                }
+
+                @Override
+                public void onFailure(Throwable failure) {
+                    failed(txId, sender, failure);
+                }
+            }, callbackExecutor);
+        }
+
+        private void failed(TransactionIdentifier txId, ActorRef sender, Throwable failure) {
+            currentStateMap.remove(txId);
+            sender.tell(new Status.Failure(failure), getSelf());
+        }
+
+        private void success(TransactionIdentifier txId, ActorRef sender, S nextStep) {
+            currentStateMap.computeIfPresent(txId, (key, behaviour) -> nextBehaviour(txId, nextStep));
+            sender.tell(new Success(getSelf(), txId), getSelf());
+        }
+
+        private void onAbort(TransactionIdentifier txId) {
+            currentStateMap.remove(txId);
+            final ActorRef sender = getSender();
+            Futures.addCallback(abort(), new FutureCallback<Object>() {
+                @Override
+                public void onSuccess(Object noop) {
+                    sender.tell(new Success(getSelf(), txId), getSelf());
+                }
+
+                @Override
+                public void onFailure(Throwable failure) {
+                    LOG.warn("Abort of transaction {} failed for cohort {}", txId, cohort, failure);
+                    sender.tell(new Status.Failure(failure), getSelf());
+                }
+            }, MoreExecutors.directExecutor());
+        }
+
+        @Nullable
+        abstract CohortBehaviour<?, ?> nextBehaviour(TransactionIdentifier txId, S nextStep);
+
+        @Nonnull
+        abstract ListenableFuture<S> process(M command);
 
 
-        abstract CohortBehaviour<?> process(E message);
+        abstract ListenableFuture<?> abort();
 
         @Override
         public String toString() {
 
         @Override
         public String toString() {
@@ -178,65 +249,45 @@ final class DataTreeCohortActor extends AbstractUntypedActor {
         }
     }
 
         }
     }
 
-    private class Idle extends CohortBehaviour<CanCommit> {
+    private class Idle extends CohortBehaviour<CanCommit, PostCanCommitStep> {
+        Idle() {
+            super(CanCommit.class);
+        }
 
         @Override
 
         @Override
-        Class<CanCommit> getHandledMessageType() {
-            return CanCommit.class;
+        ListenableFuture<PostCanCommitStep> process(CanCommit message) {
+            return cohort.canCommit(message.getTxId(), message.getCandidates(), message.getSchema());
         }
 
         @Override
         }
 
         @Override
-        @SuppressWarnings("checkstyle:IllegalCatch")
-        CohortBehaviour<?> process(CanCommit message) {
-            final PostCanCommitStep nextStep;
-            try {
-                nextStep = cohort.canCommit(message.getTxId(), message.getCandidates(), message.getSchema()).get();
-            } catch (final Exception e) {
-                getSender().tell(new Status.Failure(e), getSelf());
-                return this;
-            }
-            getSender().tell(new Success(getSelf(), message.getTxId()), getSelf());
-            return new PostCanCommit(message.getTxId(), nextStep);
+        CohortBehaviour<?, ?> nextBehaviour(TransactionIdentifier txId, PostCanCommitStep nextStep) {
+            return new PostCanCommit(txId, nextStep);
         }
 
         @Override
         }
 
         @Override
-        CohortBehaviour<?> abort() {
-            return this;
+        ListenableFuture<?> abort() {
+            return ThreePhaseCommitStep.NOOP_ABORT_FUTURE;
         }
     }
 
         }
     }
 
-
-    private abstract class CohortStateWithStep<M extends CommitProtocolCommand<?>, S extends ThreePhaseCommitStep>
-            extends CohortBehaviour<M> {
-
+    private abstract class CohortStateWithStep<M extends CommitProtocolCommand<?>, S extends ThreePhaseCommitStep,
+            N extends ThreePhaseCommitStep> extends CohortBehaviour<M, N> {
         private final S step;
         private final TransactionIdentifier txId;
 
         private final S step;
         private final TransactionIdentifier txId;
 
-        CohortStateWithStep(TransactionIdentifier txId, S step) {
-            this.txId = Preconditions.checkNotNull(txId);
-            this.step = Preconditions.checkNotNull(step);
+        CohortStateWithStep(Class<M> handledMessageType, TransactionIdentifier txId, S step) {
+            super(handledMessageType);
+            this.txId = Objects.requireNonNull(txId);
+            this.step = Objects.requireNonNull(step);
         }
 
         final S getStep() {
             return step;
         }
 
         }
 
         final S getStep() {
             return step;
         }
 
-        final TransactionIdentifier getTxId() {
-            return txId;
-        }
-
         @Override
         @Override
-        @SuppressWarnings("checkstyle:IllegalCatch")
-        final CohortBehaviour<?> abort() {
-            try {
-                getStep().abort().get();
-            } catch (final Exception e) {
-                LOG.warn("Abort of transaction {} failed for cohort {}", txId, cohort, e);
-                getSender().tell(new Status.Failure(e), getSelf());
-                return idleState;
-            }
-            getSender().tell(new Success(getSelf(), txId), getSelf());
-            return idleState;
+        ListenableFuture<?> abort() {
+            return getStep().abort();
         }
 
         @Override
         }
 
         @Override
@@ -245,57 +296,44 @@ final class DataTreeCohortActor extends AbstractUntypedActor {
         }
     }
 
         }
     }
 
-    private class PostCanCommit extends CohortStateWithStep<PreCommit, PostCanCommitStep> {
+    private class PostCanCommit extends CohortStateWithStep<PreCommit, PostCanCommitStep, PostPreCommitStep> {
 
         PostCanCommit(TransactionIdentifier txId, PostCanCommitStep nextStep) {
 
         PostCanCommit(TransactionIdentifier txId, PostCanCommitStep nextStep) {
-            super(txId, nextStep);
+            super(PreCommit.class, txId, nextStep);
         }
 
         }
 
+        @SuppressWarnings("unchecked")
         @Override
         @Override
-        Class<PreCommit> getHandledMessageType() {
-            return PreCommit.class;
+        ListenableFuture<PostPreCommitStep> process(PreCommit message) {
+            return (ListenableFuture<PostPreCommitStep>) getStep().preCommit();
         }
 
         @Override
         }
 
         @Override
-        @SuppressWarnings("checkstyle:IllegalCatch")
-        CohortBehaviour<?> process(PreCommit message) {
-            final PostPreCommitStep nextStep;
-            try {
-                nextStep = getStep().preCommit().get();
-            } catch (final Exception e) {
-                getSender().tell(new Status.Failure(e), getSelf());
-                return idleState;
-            }
-            getSender().tell(new Success(getSelf(), message.getTxId()), getSelf());
-            return new PostPreCommit(getTxId(), nextStep);
+        CohortBehaviour<?, ?> nextBehaviour(TransactionIdentifier txId, PostPreCommitStep nextStep) {
+            return new PostPreCommit(txId, nextStep);
         }
 
     }
 
         }
 
     }
 
-    private class PostPreCommit extends CohortStateWithStep<Commit, PostPreCommitStep> {
+    private class PostPreCommit extends CohortStateWithStep<Commit, PostPreCommitStep, NoopThreePhaseCommitStep> {
 
         PostPreCommit(TransactionIdentifier txId, PostPreCommitStep step) {
 
         PostPreCommit(TransactionIdentifier txId, PostPreCommitStep step) {
-            super(txId, step);
+            super(Commit.class, txId, step);
         }
 
         }
 
+        @SuppressWarnings("unchecked")
         @Override
         @Override
-        @SuppressWarnings("checkstyle:IllegalCatch")
-        CohortBehaviour<?> process(Commit message) {
-            try {
-                getStep().commit().get();
-            } catch (final Exception e) {
-                getSender().tell(new Status.Failure(e), getSender());
-                return idleState;
-            }
-            getSender().tell(new Success(getSelf(), getTxId()), getSelf());
-            return idleState;
+        ListenableFuture<NoopThreePhaseCommitStep> process(Commit message) {
+            return (ListenableFuture<NoopThreePhaseCommitStep>) getStep().commit();
         }
 
         @Override
         }
 
         @Override
-        Class<Commit> getHandledMessageType() {
-            return Commit.class;
+        CohortBehaviour<?, ?> nextBehaviour(TransactionIdentifier txId, NoopThreePhaseCommitStep nextStep) {
+            return null;
         }
         }
+    }
 
 
+    private interface NoopThreePhaseCommitStep extends ThreePhaseCommitStep {
     }
 
     static Props props(final DOMDataTreeCommitCohort cohort, final YangInstanceIdentifier registeredPath) {
     }
 
     static Props props(final DOMDataTreeCommitCohort cohort, final YangInstanceIdentifier registeredPath) {
diff --git a/opendaylight/md-sal/sal-distributed-datastore/src/test/java/org/opendaylight/controller/cluster/datastore/DataTreeCohortActorTest.java b/opendaylight/md-sal/sal-distributed-datastore/src/test/java/org/opendaylight/controller/cluster/datastore/DataTreeCohortActorTest.java
new file mode 100644 (file)
index 0000000..b1b3ff4
--- /dev/null
@@ -0,0 +1,222 @@
+/*
+ * Copyright (c) 2018 Inocybe Technologies and others.  All rights reserved.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License v1.0 which accompanies this distribution,
+ * and is available at http://www.eclipse.org/legal/epl-v10.html
+ */
+package org.opendaylight.controller.cluster.datastore;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.reset;
+import static org.mockito.Mockito.verify;
+
+import akka.actor.ActorRef;
+import akka.pattern.Patterns;
+import akka.util.Timeout;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.JdkFutureAdapters;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.MoreExecutors;
+import com.google.common.util.concurrent.Uninterruptibles;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.opendaylight.controller.cluster.access.concepts.TransactionIdentifier;
+import org.opendaylight.controller.cluster.datastore.DataTreeCohortActor.Abort;
+import org.opendaylight.controller.cluster.datastore.DataTreeCohortActor.CanCommit;
+import org.opendaylight.controller.cluster.datastore.DataTreeCohortActor.Commit;
+import org.opendaylight.controller.cluster.datastore.DataTreeCohortActor.CommitProtocolCommand;
+import org.opendaylight.controller.cluster.datastore.DataTreeCohortActor.PreCommit;
+import org.opendaylight.controller.cluster.datastore.DataTreeCohortActor.Success;
+import org.opendaylight.controller.cluster.raft.TestActorFactory;
+import org.opendaylight.mdsal.common.api.PostCanCommitStep;
+import org.opendaylight.mdsal.common.api.PostPreCommitStep;
+import org.opendaylight.mdsal.common.api.ThreePhaseCommitStep;
+import org.opendaylight.mdsal.dom.api.DOMDataTreeCandidate;
+import org.opendaylight.mdsal.dom.api.DOMDataTreeCommitCohort;
+import org.opendaylight.yangtools.yang.data.api.YangInstanceIdentifier;
+import org.opendaylight.yangtools.yang.data.api.schema.tree.DataValidationFailedException;
+import org.opendaylight.yangtools.yang.model.api.SchemaContext;
+import scala.concurrent.Await;
+
+/**
+ * Unit tests for DataTreeCohortActor.
+ *
+ * @author Thomas Pantelis
+ */
+public class DataTreeCohortActorTest extends AbstractActorTest {
+    private static final Collection<DOMDataTreeCandidate> CANDIDATES = new ArrayList<>();
+    private static final SchemaContext MOCK_SCHEMA = mock(SchemaContext.class);
+    private final TestActorFactory actorFactory = new TestActorFactory(getSystem());
+    private final DOMDataTreeCommitCohort mockCohort = mock(DOMDataTreeCommitCohort.class);
+    private final PostCanCommitStep mockPostCanCommit = mock(PostCanCommitStep.class);
+    private final PostPreCommitStep mockPostPreCommit = mock(PostPreCommitStep.class);
+
+    @Before
+    public void setup() {
+        resetMockCohort();
+    }
+
+    @After
+    public void tearDown() {
+        actorFactory.close();
+    }
+
+    @Test
+    public void testSuccessfulThreePhaseCommit() throws Exception {
+        ActorRef cohortActor = newCohortActor("testSuccessfulThreePhaseCommit");
+
+        TransactionIdentifier txId = nextTransactionId();
+        askAndAwait(cohortActor, new CanCommit(txId, CANDIDATES, MOCK_SCHEMA, cohortActor));
+        verify(mockCohort).canCommit(txId, CANDIDATES, MOCK_SCHEMA);
+
+        askAndAwait(cohortActor, new PreCommit(txId));
+        verify(mockPostCanCommit).preCommit();
+
+        askAndAwait(cohortActor, new Commit(txId));
+        verify(mockPostPreCommit).commit();
+
+        resetMockCohort();
+        askAndAwait(cohortActor, new CanCommit(txId, CANDIDATES, MOCK_SCHEMA, cohortActor));
+        verify(mockCohort).canCommit(txId, CANDIDATES, MOCK_SCHEMA);
+    }
+
+    @Test
+    public void testMultipleThreePhaseCommits() throws Exception {
+        ActorRef cohortActor = newCohortActor("testMultipleThreePhaseCommits");
+
+        TransactionIdentifier txId1 = nextTransactionId();
+        TransactionIdentifier txId2 = nextTransactionId();
+
+        askAndAwait(cohortActor, new CanCommit(txId1, CANDIDATES, MOCK_SCHEMA, cohortActor));
+        askAndAwait(cohortActor, new CanCommit(txId2, CANDIDATES, MOCK_SCHEMA, cohortActor));
+
+        askAndAwait(cohortActor, new PreCommit(txId1));
+        askAndAwait(cohortActor, new PreCommit(txId2));
+
+        askAndAwait(cohortActor, new Commit(txId1));
+        askAndAwait(cohortActor, new Commit(txId2));
+    }
+
+    @SuppressWarnings("unchecked")
+    @Test
+    public void testAsyncCohort() throws Exception {
+        ExecutorService executor = Executors.newSingleThreadExecutor();
+
+        doReturn(Futures.makeChecked(executeWithDelay(executor, mockPostCanCommit),
+            ex -> new DataValidationFailedException(YangInstanceIdentifier.EMPTY, "mock")))
+                .when(mockCohort).canCommit(any(Object.class), any(Collection.class), any(SchemaContext.class));
+
+        doReturn(JdkFutureAdapters.listenInPoolThread(executor.submit(() ->
+            mockPostPreCommit), MoreExecutors.directExecutor())).when(mockPostCanCommit).preCommit();
+
+        doReturn(JdkFutureAdapters.listenInPoolThread(executor.submit(() ->
+            null), MoreExecutors.directExecutor())).when(mockPostPreCommit).commit();
+
+        ActorRef cohortActor = newCohortActor("testAsyncCohort");
+
+        TransactionIdentifier txId = nextTransactionId();
+        askAndAwait(cohortActor, new CanCommit(txId, CANDIDATES, MOCK_SCHEMA, cohortActor));
+        verify(mockCohort).canCommit(txId, CANDIDATES, MOCK_SCHEMA);
+
+        askAndAwait(cohortActor, new PreCommit(txId));
+        verify(mockPostCanCommit).preCommit();
+
+        askAndAwait(cohortActor, new Commit(txId));
+        verify(mockPostPreCommit).commit();
+
+        executor.shutdownNow();
+    }
+
+    @SuppressWarnings("unchecked")
+    @Test
+    public void testFailureOnCanCommit() throws Exception {
+        DataValidationFailedException failure = new DataValidationFailedException(YangInstanceIdentifier.EMPTY, "mock");
+        doReturn(Futures.immediateFailedCheckedFuture(failure)).when(mockCohort).canCommit(any(Object.class),
+                any(Collection.class), any(SchemaContext.class));
+
+        ActorRef cohortActor = newCohortActor("testFailureOnCanCommit");
+
+        TransactionIdentifier txId = nextTransactionId();
+        try {
+            askAndAwait(cohortActor, new CanCommit(txId, CANDIDATES, MOCK_SCHEMA, cohortActor));
+        } catch (DataValidationFailedException e) {
+            assertEquals("DataValidationFailedException", failure, e);
+        }
+
+        resetMockCohort();
+        askAndAwait(cohortActor, new CanCommit(txId, CANDIDATES, MOCK_SCHEMA, cohortActor));
+        verify(mockCohort).canCommit(txId, CANDIDATES, MOCK_SCHEMA);
+    }
+
+    @Test
+    public void testAbortAfterCanCommit() throws Exception {
+        ActorRef cohortActor = newCohortActor("testAbortAfterCanCommit");
+
+        TransactionIdentifier txId = nextTransactionId();
+        askAndAwait(cohortActor, new CanCommit(txId, CANDIDATES, MOCK_SCHEMA, cohortActor));
+        verify(mockCohort).canCommit(txId, CANDIDATES, MOCK_SCHEMA);
+
+        askAndAwait(cohortActor, new Abort(txId));
+        verify(mockPostCanCommit).abort();
+
+        resetMockCohort();
+        askAndAwait(cohortActor, new CanCommit(txId, CANDIDATES, MOCK_SCHEMA, cohortActor));
+        verify(mockCohort).canCommit(txId, CANDIDATES, MOCK_SCHEMA);
+    }
+
+    @Test
+    public void testAbortAfterPreCommit() throws Exception {
+        ActorRef cohortActor = newCohortActor("testAbortAfterPreCommit");
+
+        TransactionIdentifier txId = nextTransactionId();
+        askAndAwait(cohortActor, new CanCommit(txId, CANDIDATES, MOCK_SCHEMA, cohortActor));
+        verify(mockCohort).canCommit(txId, CANDIDATES, MOCK_SCHEMA);
+
+        askAndAwait(cohortActor, new PreCommit(txId));
+        verify(mockPostCanCommit).preCommit();
+
+        askAndAwait(cohortActor, new Abort(txId));
+        verify(mockPostPreCommit).abort();
+    }
+
+    private <T> ListenableFuture<T> executeWithDelay(ExecutorService executor, T result) {
+        return JdkFutureAdapters.listenInPoolThread(executor.submit(() -> {
+            Uninterruptibles.sleepUninterruptibly(500, TimeUnit.MILLISECONDS);
+            return result;
+        }), MoreExecutors.directExecutor());
+    }
+
+    private ActorRef newCohortActor(String name) {
+        return actorFactory.createActor(DataTreeCohortActor.props(mockCohort, YangInstanceIdentifier.EMPTY), name);
+    }
+
+    @SuppressWarnings("unchecked")
+    private void resetMockCohort() {
+        reset(mockCohort);
+        doReturn(ThreePhaseCommitStep.NOOP_ABORT_FUTURE).when(mockPostCanCommit).abort();
+        doReturn(Futures.immediateFuture(mockPostPreCommit)).when(mockPostCanCommit).preCommit();
+        doReturn(Futures.immediateCheckedFuture(mockPostCanCommit)).when(mockCohort).canCommit(any(Object.class),
+                any(Collection.class), any(SchemaContext.class));
+
+        doReturn(ThreePhaseCommitStep.NOOP_ABORT_FUTURE).when(mockPostPreCommit).abort();
+        doReturn(Futures.immediateFuture(null)).when(mockPostPreCommit).commit();
+    }
+
+    private static void askAndAwait(ActorRef actor, CommitProtocolCommand<?> message) throws Exception {
+        Timeout timeout = new Timeout(5, TimeUnit.SECONDS);
+        Object result = Await.result(Patterns.ask(actor, message, timeout), timeout.duration());
+        assertTrue("Expected Success but was " + result, result instanceof Success);
+        assertEquals("Success", message.getTxId(), ((Success)result).getTxId());
+    }
+}