90aff5b72616e9d85abfb3786f28aa5f6345bcbe
[controller.git] / opendaylight / md-sal / sal-akka-raft / src / test / java / org / opendaylight / controller / cluster / raft / utils / InMemoryJournal.java
1 /*
2  * Copyright (c) 2015 Brocade Communications Systems, Inc. and others.  All rights reserved.
3  *
4  * This program and the accompanying materials are made available under the
5  * terms of the Eclipse Public License v1.0 which accompanies this distribution,
6  * and is available at http://www.eclipse.org/legal/epl-v10.html
7  */
8 package org.opendaylight.controller.cluster.raft.utils;
9
10 import akka.dispatch.Futures;
11 import akka.persistence.AtomicWrite;
12 import akka.persistence.PersistentImpl;
13 import akka.persistence.PersistentRepr;
14 import akka.persistence.journal.japi.AsyncWriteJournal;
15 import com.google.common.util.concurrent.Uninterruptibles;
16 import java.io.Serializable;
17 import java.util.ArrayList;
18 import java.util.Collections;
19 import java.util.LinkedHashMap;
20 import java.util.List;
21 import java.util.Map;
22 import java.util.Optional;
23 import java.util.concurrent.ConcurrentHashMap;
24 import java.util.concurrent.CountDownLatch;
25 import java.util.concurrent.TimeUnit;
26 import java.util.function.Consumer;
27 import org.apache.commons.lang.SerializationUtils;
28 import org.slf4j.Logger;
29 import org.slf4j.LoggerFactory;
30 import scala.concurrent.Future;
31
32 /**
33  * An akka AsyncWriteJournal implementation that stores data in memory. This is intended for testing.
34  *
35  * @author Thomas Pantelis
36  */
37 public class InMemoryJournal extends AsyncWriteJournal {
38
39     private static class WriteMessagesComplete {
40         final CountDownLatch latch;
41         final Class<?> ofType;
42
43         WriteMessagesComplete(int count, Class<?> ofType) {
44             this.latch = new CountDownLatch(count);
45             this.ofType = ofType;
46         }
47     }
48
49     static final Logger LOG = LoggerFactory.getLogger(InMemoryJournal.class);
50
51     private static final Map<String, Map<Long, Object>> JOURNALS = new ConcurrentHashMap<>();
52
53     private static final Map<String, CountDownLatch> DELETE_MESSAGES_COMPLETE_LATCHES = new ConcurrentHashMap<>();
54
55     private static final Map<String, WriteMessagesComplete> WRITE_MESSAGES_COMPLETE = new ConcurrentHashMap<>();
56
57     private static final Map<String, CountDownLatch> BLOCK_READ_MESSAGES_LATCHES = new ConcurrentHashMap<>();
58
59     private static Object deserialize(Object data) {
60         return data instanceof byte[] ? SerializationUtils.deserialize((byte[])data) : data;
61     }
62
63     public static void addEntry(String persistenceId, long sequenceNr, Object data) {
64         Map<Long, Object> journal = JOURNALS.computeIfAbsent(persistenceId, k -> new LinkedHashMap<>());
65
66         synchronized (journal) {
67             journal.put(sequenceNr, data instanceof Serializable
68                     ? SerializationUtils.serialize((Serializable) data) : data);
69         }
70     }
71
72     public static void clear() {
73         JOURNALS.clear();
74         DELETE_MESSAGES_COMPLETE_LATCHES.clear();
75         WRITE_MESSAGES_COMPLETE.clear();
76         BLOCK_READ_MESSAGES_LATCHES.clear();
77     }
78
79     @SuppressWarnings("unchecked")
80     public static <T> List<T> get(String persistenceId, Class<T> type) {
81         Map<Long, Object> journalMap = JOURNALS.get(persistenceId);
82         if (journalMap == null) {
83             return Collections.<T>emptyList();
84         }
85
86         synchronized (journalMap) {
87             List<T> journal = new ArrayList<>(journalMap.size());
88             for (Object entry: journalMap.values()) {
89                 Object data = deserialize(entry);
90                 if (type.isInstance(data)) {
91                     journal.add((T) data);
92                 }
93             }
94
95             return journal;
96         }
97     }
98
99     public static Map<Long, Object> get(String persistenceId) {
100         Map<Long, Object> journalMap = JOURNALS.get(persistenceId);
101         return journalMap != null ? journalMap : Collections.<Long, Object>emptyMap();
102     }
103
104     public static void dumpJournal(String persistenceId) {
105         StringBuilder builder = new StringBuilder(String.format("Journal log for %s:", persistenceId));
106         Map<Long, Object> journalMap = JOURNALS.get(persistenceId);
107         if (journalMap != null) {
108             synchronized (journalMap) {
109                 for (Map.Entry<Long, Object> e: journalMap.entrySet()) {
110                     builder.append("\n    ").append(e.getKey()).append(" = ").append(deserialize(e.getValue()));
111                 }
112             }
113         }
114
115         LOG.info(builder.toString());
116     }
117
118     public static void waitForDeleteMessagesComplete(String persistenceId) {
119         if (!Uninterruptibles.awaitUninterruptibly(DELETE_MESSAGES_COMPLETE_LATCHES.get(persistenceId),
120                 5, TimeUnit.SECONDS)) {
121             throw new AssertionError("Delete messages did not complete");
122         }
123     }
124
125     public static void waitForWriteMessagesComplete(String persistenceId) {
126         if (!Uninterruptibles.awaitUninterruptibly(WRITE_MESSAGES_COMPLETE.get(persistenceId).latch,
127                 5, TimeUnit.SECONDS)) {
128             throw new AssertionError("Journal write messages did not complete");
129         }
130     }
131
132     public static void addDeleteMessagesCompleteLatch(String persistenceId) {
133         DELETE_MESSAGES_COMPLETE_LATCHES.put(persistenceId, new CountDownLatch(1));
134     }
135
136     public static void addWriteMessagesCompleteLatch(String persistenceId, int count) {
137         WRITE_MESSAGES_COMPLETE.put(persistenceId, new WriteMessagesComplete(count, null));
138     }
139
140     public static void addWriteMessagesCompleteLatch(String persistenceId, int count, Class<?> ofType) {
141         WRITE_MESSAGES_COMPLETE.put(persistenceId, new WriteMessagesComplete(count, ofType));
142     }
143
144     public static void addBlockReadMessagesLatch(String persistenceId, CountDownLatch latch) {
145         BLOCK_READ_MESSAGES_LATCHES.put(persistenceId, latch);
146     }
147
148     @Override
149     public Future<Void> doAsyncReplayMessages(final String persistenceId, final long fromSequenceNr,
150             final long toSequenceNr, final long max, final Consumer<PersistentRepr> replayCallback) {
151         LOG.trace("doAsyncReplayMessages for {}: fromSequenceNr: {}, toSequenceNr: {}", persistenceId,
152                 fromSequenceNr,toSequenceNr);
153         return Futures.future(() -> {
154             CountDownLatch blockLatch = BLOCK_READ_MESSAGES_LATCHES.remove(persistenceId);
155             if (blockLatch != null) {
156                 Uninterruptibles.awaitUninterruptibly(blockLatch);
157             }
158
159             Map<Long, Object> journal = JOURNALS.get(persistenceId);
160             if (journal == null) {
161                 return null;
162             }
163
164             synchronized (journal) {
165                 int count = 0;
166                 for (Map.Entry<Long,Object> entry : journal.entrySet()) {
167                     if (++count <= max && entry.getKey() >= fromSequenceNr && entry.getKey() <= toSequenceNr) {
168                         PersistentRepr persistentMessage =
169                                 new PersistentImpl(deserialize(entry.getValue()), entry.getKey(), persistenceId,
170                                         null, false, null, null);
171                         replayCallback.accept(persistentMessage);
172                     }
173                 }
174             }
175
176             return null;
177         }, context().dispatcher());
178     }
179
180     @Override
181     public Future<Long> doAsyncReadHighestSequenceNr(String persistenceId, long fromSequenceNr) {
182         LOG.trace("doAsyncReadHighestSequenceNr for {}: fromSequenceNr: {}", persistenceId, fromSequenceNr);
183
184         // Akka calls this during recovery.
185         Map<Long, Object> journal = JOURNALS.get(persistenceId);
186         if (journal == null) {
187             return Futures.successful(fromSequenceNr);
188         }
189
190         synchronized (journal) {
191             long highest = -1;
192             for (Long seqNr : journal.keySet()) {
193                 if (seqNr.longValue() >= fromSequenceNr && seqNr.longValue() > highest) {
194                     highest = seqNr.longValue();
195                 }
196             }
197
198             return Futures.successful(highest);
199         }
200     }
201
202     @Override
203     public Future<Iterable<Optional<Exception>>> doAsyncWriteMessages(final Iterable<AtomicWrite> messages) {
204         return Futures.future(() -> {
205             for (AtomicWrite write : messages) {
206                 // Copy to array - workaround for eclipse "ambiguous method" errors for toIterator, toIterable etc
207                 PersistentRepr[] array = new PersistentRepr[write.payload().size()];
208                 write.payload().copyToArray(array);
209                 for (PersistentRepr repr: array) {
210                     LOG.trace("doAsyncWriteMessages: id: {}: seqNr: {}, payload: {}", repr.persistenceId(),
211                         repr.sequenceNr(), repr.payload());
212
213                     addEntry(repr.persistenceId(), repr.sequenceNr(), repr.payload());
214
215                     WriteMessagesComplete complete = WRITE_MESSAGES_COMPLETE.get(repr.persistenceId());
216                     if (complete != null) {
217                         if (complete.ofType == null || complete.ofType.equals(repr.payload().getClass())) {
218                             complete.latch.countDown();
219                         }
220                     }
221                 }
222             }
223
224             return Collections.emptyList();
225         }, context().dispatcher());
226     }
227
228     @Override
229     public Future<Void> doAsyncDeleteMessagesTo(String persistenceId, long toSequenceNr) {
230         LOG.trace("doAsyncDeleteMessagesTo: {}", toSequenceNr);
231         Map<Long, Object> journal = JOURNALS.get(persistenceId);
232         if (journal != null) {
233             synchronized (journal) {
234                 journal.keySet().removeIf(num -> num <= toSequenceNr);
235             }
236         }
237
238         CountDownLatch latch = DELETE_MESSAGES_COMPLETE_LATCHES.get(persistenceId);
239         if (latch != null) {
240             latch.countDown();
241         }
242
243         return Futures.successful(null);
244     }
245 }