Fix intermittent failures in FollowerTest
[controller.git] / opendaylight / md-sal / sal-akka-raft / src / test / java / org / opendaylight / controller / cluster / raft / utils / InMemoryJournal.java
1 /*
2  * Copyright (c) 2015 Brocade Communications Systems, Inc. and others.  All rights reserved.
3  *
4  * This program and the accompanying materials are made available under the
5  * terms of the Eclipse Public License v1.0 which accompanies this distribution,
6  * and is available at http://www.eclipse.org/legal/epl-v10.html
7  */
8 package org.opendaylight.controller.cluster.raft.utils;
9
10 import akka.dispatch.Futures;
11 import akka.persistence.AtomicWrite;
12 import akka.persistence.PersistentImpl;
13 import akka.persistence.PersistentRepr;
14 import akka.persistence.journal.japi.AsyncWriteJournal;
15 import com.google.common.collect.Maps;
16 import com.google.common.util.concurrent.Uninterruptibles;
17 import java.io.Serializable;
18 import java.util.ArrayList;
19 import java.util.Collections;
20 import java.util.Iterator;
21 import java.util.List;
22 import java.util.Map;
23 import java.util.Optional;
24 import java.util.concurrent.ConcurrentHashMap;
25 import java.util.concurrent.CountDownLatch;
26 import java.util.concurrent.TimeUnit;
27 import java.util.function.Consumer;
28 import org.apache.commons.lang.SerializationUtils;
29 import org.slf4j.Logger;
30 import org.slf4j.LoggerFactory;
31 import scala.concurrent.Future;
32
33 /**
34  * An akka AsyncWriteJournal implementation that stores data in memory. This is intended for testing.
35  *
36  * @author Thomas Pantelis
37  */
38 public class InMemoryJournal extends AsyncWriteJournal {
39
40     private static class WriteMessagesComplete {
41         final CountDownLatch latch;
42         final Class<?> ofType;
43
44         WriteMessagesComplete(int count, Class<?> ofType) {
45             this.latch = new CountDownLatch(count);
46             this.ofType = ofType;
47         }
48     }
49
50     static final Logger LOG = LoggerFactory.getLogger(InMemoryJournal.class);
51
52     private static final Map<String, Map<Long, Object>> JOURNALS = new ConcurrentHashMap<>();
53
54     private static final Map<String, CountDownLatch> DELETE_MESSAGES_COMPLETE_LATCHES = new ConcurrentHashMap<>();
55
56     private static final Map<String, WriteMessagesComplete> WRITE_MESSAGES_COMPLETE = new ConcurrentHashMap<>();
57
58     private static final Map<String, CountDownLatch> BLOCK_READ_MESSAGES_LATCHES = new ConcurrentHashMap<>();
59
60     private static Object deserialize(Object data) {
61         return data instanceof byte[] ? SerializationUtils.deserialize((byte[])data) : data;
62     }
63
64     public static void addEntry(String persistenceId, long sequenceNr, Object data) {
65         Map<Long, Object> journal = JOURNALS.get(persistenceId);
66         if (journal == null) {
67             journal = Maps.newLinkedHashMap();
68             JOURNALS.put(persistenceId, journal);
69         }
70
71         synchronized (journal) {
72             journal.put(sequenceNr, data instanceof Serializable
73                     ? SerializationUtils.serialize((Serializable) data) : data);
74         }
75     }
76
77     public static void clear() {
78         JOURNALS.clear();
79         DELETE_MESSAGES_COMPLETE_LATCHES.clear();
80         WRITE_MESSAGES_COMPLETE.clear();
81         BLOCK_READ_MESSAGES_LATCHES.clear();
82     }
83
84     @SuppressWarnings("unchecked")
85     public static <T> List<T> get(String persistenceId, Class<T> type) {
86         Map<Long, Object> journalMap = JOURNALS.get(persistenceId);
87         if (journalMap == null) {
88             return Collections.<T>emptyList();
89         }
90
91         synchronized (journalMap) {
92             List<T> journal = new ArrayList<>(journalMap.size());
93             for (Object entry: journalMap.values()) {
94                 Object data = deserialize(entry);
95                 if (type.isInstance(data)) {
96                     journal.add((T) data);
97                 }
98             }
99
100             return journal;
101         }
102     }
103
104     public static Map<Long, Object> get(String persistenceId) {
105         Map<Long, Object> journalMap = JOURNALS.get(persistenceId);
106         return journalMap != null ? journalMap : Collections.<Long, Object>emptyMap();
107     }
108
109     public static void dumpJournal(String persistenceId) {
110         StringBuilder builder = new StringBuilder(String.format("Journal log for %s:", persistenceId));
111         Map<Long, Object> journalMap = JOURNALS.get(persistenceId);
112         if (journalMap != null) {
113             synchronized (journalMap) {
114                 for (Map.Entry<Long, Object> e: journalMap.entrySet()) {
115                     builder.append("\n    ").append(e.getKey()).append(" = ").append(deserialize(e.getValue()));
116                 }
117             }
118         }
119
120         LOG.info(builder.toString());
121     }
122
123     public static void waitForDeleteMessagesComplete(String persistenceId) {
124         if (!Uninterruptibles.awaitUninterruptibly(DELETE_MESSAGES_COMPLETE_LATCHES.get(persistenceId),
125                 5, TimeUnit.SECONDS)) {
126             throw new AssertionError("Delete messages did not complete");
127         }
128     }
129
130     public static void waitForWriteMessagesComplete(String persistenceId) {
131         if (!Uninterruptibles.awaitUninterruptibly(WRITE_MESSAGES_COMPLETE.get(persistenceId).latch,
132                 5, TimeUnit.SECONDS)) {
133             throw new AssertionError("Journal write messages did not complete");
134         }
135     }
136
137     public static void addDeleteMessagesCompleteLatch(String persistenceId) {
138         DELETE_MESSAGES_COMPLETE_LATCHES.put(persistenceId, new CountDownLatch(1));
139     }
140
141     public static void addWriteMessagesCompleteLatch(String persistenceId, int count) {
142         WRITE_MESSAGES_COMPLETE.put(persistenceId, new WriteMessagesComplete(count, null));
143     }
144
145     public static void addWriteMessagesCompleteLatch(String persistenceId, int count, Class<?> ofType) {
146         WRITE_MESSAGES_COMPLETE.put(persistenceId, new WriteMessagesComplete(count, ofType));
147     }
148
149     public static void addBlockReadMessagesLatch(String persistenceId, CountDownLatch latch) {
150         BLOCK_READ_MESSAGES_LATCHES.put(persistenceId, latch);
151     }
152
153     @Override
154     public Future<Void> doAsyncReplayMessages(final String persistenceId, final long fromSequenceNr,
155             final long toSequenceNr, final long max, final Consumer<PersistentRepr> replayCallback) {
156         LOG.trace("doAsyncReplayMessages for {}: fromSequenceNr: {}, toSequenceNr: {}", persistenceId,
157                 fromSequenceNr,toSequenceNr);
158         return Futures.future(() -> {
159             CountDownLatch blockLatch = BLOCK_READ_MESSAGES_LATCHES.remove(persistenceId);
160             if (blockLatch != null) {
161                 Uninterruptibles.awaitUninterruptibly(blockLatch);
162             }
163
164             Map<Long, Object> journal = JOURNALS.get(persistenceId);
165             if (journal == null) {
166                 return null;
167             }
168
169             synchronized (journal) {
170                 int count = 0;
171                 for (Map.Entry<Long,Object> entry : journal.entrySet()) {
172                     if (++count <= max && entry.getKey() >= fromSequenceNr && entry.getKey() <= toSequenceNr) {
173                         PersistentRepr persistentMessage =
174                                 new PersistentImpl(deserialize(entry.getValue()), entry.getKey(), persistenceId,
175                                         null, false, null, null);
176                         replayCallback.accept(persistentMessage);
177                     }
178                 }
179             }
180
181             return null;
182         }, context().dispatcher());
183     }
184
185     @Override
186     public Future<Long> doAsyncReadHighestSequenceNr(String persistenceId, long fromSequenceNr) {
187         LOG.trace("doAsyncReadHighestSequenceNr for {}: fromSequenceNr: {}", persistenceId, fromSequenceNr);
188
189         // Akka calls this during recovery.
190         Map<Long, Object> journal = JOURNALS.get(persistenceId);
191         if (journal == null) {
192             return Futures.successful(fromSequenceNr);
193         }
194
195         synchronized (journal) {
196             long highest = -1;
197             for (Long seqNr : journal.keySet()) {
198                 if (seqNr.longValue() >= fromSequenceNr && seqNr.longValue() > highest) {
199                     highest = seqNr.longValue();
200                 }
201             }
202
203             return Futures.successful(highest);
204         }
205     }
206
207     @Override
208     public Future<Iterable<Optional<Exception>>> doAsyncWriteMessages(final Iterable<AtomicWrite> messages) {
209         return Futures.future(() -> {
210             for (AtomicWrite write : messages) {
211                 // Copy to array - workaround for eclipse "ambiguous method" errors for toIterator, toIterable etc
212                 PersistentRepr[] array = new PersistentRepr[write.payload().size()];
213                 write.payload().copyToArray(array);
214                 for (PersistentRepr repr: array) {
215                     LOG.trace("doAsyncWriteMessages: id: {}: seqNr: {}, payload: {}", repr.persistenceId(),
216                         repr.sequenceNr(), repr.payload());
217
218                     addEntry(repr.persistenceId(), repr.sequenceNr(), repr.payload());
219
220                     WriteMessagesComplete complete = WRITE_MESSAGES_COMPLETE.get(repr.persistenceId());
221                     if (complete != null) {
222                         if (complete.ofType == null || complete.ofType.equals(repr.payload().getClass())) {
223                             complete.latch.countDown();
224                         }
225                     }
226                 }
227             }
228
229             return Collections.emptyList();
230         }, context().dispatcher());
231     }
232
233     @Override
234     public Future<Void> doAsyncDeleteMessagesTo(String persistenceId, long toSequenceNr) {
235         LOG.trace("doAsyncDeleteMessagesTo: {}", toSequenceNr);
236         Map<Long, Object> journal = JOURNALS.get(persistenceId);
237         if (journal != null) {
238             synchronized (journal) {
239                 Iterator<Long> iter = journal.keySet().iterator();
240                 while (iter.hasNext()) {
241                     Long num = iter.next();
242                     if (num <= toSequenceNr) {
243                         iter.remove();
244                     }
245                 }
246             }
247         }
248
249         CountDownLatch latch = DELETE_MESSAGES_COMPLETE_LATCHES.get(persistenceId);
250         if (latch != null) {
251             latch.countDown();
252         }
253
254         return Futures.successful(null);
255     }
256 }