import akka.persistence.snapshot.japi.SnapshotStore;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
+import com.google.common.util.concurrent.Uninterruptibles;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.concurrent.Future;
static final Logger LOG = LoggerFactory.getLogger(InMemorySnapshotStore.class);
private static Map<String, List<StoredSnapshot>> snapshots = new ConcurrentHashMap<>();
+ private static final Map<String, CountDownLatch> snapshotSavedLatches = new ConcurrentHashMap<>();
public static void addSnapshot(String persistentId, Object snapshot) {
List<StoredSnapshot> snapshotList = snapshots.get(persistentId);
snapshots.clear();
}
+ public static void addSnapshotSavedLatch(String persistenceId) {
+ snapshotSavedLatches.put(persistenceId, new CountDownLatch(1));
+ }
+
+ public static <T> T waitForSavedSnapshot(String persistenceId, Class<T> type) {
+ if(!Uninterruptibles.awaitUninterruptibly(snapshotSavedLatches.get(persistenceId), 5, TimeUnit.SECONDS)) {
+ throw new AssertionError("Snapshot was not saved");
+ }
+
+ return getSnapshots(persistenceId, type).get(0);
+ }
+
@Override
public Future<Option<SelectedSnapshot>> doLoadAsync(String s,
SnapshotSelectionCriteria snapshotSelectionCriteria) {
snapshotList.add(new StoredSnapshot(snapshotMetadata, o));
}
+ CountDownLatch latch = snapshotSavedLatches.get(snapshotMetadata.persistenceId());
+ if(latch != null) {
+ latch.countDown();
+ }
+
return Futures.successful(null);
}