Merge "Fix potential issue with transaction timeouts"
[controller.git] / opendaylight / md-sal / sal-akka-raft / src / test / java / org / opendaylight / controller / cluster / raft / RaftActorTest.java
1 package org.opendaylight.controller.cluster.raft;
2
3 import static org.junit.Assert.assertEquals;
4 import static org.junit.Assert.assertNotEquals;
5 import static org.junit.Assert.assertNotNull;
6 import static org.junit.Assert.assertNull;
7 import static org.junit.Assert.assertTrue;
8 import static org.mockito.Matchers.any;
9 import static org.mockito.Matchers.anyObject;
10 import static org.mockito.Matchers.eq;
11 import static org.mockito.Mockito.doReturn;
12 import static org.mockito.Mockito.mock;
13 import static org.mockito.Mockito.times;
14 import static org.mockito.Mockito.verify;
15 import akka.actor.ActorRef;
16 import akka.actor.ActorSystem;
17 import akka.actor.PoisonPill;
18 import akka.actor.Props;
19 import akka.actor.Terminated;
20 import akka.japi.Creator;
21 import akka.japi.Procedure;
22 import akka.pattern.Patterns;
23 import akka.persistence.RecoveryCompleted;
24 import akka.persistence.SaveSnapshotFailure;
25 import akka.persistence.SaveSnapshotSuccess;
26 import akka.persistence.SnapshotMetadata;
27 import akka.persistence.SnapshotOffer;
28 import akka.persistence.SnapshotSelectionCriteria;
29 import akka.testkit.JavaTestKit;
30 import akka.testkit.TestActorRef;
31 import akka.util.Timeout;
32 import com.google.common.base.Optional;
33 import com.google.common.collect.Lists;
34 import com.google.common.util.concurrent.Uninterruptibles;
35 import com.google.protobuf.ByteString;
36 import java.io.ByteArrayInputStream;
37 import java.io.ByteArrayOutputStream;
38 import java.io.IOException;
39 import java.io.ObjectInputStream;
40 import java.io.ObjectOutputStream;
41 import java.util.ArrayList;
42 import java.util.Arrays;
43 import java.util.Collections;
44 import java.util.List;
45 import java.util.Map;
46 import java.util.concurrent.CountDownLatch;
47 import java.util.concurrent.TimeUnit;
48 import java.util.concurrent.TimeoutException;
49 import org.junit.After;
50 import org.junit.Assert;
51 import org.junit.Test;
52 import org.opendaylight.controller.cluster.DataPersistenceProvider;
53 import org.opendaylight.controller.cluster.datastore.DataPersistenceProviderMonitor;
54 import org.opendaylight.controller.cluster.notifications.RoleChanged;
55 import org.opendaylight.controller.cluster.raft.base.messages.ApplyLogEntries;
56 import org.opendaylight.controller.cluster.raft.base.messages.ApplySnapshot;
57 import org.opendaylight.controller.cluster.raft.base.messages.ApplyState;
58 import org.opendaylight.controller.cluster.raft.base.messages.CaptureSnapshot;
59 import org.opendaylight.controller.cluster.raft.base.messages.CaptureSnapshotReply;
60 import org.opendaylight.controller.cluster.raft.behaviors.Follower;
61 import org.opendaylight.controller.cluster.raft.behaviors.Leader;
62 import org.opendaylight.controller.cluster.raft.client.messages.FindLeader;
63 import org.opendaylight.controller.cluster.raft.client.messages.FindLeaderReply;
64 import org.opendaylight.controller.cluster.raft.protobuff.client.messages.Payload;
65 import org.opendaylight.controller.cluster.raft.utils.MessageCollectorActor;
66 import org.opendaylight.controller.cluster.raft.utils.MockAkkaJournal;
67 import org.opendaylight.controller.cluster.raft.utils.MockSnapshotStore;
68 import scala.concurrent.Await;
69 import scala.concurrent.Future;
70 import scala.concurrent.duration.Duration;
71 import scala.concurrent.duration.FiniteDuration;
72
73 public class RaftActorTest extends AbstractActorTest {
74
75
76     @After
77     public void tearDown() {
78         MockAkkaJournal.clearJournal();
79         MockSnapshotStore.setMockSnapshot(null);
80     }
81
82     public static class MockRaftActor extends RaftActor {
83
84         private final DataPersistenceProvider dataPersistenceProvider;
85         private final RaftActor delegate;
86         private final CountDownLatch recoveryComplete = new CountDownLatch(1);
87         private final List<Object> state;
88         private ActorRef roleChangeNotifier;
89
90         public static final class MockRaftActorCreator implements Creator<MockRaftActor> {
91             private static final long serialVersionUID = 1L;
92             private final Map<String, String> peerAddresses;
93             private final String id;
94             private final Optional<ConfigParams> config;
95             private final DataPersistenceProvider dataPersistenceProvider;
96             private final ActorRef roleChangeNotifier;
97
98             private MockRaftActorCreator(Map<String, String> peerAddresses, String id,
99                 Optional<ConfigParams> config, DataPersistenceProvider dataPersistenceProvider,
100                 ActorRef roleChangeNotifier) {
101                 this.peerAddresses = peerAddresses;
102                 this.id = id;
103                 this.config = config;
104                 this.dataPersistenceProvider = dataPersistenceProvider;
105                 this.roleChangeNotifier = roleChangeNotifier;
106             }
107
108             @Override
109             public MockRaftActor create() throws Exception {
110                 MockRaftActor mockRaftActor = new MockRaftActor(id, peerAddresses, config,
111                     dataPersistenceProvider);
112                 mockRaftActor.roleChangeNotifier = this.roleChangeNotifier;
113                 return mockRaftActor;
114             }
115         }
116
117         public MockRaftActor(String id, Map<String, String> peerAddresses, Optional<ConfigParams> config, DataPersistenceProvider dataPersistenceProvider) {
118             super(id, peerAddresses, config);
119             state = new ArrayList<>();
120             this.delegate = mock(RaftActor.class);
121             if(dataPersistenceProvider == null){
122                 this.dataPersistenceProvider = new PersistentDataProvider();
123             } else {
124                 this.dataPersistenceProvider = dataPersistenceProvider;
125             }
126         }
127
128         public void waitForRecoveryComplete() {
129             try {
130                 assertEquals("Recovery complete", true, recoveryComplete.await(5,  TimeUnit.SECONDS));
131             } catch (InterruptedException e) {
132                 e.printStackTrace();
133             }
134         }
135
136         public List<Object> getState() {
137             return state;
138         }
139
140         public static Props props(final String id, final Map<String, String> peerAddresses,
141                 Optional<ConfigParams> config){
142             return Props.create(new MockRaftActorCreator(peerAddresses, id, config, null, null));
143         }
144
145         public static Props props(final String id, final Map<String, String> peerAddresses,
146                                   Optional<ConfigParams> config, DataPersistenceProvider dataPersistenceProvider){
147             return Props.create(new MockRaftActorCreator(peerAddresses, id, config, dataPersistenceProvider, null));
148         }
149
150         public static Props props(final String id, final Map<String, String> peerAddresses,
151             Optional<ConfigParams> config, ActorRef roleChangeNotifier){
152             return Props.create(new MockRaftActorCreator(peerAddresses, id, config, null, roleChangeNotifier));
153         }
154
155         @Override protected void applyState(ActorRef clientActor, String identifier, Object data) {
156             delegate.applyState(clientActor, identifier, data);
157             LOG.info("applyState called");
158         }
159
160         @Override
161         protected void startLogRecoveryBatch(int maxBatchSize) {
162         }
163
164         @Override
165         protected void appendRecoveredLogEntry(Payload data) {
166             state.add(data);
167         }
168
169         @Override
170         protected void applyCurrentLogRecoveryBatch() {
171         }
172
173         @Override
174         protected void onRecoveryComplete() {
175             delegate.onRecoveryComplete();
176             recoveryComplete.countDown();
177         }
178
179         @Override
180         protected void applyRecoverySnapshot(byte[] bytes) {
181             delegate.applyRecoverySnapshot(bytes);
182             try {
183                 Object data = toObject(bytes);
184                 if (data instanceof List) {
185                     state.addAll((List<?>) data);
186                 }
187             } catch (Exception e) {
188                 e.printStackTrace();
189             }
190         }
191
192         @Override protected void createSnapshot() {
193             delegate.createSnapshot();
194         }
195
196         @Override protected void applySnapshot(byte [] snapshot) {
197             delegate.applySnapshot(snapshot);
198         }
199
200         @Override protected void onStateChanged() {
201             delegate.onStateChanged();
202         }
203
204         @Override
205         protected DataPersistenceProvider persistence() {
206             return this.dataPersistenceProvider;
207         }
208
209         @Override
210         protected Optional<ActorRef> getRoleChangeNotifier() {
211             return Optional.fromNullable(roleChangeNotifier);
212         }
213
214         @Override public String persistenceId() {
215             return this.getId();
216         }
217
218         private Object toObject(byte[] bs) throws ClassNotFoundException, IOException {
219             Object obj = null;
220             ByteArrayInputStream bis = null;
221             ObjectInputStream ois = null;
222             try {
223                 bis = new ByteArrayInputStream(bs);
224                 ois = new ObjectInputStream(bis);
225                 obj = ois.readObject();
226             } finally {
227                 if (bis != null) {
228                     bis.close();
229                 }
230                 if (ois != null) {
231                     ois.close();
232                 }
233             }
234             return obj;
235         }
236
237         public ReplicatedLog getReplicatedLog(){
238             return this.getRaftActorContext().getReplicatedLog();
239         }
240
241     }
242
243
244     private static class RaftActorTestKit extends JavaTestKit {
245         private final ActorRef raftActor;
246
247         public RaftActorTestKit(ActorSystem actorSystem, String actorName) {
248             super(actorSystem);
249
250             raftActor = this.getSystem().actorOf(MockRaftActor.props(actorName,
251                     Collections.<String,String>emptyMap(), Optional.<ConfigParams>absent()), actorName);
252
253         }
254
255
256         public ActorRef getRaftActor() {
257             return raftActor;
258         }
259
260         public boolean waitForLogMessage(final Class<?> logEventClass, String message){
261             // Wait for a specific log message to show up
262             return
263                 new JavaTestKit.EventFilter<Boolean>(logEventClass
264                 ) {
265                     @Override
266                     protected Boolean run() {
267                         return true;
268                     }
269                 }.from(raftActor.path().toString())
270                     .message(message)
271                     .occurrences(1).exec();
272
273
274         }
275
276         protected void waitUntilLeader(){
277             waitUntilLeader(raftActor);
278         }
279
280         protected void waitUntilLeader(ActorRef actorRef) {
281             FiniteDuration duration = Duration.create(100, TimeUnit.MILLISECONDS);
282             for(int i = 0; i < 20 * 5; i++) {
283                 Future<Object> future = Patterns.ask(actorRef, new FindLeader(), new Timeout(duration));
284                 try {
285                     FindLeaderReply resp = (FindLeaderReply) Await.result(future, duration);
286                     if(resp.getLeaderActor() != null) {
287                         return;
288                     }
289                 } catch(TimeoutException e) {
290                 } catch(Exception e) {
291                     System.err.println("FindLeader threw ex");
292                     e.printStackTrace();
293                 }
294
295
296                 Uninterruptibles.sleepUninterruptibly(50, TimeUnit.MILLISECONDS);
297             }
298
299             Assert.fail("Leader not found for actorRef " + actorRef.path());
300         }
301
302     }
303
304
305     @Test
306     public void testConstruction() {
307         new RaftActorTestKit(getSystem(), "testConstruction").waitUntilLeader();
308     }
309
310     @Test
311     public void testFindLeaderWhenLeaderIsSelf(){
312         RaftActorTestKit kit = new RaftActorTestKit(getSystem(), "testFindLeader");
313         kit.waitUntilLeader();
314     }
315
316     @Test
317     public void testRaftActorRecovery() throws Exception {
318         new JavaTestKit(getSystem()) {{
319             String persistenceId = "follower10";
320
321             DefaultConfigParamsImpl config = new DefaultConfigParamsImpl();
322             // Set the heartbeat interval high to essentially disable election otherwise the test
323             // may fail if the actor is switched to Leader and the commitIndex is set to the last
324             // log entry.
325             config.setHeartBeatInterval(new FiniteDuration(1, TimeUnit.DAYS));
326
327             ActorRef followerActor = getSystem().actorOf(MockRaftActor.props(persistenceId,
328                     Collections.<String,String>emptyMap(), Optional.<ConfigParams>of(config)), persistenceId);
329
330             watch(followerActor);
331
332             List<ReplicatedLogEntry> snapshotUnappliedEntries = new ArrayList<>();
333             ReplicatedLogEntry entry1 = new MockRaftActorContext.MockReplicatedLogEntry(1, 4,
334                     new MockRaftActorContext.MockPayload("E"));
335             snapshotUnappliedEntries.add(entry1);
336
337             int lastAppliedDuringSnapshotCapture = 3;
338             int lastIndexDuringSnapshotCapture = 4;
339
340                 // 4 messages as part of snapshot, which are applied to state
341             ByteString snapshotBytes  = fromObject(Arrays.asList(
342                         new MockRaftActorContext.MockPayload("A"),
343                         new MockRaftActorContext.MockPayload("B"),
344                         new MockRaftActorContext.MockPayload("C"),
345                         new MockRaftActorContext.MockPayload("D")));
346
347             Snapshot snapshot = Snapshot.create(snapshotBytes.toByteArray(),
348                     snapshotUnappliedEntries, lastIndexDuringSnapshotCapture, 1 ,
349                     lastAppliedDuringSnapshotCapture, 1);
350             MockSnapshotStore.setMockSnapshot(snapshot);
351             MockSnapshotStore.setPersistenceId(persistenceId);
352
353             // add more entries after snapshot is taken
354             List<ReplicatedLogEntry> entries = new ArrayList<>();
355             ReplicatedLogEntry entry2 = new MockRaftActorContext.MockReplicatedLogEntry(1, 5,
356                     new MockRaftActorContext.MockPayload("F"));
357             ReplicatedLogEntry entry3 = new MockRaftActorContext.MockReplicatedLogEntry(1, 6,
358                     new MockRaftActorContext.MockPayload("G"));
359             ReplicatedLogEntry entry4 = new MockRaftActorContext.MockReplicatedLogEntry(1, 7,
360                     new MockRaftActorContext.MockPayload("H"));
361             entries.add(entry2);
362             entries.add(entry3);
363             entries.add(entry4);
364
365             int lastAppliedToState = 5;
366             int lastIndex = 7;
367
368             MockAkkaJournal.addToJournal(5, entry2);
369             // 2 entries are applied to state besides the 4 entries in snapshot
370             MockAkkaJournal.addToJournal(6, new ApplyLogEntries(lastAppliedToState));
371             MockAkkaJournal.addToJournal(7, entry3);
372             MockAkkaJournal.addToJournal(8, entry4);
373
374             // kill the actor
375             followerActor.tell(PoisonPill.getInstance(), null);
376             expectMsgClass(duration("5 seconds"), Terminated.class);
377
378             unwatch(followerActor);
379
380             //reinstate the actor
381             TestActorRef<MockRaftActor> ref = TestActorRef.create(getSystem(),
382                     MockRaftActor.props(persistenceId, Collections.<String,String>emptyMap(),
383                             Optional.<ConfigParams>of(config)));
384
385             ref.underlyingActor().waitForRecoveryComplete();
386
387             RaftActorContext context = ref.underlyingActor().getRaftActorContext();
388             assertEquals("Journal log size", snapshotUnappliedEntries.size() + entries.size(),
389                     context.getReplicatedLog().size());
390             assertEquals("Last index", lastIndex, context.getReplicatedLog().lastIndex());
391             assertEquals("Last applied", lastAppliedToState, context.getLastApplied());
392             assertEquals("Commit index", lastAppliedToState, context.getCommitIndex());
393             assertEquals("Recovered state size", 6, ref.underlyingActor().getState().size());
394         }};
395     }
396
397     /**
398      * This test verifies that when recovery is applicable (typically when persistence is true) the RaftActor does
399      * process recovery messages
400      *
401      * @throws Exception
402      */
403
404     @Test
405     public void testHandleRecoveryWhenDataPersistenceRecoveryApplicable() throws Exception {
406         new JavaTestKit(getSystem()) {
407             {
408                 String persistenceId = "testHandleRecoveryWhenDataPersistenceRecoveryApplicable";
409
410                 DefaultConfigParamsImpl config = new DefaultConfigParamsImpl();
411
412                 config.setHeartBeatInterval(new FiniteDuration(1, TimeUnit.DAYS));
413
414                 TestActorRef<MockRaftActor> mockActorRef = TestActorRef.create(getSystem(), MockRaftActor.props(persistenceId,
415                         Collections.<String,String>emptyMap(), Optional.<ConfigParams>of(config)), persistenceId);
416
417                 MockRaftActor mockRaftActor = mockActorRef.underlyingActor();
418
419                 // Wait for akka's recovery to complete so it doesn't interfere.
420                 mockRaftActor.waitForRecoveryComplete();
421
422                 ByteString snapshotBytes  = fromObject(Arrays.asList(
423                         new MockRaftActorContext.MockPayload("A"),
424                         new MockRaftActorContext.MockPayload("B"),
425                         new MockRaftActorContext.MockPayload("C"),
426                         new MockRaftActorContext.MockPayload("D")));
427
428                 Snapshot snapshot = Snapshot.create(snapshotBytes.toByteArray(),
429                         Lists.<ReplicatedLogEntry>newArrayList(), 3, 1 ,3, 1);
430
431                 mockRaftActor.onReceiveRecover(new SnapshotOffer(new SnapshotMetadata(persistenceId, 100, 100), snapshot));
432
433                 verify(mockRaftActor.delegate).applyRecoverySnapshot(eq(snapshotBytes.toByteArray()));
434
435                 mockRaftActor.onReceiveRecover(new ReplicatedLogImplEntry(0, 1, new MockRaftActorContext.MockPayload("A")));
436
437                 ReplicatedLog replicatedLog = mockRaftActor.getReplicatedLog();
438
439                 assertEquals("add replicated log entry", 1, replicatedLog.size());
440
441                 mockRaftActor.onReceiveRecover(new ReplicatedLogImplEntry(1, 1, new MockRaftActorContext.MockPayload("A")));
442
443                 assertEquals("add replicated log entry", 2, replicatedLog.size());
444
445                 mockRaftActor.onReceiveRecover(new ApplyLogEntries(1));
446
447                 assertEquals("commit index 1", 1, mockRaftActor.getRaftActorContext().getCommitIndex());
448
449                 // The snapshot had 4 items + we added 2 more items during the test
450                 // We start removing from 5 and we should get 1 item in the replicated log
451                 mockRaftActor.onReceiveRecover(new RaftActor.DeleteEntries(5));
452
453                 assertEquals("remove log entries", 1, replicatedLog.size());
454
455                 mockRaftActor.onReceiveRecover(new RaftActor.UpdateElectionTerm(10, "foobar"));
456
457                 assertEquals("election term", 10, mockRaftActor.getRaftActorContext().getTermInformation().getCurrentTerm());
458                 assertEquals("voted for", "foobar", mockRaftActor.getRaftActorContext().getTermInformation().getVotedFor());
459
460                 mockRaftActor.onReceiveRecover(mock(RecoveryCompleted.class));
461
462                 mockActorRef.tell(PoisonPill.getInstance(), getRef());
463
464             }};
465     }
466
467     /**
468      * This test verifies that when recovery is not applicable (typically when persistence is false) the RaftActor does
469      * not process recovery messages
470      *
471      * @throws Exception
472      */
473     @Test
474     public void testHandleRecoveryWhenDataPersistenceRecoveryNotApplicable() throws Exception {
475         new JavaTestKit(getSystem()) {
476             {
477                 String persistenceId = "testHandleRecoveryWhenDataPersistenceRecoveryNotApplicable";
478
479                 DefaultConfigParamsImpl config = new DefaultConfigParamsImpl();
480
481                 config.setHeartBeatInterval(new FiniteDuration(1, TimeUnit.DAYS));
482
483                 TestActorRef<MockRaftActor> mockActorRef = TestActorRef.create(getSystem(), MockRaftActor.props(persistenceId,
484                         Collections.<String,String>emptyMap(), Optional.<ConfigParams>of(config), new DataPersistenceProviderMonitor()), persistenceId);
485
486                 MockRaftActor mockRaftActor = mockActorRef.underlyingActor();
487
488                 // Wait for akka's recovery to complete so it doesn't interfere.
489                 mockRaftActor.waitForRecoveryComplete();
490
491                 ByteString snapshotBytes  = fromObject(Arrays.asList(
492                         new MockRaftActorContext.MockPayload("A"),
493                         new MockRaftActorContext.MockPayload("B"),
494                         new MockRaftActorContext.MockPayload("C"),
495                         new MockRaftActorContext.MockPayload("D")));
496
497                 Snapshot snapshot = Snapshot.create(snapshotBytes.toByteArray(),
498                         Lists.<ReplicatedLogEntry>newArrayList(), 3, 1 ,3, 1);
499
500                 mockRaftActor.onReceiveRecover(new SnapshotOffer(new SnapshotMetadata(persistenceId, 100, 100), snapshot));
501
502                 verify(mockRaftActor.delegate, times(0)).applyRecoverySnapshot(any(byte[].class));
503
504                 mockRaftActor.onReceiveRecover(new ReplicatedLogImplEntry(0, 1, new MockRaftActorContext.MockPayload("A")));
505
506                 ReplicatedLog replicatedLog = mockRaftActor.getReplicatedLog();
507
508                 assertEquals("add replicated log entry", 0, replicatedLog.size());
509
510                 mockRaftActor.onReceiveRecover(new ReplicatedLogImplEntry(1, 1, new MockRaftActorContext.MockPayload("A")));
511
512                 assertEquals("add replicated log entry", 0, replicatedLog.size());
513
514                 mockRaftActor.onReceiveRecover(new ApplyLogEntries(1));
515
516                 assertEquals("commit index -1", -1, mockRaftActor.getRaftActorContext().getCommitIndex());
517
518                 mockRaftActor.onReceiveRecover(new RaftActor.DeleteEntries(2));
519
520                 assertEquals("remove log entries", 0, replicatedLog.size());
521
522                 mockRaftActor.onReceiveRecover(new RaftActor.UpdateElectionTerm(10, "foobar"));
523
524                 assertNotEquals("election term", 10, mockRaftActor.getRaftActorContext().getTermInformation().getCurrentTerm());
525                 assertNotEquals("voted for", "foobar", mockRaftActor.getRaftActorContext().getTermInformation().getVotedFor());
526
527                 mockRaftActor.onReceiveRecover(mock(RecoveryCompleted.class));
528
529                 mockActorRef.tell(PoisonPill.getInstance(), getRef());
530             }};
531     }
532
533
534     @Test
535     public void testUpdatingElectionTermCallsDataPersistence() throws Exception {
536         new JavaTestKit(getSystem()) {
537             {
538                 String persistenceId = "testUpdatingElectionTermCallsDataPersistence";
539
540                 DefaultConfigParamsImpl config = new DefaultConfigParamsImpl();
541
542                 config.setHeartBeatInterval(new FiniteDuration(1, TimeUnit.DAYS));
543
544                 CountDownLatch persistLatch = new CountDownLatch(1);
545                 DataPersistenceProviderMonitor dataPersistenceProviderMonitor = new DataPersistenceProviderMonitor();
546                 dataPersistenceProviderMonitor.setPersistLatch(persistLatch);
547
548                 TestActorRef<MockRaftActor> mockActorRef = TestActorRef.create(getSystem(), MockRaftActor.props(persistenceId,
549                         Collections.<String,String>emptyMap(), Optional.<ConfigParams>of(config), dataPersistenceProviderMonitor), persistenceId);
550
551                 MockRaftActor mockRaftActor = mockActorRef.underlyingActor();
552
553                 mockRaftActor.getRaftActorContext().getTermInformation().updateAndPersist(10, "foobar");
554
555                 assertEquals("Persist called", true, persistLatch.await(5, TimeUnit.SECONDS));
556
557                 mockActorRef.tell(PoisonPill.getInstance(), getRef());
558
559             }
560         };
561     }
562
563     @Test
564     public void testAddingReplicatedLogEntryCallsDataPersistence() throws Exception {
565         new JavaTestKit(getSystem()) {
566             {
567                 String persistenceId = "testAddingReplicatedLogEntryCallsDataPersistence";
568
569                 DefaultConfigParamsImpl config = new DefaultConfigParamsImpl();
570
571                 config.setHeartBeatInterval(new FiniteDuration(1, TimeUnit.DAYS));
572
573                 DataPersistenceProvider dataPersistenceProvider = mock(DataPersistenceProvider.class);
574
575                 TestActorRef<MockRaftActor> mockActorRef = TestActorRef.create(getSystem(), MockRaftActor.props(persistenceId,
576                         Collections.<String,String>emptyMap(), Optional.<ConfigParams>of(config), dataPersistenceProvider), persistenceId);
577
578                 MockRaftActor mockRaftActor = mockActorRef.underlyingActor();
579
580                 MockRaftActorContext.MockReplicatedLogEntry logEntry = new MockRaftActorContext.MockReplicatedLogEntry(10, 10, mock(Payload.class));
581
582                 mockRaftActor.getRaftActorContext().getReplicatedLog().appendAndPersist(logEntry);
583
584                 verify(dataPersistenceProvider).persist(eq(logEntry), any(Procedure.class));
585
586                 mockActorRef.tell(PoisonPill.getInstance(), getRef());
587
588             }
589         };
590     }
591
592     @Test
593     public void testRemovingReplicatedLogEntryCallsDataPersistence() throws Exception {
594         new JavaTestKit(getSystem()) {
595             {
596                 String persistenceId = "testRemovingReplicatedLogEntryCallsDataPersistence";
597
598                 DefaultConfigParamsImpl config = new DefaultConfigParamsImpl();
599
600                 config.setHeartBeatInterval(new FiniteDuration(1, TimeUnit.DAYS));
601
602                 DataPersistenceProvider dataPersistenceProvider = mock(DataPersistenceProvider.class);
603
604                 TestActorRef<MockRaftActor> mockActorRef = TestActorRef.create(getSystem(), MockRaftActor.props(persistenceId,
605                         Collections.<String,String>emptyMap(), Optional.<ConfigParams>of(config), dataPersistenceProvider), persistenceId);
606
607                 MockRaftActor mockRaftActor = mockActorRef.underlyingActor();
608
609                 mockRaftActor.getReplicatedLog().appendAndPersist(new MockRaftActorContext.MockReplicatedLogEntry(1, 0, mock(Payload.class)));
610
611                 mockRaftActor.getRaftActorContext().getReplicatedLog().removeFromAndPersist(0);
612
613                 verify(dataPersistenceProvider, times(2)).persist(anyObject(), any(Procedure.class));
614
615                 mockActorRef.tell(PoisonPill.getInstance(), getRef());
616
617             }
618         };
619     }
620
621     @Test
622     public void testApplyLogEntriesCallsDataPersistence() throws Exception {
623         new JavaTestKit(getSystem()) {
624             {
625                 String persistenceId = "testApplyLogEntriesCallsDataPersistence";
626
627                 DefaultConfigParamsImpl config = new DefaultConfigParamsImpl();
628
629                 config.setHeartBeatInterval(new FiniteDuration(1, TimeUnit.DAYS));
630
631                 DataPersistenceProvider dataPersistenceProvider = mock(DataPersistenceProvider.class);
632
633                 TestActorRef<MockRaftActor> mockActorRef = TestActorRef.create(getSystem(), MockRaftActor.props(persistenceId,
634                         Collections.<String,String>emptyMap(), Optional.<ConfigParams>of(config), dataPersistenceProvider), persistenceId);
635
636                 MockRaftActor mockRaftActor = mockActorRef.underlyingActor();
637
638                 mockRaftActor.onReceiveCommand(new ApplyLogEntries(10));
639
640                 verify(dataPersistenceProvider, times(1)).persist(anyObject(), any(Procedure.class));
641
642                 mockActorRef.tell(PoisonPill.getInstance(), getRef());
643
644             }
645         };
646     }
647
648     @Test
649     public void testCaptureSnapshotReplyCallsDataPersistence() throws Exception {
650         new JavaTestKit(getSystem()) {
651             {
652                 String persistenceId = "testCaptureSnapshotReplyCallsDataPersistence";
653
654                 DefaultConfigParamsImpl config = new DefaultConfigParamsImpl();
655
656                 config.setHeartBeatInterval(new FiniteDuration(1, TimeUnit.DAYS));
657
658                 DataPersistenceProvider dataPersistenceProvider = mock(DataPersistenceProvider.class);
659
660                 TestActorRef<MockRaftActor> mockActorRef = TestActorRef.create(getSystem(),
661                     MockRaftActor.props(persistenceId,Collections.<String,String>emptyMap(),
662                         Optional.<ConfigParams>of(config), dataPersistenceProvider), persistenceId);
663
664                 MockRaftActor mockRaftActor = mockActorRef.underlyingActor();
665
666                 ByteString snapshotBytes  = fromObject(Arrays.asList(
667                         new MockRaftActorContext.MockPayload("A"),
668                         new MockRaftActorContext.MockPayload("B"),
669                         new MockRaftActorContext.MockPayload("C"),
670                         new MockRaftActorContext.MockPayload("D")));
671
672                 mockRaftActor.onReceiveCommand(new CaptureSnapshot(-1,1,-1,1));
673
674                 RaftActorContext raftActorContext = mockRaftActor.getRaftActorContext();
675
676                 mockRaftActor.setCurrentBehavior(new Leader(raftActorContext));
677
678                 mockRaftActor.onReceiveCommand(new CaptureSnapshotReply(snapshotBytes.toByteArray()));
679
680                 verify(dataPersistenceProvider).saveSnapshot(anyObject());
681
682                 mockActorRef.tell(PoisonPill.getInstance(), getRef());
683
684             }
685         };
686     }
687
688     @Test
689     public void testSaveSnapshotSuccessCallsDataPersistence() throws Exception {
690         new JavaTestKit(getSystem()) {
691             {
692                 String persistenceId = "testSaveSnapshotSuccessCallsDataPersistence";
693
694                 DefaultConfigParamsImpl config = new DefaultConfigParamsImpl();
695
696                 config.setHeartBeatInterval(new FiniteDuration(1, TimeUnit.DAYS));
697
698                 DataPersistenceProvider dataPersistenceProvider = mock(DataPersistenceProvider.class);
699
700                 TestActorRef<MockRaftActor> mockActorRef = TestActorRef.create(getSystem(), MockRaftActor.props(persistenceId,
701                         Collections.<String,String>emptyMap(), Optional.<ConfigParams>of(config), dataPersistenceProvider), persistenceId);
702
703                 MockRaftActor mockRaftActor = mockActorRef.underlyingActor();
704
705                 mockRaftActor.getReplicatedLog().append(new MockRaftActorContext.MockReplicatedLogEntry(1,0, mock(Payload.class)));
706                 mockRaftActor.getReplicatedLog().append(new MockRaftActorContext.MockReplicatedLogEntry(1,1, mock(Payload.class)));
707                 mockRaftActor.getReplicatedLog().append(new MockRaftActorContext.MockReplicatedLogEntry(1,2, mock(Payload.class)));
708                 mockRaftActor.getReplicatedLog().append(new MockRaftActorContext.MockReplicatedLogEntry(1,3, mock(Payload.class)));
709                 mockRaftActor.getReplicatedLog().append(new MockRaftActorContext.MockReplicatedLogEntry(1,4, mock(Payload.class)));
710
711                 ByteString snapshotBytes = fromObject(Arrays.asList(
712                         new MockRaftActorContext.MockPayload("A"),
713                         new MockRaftActorContext.MockPayload("B"),
714                         new MockRaftActorContext.MockPayload("C"),
715                         new MockRaftActorContext.MockPayload("D")));
716
717                 RaftActorContext raftActorContext = mockRaftActor.getRaftActorContext();
718                 mockRaftActor.setCurrentBehavior(new Follower(raftActorContext));
719
720                 mockRaftActor.onReceiveCommand(new CaptureSnapshot(-1, 1, 2, 1));
721
722                 verify(mockRaftActor.delegate).createSnapshot();
723
724                 mockRaftActor.onReceiveCommand(new CaptureSnapshotReply(snapshotBytes.toByteArray()));
725
726                 mockRaftActor.onReceiveCommand(new SaveSnapshotSuccess(new SnapshotMetadata("foo", 100, 100)));
727
728                 verify(dataPersistenceProvider).deleteSnapshots(any(SnapshotSelectionCriteria.class));
729
730                 verify(dataPersistenceProvider).deleteMessages(100);
731
732                 assertEquals(2, mockRaftActor.getReplicatedLog().size());
733
734                 assertNotNull(mockRaftActor.getReplicatedLog().get(3));
735                 assertNotNull(mockRaftActor.getReplicatedLog().get(4));
736
737                 // Index 2 will not be in the log because it was removed due to snapshotting
738                 assertNull(mockRaftActor.getReplicatedLog().get(2));
739
740                 mockActorRef.tell(PoisonPill.getInstance(), getRef());
741
742             }
743         };
744     }
745
746     @Test
747     public void testApplyState() throws Exception {
748
749         new JavaTestKit(getSystem()) {
750             {
751                 String persistenceId = "testApplyState";
752
753                 DefaultConfigParamsImpl config = new DefaultConfigParamsImpl();
754
755                 config.setHeartBeatInterval(new FiniteDuration(1, TimeUnit.DAYS));
756
757                 DataPersistenceProvider dataPersistenceProvider = mock(DataPersistenceProvider.class);
758
759                 TestActorRef<MockRaftActor> mockActorRef = TestActorRef.create(getSystem(), MockRaftActor.props(persistenceId,
760                         Collections.<String,String>emptyMap(), Optional.<ConfigParams>of(config), dataPersistenceProvider), persistenceId);
761
762                 MockRaftActor mockRaftActor = mockActorRef.underlyingActor();
763
764                 ReplicatedLogEntry entry = new MockRaftActorContext.MockReplicatedLogEntry(1, 5,
765                         new MockRaftActorContext.MockPayload("F"));
766
767                 mockRaftActor.onReceiveCommand(new ApplyState(mockActorRef, "apply-state", entry));
768
769                 verify(mockRaftActor.delegate).applyState(eq(mockActorRef), eq("apply-state"), anyObject());
770
771                 mockActorRef.tell(PoisonPill.getInstance(), getRef());
772
773             }
774         };
775     }
776
777     @Test
778     public void testApplySnapshot() throws Exception {
779         new JavaTestKit(getSystem()) {
780             {
781                 String persistenceId = "testApplySnapshot";
782
783                 DefaultConfigParamsImpl config = new DefaultConfigParamsImpl();
784
785                 config.setHeartBeatInterval(new FiniteDuration(1, TimeUnit.DAYS));
786
787                 DataPersistenceProviderMonitor dataPersistenceProviderMonitor = new DataPersistenceProviderMonitor();
788
789                 TestActorRef<MockRaftActor> mockActorRef = TestActorRef.create(getSystem(), MockRaftActor.props(persistenceId,
790                         Collections.<String,String>emptyMap(), Optional.<ConfigParams>of(config), dataPersistenceProviderMonitor), persistenceId);
791
792                 MockRaftActor mockRaftActor = mockActorRef.underlyingActor();
793
794                 ReplicatedLog oldReplicatedLog = mockRaftActor.getReplicatedLog();
795
796                 oldReplicatedLog.append(new MockRaftActorContext.MockReplicatedLogEntry(1,0,mock(Payload.class)));
797                 oldReplicatedLog.append(new MockRaftActorContext.MockReplicatedLogEntry(1,1,mock(Payload.class)));
798                 oldReplicatedLog.append(
799                     new MockRaftActorContext.MockReplicatedLogEntry(1, 2,
800                         mock(Payload.class)));
801
802                 ByteString snapshotBytes = fromObject(Arrays.asList(
803                     new MockRaftActorContext.MockPayload("A"),
804                     new MockRaftActorContext.MockPayload("B"),
805                     new MockRaftActorContext.MockPayload("C"),
806                     new MockRaftActorContext.MockPayload("D")));
807
808                 Snapshot snapshot = mock(Snapshot.class);
809
810                 doReturn(snapshotBytes.toByteArray()).when(snapshot).getState();
811
812                 doReturn(3L).when(snapshot).getLastAppliedIndex();
813
814                 mockRaftActor.onReceiveCommand(new ApplySnapshot(snapshot));
815
816                 verify(mockRaftActor.delegate).applySnapshot(eq(snapshot.getState()));
817
818                 assertTrue("The replicatedLog should have changed",
819                     oldReplicatedLog != mockRaftActor.getReplicatedLog());
820
821                 assertEquals("lastApplied should be same as in the snapshot",
822                     (Long) 3L, mockRaftActor.getLastApplied());
823
824                 assertEquals(0, mockRaftActor.getReplicatedLog().size());
825
826                 mockActorRef.tell(PoisonPill.getInstance(), getRef());
827
828             }
829         };
830     }
831
832     @Test
833     public void testSaveSnapshotFailure() throws Exception {
834         new JavaTestKit(getSystem()) {
835             {
836                 String persistenceId = "testSaveSnapshotFailure";
837
838                 DefaultConfigParamsImpl config = new DefaultConfigParamsImpl();
839
840                 config.setHeartBeatInterval(new FiniteDuration(1, TimeUnit.DAYS));
841
842                 DataPersistenceProviderMonitor dataPersistenceProviderMonitor = new DataPersistenceProviderMonitor();
843
844                 TestActorRef<MockRaftActor> mockActorRef = TestActorRef.create(getSystem(), MockRaftActor.props(persistenceId,
845                         Collections.<String,String>emptyMap(), Optional.<ConfigParams>of(config), dataPersistenceProviderMonitor), persistenceId);
846
847                 MockRaftActor mockRaftActor = mockActorRef.underlyingActor();
848
849                 ByteString snapshotBytes  = fromObject(Arrays.asList(
850                         new MockRaftActorContext.MockPayload("A"),
851                         new MockRaftActorContext.MockPayload("B"),
852                         new MockRaftActorContext.MockPayload("C"),
853                         new MockRaftActorContext.MockPayload("D")));
854
855                 RaftActorContext raftActorContext = mockRaftActor.getRaftActorContext();
856
857                 mockRaftActor.setCurrentBehavior(new Leader(raftActorContext));
858
859                 mockRaftActor.onReceiveCommand(new CaptureSnapshot(-1,1,-1,1));
860
861                 mockRaftActor.onReceiveCommand(new CaptureSnapshotReply(snapshotBytes.toByteArray()));
862
863                 mockRaftActor.onReceiveCommand(new SaveSnapshotFailure(new SnapshotMetadata("foobar", 10L, 1234L),
864                         new Exception()));
865
866                 assertEquals("Snapshot index should not have advanced because save snapshot failed", -1,
867                         mockRaftActor.getReplicatedLog().getSnapshotIndex());
868
869                 mockActorRef.tell(PoisonPill.getInstance(), getRef());
870
871             }
872         };
873     }
874
875     @Test
876     public void testRaftRoleChangeNotifier() throws Exception {
877         new JavaTestKit(getSystem()) {{
878             ActorRef notifierActor = getSystem().actorOf(Props.create(MessageCollectorActor.class));
879             DefaultConfigParamsImpl config = new DefaultConfigParamsImpl();
880             String id = "testRaftRoleChangeNotifier";
881
882             TestActorRef<MockRaftActor> mockActorRef = TestActorRef.create(getSystem(), MockRaftActor.props(id,
883                 Collections.<String,String>emptyMap(), Optional.<ConfigParams>of(config), notifierActor), id);
884
885             // sleeping for a minimum of 2 seconds, if it spans more its fine.
886             Uninterruptibles.sleepUninterruptibly(2, TimeUnit.SECONDS);
887
888             List<Object> matches =  MessageCollectorActor.getAllMatching(notifierActor, RoleChanged.class);
889             assertNotNull(matches);
890             assertEquals(3, matches.size());
891
892             // check if the notifier got a role change from null to Follower
893             RoleChanged raftRoleChanged = (RoleChanged) matches.get(0);
894             assertEquals(id, raftRoleChanged.getMemberId());
895             assertNull(raftRoleChanged.getOldRole());
896             assertEquals(RaftState.Follower.name(), raftRoleChanged.getNewRole());
897
898             // check if the notifier got a role change from Follower to Candidate
899             raftRoleChanged = (RoleChanged) matches.get(1);
900             assertEquals(id, raftRoleChanged.getMemberId());
901             assertEquals(RaftState.Follower.name(), raftRoleChanged.getOldRole());
902             assertEquals(RaftState.Candidate.name(), raftRoleChanged.getNewRole());
903
904             // check if the notifier got a role change from Candidate to Leader
905             raftRoleChanged = (RoleChanged) matches.get(2);
906             assertEquals(id, raftRoleChanged.getMemberId());
907             assertEquals(RaftState.Candidate.name(), raftRoleChanged.getOldRole());
908             assertEquals(RaftState.Leader.name(), raftRoleChanged.getNewRole());
909         }};
910     }
911
912     private ByteString fromObject(Object snapshot) throws Exception {
913         ByteArrayOutputStream b = null;
914         ObjectOutputStream o = null;
915         try {
916             b = new ByteArrayOutputStream();
917             o = new ObjectOutputStream(b);
918             o.writeObject(snapshot);
919             byte[] snapshotBytes = b.toByteArray();
920             return ByteString.copyFrom(snapshotBytes);
921         } finally {
922             if (o != null) {
923                 o.flush();
924                 o.close();
925             }
926             if (b != null) {
927                 b.close();
928             }
929         }
930     }
931 }