Merge changes I114cbac1,I45c2e7cd
[controller.git] / opendaylight / md-sal / sal-distributed-datastore / src / main / java / org / opendaylight / controller / cluster / datastore / Shard.java
index c509580ac657f920564cc2c4db1a43f91fa7d325..81449c574780705196e28d3dbd738151c2d948cd 100644 (file)
@@ -18,19 +18,15 @@ import akka.serialization.Serialization;
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Optional;
 import com.google.common.base.Preconditions;
-import com.google.common.collect.Lists;
 import com.google.common.util.concurrent.FutureCallback;
 import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.ListenableFuture;
 import java.io.IOException;
-import java.util.Collection;
 import java.util.HashMap;
-import java.util.List;
 import java.util.Map;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeUnit;
 import javax.annotation.Nonnull;
-import org.opendaylight.controller.cluster.DataPersistenceProvider;
 import org.opendaylight.controller.cluster.common.actor.CommonConfig;
 import org.opendaylight.controller.cluster.common.actor.MeteringBehavior;
 import org.opendaylight.controller.cluster.datastore.ShardCommitCoordinator.CohortEntry;
@@ -43,44 +39,37 @@ import org.opendaylight.controller.cluster.datastore.jmx.mbeans.shard.ShardStats
 import org.opendaylight.controller.cluster.datastore.messages.AbortTransaction;
 import org.opendaylight.controller.cluster.datastore.messages.AbortTransactionReply;
 import org.opendaylight.controller.cluster.datastore.messages.ActorInitialized;
+import org.opendaylight.controller.cluster.datastore.messages.BatchedModifications;
+import org.opendaylight.controller.cluster.datastore.messages.BatchedModificationsReply;
 import org.opendaylight.controller.cluster.datastore.messages.CanCommitTransaction;
 import org.opendaylight.controller.cluster.datastore.messages.CloseTransactionChain;
 import org.opendaylight.controller.cluster.datastore.messages.CommitTransaction;
 import org.opendaylight.controller.cluster.datastore.messages.CommitTransactionReply;
-import org.opendaylight.controller.cluster.datastore.messages.CreateSnapshot;
 import org.opendaylight.controller.cluster.datastore.messages.CreateTransaction;
 import org.opendaylight.controller.cluster.datastore.messages.CreateTransactionReply;
-import org.opendaylight.controller.cluster.datastore.messages.EnableNotification;
 import org.opendaylight.controller.cluster.datastore.messages.ForwardedReadyTransaction;
 import org.opendaylight.controller.cluster.datastore.messages.PeerAddressResolved;
 import org.opendaylight.controller.cluster.datastore.messages.ReadyTransactionReply;
 import org.opendaylight.controller.cluster.datastore.messages.RegisterChangeListener;
-import org.opendaylight.controller.cluster.datastore.messages.RegisterChangeListenerReply;
+import org.opendaylight.controller.cluster.datastore.messages.RegisterDataTreeChangeListener;
 import org.opendaylight.controller.cluster.datastore.messages.UpdateSchemaContext;
 import org.opendaylight.controller.cluster.datastore.modification.Modification;
 import org.opendaylight.controller.cluster.datastore.modification.ModificationPayload;
 import org.opendaylight.controller.cluster.datastore.modification.MutableCompositeModification;
 import org.opendaylight.controller.cluster.datastore.utils.Dispatchers;
 import org.opendaylight.controller.cluster.datastore.utils.MessageTracker;
-import org.opendaylight.controller.cluster.datastore.utils.SerializationUtils;
+import org.opendaylight.controller.cluster.notifications.RegisterRoleChangeListener;
 import org.opendaylight.controller.cluster.notifications.RoleChangeNotifier;
 import org.opendaylight.controller.cluster.raft.RaftActor;
-import org.opendaylight.controller.cluster.raft.ReplicatedLogEntry;
+import org.opendaylight.controller.cluster.raft.RaftActorRecoveryCohort;
+import org.opendaylight.controller.cluster.raft.RaftActorSnapshotCohort;
+import org.opendaylight.controller.cluster.raft.base.messages.FollowerInitialSyncUpStatus;
 import org.opendaylight.controller.cluster.raft.messages.AppendEntriesReply;
 import org.opendaylight.controller.cluster.raft.protobuff.client.messages.CompositeModificationByteStringPayload;
 import org.opendaylight.controller.cluster.raft.protobuff.client.messages.CompositeModificationPayload;
-import org.opendaylight.controller.cluster.raft.protobuff.client.messages.Payload;
-import org.opendaylight.controller.md.sal.common.api.data.AsyncDataChangeListener;
 import org.opendaylight.controller.md.sal.dom.store.impl.InMemoryDOMDataStore;
 import org.opendaylight.controller.md.sal.dom.store.impl.InMemoryDOMDataStoreFactory;
-import org.opendaylight.controller.sal.core.spi.data.DOMStoreThreePhaseCommitCohort;
-import org.opendaylight.controller.sal.core.spi.data.DOMStoreTransaction;
-import org.opendaylight.controller.sal.core.spi.data.DOMStoreTransactionChain;
-import org.opendaylight.controller.sal.core.spi.data.DOMStoreTransactionFactory;
 import org.opendaylight.controller.sal.core.spi.data.DOMStoreWriteTransaction;
