Bug 4823: Notify findPrimary callbacks on ReachableMember event
[controller.git] / opendaylight / md-sal / sal-distributed-datastore / src / main / java / org / opendaylight / controller / cluster / datastore / ShardManager.java
index 4f3d4aa7f931b1af11e1198de001d02831ac6d63..a91109c64b3fa5aac674677b95d2a31c83477fbd 100644 (file)
@@ -16,12 +16,16 @@ import akka.actor.Cancellable;
 import akka.actor.OneForOneStrategy;
 import akka.actor.PoisonPill;
 import akka.actor.Props;
+import akka.actor.Status;
 import akka.actor.SupervisorStrategy;
 import akka.cluster.ClusterEvent;
 import akka.dispatch.OnComplete;
-import akka.japi.Creator;
 import akka.japi.Function;
 import akka.persistence.RecoveryCompleted;
+import akka.persistence.SaveSnapshotFailure;
+import akka.persistence.SaveSnapshotSuccess;
+import akka.persistence.SnapshotOffer;
+import akka.persistence.SnapshotSelectionCriteria;
 import akka.serialization.Serialization;
 import akka.util.Timeout;
 import com.google.common.annotations.VisibleForTesting;
@@ -31,19 +35,28 @@ import com.google.common.base.Preconditions;
 import com.google.common.base.Strings;
 import com.google.common.base.Supplier;
 import com.google.common.collect.Sets;
+import java.io.ByteArrayInputStream;
+import java.io.ObjectInputStream;
 import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Collection;
+import java.util.Collections;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+import org.apache.commons.lang3.SerializationUtils;
 import org.opendaylight.controller.cluster.common.actor.AbstractUntypedPersistentActorWithMetering;
 import org.opendaylight.controller.cluster.datastore.config.Configuration;
 import org.opendaylight.controller.cluster.datastore.config.ModuleShardConfiguration;
+import org.opendaylight.controller.cluster.datastore.exceptions.AlreadyExistsException;
 import org.opendaylight.controller.cluster.datastore.exceptions.NoShardLeaderException;
 import org.opendaylight.controller.cluster.datastore.exceptions.NotInitializedException;
 import org.opendaylight.controller.cluster.datastore.exceptions.PrimaryNotFoundException;
@@ -53,7 +66,7 @@ import org.opendaylight.controller.cluster.datastore.jmx.mbeans.shardmanager.Sha
 import org.opendaylight.controller.cluster.datastore.messages.ActorInitialized;
 import org.opendaylight.controller.cluster.datastore.messages.AddShardReplica;
 import org.opendaylight.controller.cluster.datastore.messages.CreateShard;
-import org.opendaylight.controller.cluster.datastore.messages.CreateShardReply;
+import org.opendaylight.controller.cluster.datastore.messages.DatastoreSnapshot;
 import org.opendaylight.controller.cluster.datastore.messages.FindLocalShard;
 import org.opendaylight.controller.cluster.datastore.messages.FindPrimary;
 import org.opendaylight.controller.cluster.datastore.messages.LocalPrimaryShardFound;
@@ -75,9 +88,13 @@ import org.opendaylight.controller.cluster.notifications.RoleChangeNotification;
 import org.opendaylight.controller.cluster.raft.RaftState;
 import org.opendaylight.controller.cluster.raft.base.messages.FollowerInitialSyncUpStatus;
 import org.opendaylight.controller.cluster.raft.base.messages.SwitchBehavior;
+import org.opendaylight.controller.cluster.raft.client.messages.GetSnapshot;
 import org.opendaylight.controller.cluster.raft.messages.AddServer;
 import org.opendaylight.controller.cluster.raft.messages.AddServerReply;
+import org.opendaylight.controller.cluster.raft.messages.RemoveServer;
+import org.opendaylight.controller.cluster.raft.messages.RemoveServerReply;
 import org.opendaylight.controller.cluster.raft.messages.ServerChangeStatus;
+import org.opendaylight.controller.cluster.raft.messages.ServerRemoved;
 import org.opendaylight.controller.cluster.raft.policy.DisableElectionsRaftPolicy;
 import org.opendaylight.yangtools.yang.data.api.schema.tree.DataTree;
 import org.opendaylight.yangtools.yang.model.api.SchemaContext;
@@ -115,7 +132,7 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
 
     private final String shardDispatcherPath;
 
-    private ShardManagerInfo mBean;
+    private final ShardManagerInfo mBean;
 
     private DatastoreContextFactory datastoreContextFactory;
 
@@ -127,43 +144,42 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
 
     private SchemaContext schemaContext;
 
+    private DatastoreSnapshot restoreFromSnapshot;
+
+    private ShardManagerSnapshot currentSnapshot;
+
+    private final Set<String> shardReplicaOperationsInProgress = new HashSet<>();
+
+    private final String persistenceId;
+
     /**
      */
