Elide front-end 3PC for single-shard Tx
[controller.git] / opendaylight / md-sal / sal-distributed-datastore / src / test / java / org / opendaylight / controller / cluster / datastore / ShardTest.java
index e3b82df1743e75c433cec193d54a2cbfbd696319..72f672794ab16c9bf89920bcf855cb72fd4bcc63 100644 (file)
@@ -8,6 +8,7 @@ import static org.junit.Assert.assertTrue;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.inOrder;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.reset;
 import static org.opendaylight.controller.cluster.datastore.DataStoreVersions.CURRENT_VERSION;
 import akka.actor.ActorRef;
 import akka.actor.ActorSelection;
@@ -469,7 +470,7 @@ public class ShardTest extends AbstractShardTest {
             // by the ShardTransaction.
 
             shard.tell(new ForwardedReadyTransaction(transactionID1, CURRENT_VERSION,
-                    cohort1, modification1, true), getRef());
+                    cohort1, modification1, true, false), getRef());
             ReadyTransactionReply readyReply = ReadyTransactionReply.fromSerializable(
                     expectMsgClass(duration, ReadyTransactionReply.class));
             assertEquals("Cohort path", shard.path().toString(), readyReply.getCohortPath());
@@ -484,11 +485,11 @@ public class ShardTest extends AbstractShardTest {
             // Send the ForwardedReadyTransaction for the next 2 Tx's.
 
             shard.tell(new ForwardedReadyTransaction(transactionID2, CURRENT_VERSION,
-                    cohort2, modification2, true), getRef());
+                    cohort2, modification2, true, false), getRef());
             expectMsgClass(duration, ReadyTransactionReply.class);
 
             shard.tell(new ForwardedReadyTransaction(transactionID3, CURRENT_VERSION,
-                    cohort3, modification3, true), getRef());
+                    cohort3, modification3, true, false), getRef());
             expectMsgClass(duration, ReadyTransactionReply.class);
 
             // Send the CanCommitTransaction message for the next 2 Tx's. These should get queued and
