Bug 4823: Notify findPrimary callbacks on ReachableMember event
[controller.git] / opendaylight / md-sal / sal-distributed-datastore / src / main / java / org / opendaylight / controller / cluster / datastore / ShardManager.java
index 98a6090514c9549f2f506c82a85fce7376e35cf6..a91109c64b3fa5aac674677b95d2a31c83477fbd 100644 (file)
@@ -16,6 +16,7 @@ import akka.actor.Cancellable;
 import akka.actor.OneForOneStrategy;
 import akka.actor.PoisonPill;
 import akka.actor.Props;
+import akka.actor.Status;
 import akka.actor.SupervisorStrategy;
 import akka.cluster.ClusterEvent;
 import akka.dispatch.OnComplete;
@@ -24,6 +25,7 @@ import akka.persistence.RecoveryCompleted;
 import akka.persistence.SaveSnapshotFailure;
 import akka.persistence.SaveSnapshotSuccess;
 import akka.persistence.SnapshotOffer;
+import akka.persistence.SnapshotSelectionCriteria;
 import akka.serialization.Serialization;
 import akka.util.Timeout;
 import com.google.common.annotations.VisibleForTesting;
@@ -33,9 +35,12 @@ import com.google.common.base.Preconditions;
 import com.google.common.base.Strings;
 import com.google.common.base.Supplier;
 import com.google.common.collect.Sets;
+import java.io.ByteArrayInputStream;
+import java.io.ObjectInputStream;
 import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Collection;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Iterator;
@@ -44,9 +49,14 @@ import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+import org.apache.commons.lang3.SerializationUtils;
 import org.opendaylight.controller.cluster.common.actor.AbstractUntypedPersistentActorWithMetering;
 import org.opendaylight.controller.cluster.datastore.config.Configuration;
 import org.opendaylight.controller.cluster.datastore.config.ModuleShardConfiguration;
+import org.opendaylight.controller.cluster.datastore.exceptions.AlreadyExistsException;
 import org.opendaylight.controller.cluster.datastore.exceptions.NoShardLeaderException;
 import org.opendaylight.controller.cluster.datastore.exceptions.NotInitializedException;
 import org.opendaylight.controller.cluster.datastore.exceptions.PrimaryNotFoundException;
@@ -56,7 +66,6 @@ import org.opendaylight.controller.cluster.datastore.jmx.mbeans.shardmanager.Sha
 import org.opendaylight.controller.cluster.datastore.messages.ActorInitialized;
 import org.opendaylight.controller.cluster.datastore.messages.AddShardReplica;
 import org.opendaylight.controller.cluster.datastore.messages.CreateShard;
-import org.opendaylight.controller.cluster.datastore.messages.CreateShardReply;
 import org.opendaylight.controller.cluster.datastore.messages.DatastoreSnapshot;
 import org.opendaylight.controller.cluster.datastore.messages.FindLocalShard;
 import org.opendaylight.controller.cluster.datastore.messages.FindPrimary;
@@ -82,7 +91,10 @@ import org.opendaylight.controller.cluster.raft.base.messages.SwitchBehavior;
 import org.opendaylight.controller.cluster.raft.client.messages.GetSnapshot;
 import org.opendaylight.controller.cluster.raft.messages.AddServer;
 import org.opendaylight.controller.cluster.raft.messages.AddServerReply;
+import org.opendaylight.controller.cluster.raft.messages.RemoveServer;
+import org.opendaylight.controller.cluster.raft.messages.RemoveServerReply;
 import org.opendaylight.controller.cluster.raft.messages.ServerChangeStatus;
+import org.opendaylight.controller.cluster.raft.messages.ServerRemoved;
 import org.opendaylight.controller.cluster.raft.policy.DisableElectionsRaftPolicy;
 import org.opendaylight.yangtools.yang.data.api.schema.tree.DataTree;
 import org.opendaylight.yangtools.yang.model.api.SchemaContext;
@@ -134,20 +146,29 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
 
     private DatastoreSnapshot restoreFromSnapshot;
 
+    private ShardManagerSnapshot currentSnapshot;
+
+    private final Set<String> shardReplicaOperationsInProgress = new HashSet<>();
+
+    private final String persistenceId;
+
     /**
      */
-    protected ShardManager(Builder builder) {
+    protected ShardManager(AbstractBuilder<?> builder) {
 
         this.cluster = builder.cluster;
         this.configuration = builder.configuration;
         this.datastoreContextFactory = builder.datastoreContextFactory;
-        this.type = builder.datastoreContextFactory.getBaseDatastoreContext().getDataStoreType();
+        this.type = builder.datastoreContextFactory.getBaseDatastoreContext().getDataStoreName();
         this.shardDispatcherPath =
                 new Dispatchers(context().system().dispatchers()).getDispatcherPath(Dispatchers.DispatcherType.Shard);
         this.waitTillReadyCountdownLatch = builder.waitTillReadyCountdownLatch;
         this.primaryShardInfoCache = builder.primaryShardInfoCache;
         this.restoreFromSnapshot = builder.restoreFromSnapshot;
 
+        String possiblePersistenceId = datastoreContextFactory.getBaseDatastoreContext().getShardManagerPersistenceId();
+        persistenceId = possiblePersistenceId != null ? possiblePersistenceId : "shard-manager-" + type;
+
         peerAddressResolver = new ShardPeerAddressResolver(type, cluster.getCurrentMemberName());
 
         // Subscribe this actor to cluster member events
@@ -204,19 +225,107 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
             onCreateShard((CreateShard)message);
         } else if(message instanceof AddShardReplica){
             onAddShardReplica((AddShardReplica)message);
-        } else if(message instanceof RemoveShardReplica){
-            onRemoveShardReplica((RemoveShardReplica)message);
+        } else if(message instanceof ForwardedAddServerReply) {
+            ForwardedAddServerReply msg = (ForwardedAddServerReply)message;
+            onAddServerReply(msg.shardInfo, msg.addServerReply, getSender(), msg.leaderPath,
+                    msg.removeShardOnFailure);
+        } else if(message instanceof ForwardedAddServerFailure) {
+            ForwardedAddServerFailure msg = (ForwardedAddServerFailure)message;
+            onAddServerFailure(msg.shardName, msg.failureMessage, msg.failure, getSender(), msg.removeShardOnFailure);
+        } else if(message instanceof PrimaryShardFoundForContext) {
+            PrimaryShardFoundForContext primaryShardFoundContext = (PrimaryShardFoundForContext)message;
+            onPrimaryShardFoundContext(primaryShardFoundContext);
+        } else if(message instanceof RemoveShardReplica) {
+            onRemoveShardReplica((RemoveShardReplica) message);
+        } else if(message instanceof WrappedShardResponse){
+            onWrappedShardResponse((WrappedShardResponse) message);
         } else if(message instanceof GetSnapshot) {
             onGetSnapshot();
+        } else if(message instanceof ServerRemoved){
+            onShardReplicaRemoved((ServerRemoved) message);
         } else if (message instanceof SaveSnapshotSuccess) {
-            LOG.debug ("{} saved ShardManager snapshot successfully", persistenceId());
+            onSaveSnapshotSuccess((SaveSnapshotSuccess)message);
         } else if (message instanceof SaveSnapshotFailure) {
-            LOG.error ("{}: SaveSnapshotFailure received for saving snapshot of shards",
-                persistenceId(), ((SaveSnapshotFailure)message).cause());
+            LOG.error("{}: SaveSnapshotFailure received for saving snapshot of shards",
+                    persistenceId(), ((SaveSnapshotFailure) message).cause());
         } else {
             unknownMessage(message);
         }
+    }
+
+    private void onWrappedShardResponse(WrappedShardResponse message) {
+        if (message.getResponse() instanceof RemoveServerReply) {
+            onRemoveServerReply(getSender(), message.getShardName(), (RemoveServerReply) message.getResponse());
+        }
+    }
 
