BUG 1815 - Do not allow Shards to be created till an appropriate schema context is... 46/11346/4
authorMoiz Raja <moraja@cisco.com>
Thu, 18 Sep 2014 17:54:47 +0000 (10:54 -0700)
committerMoiz Raja <moraja@cisco.com>
Fri, 19 Sep 2014 23:18:57 +0000 (23:18 +0000)
The fix works like so,
- ShardManager maintains a list of all the modules that it ever knew about
- ShardManager persists the known modules to disk using persistence
- When ShardManager recovers it reads back the knownModules from persistence
- As ShardManager gets new SchemaContext's it checks whether the modules in
  the new SchemaContext are a superset of the knownModules. If they are then
  ShardManager persists it and let's the Shards know about the new SchemaContext
  otherwise the new SchemaContext is rejected and a message is logged

Also reduced the log level of some log messages in RaftActor from info to debug
it was too verbose

Change-Id: If388f690114c58e6a8df30f34ddac32a99f255e5
Signed-off-by: Moiz Raja <moraja@cisco.com>
opendaylight/md-sal/sal-akka-raft/src/main/java/org/opendaylight/controller/cluster/raft/RaftActor.java
opendaylight/md-sal/sal-clustering-commons/src/main/java/org/opendaylight/controller/cluster/common/actor/AbstractUntypedPersistentActor.java [new file with mode: 0644]
opendaylight/md-sal/sal-distributed-datastore/src/main/java/org/opendaylight/controller/cluster/datastore/ShardManager.java
opendaylight/md-sal/sal-distributed-datastore/src/test/java/org/opendaylight/controller/cluster/datastore/ShardManagerTest.java

