tell-based - reconnect on leader change 29/78029/5
authorTom Pantelis <tompantelis@gmail.com>
Wed, 21 Nov 2018 20:38:05 +0000 (15:38 -0500)
committerRobert Varga <nite@hq.sk>
Tue, 27 Nov 2018 14:44:11 +0000 (14:44 +0000)
The ShardManager is the aggregation point for shards so we need
to propagate shard leader change events etc to the ClientActorBehavior
to initiate a refresh of the backend info. The ModuleShardBackendResolver
sends a new message, RegisterForShardAvailabilityChanges, to the
ShardManager actor with a Consumer callback that is notified by the
ShardManager when events affecting shard leader availability occur.
The ModuleShardBackendResolver then propagates the event notification
to callbacks registered via a new notifyWhenBackendInfoIsStale method
exposed via the BackendInfoResolver interface, which the
ClientActorBehavior calls.

JIRA: CONTROLLER-1873
Change-Id: I9dbcabf5a75b195c811a22dd522115d329e5dc4b
Signed-off-by: Tom Pantelis <tompantelis@gmail.com>
12 files changed:
opendaylight/md-sal/cds-access-client/src/main/java/org/opendaylight/controller/cluster/access/client/AbstractClientConnection.java
opendaylight/md-sal/cds-access-client/src/main/java/org/opendaylight/controller/cluster/access/client/BackendInfoResolver.java
opendaylight/md-sal/cds-access-client/src/main/java/org/opendaylight/controller/cluster/access/client/ClientActorBehavior.java
opendaylight/md-sal/sal-distributed-datastore/src/main/java/org/opendaylight/controller/cluster/databroker/actors/dds/AbstractDataStoreClientBehavior.java
opendaylight/md-sal/sal-distributed-datastore/src/main/java/org/opendaylight/controller/cluster/databroker/actors/dds/AbstractShardBackendResolver.java
opendaylight/md-sal/sal-distributed-datastore/src/main/java/org/opendaylight/controller/cluster/databroker/actors/dds/DistributedDataStoreClientBehavior.java
opendaylight/md-sal/sal-distributed-datastore/src/main/java/org/opendaylight/controller/cluster/databroker/actors/dds/ModuleShardBackendResolver.java
opendaylight/md-sal/sal-distributed-datastore/src/main/java/org/opendaylight/controller/cluster/datastore/shardmanager/RegisterForShardAvailabilityChanges.java [new file with mode: 0644]
opendaylight/md-sal/sal-distributed-datastore/src/main/java/org/opendaylight/controller/cluster/datastore/shardmanager/ShardInformation.java
opendaylight/md-sal/sal-distributed-datastore/src/main/java/org/opendaylight/controller/cluster/datastore/shardmanager/ShardManager.java
opendaylight/md-sal/sal-distributed-datastore/src/test/java/org/opendaylight/controller/cluster/databroker/actors/dds/ModuleShardBackendResolverTest.java
opendaylight/md-sal/sal-distributed-datastore/src/test/java/org/opendaylight/controller/cluster/datastore/shardmanager/ShardManagerTest.java

index af66369271c34c8eba953d6fc6ecf5c27637aa51..9fd75cc439b68499e13c463fbd419ae6b2d4704b 100644 (file)
@@ -422,7 +422,7 @@ public abstract class AbstractClientConnection<T extends BackendInfo> {
         context.executeInActor(current -> {
             final double time = beenOpen * 1.0 / 1_000_000_000;
             entry.complete(entry.getRequest().toRequestFailure(
-                new RequestTimeoutException("Timed out after " + time + "seconds")));
+                new RequestTimeoutException("Timed out after " + time + " seconds")));
             return current;
         });
     }
index 4ece691d898c45169a4783491d19e5fd2877bdaf..3c6e093bfb570e21512ac9cf5bb8ee0a1a805abd 100644 (file)
@@ -9,7 +9,9 @@ package org.opendaylight.controller.cluster.access.client;
 
 import akka.actor.ActorRef;
 import java.util.concurrent.CompletionStage;
+import java.util.function.Consumer;
 import javax.annotation.Nonnull;
+import org.opendaylight.yangtools.concepts.Registration;
 
 /**
  * Caching resolver which resolves a cookie to a leader {@link ActorRef}. This class needs to be specialized by the
@@ -26,7 +28,7 @@ import javax.annotation.Nonnull;
  *
  * @author Robert Varga
  */
