Bug 4823: Notify findPrimary callbacks on ReachableMember event
[controller.git] / opendaylight / md-sal / sal-distributed-datastore / src / main / java / org / opendaylight / controller / cluster / datastore / ShardManager.java
index b5e1ca4f1a5b3a3793de1e63646f88d593bb627f..a91109c64b3fa5aac674677b95d2a31c83477fbd 100644 (file)
@@ -16,6 +16,7 @@ import akka.actor.Cancellable;
 import akka.actor.OneForOneStrategy;
 import akka.actor.PoisonPill;
 import akka.actor.Props;
+import akka.actor.Status;
 import akka.actor.SupervisorStrategy;
 import akka.cluster.ClusterEvent;
 import akka.dispatch.OnComplete;
@@ -24,6 +25,7 @@ import akka.persistence.RecoveryCompleted;
 import akka.persistence.SaveSnapshotFailure;
 import akka.persistence.SaveSnapshotSuccess;
 import akka.persistence.SnapshotOffer;
+import akka.persistence.SnapshotSelectionCriteria;
 import akka.serialization.Serialization;
 import akka.util.Timeout;
 import com.google.common.annotations.VisibleForTesting;
@@ -33,6 +35,8 @@ import com.google.common.base.Preconditions;
 import com.google.common.base.Strings;
 import com.google.common.base.Supplier;
 import com.google.common.collect.Sets;
+import java.io.ByteArrayInputStream;
+import java.io.ObjectInputStream;
 import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Collection;
@@ -48,6 +52,7 @@ import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
 import javax.annotation.Nonnull;
 import javax.annotation.Nullable;
+import org.apache.commons.lang3.SerializationUtils;
 import org.opendaylight.controller.cluster.common.actor.AbstractUntypedPersistentActorWithMetering;
 import org.opendaylight.controller.cluster.datastore.config.Configuration;
 import org.opendaylight.controller.cluster.datastore.config.ModuleShardConfiguration;
@@ -86,6 +91,8 @@ import org.opendaylight.controller.cluster.raft.base.messages.SwitchBehavior;
 import org.opendaylight.controller.cluster.raft.client.messages.GetSnapshot;
 import org.opendaylight.controller.cluster.raft.messages.AddServer;
 import org.opendaylight.controller.cluster.raft.messages.AddServerReply;
+import org.opendaylight.controller.cluster.raft.messages.RemoveServer;
+import org.opendaylight.controller.cluster.raft.messages.RemoveServerReply;
 import org.opendaylight.controller.cluster.raft.messages.ServerChangeStatus;
 import org.opendaylight.controller.cluster.raft.messages.ServerRemoved;
 import org.opendaylight.controller.cluster.raft.policy.DisableElectionsRaftPolicy;
@@ -139,18 +146,20 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
 
     private DatastoreSnapshot restoreFromSnapshot;
 
+    private ShardManagerSnapshot currentSnapshot;
+
     private final Set<String> shardReplicaOperationsInProgress = new HashSet<>();
 
     private final String persistenceId;
 
     /**
      */
