Merge "BUG 720 - YANG leaf as JSON input *<*:* couldn't be saved"
[controller.git] / opendaylight / md-sal / sal-distributed-datastore / src / test / java / org / opendaylight / controller / cluster / datastore / ShardTest.java
index fc45efcdea854bea791c44b12a4437c2c74bc2f1..2051c9debe88c69c071663df511fbff130f325ed 100644 (file)
 package org.opendaylight.controller.cluster.datastore;
 
 import akka.actor.ActorRef;
+import akka.actor.ActorSystem;
 import akka.actor.Props;
 import akka.event.Logging;
+import akka.japi.Creator;
 import akka.testkit.JavaTestKit;
-
+import akka.testkit.TestActorRef;
+import com.google.common.base.Optional;
+import com.google.common.util.concurrent.CheckedFuture;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.MoreExecutors;
+import org.junit.After;
 import org.junit.Assert;
+import org.junit.Before;
 import org.junit.Test;
 import org.opendaylight.controller.cluster.datastore.identifiers.ShardIdentifier;
+import org.opendaylight.controller.cluster.datastore.messages.CommitTransactionReply;
 import org.opendaylight.controller.cluster.datastore.messages.CreateTransaction;
-import org.opendaylight.controller.cluster.datastore.messages.CreateTransactionChain;
-import org.opendaylight.controller.cluster.datastore.messages.CreateTransactionChainReply;
 import org.opendaylight.controller.cluster.datastore.messages.EnableNotification;
+import org.opendaylight.controller.cluster.datastore.messages.ForwardedCommitTransaction;
 import org.opendaylight.controller.cluster.datastore.messages.PeerAddressResolved;
 import org.opendaylight.controller.cluster.datastore.messages.RegisterChangeListener;
 import org.opendaylight.controller.cluster.datastore.messages.RegisterChangeListenerReply;
 import org.opendaylight.controller.cluster.datastore.messages.UpdateSchemaContext;
+import org.opendaylight.controller.cluster.datastore.modification.MergeModification;
+import org.opendaylight.controller.cluster.datastore.modification.Modification;
+import org.opendaylight.controller.cluster.datastore.modification.MutableCompositeModification;
+import org.opendaylight.controller.cluster.datastore.modification.WriteModification;
+import org.opendaylight.controller.cluster.datastore.node.NormalizedNodeToNodeCodec;
+import org.opendaylight.controller.cluster.datastore.utils.InMemoryJournal;
+import org.opendaylight.controller.cluster.datastore.utils.InMemorySnapshotStore;
+import org.opendaylight.controller.cluster.raft.ReplicatedLogEntry;
+import org.opendaylight.controller.cluster.raft.ReplicatedLogImplEntry;
+import org.opendaylight.controller.cluster.raft.Snapshot;
+import org.opendaylight.controller.cluster.raft.base.messages.ApplyLogEntries;
+import org.opendaylight.controller.cluster.raft.base.messages.ApplySnapshot;
+import org.opendaylight.controller.cluster.raft.base.messages.ApplyState;
+import org.opendaylight.controller.cluster.raft.base.messages.CaptureSnapshot;
+import org.opendaylight.controller.cluster.raft.protobuff.client.messages.CompositeModificationPayload;
+import org.opendaylight.controller.cluster.raft.protobuff.client.messages.Payload;
 import org.opendaylight.controller.md.cluster.datastore.model.SchemaContextHelper;
 import org.opendaylight.controller.md.cluster.datastore.model.TestModel;
 import org.opendaylight.controller.md.sal.common.api.data.AsyncDataBroker;
 import org.opendaylight.controller.md.sal.common.api.data.AsyncDataChangeEvent;
 import org.opendaylight.controller.md.sal.common.api.data.AsyncDataChangeListener;
+import org.opendaylight.controller.md.sal.common.api.data.ReadFailedException;
+import org.opendaylight.controller.md.sal.dom.store.impl.InMemoryDOMDataStore;
+import org.opendaylight.controller.md.sal.dom.store.impl.InMemoryDOMDataStoreFactory;
+import org.opendaylight.controller.protobuff.messages.common.NormalizedNodeMessages;
 import org.opendaylight.controller.protobuff.messages.transaction.ShardTransactionMessages.CreateTransactionReply;