@@ -595,18 +596,7 @@ public class ShardTest extends AbstractShardTest {
 
             // Verify data in the data store.
 
-            NormalizedNode<?, ?> outerList = readStore(shard, TestModel.OUTER_LIST_PATH);
-            assertNotNull(TestModel.OUTER_LIST_QNAME.getLocalName() + " not found", outerList);
-            assertTrue(TestModel.OUTER_LIST_QNAME.getLocalName() + " value is not Iterable",
-                    outerList.getValue() instanceof Iterable);
-            Object entry = ((Iterable<Object>)outerList.getValue()).iterator().next();
-            assertTrue(TestModel.OUTER_LIST_QNAME.getLocalName() + " entry is not MapEntryNode",
-                       entry instanceof MapEntryNode);
-            MapEntryNode mapEntry = (MapEntryNode)entry;
-            Optional<DataContainerChild<? extends PathArgument, ?>> idLeaf =
-                    mapEntry.getChild(new YangInstanceIdentifier.NodeIdentifier(TestModel.ID_QNAME));
-            assertTrue("Missing leaf " + TestModel.ID_QNAME.getLocalName(), idLeaf.isPresent());
-            assertEquals(TestModel.ID_QNAME.getLocalName() + " value", 1, idLeaf.get().getValue());
+            verifyOuterListEntry(shard, 1);
 
             verifyLastApplied(shard, 2);
 
@@ -615,25 +605,25 @@ public class ShardTest extends AbstractShardTest {
     }
 
     private BatchedModifications newBatchedModifications(String transactionID, YangInstanceIdentifier path,
-            NormalizedNode<?, ?> data, boolean ready) {
-        return newBatchedModifications(transactionID, null, path, data, ready);
+            NormalizedNode<?, ?> data, boolean ready, boolean doCommitOnReady) {
+        return newBatchedModifications(transactionID, null, path, data, ready, doCommitOnReady);
     }
 
     private BatchedModifications newBatchedModifications(String transactionID, String transactionChainID,
-            YangInstanceIdentifier path, NormalizedNode<?, ?> data, boolean ready) {
+            YangInstanceIdentifier path, NormalizedNode<?, ?> data, boolean ready, boolean doCommitOnReady) {
         BatchedModifications batched = new BatchedModifications(transactionID, CURRENT_VERSION, transactionChainID);
         batched.addModification(new WriteModification(path, data));
         batched.setReady(ready);
+        batched.setDoCommitOnReady(doCommitOnReady);
         return batched;
     }
 
-    @SuppressWarnings("unchecked")
     @Test
-    public void testMultipleBatchedModifications() throws Throwable {
+    public void testBatchedModificationsWithNoCommitOnReady() throws Throwable {
         new ShardTestKit(getSystem()) {{
             final TestActorRef<Shard> shard = TestActorRef.create(getSystem(),
                     newShardProps().withDispatcher(Dispatchers.DefaultDispatcherId()),
-                    "testMultipleBatchedModifications");
+                    "testBatchedModificationsWithNoCommitOnReady");
 
             waitUntilLeader(shard);
 
@@ -657,18 +647,18 @@ public class ShardTest extends AbstractShardTest {
             // Send a BatchedModifications to start a transaction.
 
             shard.tell(newBatchedModifications(transactionID, TestModel.TEST_PATH,
-                    ImmutableNodes.containerNode(TestModel.TEST_QNAME), false), getRef());
+                    ImmutableNodes.containerNode(TestModel.TEST_QNAME), false, false), getRef());
             expectMsgClass(duration, BatchedModificationsReply.class);
 
             // Send a couple more BatchedModifications.
 
             shard.tell(newBatchedModifications(transactionID, TestModel.OUTER_LIST_PATH,
-                    ImmutableNodes.mapNodeBuilder(TestModel.OUTER_LIST_QNAME).build(), false), getRef());
+                    ImmutableNodes.mapNodeBuilder(TestModel.OUTER_LIST_QNAME).build(), false, false), getRef());
             expectMsgClass(duration, BatchedModificationsReply.class);
 
             shard.tell(newBatchedModifications(transactionID, YangInstanceIdentifier.builder(
                     TestModel.OUTER_LIST_PATH).nodeWithKey(TestModel.OUTER_LIST_QNAME, TestModel.ID_QNAME, 1).build(),
-                    ImmutableNodes.mapEntry(TestModel.OUTER_LIST_QNAME, TestModel.ID_QNAME, 1), true), getRef());
+                    ImmutableNodes.mapEntry(TestModel.OUTER_LIST_QNAME, TestModel.ID_QNAME, 1), true, false), getRef());
             expectMsgClass(duration, ReadyTransactionReply.class);
 
             // Send the CanCommitTransaction message.
@@ -690,23 +680,85 @@ public class ShardTest extends AbstractShardTest {
 
             // Verify data in the data store.
 
-            NormalizedNode<?, ?> outerList = readStore(shard, TestModel.OUTER_LIST_PATH);
-            assertNotNull(TestModel.OUTER_LIST_QNAME.getLocalName() + " not found", outerList);
-            assertTrue(TestModel.OUTER_LIST_QNAME.getLocalName() + " value is not Iterable",
-                    outerList.getValue() instanceof Iterable);
-            Object entry = ((Iterable<Object>)outerList.getValue()).iterator().next();
-            assertTrue(TestModel.OUTER_LIST_QNAME.getLocalName() + " entry is not MapEntryNode",
-                       entry instanceof MapEntryNode);
-            MapEntryNode mapEntry = (MapEntryNode)entry;
-            Optional<DataContainerChild<? extends PathArgument, ?>> idLeaf =
-                    mapEntry.getChild(new YangInstanceIdentifier.NodeIdentifier(TestModel.ID_QNAME));
-            assertTrue("Missing leaf " + TestModel.ID_QNAME.getLocalName(), idLeaf.isPresent());
-            assertEquals(TestModel.ID_QNAME.getLocalName() + " value", 1, idLeaf.get().getValue());
+            verifyOuterListEntry(shard, 1);
 
             shard.tell(PoisonPill.getInstance(), ActorRef.noSender());
         }};
     }
 
