Always persist and recover election term info
[controller.git] / opendaylight / md-sal / sal-akka-raft / src / test / java / org / opendaylight / controller / cluster / raft / utils / InMemoryJournal.java
1 /*
2  * Copyright (c) 2015 Brocade Communications Systems, Inc. and others.  All rights reserved.
3  *
4  * This program and the accompanying materials are made available under the
5  * terms of the Eclipse Public License v1.0 which accompanies this distribution,
6  * and is available at http://www.eclipse.org/legal/epl-v10.html
7  */
8 package org.opendaylight.controller.cluster.raft.utils;
9
10 import akka.dispatch.Futures;
11 import akka.japi.Procedure;
12 import akka.persistence.PersistentConfirmation;
13 import akka.persistence.PersistentId;
14 import akka.persistence.PersistentImpl;
15 import akka.persistence.PersistentRepr;
16 import akka.persistence.journal.japi.AsyncWriteJournal;
17 import com.google.common.collect.Maps;
18 import com.google.common.util.concurrent.Uninterruptibles;
19 import java.util.ArrayList;
20 import java.util.Collections;
21 import java.util.Iterator;
22 import java.util.List;
23 import java.util.Map;
24 import java.util.concurrent.Callable;
25 import java.util.concurrent.ConcurrentHashMap;
26 import java.util.concurrent.CountDownLatch;
27 import java.util.concurrent.TimeUnit;
28 import org.slf4j.Logger;
29 import org.slf4j.LoggerFactory;
30 import scala.concurrent.Future;
31
32 /**
33  * An akka AsyncWriteJournal implementation that stores data in memory. This is intended for testing.
34  *
35  * @author Thomas Pantelis
36  */
37 public class InMemoryJournal extends AsyncWriteJournal {
38
39     private static class WriteMessagesComplete {
40         final CountDownLatch latch;
41         final Class<?> ofType;
42
43         public WriteMessagesComplete(int count, Class<?> ofType) {
44             this.latch = new CountDownLatch(count);
45             this.ofType = ofType;
46         }
47     }
48
49     static final Logger LOG = LoggerFactory.getLogger(InMemoryJournal.class);
50
51     private static final Map<String, Map<Long, Object>> journals = new ConcurrentHashMap<>();
52
53     private static final Map<String, CountDownLatch> deleteMessagesCompleteLatches = new ConcurrentHashMap<>();
54
55     private static final Map<String, WriteMessagesComplete> writeMessagesComplete = new ConcurrentHashMap<>();
56
57     private static final Map<String, CountDownLatch> blockReadMessagesLatches = new ConcurrentHashMap<>();
58
59     public static void addEntry(String persistenceId, long sequenceNr, Object data) {
60         Map<Long, Object> journal = journals.get(persistenceId);
61         if(journal == null) {
62             journal = Maps.newLinkedHashMap();
63             journals.put(persistenceId, journal);
64         }
65
66         synchronized (journal) {
67             journal.put(sequenceNr, data);
68         }
69     }
70
71     public static void clear() {
72         journals.clear();
73     }
74
75     @SuppressWarnings("unchecked")
76     public static <T> List<T> get(String persistenceId, Class<T> type) {
77         Map<Long, Object> journalMap = journals.get(persistenceId);
78         if(journalMap == null) {
79             return Collections.<T>emptyList();
80         }
81
82         synchronized (journalMap) {
83             List<T> journal = new ArrayList<>(journalMap.size());
84             for(Object entry: journalMap.values()) {
85                 if(type.isInstance(entry)) {
86                     journal.add((T) entry);
87                 }
88             }
89
90             return journal;
91         }
92     }
93
94     public static Map<Long, Object> get(String persistenceId) {
95         Map<Long, Object> journalMap = journals.get(persistenceId);
96         return journalMap != null ? journalMap : Collections.<Long, Object>emptyMap();
97     }
98
99     public static void dumpJournal(String persistenceId) {
100         StringBuilder builder = new StringBuilder(String.format("Journal log for %s:", persistenceId));
101         Map<Long, Object> journalMap = journals.get(persistenceId);
102         if(journalMap != null) {
103             synchronized (journalMap) {
104                 for(Map.Entry<Long, Object> e: journalMap.entrySet()) {
105                     builder.append("\n    ").append(e.getKey()).append(" = ").append(e.getValue());
106                 }
107             }
108         }
109
110         LOG.info(builder.toString());
111     }
112
113     public static void waitForDeleteMessagesComplete(String persistenceId) {
114         if(!Uninterruptibles.awaitUninterruptibly(deleteMessagesCompleteLatches.get(persistenceId), 5, TimeUnit.SECONDS)) {
115             throw new AssertionError("Delete messages did not complete");
116         }
117     }
118
119     public static void waitForWriteMessagesComplete(String persistenceId) {
120         if(!Uninterruptibles.awaitUninterruptibly(writeMessagesComplete.get(persistenceId).latch, 5, TimeUnit.SECONDS)) {
121             throw new AssertionError("Journal write messages did not complete");
122         }
123     }
124
125     public static void addDeleteMessagesCompleteLatch(String persistenceId) {
126         deleteMessagesCompleteLatches.put(persistenceId, new CountDownLatch(1));
127     }
128
129     public static void addWriteMessagesCompleteLatch(String persistenceId, int count) {
130         writeMessagesComplete.put(persistenceId, new WriteMessagesComplete(count, null));
131     }
132
133     public static void addWriteMessagesCompleteLatch(String persistenceId, int count, Class<?> ofType) {
134         writeMessagesComplete.put(persistenceId, new WriteMessagesComplete(count, ofType));
135     }
136
137     public static void addBlockReadMessagesLatch(String persistenceId, CountDownLatch latch) {
138         blockReadMessagesLatches.put(persistenceId, latch);
139     }
140
141     @Override
142     public Future<Void> doAsyncReplayMessages(final String persistenceId, final long fromSequenceNr,
143             final long toSequenceNr, long max, final Procedure<PersistentRepr> replayCallback) {
144         return Futures.future(new Callable<Void>() {
145             @Override
146             public Void call() throws Exception {
147                 CountDownLatch blockLatch = blockReadMessagesLatches.remove(persistenceId);
148                 if(blockLatch != null) {
149                     Uninterruptibles.awaitUninterruptibly(blockLatch);
150                 }
151
152                 Map<Long, Object> journal = journals.get(persistenceId);
153                 if (journal == null) {
154                     return null;
155                 }
156
157                 synchronized (journal) {
158                     for (Map.Entry<Long,Object> entry : journal.entrySet()) {
159                         if (entry.getKey() >= fromSequenceNr && entry.getKey() <= toSequenceNr) {
160                             PersistentRepr persistentMessage =
161                                     new PersistentImpl(entry.getValue(), entry.getKey(), persistenceId,
162                                             false, null, null);
163                             replayCallback.apply(persistentMessage);
164                         }
165                     }
166                 }
167
168                 return null;
169             }
170         }, context().dispatcher());
171     }
172
173     @Override
174     public Future<Long> doAsyncReadHighestSequenceNr(String persistenceId, long fromSequenceNr) {
175         // Akka calls this during recovery.
176         Map<Long, Object> journal = journals.get(persistenceId);
177         if(journal == null) {
178             return Futures.successful(fromSequenceNr);
179         }
180
181         synchronized (journal) {
182             long highest = -1;
183             for (Long seqNr : journal.keySet()) {
184                 if(seqNr.longValue() >= fromSequenceNr && seqNr.longValue() > highest) {
185                     highest = seqNr.longValue();
186                 }
187             }
188
189             return Futures.successful(highest);
190         }
191     }
192
193     @Override
194     public Future<Void> doAsyncWriteMessages(final Iterable<PersistentRepr> messages) {
195         return Futures.future(new Callable<Void>() {
196             @Override
197             public Void call() throws Exception {
198                 for (PersistentRepr repr : messages) {
199                     Map<Long, Object> journal = journals.get(repr.persistenceId());
200                     if(journal == null) {
201                         journal = Maps.newLinkedHashMap();
202                         journals.put(repr.persistenceId(), journal);
203                     }
204
205                     synchronized (journal) {
206                         LOG.trace("doAsyncWriteMessages: id: {}: seqNr: {}, payload: {}", repr.persistenceId(),
207                                 repr.sequenceNr(), repr.payload());
208                         journal.put(repr.sequenceNr(), repr.payload());
209                     }
210
211                     WriteMessagesComplete complete = writeMessagesComplete.get(repr.persistenceId());
212                     if(complete != null) {
213                         if(complete.ofType == null || complete.ofType.equals(repr.payload().getClass())) {
214                             complete.latch.countDown();
215                         }
216                     }
217                 }
218
219                 return null;
220             }
221         }, context().dispatcher());
222     }
223
224     @Override
225     public Future<Void> doAsyncWriteConfirmations(Iterable<PersistentConfirmation> confirmations) {
226         return Futures.successful(null);
227     }
228
229     @Override
230     public Future<Void> doAsyncDeleteMessages(Iterable<PersistentId> messageIds, boolean permanent) {
231         return Futures.successful(null);
232     }
233
234     @Override
235     public Future<Void> doAsyncDeleteMessagesTo(String persistenceId, long toSequenceNr, boolean permanent) {
236         LOG.trace("doAsyncDeleteMessagesTo: {}", toSequenceNr);
237         Map<Long, Object> journal = journals.get(persistenceId);
238         if(journal != null) {
239             synchronized (journal) {
240                 Iterator<Long> iter = journal.keySet().iterator();
241                 while(iter.hasNext()) {
242                     Long n = iter.next();
243                     if(n <= toSequenceNr) {
244                         iter.remove();
245                     }
246                 }
247             }
248         }
249
250         CountDownLatch latch = deleteMessagesCompleteLatches.get(persistenceId);
251         if(latch != null) {
252             latch.countDown();
253         }
254
255         return Futures.successful(null);
256     }
257 }