index 6e1a13cf0c19669443b9273e1d2703a3ff2dede9..0a4e2170e527e0dd74cae73802ae7893eba436be 100644 (file)
@@ -148,14 +148,18 @@ public abstract class RaftActor extends UntypedPersistentActor {
 
         } else if (message instanceof ReplicatedLogEntry) {
             ReplicatedLogEntry logEntry = (ReplicatedLogEntry) message;
 
         } else if (message instanceof ReplicatedLogEntry) {
             ReplicatedLogEntry logEntry = (ReplicatedLogEntry) message;
-            LOG.info("Received ReplicatedLogEntry for recovery:{}", logEntry.getIndex());
+            if(LOG.isDebugEnabled()) {
+                LOG.debug("Received ReplicatedLogEntry for recovery:{}", logEntry.getIndex());
+            }
             replicatedLog.append(logEntry);
 
         } else if (message instanceof ApplyLogEntries) {
             ApplyLogEntries ale = (ApplyLogEntries) message;
 
             replicatedLog.append(logEntry);
 
         } else if (message instanceof ApplyLogEntries) {
             ApplyLogEntries ale = (ApplyLogEntries) message;
 
-            LOG.info("Received ApplyLogEntries for recovery, applying to state:{} to {}",
-                context.getLastApplied() + 1, ale.getToIndex());
+            if(LOG.isDebugEnabled()) {
+                LOG.debug("Received ApplyLogEntries for recovery, applying to state:{} to {}",
+                    context.getLastApplied() + 1, ale.getToIndex());
+            }
 
             for (long i = context.getLastApplied() + 1; i <= ale.getToIndex(); i++) {
                 applyState(null, "recovery", replicatedLog.get(i).getData());
 
             for (long i = context.getLastApplied() + 1; i <= ale.getToIndex(); i++) {
                 applyState(null, "recovery", replicatedLog.get(i).getData());
@@ -198,7 +202,9 @@ public abstract class RaftActor extends UntypedPersistentActor {
 
         } else if (message instanceof ApplyLogEntries){
             ApplyLogEntries ale = (ApplyLogEntries) message;
 
         } else if (message instanceof ApplyLogEntries){
             ApplyLogEntries ale = (ApplyLogEntries) message;
-            LOG.info("Persisting ApplyLogEntries with index={}", ale.getToIndex());
+            if(LOG.isDebugEnabled()) {
+                LOG.debug("Persisting ApplyLogEntries with index={}", ale.getToIndex());
+            }
             persist(new ApplyLogEntries(ale.getToIndex()), new Procedure<ApplyLogEntries>() {
                 @Override
                 public void apply(ApplyLogEntries param) throws Exception {
             persist(new ApplyLogEntries(ale.getToIndex()), new Procedure<ApplyLogEntries>() {
                 @Override
                 public void apply(ApplyLogEntries param) throws Exception {
diff --git a/opendaylight/md-sal/sal-clustering-commons/src/main/java/org/opendaylight/controller/cluster/common/actor/AbstractUntypedPersistentActor.java b/opendaylight/md-sal/sal-clustering-commons/src/main/java/org/opendaylight/controller/cluster/common/actor/AbstractUntypedPersistentActor.java
new file mode 100644 (file)
index 0000000..36b2866
--- /dev/null
@@ -0,0 +1,70 @@
+/*
+ * Copyright (c) 2014 Cisco Systems, Inc. and others.  All rights reserved.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License v1.0 which accompanies this distribution,
+ * and is available at http://www.eclipse.org/legal/epl-v10.html
+ */
+
+package org.opendaylight.controller.cluster.common.actor;
+
+import akka.event.Logging;
+import akka.event.LoggingAdapter;
+import akka.persistence.UntypedPersistentActor;
+
+public abstract class AbstractUntypedPersistentActor extends UntypedPersistentActor {
+
+    protected final LoggingAdapter LOG =
+        Logging.getLogger(getContext().system(), this);
+
+    public AbstractUntypedPersistentActor() {
+        if(LOG.isDebugEnabled()) {
+            LOG.debug("Actor created {}", getSelf());
+        }
+        getContext().
+            system().
+            actorSelection("user/termination-monitor").
+            tell(new Monitor(getSelf()), getSelf());
+
+    }
+
+
+    @Override public void onReceiveCommand(Object message) throws Exception {
+        final String messageType = message.getClass().getSimpleName();
+        if(LOG.isDebugEnabled()) {
+            LOG.debug("Received message {}", messageType);
+        }
+        handleCommand(message);
+        if(LOG.isDebugEnabled()) {
+            LOG.debug("Done handling message {}", messageType);
+        }
+
+    }
+
+    @Override public void onReceiveRecover(Object message) throws Exception {
+        final String messageType = message.getClass().getSimpleName();
+        if(LOG.isDebugEnabled()) {
+            LOG.debug("Received message {}", messageType);
+        }
+        handleRecover(message);
+        if(LOG.isDebugEnabled()) {
+            LOG.debug("Done handling message {}", messageType);
+        }
+
+    }
+
+    protected abstract void handleRecover(Object message) throws Exception;
+
+    protected abstract void handleCommand(Object message) throws Exception;
+
+    protected void ignoreMessage(Object message) {
+        LOG.debug("Unhandled message {} ", message);
+    }
+
+    protected void unknownMessage(Object message) throws Exception {
+        if(LOG.isDebugEnabled()) {
+            LOG.debug("Received unhandled message {}", message);
+        }
+        unhandled(message);
+    }
+}
index a97c00f1d88227fb9d01c90ce38a80b8ccbb1e50..5874c5296f0ebd8d1b5085abd3797f6c77907618 100644 (file)
@@ -15,11 +15,16 @@ import akka.actor.OneForOneStrategy;
 import akka.actor.Props;
 import akka.actor.SupervisorStrategy;
 import akka.cluster.ClusterEvent;
 import akka.actor.Props;
 import akka.actor.SupervisorStrategy;
 import akka.cluster.ClusterEvent;
+import akka.event.Logging;
+import akka.event.LoggingAdapter;
 import akka.japi.Creator;
 import akka.japi.Function;
 import akka.japi.Creator;
 import akka.japi.Function;
+import akka.japi.Procedure;
+import akka.persistence.RecoveryCompleted;
+import akka.persistence.RecoveryFailure;
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
 import com.google.common.base.Preconditions;
-import org.opendaylight.controller.cluster.common.actor.AbstractUntypedActorWithMetering;
-
+import org.opendaylight.controller.cluster.common.actor.AbstractUntypedPersistentActor;
 import org.opendaylight.controller.cluster.datastore.identifiers.ShardIdentifier;
 import org.opendaylight.controller.cluster.datastore.identifiers.ShardManagerIdentifier;
 import org.opendaylight.controller.cluster.datastore.jmx.mbeans.shardmanager.ShardManagerInfo;
 import org.opendaylight.controller.cluster.datastore.identifiers.ShardIdentifier;
 import org.opendaylight.controller.cluster.datastore.identifiers.ShardManagerIdentifier;
 import org.opendaylight.controller.cluster.datastore.jmx.mbeans.shardmanager.ShardManagerInfo;
@@ -33,13 +38,18 @@ import org.opendaylight.controller.cluster.datastore.messages.PrimaryFound;
 import org.opendaylight.controller.cluster.datastore.messages.PrimaryNotFound;
 import org.opendaylight.controller.cluster.datastore.messages.UpdateSchemaContext;
 import org.opendaylight.controller.cluster.datastore.utils.ActorContext;
 import org.opendaylight.controller.cluster.datastore.messages.PrimaryNotFound;
 import org.opendaylight.controller.cluster.datastore.messages.UpdateSchemaContext;
 import org.opendaylight.controller.cluster.datastore.utils.ActorContext;
+import org.opendaylight.yangtools.yang.model.api.ModuleIdentifier;
 import org.opendaylight.yangtools.yang.model.api.SchemaContext;
 import scala.concurrent.duration.Duration;
 
 import org.opendaylight.yangtools.yang.model.api.SchemaContext;
 import scala.concurrent.duration.Duration;
 
+import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.HashMap;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 
 /**
  * The ShardManager has the following jobs,
 
 /**
  * The ShardManager has the following jobs,
@@ -50,7 +60,10 @@ import java.util.Map;
  * <li> Monitor the cluster members and store their addresses
  * <ul>
  */
  * <li> Monitor the cluster members and store their addresses
  * <ul>
  */
-public class ShardManager extends AbstractUntypedActorWithMetering {
+public class ShardManager extends AbstractUntypedPersistentActor {
+
+    protected final LoggingAdapter LOG =
+        Logging.getLogger(getContext().system(), this);
 
     // Stores a mapping between a member name and the address of the member
     // Member names look like "member-1", "member-2" etc and are as specified
 
     // Stores a mapping between a member name and the address of the member
     // Member names look like "member-1", "member-2" etc and are as specified
@@ -74,6 +87,8 @@ public class ShardManager extends AbstractUntypedActorWithMetering {
 
     private final DatastoreContext datastoreContext;
 
 
     private final DatastoreContext datastoreContext;
 
+    private final Collection<String> knownModules = new HashSet<>(128);
+
     /**
      * @param type defines the kind of data that goes into shards created by this shard manager. Examples of type would be
      *             configuration or operational
     /**
      * @param type defines the kind of data that goes into shards created by this shard manager. Examples of type would be
      *             configuration or operational
@@ -105,7 +120,7 @@ public class ShardManager extends AbstractUntypedActorWithMetering {
     }
 
     @Override
     }
 
     @Override
-    public void handleReceive(Object message) throws Exception {
+    public void handleCommand(Object message) throws Exception {
         if (message.getClass().equals(FindPrimary.SERIALIZABLE_CLASS)) {
             findPrimary(
                 FindPrimary.fromSerializable(message));
         if (message.getClass().equals(FindPrimary.SERIALIZABLE_CLASS)) {
             findPrimary(
                 FindPrimary.fromSerializable(message));
@@ -125,6 +140,23 @@ public class ShardManager extends AbstractUntypedActorWithMetering {
 
     }
 
 
     }
 
+    @Override protected void handleRecover(Object message) throws Exception {
+
+        if(message instanceof SchemaContextModules){
+            SchemaContextModules msg = (SchemaContextModules) message;
+            knownModules.clear();
+            knownModules.addAll(msg.getModules());
+        } else if(message instanceof RecoveryFailure){
+            RecoveryFailure failure = (RecoveryFailure) message;
+            LOG.error(failure.cause(), "Recovery failed");
+        } else if(message instanceof RecoveryCompleted){
+            LOG.info("Recovery complete : {}", persistenceId());
+
+            // Delete all the messages from the akka journal except the last one
+            deleteMessages(lastSequenceNr() - 1);
+        }
+    }
+
     private void findLocalShard(FindLocalShard message) {
         ShardInformation shardInformation =
             localShards.get(message.getShardName());
     private void findLocalShard(FindLocalShard message) {
         ShardInformation shardInformation =
             localShards.get(message.getShardName());
@@ -159,16 +191,42 @@ public class ShardManager extends AbstractUntypedActorWithMetering {
      *
      * @param message
      */
      *
      * @param message
      */
-    private void updateSchemaContext(Object message) {
-        SchemaContext schemaContext = ((UpdateSchemaContext) message).getSchemaContext();
+    private void updateSchemaContext(final Object message) {
+        final SchemaContext schemaContext = ((UpdateSchemaContext) message).getSchemaContext();
+
+        Set<ModuleIdentifier> allModuleIdentifiers = schemaContext.getAllModuleIdentifiers();
+        Set<String> newModules = new HashSet<>(128);
+
+        for(ModuleIdentifier moduleIdentifier : allModuleIdentifiers){
+            String s = moduleIdentifier.getNamespace().toString();
+            newModules.add(s);
+        }
+
+        if(newModules.containsAll(knownModules)) {
+
+            LOG.info("New SchemaContext has a super set of current knownModules - persisting info");
+
+            knownModules.clear();
+            knownModules.addAll(newModules);
+
+            persist(new SchemaContextModules(newModules), new Procedure<SchemaContextModules>() {
 
 
-        if(localShards.size() == 0){
-            createLocalShards(schemaContext);
+                @Override public void apply(SchemaContextModules param) throws Exception {
+                    LOG.info("Sending new SchemaContext to Shards");
+                    if (localShards.size() == 0) {
+                        createLocalShards(schemaContext);
+                    } else {
+                        for (ShardInformation info : localShards.values()) {
+                            info.getActor().tell(message, getSelf());
+                        }
+                    }
+                }
+
+            });
         } else {
         } else {
-            for (ShardInformation info : localShards.values()) {
-                info.getActor().tell(message, getSelf());
-            }
+            LOG.info("Rejecting schema context update because it is not a super set of previously known modules");
         }
         }
+
     }
 
     private void findPrimary(FindPrimary message) {
     }
 
     private void findPrimary(FindPrimary message) {
@@ -306,6 +364,14 @@ public class ShardManager extends AbstractUntypedActorWithMetering {
 
     }
 
 
     }
 
+    @Override public String persistenceId() {
+        return "shard-manager-" + type;
+    }
+
+    @VisibleForTesting public Collection<String> getKnownModules() {
+        return knownModules;
+    }
+
     private class ShardInformation {
         private final String shardName;
         private final ActorRef actor;
     private class ShardInformation {
         private final String shardName;
         private final ActorRef actor;
@@ -371,6 +437,18 @@ public class ShardManager extends AbstractUntypedActorWithMetering {
             return new ShardManager(type, cluster, configuration, datastoreContext);
         }
     }
             return new ShardManager(type, cluster, configuration, datastoreContext);
         }
     }
+
+    static class SchemaContextModules implements Serializable {
+        private final Set<String> modules;
+
+        SchemaContextModules(Set<String> modules){
+            this.modules = modules;
+        }
+
+        public Set<String> getModules() {
+            return modules;
+        }
+    }
 }
 
 
 }
 
 
index 02201f7cd1672d2c8d02415ae4d65d8b71432f8e..8a3cdd0c8aa3b9890811c8a52318c8c18051d7b8 100644 (file)
@@ -3,10 +3,23 @@ package org.opendaylight.controller.cluster.datastore;
 import akka.actor.ActorRef;
 import akka.actor.ActorSystem;
 import akka.actor.Props;
 import akka.actor.ActorRef;
 import akka.actor.ActorSystem;
 import akka.actor.Props;
+import akka.dispatch.Futures;
+import akka.japi.Procedure;
+import akka.persistence.PersistentConfirmation;
+import akka.persistence.PersistentId;
+import akka.persistence.PersistentImpl;
+import akka.persistence.PersistentRepr;
+import akka.persistence.journal.japi.AsyncWriteJournal;
 import akka.testkit.JavaTestKit;
 import akka.testkit.TestActorRef;
 import akka.testkit.JavaTestKit;
 import akka.testkit.TestActorRef;
-import junit.framework.Assert;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Maps;
+import com.google.common.util.concurrent.Uninterruptibles;
+import com.typesafe.config.Config;
+import com.typesafe.config.ConfigFactory;
+import com.typesafe.config.ConfigValueFactory;
 import org.junit.AfterClass;
 import org.junit.AfterClass;
+import org.junit.Before;
 import org.junit.BeforeClass;
 import org.junit.Test;
 import org.opendaylight.controller.cluster.datastore.messages.FindLocalShard;
 import org.junit.BeforeClass;
 import org.junit.Test;
 import org.opendaylight.controller.cluster.datastore.messages.FindLocalShard;
@@ -19,17 +32,41 @@ import org.opendaylight.controller.cluster.datastore.messages.UpdateSchemaContex
 import org.opendaylight.controller.cluster.datastore.utils.MockClusterWrapper;
 import org.opendaylight.controller.cluster.datastore.utils.MockConfiguration;
 import org.opendaylight.controller.md.cluster.datastore.model.TestModel;
 import org.opendaylight.controller.cluster.datastore.utils.MockClusterWrapper;
 import org.opendaylight.controller.cluster.datastore.utils.MockConfiguration;
 import org.opendaylight.controller.md.cluster.datastore.model.TestModel;
-import scala.concurrent.duration.Duration;
+import org.opendaylight.yangtools.yang.model.api.ModuleIdentifier;
+import org.opendaylight.yangtools.yang.model.api.SchemaContext;
+import scala.concurrent.Future;
+
+import java.net.URI;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.Callable;
+import java.util.concurrent.TimeUnit;
 
 import static junit.framework.Assert.assertEquals;
 
 import static junit.framework.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
 
 public class ShardManagerTest {
     private static ActorSystem system;
 
     @BeforeClass
 
 public class ShardManagerTest {
     private static ActorSystem system;
 
     @BeforeClass
-    public static void setUp() {
-        system = ActorSystem.create("test");
+    public static void setUpClass() {
+        Map<String, String> myJournal = new HashMap<>();
+        myJournal.put("class", "org.opendaylight.controller.cluster.datastore.ShardManagerTest$MyJournal");
+        myJournal.put("plugin-dispatcher", "akka.actor.default-dispatcher");
+        Config config = ConfigFactory.load()
+            .withValue("akka.persistence.journal.plugin",
+                ConfigValueFactory.fromAnyRef("my-journal"))
+            .withValue("my-journal", ConfigValueFactory.fromMap(myJournal));
+
+        MyJournal.clear();
+
+        system = ActorSystem.create("test", config);
     }
 
     @AfterClass
     }
 
     @AfterClass
@@ -38,29 +75,27 @@ public class ShardManagerTest {
         system = null;
     }
 
         system = null;
     }
 
+    @Before
+    public void setUpTest(){
+        MyJournal.clear();
+    }
+
     @Test
     public void testOnReceiveFindPrimaryForNonExistentShard() throws Exception {
 
     @Test
     public void testOnReceiveFindPrimaryForNonExistentShard() throws Exception {
 
-        new JavaTestKit(system) {{
-            final Props props = ShardManager
-                .props("config", new MockClusterWrapper(),
-                    new MockConfiguration(), new DatastoreContext());
-            final TestActorRef<ShardManager> subject =
-                TestActorRef.create(system, props);
+        new JavaTestKit(system) {
+            {
+                final Props props = ShardManager
+                    .props("config", new MockClusterWrapper(),
+                        new MockConfiguration(), new DatastoreContext());
 
 
-            new Within(duration("10 seconds")) {
-                @Override
-                protected void run() {
+                final ActorRef subject = getSystem().actorOf(props);
 
 
-                    subject.tell(new FindPrimary("inventory").toSerializable(), getRef());
+                subject.tell(new FindPrimary("inventory").toSerializable(), getRef());
 
 
-                    expectMsgEquals(Duration.Zero(),
-                        new PrimaryNotFound("inventory").toSerializable());
-
-                    expectNoMsg();
-                }
-            };
-        }};
+                expectMsgEquals(duration("2 seconds"),
+                    new PrimaryNotFound("inventory").toSerializable());
+            }};
     }
 
     @Test
     }
 
     @Test
@@ -70,22 +105,14 @@ public class ShardManagerTest {
             final Props props = ShardManager
                 .props("config", new MockClusterWrapper(),
                     new MockConfiguration(), new DatastoreContext());
             final Props props = ShardManager
                 .props("config", new MockClusterWrapper(),
                     new MockConfiguration(), new DatastoreContext());
-            final TestActorRef<ShardManager> subject =
-                TestActorRef.create(system, props);
-
-            subject.tell(new UpdateSchemaContext(TestModel.createTestContext()), getRef());
 
 
-            new Within(duration("10 seconds")) {
-                @Override
-                protected void run() {
+            final ActorRef subject = getSystem().actorOf(props);
 
 
-                    subject.tell(new FindPrimary(Shard.DEFAULT_NAME).toSerializable(), getRef());
+            subject.tell(new UpdateSchemaContext(TestModel.createTestContext()), getRef());
 
 
-                    expectMsgClass(duration("1 seconds"), PrimaryFound.SERIALIZABLE_CLASS);
+            subject.tell(new FindPrimary(Shard.DEFAULT_NAME).toSerializable(), getRef());
 
 
-                    expectNoMsg();
-                }
-            };
+            expectMsgClass(duration("1 seconds"), PrimaryFound.SERIALIZABLE_CLASS);
         }};
     }
 
         }};
     }
 
@@ -96,31 +123,23 @@ public class ShardManagerTest {
             final Props props = ShardManager
                 .props("config", new MockClusterWrapper(),
                     new MockConfiguration(), new DatastoreContext());
             final Props props = ShardManager
                 .props("config", new MockClusterWrapper(),
                     new MockConfiguration(), new DatastoreContext());
-            final TestActorRef<ShardManager> subject =
-                TestActorRef.create(system, props);
 
 
-            new Within(duration("10 seconds")) {
-                @Override
-                protected void run() {
-
-                    subject.tell(new FindLocalShard("inventory"), getRef());
-
-                    final String out = new ExpectMsg<String>(duration("10 seconds"), "find local") {
-                        @Override
-                        protected String match(Object in) {
-                            if (in instanceof LocalShardNotFound) {
-                                return ((LocalShardNotFound) in).getShardName();
-                            } else {
-                                throw noMatch();
-                            }
-                        }
-                    }.get(); // this extracts the received message
+            final ActorRef subject = getSystem().actorOf(props);
 
 
-                    assertEquals("inventory", out);
+            subject.tell(new FindLocalShard("inventory"), getRef());
 
 
-                    expectNoMsg();
+            final String out = new ExpectMsg<String>(duration("3 seconds"), "find local") {
+                @Override
+                protected String match(Object in) {
+                    if (in instanceof LocalShardNotFound) {
+                        return ((LocalShardNotFound) in).getShardName();
+                    } else {
+                        throw noMatch();
+                    }
                 }
                 }
-            };
+            }.get(); // this extracts the received message
+
+            assertEquals("inventory", out);
         }};
     }
 
         }};
     }
 
@@ -133,40 +152,109 @@ public class ShardManagerTest {
             final Props props = ShardManager
                 .props("config", mockClusterWrapper,
                     new MockConfiguration(), new DatastoreContext());
             final Props props = ShardManager
                 .props("config", mockClusterWrapper,
                     new MockConfiguration(), new DatastoreContext());
-            final TestActorRef<ShardManager> subject =
-                TestActorRef.create(system, props);
+
+            final ActorRef subject = getSystem().actorOf(props);
 
             subject.tell(new UpdateSchemaContext(TestModel.createTestContext()), getRef());
 
 
             subject.tell(new UpdateSchemaContext(TestModel.createTestContext()), getRef());
 
-            new Within(duration("10 seconds")) {
+            subject.tell(new FindLocalShard(Shard.DEFAULT_NAME), getRef());
+
+            final ActorRef out = new ExpectMsg<ActorRef>(duration("3 seconds"), "find local") {
                 @Override
                 @Override
-                protected void run() {
-
-                    subject.tell(new FindLocalShard(Shard.DEFAULT_NAME), getRef());
-
-                    final ActorRef out = new ExpectMsg<ActorRef>(duration("10 seconds"), "find local") {
-                        @Override
-                        protected ActorRef match(Object in) {
-                            if (in instanceof LocalShardFound) {
-                                return ((LocalShardFound) in).getPath();
-                            } else {
-                                throw noMatch();
-                            }
-                        }
-                    }.get(); // this extracts the received message
+                protected ActorRef match(Object in) {
+                    if (in instanceof LocalShardFound) {
+                        return ((LocalShardFound) in).getPath();
+                    } else {
+                        throw noMatch();
+                    }
+                }
+            }.get(); // this extracts the received message
+
+            assertTrue(out.path().toString(),
+                out.path().toString().contains("member-1-shard-default-config"));
+        }};
+    }
+
+    @Test
+    public void testOnReceiveMemberUp() throws Exception {
+
+        new JavaTestKit(system) {{
+            final Props props = ShardManager
+                .props("config", new MockClusterWrapper(),
+                    new MockConfiguration(), new DatastoreContext());
 
 
-                    assertTrue(out.path().toString(), out.path().toString().contains("member-1-shard-default-config"));
+            final ActorRef subject = getSystem().actorOf(props);
 
 
+            MockClusterWrapper.sendMemberUp(subject, "member-2", getRef().path().toString());
 
 
-                    expectNoMsg();
+            subject.tell(new FindPrimary("astronauts").toSerializable(), getRef());
+
+            final String out = new ExpectMsg<String>(duration("3 seconds"), "primary found") {
+                // do not put code outside this method, will run afterwards
+                @Override
+                protected String match(Object in) {
+                    if (in.getClass().equals(PrimaryFound.SERIALIZABLE_CLASS)) {
+                        PrimaryFound f = PrimaryFound.fromSerializable(in);
+                        return f.getPrimaryPath();
+                    } else {
+                        throw noMatch();
+                    }
                 }
                 }
-            };
+            }.get(); // this extracts the received message
+
+            assertTrue(out, out.contains("member-2-shard-astronauts-config"));
         }};
     }
 
     @Test
         }};
     }
 
     @Test
