Bug 4564: Implement restore from snapshot in RaftActor 80/29280/7
authorTom Pantelis <tpanteli@brocade.com>
Wed, 4 Nov 2015 07:09:45 +0000 (02:09 -0500)
committerGerrit Code Review <gerrit@opendaylight.org>
Fri, 13 Nov 2015 06:57:09 +0000 (06:57 +0000)
The restore snapshot is supplied by the derived actor's
RaftActorRecoveryCohort. If one exists the the RaftActorRecoverySupport
desrializes and applies the snapshot.

I also add a Builder to MockRaftActor to make it easier to pass
additional params.

Change-Id: Ib52b24331038ed48221cc27086fa3cceafe39fcf
Signed-off-by: Tom Pantelis <tpanteli@brocade.com>
12 files changed:
opendaylight/md-sal/sal-akka-raft-example/src/main/java/org/opendaylight/controller/cluster/example/ExampleActor.java
opendaylight/md-sal/sal-akka-raft/src/main/java/org/opendaylight/controller/cluster/raft/RaftActorRecoveryCohort.java
opendaylight/md-sal/sal-akka-raft/src/main/java/org/opendaylight/controller/cluster/raft/RaftActorRecoverySupport.java
opendaylight/md-sal/sal-akka-raft/src/main/java/org/opendaylight/controller/cluster/raft/SnapshotManager.java
opendaylight/md-sal/sal-akka-raft/src/test/java/org/opendaylight/controller/cluster/raft/AbstractRaftActorIntegrationTest.java
opendaylight/md-sal/sal-akka-raft/src/test/java/org/opendaylight/controller/cluster/raft/MockRaftActor.java
opendaylight/md-sal/sal-akka-raft/src/test/java/org/opendaylight/controller/cluster/raft/RaftActorRecoverySupportTest.java
opendaylight/md-sal/sal-akka-raft/src/test/java/org/opendaylight/controller/cluster/raft/RaftActorServerConfigurationSupportTest.java
opendaylight/md-sal/sal-akka-raft/src/test/java/org/opendaylight/controller/cluster/raft/RaftActorTest.java
opendaylight/md-sal/sal-akka-raft/src/test/java/org/opendaylight/controller/cluster/raft/RaftActorTestKit.java
opendaylight/md-sal/sal-akka-raft/src/test/java/org/opendaylight/controller/cluster/raft/utils/InMemorySnapshotStore.java
opendaylight/md-sal/sal-distributed-datastore/src/main/java/org/opendaylight/controller/cluster/datastore/ShardRecoveryCoordinator.java

index 3ce8364fcea603647d794bef1c92a6d29def76fa..9b7c2e8d4b0f3af9d522a55616d86a82f72a9cda 100644 (file)
@@ -226,4 +226,9 @@ public class ExampleActor extends RaftActor implements RaftActorRecoveryCohort,
     protected RaftActorSnapshotCohort getRaftActorSnapshotCohort() {
         return this;
     }
+
+    @Override
+    public byte[] getRestoreFromSnapshot() {
+        return null;
+    }
 }
index a9f00aa80bcb93621d6f1ef3fadb4695dbe3b8c5..30e27e17fe4d483f872ae9e71dc6ae289b65cf83 100644 (file)
@@ -7,6 +7,7 @@
  */
 package org.opendaylight.controller.cluster.raft;
 
+import javax.annotation.Nullable;
 import org.opendaylight.controller.cluster.raft.protobuff.client.messages.Payload;
 
 /**
@@ -42,4 +43,12 @@ public interface RaftActorRecoveryCohort {
      * log entries. This method is called after {@link #appendRecoveredLogEntry}.
      */
     void applyCurrentLogRecoveryBatch();
+
+    /**
+     * Returns the state snapshot to restore from on recovery.
+     *
+     * @return the snapshot bytes or null if there's no snapshot to restore
+     */
+    @Nullable
+    byte[] getRestoreFromSnapshot();
 }
index 05405dc6dfc2457d9a33a5340a619ba8990064f3..0a37ef7a466725eb963aa6253b71c04b217deb10 100644 (file)
@@ -11,10 +11,13 @@ import akka.persistence.RecoveryCompleted;
 import akka.persistence.SnapshotOffer;
 import akka.persistence.SnapshotSelectionCriteria;
 import com.google.common.base.Stopwatch;
+import java.io.ByteArrayInputStream;
+import java.io.ObjectInputStream;
 import org.opendaylight.controller.cluster.DataPersistenceProvider;
 import org.opendaylight.controller.cluster.PersistentDataProvider;
 import org.opendaylight.controller.cluster.raft.base.messages.ApplyJournalEntries;
 import org.opendaylight.controller.cluster.raft.base.messages.ApplyLogEntries;
+import org.opendaylight.controller.cluster.raft.base.messages.ApplySnapshot;
 import org.opendaylight.controller.cluster.raft.base.messages.DeleteEntries;
 import org.opendaylight.controller.cluster.raft.base.messages.UpdateElectionTerm;
 import org.opendaylight.controller.cluster.raft.behaviors.RaftActorBehavior;