+    private void onRemoveServerReply(ActorRef originalSender, String shardName, RemoveServerReply response) {
+        shardReplicaOperationsInProgress.remove(shardName);
+        originalSender.tell(new Status.Success(null), self());
+    }
+
+    private void onPrimaryShardFoundContext(PrimaryShardFoundForContext primaryShardFoundContext) {
+        if(primaryShardFoundContext.getContextMessage() instanceof AddShardReplica) {
+            addShard(primaryShardFoundContext.getShardName(), primaryShardFoundContext.getRemotePrimaryShardFound(),
+                    getSender());
+        } else if(primaryShardFoundContext.getContextMessage() instanceof RemoveShardReplica){
+            removeShardReplica((RemoveShardReplica) primaryShardFoundContext.getContextMessage(),
+                    primaryShardFoundContext.getShardName(), primaryShardFoundContext.getPrimaryPath(), getSender());
+        }
+    }
+
+    private void removeShardReplica(RemoveShardReplica contextMessage, final String shardName, final String primaryPath,
+            final ActorRef sender) {
+        if(isShardReplicaOperationInProgress(shardName, sender)) {
+            return;
+        }
+
+        shardReplicaOperationsInProgress.add(shardName);
+
+        final ShardIdentifier shardId = getShardIdentifier(contextMessage.getMemberName(), shardName);
+
+        final DatastoreContext datastoreContext = newShardDatastoreContextBuilder(shardName).build();
+
+        //inform ShardLeader to remove this shard as a replica by sending an RemoveServer message
+        LOG.debug ("{}: Sending RemoveServer message to peer {} for shard {}", persistenceId(),
+                primaryPath, shardId);
+
+        Timeout removeServerTimeout = new Timeout(datastoreContext.getShardLeaderElectionTimeout().
+                duration());
+        Future<Object> futureObj = ask(getContext().actorSelection(primaryPath),
+                new RemoveServer(shardId.toString()), removeServerTimeout);
+
+        futureObj.onComplete(new OnComplete<Object>() {
+            @Override
+            public void onComplete(Throwable failure, Object response) {
+                if (failure != null) {
+                    String msg = String.format("RemoveServer request to leader %s for shard %s failed",
+                            primaryPath, shardName);
+
+                    LOG.debug ("{}: {}", persistenceId(), msg, failure);
+
+                    // FAILURE
+                    sender.tell(new Status.Failure(new RuntimeException(msg, failure)), self());
+                } else {
+                    // SUCCESS
+                    self().tell(new WrappedShardResponse(shardName, response), sender);
+                }
+            }
+        }, new Dispatchers(context().system().dispatchers()).getDispatcher(Dispatchers.DispatcherType.Client));
+    }
+
+    private void onShardReplicaRemoved(ServerRemoved message) {
+        final ShardIdentifier shardId = new ShardIdentifier.Builder().fromShardIdString(message.getServerId()).build();
+        final ShardInformation shardInformation = localShards.remove(shardId.getShardName());
+        if(shardInformation == null) {
+            LOG.debug("{} : Shard replica {} is not present in list", persistenceId(), shardId.toString());
+            return;
+        } else if(shardInformation.getActor() != null) {
+            LOG.debug("{} : Sending PoisonPill to Shard actor {}", persistenceId(), shardInformation.getActor());
+            shardInformation.getActor().tell(PoisonPill.getInstance(), self());
+        }
+        LOG.debug("{} : Local Shard replica for shard {} has been removed", persistenceId(), shardId.getShardName());
+        persistShardList();
     }
 
     private void onGetSnapshot() {
@@ -240,6 +349,10 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
         }
 
         byte[] shardManagerSnapshot = null;
+        if(currentSnapshot != null) {
+            shardManagerSnapshot = SerializationUtils.serialize(currentSnapshot);
+        }
+
         ActorRef replyActor = getContext().actorOf(ShardManagerGetSnapshotReplyActor.props(
                 new ArrayList<>(localShards.keySet()), type, shardManagerSnapshot , getSender(), persistenceId(),
                 datastoreContextFactory.getBaseDatastoreContext().getShardInitializationTimeout().duration()));
