Bug 4823: Notify findPrimary callbacks on ReachableMember event
[controller.git] / opendaylight / md-sal / sal-distributed-datastore / src / main / java / org / opendaylight / controller / cluster / datastore / ShardManager.java
index 616f56c466bbac02194460dca1d07b4a57b2569f..a91109c64b3fa5aac674677b95d2a31c83477fbd 100644 (file)
@@ -16,6 +16,7 @@ import akka.actor.Cancellable;
 import akka.actor.OneForOneStrategy;
 import akka.actor.PoisonPill;
 import akka.actor.Props;
+import akka.actor.Status;
 import akka.actor.SupervisorStrategy;
 import akka.cluster.ClusterEvent;
 import akka.dispatch.OnComplete;
@@ -24,6 +25,7 @@ import akka.persistence.RecoveryCompleted;
 import akka.persistence.SaveSnapshotFailure;
 import akka.persistence.SaveSnapshotSuccess;
 import akka.persistence.SnapshotOffer;
+import akka.persistence.SnapshotSelectionCriteria;
 import akka.serialization.Serialization;
 import akka.util.Timeout;
 import com.google.common.annotations.VisibleForTesting;
@@ -33,6 +35,8 @@ import com.google.common.base.Preconditions;
 import com.google.common.base.Strings;
 import com.google.common.base.Supplier;
 import com.google.common.collect.Sets;
+import java.io.ByteArrayInputStream;
+import java.io.ObjectInputStream;
 import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Collection;
@@ -46,7 +50,9 @@ import java.util.Set;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
+import javax.annotation.Nonnull;
 import javax.annotation.Nullable;
+import org.apache.commons.lang3.SerializationUtils;
 import org.opendaylight.controller.cluster.common.actor.AbstractUntypedPersistentActorWithMetering;
 import org.opendaylight.controller.cluster.datastore.config.Configuration;
 import org.opendaylight.controller.cluster.datastore.config.ModuleShardConfiguration;
@@ -85,6 +91,8 @@ import org.opendaylight.controller.cluster.raft.base.messages.SwitchBehavior;
 import org.opendaylight.controller.cluster.raft.client.messages.GetSnapshot;
 import org.opendaylight.controller.cluster.raft.messages.AddServer;
 import org.opendaylight.controller.cluster.raft.messages.AddServerReply;
+import org.opendaylight.controller.cluster.raft.messages.RemoveServer;
+import org.opendaylight.controller.cluster.raft.messages.RemoveServerReply;
 import org.opendaylight.controller.cluster.raft.messages.ServerChangeStatus;
 import org.opendaylight.controller.cluster.raft.messages.ServerRemoved;
 import org.opendaylight.controller.cluster.raft.policy.DisableElectionsRaftPolicy;
@@ -138,18 +146,20 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
 
     private DatastoreSnapshot restoreFromSnapshot;
 
+    private ShardManagerSnapshot currentSnapshot;
+
     private final Set<String> shardReplicaOperationsInProgress = new HashSet<>();
 
     private final String persistenceId;
 
     /**
      */
