CDS: ShardManagerTest cleanup 36/30036/5
authorTom Pantelis <tpanteli@brocade.com>
Thu, 19 Nov 2015 18:55:05 +0000 (13:55 -0500)
committerGerrit Code Review <gerrit@opendaylight.org>
Wed, 2 Dec 2015 10:57:29 +0000 (10:57 +0000)
Added a Builder to TestShardManager and modified tests to use the
Builder to gain more consistency between the tests. As a result, I made the
ShardManager builder class abstract for derivation same as was done for
the Shard builder.

I removed the ForwardingShardManager class and merged the functionality
into TestShardManager.

Change-Id: I55471b388a40b9da68bdb249f4cc597d2a0e7f90
Signed-off-by: Tom Pantelis <tpanteli@brocade.com>
opendaylight/md-sal/sal-distributed-datastore/src/main/java/org/opendaylight/controller/cluster/datastore/ShardManager.java
opendaylight/md-sal/sal-distributed-datastore/src/test/java/org/opendaylight/controller/cluster/datastore/ShardManagerTest.java

index b5e1ca4f1a5b3a3793de1e63646f88d593bb627f..fabfd096104dcab91ee6a4ffabd5758a50466e64 100644 (file)
@@ -145,7 +145,7 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
 
     /**
      */
