Refactor LeaderTest
[controller.git] / opendaylight / md-sal / sal-akka-raft / src / test / java / org / opendaylight / controller / cluster / raft / utils / MessageCollectorActor.java
index 79c90cf051cc928ac50e563a136c62030809bbc2..448c28e8c8e9f090dacdbab184b7760f85c66893 100644 (file)
@@ -18,6 +18,7 @@ import java.util.ArrayList;
 import java.util.List;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
+import org.junit.Assert;
 import scala.concurrent.Await;
 import scala.concurrent.Future;
 import scala.concurrent.duration.Duration;
@@ -63,29 +64,39 @@ public class MessageCollectorActor extends UntypedActor {
      * @return
      */
     public static <T> T getFirstMatching(ActorRef actor, Class<T> clazz) throws Exception {
-        for(int i = 0; i < 50; i++) {
-            List<Object> allMessages = getAllMessages(actor);
+        List<Object> allMessages = getAllMessages(actor);
 
-            for(Object message : allMessages){
-                if(message.getClass().equals(clazz)){
-                    return (T) message;
-                }
+        for(Object message : allMessages){
+            if(message.getClass().equals(clazz)){
+                return (T) message;
+            }
+        }
+
+        return null;
+    }
+
+    public static <T> T expectFirstMatching(ActorRef actor, Class<T> clazz) throws Exception {
+        for(int i = 0; i < 50; i++) {
+            T message = getFirstMatching(actor, clazz);
+            if(message != null) {
+                return message;
             }
 
             Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS);
         }
 
+        Assert.fail("Did not receive message of type " + clazz);
         return null;
     }
 
-    public static List<Object> getAllMatching(ActorRef actor, Class<?> clazz) throws Exception {
+    public static <T> List<T> getAllMatching(ActorRef actor, Class<T> clazz) throws Exception {
         List<Object> allMessages = getAllMessages(actor);
 
-        List<Object> output = Lists.newArrayList();
+        List<T> output = Lists.newArrayList();
 
         for(Object message : allMessages){
             if(message.getClass().equals(clazz)){
-                output.add(message);
+                output.add((T) message);
             }
         }