-    protected ShardManager(Builder builder) {
+    protected ShardManager(AbstractBuilder<?> builder) {
 
         this.cluster = builder.cluster;
         this.configuration = builder.configuration;
         this.datastoreContextFactory = builder.datastoreContextFactory;
-        this.type = builder.datastoreContextFactory.getBaseDatastoreContext().getDataStoreType();
+        this.type = builder.datastoreContextFactory.getBaseDatastoreContext().getDataStoreName();
         this.shardDispatcherPath =
                 new Dispatchers(context().system().dispatchers()).getDispatcherPath(Dispatchers.DispatcherType.Shard);
         this.waitTillReadyCountdownLatch = builder.waitTillReadyCountdownLatch;
@@ -222,25 +232,88 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
         } else if(message instanceof ForwardedAddServerFailure) {
             ForwardedAddServerFailure msg = (ForwardedAddServerFailure)message;
             onAddServerFailure(msg.shardName, msg.failureMessage, msg.failure, getSender(), msg.removeShardOnFailure);
-        } else if(message instanceof ForwardedAddServerPrimaryShardFound) {
-            ForwardedAddServerPrimaryShardFound msg = (ForwardedAddServerPrimaryShardFound)message;
-            addShard(msg.shardName, msg.primaryFound, getSender());
-        } else if(message instanceof RemoveShardReplica){
-            onRemoveShardReplica((RemoveShardReplica)message);
+        } else if(message instanceof PrimaryShardFoundForContext) {
+            PrimaryShardFoundForContext primaryShardFoundContext = (PrimaryShardFoundForContext)message;
+            onPrimaryShardFoundContext(primaryShardFoundContext);
+        } else if(message instanceof RemoveShardReplica) {
+            onRemoveShardReplica((RemoveShardReplica) message);
+        } else if(message instanceof WrappedShardResponse){
+            onWrappedShardResponse((WrappedShardResponse) message);
         } else if(message instanceof GetSnapshot) {
             onGetSnapshot();
         } else if(message instanceof ServerRemoved){
             onShardReplicaRemoved((ServerRemoved) message);
         } else if (message instanceof SaveSnapshotSuccess) {
-            LOG.debug("{} saved ShardManager snapshot successfully", persistenceId());
+            onSaveSnapshotSuccess((SaveSnapshotSuccess)message);
         } else if (message instanceof SaveSnapshotFailure) {
-            LOG.error ("{}: SaveSnapshotFailure received for saving snapshot of shards",
-                persistenceId(), ((SaveSnapshotFailure)message).cause());
+            LOG.error("{}: SaveSnapshotFailure received for saving snapshot of shards",
+                    persistenceId(), ((SaveSnapshotFailure) message).cause());
         } else {
             unknownMessage(message);
         }
     }
 
+    private void onWrappedShardResponse(WrappedShardResponse message) {
+        if (message.getResponse() instanceof RemoveServerReply) {
+            onRemoveServerReply(getSender(), message.getShardName(), (RemoveServerReply) message.getResponse());
+        }
+    }
+
+    private void onRemoveServerReply(ActorRef originalSender, String shardName, RemoveServerReply response) {
+        shardReplicaOperationsInProgress.remove(shardName);
+        originalSender.tell(new Status.Success(null), self());
+    }
+
+    private void onPrimaryShardFoundContext(PrimaryShardFoundForContext primaryShardFoundContext) {
+        if(primaryShardFoundContext.getContextMessage() instanceof AddShardReplica) {
+            addShard(primaryShardFoundContext.getShardName(), primaryShardFoundContext.getRemotePrimaryShardFound(),
+                    getSender());
+        } else if(primaryShardFoundContext.getContextMessage() instanceof RemoveShardReplica){
+            removeShardReplica((RemoveShardReplica) primaryShardFoundContext.getContextMessage(),
+                    primaryShardFoundContext.getShardName(), primaryShardFoundContext.getPrimaryPath(), getSender());
+        }
+    }
+
+    private void removeShardReplica(RemoveShardReplica contextMessage, final String shardName, final String primaryPath,
+            final ActorRef sender) {
+        if(isShardReplicaOperationInProgress(shardName, sender)) {
+            return;
+        }
+
+        shardReplicaOperationsInProgress.add(shardName);
+
+        final ShardIdentifier shardId = getShardIdentifier(contextMessage.getMemberName(), shardName);
+
+        final DatastoreContext datastoreContext = newShardDatastoreContextBuilder(shardName).build();
+
+        //inform ShardLeader to remove this shard as a replica by sending an RemoveServer message
+        LOG.debug ("{}: Sending RemoveServer message to peer {} for shard {}", persistenceId(),
+                primaryPath, shardId);
+
+        Timeout removeServerTimeout = new Timeout(datastoreContext.getShardLeaderElectionTimeout().
+                duration());
+        Future<Object> futureObj = ask(getContext().actorSelection(primaryPath),
+                new RemoveServer(shardId.toString()), removeServerTimeout);
+
+        futureObj.onComplete(new OnComplete<Object>() {
+            @Override
+            public void onComplete(Throwable failure, Object response) {
+                if (failure != null) {
+                    String msg = String.format("RemoveServer request to leader %s for shard %s failed",
+                            primaryPath, shardName);
+
+                    LOG.debug ("{}: {}", persistenceId(), msg, failure);
+
+                    // FAILURE
+                    sender.tell(new Status.Failure(new RuntimeException(msg, failure)), self());
+                } else {
+                    // SUCCESS
+                    self().tell(new WrappedShardResponse(shardName, response), sender);
+                }
+            }
+        }, new Dispatchers(context().system().dispatchers()).getDispatcher(Dispatchers.DispatcherType.Client));
+    }
+
     private void onShardReplicaRemoved(ServerRemoved message) {
         final ShardIdentifier shardId = new ShardIdentifier.Builder().fromShardIdString(message.getServerId()).build();
         final ShardInformation shardInformation = localShards.remove(shardId.getShardName());
@@ -276,6 +349,10 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
         }
 
         byte[] shardManagerSnapshot = null;