-    protected ShardManager(Builder builder) {
+    protected ShardManager(AbstractBuilder<?> builder) {
 
         this.cluster = builder.cluster;
         this.configuration = builder.configuration;
@@ -1415,7 +1415,7 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
         return new Builder();
     }
 
-    public static class Builder {
+    public static abstract class AbstractBuilder<T extends AbstractBuilder<T>> {
         private ClusterWrapper cluster;
         private Configuration configuration;
         private DatastoreContextFactory datastoreContextFactory;
@@ -1424,44 +1424,49 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
         private DatastoreSnapshot restoreFromSnapshot;
         private volatile boolean sealed;
 
+        @SuppressWarnings("unchecked")
+        private T self() {
+            return (T) this;
+        }
+
         protected void checkSealed() {
             Preconditions.checkState(!sealed, "Builder is already sealed - further modifications are not allowed");
         }
 
-        public Builder cluster(ClusterWrapper cluster) {
+        public T cluster(ClusterWrapper cluster) {
             checkSealed();
             this.cluster = cluster;
-            return this;
+            return self();
         }
 
-        public Builder configuration(Configuration configuration) {
+        public T configuration(Configuration configuration) {
             checkSealed();
             this.configuration = configuration;
-            return this;
+            return self();
         }
 
-        public Builder datastoreContextFactory(DatastoreContextFactory datastoreContextFactory) {
+        public T datastoreContextFactory(DatastoreContextFactory datastoreContextFactory) {
             checkSealed();
             this.datastoreContextFactory = datastoreContextFactory;
-            return this;
+            return self();
         }
 
-        public Builder waitTillReadyCountdownLatch(CountDownLatch waitTillReadyCountdownLatch) {
+        public T waitTillReadyCountdownLatch(CountDownLatch waitTillReadyCountdownLatch) {
             checkSealed();
             this.waitTillReadyCountdownLatch = waitTillReadyCountdownLatch;
-            return this;
+            return self();
         }
 
-        public Builder primaryShardInfoCache(PrimaryShardInfoFutureCache primaryShardInfoCache) {
+        public T primaryShardInfoCache(PrimaryShardInfoFutureCache primaryShardInfoCache) {
             checkSealed();
             this.primaryShardInfoCache = primaryShardInfoCache;
-            return this;
+            return self();
         }
 
-        public Builder restoreFromSnapshot(DatastoreSnapshot restoreFromSnapshot) {
+        public T restoreFromSnapshot(DatastoreSnapshot restoreFromSnapshot) {
             checkSealed();
             this.restoreFromSnapshot = restoreFromSnapshot;
-            return this;
+            return self();
         }
 
         protected void verify() {
@@ -1479,6 +1484,9 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
         }
     }
 
+    public static class Builder extends AbstractBuilder<Builder> {
+    }
+
     private void findPrimary(final String shardName, final FindPrimaryResponseHandler handler) {
         Timeout findPrimaryTimeout = new Timeout(datastoreContextFactory.getBaseDatastoreContext().
                 getShardInitializationTimeout().duration().$times(2));
index 3257e8f910e0538ec99bbd0f608daf2b5056e1bd..e5f2ee0b9d8b81bcb01ba7ab467d86ef2a6562ab 100644 (file)
@@ -40,16 +40,17 @@ import com.google.common.base.Function;
 import com.google.common.base.Optional;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Lists;
 import com.google.common.collect.Sets;
 import com.google.common.util.concurrent.Uninterruptibles;
 import com.typesafe.config.ConfigFactory;
 import java.net.URI;
 import java.util.AbstractMap;
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
-import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Map.Entry;
@@ -104,8 +105,8 @@ import org.opendaylight.controller.cluster.raft.client.messages.GetSnapshot;
 import org.opendaylight.controller.cluster.raft.messages.AddServer;
 import org.opendaylight.controller.cluster.raft.messages.AddServerReply;
 import org.opendaylight.controller.cluster.raft.messages.ServerChangeStatus;
-import org.opendaylight.controller.cluster.raft.policy.DisableElectionsRaftPolicy;
 import org.opendaylight.controller.cluster.raft.messages.ServerRemoved;
+import org.opendaylight.controller.cluster.raft.policy.DisableElectionsRaftPolicy;
 import org.opendaylight.controller.cluster.raft.utils.InMemoryJournal;
 import org.opendaylight.controller.cluster.raft.utils.InMemorySnapshotStore;
 import org.opendaylight.controller.cluster.raft.utils.MessageCollectorActor;
@@ -138,7 +139,7 @@ public class ShardManagerTest extends AbstractActorTest {
         return TestActorRef.create(system, Props.create(MessageCollectorActor.class), name);
     }
 
-    private final PrimaryShardInfoFutureCache primaryShardInfoCache = new PrimaryShardInfoFutureCache();
+    private final Collection<ActorSystem> actorSystems = new ArrayList<>();
 
     @Before
     public void setUp() {
@@ -159,6 +160,16 @@ public class ShardManagerTest extends AbstractActorTest {
     public void tearDown() {
         InMemoryJournal.clear();
         InMemorySnapshotStore.clear();
+
+        for(ActorSystem system: actorSystems) {
+            JavaTestKit.shutdownActorSystem(system, null, Boolean.TRUE);
+        }
+    }
+
+    private ActorSystem newActorSystem(String config) {
+        ActorSystem system = ActorSystem.create("cluster-test", ConfigFactory.load().getConfig(config));
+        actorSystems.add(system);
+        return system;
     }
 
     private Props newShardMgrProps() {
@@ -172,28 +183,24 @@ public class ShardManagerTest extends AbstractActorTest {
         return mockFactory;
     }
 
-    private Props newShardMgrProps(Configuration config) {
-        return TestShardManager.builder(datastoreContextBuilder).configuration(config).props();
+    private TestShardManager.Builder newTestShardMgrBuilder() {
+        return TestShardManager.builder(datastoreContextBuilder);
     }
 
-    private Props newPropsShardMgrWithMockShardActor() {
-        return newPropsShardMgrWithMockShardActor("shardManager", mockShardActor, new MockClusterWrapper(),
-                new MockConfiguration());
+    private TestShardManager.Builder newTestShardMgrBuilder(Configuration config) {
+        return TestShardManager.builder(datastoreContextBuilder).configuration(config);
     }
 
-    private Props newPropsShardMgrWithMockShardActor(final String name, final ActorRef shardActor,
-            final ClusterWrapper clusterWrapper, final Configuration config) {
-        Creator<ShardManager> creator = new Creator<ShardManager>() {
-            private static final long serialVersionUID = 1L;
-            @Override
-            public ShardManager create() throws Exception {
-                return new ForwardingShardManager(ShardManager.builder().cluster(clusterWrapper).configuration(config).
-                        datastoreContextFactory(newDatastoreContextFactory(datastoreContextBuilder.build())).
-                        waitTillReadyCountdownLatch(ready).primaryShardInfoCache(primaryShardInfoCache), name, shardActor);
-            }
-        };
+    private Props newShardMgrProps(Configuration config) {
+        return newTestShardMgrBuilder(config).props();
+    }
+
+    private TestShardManager.Builder newTestShardMgrBuilderWithMockShardActor() {
+        return TestShardManager.builder(datastoreContextBuilder).shardActor(mockShardActor);
+    }
 
-        return Props.create(new DelegatingShardManagerCreator(creator)).withDispatcher(Dispatchers.DefaultDispatcherId());
+    private Props newPropsShardMgrWithMockShardActor() {
+        return newTestShardMgrBuilderWithMockShardActor().props();
     }
 
     private TestShardManager newTestShardManager() {
@@ -240,27 +247,34 @@ public class ShardManagerTest extends AbstractActorTest {
         shardInfoMap.put("default", new AbstractMap.SimpleEntry<ActorRef, DatastoreContext>(defaultShardActor, null));
         shardInfoMap.put("topology", new AbstractMap.SimpleEntry<ActorRef, DatastoreContext>(topologyShardActor, null));
 
+        final PrimaryShardInfoFutureCache primaryShardInfoCache = new PrimaryShardInfoFutureCache();
         final CountDownLatch newShardActorLatch = new CountDownLatch(2);
+        class LocalShardManager extends ShardManager {
+            public LocalShardManager(AbstractBuilder<?> builder) {
+                super(builder);
+            }
+
+            @Override
+            protected ActorRef newShardActor(SchemaContext schemaContext, ShardInformation info) {
+                Entry<ActorRef, DatastoreContext> entry = shardInfoMap.get(info.getShardName());
+                ActorRef ref = null;
+                if(entry != null) {
+                    ref = entry.getKey();
+                    entry.setValue(info.getDatastoreContext());
+                }
+
+                newShardActorLatch.countDown();
+                return ref;
+            }
+        }
+
         final Creator<ShardManager> creator = new Creator<ShardManager>() {
             private static final long serialVersionUID = 1L;
             @Override
             public ShardManager create() throws Exception {
-                return new ShardManager(ShardManager.builder().cluster(new MockClusterWrapper()).configuration(mockConfig).
-                        datastoreContextFactory(mockFactory).waitTillReadyCountdownLatch(ready).
-                        primaryShardInfoCache(primaryShardInfoCache)) {
-                    @Override
-                    protected ActorRef newShardActor(SchemaContext schemaContext, ShardInformation info) {
-                        Entry<ActorRef, DatastoreContext> entry = shardInfoMap.get(info.getShardName());
-                        ActorRef ref = null;
-                        if(entry != null) {
-                            ref = entry.getKey();
-                            entry.setValue(info.getDatastoreContext());
-                        }
-
-                        newShardActorLatch.countDown();
-                        return ref;
-                    }
-                };
+                return new LocalShardManager(new GenericBuilder<LocalShardManager>(LocalShardManager.class).
+                        datastoreContextFactory(mockFactory).primaryShardInfoCache(primaryShardInfoCache).
+                        configuration(mockConfig));
             }
         };
 
@@ -543,18 +557,16 @@ public class ShardManagerTest extends AbstractActorTest {
 
         // Create an ActorSystem ShardManager actor for member-1.
 
-        final ActorSystem system1 = ActorSystem.create("cluster-test", ConfigFactory.load().getConfig("Member1"));
+        final ActorSystem system1 = newActorSystem("Member1");
         Cluster.get(system1).join(AddressFromURIString.parse("akka.tcp://cluster-test@127.0.0.1:2558"));
 
-        ActorRef mockShardActor1 = newMockShardActor(system1, Shard.DEFAULT_NAME, "member-1");
-
-        final TestActorRef<ForwardingShardManager> shardManager1 = TestActorRef.create(system1,
-                newPropsShardMgrWithMockShardActor("shardManager1", mockShardActor1, new ClusterWrapperImpl(system1),
-                        new MockConfiguration()), shardManagerID);
+        final TestActorRef<TestShardManager> shardManager1 = TestActorRef.create(system1,
+                newTestShardMgrBuilderWithMockShardActor().cluster(
+                        new ClusterWrapperImpl(system1)).props(), shardManagerID);
 
         // Create an ActorSystem ShardManager actor for member-2.
 
-        final ActorSystem system2 = ActorSystem.create("cluster-test", ConfigFactory.load().getConfig("Member2"));
+        final ActorSystem system2 = newActorSystem("Member2");
 
         Cluster.get(system2).join(AddressFromURIString.parse("akka.tcp://cluster-test@127.0.0.1:2558"));
 
@@ -564,9 +576,9 @@ public class ShardManagerTest extends AbstractActorTest {
                 put("default", Arrays.asList("member-1", "member-2")).
                 put("astronauts", Arrays.asList("member-2")).build());
 
-        final TestActorRef<ForwardingShardManager> shardManager2 = TestActorRef.create(system2,
-                newPropsShardMgrWithMockShardActor("shardManager2", mockShardActor2, new ClusterWrapperImpl(system2),
-                        mockConfig2), shardManagerID);
+        final TestActorRef<TestShardManager> shardManager2 = TestActorRef.create(system2,
+                newTestShardMgrBuilder(mockConfig2).shardActor(mockShardActor2).cluster(
+                        new ClusterWrapperImpl(system2)).props(), shardManagerID);
 
         new JavaTestKit(system1) {{
 
@@ -601,9 +613,6 @@ public class ShardManagerTest extends AbstractActorTest {
 
             expectMsgClass(duration("5 seconds"), PrimaryNotFoundException.class);
         }};
-
-        JavaTestKit.shutdownActorSystem(system1);
-        JavaTestKit.shutdownActorSystem(system2);
     }
 
     @Test
@@ -612,18 +621,18 @@ public class ShardManagerTest extends AbstractActorTest {
 
         // Create an ActorSystem ShardManager actor for member-1.
 
-        final ActorSystem system1 = ActorSystem.create("cluster-test", ConfigFactory.load().getConfig("Member1"));
+        final ActorSystem system1 = newActorSystem("Member1");
         Cluster.get(system1).join(AddressFromURIString.parse("akka.tcp://cluster-test@127.0.0.1:2558"));
 
         final ActorRef mockShardActor1 = newMockShardActor(system1, Shard.DEFAULT_NAME, "member-1");
 
-        final TestActorRef<ForwardingShardManager> shardManager1 = TestActorRef.create(system1,
-            newPropsShardMgrWithMockShardActor("shardManager1", mockShardActor1, new ClusterWrapperImpl(system1),
-                new MockConfiguration()), shardManagerID);
+        final TestActorRef<TestShardManager> shardManager1 = TestActorRef.create(system1,
+                newTestShardMgrBuilder().shardActor(mockShardActor1).cluster(
+                        new ClusterWrapperImpl(system1)).props(), shardManagerID);
 
         // Create an ActorSystem ShardManager actor for member-2.
 
-        final ActorSystem system2 = ActorSystem.create("cluster-test", ConfigFactory.load().getConfig("Member2"));
+        final ActorSystem system2 = newActorSystem("Member2");
 
         Cluster.get(system2).join(AddressFromURIString.parse("akka.tcp://cluster-test@127.0.0.1:2558"));
 
@@ -632,9 +641,9 @@ public class ShardManagerTest extends AbstractActorTest {
         MockConfiguration mockConfig2 = new MockConfiguration(ImmutableMap.<String, List<String>>builder().
             put("default", Arrays.asList("member-1", "member-2")).build());
 
-        final TestActorRef<ForwardingShardManager> shardManager2 = TestActorRef.create(system2,
-            newPropsShardMgrWithMockShardActor("shardManager2", mockShardActor2, new ClusterWrapperImpl(system2),
-                mockConfig2), shardManagerID);
+        final TestActorRef<TestShardManager> shardManager2 = TestActorRef.create(system2,
+                newTestShardMgrBuilder(mockConfig2).shardActor(mockShardActor2).cluster(
+                        new ClusterWrapperImpl(system2)).props(), shardManagerID);
 
         new JavaTestKit(system1) {{
 
@@ -701,9 +710,6 @@ public class ShardManagerTest extends AbstractActorTest {
             MessageCollectorActor.expectFirstMatching(mockShardActor1, PeerUp.class);
 
         }};
-
-        JavaTestKit.shutdownActorSystem(system1);
-        JavaTestKit.shutdownActorSystem(system2);
     }
 
     @Test
@@ -712,18 +718,20 @@ public class ShardManagerTest extends AbstractActorTest {
 
         // Create an ActorSystem ShardManager actor for member-1.
 
-        final ActorSystem system1 = ActorSystem.create("cluster-test", ConfigFactory.load().getConfig("Member1"));
+        final ActorSystem system1 = newActorSystem("Member1");
         Cluster.get(system1).join(AddressFromURIString.parse("akka.tcp://cluster-test@127.0.0.1:2558"));
 
         final ActorRef mockShardActor1 = newMockShardActor(system1, Shard.DEFAULT_NAME, "member-1");
 
-        final TestActorRef<ForwardingShardManager> shardManager1 = TestActorRef.create(system1,
-            newPropsShardMgrWithMockShardActor("shardManager1", mockShardActor1, new ClusterWrapperImpl(system1),
-                new MockConfiguration()), shardManagerID);
+        final PrimaryShardInfoFutureCache primaryShardInfoCache = new PrimaryShardInfoFutureCache();
+        final TestActorRef<TestShardManager> shardManager1 = TestActorRef.create(system1,
+                newTestShardMgrBuilder().shardActor(mockShardActor1).cluster(
+                        new ClusterWrapperImpl(system1)).primaryShardInfoCache(primaryShardInfoCache).props(),
+                shardManagerID);
 
         // Create an ActorSystem ShardManager actor for member-2.
 
-        final ActorSystem system2 = ActorSystem.create("cluster-test", ConfigFactory.load().getConfig("Member2"));
+        final ActorSystem system2 = newActorSystem("Member2");
 
         Cluster.get(system2).join(AddressFromURIString.parse("akka.tcp://cluster-test@127.0.0.1:2558"));
 
@@ -732,12 +740,11 @@ public class ShardManagerTest extends AbstractActorTest {
         MockConfiguration mockConfig2 = new MockConfiguration(ImmutableMap.<String, List<String>>builder().
             put("default", Arrays.asList("member-1", "member-2")).build());
 
-        final TestActorRef<ForwardingShardManager> shardManager2 = TestActorRef.create(system2,
-            newPropsShardMgrWithMockShardActor("shardManager2", mockShardActor2, new ClusterWrapperImpl(system2),
-                mockConfig2), shardManagerID);
+        final TestActorRef<TestShardManager> shardManager2 = TestActorRef.create(system2,
+                newTestShardMgrBuilder(mockConfig2).shardActor(mockShardActor2).cluster(
+                        new ClusterWrapperImpl(system2)).props(), shardManagerID);
 
         new JavaTestKit(system1) {{
-
             shardManager1.tell(new UpdateSchemaContext(TestModel.createTestContext()), getRef());
             shardManager2.tell(new UpdateSchemaContext(TestModel.createTestContext()), getRef());
             shardManager1.tell(new ActorInitialized(), mockShardActor1);
@@ -788,9 +795,6 @@ public class ShardManagerTest extends AbstractActorTest {
             assertTrue("Unexpected primary path " + path1, path1.contains("member-1-shard-default-config"));
 
         }};
-
-        JavaTestKit.shutdownActorSystem(system1);
-        JavaTestKit.shutdownActorSystem(system2);
     }
 
 
@@ -1185,13 +1189,17 @@ public class ShardManagerTest extends AbstractActorTest {
         DatastoreSnapshot datastoreSnapshot = kit.expectMsgClass(DatastoreSnapshot.class);
 
         assertEquals("getType", shardMrgIDSuffix, datastoreSnapshot.getType());
-        List<ShardSnapshot> shardSnapshots = datastoreSnapshot.getShardSnapshots();
-        Set<String> actualShardNames = new HashSet<>();
-        for(ShardSnapshot s: shardSnapshots) {
-            actualShardNames.add(s.getName());
-        }
+        assertNull("Expected null ShardManagerSnapshot", datastoreSnapshot.getShardManagerSnapshot());
+
+        Function<ShardSnapshot, String> shardNameTransformer = new Function<ShardSnapshot, String>() {
+            @Override
+            public String apply(ShardSnapshot s) {
+                return s.getName();
+            }
+        };
 
-        assertEquals("Shard names", Sets.newHashSet("shard1", "shard2"), actualShardNames);
+        assertEquals("Shard names", Sets.newHashSet("shard1", "shard2"), Sets.newHashSet(
+                Lists.transform(datastoreSnapshot.getShardSnapshots(), shardNameTransformer)));
 
         shardManager.tell(PoisonPill.getInstance(), ActorRef.noSender());
     }
@@ -1220,24 +1228,23 @@ public class ShardManagerTest extends AbstractActorTest {
         String shardManagerID = ShardManagerIdentifier.builder().type(shardMrgIDSuffix).build().toString();
 
         // Create an ActorSystem ShardManager actor for member-1.
-        final ActorSystem system1 = ActorSystem.create("cluster-test", ConfigFactory.load().getConfig("Member1"));
+        final ActorSystem system1 = newActorSystem("Member1");
         Cluster.get(system1).join(AddressFromURIString.parse("akka.tcp://cluster-test@127.0.0.1:2558"));
         ActorRef mockDefaultShardActor = newMockShardActor(system1, Shard.DEFAULT_NAME, "member-1");
-        final TestActorRef<ForwardingShardManager> newReplicaShardManager = TestActorRef.create(system1,
-                newPropsShardMgrWithMockShardActor("shardManager1", mockDefaultShardActor,
-                   new ClusterWrapperImpl(system1), mockConfig), shardManagerID);
+        final TestActorRef<TestShardManager> newReplicaShardManager = TestActorRef.create(system1,
+                newTestShardMgrBuilder(mockConfig).shardActor(mockDefaultShardActor).cluster(
+                        new ClusterWrapperImpl(system1)).props(), shardManagerID);
 
         // Create an ActorSystem ShardManager actor for member-2.
-        final ActorSystem system2 = ActorSystem.create("cluster-test",
-            ConfigFactory.load().getConfig("Member2"));
+        final ActorSystem system2 = newActorSystem("Member2");
         Cluster.get(system2).join(AddressFromURIString.parse("akka.tcp://cluster-test@127.0.0.1:2558"));
 
         String name = new ShardIdentifier("astronauts", "member-2", "config").toString();
         final TestActorRef<MockRespondActor> mockShardLeaderActor =
-            TestActorRef.create(system2, Props.create(MockRespondActor.class), name);
-        final TestActorRef<ForwardingShardManager> leaderShardManager = TestActorRef.create(system2,
-                newPropsShardMgrWithMockShardActor("shardManager2", mockShardLeaderActor,
-                        new ClusterWrapperImpl(system2), mockConfig), shardManagerID);
+                TestActorRef.create(system2, Props.create(MockRespondActor.class), name);
+        final TestActorRef<TestShardManager> leaderShardManager = TestActorRef.create(system2,
+                newTestShardMgrBuilder(mockConfig).shardActor(mockShardLeaderActor).cluster(
+                        new ClusterWrapperImpl(system2)).props(), shardManagerID);
 
         new JavaTestKit(system1) {{
 
@@ -1268,15 +1275,12 @@ public class ShardManagerTest extends AbstractActorTest {
                 .verifySnapshotPersisted(Sets.newHashSet("default", "astronauts"));
             expectMsgClass(duration("5 seconds"), Status.Success.class);
         }};
-
-        JavaTestKit.shutdownActorSystem(system1);
-        JavaTestKit.shutdownActorSystem(system2);
     }
 
     @Test
     public void testAddShardReplicaWithPreExistingReplicaInRemoteShardLeader() throws Exception {
         new JavaTestKit(getSystem()) {{
-            TestActorRef<ForwardingShardManager> shardManager = TestActorRef.create(getSystem(),
+            TestActorRef<TestShardManager> shardManager = TestActorRef.create(getSystem(),
                     newPropsShardMgrWithMockShardActor(), shardMgrID);
 
             shardManager.tell(new UpdateSchemaContext(TestModel.createTestContext()), getRef());
@@ -1358,9 +1362,8 @@ public class ShardManagerTest extends AbstractActorTest {
                        put("astronauts", Arrays.asList("member-2")).build());
 
             ActorRef mockNewReplicaShardActor = newMockShardActor(getSystem(), "astronauts", "member-1");
-            TestActorRef<ForwardingShardManager> shardManager = TestActorRef.create(getSystem(),
-                    newPropsShardMgrWithMockShardActor("newReplicaShardManager", mockNewReplicaShardActor,
-                            new MockClusterWrapper(), mockConfig), shardMgrID);
+            final TestActorRef<TestShardManager> shardManager = TestActorRef.create(getSystem(),
+                    newTestShardMgrBuilder(mockConfig).shardActor(mockNewReplicaShardActor).props(), shardMgrID);
             shardManager.underlyingActor().setMessageInterceptor(newFindPrimaryInterceptor(mockShardLeaderKit.getRef()));
 
             shardManager.tell(new UpdateSchemaContext(TestModel.createTestContext()), getRef());
@@ -1401,9 +1404,8 @@ public class ShardManagerTest extends AbstractActorTest {
                     new MockConfiguration(ImmutableMap.<String, List<String>>builder().
                        put("astronauts", Arrays.asList("member-2")).build());
 
-            TestActorRef<ForwardingShardManager> shardManager = TestActorRef.create(getSystem(),
-                    newPropsShardMgrWithMockShardActor("newReplicaShardManager", mockShardActor,
-                            new MockClusterWrapper(), mockConfig), shardMgrID);
+            final TestActorRef<TestShardManager> shardManager = TestActorRef.create(getSystem(),
+                    newTestShardMgrBuilder(mockConfig).shardActor(mockShardActor).props(), shardMgrID);
             shardManager.underlyingActor().setMessageInterceptor(newFindPrimaryInterceptor(mockShardLeaderKit.getRef()));
 
             shardManager.tell(new UpdateSchemaContext(TestModel.createTestContext()), getRef());
@@ -1424,8 +1426,8 @@ public class ShardManagerTest extends AbstractActorTest {
             MockConfiguration mockConfig = new MockConfiguration(ImmutableMap.<String, List<String>>builder().
                        put("astronauts", Arrays.asList("member-2")).build());
 
-            ActorRef newReplicaShardManager = getSystem().actorOf(newPropsShardMgrWithMockShardActor(
-                    "shardManager", mockShardActor, new MockClusterWrapper(), mockConfig));
+            final ActorRef newReplicaShardManager = getSystem().actorOf(newTestShardMgrBuilder(mockConfig).
+                    shardActor(mockShardActor).props(), shardMgrID);
 
             newReplicaShardManager.tell(new UpdateSchemaContext(TestModel.createTestContext()), getRef());
             MockClusterWrapper.sendMemberUp(newReplicaShardManager, "member-2", getRef().path().toString());
@@ -1489,17 +1491,15 @@ public class ShardManagerTest extends AbstractActorTest {
                             put("astronauts", Arrays.asList("member-2")).
                             put("people", Arrays.asList("member-1", "member-2")).build());
 
-            TestActorRef<TestShardManager> shardManager = TestActorRef.create(getSystem(),
-                    newShardMgrProps(mockConfig));
-
             TestActorRef<MessageCollectorActor> shard = TestActorRef.create(getSystem(), MessageCollectorActor.props());
 
+            TestActorRef<TestShardManager> shardManager = TestActorRef.create(getSystem(),
+                    newTestShardMgrBuilder(mockConfig).addShardActor("default", shard).props());
+
             watch(shard);
 
             shardManager.underlyingActor().waitForRecoveryComplete();
 
-            shardManager.underlyingActor().addShardActor("default", shard);
-
             shardManager.tell(new UpdateSchemaContext(TestModel.createTestContext()), getRef());
 
             shardManager.tell(new FindLocalShard("people", false), getRef());
@@ -1558,10 +1558,19 @@ public class ShardManagerTest extends AbstractActorTest {
         private final CountDownLatch recoveryComplete = new CountDownLatch(1);
         private final CountDownLatch snapshotPersist = new CountDownLatch(1);
         private ShardManagerSnapshot snapshot;
-        private Map<String, ActorRef> shardActors = new HashMap<>();
+        private final Map<String, ActorRef> shardActors;
+        private final ActorRef shardActor;
+        private CountDownLatch findPrimaryMessageReceived = new CountDownLatch(1);
+        private CountDownLatch memberUpReceived = new CountDownLatch(1);
+        private CountDownLatch memberRemovedReceived = new CountDownLatch(1);
+        private CountDownLatch memberUnreachableReceived = new CountDownLatch(1);
+        private CountDownLatch memberReachableReceived = new CountDownLatch(1);
+        private volatile MessageInterceptor messageInterceptor;
 
         private TestShardManager(Builder builder) {
             super(builder);
+            shardActor = builder.shardActor;
+            shardActors = builder.shardActors;
         }
 
         @Override
@@ -1575,109 +1584,6 @@ public class ShardManagerTest extends AbstractActorTest {
             }
         }
 
-        void waitForRecoveryComplete() {
-            assertEquals("Recovery complete", true,
-                    Uninterruptibles.awaitUninterruptibly(recoveryComplete, 5, TimeUnit.SECONDS));
-        }
-
-        public static Builder builder(DatastoreContext.Builder datastoreContextBuilder) {
-            return new Builder(datastoreContextBuilder);
-        }
-
-        private static class Builder extends ShardManager.Builder {
-            Builder(DatastoreContext.Builder datastoreContextBuilder) {
-                cluster(new MockClusterWrapper()).configuration(new MockConfiguration());
-                datastoreContextFactory(newDatastoreContextFactory(datastoreContextBuilder.build()));
-                waitTillReadyCountdownLatch(ready).primaryShardInfoCache(new PrimaryShardInfoFutureCache());
-            }
-
-            @Override
-            public Props props() {
-                verify();
-                return Props.create(TestShardManager.class, this);
-            }
-        }
-
-        @Override
-        public void saveSnapshot(Object obj) {
-            snapshot = (ShardManagerSnapshot) obj;
-            snapshotPersist.countDown();
-        }
-
-        void verifySnapshotPersisted(Set<String> shardList) {
-            assertEquals("saveSnapshot invoked", true,
-                    Uninterruptibles.awaitUninterruptibly(snapshotPersist, 5, TimeUnit.SECONDS));
-            assertEquals("Shard Persisted", shardList, Sets.newHashSet(snapshot.getShardList()));
-        }
-
-        @Override
-        protected ActorRef newShardActor(SchemaContext schemaContext, ShardInformation info) {
-            if(shardActors.get(info.getShardName()) != null){
-                return shardActors.get(info.getShardName());
-            }
-            return super.newShardActor(schemaContext, info);
-        }
-
-        public void addShardActor(String shardName, ActorRef actorRef){
-            shardActors.put(shardName, actorRef);
-        }
-    }
-
-    private static class DelegatingShardManagerCreator implements Creator<ShardManager> {
-        private static final long serialVersionUID = 1L;
-        private final Creator<ShardManager> delegate;
-
-        public DelegatingShardManagerCreator(Creator<ShardManager> delegate) {
-            this.delegate = delegate;
-        }
-
-        @Override
-        public ShardManager create() throws Exception {
-            return delegate.create();
-        }
-    }
-
-    interface MessageInterceptor extends Function<Object, Object> {
-        boolean canIntercept(Object message);
-    }
-
-    private MessageInterceptor newFindPrimaryInterceptor(final ActorRef primaryActor) {
-        return new MessageInterceptor(){
-            @Override
-            public Object apply(Object message) {
-                return new RemotePrimaryShardFound(Serialization.serializedActorPath(primaryActor), (short) 1);
-            }
-
-            @Override
-            public boolean canIntercept(Object message) {
-                return message instanceof FindPrimary;
-            }
-        };
-    }
-
-    private static class ForwardingShardManager extends ShardManager {
-        private CountDownLatch findPrimaryMessageReceived = new CountDownLatch(1);
-        private CountDownLatch memberUpReceived = new CountDownLatch(1);
-        private CountDownLatch memberRemovedReceived = new CountDownLatch(1);
-        private CountDownLatch memberUnreachableReceived = new CountDownLatch(1);
-        private CountDownLatch memberReachableReceived = new CountDownLatch(1);
-        private final ActorRef shardActor;
-        private final String name;
-        private final CountDownLatch snapshotPersist = new CountDownLatch(1);
-        private ShardManagerSnapshot snapshot;
-        private volatile MessageInterceptor messageInterceptor;
-
-        public ForwardingShardManager(Builder builder, String name, ActorRef shardActor) {
-            super(builder);
-            this.shardActor = shardActor;
-            this.name = name;
-        }
-
-        void setMessageInterceptor(MessageInterceptor messageInterceptor) {
-            this.messageInterceptor = messageInterceptor;
-        }
-
-
         @Override
         public void handleCommand(Object message) throws Exception {
             try{
@@ -1713,14 +1619,13 @@ public class ShardManagerTest extends AbstractActorTest {
             }
         }
 
-        @Override
-        public String persistenceId() {
-            return name;
+        void setMessageInterceptor(MessageInterceptor messageInterceptor) {
+            this.messageInterceptor = messageInterceptor;
         }
 
-        @Override
-        protected ActorRef newShardActor(SchemaContext schemaContext, ShardInformation info) {
-            return shardActor;
+        void waitForRecoveryComplete() {
+            assertEquals("Recovery complete", true,
+                    Uninterruptibles.awaitUninterruptibly(recoveryComplete, 5, TimeUnit.SECONDS));
         }
 
         void waitForMemberUp() {
@@ -1754,6 +1659,30 @@ public class ShardManagerTest extends AbstractActorTest {
             findPrimaryMessageReceived = new CountDownLatch(1);
         }
 
+        public static Builder builder(DatastoreContext.Builder datastoreContextBuilder) {
+            return new Builder(datastoreContextBuilder);
+        }
+
+        private static class Builder extends AbstractGenericBuilder<Builder, TestShardManager> {
+            private ActorRef shardActor;
+            private final Map<String, ActorRef> shardActors = new HashMap<>();
+
+            Builder(DatastoreContext.Builder datastoreContextBuilder) {
+                super(TestShardManager.class);
+                datastoreContextFactory(newDatastoreContextFactory(datastoreContextBuilder.build()));
+            }
+
+            Builder shardActor(ActorRef shardActor) {
+                this.shardActor = shardActor;
+                return this;
+            }
+
+            Builder addShardActor(String shardName, ActorRef actorRef){
+                shardActors.put(shardName, actorRef);
+                return this;
+            }
+        }
+
         @Override
         public void saveSnapshot(Object obj) {
             snapshot = (ShardManagerSnapshot) obj;
@@ -1762,9 +1691,77 @@ public class ShardManagerTest extends AbstractActorTest {
 
         void verifySnapshotPersisted(Set<String> shardList) {
             assertEquals("saveSnapshot invoked", true,
-                Uninterruptibles.awaitUninterruptibly(snapshotPersist, 5, TimeUnit.SECONDS));
+                    Uninterruptibles.awaitUninterruptibly(snapshotPersist, 5, TimeUnit.SECONDS));
             assertEquals("Shard Persisted", shardList, Sets.newHashSet(snapshot.getShardList()));
         }
+
+        @Override
+        protected ActorRef newShardActor(SchemaContext schemaContext, ShardInformation info) {
+            if(shardActors.get(info.getShardName()) != null){
+                return shardActors.get(info.getShardName());
+            }
+
+            if(shardActor != null) {
+                return shardActor;
+            }
+
+            return super.newShardActor(schemaContext, info);
+        }
+    }
+
+    private static abstract class AbstractGenericBuilder<T extends AbstractGenericBuilder<T, ?>, C extends ShardManager>
+                                                     extends ShardManager.AbstractBuilder<T> {
+        private final Class<C> shardManagerClass;
+
+        AbstractGenericBuilder(Class<C> shardManagerClass) {
+            this.shardManagerClass = shardManagerClass;
+            cluster(new MockClusterWrapper()).configuration(new MockConfiguration()).
+                    waitTillReadyCountdownLatch(ready).primaryShardInfoCache(new PrimaryShardInfoFutureCache());
+        }
+
+        @Override
+        public Props props() {
+            verify();
+            return Props.create(shardManagerClass, this);
+        }
+    }
+
+    private static class GenericBuilder<C extends ShardManager> extends AbstractGenericBuilder<GenericBuilder<C>, C> {
+        GenericBuilder(Class<C> shardManagerClass) {
+            super(shardManagerClass);
+        }
+    }
+
+    private static class DelegatingShardManagerCreator implements Creator<ShardManager> {
+        private static final long serialVersionUID = 1L;
+        private final Creator<ShardManager> delegate;
+
+        public DelegatingShardManagerCreator(Creator<ShardManager> delegate) {
+            this.delegate = delegate;
+        }
+
+        @Override
+        public ShardManager create() throws Exception {
+            return delegate.create();
+        }
+    }
+
+    interface MessageInterceptor extends Function<Object, Object> {
+        boolean canIntercept(Object message);
+    }
+
+    private MessageInterceptor newFindPrimaryInterceptor(final ActorRef primaryActor) {
+        return new MessageInterceptor(){
+            @Override
+            public Object apply(Object message) {
+                return new RemotePrimaryShardFound(Serialization.serializedActorPath(primaryActor), (short) 1);
+            }
+
+            @Override
+            public boolean canIntercept(Object message) {
+                return message instanceof FindPrimary;
+            }
+        };
     }
 
     private static class MockRespondActor extends MessageCollectorActor {