Bug-1903:On recovery all replicated log entries should not be applied to state
[controller.git] / opendaylight / md-sal / sal-akka-raft / src / test / java / org / opendaylight / controller / cluster / raft / RaftActorTest.java
index ff0ffeb271b55b38455901d03cb31871620b078c..998c198756d191df038de6f6cc75a8091fd0d1f1 100644 (file)
@@ -2,27 +2,65 @@ package org.opendaylight.controller.cluster.raft;
 
 import akka.actor.ActorRef;
 import akka.actor.ActorSystem;
+import akka.actor.PoisonPill;
 import akka.actor.Props;
 import akka.event.Logging;
 import akka.japi.Creator;
 import akka.testkit.JavaTestKit;
+import akka.testkit.TestActorRef;
+import com.google.protobuf.ByteString;
+import org.junit.After;
 import org.junit.Test;
+import org.opendaylight.controller.cluster.raft.base.messages.ApplyLogEntries;
 import org.opendaylight.controller.cluster.raft.client.messages.FindLeader;
 import org.opendaylight.controller.cluster.raft.client.messages.FindLeaderReply;
-
+import org.opendaylight.controller.cluster.raft.utils.MockAkkaJournal;
+import org.opendaylight.controller.cluster.raft.utils.MockSnapshotStore;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
+import java.util.List;
 import java.util.Map;
 