+import org.opendaylight.controller.sal.core.spi.data.DOMStoreReadTransaction;
+import org.opendaylight.controller.sal.core.spi.data.DOMStoreThreePhaseCommitCohort;
+import org.opendaylight.controller.sal.core.spi.data.DOMStoreWriteTransaction;
 import org.opendaylight.yangtools.yang.data.api.YangInstanceIdentifier;
+import org.opendaylight.yangtools.yang.data.api.YangInstanceIdentifier.PathArgument;
+import org.opendaylight.yangtools.yang.data.api.schema.DataContainerChild;
+import org.opendaylight.yangtools.yang.data.api.schema.MapEntryNode;
 import org.opendaylight.yangtools.yang.data.api.schema.NormalizedNode;
-
+import org.opendaylight.yangtools.yang.data.impl.schema.ImmutableNodes;
+import org.opendaylight.yangtools.yang.model.api.SchemaContext;
+import scala.concurrent.duration.Duration;
+import java.io.IOException;
 import java.util.Collections;
-import java.util.HashMap;
-import java.util.Map;
-
-import static org.junit.Assert.assertFalse;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.verify;
 
 public class ShardTest extends AbstractActorTest {
 
-    private static final ShardContext shardContext = new ShardContext();
+    private static final DatastoreContext DATA_STORE_CONTEXT =
+            new DatastoreContext("", null, Duration.create(10, TimeUnit.MINUTES), 5, 3, 5000, 500);
+
+    private static final SchemaContext SCHEMA_CONTEXT = TestModel.createTestContext();
+
+    private static final ShardIdentifier IDENTIFIER = ShardIdentifier.builder().memberName("member-1")
+            .shardName("inventory").type("config").build();
+
+    @Before
+    public void setUp() {
+        System.setProperty("shard.persistent", "false");
+
+        InMemorySnapshotStore.clear();
+        InMemoryJournal.clear();
+    }
+
+    @After
+    public void tearDown() {
+        InMemorySnapshotStore.clear();
+        InMemoryJournal.clear();
+    }
+
+    private Props newShardProps() {
+        return Shard.props(IDENTIFIER, Collections.<ShardIdentifier,String>emptyMap(),
+                DATA_STORE_CONTEXT, SCHEMA_CONTEXT);
+    }
 
     @Test
-    public void testOnReceiveCreateTransactionChain() throws Exception {
+    public void testOnReceiveRegisterListener() throws Exception {
         new JavaTestKit(getSystem()) {{
-            final ShardIdentifier identifier =
-                ShardIdentifier.builder().memberName("member-1")
-                    .shardName("inventory").type("config").build();
+            ActorRef subject = getSystem().actorOf(newShardProps(), "testRegisterChangeListener");
 
-            final Props props = Shard.props(identifier, Collections.EMPTY_MAP, shardContext);
-            final ActorRef subject =
-                getSystem().actorOf(props, "testCreateTransactionChain");
+            subject.tell(new UpdateSchemaContext(SchemaContextHelper.full()), getRef());
 
+            subject.tell(new RegisterChangeListener(TestModel.TEST_PATH,
+                    getRef().path(), AsyncDataBroker.DataChangeScope.BASE), getRef());
 
-            // Wait for a specific log message to show up
-            final boolean result =
-                new JavaTestKit.EventFilter<Boolean>(Logging.Info.class
-                ) {
-                    @Override
-                    protected Boolean run() {
-                        return true;
-                    }
-                }.from(subject.path().toString())
-                    .message("Switching from state Candidate to Leader")
-                    .occurrences(1).exec();
+            EnableNotification enable = expectMsgClass(duration("3 seconds"), EnableNotification.class);
+            assertEquals("isEnabled", false, enable.isEnabled());
 
-            Assert.assertEquals(true, result);
+            RegisterChangeListenerReply reply = expectMsgClass(duration("3 seconds"),
+                    RegisterChangeListenerReply.class);
+            assertTrue(reply.getListenerRegistrationPath().toString().matches(
+                    "akka:\\/\\/test\\/user\\/testRegisterChangeListener\\/\\$.*"));
+        }};
+    }
 
-            new Within(duration("3 seconds")) {
-                @Override
-                protected void run() {
+    @Test
+    public void testCreateTransaction(){
+        new ShardTestKit(getSystem()) {{
+            ActorRef subject = getSystem().actorOf(newShardProps(), "testCreateTransaction");
 
-                    subject.tell(new CreateTransactionChain().toSerializable(), getRef());
-
-                    final String out = new ExpectMsg<String>(duration("3 seconds"), "match hint") {
-                        // do not put code outside this method, will run afterwards
-                        @Override
-                        protected String match(Object in) {
-                            if (in.getClass().equals(CreateTransactionChainReply.SERIALIZABLE_CLASS)){
-                                CreateTransactionChainReply reply =
-                                    CreateTransactionChainReply.fromSerializable(getSystem(),in);
-                                return reply.getTransactionChainPath()
-                                    .toString();
-                            } else {
-                                throw noMatch();
-                            }
-                        }
-                    }.get(); // this extracts the received message
+            waitUntilLeader(subject);
 
-                    assertEquals("Unexpected transaction path " + out,
-                        "akka://test/user/testCreateTransactionChain/$a",
-                        out);
+            subject.tell(new UpdateSchemaContext(TestModel.createTestContext()), getRef());
 
-                    expectNoMsg();
-                }
+            subject.tell(new CreateTransaction("txn-1",
+                    TransactionProxy.TransactionType.READ_ONLY.ordinal() ).toSerializable(), getRef());
 
+            CreateTransactionReply reply = expectMsgClass(duration("3 seconds"),
+                    CreateTransactionReply.class);
 
-            };
+            String path = reply.getTransactionActorPath().toString();
+            assertTrue("Unexpected transaction path " + path,
+                    path.contains("akka://test/user/testCreateTransaction/shard-txn-1"));
+            expectNoMsg();
         }};
     }
 
     @Test
-    public void testOnReceiveRegisterListener() throws Exception {
+    public void testCreateTransactionOnChain(){
+        new ShardTestKit(getSystem()) {{
+            final ActorRef subject = getSystem().actorOf(newShardProps(), "testCreateTransactionOnChain");
+
+            waitUntilLeader(subject);
+
+            subject.tell(new CreateTransaction("txn-1",
+                    TransactionProxy.TransactionType.READ_ONLY.ordinal() , "foobar").toSerializable(),
+                    getRef());
+
+            CreateTransactionReply reply = expectMsgClass(duration("3 seconds"),
+                    CreateTransactionReply.class);
+
+            String path = reply.getTransactionActorPath().toString();
+            assertTrue("Unexpected transaction path " + path,
+                    path.contains("akka://test/user/testCreateTransactionOnChain/shard-txn-1"));
+            expectNoMsg();
+        }};
+    }
+
+    @Test
+    public void testPeerAddressResolved(){
         new JavaTestKit(getSystem()) {{
             final ShardIdentifier identifier =
                 ShardIdentifier.builder().memberName("member-1")
                     .shardName("inventory").type("config").build();
 
-            final Props props = Shard.props(identifier, Collections.EMPTY_MAP, shardContext);
-            final ActorRef subject =
-                getSystem().actorOf(props, "testRegisterChangeListener");
+            Props props = Shard.props(identifier,
+                    Collections.<ShardIdentifier, String>singletonMap(identifier, null),
+                    DATA_STORE_CONTEXT, SCHEMA_CONTEXT);
+            final ActorRef subject = getSystem().actorOf(props, "testPeerAddressResolved");
 
             new Within(duration("3 seconds")) {
                 @Override
                 protected void run() {
 
                     subject.tell(
-                        new UpdateSchemaContext(SchemaContextHelper.full()),
+                        new PeerAddressResolved(identifier, "akka://foobar"),
                         getRef());
 
-                    subject.tell(new RegisterChangeListener(TestModel.TEST_PATH,
-                        getRef().path(), AsyncDataBroker.DataChangeScope.BASE),
-                        getRef());
+                    expectNoMsg();
+                }
+            };
+        }};
+    }
 
-                    final Boolean notificationEnabled = new ExpectMsg<Boolean>(
-                                                   duration("3 seconds"), "enable notification") {
-                        // do not put code outside this method, will run afterwards
-                        @Override
-                        protected Boolean match(Object in) {
-                            if(in instanceof EnableNotification){
-                                return ((EnableNotification) in).isEnabled();
-                            } else {
-                                throw noMatch();
-                            }
-                        }
-                    }.get(); // this extracts the received message
-
-                    assertFalse(notificationEnabled);
-
-                    final String out = new ExpectMsg<String>(duration("3 seconds"), "match hint") {
-                        // do not put code outside this method, will run afterwards
-                        @Override
-                        protected String match(Object in) {
-                            if (in.getClass().equals(RegisterChangeListenerReply.class)) {
-                                RegisterChangeListenerReply reply =
-                                    (RegisterChangeListenerReply) in;
-                                return reply.getListenerRegistrationPath()
-                                    .toString();
-                            } else {
-                                throw noMatch();
-                            }
-                        }
-                    }.get(); // this extracts the received message
+    @Test
+    public void testApplySnapshot() throws ExecutionException, InterruptedException {
+        TestActorRef<Shard> ref = TestActorRef.create(getSystem(), newShardProps());
 
-                    assertTrue(out.matches(
-                        "akka:\\/\\/test\\/user\\/testRegisterChangeListener\\/\\$.*"));
-                }
+        NormalizedNodeToNodeCodec codec =
+            new NormalizedNodeToNodeCodec(SCHEMA_CONTEXT);
 
+        ref.underlyingActor().writeToStore(TestModel.TEST_PATH, ImmutableNodes.containerNode(
+                TestModel.TEST_QNAME));
 
-            };
-        }};
+        YangInstanceIdentifier root = YangInstanceIdentifier.builder().build();
+        NormalizedNode<?,?> expected = ref.underlyingActor().readStore(root);
+
+        NormalizedNodeMessages.Container encode = codec.encode(expected);
+
+        ApplySnapshot applySnapshot = new ApplySnapshot(Snapshot.create(
+                encode.getNormalizedNode().toByteString().toByteArray(),
+                Collections.<ReplicatedLogEntry>emptyList(), 1, 2, 3, 4));
+
+        ref.underlyingActor().onReceiveCommand(applySnapshot);
+
+        NormalizedNode<?,?> actual = ref.underlyingActor().readStore(root);
+
+        assertEquals(expected, actual);
     }
 
     @Test