+    @Test
+    public void testBatchedModificationsWithCommitOnReady() throws Throwable {
+        new ShardTestKit(getSystem()) {{
+            final TestActorRef<Shard> shard = TestActorRef.create(getSystem(),
+                    newShardProps().withDispatcher(Dispatchers.DefaultDispatcherId()),
+                    "testBatchedModificationsWithCommitOnReady");
+
+            waitUntilLeader(shard);
+
+            final String transactionID = "tx";
+            FiniteDuration duration = duration("5 seconds");
+
+            final AtomicReference<DOMStoreThreePhaseCommitCohort> mockCohort = new AtomicReference<>();
+            ShardCommitCoordinator.CohortDecorator cohortDecorator = new ShardCommitCoordinator.CohortDecorator() {
+                @Override
+                public DOMStoreThreePhaseCommitCohort decorate(String txID, DOMStoreThreePhaseCommitCohort actual) {
+                    if(mockCohort.get() == null) {
+                        mockCohort.set(createDelegatingMockCohort("cohort", actual));
+                    }
+
+                    return mockCohort.get();
+                }
+            };
+
+            shard.underlyingActor().getCommitCoordinator().setCohortDecorator(cohortDecorator);
+
+            // Send a BatchedModifications to start a transaction.
+
+            shard.tell(newBatchedModifications(transactionID, TestModel.TEST_PATH,
+                    ImmutableNodes.containerNode(TestModel.TEST_QNAME), false, false), getRef());
+            expectMsgClass(duration, BatchedModificationsReply.class);
+
+            // Send a couple more BatchedModifications.
+
+            shard.tell(newBatchedModifications(transactionID, TestModel.OUTER_LIST_PATH,
+                    ImmutableNodes.mapNodeBuilder(TestModel.OUTER_LIST_QNAME).build(), false, false), getRef());
+            expectMsgClass(duration, BatchedModificationsReply.class);
+
+            shard.tell(newBatchedModifications(transactionID, YangInstanceIdentifier.builder(
+                    TestModel.OUTER_LIST_PATH).nodeWithKey(TestModel.OUTER_LIST_QNAME, TestModel.ID_QNAME, 1).build(),
+                    ImmutableNodes.mapEntry(TestModel.OUTER_LIST_QNAME, TestModel.ID_QNAME, 1), true, true), getRef());
+
+            expectMsgClass(duration, CommitTransactionReply.SERIALIZABLE_CLASS);
+
+            InOrder inOrder = inOrder(mockCohort.get());
+            inOrder.verify(mockCohort.get()).canCommit();
+            inOrder.verify(mockCohort.get()).preCommit();
+            inOrder.verify(mockCohort.get()).commit();
+
+            // Verify data in the data store.
+
+            verifyOuterListEntry(shard, 1);
+
+            shard.tell(PoisonPill.getInstance(), ActorRef.noSender());
+        }};
+    }
+
+    @SuppressWarnings("unchecked")
+    private void verifyOuterListEntry(final TestActorRef<Shard> shard, Object expIDValue) throws Exception {
+        NormalizedNode<?, ?> outerList = readStore(shard, TestModel.OUTER_LIST_PATH);
+        assertNotNull(TestModel.OUTER_LIST_QNAME.getLocalName() + " not found", outerList);
+        assertTrue(TestModel.OUTER_LIST_QNAME.getLocalName() + " value is not Iterable",
+                outerList.getValue() instanceof Iterable);
+        Object entry = ((Iterable<Object>)outerList.getValue()).iterator().next();
+        assertTrue(TestModel.OUTER_LIST_QNAME.getLocalName() + " entry is not MapEntryNode",
+                entry instanceof MapEntryNode);
+        MapEntryNode mapEntry = (MapEntryNode)entry;
+        Optional<DataContainerChild<? extends PathArgument, ?>> idLeaf =
+                mapEntry.getChild(new YangInstanceIdentifier.NodeIdentifier(TestModel.ID_QNAME));
+        assertTrue("Missing leaf " + TestModel.ID_QNAME.getLocalName(), idLeaf.isPresent());
+        assertEquals(TestModel.ID_QNAME.getLocalName() + " value", expIDValue, idLeaf.get().getValue());
+    }
+
     @Test
     public void testBatchedModificationsOnTransactionChain() throws Throwable {
         new ShardTestKit(getSystem()) {{
@@ -727,7 +779,7 @@ public class ShardTest extends AbstractShardTest {
             ContainerNode containerNode = ImmutableNodes.containerNode(TestModel.TEST_QNAME);
             YangInstanceIdentifier path = TestModel.TEST_PATH;
             shard.tell(newBatchedModifications(transactionID1, transactionChainID, path,
-                    containerNode, true), getRef());
+                    containerNode, true, false), getRef());
             expectMsgClass(duration, ReadyTransactionReply.class);
 
             // Create a read Tx on the same chain.
@@ -800,6 +852,45 @@ public class ShardTest extends AbstractShardTest {
         }};
     }
 