@@ -32,6 +35,7 @@ class RaftActorRecoverySupport {
 
     private int currentRecoveryBatchCount;
     private boolean dataRecoveredWithPersistenceDisabled;
+    private boolean anyDataRecovered;
 
     private Stopwatch recoveryTimer;
     private final Logger log;
@@ -45,7 +49,9 @@ class RaftActorRecoverySupport {
     }
 
     boolean handleRecoveryMessage(Object message, PersistentDataProvider persistentProvider) {
-        log.trace("handleRecoveryMessage: {}", message);
+        log.trace("{}: handleRecoveryMessage: {}", context.getId(), message);
+
+        anyDataRecovered = anyDataRecovered || !(message instanceof RecoveryCompleted);
 
         boolean recoveryComplete = false;
         DataPersistenceProvider persistence = context.getPersistenceProvider();
@@ -74,6 +80,7 @@ class RaftActorRecoverySupport {
                 replicatedLog().removeFrom(((org.opendaylight.controller.cluster.raft.RaftActor.DeleteEntries) message).getFromIndex());
             } else if (message instanceof RecoveryCompleted) {
                 onRecoveryCompletedMessage();
+                possiblyRestoreFromSnapshot();
                 recoveryComplete = true;
             }
         } else if (message instanceof RecoveryCompleted) {
@@ -94,6 +101,8 @@ class RaftActorRecoverySupport {
                 context.getTermInformation().updateAndPersist(context.getTermInformation().getCurrentTerm(),
                         context.getTermInformation().getVotedFor());
             }
+
+            possiblyRestoreFromSnapshot();
         } else {
             boolean isServerConfigPayload = false;
             if(message instanceof ReplicatedLogEntry){
@@ -112,6 +121,29 @@ class RaftActorRecoverySupport {
         return recoveryComplete;
     }
 
+    private void possiblyRestoreFromSnapshot() {
+        byte[] restoreFromSnapshot = cohort.getRestoreFromSnapshot();
+        if(restoreFromSnapshot == null) {
+            return;
+        }
+
+        if(anyDataRecovered) {
+            log.warn("{}: The provided restore snapshot was not applied because the persistence store is not empty",
+                    context.getId());
+            return;
+        }
+
+        try(ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(restoreFromSnapshot))) {
+            Snapshot snapshot = (Snapshot) ois.readObject();
+
+            log.debug("{}: Deserialized restore snapshot: {}", context.getId(), snapshot);
+
+            context.getSnapshotManager().apply(new ApplySnapshot(snapshot));
+        } catch(Exception e) {
+            log.error("{}: Error deserializing snapshot restore", context.getId(), e);
+        }
+    }
+
     private ReplicatedLog replicatedLog() {
         return context.getReplicatedLog();
     }
@@ -181,7 +213,7 @@ class RaftActorRecoverySupport {
                 batchRecoveredLogEntry(logEntry);
             } else {
                 // Shouldn't happen but cover it anyway.
-                log.error("Log entry not found for index {}", i);
+                log.error("{}: Log entry not found for index {}", context.getId(), i);
                 break;
             }
         }
index 9571173175ff220aa3fb5b6b9ba757eee150ffe4..0d0c910298f7cee9975b1d460c762b603a7ff9a1 100644 (file)
@@ -387,12 +387,16 @@ public class SnapshotManager implements SnapshotState {
             if(applySnapshot != null) {
                 try {
                     Snapshot snapshot = applySnapshot.getSnapshot();
-                    applySnapshotProcedure.apply(snapshot.getState());
 
                     //clears the followers log, sets the snapshot index to ensure adjusted-index works
                     context.setReplicatedLog(ReplicatedLogImpl.newInstance(snapshot, context, currentBehavior));
                     context.setLastApplied(snapshot.getLastAppliedIndex());
                     context.setCommitIndex(snapshot.getLastAppliedIndex());
+                    context.getTermInformation().update(snapshot.getElectionTerm(), snapshot.getElectionVotedFor());
+
+                    if(snapshot.getState().length > 0 ) {
+                        applySnapshotProcedure.apply(snapshot.getState());
+                    }
 
                     applySnapshot.getCallback().onSuccess();
                 } catch (Exception e) {
index 7cd893691284b7349e4d09efacd9911fdeae4025..30ead98cb4060edefba42a5f4583113523c109f3 100644 (file)
@@ -17,7 +17,6 @@ import akka.actor.Terminated;
 import akka.dispatch.Dispatchers;
 import akka.testkit.JavaTestKit;
 import akka.testkit.TestActorRef;
-import com.google.common.base.Optional;
 import com.google.common.base.Predicate;
 import com.google.common.base.Supplier;
 import com.google.common.collect.ImmutableMap;
@@ -75,7 +74,7 @@ public abstract class AbstractRaftActorIntegrationTest extends AbstractActorTest
 
         private TestRaftActor(String id, Map<String, String> peerAddresses, ConfigParams config,
                 TestActorRef<MessageCollectorActor> collectorActor) {
-            super(id, peerAddresses, Optional.of(config), null);
+            super(builder().id(id).peerAddresses(peerAddresses).config(config));
             this.collectorActor = collectorActor;
         }
 
index f56638bc8223fa3593c92e67a59a4296068de1a2..38650e834f1eb015a0831a6078dd88bf904a22ba 100644 (file)
@@ -12,13 +12,13 @@ import static org.junit.Assert.assertEquals;
 import static org.mockito.Mockito.mock;
 import akka.actor.ActorRef;
 import akka.actor.Props;
-import akka.japi.Creator;
 import com.google.common.base.Optional;
 import com.google.common.util.concurrent.Uninterruptibles;
 import java.io.ByteArrayInputStream;
 import java.io.IOException;
 import java.io.ObjectInputStream;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.CountDownLatch;
@@ -37,52 +37,29 @@ public class MockRaftActor extends RaftActor implements RaftActorRecoveryCohort,
     volatile RaftActorSnapshotCohort snapshotCohortDelegate;
     private final CountDownLatch recoveryComplete = new CountDownLatch(1);
     private final List<Object> state;
-    private ActorRef roleChangeNotifier;
+    private final ActorRef roleChangeNotifier;
     protected final CountDownLatch initializeBehaviorComplete = new CountDownLatch(1);
     private RaftActorRecoverySupport raftActorRecoverySupport;
     private RaftActorSnapshotMessageSupport snapshotMessageSupport;
+    private final byte[] restoreFromSnapshot;
+    final CountDownLatch snapshotCommitted = new CountDownLatch(1);
 
-    public static final class MockRaftActorCreator implements Creator<MockRaftActor> {
-        private static final long serialVersionUID = 1L;
-        private final Map<String, String> peerAddresses;
-        private final String id;
-        private final Optional<ConfigParams> config;
-        private final DataPersistenceProvider dataPersistenceProvider;
-        private final ActorRef roleChangeNotifier;
-        private RaftActorSnapshotMessageSupport snapshotMessageSupport;
-
-        private MockRaftActorCreator(Map<String, String> peerAddresses, String id,
-            Optional<ConfigParams> config, DataPersistenceProvider dataPersistenceProvider,
-            ActorRef roleChangeNotifier) {
-            this.peerAddresses = peerAddresses;
-            this.id = id;
-            this.config = config;
-            this.dataPersistenceProvider = dataPersistenceProvider;
-            this.roleChangeNotifier = roleChangeNotifier;
-        }
-
-        @Override
-        public MockRaftActor create() throws Exception {
-            MockRaftActor mockRaftActor = new MockRaftActor(id, peerAddresses, config,
-                dataPersistenceProvider);
-            mockRaftActor.roleChangeNotifier = this.roleChangeNotifier;
-            mockRaftActor.snapshotMessageSupport = snapshotMessageSupport;
-            return mockRaftActor;
-        }
-    }
-
-    public MockRaftActor(String id, Map<String, String> peerAddresses, Optional<ConfigParams> config,
-                         DataPersistenceProvider dataPersistenceProvider) {
-        super(id, peerAddresses, config, PAYLOAD_VERSION);
+    protected MockRaftActor(Builder builder) {
+        super(builder.id, builder.peerAddresses, Optional.fromNullable(builder.config), PAYLOAD_VERSION);
         state = new ArrayList<>();
         this.actorDelegate = mock(RaftActor.class);
         this.recoveryCohortDelegate = mock(RaftActorRecoveryCohort.class);
         this.snapshotCohortDelegate = mock(RaftActorSnapshotCohort.class);
-        if(dataPersistenceProvider == null){
-            setPersistence(true);
+
+        if(builder.dataPersistenceProvider == null){
+            setPersistence(builder.persistent.isPresent() ? builder.persistent.get() : true);
         } else {
-            setPersistence(dataPersistenceProvider);
+            setPersistence(builder.dataPersistenceProvider);
         }
+
+        roleChangeNotifier = builder.roleChangeNotifier;
+        snapshotMessageSupport = builder.snapshotMessageSupport;
+        restoreFromSnapshot = builder.restoreFromSnapshot;
     }
 
     public void setRaftActorRecoverySupport(RaftActorRecoverySupport support) {
@@ -134,33 +111,6 @@ public class MockRaftActor extends RaftActor implements RaftActorRecoveryCohort,
         return state;
     }
 
-    public static Props props(final String id, final Map<String, String> peerAddresses,
-            Optional<ConfigParams> config){
-        return Props.create(new MockRaftActorCreator(peerAddresses, id, config, null, null));
-    }
-
-    public static Props props(final String id, final Map<String, String> peerAddresses,
-            Optional<ConfigParams> config, RaftActorSnapshotMessageSupport snapshotMessageSupport){
-        MockRaftActorCreator creator = new MockRaftActorCreator(peerAddresses, id, config, null, null);
-        creator.snapshotMessageSupport = snapshotMessageSupport;
-        return Props.create(creator);
-    }
-
-    public static Props props(final String id, final Map<String, String> peerAddresses,
-                              Optional<ConfigParams> config, DataPersistenceProvider dataPersistenceProvider){
-        return Props.create(new MockRaftActorCreator(peerAddresses, id, config, dataPersistenceProvider, null));
-    }
-
-    public static Props props(final String id, final Map<String, String> peerAddresses,
-        Optional<ConfigParams> config, ActorRef roleChangeNotifier){
-        return Props.create(new MockRaftActorCreator(peerAddresses, id, config, null, roleChangeNotifier));
-    }
-
-    public static Props props(final String id, final Map<String, String> peerAddresses,
-                              Optional<ConfigParams> config, ActorRef roleChangeNotifier,
-                              DataPersistenceProvider dataPersistenceProvider){
-        return Props.create(new MockRaftActorCreator(peerAddresses, id, config, dataPersistenceProvider, roleChangeNotifier));
-    }
 
     @Override protected void applyState(ActorRef clientActor, String identifier, Object data) {
         actorDelegate.applyState(clientActor, identifier, data);
@@ -231,8 +181,8 @@ public class MockRaftActor extends RaftActor implements RaftActorRecoveryCohort,
     @Override
     public void applySnapshot(byte [] snapshot) {
         LOG.info("{}: applySnapshot called", persistenceId());
-        snapshotCohortDelegate.applySnapshot(snapshot);
         applySnapshotBytes(snapshot);
+        snapshotCohortDelegate.applySnapshot(snapshot);
     }
 
     @Override
@@ -259,6 +209,10 @@ public class MockRaftActor extends RaftActor implements RaftActorRecoveryCohort,
             super.changeCurrentBehavior((RaftActorBehavior)message);
         } else {
             super.handleCommand(message);
+
+            if(RaftActorSnapshotMessageSupport.COMMIT_SNAPSHOT.equals(message)) {
+                snapshotCommitted.countDown();
+            }
         }
     }
 
@@ -284,4 +238,79 @@ public class MockRaftActor extends RaftActor implements RaftActorRecoveryCohort,
     public ReplicatedLog getReplicatedLog(){
         return this.getRaftActorContext().getReplicatedLog();
     }
+
+    @Override
+    public byte[] getRestoreFromSnapshot() {
+        return restoreFromSnapshot;
+    }
+
+    public static Props props(final String id, final Map<String, String> peerAddresses,
+            ConfigParams config){
+        return builder().id(id).peerAddresses(peerAddresses).config(config).props();
+    }
+
+    public static Props props(final String id, final Map<String, String> peerAddresses,
+                              ConfigParams config, DataPersistenceProvider dataPersistenceProvider){
+        return builder().id(id).peerAddresses(peerAddresses).config(config).
+                dataPersistenceProvider(dataPersistenceProvider).props();
+    }
+
+    public static Builder builder() {
+        return new Builder();
+    }
+
+    public static class Builder {
+        private Map<String, String> peerAddresses = Collections.emptyMap();
+        private String id;
+        private ConfigParams config;
+        private DataPersistenceProvider dataPersistenceProvider;
+        private ActorRef roleChangeNotifier;
+        private RaftActorSnapshotMessageSupport snapshotMessageSupport;
+        private byte[] restoreFromSnapshot;
+        private Optional<Boolean> persistent = Optional.absent();
+
+        public Builder id(String id) {
+            this.id = id;
+            return this;
+        }
+
+        public Builder peerAddresses(Map<String, String> peerAddresses) {
+            this.peerAddresses = peerAddresses;
+            return this;
+        }
+
+        public Builder config(ConfigParams config) {
+            this.config = config;
+            return this;
+        }
+
+        public Builder dataPersistenceProvider(DataPersistenceProvider dataPersistenceProvider) {
+            this.dataPersistenceProvider = dataPersistenceProvider;
+            return this;
+        }
+
+        public Builder roleChangeNotifier(ActorRef roleChangeNotifier) {
+            this.roleChangeNotifier = roleChangeNotifier;
+            return this;
+        }
+
+        public Builder snapshotMessageSupport(RaftActorSnapshotMessageSupport snapshotMessageSupport) {
+            this.snapshotMessageSupport = snapshotMessageSupport;
+            return this;
+        }
+
+        public Builder restoreFromSnapshot(byte[] restoreFromSnapshot) {
+            this.restoreFromSnapshot = restoreFromSnapshot;
+            return this;
+        }
+
+        public Builder persistent(Optional<Boolean> persistent) {
+            this.persistent = persistent;
+            return this;
+        }
+
+        public Props props() {
+            return Props.create(MockRaftActor.class, this);
+        }
+    }
 }
index da02e81fa82b3cd67c39785fe5669ca6d7a054c0..ddc8bed42a452767b6229d8996c84e06ac56e1b1 100644 (file)
@@ -240,7 +240,7 @@ public class RaftActorRecoverySupportTest {
         }
 
         inOrder.verify(mockCohort).applyCurrentLogRecoveryBatch();
-
+        inOrder.verify(mockCohort).getRestoreFromSnapshot();
         inOrder.verifyNoMoreInteractions();
     }
 
@@ -248,6 +248,7 @@ public class RaftActorRecoverySupportTest {
     public void testOnRecoveryCompletedWithNoRemainingBatch() {
         sendMessageToSupport(RecoveryCompleted.getInstance(), true);
 
+        verify(mockCohort).getRestoreFromSnapshot();
         verifyNoMoreInteractions(mockCohort);
     }
 
@@ -337,6 +338,7 @@ public class RaftActorRecoverySupportTest {
 
         sendMessageToSupport(RecoveryCompleted.getInstance(), true);
 
+        verify(mockCohort).getRestoreFromSnapshot();
         verifyNoMoreInteractions(mockCohort);
 
         verify(mockPersistentProvider).deleteMessages(10L);
@@ -370,6 +372,7 @@ public class RaftActorRecoverySupportTest {
 
         sendMessageToSupport(RecoveryCompleted.getInstance(), true);
 
+        verify(mockCohort).getRestoreFromSnapshot();
         verifyNoMoreInteractions(mockCohort, mockPersistentProvider);
     }
 
index 16acb410fb169702ee82cc7908f3e0b61abc5eb2..df526d8c52a0773466175b098c585ecd90646e99 100644 (file)
@@ -565,7 +565,7 @@ public class RaftActorServerConfigurationSupportTest extends AbstractActorTest {
 
         TestActorRef<MockRaftActor> noLeaderActor = actorFactory.createTestActor(
                 MockRaftActor.props(LEADER_ID, ImmutableMap.<String,String>of(FOLLOWER_ID, followerActor.path().toString()),
-                        Optional.<ConfigParams>of(configParams), NO_PERSISTENCE).withDispatcher(Dispatchers.DefaultDispatcherId()),
+                        configParams, NO_PERSISTENCE).withDispatcher(Dispatchers.DefaultDispatcherId()),
                 actorFactory.generateActorId(LEADER_ID));
         noLeaderActor.underlyingActor().waitForInitializeBehaviorComplete();
 
@@ -626,7 +626,7 @@ public class RaftActorServerConfigurationSupportTest extends AbstractActorTest {
 
         TestActorRef<MockRaftActor> followerRaftActor = actorFactory.createTestActor(
                 MockRaftActor.props(FOLLOWER_ID, ImmutableMap.<String,String>of(LEADER_ID, leaderActor.path().toString()),
-                        Optional.<ConfigParams>of(configParams), NO_PERSISTENCE).withDispatcher(Dispatchers.DefaultDispatcherId()),
+                        configParams, NO_PERSISTENCE).withDispatcher(Dispatchers.DefaultDispatcherId()),
                 actorFactory.generateActorId(FOLLOWER_ID));
         followerRaftActor.underlyingActor().waitForInitializeBehaviorComplete();
 
@@ -691,7 +691,8 @@ public class RaftActorServerConfigurationSupportTest extends AbstractActorTest {
 
         AbstractMockRaftActor(String id, Map<String, String> peerAddresses, Optional<ConfigParams> config,
                 DataPersistenceProvider dataPersistenceProvider, TestActorRef<MessageCollectorActor> collectorActor) {
-            super(id, peerAddresses, config, dataPersistenceProvider);
+            super(builder().id(id).peerAddresses(peerAddresses).config(config.get()).
+                    dataPersistenceProvider(dataPersistenceProvider));
             this.collectorActor = collectorActor;
         }
 
index 941deb5843e2749a08232ba8b1cf5a55921419fb..c2ee4a26d14e040de00fed6acc0601232ff23cca 100644 (file)
@@ -62,6 +62,7 @@ import org.opendaylight.controller.cluster.NonPersistentDataProvider;
 import org.opendaylight.controller.cluster.PersistentDataProvider;
 import org.opendaylight.controller.cluster.notifications.LeaderStateChanged;
 import org.opendaylight.controller.cluster.notifications.RoleChanged;
+import org.opendaylight.controller.cluster.raft.MockRaftActorContext.MockPayload;
 import org.opendaylight.controller.cluster.raft.base.messages.ApplyJournalEntries;
 import org.opendaylight.controller.cluster.raft.base.messages.ApplyLogEntries;
 import org.opendaylight.controller.cluster.raft.base.messages.ApplySnapshot;
@@ -134,7 +135,7 @@ public class RaftActorTest extends AbstractActorTest {
 
             ImmutableMap<String, String> peerAddresses = ImmutableMap.<String, String>builder().put("member1", "address").build();
             ActorRef followerActor = factory.createActor(MockRaftActor.props(persistenceId,
-                    peerAddresses, Optional.<ConfigParams>of(config)), persistenceId);
+                    peerAddresses, config), persistenceId);
 
             watch(followerActor);
 
@@ -187,7 +188,7 @@ public class RaftActorTest extends AbstractActorTest {
 
             //reinstate the actor
             TestActorRef<MockRaftActor> ref = factory.createTestActor(
-                    MockRaftActor.props(persistenceId, peerAddresses, Optional.<ConfigParams>of(config)));
+                    MockRaftActor.props(persistenceId, peerAddresses, config));
 
             MockRaftActor mockRaftActor = ref.underlyingActor();
 
@@ -221,7 +222,7 @@ public class RaftActorTest extends AbstractActorTest {
 
             TestActorRef<MockRaftActor> ref = factory.createTestActor(MockRaftActor.props(persistenceId,
                     ImmutableMap.<String, String>builder().put("member1", "address").build(),
-                    Optional.<ConfigParams>of(config), new NonPersistentDataProvider()), persistenceId);
+                    config, new NonPersistentDataProvider()), persistenceId);
 
             MockRaftActor mockRaftActor = ref.underlyingActor();
 
@@ -245,7 +246,7 @@ public class RaftActorTest extends AbstractActorTest {
 
             TestActorRef<MockRaftActor> ref = factory.createTestActor(MockRaftActor.props(persistenceId,
                     ImmutableMap.<String, String>builder().put("member1", "address").build(),
-                    Optional.<ConfigParams>of(config), new NonPersistentDataProvider()).
+                    config, new NonPersistentDataProvider()).
                             withDispatcher(Dispatchers.DefaultDispatcherId()), persistenceId);
 
             InMemoryJournal.waitForWriteMessagesComplete(persistenceId);
@@ -257,8 +258,8 @@ public class RaftActorTest extends AbstractActorTest {
 
             config.setHeartBeatInterval(new FiniteDuration(1, TimeUnit.DAYS));
             ref = factory.createTestActor(MockRaftActor.props(persistenceId,
-                    ImmutableMap.<String, String>builder().put("member1", "address").build(),
-                    Optional.<ConfigParams>of(config), new NonPersistentDataProvider()).
+                    ImmutableMap.<String, String>builder().put("member1", "address").build(), config,
+                    new NonPersistentDataProvider()).
                             withDispatcher(Dispatchers.DefaultDispatcherId()),
                             factory.generateActorId("follower-"));
 
@@ -284,7 +285,7 @@ public class RaftActorTest extends AbstractActorTest {
         config.setHeartBeatInterval(new FiniteDuration(1, TimeUnit.DAYS));
 
         TestActorRef<MockRaftActor> mockActorRef = factory.createTestActor(MockRaftActor.props(persistenceId,
-                Collections.<String, String>emptyMap(), Optional.<ConfigParams>of(config)), persistenceId);
+                Collections.<String, String>emptyMap(), config), persistenceId);
 
         MockRaftActor mockRaftActor = mockActorRef.underlyingActor();
 
@@ -342,8 +343,8 @@ public class RaftActorTest extends AbstractActorTest {
 
         RaftActorSnapshotMessageSupport mockSupport = mock(RaftActorSnapshotMessageSupport.class);
 
-        TestActorRef<MockRaftActor> mockActorRef = factory.createTestActor(MockRaftActor.props(persistenceId,
-                Collections.<String, String>emptyMap(), Optional.<ConfigParams>of(config), mockSupport), persistenceId);
+        TestActorRef<MockRaftActor> mockActorRef = factory.createTestActor(MockRaftActor.builder().id(persistenceId).
+                config(config).snapshotMessageSupport(mockSupport).props());
 
         MockRaftActor mockRaftActor = mockActorRef.underlyingActor();
 
@@ -400,7 +401,7 @@ public class RaftActorTest extends AbstractActorTest {
                 DataPersistenceProvider dataPersistenceProvider = mock(DataPersistenceProvider.class);
 
                 TestActorRef<MockRaftActor> mockActorRef = factory.createTestActor(MockRaftActor.props(persistenceId,
-                        Collections.<String, String>emptyMap(), Optional.<ConfigParams>of(config), dataPersistenceProvider), persistenceId);
+                        Collections.<String, String>emptyMap(), config, dataPersistenceProvider), persistenceId);
 
                 MockRaftActor mockRaftActor = mockActorRef.underlyingActor();
 
@@ -431,7 +432,7 @@ public class RaftActorTest extends AbstractActorTest {
                 DataPersistenceProvider dataPersistenceProvider = mock(DataPersistenceProvider.class);
 
                 TestActorRef<MockRaftActor> mockActorRef = factory.createTestActor(MockRaftActor.props(persistenceId,
-                        Collections.<String, String>emptyMap(), Optional.<ConfigParams>of(config), dataPersistenceProvider), persistenceId);
+                        Collections.<String, String>emptyMap(), config, dataPersistenceProvider), persistenceId);
 
                 MockRaftActor mockRaftActor = mockActorRef.underlyingActor();
 
@@ -462,9 +463,10 @@ public class RaftActorTest extends AbstractActorTest {
 
             String persistenceId = factory.generateActorId("notifier-");
 
-            TestActorRef<MockRaftActor> raftActorRef = factory.createTestActor(MockRaftActor.props(persistenceId,
-                    Collections.<String, String>emptyMap(), Optional.<ConfigParams>of(config), notifierActor,
-                    new NonPersistentDataProvider()).withDispatcher(Dispatchers.DefaultDispatcherId()), persistenceId);
+            TestActorRef<MockRaftActor> raftActorRef = factory.createTestActor(MockRaftActor.builder().id(persistenceId).
+                    config(config).roleChangeNotifier(notifierActor).dataPersistenceProvider(
+                            new NonPersistentDataProvider()).props().withDispatcher(Dispatchers.DefaultDispatcherId()),
+                            persistenceId);
 
             List<RoleChanged> matches =  MessageCollectorActor.expectMatching(notifierActor, RoleChanged.class, 3);
 
@@ -549,8 +551,9 @@ public class RaftActorTest extends AbstractActorTest {
 
             String persistenceId = factory.generateActorId("notifier-");
 
-            factory.createActor(MockRaftActor.props(persistenceId,
-                    ImmutableMap.of("leader", "fake/path"), Optional.<ConfigParams>of(config), notifierActor), persistenceId);
+            factory.createActor(MockRaftActor.builder().id(persistenceId).
+                    peerAddresses(ImmutableMap.of("leader", "fake/path")).
+                    config(config).roleChangeNotifier(notifierActor).props());
 
             List<RoleChanged> matches =  null;
             for(int i = 0; i < 5000 / heartBeatInterval; i++) {
@@ -600,8 +603,7 @@ public class RaftActorTest extends AbstractActorTest {
                 peerAddresses.put(follower1Id, followerActor1.path().toString());
 
                 TestActorRef<MockRaftActor> mockActorRef = factory.createTestActor(
-                        MockRaftActor.props(persistenceId, peerAddresses,
-                                Optional.<ConfigParams>of(config), dataPersistenceProvider), persistenceId);
+                        MockRaftActor.props(persistenceId, peerAddresses, config, dataPersistenceProvider), persistenceId);
 
                 MockRaftActor leaderActor = mockActorRef.underlyingActor();
 
@@ -698,8 +700,7 @@ public class RaftActorTest extends AbstractActorTest {
                 peerAddresses.put(leaderId, leaderActor1.path().toString());
 
                 TestActorRef<MockRaftActor> mockActorRef = factory.createTestActor(
-                        MockRaftActor.props(persistenceId, peerAddresses,
-                                Optional.<ConfigParams>of(config), dataPersistenceProvider), persistenceId);
+                        MockRaftActor.props(persistenceId, peerAddresses, config, dataPersistenceProvider), persistenceId);
 
                 MockRaftActor followerActor = mockActorRef.underlyingActor();
                 followerActor.getRaftActorContext().setCommitIndex(4);
@@ -807,8 +808,7 @@ public class RaftActorTest extends AbstractActorTest {
                 peerAddresses.put(follower2Id, followerActor2.path().toString());
 
                 TestActorRef<MockRaftActor> mockActorRef = factory.createTestActor(
-                        MockRaftActor.props(persistenceId, peerAddresses,
-                                Optional.<ConfigParams>of(config), dataPersistenceProvider), persistenceId);
+                        MockRaftActor.props(persistenceId, peerAddresses, config, dataPersistenceProvider), persistenceId);
 
                 MockRaftActor leaderActor = mockActorRef.underlyingActor();
                 leaderActor.getRaftActorContext().setCommitIndex(9);
@@ -885,8 +885,7 @@ public class RaftActorTest extends AbstractActorTest {
             Map<String, String> peerAddresses = ImmutableMap.<String, String>builder().put("member1", "address").build();
 
             TestActorRef<MockRaftActor> mockActorRef = factory.createTestActor(
-                    MockRaftActor.props(persistenceId, peerAddresses,
-                            Optional.<ConfigParams>of(config), dataPersistenceProvider), persistenceId);
+                    MockRaftActor.props(persistenceId, peerAddresses, config, dataPersistenceProvider), persistenceId);
 
             MockRaftActor leaderActor = mockActorRef.underlyingActor();
             leaderActor.getRaftActorContext().setCommitIndex(3);
@@ -933,8 +932,7 @@ public class RaftActorTest extends AbstractActorTest {
             Map<String, String> peerAddresses = ImmutableMap.<String, String>builder().put("member1", "address").build();
 
             TestActorRef<MockRaftActor> mockActorRef = factory.createTestActor(
-                    MockRaftActor.props(persistenceId, peerAddresses,
-                            Optional.<ConfigParams>of(config), dataPersistenceProvider), persistenceId);
+                    MockRaftActor.props(persistenceId, peerAddresses, config, dataPersistenceProvider), persistenceId);
 
             MockRaftActor leaderActor = mockActorRef.underlyingActor();
             leaderActor.getRaftActorContext().setCommitIndex(3);
@@ -984,7 +982,7 @@ public class RaftActorTest extends AbstractActorTest {
                 InMemoryJournal.addEntry(persistenceId, 1, replLogEntry);
 
                 TestActorRef<MockRaftActor> ref = factory.createTestActor(
-                        MockRaftActor.props(persistenceId, peerAddresses, Optional.<ConfigParams>of(config)));
+                        MockRaftActor.props(persistenceId, peerAddresses, config));
 
                 MockRaftActor mockRaftActor = ref.underlyingActor();
 
@@ -1010,8 +1008,7 @@ public class RaftActorTest extends AbstractActorTest {
         Map<String, String> peerAddresses = ImmutableMap.<String, String>builder().build();
 
         TestActorRef<MockRaftActor> mockActorRef = factory.createTestActor(
-                MockRaftActor.props(persistenceId, peerAddresses,
-                        Optional.<ConfigParams>of(config), dataPersistenceProvider), persistenceId);
+                MockRaftActor.props(persistenceId, peerAddresses, config, dataPersistenceProvider), persistenceId);
 
         MockRaftActor leaderActor = mockActorRef.underlyingActor();
 
@@ -1067,8 +1064,7 @@ public class RaftActorTest extends AbstractActorTest {
         DataPersistenceProvider dataPersistenceProvider = mock(DataPersistenceProvider.class);
 
         TestActorRef<MockRaftActor> actorRef = factory.createTestActor(
-                MockRaftActor.props(persistenceId, peerAddresses,
-                    Optional.<ConfigParams>of(emptyConfig), dataPersistenceProvider), persistenceId);
+                MockRaftActor.props(persistenceId, peerAddresses, emptyConfig, dataPersistenceProvider), persistenceId);
         MockRaftActor mockRaftActor = actorRef.underlyingActor();
         mockRaftActor.waitForInitializeBehaviorComplete();
 
@@ -1129,7 +1125,7 @@ public class RaftActorTest extends AbstractActorTest {
                 new MockRaftActorContext.MockPayload("C")));
 
         TestActorRef<MockRaftActor> raftActorRef = factory.createTestActor(MockRaftActor.props(persistenceId,
-                ImmutableMap.<String, String>builder().put("member1", "address").build(), Optional.<ConfigParams>of(config)).
+                ImmutableMap.<String, String>builder().put("member1", "address").build(), config).
                     withDispatcher(Dispatchers.DefaultDispatcherId()), persistenceId);
         MockRaftActor mockRaftActor = raftActorRef.underlyingActor();
 
@@ -1195,4 +1191,118 @@ public class RaftActorTest extends AbstractActorTest {
 
         TEST_LOG.info("testGetSnapshot ending");
     }
+
+    @Test
+    public void testRestoreFromSnapshot() throws Exception {
+        TEST_LOG.info("testRestoreFromSnapshot starting");
+
+        String persistenceId = factory.generateActorId("test-actor-");
+        DefaultConfigParamsImpl config = new DefaultConfigParamsImpl();
+        config.setCustomRaftPolicyImplementationClass(DisableElectionsRaftPolicy.class.getName());
+
+        List<ReplicatedLogEntry> snapshotUnappliedEntries = new ArrayList<>();
+        snapshotUnappliedEntries.add(new MockRaftActorContext.MockReplicatedLogEntry(1, 4,
+                new MockRaftActorContext.MockPayload("E")));
+
+        int snapshotLastApplied = 3;
+        int snapshotLastIndex = 4;
+
+        List<MockPayload> state = Arrays.asList(
+                new MockRaftActorContext.MockPayload("A"),
+                new MockRaftActorContext.MockPayload("B"),
+                new MockRaftActorContext.MockPayload("C"),
+                new MockRaftActorContext.MockPayload("D"));
+        ByteString stateBytes = fromObject(state);
+
+        Snapshot snapshot = Snapshot.create(stateBytes.toByteArray(), snapshotUnappliedEntries,
+                snapshotLastIndex, 1, snapshotLastApplied, 1, 1, "member-1");
+
+        InMemorySnapshotStore.addSnapshotSavedLatch(persistenceId);
+
+        TestActorRef<MockRaftActor> raftActorRef = factory.createTestActor(MockRaftActor.builder().id(persistenceId).
+                config(config).restoreFromSnapshot(SerializationUtils.serialize(snapshot)).props().
+                    withDispatcher(Dispatchers.DefaultDispatcherId()), persistenceId);
+        MockRaftActor mockRaftActor = raftActorRef.underlyingActor();
+
+        mockRaftActor.waitForRecoveryComplete();
+
+        Snapshot savedSnapshot = InMemorySnapshotStore.waitForSavedSnapshot(persistenceId, Snapshot.class);
+        assertEquals("getElectionTerm", snapshot.getElectionTerm(), savedSnapshot.getElectionTerm());
+        assertEquals("getElectionVotedFor", snapshot.getElectionVotedFor(), savedSnapshot.getElectionVotedFor());
+        assertEquals("getLastAppliedIndex", snapshot.getLastAppliedIndex(), savedSnapshot.getLastAppliedIndex());
+        assertEquals("getLastAppliedTerm", snapshot.getLastAppliedTerm(), savedSnapshot.getLastAppliedTerm());
+        assertEquals("getLastIndex", snapshot.getLastIndex(), savedSnapshot.getLastIndex());
+        assertEquals("getLastTerm", snapshot.getLastTerm(), savedSnapshot.getLastTerm());
+        assertArrayEquals("getState", snapshot.getState(), savedSnapshot.getState());
+        assertEquals("getUnAppliedEntries", snapshot.getUnAppliedEntries(), savedSnapshot.getUnAppliedEntries());
+
+        verify(mockRaftActor.snapshotCohortDelegate, timeout(5000)).applySnapshot(any(byte[].class));
+
+        RaftActorContext context = mockRaftActor.getRaftActorContext();
+        assertEquals("Journal log size", 1, context.getReplicatedLog().size());
+        assertEquals("Last index", snapshotLastIndex, context.getReplicatedLog().lastIndex());
+        assertEquals("Last applied", snapshotLastApplied, context.getLastApplied());
+        assertEquals("Commit index", snapshotLastApplied, context.getCommitIndex());
+        assertEquals("Recovered state", state, mockRaftActor.getState());
+        assertEquals("Current term", 1L, context.getTermInformation().getCurrentTerm());
+        assertEquals("Voted for", "member-1", context.getTermInformation().getVotedFor());
+
+        // Test with data persistence disabled
+
+        snapshot = Snapshot.create(new byte[0], Collections.<ReplicatedLogEntry>emptyList(),
+                -1, -1, -1, -1, 5, "member-1");
+
+        persistenceId = factory.generateActorId("test-actor-");
+
+        raftActorRef = factory.createTestActor(MockRaftActor.builder().id(persistenceId).
+                config(config).restoreFromSnapshot(SerializationUtils.serialize(snapshot)).
+                persistent(Optional.of(Boolean.FALSE)).props().
+                    withDispatcher(Dispatchers.DefaultDispatcherId()), persistenceId);
+        mockRaftActor = raftActorRef.underlyingActor();
+
+        mockRaftActor.waitForRecoveryComplete();
+        assertEquals("snapshot committed", true,
+                Uninterruptibles.awaitUninterruptibly(mockRaftActor.snapshotCommitted, 5, TimeUnit.SECONDS));
+
+        context = mockRaftActor.getRaftActorContext();
+        assertEquals("Current term", 5L, context.getTermInformation().getCurrentTerm());
+        assertEquals("Voted for", "member-1", context.getTermInformation().getVotedFor());
+
+        TEST_LOG.info("testRestoreFromSnapshot ending");
+    }
+
+    @Test
+    public void testRestoreFromSnapshotWithRecoveredData() throws Exception {
+        TEST_LOG.info("testRestoreFromSnapshotWithRecoveredData starting");
+
+        String persistenceId = factory.generateActorId("test-actor-");
+        DefaultConfigParamsImpl config = new DefaultConfigParamsImpl();
+        config.setCustomRaftPolicyImplementationClass(DisableElectionsRaftPolicy.class.getName());
+
+        List<MockPayload> state = Arrays.asList(new MockRaftActorContext.MockPayload("A"));
+        Snapshot snapshot = Snapshot.create(fromObject(state).toByteArray(), Arrays.<ReplicatedLogEntry>asList(),
+                5, 2, 5, 2, 2, "member-1");
+
+        InMemoryJournal.addEntry(persistenceId, 1, new MockRaftActorContext.MockReplicatedLogEntry(1, 0,
+                new MockRaftActorContext.MockPayload("B")));
+
+        TestActorRef<MockRaftActor> raftActorRef = factory.createTestActor(MockRaftActor.builder().id(persistenceId).
+                config(config).restoreFromSnapshot(SerializationUtils.serialize(snapshot)).props().
+                    withDispatcher(Dispatchers.DefaultDispatcherId()), persistenceId);
+        MockRaftActor mockRaftActor = raftActorRef.underlyingActor();
+
+        mockRaftActor.waitForRecoveryComplete();
+
+        verify(mockRaftActor.snapshotCohortDelegate, timeout(500).never()).applySnapshot(any(byte[].class));
+
+        RaftActorContext context = mockRaftActor.getRaftActorContext();
+        assertEquals("Journal log size", 1, context.getReplicatedLog().size());
+        assertEquals("Last index", 0, context.getReplicatedLog().lastIndex());
+        assertEquals("Last applied", -1, context.getLastApplied());
+        assertEquals("Commit index", -1, context.getCommitIndex());
+        assertEquals("Current term", 0, context.getTermInformation().getCurrentTerm());
+        assertEquals("Voted for", null, context.getTermInformation().getVotedFor());
+
+        TEST_LOG.info("testRestoreFromSnapshotWithRecoveredData ending");
+    }
 }
index 3e747e387ee1f881ba0141f1042cd837430b1ae2..815047423426e559b9099150f6c85d31aafe53ce 100644 (file)
@@ -12,9 +12,7 @@ import akka.actor.ActorSystem;
 import akka.pattern.Patterns;
 import akka.testkit.JavaTestKit;
 import akka.util.Timeout;
-import com.google.common.base.Optional;
 import com.google.common.util.concurrent.Uninterruptibles;
-import java.util.Collections;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
 import org.junit.Assert;
@@ -31,8 +29,7 @@ public class RaftActorTestKit extends JavaTestKit {
     public RaftActorTestKit(ActorSystem actorSystem, String actorName) {
         super(actorSystem);
 
-        raftActor = this.getSystem().actorOf(MockRaftActor.props(actorName,
-                Collections.<String,String>emptyMap(), Optional.<ConfigParams>absent()), actorName);
+        raftActor = this.getSystem().actorOf(MockRaftActor.builder().id(actorName).props(), actorName);
 
     }
 
index 01f337567560ba02f7849c8b920cd5f05b18f949..e16d7949745004ca91dbf20e4d186beb8d84cf50 100644 (file)
@@ -16,12 +16,15 @@ import akka.persistence.SnapshotSelectionCriteria;
 import akka.persistence.snapshot.japi.SnapshotStore;
 import com.google.common.collect.Iterables;
 import com.google.common.collect.Lists;
+import com.google.common.util.concurrent.Uninterruptibles;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import scala.concurrent.Future;
@@ -36,6 +39,7 @@ public class InMemorySnapshotStore extends SnapshotStore {
     static final Logger LOG = LoggerFactory.getLogger(InMemorySnapshotStore.class);
 
     private static Map<String, List<StoredSnapshot>> snapshots = new ConcurrentHashMap<>();
+    private static final Map<String, CountDownLatch> snapshotSavedLatches = new ConcurrentHashMap<>();
 
     public static void addSnapshot(String persistentId, Object snapshot) {
         List<StoredSnapshot> snapshotList = snapshots.get(persistentId);
@@ -75,6 +79,18 @@ public class InMemorySnapshotStore extends SnapshotStore {
         snapshots.clear();
     }
 
+    public static void addSnapshotSavedLatch(String persistenceId) {
+        snapshotSavedLatches.put(persistenceId, new CountDownLatch(1));
+    }
+
+    public static <T> T waitForSavedSnapshot(String persistenceId, Class<T> type) {
+        if(!Uninterruptibles.awaitUninterruptibly(snapshotSavedLatches.get(persistenceId), 5, TimeUnit.SECONDS)) {
+            throw new AssertionError("Snapshot was not saved");
+        }
+
+        return getSnapshots(persistenceId, type).get(0);
+    }
+
     @Override
     public Future<Option<SelectedSnapshot>> doLoadAsync(String s,
         SnapshotSelectionCriteria snapshotSelectionCriteria) {
@@ -101,6 +117,11 @@ public class InMemorySnapshotStore extends SnapshotStore {
             snapshotList.add(new StoredSnapshot(snapshotMetadata, o));
         }
 
+        CountDownLatch latch = snapshotSavedLatches.get(snapshotMetadata.persistenceId());
+        if(latch != null) {
+            latch.countDown();
+        }
+
         return Futures.successful(null);
     }
 
index f8c1db987912e44e487088abb9dd3d903b0c8fee..87d591dd228286c9f7da385f6d8c42a34b1286c0 100644 (file)
@@ -125,4 +125,10 @@ class ShardRecoveryCoordinator implements RaftActorRecoveryCohort {
             log.error("{}: Failed to apply recovery snapshot", shardName, e);
         }
     }
+
+    @Override
+    public byte[] getRestoreFromSnapshot() {
+        // TODO Auto-generated method stub
+        return null;
+    }
 }