-    public void testCreateTransaction(){
-        new JavaTestKit(getSystem()) {{
-            final ShardIdentifier identifier =
-                ShardIdentifier.builder().memberName("member-1")
-                    .shardName("inventory").type("config").build();
+    public void testApplyState() throws Exception {
 
-            final Props props = Shard.props(identifier, Collections.EMPTY_MAP, shardContext);
-            final ActorRef subject =
-                getSystem().actorOf(props, "testCreateTransaction");
+        TestActorRef<Shard> shard = TestActorRef.create(getSystem(), newShardProps());
 
-            // Wait for a specific log message to show up
-            final boolean result =
-                new JavaTestKit.EventFilter<Boolean>(Logging.Info.class
-                ) {
+        NormalizedNode<?, ?> node = ImmutableNodes.containerNode(TestModel.TEST_QNAME);
+
+        MutableCompositeModification compMod = new MutableCompositeModification();
+        compMod.addModification(new WriteModification(TestModel.TEST_PATH, node, SCHEMA_CONTEXT));
+        Payload payload = new CompositeModificationPayload(compMod.toSerializable());
+        ApplyState applyState = new ApplyState(null, "test",
+                new ReplicatedLogImplEntry(1, 2, payload));
+
+        shard.underlyingActor().onReceiveCommand(applyState);
+
+        NormalizedNode<?,?> actual = shard.underlyingActor().readStore(TestModel.TEST_PATH);
+        assertEquals("Applied state", node, actual);
+    }
+
+    @SuppressWarnings("serial")
+    @Test
+    public void testRecovery() throws Exception {
+
+        // Set up the InMemorySnapshotStore.
+
+        InMemoryDOMDataStore testStore = InMemoryDOMDataStoreFactory.create("Test", null, null);
+        testStore.onGlobalContextUpdated(SCHEMA_CONTEXT);
+
+        DOMStoreWriteTransaction writeTx = testStore.newWriteOnlyTransaction();
+        writeTx.write(TestModel.TEST_PATH, ImmutableNodes.containerNode(TestModel.TEST_QNAME));
+        DOMStoreThreePhaseCommitCohort commitCohort = writeTx.ready();
+        commitCohort.preCommit().get();
+        commitCohort.commit().get();
+
+        DOMStoreReadTransaction readTx = testStore.newReadOnlyTransaction();
+        NormalizedNode<?, ?> root = readTx.read(YangInstanceIdentifier.builder().build()).get().get();
+
+        InMemorySnapshotStore.addSnapshot(IDENTIFIER.toString(), Snapshot.create(
+                new NormalizedNodeToNodeCodec(SCHEMA_CONTEXT).encode(
+                        root).
+                                getNormalizedNode().toByteString().toByteArray(),
+                                Collections.<ReplicatedLogEntry>emptyList(), 0, 1, -1, -1));
+
+        // Set up the InMemoryJournal.
+
+        InMemoryJournal.addEntry(IDENTIFIER.toString(), 0, new ReplicatedLogImplEntry(0, 1, newPayload(
+                  new WriteModification(TestModel.OUTER_LIST_PATH,
+                          ImmutableNodes.mapNodeBuilder(TestModel.OUTER_LIST_QNAME).build(),
+                          SCHEMA_CONTEXT))));
+
+        int nListEntries = 11;
+        Set<Integer> listEntryKeys = new HashSet<>();
+        for(int i = 1; i <= nListEntries; i++) {
+            listEntryKeys.add(Integer.valueOf(i));
+            YangInstanceIdentifier path = YangInstanceIdentifier.builder(TestModel.OUTER_LIST_PATH)
+                    .nodeWithKey(TestModel.OUTER_LIST_QNAME, TestModel.ID_QNAME, i).build();
+            Modification mod = new MergeModification(path,
+                    ImmutableNodes.mapEntry(TestModel.OUTER_LIST_QNAME, TestModel.ID_QNAME, i),
+                    SCHEMA_CONTEXT);
+            InMemoryJournal.addEntry(IDENTIFIER.toString(), i, new ReplicatedLogImplEntry(i, 1,
+                    newPayload(mod)));
+        }
+
+        InMemoryJournal.addEntry(IDENTIFIER.toString(), nListEntries + 1,
+                new ApplyLogEntries(nListEntries));
+
+        // Create the actor and wait for recovery complete.
+
+        final CountDownLatch recoveryComplete = new CountDownLatch(1);
+
+        Creator<Shard> creator = new Creator<Shard>() {
+            @Override
+            public Shard create() throws Exception {
+                return new Shard(IDENTIFIER, Collections.<ShardIdentifier,String>emptyMap(),
+                        DATA_STORE_CONTEXT, SCHEMA_CONTEXT) {
                     @Override
-                    protected Boolean run() {
-                        return true;
+                    protected void onRecoveryComplete() {
+                        try {
+                            super.onRecoveryComplete();
+                        } finally {
+                            recoveryComplete.countDown();
+                        }
                     }
-                }.from(subject.path().toString())
-                    .message("Switching from state Candidate to Leader")
-                    .occurrences(1).exec();
+                };
+            }
+        };
 
-            Assert.assertEquals(true, result);
+        TestActorRef<Shard> shard = TestActorRef.create(getSystem(),
+                Props.create(new DelegatingShardCreator(creator)), "testRecovery");
+
+        assertEquals("Recovery complete", true, recoveryComplete.await(5, TimeUnit.SECONDS));
+
+        // Verify data in the data store.
+
+        NormalizedNode<?, ?> outerList = shard.underlyingActor().readStore(TestModel.OUTER_LIST_PATH);
+        assertNotNull(TestModel.OUTER_LIST_QNAME.getLocalName() + " not found", outerList);
+        assertTrue(TestModel.OUTER_LIST_QNAME.getLocalName() + " value is not Iterable",
+                outerList.getValue() instanceof Iterable);
+        for(Object entry: (Iterable<?>) outerList.getValue()) {
+            assertTrue(TestModel.OUTER_LIST_QNAME.getLocalName() + " entry is not MapEntryNode",
+                    entry instanceof MapEntryNode);
+            MapEntryNode mapEntry = (MapEntryNode)entry;
+            Optional<DataContainerChild<? extends PathArgument, ?>> idLeaf =
+                    mapEntry.getChild(new YangInstanceIdentifier.NodeIdentifier(TestModel.ID_QNAME));
+            assertTrue("Missing leaf " + TestModel.ID_QNAME.getLocalName(), idLeaf.isPresent());
+            Object value = idLeaf.get().getValue();
+            assertTrue("Unexpected value for leaf "+ TestModel.ID_QNAME.getLocalName() + ": " + value,
+                    listEntryKeys.remove(value));
+        }
+
+        if(!listEntryKeys.isEmpty()) {
+            fail("Missing " + TestModel.OUTER_LIST_QNAME.getLocalName() + " entries with keys: " +
+                    listEntryKeys);
+        }
+
+        assertEquals("Last log index", nListEntries,
+                shard.underlyingActor().getShardMBean().getLastLogIndex());
+        assertEquals("Commit index", nListEntries,
+                shard.underlyingActor().getShardMBean().getCommitIndex());
+        assertEquals("Last applied", nListEntries,
+                shard.underlyingActor().getShardMBean().getLastApplied());
+    }
 
-            new Within(duration("3 seconds")) {
-                @Override
-                protected void run() {
+    private CompositeModificationPayload newPayload(Modification... mods) {
+        MutableCompositeModification compMod = new MutableCompositeModification();
+        for(Modification mod: mods) {
+            compMod.addModification(mod);
+        }
 
-                    subject.tell(
-                        new UpdateSchemaContext(TestModel.createTestContext()),
-                        getRef());
+        return new CompositeModificationPayload(compMod.toSerializable());
+    }
 
-                    subject.tell(new CreateTransaction("txn-1", TransactionProxy.TransactionType.READ_ONLY.ordinal() ).toSerializable(),
-                        getRef());
+    @SuppressWarnings("unchecked")
+    @Test
+    public void testForwardedCommitTransactionWithPersistence() throws IOException {
+        System.setProperty("shard.persistent", "true");
 
-                    final String out = new ExpectMsg<String>(duration("3 seconds"), "match hint") {
-                        // do not put code outside this method, will run afterwards
-                        @Override
-                        protected String match(Object in) {
-                            if (in instanceof CreateTransactionReply) {
-                                CreateTransactionReply reply =
-                                    (CreateTransactionReply) in;
-                                return reply.getTransactionActorPath()
-                                    .toString();
-                            } else {
-                                throw noMatch();
-                            }
-                        }
-                    }.get(); // this extracts the received message
+        new ShardTestKit(getSystem()) {{
+            TestActorRef<Shard> shard = TestActorRef.create(getSystem(), newShardProps());
 
-                    assertTrue("Unexpected transaction path " + out,
-                        out.contains("akka://test/user/testCreateTransaction/shard-txn-1"));
-                    expectNoMsg();
-                }
-            };
+            waitUntilLeader(shard);
+
+            NormalizedNode<?, ?> node = ImmutableNodes.containerNode(TestModel.TEST_QNAME);
+
+            DOMStoreThreePhaseCommitCohort cohort = mock(DOMStoreThreePhaseCommitCohort.class);
+            doReturn(Futures.immediateFuture(null)).when(cohort).commit();
+
+            MutableCompositeModification modification = new MutableCompositeModification();
+            modification.addModification(new WriteModification(TestModel.TEST_PATH, node,
+                    SCHEMA_CONTEXT));
+
+            shard.tell(new ForwardedCommitTransaction(cohort, modification), getRef());
+
+            expectMsgClass(duration("5 seconds"), CommitTransactionReply.SERIALIZABLE_CLASS);
+
+            verify(cohort).commit();
+
+            assertEquals("Last log index", 0, shard.underlyingActor().getShardMBean().getLastLogIndex());
         }};
     }
 
     @Test
-    public void testPeerAddressResolved(){
-        new JavaTestKit(getSystem()) {{
-            Map<ShardIdentifier, String> peerAddresses = new HashMap<>();
+    public void testCreateSnapshot() throws IOException, InterruptedException {
+        new ShardTestKit(getSystem()) {{
+            final ActorRef subject = getSystem().actorOf(newShardProps(), "testCreateSnapshot");
 
-            final ShardIdentifier identifier =
-                ShardIdentifier.builder().memberName("member-1")
-                    .shardName("inventory").type("config").build();
+            waitUntilLeader(subject);
 
-            peerAddresses.put(identifier, null);
-            final Props props = Shard.props(identifier, peerAddresses, shardContext);
-            final ActorRef subject =
-                getSystem().actorOf(props, "testPeerAddressResolved");
+            subject.tell(new CaptureSnapshot(-1,-1,-1,-1), getRef());
 
-            new Within(duration("3 seconds")) {
-                @Override
-                protected void run() {
+            waitForLogMessage(Logging.Info.class, subject, "CaptureSnapshotReply received by actor");
 
-                    subject.tell(
-                        new PeerAddressResolved(identifier, "akka://foobar"),
-                        getRef());
+            subject.tell(new CaptureSnapshot(-1,-1,-1,-1), getRef());
 
-                    expectNoMsg();
-                }
-            };
+            waitForLogMessage(Logging.Info.class, subject, "CaptureSnapshotReply received by actor");
         }};
     }
 
+    /**
+     * This test simply verifies that the applySnapShot logic will work
+     * @throws ReadFailedException
+     */
+    @Test
+    public void testInMemoryDataStoreRestore() throws ReadFailedException {
+        InMemoryDOMDataStore store = new InMemoryDOMDataStore("test", MoreExecutors.listeningDecorator(
+            MoreExecutors.sameThreadExecutor()), MoreExecutors.sameThreadExecutor());
+
+        store.onGlobalContextUpdated(SCHEMA_CONTEXT);
+
+        DOMStoreWriteTransaction putTransaction = store.newWriteOnlyTransaction();
+        putTransaction.write(TestModel.TEST_PATH,
+            ImmutableNodes.containerNode(TestModel.TEST_QNAME));
+        commitTransaction(putTransaction);
+
+
+        NormalizedNode expected = readStore(store);
+
+        DOMStoreWriteTransaction writeTransaction = store.newWriteOnlyTransaction();
+
+        writeTransaction.delete(YangInstanceIdentifier.builder().build());
+        writeTransaction.write(YangInstanceIdentifier.builder().build(), expected);
+
+        commitTransaction(writeTransaction);
+
+        NormalizedNode actual = readStore(store);
+
+        assertEquals(expected, actual);
+
+    }
+
+    private NormalizedNode readStore(InMemoryDOMDataStore store) throws ReadFailedException {
+        DOMStoreReadTransaction transaction = store.newReadOnlyTransaction();
+        CheckedFuture<Optional<NormalizedNode<?, ?>>, ReadFailedException> read =
+            transaction.read(YangInstanceIdentifier.builder().build());
+
+        Optional<NormalizedNode<?, ?>> optional = read.checkedGet();
+
+        NormalizedNode<?, ?> normalizedNode = optional.get();
+
+        transaction.close();
+
+        return normalizedNode;
+    }
+
+    private void commitTransaction(DOMStoreWriteTransaction transaction) {
+        DOMStoreThreePhaseCommitCohort commitCohort = transaction.ready();
+        ListenableFuture<Void> future =
+            commitCohort.preCommit();
+        try {
+            future.get();
+            future = commitCohort.commit();
+            future.get();
+        } catch (InterruptedException | ExecutionException e) {
+        }
+    }
+
     private AsyncDataChangeListener<YangInstanceIdentifier, NormalizedNode<?, ?>> noOpDataChangeListener() {
         return new AsyncDataChangeListener<YangInstanceIdentifier, NormalizedNode<?, ?>>() {
             @Override
@@ -254,4 +466,46 @@ public class ShardTest extends AbstractActorTest {
             }
         };
     }
+
+    private static final class DelegatingShardCreator implements Creator<Shard> {
+        private final Creator<Shard> delegate;
+
+        DelegatingShardCreator(Creator<Shard> delegate) {
+            this.delegate = delegate;
+        }
+
+        @Override
+        public Shard create() throws Exception {
+            return delegate.create();
+        }
+    }
+
+    private static class ShardTestKit extends JavaTestKit {
+
+        private ShardTestKit(ActorSystem actorSystem) {
+            super(actorSystem);
+        }
+
+        protected void waitForLogMessage(final Class logLevel, ActorRef subject, String logMessage){
+            // Wait for a specific log message to show up
+            final boolean result =
+                new JavaTestKit.EventFilter<Boolean>(logLevel
+                ) {
+                    @Override
+                    protected Boolean run() {
+                        return true;
+                    }
+                }.from(subject.path().toString())
+                    .message(logMessage)
+                    .occurrences(1).exec();
+
+            Assert.assertEquals(true, result);
+
+        }
+
+        protected void waitUntilLeader(ActorRef subject) {
+            waitForLogMessage(Logging.Info.class, subject,
+                    "Switching from state Candidate to Leader");
+        }
+    }
 }