-    public void testOnReceiveMemberUp() throws Exception {
+    public void testOnReceiveMemberDown() throws Exception {
 
 
+        new JavaTestKit(system) {{
+            final Props props = ShardManager
+                .props("config", new MockClusterWrapper(),
+                    new MockConfiguration(), new DatastoreContext());
+
+            final ActorRef subject = getSystem().actorOf(props);
+
+            MockClusterWrapper.sendMemberUp(subject, "member-2", getRef().path().toString());
+
+            subject.tell(new FindPrimary("astronauts").toSerializable(), getRef());
+
+            expectMsgClass(duration("3 seconds"), PrimaryFound.SERIALIZABLE_CLASS);
+
+            MockClusterWrapper.sendMemberRemoved(subject, "member-2", getRef().path().toString());
+
+            subject.tell(new FindPrimary("astronauts").toSerializable(), getRef());
+
+            expectMsgClass(duration("1 seconds"), PrimaryNotFound.SERIALIZABLE_CLASS);
+        }};
+    }
+
+    @Test
+    public void testOnRecoveryJournalIsEmptied(){
+        MyJournal.addToJournal(1L, new ShardManager.SchemaContextModules(
+            ImmutableSet.of("foo")));
+
+        assertEquals(1, MyJournal.get().size());
+
+        new JavaTestKit(system) {{
+            final Props props = ShardManager
+                .props("config", new MockClusterWrapper(),
+                    new MockConfiguration(), new DatastoreContext());
+
+            final ActorRef subject = getSystem().actorOf(props);
+
+            // Send message to check that ShardManager is ready
+            subject.tell(new FindPrimary("unknown").toSerializable(), getRef());
+
+            expectMsgClass(duration("3 seconds"), PrimaryNotFound.SERIALIZABLE_CLASS);
+
+            assertEquals(0, MyJournal.get().size());
+        }};
+    }
+
+    @Test
+    public void testOnRecoveryPreviouslyKnownModulesAreDiscovered() throws Exception {
         new JavaTestKit(system) {{
             final Props props = ShardManager
                 .props("config", new MockClusterWrapper(),
         new JavaTestKit(system) {{
             final Props props = ShardManager
                 .props("config", new MockClusterWrapper(),
@@ -174,39 +262,63 @@ public class ShardManagerTest {
             final TestActorRef<ShardManager> subject =
                 TestActorRef.create(system, props);
 
             final TestActorRef<ShardManager> subject =
                 TestActorRef.create(system, props);
 
-            // the run() method needs to finish within 3 seconds
-            new Within(duration("10 seconds")) {
-                @Override
-                protected void run() {
-
-                    MockClusterWrapper.sendMemberUp(subject, "member-2", getRef().path().toString());
-
-                    subject.tell(new FindPrimary("astronauts").toSerializable(), getRef());
-
-                    final String out = new ExpectMsg<String>(duration("1 seconds"), "primary found") {
-                        // do not put code outside this method, will run afterwards
-                        @Override
-                        protected String match(Object in) {
-                            if (in.getClass().equals(PrimaryFound.SERIALIZABLE_CLASS)) {
-                                PrimaryFound f = PrimaryFound.fromSerializable(in);
-                                return f.getPrimaryPath();
-                            } else {
-                                throw noMatch();
-                            }
-                        }
-                    }.get(); // this extracts the received message
+            subject.underlyingActor().onReceiveRecover(new ShardManager.SchemaContextModules(ImmutableSet.of("foo")));
 
 
-                    Assert.assertTrue(out, out.contains("member-2-shard-astronauts-config"));
+            Collection<String> knownModules = subject.underlyingActor().getKnownModules();
 
 
-                    expectNoMsg();
-                }
-            };
+            assertTrue(knownModules.contains("foo"));
         }};
     }
 
     @Test
         }};
     }
 
     @Test
-    public void testOnReceiveMemberDown() throws Exception {
+    public void testOnUpdateSchemaContextUpdateKnownModulesIfTheyContainASuperSetOfTheKnownModules()
+        throws Exception {
+        new JavaTestKit(system) {{
+            final Props props = ShardManager
+                .props("config", new MockClusterWrapper(),
+                    new MockConfiguration(), new DatastoreContext());
+            final TestActorRef<ShardManager> subject =
+                TestActorRef.create(system, props);
+
+            Collection<String> knownModules = subject.underlyingActor().getKnownModules();
+
+            assertEquals(0, knownModules.size());
+
+            SchemaContext schemaContext = mock(SchemaContext.class);
+            Set<ModuleIdentifier> moduleIdentifierSet = new HashSet<>();
+
+            ModuleIdentifier foo = mock(ModuleIdentifier.class);
+            when(foo.getNamespace()).thenReturn(new URI("foo"));
+
+            moduleIdentifierSet.add(foo);
+
+            when(schemaContext.getAllModuleIdentifiers()).thenReturn(moduleIdentifierSet);
+
+            subject.underlyingActor().onReceiveCommand(new UpdateSchemaContext(schemaContext));
+
+            assertTrue(knownModules.contains("foo"));
+
+            assertEquals(1, knownModules.size());
+
+            ModuleIdentifier bar = mock(ModuleIdentifier.class);
+            when(bar.getNamespace()).thenReturn(new URI("bar"));
+
+            moduleIdentifierSet.add(bar);
+
+            subject.underlyingActor().onReceiveCommand(new UpdateSchemaContext(schemaContext));
+
+            assertTrue(knownModules.contains("bar"));
 
 
+            assertEquals(2, knownModules.size());
+
+        }};
+
+    }
+
+
+    @Test
+    public void testOnUpdateSchemaContextDoNotUpdateKnownModulesIfTheyDoNotContainASuperSetOfKnownModules()
+        throws Exception {
         new JavaTestKit(system) {{
             final Props props = ShardManager
                 .props("config", new MockClusterWrapper(),
         new JavaTestKit(system) {{
             final Props props = ShardManager
                 .props("config", new MockClusterWrapper(),
@@ -214,28 +326,117 @@ public class ShardManagerTest {
             final TestActorRef<ShardManager> subject =
                 TestActorRef.create(system, props);
 
             final TestActorRef<ShardManager> subject =
                 TestActorRef.create(system, props);
 
-            // the run() method needs to finish within 3 seconds
-            new Within(duration("10 seconds")) {
-                @Override
-                protected void run() {
+            Collection<String> knownModules = subject.underlyingActor().getKnownModules();
 
 
-                    MockClusterWrapper.sendMemberUp(subject, "member-2", getRef().path().toString());
+            assertEquals(0, knownModules.size());
 
 
-                    subject.tell(new FindPrimary("astronauts").toSerializable(), getRef());
+            SchemaContext schemaContext = mock(SchemaContext.class);
+            Set<ModuleIdentifier> moduleIdentifierSet = new HashSet<>();
 
 
-                    expectMsgClass(duration("1 seconds"), PrimaryFound.SERIALIZABLE_CLASS);
+            ModuleIdentifier foo = mock(ModuleIdentifier.class);
+            when(foo.getNamespace()).thenReturn(new URI("foo"));
 
 
-                    MockClusterWrapper.sendMemberRemoved(subject, "member-2", getRef().path().toString());
+            moduleIdentifierSet.add(foo);
 
 
-                    subject.tell(new FindPrimary("astronauts").toSerializable(), getRef());
+            when(schemaContext.getAllModuleIdentifiers()).thenReturn(moduleIdentifierSet);
 
 
-                    expectMsgClass(duration("1 seconds"), PrimaryNotFound.SERIALIZABLE_CLASS);
+            subject.underlyingActor().onReceiveCommand(new UpdateSchemaContext(schemaContext));
+
+            assertTrue(knownModules.contains("foo"));
+
+            assertEquals(1, knownModules.size());
+
+            //Create a completely different SchemaContext with only the bar module in it
+            schemaContext = mock(SchemaContext.class);
+            moduleIdentifierSet = new HashSet<>();
+            ModuleIdentifier bar = mock(ModuleIdentifier.class);
+            when(bar.getNamespace()).thenReturn(new URI("bar"));
+
+            moduleIdentifierSet.add(bar);
+
+            subject.underlyingActor().onReceiveCommand(new UpdateSchemaContext(schemaContext));
+
+            assertFalse(knownModules.contains("bar"));
+
+            assertEquals(1, knownModules.size());
 
 
-                    expectNoMsg();
-                }
-            };
         }};
         }};
+
+    }
+
+
+    private void sleep(long period){
+        Uninterruptibles.sleepUninterruptibly(period, TimeUnit.MILLISECONDS);
     }
 
     }
 
+    public static class MyJournal extends AsyncWriteJournal {
+
+        private static Map<Long, Object> journal = Maps.newTreeMap();
+
+        public static void addToJournal(Long sequenceNr, Object value){
+            journal.put(sequenceNr, value);
+        }
+
+        public static Map<Long, Object> get(){
+            return journal;
+        }
+
+        public static void clear(){
+            journal.clear();
+        }
 
 
+        @Override public Future<Void> doAsyncReplayMessages(final String persistenceId, long fromSequenceNr, long toSequenceNr, long max,
+            final Procedure<PersistentRepr> replayCallback) {
+            if(journal.size() == 0){
+                return Futures.successful(null);
+            }
+            return Futures.future(new Callable<Void>() {
+                @Override
+                public Void call() throws Exception {
+                    for (Map.Entry<Long, Object> entry : journal.entrySet()) {
+                        PersistentRepr persistentMessage =
+                            new PersistentImpl(entry.getValue(), entry.getKey(), persistenceId,
+                                false, null, null);
+                        replayCallback.apply(persistentMessage);
+                    }
+                    return null;
+                }
+            }, context().dispatcher());
+        }
+
+        @Override public Future<Long> doAsyncReadHighestSequenceNr(String s, long l) {
+            return Futures.successful(-1L);
+        }
+
+        @Override public Future<Void> doAsyncWriteMessages(
+            final Iterable<PersistentRepr> persistentReprs) {
+            return Futures.future(new Callable<Void>() {
+                @Override
+                public Void call() throws Exception {
+                    for (PersistentRepr repr : persistentReprs){
+                        if(repr.payload() instanceof ShardManager.SchemaContextModules) {
+                            journal.put(repr.sequenceNr(), repr.payload());
+                        }
+                    }
+                    return null;
+                }
+            }, context().dispatcher());
+        }
+
+        @Override public Future<Void> doAsyncWriteConfirmations(
+            Iterable<PersistentConfirmation> persistentConfirmations) {
+            return Futures.successful(null);
+        }
+
+        @Override public Future<Void> doAsyncDeleteMessages(Iterable<PersistentId> persistentIds,
+            boolean b) {
+            clear();
+            return Futures.successful(null);
+        }
+
+        @Override public Future<Void> doAsyncDeleteMessagesTo(String s, long l, boolean b) {
+            clear();
+            return Futures.successful(null);
+        }
+    }
 }
 }