-    protected ShardManager(Builder builder) {
+    protected ShardManager(AbstractBuilder<?> builder) {
 
         this.cluster = builder.cluster;
         this.configuration = builder.configuration;
         this.datastoreContextFactory = builder.datastoreContextFactory;
-        this.type = builder.datastoreContextFactory.getBaseDatastoreContext().getDataStoreType();
+        this.type = builder.datastoreContextFactory.getBaseDatastoreContext().getDataStoreName();
         this.shardDispatcherPath =
                 new Dispatchers(context().system().dispatchers()).getDispatcherPath(Dispatchers.DispatcherType.Shard);
         this.waitTillReadyCountdownLatch = builder.waitTillReadyCountdownLatch;
@@ -226,14 +235,16 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
         } else if(message instanceof PrimaryShardFoundForContext) {
             PrimaryShardFoundForContext primaryShardFoundContext = (PrimaryShardFoundForContext)message;
             onPrimaryShardFoundContext(primaryShardFoundContext);
-        } else if(message instanceof RemoveShardReplica){
-            onRemoveShardReplica((RemoveShardReplica)message);
+        } else if(message instanceof RemoveShardReplica) {
+            onRemoveShardReplica((RemoveShardReplica) message);
+        } else if(message instanceof WrappedShardResponse){
+            onWrappedShardResponse((WrappedShardResponse) message);
         } else if(message instanceof GetSnapshot) {
             onGetSnapshot();
         } else if(message instanceof ServerRemoved){
             onShardReplicaRemoved((ServerRemoved) message);
         } else if (message instanceof SaveSnapshotSuccess) {
-            LOG.debug("{} saved ShardManager snapshot successfully", persistenceId());
+            onSaveSnapshotSuccess((SaveSnapshotSuccess)message);
         } else if (message instanceof SaveSnapshotFailure) {
             LOG.error("{}: SaveSnapshotFailure received for saving snapshot of shards",
                     persistenceId(), ((SaveSnapshotFailure) message).cause());
@@ -242,12 +253,67 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
         }
     }
 
+    private void onWrappedShardResponse(WrappedShardResponse message) {
+        if (message.getResponse() instanceof RemoveServerReply) {
+            onRemoveServerReply(getSender(), message.getShardName(), (RemoveServerReply) message.getResponse());
+        }
+    }
+
+    private void onRemoveServerReply(ActorRef originalSender, String shardName, RemoveServerReply response) {
+        shardReplicaOperationsInProgress.remove(shardName);
+        originalSender.tell(new Status.Success(null), self());
+    }
+
     private void onPrimaryShardFoundContext(PrimaryShardFoundForContext primaryShardFoundContext) {
         if(primaryShardFoundContext.getContextMessage() instanceof AddShardReplica) {
-            addShard(primaryShardFoundContext.shardName, primaryShardFoundContext.getRemotePrimaryShardFound(), getSender());
+            addShard(primaryShardFoundContext.getShardName(), primaryShardFoundContext.getRemotePrimaryShardFound(),
+                    getSender());
+        } else if(primaryShardFoundContext.getContextMessage() instanceof RemoveShardReplica){
+            removeShardReplica((RemoveShardReplica) primaryShardFoundContext.getContextMessage(),
+                    primaryShardFoundContext.getShardName(), primaryShardFoundContext.getPrimaryPath(), getSender());
         }
     }
 
+    private void removeShardReplica(RemoveShardReplica contextMessage, final String shardName, final String primaryPath,
+            final ActorRef sender) {
+        if(isShardReplicaOperationInProgress(shardName, sender)) {
+            return;
+        }
+
+        shardReplicaOperationsInProgress.add(shardName);
+
+        final ShardIdentifier shardId = getShardIdentifier(contextMessage.getMemberName(), shardName);
+
+        final DatastoreContext datastoreContext = newShardDatastoreContextBuilder(shardName).build();
+
+        //inform ShardLeader to remove this shard as a replica by sending an RemoveServer message
+        LOG.debug ("{}: Sending RemoveServer message to peer {} for shard {}", persistenceId(),
+                primaryPath, shardId);
+
+        Timeout removeServerTimeout = new Timeout(datastoreContext.getShardLeaderElectionTimeout().
+                duration());
+        Future<Object> futureObj = ask(getContext().actorSelection(primaryPath),
+                new RemoveServer(shardId.toString()), removeServerTimeout);
+
+        futureObj.onComplete(new OnComplete<Object>() {
+            @Override
+            public void onComplete(Throwable failure, Object response) {
+                if (failure != null) {
+                    String msg = String.format("RemoveServer request to leader %s for shard %s failed",
+                            primaryPath, shardName);
+
+                    LOG.debug ("{}: {}", persistenceId(), msg, failure);
+
+                    // FAILURE
+                    sender.tell(new Status.Failure(new RuntimeException(msg, failure)), self());
+                } else {
+                    // SUCCESS
+                    self().tell(new WrappedShardResponse(shardName, response), sender);
+                }
+            }
+        }, new Dispatchers(context().system().dispatchers()).getDispatcher(Dispatchers.DispatcherType.Client));
+    }
+
     private void onShardReplicaRemoved(ServerRemoved message) {
         final ShardIdentifier shardId = new ShardIdentifier.Builder().fromShardIdString(message.getServerId()).build();
         final ShardInformation shardInformation = localShards.remove(shardId.getShardName());
@@ -283,6 +349,10 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
         }
 
         byte[] shardManagerSnapshot = null;
