Remove RaftActor#trimPersistentData method
[controller.git] / opendaylight / md-sal / sal-akka-raft / src / main / java / org / opendaylight / controller / cluster / raft / RaftActor.java
index 766b80e73dd12c890df3ed493e397a7cd144aab4..b74259d4851153659df0c2866f6323b9234eff06 100644 (file)
@@ -10,8 +10,6 @@ package org.opendaylight.controller.cluster.raft;
 
 import akka.actor.ActorRef;
 import akka.actor.ActorSelection;
-import akka.event.Logging;
-import akka.event.LoggingAdapter;
 import akka.japi.Procedure;
 import akka.persistence.RecoveryCompleted;
 import akka.persistence.SaveSnapshotFailure;
@@ -19,30 +17,40 @@ import akka.persistence.SaveSnapshotSuccess;
 import akka.persistence.SnapshotOffer;
 import akka.persistence.SnapshotSelectionCriteria;
 import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Objects;
 import com.google.common.base.Optional;
 import com.google.common.base.Stopwatch;
-import com.google.protobuf.ByteString;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Lists;
 import java.io.Serializable;
+import java.util.Collection;
+import java.util.List;
 import java.util.Map;
+import java.util.concurrent.TimeUnit;
+import org.apache.commons.lang3.time.DurationFormatUtils;
 import org.opendaylight.controller.cluster.DataPersistenceProvider;
 import org.opendaylight.controller.cluster.common.actor.AbstractUntypedPersistentActor;
+import org.opendaylight.controller.cluster.notifications.LeaderStateChanged;
 import org.opendaylight.controller.cluster.notifications.RoleChanged;
+import org.opendaylight.controller.cluster.raft.base.messages.ApplyJournalEntries;
 import org.opendaylight.controller.cluster.raft.base.messages.ApplyLogEntries;
 import org.opendaylight.controller.cluster.raft.base.messages.ApplySnapshot;
 import org.opendaylight.controller.cluster.raft.base.messages.ApplyState;
 import org.opendaylight.controller.cluster.raft.base.messages.CaptureSnapshot;
 import org.opendaylight.controller.cluster.raft.base.messages.CaptureSnapshotReply;
 import org.opendaylight.controller.cluster.raft.base.messages.Replicate;
-import org.opendaylight.controller.cluster.raft.base.messages.SendHeartBeat;
-import org.opendaylight.controller.cluster.raft.base.messages.SendInstallSnapshot;
+import org.opendaylight.controller.cluster.raft.behaviors.AbstractLeader;
 import org.opendaylight.controller.cluster.raft.behaviors.AbstractRaftActorBehavior;
 import org.opendaylight.controller.cluster.raft.behaviors.Follower;
 import org.opendaylight.controller.cluster.raft.behaviors.RaftActorBehavior;
 import org.opendaylight.controller.cluster.raft.client.messages.FindLeader;
 import org.opendaylight.controller.cluster.raft.client.messages.FindLeaderReply;
-import org.opendaylight.controller.cluster.raft.messages.AppendEntriesReply;
+import org.opendaylight.controller.cluster.raft.client.messages.FollowerInfo;
+import org.opendaylight.controller.cluster.raft.client.messages.GetOnDemandRaftState;
+import org.opendaylight.controller.cluster.raft.client.messages.OnDemandRaftState;
 import org.opendaylight.controller.cluster.raft.protobuff.client.messages.Payload;