+    @Test
+    public void testForwardedReadyTransactionWithImmediateCommit() throws Exception{
+        new ShardTestKit(getSystem()) {{
+            final TestActorRef<Shard> shard = TestActorRef.create(getSystem(),
+                    newShardProps().withDispatcher(Dispatchers.DefaultDispatcherId()),
+                    "testForwardedReadyTransactionWithImmediateCommit");
+
+            waitUntilLeader(shard);
+
+            InMemoryDOMDataStore dataStore = shard.underlyingActor().getDataStore();
+
+            String transactionID = "tx1";
+            MutableCompositeModification modification = new MutableCompositeModification();
+            NormalizedNode<?, ?> containerNode = ImmutableNodes.containerNode(TestModel.TEST_QNAME);
+            DOMStoreThreePhaseCommitCohort cohort = setupMockWriteTransaction("cohort", dataStore,
+                    TestModel.TEST_PATH, containerNode, modification);
+
+            FiniteDuration duration = duration("5 seconds");
+
+            // Simulate the ForwardedReadyTransaction messages that would be sent
+            // by the ShardTransaction.
+
+            shard.tell(new ForwardedReadyTransaction(transactionID, CURRENT_VERSION,
+                    cohort, modification, true, true), getRef());
+
+            expectMsgClass(duration, ThreePhaseCommitCohortMessages.CommitTransactionReply.class);
+
+            InOrder inOrder = inOrder(cohort);
+            inOrder.verify(cohort).canCommit();
+            inOrder.verify(cohort).preCommit();
+            inOrder.verify(cohort).commit();
+
+            NormalizedNode<?, ?> actualNode = readStore(shard, TestModel.TEST_PATH);
+            assertEquals(TestModel.TEST_QNAME.getLocalName(), containerNode, actualNode);
+
+            shard.tell(PoisonPill.getInstance(), ActorRef.noSender());
+        }};
+    }
+
     @Test
     public void testCommitWithPersistenceDisabled() throws Throwable {
         dataStoreContextBuilder.persistent(false);
@@ -826,7 +917,7 @@ public class ShardTest extends AbstractShardTest {
             // by the ShardTransaction.
 
             shard.tell(new ForwardedReadyTransaction(transactionID, CURRENT_VERSION,
-                    cohort, modification, true), getRef());
+                    cohort, modification, true, false), getRef());
             expectMsgClass(duration, ReadyTransactionReply.class);
 
             // Send the CanCommitTransaction message.
@@ -878,7 +969,7 @@ public class ShardTest extends AbstractShardTest {
                 // by the ShardTransaction.
 
                 shard.tell(new ForwardedReadyTransaction(transactionID, CURRENT_VERSION,
-                        cohort, modification, true), getRef());
+                        cohort, modification, true, false), getRef());
                 expectMsgClass(duration, ReadyTransactionReply.class);
 
                 // Send the CanCommitTransaction message.
@@ -933,7 +1024,7 @@ public class ShardTest extends AbstractShardTest {
                 // by the ShardTransaction.
 
                 shard.tell(new ForwardedReadyTransaction(transactionID, CURRENT_VERSION,
-                        cohort, modification, true), getRef());
+                        cohort, modification, true, false), getRef());
                 expectMsgClass(duration, ReadyTransactionReply.class);
 
                 // Send the CanCommitTransaction message.
@@ -973,7 +1064,7 @@ public class ShardTest extends AbstractShardTest {
 
             waitUntilLeader(shard);
 
-         // Setup 2 simulated transactions with mock cohorts. The first one fails in the
+            // Setup 2 simulated transactions with mock cohorts. The first one fails in the
             // commit phase.
 
             String transactionID1 = "tx1";
@@ -995,11 +1086,11 @@ public class ShardTest extends AbstractShardTest {
             // by the ShardTransaction.
 
             shard.tell(new ForwardedReadyTransaction(transactionID1, CURRENT_VERSION,
-                    cohort1, modification1, true), getRef());
+                    cohort1, modification1, true, false), getRef());
             expectMsgClass(duration, ReadyTransactionReply.class);
 
             shard.tell(new ForwardedReadyTransaction(transactionID2, CURRENT_VERSION,
-                    cohort2, modification2, true), getRef());
+                    cohort2, modification2, true, false), getRef());
             expectMsgClass(duration, ReadyTransactionReply.class);
 
             // Send the CanCommitTransaction message for the first Tx.