-    protected ShardManager(ClusterWrapper cluster, Configuration configuration,
-            DatastoreContextFactory datastoreContextFactory, CountDownLatch waitTillReadyCountdownLatch,
-            PrimaryShardInfoFutureCache primaryShardInfoCache) {
-
-        this.cluster = Preconditions.checkNotNull(cluster, "cluster should not be null");
-        this.configuration = Preconditions.checkNotNull(configuration, "configuration should not be null");
-        this.datastoreContextFactory = datastoreContextFactory;
-        this.type = datastoreContextFactory.getBaseDatastoreContext().getDataStoreType();
+    protected ShardManager(AbstractBuilder<?> builder) {
+
+        this.cluster = builder.cluster;
+        this.configuration = builder.configuration;
+        this.datastoreContextFactory = builder.datastoreContextFactory;
+        this.type = builder.datastoreContextFactory.getBaseDatastoreContext().getDataStoreName();
         this.shardDispatcherPath =
                 new Dispatchers(context().system().dispatchers()).getDispatcherPath(Dispatchers.DispatcherType.Shard);
-        this.waitTillReadyCountdownLatch = waitTillReadyCountdownLatch;
-        this.primaryShardInfoCache = primaryShardInfoCache;
+        this.waitTillReadyCountdownLatch = builder.waitTillReadyCountdownLatch;
+        this.primaryShardInfoCache = builder.primaryShardInfoCache;
+        this.restoreFromSnapshot = builder.restoreFromSnapshot;
+
+        String possiblePersistenceId = datastoreContextFactory.getBaseDatastoreContext().getShardManagerPersistenceId();
+        persistenceId = possiblePersistenceId != null ? possiblePersistenceId : "shard-manager-" + type;
 
         peerAddressResolver = new ShardPeerAddressResolver(type, cluster.getCurrentMemberName());
 
         // Subscribe this actor to cluster member events
         cluster.subscribeToMemberEvents(getSelf());
 
-        createLocalShards();
-    }
-
-    public static Props props(
-            final ClusterWrapper cluster,
-            final Configuration configuration,
-            final DatastoreContextFactory datastoreContextFactory,
-            final CountDownLatch waitTillReadyCountdownLatch,
-            final PrimaryShardInfoFutureCache primaryShardInfoCache) {
-
-        Preconditions.checkNotNull(cluster, "cluster should not be null");
-        Preconditions.checkNotNull(configuration, "configuration should not be null");
-        Preconditions.checkNotNull(waitTillReadyCountdownLatch, "waitTillReadyCountdownLatch should not be null");
-        Preconditions.checkNotNull(primaryShardInfoCache, "primaryShardInfoCache should not be null");
-
-        return Props.create(new ShardManagerCreator(cluster, configuration, datastoreContextFactory,
-                waitTillReadyCountdownLatch, primaryShardInfoCache));
+        List<String> localShardActorNames = new ArrayList<>();
+        mBean = ShardManagerInfo.createShardManagerMBean(cluster.getCurrentMemberName(),
+                "shard-manager-" + this.type,
+                datastoreContextFactory.getBaseDatastoreContext().getDataStoreMXBeanType(),
+                localShardActorNames);
+        mBean.setShardManager(this);
     }
 
     @Override
@@ -209,53 +225,158 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
             onCreateShard((CreateShard)message);
         } else if(message instanceof AddShardReplica){
             onAddShardReplica((AddShardReplica)message);
-        } else if(message instanceof RemoveShardReplica){
-            onRemoveShardReplica((RemoveShardReplica)message);
+        } else if(message instanceof ForwardedAddServerReply) {
+            ForwardedAddServerReply msg = (ForwardedAddServerReply)message;
+            onAddServerReply(msg.shardInfo, msg.addServerReply, getSender(), msg.leaderPath,
+                    msg.removeShardOnFailure);
+        } else if(message instanceof ForwardedAddServerFailure) {
+            ForwardedAddServerFailure msg = (ForwardedAddServerFailure)message;
+            onAddServerFailure(msg.shardName, msg.failureMessage, msg.failure, getSender(), msg.removeShardOnFailure);
+        } else if(message instanceof PrimaryShardFoundForContext) {
+            PrimaryShardFoundForContext primaryShardFoundContext = (PrimaryShardFoundForContext)message;
+            onPrimaryShardFoundContext(primaryShardFoundContext);
+        } else if(message instanceof RemoveShardReplica) {
+            onRemoveShardReplica((RemoveShardReplica) message);
+        } else if(message instanceof WrappedShardResponse){
+            onWrappedShardResponse((WrappedShardResponse) message);
+        } else if(message instanceof GetSnapshot) {
+            onGetSnapshot();
+        } else if(message instanceof ServerRemoved){
+            onShardReplicaRemoved((ServerRemoved) message);
+        } else if (message instanceof SaveSnapshotSuccess) {
+            onSaveSnapshotSuccess((SaveSnapshotSuccess)message);
+        } else if (message instanceof SaveSnapshotFailure) {
+            LOG.error("{}: SaveSnapshotFailure received for saving snapshot of shards",
+                    persistenceId(), ((SaveSnapshotFailure) message).cause());
         } else {
             unknownMessage(message);
         }
+    }
 
+    private void onWrappedShardResponse(WrappedShardResponse message) {
+        if (message.getResponse() instanceof RemoveServerReply) {
+            onRemoveServerReply(getSender(), message.getShardName(), (RemoveServerReply) message.getResponse());
+        }
     }
 