-import org.opendaylight.controller.protobuff.messages.cluster.raft.AppendEntriesMessages;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 /**
  * RaftActor encapsulates a state machine that needs to be kept synchronized
@@ -85,8 +93,18 @@ import org.opendaylight.controller.protobuff.messages.cluster.raft.AppendEntries
  * </ul>
  */
 public abstract class RaftActor extends AbstractUntypedPersistentActor {
-    protected final LoggingAdapter LOG =
-        Logging.getLogger(getContext().system(), this);
+
+    private static final long APPLY_STATE_DELAY_THRESHOLD_IN_NANOS = TimeUnit.MILLISECONDS.toNanos(50L); // 50 millis
+
+    private static final Procedure<ApplyJournalEntries> APPLY_JOURNAL_ENTRIES_PERSIST_CALLBACK =
+            new Procedure<ApplyJournalEntries>() {
+                @Override
+                public void apply(ApplyJournalEntries param) throws Exception {
+                }
+            };
+    private static final String COMMIT_SNAPSHOT = "commit_snapshot";
+
+    protected final Logger LOG = LoggerFactory.getLogger(getClass());
 
     /**
      * The current state determines the current behavior of a RaftActor
@@ -98,19 +116,21 @@ public abstract class RaftActor extends AbstractUntypedPersistentActor {
      * This context should NOT be passed directly to any other actor it is
      * only to be consumed by the RaftActorBehaviors
      */
-    private final RaftActorContext context;
+    private final RaftActorContextImpl context;
+
+    private final Procedure<Void> createSnapshotProcedure = new CreateSnapshotProcedure();
 
     /**
      * The in-memory journal
      */
     private ReplicatedLogImpl replicatedLog = new ReplicatedLogImpl();
 
-    private CaptureSnapshot captureSnapshot = null;
-
     private Stopwatch recoveryTimer;
 
     private int currentRecoveryBatchCount;
 
+    private final BehaviorStateHolder reusableBehaviorStateHolder = new BehaviorStateHolder();
+
     public RaftActor(String id, Map<String, String> peerAddresses) {
         this(id, peerAddresses, Optional.<ConfigParams>absent());
     }
@@ -127,8 +147,7 @@ public abstract class RaftActor extends AbstractUntypedPersistentActor {
 
     private void initRecoveryTimer() {
         if(recoveryTimer == null) {
-            recoveryTimer = new Stopwatch();
-            recoveryTimer.start();
+            recoveryTimer = Stopwatch.createStarted();
         }
     }
 
@@ -140,6 +159,19 @@ public abstract class RaftActor extends AbstractUntypedPersistentActor {
         super.preStart();
     }
 
+    @Override
+    public void postStop() {
+        if(currentBehavior != null) {
+            try {
+                currentBehavior.close();
+            } catch (Exception e) {
+                LOG.debug("{}: Error closing behavior {}", persistenceId(), currentBehavior.state());
+            }
+        }
+
+        super.postStop();
+    }
+
     @Override
     public void handleRecover(Object message) {
         if(persistence().isRecoveryApplicable()) {
@@ -148,7 +180,10 @@ public abstract class RaftActor extends AbstractUntypedPersistentActor {
             } else if (message instanceof ReplicatedLogEntry) {
                 onRecoveredJournalLogEntry((ReplicatedLogEntry) message);
             } else if (message instanceof ApplyLogEntries) {
-                onRecoveredApplyLogEntries((ApplyLogEntries) message);
+                // Handle this message for backwards compatibility with pre-Lithium versions.
+                onRecoveredApplyLogEntries(((ApplyLogEntries) message).getToIndex());
+            } else if (message instanceof ApplyJournalEntries) {
+                onRecoveredApplyLogEntries(((ApplyJournalEntries) message).getToIndex());
             } else if (message instanceof DeleteEntries) {
                 replicatedLog.removeFrom(((DeleteEntries) message).getFromIndex());
             } else if (message instanceof UpdateElectionTerm) {
@@ -191,8 +226,7 @@ public abstract class RaftActor extends AbstractUntypedPersistentActor {
         context.setLastApplied(snapshot.getLastAppliedIndex());
         context.setCommitIndex(snapshot.getLastAppliedIndex());
 
-        Stopwatch timer = new Stopwatch();
-        timer.start();
+        Stopwatch timer = Stopwatch.createStarted();
 
         // Apply the snapshot to the actors state
         applyRecoverySnapshot(snapshot.getState());
@@ -211,18 +245,18 @@ public abstract class RaftActor extends AbstractUntypedPersistentActor {
         replicatedLog.append(logEntry);
     }
 
-    private void onRecoveredApplyLogEntries(ApplyLogEntries ale) {
+    private void onRecoveredApplyLogEntries(long toIndex) {
         if(LOG.isDebugEnabled()) {
             LOG.debug("{}: Received ApplyLogEntries for recovery, applying to state: {} to {}",
-                    persistenceId(), context.getLastApplied() + 1, ale.getToIndex());
+                    persistenceId(), context.getLastApplied() + 1, toIndex);
         }
 
-        for (long i = context.getLastApplied() + 1; i <= ale.getToIndex(); i++) {
+        for (long i = context.getLastApplied() + 1; i <= toIndex; i++) {
             batchRecoveredLogEntry(replicatedLog.get(i));
         }
 
-        context.setLastApplied(ale.getToIndex());
-        context.setCommitIndex(ale.getToIndex());
+        context.setLastApplied(toIndex);
+        context.setCommitIndex(toIndex);
     }
 
     private void batchRecoveredLogEntry(ReplicatedLogEntry logEntry) {
@@ -275,15 +309,21 @@ public abstract class RaftActor extends AbstractUntypedPersistentActor {
     }
 
     protected void changeCurrentBehavior(RaftActorBehavior newBehavior){
-        RaftActorBehavior oldBehavior = currentBehavior;
+        reusableBehaviorStateHolder.init(currentBehavior);
         currentBehavior = newBehavior;
-        handleBehaviorChange(oldBehavior, currentBehavior);
+        handleBehaviorChange(reusableBehaviorStateHolder, currentBehavior);
     }
 
     @Override public void handleCommand(Object message) {
         if (message instanceof ApplyState){
             ApplyState applyState = (ApplyState) message;
 
+            long elapsedTime = (System.nanoTime() - applyState.getStartTime());
+            if(elapsedTime >= APPLY_STATE_DELAY_THRESHOLD_IN_NANOS){
+                LOG.warn("ApplyState took more time than expected. Elapsed Time = {} ms ApplyState = {}",
+                        TimeUnit.NANOSECONDS.toMillis(elapsedTime), applyState);
+            }
+
             if(LOG.isDebugEnabled()) {
                 LOG.debug("{}: Applying state for log index {} data {}",
                     persistenceId(), applyState.getReplicatedLogEntry().getIndex(),
@@ -293,16 +333,13 @@ public abstract class RaftActor extends AbstractUntypedPersistentActor {
             applyState(applyState.getClientActor(), applyState.getIdentifier(),
                 applyState.getReplicatedLogEntry().getData());
 
-        } else if (message instanceof ApplyLogEntries){
-            ApplyLogEntries ale = (ApplyLogEntries) message;
+        } else if (message instanceof ApplyJournalEntries){
+            ApplyJournalEntries applyEntries = (ApplyJournalEntries) message;
             if(LOG.isDebugEnabled()) {
-                LOG.debug("{}: Persisting ApplyLogEntries with index={}", persistenceId(), ale.getToIndex());
+                LOG.debug("{}: Persisting ApplyLogEntries with index={}", persistenceId(), applyEntries.getToIndex());
             }
-            persistence().persist(new ApplyLogEntries(ale.getToIndex()), new Procedure<ApplyLogEntries>() {
-                @Override
-                public void apply(ApplyLogEntries param) throws Exception {
-                }
-            });
+
+            persistence().persist(applyEntries, APPLY_JOURNAL_ENTRIES_PERSIST_CALLBACK);
 
         } else if(message instanceof ApplySnapshot ) {
             Snapshot snapshot = ((ApplySnapshot) message).getSnapshot();
@@ -338,59 +375,98 @@ public abstract class RaftActor extends AbstractUntypedPersistentActor {
         } else if (message instanceof SaveSnapshotFailure) {
             SaveSnapshotFailure saveSnapshotFailure = (SaveSnapshotFailure) message;
 
-            LOG.error(saveSnapshotFailure.cause(), "{}: SaveSnapshotFailure received for snapshot Cause:",
-                    persistenceId());
+            LOG.error("{}: SaveSnapshotFailure received for snapshot Cause:",
+                    persistenceId(), saveSnapshotFailure.cause());
 
-            context.getReplicatedLog().snapshotRollback();
-
-            LOG.info("{}: Replicated Log rollbacked. Snapshot will be attempted in the next cycle." +
-                "snapshotIndex:{}, snapshotTerm:{}, log-size:{}", persistenceId(),
-                context.getReplicatedLog().getSnapshotIndex(),
-                context.getReplicatedLog().getSnapshotTerm(),
-                context.getReplicatedLog().size());
+            context.getSnapshotManager().rollback();
 
         } else if (message instanceof CaptureSnapshot) {
-            LOG.info("{}: CaptureSnapshot received by actor", persistenceId());
+            LOG.debug("{}: CaptureSnapshot received by actor: {}", persistenceId(), message);
 
-            if(captureSnapshot == null) {
-                captureSnapshot = (CaptureSnapshot)message;
-                createSnapshot();
-            }
+            context.getSnapshotManager().create(createSnapshotProcedure);
 
-        } else if (message instanceof CaptureSnapshotReply){
+        } else if (message instanceof CaptureSnapshotReply) {
             handleCaptureSnapshotReply(((CaptureSnapshotReply) message).getSnapshot());
-
+        } else if(message instanceof GetOnDemandRaftState) {
+            onGetOnDemandRaftStats();
+        } else if (message.equals(COMMIT_SNAPSHOT)) {
+            commitSnapshot(-1);
         } else {
-            if (!(message instanceof AppendEntriesMessages.AppendEntries)
-                && !(message instanceof AppendEntriesReply) && !(message instanceof SendHeartBeat)) {
-                if(LOG.isDebugEnabled()) {
-                    LOG.debug("{}: onReceiveCommand: message: {}", persistenceId(), message.getClass());
-                }
-            }
+            reusableBehaviorStateHolder.init(currentBehavior);
 
-            RaftActorBehavior oldBehavior = currentBehavior;
             currentBehavior = currentBehavior.handleMessage(getSender(), message);
 
-            handleBehaviorChange(oldBehavior, currentBehavior);
+            handleBehaviorChange(reusableBehaviorStateHolder, currentBehavior);
+        }
+    }
+
+    private void onGetOnDemandRaftStats() {
+        // Debugging message to retrieve raft stats.
+
+        OnDemandRaftState.Builder builder = OnDemandRaftState.builder()
+                .commitIndex(context.getCommitIndex())
+                .currentTerm(context.getTermInformation().getCurrentTerm())
+                .inMemoryJournalDataSize(replicatedLog.dataSize())
+                .inMemoryJournalLogSize(replicatedLog.size())
+                .isSnapshotCaptureInitiated(context.getSnapshotManager().isCapturing())
+                .lastApplied(context.getLastApplied())
+                .lastIndex(replicatedLog.lastIndex())
+                .lastTerm(replicatedLog.lastTerm())
+                .leader(getLeaderId())
+                .raftState(currentBehavior.state().toString())
+                .replicatedToAllIndex(currentBehavior.getReplicatedToAllIndex())
+                .snapshotIndex(replicatedLog.getSnapshotIndex())
+                .snapshotTerm(replicatedLog.getSnapshotTerm())
+                .votedFor(context.getTermInformation().getVotedFor())
+                .peerAddresses(ImmutableMap.copyOf(context.getPeerAddresses()));
+
+        ReplicatedLogEntry lastLogEntry = getLastLogEntry();
+        if (lastLogEntry != null) {
+            builder.lastLogIndex(lastLogEntry.getIndex());
+            builder.lastLogTerm(lastLogEntry.getTerm());
+        }
+
+        if(currentBehavior instanceof AbstractLeader) {
+            AbstractLeader leader = (AbstractLeader)currentBehavior;
+            Collection<String> followerIds = leader.getFollowerIds();
+            List<FollowerInfo> followerInfoList = Lists.newArrayListWithCapacity(followerIds.size());
+            for(String id: followerIds) {
+                final FollowerLogInformation info = leader.getFollower(id);
+                followerInfoList.add(new FollowerInfo(id, info.getNextIndex(), info.getMatchIndex(),
+                        info.isFollowerActive(), DurationFormatUtils.formatDurationHMS(info.timeSinceLastActivity())));
+            }
+
+            builder.followerInfoList(followerInfoList);
         }
+
+        sender().tell(builder.build(), self());
+
     }
 
-    private void handleBehaviorChange(RaftActorBehavior oldBehavior, RaftActorBehavior currentBehavior) {
+    private void handleBehaviorChange(BehaviorStateHolder oldBehaviorState, RaftActorBehavior currentBehavior) {
+        RaftActorBehavior oldBehavior = oldBehaviorState.getBehavior();
+
         if (oldBehavior != currentBehavior){
             onStateChanged();
         }
 
-        String oldBehaviorLeaderId = oldBehavior == null? null : oldBehavior.getLeaderId();
-        String oldBehaviorState = oldBehavior == null? null : oldBehavior.state().name();
+        String oldBehaviorLeaderId = oldBehavior == null ? null : oldBehaviorState.getLeaderId();
+        String oldBehaviorStateName = oldBehavior == null ? null : oldBehavior.state().name();
 
         // it can happen that the state has not changed but the leader has changed.
-        onLeaderChanged(oldBehaviorLeaderId, currentBehavior.getLeaderId());
+        Optional<ActorRef> roleChangeNotifier = getRoleChangeNotifier();
+        if(!Objects.equal(oldBehaviorLeaderId, currentBehavior.getLeaderId())) {
+            if(roleChangeNotifier.isPresent()) {
+                roleChangeNotifier.get().tell(new LeaderStateChanged(getId(), currentBehavior.getLeaderId()), getSelf());
+            }
+
+            onLeaderChanged(oldBehaviorLeaderId, currentBehavior.getLeaderId());
+        }
 
-        if (getRoleChangeNotifier().isPresent() &&
+        if (roleChangeNotifier.isPresent() &&
                 (oldBehavior == null || (oldBehavior.state() != currentBehavior.state()))) {
-            getRoleChangeNotifier().get().tell(
-                    new RoleChanged(getId(), oldBehaviorState , currentBehavior.state().name()),
-                    getSelf());
+            roleChangeNotifier.get().tell(new RoleChanged(getId(), oldBehaviorStateName ,
+                    currentBehavior.state().name()), getSelf());
         }
     }
 
@@ -427,19 +503,12 @@ public abstract class RaftActor extends AbstractUntypedPersistentActor {
                             // Apply the state immediately
                             applyState(clientActor, identifier, data);
 
-                            // Send a ApplyLogEntries message so that we write the fact that we applied
+                            // Send a ApplyJournalEntries message so that we write the fact that we applied
                             // the state to durable storage
-                            self().tell(new ApplyLogEntries((int) replicatedLogEntry.getIndex()), self());
-
-                            // Check if the "real" snapshot capture has been initiated. If no then do the fake snapshot
-                            if(!context.isSnapshotCaptureInitiated()){
-                                raftContext.getReplicatedLog().snapshotPreCommit(raftContext.getLastApplied(),
-                                        raftContext.getTermInformation().getCurrentTerm());
-                                raftContext.getReplicatedLog().snapshotCommit();
-                            } else {
-                                LOG.debug("{}: Skipping fake snapshotting for {} because real snapshotting is in progress",
-                                        persistenceId(), getId());
-                            }
+                            self().tell(new ApplyJournalEntries(replicatedLogEntry.getIndex()), self());
+
+                            context.getSnapshotManager().trimLog(context.getLastApplied(), currentBehavior);
+
                         } else if (clientActor != null) {
                             // Send message for replication
                             currentBehavior.handleMessage(getSelf(),
@@ -514,6 +583,10 @@ public abstract class RaftActor extends AbstractUntypedPersistentActor {
         return context;
     }
 
+    protected void updateConfigParams(ConfigParams configParams) {
+        context.setConfigParams(configParams);
+    }
+
     /**
      * setPeerAddress sets the address of a known peer at a later time.
      * <p>
@@ -533,10 +606,7 @@ public abstract class RaftActor extends AbstractUntypedPersistentActor {
     }
 
     protected void commitSnapshot(long sequenceNumber) {
-        context.getReplicatedLog().snapshotCommit();
-
-        // TODO: Not sure if we want to be this aggressive with trimming stuff
-        trimPersistentData(sequenceNumber);
+        context.getSnapshotManager().commit(persistence(), sequenceNumber);
     }
 
     /**
@@ -576,7 +646,7 @@ public abstract class RaftActor extends AbstractUntypedPersistentActor {
     /**
      * This method is called during recovery to reconstruct the state of the actor.
      *
-     * @param snapshot A snapshot of the state of the actor
+     * @param snapshotBytes A snapshot of the state of the actor
      */
     protected abstract void applyRecoverySnapshot(byte[] snapshotBytes);
 
@@ -628,17 +698,6 @@ public abstract class RaftActor extends AbstractUntypedPersistentActor {
 
     protected void onLeaderChanged(String oldLeader, String newLeader){};
 
-    private void trimPersistentData(long sequenceNumber) {
-        // Trim akka snapshots
-        // FIXME : Not sure how exactly the SnapshotSelectionCriteria is applied
-        // For now guessing that it is ANDed.
-        persistence().deleteSnapshots(new SnapshotSelectionCriteria(
-            sequenceNumber - context.getConfigParams().getSnapshotBatchCount(), 43200000));
-
-        // Trim akka journal
-        persistence().deleteMessages(sequenceNumber);
-    }
-
     private String getLeaderAddress(){
         if(isLeader()){
             return getSelf().path().toString();
@@ -657,39 +716,13 @@ public abstract class RaftActor extends AbstractUntypedPersistentActor {
     }
 
     private void handleCaptureSnapshotReply(byte[] snapshotBytes) {
-        LOG.info("{}: CaptureSnapshotReply received by actor: snapshot size {}", persistenceId(), snapshotBytes.length);
-
-        // create a snapshot object from the state provided and save it
-        // when snapshot is saved async, SaveSnapshotSuccess is raised.
-
-        Snapshot sn = Snapshot.create(snapshotBytes,
-            context.getReplicatedLog().getFrom(captureSnapshot.getLastAppliedIndex() + 1),
-            captureSnapshot.getLastIndex(), captureSnapshot.getLastTerm(),
-            captureSnapshot.getLastAppliedIndex(), captureSnapshot.getLastAppliedTerm());
-
-        persistence().saveSnapshot(sn);
-
-        LOG.info("{}: Persisting of snapshot done:{}", persistenceId(), sn.getLogMessage());
-
-        //be greedy and remove entries from in-mem journal which are in the snapshot
-        // and update snapshotIndex and snapshotTerm without waiting for the success,
-
-        context.getReplicatedLog().snapshotPreCommit(
-            captureSnapshot.getLastAppliedIndex(),
-            captureSnapshot.getLastAppliedTerm());
+        LOG.debug("{}: CaptureSnapshotReply received by actor: snapshot size {}", persistenceId(), snapshotBytes.length);
 
-        LOG.info("{}: Removed in-memory snapshotted entries, adjusted snaphsotIndex:{} " +
-            "and term:{}", persistenceId(), captureSnapshot.getLastAppliedIndex(),
-            captureSnapshot.getLastAppliedTerm());
-
-        if (isLeader() && captureSnapshot.isInstallSnapshotInitiated()) {
-            // this would be call straight to the leader and won't initiate in serialization
-            currentBehavior.handleMessage(getSelf(), new SendInstallSnapshot(
-                    ByteString.copyFrom(snapshotBytes)));
-        }
+        context.getSnapshotManager().persist(persistence(), snapshotBytes, currentBehavior, getTotalMemory());
+    }
 
-        captureSnapshot = null;
-        context.setSnapshotCaptureInitiated(false);
+    protected long getTotalMemory() {
+        return Runtime.getRuntime().totalMemory();
     }
 
     protected boolean hasFollowers(){
@@ -697,9 +730,9 @@ public abstract class RaftActor extends AbstractUntypedPersistentActor {
     }
 
     private class ReplicatedLogImpl extends AbstractReplicatedLogImpl {
-
         private static final int DATA_SIZE_DIVIDER = 5;
-        private long dataSizeSinceLastSnapshot = 0;
+        private long dataSizeSinceLastSnapshot = 0L;
+
 
         public ReplicatedLogImpl(Snapshot snapshot) {
             super(snapshot.getLastAppliedIndex(), snapshot.getLastAppliedTerm(),
@@ -720,13 +753,14 @@ public abstract class RaftActor extends AbstractUntypedPersistentActor {
             // FIXME: Maybe this should be done after the command is saved
             journal.subList(adjustedIndex , journal.size()).clear();
 
-            persistence().persist(new DeleteEntries(adjustedIndex), new Procedure<DeleteEntries>(){
+            persistence().persist(new DeleteEntries(adjustedIndex), new Procedure<DeleteEntries>() {
 
-                @Override public void apply(DeleteEntries param)
-                    throws Exception {
+                @Override
+                public void apply(DeleteEntries param)
+                        throws Exception {
                     //FIXME : Doing nothing for now
                     dataSize = 0;
-                    for(ReplicatedLogEntry entry : journal){
+                    for (ReplicatedLogEntry entry : journal) {
                         dataSize += entry.size();
                     }
                 }
@@ -738,11 +772,6 @@ public abstract class RaftActor extends AbstractUntypedPersistentActor {
             appendAndPersist(replicatedLogEntry, null);
         }
 
-        @Override
-        public int dataSize() {
-            return dataSize;
-        }
-
         public void appendAndPersist(
             final ReplicatedLogEntry replicatedLogEntry,
             final Procedure<ReplicatedLogEntry> callback)  {
@@ -769,9 +798,8 @@ public abstract class RaftActor extends AbstractUntypedPersistentActor {
                         long dataSizeForCheck = dataSize;
 
                         dataSizeSinceLastSnapshot += logEntrySize;
-                        long journalSize = lastIndex()+1;
 
-                        if(!hasFollowers()) {
+                        if (!hasFollowers()) {
                             // When we do not have followers we do not maintain an in-memory log
                             // due to this the journalSize will never become anything close to the
                             // snapshot batch count. In fact will mostly be 1.
@@ -785,47 +813,23 @@ public abstract class RaftActor extends AbstractUntypedPersistentActor {
                             // as if we were maintaining a real snapshot
                             dataSizeForCheck = dataSizeSinceLastSnapshot / DATA_SIZE_DIVIDER;
                         }
+                        long journalSize = replicatedLogEntry.getIndex() + 1;
+                        long dataThreshold = getTotalMemory() *
+                                context.getConfigParams().getSnapshotDataThresholdPercentage() / 100;
 
-                        long dataThreshold = Runtime.getRuntime().totalMemory() *
-                                getRaftActorContext().getConfigParams().getSnapshotDataThresholdPercentage() / 100;
-
-                        // when a snaphsot is being taken, captureSnapshot != null
-                        if (!context.isSnapshotCaptureInitiated() &&
-                                ( journalSize % context.getConfigParams().getSnapshotBatchCount() == 0 ||
-                                        dataSizeForCheck > dataThreshold)) {
-
-                            dataSizeSinceLastSnapshot = 0;
+                        if ((journalSize % context.getConfigParams().getSnapshotBatchCount() == 0
+                                || dataSizeForCheck > dataThreshold)) {
 
-                            LOG.info("{}: Initiating Snapshot Capture..", persistenceId());
-                            long lastAppliedIndex = -1;
-                            long lastAppliedTerm = -1;
-
-                            ReplicatedLogEntry lastAppliedEntry = get(context.getLastApplied());
-                            if (!hasFollowers()) {
-                                lastAppliedIndex = replicatedLogEntry.getIndex();
-                                lastAppliedTerm = replicatedLogEntry.getTerm();
-                            } else if (lastAppliedEntry != null) {
-                                lastAppliedIndex = lastAppliedEntry.getIndex();
-                                lastAppliedTerm = lastAppliedEntry.getTerm();
-                            }
+                            boolean started = context.getSnapshotManager().capture(replicatedLogEntry,
+                                    currentBehavior.getReplicatedToAllIndex());
 
-                            if(LOG.isDebugEnabled()) {
-                                LOG.debug("{}: Snapshot Capture logSize: {}", persistenceId(), journal.size());
-                                LOG.debug("{}: Snapshot Capture lastApplied:{} ",
-                                        persistenceId(), context.getLastApplied());
-                                LOG.debug("{}: Snapshot Capture lastAppliedIndex:{}", persistenceId(),
-                                        lastAppliedIndex);
-                                LOG.debug("{}: Snapshot Capture lastAppliedTerm:{}", persistenceId(),
-                                        lastAppliedTerm);
+                            if(started){
+                                dataSizeSinceLastSnapshot = 0;
                             }
 
-                            // send a CaptureSnapshot to self to make the expensive operation async.
-                            getSelf().tell(new CaptureSnapshot(
-                                lastIndex(), lastTerm(), lastAppliedIndex, lastAppliedTerm),
-                                null);
-                            context.setSnapshotCaptureInitiated(true);
                         }
-                        if(callback != null){
+
+                        if (callback != null){
                             callback.apply(replicatedLogEntry);
                         }
                     }
@@ -928,7 +932,18 @@ public abstract class RaftActor extends AbstractUntypedPersistentActor {
         @Override
         public void saveSnapshot(Object o) {
             // Make saving Snapshot successful
-            commitSnapshot(-1L);
+            // Committing the snapshot here would end up calling commit in the creating state which would
+            // be a state violation. That's why now we send a message to commit the snapshot.
+            self().tell(COMMIT_SNAPSHOT, self());
+        }
+    }
+
+
+    private class CreateSnapshotProcedure implements Procedure<Void> {
+
+        @Override
+        public void apply(Void aVoid) throws Exception {
+            createSnapshot();
         }
     }
 
@@ -941,4 +956,21 @@ public abstract class RaftActor extends AbstractUntypedPersistentActor {
         return currentBehavior;
     }
 
+    private static class BehaviorStateHolder {
+        private RaftActorBehavior behavior;
+        private String leaderId;
+
+        void init(RaftActorBehavior behavior) {
+            this.behavior = behavior;
+            this.leaderId = behavior != null ? behavior.getLeaderId() : null;
+        }
+
+        RaftActorBehavior getBehavior() {
+            return behavior;
+        }
+
+        String getLeaderId() {
+            return leaderId;
+        }
+    }
 }