@@ -1052,37 +1143,66 @@ public class ShardTest extends AbstractShardTest {
 
             waitUntilLeader(shard);
 
-            String transactionID = "tx1";
-            MutableCompositeModification modification = new MutableCompositeModification();
-            DOMStoreThreePhaseCommitCohort cohort = mock(DOMStoreThreePhaseCommitCohort.class, "cohort1");
-            doReturn(Futures.immediateFuture(Boolean.TRUE)).when(cohort).canCommit();
-            doReturn(Futures.immediateFailedFuture(new IllegalStateException("mock"))).when(cohort).preCommit();
+            String transactionID1 = "tx1";
+            MutableCompositeModification modification1 = new MutableCompositeModification();
+            DOMStoreThreePhaseCommitCohort cohort1 = mock(DOMStoreThreePhaseCommitCohort.class, "cohort1");
+            doReturn(Futures.immediateFuture(Boolean.TRUE)).when(cohort1).canCommit();
+            doReturn(Futures.immediateFailedFuture(new IllegalStateException("mock"))).when(cohort1).preCommit();
+
+            String transactionID2 = "tx2";
+            MutableCompositeModification modification2 = new MutableCompositeModification();
+            DOMStoreThreePhaseCommitCohort cohort2 = mock(DOMStoreThreePhaseCommitCohort.class, "cohort2");
+            doReturn(Futures.immediateFuture(Boolean.TRUE)).when(cohort2).canCommit();
 
             FiniteDuration duration = duration("5 seconds");
+            final Timeout timeout = new Timeout(duration);
 
             // Simulate the ForwardedReadyTransaction messages that would be sent
             // by the ShardTransaction.
 
-            shard.tell(new ForwardedReadyTransaction(transactionID, CURRENT_VERSION,
-                    cohort, modification, true), getRef());
+            shard.tell(new ForwardedReadyTransaction(transactionID1, CURRENT_VERSION,
+                    cohort1, modification1, true, false), getRef());
             expectMsgClass(duration, ReadyTransactionReply.class);
 
-            // Send the CanCommitTransaction message.
+            shard.tell(new ForwardedReadyTransaction(transactionID2, CURRENT_VERSION,
+                    cohort2, modification2, true, false), getRef());
+            expectMsgClass(duration, ReadyTransactionReply.class);
 
-            shard.tell(new CanCommitTransaction(transactionID).toSerializable(), getRef());
+            // Send the CanCommitTransaction message for the first Tx.
+
+            shard.tell(new CanCommitTransaction(transactionID1).toSerializable(), getRef());
             CanCommitTransactionReply canCommitReply = CanCommitTransactionReply.fromSerializable(
                     expectMsgClass(duration, CanCommitTransactionReply.SERIALIZABLE_CLASS));
             assertEquals("Can commit", true, canCommitReply.getCanCommit());
 
-            // Send the CommitTransaction message. This should send back an error
-            // for preCommit failure.
+            // Send the CanCommitTransaction message for the 2nd Tx. This should get queued and
+            // processed after the first Tx completes.
 
-            shard.tell(new CommitTransaction(transactionID).toSerializable(), getRef());
+            Future<Object> canCommitFuture = Patterns.ask(shard,
+                    new CanCommitTransaction(transactionID2).toSerializable(), timeout);
+
+            // Send the CommitTransaction message for the first Tx. This should send back an error
+            // and trigger the 2nd Tx to proceed.
+
+            shard.tell(new CommitTransaction(transactionID1).toSerializable(), getRef());
             expectMsgClass(duration, akka.actor.Status.Failure.class);
 
-            InOrder inOrder = inOrder(cohort);
-            inOrder.verify(cohort).canCommit();
-            inOrder.verify(cohort).preCommit();
+            // Wait for the 2nd Tx to complete the canCommit phase.
+
+            final CountDownLatch latch = new CountDownLatch(1);
+            canCommitFuture.onComplete(new OnComplete<Object>() {
+                @Override
+                public void onComplete(final Throwable t, final Object resp) {
+                    latch.countDown();
+                }
+            }, getSystem().dispatcher());
+
+            assertEquals("2nd CanCommit complete", true, latch.await(5, TimeUnit.SECONDS));
+
+            InOrder inOrder = inOrder(cohort1, cohort2);
+            inOrder.verify(cohort1).canCommit();
+            inOrder.verify(cohort1).preCommit();
+            inOrder.verify(cohort2).canCommit();
 
             shard.tell(PoisonPill.getInstance(), ActorRef.noSender());
         }};