+import static junit.framework.Assert.assertTrue;
 import static junit.framework.TestCase.assertEquals;
 
 public class RaftActorTest extends AbstractActorTest {
 
 
+    @After
+    public void tearDown() {
+        MockAkkaJournal.clearJournal();
+        MockSnapshotStore.setMockSnapshot(null);
+    }
+
     public static class MockRaftActor extends RaftActor {
 
+        private boolean applySnapshotCalled = false;
+        private List<Object> state;
+
         public MockRaftActor(String id,
             Map<String, String> peerAddresses) {
             super(id, peerAddresses);
+            state = new ArrayList<>();
+        }
+
+        public RaftActorContext getRaftActorContext() {
+            return context;
+        }
+
+        public boolean isApplySnapshotCalled() {
+            return applySnapshotCalled;
+        }
+
+        public List<Object> getState() {
+            return state;
         }
 
         public static Props props(final String id, final Map<String, String> peerAddresses){
@@ -34,17 +72,26 @@ public class RaftActorTest extends AbstractActorTest {
             });
         }
 
-        @Override protected void applyState(ActorRef clientActor,
-            String identifier,
-            Object data) {
+        @Override protected void applyState(ActorRef clientActor, String identifier, Object data) {
+            state.add(data);
         }
 
-        @Override protected Object createSnapshot() {
+        @Override protected void createSnapshot() {
             throw new UnsupportedOperationException("createSnapshot");
         }
 
-        @Override protected void applySnapshot(Object snapshot) {
-            throw new UnsupportedOperationException("applySnapshot");
+        @Override protected void applySnapshot(ByteString snapshot) {
+            applySnapshotCalled = true;
+            try {
+                Object data = toObject(snapshot);
+                if (data instanceof List) {
+                    state.addAll((List) data);
+                }
+            } catch (ClassNotFoundException e) {
+                e.printStackTrace();
+            } catch (IOException e) {
+                e.printStackTrace();
+            }
         }
 
         @Override protected void onStateChanged() {
@@ -54,6 +101,26 @@ public class RaftActorTest extends AbstractActorTest {
             return this.getId();
         }
 
+        private Object toObject(ByteString bs) throws ClassNotFoundException, IOException {
+            Object obj = null;
+            ByteArrayInputStream bis = null;
+            ObjectInputStream ois = null;
+            try {
+                bis = new ByteArrayInputStream(bs.toByteArray());
+                ois = new ObjectInputStream(bis);
+                obj = ois.readObject();
+            } finally {
+                if (bis != null) {
+                    bis.close();
+                }
+                if (ois != null) {
+                    ois.close();
+                }
+            }
+            return obj;
+        }
+
+
     }
 
 
@@ -133,5 +200,109 @@ public class RaftActorTest extends AbstractActorTest {
         kit.findLeader(kit.getRaftActor().path().toString());
     }
 
+    @Test
+    public void testRaftActorRecovery() {
+        new JavaTestKit(getSystem()) {{
+            new Within(duration("1 seconds")) {
+                protected void run() {
+
+                    String persistenceId = "follower10";
+
+                    ActorRef followerActor = getSystem().actorOf(
+                        MockRaftActor.props(persistenceId, Collections.EMPTY_MAP), persistenceId);
+
+                    List<ReplicatedLogEntry> snapshotUnappliedEntries = new ArrayList<>();
+                    ReplicatedLogEntry entry1 = new MockRaftActorContext.MockReplicatedLogEntry(1, 4, new MockRaftActorContext.MockPayload("E"));
+                    snapshotUnappliedEntries.add(entry1);
+
+                    int lastAppliedDuringSnapshotCapture = 3;
+                    int lastIndexDuringSnapshotCapture = 4;
+
+                    ByteString snapshotBytes = null;
+                    try {
+                        // 4 messages as part of snapshot, which are applied to state
+                        snapshotBytes  = fromObject(Arrays.asList(new MockRaftActorContext.MockPayload("A"),
+                            new MockRaftActorContext.MockPayload("B"),
+                            new MockRaftActorContext.MockPayload("C"),
+                            new MockRaftActorContext.MockPayload("D")));
+                    } catch (Exception e) {
+                        e.printStackTrace();
+                    }
+                    Snapshot snapshot = Snapshot.create(snapshotBytes.toByteArray(),
+                        snapshotUnappliedEntries, lastIndexDuringSnapshotCapture, 1 ,
+                        lastAppliedDuringSnapshotCapture, 1);
+                    MockSnapshotStore.setMockSnapshot(snapshot);
+                    MockSnapshotStore.setPersistenceId(persistenceId);
+
+                    // add more entries after snapshot is taken
+                    List<ReplicatedLogEntry> entries = new ArrayList<>();
+                    ReplicatedLogEntry entry2 = new MockRaftActorContext.MockReplicatedLogEntry(1, 5, new MockRaftActorContext.MockPayload("F"));
+                    ReplicatedLogEntry entry3 = new MockRaftActorContext.MockReplicatedLogEntry(1, 6, new MockRaftActorContext.MockPayload("G"));
+                    ReplicatedLogEntry entry4 = new MockRaftActorContext.MockReplicatedLogEntry(1, 7, new MockRaftActorContext.MockPayload("H"));
+                    entries.add(entry2);
+                    entries.add(entry3);
+                    entries.add(entry4);
+
+                    int lastAppliedToState = 5;
+                    int lastIndex = 7;
+
+                    MockAkkaJournal.addToJournal(5, entry2);
+                    // 2 entries are applied to state besides the 4 entries in snapshot
+                    MockAkkaJournal.addToJournal(6, new ApplyLogEntries(lastAppliedToState));
+                    MockAkkaJournal.addToJournal(7, entry3);
+                    MockAkkaJournal.addToJournal(8, entry4);
+
+                    // kill the actor
+                    followerActor.tell(PoisonPill.getInstance(), null);
+
+                    try {
+                        // give some time for actor to die
+                        Thread.sleep(200);
+                    } catch (InterruptedException e) {
+                        e.printStackTrace();
+                    }
+
+                    //reinstate the actor
+                    TestActorRef<MockRaftActor> ref = TestActorRef.create(getSystem(),
+                        MockRaftActor.props(persistenceId, Collections.EMPTY_MAP));
 
+                    try {
+                        //give some time for snapshot offer to get called.
+                        Thread.sleep(200);
+                    } catch (InterruptedException e) {
+                        e.printStackTrace();
+                    }
+
+                    RaftActorContext context = ref.underlyingActor().getRaftActorContext();
+                    assertEquals(snapshotUnappliedEntries.size() + entries.size(), context.getReplicatedLog().size());
+                    assertEquals(lastIndex, context.getReplicatedLog().lastIndex());
+                    assertEquals(lastAppliedToState, context.getLastApplied());
+                    assertEquals(lastAppliedToState, context.getCommitIndex());
+                    assertTrue(ref.underlyingActor().isApplySnapshotCalled());
+                    assertEquals(6, ref.underlyingActor().getState().size());
+                }
+            };
+        }};
+
+    }
+
+    private ByteString fromObject(Object snapshot) throws Exception {
+        ByteArrayOutputStream b = null;
+        ObjectOutputStream o = null;
+        try {
+            b = new ByteArrayOutputStream();
+            o = new ObjectOutputStream(b);
+            o.writeObject(snapshot);
+            byte[] snapshotBytes = b.toByteArray();
+            return ByteString.copyFrom(snapshotBytes);
+        } finally {
+            if (o != null) {
+                o.flush();
+                o.close();
+            }
+            if (b != null) {
+                b.close();
+            }
+        }
+    }
 }