+        if(currentSnapshot != null) {
+            shardManagerSnapshot = SerializationUtils.serialize(currentSnapshot);
+        }
+
         ActorRef replyActor = getContext().actorOf(ShardManagerGetSnapshotReplyActor.props(
                 new ArrayList<>(localShards.keySet()), type, shardManagerSnapshot , getSender(), persistenceId(),
                 datastoreContextFactory.getBaseDatastoreContext().getShardInitializationTimeout().duration()));
@@ -293,17 +363,20 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
     }
 
     private void onCreateShard(CreateShard createShard) {
+        LOG.debug("{}: onCreateShard: {}", persistenceId(), createShard);
+
         Object reply;
         try {
             String shardName = createShard.getModuleShardConfig().getShardName();
             if(localShards.containsKey(shardName)) {
+                LOG.debug("{}: Shard {} already exists", persistenceId(), shardName);
                 reply = new akka.actor.Status.Success(String.format("Shard with name %s already exists", shardName));
             } else {
                 doCreateShard(createShard);
                 reply = new akka.actor.Status.Success(null);
             }
         } catch (Exception e) {
-            LOG.error("onCreateShard failed", e);
+            LOG.error("{}: onCreateShard failed", persistenceId(), e);
             reply = new akka.actor.Status.Failure(e);
         }
 
@@ -328,13 +401,18 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
 
         ShardIdentifier shardId = getShardIdentifier(cluster.getCurrentMemberName(), shardName);
 
+        boolean shardWasInRecoveredSnapshot = currentSnapshot != null &&
+                currentSnapshot.getShardList().contains(shardName);
+
         Map<String, String> peerAddresses;
         boolean isActiveMember;
-        if(configuration.getMembersFromShardName(shardName).contains(cluster.getCurrentMemberName())) {
+        if(shardWasInRecoveredSnapshot || configuration.getMembersFromShardName(shardName).
+                contains(cluster.getCurrentMemberName())) {
             peerAddresses = getPeerAddresses(shardName);
             isActiveMember = true;
         } else {
-            // The local member is not in the given shard member configuration. In this case we'll create
+            // The local member is not in the static shard member configuration and the shard did not
+            // previously exist (ie !shardWasInRecoveredSnapshot). In this case we'll create
             // the shard with no peers and with elections disabled so it stays as follower. A
             // subsequent AddServer request will be needed to make it an active member.
             isActiveMember = false;
@@ -343,8 +421,9 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
                     customRaftPolicyImplementation(DisableElectionsRaftPolicy.class.getName()).build();
         }
 
-        LOG.debug("onCreateShard: shardId: {}, memberNames: {}. peerAddresses: {}", shardId,
-                moduleShardConfig.getShardMemberNames(), peerAddresses);
+        LOG.debug("{} doCreateShard: shardId: {}, memberNames: {}, peerAddresses: {}, isActiveMember: {}",
+                persistenceId(), shardId, moduleShardConfig.getShardMemberNames(), peerAddresses,
+                isActiveMember);
 
         ShardInformation info = new ShardInformation(shardName, shardId, peerAddresses,
                 shardDatastoreContext, createShard.getShardBuilder(), peerAddressResolver);
@@ -499,15 +578,34 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
     @Override
     protected void handleRecover(Object message) throws Exception {
         if (message instanceof RecoveryCompleted) {
-            LOG.info("Recovery complete : {}", persistenceId());
-
-            // We no longer persist SchemaContext modules so delete all the prior messages from the akka
-            // journal on upgrade from Helium.
-            deleteMessages(lastSequenceNr());
-            createLocalShards();
+            onRecoveryCompleted();
         } else if (message instanceof SnapshotOffer) {
-            handleShardRecovery((SnapshotOffer) message);
+            applyShardManagerSnapshot((ShardManagerSnapshot)((SnapshotOffer) message).snapshot());
+        }
+    }
+
+    private void onRecoveryCompleted() {
+        LOG.info("Recovery complete : {}", persistenceId());
+
+        // We no longer persist SchemaContext modules so delete all the prior messages from the akka
+        // journal on upgrade from Helium.
+        deleteMessages(lastSequenceNr());
+
+        if(currentSnapshot == null && restoreFromSnapshot != null &&
+                restoreFromSnapshot.getShardManagerSnapshot() != null) {
+            try(ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(
+                    restoreFromSnapshot.getShardManagerSnapshot()))) {
+                ShardManagerSnapshot snapshot = (ShardManagerSnapshot) ois.readObject();
+
+                LOG.debug("{}: Deserialized restored ShardManagerSnapshot: {}", persistenceId(), snapshot);
+
+                applyShardManagerSnapshot(snapshot);
+            } catch(Exception e) {
+                LOG.error("{}: Error deserializing restored ShardManagerSnapshot", persistenceId(), e);
+            }
         }
+
+        createLocalShards();
     }
 
     private void findLocalShard(FindLocalShard message) {
@@ -759,12 +857,25 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
             return;
         }
 
+        Collection<String> visitedAddresses;
+        if(message instanceof RemoteFindPrimary) {
+            visitedAddresses = ((RemoteFindPrimary)message).getVisitedAddresses();
+        } else {
+            visitedAddresses = new ArrayList<>();
+        }
+
+        visitedAddresses.add(peerAddressResolver.getShardManagerActorPathBuilder(cluster.getSelfAddress()).toString());
+
         for(String address: peerAddressResolver.getShardManagerPeerActorAddresses()) {
+            if(visitedAddresses.contains(address)) {
+                continue;
+            }
+
             LOG.debug("{}: findPrimary for {} forwarding to remote ShardManager {}", persistenceId(),
                     shardName, address);
 
             getContext().actorSelection(address).forward(new RemoteFindPrimary(shardName,
-                    message.isWaitUntilReady()), getContext());
+                    message.isWaitUntilReady(), visitedAddresses), getContext());
             return;
         }
 