+        if(currentSnapshot != null) {
+            shardManagerSnapshot = SerializationUtils.serialize(currentSnapshot);
+        }
+
         ActorRef replyActor = getContext().actorOf(ShardManagerGetSnapshotReplyActor.props(
                 new ArrayList<>(localShards.keySet()), type, shardManagerSnapshot , getSender(), persistenceId(),
                 datastoreContextFactory.getBaseDatastoreContext().getShardInitializationTimeout().duration()));
@@ -286,17 +363,20 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
     }
 
     private void onCreateShard(CreateShard createShard) {
+        LOG.debug("{}: onCreateShard: {}", persistenceId(), createShard);
+
         Object reply;
         try {
             String shardName = createShard.getModuleShardConfig().getShardName();
             if(localShards.containsKey(shardName)) {
+                LOG.debug("{}: Shard {} already exists", persistenceId(), shardName);
                 reply = new akka.actor.Status.Success(String.format("Shard with name %s already exists", shardName));
             } else {
                 doCreateShard(createShard);
                 reply = new akka.actor.Status.Success(null);
             }
         } catch (Exception e) {
-            LOG.error("onCreateShard failed", e);
+            LOG.error("{}: onCreateShard failed", persistenceId(), e);
             reply = new akka.actor.Status.Failure(e);
         }
 
@@ -321,13 +401,18 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
 
         ShardIdentifier shardId = getShardIdentifier(cluster.getCurrentMemberName(), shardName);
 
+        boolean shardWasInRecoveredSnapshot = currentSnapshot != null &&
+                currentSnapshot.getShardList().contains(shardName);
+
         Map<String, String> peerAddresses;
         boolean isActiveMember;
-        if(configuration.getMembersFromShardName(shardName).contains(cluster.getCurrentMemberName())) {
+        if(shardWasInRecoveredSnapshot || configuration.getMembersFromShardName(shardName).
+                contains(cluster.getCurrentMemberName())) {
             peerAddresses = getPeerAddresses(shardName);
             isActiveMember = true;
         } else {
-            // The local member is not in the given shard member configuration. In this case we'll create
+            // The local member is not in the static shard member configuration and the shard did not
+            // previously exist (ie !shardWasInRecoveredSnapshot). In this case we'll create
             // the shard with no peers and with elections disabled so it stays as follower. A
             // subsequent AddServer request will be needed to make it an active member.
             isActiveMember = false;
@@ -336,8 +421,9 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
                     customRaftPolicyImplementation(DisableElectionsRaftPolicy.class.getName()).build();
         }
 
-        LOG.debug("onCreateShard: shardId: {}, memberNames: {}. peerAddresses: {}", shardId,
-                moduleShardConfig.getShardMemberNames(), peerAddresses);
+        LOG.debug("{} doCreateShard: shardId: {}, memberNames: {}, peerAddresses: {}, isActiveMember: {}",
+                persistenceId(), shardId, moduleShardConfig.getShardMemberNames(), peerAddresses,
+                isActiveMember);
 
         ShardInformation info = new ShardInformation(shardName, shardId, peerAddresses,
                 shardDatastoreContext, createShard.getShardBuilder(), peerAddressResolver);
@@ -492,15 +578,34 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
     @Override
     protected void handleRecover(Object message) throws Exception {
         if (message instanceof RecoveryCompleted) {
-            LOG.info("Recovery complete : {}", persistenceId());
-
-            // We no longer persist SchemaContext modules so delete all the prior messages from the akka
-            // journal on upgrade from Helium.
-            deleteMessages(lastSequenceNr());
-            createLocalShards();
+            onRecoveryCompleted();
         } else if (message instanceof SnapshotOffer) {
-            handleShardRecovery((SnapshotOffer) message);
+            applyShardManagerSnapshot((ShardManagerSnapshot)((SnapshotOffer) message).snapshot());
+        }
+    }
+
+    private void onRecoveryCompleted() {
+        LOG.info("Recovery complete : {}", persistenceId());
+
+        // We no longer persist SchemaContext modules so delete all the prior messages from the akka
+        // journal on upgrade from Helium.
+        deleteMessages(lastSequenceNr());
+
+        if(currentSnapshot == null && restoreFromSnapshot != null &&
+                restoreFromSnapshot.getShardManagerSnapshot() != null) {
+            try(ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(
+                    restoreFromSnapshot.getShardManagerSnapshot()))) {
+                ShardManagerSnapshot snapshot = (ShardManagerSnapshot) ois.readObject();
+
+                LOG.debug("{}: Deserialized restored ShardManagerSnapshot: {}", persistenceId(), snapshot);
+
+                applyShardManagerSnapshot(snapshot);
+            } catch(Exception e) {
+                LOG.error("{}: Error deserializing restored ShardManagerSnapshot", persistenceId(), e);
+            }
         }
+
+        createLocalShards();
     }
 
     private void findLocalShard(FindLocalShard message) {
@@ -752,12 +857,25 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
             return;
         }
 
+        Collection<String> visitedAddresses;
+        if(message instanceof RemoteFindPrimary) {
+            visitedAddresses = ((RemoteFindPrimary)message).getVisitedAddresses();
+        } else {
+            visitedAddresses = new ArrayList<>();
+        }
+
+        visitedAddresses.add(peerAddressResolver.getShardManagerActorPathBuilder(cluster.getSelfAddress()).toString());
+
         for(String address: peerAddressResolver.getShardManagerPeerActorAddresses()) {
+            if(visitedAddresses.contains(address)) {
+                continue;
+            }
+
             LOG.debug("{}: findPrimary for {} forwarding to remote ShardManager {}", persistenceId(),
                     shardName, address);
 
             getContext().actorSelection(address).forward(new RemoteFindPrimary(shardName,
-                    message.isWaitUntilReady()), getContext());
+                    message.isWaitUntilReady(), visitedAddresses), getContext());
             return;
         }
 