@@ -250,49 +363,77 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
     }
 
     private void onCreateShard(CreateShard createShard) {
+        LOG.debug("{}: onCreateShard: {}", persistenceId(), createShard);
+
         Object reply;
         try {
-            ModuleShardConfiguration moduleShardConfig = createShard.getModuleShardConfig();
-            if(localShards.containsKey(moduleShardConfig.getShardName())) {
-                throw new IllegalStateException(String.format("Shard with name %s already exists",
-                        moduleShardConfig.getShardName()));
+            String shardName = createShard.getModuleShardConfig().getShardName();
+            if(localShards.containsKey(shardName)) {
+                LOG.debug("{}: Shard {} already exists", persistenceId(), shardName);
+                reply = new akka.actor.Status.Success(String.format("Shard with name %s already exists", shardName));
+            } else {
+                doCreateShard(createShard);
+                reply = new akka.actor.Status.Success(null);
             }
+        } catch (Exception e) {
+            LOG.error("{}: onCreateShard failed", persistenceId(), e);
+            reply = new akka.actor.Status.Failure(e);
+        }
 
-            configuration.addModuleShardConfiguration(moduleShardConfig);
-
-            ShardIdentifier shardId = getShardIdentifier(cluster.getCurrentMemberName(), moduleShardConfig.getShardName());
-            Map<String, String> peerAddresses = getPeerAddresses(moduleShardConfig.getShardName()/*,
-                    moduleShardConfig.getShardMemberNames()*/);
+        if(getSender() != null && !getContext().system().deadLetters().equals(getSender())) {
+            getSender().tell(reply, getSelf());
+        }
+    }
 
-            LOG.debug("onCreateShard: shardId: {}, memberNames: {}. peerAddresses: {}", shardId,
-                    moduleShardConfig.getShardMemberNames(), peerAddresses);
+    private void doCreateShard(CreateShard createShard) {
+        ModuleShardConfiguration moduleShardConfig = createShard.getModuleShardConfig();
+        String shardName = moduleShardConfig.getShardName();
 
-            DatastoreContext shardDatastoreContext = createShard.getDatastoreContext();
-            if(shardDatastoreContext == null) {
-                shardDatastoreContext = newShardDatastoreContext(moduleShardConfig.getShardName());
-            } else {
-                shardDatastoreContext = DatastoreContext.newBuilderFrom(shardDatastoreContext).shardPeerAddressResolver(
-                        peerAddressResolver).build();
-            }
+        configuration.addModuleShardConfiguration(moduleShardConfig);
 
-            ShardInformation info = new ShardInformation(moduleShardConfig.getShardName(), shardId, peerAddresses,
-                    shardDatastoreContext, createShard.getShardBuilder(), peerAddressResolver);
-            localShards.put(info.getShardName(), info);
+        DatastoreContext shardDatastoreContext = createShard.getDatastoreContext();
+        if(shardDatastoreContext == null) {
+            shardDatastoreContext = newShardDatastoreContext(shardName);
+        } else {
+            shardDatastoreContext = DatastoreContext.newBuilderFrom(shardDatastoreContext).shardPeerAddressResolver(
+                    peerAddressResolver).build();
+        }
 
-            mBean.addLocalShard(shardId.toString());
+        ShardIdentifier shardId = getShardIdentifier(cluster.getCurrentMemberName(), shardName);
 
-            if(schemaContext != null) {
-                info.setActor(newShardActor(schemaContext, info));
-            }
+        boolean shardWasInRecoveredSnapshot = currentSnapshot != null &&
+                currentSnapshot.getShardList().contains(shardName);
 
-            reply = new CreateShardReply();
-        } catch (Exception e) {
-            LOG.error("onCreateShard failed", e);
-            reply = new akka.actor.Status.Failure(e);
+        Map<String, String> peerAddresses;
+        boolean isActiveMember;
+        if(shardWasInRecoveredSnapshot || configuration.getMembersFromShardName(shardName).
+                contains(cluster.getCurrentMemberName())) {
+            peerAddresses = getPeerAddresses(shardName);
+            isActiveMember = true;
+        } else {
+            // The local member is not in the static shard member configuration and the shard did not
+            // previously exist (ie !shardWasInRecoveredSnapshot). In this case we'll create
+            // the shard with no peers and with elections disabled so it stays as follower. A
+            // subsequent AddServer request will be needed to make it an active member.
+            isActiveMember = false;
+            peerAddresses = Collections.emptyMap();
+            shardDatastoreContext = DatastoreContext.newBuilderFrom(shardDatastoreContext).
+                    customRaftPolicyImplementation(DisableElectionsRaftPolicy.class.getName()).build();
         }
 
-        if(getSender() != null && !getContext().system().deadLetters().equals(getSender())) {
-            getSender().tell(reply, getSelf());
+        LOG.debug("{} doCreateShard: shardId: {}, memberNames: {}, peerAddresses: {}, isActiveMember: {}",
+                persistenceId(), shardId, moduleShardConfig.getShardMemberNames(), peerAddresses,
+                isActiveMember);
+
+        ShardInformation info = new ShardInformation(shardName, shardId, peerAddresses,
+                shardDatastoreContext, createShard.getShardBuilder(), peerAddressResolver);
+        info.setActiveMember(isActiveMember);
+        localShards.put(info.getShardName(), info);
+
+        mBean.addLocalShard(shardId.toString());
+
+        if(schemaContext != null) {
+            info.setActor(newShardActor(schemaContext, info));
         }
     }
 
@@ -437,17 +578,36 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
     @Override
     protected void handleRecover(Object message) throws Exception {
         if (message instanceof RecoveryCompleted) {
-            LOG.info("Recovery complete : {}", persistenceId());
-
-            // We no longer persist SchemaContext modules so delete all the prior messages from the akka
-            // journal on upgrade from Helium.
-            deleteMessages(lastSequenceNr());
-            createLocalShards();
+            onRecoveryCompleted();
         } else if (message instanceof SnapshotOffer) {
-            handleShardRecovery((SnapshotOffer) message);
+            applyShardManagerSnapshot((ShardManagerSnapshot)((SnapshotOffer) message).snapshot());
         }
     }
 
+    private void onRecoveryCompleted() {
+        LOG.info("Recovery complete : {}", persistenceId());
+
+        // We no longer persist SchemaContext modules so delete all the prior messages from the akka
+        // journal on upgrade from Helium.
+        deleteMessages(lastSequenceNr());
+
+        if(currentSnapshot == null && restoreFromSnapshot != null &&
+                restoreFromSnapshot.getShardManagerSnapshot() != null) {
+            try(ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(
+                    restoreFromSnapshot.getShardManagerSnapshot()))) {
+                ShardManagerSnapshot snapshot = (ShardManagerSnapshot) ois.readObject();
+
+                LOG.debug("{}: Deserialized restored ShardManagerSnapshot: {}", persistenceId(), snapshot);
+
+                applyShardManagerSnapshot(snapshot);
+            } catch(Exception e) {
+                LOG.error("{}: Error deserializing restored ShardManagerSnapshot", persistenceId(), e);
+            }
+        }
+
+        createLocalShards();
+    }
+
     private void findLocalShard(FindLocalShard message) {
         final ShardInformation shardInformation = localShards.get(message.getShardName());
 
@@ -677,7 +837,7 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
 
         // First see if the there is a local replica for the shard
         final ShardInformation info = localShards.get(shardName);
-        if (info != null) {
+        if (info != null && info.isActiveMember()) {
             sendResponse(info, message.isWaitUntilReady(), true, new Supplier<Object>() {
                 @Override
                 public Object get() {
@@ -697,12 +857,25 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
             return;
         }
 
+        Collection<String> visitedAddresses;
+        if(message instanceof RemoteFindPrimary) {
+            visitedAddresses = ((RemoteFindPrimary)message).getVisitedAddresses();
+        } else {
+            visitedAddresses = new ArrayList<>();
+        }
+
+        visitedAddresses.add(peerAddressResolver.getShardManagerActorPathBuilder(cluster.getSelfAddress()).toString());
+
         for(String address: peerAddressResolver.getShardManagerPeerActorAddresses()) {
+            if(visitedAddresses.contains(address)) {
+                continue;
+            }
+
             LOG.debug("{}: findPrimary for {} forwarding to remote ShardManager {}", persistenceId(),
                     shardName, address);
 
             getContext().actorSelection(address).forward(new RemoteFindPrimary(shardName,
-                    message.isWaitUntilReady()), getContext());
+                    message.isWaitUntilReady(), visitedAddresses), getContext());
             return;
         }
 
@@ -745,6 +918,9 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
 
         for(String shardName : memberShardNames){
             ShardIdentifier shardId = getShardIdentifier(memberName, shardName);
+
+            LOG.debug("{}: Creating local shard: {}", persistenceId(), shardId);
+
             Map<String, String> peerAddresses = getPeerAddresses(shardName);
             localShards.put(shardName, new ShardInformation(shardName, shardId, peerAddresses,
                     newShardDatastoreContext(shardName), Shard.builder().restoreFromSnapshot(
@@ -791,7 +967,7 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
 
     @Override
     public String persistenceId() {
-        return "shard-manager-" + type;
+        return persistenceId;
     }
 
     @VisibleForTesting
@@ -799,21 +975,21 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
         return mBean;
     }
 
-    private void checkLocalShardExists(final String shardName, final ActorRef sender) {
-        if (localShards.containsKey(shardName)) {
-            String msg = String.format("Local shard %s already exists", shardName);
+    private boolean isShardReplicaOperationInProgress(final String shardName, final ActorRef sender) {
+        if (shardReplicaOperationsInProgress.contains(shardName)) {
+            String msg = String.format("A shard replica operation for %s is already in progress", shardName);
             LOG.debug ("{}: {}", persistenceId(), msg);
-            sender.tell(new akka.actor.Status.Failure(new IllegalArgumentException(msg)), getSelf());
+            sender.tell(new akka.actor.Status.Failure(new IllegalStateException(msg)), getSelf());
+            return true;
         }
+
+        return false;
     }
 
-    private void onAddShardReplica (AddShardReplica shardReplicaMsg) {
+    private void onAddShardReplica (final AddShardReplica shardReplicaMsg) {
         final String shardName = shardReplicaMsg.getShardName();
 
-        // verify the local shard replica is already available in the controller node
-        LOG.debug ("onAddShardReplica: {}", shardReplicaMsg);
-
-        checkLocalShardExists(shardName, getSender());
+        LOG.debug("{}: onAddShardReplica: {}", persistenceId(), shardReplicaMsg);
 
         // verify the shard with the specified name is present in the cluster configuration
         if (!(this.configuration.isShardConfigured(shardName))) {
@@ -832,66 +1008,63 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
             return;
         }
 
-        Map<String, String> peerAddresses = getPeerAddresses(shardName);
-        if (peerAddresses.isEmpty()) {
-            String msg = String.format("Cannot add replica for shard %s because no peer is available", shardName);
-            LOG.debug ("{}: {}", persistenceId(), msg);
-            getSender().tell(new akka.actor.Status.Failure(new IllegalStateException(msg)), getSelf());
-            return;
-        }
-
-        Timeout findPrimaryTimeout = new Timeout(datastoreContextFactory.getBaseDatastoreContext().
-                getShardInitializationTimeout().duration().$times(2));
-
-        final ActorRef sender = getSender();
-        Future<Object> futureObj = ask(getSelf(), new RemoteFindPrimary(shardName, true), findPrimaryTimeout);
-        futureObj.onComplete(new OnComplete<Object>() {
+        findPrimary(shardName, new AutoFindPrimaryFailureResponseHandler(getSender(), shardName, persistenceId(), getSelf()) {
             @Override
-            public void onComplete(Throwable failure, Object response) {
-                if (failure != null) {
-                    LOG.debug ("{}: Received failure from FindPrimary for shard {}", persistenceId(), shardName, failure);
-                    sender.tell(new akka.actor.Status.Failure(new RuntimeException(
-                        String.format("Failed to find leader for shard %s", shardName), failure)),
-                        getSelf());
-                } else {
-                    if (!(response instanceof RemotePrimaryShardFound)) {
-                        String msg = String.format("Failed to find leader for shard %s: received response: %s",
-                                shardName, response);
-                        LOG.debug ("{}: {}", persistenceId(), msg);
-                        sender.tell(new akka.actor.Status.Failure(new RuntimeException(msg)), getSelf());
-                        return;
-                    }
+            public void onRemotePrimaryShardFound(RemotePrimaryShardFound response) {
+                getSelf().tell(new PrimaryShardFoundForContext(getShardName(), shardReplicaMsg, response), getTargetActor());
+            }
 
-                    RemotePrimaryShardFound message = (RemotePrimaryShardFound)response;
-                    addShard (shardName, message, sender);
-                }
+            @Override
+            public void onLocalPrimaryFound(LocalPrimaryShardFound message) {
+                sendLocalReplicaAlreadyExistsReply(getShardName(), getTargetActor());
             }
-        }, new Dispatchers(context().system().dispatchers()).getDispatcher(Dispatchers.DispatcherType.Client));
+
+        });
+    }
+
+    private void sendLocalReplicaAlreadyExistsReply(String shardName, ActorRef sender) {
+        String msg = String.format("Local shard %s already exists", shardName);
+        LOG.debug ("{}: {}", persistenceId(), msg);
+        sender.tell(new akka.actor.Status.Failure(new AlreadyExistsException(msg)), getSelf());
     }
 
     private void addShard(final String shardName, final RemotePrimaryShardFound response, final ActorRef sender) {
-        checkLocalShardExists(shardName, sender);
+        if(isShardReplicaOperationInProgress(shardName, sender)) {
+            return;
+        }
 
-        ShardIdentifier shardId = getShardIdentifier(cluster.getCurrentMemberName(), shardName);
-        String localShardAddress = peerAddressResolver.getShardActorAddress(shardName, cluster.getCurrentMemberName());
+        shardReplicaOperationsInProgress.add(shardName);
+
+        final ShardInformation shardInfo;
+        final boolean removeShardOnFailure;
+        ShardInformation existingShardInfo = localShards.get(shardName);
+        if(existingShardInfo == null) {
+            removeShardOnFailure = true;
+            ShardIdentifier shardId = getShardIdentifier(cluster.getCurrentMemberName(), shardName);
+
+            DatastoreContext datastoreContext = newShardDatastoreContextBuilder(shardName).customRaftPolicyImplementation(
+                    DisableElectionsRaftPolicy.class.getName()).build();
 
-        DatastoreContext datastoreContext = newShardDatastoreContextBuilder(shardName).customRaftPolicyImplementation(
-                DisableElectionsRaftPolicy.class.getName()).build();
+            shardInfo = new ShardInformation(shardName, shardId, getPeerAddresses(shardName), datastoreContext,
+                    Shard.builder(), peerAddressResolver);
+            shardInfo.setActiveMember(false);
+            localShards.put(shardName, shardInfo);
+            shardInfo.setActor(newShardActor(schemaContext, shardInfo));
+        } else {
+            removeShardOnFailure = false;
+            shardInfo = existingShardInfo;
+        }
 
-        final ShardInformation shardInfo = new ShardInformation(shardName, shardId,
-                          getPeerAddresses(shardName), datastoreContext,
-                          Shard.builder(), peerAddressResolver);
-        shardInfo.setShardActiveMember(false);
-        localShards.put(shardName, shardInfo);
-        shardInfo.setActor(newShardActor(schemaContext, shardInfo));
+        String localShardAddress = peerAddressResolver.getShardActorAddress(shardName, cluster.getCurrentMemberName());
 
         //inform ShardLeader to add this shard as a replica by sending an AddServer message
         LOG.debug ("{}: Sending AddServer message to peer {} for shard {}", persistenceId(),
-                response.getPrimaryPath(), shardId);
+                response.getPrimaryPath(), shardInfo.getShardId());
 
-        Timeout addServerTimeout = new Timeout(datastoreContext.getShardLeaderElectionTimeout().duration().$times(4));
+        Timeout addServerTimeout = new Timeout(shardInfo.getDatastoreContext().getShardLeaderElectionTimeout().
+                duration());
         Future<Object> futureObj = ask(getContext().actorSelection(response.getPrimaryPath()),
-            new AddServer(shardId.toString(), localShardAddress, true), addServerTimeout);
+            new AddServer(shardInfo.getShardId().toString(), localShardAddress, true), addServerTimeout);
 
         futureObj.onComplete(new OnComplete<Object>() {
             @Override
@@ -900,27 +1073,37 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
                     LOG.debug ("{}: AddServer request to {} for {} failed", persistenceId(),
                             response.getPrimaryPath(), shardName, failure);
 
-                    // Remove the shard
-                    localShards.remove(shardName);
-                    if (shardInfo.getActor() != null) {
-                        shardInfo.getActor().tell(PoisonPill.getInstance(), getSelf());
-                    }
-
-                    sender.tell(new akka.actor.Status.Failure(new RuntimeException(
-                        String.format("AddServer request to leader %s for shard %s failed",
-                            response.getPrimaryPath(), shardName), failure)), getSelf());
+                    String msg = String.format("AddServer request to leader %s for shard %s failed",
+                            response.getPrimaryPath(), shardName);
+                    self().tell(new ForwardedAddServerFailure(shardName, msg, failure, removeShardOnFailure), sender);
                 } else {
-                    AddServerReply reply = (AddServerReply)addServerResponse;
-                    onAddServerReply(shardName, shardInfo, reply, sender, response.getPrimaryPath());
+                    self().tell(new ForwardedAddServerReply(shardInfo, (AddServerReply)addServerResponse,
+                            response.getPrimaryPath(), removeShardOnFailure), sender);
                 }
             }
-        }, new Dispatchers(context().system().dispatchers()).
-            getDispatcher(Dispatchers.DispatcherType.Client));
-        return;
+        }, new Dispatchers(context().system().dispatchers()).getDispatcher(Dispatchers.DispatcherType.Client));
     }
 
-    private void onAddServerReply (String shardName, ShardInformation shardInfo,
-                                   AddServerReply replyMsg, ActorRef sender, String leaderPath) {
+    private void onAddServerFailure(String shardName, String message, Throwable failure, ActorRef sender,
+            boolean removeShardOnFailure) {
+        shardReplicaOperationsInProgress.remove(shardName);
+
+        if(removeShardOnFailure) {
+            ShardInformation shardInfo = localShards.remove(shardName);
+            if (shardInfo.getActor() != null) {
+                shardInfo.getActor().tell(PoisonPill.getInstance(), getSelf());
+            }
+        }
+
+        sender.tell(new akka.actor.Status.Failure(message == null ? failure :
+            new RuntimeException(message, failure)), getSelf());
+    }
+
+    private void onAddServerReply(ShardInformation shardInfo, AddServerReply replyMsg, ActorRef sender,
+            String leaderPath, boolean removeShardOnFailure) {
+        String shardName = shardInfo.getShardName();
+        shardReplicaOperationsInProgress.remove(shardName);
+
         LOG.debug ("{}: Received {} for shard {} from leader {}", persistenceId(), replyMsg, shardName, leaderPath);
 
         if (replyMsg.getStatus() == ServerChangeStatus.OK) {
@@ -928,71 +1111,90 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
 
             // Make the local shard voting capable
             shardInfo.setDatastoreContext(newShardDatastoreContext(shardName), getSelf());
-            shardInfo.setShardActiveMember(true);
+            shardInfo.setActiveMember(true);
             persistShardList();
 
             mBean.addLocalShard(shardInfo.getShardId().toString());
-            sender.tell(new akka.actor.Status.Success(true), getSelf());
+            sender.tell(new akka.actor.Status.Success(null), getSelf());
+        } else if(replyMsg.getStatus() == ServerChangeStatus.ALREADY_EXISTS) {
+            sendLocalReplicaAlreadyExistsReply(shardName, sender);
         } else {
-            LOG.warn ("{}: Leader failed to add shard replica {} with status {} - removing the local shard",
+            LOG.warn ("{}: Leader failed to add shard replica {} with status {}",
                     persistenceId(), shardName, replyMsg.getStatus());
 
-            //remove the local replica created
-            localShards.remove(shardName);
-            if (shardInfo.getActor() != null) {
-                shardInfo.getActor().tell(PoisonPill.getInstance(), getSelf());
-            }
-            switch (replyMsg.getStatus()) {
-                case TIMEOUT:
-                    sender.tell(new akka.actor.Status.Failure(new RuntimeException(
-                        String.format("The shard leader %s timed out trying to replicate the initial data to the new shard %s. Possible causes - there was a problem replicating the data or shard leadership changed while replicating the shard data",
-                            leaderPath, shardName))), getSelf());
-                    break;
-                case NO_LEADER:
-                    sender.tell(new akka.actor.Status.Failure(new RuntimeException(String.format(
-                        "There is no shard leader available for shard %s", shardName))), getSelf());
-                    break;
-                default :
-                    sender.tell(new akka.actor.Status.Failure(new RuntimeException(String.format(
-                        "AddServer request to leader %s for shard %s failed with status %s",
-                        leaderPath, shardName, replyMsg.getStatus()))), getSelf());
-            }
+            Exception failure = getServerChangeException(AddServer.class, replyMsg.getStatus(), leaderPath, shardInfo.getShardId());
+
+            onAddServerFailure(shardName, null, failure, sender, removeShardOnFailure);
         }
     }
 
-    private void onRemoveShardReplica (RemoveShardReplica shardReplicaMsg) {
-        String shardName = shardReplicaMsg.getShardName();
-
-        // verify the local shard replica is available in the controller node
-        if (!localShards.containsKey(shardName)) {
-            String msg = String.format("Local shard %s does not", shardName);
-            LOG.debug ("{}: {}", persistenceId(), msg);
-            getSender().tell(new akka.actor.Status.Failure(new IllegalArgumentException(msg)), getSelf());
-            return;
+    private Exception getServerChangeException(Class<?> serverChange, ServerChangeStatus serverChangeStatus,
+                                               String leaderPath, ShardIdentifier shardId) {
+        Exception failure;
+        switch (serverChangeStatus) {
+            case TIMEOUT:
+                failure = new TimeoutException(String.format(
+                        "The shard leader %s timed out trying to replicate the initial data to the new shard %s." +
+                        "Possible causes - there was a problem replicating the data or shard leadership changed while replicating the shard data",
+                        leaderPath, shardId.getShardName()));
+                break;
+            case NO_LEADER:
+                failure = createNoShardLeaderException(shardId);
+                break;
+            case NOT_SUPPORTED:
+                failure = new UnsupportedOperationException(String.format("%s request is not supported for shard %s",
+                        serverChange.getSimpleName(), shardId.getShardName()));
+                break;
+            default :
+                failure = new RuntimeException(String.format(
+                        "%s request to leader %s for shard %s failed with status %s",
+                        serverChange.getSimpleName(), leaderPath, shardId.getShardName(), serverChangeStatus));
         }
-        // call RemoveShard for the shardName
-        getSender().tell(new akka.actor.Status.Success(true), getSelf());
-        return;
+        return failure;
+    }
+
+    private void onRemoveShardReplica (final RemoveShardReplica shardReplicaMsg) {
+        LOG.debug("{}: onRemoveShardReplica: {}", persistenceId(), shardReplicaMsg);
+
+        findPrimary(shardReplicaMsg.getShardName(), new AutoFindPrimaryFailureResponseHandler(getSender(),
+                shardReplicaMsg.getShardName(), persistenceId(), getSelf()) {
+            @Override
+            public void onRemotePrimaryShardFound(RemotePrimaryShardFound response) {
+                getSelf().tell(new PrimaryShardFoundForContext(getShardName(), shardReplicaMsg, response), getTargetActor());
+            }
+
+            @Override
+            public void onLocalPrimaryFound(LocalPrimaryShardFound response) {
+                getSelf().tell(new PrimaryShardFoundForContext(getShardName(), shardReplicaMsg, response), getTargetActor());
+            }
+        });
     }
 
     private void persistShardList() {
-        List<String> shardList = new ArrayList(localShards.keySet());
+        List<String> shardList = new ArrayList<>(localShards.keySet());
         for (ShardInformation shardInfo : localShards.values()) {
-            if (!shardInfo.isShardActiveMember()) {
+            if (!shardInfo.isActiveMember()) {
                 shardList.remove(shardInfo.getShardName());
             }
         }
         LOG.debug ("{}: persisting the shard list {}", persistenceId(), shardList);
-        saveSnapshot(new ShardManagerSnapshot(shardList));
+        saveSnapshot(updateShardManagerSnapshot(shardList));
     }
 
-    private void handleShardRecovery(SnapshotOffer offer) {
-        LOG.debug ("{}: in handleShardRecovery", persistenceId());
-        ShardManagerSnapshot snapshot = (ShardManagerSnapshot)offer.snapshot();
+    private ShardManagerSnapshot updateShardManagerSnapshot(List<String> shardList) {
+        currentSnapshot = new ShardManagerSnapshot(shardList);
+        return currentSnapshot;
+    }
+
+    private void applyShardManagerSnapshot(ShardManagerSnapshot snapshot) {
+        currentSnapshot = snapshot;
+
+        LOG.debug ("{}: onSnapshotOffer: {}", persistenceId(), currentSnapshot);
+
         String currentMember = cluster.getCurrentMemberName();
         Set<String> configuredShardList =
             new HashSet<>(configuration.getMemberShardNames(currentMember));
-        for (String shard : snapshot.getShardList()) {
+        for (String shard : currentSnapshot.getShardList()) {
             if (!configuredShardList.contains(shard)) {
                 // add the current member as a replica for the shard
                 LOG.debug ("{}: adding shard {}", persistenceId(), shard);
@@ -1008,6 +1210,42 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
         }
     }
 
+    private void onSaveSnapshotSuccess (SaveSnapshotSuccess successMessage) {
+        LOG.debug ("{} saved ShardManager snapshot successfully. Deleting the prev snapshot if available",
+            persistenceId());
+        deleteSnapshots(new SnapshotSelectionCriteria(scala.Long.MaxValue(), (successMessage.metadata().timestamp() - 1)));
+    }
+
+    private static class ForwardedAddServerReply {
+        ShardInformation shardInfo;
+        AddServerReply addServerReply;
+        String leaderPath;
+        boolean removeShardOnFailure;
+
+        ForwardedAddServerReply(ShardInformation shardInfo, AddServerReply addServerReply, String leaderPath,
+                boolean removeShardOnFailure) {
+            this.shardInfo = shardInfo;
+            this.addServerReply = addServerReply;
+            this.leaderPath = leaderPath;
+            this.removeShardOnFailure = removeShardOnFailure;
+        }
+    }
+
+    private static class ForwardedAddServerFailure {
+        String shardName;
+        String failureMessage;
+        Throwable failure;
+        boolean removeShardOnFailure;
+
+        ForwardedAddServerFailure(String shardName, String failureMessage, Throwable failure,
+                boolean removeShardOnFailure) {
+            this.shardName = shardName;
+            this.failureMessage = failureMessage;
+            this.failure = failure;
+            this.removeShardOnFailure = removeShardOnFailure;
+        }
+    }
+
     @VisibleForTesting
     protected static class ShardInformation {
         private final ShardIdentifier shardId;
@@ -1031,7 +1269,7 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
         private DatastoreContext datastoreContext;
         private Shard.AbstractBuilder<?, ?> builder;
         private final ShardPeerAddressResolver addressResolver;
-        private boolean shardActiveStatus = true;
+        private boolean isActiveMember = true;
 
         private ShardInformation(String shardName, ShardIdentifier shardId,
                 Map<String, String> initialPeerAddresses, DatastoreContext datastoreContext,
@@ -1056,6 +1294,7 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
             return shardName;
         }
 
+        @Nullable
         ActorRef getActor(){
             return actor;
         }
@@ -1221,6 +1460,10 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
 
         void setLeaderAvailable(boolean leaderAvailable) {
             this.leaderAvailable = leaderAvailable;
+
+            if(leaderAvailable) {
+                notifyOnShardInitializedCallbacks();
+            }
         }
 
         short getLeaderVersion() {
@@ -1231,12 +1474,12 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
             this.leaderVersion = leaderVersion;
         }
 
-        void setShardActiveMember(boolean flag) {
-            shardActiveStatus = flag;
+        boolean isActiveMember() {
+            return isActiveMember;
         }
 
-        boolean isShardActiveMember() {
-            return shardActiveStatus;
+        void setActiveMember(boolean isActiveMember) {
+            this.isActiveMember = isActiveMember;
         }
     }
 
@@ -1314,53 +1557,59 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
         return new Builder();
     }
 
-    public static class Builder {
+    public static abstract class AbstractBuilder<T extends AbstractBuilder<T>> {
         private ClusterWrapper cluster;
         private Configuration configuration;
         private DatastoreContextFactory datastoreContextFactory;
         private CountDownLatch waitTillReadyCountdownLatch;
         private PrimaryShardInfoFutureCache primaryShardInfoCache;
         private DatastoreSnapshot restoreFromSnapshot;
+
         private volatile boolean sealed;
 
+        @SuppressWarnings("unchecked")
+        private T self() {
+            return (T) this;
+        }
+
         protected void checkSealed() {
             Preconditions.checkState(!sealed, "Builder is already sealed - further modifications are not allowed");
         }
 
-        public Builder cluster(ClusterWrapper cluster) {
+        public T cluster(ClusterWrapper cluster) {
             checkSealed();
             this.cluster = cluster;
-            return this;
+            return self();
         }
 
-        public Builder configuration(Configuration configuration) {
+        public T configuration(Configuration configuration) {
             checkSealed();
             this.configuration = configuration;
-            return this;
+            return self();
         }
 
-        public Builder datastoreContextFactory(DatastoreContextFactory datastoreContextFactory) {
+        public T datastoreContextFactory(DatastoreContextFactory datastoreContextFactory) {
             checkSealed();
             this.datastoreContextFactory = datastoreContextFactory;
-            return this;
+            return self();
         }
 
-        public Builder waitTillReadyCountdownLatch(CountDownLatch waitTillReadyCountdownLatch) {
+        public T waitTillReadyCountdownLatch(CountDownLatch waitTillReadyCountdownLatch) {
             checkSealed();
             this.waitTillReadyCountdownLatch = waitTillReadyCountdownLatch;
-            return this;
+            return self();
         }
 
-        public Builder primaryShardInfoCache(PrimaryShardInfoFutureCache primaryShardInfoCache) {
+        public T primaryShardInfoCache(PrimaryShardInfoFutureCache primaryShardInfoCache) {
             checkSealed();
             this.primaryShardInfoCache = primaryShardInfoCache;
-            return this;
+            return self();
         }
 
-        public Builder restoreFromSnapshot(DatastoreSnapshot restoreFromSnapshot) {
+        public T restoreFromSnapshot(DatastoreSnapshot restoreFromSnapshot) {
             checkSealed();
             this.restoreFromSnapshot = restoreFromSnapshot;
-            return this;
+            return self();
         }
 
         protected void verify() {
@@ -1377,6 +1626,182 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
             return Props.create(ShardManager.class, this);
         }
     }
+
+    public static class Builder extends AbstractBuilder<Builder> {
+    }
+
+    private void findPrimary(final String shardName, final FindPrimaryResponseHandler handler) {
+        Timeout findPrimaryTimeout = new Timeout(datastoreContextFactory.getBaseDatastoreContext().
+                getShardInitializationTimeout().duration().$times(2));
+
+
+        Future<Object> futureObj = ask(getSelf(), new FindPrimary(shardName, true), findPrimaryTimeout);
+        futureObj.onComplete(new OnComplete<Object>() {
+            @Override
+            public void onComplete(Throwable failure, Object response) {
+                if (failure != null) {
+                    handler.onFailure(failure);
+                } else {
+                    if(response instanceof RemotePrimaryShardFound) {
+                        handler.onRemotePrimaryShardFound((RemotePrimaryShardFound) response);
+                    } else if(response instanceof LocalPrimaryShardFound) {
+                        handler.onLocalPrimaryFound((LocalPrimaryShardFound) response);
+                    } else {
+                        handler.onUnknownResponse(response);
+                    }
+                }
+            }
+        }, new Dispatchers(context().system().dispatchers()).getDispatcher(Dispatchers.DispatcherType.Client));
+    }
+
+    /**
+     * The FindPrimaryResponseHandler provides specific callback methods which are invoked when a response to the
+     * a remote or local find primary message is processed
+     */
+    private static interface FindPrimaryResponseHandler {
+        /**
+         * Invoked when a Failure message is received as a response
+         *
+         * @param failure
+         */
+        void onFailure(Throwable failure);
+
+        /**
+         * Invoked when a RemotePrimaryShardFound response is received
+         *
+         * @param response
+         */
+        void onRemotePrimaryShardFound(RemotePrimaryShardFound response);
+
+        /**
+         * Invoked when a LocalPrimaryShardFound response is received
+         * @param response
+         */
+        void onLocalPrimaryFound(LocalPrimaryShardFound response);
+
+        /**
+         * Invoked when an unknown response is received. This is another type of failure.
+         *
+         * @param response
+         */
+        void onUnknownResponse(Object response);
+    }
+
+    /**
+     * The AutoFindPrimaryFailureResponseHandler automatically processes Failure responses when finding a primary
+     * replica and sends a wrapped Failure response to some targetActor
+     */
+    private static abstract class AutoFindPrimaryFailureResponseHandler implements FindPrimaryResponseHandler {
+        private final ActorRef targetActor;
+        private final String shardName;
+        private final String persistenceId;
+        private final ActorRef shardManagerActor;
+
+        /**
+         * @param targetActor The actor to whom the Failure response should be sent when a FindPrimary failure occurs
+         * @param shardName The name of the shard for which the primary replica had to be found
+         * @param persistenceId The persistenceId for the ShardManager
+         * @param shardManagerActor The ShardManager actor which triggered the call to FindPrimary
+         */
+        protected AutoFindPrimaryFailureResponseHandler(ActorRef targetActor, String shardName, String persistenceId, ActorRef shardManagerActor){
+            this.targetActor = Preconditions.checkNotNull(targetActor);
+            this.shardName = Preconditions.checkNotNull(shardName);
+            this.persistenceId = Preconditions.checkNotNull(persistenceId);
+            this.shardManagerActor = Preconditions.checkNotNull(shardManagerActor);
+        }
+
+        public ActorRef getTargetActor() {
+            return targetActor;
+        }
+
+        public String getShardName() {
+            return shardName;
+        }
+
+        @Override
+        public void onFailure(Throwable failure) {
+            LOG.debug ("{}: Received failure from FindPrimary for shard {}", persistenceId, shardName, failure);
+            targetActor.tell(new akka.actor.Status.Failure(new RuntimeException(
+                    String.format("Failed to find leader for shard %s", shardName), failure)), shardManagerActor);
+        }
+
+        @Override
+        public void onUnknownResponse(Object response) {
+            String msg = String.format("Failed to find leader for shard %s: received response: %s",
+                    shardName, response);
+            LOG.debug ("{}: {}", persistenceId, msg);
+            targetActor.tell(new akka.actor.Status.Failure(response instanceof Throwable ? (Throwable) response :
+                    new RuntimeException(msg)), shardManagerActor);
+        }
+    }
+
+
+    /**
+     * The PrimaryShardFoundForContext is a DTO which puts together a message (aka 'Context' message) which needs to be
+     * forwarded to the primary replica of a shard and the message (aka 'PrimaryShardFound' message) that is received
+     * as a successful response to find primary.
+     */
+    private static class PrimaryShardFoundForContext {
+        private final String shardName;
+        private final Object contextMessage;
+        private final RemotePrimaryShardFound remotePrimaryShardFound;
+        private final LocalPrimaryShardFound localPrimaryShardFound;
+
+        public PrimaryShardFoundForContext(@Nonnull String shardName, @Nonnull Object contextMessage,
+                @Nonnull Object primaryFoundMessage) {
+            this.shardName = Preconditions.checkNotNull(shardName);
+            this.contextMessage = Preconditions.checkNotNull(contextMessage);
+            Preconditions.checkNotNull(primaryFoundMessage);
+            this.remotePrimaryShardFound = (primaryFoundMessage instanceof RemotePrimaryShardFound) ?
+                    (RemotePrimaryShardFound) primaryFoundMessage : null;
+            this.localPrimaryShardFound = (primaryFoundMessage instanceof LocalPrimaryShardFound) ?
+                    (LocalPrimaryShardFound) primaryFoundMessage : null;
+        }
+
+        @Nonnull
+        String getPrimaryPath(){
+            if(remotePrimaryShardFound != null) {
+                return remotePrimaryShardFound.getPrimaryPath();
+            }
+            return localPrimaryShardFound.getPrimaryPath();
+        }
+
+        @Nonnull
+        Object getContextMessage() {
+            return contextMessage;
+        }
+
+        @Nullable
+        RemotePrimaryShardFound getRemotePrimaryShardFound() {
+            return remotePrimaryShardFound;
+        }
+
+        @Nonnull
+        String getShardName() {
+            return shardName;
+        }
+    }
+
+    /**
+     * The WrappedShardResponse class wraps a response from a Shard.
+     */
+    private static class WrappedShardResponse {
+        private final String shardName;
+        private final Object response;
+
+        private WrappedShardResponse(String shardName, Object response) {
+            this.shardName = shardName;
+            this.response = response;
+        }
+
+        String getShardName() {
+            return shardName;
+        }
+
+        Object getResponse() {
+            return response;
+        }
+    }
 }