BUG-5626: do not allow overriding of RaftActor.handleCommand()
[controller.git] / opendaylight / md-sal / sal-akka-raft / src / main / java / org / opendaylight / controller / cluster / raft / RaftActor.java
index 65254f2d6277c34c9773570e5938d8a005ab3015..47c8db6006544b13c11b42f7abdb370590fd3b8b 100644 (file)
@@ -1,5 +1,6 @@
 /*
  * Copyright (c) 2014 Cisco Systems, Inc. and others.  All rights reserved.
+ * Copyright (c) 2015 Brocade Communications Systems, Inc. and others.  All rights reserved.
  *
  * This program and the accompanying materials are made available under the
  * terms of the Eclipse Public License v1.0 which accompanies this distribution,
@@ -10,34 +11,48 @@ package org.opendaylight.controller.cluster.raft;
 
 import akka.actor.ActorRef;
 import akka.actor.ActorSelection;
+import akka.actor.PoisonPill;
 import akka.japi.Procedure;
-import akka.persistence.RecoveryCompleted;
-import akka.persistence.SaveSnapshotFailure;
-import akka.persistence.SaveSnapshotSuccess;
-import akka.persistence.SnapshotOffer;
-import akka.persistence.SnapshotSelectionCriteria;
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Optional;
-import com.google.common.base.Stopwatch;
-import com.google.protobuf.ByteString;
+import com.google.common.base.Preconditions;
+import com.google.common.base.Verify;
+import com.google.common.collect.Lists;
 import java.io.Serializable;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
+import java.util.Objects;
+import java.util.concurrent.TimeUnit;
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+import org.apache.commons.lang3.time.DurationFormatUtils;
 import org.opendaylight.controller.cluster.DataPersistenceProvider;
+import org.opendaylight.controller.cluster.DelegatingPersistentDataProvider;
+import org.opendaylight.controller.cluster.NonPersistentDataProvider;
+import org.opendaylight.controller.cluster.PersistentDataProvider;
 import org.opendaylight.controller.cluster.common.actor.AbstractUntypedPersistentActor;
+import org.opendaylight.controller.cluster.notifications.LeaderStateChanged;
 import org.opendaylight.controller.cluster.notifications.RoleChanged;
-import org.opendaylight.controller.cluster.raft.base.messages.ApplyLogEntries;
-import org.opendaylight.controller.cluster.raft.base.messages.ApplySnapshot;
+import org.opendaylight.controller.cluster.raft.base.messages.ApplyJournalEntries;
 import org.opendaylight.controller.cluster.raft.base.messages.ApplyState;
-import org.opendaylight.controller.cluster.raft.base.messages.CaptureSnapshot;
-import org.opendaylight.controller.cluster.raft.base.messages.CaptureSnapshotReply;
+import org.opendaylight.controller.cluster.raft.base.messages.InitiateCaptureSnapshot;
+import org.opendaylight.controller.cluster.raft.base.messages.LeaderTransitioning;
 import org.opendaylight.controller.cluster.raft.base.messages.Replicate;
-import org.opendaylight.controller.cluster.raft.base.messages.SendInstallSnapshot;
+import org.opendaylight.controller.cluster.raft.base.messages.SwitchBehavior;
+import org.opendaylight.controller.cluster.raft.behaviors.AbstractLeader;
 import org.opendaylight.controller.cluster.raft.behaviors.AbstractRaftActorBehavior;
 import org.opendaylight.controller.cluster.raft.behaviors.Follower;
 import org.opendaylight.controller.cluster.raft.behaviors.RaftActorBehavior;
 import org.opendaylight.controller.cluster.raft.client.messages.FindLeader;
 import org.opendaylight.controller.cluster.raft.client.messages.FindLeaderReply;
+import org.opendaylight.controller.cluster.raft.client.messages.FollowerInfo;
+import org.opendaylight.controller.cluster.raft.client.messages.GetOnDemandRaftState;
+import org.opendaylight.controller.cluster.raft.client.messages.OnDemandRaftState;
+import org.opendaylight.controller.cluster.raft.client.messages.Shutdown;
 import org.opendaylight.controller.cluster.raft.protobuff.client.messages.Payload;
+import org.opendaylight.yangtools.concepts.Immutable;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -82,49 +97,47 @@ import org.slf4j.LoggerFactory;
  * </ul>
  */
 public abstract class RaftActor extends AbstractUntypedPersistentActor {
-    protected final Logger LOG = LoggerFactory.getLogger(getClass());
 
-    /**
-     * The current state determines the current behavior of a RaftActor
-     * A Raft Actor always starts off in the Follower State
-     */
-    private RaftActorBehavior currentBehavior;
+    private static final long APPLY_STATE_DELAY_THRESHOLD_IN_NANOS = TimeUnit.MILLISECONDS.toNanos(50L); // 50 millis
+
+    protected final Logger LOG = LoggerFactory.getLogger(getClass());
 
     /**
      * This context should NOT be passed directly to any other actor it is
      * only to be consumed by the RaftActorBehaviors
      */
-    private final RaftActorContext context;
+    private final RaftActorContextImpl context;
 
-    /**
-     * The in-memory journal
-     */
-    private ReplicatedLogImpl replicatedLog = new ReplicatedLogImpl();
+    private final DelegatingPersistentDataProvider delegatingPersistenceProvider;
 
-    private CaptureSnapshot captureSnapshot = null;
+    private final PersistentDataProvider persistentProvider;
 
-    private Stopwatch recoveryTimer;
+    private final BehaviorStateTracker behaviorStateTracker = new BehaviorStateTracker();
 
-    private int currentRecoveryBatchCount;
+    private RaftActorRecoverySupport raftRecovery;
 
-    public RaftActor(String id, Map<String, String> peerAddresses) {
-        this(id, peerAddresses, Optional.<ConfigParams>absent());
-    }
+    private RaftActorSnapshotMessageSupport snapshotSupport;
+
+    private RaftActorServerConfigurationSupport serverConfigurationSupport;
+
+    private RaftActorLeadershipTransferCohort leadershipTransferInProgress;
+
+    private boolean shuttingDown;
 
     public RaftActor(String id, Map<String, String> peerAddresses,
-         Optional<ConfigParams> configParams) {
+         Optional<ConfigParams> configParams, short payloadVersion) {
+
+        persistentProvider = new PersistentDataProvider(this);
+        delegatingPersistenceProvider = new RaftActorDelegatingPersistentDataProvider(null, persistentProvider);
 
         context = new RaftActorContextImpl(this.getSelf(),
-            this.getContext(), id, new ElectionTermImpl(),
-            -1, -1, replicatedLog, peerAddresses,
+            this.getContext(), id, new ElectionTermImpl(persistentProvider, id, LOG),
+            -1, -1, peerAddresses,
             (configParams.isPresent() ? configParams.get(): new DefaultConfigParamsImpl()),
-            LOG);
-    }
+            delegatingPersistenceProvider, LOG);
 
-    private void initRecoveryTimer() {
-        if(recoveryTimer == null) {
-            recoveryTimer = Stopwatch.createStarted();
-        }
+        context.setPayloadVersion(payloadVersion);
+        context.setReplicatedLog(ReplicatedLogImpl.newInstance(context));
     }
 
     @Override
@@ -133,151 +146,96 @@ public abstract class RaftActor extends AbstractUntypedPersistentActor {
                 context.getConfigParams().getJournalRecoveryLogBatchSize());
 
         super.preStart();
+
+        snapshotSupport = newRaftActorSnapshotMessageSupport();
+        serverConfigurationSupport = new RaftActorServerConfigurationSupport(this);
     }
 
     @Override
