+ public void testWithNoCohorts() throws Exception {
+ ThreePhaseCommitCohortProxy proxy = new ThreePhaseCommitCohortProxy(actorContext,
+ Collections.<CohortInfo>emptyList(), tx);
+
+ verifyCanCommit(proxy.canCommit(), true);
+ verifySuccessfulFuture(proxy.preCommit());
+ verifySuccessfulFuture(proxy.commit());
+ verifyCohortActors();
+ }
+
+ @SuppressWarnings("checkstyle:avoidHidingCauseException")
+ private void propagateExecutionExceptionCause(final ListenableFuture<?> future) throws Exception {
+ try {
+ future.get(5, TimeUnit.SECONDS);
+ fail("Expected ExecutionException");
+ } catch (ExecutionException e) {
+ verifyCohortActors();
+ Throwables.propagateIfPossible(e.getCause(), Exception.class);
+ throw new RuntimeException(e.getCause());
+ }
+ }
+
+ private CohortInfo newCohortInfo(final CohortActor.Builder builder, final short version) {
+ TestActorRef<CohortActor> actor = actorFactory.createTestActor(builder.props()
+ .withDispatcher(Dispatchers.DefaultDispatcherId()), actorFactory.generateActorId("cohort"));
+ cohortActors.add(actor);
+ return new CohortInfo(Futures.successful(getSystem().actorSelection(actor.path())), () -> version);
+ }
+
+ private CohortInfo newCohortInfo(final CohortActor.Builder builder) {
+ return newCohortInfo(builder, CURRENT_VERSION);
+ }
+
+ private static CohortInfo newCohortInfoWithFailedFuture(final Exception failure) {
+ return new CohortInfo(Futures.<ActorSelection>failed(failure), () -> CURRENT_VERSION);
+ }
+
+ private void verifyCohortActors() {
+ for (TestActorRef<CohortActor> actor: cohortActors) {
+ actor.underlyingActor().verify();
+ }
+ }
+
+ @SuppressWarnings("checkstyle:IllegalCatch")
+ private <T> T verifySuccessfulFuture(final ListenableFuture<T> future) throws Exception {
+ try {
+ return future.get(5, TimeUnit.SECONDS);
+ } catch (Exception e) {
+ verifyCohortActors();
+ throw e;
+ }
+ }
+
+ private void verifyCanCommit(final ListenableFuture<Boolean> future, final boolean expected) throws Exception {
+ Boolean actual = verifySuccessfulFuture(future);
+ assertEquals("canCommit", expected, actual);
+ }
+
+ private static class CohortActor extends UntypedActor {
+ private final Builder builder;
+ private final AtomicInteger canCommitCount = new AtomicInteger();
+ private final AtomicInteger commitCount = new AtomicInteger();
+ private final AtomicInteger abortCount = new AtomicInteger();
+ private volatile AssertionError assertionError;
+
+ CohortActor(final Builder builder) {
+ this.builder = builder;
+ }
+
+ @Override
+ public void onReceive(final Object message) {
+ if (CanCommitTransaction.isSerializedType(message)) {
+ canCommitCount.incrementAndGet();
+ onMessage("CanCommitTransaction", message, CanCommitTransaction.fromSerializable(message),
+ builder.expCanCommitType, builder.canCommitReply);
+ } else if (CommitTransaction.isSerializedType(message)) {
+ commitCount.incrementAndGet();
+ onMessage("CommitTransaction", message, CommitTransaction.fromSerializable(message),
+ builder.expCommitType, builder.commitReply);
+ } else if (AbortTransaction.isSerializedType(message)) {
+ abortCount.incrementAndGet();
+ onMessage("AbortTransaction", message, AbortTransaction.fromSerializable(message),
+ builder.expAbortType, builder.abortReply);
+ } else {
+ assertionError = new AssertionError("Unexpected message " + message);
+ }
+ }
+
+ private void onMessage(final String name, final Object rawMessage,
+ final AbstractThreePhaseCommitMessage actualMessage, final Class<?> expType, final Object reply) {
+ try {
+ assertNotNull("Unexpected " + name, expType);
+ assertEquals(name + " type", expType, rawMessage.getClass());
+ assertEquals(name + " transactionId", builder.transactionId, actualMessage.getTransactionId());
+
+ if (reply instanceof Throwable) {
+ getSender().tell(new akka.actor.Status.Failure((Throwable)reply), self());
+ } else {
+ getSender().tell(reply, self());
+ }
+ } catch (AssertionError e) {
+ assertionError = e;
+ }
+ }
+
+ void verify() {
+ if (assertionError != null) {
+ throw assertionError;
+ }
+
+ if (builder.expCanCommitType != null) {
+ assertEquals("CanCommitTransaction count", 1, canCommitCount.get());
+ }
+
+ if (builder.expCommitType != null) {
+ assertEquals("CommitTransaction count", 1, commitCount.get());
+ }
+
+ if (builder.expAbortType != null) {
+ assertEquals("AbortTransaction count", 1, abortCount.get());
+ }
+ }
+
+ static class Builder {
+ private Class<?> expCanCommitType;
+ private Class<?> expCommitType;
+ private Class<?> expAbortType;
+ private Object canCommitReply;
+ private Object commitReply;
+ private Object abortReply;
+ private final TransactionIdentifier transactionId;
+
+ Builder(final TransactionIdentifier transactionId) {
+ this.transactionId = Preconditions.checkNotNull(transactionId);
+ }
+
+ Builder expectCanCommit(final Class<?> newExpCanCommitType, final Object newCanCommitReply) {
+ this.expCanCommitType = newExpCanCommitType;
+ this.canCommitReply = newCanCommitReply;
+ return this;
+ }
+
+ Builder expectCanCommit(final Object newCanCommitReply) {
+ return expectCanCommit(CanCommitTransaction.class, newCanCommitReply);
+ }
+
+ Builder expectCommit(final Class<?> newExpCommitType, final Object newCommitReply) {
+ this.expCommitType = newExpCommitType;
+ this.commitReply = newCommitReply;
+ return this;
+ }
+
+ Builder expectCommit(final Object newCommitReply) {
+ return expectCommit(CommitTransaction.class, newCommitReply);
+ }
+
+ Builder expectAbort(final Class<?> newExpAbortType, final Object newAbortReply) {
+ this.expAbortType = newExpAbortType;
+ this.abortReply = newAbortReply;
+ return this;
+ }
+
+ Builder expectAbort(final Object newAbortReply) {
+ return expectAbort(AbortTransaction.class, newAbortReply);
+ }
+
+ Props props() {
+ return Props.create(CohortActor.class, this);
+ }
+ }