@@ -807,6 +918,9 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
 
         for(String shardName : memberShardNames){
             ShardIdentifier shardId = getShardIdentifier(memberName, shardName);
+
+            LOG.debug("{}: Creating local shard: {}", persistenceId(), shardId);
+
             Map<String, String> peerAddresses = getPeerAddresses(shardName);
             localShards.put(shardName, new ShardInformation(shardName, shardId, peerAddresses,
                     newShardDatastoreContext(shardName), Shard.builder().restoreFromSnapshot(
@@ -898,7 +1012,6 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
             @Override
             public void onRemotePrimaryShardFound(RemotePrimaryShardFound response) {
                 getSelf().tell(new PrimaryShardFoundForContext(getShardName(), shardReplicaMsg, response), getTargetActor());
-
             }
 
             @Override
@@ -1009,40 +1122,52 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
             LOG.warn ("{}: Leader failed to add shard replica {} with status {}",
                     persistenceId(), shardName, replyMsg.getStatus());
 
-            Exception failure;
-            switch (replyMsg.getStatus()) {
-                case TIMEOUT:
-                    failure = new TimeoutException(String.format(
-                            "The shard leader %s timed out trying to replicate the initial data to the new shard %s." +
-                            "Possible causes - there was a problem replicating the data or shard leadership changed while replicating the shard data",
-                            leaderPath, shardName));
-                    break;
-                case NO_LEADER:
-                    failure = createNoShardLeaderException(shardInfo.getShardId());
-                    break;
-                default :
-                    failure = new RuntimeException(String.format(
-                            "AddServer request to leader %s for shard %s failed with status %s",
-                            leaderPath, shardName, replyMsg.getStatus()));
-            }
+            Exception failure = getServerChangeException(AddServer.class, replyMsg.getStatus(), leaderPath, shardInfo.getShardId());
 
             onAddServerFailure(shardName, null, failure, sender, removeShardOnFailure);
         }
     }
 
-    private void onRemoveShardReplica (RemoveShardReplica shardReplicaMsg) {
-        String shardName = shardReplicaMsg.getShardName();
-
-        // verify the local shard replica is available in the controller node
-        if (!localShards.containsKey(shardName)) {
-            String msg = String.format("Local shard %s does not", shardName);
-            LOG.debug ("{}: {}", persistenceId(), msg);
-            getSender().tell(new akka.actor.Status.Failure(new IllegalArgumentException(msg)), getSelf());
-            return;
+    private Exception getServerChangeException(Class<?> serverChange, ServerChangeStatus serverChangeStatus,
+                                               String leaderPath, ShardIdentifier shardId) {
+        Exception failure;
+        switch (serverChangeStatus) {
+            case TIMEOUT:
+                failure = new TimeoutException(String.format(
+                        "The shard leader %s timed out trying to replicate the initial data to the new shard %s." +
+                        "Possible causes - there was a problem replicating the data or shard leadership changed while replicating the shard data",
+                        leaderPath, shardId.getShardName()));
+                break;
+            case NO_LEADER:
+                failure = createNoShardLeaderException(shardId);
+                break;
+            case NOT_SUPPORTED:
+                failure = new UnsupportedOperationException(String.format("%s request is not supported for shard %s",
+                        serverChange.getSimpleName(), shardId.getShardName()));
+                break;
+            default :
+                failure = new RuntimeException(String.format(
+                        "%s request to leader %s for shard %s failed with status %s",
+                        serverChange.getSimpleName(), leaderPath, shardId.getShardName(), serverChangeStatus));
         }
-        // call RemoveShard for the shardName
-        getSender().tell(new akka.actor.Status.Success(true), getSelf());
-        return;
+        return failure;
+    }
+
+    private void onRemoveShardReplica (final RemoveShardReplica shardReplicaMsg) {
+        LOG.debug("{}: onRemoveShardReplica: {}", persistenceId(), shardReplicaMsg);
+
+        findPrimary(shardReplicaMsg.getShardName(), new AutoFindPrimaryFailureResponseHandler(getSender(),
+                shardReplicaMsg.getShardName(), persistenceId(), getSelf()) {
+            @Override
+            public void onRemotePrimaryShardFound(RemotePrimaryShardFound response) {
+                getSelf().tell(new PrimaryShardFoundForContext(getShardName(), shardReplicaMsg, response), getTargetActor());
+            }
+
+            @Override
+            public void onLocalPrimaryFound(LocalPrimaryShardFound response) {
+                getSelf().tell(new PrimaryShardFoundForContext(getShardName(), shardReplicaMsg, response), getTargetActor());
+            }
+        });
     }
 
     private void persistShardList() {
@@ -1053,16 +1178,23 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
             }
         }
         LOG.debug ("{}: persisting the shard list {}", persistenceId(), shardList);
-        saveSnapshot(new ShardManagerSnapshot(shardList));
+        saveSnapshot(updateShardManagerSnapshot(shardList));
+    }
+
+    private ShardManagerSnapshot updateShardManagerSnapshot(List<String> shardList) {
+        currentSnapshot = new ShardManagerSnapshot(shardList);
+        return currentSnapshot;
     }
 
-    private void handleShardRecovery(SnapshotOffer offer) {
-        LOG.debug ("{}: in handleShardRecovery", persistenceId());
-        ShardManagerSnapshot snapshot = (ShardManagerSnapshot)offer.snapshot();
+    private void applyShardManagerSnapshot(ShardManagerSnapshot snapshot) {
+        currentSnapshot = snapshot;
+
+        LOG.debug ("{}: onSnapshotOffer: {}", persistenceId(), currentSnapshot);
+
         String currentMember = cluster.getCurrentMemberName();
         Set<String> configuredShardList =
             new HashSet<>(configuration.getMemberShardNames(currentMember));
-        for (String shard : snapshot.getShardList()) {
+        for (String shard : currentSnapshot.getShardList()) {
             if (!configuredShardList.contains(shard)) {
                 // add the current member as a replica for the shard
                 LOG.debug ("{}: adding shard {}", persistenceId(), shard);
@@ -1078,6 +1210,12 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
         }
     }
 
+    private void onSaveSnapshotSuccess (SaveSnapshotSuccess successMessage) {
+        LOG.debug ("{} saved ShardManager snapshot successfully. Deleting the prev snapshot if available",
+            persistenceId());
+        deleteSnapshots(new SnapshotSelectionCriteria(scala.Long.MaxValue(), (successMessage.metadata().timestamp() - 1)));
+    }
+
     private static class ForwardedAddServerReply {
         ShardInformation shardInfo;
         AddServerReply addServerReply;
@@ -1322,6 +1460,10 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
 
         void setLeaderAvailable(boolean leaderAvailable) {
             this.leaderAvailable = leaderAvailable;
+
+            if(leaderAvailable) {
+                notifyOnShardInitializedCallbacks();
+            }
         }
 
         short getLeaderVersion() {
@@ -1415,53 +1557,59 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
         return new Builder();
     }
 
-    public static class Builder {
+    public static abstract class AbstractBuilder<T extends AbstractBuilder<T>> {
         private ClusterWrapper cluster;
         private Configuration configuration;
         private DatastoreContextFactory datastoreContextFactory;
         private CountDownLatch waitTillReadyCountdownLatch;
         private PrimaryShardInfoFutureCache primaryShardInfoCache;
         private DatastoreSnapshot restoreFromSnapshot;
+
         private volatile boolean sealed;
 
+        @SuppressWarnings("unchecked")
+        private T self() {
+            return (T) this;
+        }
+
         protected void checkSealed() {
             Preconditions.checkState(!sealed, "Builder is already sealed - further modifications are not allowed");
         }
 
-        public Builder cluster(ClusterWrapper cluster) {
+        public T cluster(ClusterWrapper cluster) {
             checkSealed();
             this.cluster = cluster;
-            return this;
+            return self();
         }
 
-        public Builder configuration(Configuration configuration) {
+        public T configuration(Configuration configuration) {
             checkSealed();
             this.configuration = configuration;
-            return this;
+            return self();
         }
 
-        public Builder datastoreContextFactory(DatastoreContextFactory datastoreContextFactory) {
+        public T datastoreContextFactory(DatastoreContextFactory datastoreContextFactory) {
             checkSealed();
             this.datastoreContextFactory = datastoreContextFactory;
-            return this;
+            return self();
         }
 
-        public Builder waitTillReadyCountdownLatch(CountDownLatch waitTillReadyCountdownLatch) {
+        public T waitTillReadyCountdownLatch(CountDownLatch waitTillReadyCountdownLatch) {
             checkSealed();
             this.waitTillReadyCountdownLatch = waitTillReadyCountdownLatch;
-            return this;
+            return self();
         }
 
-        public Builder primaryShardInfoCache(PrimaryShardInfoFutureCache primaryShardInfoCache) {
+        public T primaryShardInfoCache(PrimaryShardInfoFutureCache primaryShardInfoCache) {
             checkSealed();
             this.primaryShardInfoCache = primaryShardInfoCache;
-            return this;
+            return self();
         }
 
-        public Builder restoreFromSnapshot(DatastoreSnapshot restoreFromSnapshot) {
+        public T restoreFromSnapshot(DatastoreSnapshot restoreFromSnapshot) {
             checkSealed();
             this.restoreFromSnapshot = restoreFromSnapshot;
-            return this;
+            return self();
         }
 
         protected void verify() {
@@ -1479,6 +1627,9 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
         }
     }
 
+    public static class Builder extends AbstractBuilder<Builder> {
+    }
+
     private void findPrimary(final String shardName, final FindPrimaryResponseHandler handler) {
         Timeout findPrimaryTimeout = new Timeout(datastoreContextFactory.getBaseDatastoreContext().
                 getShardInitializationTimeout().duration().$times(2));
@@ -1567,10 +1718,6 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
             return shardName;
         }
 
-        public ActorRef getShardManagerActor() {
-            return shardManagerActor;
-        }
-
         @Override
         public void onFailure(Throwable failure) {
             LOG.debug ("{}: Received failure from FindPrimary for shard {}", persistenceId, shardName, failure);
@@ -1600,45 +1747,60 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
         private final RemotePrimaryShardFound remotePrimaryShardFound;
         private final LocalPrimaryShardFound localPrimaryShardFound;
 
-        public PrimaryShardFoundForContext(@Nonnull String shardName, @Nonnull Object contextMessage, @Nonnull Object primaryFoundMessage) {
+        public PrimaryShardFoundForContext(@Nonnull String shardName, @Nonnull Object contextMessage,
+                @Nonnull Object primaryFoundMessage) {
             this.shardName = Preconditions.checkNotNull(shardName);
             this.contextMessage = Preconditions.checkNotNull(contextMessage);
             Preconditions.checkNotNull(primaryFoundMessage);
-            this.remotePrimaryShardFound = (primaryFoundMessage instanceof RemotePrimaryShardFound) ? (RemotePrimaryShardFound) primaryFoundMessage : null;
-            this.localPrimaryShardFound = (primaryFoundMessage instanceof LocalPrimaryShardFound) ? (LocalPrimaryShardFound) primaryFoundMessage : null;
+            this.remotePrimaryShardFound = (primaryFoundMessage instanceof RemotePrimaryShardFound) ?
+                    (RemotePrimaryShardFound) primaryFoundMessage : null;
+            this.localPrimaryShardFound = (primaryFoundMessage instanceof LocalPrimaryShardFound) ?
+                    (LocalPrimaryShardFound) primaryFoundMessage : null;
         }
 
         @Nonnull
-        public String getPrimaryPath(){
-            if(remotePrimaryShardFound != null){
+        String getPrimaryPath(){
+            if(remotePrimaryShardFound != null) {
                 return remotePrimaryShardFound.getPrimaryPath();
             }
             return localPrimaryShardFound.getPrimaryPath();
         }
 
         @Nonnull
-        public Object getContextMessage() {
+        Object getContextMessage() {
             return contextMessage;
         }
 
         @Nullable
-        public RemotePrimaryShardFound getRemotePrimaryShardFound(){
+        RemotePrimaryShardFound getRemotePrimaryShardFound() {
             return remotePrimaryShardFound;
         }
 
-        @Nullable
-        public LocalPrimaryShardFound getLocalPrimaryShardFound(){
-            return localPrimaryShardFound;
+        @Nonnull
+        String getShardName() {
+            return shardName;
         }
+    }
+
+    /**
+     * The WrappedShardResponse class wraps a response from a Shard.
+     */
+    private static class WrappedShardResponse {
+        private final String shardName;
+        private final Object response;
 
-        boolean isPrimaryLocal(){
-            return (remotePrimaryShardFound == null);
+        private WrappedShardResponse(String shardName, Object response) {
+            this.shardName = shardName;
+            this.response = response;
         }
 
-        @Nonnull
-        public String getShardName() {
+        String getShardName() {
             return shardName;
         }
+
+        Object getResponse() {
+            return response;
+        }
     }
 }