@@ -1099,7 +1219,7 @@ public class ShardTest extends AbstractShardTest {
 
             final FiniteDuration duration = duration("5 seconds");
 
-            String transactionID = "tx1";
+            String transactionID1 = "tx1";
             MutableCompositeModification modification = new MutableCompositeModification();
             DOMStoreThreePhaseCommitCohort cohort = mock(DOMStoreThreePhaseCommitCohort.class, "cohort1");
             doReturn(Futures.immediateFailedFuture(new IllegalStateException("mock"))).when(cohort).canCommit();
@@ -1107,15 +1227,165 @@ public class ShardTest extends AbstractShardTest {
             // Simulate the ForwardedReadyTransaction messages that would be sent
             // by the ShardTransaction.
 
-            shard.tell(new ForwardedReadyTransaction(transactionID, CURRENT_VERSION,
-                    cohort, modification, true), getRef());
+            shard.tell(new ForwardedReadyTransaction(transactionID1, CURRENT_VERSION,
+                    cohort, modification, true, false), getRef());
             expectMsgClass(duration, ReadyTransactionReply.class);
 
             // Send the CanCommitTransaction message.
 
-            shard.tell(new CanCommitTransaction(transactionID).toSerializable(), getRef());
+            shard.tell(new CanCommitTransaction(transactionID1).toSerializable(), getRef());
             expectMsgClass(duration, akka.actor.Status.Failure.class);
 
+            // Send another can commit to ensure the failed one got cleaned up.
+
+            reset(cohort);
+
+            String transactionID2 = "tx2";
+            doReturn(Futures.immediateFuture(Boolean.TRUE)).when(cohort).canCommit();
+
+            shard.tell(new ForwardedReadyTransaction(transactionID2, CURRENT_VERSION,
+                    cohort, modification, true, false), getRef());
+            expectMsgClass(duration, ReadyTransactionReply.class);
+
+            shard.tell(new CanCommitTransaction(transactionID2).toSerializable(), getRef());
+            CanCommitTransactionReply reply = CanCommitTransactionReply.fromSerializable(
+                    expectMsgClass(CanCommitTransactionReply.SERIALIZABLE_CLASS));
+            assertEquals("getCanCommit", true, reply.getCanCommit());
+
+            shard.tell(PoisonPill.getInstance(), ActorRef.noSender());
+        }};
+    }
+
+    @Test
+    public void testCanCommitPhaseFalseResponse() throws Throwable {
+        new ShardTestKit(getSystem()) {{
+            final TestActorRef<Shard> shard = TestActorRef.create(getSystem(),
+                    newShardProps().withDispatcher(Dispatchers.DefaultDispatcherId()),
+                    "testCanCommitPhaseFalseResponse");
+
+            waitUntilLeader(shard);
+
+            final FiniteDuration duration = duration("5 seconds");
+
+            String transactionID1 = "tx1";
+            MutableCompositeModification modification = new MutableCompositeModification();
+            DOMStoreThreePhaseCommitCohort cohort = mock(DOMStoreThreePhaseCommitCohort.class, "cohort1");
+            doReturn(Futures.immediateFuture(Boolean.FALSE)).when(cohort).canCommit();
+
+            // Simulate the ForwardedReadyTransaction messages that would be sent
+            // by the ShardTransaction.
+
+            shard.tell(new ForwardedReadyTransaction(transactionID1, CURRENT_VERSION,
+                    cohort, modification, true, false), getRef());
+            expectMsgClass(duration, ReadyTransactionReply.class);
+
+            // Send the CanCommitTransaction message.
+
+            shard.tell(new CanCommitTransaction(transactionID1).toSerializable(), getRef());
+            CanCommitTransactionReply reply = CanCommitTransactionReply.fromSerializable(
+                    expectMsgClass(CanCommitTransactionReply.SERIALIZABLE_CLASS));
+            assertEquals("getCanCommit", false, reply.getCanCommit());
+
+            // Send another can commit to ensure the failed one got cleaned up.
+
+            reset(cohort);
+
+            String transactionID2 = "tx2";
+            doReturn(Futures.immediateFuture(Boolean.TRUE)).when(cohort).canCommit();
+
+            shard.tell(new ForwardedReadyTransaction(transactionID2, CURRENT_VERSION,
+                    cohort, modification, true, false), getRef());
+            expectMsgClass(duration, ReadyTransactionReply.class);
+
+            shard.tell(new CanCommitTransaction(transactionID2).toSerializable(), getRef());
+            reply = CanCommitTransactionReply.fromSerializable(
+                    expectMsgClass(CanCommitTransactionReply.SERIALIZABLE_CLASS));
+            assertEquals("getCanCommit", true, reply.getCanCommit());
+
+            shard.tell(PoisonPill.getInstance(), ActorRef.noSender());
+        }};
+    }
+
+    @Test
+    public void testImmediateCommitWithCanCommitPhaseFailure() throws Throwable {
+        new ShardTestKit(getSystem()) {{
+            final TestActorRef<Shard> shard = TestActorRef.create(getSystem(),
+                    newShardProps().withDispatcher(Dispatchers.DefaultDispatcherId()),
+                    "testImmediateCommitWithCanCommitPhaseFailure");
+
+            waitUntilLeader(shard);
+
+            final FiniteDuration duration = duration("5 seconds");
+
+            String transactionID1 = "tx1";
+            MutableCompositeModification modification = new MutableCompositeModification();
+            DOMStoreThreePhaseCommitCohort cohort = mock(DOMStoreThreePhaseCommitCohort.class, "cohort1");
+            doReturn(Futures.immediateFailedFuture(new IllegalStateException("mock"))).when(cohort).canCommit();
+
+            // Simulate the ForwardedReadyTransaction messages that would be sent
+            // by the ShardTransaction.
+
+            shard.tell(new ForwardedReadyTransaction(transactionID1, CURRENT_VERSION,
+                    cohort, modification, true, true), getRef());
+
+            expectMsgClass(duration, akka.actor.Status.Failure.class);
+
+            // Send another can commit to ensure the failed one got cleaned up.
+
+            reset(cohort);
+
+            String transactionID2 = "tx2";
+            doReturn(Futures.immediateFuture(Boolean.TRUE)).when(cohort).canCommit();
+            doReturn(Futures.immediateFuture(null)).when(cohort).preCommit();
+            doReturn(Futures.immediateFuture(null)).when(cohort).commit();
+
+            shard.tell(new ForwardedReadyTransaction(transactionID2, CURRENT_VERSION,
+                    cohort, modification, true, true), getRef());
+
+            expectMsgClass(duration, CommitTransactionReply.SERIALIZABLE_CLASS);
+
+            shard.tell(PoisonPill.getInstance(), ActorRef.noSender());
+        }};
+    }
+
+    @Test
+    public void testImmediateCommitWithCanCommitPhaseFalseResponse() throws Throwable {
+        new ShardTestKit(getSystem()) {{
+            final TestActorRef<Shard> shard = TestActorRef.create(getSystem(),
+                    newShardProps().withDispatcher(Dispatchers.DefaultDispatcherId()),
+                    "testImmediateCommitWithCanCommitPhaseFalseResponse");
+
+            waitUntilLeader(shard);
+
+            final FiniteDuration duration = duration("5 seconds");
+
+            String transactionID = "tx1";
+            MutableCompositeModification modification = new MutableCompositeModification();
+            DOMStoreThreePhaseCommitCohort cohort = mock(DOMStoreThreePhaseCommitCohort.class, "cohort1");
+            doReturn(Futures.immediateFuture(Boolean.FALSE)).when(cohort).canCommit();
+
+            // Simulate the ForwardedReadyTransaction messages that would be sent
+            // by the ShardTransaction.
+
+            shard.tell(new ForwardedReadyTransaction(transactionID, CURRENT_VERSION,
+                    cohort, modification, true, true), getRef());
+
+            expectMsgClass(duration, akka.actor.Status.Failure.class);
+
+            // Send another can commit to ensure the failed one got cleaned up.
+
+            reset(cohort);
+
+            String transactionID2 = "tx2";
+            doReturn(Futures.immediateFuture(Boolean.TRUE)).when(cohort).canCommit();
+            doReturn(Futures.immediateFuture(null)).when(cohort).preCommit();
+            doReturn(Futures.immediateFuture(null)).when(cohort).commit();
+
+            shard.tell(new ForwardedReadyTransaction(transactionID2, CURRENT_VERSION,
+                    cohort, modification, true, true), getRef());
+
+            expectMsgClass(duration, CommitTransactionReply.SERIALIZABLE_CLASS);
+
             shard.tell(PoisonPill.getInstance(), ActorRef.noSender());
         }};
     }
