+
+ public static Builder builder(DatastoreContext.Builder datastoreContextBuilder) {
+ return new Builder(datastoreContextBuilder);
+ }
+
+ private static class Builder extends AbstractGenericBuilder<Builder, TestShardManager> {
+ private ActorRef shardActor;
+ private final Map<String, ActorRef> shardActors = new HashMap<>();
+
+ Builder(DatastoreContext.Builder datastoreContextBuilder) {
+ super(TestShardManager.class);
+ datastoreContextFactory(newDatastoreContextFactory(datastoreContextBuilder.build()));
+ }
+
+ Builder shardActor(ActorRef shardActor) {
+ this.shardActor = shardActor;
+ return this;
+ }
+
+ Builder addShardActor(String shardName, ActorRef actorRef){
+ shardActors.put(shardName, actorRef);
+ return this;
+ }
+ }
+
+ @Override
+ public void saveSnapshot(Object obj) {
+ snapshot = (ShardManagerSnapshot) obj;
+ snapshotPersist.countDown();
+ }
+
+ void verifySnapshotPersisted(Set<String> shardList) {
+ assertEquals("saveSnapshot invoked", true,
+ Uninterruptibles.awaitUninterruptibly(snapshotPersist, 5, TimeUnit.SECONDS));
+ assertEquals("Shard Persisted", shardList, Sets.newHashSet(snapshot.getShardList()));
+ }
+
+ @Override
+ protected ActorRef newShardActor(SchemaContext schemaContext, ShardInformation info) {
+ if(shardActors.get(info.getShardName()) != null){
+ return shardActors.get(info.getShardName());
+ }
+
+ if(shardActor != null) {
+ return shardActor;
+ }
+
+ return super.newShardActor(schemaContext, info);
+ }
+ }
+
+ private static abstract class AbstractGenericBuilder<T extends AbstractGenericBuilder<T, ?>, C extends ShardManager>
+ extends ShardManager.AbstractBuilder<T> {
+ private final Class<C> shardManagerClass;
+
+ AbstractGenericBuilder(Class<C> shardManagerClass) {
+ this.shardManagerClass = shardManagerClass;
+ cluster(new MockClusterWrapper()).configuration(new MockConfiguration()).
+ waitTillReadyCountdownLatch(ready).primaryShardInfoCache(new PrimaryShardInfoFutureCache());
+ }
+
+ @Override
+ public Props props() {
+ verify();
+ return Props.create(shardManagerClass, this);
+ }
+ }
+
+ private static class GenericBuilder<C extends ShardManager> extends AbstractGenericBuilder<GenericBuilder<C>, C> {
+ GenericBuilder(Class<C> shardManagerClass) {
+ super(shardManagerClass);
+ }
+ }
+
+ private static class DelegatingShardManagerCreator implements Creator<ShardManager> {
+ private static final long serialVersionUID = 1L;
+ private final Creator<ShardManager> delegate;
+
+ public DelegatingShardManagerCreator(Creator<ShardManager> delegate) {
+ this.delegate = delegate;
+ }
+
+ @Override
+ public ShardManager create() throws Exception {
+ return delegate.create();
+ }
+ }
+
+ interface MessageInterceptor extends Function<Object, Object> {
+ boolean canIntercept(Object message);
+ }
+
+ private MessageInterceptor newFindPrimaryInterceptor(final ActorRef primaryActor) {
+ return new MessageInterceptor(){
+ @Override
+ public Object apply(Object message) {
+ return new RemotePrimaryShardFound(Serialization.serializedActorPath(primaryActor), (short) 1);
+ }
+
+ @Override
+ public boolean canIntercept(Object message) {
+ return message instanceof FindPrimary;
+ }
+ };
+ }
+
+ private static class MockRespondActor extends MessageCollectorActor {
+ static final String CLEAR_RESPONSE = "clear-response";
+
+ private volatile Object responseMsg;
+
+ @SuppressWarnings("unused")
+ public MockRespondActor() {
+ }
+
+ @SuppressWarnings("unused")
+ public MockRespondActor(Object responseMsg) {
+ this.responseMsg = responseMsg;
+ }
+
+ public void updateResponse(Object response) {
+ responseMsg = response;
+ }
+
+ @Override
+ public void onReceive(Object message) throws Exception {
+ super.onReceive(message);
+ if (message instanceof AddServer) {
+ if (responseMsg != null) {
+ getSender().tell(responseMsg, getSelf());
+ }
+ } if(message.equals(CLEAR_RESPONSE)) {
+ responseMsg = null;
+ }
+ }