Change InMemoryJournal to ser/der the data
[controller.git] / opendaylight / md-sal / sal-akka-raft / src / test / java / org / opendaylight / controller / cluster / raft / utils / InMemoryJournal.java
1 /*
2  * Copyright (c) 2015 Brocade Communications Systems, Inc. and others.  All rights reserved.
3  *
4  * This program and the accompanying materials are made available under the
5  * terms of the Eclipse Public License v1.0 which accompanies this distribution,
6  * and is available at http://www.eclipse.org/legal/epl-v10.html
7  */
8 package org.opendaylight.controller.cluster.raft.utils;
9
10 import akka.dispatch.Futures;
11 import akka.japi.Procedure;
12 import akka.persistence.PersistentConfirmation;
13 import akka.persistence.PersistentId;
14 import akka.persistence.PersistentImpl;
15 import akka.persistence.PersistentRepr;
16 import akka.persistence.journal.japi.AsyncWriteJournal;
17 import com.google.common.collect.Maps;
18 import com.google.common.util.concurrent.Uninterruptibles;
19 import java.io.Serializable;
20 import java.util.ArrayList;
21 import java.util.Collections;
22 import java.util.Iterator;
23 import java.util.List;
24 import java.util.Map;
25 import java.util.concurrent.Callable;
26 import java.util.concurrent.ConcurrentHashMap;
27 import java.util.concurrent.CountDownLatch;
28 import java.util.concurrent.TimeUnit;
29 import org.apache.commons.lang.SerializationUtils;
30 import org.slf4j.Logger;
31 import org.slf4j.LoggerFactory;
32 import scala.concurrent.Future;
33
34 /**
35  * An akka AsyncWriteJournal implementation that stores data in memory. This is intended for testing.
36  *
37  * @author Thomas Pantelis
38  */
39 public class InMemoryJournal extends AsyncWriteJournal {
40
41     private static class WriteMessagesComplete {
42         final CountDownLatch latch;
43         final Class<?> ofType;
44
45         public WriteMessagesComplete(int count, Class<?> ofType) {
46             this.latch = new CountDownLatch(count);
47             this.ofType = ofType;
48         }
49     }
50
51     static final Logger LOG = LoggerFactory.getLogger(InMemoryJournal.class);
52
53     private static final Map<String, Map<Long, Object>> journals = new ConcurrentHashMap<>();
54
55     private static final Map<String, CountDownLatch> deleteMessagesCompleteLatches = new ConcurrentHashMap<>();
56
57     private static final Map<String, WriteMessagesComplete> writeMessagesComplete = new ConcurrentHashMap<>();
58
59     private static final Map<String, CountDownLatch> blockReadMessagesLatches = new ConcurrentHashMap<>();
60
61     private static Object deserialize(Object data) {
62         return data instanceof byte[] ? SerializationUtils.deserialize((byte[])data) : data;
63     }
64
65     public static void addEntry(String persistenceId, long sequenceNr, Object data) {
66         Map<Long, Object> journal = journals.get(persistenceId);
67         if(journal == null) {
68             journal = Maps.newLinkedHashMap();
69             journals.put(persistenceId, journal);
70         }
71
72         synchronized (journal) {
73             journal.put(sequenceNr, data instanceof Serializable ?
74                     SerializationUtils.serialize((Serializable) data) : data);
75         }
76     }
77
78     public static void clear() {
79         journals.clear();
80     }
81
82     @SuppressWarnings("unchecked")
83     public static <T> List<T> get(String persistenceId, Class<T> type) {
84         Map<Long, Object> journalMap = journals.get(persistenceId);
85         if(journalMap == null) {
86             return Collections.<T>emptyList();
87         }
88
89         synchronized (journalMap) {
90             List<T> journal = new ArrayList<>(journalMap.size());
91             for(Object entry: journalMap.values()) {
92                 Object data = deserialize(entry);
93                 if(type.isInstance(data)) {
94                     journal.add((T) data);
95                 }
96             }
97
98             return journal;
99         }
100     }
101
102     public static Map<Long, Object> get(String persistenceId) {
103         Map<Long, Object> journalMap = journals.get(persistenceId);
104         return journalMap != null ? journalMap : Collections.<Long, Object>emptyMap();
105     }
106
107     public static void dumpJournal(String persistenceId) {
108         StringBuilder builder = new StringBuilder(String.format("Journal log for %s:", persistenceId));
109         Map<Long, Object> journalMap = journals.get(persistenceId);
110         if(journalMap != null) {
111             synchronized (journalMap) {
112                 for(Map.Entry<Long, Object> e: journalMap.entrySet()) {
113                     builder.append("\n    ").append(e.getKey()).append(" = ").append(e.getValue());
114                 }
115             }
116         }
117
118         LOG.info(builder.toString());
119     }
120
121     public static void waitForDeleteMessagesComplete(String persistenceId) {
122         if(!Uninterruptibles.awaitUninterruptibly(deleteMessagesCompleteLatches.get(persistenceId), 5, TimeUnit.SECONDS)) {
123             throw new AssertionError("Delete messages did not complete");
124         }
125     }
126
127     public static void waitForWriteMessagesComplete(String persistenceId) {
128         if(!Uninterruptibles.awaitUninterruptibly(writeMessagesComplete.get(persistenceId).latch, 5, TimeUnit.SECONDS)) {
129             throw new AssertionError("Journal write messages did not complete");
130         }
131     }
132
133     public static void addDeleteMessagesCompleteLatch(String persistenceId) {
134         deleteMessagesCompleteLatches.put(persistenceId, new CountDownLatch(1));
135     }
136
137     public static void addWriteMessagesCompleteLatch(String persistenceId, int count) {
138         writeMessagesComplete.put(persistenceId, new WriteMessagesComplete(count, null));
139     }
140
141     public static void addWriteMessagesCompleteLatch(String persistenceId, int count, Class<?> ofType) {
142         writeMessagesComplete.put(persistenceId, new WriteMessagesComplete(count, ofType));
143     }
144
145     public static void addBlockReadMessagesLatch(String persistenceId, CountDownLatch latch) {
146         blockReadMessagesLatches.put(persistenceId, latch);
147     }
148
149     @Override
150     public Future<Void> doAsyncReplayMessages(final String persistenceId, final long fromSequenceNr,
151             final long toSequenceNr, final long max, final Procedure<PersistentRepr> replayCallback) {
152         return Futures.future(new Callable<Void>() {
153             @Override
154             public Void call() throws Exception {
155                 CountDownLatch blockLatch = blockReadMessagesLatches.remove(persistenceId);
156                 if(blockLatch != null) {
157                     Uninterruptibles.awaitUninterruptibly(blockLatch);
158                 }
159
160                 Map<Long, Object> journal = journals.get(persistenceId);
161                 if (journal == null) {
162                     return null;
163                 }
164
165                 synchronized (journal) {
166                     int count = 0;
167                     for (Map.Entry<Long,Object> entry : journal.entrySet()) {
168                         if (++count <= max && entry.getKey() >= fromSequenceNr && entry.getKey() <= toSequenceNr) {
169                             PersistentRepr persistentMessage =
170                                     new PersistentImpl(deserialize(entry.getValue()), entry.getKey(), persistenceId,
171                                             false, null, null);
172                             replayCallback.apply(persistentMessage);
173                         }
174                     }
175                 }
176
177                 return null;
178             }
179         }, context().dispatcher());
180     }
181
182     @Override
183     public Future<Long> doAsyncReadHighestSequenceNr(String persistenceId, long fromSequenceNr) {
184         // Akka calls this during recovery.
185         Map<Long, Object> journal = journals.get(persistenceId);
186         if(journal == null) {
187             return Futures.successful(fromSequenceNr);
188         }
189
190         synchronized (journal) {
191             long highest = -1;
192             for (Long seqNr : journal.keySet()) {
193                 if(seqNr.longValue() >= fromSequenceNr && seqNr.longValue() > highest) {
194                     highest = seqNr.longValue();
195                 }
196             }
197
198             return Futures.successful(highest);
199         }
200     }
201
202     @Override
203     public Future<Void> doAsyncWriteMessages(final Iterable<PersistentRepr> messages) {
204         return Futures.future(new Callable<Void>() {
205             @Override
206             public Void call() throws Exception {
207                 for (PersistentRepr repr : messages) {
208                     LOG.trace("doAsyncWriteMessages: id: {}: seqNr: {}, payload: {}", repr.persistenceId(),
209                             repr.sequenceNr(), repr.payload());
210
211                     addEntry(repr.persistenceId(), repr.sequenceNr(), repr.payload());
212
213                     WriteMessagesComplete complete = writeMessagesComplete.get(repr.persistenceId());
214                     if(complete != null) {
215                         if(complete.ofType == null || complete.ofType.equals(repr.payload().getClass())) {
216                             complete.latch.countDown();
217                         }
218                     }
219                 }
220
221                 return null;
222             }
223         }, context().dispatcher());
224     }
225
226     @Override
227     public Future<Void> doAsyncWriteConfirmations(Iterable<PersistentConfirmation> confirmations) {
228         return Futures.successful(null);
229     }
230
231     @Override
232     public Future<Void> doAsyncDeleteMessages(Iterable<PersistentId> messageIds, boolean permanent) {
233         return Futures.successful(null);
234     }
235
236     @Override
237     public Future<Void> doAsyncDeleteMessagesTo(String persistenceId, long toSequenceNr, boolean permanent) {
238         LOG.trace("doAsyncDeleteMessagesTo: {}", toSequenceNr);
239         Map<Long, Object> journal = journals.get(persistenceId);
240         if(journal != null) {
241             synchronized (journal) {
242                 Iterator<Long> iter = journal.keySet().iterator();
243                 while(iter.hasNext()) {
244                     Long n = iter.next();
245                     if(n <= toSequenceNr) {
246                         iter.remove();
247                     }
248                 }
249             }
250         }
251
252         CountDownLatch latch = deleteMessagesCompleteLatches.get(persistenceId);
253         if(latch != null) {
254             latch.countDown();
255         }
256
257         return Futures.successful(null);
258     }
259 }