-    public void handleRecover(Object message) {
-        if(persistence().isRecoveryApplicable()) {
-            if (message instanceof SnapshotOffer) {
-                onRecoveredSnapshot((SnapshotOffer) message);
-            } else if (message instanceof ReplicatedLogEntry) {
-                onRecoveredJournalLogEntry((ReplicatedLogEntry) message);
-            } else if (message instanceof ApplyLogEntries) {
-                onRecoveredApplyLogEntries((ApplyLogEntries) message);
-            } else if (message instanceof DeleteEntries) {
-                replicatedLog.removeFrom(((DeleteEntries) message).getFromIndex());
-            } else if (message instanceof UpdateElectionTerm) {
-                context.getTermInformation().update(((UpdateElectionTerm) message).getCurrentTerm(),
-                        ((UpdateElectionTerm) message).getVotedFor());
-            } else if (message instanceof RecoveryCompleted) {
-                onRecoveryCompletedMessage();
-            }
-        } else {
-            if (message instanceof RecoveryCompleted) {
-                // Delete all the messages from the akka journal so that we do not end up with consistency issues
-                // Note I am not using the dataPersistenceProvider and directly using the akka api here
-                deleteMessages(lastSequenceNr());
-
-                // Delete all the akka snapshots as they will not be needed
-                deleteSnapshots(new SnapshotSelectionCriteria(scala.Long.MaxValue(), scala.Long.MaxValue()));
-
-                onRecoveryComplete();
-
-                initializeBehavior();
-            }
-        }
+    public void postStop() {
+        context.close();
+        super.postStop();
     }
 
-    private void onRecoveredSnapshot(SnapshotOffer offer) {
-        if(LOG.isDebugEnabled()) {
-            LOG.debug("{}: SnapshotOffer called..", persistenceId());
+    @Override
+    protected void handleRecover(Object message) {
+        if(raftRecovery == null) {
+            raftRecovery = newRaftActorRecoverySupport();
         }
 
-        initRecoveryTimer();
-
-        Snapshot snapshot = (Snapshot) offer.snapshot();
-
-        // Create a replicated log with the snapshot information
-        // The replicated log can be used later on to retrieve this snapshot
-        // when we need to install it on a peer
-        replicatedLog = new ReplicatedLogImpl(snapshot);
-
-        context.setReplicatedLog(replicatedLog);
-        context.setLastApplied(snapshot.getLastAppliedIndex());
-        context.setCommitIndex(snapshot.getLastAppliedIndex());
+        boolean recoveryComplete = raftRecovery.handleRecoveryMessage(message, persistentProvider);
+        if(recoveryComplete) {
+            onRecoveryComplete();
 
-        Stopwatch timer = Stopwatch.createStarted();
+            initializeBehavior();
 
-        // Apply the snapshot to the actors state
-        applyRecoverySnapshot(snapshot.getState());
+            raftRecovery = null;
 
-        timer.stop();
-        LOG.info("Recovery snapshot applied for {} in {}: snapshotIndex={}, snapshotTerm={}, journal-size=" +
-                replicatedLog.size(), persistenceId(), timer.toString(),
-                replicatedLog.getSnapshotIndex(), replicatedLog.getSnapshotTerm());
-    }
-
-    private void onRecoveredJournalLogEntry(ReplicatedLogEntry logEntry) {
-        if(LOG.isDebugEnabled()) {
-            LOG.debug("{}: Received ReplicatedLogEntry for recovery: {}", persistenceId(), logEntry.getIndex());
+            if (context.getReplicatedLog().size() > 0) {
+                self().tell(new InitiateCaptureSnapshot(), self());
+                LOG.info("{}: Snapshot capture initiated after recovery", persistenceId());
+            } else {
+                LOG.info("{}: Snapshot capture NOT initiated after recovery, journal empty", persistenceId());
+            }
         }
-
-        replicatedLog.append(logEntry);
     }
 
-    private void onRecoveredApplyLogEntries(ApplyLogEntries ale) {
-        if(LOG.isDebugEnabled()) {
-            LOG.debug("{}: Received ApplyLogEntries for recovery, applying to state: {} to {}",
-                    persistenceId(), context.getLastApplied() + 1, ale.getToIndex());
-        }
-
-        for (long i = context.getLastApplied() + 1; i <= ale.getToIndex(); i++) {
-            batchRecoveredLogEntry(replicatedLog.get(i));
-        }
-
-        context.setLastApplied(ale.getToIndex());
-        context.setCommitIndex(ale.getToIndex());
+    protected RaftActorRecoverySupport newRaftActorRecoverySupport() {
+        return new RaftActorRecoverySupport(context, getRaftActorRecoveryCohort());
     }
 
-    private void batchRecoveredLogEntry(ReplicatedLogEntry logEntry) {
-        initRecoveryTimer();
+    @VisibleForTesting
+    void initializeBehavior(){
+        changeCurrentBehavior(new Follower(context));
+    }
 
-        int batchSize = context.getConfigParams().getJournalRecoveryLogBatchSize();
-        if(currentRecoveryBatchCount == 0) {
-            startLogRecoveryBatch(batchSize);
+    @VisibleForTesting
+    protected void changeCurrentBehavior(RaftActorBehavior newBehavior) {
+        final RaftActorBehavior currentBehavior = getCurrentBehavior();
+        if (currentBehavior != null) {
+            try {
+                currentBehavior.close();
+            } catch (Exception e) {
+                LOG.warn("{}: Error closing behavior {}", persistence(), currentBehavior, e);
+            }
         }
 
-        appendRecoveredLogEntry(logEntry.getData());
-
-        if(++currentRecoveryBatchCount >= batchSize) {
-            endCurrentLogRecoveryBatch();
-        }
+        final BehaviorState state = behaviorStateTracker.capture(currentBehavior);
+        setCurrentBehavior(newBehavior);
+        handleBehaviorChange(state, newBehavior);
     }
 
-    private void endCurrentLogRecoveryBatch() {
-        applyCurrentLogRecoveryBatch();
-        currentRecoveryBatchCount = 0;
+    /**
+     * Method exposed for subclasses to plug-in their logic. This method is invoked by {@link #handleCommand(Object)}
+     * for messages which are not handled by this class. Subclasses overriding this class should fall back to this
+     * implementation for messages which they do not handle
+     *
+     * @param message Incoming command message
+     */
+    protected void handleNonRaftCommand(final Object message) {
+        unhandled(message);
     }
 
-    private void onRecoveryCompletedMessage() {
-        if(currentRecoveryBatchCount > 0) {
-            endCurrentLogRecoveryBatch();
+    /**
+     * @deprecated This method is not final for testing purposes. DO NOT OVERRIDE IT, override
+     * {@link #handleNonRaftCommand(Object)} instead.
+     */
+    @Deprecated
+    @Override
+    // FIXME: make this method final once our unit tests do not need to override it
+    protected void handleCommand(final Object message) {
+        if (serverConfigurationSupport.handleMessage(message, getSender())) {
+            return;
         }
-
-        onRecoveryComplete();
-
-        String recoveryTime = "";
-        if(recoveryTimer != null) {
-            recoveryTimer.stop();
-            recoveryTime = " in " + recoveryTimer.toString();
-            recoveryTimer = null;
+        if (snapshotSupport.handleSnapshotMessage(message, getSender())) {
+            return;
         }
 
-        LOG.info(
-            "Recovery completed" + recoveryTime + " - Switching actor to Follower - " +
-                "Persistence Id =  " + persistenceId() +
-                " Last index in log={}, snapshotIndex={}, snapshotTerm={}, " +
-                "journal-size={}",
-            replicatedLog.lastIndex(), replicatedLog.getSnapshotIndex(),
-            replicatedLog.getSnapshotTerm(), replicatedLog.size());
-
-        initializeBehavior();
-    }
-
-    protected void initializeBehavior(){
-        changeCurrentBehavior(new Follower(context));
-    }
-
-    protected void changeCurrentBehavior(RaftActorBehavior newBehavior){
-        RaftActorBehavior oldBehavior = currentBehavior;
-        currentBehavior = newBehavior;
-        handleBehaviorChange(oldBehavior, currentBehavior);
-    }
-
-    @Override public void handleCommand(Object message) {
-        if (message instanceof ApplyState){
+        if (message instanceof ApplyState) {
             ApplyState applyState = (ApplyState) message;
 
+            long startTime = System.nanoTime();
+
             if(LOG.isDebugEnabled()) {
                 LOG.debug("{}: Applying state for log index {} data {}",
                     persistenceId(), applyState.getReplicatedLogEntry().getIndex(),
@@ -287,100 +245,265 @@ public abstract class RaftActor extends AbstractUntypedPersistentActor {
             applyState(applyState.getClientActor(), applyState.getIdentifier(),
                 applyState.getReplicatedLogEntry().getData());
 
-        } else if (message instanceof ApplyLogEntries){
-            ApplyLogEntries ale = (ApplyLogEntries) message;
-            if(LOG.isDebugEnabled()) {
-                LOG.debug("{}: Persisting ApplyLogEntries with index={}", persistenceId(), ale.getToIndex());
+            long elapsedTime = System.nanoTime() - startTime;
+            if(elapsedTime >= APPLY_STATE_DELAY_THRESHOLD_IN_NANOS){
+                LOG.debug("ApplyState took more time than expected. Elapsed Time = {} ms ApplyState = {}",
+                        TimeUnit.NANOSECONDS.toMillis(elapsedTime), applyState);
             }
-            persistence().persist(new ApplyLogEntries(ale.getToIndex()), new Procedure<ApplyLogEntries>() {
-                @Override
-                public void apply(ApplyLogEntries param) throws Exception {
-                }
-            });
 
-        } else if(message instanceof ApplySnapshot ) {
-            Snapshot snapshot = ((ApplySnapshot) message).getSnapshot();
+            if (!hasFollowers()) {
+                // for single node, the capture should happen after the apply state
+                // as we delete messages from the persistent journal which have made it to the snapshot
+                // capturing the snapshot before applying makes the persistent journal and snapshot out of sync
+                // and recovery shows data missing
+                context.getReplicatedLog().captureSnapshotIfReady(applyState.getReplicatedLogEntry());
 
-            if(LOG.isDebugEnabled()) {
-                LOG.debug("{}: ApplySnapshot called on Follower Actor " +
-                        "snapshotIndex:{}, snapshotTerm:{}", persistenceId(), snapshot.getLastAppliedIndex(),
-                    snapshot.getLastAppliedTerm()
-                );
+                context.getSnapshotManager().trimLog(context.getLastApplied());
             }
 
-            applySnapshot(snapshot.getState());
+        } else if (message instanceof ApplyJournalEntries) {
+            ApplyJournalEntries applyEntries = (ApplyJournalEntries) message;
+            if(LOG.isDebugEnabled()) {
+                LOG.debug("{}: Persisting ApplyJournalEntries with index={}", persistenceId(), applyEntries.getToIndex());
+            }
 
-            //clears the followers log, sets the snapshot index to ensure adjusted-index works
-            replicatedLog = new ReplicatedLogImpl(snapshot);
-            context.setReplicatedLog(replicatedLog);
-            context.setLastApplied(snapshot.getLastAppliedIndex());
+            persistence().persist(applyEntries, NoopProcedure.instance());
 
         } else if (message instanceof FindLeader) {
             getSender().tell(
                 new FindLeaderReply(getLeaderAddress()),
                 getSelf()
             );
+        } else if(message instanceof GetOnDemandRaftState) {
+            onGetOnDemandRaftStats();
+        } else if(message instanceof InitiateCaptureSnapshot) {
+            captureSnapshot();
+        } else if(message instanceof SwitchBehavior) {
+            switchBehavior(((SwitchBehavior) message));
+        } else if(message instanceof LeaderTransitioning) {
+            onLeaderTransitioning();
+        } else if(message instanceof Shutdown) {
+            onShutDown();
+        } else if(message instanceof Runnable) {
+            ((Runnable)message).run();
+        } else {
+            // Processing the message may affect the state, hence we need to capture it
+            final RaftActorBehavior currentBehavior = getCurrentBehavior();
+            final BehaviorState state = behaviorStateTracker.capture(currentBehavior);
+
+            // A behavior indicates that it processed the change by returning a reference to the next behavior
+            // to be used. A null return indicates it has not processed the message and we should be passing it to
+            // the subclass for handling.
+            final RaftActorBehavior nextBehavior = currentBehavior.handleMessage(getSender(), message);
+            if (nextBehavior != null) {
+                switchBehavior(state, nextBehavior);
+            } else {
+                handleNonRaftCommand(message);
+            }
+        }
+    }
 
-        } else if (message instanceof SaveSnapshotSuccess) {
-            SaveSnapshotSuccess success = (SaveSnapshotSuccess) message;
-            LOG.info("{}: SaveSnapshotSuccess received for snapshot", persistenceId());
+    private void initiateLeadershipTransfer(final RaftActorLeadershipTransferCohort.OnComplete onComplete) {
+        LOG.debug("{}: Initiating leader transfer", persistenceId());
 
-            long sequenceNumber = success.metadata().sequenceNr();
+        if(leadershipTransferInProgress == null) {
+            leadershipTransferInProgress = new RaftActorLeadershipTransferCohort(this);
+            leadershipTransferInProgress.addOnComplete(new RaftActorLeadershipTransferCohort.OnComplete() {
+                @Override
+                public void onSuccess(ActorRef raftActorRef) {
+                    leadershipTransferInProgress = null;
+                }
+
+                @Override
+                public void onFailure(ActorRef raftActorRef) {
+                    leadershipTransferInProgress = null;
+                }
+            });
+
+            leadershipTransferInProgress.addOnComplete(onComplete);
+            leadershipTransferInProgress.init();
+        } else {
+            LOG.debug("{}: prior leader transfer in progress - adding callback", persistenceId());
+            leadershipTransferInProgress.addOnComplete(onComplete);
+        }
+    }
 
-            commitSnapshot(sequenceNumber);
+    private void onShutDown() {
+        LOG.debug("{}: onShutDown", persistenceId());
 
-        } else if (message instanceof SaveSnapshotFailure) {
-            SaveSnapshotFailure saveSnapshotFailure = (SaveSnapshotFailure) message;
+        if(shuttingDown) {
+            return;
+        }
 
-            LOG.error("{}: SaveSnapshotFailure received for snapshot Cause:",
-                    persistenceId(), saveSnapshotFailure.cause());
+        shuttingDown = true;
 
-            context.getReplicatedLog().snapshotRollback();
+        final RaftActorBehavior currentBehavior = context.getCurrentBehavior();
+        if (currentBehavior.state() != RaftState.Leader) {
+            // For non-leaders shutdown is a no-op
+            self().tell(PoisonPill.getInstance(), self());
+            return;
+        }
+
+        if (context.hasFollowers()) {
+            initiateLeadershipTransfer(new RaftActorLeadershipTransferCohort.OnComplete() {
+                @Override
+                public void onSuccess(ActorRef raftActorRef) {
+                    LOG.debug("{}: leader transfer succeeded - sending PoisonPill", persistenceId());
+                    raftActorRef.tell(PoisonPill.getInstance(), raftActorRef);
+                }
+
+                @Override
+                public void onFailure(ActorRef raftActorRef) {
+                    LOG.debug("{}: leader transfer failed - sending PoisonPill", persistenceId());
+                    raftActorRef.tell(PoisonPill.getInstance(), raftActorRef);
+                }
+            });
+        } else {
+            pauseLeader(new TimedRunnable(context.getConfigParams().getElectionTimeOutInterval(), this) {
+                @Override
+                protected void doRun() {
+                    self().tell(PoisonPill.getInstance(), self());
+                }
 
-            LOG.info("{}: Replicated Log rollbacked. Snapshot will be attempted in the next cycle." +
-                "snapshotIndex:{}, snapshotTerm:{}, log-size:{}", persistenceId(),
-                context.getReplicatedLog().getSnapshotIndex(),
-                context.getReplicatedLog().getSnapshotTerm(),
-                context.getReplicatedLog().size());
+                @Override
+                protected void doCancel() {
+                    self().tell(PoisonPill.getInstance(), self());
+                }
+            });
+        }
+    }
 
-        } else if (message instanceof CaptureSnapshot) {
-            LOG.info("{}: CaptureSnapshot received by actor", persistenceId());
+    private void onLeaderTransitioning() {
+        LOG.debug("{}: onLeaderTransitioning", persistenceId());
+        Optional<ActorRef> roleChangeNotifier = getRoleChangeNotifier();
+        if(getRaftState() == RaftState.Follower && roleChangeNotifier.isPresent()) {
+            roleChangeNotifier.get().tell(newLeaderStateChanged(getId(), null,
+                getCurrentBehavior().getLeaderPayloadVersion()), getSelf());
+        }
+    }
 
-            if(captureSnapshot == null) {
-                captureSnapshot = (CaptureSnapshot)message;
-                createSnapshot();
+    private void switchBehavior(SwitchBehavior message) {
+        if(!getRaftActorContext().getRaftPolicy().automaticElectionsEnabled()) {
+            RaftState newState = message.getNewState();
+            if( newState == RaftState.Leader || newState == RaftState.Follower) {
+                switchBehavior(behaviorStateTracker.capture(getCurrentBehavior()),
+                    AbstractRaftActorBehavior.createBehavior(context, message.getNewState()));
+                getRaftActorContext().getTermInformation().updateAndPersist(message.getNewTerm(), "");
+            } else {
+                LOG.warn("Switching to behavior : {} - not supported", newState);
             }
+        }
+    }
 
-        } else if (message instanceof CaptureSnapshotReply){
-            handleCaptureSnapshotReply(((CaptureSnapshotReply) message).getSnapshot());
+    private void switchBehavior(final BehaviorState oldBehaviorState, final RaftActorBehavior nextBehavior) {
+        setCurrentBehavior(nextBehavior);
+        handleBehaviorChange(oldBehaviorState, nextBehavior);
+    }
 
-        } else {
-            RaftActorBehavior oldBehavior = currentBehavior;
-            currentBehavior = currentBehavior.handleMessage(getSender(), message);
+    @VisibleForTesting
+    RaftActorSnapshotMessageSupport newRaftActorSnapshotMessageSupport() {
+        return new RaftActorSnapshotMessageSupport(context, getRaftActorSnapshotCohort());
+    }
+
+    private void onGetOnDemandRaftStats() {
+        // Debugging message to retrieve raft stats.
+
+        Map<String, String> peerAddresses = new HashMap<>();
+        for(String peerId: context.getPeerIds()) {
+            peerAddresses.put(peerId, context.getPeerAddress(peerId));
+        }
+
+        final RaftActorBehavior currentBehavior = context.getCurrentBehavior();
+        OnDemandRaftState.Builder builder = OnDemandRaftState.builder()
+                .commitIndex(context.getCommitIndex())
+                .currentTerm(context.getTermInformation().getCurrentTerm())
+                .inMemoryJournalDataSize(replicatedLog().dataSize())
+                .inMemoryJournalLogSize(replicatedLog().size())
+                .isSnapshotCaptureInitiated(context.getSnapshotManager().isCapturing())
+                .lastApplied(context.getLastApplied())
+                .lastIndex(replicatedLog().lastIndex())
+                .lastTerm(replicatedLog().lastTerm())
+                .leader(getLeaderId())
+                .raftState(currentBehavior.state().toString())
+                .replicatedToAllIndex(currentBehavior.getReplicatedToAllIndex())
+                .snapshotIndex(replicatedLog().getSnapshotIndex())
+                .snapshotTerm(replicatedLog().getSnapshotTerm())
+                .votedFor(context.getTermInformation().getVotedFor())
+                .peerAddresses(peerAddresses)
+                .customRaftPolicyClassName(context.getConfigParams().getCustomRaftPolicyImplementationClass());
+
+        ReplicatedLogEntry lastLogEntry = replicatedLog().last();
+        if (lastLogEntry != null) {
+            builder.lastLogIndex(lastLogEntry.getIndex());
+            builder.lastLogTerm(lastLogEntry.getTerm());
+        }
+
+        if(getCurrentBehavior() instanceof AbstractLeader) {
+            AbstractLeader leader = (AbstractLeader)getCurrentBehavior();
+            Collection<String> followerIds = leader.getFollowerIds();
+            List<FollowerInfo> followerInfoList = Lists.newArrayListWithCapacity(followerIds.size());
+            for(String id: followerIds) {
+                final FollowerLogInformation info = leader.getFollower(id);
+                followerInfoList.add(new FollowerInfo(id, info.getNextIndex(), info.getMatchIndex(),
+                        info.isFollowerActive(), DurationFormatUtils.formatDurationHMS(info.timeSinceLastActivity())));
+            }
 
-            handleBehaviorChange(oldBehavior, currentBehavior);
+            builder.followerInfoList(followerInfoList);
         }
+
+        sender().tell(builder.build(), self());
+
     }
 
-    private void handleBehaviorChange(RaftActorBehavior oldBehavior, RaftActorBehavior currentBehavior) {
+    private void handleBehaviorChange(BehaviorState oldBehaviorState, RaftActorBehavior currentBehavior) {
+        RaftActorBehavior oldBehavior = oldBehaviorState.getBehavior();
+
         if (oldBehavior != currentBehavior){
             onStateChanged();
         }
 
-        String oldBehaviorLeaderId = oldBehavior == null? null : oldBehavior.getLeaderId();
-        String oldBehaviorState = oldBehavior == null? null : oldBehavior.state().name();
+        String lastValidLeaderId = oldBehavior == null ? null : oldBehaviorState.getLastValidLeaderId();
+        String oldBehaviorStateName = oldBehavior == null ? null : oldBehavior.state().name();
 
         // it can happen that the state has not changed but the leader has changed.
-        onLeaderChanged(oldBehaviorLeaderId, currentBehavior.getLeaderId());
+        Optional<ActorRef> roleChangeNotifier = getRoleChangeNotifier();
+        if(!Objects.equals(lastValidLeaderId, currentBehavior.getLeaderId()) ||
+           oldBehaviorState.getLeaderPayloadVersion() != currentBehavior.getLeaderPayloadVersion()) {
+            if(roleChangeNotifier.isPresent()) {
+                roleChangeNotifier.get().tell(newLeaderStateChanged(getId(), currentBehavior.getLeaderId(),
+                        currentBehavior.getLeaderPayloadVersion()), getSelf());
+            }
 
-        if (getRoleChangeNotifier().isPresent() &&
+            onLeaderChanged(lastValidLeaderId, currentBehavior.getLeaderId());
+
+            if(leadershipTransferInProgress != null) {
+                leadershipTransferInProgress.onNewLeader(currentBehavior.getLeaderId());
+            }
+
+            serverConfigurationSupport.onNewLeader(currentBehavior.getLeaderId());
+        }
+
+        if (roleChangeNotifier.isPresent() &&
                 (oldBehavior == null || (oldBehavior.state() != currentBehavior.state()))) {
-            getRoleChangeNotifier().get().tell(
-                    new RoleChanged(getId(), oldBehaviorState , currentBehavior.state().name()),
-                    getSelf());
+            roleChangeNotifier.get().tell(new RoleChanged(getId(), oldBehaviorStateName ,
+                    currentBehavior.state().name()), getSelf());
         }
     }
 
+    protected LeaderStateChanged newLeaderStateChanged(String memberId, String leaderId, short leaderPayloadVersion) {
+        return new LeaderStateChanged(memberId, leaderId, leaderPayloadVersion);
+    }
+
+    @Override
+    public long snapshotSequenceNr() {
+        // When we do a snapshot capture, we also capture and save the sequence-number of the persistent journal,
+        // so that we can delete the persistent journal based on the saved sequence-number
+        // However , when akka replays the journal during recovery, it replays it from the sequence number when the snapshot
+        // was saved and not the number we saved.
+        // We would want to override it , by asking akka to use the last-sequence number known to us.
+        return context.getSnapshotManager().getLastSequenceNumber();
+    }
+
     /**
      * When a derived RaftActor needs to persist something it must call
      * persistData.
@@ -402,46 +525,49 @@ public abstract class RaftActor extends AbstractUntypedPersistentActor {
 
         final RaftActorContext raftContext = getRaftActorContext();
 
-        replicatedLog
-                .appendAndPersist(replicatedLogEntry, new Procedure<ReplicatedLogEntry>() {
-                    @Override
-                    public void apply(ReplicatedLogEntry replicatedLogEntry) throws Exception {
-                        if(!hasFollowers()){
-                            // Increment the Commit Index and the Last Applied values
-                            raftContext.setCommitIndex(replicatedLogEntry.getIndex());
-                            raftContext.setLastApplied(replicatedLogEntry.getIndex());
-
-                            // Apply the state immediately
-                            applyState(clientActor, identifier, data);
-
-                            // Send a ApplyLogEntries message so that we write the fact that we applied
-                            // the state to durable storage
-                            self().tell(new ApplyLogEntries((int) replicatedLogEntry.getIndex()), self());
-
-                            // Check if the "real" snapshot capture has been initiated. If no then do the fake snapshot
-                            if(!context.isSnapshotCaptureInitiated()){
-                                raftContext.getReplicatedLog().snapshotPreCommit(raftContext.getLastApplied(),
-                                        raftContext.getTermInformation().getCurrentTerm());
-                                raftContext.getReplicatedLog().snapshotCommit();
-                            } else {
-                                LOG.debug("{}: Skipping fake snapshotting for {} because real snapshotting is in progress",
-                                        persistenceId(), getId());
-                            }
-                        } else if (clientActor != null) {
-                            // Send message for replication
-                            currentBehavior.handleMessage(getSelf(),
-                                    new Replicate(clientActor, identifier,
-                                            replicatedLogEntry)
-                            );
-                        }
+        replicatedLog().appendAndPersist(replicatedLogEntry, new Procedure<ReplicatedLogEntry>() {
+            @Override
+            public void apply(ReplicatedLogEntry replicatedLogEntry) {
+                if (!hasFollowers()){
+                    // Increment the Commit Index and the Last Applied values
+                    raftContext.setCommitIndex(replicatedLogEntry.getIndex());
+                    raftContext.setLastApplied(replicatedLogEntry.getIndex());
 
-                    }
-                });    }
+                    // Apply the state immediately.
+                    self().tell(new ApplyState(clientActor, identifier, replicatedLogEntry), self());
+
+                    // Send a ApplyJournalEntries message so that we write the fact that we applied
+                    // the state to durable storage
+                    self().tell(new ApplyJournalEntries(replicatedLogEntry.getIndex()), self());
+
+                } else if (clientActor != null) {
+                    context.getReplicatedLog().captureSnapshotIfReady(replicatedLogEntry);
+
+                    // Send message for replication
+                    getCurrentBehavior().handleMessage(getSelf(),
+                            new Replicate(clientActor, identifier, replicatedLogEntry));
+                }
+            }
+        });
+    }
+
+    private ReplicatedLog replicatedLog() {
+        return context.getReplicatedLog();
+    }
 
     protected String getId() {
         return context.getId();
     }
 
+    @VisibleForTesting
+    void setCurrentBehavior(RaftActorBehavior behavior) {
+        context.setCurrentBehavior(behavior);
+    }
+
+    protected RaftActorBehavior getCurrentBehavior() {
+        return context.getCurrentBehavior();
+    }
+
     /**
      * Derived actors can call the isLeader method to check if the current
      * RaftActor is the Leader or not
@@ -449,7 +575,16 @@ public abstract class RaftActor extends AbstractUntypedPersistentActor {
      * @return true it this RaftActor is a Leader false otherwise
      */
     protected boolean isLeader() {
-        return context.getId().equals(currentBehavior.getLeaderId());
+        return context.getId().equals(getCurrentBehavior().getLeaderId());
+    }
+
+    protected final boolean isLeaderActive() {
+        return getRaftState() != RaftState.IsolatedLeader && !shuttingDown &&
+                !isLeadershipTransferInProgress();
+    }
+
+    private boolean isLeadershipTransferInProgress() {
+        return leadershipTransferInProgress != null && leadershipTransferInProgress.isTransferring();
     }
 
     /**
@@ -473,32 +608,85 @@ public abstract class RaftActor extends AbstractUntypedPersistentActor {
      *
      * @return the current leader's id
      */
-    protected String getLeaderId(){
-        return currentBehavior.getLeaderId();
+    protected final String getLeaderId(){
+        return getCurrentBehavior().getLeaderId();
     }
 
-    protected RaftState getRaftState() {
-        return currentBehavior.state();
-    }
-
-    protected ReplicatedLogEntry getLastLogEntry() {
-        return replicatedLog.last();
+    @VisibleForTesting
+    protected final RaftState getRaftState() {
+        return getCurrentBehavior().state();
     }
 
     protected Long getCurrentTerm(){
         return context.getTermInformation().getCurrentTerm();
     }
 
-    protected Long getCommitIndex(){
-        return context.getCommitIndex();
+    protected RaftActorContext getRaftActorContext() {
+        return context;
+    }
+
+    protected void updateConfigParams(ConfigParams configParams) {
+
+        // obtain the RaftPolicy for oldConfigParams and the updated one.
+        String oldRaftPolicy = context.getConfigParams().
+            getCustomRaftPolicyImplementationClass();
+        String newRaftPolicy = configParams.
+            getCustomRaftPolicyImplementationClass();
+
+        LOG.debug("{}: RaftPolicy used with prev.config {}, RaftPolicy used with newConfig {}", persistenceId(),
+            oldRaftPolicy, newRaftPolicy);
+        context.setConfigParams(configParams);
+        if (!Objects.equals(oldRaftPolicy, newRaftPolicy)) {
+            // The RaftPolicy was modified. If the current behavior is Follower then re-initialize to Follower
+            // but transfer the previous leaderId so it doesn't immediately try to schedule an election. This
+            // avoids potential disruption. Otherwise, switch to Follower normally.
+            RaftActorBehavior behavior = getCurrentBehavior();
+            if (behavior != null && behavior.state() == RaftState.Follower) {
+                String previousLeaderId = behavior.getLeaderId();
+                short previousLeaderPayloadVersion = behavior.getLeaderPayloadVersion();
+
+                LOG.debug("{}: Re-initializing to Follower with previous leaderId {}", persistenceId(), previousLeaderId);
+
+                changeCurrentBehavior(new Follower(context, previousLeaderId, previousLeaderPayloadVersion));
+            } else {
+                initializeBehavior();
+            }
+        }
     }
 
-    protected Long getLastApplied(){
-        return context.getLastApplied();
+    public final DataPersistenceProvider persistence() {
+        return delegatingPersistenceProvider.getDelegate();
     }
 
-    protected RaftActorContext getRaftActorContext() {
-        return context;
+    public void setPersistence(DataPersistenceProvider provider) {
+        delegatingPersistenceProvider.setDelegate(provider);
+    }
+
+    protected void setPersistence(boolean persistent) {
+        if(persistent) {
+            setPersistence(new PersistentDataProvider(this));
+        } else {
+            setPersistence(new NonPersistentDataProvider() {
+                /**
+                 * The way snapshotting works is,
+                 * <ol>
+                 * <li> RaftActor calls createSnapshot on the Shard
+                 * <li> Shard sends a CaptureSnapshotReply and RaftActor then calls saveSnapshot
+                 * <li> When saveSnapshot is invoked on the akka-persistence API it uses the SnapshotStore to save
+                 * the snapshot. The SnapshotStore sends SaveSnapshotSuccess or SaveSnapshotFailure. When the
+                 * RaftActor gets SaveSnapshot success it commits the snapshot to the in-memory journal. This
+                 * commitSnapshot is mimicking what is done in SaveSnapshotSuccess.
+                 * </ol>
+                 */
+                @Override
+                public void saveSnapshot(Object o) {
+                    // Make saving Snapshot successful
+                    // Committing the snapshot here would end up calling commit in the creating state which would
+                    // be a state violation. That's why now we send a message to commit the snapshot.
+                    self().tell(RaftActorSnapshotMessageSupport.COMMIT_SNAPSHOT, self());
+                }
+            });
+        }
     }
 
     /**
@@ -519,13 +707,6 @@ public abstract class RaftActor extends AbstractUntypedPersistentActor {
         context.setPeerAddress(peerId, peerAddress);
     }
 
-    protected void commitSnapshot(long sequenceNumber) {
-        context.getReplicatedLog().snapshotCommit();
-
-        // TODO: Not sure if we want to be this aggressive with trimming stuff
-        trimPersistentData(sequenceNumber);
-    }
-
     /**
      * The applyState method will be called by the RaftActor when some data
      * needs to be applied to the actor's state
@@ -547,31 +728,10 @@ public abstract class RaftActor extends AbstractUntypedPersistentActor {
         Object data);
 
     /**
-     * This method is called during recovery at the start of a batch of state entries. Derived
-     * classes should perform any initialization needed to start a batch.
-     */
-    protected abstract void startLogRecoveryBatch(int maxBatchSize);
-
-    /**
-     * This method is called during recovery to append state data to the current batch. This method
-     * is called 1 or more times after {@link #startLogRecoveryBatch}.
-     *
-     * @param data the state data
-     */
-    protected abstract void appendRecoveredLogEntry(Payload data);
-
-    /**
-     * This method is called during recovery to reconstruct the state of the actor.
-     *
-     * @param snapshotBytes A snapshot of the state of the actor
-     */
-    protected abstract void applyRecoverySnapshot(byte[] snapshotBytes);
-
-    /**
-     * This method is called during recovery at the end of a batch to apply the current batched
-     * log entries. This method is called after {@link #appendRecoveredLogEntry}.
+     * Returns the RaftActorRecoveryCohort to participate in persistence recovery.
      */
-    protected abstract void applyCurrentLogRecoveryBatch();
+    @Nonnull
+    protected abstract RaftActorRecoveryCohort getRaftActorRecoveryCohort();
 
     /**
      * This method is called when recovery is complete.
@@ -579,24 +739,10 @@ public abstract class RaftActor extends AbstractUntypedPersistentActor {
     protected abstract void onRecoveryComplete();
 
     /**
-     * This method will be called by the RaftActor when a snapshot needs to be
-     * created. The derived actor should respond with its current state.
-     * <p/>
-     * During recovery the state that is returned by the derived actor will
-     * be passed back to it by calling the applySnapshot  method
-     *
-     * @return The current state of the actor
+     * Returns the RaftActorSnapshotCohort to participate in persistence recovery.
      */
-    protected abstract void createSnapshot();
-
-    /**
-     * This method can be called at any other point during normal
-     * operations when the derived actor is out of sync with it's peers
-     * and the only way to bring it in sync is by applying a snapshot
-     *
-     * @param snapshotBytes A snapshot of the state of the actor
-     */
-    protected abstract void applySnapshot(byte[] snapshotBytes);
+    @Nonnull
+    protected abstract RaftActorSnapshotCohort getRaftActorSnapshotCohort();
 
     /**
      * This method will be called by the RaftActor when the state of the
@@ -605,32 +751,36 @@ public abstract class RaftActor extends AbstractUntypedPersistentActor {
      */
     protected abstract void onStateChanged();
 
-    protected abstract DataPersistenceProvider persistence();
-
     /**
      * Notifier Actor for this RaftActor to notify when a role change happens
      * @return ActorRef - ActorRef of the notifier or Optional.absent if none.
      */
     protected abstract Optional<ActorRef> getRoleChangeNotifier();
 
-    protected void onLeaderChanged(String oldLeader, String newLeader){};
+    /**
+     * This method is called prior to operations such as leadership transfer and actor shutdown when the leader
+     * must pause or stop its duties. This method allows derived classes to gracefully pause or finish current
+     * work prior to performing the operation. On completion of any work, the run method must be called on the
+     * given Runnable to proceed with the given operation. <b>Important:</b> the run method must be called on
+     * this actor's thread dispatcher as as it modifies internal state.
+     * <p>
+     * The default implementation immediately runs the operation.
+     *
+     * @param operation the operation to run
+     */
+    protected void pauseLeader(Runnable operation) {
+        operation.run();
+    }
 
-    private void trimPersistentData(long sequenceNumber) {
-        // Trim akka snapshots
-        // FIXME : Not sure how exactly the SnapshotSelectionCriteria is applied
-        // For now guessing that it is ANDed.
-        persistence().deleteSnapshots(new SnapshotSelectionCriteria(
-            sequenceNumber - context.getConfigParams().getSnapshotBatchCount(), 43200000));
+    protected void onLeaderChanged(String oldLeader, String newLeader) {
 
-        // Trim akka journal
-        persistence().deleteMessages(sequenceNumber);
-    }
+    };
 
     private String getLeaderAddress(){
         if(isLeader()){
             return getSelf().path().toString();
         }
-        String leaderId = currentBehavior.getLeaderId();
+        String leaderId = getLeaderId();
         if (leaderId == null) {
             return null;
         }
@@ -643,197 +793,61 @@ public abstract class RaftActor extends AbstractUntypedPersistentActor {
         return peerAddress;
     }
 
-    private void handleCaptureSnapshotReply(byte[] snapshotBytes) {
-        LOG.info("{}: CaptureSnapshotReply received by actor: snapshot size {}", persistenceId(), snapshotBytes.length);
-
-        // create a snapshot object from the state provided and save it
-        // when snapshot is saved async, SaveSnapshotSuccess is raised.
-
-        Snapshot sn = Snapshot.create(snapshotBytes,
-            context.getReplicatedLog().getFrom(captureSnapshot.getLastAppliedIndex() + 1),
-            captureSnapshot.getLastIndex(), captureSnapshot.getLastTerm(),
-            captureSnapshot.getLastAppliedIndex(), captureSnapshot.getLastAppliedTerm());
-
-        persistence().saveSnapshot(sn);
-
-        LOG.info("{}: Persisting of snapshot done:{}", persistenceId(), sn.getLogMessage());
-
-        long dataThreshold = Runtime.getRuntime().totalMemory() *
-                getRaftActorContext().getConfigParams().getSnapshotDataThresholdPercentage() / 100;
-        if (context.getReplicatedLog().dataSize() > dataThreshold) {
-            // if memory is less, clear the log based on lastApplied.
-            // this could/should only happen if one of the followers is down
-            // as normally we keep removing from the log when its replicated to all.
-            context.getReplicatedLog().snapshotPreCommit(captureSnapshot.getLastAppliedIndex(),
-                    captureSnapshot.getLastAppliedTerm());
-
-        } else {
-            // clear the log based on replicatedToAllIndex
-            context.getReplicatedLog().snapshotPreCommit(captureSnapshot.getReplicatedToAllIndex(),
-                    captureSnapshot.getReplicatedToAllTerm());
-        }
-        getCurrentBehavior().setReplicatedToAllIndex(captureSnapshot.getReplicatedToAllIndex());
-
-        LOG.info("{}: Removed in-memory snapshotted entries, adjusted snaphsotIndex:{} " +
-            "and term:{}", persistenceId(), captureSnapshot.getLastAppliedIndex(),
-            captureSnapshot.getLastAppliedTerm());
-
-        if (isLeader() && captureSnapshot.isInstallSnapshotInitiated()) {
-            // this would be call straight to the leader and won't initiate in serialization
-            currentBehavior.handleMessage(getSelf(), new SendInstallSnapshot(
-                    ByteString.copyFrom(snapshotBytes)));
-        }
-
-        captureSnapshot = null;
-        context.setSnapshotCaptureInitiated(false);
-    }
-
     protected boolean hasFollowers(){
-        return getRaftActorContext().getPeerAddresses().keySet().size() > 0;
+        return getRaftActorContext().hasFollowers();
     }
 
-    private class ReplicatedLogImpl extends AbstractReplicatedLogImpl {
+    private void captureSnapshot() {
+        SnapshotManager snapshotManager = context.getSnapshotManager();
 
-        private static final int DATA_SIZE_DIVIDER = 5;
-        private long dataSizeSinceLastSnapshot = 0;
+        if (!snapshotManager.isCapturing()) {
+            final long idx = getCurrentBehavior().getReplicatedToAllIndex();
+            LOG.debug("Take a snapshot of current state. lastReplicatedLog is {} and replicatedToAllIndex is {}",
+                replicatedLog().last(), idx);
 
-        public ReplicatedLogImpl(Snapshot snapshot) {
-            super(snapshot.getLastAppliedIndex(), snapshot.getLastAppliedTerm(),
-                snapshot.getUnAppliedEntries());
+            snapshotManager.capture(replicatedLog().last(), idx);
         }
+    }
 
-        public ReplicatedLogImpl() {
-            super();
-        }
-
-        @Override public void removeFromAndPersist(long logEntryIndex) {
-            int adjustedIndex = adjustedIndex(logEntryIndex);
-
-            if (adjustedIndex < 0) {
-                return;
-            }
-
-            // FIXME: Maybe this should be done after the command is saved
-            journal.subList(adjustedIndex , journal.size()).clear();
-
-            persistence().persist(new DeleteEntries(adjustedIndex), new Procedure<DeleteEntries>() {
-
+    /**
+     * Switch this member to non-voting status. This is a no-op for all behaviors except when we are the leader,
+     * in which case we need to step down.
+     */
+    void becomeNonVoting() {
+        if (isLeader()) {
+            initiateLeadershipTransfer(new RaftActorLeadershipTransferCohort.OnComplete() {
                 @Override
-                public void apply(DeleteEntries param)
-                        throws Exception {
-                    //FIXME : Doing nothing for now
-                    dataSize = 0;
-                    for (ReplicatedLogEntry entry : journal) {
-                        dataSize += entry.size();
-                    }
+                public void onSuccess(ActorRef raftActorRef) {
+                    LOG.debug("{}: leader transfer succeeded after change to non-voting", persistenceId());
+                    ensureFollowerState();
                 }
-            });
-        }
 
-        @Override public void appendAndPersist(
-            final ReplicatedLogEntry replicatedLogEntry) {
-            appendAndPersist(replicatedLogEntry, null);
-        }
-
-        public void appendAndPersist(
-            final ReplicatedLogEntry replicatedLogEntry,
-            final Procedure<ReplicatedLogEntry> callback)  {
-
-            if(LOG.isDebugEnabled()) {
-                LOG.debug("{}: Append log entry and persist {} ", persistenceId(), replicatedLogEntry);
-            }
+                @Override
+                public void onFailure(ActorRef raftActorRef) {
+                    LOG.debug("{}: leader transfer failed after change to non-voting", persistenceId());
+                    ensureFollowerState();
+                }
 
-            // FIXME : By adding the replicated log entry to the in-memory journal we are not truly ensuring durability of the logs
-            journal.add(replicatedLogEntry);
-
-            // When persisting events with persist it is guaranteed that the
-            // persistent actor will not receive further commands between the
-            // persist call and the execution(s) of the associated event
-            // handler. This also holds for multiple persist calls in context
-            // of a single command.
-            persistence().persist(replicatedLogEntry,
-                new Procedure<ReplicatedLogEntry>() {
-                    @Override
-                    public void apply(ReplicatedLogEntry evt) throws Exception {
-                        int logEntrySize = replicatedLogEntry.size();
-
-                        dataSize += logEntrySize;
-                        long dataSizeForCheck = dataSize;
-
-                        dataSizeSinceLastSnapshot += logEntrySize;
-                        long journalSize = lastIndex() + 1;
-
-                        if(!hasFollowers()) {
-                            // When we do not have followers we do not maintain an in-memory log
-                            // due to this the journalSize will never become anything close to the
-                            // snapshot batch count. In fact will mostly be 1.
-                            // Similarly since the journal's dataSize depends on the entries in the
-                            // journal the journal's dataSize will never reach a value close to the
-                            // memory threshold.
-                            // By maintaining the dataSize outside the journal we are tracking essentially
-                            // what we have written to the disk however since we no longer are in
-                            // need of doing a snapshot just for the sake of freeing up memory we adjust
-                            // the real size of data by the DATA_SIZE_DIVIDER so that we do not snapshot as often
-                            // as if we were maintaining a real snapshot
-                            dataSizeForCheck = dataSizeSinceLastSnapshot / DATA_SIZE_DIVIDER;
-                        }
-
-                        long dataThreshold = Runtime.getRuntime().totalMemory() *
-                                getRaftActorContext().getConfigParams().getSnapshotDataThresholdPercentage() / 100;
-
-                        // when a snaphsot is being taken, captureSnapshot != null
-                        if (!context.isSnapshotCaptureInitiated() &&
-                                ( journalSize % context.getConfigParams().getSnapshotBatchCount() == 0 ||
-                                        dataSizeForCheck > dataThreshold)) {
-
-                            dataSizeSinceLastSnapshot = 0;
-
-                            LOG.info("{}: Initiating Snapshot Capture, journalSize = {}, dataSizeForCheck = {}," +
-                                " dataThreshold = {}", persistenceId(), journalSize, dataSizeForCheck, dataThreshold);
-
-                            long lastAppliedIndex = -1;
-                            long lastAppliedTerm = -1;
-
-                            ReplicatedLogEntry lastAppliedEntry = get(context.getLastApplied());
-                            if (!hasFollowers()) {
-                                lastAppliedIndex = replicatedLogEntry.getIndex();
-                                lastAppliedTerm = replicatedLogEntry.getTerm();
-                            } else if (lastAppliedEntry != null) {
-                                lastAppliedIndex = lastAppliedEntry.getIndex();
-                                lastAppliedTerm = lastAppliedEntry.getTerm();
-                            }
-
-                            if(LOG.isDebugEnabled()) {
-                                LOG.debug("{}: Snapshot Capture logSize: {}", persistenceId(), journal.size());
-                                LOG.debug("{}: Snapshot Capture lastApplied:{} ",
-                                        persistenceId(), context.getLastApplied());
-                                LOG.debug("{}: Snapshot Capture lastAppliedIndex:{}", persistenceId(),
-                                        lastAppliedIndex);
-                                LOG.debug("{}: Snapshot Capture lastAppliedTerm:{}", persistenceId(),
-                                        lastAppliedTerm);
-                            }
-
-                            // send a CaptureSnapshot to self to make the expensive operation async.
-                            long replicatedToAllIndex = getCurrentBehavior().getReplicatedToAllIndex();
-                            ReplicatedLogEntry replicatedToAllEntry = context.getReplicatedLog().get(replicatedToAllIndex);
-                            getSelf().tell(new CaptureSnapshot(lastIndex(), lastTerm(), lastAppliedIndex, lastAppliedTerm,
-                                (replicatedToAllEntry != null ? replicatedToAllEntry.getIndex() : -1),
-                                (replicatedToAllEntry != null ? replicatedToAllEntry.getTerm() : -1)),
-                                null);
-                            context.setSnapshotCaptureInitiated(true);
-                        }
-                        if (callback != null){
-                            callback.apply(replicatedLogEntry);
-                        }
+                private void ensureFollowerState() {
+                    // Whether or not leadership transfer succeeded, we have to step down as leader and
+                    // switch to Follower so ensure that.
+                    if (getRaftState() != RaftState.Follower) {
+                        initializeBehavior();
                     }
                 }
-            );
+            });
         }
-
     }
 
+    /**
+     * @deprecated Deprecated in favor of {@link org.opendaylight.controller.cluster.raft.base.messages.DeleteEntries}
+     *             whose type for fromIndex is long instead of int. This class was kept for backwards
+     *             compatibility with Helium.
+     */
+    // Suppressing this warning as we can't set serialVersionUID to maintain backwards compatibility.
+    @SuppressWarnings("serial")
+    @Deprecated
     static class DeleteEntries implements Serializable {
-        private static final long serialVersionUID = 1L;
         private final int fromIndex;
 
         public DeleteEntries(int fromIndex) {
@@ -845,48 +859,14 @@ public abstract class RaftActor extends AbstractUntypedPersistentActor {
         }
     }
 
-
-    private class ElectionTermImpl implements ElectionTerm {
-        /**
-         * Identifier of the actor whose election term information this is
-         */
-        private long currentTerm = 0;
-        private String votedFor = null;
-
-        @Override
-        public long getCurrentTerm() {
-            return currentTerm;
-        }
-
-        @Override
-        public String getVotedFor() {
-            return votedFor;
-        }
-
-        @Override public void update(long currentTerm, String votedFor) {
-            if(LOG.isDebugEnabled()) {
-                LOG.debug("{}: Set currentTerm={}, votedFor={}", persistenceId(), currentTerm, votedFor);
-            }
-            this.currentTerm = currentTerm;
-            this.votedFor = votedFor;
-        }
-
-        @Override
-        public void updateAndPersist(long currentTerm, String votedFor){
-            update(currentTerm, votedFor);
-            // FIXME : Maybe first persist then update the state
-            persistence().persist(new UpdateElectionTerm(this.currentTerm, this.votedFor), new Procedure<UpdateElectionTerm>(){
-
-                @Override public void apply(UpdateElectionTerm param)
-                    throws Exception {
-
-                }
-            });
-        }
-    }
-
+    /**
+     * @deprecated Deprecated in favor of non-inner class {@link org.opendaylight.controller.cluster.raft.base.messages.UpdateElectionTerm}
+     *             which has serialVersionUID set. This class was kept for backwards compatibility with Helium.
+     */
+    // Suppressing this warning as we can't set serialVersionUID to maintain backwards compatibility.
+    @SuppressWarnings("serial")
+    @Deprecated
     static class UpdateElectionTerm implements Serializable {
-        private static final long serialVersionUID = 1L;
         private final long currentTerm;
         private final String votedFor;
 
@@ -904,38 +884,88 @@ public abstract class RaftActor extends AbstractUntypedPersistentActor {
         }
     }
 
-    protected class NonPersistentRaftDataProvider extends NonPersistentDataProvider {
+    /**
+     * A point-in-time capture of {@link RaftActorBehavior} state critical for transitioning between behaviors.
+     */
+    private static abstract class BehaviorState implements Immutable {
+        @Nullable abstract RaftActorBehavior getBehavior();
+        @Nullable abstract String getLastValidLeaderId();
+        @Nullable abstract short getLeaderPayloadVersion();
+    }
+
+    /**
+     * A {@link BehaviorState} corresponding to non-null {@link RaftActorBehavior} state.
+     */
+    private static final class SimpleBehaviorState extends BehaviorState {
+        private final RaftActorBehavior behavior;
+        private final String lastValidLeaderId;
+        private final short leaderPayloadVersion;
 
-        public NonPersistentRaftDataProvider(){
+        SimpleBehaviorState(final String lastValidLeaderId, final RaftActorBehavior behavior) {
+            this.lastValidLeaderId = lastValidLeaderId;
+            this.behavior = Preconditions.checkNotNull(behavior);
+            this.leaderPayloadVersion = behavior.getLeaderPayloadVersion();
+        }
 
+        @Override
+        RaftActorBehavior getBehavior() {
+            return behavior;
         }
 
-        /**
-         * The way snapshotting works is,
-         * <ol>
-         * <li> RaftActor calls createSnapshot on the Shard
-         * <li> Shard sends a CaptureSnapshotReply and RaftActor then calls saveSnapshot
-         * <li> When saveSnapshot is invoked on the akka-persistence API it uses the SnapshotStore to save the snapshot.
-         * The SnapshotStore sends SaveSnapshotSuccess or SaveSnapshotFailure. When the RaftActor gets SaveSnapshot
-         * success it commits the snapshot to the in-memory journal. This commitSnapshot is mimicking what is done
-         * in SaveSnapshotSuccess.
-         * </ol>
-         * @param o
-         */
         @Override
-        public void saveSnapshot(Object o) {
-            // Make saving Snapshot successful
-            commitSnapshot(-1L);
+        String getLastValidLeaderId() {
+            return lastValidLeaderId;
         }
-    }
 
-    @VisibleForTesting
-    void setCurrentBehavior(AbstractRaftActorBehavior behavior) {
-        currentBehavior = behavior;
+        @Override
+        short getLeaderPayloadVersion() {
+            return leaderPayloadVersion;
+        }
     }
 
-    protected RaftActorBehavior getCurrentBehavior() {
-        return currentBehavior;
+    /**
+     * Class tracking behavior-related information, which we need to keep around and pass across behavior switches.
+     * An instance is created for each RaftActor. It has two functions:
+     * - it keeps track of the last leader ID we have encountered since we have been created
+     * - it creates state capture needed to transition from one behavior to the next
+     */
+    private static final class BehaviorStateTracker {
+        /**
+         * A {@link BehaviorState} corresponding to null {@link RaftActorBehavior} state. Since null behavior is only
+         * allowed before we receive the first message, we know the leader ID to be null.
+         */
+        private static final BehaviorState NULL_BEHAVIOR_STATE = new BehaviorState() {
+            @Override
+            RaftActorBehavior getBehavior() {
+                return null;
+            }
+
+            @Override
+            String getLastValidLeaderId() {
+                return null;
+            }
+
+            @Override
+            short getLeaderPayloadVersion() {
+                return -1;
+            }
+        };
+
+        private String lastValidLeaderId;
+
+        BehaviorState capture(final RaftActorBehavior behavior) {
+            if (behavior == null) {
+                Verify.verify(lastValidLeaderId == null, "Null behavior with non-null last leader");
+                return NULL_BEHAVIOR_STATE;
+            }
+
+            final String leaderId = behavior.getLeaderId();
+            if (leaderId != null) {
+                lastValidLeaderId = leaderId;
+            }
+
+            return new SimpleBehaviorState(lastValidLeaderId, behavior);
+        }
     }
 
 }