@@ -1159,7 +1429,7 @@ public class ShardTest extends AbstractShardTest {
                     modification, preCommit);
 
             shard.tell(new ForwardedReadyTransaction(transactionID, CURRENT_VERSION,
-                    cohort, modification, true), getRef());
+                    cohort, modification, true, false), getRef());
             expectMsgClass(duration, ReadyTransactionReply.class);
 
             shard.tell(new CanCommitTransaction(transactionID).toSerializable(), getRef());
@@ -1224,11 +1494,11 @@ public class ShardTest extends AbstractShardTest {
             // Ready the Tx's
 
             shard.tell(new ForwardedReadyTransaction(transactionID1, CURRENT_VERSION,
-                    cohort1, modification1, true), getRef());
+                    cohort1, modification1, true, false), getRef());
             expectMsgClass(duration, ReadyTransactionReply.class);
 
             shard.tell(new ForwardedReadyTransaction(transactionID2, CURRENT_VERSION,
-                    cohort2, modification2, true), getRef());
+                    cohort2, modification2, true, false), getRef());
             expectMsgClass(duration, ReadyTransactionReply.class);
 
             // canCommit 1st Tx. We don't send the commit so it should timeout.
@@ -1241,6 +1511,11 @@ public class ShardTest extends AbstractShardTest {
             shard.tell(new CanCommitTransaction(transactionID2).toSerializable(), getRef());
             expectMsgClass(duration, CanCommitTransactionReply.SERIALIZABLE_CLASS);
 
+            // Try to commit the 1st Tx - should fail as it's not the current Tx.
+
+            shard.tell(new CommitTransaction(transactionID1).toSerializable(), getRef());
+            expectMsgClass(duration, akka.actor.Status.Failure.class);
+
             // Commit the 2nd Tx.
 
             shard.tell(new CommitTransaction(transactionID2).toSerializable(), getRef());
@@ -1288,15 +1563,15 @@ public class ShardTest extends AbstractShardTest {
             // Ready the Tx's
 
             shard.tell(new ForwardedReadyTransaction(transactionID1, CURRENT_VERSION,
-                    cohort1, modification1, true), getRef());
+                    cohort1, modification1, true, false), getRef());
             expectMsgClass(duration, ReadyTransactionReply.class);
 
             shard.tell(new ForwardedReadyTransaction(transactionID2, CURRENT_VERSION,
-                    cohort2, modification2, true), getRef());
+                    cohort2, modification2, true, false), getRef());
             expectMsgClass(duration, ReadyTransactionReply.class);
 
             shard.tell(new ForwardedReadyTransaction(transactionID3, CURRENT_VERSION,
-                    cohort3, modification3, true), getRef());
+                    cohort3, modification3, true, false), getRef());
             expectMsgClass(duration, ReadyTransactionReply.class);
 
             // canCommit 1st Tx.
@@ -1360,11 +1635,11 @@ public class ShardTest extends AbstractShardTest {
             // by the ShardTransaction.
 
             shard.tell(new ForwardedReadyTransaction(transactionID1, CURRENT_VERSION,
-                    cohort1, modification1, true), getRef());
+                    cohort1, modification1, true, false), getRef());
             expectMsgClass(duration, ReadyTransactionReply.class);
 
             shard.tell(new ForwardedReadyTransaction(transactionID2, CURRENT_VERSION,
-                    cohort2, modification2, true), getRef());
+                    cohort2, modification2, true, false), getRef());
             expectMsgClass(duration, ReadyTransactionReply.class);
 
             // Send the CanCommitTransaction message for the first Tx.