Fix non-generic references to List
[controller.git] / opendaylight / md-sal / sal-akka-raft / src / test / java / org / opendaylight / controller / cluster / raft / RaftActorTest.java
1 package org.opendaylight.controller.cluster.raft;
2
3 import akka.actor.ActorRef;
4 import akka.actor.ActorSystem;
5 import akka.actor.PoisonPill;
6 import akka.actor.Props;
7 import akka.actor.Terminated;
8 import akka.japi.Creator;
9 import akka.japi.Procedure;
10 import akka.pattern.Patterns;
11 import akka.persistence.RecoveryCompleted;
12 import akka.persistence.SaveSnapshotFailure;
13 import akka.persistence.SaveSnapshotSuccess;
14 import akka.persistence.SnapshotMetadata;
15 import akka.persistence.SnapshotOffer;
16 import akka.persistence.SnapshotSelectionCriteria;
17 import akka.testkit.JavaTestKit;
18 import akka.testkit.TestActorRef;
19 import akka.util.Timeout;
20 import com.google.common.base.Optional;
21 import com.google.common.collect.Lists;
22 import com.google.common.util.concurrent.Uninterruptibles;
23 import com.google.protobuf.ByteString;
24 import org.junit.After;
25 import org.junit.Assert;
26 import org.junit.Test;
27 import org.opendaylight.controller.cluster.DataPersistenceProvider;
28 import org.opendaylight.controller.cluster.datastore.DataPersistenceProviderMonitor;
29 import org.opendaylight.controller.cluster.raft.base.messages.ApplyLogEntries;
30 import org.opendaylight.controller.cluster.raft.base.messages.ApplySnapshot;
31 import org.opendaylight.controller.cluster.raft.base.messages.ApplyState;
32 import org.opendaylight.controller.cluster.raft.base.messages.CaptureSnapshot;
33 import org.opendaylight.controller.cluster.raft.base.messages.CaptureSnapshotReply;
34 import org.opendaylight.controller.cluster.raft.behaviors.Follower;
35 import org.opendaylight.controller.cluster.raft.behaviors.Leader;
36 import org.opendaylight.controller.cluster.raft.client.messages.FindLeader;
37 import org.opendaylight.controller.cluster.raft.client.messages.FindLeaderReply;
38 import org.opendaylight.controller.cluster.raft.protobuff.client.messages.Payload;
39 import org.opendaylight.controller.cluster.raft.utils.MockAkkaJournal;
40 import org.opendaylight.controller.cluster.raft.utils.MockSnapshotStore;
41 import scala.concurrent.Await;
42 import scala.concurrent.Future;
43 import scala.concurrent.duration.Duration;
44 import scala.concurrent.duration.FiniteDuration;
45
46 import java.io.ByteArrayInputStream;
47 import java.io.ByteArrayOutputStream;
48 import java.io.IOException;
49 import java.io.ObjectInputStream;
50 import java.io.ObjectOutputStream;
51 import java.util.ArrayList;
52 import java.util.Arrays;
53 import java.util.Collections;
54 import java.util.List;
55 import java.util.Map;
56 import java.util.concurrent.CountDownLatch;
57 import java.util.concurrent.TimeUnit;
58 import java.util.concurrent.TimeoutException;
59
60 import static org.junit.Assert.assertEquals;
61 import static org.junit.Assert.assertNotEquals;
62 import static org.junit.Assert.assertNotNull;
63 import static org.junit.Assert.assertNull;
64 import static org.junit.Assert.assertTrue;
65 import static org.mockito.Matchers.any;
66 import static org.mockito.Matchers.anyObject;
67 import static org.mockito.Matchers.eq;
68 import static org.mockito.Mockito.doReturn;
69 import static org.mockito.Mockito.mock;
70 import static org.mockito.Mockito.times;
71 import static org.mockito.Mockito.verify;
72
73 public class RaftActorTest extends AbstractActorTest {
74
75
76     @After
77     public void tearDown() {
78         MockAkkaJournal.clearJournal();
79         MockSnapshotStore.setMockSnapshot(null);
80     }
81
82     public static class MockRaftActor extends RaftActor {
83
84         private final DataPersistenceProvider dataPersistenceProvider;
85         private final RaftActor delegate;
86
87         public static final class MockRaftActorCreator implements Creator<MockRaftActor> {
88             private final Map<String, String> peerAddresses;
89             private final String id;
90             private final Optional<ConfigParams> config;
91             private final DataPersistenceProvider dataPersistenceProvider;
92
93             private MockRaftActorCreator(Map<String, String> peerAddresses, String id,
94                     Optional<ConfigParams> config, DataPersistenceProvider dataPersistenceProvider) {
95                 this.peerAddresses = peerAddresses;
96                 this.id = id;
97                 this.config = config;
98                 this.dataPersistenceProvider = dataPersistenceProvider;
99             }
100
101             @Override
102             public MockRaftActor create() throws Exception {
103                 return new MockRaftActor(id, peerAddresses, config, dataPersistenceProvider);
104             }
105         }
106
107         private final CountDownLatch recoveryComplete = new CountDownLatch(1);
108
109         private final List<Object> state;
110
111         public MockRaftActor(String id, Map<String, String> peerAddresses, Optional<ConfigParams> config, DataPersistenceProvider dataPersistenceProvider) {
112             super(id, peerAddresses, config);
113             state = new ArrayList<>();
114             this.delegate = mock(RaftActor.class);
115             if(dataPersistenceProvider == null){
116                 this.dataPersistenceProvider = new PersistentDataProvider();
117             } else {
118                 this.dataPersistenceProvider = dataPersistenceProvider;
119             }
120         }
121
122         public void waitForRecoveryComplete() {
123             try {
124                 assertEquals("Recovery complete", true, recoveryComplete.await(5,  TimeUnit.SECONDS));
125             } catch (InterruptedException e) {
126                 e.printStackTrace();
127             }
128         }
129
130         public List<Object> getState() {
131             return state;
132         }
133
134         public static Props props(final String id, final Map<String, String> peerAddresses,
135                 Optional<ConfigParams> config){
136             return Props.create(new MockRaftActorCreator(peerAddresses, id, config, null));
137         }
138
139         public static Props props(final String id, final Map<String, String> peerAddresses,
140                                   Optional<ConfigParams> config, DataPersistenceProvider dataPersistenceProvider){
141             return Props.create(new MockRaftActorCreator(peerAddresses, id, config, dataPersistenceProvider));
142         }
143
144
145         @Override protected void applyState(ActorRef clientActor, String identifier, Object data) {
146             delegate.applyState(clientActor, identifier, data);
147             LOG.info("applyState called");
148         }
149
150
151
152
153         @Override
154         protected void startLogRecoveryBatch(int maxBatchSize) {
155         }
156
157         @Override
158         protected void appendRecoveredLogEntry(Payload data) {
159             state.add(data);
160         }
161
162         @Override
163         protected void applyCurrentLogRecoveryBatch() {
164         }
165
166         @Override
167         protected void onRecoveryComplete() {
168             delegate.onRecoveryComplete();
169             recoveryComplete.countDown();
170         }
171
172         @Override
173         protected void applyRecoverySnapshot(ByteString snapshot) {
174             delegate.applyRecoverySnapshot(snapshot);
175             try {
176                 Object data = toObject(snapshot);
177                 System.out.println("!!!!!applyRecoverySnapshot: "+data);
178                 if (data instanceof List) {
179                     state.addAll((List<?>) data);
180                 }
181             } catch (Exception e) {
182                 e.printStackTrace();
183             }
184         }
185
186         @Override protected void createSnapshot() {
187             delegate.createSnapshot();
188         }
189
190         @Override protected void applySnapshot(ByteString snapshot) {
191             delegate.applySnapshot(snapshot);
192         }
193
194         @Override protected void onStateChanged() {
195             delegate.onStateChanged();
196         }
197
198         @Override
199         protected DataPersistenceProvider persistence() {
200             return this.dataPersistenceProvider;
201         }
202
203         @Override public String persistenceId() {
204             return this.getId();
205         }
206
207         private Object toObject(ByteString bs) throws ClassNotFoundException, IOException {
208             Object obj = null;
209             ByteArrayInputStream bis = null;
210             ObjectInputStream ois = null;
211             try {
212                 bis = new ByteArrayInputStream(bs.toByteArray());
213                 ois = new ObjectInputStream(bis);
214                 obj = ois.readObject();
215             } finally {
216                 if (bis != null) {
217                     bis.close();
218                 }
219                 if (ois != null) {
220                     ois.close();
221                 }
222             }
223             return obj;
224         }
225
226         public ReplicatedLog getReplicatedLog(){
227             return this.getRaftActorContext().getReplicatedLog();
228         }
229
230     }
231
232
233     private static class RaftActorTestKit extends JavaTestKit {
234         private final ActorRef raftActor;
235
236         public RaftActorTestKit(ActorSystem actorSystem, String actorName) {
237             super(actorSystem);
238
239             raftActor = this.getSystem().actorOf(MockRaftActor.props(actorName,
240                     Collections.EMPTY_MAP, Optional.<ConfigParams>absent()), actorName);
241
242         }
243
244
245         public ActorRef getRaftActor() {
246             return raftActor;
247         }
248
249         public boolean waitForLogMessage(final Class<?> logEventClass, String message){
250             // Wait for a specific log message to show up
251             return
252                 new JavaTestKit.EventFilter<Boolean>(logEventClass
253                 ) {
254                     @Override
255                     protected Boolean run() {
256                         return true;
257                     }
258                 }.from(raftActor.path().toString())
259                     .message(message)
260                     .occurrences(1).exec();
261
262
263         }
264
265         protected void waitUntilLeader(){
266             waitUntilLeader(raftActor);
267         }
268
269         protected void waitUntilLeader(ActorRef actorRef) {
270             FiniteDuration duration = Duration.create(100, TimeUnit.MILLISECONDS);
271             for(int i = 0; i < 20 * 5; i++) {
272                 Future<Object> future = Patterns.ask(actorRef, new FindLeader(), new Timeout(duration));
273                 try {
274                     FindLeaderReply resp = (FindLeaderReply) Await.result(future, duration);
275                     if(resp.getLeaderActor() != null) {
276                         return;
277                     }
278                 } catch(TimeoutException e) {
279                 } catch(Exception e) {
280                     System.err.println("FindLeader threw ex");
281                     e.printStackTrace();
282                 }
283
284
285                 Uninterruptibles.sleepUninterruptibly(50, TimeUnit.MILLISECONDS);
286             }
287
288             Assert.fail("Leader not found for actorRef " + actorRef.path());
289         }
290
291     }
292
293
294     @Test
295     public void testConstruction() {
296         new RaftActorTestKit(getSystem(), "testConstruction").waitUntilLeader();
297     }
298
299     @Test
300     public void testFindLeaderWhenLeaderIsSelf(){
301         RaftActorTestKit kit = new RaftActorTestKit(getSystem(), "testFindLeader");
302         kit.waitUntilLeader();
303     }
304
305     @Test
306     public void testRaftActorRecovery() throws Exception {
307         new JavaTestKit(getSystem()) {{
308             String persistenceId = "follower10";
309
310             DefaultConfigParamsImpl config = new DefaultConfigParamsImpl();
311             // Set the heartbeat interval high to essentially disable election otherwise the test
312             // may fail if the actor is switched to Leader and the commitIndex is set to the last
313             // log entry.
314             config.setHeartBeatInterval(new FiniteDuration(1, TimeUnit.DAYS));
315
316             ActorRef followerActor = getSystem().actorOf(MockRaftActor.props(persistenceId,
317                     Collections.EMPTY_MAP, Optional.<ConfigParams>of(config)), persistenceId);
318
319             watch(followerActor);
320
321             List<ReplicatedLogEntry> snapshotUnappliedEntries = new ArrayList<>();
322             ReplicatedLogEntry entry1 = new MockRaftActorContext.MockReplicatedLogEntry(1, 4,
323                     new MockRaftActorContext.MockPayload("E"));
324             snapshotUnappliedEntries.add(entry1);
325
326             int lastAppliedDuringSnapshotCapture = 3;
327             int lastIndexDuringSnapshotCapture = 4;
328
329                 // 4 messages as part of snapshot, which are applied to state
330             ByteString snapshotBytes  = fromObject(Arrays.asList(
331                         new MockRaftActorContext.MockPayload("A"),
332                         new MockRaftActorContext.MockPayload("B"),
333                         new MockRaftActorContext.MockPayload("C"),
334                         new MockRaftActorContext.MockPayload("D")));
335
336             Snapshot snapshot = Snapshot.create(snapshotBytes.toByteArray(),
337                     snapshotUnappliedEntries, lastIndexDuringSnapshotCapture, 1 ,
338                     lastAppliedDuringSnapshotCapture, 1);
339             MockSnapshotStore.setMockSnapshot(snapshot);
340             MockSnapshotStore.setPersistenceId(persistenceId);
341
342             // add more entries after snapshot is taken
343             List<ReplicatedLogEntry> entries = new ArrayList<>();
344             ReplicatedLogEntry entry2 = new MockRaftActorContext.MockReplicatedLogEntry(1, 5,
345                     new MockRaftActorContext.MockPayload("F"));
346             ReplicatedLogEntry entry3 = new MockRaftActorContext.MockReplicatedLogEntry(1, 6,
347                     new MockRaftActorContext.MockPayload("G"));
348             ReplicatedLogEntry entry4 = new MockRaftActorContext.MockReplicatedLogEntry(1, 7,
349                     new MockRaftActorContext.MockPayload("H"));
350             entries.add(entry2);
351             entries.add(entry3);
352             entries.add(entry4);
353
354             int lastAppliedToState = 5;
355             int lastIndex = 7;
356
357             MockAkkaJournal.addToJournal(5, entry2);
358             // 2 entries are applied to state besides the 4 entries in snapshot
359             MockAkkaJournal.addToJournal(6, new ApplyLogEntries(lastAppliedToState));
360             MockAkkaJournal.addToJournal(7, entry3);
361             MockAkkaJournal.addToJournal(8, entry4);
362
363             // kill the actor
364             followerActor.tell(PoisonPill.getInstance(), null);
365             expectMsgClass(duration("5 seconds"), Terminated.class);
366
367             unwatch(followerActor);
368
369             //reinstate the actor
370             TestActorRef<MockRaftActor> ref = TestActorRef.create(getSystem(),
371                     MockRaftActor.props(persistenceId, Collections.EMPTY_MAP,
372                             Optional.<ConfigParams>of(config)));
373
374             ref.underlyingActor().waitForRecoveryComplete();
375
376             RaftActorContext context = ref.underlyingActor().getRaftActorContext();
377             assertEquals("Journal log size", snapshotUnappliedEntries.size() + entries.size(),
378                     context.getReplicatedLog().size());
379             assertEquals("Last index", lastIndex, context.getReplicatedLog().lastIndex());
380             assertEquals("Last applied", lastAppliedToState, context.getLastApplied());
381             assertEquals("Commit index", lastAppliedToState, context.getCommitIndex());
382             assertEquals("Recovered state size", 6, ref.underlyingActor().getState().size());
383         }};
384     }
385
386     /**
387      * This test verifies that when recovery is applicable (typically when persistence is true) the RaftActor does
388      * process recovery messages
389      *
390      * @throws Exception
391      */
392
393     @Test
394     public void testHandleRecoveryWhenDataPersistenceRecoveryApplicable() throws Exception {
395         new JavaTestKit(getSystem()) {
396             {
397                 String persistenceId = "testHandleRecoveryWhenDataPersistenceRecoveryApplicable";
398
399                 DefaultConfigParamsImpl config = new DefaultConfigParamsImpl();
400
401                 config.setHeartBeatInterval(new FiniteDuration(1, TimeUnit.DAYS));
402
403                 TestActorRef<MockRaftActor> mockActorRef = TestActorRef.create(getSystem(), MockRaftActor.props(persistenceId,
404                         Collections.EMPTY_MAP, Optional.<ConfigParams>of(config)), persistenceId);
405
406                 MockRaftActor mockRaftActor = mockActorRef.underlyingActor();
407
408                 // Wait for akka's recovery to complete so it doesn't interfere.
409                 mockRaftActor.waitForRecoveryComplete();
410
411                 ByteString snapshotBytes  = fromObject(Arrays.asList(
412                         new MockRaftActorContext.MockPayload("A"),
413                         new MockRaftActorContext.MockPayload("B"),
414                         new MockRaftActorContext.MockPayload("C"),
415                         new MockRaftActorContext.MockPayload("D")));
416
417                 Snapshot snapshot = Snapshot.create(snapshotBytes.toByteArray(),
418                         Lists.<ReplicatedLogEntry>newArrayList(), 3, 1 ,3, 1);
419
420                 mockRaftActor.onReceiveRecover(new SnapshotOffer(new SnapshotMetadata(persistenceId, 100, 100), snapshot));
421
422                 verify(mockRaftActor.delegate).applyRecoverySnapshot(eq(snapshotBytes));
423
424                 mockRaftActor.onReceiveRecover(new ReplicatedLogImplEntry(0, 1, new MockRaftActorContext.MockPayload("A")));
425
426                 ReplicatedLog replicatedLog = mockRaftActor.getReplicatedLog();
427
428                 assertEquals("add replicated log entry", 1, replicatedLog.size());
429
430                 mockRaftActor.onReceiveRecover(new ReplicatedLogImplEntry(1, 1, new MockRaftActorContext.MockPayload("A")));
431
432                 assertEquals("add replicated log entry", 2, replicatedLog.size());
433
434                 mockRaftActor.onReceiveRecover(new ApplyLogEntries(1));
435
436                 assertEquals("commit index 1", 1, mockRaftActor.getRaftActorContext().getCommitIndex());
437
438                 // The snapshot had 4 items + we added 2 more items during the test
439                 // We start removing from 5 and we should get 1 item in the replicated log
440                 mockRaftActor.onReceiveRecover(new RaftActor.DeleteEntries(5));
441
442                 assertEquals("remove log entries", 1, replicatedLog.size());
443
444                 mockRaftActor.onReceiveRecover(new RaftActor.UpdateElectionTerm(10, "foobar"));
445
446                 assertEquals("election term", 10, mockRaftActor.getRaftActorContext().getTermInformation().getCurrentTerm());
447                 assertEquals("voted for", "foobar", mockRaftActor.getRaftActorContext().getTermInformation().getVotedFor());
448
449                 mockRaftActor.onReceiveRecover(mock(RecoveryCompleted.class));
450
451                 mockActorRef.tell(PoisonPill.getInstance(), getRef());
452
453             }};
454     }
455
456     /**
457      * This test verifies that when recovery is not applicable (typically when persistence is false) the RaftActor does
458      * not process recovery messages
459      *
460      * @throws Exception
461      */
462     @Test
463     public void testHandleRecoveryWhenDataPersistenceRecoveryNotApplicable() throws Exception {
464         new JavaTestKit(getSystem()) {
465             {
466                 String persistenceId = "testHandleRecoveryWhenDataPersistenceRecoveryNotApplicable";
467
468                 DefaultConfigParamsImpl config = new DefaultConfigParamsImpl();
469
470                 config.setHeartBeatInterval(new FiniteDuration(1, TimeUnit.DAYS));
471
472                 TestActorRef<MockRaftActor> mockActorRef = TestActorRef.create(getSystem(), MockRaftActor.props(persistenceId,
473                         Collections.EMPTY_MAP, Optional.<ConfigParams>of(config), new DataPersistenceProviderMonitor()), persistenceId);
474
475                 MockRaftActor mockRaftActor = mockActorRef.underlyingActor();
476
477                 // Wait for akka's recovery to complete so it doesn't interfere.
478                 mockRaftActor.waitForRecoveryComplete();
479
480                 ByteString snapshotBytes  = fromObject(Arrays.asList(
481                         new MockRaftActorContext.MockPayload("A"),
482                         new MockRaftActorContext.MockPayload("B"),
483                         new MockRaftActorContext.MockPayload("C"),
484                         new MockRaftActorContext.MockPayload("D")));
485
486                 Snapshot snapshot = Snapshot.create(snapshotBytes.toByteArray(),
487                         Lists.<ReplicatedLogEntry>newArrayList(), 3, 1 ,3, 1);
488
489                 mockRaftActor.onReceiveRecover(new SnapshotOffer(new SnapshotMetadata(persistenceId, 100, 100), snapshot));
490
491                 verify(mockRaftActor.delegate, times(0)).applyRecoverySnapshot(any(ByteString.class));
492
493                 mockRaftActor.onReceiveRecover(new ReplicatedLogImplEntry(0, 1, new MockRaftActorContext.MockPayload("A")));
494
495                 ReplicatedLog replicatedLog = mockRaftActor.getReplicatedLog();
496
497                 assertEquals("add replicated log entry", 0, replicatedLog.size());
498
499                 mockRaftActor.onReceiveRecover(new ReplicatedLogImplEntry(1, 1, new MockRaftActorContext.MockPayload("A")));
500
501                 assertEquals("add replicated log entry", 0, replicatedLog.size());
502
503                 mockRaftActor.onReceiveRecover(new ApplyLogEntries(1));
504
505                 assertEquals("commit index -1", -1, mockRaftActor.getRaftActorContext().getCommitIndex());
506
507                 mockRaftActor.onReceiveRecover(new RaftActor.DeleteEntries(2));
508
509                 assertEquals("remove log entries", 0, replicatedLog.size());
510
511                 mockRaftActor.onReceiveRecover(new RaftActor.UpdateElectionTerm(10, "foobar"));
512
513                 assertNotEquals("election term", 10, mockRaftActor.getRaftActorContext().getTermInformation().getCurrentTerm());
514                 assertNotEquals("voted for", "foobar", mockRaftActor.getRaftActorContext().getTermInformation().getVotedFor());
515
516                 mockRaftActor.onReceiveRecover(mock(RecoveryCompleted.class));
517
518                 mockActorRef.tell(PoisonPill.getInstance(), getRef());
519             }};
520     }
521
522
523     @Test
524     public void testUpdatingElectionTermCallsDataPersistence() throws Exception {
525         new JavaTestKit(getSystem()) {
526             {
527                 String persistenceId = "testUpdatingElectionTermCallsDataPersistence";
528
529                 DefaultConfigParamsImpl config = new DefaultConfigParamsImpl();
530
531                 config.setHeartBeatInterval(new FiniteDuration(1, TimeUnit.DAYS));
532
533                 CountDownLatch persistLatch = new CountDownLatch(1);
534                 DataPersistenceProviderMonitor dataPersistenceProviderMonitor = new DataPersistenceProviderMonitor();
535                 dataPersistenceProviderMonitor.setPersistLatch(persistLatch);
536
537                 TestActorRef<MockRaftActor> mockActorRef = TestActorRef.create(getSystem(), MockRaftActor.props(persistenceId,
538                         Collections.EMPTY_MAP, Optional.<ConfigParams>of(config), dataPersistenceProviderMonitor), persistenceId);
539
540                 MockRaftActor mockRaftActor = mockActorRef.underlyingActor();
541
542                 mockRaftActor.getRaftActorContext().getTermInformation().updateAndPersist(10, "foobar");
543
544                 assertEquals("Persist called", true, persistLatch.await(5, TimeUnit.SECONDS));
545
546                 mockActorRef.tell(PoisonPill.getInstance(), getRef());
547
548             }
549         };
550     }
551
552     @Test
553     public void testAddingReplicatedLogEntryCallsDataPersistence() throws Exception {
554         new JavaTestKit(getSystem()) {
555             {
556                 String persistenceId = "testAddingReplicatedLogEntryCallsDataPersistence";
557
558                 DefaultConfigParamsImpl config = new DefaultConfigParamsImpl();
559
560                 config.setHeartBeatInterval(new FiniteDuration(1, TimeUnit.DAYS));
561
562                 DataPersistenceProvider dataPersistenceProvider = mock(DataPersistenceProvider.class);
563
564                 TestActorRef<MockRaftActor> mockActorRef = TestActorRef.create(getSystem(), MockRaftActor.props(persistenceId,
565                         Collections.EMPTY_MAP, Optional.<ConfigParams>of(config), dataPersistenceProvider), persistenceId);
566
567                 MockRaftActor mockRaftActor = mockActorRef.underlyingActor();
568
569                 MockRaftActorContext.MockReplicatedLogEntry logEntry = new MockRaftActorContext.MockReplicatedLogEntry(10, 10, mock(Payload.class));
570
571                 mockRaftActor.getRaftActorContext().getReplicatedLog().appendAndPersist(logEntry);
572
573                 verify(dataPersistenceProvider).persist(eq(logEntry), any(Procedure.class));
574
575                 mockActorRef.tell(PoisonPill.getInstance(), getRef());
576
577             }
578         };
579     }
580
581     @Test
582     public void testRemovingReplicatedLogEntryCallsDataPersistence() throws Exception {
583         new JavaTestKit(getSystem()) {
584             {
585                 String persistenceId = "testRemovingReplicatedLogEntryCallsDataPersistence";
586
587                 DefaultConfigParamsImpl config = new DefaultConfigParamsImpl();
588
589                 config.setHeartBeatInterval(new FiniteDuration(1, TimeUnit.DAYS));
590
591                 DataPersistenceProvider dataPersistenceProvider = mock(DataPersistenceProvider.class);
592
593                 TestActorRef<MockRaftActor> mockActorRef = TestActorRef.create(getSystem(), MockRaftActor.props(persistenceId,
594                         Collections.EMPTY_MAP, Optional.<ConfigParams>of(config), dataPersistenceProvider), persistenceId);
595
596                 MockRaftActor mockRaftActor = mockActorRef.underlyingActor();
597
598                 mockRaftActor.getReplicatedLog().appendAndPersist(new MockRaftActorContext.MockReplicatedLogEntry(1, 0, mock(Payload.class)));
599
600                 mockRaftActor.getRaftActorContext().getReplicatedLog().removeFromAndPersist(0);
601
602                 verify(dataPersistenceProvider, times(2)).persist(anyObject(), any(Procedure.class));
603
604                 mockActorRef.tell(PoisonPill.getInstance(), getRef());
605
606             }
607         };
608     }
609
610     @Test
611     public void testApplyLogEntriesCallsDataPersistence() throws Exception {
612         new JavaTestKit(getSystem()) {
613             {
614                 String persistenceId = "testApplyLogEntriesCallsDataPersistence";
615
616                 DefaultConfigParamsImpl config = new DefaultConfigParamsImpl();
617
618                 config.setHeartBeatInterval(new FiniteDuration(1, TimeUnit.DAYS));
619
620                 DataPersistenceProvider dataPersistenceProvider = mock(DataPersistenceProvider.class);
621
622                 TestActorRef<MockRaftActor> mockActorRef = TestActorRef.create(getSystem(), MockRaftActor.props(persistenceId,
623                         Collections.EMPTY_MAP, Optional.<ConfigParams>of(config), dataPersistenceProvider), persistenceId);
624
625                 MockRaftActor mockRaftActor = mockActorRef.underlyingActor();
626
627                 mockRaftActor.onReceiveCommand(new ApplyLogEntries(10));
628
629                 verify(dataPersistenceProvider, times(1)).persist(anyObject(), any(Procedure.class));
630
631                 mockActorRef.tell(PoisonPill.getInstance(), getRef());
632
633             }
634         };
635     }
636
637     @Test
638     public void testCaptureSnapshotReplyCallsDataPersistence() throws Exception {
639         new JavaTestKit(getSystem()) {
640             {
641                 String persistenceId = "testCaptureSnapshotReplyCallsDataPersistence";
642
643                 DefaultConfigParamsImpl config = new DefaultConfigParamsImpl();
644
645                 config.setHeartBeatInterval(new FiniteDuration(1, TimeUnit.DAYS));
646
647                 DataPersistenceProvider dataPersistenceProvider = mock(DataPersistenceProvider.class);
648
649                 TestActorRef<MockRaftActor> mockActorRef = TestActorRef.create(getSystem(),
650                     MockRaftActor.props(persistenceId,Collections.EMPTY_MAP,
651                         Optional.<ConfigParams>of(config), dataPersistenceProvider), persistenceId);
652
653                 MockRaftActor mockRaftActor = mockActorRef.underlyingActor();
654
655                 ByteString snapshotBytes  = fromObject(Arrays.asList(
656                         new MockRaftActorContext.MockPayload("A"),
657                         new MockRaftActorContext.MockPayload("B"),
658                         new MockRaftActorContext.MockPayload("C"),
659                         new MockRaftActorContext.MockPayload("D")));
660
661                 mockRaftActor.onReceiveCommand(new CaptureSnapshot(-1,1,-1,1));
662
663                 RaftActorContext raftActorContext = mockRaftActor.getRaftActorContext();
664
665                 mockRaftActor.setCurrentBehavior(new Leader(raftActorContext));
666
667                 mockRaftActor.onReceiveCommand(new CaptureSnapshotReply(snapshotBytes));
668
669                 verify(dataPersistenceProvider).saveSnapshot(anyObject());
670
671                 mockActorRef.tell(PoisonPill.getInstance(), getRef());
672
673             }
674         };
675     }
676
677     @Test
678     public void testSaveSnapshotSuccessCallsDataPersistence() throws Exception {
679         new JavaTestKit(getSystem()) {
680             {
681                 String persistenceId = "testSaveSnapshotSuccessCallsDataPersistence";
682
683                 DefaultConfigParamsImpl config = new DefaultConfigParamsImpl();
684
685                 config.setHeartBeatInterval(new FiniteDuration(1, TimeUnit.DAYS));
686
687                 DataPersistenceProvider dataPersistenceProvider = mock(DataPersistenceProvider.class);
688
689                 TestActorRef<MockRaftActor> mockActorRef = TestActorRef.create(getSystem(), MockRaftActor.props(persistenceId,
690                         Collections.EMPTY_MAP, Optional.<ConfigParams>of(config), dataPersistenceProvider), persistenceId);
691
692                 MockRaftActor mockRaftActor = mockActorRef.underlyingActor();
693
694                 mockRaftActor.getReplicatedLog().append(new MockRaftActorContext.MockReplicatedLogEntry(1,0, mock(Payload.class)));
695                 mockRaftActor.getReplicatedLog().append(new MockRaftActorContext.MockReplicatedLogEntry(1,1, mock(Payload.class)));
696                 mockRaftActor.getReplicatedLog().append(new MockRaftActorContext.MockReplicatedLogEntry(1,2, mock(Payload.class)));
697                 mockRaftActor.getReplicatedLog().append(new MockRaftActorContext.MockReplicatedLogEntry(1,3, mock(Payload.class)));
698                 mockRaftActor.getReplicatedLog().append(new MockRaftActorContext.MockReplicatedLogEntry(1,4, mock(Payload.class)));
699
700                 ByteString snapshotBytes = fromObject(Arrays.asList(
701                         new MockRaftActorContext.MockPayload("A"),
702                         new MockRaftActorContext.MockPayload("B"),
703                         new MockRaftActorContext.MockPayload("C"),
704                         new MockRaftActorContext.MockPayload("D")));
705
706                 RaftActorContext raftActorContext = mockRaftActor.getRaftActorContext();
707                 mockRaftActor.setCurrentBehavior(new Follower(raftActorContext));
708
709                 mockRaftActor.onReceiveCommand(new CaptureSnapshot(-1, 1, 2, 1));
710
711                 verify(mockRaftActor.delegate).createSnapshot();
712
713                 mockRaftActor.onReceiveCommand(new CaptureSnapshotReply(snapshotBytes));
714
715                 mockRaftActor.onReceiveCommand(new SaveSnapshotSuccess(new SnapshotMetadata("foo", 100, 100)));
716
717                 verify(dataPersistenceProvider).deleteSnapshots(any(SnapshotSelectionCriteria.class));
718
719                 verify(dataPersistenceProvider).deleteMessages(100);
720
721                 assertEquals(2, mockRaftActor.getReplicatedLog().size());
722
723                 assertNotNull(mockRaftActor.getReplicatedLog().get(3));
724                 assertNotNull(mockRaftActor.getReplicatedLog().get(4));
725
726                 // Index 2 will not be in the log because it was removed due to snapshotting
727                 assertNull(mockRaftActor.getReplicatedLog().get(2));
728
729                 mockActorRef.tell(PoisonPill.getInstance(), getRef());
730
731             }
732         };
733     }
734
735     @Test
736     public void testApplyState() throws Exception {
737
738         new JavaTestKit(getSystem()) {
739             {
740                 String persistenceId = "testApplyState";
741
742                 DefaultConfigParamsImpl config = new DefaultConfigParamsImpl();
743
744                 config.setHeartBeatInterval(new FiniteDuration(1, TimeUnit.DAYS));
745
746                 DataPersistenceProvider dataPersistenceProvider = mock(DataPersistenceProvider.class);
747
748                 TestActorRef<MockRaftActor> mockActorRef = TestActorRef.create(getSystem(), MockRaftActor.props(persistenceId,
749                         Collections.EMPTY_MAP, Optional.<ConfigParams>of(config), dataPersistenceProvider), persistenceId);
750
751                 MockRaftActor mockRaftActor = mockActorRef.underlyingActor();
752
753                 ReplicatedLogEntry entry = new MockRaftActorContext.MockReplicatedLogEntry(1, 5,
754                         new MockRaftActorContext.MockPayload("F"));
755
756                 mockRaftActor.onReceiveCommand(new ApplyState(mockActorRef, "apply-state", entry));
757
758                 verify(mockRaftActor.delegate).applyState(eq(mockActorRef), eq("apply-state"), anyObject());
759
760                 mockActorRef.tell(PoisonPill.getInstance(), getRef());
761
762             }
763         };
764     }
765
766     @Test
767     public void testApplySnapshot() throws Exception {
768         new JavaTestKit(getSystem()) {
769             {
770                 String persistenceId = "testApplySnapshot";
771
772                 DefaultConfigParamsImpl config = new DefaultConfigParamsImpl();
773
774                 config.setHeartBeatInterval(new FiniteDuration(1, TimeUnit.DAYS));
775
776                 DataPersistenceProviderMonitor dataPersistenceProviderMonitor = new DataPersistenceProviderMonitor();
777
778                 TestActorRef<MockRaftActor> mockActorRef = TestActorRef.create(getSystem(), MockRaftActor.props(persistenceId,
779                         Collections.EMPTY_MAP, Optional.<ConfigParams>of(config), dataPersistenceProviderMonitor), persistenceId);
780
781                 MockRaftActor mockRaftActor = mockActorRef.underlyingActor();
782
783                 ReplicatedLog oldReplicatedLog = mockRaftActor.getReplicatedLog();
784
785                 oldReplicatedLog.append(new MockRaftActorContext.MockReplicatedLogEntry(1,0,mock(Payload.class)));
786                 oldReplicatedLog.append(new MockRaftActorContext.MockReplicatedLogEntry(1,1,mock(Payload.class)));
787                 oldReplicatedLog.append(
788                     new MockRaftActorContext.MockReplicatedLogEntry(1, 2,
789                         mock(Payload.class)));
790
791                 ByteString snapshotBytes = fromObject(Arrays.asList(
792                     new MockRaftActorContext.MockPayload("A"),
793                     new MockRaftActorContext.MockPayload("B"),
794                     new MockRaftActorContext.MockPayload("C"),
795                     new MockRaftActorContext.MockPayload("D")));
796
797                 Snapshot snapshot = mock(Snapshot.class);
798
799                 doReturn(snapshotBytes.toByteArray()).when(snapshot).getState();
800
801                 doReturn(3L).when(snapshot).getLastAppliedIndex();
802
803                 mockRaftActor.onReceiveCommand(new ApplySnapshot(snapshot));
804
805                 verify(mockRaftActor.delegate).applySnapshot(eq(snapshotBytes));
806
807                 assertTrue("The replicatedLog should have changed",
808                     oldReplicatedLog != mockRaftActor.getReplicatedLog());
809
810                 assertEquals("lastApplied should be same as in the snapshot",
811                     (Long) 3L, mockRaftActor.getLastApplied());
812
813                 assertEquals(0, mockRaftActor.getReplicatedLog().size());
814
815                 mockActorRef.tell(PoisonPill.getInstance(), getRef());
816
817             }
818         };
819     }
820
821     @Test
822     public void testSaveSnapshotFailure() throws Exception {
823         new JavaTestKit(getSystem()) {
824             {
825                 String persistenceId = "testSaveSnapshotFailure";
826
827                 DefaultConfigParamsImpl config = new DefaultConfigParamsImpl();
828
829                 config.setHeartBeatInterval(new FiniteDuration(1, TimeUnit.DAYS));
830
831                 DataPersistenceProviderMonitor dataPersistenceProviderMonitor = new DataPersistenceProviderMonitor();
832
833                 TestActorRef<MockRaftActor> mockActorRef = TestActorRef.create(getSystem(), MockRaftActor.props(persistenceId,
834                         Collections.EMPTY_MAP, Optional.<ConfigParams>of(config), dataPersistenceProviderMonitor), persistenceId);
835
836                 MockRaftActor mockRaftActor = mockActorRef.underlyingActor();
837
838                 ByteString snapshotBytes  = fromObject(Arrays.asList(
839                         new MockRaftActorContext.MockPayload("A"),
840                         new MockRaftActorContext.MockPayload("B"),
841                         new MockRaftActorContext.MockPayload("C"),
842                         new MockRaftActorContext.MockPayload("D")));
843
844                 RaftActorContext raftActorContext = mockRaftActor.getRaftActorContext();
845
846                 mockRaftActor.setCurrentBehavior(new Leader(raftActorContext));
847
848                 mockRaftActor.onReceiveCommand(new CaptureSnapshot(-1,1,-1,1));
849
850                 mockRaftActor.onReceiveCommand(new CaptureSnapshotReply(snapshotBytes));
851
852                 mockRaftActor.onReceiveCommand(new SaveSnapshotFailure(new SnapshotMetadata("foobar", 10L, 1234L),
853                         new Exception()));
854
855                 assertEquals("Snapshot index should not have advanced because save snapshot failed", -1,
856                         mockRaftActor.getReplicatedLog().getSnapshotIndex());
857
858                 mockActorRef.tell(PoisonPill.getInstance(), getRef());
859
860             }
861         };
862     }
863
864     private ByteString fromObject(Object snapshot) throws Exception {
865         ByteArrayOutputStream b = null;
866         ObjectOutputStream o = null;
867         try {
868             b = new ByteArrayOutputStream();
869             o = new ObjectOutputStream(b);
870             o.writeObject(snapshot);
871             byte[] snapshotBytes = b.toByteArray();
872             return ByteString.copyFrom(snapshotBytes);
873         } finally {
874             if (o != null) {
875                 o.flush();
876                 o.close();
877             }
878             if (b != null) {
879                 b.close();
880             }
881         }
882     }
883 }