-import org.opendaylight.yangtools.concepts.ListenerRegistration;
-import org.opendaylight.yangtools.yang.data.api.YangInstanceIdentifier;
-import org.opendaylight.yangtools.yang.data.api.schema.NormalizedNode;
 import org.opendaylight.yangtools.yang.model.api.SchemaContext;
 import scala.concurrent.duration.Duration;
 import scala.concurrent.duration.FiniteDuration;
@@ -93,8 +82,6 @@ import scala.concurrent.duration.FiniteDuration;
  */
 public class Shard extends RaftActor {
 
-    private static final YangInstanceIdentifier DATASTORE_ROOT = YangInstanceIdentifier.builder().build();
-
     private static final Object TX_COMMIT_TIMEOUT_CHECK_MESSAGE = "txCommitTimeoutCheck";
 
     @VisibleForTesting
@@ -104,23 +91,12 @@ public class Shard extends RaftActor {
     private final InMemoryDOMDataStore store;
 
     /// The name of this shard
-    private final ShardIdentifier name;
+    private final String name;
 
     private final ShardStats shardMBean;
 
-    private final List<ActorSelection> dataChangeListeners =  Lists.newArrayList();
-
-    private final List<DelayedListenerRegistration> delayedListenerRegistrations =
-                                                                       Lists.newArrayList();
-
     private DatastoreContext datastoreContext;
 
-    private DataPersistenceProvider dataPersistenceProvider;
-
-    private SchemaContext schemaContext;
-
-    private int createSnapshotTransactionCounter;
-
     private final ShardCommitCoordinator commitCoordinator;
 
     private long transactionCommitTimeout;
@@ -131,49 +107,50 @@ public class Shard extends RaftActor {
 
     private final MessageTracker appendEntriesReplyTracker;
 
-    /**
-     * Coordinates persistence recovery on startup.
-     */
-    private ShardRecoveryCoordinator recoveryCoordinator;
-    private List<Object> currentLogRecoveryBatch;
+    private final ReadyTransactionReply READY_TRANSACTION_REPLY = new ReadyTransactionReply(
+            Serialization.serializedActorPath(getSelf()));
 
-    private final Map<String, DOMStoreTransactionChain> transactionChains = new HashMap<>();
+    private final DOMTransactionFactory domTransactionFactory;
 
-    private final String txnDispatcherPath;
+    private final ShardTransactionActorFactory transactionActorFactory;
 
-    protected Shard(final ShardIdentifier name, final Map<ShardIdentifier, String> peerAddresses,
+    private final ShardSnapshotCohort snapshotCohort;
+
+    private final DataTreeChangeListenerSupport treeChangeSupport = new DataTreeChangeListenerSupport(this);
+    private final DataChangeListenerSupport changeSupport = new DataChangeListenerSupport(this);
+
+    protected Shard(final ShardIdentifier name, final Map<String, String> peerAddresses,
             final DatastoreContext datastoreContext, final SchemaContext schemaContext) {
-        super(name.toString(), mapPeerAddresses(peerAddresses),
-                Optional.of(datastoreContext.getShardRaftConfig()));
+        super(name.toString(), new HashMap<>(peerAddresses), Optional.of(datastoreContext.getShardRaftConfig()));
 
-        this.name = name;
+        this.name = name.toString();
         this.datastoreContext = datastoreContext;
-        this.schemaContext = schemaContext;
-        this.dataPersistenceProvider = (datastoreContext.isPersistent())
-                ? new PersistentDataProvider() : new NonPersistentRaftDataProvider();
-        this.txnDispatcherPath = new Dispatchers(context().system().dispatchers())
-                .getDispatcherPath(Dispatchers.DispatcherType.Transaction);
 
+        setPersistence(datastoreContext.isPersistent());
 
         LOG.info("Shard created : {}, persistent : {}", name, datastoreContext.isPersistent());
 
         store = InMemoryDOMDataStoreFactory.create(name.toString(), null,
                 datastoreContext.getDataStoreProperties());
 
-        if(schemaContext != null) {
+        if (schemaContext != null) {
             store.onGlobalContextUpdated(schemaContext);
         }
 
         shardMBean = ShardMBeanFactory.getShardStatsMBean(name.toString(),
                 datastoreContext.getDataStoreMXBeanType());
         shardMBean.setNotificationManager(store.getDataChangeListenerNotificationManager());
+        shardMBean.setShardActor(getSelf());
 
         if (isMetricsCaptureEnabled()) {
             getContext().become(new MeteringBehavior(this));
         }
 
-        commitCoordinator = new ShardCommitCoordinator(TimeUnit.SECONDS.convert(1, TimeUnit.MINUTES),
-                datastoreContext.getShardTransactionCommitQueueCapacity(), LOG, name.toString());
+        domTransactionFactory = new DOMTransactionFactory(store, shardMBean, LOG, this.name);
+
+        commitCoordinator = new ShardCommitCoordinator(domTransactionFactory,
+                TimeUnit.SECONDS.convert(5, TimeUnit.MINUTES),
+                datastoreContext.getShardTransactionCommitQueueCapacity(), self(), LOG, this.name);
 
         setTransactionCommitTimeout();
 
@@ -182,6 +159,12 @@ public class Shard extends RaftActor {
 
         appendEntriesReplyTracker = new MessageTracker(AppendEntriesReply.class,
                 getRaftActorContext().getConfigParams().getIsolatedCheckIntervalInMillis());
+
+        transactionActorFactory = new ShardTransactionActorFactory(domTransactionFactory, datastoreContext,
+                new Dispatchers(context().system().dispatchers()).getDispatcherPath(
+                        Dispatchers.DispatcherType.Transaction), self(), getContext(), shardMBean);
+
+        snapshotCohort = new ShardSnapshotCohort(transactionActorFactory, store, LOG, this.name);
     }
 
     private void setTransactionCommitTimeout() {
@@ -189,20 +172,8 @@ public class Shard extends RaftActor {
                 datastoreContext.getShardTransactionCommitTimeoutInSeconds(), TimeUnit.SECONDS);
     }
 
-    private static Map<String, String> mapPeerAddresses(
-        final Map<ShardIdentifier, String> peerAddresses) {
-        Map<String, String> map = new HashMap<>();
-
-        for (Map.Entry<ShardIdentifier, String> entry : peerAddresses
-            .entrySet()) {
-            map.put(entry.getKey().toString(), entry.getValue());
-        }
-
-        return map;
-    }
-
     public static Props props(final ShardIdentifier name,
-        final Map<ShardIdentifier, String> peerAddresses,
+        final Map<String, String> peerAddresses,
         final DatastoreContext datastoreContext, final SchemaContext schemaContext) {
         Preconditions.checkNotNull(name, "name should not be null");
         Preconditions.checkNotNull(peerAddresses, "peerAddresses should not be null");
@@ -215,7 +186,7 @@ public class Shard extends RaftActor {
     private Optional<ActorRef> createRoleChangeNotifier(String shardId) {
         ActorRef shardRoleChangeNotifier = this.getContext().actorOf(
             RoleChangeNotifier.getProps(shardId), shardId + "-notifier");
-        return Optional.<ActorRef>of(shardRoleChangeNotifier);
+        return Optional.of(shardRoleChangeNotifier);
     }
 
     @Override
@@ -264,20 +235,24 @@ public class Shard extends RaftActor {
         }
 
         try {
-            if (message.getClass().equals(CreateTransaction.SERIALIZABLE_CLASS)) {
+            if (CreateTransaction.SERIALIZABLE_CLASS.isInstance(message)) {
                 handleCreateTransaction(message);
+            } else if (BatchedModifications.class.isInstance(message)) {
+                handleBatchedModifications((BatchedModifications)message);
             } else if (message instanceof ForwardedReadyTransaction) {
                 handleForwardedReadyTransaction((ForwardedReadyTransaction) message);
-            } else if (message.getClass().equals(CanCommitTransaction.SERIALIZABLE_CLASS)) {
+            } else if (CanCommitTransaction.SERIALIZABLE_CLASS.isInstance(message)) {
                 handleCanCommitTransaction(CanCommitTransaction.fromSerializable(message));
-            } else if (message.getClass().equals(CommitTransaction.SERIALIZABLE_CLASS)) {
+            } else if (CommitTransaction.SERIALIZABLE_CLASS.isInstance(message)) {
                 handleCommitTransaction(CommitTransaction.fromSerializable(message));
-            } else if (message.getClass().equals(AbortTransaction.SERIALIZABLE_CLASS)) {
+            } else if (AbortTransaction.SERIALIZABLE_CLASS.isInstance(message)) {
                 handleAbortTransaction(AbortTransaction.fromSerializable(message));
-            } else if (message.getClass().equals(CloseTransactionChain.SERIALIZABLE_CLASS)) {
+            } else if (CloseTransactionChain.SERIALIZABLE_CLASS.isInstance(message)) {
                 closeTransactionChain(CloseTransactionChain.fromSerializable(message));
             } else if (message instanceof RegisterChangeListener) {
-                registerChangeListener((RegisterChangeListener) message);
+                changeSupport.onMessage((RegisterChangeListener) message, isLeader());
+            } else if (message instanceof RegisterDataTreeChangeListener) {
+                treeChangeSupport.onMessage((RegisterDataTreeChangeListener) message, isLeader());
             } else if (message instanceof UpdateSchemaContext) {
                 updateSchemaContext((UpdateSchemaContext) message);
             } else if (message instanceof PeerAddressResolved) {
@@ -288,6 +263,11 @@ public class Shard extends RaftActor {
                 handleTransactionCommitTimeoutCheck();
             } else if(message instanceof DatastoreContext) {
                 onDatastoreContext((DatastoreContext)message);
+            } else if(message instanceof RegisterRoleChangeListener){
+                roleChangeNotifier.get().forward(message, context());
+            } else if (message instanceof FollowerInitialSyncUpStatus){
+                shardMBean.setFollowerInitialSyncStatus(((FollowerInitialSyncUpStatus) message).isInitialSyncDone());
+                context().parent().tell(message, self());
             } else {
                 super.onReceiveCommand(message);
             }
@@ -308,12 +288,10 @@ public class Shard extends RaftActor {
 
         setTransactionCommitTimeout();
 
-        if(datastoreContext.isPersistent() &&
-                dataPersistenceProvider instanceof NonPersistentRaftDataProvider) {
-            dataPersistenceProvider = new PersistentDataProvider();
-        } else if(!datastoreContext.isPersistent() &&
-                dataPersistenceProvider instanceof PersistentDataProvider) {
-            dataPersistenceProvider = new NonPersistentRaftDataProvider();
+        if(datastoreContext.isPersistent() && !persistence().isRecoveryApplicable()) {
+            setPersistence(true);
+        } else if(!datastoreContext.isPersistent() && persistence().isRecoveryApplicable()) {
+            setPersistence(false);
         }
 
         updateConfigParams(datastoreContext.getShardRaftConfig());
@@ -363,9 +341,10 @@ public class Shard extends RaftActor {
             // currently uses a same thread executor anyway.
             cohortEntry.getCohort().preCommit().get();
 
-            // If we do not have any followers and we are not using persistence we can
-            // apply modification to the state immediately
-            if(!hasFollowers() && !persistence().isRecoveryApplicable()){
+            // If we do not have any followers and we are not using persistence
+            // or if cohortEntry has no modifications
+            // we can apply modification to the state immediately
+            if((!hasFollowers() && !persistence().isRecoveryApplicable()) || (!cohortEntry.hasModifications())){
                 applyModificationToState(getSender(), transactionID, cohortEntry.getModification());
             } else {
                 Shard.this.persistData(getSender(), transactionID,
@@ -439,6 +418,51 @@ public class Shard extends RaftActor {
         commitCoordinator.handleCanCommit(canCommit, getSender(), self());
     }
 
+    private void handleBatchedModifications(BatchedModifications batched) {
+        // This message is sent to prepare the modificationsa transaction directly on the Shard as an
+        // optimization to avoid the extra overhead of a separate ShardTransaction actor. On the last
+        // BatchedModifications message, the caller sets the ready flag in the message indicating
+        // modifications are complete. The reply contains the cohort actor path (this actor) for the caller
+        // to initiate the 3-phase commit. This also avoids the overhead of sending an additional
+        // ReadyTransaction message.
+
+        // If we're not the leader then forward to the leader. This is a safety measure - we shouldn't
+        // normally get here if we're not the leader as the front-end (TransactionProxy) should determine
+        // the primary/leader shard. However with timing and caching on the front-end, there's a small
+        // window where it could have a stale leader during leadership transitions.
+        //
+        if(isLeader()) {
+            try {
+                boolean ready = commitCoordinator.handleTransactionModifications(batched);
+                if(ready) {
+                    sender().tell(READY_TRANSACTION_REPLY, self());
+                } else {
+                    sender().tell(new BatchedModificationsReply(batched.getModifications().size()), self());
+                }
+            } catch (Exception e) {
+                LOG.error("{}: Error handling BatchedModifications for Tx {}", persistenceId(),
+                        batched.getTransactionID(), e);
+                getSender().tell(new akka.actor.Status.Failure(e), getSelf());
+            }
+        } else {
+            ActorSelection leader = getLeader();
+            if(leader != null) {
+                // TODO: what if this is not the first batch and leadership changed in between batched messages?
+                // We could check if the commitCoordinator already has a cached entry and forward all the previous
+                // batched modifications.
+                LOG.debug("{}: Forwarding BatchedModifications to leader {}", persistenceId(), leader);
+                leader.forward(batched, getContext());
+            } else {
+                // TODO: rather than throwing an immediate exception, we could schedule a timer to try again to make
+                // it more resilient in case we're in the process of electing a new leader.
+                getSender().tell(new akka.actor.Status.Failure(new NoShardLeaderException(String.format(
+                    "Could not find the leader for shard %s. This typically happens" +
+                    " when the system is coming up or recovering and a leader is being elected. Try again" +
+                    " later.", persistenceId()))), getSelf());
+            }
+        }
+    }
+
     private void handleForwardedReadyTransaction(ForwardedReadyTransaction ready) {
         LOG.debug("{}: Readying transaction {}, client version {}", persistenceId(),
                 ready.getTransactionID(), ready.getTxnClientVersion());
@@ -447,24 +471,29 @@ public class Shard extends RaftActor {
         // commitCoordinator in preparation for the subsequent three phase commit initiated by
         // the front-end.
         commitCoordinator.transactionReady(ready.getTransactionID(), ready.getCohort(),
-                ready.getModification());
+                (MutableCompositeModification) ready.getModification());
 
         // Return our actor path as we'll handle the three phase commit, except if the Tx client
         // version < 1 (Helium-1 version). This means the Tx was initiated by a base Helium version
         // node. In that case, the subsequent 3-phase commit messages won't contain the
         // transactionId so to maintain backwards compatibility, we create a separate cohort actor
         // to provide the compatible behavior.
-        ActorRef replyActorPath = self();
-        if(ready.getTxnClientVersion() < DataStoreVersions.HELIUM_1_VERSION) {
-            LOG.debug("{}: Creating BackwardsCompatibleThreePhaseCommitCohort", persistenceId());
-            replyActorPath = getContext().actorOf(BackwardsCompatibleThreePhaseCommitCohort.props(
-                    ready.getTransactionID()));
-        }
+        if(ready.getTxnClientVersion() < DataStoreVersions.LITHIUM_VERSION) {
+            ActorRef replyActorPath = getSelf();
+            if(ready.getTxnClientVersion() < DataStoreVersions.HELIUM_1_VERSION) {
+                LOG.debug("{}: Creating BackwardsCompatibleThreePhaseCommitCohort", persistenceId());
+                replyActorPath = getContext().actorOf(BackwardsCompatibleThreePhaseCommitCohort.props(
+                        ready.getTransactionID()));
+            }
 
-        ReadyTransactionReply readyTransactionReply = new ReadyTransactionReply(
-                Serialization.serializedActorPath(replyActorPath));
-        getSender().tell(ready.isReturnSerialized() ? readyTransactionReply.toSerializable() :
+            ReadyTransactionReply readyTransactionReply =
+                    new ReadyTransactionReply(Serialization.serializedActorPath(replyActorPath),
+                            ready.getTxnClientVersion());
+            getSender().tell(ready.isReturnSerialized() ? readyTransactionReply.toSerializable() :
                 readyTransactionReply, getSelf());
+        } else {
+            getSender().tell(READY_TRANSACTION_REPLY, getSelf());
+        }
     }
 
     private void handleAbortTransaction(final AbortTransaction abort) {
@@ -520,66 +549,15 @@ public class Shard extends RaftActor {
     }
 
     private void closeTransactionChain(final CloseTransactionChain closeTransactionChain) {
-        DOMStoreTransactionChain chain =
-            transactionChains.remove(closeTransactionChain.getTransactionChainId());
-
-        if(chain != null) {
-            chain.close();
-        }
+        domTransactionFactory.closeTransactionChain(closeTransactionChain.getTransactionChainId());
     }
 
     private ActorRef createTypedTransactionActor(int transactionType,
             ShardTransactionIdentifier transactionId, String transactionChainId,
             short clientVersion ) {
 
-        DOMStoreTransactionFactory factory = store;
-
-        if(!transactionChainId.isEmpty()) {
-            factory = transactionChains.get(transactionChainId);
-            if(factory == null){
-                DOMStoreTransactionChain transactionChain = store.createTransactionChain();
-                transactionChains.put(transactionChainId, transactionChain);
-                factory = transactionChain;
-            }
-        }
-
-        if(this.schemaContext == null) {
-            throw new IllegalStateException("SchemaContext is not set");
-        }
-
-        if (transactionType == TransactionProxy.TransactionType.READ_ONLY.ordinal()) {
-
-            shardMBean.incrementReadOnlyTransactionCount();
-
-            return createShardTransaction(factory.newReadOnlyTransaction(), transactionId, clientVersion);
-
-        } else if (transactionType == TransactionProxy.TransactionType.READ_WRITE.ordinal()) {
-
-            shardMBean.incrementReadWriteTransactionCount();
-
-            return createShardTransaction(factory.newReadWriteTransaction(), transactionId, clientVersion);
-
-        } else if (transactionType == TransactionProxy.TransactionType.WRITE_ONLY.ordinal()) {
-
-            shardMBean.incrementWriteOnlyTransactionCount();
-
-            return createShardTransaction(factory.newWriteOnlyTransaction(), transactionId, clientVersion);
-        } else {
-            throw new IllegalArgumentException(
-                "Shard="+name + ":CreateTransaction message has unidentified transaction type="
-                    + transactionType);
-        }
-    }
-
-    private ActorRef createShardTransaction(DOMStoreTransaction transaction, ShardTransactionIdentifier transactionId,
-                                            short clientVersion){
-        return getContext().actorOf(
-                ShardTransaction.props(transaction, getSelf(),
-                        schemaContext, datastoreContext, shardMBean,
-                        transactionId.getRemoteTransactionId(), clientVersion)
-                        .withDispatcher(txnDispatcherPath),
-                transactionId.toString());
-
+        return transactionActorFactory.newShardTransaction(TransactionProxy.TransactionType.fromInt(transactionType),
+                transactionId, transactionChainId, clientVersion);
     }
 
     private void createTransaction(CreateTransaction createTransaction) {
@@ -598,10 +576,8 @@ public class Shard extends RaftActor {
     private ActorRef createTransaction(int transactionType, String remoteTransactionId,
             String transactionChainId, short clientVersion) {
 
-        ShardTransactionIdentifier transactionId =
-            ShardTransactionIdentifier.builder()
-                .remoteTransactionId(remoteTransactionId)
-                .build();
+
+        ShardTransactionIdentifier transactionId = new ShardTransactionIdentifier(remoteTransactionId);
 
         if(LOG.isDebugEnabled()) {
             LOG.debug("{}: Creating transaction : {} ", persistenceId(), transactionId);
@@ -613,18 +589,11 @@ public class Shard extends RaftActor {
         return transactionActor;
     }
 
-    private void syncCommitTransaction(final DOMStoreWriteTransaction transaction)
-        throws ExecutionException, InterruptedException {
-        DOMStoreThreePhaseCommitCohort commitCohort = transaction.ready();
-        commitCohort.preCommit().get();
-        commitCohort.commit().get();
-    }
-
     private void commitWithNewTransaction(final Modification modification) {
         DOMStoreWriteTransaction tx = store.newWriteOnlyTransaction();
         modification.apply(tx);
         try {
-            syncCommitTransaction(tx);
+            snapshotCohort.syncCommitTransaction(tx);
             shardMBean.incrementCommittedTransactionCount();
             shardMBean.setLastCommittedTransactionTime(System.currentTimeMillis());
         } catch (InterruptedException | ExecutionException e) {
@@ -634,9 +603,7 @@ public class Shard extends RaftActor {
     }
 
     private void updateSchemaContext(final UpdateSchemaContext message) {
-        this.schemaContext = message.getSchemaContext();
         updateSchemaContext(message.getSchemaContext());
-        store.onGlobalContextUpdated(message.getSchemaContext());
     }
 
     @VisibleForTesting
@@ -644,142 +611,24 @@ public class Shard extends RaftActor {
         store.onGlobalContextUpdated(schemaContext);
     }
 
-    private void registerChangeListener(final RegisterChangeListener registerChangeListener) {
-
-        LOG.debug("{}: registerDataChangeListener for {}", persistenceId(), registerChangeListener.getPath());
-
-        ListenerRegistration<AsyncDataChangeListener<YangInstanceIdentifier,
-                                                     NormalizedNode<?, ?>>> registration;
-        if(isLeader()) {
-            registration = doChangeListenerRegistration(registerChangeListener);
-        } else {
-            LOG.debug("{}: Shard is not the leader - delaying registration", persistenceId());
-
-            DelayedListenerRegistration delayedReg =
-                    new DelayedListenerRegistration(registerChangeListener);
-            delayedListenerRegistrations.add(delayedReg);
-            registration = delayedReg;
-        }
-
-        ActorRef listenerRegistration = getContext().actorOf(
-                DataChangeListenerRegistration.props(registration));
-
-        LOG.debug("{}: registerDataChangeListener sending reply, listenerRegistrationPath = {} ",
-                persistenceId(), listenerRegistration.path());
-
-        getSender().tell(new RegisterChangeListenerReply(listenerRegistration.path()), getSelf());
-    }
-
-    private ListenerRegistration<AsyncDataChangeListener<YangInstanceIdentifier,
-                                               NormalizedNode<?, ?>>> doChangeListenerRegistration(
-            final RegisterChangeListener registerChangeListener) {
-
-        ActorSelection dataChangeListenerPath = getContext().system().actorSelection(
-                registerChangeListener.getDataChangeListenerPath());
-
-        // Notify the listener if notifications should be enabled or not
-        // If this shard is the leader then it will enable notifications else
-        // it will not
-        dataChangeListenerPath.tell(new EnableNotification(true), getSelf());
-
-        // Now store a reference to the data change listener so it can be notified
-        // at a later point if notifications should be enabled or disabled
-        dataChangeListeners.add(dataChangeListenerPath);
-
-        AsyncDataChangeListener<YangInstanceIdentifier, NormalizedNode<?, ?>> listener =
-                new DataChangeListenerProxy(dataChangeListenerPath);
-
-        LOG.debug("{}: Registering for path {}", persistenceId(), registerChangeListener.getPath());
-
-        return store.registerChangeListener(registerChangeListener.getPath(), listener,
-                registerChangeListener.getScope());
-    }
-
-    private boolean isMetricsCaptureEnabled(){
+    private boolean isMetricsCaptureEnabled() {
         CommonConfig config = new CommonConfig(getContext().system().settings().config());
         return config.isMetricCaptureEnabled();
     }
 
     @Override
-    protected
-    void startLogRecoveryBatch(final int maxBatchSize) {
-        currentLogRecoveryBatch = Lists.newArrayListWithCapacity(maxBatchSize);
-
-        if(LOG.isDebugEnabled()) {
-            LOG.debug("{}: starting log recovery batch with max size {}", persistenceId(), maxBatchSize);
-        }
+    protected RaftActorSnapshotCohort getRaftActorSnapshotCohort() {
+        return snapshotCohort;
     }
 
     @Override
-    protected void appendRecoveredLogEntry(final Payload data) {
-        if(data instanceof ModificationPayload) {
-            try {
-                currentLogRecoveryBatch.add(((ModificationPayload) data).getModification());
-            } catch (ClassNotFoundException | IOException e) {
-                LOG.error("{}: Error extracting ModificationPayload", persistenceId(), e);
-            }
-        } else if (data instanceof CompositeModificationPayload) {
-            currentLogRecoveryBatch.add(((CompositeModificationPayload) data).getModification());
-        } else if (data instanceof CompositeModificationByteStringPayload) {
-            currentLogRecoveryBatch.add(((CompositeModificationByteStringPayload) data).getModification());
-        } else {
-            LOG.error("{}: Unknown state received {} during recovery", persistenceId(), data);
-        }
-    }
-
-    @Override
-    protected void applyRecoverySnapshot(final byte[] snapshotBytes) {
-        if(recoveryCoordinator == null) {
-            recoveryCoordinator = new ShardRecoveryCoordinator(persistenceId(), schemaContext,
-                    LOG, name.toString());
-        }
-
-        recoveryCoordinator.submit(snapshotBytes, store.newWriteOnlyTransaction());
-
-        if(LOG.isDebugEnabled()) {
-            LOG.debug("{}: submitted recovery sbapshot", persistenceId());
-        }
-    }
-
-    @Override
-    protected void applyCurrentLogRecoveryBatch() {
-        if(recoveryCoordinator == null) {
-            recoveryCoordinator = new ShardRecoveryCoordinator(persistenceId(), schemaContext,
-                    LOG, name.toString());
-        }
-
-        recoveryCoordinator.submit(currentLogRecoveryBatch, store.newWriteOnlyTransaction());
-
-        if(LOG.isDebugEnabled()) {
-            LOG.debug("{}: submitted log recovery batch with size {}", persistenceId(),
-                    currentLogRecoveryBatch.size());
-        }
+    @Nonnull
+    protected RaftActorRecoveryCohort getRaftActorRecoveryCohort() {
+        return new ShardRecoveryCoordinator(store, persistenceId(), LOG);
     }
 
     @Override
     protected void onRecoveryComplete() {
-        if(recoveryCoordinator != null) {
-            Collection<DOMStoreWriteTransaction> txList = recoveryCoordinator.getTransactions();
-
-            if(LOG.isDebugEnabled()) {
-                LOG.debug("{}: recovery complete - committing {} Tx's", persistenceId(), txList.size());
-            }
-
-            for(DOMStoreWriteTransaction tx: txList) {
-                try {
-                    syncCommitTransaction(tx);
-                    shardMBean.incrementCommittedTransactionCount();
-                } catch (InterruptedException | ExecutionException e) {
-                    shardMBean.incrementFailedTransactionsCount();
-                    LOG.error("{}: Failed to commit", persistenceId(), e);
-                }
-            }
-        }
-
-        recoveryCoordinator = null;
-        currentLogRecoveryBatch = null;
-        updateJournalStats();
-
         //notify shard manager
         getContext().parent().tell(new ActorInitialized(), getSelf());
 
@@ -817,9 +666,6 @@ public class Shard extends RaftActor {
                     persistenceId(), data, data.getClass().getClassLoader(),
                     CompositeModificationPayload.class.getClassLoader());
         }
-
-        updateJournalStats();
-
     }
 
     private void applyModificationToState(ActorRef clientActor, String identifier, Object modification) {
@@ -837,122 +683,45 @@ public class Shard extends RaftActor {
         }
     }
 
-    private void updateJournalStats() {
-        ReplicatedLogEntry lastLogEntry = getLastLogEntry();
-
-        if (lastLogEntry != null) {
-            shardMBean.setLastLogIndex(lastLogEntry.getIndex());
-            shardMBean.setLastLogTerm(lastLogEntry.getTerm());
-        }
-
-        shardMBean.setCommitIndex(getCommitIndex());
-        shardMBean.setLastApplied(getLastApplied());
-        shardMBean.setInMemoryJournalDataSize(getRaftActorContext().getReplicatedLog().dataSize());
-    }
-
-    @Override
-    protected void createSnapshot() {
-        // Create a transaction actor. We are really going to treat the transaction as a worker
-        // so that this actor does not get block building the snapshot. THe transaction actor will
-        // after processing the CreateSnapshot message.
-
-        ActorRef createSnapshotTransaction = createTransaction(
-                TransactionProxy.TransactionType.READ_ONLY.ordinal(),
-                "createSnapshot" + ++createSnapshotTransactionCounter, "",
-                DataStoreVersions.CURRENT_VERSION);
-
-        createSnapshotTransaction.tell(CreateSnapshot.INSTANCE, self());
-    }
-
-    @VisibleForTesting
-    @Override
-    protected void applySnapshot(final byte[] snapshotBytes) {
-        // Since this will be done only on Recovery or when this actor is a Follower
-        // we can safely commit everything in here. We not need to worry about event notifications
-        // as they would have already been disabled on the follower
-
-        LOG.info("{}: Applying snapshot", persistenceId());
-        try {
-            DOMStoreWriteTransaction transaction = store.newWriteOnlyTransaction();
-
-            NormalizedNode<?, ?> node = SerializationUtils.deserializeNormalizedNode(snapshotBytes);
-
-            // delete everything first
-            transaction.delete(DATASTORE_ROOT);
-
-            // Add everything from the remote node back
-            transaction.write(DATASTORE_ROOT, node);
-            syncCommitTransaction(transaction);
-        } catch (InterruptedException | ExecutionException e) {
-            LOG.error("{}: An exception occurred when applying snapshot", persistenceId(), e);
-        } finally {
-            LOG.info("{}: Done applying snapshot", persistenceId());
-        }
-    }
-
     @Override
     protected void onStateChanged() {
         boolean isLeader = isLeader();
-        for (ActorSelection dataChangeListener : dataChangeListeners) {
-            dataChangeListener.tell(new EnableNotification(isLeader), getSelf());
-        }
-
-        if(isLeader) {
-            for(DelayedListenerRegistration reg: delayedListenerRegistrations) {
-                if(!reg.isClosed()) {
-                    reg.setDelegate(doChangeListenerRegistration(reg.getRegisterChangeListener()));
-                }
-            }
-
-            delayedListenerRegistrations.clear();
-        }
-
-        shardMBean.setRaftState(getRaftState().name());
-        shardMBean.setCurrentTerm(getCurrentTerm());
+        changeSupport.onLeadershipChange(isLeader);
+        treeChangeSupport.onLeadershipChange(isLeader);
 
         // If this actor is no longer the leader close all the transaction chains
-        if(!isLeader){
-            for(Map.Entry<String, DOMStoreTransactionChain> entry : transactionChains.entrySet()){
-                if(LOG.isDebugEnabled()) {
-                    LOG.debug(
-                        "{}: onStateChanged: Closing transaction chain {} because shard {} is no longer the leader",
-                        persistenceId(), entry.getKey(), getId());
-                }
-                entry.getValue().close();
+        if (!isLeader) {
+            if(LOG.isDebugEnabled()) {
+                LOG.debug(
+                    "{}: onStateChanged: Closing all transaction chains because shard {} is no longer the leader",
+                    persistenceId(), getId());
             }
 
-            transactionChains.clear();
+            domTransactionFactory.closeAllTransactionChains();
         }
     }
 
     @Override
-    protected DataPersistenceProvider persistence() {
-        return dataPersistenceProvider;
-    }
-
-    @Override protected void onLeaderChanged(final String oldLeader, final String newLeader) {
-        shardMBean.setLeader(newLeader);
-    }
-
-    @Override public String persistenceId() {
-        return this.name.toString();
+    public String persistenceId() {
+        return this.name;
     }
 
     @VisibleForTesting
-    DataPersistenceProvider getDataPersistenceProvider() {
-        return dataPersistenceProvider;
+    ShardCommitCoordinator getCommitCoordinator() {
+        return commitCoordinator;
     }
 
+
     private static class ShardCreator implements Creator<Shard> {
 
         private static final long serialVersionUID = 1L;
 
         final ShardIdentifier name;
-        final Map<ShardIdentifier, String> peerAddresses;
+        final Map<String, String> peerAddresses;
         final DatastoreContext datastoreContext;
         final SchemaContext schemaContext;
 
-        ShardCreator(final ShardIdentifier name, final Map<ShardIdentifier, String> peerAddresses,
+        ShardCreator(final ShardIdentifier name, final Map<String, String> peerAddresses,
                 final DatastoreContext datastoreContext, final SchemaContext schemaContext) {
             this.name = name;
             this.peerAddresses = peerAddresses;
@@ -967,7 +736,7 @@ public class Shard extends RaftActor {
     }
 
     @VisibleForTesting
-    InMemoryDOMDataStore getDataStore() {
+    public InMemoryDOMDataStore getDataStore() {
         return store;
     }
 
@@ -975,45 +744,4 @@ public class Shard extends RaftActor {
     ShardStats getShardMBean() {
         return shardMBean;
     }
-
-    private static class DelayedListenerRegistration implements
-        ListenerRegistration<AsyncDataChangeListener<YangInstanceIdentifier, NormalizedNode<?, ?>>> {
-
-        private volatile boolean closed;
-
-        private final RegisterChangeListener registerChangeListener;
-
-        private volatile ListenerRegistration<AsyncDataChangeListener<YangInstanceIdentifier,
-                                                             NormalizedNode<?, ?>>> delegate;
-
-        DelayedListenerRegistration(final RegisterChangeListener registerChangeListener) {
-            this.registerChangeListener = registerChangeListener;
-        }
-
-        void setDelegate( final ListenerRegistration<AsyncDataChangeListener<YangInstanceIdentifier,
-                                            NormalizedNode<?, ?>>> registration) {
-            this.delegate = registration;
-        }
-
-        boolean isClosed() {
-            return closed;
-        }
-
-        RegisterChangeListener getRegisterChangeListener() {
-            return registerChangeListener;
-        }
-
-        @Override
-        public AsyncDataChangeListener<YangInstanceIdentifier, NormalizedNode<?, ?>> getInstance() {
-            return delegate != null ? delegate.getInstance() : null;
-        }
-
-        @Override
-        public void close() {
-            closed = true;
-            if(delegate != null) {
-                delegate.close();
-            }
-        }
-    }
 }