X-Git-Url: https://git.opendaylight.org/gerrit/gitweb?a=blobdiff_plain;ds=sidebyside;f=opendaylight%2Fmd-sal%2Fsal-akka-raft%2Fsrc%2Ftest%2Fjava%2Forg%2Fopendaylight%2Fcontroller%2Fcluster%2Fraft%2FRaftActorTest.java;h=fd9912244a620c6bc060cb465dd0da51449ce47d;hb=f5a373c5378af41f62a2c36ced4046fbdb77e00b;hp=998c198756d191df038de6f6cc75a8091fd0d1f1;hpb=6e2f13caf52fe4f4af8aa7c0e8ffd475d4e15fba;p=controller.git diff --git a/opendaylight/md-sal/sal-akka-raft/src/test/java/org/opendaylight/controller/cluster/raft/RaftActorTest.java b/opendaylight/md-sal/sal-akka-raft/src/test/java/org/opendaylight/controller/cluster/raft/RaftActorTest.java index 998c198756..fd9912244a 100644 --- a/opendaylight/md-sal/sal-akka-raft/src/test/java/org/opendaylight/controller/cluster/raft/RaftActorTest.java +++ b/opendaylight/md-sal/sal-akka-raft/src/test/java/org/opendaylight/controller/cluster/raft/RaftActorTest.java @@ -4,18 +4,31 @@ import akka.actor.ActorRef; import akka.actor.ActorSystem; import akka.actor.PoisonPill; import akka.actor.Props; +import akka.actor.Terminated; import akka.event.Logging; import akka.japi.Creator; +import akka.persistence.RecoveryCompleted; +import akka.persistence.SaveSnapshotSuccess; +import akka.persistence.SnapshotMetadata; +import akka.persistence.SnapshotOffer; import akka.testkit.JavaTestKit; import akka.testkit.TestActorRef; +import com.google.common.base.Optional; +import com.google.common.collect.Lists; import com.google.protobuf.ByteString; import org.junit.After; import org.junit.Test; +import org.opendaylight.controller.cluster.DataPersistenceProvider; +import org.opendaylight.controller.cluster.datastore.DataPersistenceProviderMonitor; import org.opendaylight.controller.cluster.raft.base.messages.ApplyLogEntries; +import org.opendaylight.controller.cluster.raft.base.messages.CaptureSnapshot; +import org.opendaylight.controller.cluster.raft.base.messages.CaptureSnapshotReply; import org.opendaylight.controller.cluster.raft.client.messages.FindLeader; import org.opendaylight.controller.cluster.raft.client.messages.FindLeaderReply; +import org.opendaylight.controller.cluster.raft.protobuff.client.messages.Payload; import org.opendaylight.controller.cluster.raft.utils.MockAkkaJournal; import org.opendaylight.controller.cluster.raft.utils.MockSnapshotStore; +import scala.concurrent.duration.FiniteDuration; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; @@ -27,9 +40,12 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; -import static junit.framework.Assert.assertTrue; -import static junit.framework.TestCase.assertEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.mockito.Mockito.mock; public class RaftActorTest extends AbstractActorTest { @@ -42,61 +58,121 @@ public class RaftActorTest extends AbstractActorTest { public static class MockRaftActor extends RaftActor { - private boolean applySnapshotCalled = false; - private List state; + private final DataPersistenceProvider dataPersistenceProvider; - public MockRaftActor(String id, - Map peerAddresses) { - super(id, peerAddresses); + public static final class MockRaftActorCreator implements Creator { + private final Map peerAddresses; + private final String id; + private final Optional config; + private final DataPersistenceProvider dataPersistenceProvider; + + private MockRaftActorCreator(Map peerAddresses, String id, + Optional config, DataPersistenceProvider dataPersistenceProvider) { + this.peerAddresses = peerAddresses; + this.id = id; + this.config = config; + this.dataPersistenceProvider = dataPersistenceProvider; + } + + @Override + public MockRaftActor create() throws Exception { + return new MockRaftActor(id, peerAddresses, config, dataPersistenceProvider); + } + } + + private final CountDownLatch recoveryComplete = new CountDownLatch(1); + private final CountDownLatch applyRecoverySnapshot = new CountDownLatch(1); + private final CountDownLatch applyStateLatch = new CountDownLatch(1); + + private final List state; + + public MockRaftActor(String id, Map peerAddresses, Optional config, DataPersistenceProvider dataPersistenceProvider) { + super(id, peerAddresses, config); state = new ArrayList<>(); + if(dataPersistenceProvider == null){ + this.dataPersistenceProvider = new PersistentDataProvider(); + } else { + this.dataPersistenceProvider = dataPersistenceProvider; + } } - public RaftActorContext getRaftActorContext() { - return context; + public void waitForRecoveryComplete() { + try { + assertEquals("Recovery complete", true, recoveryComplete.await(5, TimeUnit.SECONDS)); + } catch (InterruptedException e) { + e.printStackTrace(); + } } - public boolean isApplySnapshotCalled() { - return applySnapshotCalled; + public CountDownLatch getApplyRecoverySnapshotLatch(){ + return applyRecoverySnapshot; } public List getState() { return state; } - public static Props props(final String id, final Map peerAddresses){ - return Props.create(new Creator(){ + public static Props props(final String id, final Map peerAddresses, + Optional config){ + return Props.create(new MockRaftActorCreator(peerAddresses, id, config, null)); + } - @Override public MockRaftActor create() throws Exception { - return new MockRaftActor(id, peerAddresses); - } - }); + public static Props props(final String id, final Map peerAddresses, + Optional config, DataPersistenceProvider dataPersistenceProvider){ + return Props.create(new MockRaftActorCreator(peerAddresses, id, config, dataPersistenceProvider)); } + @Override protected void applyState(ActorRef clientActor, String identifier, Object data) { + applyStateLatch.countDown(); + } + + @Override + protected void startLogRecoveryBatch(int maxBatchSize) { + } + + @Override + protected void appendRecoveredLogEntry(Payload data) { state.add(data); } - @Override protected void createSnapshot() { - throw new UnsupportedOperationException("createSnapshot"); + @Override + protected void applyCurrentLogRecoveryBatch() { } - @Override protected void applySnapshot(ByteString snapshot) { - applySnapshotCalled = true; + @Override + protected void onRecoveryComplete() { + recoveryComplete.countDown(); + } + + @Override + protected void applyRecoverySnapshot(ByteString snapshot) { + applyRecoverySnapshot.countDown(); try { Object data = toObject(snapshot); + System.out.println("!!!!!applyRecoverySnapshot: "+data); if (data instanceof List) { state.addAll((List) data); } - } catch (ClassNotFoundException e) { - e.printStackTrace(); - } catch (IOException e) { + } catch (Exception e) { e.printStackTrace(); } } + @Override protected void createSnapshot() { + } + + @Override protected void applySnapshot(ByteString snapshot) { + } + @Override protected void onStateChanged() { } + @Override + protected DataPersistenceProvider persistence() { + return this.dataPersistenceProvider; + } + @Override public String persistenceId() { return this.getId(); } @@ -120,6 +196,9 @@ public class RaftActorTest extends AbstractActorTest { return obj; } + public ReplicatedLog getReplicatedLog(){ + return this.getRaftActorContext().getReplicatedLog(); + } } @@ -130,9 +209,8 @@ public class RaftActorTest extends AbstractActorTest { public RaftActorTestKit(ActorSystem actorSystem, String actorName) { super(actorSystem); - raftActor = this.getSystem() - .actorOf(MockRaftActor.props(actorName, - Collections.EMPTY_MAP), actorName); + raftActor = this.getSystem().actorOf(MockRaftActor.props(actorName, + Collections.EMPTY_MAP, Optional.absent()), actorName); } @@ -142,48 +220,27 @@ public class RaftActorTest extends AbstractActorTest { return new JavaTestKit.EventFilter(Logging.Info.class ) { + @Override protected Boolean run() { return true; } }.from(raftActor.path().toString()) - .message("Switching from state Candidate to Leader") + .message("Switching from behavior Candidate to Leader") .occurrences(1).exec(); } public void findLeader(final String expectedLeader){ + raftActor.tell(new FindLeader(), getRef()); - - new Within(duration("1 seconds")) { - protected void run() { - - raftActor.tell(new FindLeader(), getRef()); - - String s = new ExpectMsg(duration("1 seconds"), - "findLeader") { - // do not put code outside this method, will run afterwards - protected String match(Object in) { - if (in instanceof FindLeaderReply) { - return ((FindLeaderReply) in).getLeaderActor(); - } else { - throw noMatch(); - } - } - }.get();// this extracts the received message - - assertEquals(expectedLeader, s); - - } - - - }; + FindLeaderReply reply = expectMsgClass(duration("5 seconds"), FindLeaderReply.class); + assertEquals("getLeaderActor", expectedLeader, reply.getLeaderActor()); } public ActorRef getRaftActor() { return raftActor; } - } @@ -201,89 +258,423 @@ public class RaftActorTest extends AbstractActorTest { } @Test - public void testRaftActorRecovery() { + public void testRaftActorRecovery() throws Exception { new JavaTestKit(getSystem()) {{ - new Within(duration("1 seconds")) { - protected void run() { - - String persistenceId = "follower10"; - - ActorRef followerActor = getSystem().actorOf( - MockRaftActor.props(persistenceId, Collections.EMPTY_MAP), persistenceId); - - List snapshotUnappliedEntries = new ArrayList<>(); - ReplicatedLogEntry entry1 = new MockRaftActorContext.MockReplicatedLogEntry(1, 4, new MockRaftActorContext.MockPayload("E")); - snapshotUnappliedEntries.add(entry1); - - int lastAppliedDuringSnapshotCapture = 3; - int lastIndexDuringSnapshotCapture = 4; - - ByteString snapshotBytes = null; - try { - // 4 messages as part of snapshot, which are applied to state - snapshotBytes = fromObject(Arrays.asList(new MockRaftActorContext.MockPayload("A"), - new MockRaftActorContext.MockPayload("B"), - new MockRaftActorContext.MockPayload("C"), - new MockRaftActorContext.MockPayload("D"))); - } catch (Exception e) { - e.printStackTrace(); - } - Snapshot snapshot = Snapshot.create(snapshotBytes.toByteArray(), - snapshotUnappliedEntries, lastIndexDuringSnapshotCapture, 1 , - lastAppliedDuringSnapshotCapture, 1); - MockSnapshotStore.setMockSnapshot(snapshot); - MockSnapshotStore.setPersistenceId(persistenceId); - - // add more entries after snapshot is taken - List entries = new ArrayList<>(); - ReplicatedLogEntry entry2 = new MockRaftActorContext.MockReplicatedLogEntry(1, 5, new MockRaftActorContext.MockPayload("F")); - ReplicatedLogEntry entry3 = new MockRaftActorContext.MockReplicatedLogEntry(1, 6, new MockRaftActorContext.MockPayload("G")); - ReplicatedLogEntry entry4 = new MockRaftActorContext.MockReplicatedLogEntry(1, 7, new MockRaftActorContext.MockPayload("H")); - entries.add(entry2); - entries.add(entry3); - entries.add(entry4); - - int lastAppliedToState = 5; - int lastIndex = 7; - - MockAkkaJournal.addToJournal(5, entry2); - // 2 entries are applied to state besides the 4 entries in snapshot - MockAkkaJournal.addToJournal(6, new ApplyLogEntries(lastAppliedToState)); - MockAkkaJournal.addToJournal(7, entry3); - MockAkkaJournal.addToJournal(8, entry4); - - // kill the actor - followerActor.tell(PoisonPill.getInstance(), null); - - try { - // give some time for actor to die - Thread.sleep(200); - } catch (InterruptedException e) { - e.printStackTrace(); - } + String persistenceId = "follower10"; + + DefaultConfigParamsImpl config = new DefaultConfigParamsImpl(); + // Set the heartbeat interval high to essentially disable election otherwise the test + // may fail if the actor is switched to Leader and the commitIndex is set to the last + // log entry. + config.setHeartBeatInterval(new FiniteDuration(1, TimeUnit.DAYS)); + + ActorRef followerActor = getSystem().actorOf(MockRaftActor.props(persistenceId, + Collections.EMPTY_MAP, Optional.of(config)), persistenceId); + + watch(followerActor); + + List snapshotUnappliedEntries = new ArrayList<>(); + ReplicatedLogEntry entry1 = new MockRaftActorContext.MockReplicatedLogEntry(1, 4, + new MockRaftActorContext.MockPayload("E")); + snapshotUnappliedEntries.add(entry1); + + int lastAppliedDuringSnapshotCapture = 3; + int lastIndexDuringSnapshotCapture = 4; + + // 4 messages as part of snapshot, which are applied to state + ByteString snapshotBytes = fromObject(Arrays.asList( + new MockRaftActorContext.MockPayload("A"), + new MockRaftActorContext.MockPayload("B"), + new MockRaftActorContext.MockPayload("C"), + new MockRaftActorContext.MockPayload("D"))); + + Snapshot snapshot = Snapshot.create(snapshotBytes.toByteArray(), + snapshotUnappliedEntries, lastIndexDuringSnapshotCapture, 1 , + lastAppliedDuringSnapshotCapture, 1); + MockSnapshotStore.setMockSnapshot(snapshot); + MockSnapshotStore.setPersistenceId(persistenceId); + + // add more entries after snapshot is taken + List entries = new ArrayList<>(); + ReplicatedLogEntry entry2 = new MockRaftActorContext.MockReplicatedLogEntry(1, 5, + new MockRaftActorContext.MockPayload("F")); + ReplicatedLogEntry entry3 = new MockRaftActorContext.MockReplicatedLogEntry(1, 6, + new MockRaftActorContext.MockPayload("G")); + ReplicatedLogEntry entry4 = new MockRaftActorContext.MockReplicatedLogEntry(1, 7, + new MockRaftActorContext.MockPayload("H")); + entries.add(entry2); + entries.add(entry3); + entries.add(entry4); + + int lastAppliedToState = 5; + int lastIndex = 7; + + MockAkkaJournal.addToJournal(5, entry2); + // 2 entries are applied to state besides the 4 entries in snapshot + MockAkkaJournal.addToJournal(6, new ApplyLogEntries(lastAppliedToState)); + MockAkkaJournal.addToJournal(7, entry3); + MockAkkaJournal.addToJournal(8, entry4); + + // kill the actor + followerActor.tell(PoisonPill.getInstance(), null); + expectMsgClass(duration("5 seconds"), Terminated.class); + + unwatch(followerActor); + + //reinstate the actor + TestActorRef ref = TestActorRef.create(getSystem(), + MockRaftActor.props(persistenceId, Collections.EMPTY_MAP, + Optional.of(config))); + + ref.underlyingActor().waitForRecoveryComplete(); + + RaftActorContext context = ref.underlyingActor().getRaftActorContext(); + assertEquals("Journal log size", snapshotUnappliedEntries.size() + entries.size(), + context.getReplicatedLog().size()); + assertEquals("Last index", lastIndex, context.getReplicatedLog().lastIndex()); + assertEquals("Last applied", lastAppliedToState, context.getLastApplied()); + assertEquals("Commit index", lastAppliedToState, context.getCommitIndex()); + assertEquals("Recovered state size", 6, ref.underlyingActor().getState().size()); + }}; + } - //reinstate the actor - TestActorRef ref = TestActorRef.create(getSystem(), - MockRaftActor.props(persistenceId, Collections.EMPTY_MAP)); + /** + * This test verifies that when recovery is applicable (typically when persistence is true) the RaftActor does + * process recovery messages + * + * @throws Exception + */ - try { - //give some time for snapshot offer to get called. - Thread.sleep(200); - } catch (InterruptedException e) { - e.printStackTrace(); - } + @Test + public void testHandleRecoveryWhenDataPersistenceRecoveryApplicable() throws Exception { + new JavaTestKit(getSystem()) { + { + String persistenceId = "testHandleRecoveryWhenDataPersistenceRecoveryApplicable"; - RaftActorContext context = ref.underlyingActor().getRaftActorContext(); - assertEquals(snapshotUnappliedEntries.size() + entries.size(), context.getReplicatedLog().size()); - assertEquals(lastIndex, context.getReplicatedLog().lastIndex()); - assertEquals(lastAppliedToState, context.getLastApplied()); - assertEquals(lastAppliedToState, context.getCommitIndex()); - assertTrue(ref.underlyingActor().isApplySnapshotCalled()); - assertEquals(6, ref.underlyingActor().getState().size()); - } - }; - }}; + DefaultConfigParamsImpl config = new DefaultConfigParamsImpl(); + + config.setHeartBeatInterval(new FiniteDuration(1, TimeUnit.DAYS)); + + TestActorRef mockActorRef = TestActorRef.create(getSystem(), MockRaftActor.props(persistenceId, + Collections.EMPTY_MAP, Optional.of(config)), persistenceId); + + MockRaftActor mockRaftActor = mockActorRef.underlyingActor(); + + // Wait for akka's recovery to complete so it doesn't interfere. + mockRaftActor.waitForRecoveryComplete(); + + ByteString snapshotBytes = fromObject(Arrays.asList( + new MockRaftActorContext.MockPayload("A"), + new MockRaftActorContext.MockPayload("B"), + new MockRaftActorContext.MockPayload("C"), + new MockRaftActorContext.MockPayload("D"))); + + Snapshot snapshot = Snapshot.create(snapshotBytes.toByteArray(), + Lists.newArrayList(), 3, 1 ,3, 1); + + mockRaftActor.onReceiveRecover(new SnapshotOffer(new SnapshotMetadata(persistenceId, 100, 100), snapshot)); + + CountDownLatch applyRecoverySnapshotLatch = mockRaftActor.getApplyRecoverySnapshotLatch(); + + assertEquals("apply recovery snapshot", true, applyRecoverySnapshotLatch.await(5, TimeUnit.SECONDS)); + + mockRaftActor.onReceiveRecover(new ReplicatedLogImplEntry(0, 1, new MockRaftActorContext.MockPayload("A"))); + + ReplicatedLog replicatedLog = mockRaftActor.getReplicatedLog(); + + assertEquals("add replicated log entry", 1, replicatedLog.size()); + + mockRaftActor.onReceiveRecover(new ReplicatedLogImplEntry(1, 1, new MockRaftActorContext.MockPayload("A"))); + + assertEquals("add replicated log entry", 2, replicatedLog.size()); + + mockRaftActor.onReceiveRecover(new ApplyLogEntries(1)); + + assertEquals("commit index 1", 1, mockRaftActor.getRaftActorContext().getCommitIndex()); + + // The snapshot had 4 items + we added 2 more items during the test + // We start removing from 5 and we should get 1 item in the replicated log + mockRaftActor.onReceiveRecover(new RaftActor.DeleteEntries(5)); + + assertEquals("remove log entries", 1, replicatedLog.size()); + + mockRaftActor.onReceiveRecover(new RaftActor.UpdateElectionTerm(10, "foobar")); + + assertEquals("election term", 10, mockRaftActor.getRaftActorContext().getTermInformation().getCurrentTerm()); + assertEquals("voted for", "foobar", mockRaftActor.getRaftActorContext().getTermInformation().getVotedFor()); + + mockRaftActor.onReceiveRecover(mock(RecoveryCompleted.class)); + + mockActorRef.tell(PoisonPill.getInstance(), getRef()); + + }}; + } + + /** + * This test verifies that when recovery is not applicable (typically when persistence is false) the RaftActor does + * not process recovery messages + * + * @throws Exception + */ + @Test + public void testHandleRecoveryWhenDataPersistenceRecoveryNotApplicable() throws Exception { + new JavaTestKit(getSystem()) { + { + String persistenceId = "testHandleRecoveryWhenDataPersistenceRecoveryNotApplicable"; + + DefaultConfigParamsImpl config = new DefaultConfigParamsImpl(); + + config.setHeartBeatInterval(new FiniteDuration(1, TimeUnit.DAYS)); + + TestActorRef mockActorRef = TestActorRef.create(getSystem(), MockRaftActor.props(persistenceId, + Collections.EMPTY_MAP, Optional.of(config), new DataPersistenceProviderMonitor()), persistenceId); + + MockRaftActor mockRaftActor = mockActorRef.underlyingActor(); + + // Wait for akka's recovery to complete so it doesn't interfere. + mockRaftActor.waitForRecoveryComplete(); + + ByteString snapshotBytes = fromObject(Arrays.asList( + new MockRaftActorContext.MockPayload("A"), + new MockRaftActorContext.MockPayload("B"), + new MockRaftActorContext.MockPayload("C"), + new MockRaftActorContext.MockPayload("D"))); + + Snapshot snapshot = Snapshot.create(snapshotBytes.toByteArray(), + Lists.newArrayList(), 3, 1 ,3, 1); + + mockRaftActor.onReceiveRecover(new SnapshotOffer(new SnapshotMetadata(persistenceId, 100, 100), snapshot)); + + CountDownLatch applyRecoverySnapshotLatch = mockRaftActor.getApplyRecoverySnapshotLatch(); + + assertEquals("apply recovery snapshot", false, applyRecoverySnapshotLatch.await(1, TimeUnit.SECONDS)); + + mockRaftActor.onReceiveRecover(new ReplicatedLogImplEntry(0, 1, new MockRaftActorContext.MockPayload("A"))); + + ReplicatedLog replicatedLog = mockRaftActor.getReplicatedLog(); + + assertEquals("add replicated log entry", 0, replicatedLog.size()); + + mockRaftActor.onReceiveRecover(new ReplicatedLogImplEntry(1, 1, new MockRaftActorContext.MockPayload("A"))); + + assertEquals("add replicated log entry", 0, replicatedLog.size()); + + mockRaftActor.onReceiveRecover(new ApplyLogEntries(1)); + assertEquals("commit index -1", -1, mockRaftActor.getRaftActorContext().getCommitIndex()); + + mockRaftActor.onReceiveRecover(new RaftActor.DeleteEntries(2)); + + assertEquals("remove log entries", 0, replicatedLog.size()); + + mockRaftActor.onReceiveRecover(new RaftActor.UpdateElectionTerm(10, "foobar")); + + assertNotEquals("election term", 10, mockRaftActor.getRaftActorContext().getTermInformation().getCurrentTerm()); + assertNotEquals("voted for", "foobar", mockRaftActor.getRaftActorContext().getTermInformation().getVotedFor()); + + mockRaftActor.onReceiveRecover(mock(RecoveryCompleted.class)); + + mockActorRef.tell(PoisonPill.getInstance(), getRef()); + }}; + } + + + @Test + public void testUpdatingElectionTermCallsDataPersistence() throws Exception { + new JavaTestKit(getSystem()) { + { + String persistenceId = "testUpdatingElectionTermCallsDataPersistence"; + + DefaultConfigParamsImpl config = new DefaultConfigParamsImpl(); + + config.setHeartBeatInterval(new FiniteDuration(1, TimeUnit.DAYS)); + + CountDownLatch persistLatch = new CountDownLatch(1); + DataPersistenceProviderMonitor dataPersistenceProviderMonitor = new DataPersistenceProviderMonitor(); + dataPersistenceProviderMonitor.setPersistLatch(persistLatch); + + TestActorRef mockActorRef = TestActorRef.create(getSystem(), MockRaftActor.props(persistenceId, + Collections.EMPTY_MAP, Optional.of(config), dataPersistenceProviderMonitor), persistenceId); + + MockRaftActor mockRaftActor = mockActorRef.underlyingActor(); + + mockRaftActor.getRaftActorContext().getTermInformation().updateAndPersist(10, "foobar"); + + assertEquals("Persist called", true, persistLatch.await(5, TimeUnit.SECONDS)); + + mockActorRef.tell(PoisonPill.getInstance(), getRef()); + + } + }; + } + + @Test + public void testAddingReplicatedLogEntryCallsDataPersistence() throws Exception { + new JavaTestKit(getSystem()) { + { + String persistenceId = "testAddingReplicatedLogEntryCallsDataPersistence"; + + DefaultConfigParamsImpl config = new DefaultConfigParamsImpl(); + + config.setHeartBeatInterval(new FiniteDuration(1, TimeUnit.DAYS)); + + CountDownLatch persistLatch = new CountDownLatch(1); + DataPersistenceProviderMonitor dataPersistenceProviderMonitor = new DataPersistenceProviderMonitor(); + dataPersistenceProviderMonitor.setPersistLatch(persistLatch); + + TestActorRef mockActorRef = TestActorRef.create(getSystem(), MockRaftActor.props(persistenceId, + Collections.EMPTY_MAP, Optional.of(config), dataPersistenceProviderMonitor), persistenceId); + + MockRaftActor mockRaftActor = mockActorRef.underlyingActor(); + + mockRaftActor.getRaftActorContext().getReplicatedLog().appendAndPersist(new MockRaftActorContext.MockReplicatedLogEntry(10, 10, mock(Payload.class))); + + assertEquals("Persist called", true, persistLatch.await(5, TimeUnit.SECONDS)); + + mockActorRef.tell(PoisonPill.getInstance(), getRef()); + + } + }; + } + + @Test + public void testRemovingReplicatedLogEntryCallsDataPersistence() throws Exception { + new JavaTestKit(getSystem()) { + { + String persistenceId = "testRemovingReplicatedLogEntryCallsDataPersistence"; + + DefaultConfigParamsImpl config = new DefaultConfigParamsImpl(); + + config.setHeartBeatInterval(new FiniteDuration(1, TimeUnit.DAYS)); + + CountDownLatch persistLatch = new CountDownLatch(2); + DataPersistenceProviderMonitor dataPersistenceProviderMonitor = new DataPersistenceProviderMonitor(); + dataPersistenceProviderMonitor.setPersistLatch(persistLatch); + + TestActorRef mockActorRef = TestActorRef.create(getSystem(), MockRaftActor.props(persistenceId, + Collections.EMPTY_MAP, Optional.of(config), dataPersistenceProviderMonitor), persistenceId); + + MockRaftActor mockRaftActor = mockActorRef.underlyingActor(); + + mockRaftActor.getReplicatedLog().appendAndPersist(new MockRaftActorContext.MockReplicatedLogEntry(1, 0, mock(Payload.class))); + + mockRaftActor.getRaftActorContext().getReplicatedLog().removeFromAndPersist(0); + + assertEquals("Persist called", true, persistLatch.await(5, TimeUnit.SECONDS)); + + mockActorRef.tell(PoisonPill.getInstance(), getRef()); + + } + }; + } + + @Test + public void testApplyLogEntriesCallsDataPersistence() throws Exception { + new JavaTestKit(getSystem()) { + { + String persistenceId = "testApplyLogEntriesCallsDataPersistence"; + + DefaultConfigParamsImpl config = new DefaultConfigParamsImpl(); + + config.setHeartBeatInterval(new FiniteDuration(1, TimeUnit.DAYS)); + + CountDownLatch persistLatch = new CountDownLatch(1); + DataPersistenceProviderMonitor dataPersistenceProviderMonitor = new DataPersistenceProviderMonitor(); + dataPersistenceProviderMonitor.setPersistLatch(persistLatch); + + TestActorRef mockActorRef = TestActorRef.create(getSystem(), MockRaftActor.props(persistenceId, + Collections.EMPTY_MAP, Optional.of(config), dataPersistenceProviderMonitor), persistenceId); + + MockRaftActor mockRaftActor = mockActorRef.underlyingActor(); + + mockRaftActor.onReceiveCommand(new ApplyLogEntries(10)); + + assertEquals("Persist called", true, persistLatch.await(5, TimeUnit.SECONDS)); + + mockActorRef.tell(PoisonPill.getInstance(), getRef()); + + } + }; + } + + @Test + public void testCaptureSnapshotReplyCallsDataPersistence() throws Exception { + new JavaTestKit(getSystem()) { + { + String persistenceId = "testCaptureSnapshotReplyCallsDataPersistence"; + + DefaultConfigParamsImpl config = new DefaultConfigParamsImpl(); + + config.setHeartBeatInterval(new FiniteDuration(1, TimeUnit.DAYS)); + + CountDownLatch persistLatch = new CountDownLatch(1); + DataPersistenceProviderMonitor dataPersistenceProviderMonitor = new DataPersistenceProviderMonitor(); + dataPersistenceProviderMonitor.setSaveSnapshotLatch(persistLatch); + + TestActorRef mockActorRef = TestActorRef.create(getSystem(), MockRaftActor.props(persistenceId, + Collections.EMPTY_MAP, Optional.of(config), dataPersistenceProviderMonitor), persistenceId); + + MockRaftActor mockRaftActor = mockActorRef.underlyingActor(); + + ByteString snapshotBytes = fromObject(Arrays.asList( + new MockRaftActorContext.MockPayload("A"), + new MockRaftActorContext.MockPayload("B"), + new MockRaftActorContext.MockPayload("C"), + new MockRaftActorContext.MockPayload("D"))); + + mockRaftActor.onReceiveCommand(new CaptureSnapshot(-1,1,-1,1)); + + mockRaftActor.onReceiveCommand(new CaptureSnapshotReply(snapshotBytes)); + + assertEquals("Save Snapshot called", true, persistLatch.await(5, TimeUnit.SECONDS)); + + mockActorRef.tell(PoisonPill.getInstance(), getRef()); + + } + }; + } + + @Test + public void testSaveSnapshotSuccessCallsDataPersistence() throws Exception { + new JavaTestKit(getSystem()) { + { + String persistenceId = "testSaveSnapshotSuccessCallsDataPersistence"; + + DefaultConfigParamsImpl config = new DefaultConfigParamsImpl(); + + config.setHeartBeatInterval(new FiniteDuration(1, TimeUnit.DAYS)); + + CountDownLatch deleteMessagesLatch = new CountDownLatch(1); + CountDownLatch deleteSnapshotsLatch = new CountDownLatch(1); + DataPersistenceProviderMonitor dataPersistenceProviderMonitor = new DataPersistenceProviderMonitor(); + dataPersistenceProviderMonitor.setDeleteMessagesLatch(deleteMessagesLatch); + dataPersistenceProviderMonitor.setDeleteSnapshotsLatch(deleteSnapshotsLatch); + + TestActorRef mockActorRef = TestActorRef.create(getSystem(), MockRaftActor.props(persistenceId, + Collections.EMPTY_MAP, Optional.of(config), dataPersistenceProviderMonitor), persistenceId); + + MockRaftActor mockRaftActor = mockActorRef.underlyingActor(); + + ByteString snapshotBytes = fromObject(Arrays.asList( + new MockRaftActorContext.MockPayload("A"), + new MockRaftActorContext.MockPayload("B"), + new MockRaftActorContext.MockPayload("C"), + new MockRaftActorContext.MockPayload("D"))); + + mockRaftActor.onReceiveCommand(new CaptureSnapshot(-1,1,-1,1)); + + mockRaftActor.onReceiveCommand(new CaptureSnapshotReply(snapshotBytes)); + + mockRaftActor.onReceiveCommand(new SaveSnapshotSuccess(new SnapshotMetadata("foo", 100, 100))); + + assertEquals("Delete Messages called", true, deleteMessagesLatch.await(5, TimeUnit.SECONDS)); + + assertEquals("Delete Snapshots called", true, deleteSnapshotsLatch.await(5, TimeUnit.SECONDS)); + + mockActorRef.tell(PoisonPill.getInstance(), getRef()); + + } + }; } private ByteString fromObject(Object snapshot) throws Exception {