-    private void onCreateShard(CreateShard createShard) {
-        Object reply;
-        try {
-            ModuleShardConfiguration moduleShardConfig = createShard.getModuleShardConfig();
-            if(localShards.containsKey(moduleShardConfig.getShardName())) {
-                throw new IllegalStateException(String.format("Shard with name %s already exists",
-                        moduleShardConfig.getShardName()));
-            }
+    private void onRemoveServerReply(ActorRef originalSender, String shardName, RemoveServerReply response) {
+        shardReplicaOperationsInProgress.remove(shardName);
+        originalSender.tell(new Status.Success(null), self());
+    }
+
+    private void onPrimaryShardFoundContext(PrimaryShardFoundForContext primaryShardFoundContext) {
+        if(primaryShardFoundContext.getContextMessage() instanceof AddShardReplica) {
+            addShard(primaryShardFoundContext.getShardName(), primaryShardFoundContext.getRemotePrimaryShardFound(),
+                    getSender());
+        } else if(primaryShardFoundContext.getContextMessage() instanceof RemoveShardReplica){
+            removeShardReplica((RemoveShardReplica) primaryShardFoundContext.getContextMessage(),
+                    primaryShardFoundContext.getShardName(), primaryShardFoundContext.getPrimaryPath(), getSender());
+        }
+    }
+
+    private void removeShardReplica(RemoveShardReplica contextMessage, final String shardName, final String primaryPath,
+            final ActorRef sender) {
+        if(isShardReplicaOperationInProgress(shardName, sender)) {
+            return;
+        }
 
-            configuration.addModuleShardConfiguration(moduleShardConfig);
+        shardReplicaOperationsInProgress.add(shardName);
 
-            ShardIdentifier shardId = getShardIdentifier(cluster.getCurrentMemberName(), moduleShardConfig.getShardName());
-            Map<String, String> peerAddresses = getPeerAddresses(moduleShardConfig.getShardName()/*,
-                    moduleShardConfig.getShardMemberNames()*/);
+        final ShardIdentifier shardId = getShardIdentifier(contextMessage.getMemberName(), shardName);
 
-            LOG.debug("onCreateShard: shardId: {}, memberNames: {}. peerAddresses: {}", shardId,
-                    moduleShardConfig.getShardMemberNames(), peerAddresses);
+        final DatastoreContext datastoreContext = newShardDatastoreContextBuilder(shardName).build();
 
-            DatastoreContext shardDatastoreContext = createShard.getDatastoreContext();
-            if(shardDatastoreContext == null) {
-                shardDatastoreContext = newShardDatastoreContext(moduleShardConfig.getShardName());
-            } else {
-                shardDatastoreContext = DatastoreContext.newBuilderFrom(shardDatastoreContext).shardPeerAddressResolver(
-                        peerAddressResolver).build();
+        //inform ShardLeader to remove this shard as a replica by sending an RemoveServer message
+        LOG.debug ("{}: Sending RemoveServer message to peer {} for shard {}", persistenceId(),
+                primaryPath, shardId);
+
+        Timeout removeServerTimeout = new Timeout(datastoreContext.getShardLeaderElectionTimeout().
+                duration());
+        Future<Object> futureObj = ask(getContext().actorSelection(primaryPath),
+                new RemoveServer(shardId.toString()), removeServerTimeout);
+
+        futureObj.onComplete(new OnComplete<Object>() {
+            @Override
+            public void onComplete(Throwable failure, Object response) {
+                if (failure != null) {
+                    String msg = String.format("RemoveServer request to leader %s for shard %s failed",
+                            primaryPath, shardName);
+
+                    LOG.debug ("{}: {}", persistenceId(), msg, failure);
+
+                    // FAILURE
+                    sender.tell(new Status.Failure(new RuntimeException(msg, failure)), self());
+                } else {
+                    // SUCCESS
+                    self().tell(new WrappedShardResponse(shardName, response), sender);
+                }
             }
+        }, new Dispatchers(context().system().dispatchers()).getDispatcher(Dispatchers.DispatcherType.Client));
+    }
 
-            ShardInformation info = new ShardInformation(moduleShardConfig.getShardName(), shardId, peerAddresses,
-                    shardDatastoreContext, createShard.getShardPropsCreator(), peerAddressResolver);
-            localShards.put(info.getShardName(), info);
+    private void onShardReplicaRemoved(ServerRemoved message) {
+        final ShardIdentifier shardId = new ShardIdentifier.Builder().fromShardIdString(message.getServerId()).build();
+        final ShardInformation shardInformation = localShards.remove(shardId.getShardName());
+        if(shardInformation == null) {
+            LOG.debug("{} : Shard replica {} is not present in list", persistenceId(), shardId.toString());
+            return;
+        } else if(shardInformation.getActor() != null) {
+            LOG.debug("{} : Sending PoisonPill to Shard actor {}", persistenceId(), shardInformation.getActor());
+            shardInformation.getActor().tell(PoisonPill.getInstance(), self());
+        }
+        LOG.debug("{} : Local Shard replica for shard {} has been removed", persistenceId(), shardId.getShardName());
+        persistShardList();
+    }
 
-            mBean.addLocalShard(shardId.toString());
+    private void onGetSnapshot() {
+        LOG.debug("{}: onGetSnapshot", persistenceId());
 
-            if(schemaContext != null) {
-                info.setActor(newShardActor(schemaContext, info));
+        List<String> notInitialized = null;
+        for(ShardInformation shardInfo: localShards.values()) {
+            if(!shardInfo.isShardInitialized()) {
+                if(notInitialized == null) {
+                    notInitialized = new ArrayList<>();
+                }
+
+                notInitialized.add(shardInfo.getShardName());
             }
+        }
+
+        if(notInitialized != null) {
+            getSender().tell(new akka.actor.Status.Failure(new IllegalStateException(String.format(
+                    "%d shard(s) %s are not initialized", notInitialized.size(), notInitialized))), getSelf());
+            return;
+        }
+
+        byte[] shardManagerSnapshot = null;
+        if(currentSnapshot != null) {
+            shardManagerSnapshot = SerializationUtils.serialize(currentSnapshot);
+        }
+
+        ActorRef replyActor = getContext().actorOf(ShardManagerGetSnapshotReplyActor.props(
+                new ArrayList<>(localShards.keySet()), type, shardManagerSnapshot , getSender(), persistenceId(),
+                datastoreContextFactory.getBaseDatastoreContext().getShardInitializationTimeout().duration()));
+
+        for(ShardInformation shardInfo: localShards.values()) {
+            shardInfo.getActor().tell(GetSnapshot.INSTANCE, replyActor);
+        }
+    }
+
+    private void onCreateShard(CreateShard createShard) {
+        LOG.debug("{}: onCreateShard: {}", persistenceId(), createShard);
 
-            reply = new CreateShardReply();
+        Object reply;
+        try {
+            String shardName = createShard.getModuleShardConfig().getShardName();
+            if(localShards.containsKey(shardName)) {
+                LOG.debug("{}: Shard {} already exists", persistenceId(), shardName);
+                reply = new akka.actor.Status.Success(String.format("Shard with name %s already exists", shardName));
+            } else {
+                doCreateShard(createShard);
+                reply = new akka.actor.Status.Success(null);
+            }
         } catch (Exception e) {
-            LOG.error("onCreateShard failed", e);
+            LOG.error("{}: onCreateShard failed", persistenceId(), e);
             reply = new akka.actor.Status.Failure(e);
         }
 
@@ -264,6 +385,58 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
         }
     }
 
+    private void doCreateShard(CreateShard createShard) {
+        ModuleShardConfiguration moduleShardConfig = createShard.getModuleShardConfig();
+        String shardName = moduleShardConfig.getShardName();
+
+        configuration.addModuleShardConfiguration(moduleShardConfig);
+
+        DatastoreContext shardDatastoreContext = createShard.getDatastoreContext();
+        if(shardDatastoreContext == null) {
+            shardDatastoreContext = newShardDatastoreContext(shardName);
+        } else {
+            shardDatastoreContext = DatastoreContext.newBuilderFrom(shardDatastoreContext).shardPeerAddressResolver(
+                    peerAddressResolver).build();
+        }
+
+        ShardIdentifier shardId = getShardIdentifier(cluster.getCurrentMemberName(), shardName);
+
+        boolean shardWasInRecoveredSnapshot = currentSnapshot != null &&
+                currentSnapshot.getShardList().contains(shardName);
+
+        Map<String, String> peerAddresses;
+        boolean isActiveMember;
+        if(shardWasInRecoveredSnapshot || configuration.getMembersFromShardName(shardName).
+                contains(cluster.getCurrentMemberName())) {
+            peerAddresses = getPeerAddresses(shardName);
+            isActiveMember = true;
+        } else {
+            // The local member is not in the static shard member configuration and the shard did not
+            // previously exist (ie !shardWasInRecoveredSnapshot). In this case we'll create
+            // the shard with no peers and with elections disabled so it stays as follower. A
+            // subsequent AddServer request will be needed to make it an active member.
+            isActiveMember = false;
+            peerAddresses = Collections.emptyMap();
+            shardDatastoreContext = DatastoreContext.newBuilderFrom(shardDatastoreContext).
+                    customRaftPolicyImplementation(DisableElectionsRaftPolicy.class.getName()).build();
+        }
+
+        LOG.debug("{} doCreateShard: shardId: {}, memberNames: {}, peerAddresses: {}, isActiveMember: {}",
+                persistenceId(), shardId, moduleShardConfig.getShardMemberNames(), peerAddresses,
+                isActiveMember);
+
+        ShardInformation info = new ShardInformation(shardName, shardId, peerAddresses,
+                shardDatastoreContext, createShard.getShardBuilder(), peerAddressResolver);
+        info.setActiveMember(isActiveMember);
+        localShards.put(info.getShardName(), info);
+
+        mBean.addLocalShard(shardId.toString());
+
+        if(schemaContext != null) {
+            info.setActor(newShardActor(schemaContext, info));
+        }
+    }
+
     private DatastoreContext.Builder newShardDatastoreContextBuilder(String shardName) {
         return DatastoreContext.newBuilderFrom(datastoreContextFactory.getShardDatastoreContext(shardName)).
                 shardPeerAddressResolver(peerAddressResolver);
@@ -405,12 +578,34 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
     @Override
     protected void handleRecover(Object message) throws Exception {
         if (message instanceof RecoveryCompleted) {
-            LOG.info("Recovery complete : {}", persistenceId());
+            onRecoveryCompleted();
+        } else if (message instanceof SnapshotOffer) {
+            applyShardManagerSnapshot((ShardManagerSnapshot)((SnapshotOffer) message).snapshot());
+        }
+    }
 
-            // We no longer persist SchemaContext modules so delete all the prior messages from the akka
-            // journal on upgrade from Helium.
-            deleteMessages(lastSequenceNr());
+    private void onRecoveryCompleted() {
+        LOG.info("Recovery complete : {}", persistenceId());
+
+        // We no longer persist SchemaContext modules so delete all the prior messages from the akka
+        // journal on upgrade from Helium.
+        deleteMessages(lastSequenceNr());
+
+        if(currentSnapshot == null && restoreFromSnapshot != null &&
+                restoreFromSnapshot.getShardManagerSnapshot() != null) {
+            try(ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(
+                    restoreFromSnapshot.getShardManagerSnapshot()))) {
+                ShardManagerSnapshot snapshot = (ShardManagerSnapshot) ois.readObject();
+
+                LOG.debug("{}: Deserialized restored ShardManagerSnapshot: {}", persistenceId(), snapshot);
+
+                applyShardManagerSnapshot(snapshot);
+            } catch(Exception e) {
+                LOG.error("{}: Error deserializing restored ShardManagerSnapshot", persistenceId(), e);
+            }
         }
+
+        createLocalShards();
     }
 
     private void findLocalShard(FindLocalShard message) {
@@ -448,8 +643,6 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
 
                 shardInformation.addOnShardInitialized(onShardInitialized);
 
-                LOG.debug("{}: Scheduling timer to wait for shard {}", persistenceId(), shardInformation.getShardName());
-
                 FiniteDuration timeout = shardInformation.getDatastoreContext().getShardInitializationTimeout().duration();
                 if(shardInformation.isShardInitialized()) {
                     // If the shard is already initialized then we'll wait enough time for the shard to
@@ -458,6 +651,9 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
                             .getElectionTimeOutInterval().toMillis() * 2, TimeUnit.MILLISECONDS);
                 }
 
+                LOG.debug("{}: Scheduling {} ms timer to wait for shard {}", persistenceId(), timeout.toMillis(),
+                        shardInformation.getShardName());
+
                 Cancellable timeoutSchedule = getContext().system().scheduler().scheduleOnce(
                         timeout, getSelf(),
                         new ShardNotInitializedTimeout(shardInformation, onShardInitialized, sender),
@@ -641,7 +837,7 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
 
         // First see if the there is a local replica for the shard
         final ShardInformation info = localShards.get(shardName);
-        if (info != null) {
+        if (info != null && info.isActiveMember()) {
             sendResponse(info, message.isWaitUntilReady(), true, new Supplier<Object>() {
                 @Override
                 public Object get() {
@@ -661,12 +857,25 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
             return;
         }
 
+        Collection<String> visitedAddresses;
+        if(message instanceof RemoteFindPrimary) {
+            visitedAddresses = ((RemoteFindPrimary)message).getVisitedAddresses();
+        } else {
+            visitedAddresses = new ArrayList<>();
+        }
+
+        visitedAddresses.add(peerAddressResolver.getShardManagerActorPathBuilder(cluster.getSelfAddress()).toString());
+
         for(String address: peerAddressResolver.getShardManagerPeerActorAddresses()) {
+            if(visitedAddresses.contains(address)) {
+                continue;
+            }
+
             LOG.debug("{}: findPrimary for {} forwarding to remote ShardManager {}", persistenceId(),
                     shardName, address);
 
             getContext().actorSelection(address).forward(new RemoteFindPrimary(shardName,
-                    message.isWaitUntilReady()), getContext());
+                    message.isWaitUntilReady(), visitedAddresses), getContext());
             return;
         }
 
@@ -697,20 +906,27 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
         String memberName = this.cluster.getCurrentMemberName();
         Collection<String> memberShardNames = this.configuration.getMemberShardNames(memberName);
 
-        ShardPropsCreator shardPropsCreator = new DefaultShardPropsCreator();
-        List<String> localShardActorNames = new ArrayList<>();
+        Map<String, DatastoreSnapshot.ShardSnapshot> shardSnapshots = new HashMap<>();
+        if(restoreFromSnapshot != null)
+        {
+            for(DatastoreSnapshot.ShardSnapshot snapshot: restoreFromSnapshot.getShardSnapshots()) {
+                shardSnapshots.put(snapshot.getName(), snapshot);
+            }
+        }
+
+        restoreFromSnapshot = null; // null out to GC
+
         for(String shardName : memberShardNames){
             ShardIdentifier shardId = getShardIdentifier(memberName, shardName);
+
+            LOG.debug("{}: Creating local shard: {}", persistenceId(), shardId);
+
             Map<String, String> peerAddresses = getPeerAddresses(shardName);
-            localShardActorNames.add(shardId.toString());
             localShards.put(shardName, new ShardInformation(shardName, shardId, peerAddresses,
-                    newShardDatastoreContext(shardName), shardPropsCreator, peerAddressResolver));
+                    newShardDatastoreContext(shardName), Shard.builder().restoreFromSnapshot(
+                        shardSnapshots.get(shardName)), peerAddressResolver));
+            mBean.addLocalShard(shardId.toString());
         }
-
-        mBean = ShardManagerInfo.createShardManagerMBean(memberName, "shard-manager-" + this.type,
-                datastoreContextFactory.getBaseDatastoreContext().getDataStoreMXBeanType(), localShardActorNames);
-
-        mBean.setShardManager(this);
     }
 
     /**
@@ -751,7 +967,7 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
 
     @Override
     public String persistenceId() {
-        return "shard-manager-" + type;
+        return persistenceId;
     }
 
     @VisibleForTesting
@@ -759,21 +975,21 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
         return mBean;
     }
 
-    private void checkLocalShardExists(final String shardName, final ActorRef sender) {
-        if (localShards.containsKey(shardName)) {
-            String msg = String.format("Local shard %s already exists", shardName);
+    private boolean isShardReplicaOperationInProgress(final String shardName, final ActorRef sender) {
+        if (shardReplicaOperationsInProgress.contains(shardName)) {
+            String msg = String.format("A shard replica operation for %s is already in progress", shardName);
             LOG.debug ("{}: {}", persistenceId(), msg);
-            sender.tell(new akka.actor.Status.Failure(new IllegalArgumentException(msg)), getSelf());
+            sender.tell(new akka.actor.Status.Failure(new IllegalStateException(msg)), getSelf());
+            return true;
         }
+
+        return false;
     }
 
-    private void onAddShardReplica (AddShardReplica shardReplicaMsg) {
+    private void onAddShardReplica (final AddShardReplica shardReplicaMsg) {
         final String shardName = shardReplicaMsg.getShardName();
 
-        // verify the local shard replica is already available in the controller node
-        LOG.debug ("onAddShardReplica: {}", shardReplicaMsg);
-
-        checkLocalShardExists(shardName, getSender());
+        LOG.debug("{}: onAddShardReplica: {}", persistenceId(), shardReplicaMsg);
 
         // verify the shard with the specified name is present in the cluster configuration
         if (!(this.configuration.isShardConfigured(shardName))) {
@@ -792,65 +1008,63 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
             return;
         }
 
-        Map<String, String> peerAddresses = getPeerAddresses(shardName);
-        if (peerAddresses.isEmpty()) {
-            String msg = String.format("Cannot add replica for shard %s because no peer is available", shardName);
-            LOG.debug ("{}: {}", persistenceId(), msg);
-            getSender().tell(new akka.actor.Status.Failure(new IllegalStateException(msg)), getSelf());
-            return;
-        }
-
-        Timeout findPrimaryTimeout = new Timeout(datastoreContextFactory.getBaseDatastoreContext().
-                getShardInitializationTimeout().duration().$times(2));
-
-        final ActorRef sender = getSender();
-        Future<Object> futureObj = ask(getSelf(), new RemoteFindPrimary(shardName, true), findPrimaryTimeout);
-        futureObj.onComplete(new OnComplete<Object>() {
+        findPrimary(shardName, new AutoFindPrimaryFailureResponseHandler(getSender(), shardName, persistenceId(), getSelf()) {
             @Override
-            public void onComplete(Throwable failure, Object response) {
-                if (failure != null) {
-                    LOG.debug ("{}: Received failure from FindPrimary for shard {}", persistenceId(), shardName, failure);
-                    sender.tell(new akka.actor.Status.Failure(new RuntimeException(
-                        String.format("Failed to find leader for shard %s", shardName), failure)),
-                        getSelf());
-                } else {
-                    if (!(response instanceof RemotePrimaryShardFound)) {
-                        String msg = String.format("Failed to find leader for shard %s: received response: %s",
-                                shardName, response);
-                        LOG.debug ("{}: {}", persistenceId(), msg);
-                        sender.tell(new akka.actor.Status.Failure(new RuntimeException(msg)), getSelf());
-                        return;
-                    }
+            public void onRemotePrimaryShardFound(RemotePrimaryShardFound response) {
+                getSelf().tell(new PrimaryShardFoundForContext(getShardName(), shardReplicaMsg, response), getTargetActor());
+            }
 
-                    RemotePrimaryShardFound message = (RemotePrimaryShardFound)response;
-                    addShard (shardName, message, sender);
-                }
+            @Override
+            public void onLocalPrimaryFound(LocalPrimaryShardFound message) {
+                sendLocalReplicaAlreadyExistsReply(getShardName(), getTargetActor());
             }
-        }, new Dispatchers(context().system().dispatchers()).getDispatcher(Dispatchers.DispatcherType.Client));
+
+        });
+    }
+
+    private void sendLocalReplicaAlreadyExistsReply(String shardName, ActorRef sender) {
+        String msg = String.format("Local shard %s already exists", shardName);
+        LOG.debug ("{}: {}", persistenceId(), msg);
+        sender.tell(new akka.actor.Status.Failure(new AlreadyExistsException(msg)), getSelf());
     }
 
     private void addShard(final String shardName, final RemotePrimaryShardFound response, final ActorRef sender) {
-        checkLocalShardExists(shardName, sender);
+        if(isShardReplicaOperationInProgress(shardName, sender)) {
+            return;
+        }
 
-        ShardIdentifier shardId = getShardIdentifier(cluster.getCurrentMemberName(), shardName);
-        String localShardAddress = peerAddressResolver.getShardActorAddress(shardName, cluster.getCurrentMemberName());
+        shardReplicaOperationsInProgress.add(shardName);
+
+        final ShardInformation shardInfo;
+        final boolean removeShardOnFailure;
+        ShardInformation existingShardInfo = localShards.get(shardName);
+        if(existingShardInfo == null) {
+            removeShardOnFailure = true;
+            ShardIdentifier shardId = getShardIdentifier(cluster.getCurrentMemberName(), shardName);
+
+            DatastoreContext datastoreContext = newShardDatastoreContextBuilder(shardName).customRaftPolicyImplementation(
+                    DisableElectionsRaftPolicy.class.getName()).build();
 
-        DatastoreContext datastoreContext = newShardDatastoreContextBuilder(shardName).customRaftPolicyImplementation(
-                DisableElectionsRaftPolicy.class.getName()).build();
+            shardInfo = new ShardInformation(shardName, shardId, getPeerAddresses(shardName), datastoreContext,
+                    Shard.builder(), peerAddressResolver);
+            shardInfo.setActiveMember(false);
+            localShards.put(shardName, shardInfo);
+            shardInfo.setActor(newShardActor(schemaContext, shardInfo));
+        } else {
+            removeShardOnFailure = false;
+            shardInfo = existingShardInfo;
+        }
 
-        final ShardInformation shardInfo = new ShardInformation(shardName, shardId,
-                          getPeerAddresses(shardName), datastoreContext,
-                          new DefaultShardPropsCreator(), peerAddressResolver);
-        localShards.put(shardName, shardInfo);
-        shardInfo.setActor(newShardActor(schemaContext, shardInfo));
+        String localShardAddress = peerAddressResolver.getShardActorAddress(shardName, cluster.getCurrentMemberName());
 
         //inform ShardLeader to add this shard as a replica by sending an AddServer message
         LOG.debug ("{}: Sending AddServer message to peer {} for shard {}", persistenceId(),
-                response.getPrimaryPath(), shardId);
+                response.getPrimaryPath(), shardInfo.getShardId());
 
-        Timeout addServerTimeout = new Timeout(datastoreContext.getShardLeaderElectionTimeout().duration().$times(4));
+        Timeout addServerTimeout = new Timeout(shardInfo.getDatastoreContext().getShardLeaderElectionTimeout().
+                duration());
         Future<Object> futureObj = ask(getContext().actorSelection(response.getPrimaryPath()),
-            new AddServer(shardId.toString(), localShardAddress, true), addServerTimeout);
+            new AddServer(shardInfo.getShardId().toString(), localShardAddress, true), addServerTimeout);
 
         futureObj.onComplete(new OnComplete<Object>() {
             @Override
@@ -859,27 +1073,37 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
                     LOG.debug ("{}: AddServer request to {} for {} failed", persistenceId(),
                             response.getPrimaryPath(), shardName, failure);
 
-                    // Remove the shard
-                    localShards.remove(shardName);
-                    if (shardInfo.getActor() != null) {
-                        shardInfo.getActor().tell(PoisonPill.getInstance(), getSelf());
-                    }
-
-                    sender.tell(new akka.actor.Status.Failure(new RuntimeException(
-                        String.format("AddServer request to leader %s for shard %s failed",
-                            response.getPrimaryPath(), shardName), failure)), getSelf());
+                    String msg = String.format("AddServer request to leader %s for shard %s failed",
+                            response.getPrimaryPath(), shardName);
+                    self().tell(new ForwardedAddServerFailure(shardName, msg, failure, removeShardOnFailure), sender);
                 } else {
-                    AddServerReply reply = (AddServerReply)addServerResponse;
-                    onAddServerReply(shardName, shardInfo, reply, sender, response.getPrimaryPath());
+                    self().tell(new ForwardedAddServerReply(shardInfo, (AddServerReply)addServerResponse,
+                            response.getPrimaryPath(), removeShardOnFailure), sender);
                 }
             }
-        }, new Dispatchers(context().system().dispatchers()).
-            getDispatcher(Dispatchers.DispatcherType.Client));
-        return;
+        }, new Dispatchers(context().system().dispatchers()).getDispatcher(Dispatchers.DispatcherType.Client));
+    }
+
+    private void onAddServerFailure(String shardName, String message, Throwable failure, ActorRef sender,
+            boolean removeShardOnFailure) {
+        shardReplicaOperationsInProgress.remove(shardName);
+
+        if(removeShardOnFailure) {
+            ShardInformation shardInfo = localShards.remove(shardName);
+            if (shardInfo.getActor() != null) {
+                shardInfo.getActor().tell(PoisonPill.getInstance(), getSelf());
+            }
+        }
+
+        sender.tell(new akka.actor.Status.Failure(message == null ? failure :
+            new RuntimeException(message, failure)), getSelf());
     }
 
-    private void onAddServerReply (String shardName, ShardInformation shardInfo,
-                                   AddServerReply replyMsg, ActorRef sender, String leaderPath) {
+    private void onAddServerReply(ShardInformation shardInfo, AddServerReply replyMsg, ActorRef sender,
+            String leaderPath, boolean removeShardOnFailure) {
+        String shardName = shardInfo.getShardName();
+        shardReplicaOperationsInProgress.remove(shardName);
+
         LOG.debug ("{}: Received {} for shard {} from leader {}", persistenceId(), replyMsg, shardName, leaderPath);
 
         if (replyMsg.getStatus() == ServerChangeStatus.OK) {
@@ -887,49 +1111,139 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
 
             // Make the local shard voting capable
             shardInfo.setDatastoreContext(newShardDatastoreContext(shardName), getSelf());
+            shardInfo.setActiveMember(true);
+            persistShardList();
 
             mBean.addLocalShard(shardInfo.getShardId().toString());
-            sender.tell(new akka.actor.Status.Success(true), getSelf());
+            sender.tell(new akka.actor.Status.Success(null), getSelf());
+        } else if(replyMsg.getStatus() == ServerChangeStatus.ALREADY_EXISTS) {
+            sendLocalReplicaAlreadyExistsReply(shardName, sender);
         } else {
-            LOG.warn ("{}: Leader failed to add shard replica {} with status {} - removing the local shard",
+            LOG.warn ("{}: Leader failed to add shard replica {} with status {}",
                     persistenceId(), shardName, replyMsg.getStatus());
 
-            //remove the local replica created
-            localShards.remove(shardName);
-            if (shardInfo.getActor() != null) {
-                shardInfo.getActor().tell(PoisonPill.getInstance(), getSelf());
+            Exception failure = getServerChangeException(AddServer.class, replyMsg.getStatus(), leaderPath, shardInfo.getShardId());
+
+            onAddServerFailure(shardName, null, failure, sender, removeShardOnFailure);
+        }
+    }
+
+    private Exception getServerChangeException(Class<?> serverChange, ServerChangeStatus serverChangeStatus,
+                                               String leaderPath, ShardIdentifier shardId) {
+        Exception failure;
+        switch (serverChangeStatus) {
+            case TIMEOUT:
+                failure = new TimeoutException(String.format(
+                        "The shard leader %s timed out trying to replicate the initial data to the new shard %s." +
+                        "Possible causes - there was a problem replicating the data or shard leadership changed while replicating the shard data",
+                        leaderPath, shardId.getShardName()));
+                break;
+            case NO_LEADER:
+                failure = createNoShardLeaderException(shardId);
+                break;
+            case NOT_SUPPORTED:
+                failure = new UnsupportedOperationException(String.format("%s request is not supported for shard %s",
+                        serverChange.getSimpleName(), shardId.getShardName()));
+                break;
+            default :
+                failure = new RuntimeException(String.format(
+                        "%s request to leader %s for shard %s failed with status %s",
+                        serverChange.getSimpleName(), leaderPath, shardId.getShardName(), serverChangeStatus));
+        }
+        return failure;
+    }
+
+    private void onRemoveShardReplica (final RemoveShardReplica shardReplicaMsg) {
+        LOG.debug("{}: onRemoveShardReplica: {}", persistenceId(), shardReplicaMsg);
+
+        findPrimary(shardReplicaMsg.getShardName(), new AutoFindPrimaryFailureResponseHandler(getSender(),
+                shardReplicaMsg.getShardName(), persistenceId(), getSelf()) {
+            @Override
+            public void onRemotePrimaryShardFound(RemotePrimaryShardFound response) {
+                getSelf().tell(new PrimaryShardFoundForContext(getShardName(), shardReplicaMsg, response), getTargetActor());
+            }
+
+            @Override
+            public void onLocalPrimaryFound(LocalPrimaryShardFound response) {
+                getSelf().tell(new PrimaryShardFoundForContext(getShardName(), shardReplicaMsg, response), getTargetActor());
+            }
+        });
+    }
+
+    private void persistShardList() {
+        List<String> shardList = new ArrayList<>(localShards.keySet());
+        for (ShardInformation shardInfo : localShards.values()) {
+            if (!shardInfo.isActiveMember()) {
+                shardList.remove(shardInfo.getShardName());
             }
-            switch (replyMsg.getStatus()) {
-                case TIMEOUT:
-                    sender.tell(new akka.actor.Status.Failure(new RuntimeException(
-                        String.format("The shard leader %s timed out trying to replicate the initial data to the new shard %s. Possible causes - there was a problem replicating the data or shard leadership changed while replicating the shard data",
-                            leaderPath, shardName))), getSelf());
-                    break;
-                case NO_LEADER:
-                    sender.tell(new akka.actor.Status.Failure(new RuntimeException(String.format(
-                        "There is no shard leader available for shard %s", shardName))), getSelf());
-                    break;
-                default :
-                    sender.tell(new akka.actor.Status.Failure(new RuntimeException(String.format(
-                        "AddServer request to leader %s for shard %s failed with status %s",
-                        leaderPath, shardName, replyMsg.getStatus()))), getSelf());
+        }
+        LOG.debug ("{}: persisting the shard list {}", persistenceId(), shardList);
+        saveSnapshot(updateShardManagerSnapshot(shardList));
+    }
+
+    private ShardManagerSnapshot updateShardManagerSnapshot(List<String> shardList) {
+        currentSnapshot = new ShardManagerSnapshot(shardList);
+        return currentSnapshot;
+    }
+
+    private void applyShardManagerSnapshot(ShardManagerSnapshot snapshot) {
+        currentSnapshot = snapshot;
+
+        LOG.debug ("{}: onSnapshotOffer: {}", persistenceId(), currentSnapshot);
+
+        String currentMember = cluster.getCurrentMemberName();
+        Set<String> configuredShardList =
+            new HashSet<>(configuration.getMemberShardNames(currentMember));
+        for (String shard : currentSnapshot.getShardList()) {
+            if (!configuredShardList.contains(shard)) {
+                // add the current member as a replica for the shard
+                LOG.debug ("{}: adding shard {}", persistenceId(), shard);
+                configuration.addMemberReplicaForShard(shard, currentMember);
+            } else {
+                configuredShardList.remove(shard);
             }
         }
+        for (String shard : configuredShardList) {
+            // remove the member as a replica for the shard
+            LOG.debug ("{}: removing shard {}", persistenceId(), shard);
+            configuration.removeMemberReplicaForShard(shard, currentMember);
+        }
     }
 
-    private void onRemoveShardReplica (RemoveShardReplica shardReplicaMsg) {
-        String shardName = shardReplicaMsg.getShardName();
+    private void onSaveSnapshotSuccess (SaveSnapshotSuccess successMessage) {
+        LOG.debug ("{} saved ShardManager snapshot successfully. Deleting the prev snapshot if available",
+            persistenceId());
+        deleteSnapshots(new SnapshotSelectionCriteria(scala.Long.MaxValue(), (successMessage.metadata().timestamp() - 1)));
+    }
 
-        // verify the local shard replica is available in the controller node
-        if (!localShards.containsKey(shardName)) {
-            String msg = String.format("Local shard %s does not", shardName);
-            LOG.debug ("{}: {}", persistenceId(), msg);
-            getSender().tell(new akka.actor.Status.Failure(new IllegalArgumentException(msg)), getSelf());
-            return;
+    private static class ForwardedAddServerReply {
+        ShardInformation shardInfo;
+        AddServerReply addServerReply;
+        String leaderPath;
+        boolean removeShardOnFailure;
+
+        ForwardedAddServerReply(ShardInformation shardInfo, AddServerReply addServerReply, String leaderPath,
+                boolean removeShardOnFailure) {
+            this.shardInfo = shardInfo;
+            this.addServerReply = addServerReply;
+            this.leaderPath = leaderPath;
+            this.removeShardOnFailure = removeShardOnFailure;
+        }
+    }
+
+    private static class ForwardedAddServerFailure {
+        String shardName;
+        String failureMessage;
+        Throwable failure;
+        boolean removeShardOnFailure;
+
+        ForwardedAddServerFailure(String shardName, String failureMessage, Throwable failure,
+                boolean removeShardOnFailure) {
+            this.shardName = shardName;
+            this.failureMessage = failureMessage;
+            this.failure = failure;
+            this.removeShardOnFailure = removeShardOnFailure;
         }
-        // call RemoveShard for the shardName
-        getSender().tell(new akka.actor.Status.Success(true), getSelf());
-        return;
     }
 
     @VisibleForTesting
@@ -953,28 +1267,34 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
         private short leaderVersion;
 
         private DatastoreContext datastoreContext;
-        private final ShardPropsCreator shardPropsCreator;
+        private Shard.AbstractBuilder<?, ?> builder;
         private final ShardPeerAddressResolver addressResolver;
+        private boolean isActiveMember = true;
 
         private ShardInformation(String shardName, ShardIdentifier shardId,
                 Map<String, String> initialPeerAddresses, DatastoreContext datastoreContext,
-                ShardPropsCreator shardPropsCreator, ShardPeerAddressResolver addressResolver) {
+                Shard.AbstractBuilder<?, ?> builder, ShardPeerAddressResolver addressResolver) {
             this.shardName = shardName;
             this.shardId = shardId;
             this.initialPeerAddresses = initialPeerAddresses;
             this.datastoreContext = datastoreContext;
-            this.shardPropsCreator = shardPropsCreator;
+            this.builder = builder;
             this.addressResolver = addressResolver;
         }
 
         Props newProps(SchemaContext schemaContext) {
-            return shardPropsCreator.newProps(shardId, initialPeerAddresses, datastoreContext, schemaContext);
+            Preconditions.checkNotNull(builder);
+            Props props = builder.id(shardId).peerAddresses(initialPeerAddresses).datastoreContext(datastoreContext).
+                    schemaContext(schemaContext).props();
+            builder = null;
+            return props;
         }
 
         String getShardName() {
             return shardName;
         }
 
+        @Nullable
         ActorRef getActor(){
             return actor;
         }
@@ -1140,6 +1460,10 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
 
         void setLeaderAvailable(boolean leaderAvailable) {
             this.leaderAvailable = leaderAvailable;
+
+            if(leaderAvailable) {
+                notifyOnShardInitializedCallbacks();
+            }
         }
 
         short getLeaderVersion() {
@@ -1149,31 +1473,13 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
         void setLeaderVersion(short leaderVersion) {
             this.leaderVersion = leaderVersion;
         }
-    }
-
-    private static class ShardManagerCreator implements Creator<ShardManager> {
-        private static final long serialVersionUID = 1L;
 
-        final ClusterWrapper cluster;
-        final Configuration configuration;
-        final DatastoreContextFactory datastoreContextFactory;
-        private final CountDownLatch waitTillReadyCountdownLatch;
-        private final PrimaryShardInfoFutureCache primaryShardInfoCache;
-
-        ShardManagerCreator(ClusterWrapper cluster, Configuration configuration,
-                DatastoreContextFactory datastoreContextFactory, CountDownLatch waitTillReadyCountdownLatch,
-                PrimaryShardInfoFutureCache primaryShardInfoCache) {
-            this.cluster = cluster;
-            this.configuration = configuration;
-            this.datastoreContextFactory = datastoreContextFactory;
-            this.waitTillReadyCountdownLatch = waitTillReadyCountdownLatch;
-            this.primaryShardInfoCache = primaryShardInfoCache;
+        boolean isActiveMember() {
+            return isActiveMember;
         }
 
-        @Override
-        public ShardManager create() throws Exception {
-            return new ShardManager(cluster, configuration, datastoreContextFactory, waitTillReadyCountdownLatch,
-                    primaryShardInfoCache);
+        void setActiveMember(boolean isActiveMember) {
+            this.isActiveMember = isActiveMember;
         }
     }
 
@@ -1246,6 +1552,256 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
             return modules;
         }
     }
+
+    public static Builder builder() {
+        return new Builder();
+    }
+
+    public static abstract class AbstractBuilder<T extends AbstractBuilder<T>> {
+        private ClusterWrapper cluster;
+        private Configuration configuration;
+        private DatastoreContextFactory datastoreContextFactory;
+        private CountDownLatch waitTillReadyCountdownLatch;
+        private PrimaryShardInfoFutureCache primaryShardInfoCache;
+        private DatastoreSnapshot restoreFromSnapshot;
+
+        private volatile boolean sealed;
+
+        @SuppressWarnings("unchecked")
+        private T self() {
+            return (T) this;
+        }
+
+        protected void checkSealed() {
+            Preconditions.checkState(!sealed, "Builder is already sealed - further modifications are not allowed");
+        }
+
+        public T cluster(ClusterWrapper cluster) {
+            checkSealed();
+            this.cluster = cluster;
+            return self();
+        }
+
+        public T configuration(Configuration configuration) {
+            checkSealed();
+            this.configuration = configuration;
+            return self();
+        }
+
+        public T datastoreContextFactory(DatastoreContextFactory datastoreContextFactory) {
+            checkSealed();
+            this.datastoreContextFactory = datastoreContextFactory;
+            return self();
+        }
+
+        public T waitTillReadyCountdownLatch(CountDownLatch waitTillReadyCountdownLatch) {
+            checkSealed();
+            this.waitTillReadyCountdownLatch = waitTillReadyCountdownLatch;
+            return self();
+        }
+
+        public T primaryShardInfoCache(PrimaryShardInfoFutureCache primaryShardInfoCache) {
+            checkSealed();
+            this.primaryShardInfoCache = primaryShardInfoCache;
+            return self();
+        }
+
+        public T restoreFromSnapshot(DatastoreSnapshot restoreFromSnapshot) {
+            checkSealed();
+            this.restoreFromSnapshot = restoreFromSnapshot;
+            return self();
+        }
+
+        protected void verify() {
+            sealed = true;
+            Preconditions.checkNotNull(cluster, "cluster should not be null");
+            Preconditions.checkNotNull(configuration, "configuration should not be null");
+            Preconditions.checkNotNull(datastoreContextFactory, "datastoreContextFactory should not be null");
+            Preconditions.checkNotNull(waitTillReadyCountdownLatch, "waitTillReadyCountdownLatch should not be null");
+            Preconditions.checkNotNull(primaryShardInfoCache, "primaryShardInfoCache should not be null");
+        }
+
+        public Props props() {
+            verify();
+            return Props.create(ShardManager.class, this);
+        }
+    }
+
+    public static class Builder extends AbstractBuilder<Builder> {
+    }
+
+    private void findPrimary(final String shardName, final FindPrimaryResponseHandler handler) {
+        Timeout findPrimaryTimeout = new Timeout(datastoreContextFactory.getBaseDatastoreContext().
+                getShardInitializationTimeout().duration().$times(2));
+
+
+        Future<Object> futureObj = ask(getSelf(), new FindPrimary(shardName, true), findPrimaryTimeout);
+        futureObj.onComplete(new OnComplete<Object>() {
+            @Override
+            public void onComplete(Throwable failure, Object response) {
+                if (failure != null) {
+                    handler.onFailure(failure);
+                } else {
+                    if(response instanceof RemotePrimaryShardFound) {
+                        handler.onRemotePrimaryShardFound((RemotePrimaryShardFound) response);
+                    } else if(response instanceof LocalPrimaryShardFound) {
+                        handler.onLocalPrimaryFound((LocalPrimaryShardFound) response);
+                    } else {
+                        handler.onUnknownResponse(response);
+                    }
+                }
+            }
+        }, new Dispatchers(context().system().dispatchers()).getDispatcher(Dispatchers.DispatcherType.Client));
+    }
+
+    /**
+     * The FindPrimaryResponseHandler provides specific callback methods which are invoked when a response to the
+     * a remote or local find primary message is processed
+     */
+    private static interface FindPrimaryResponseHandler {
+        /**
+         * Invoked when a Failure message is received as a response
+         *
+         * @param failure
+         */
+        void onFailure(Throwable failure);
+
+        /**
+         * Invoked when a RemotePrimaryShardFound response is received
+         *
+         * @param response
+         */
+        void onRemotePrimaryShardFound(RemotePrimaryShardFound response);
+
+        /**
+         * Invoked when a LocalPrimaryShardFound response is received
+         * @param response
+         */
+        void onLocalPrimaryFound(LocalPrimaryShardFound response);
+
+        /**
+         * Invoked when an unknown response is received. This is another type of failure.
+         *
+         * @param response
+         */
+        void onUnknownResponse(Object response);
+    }
+
+    /**
+     * The AutoFindPrimaryFailureResponseHandler automatically processes Failure responses when finding a primary
+     * replica and sends a wrapped Failure response to some targetActor
+     */
+    private static abstract class AutoFindPrimaryFailureResponseHandler implements FindPrimaryResponseHandler {
+        private final ActorRef targetActor;
+        private final String shardName;
+        private final String persistenceId;
+        private final ActorRef shardManagerActor;
+
+        /**
+         * @param targetActor The actor to whom the Failure response should be sent when a FindPrimary failure occurs
+         * @param shardName The name of the shard for which the primary replica had to be found
+         * @param persistenceId The persistenceId for the ShardManager
+         * @param shardManagerActor The ShardManager actor which triggered the call to FindPrimary
+         */
+        protected AutoFindPrimaryFailureResponseHandler(ActorRef targetActor, String shardName, String persistenceId, ActorRef shardManagerActor){
+            this.targetActor = Preconditions.checkNotNull(targetActor);
+            this.shardName = Preconditions.checkNotNull(shardName);
+            this.persistenceId = Preconditions.checkNotNull(persistenceId);
+            this.shardManagerActor = Preconditions.checkNotNull(shardManagerActor);
+        }
+
+        public ActorRef getTargetActor() {
+            return targetActor;
+        }
+
+        public String getShardName() {
+            return shardName;
+        }
+
+        @Override
+        public void onFailure(Throwable failure) {
+            LOG.debug ("{}: Received failure from FindPrimary for shard {}", persistenceId, shardName, failure);
+            targetActor.tell(new akka.actor.Status.Failure(new RuntimeException(
+                    String.format("Failed to find leader for shard %s", shardName), failure)), shardManagerActor);
+        }
+
+        @Override
+        public void onUnknownResponse(Object response) {
+            String msg = String.format("Failed to find leader for shard %s: received response: %s",
+                    shardName, response);
+            LOG.debug ("{}: {}", persistenceId, msg);
+            targetActor.tell(new akka.actor.Status.Failure(response instanceof Throwable ? (Throwable) response :
+                    new RuntimeException(msg)), shardManagerActor);
+        }
+    }
+
+
+    /**
+     * The PrimaryShardFoundForContext is a DTO which puts together a message (aka 'Context' message) which needs to be
+     * forwarded to the primary replica of a shard and the message (aka 'PrimaryShardFound' message) that is received
+     * as a successful response to find primary.
+     */
+    private static class PrimaryShardFoundForContext {
+        private final String shardName;
+        private final Object contextMessage;
+        private final RemotePrimaryShardFound remotePrimaryShardFound;
+        private final LocalPrimaryShardFound localPrimaryShardFound;
+
+        public PrimaryShardFoundForContext(@Nonnull String shardName, @Nonnull Object contextMessage,
+                @Nonnull Object primaryFoundMessage) {
+            this.shardName = Preconditions.checkNotNull(shardName);
+            this.contextMessage = Preconditions.checkNotNull(contextMessage);
+            Preconditions.checkNotNull(primaryFoundMessage);
+            this.remotePrimaryShardFound = (primaryFoundMessage instanceof RemotePrimaryShardFound) ?
+                    (RemotePrimaryShardFound) primaryFoundMessage : null;
+            this.localPrimaryShardFound = (primaryFoundMessage instanceof LocalPrimaryShardFound) ?
+                    (LocalPrimaryShardFound) primaryFoundMessage : null;
+        }
+
+        @Nonnull
+        String getPrimaryPath(){
+            if(remotePrimaryShardFound != null) {
+                return remotePrimaryShardFound.getPrimaryPath();
+            }
+            return localPrimaryShardFound.getPrimaryPath();
+        }
+
+        @Nonnull
+        Object getContextMessage() {
+            return contextMessage;
+        }
+
+        @Nullable
+        RemotePrimaryShardFound getRemotePrimaryShardFound() {
+            return remotePrimaryShardFound;
+        }
+
+        @Nonnull
+        String getShardName() {
+            return shardName;
+        }
+    }
+
+    /**
+     * The WrappedShardResponse class wraps a response from a Shard.
+     */
+    private static class WrappedShardResponse {
+        private final String shardName;
+        private final Object response;
+
+        private WrappedShardResponse(String shardName, Object response) {
+            this.shardName = shardName;
+            this.response = response;
+        }
+
+        String getShardName() {
+            return shardName;
+        }
+
+        Object getResponse() {
+            return response;
+        }
+    }
 }