@@ -800,6 +918,9 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
 
         for(String shardName : memberShardNames){
             ShardIdentifier shardId = getShardIdentifier(memberName, shardName);
+
+            LOG.debug("{}: Creating local shard: {}", persistenceId(), shardId);
+
             Map<String, String> peerAddresses = getPeerAddresses(shardName);
             localShards.put(shardName, new ShardInformation(shardName, shardId, peerAddresses,
                     newShardDatastoreContext(shardName), Shard.builder().restoreFromSnapshot(
@@ -865,7 +986,7 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
         return false;
     }
 
-    private void onAddShardReplica (AddShardReplica shardReplicaMsg) {
+    private void onAddShardReplica (final AddShardReplica shardReplicaMsg) {
         final String shardName = shardReplicaMsg.getShardName();
 
         LOG.debug("{}: onAddShardReplica: {}", persistenceId(), shardReplicaMsg);
@@ -887,34 +1008,18 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
             return;
         }
 
-        Timeout findPrimaryTimeout = new Timeout(datastoreContextFactory.getBaseDatastoreContext().
-                getShardInitializationTimeout().duration().$times(2));
+        findPrimary(shardName, new AutoFindPrimaryFailureResponseHandler(getSender(), shardName, persistenceId(), getSelf()) {
+            @Override
+            public void onRemotePrimaryShardFound(RemotePrimaryShardFound response) {
+                getSelf().tell(new PrimaryShardFoundForContext(getShardName(), shardReplicaMsg, response), getTargetActor());
+            }
 
-        final ActorRef sender = getSender();
-        Future<Object> futureObj = ask(getSelf(), new FindPrimary(shardName, true), findPrimaryTimeout);
-        futureObj.onComplete(new OnComplete<Object>() {
             @Override
-            public void onComplete(Throwable failure, Object response) {
-                if (failure != null) {
-                    LOG.debug ("{}: Received failure from FindPrimary for shard {}", persistenceId(), shardName, failure);
-                    sender.tell(new akka.actor.Status.Failure(new RuntimeException(
-                        String.format("Failed to find leader for shard %s", shardName), failure)), getSelf());
-                } else {
-                    if(response instanceof RemotePrimaryShardFound) {
-                        self().tell(new ForwardedAddServerPrimaryShardFound(shardName,
-                                (RemotePrimaryShardFound)response), sender);
-                    } else if(response instanceof LocalPrimaryShardFound) {
-                        sendLocalReplicaAlreadyExistsReply(shardName, sender);
-                    } else {
-                        String msg = String.format("Failed to find leader for shard %s: received response: %s",
-                                shardName, response);
-                        LOG.debug ("{}: {}", persistenceId(), msg);
-                        sender.tell(new akka.actor.Status.Failure(response instanceof Throwable ? (Throwable)response :
-                            new RuntimeException(msg)), getSelf());
-                    }
-                }
+            public void onLocalPrimaryFound(LocalPrimaryShardFound message) {
+                sendLocalReplicaAlreadyExistsReply(getShardName(), getTargetActor());
             }
-        }, new Dispatchers(context().system().dispatchers()).getDispatcher(Dispatchers.DispatcherType.Client));
+
+        });
     }
 
     private void sendLocalReplicaAlreadyExistsReply(String shardName, ActorRef sender) {
@@ -1017,40 +1122,52 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
             LOG.warn ("{}: Leader failed to add shard replica {} with status {}",
                     persistenceId(), shardName, replyMsg.getStatus());
 
-            Exception failure;
-            switch (replyMsg.getStatus()) {
-                case TIMEOUT:
-                    failure = new TimeoutException(String.format(
-                            "The shard leader %s timed out trying to replicate the initial data to the new shard %s." +
-                            "Possible causes - there was a problem replicating the data or shard leadership changed while replicating the shard data",
-                            leaderPath, shardName));
-                    break;
-                case NO_LEADER:
-                    failure = createNoShardLeaderException(shardInfo.getShardId());
-                    break;
-                default :
-                    failure = new RuntimeException(String.format(
-                            "AddServer request to leader %s for shard %s failed with status %s",
-                            leaderPath, shardName, replyMsg.getStatus()));
-            }
+            Exception failure = getServerChangeException(AddServer.class, replyMsg.getStatus(), leaderPath, shardInfo.getShardId());
 
             onAddServerFailure(shardName, null, failure, sender, removeShardOnFailure);
         }
     }
 
-    private void onRemoveShardReplica (RemoveShardReplica shardReplicaMsg) {
-        String shardName = shardReplicaMsg.getShardName();
-
-        // verify the local shard replica is available in the controller node
-        if (!localShards.containsKey(shardName)) {
-            String msg = String.format("Local shard %s does not", shardName);
-            LOG.debug ("{}: {}", persistenceId(), msg);
-            getSender().tell(new akka.actor.Status.Failure(new IllegalArgumentException(msg)), getSelf());
-            return;
+    private Exception getServerChangeException(Class<?> serverChange, ServerChangeStatus serverChangeStatus,
+                                               String leaderPath, ShardIdentifier shardId) {
+        Exception failure;
+        switch (serverChangeStatus) {
+            case TIMEOUT:
+                failure = new TimeoutException(String.format(
+                        "The shard leader %s timed out trying to replicate the initial data to the new shard %s." +
+                        "Possible causes - there was a problem replicating the data or shard leadership changed while replicating the shard data",
+                        leaderPath, shardId.getShardName()));
+                break;
+            case NO_LEADER:
+                failure = createNoShardLeaderException(shardId);
+                break;
+            case NOT_SUPPORTED:
+                failure = new UnsupportedOperationException(String.format("%s request is not supported for shard %s",
+                        serverChange.getSimpleName(), shardId.getShardName()));
+                break;
+            default :
+                failure = new RuntimeException(String.format(
+                        "%s request to leader %s for shard %s failed with status %s",
+                        serverChange.getSimpleName(), leaderPath, shardId.getShardName(), serverChangeStatus));
         }
-        // call RemoveShard for the shardName
-        getSender().tell(new akka.actor.Status.Success(true), getSelf());
-        return;
+        return failure;
+    }
+
+    private void onRemoveShardReplica (final RemoveShardReplica shardReplicaMsg) {
+        LOG.debug("{}: onRemoveShardReplica: {}", persistenceId(), shardReplicaMsg);
+
+        findPrimary(shardReplicaMsg.getShardName(), new AutoFindPrimaryFailureResponseHandler(getSender(),
+                shardReplicaMsg.getShardName(), persistenceId(), getSelf()) {
+            @Override
+            public void onRemotePrimaryShardFound(RemotePrimaryShardFound response) {
+                getSelf().tell(new PrimaryShardFoundForContext(getShardName(), shardReplicaMsg, response), getTargetActor());
+            }
+
+            @Override
+            public void onLocalPrimaryFound(LocalPrimaryShardFound response) {
+                getSelf().tell(new PrimaryShardFoundForContext(getShardName(), shardReplicaMsg, response), getTargetActor());
+            }
+        });
     }
 
     private void persistShardList() {
@@ -1061,16 +1178,23 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
             }
         }
         LOG.debug ("{}: persisting the shard list {}", persistenceId(), shardList);
-        saveSnapshot(new ShardManagerSnapshot(shardList));
+        saveSnapshot(updateShardManagerSnapshot(shardList));
+    }
+
+    private ShardManagerSnapshot updateShardManagerSnapshot(List<String> shardList) {
+        currentSnapshot = new ShardManagerSnapshot(shardList);
+        return currentSnapshot;
     }
 
-    private void handleShardRecovery(SnapshotOffer offer) {
-        LOG.debug ("{}: in handleShardRecovery", persistenceId());
-        ShardManagerSnapshot snapshot = (ShardManagerSnapshot)offer.snapshot();
+    private void applyShardManagerSnapshot(ShardManagerSnapshot snapshot) {
+        currentSnapshot = snapshot;
+
+        LOG.debug ("{}: onSnapshotOffer: {}", persistenceId(), currentSnapshot);
+
         String currentMember = cluster.getCurrentMemberName();
         Set<String> configuredShardList =
             new HashSet<>(configuration.getMemberShardNames(currentMember));
-        for (String shard : snapshot.getShardList()) {
+        for (String shard : currentSnapshot.getShardList()) {
             if (!configuredShardList.contains(shard)) {
                 // add the current member as a replica for the shard
                 LOG.debug ("{}: adding shard {}", persistenceId(), shard);
@@ -1086,14 +1210,10 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
         }
     }
 
-    private static class ForwardedAddServerPrimaryShardFound {
-        String shardName;
-        RemotePrimaryShardFound primaryFound;
-
-        ForwardedAddServerPrimaryShardFound(String shardName, RemotePrimaryShardFound primaryFound) {
-            this.shardName = shardName;
-            this.primaryFound = primaryFound;
-        }
+    private void onSaveSnapshotSuccess (SaveSnapshotSuccess successMessage) {
+        LOG.debug ("{} saved ShardManager snapshot successfully. Deleting the prev snapshot if available",
+            persistenceId());
+        deleteSnapshots(new SnapshotSelectionCriteria(scala.Long.MaxValue(), (successMessage.metadata().timestamp() - 1)));
     }
 
     private static class ForwardedAddServerReply {
@@ -1340,6 +1460,10 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
 
         void setLeaderAvailable(boolean leaderAvailable) {
             this.leaderAvailable = leaderAvailable;
+
+            if(leaderAvailable) {
+                notifyOnShardInitializedCallbacks();
+            }
         }
 
         short getLeaderVersion() {
@@ -1433,53 +1557,59 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
         return new Builder();
     }
 
-    public static class Builder {
+    public static abstract class AbstractBuilder<T extends AbstractBuilder<T>> {
         private ClusterWrapper cluster;
         private Configuration configuration;
         private DatastoreContextFactory datastoreContextFactory;
         private CountDownLatch waitTillReadyCountdownLatch;
         private PrimaryShardInfoFutureCache primaryShardInfoCache;
         private DatastoreSnapshot restoreFromSnapshot;
+
         private volatile boolean sealed;
 
+        @SuppressWarnings("unchecked")
+        private T self() {
+            return (T) this;
+        }
+
         protected void checkSealed() {
             Preconditions.checkState(!sealed, "Builder is already sealed - further modifications are not allowed");
         }
 
-        public Builder cluster(ClusterWrapper cluster) {
+        public T cluster(ClusterWrapper cluster) {
             checkSealed();
             this.cluster = cluster;
-            return this;
+            return self();
         }
 
-        public Builder configuration(Configuration configuration) {
+        public T configuration(Configuration configuration) {
             checkSealed();
             this.configuration = configuration;
-            return this;
+            return self();
         }
 
-        public Builder datastoreContextFactory(DatastoreContextFactory datastoreContextFactory) {
+        public T datastoreContextFactory(DatastoreContextFactory datastoreContextFactory) {
             checkSealed();
             this.datastoreContextFactory = datastoreContextFactory;
-            return this;
+            return self();
         }
 
-        public Builder waitTillReadyCountdownLatch(CountDownLatch waitTillReadyCountdownLatch) {
+        public T waitTillReadyCountdownLatch(CountDownLatch waitTillReadyCountdownLatch) {
             checkSealed();
             this.waitTillReadyCountdownLatch = waitTillReadyCountdownLatch;
-            return this;
+            return self();
         }
 
-        public Builder primaryShardInfoCache(PrimaryShardInfoFutureCache primaryShardInfoCache) {
+        public T primaryShardInfoCache(PrimaryShardInfoFutureCache primaryShardInfoCache) {
             checkSealed();
             this.primaryShardInfoCache = primaryShardInfoCache;
-            return this;
+            return self();
         }
 
-        public Builder restoreFromSnapshot(DatastoreSnapshot restoreFromSnapshot) {
+        public T restoreFromSnapshot(DatastoreSnapshot restoreFromSnapshot) {
             checkSealed();
             this.restoreFromSnapshot = restoreFromSnapshot;
-            return this;
+            return self();
         }
 
         protected void verify() {
@@ -1496,6 +1626,182 @@ public class ShardManager extends AbstractUntypedPersistentActorWithMetering {
             return Props.create(ShardManager.class, this);
         }
     }
+
+    public static class Builder extends AbstractBuilder<Builder> {
+    }
+
+    private void findPrimary(final String shardName, final FindPrimaryResponseHandler handler) {
+        Timeout findPrimaryTimeout = new Timeout(datastoreContextFactory.getBaseDatastoreContext().
+                getShardInitializationTimeout().duration().$times(2));
+
+
+        Future<Object> futureObj = ask(getSelf(), new FindPrimary(shardName, true), findPrimaryTimeout);
+        futureObj.onComplete(new OnComplete<Object>() {
+            @Override
+            public void onComplete(Throwable failure, Object response) {
+                if (failure != null) {
+                    handler.onFailure(failure);
+                } else {
+                    if(response instanceof RemotePrimaryShardFound) {
+                        handler.onRemotePrimaryShardFound((RemotePrimaryShardFound) response);
+                    } else if(response instanceof LocalPrimaryShardFound) {
+                        handler.onLocalPrimaryFound((LocalPrimaryShardFound) response);
+                    } else {
+                        handler.onUnknownResponse(response);
+                    }
+                }
+            }
+        }, new Dispatchers(context().system().dispatchers()).getDispatcher(Dispatchers.DispatcherType.Client));
+    }
+
+    /**
+     * The FindPrimaryResponseHandler provides specific callback methods which are invoked when a response to the
+     * a remote or local find primary message is processed
+     */
+    private static interface FindPrimaryResponseHandler {
+        /**
+         * Invoked when a Failure message is received as a response
+         *
+         * @param failure
+         */
+        void onFailure(Throwable failure);
+
+        /**
+         * Invoked when a RemotePrimaryShardFound response is received
+         *
+         * @param response
+         */
+        void onRemotePrimaryShardFound(RemotePrimaryShardFound response);
+
+        /**
+         * Invoked when a LocalPrimaryShardFound response is received
+         * @param response
+         */
+        void onLocalPrimaryFound(LocalPrimaryShardFound response);
+
+        /**
+         * Invoked when an unknown response is received. This is another type of failure.
+         *
+         * @param response
+         */
+        void onUnknownResponse(Object response);
+    }
+
+    /**
+     * The AutoFindPrimaryFailureResponseHandler automatically processes Failure responses when finding a primary
+     * replica and sends a wrapped Failure response to some targetActor
+     */
+    private static abstract class AutoFindPrimaryFailureResponseHandler implements FindPrimaryResponseHandler {
+        private final ActorRef targetActor;
+        private final String shardName;
+        private final String persistenceId;
+        private final ActorRef shardManagerActor;
+
+        /**
+         * @param targetActor The actor to whom the Failure response should be sent when a FindPrimary failure occurs
+         * @param shardName The name of the shard for which the primary replica had to be found
+         * @param persistenceId The persistenceId for the ShardManager
+         * @param shardManagerActor The ShardManager actor which triggered the call to FindPrimary
+         */
+        protected AutoFindPrimaryFailureResponseHandler(ActorRef targetActor, String shardName, String persistenceId, ActorRef shardManagerActor){
+            this.targetActor = Preconditions.checkNotNull(targetActor);
+            this.shardName = Preconditions.checkNotNull(shardName);
+            this.persistenceId = Preconditions.checkNotNull(persistenceId);
+            this.shardManagerActor = Preconditions.checkNotNull(shardManagerActor);
+        }
+
+        public ActorRef getTargetActor() {
+            return targetActor;
+        }
+
+        public String getShardName() {
+            return shardName;
+        }
+
+        @Override
+        public void onFailure(Throwable failure) {
+            LOG.debug ("{}: Received failure from FindPrimary for shard {}", persistenceId, shardName, failure);
+            targetActor.tell(new akka.actor.Status.Failure(new RuntimeException(
+                    String.format("Failed to find leader for shard %s", shardName), failure)), shardManagerActor);
+        }
+
+        @Override
+        public void onUnknownResponse(Object response) {
+            String msg = String.format("Failed to find leader for shard %s: received response: %s",
+                    shardName, response);
+            LOG.debug ("{}: {}", persistenceId, msg);
+            targetActor.tell(new akka.actor.Status.Failure(response instanceof Throwable ? (Throwable) response :
+                    new RuntimeException(msg)), shardManagerActor);
+        }
+    }
+
+
+    /**
+     * The PrimaryShardFoundForContext is a DTO which puts together a message (aka 'Context' message) which needs to be
+     * forwarded to the primary replica of a shard and the message (aka 'PrimaryShardFound' message) that is received
+     * as a successful response to find primary.
+     */
+    private static class PrimaryShardFoundForContext {
+        private final String shardName;
+        private final Object contextMessage;
+        private final RemotePrimaryShardFound remotePrimaryShardFound;
+        private final LocalPrimaryShardFound localPrimaryShardFound;
+
+        public PrimaryShardFoundForContext(@Nonnull String shardName, @Nonnull Object contextMessage,
+                @Nonnull Object primaryFoundMessage) {
+            this.shardName = Preconditions.checkNotNull(shardName);
+            this.contextMessage = Preconditions.checkNotNull(contextMessage);
+            Preconditions.checkNotNull(primaryFoundMessage);
+            this.remotePrimaryShardFound = (primaryFoundMessage instanceof RemotePrimaryShardFound) ?
+                    (RemotePrimaryShardFound) primaryFoundMessage : null;
+            this.localPrimaryShardFound = (primaryFoundMessage instanceof LocalPrimaryShardFound) ?
+                    (LocalPrimaryShardFound) primaryFoundMessage : null;
+        }
+
+        @Nonnull
+        String getPrimaryPath(){
+            if(remotePrimaryShardFound != null) {
+                return remotePrimaryShardFound.getPrimaryPath();
+            }
+            return localPrimaryShardFound.getPrimaryPath();
+        }
+
+        @Nonnull
+        Object getContextMessage() {
+            return contextMessage;
+        }
+
+        @Nullable
+        RemotePrimaryShardFound getRemotePrimaryShardFound() {
+            return remotePrimaryShardFound;
+        }
+
+        @Nonnull
+        String getShardName() {
+            return shardName;
+        }
+    }
+
+    /**
+     * The WrappedShardResponse class wraps a response from a Shard.
+     */
+    private static class WrappedShardResponse {
+        private final String shardName;
+        private final Object response;
+
+        private WrappedShardResponse(String shardName, Object response) {
+            this.shardName = shardName;
+            this.response = response;
+        }
+
+        String getShardName() {
+            return shardName;
+        }
+
+        Object getResponse() {
+            return response;
+        }
+    }
 }