-public abstract class BackendInfoResolver<T extends BackendInfo> {
+public abstract class BackendInfoResolver<T extends BackendInfo> implements AutoCloseable {
     /**
      * Request resolution of a particular backend identified by a cookie. This request can be satisfied from the cache.
      *
@@ -46,4 +48,18 @@ public abstract class BackendInfoResolver<T extends BackendInfo> {
      */
     @Nonnull
     public abstract CompletionStage<? extends T> refreshBackendInfo(@Nonnull Long cookie, @Nonnull T staleInfo);
+
+    /**
+     * Registers a callback to be notified when BackendInfo that may have been previously obtained is now stale and
+     * should be refreshed.
+     *
+     * @param callback the callback that takes the backend cookie whose BackendInfo is now stale.
+     * @return a Registration
+     */
+    @Nonnull
+    public abstract Registration notifyWhenBackendInfoIsStale(Consumer<Long> callback);
+
+    @Override
+    public void close() {
+    }
 }
index fa2e3b76d8a038497d57efd7e344498862717d06..ddf1dc190b8ed1a4ea3c70b2a1386a078f426053 100644 (file)
@@ -37,6 +37,7 @@ import org.opendaylight.controller.cluster.io.FileBackedOutputStreamFactory;
 import org.opendaylight.controller.cluster.messaging.MessageAssembler;
 import org.opendaylight.yangtools.concepts.Identifiable;
 import org.opendaylight.yangtools.concepts.Identifier;
+import org.opendaylight.yangtools.concepts.Registration;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import scala.concurrent.duration.FiniteDuration;
@@ -83,6 +84,7 @@ public abstract class ClientActorBehavior<T extends BackendInfo> extends
     private final InversibleLock connectionsLock = new InversibleLock();
     private final BackendInfoResolver<T> resolver;
     private final MessageAssembler responseMessageAssembler;
+    private final Registration staleBackendInfoReg;
 
     protected ClientActorBehavior(@Nonnull final ClientActorContext context,
             @Nonnull final BackendInfoResolver<T> resolver) {
@@ -94,6 +96,17 @@ public abstract class ClientActorBehavior<T extends BackendInfo> extends
                 .fileBackedStreamFactory(new FileBackedOutputStreamFactory(config.getFileBackedStreamingThreshold(),
                         config.getTempFileDirectory()))
                 .assembledMessageCallback((message, sender) -> context.self().tell(message, sender)).build();
+
+        staleBackendInfoReg = resolver.notifyWhenBackendInfoIsStale(shard -> {
+            context().executeInActor(behavior -> {
+                LOG.debug("BackendInfo for shard {} is now stale", shard);
+                final AbstractClientConnection<T> conn = connections.get(shard);
+                if (conn instanceof ConnectedClientConnection) {
+                    conn.reconnect(this, new BackendStaleException(shard));
+                }
+                return behavior;
+            });
+        });
     }
 
     @Override
@@ -104,7 +117,9 @@ public abstract class ClientActorBehavior<T extends BackendInfo> extends
 
     @Override
     public void close() {
+        super.close();
         responseMessageAssembler.close();
+        staleBackendInfoReg.close();
     }
 
     /**
@@ -437,4 +452,17 @@ public abstract class ClientActorBehavior<T extends BackendInfo> extends
             return behavior;
         }));
     }
+
+    private static class BackendStaleException extends RequestException {
+        private static final long serialVersionUID = 1L;
+
+        BackendStaleException(final Long shard) {
+            super("Backend for shard " + shard + " is stale");
+        }
+
+        @Override
+        public boolean isRetriable() {
+            return false;
+        }
+    }
 }
index 77db4e6c0804073392950c122d3f0cfa8911daee..4f91cb27fae151a26ba5c49dbd8ba12498dc8238 100644 (file)
@@ -218,7 +218,8 @@ abstract class AbstractDataStoreClientBehavior extends ClientActorBehavior<Shard
     }
 
     @Override
-    public final void close() {
+    public void close() {
+        super.close();
         context().executeInActor(this::shutdown);
     }
 
index faffe0a3764985f5439e13867928a1ee2a1f2511..d2b10c1b7067a038e90d910b16a0749ff4b76212 100644 (file)
@@ -11,11 +11,14 @@ import akka.actor.ActorRef;
 import akka.util.Timeout;
 import com.google.common.base.Preconditions;
 import com.google.common.primitives.UnsignedLong;
+import java.util.Set;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.CompletionStage;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
 import java.util.concurrent.atomic.AtomicLong;
+import java.util.function.Consumer;
 import javax.annotation.Nonnull;
 import javax.annotation.Nullable;
 import javax.annotation.concurrent.GuardedBy;
@@ -33,6 +36,7 @@ import org.opendaylight.controller.cluster.datastore.exceptions.NotInitializedEx
 import org.opendaylight.controller.cluster.datastore.exceptions.PrimaryNotFoundException;
 import org.opendaylight.controller.cluster.datastore.messages.PrimaryShardInfo;
 import org.opendaylight.controller.cluster.datastore.utils.ActorContext;
+import org.opendaylight.yangtools.concepts.Registration;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import scala.Function1;
@@ -85,6 +89,7 @@ abstract class AbstractShardBackendResolver extends BackendInfoResolver<ShardBac
     private final AtomicLong nextSessionId = new AtomicLong();
     private final Function1<ActorRef, ?> connectFunction;
     private final ActorContext actorContext;
+    private final Set<Consumer<Long>> staleBackendInfoCallbacks = ConcurrentHashMap.newKeySet();
 
     // FIXME: we really need just ActorContext.findPrimaryShardAsync()
     AbstractShardBackendResolver(final ClientIdentifier clientId, final ActorContext actorContext) {
@@ -93,6 +98,20 @@ abstract class AbstractShardBackendResolver extends BackendInfoResolver<ShardBac
             ABIVersion.current()));
     }
 
+    @Override
+    public Registration notifyWhenBackendInfoIsStale(final Consumer<Long> callback) {
+        staleBackendInfoCallbacks.add(callback);
+        return () -> staleBackendInfoCallbacks.remove(callback);
+    }
+
+    protected void notifyStaleBackendInfoCallbacks(Long cookie) {
+        staleBackendInfoCallbacks.forEach(callback -> callback.accept(cookie));
+    }
+
+    protected ActorContext actorContext() {
+        return actorContext;
+    }
+
     protected final void flushCache(final String shardName) {
         actorContext.getPrimaryShardInfoCache().remove(shardName);
     }
index bc393a4c0f9cdffe61e0a49525edebcef92a24e1..792b5b31c1cadef2ce9c3c6575e9b9963cf36686 100644 (file)
@@ -34,4 +34,10 @@ final class DistributedDataStoreClientBehavior extends AbstractDataStoreClientBe
     Long resolveShardForPath(final YangInstanceIdentifier path) {
         return pathToShard.apply(path);
     }
+
+    @Override
+    public void close() {
+        super.close();
+        resolver().close();
+    }
 }
index b79a5ab88ec8af82e4ded27b937de4044f48e117..004df590fe2f2c7fe0c9d75a4d9a2cfd9975c918 100644 (file)
@@ -7,22 +7,30 @@
  */
 package org.opendaylight.controller.cluster.databroker.actors.dds;
 
-import com.google.common.base.Preconditions;
+import static akka.pattern.Patterns.ask;
+
+import akka.dispatch.ExecutionContexts;
+import akka.dispatch.OnComplete;
+import akka.util.Timeout;
 import com.google.common.collect.BiMap;
 import com.google.common.collect.ImmutableBiMap;
 import com.google.common.collect.ImmutableBiMap.Builder;
 import java.util.concurrent.CompletionStage;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.TimeUnit;
 import javax.annotation.concurrent.GuardedBy;
 import javax.annotation.concurrent.ThreadSafe;
 import org.opendaylight.controller.cluster.access.client.BackendInfoResolver;
 import org.opendaylight.controller.cluster.access.concepts.ClientIdentifier;
+import org.opendaylight.controller.cluster.datastore.shardmanager.RegisterForShardAvailabilityChanges;
 import org.opendaylight.controller.cluster.datastore.shardstrategy.DefaultShardStrategy;
 import org.opendaylight.controller.cluster.datastore.utils.ActorContext;
+import org.opendaylight.yangtools.concepts.Registration;
 import org.opendaylight.yangtools.yang.data.api.YangInstanceIdentifier;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
+import scala.concurrent.Future;
 
 /**
  * {@link BackendInfoResolver} implementation for static shard configuration based on ShardManager. Each string-named
@@ -36,7 +44,8 @@ final class ModuleShardBackendResolver extends AbstractShardBackendResolver {
     private static final Logger LOG = LoggerFactory.getLogger(ModuleShardBackendResolver.class);
 
     private final ConcurrentMap<Long, ShardState> backends = new ConcurrentHashMap<>();
-    private final ActorContext actorContext;
+
+    private final Future<Registration> shardAvailabilityChangesRegFuture;
 
     @GuardedBy("this")
     private long nextShard = 1;
@@ -46,11 +55,35 @@ final class ModuleShardBackendResolver extends AbstractShardBackendResolver {
     // FIXME: we really need just ActorContext.findPrimaryShardAsync()
     ModuleShardBackendResolver(final ClientIdentifier clientId, final ActorContext actorContext) {
         super(clientId, actorContext);
-        this.actorContext = Preconditions.checkNotNull(actorContext);
+
+        shardAvailabilityChangesRegFuture = ask(actorContext.getShardManager(), new RegisterForShardAvailabilityChanges(
+            this::onShardAvailabilityChange), Timeout.apply(60, TimeUnit.MINUTES))
+                .map(reply -> (Registration)reply, ExecutionContexts.global());
+
+        shardAvailabilityChangesRegFuture.onComplete(new OnComplete<Registration>() {
+            @Override
+            public void onComplete(Throwable failure, Registration reply) {
+                if (failure != null) {
+                    LOG.error("RegisterForShardAvailabilityChanges failed", failure);
+                }
+            }
+        }, ExecutionContexts.global());
+    }
+
+    private void onShardAvailabilityChange(String shardName) {
+        LOG.debug("onShardAvailabilityChange for {}", shardName);
+
+        Long cookie = shards.get(shardName);
+        if (cookie == null) {
+            LOG.debug("No shard cookie found for {}", shardName);
+            return;
+        }
+
+        notifyStaleBackendInfoCallbacks(cookie);
     }
 
     Long resolveShardForPath(final YangInstanceIdentifier path) {
-        final String shardName = actorContext.getShardStrategyFactory().getStrategy(path).findShard(path);
+        final String shardName = actorContext().getShardStrategyFactory().getStrategy(path).findShard(path);
         Long cookie = shards.get(shardName);
         if (cookie == null) {
             synchronized (this) {
@@ -69,7 +102,6 @@ final class ModuleShardBackendResolver extends AbstractShardBackendResolver {
         return cookie;
     }
 
-
     @Override
     public CompletionStage<ShardBackendInfo> getBackendInfo(final Long cookie) {
         /*
@@ -135,4 +167,14 @@ final class ModuleShardBackendResolver extends AbstractShardBackendResolver {
 
         return getBackendInfo(cookie);
     }
+
+    @Override
+    public void close() {
+        shardAvailabilityChangesRegFuture.onComplete(new OnComplete<Registration>() {
+            @Override
+            public void onComplete(Throwable failure, Registration reply) {
+                reply.close();
+            }
+        }, ExecutionContexts.global());
+    }
 }
diff --git a/opendaylight/md-sal/sal-distributed-datastore/src/main/java/org/opendaylight/controller/cluster/datastore/shardmanager/RegisterForShardAvailabilityChanges.java b/opendaylight/md-sal/sal-distributed-datastore/src/main/java/org/opendaylight/controller/cluster/datastore/shardmanager/RegisterForShardAvailabilityChanges.java
new file mode 100644 (file)
index 0000000..de58667
--- /dev/null
@@ -0,0 +1,31 @@
+/*
+ * Copyright (c) 2018 Red Hat, Inc. and others.  All rights reserved.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License v1.0 which accompanies this distribution,
+ * and is available at http://www.eclipse.org/legal/epl-v10.html
+ */
+package org.opendaylight.controller.cluster.datastore.shardmanager;
+
+import static java.util.Objects.requireNonNull;
+
+import java.util.function.Consumer;
+
+/**
+ * Local ShardManager message to register a callback to be notified of shard availability changes. The reply to
+ * this message is a {@link org.opendaylight.yangtools.concepts.Registration} instance wrapped in a
+ * {@link akka.actor.Status.Success}.
+ *
+ * @author Thomas Pantelis
+ */
+public class RegisterForShardAvailabilityChanges {
+    private final Consumer<String> callback;
+
+    public RegisterForShardAvailabilityChanges(Consumer<String> callback) {
+        this.callback = requireNonNull(callback);
+    }
+
+    public Consumer<String> getCallback() {
+        return callback;
+    }
+}
index b2ac96e1641e81d9a7dd3430601840db2bdb0c4d..ef7e4b0cfa7ae396abee28c0c7143665eff5bb70 100644 (file)
@@ -279,4 +279,13 @@ final class ShardInformation {
     void setSchemaContext(final SchemaContext schemaContext) {
         schemaContextProvider.set(Preconditions.checkNotNull(schemaContext));
     }
+
+    @Override
+    public String toString() {
+        return "ShardInformation [shardId=" + shardId + ", leaderAvailable=" + leaderAvailable + ", actorInitialized="
+                + actorInitialized + ", followerSyncStatus=" + followerSyncStatus + ", role=" + role + ", leaderId="
+                + leaderId + ", activeMember=" + activeMember + "]";
+    }
+
+
 }
index d3fb58e9d15e66e5c25ca40ce60c546cfb6d866e..621b037341c8a183027c9f57190270ff50166fa2 100644 (file)
@@ -115,6 +115,7 @@ import org.opendaylight.controller.cluster.sharding.messages.PrefixShardRemoved;
 import org.opendaylight.controller.md.sal.dom.api.DOMDataTreeChangeListener;
 import org.opendaylight.mdsal.dom.api.DOMDataTreeIdentifier;
 import org.opendaylight.yangtools.concepts.ListenerRegistration;
+import org.opendaylight.yangtools.concepts.Registration;
 import org.opendaylight.yangtools.yang.data.api.YangInstanceIdentifier;
 import org.opendaylight.yangtools.yang.model.api.SchemaContext;
 import org.slf4j.Logger;
@@ -171,6 +172,8 @@ class ShardManager extends AbstractUntypedPersistentActorWithMetering {
 
     private final Map<String, CompositeOnComplete<Boolean>> shardActorsStopping = new HashMap<>();
 
+    private final Set<Consumer<String>> shardAvailabilityCallbacks = new HashSet<>();
+
     private final String persistenceId;
     private final AbstractDataStore dataStore;
 
@@ -301,6 +304,8 @@ class ShardManager extends AbstractUntypedPersistentActorWithMetering {
             onGetShardRole((GetShardRole) message);
         } else if (message instanceof RunnableMessage) {
             ((RunnableMessage)message).run();
+        } else if (message instanceof RegisterForShardAvailabilityChanges) {
+            onRegisterForShardAvailabilityChanges((RegisterForShardAvailabilityChanges)message);
         } else if (message instanceof DeleteSnapshotsFailure) {
             LOG.warn("{}: Failed to delete prior snapshots", persistenceId(),
                     ((DeleteSnapshotsFailure) message).cause());
@@ -315,6 +320,16 @@ class ShardManager extends AbstractUntypedPersistentActorWithMetering {
         }
     }
 
+    private void onRegisterForShardAvailabilityChanges(RegisterForShardAvailabilityChanges message) {
+        LOG.debug("{}: onRegisterForShardAvailabilityChanges: {}", persistenceId(), message);
+
+        final Consumer<String> callback = message.getCallback();
+        shardAvailabilityCallbacks.add(callback);
+
+        getSender().tell(new Status.Success((Registration)
+            () -> executeInSelf(() -> shardAvailabilityCallbacks.remove(callback))), self());
+    }
+
     private void onGetShardRole(final GetShardRole message) {
         LOG.debug("{}: onGetShardRole for shard: {}", persistenceId(), message.getName());
 
@@ -763,6 +778,8 @@ class ShardManager extends AbstractUntypedPersistentActorWithMetering {
             shardInformation.setLeaderVersion(leaderStateChanged.getLeaderPayloadVersion());
             if (shardInformation.setLeaderId(leaderStateChanged.getLeaderId())) {
                 primaryShardInfoCache.remove(shardInformation.getShardName());
+
+                notifyShardAvailabilityCallbacks(shardInformation);
             }
 
             checkReady();
@@ -771,6 +788,10 @@ class ShardManager extends AbstractUntypedPersistentActorWithMetering {
         }
     }
 
+    private void notifyShardAvailabilityCallbacks(ShardInformation shardInformation) {
+        shardAvailabilityCallbacks.forEach(callback -> callback.accept(shardInformation.getShardName()));
+    }
+
     private void onShardNotInitializedTimeout(final ShardNotInitializedTimeout message) {
         ShardInformation shardInfo = message.getShardInfo();
 
@@ -927,7 +948,7 @@ class ShardManager extends AbstractUntypedPersistentActorWithMetering {
                 }
 
                 LOG.debug("{}: Scheduling {} ms timer to wait for shard {}", persistenceId(), timeout.toMillis(),
-                        shardInformation.getShardName());
+                        shardInformation);
 
                 Cancellable timeoutSchedule = getContext().system().scheduler().scheduleOnce(
                         timeout, getSelf(),
@@ -1051,6 +1072,8 @@ class ShardManager extends AbstractUntypedPersistentActorWithMetering {
                 info.setLeaderAvailable(false);
 
                 primaryShardInfoCache.remove(info.getShardName());
+
+                notifyShardAvailabilityCallbacks(info);
             }
 
             info.peerDown(memberName, getShardIdentifier(memberName, info.getShardName()).toString(), getSelf());
index bedd4a9283ec4d4710559e7f971418021f2964b7..8c0db7e392f6667994add0edc0b27799a47b2ed1 100644 (file)
@@ -8,16 +8,24 @@
 package org.opendaylight.controller.cluster.databroker.actors.dds;
 
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.reset;
+import static org.mockito.Mockito.timeout;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
 import static org.mockito.Mockito.when;
 
 import akka.actor.ActorRef;
 import akka.actor.ActorSelection;
 import akka.actor.ActorSystem;
+import akka.actor.Status;
 import akka.testkit.TestProbe;
 import akka.testkit.javadsl.TestKit;
+import com.google.common.util.concurrent.Uninterruptibles;
 import java.util.Collections;
 import java.util.concurrent.CompletionStage;
 import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Consumer;
 import org.junit.After;
 import org.junit.Assert;
 import org.junit.Before;
@@ -33,10 +41,13 @@ import org.opendaylight.controller.cluster.access.concepts.FrontendType;
 import org.opendaylight.controller.cluster.access.concepts.MemberName;
 import org.opendaylight.controller.cluster.access.concepts.RuntimeRequestException;
 import org.opendaylight.controller.cluster.datastore.messages.PrimaryShardInfo;
+import org.opendaylight.controller.cluster.datastore.shardmanager.RegisterForShardAvailabilityChanges;
+import org.opendaylight.controller.cluster.datastore.shardstrategy.DefaultShardStrategy;
 import org.opendaylight.controller.cluster.datastore.shardstrategy.ShardStrategy;
 import org.opendaylight.controller.cluster.datastore.shardstrategy.ShardStrategyFactory;
 import org.opendaylight.controller.cluster.datastore.utils.ActorContext;
 import org.opendaylight.controller.cluster.datastore.utils.PrimaryShardInfoFutureCache;
+import org.opendaylight.yangtools.concepts.Registration;
 import org.opendaylight.yangtools.yang.data.api.YangInstanceIdentifier;
 import org.opendaylight.yangtools.yang.data.api.schema.tree.DataTree;
 import scala.concurrent.Promise;
@@ -51,6 +62,7 @@ public class ModuleShardBackendResolverTest {
     private ActorSystem system;
     private ModuleShardBackendResolver moduleShardBackendResolver;
     private TestProbe contextProbe;
+    private TestProbe shardManagerProbe;
 
     @Mock
     private ShardStrategyFactory shardStrategyFactory;
@@ -64,7 +76,12 @@ public class ModuleShardBackendResolverTest {
         MockitoAnnotations.initMocks(this);
         system = ActorSystem.apply();
         contextProbe = new TestProbe(system, "context");
+
+        shardManagerProbe = new TestProbe(system, "ShardManager");
+
         final ActorContext actorContext = createActorContextMock(system, contextProbe.ref());
+        when(actorContext.getShardManager()).thenReturn(shardManagerProbe.ref());
+
         moduleShardBackendResolver = new ModuleShardBackendResolver(CLIENT_ID, actorContext);
         when(actorContext.getShardStrategyFactory()).thenReturn(shardStrategyFactory);
         when(shardStrategyFactory.getStrategy(YangInstanceIdentifier.EMPTY)).thenReturn(shardStrategy);
@@ -79,7 +96,7 @@ public class ModuleShardBackendResolverTest {
 
     @Test
     public void testResolveShardForPathNonNullCookie() {
-        when(shardStrategy.findShard(YangInstanceIdentifier.EMPTY)).thenReturn("default");
+        when(shardStrategy.findShard(YangInstanceIdentifier.EMPTY)).thenReturn(DefaultShardStrategy.DEFAULT_SHARD);
         final Long cookie = moduleShardBackendResolver.resolveShardForPath(YangInstanceIdentifier.EMPTY);
         Assert.assertEquals(0L, cookie.longValue());
     }
@@ -103,7 +120,7 @@ public class ModuleShardBackendResolverTest {
         final ShardBackendInfo shardBackendInfo = TestUtils.getWithTimeout(stage.toCompletableFuture());
         Assert.assertEquals(0L, shardBackendInfo.getCookie().longValue());
         Assert.assertEquals(dataTree, shardBackendInfo.getDataTree().get());
-        Assert.assertEquals("default", shardBackendInfo.getShardName());
+        Assert.assertEquals(DefaultShardStrategy.DEFAULT_SHARD, shardBackendInfo.getShardName());
     }
 
     @Test
@@ -145,13 +162,36 @@ public class ModuleShardBackendResolverTest {
         Assert.assertEquals(refreshedBackendProbe.ref(), refreshedBackendInfo.getActor());
     }
 
+    @SuppressWarnings("unchecked")
+    @Test
+    public void testNotifyWhenBackendInfoIsStale() {
+        final RegisterForShardAvailabilityChanges regMessage =
+                shardManagerProbe.expectMsgClass(RegisterForShardAvailabilityChanges.class);
+        Registration mockReg = mock(Registration.class);
+        shardManagerProbe.reply(new Status.Success(mockReg));
+
+        Consumer<Long> mockCallback = mock(Consumer.class);
+        final Registration callbackReg = moduleShardBackendResolver.notifyWhenBackendInfoIsStale(mockCallback);
+
+        regMessage.getCallback().accept(DefaultShardStrategy.DEFAULT_SHARD);
+        verify(mockCallback, timeout(5000)).accept(Long.valueOf(0));
+
+        reset(mockCallback);
+        callbackReg.close();
+
+        regMessage.getCallback().accept(DefaultShardStrategy.DEFAULT_SHARD);
+        Uninterruptibles.sleepUninterruptibly(500, TimeUnit.MILLISECONDS);
+        verifyNoMoreInteractions(mockCallback);
+    }
+
     private static ActorContext createActorContextMock(final ActorSystem system, final ActorRef actor) {
         final ActorContext mock = mock(ActorContext.class);
         final Promise<PrimaryShardInfo> promise = new scala.concurrent.impl.Promise.DefaultPromise<>();
         final ActorSelection selection = system.actorSelection(actor.path());
         final PrimaryShardInfo shardInfo = new PrimaryShardInfo(selection, (short) 0);
         promise.success(shardInfo);
-        when(mock.findPrimaryShardAsync("default")).thenReturn(promise.future());
+        when(mock.findPrimaryShardAsync(DefaultShardStrategy.DEFAULT_SHARD)).thenReturn(promise.future());
+        when(mock.getClientDispatcher()).thenReturn(system.dispatchers().defaultGlobalDispatcher());
         return mock;
     }
 }
index 961f32db0478eb583fc5232ee96f1483317f133e..4eea96d752d2c3e4013f095c697e123ca62857b8 100644 (file)
@@ -14,12 +14,15 @@ import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertSame;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
-import static org.mockito.Matchers.anyString;
+import static org.mockito.ArgumentMatchers.anyString;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.reset;
+import static org.mockito.Mockito.timeout;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
 
 import akka.actor.ActorRef;
 import akka.actor.ActorSystem;
@@ -59,6 +62,7 @@ import java.util.Set;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
+import java.util.function.Consumer;
 import java.util.stream.Collectors;
 import org.junit.AfterClass;
 import org.junit.BeforeClass;
@@ -123,12 +127,14 @@ import org.opendaylight.controller.cluster.raft.policy.DisableElectionsRaftPolic
 import org.opendaylight.controller.cluster.raft.utils.InMemorySnapshotStore;
 import org.opendaylight.controller.cluster.raft.utils.MessageCollectorActor;
 import org.opendaylight.controller.md.cluster.datastore.model.TestModel;
+import org.opendaylight.yangtools.concepts.Registration;
 import org.opendaylight.yangtools.yang.data.api.schema.tree.DataTree;
 import org.opendaylight.yangtools.yang.model.api.SchemaContext;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import scala.concurrent.Await;
 import scala.concurrent.Future;
+import scala.concurrent.duration.Duration;
 import scala.concurrent.duration.FiniteDuration;
 
 public class ShardManagerTest extends AbstractShardManagerTest {
@@ -2056,6 +2062,61 @@ public class ShardManagerTest extends AbstractShardManagerTest {
         assertTrue("Failure resposnse", resp.cause() instanceof NoShardLeaderException);
     }
 
+    @SuppressWarnings("unchecked")
+    @Test
+    public void testRegisterForShardLeaderChanges() {
+        LOG.info("testRegisterForShardLeaderChanges starting");
+
+        final String memberId1 = "member-1-shard-default-" + shardMrgIDSuffix;
+        final String memberId2 = "member-2-shard-default-" + shardMrgIDSuffix;
+        final TestKit kit = new TestKit(getSystem());
+        final ActorRef shardManager = actorFactory.createActor(newPropsShardMgrWithMockShardActor());
+
+        shardManager.tell(new UpdateSchemaContext(TEST_SCHEMA_CONTEXT), kit.getRef());
+        shardManager.tell(new ActorInitialized(), mockShardActor);
+
+        final Consumer<String> mockCallback = mock(Consumer.class);
+        shardManager.tell(new RegisterForShardAvailabilityChanges(mockCallback), kit.getRef());
+
+        final Success reply = kit.expectMsgClass(Duration.apply(5, TimeUnit.SECONDS), Success.class);
+        final Registration reg = (Registration) reply.status();
+
+        final DataTree mockDataTree = mock(DataTree.class);
+        shardManager.tell(new ShardLeaderStateChanged(memberId1, memberId1, mockDataTree,
+            DataStoreVersions.CURRENT_VERSION), mockShardActor);
+
+        verify(mockCallback, timeout(5000)).accept("default");
+
+        reset(mockCallback);
+        shardManager.tell(new ShardLeaderStateChanged(memberId1, memberId1, mockDataTree,
+                DataStoreVersions.CURRENT_VERSION), mockShardActor);
+
+        Uninterruptibles.sleepUninterruptibly(500, TimeUnit.MILLISECONDS);
+        verifyNoMoreInteractions(mockCallback);
+
+        shardManager.tell(new ShardLeaderStateChanged(memberId1, null, mockDataTree,
+                DataStoreVersions.CURRENT_VERSION), mockShardActor);
+
+        verify(mockCallback, timeout(5000)).accept("default");
+
+        reset(mockCallback);
+        shardManager.tell(new ShardLeaderStateChanged(memberId1, memberId2, mockDataTree,
+                DataStoreVersions.CURRENT_VERSION), mockShardActor);
+
+        verify(mockCallback, timeout(5000)).accept("default");
+
+        reset(mockCallback);
+        reg.close();
+
+        shardManager.tell(new ShardLeaderStateChanged(memberId1, memberId1, mockDataTree,
+                DataStoreVersions.CURRENT_VERSION), mockShardActor);
+
+        Uninterruptibles.sleepUninterruptibly(500, TimeUnit.MILLISECONDS);
+        verifyNoMoreInteractions(mockCallback);
+
+        LOG.info("testRegisterForShardLeaderChanges ending");
+    }
+
     public static class TestShardManager extends ShardManager {
         private final CountDownLatch recoveryComplete = new CountDownLatch(1);
         private final CountDownLatch snapshotPersist = new CountDownLatch(1);