Migrate more ThreadLocals
[aaa.git] / aaa-shiro / impl / src / main / java / org / opendaylight / aaa / shiro / realm / TokenAuthRealm.java
index 4a6bcd57f9da7085eae6395014bca8302fe57913..bb8048dd9446844b9f65de014ad7f08e8d992372 100644 (file)
@@ -7,6 +7,7 @@
  */
 package org.opendaylight.aaa.shiro.realm;
 
+import static com.google.common.base.Verify.verifyNotNull;
 import static java.util.Objects.requireNonNull;
 
 import com.google.common.base.Strings;
@@ -20,6 +21,7 @@ import org.apache.shiro.authz.AuthorizationInfo;
 import org.apache.shiro.authz.SimpleAuthorizationInfo;
 import org.apache.shiro.realm.AuthorizingRealm;
 import org.apache.shiro.subject.PrincipalCollection;
+import org.eclipse.jdt.annotation.Nullable;
 import org.opendaylight.aaa.api.Authentication;
 import org.opendaylight.aaa.api.AuthenticationService;
 import org.opendaylight.aaa.api.TokenAuth;
@@ -28,8 +30,8 @@ import org.opendaylight.aaa.api.shiro.principal.ODLPrincipal;
 import org.opendaylight.aaa.shiro.principal.ODLPrincipalImpl;
 import org.opendaylight.aaa.shiro.realm.util.TokenUtils;
 import org.opendaylight.aaa.shiro.realm.util.http.header.HeaderUtils;
-import org.opendaylight.aaa.shiro.web.env.ThreadLocals;
 import org.opendaylight.aaa.tokenauthrealm.auth.TokenAuthenticators;
+import org.opendaylight.yangtools.concepts.Registration;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -39,18 +41,46 @@ import org.slf4j.LoggerFactory;
  */
 public class TokenAuthRealm extends AuthorizingRealm {
     private static final Logger LOG = LoggerFactory.getLogger(TokenAuthRealm.class);
+    private static final ThreadLocal<TokenAuthenticators> AUTHENICATORS_TL = new ThreadLocal<>();
+    private static final ThreadLocal<AuthenticationService> AUTH_SERVICE_TL = new ThreadLocal<>();
+    private static final ThreadLocal<TokenStore> TOKEN_STORE_TL = new ThreadLocal<>();
 
-    private final AuthenticationService authenticationService;
+    private final TokenAuthenticators authenticators;
+    private final AuthenticationService authService;
     private final TokenStore tokenStore;
-    private final TokenAuthenticators tokenAuthenticators;
 
     public TokenAuthRealm() {
-        authenticationService = requireNonNull(ThreadLocals.AUTH_SETVICE_TL.get());
-        tokenStore = ThreadLocals.TOKEN_STORE_TL.get();
-        tokenAuthenticators = requireNonNull(ThreadLocals.TOKEN_AUTHENICATORS_TL.get());
+        this(verifyLoad(AUTH_SERVICE_TL), verifyLoad(AUTHENICATORS_TL), TOKEN_STORE_TL.get());
+    }
+
+    public TokenAuthRealm(final AuthenticationService authService, final TokenAuthenticators authenticators) {
+        this(authService, authenticators, null);
+    }
+
+    public TokenAuthRealm(final AuthenticationService authService, final TokenAuthenticators authenticators,
+            final @Nullable TokenStore tokenStore) {
+        this.authService = requireNonNull(authService);
+        this.authenticators = requireNonNull(authenticators);
+        this.tokenStore = tokenStore;
         super.setName("TokenAuthRealm");
     }
 
+    public static Registration prepareForLoad(final AuthenticationService authService,
+            final TokenAuthenticators authenticators, final @Nullable TokenStore tokenStore) {
+        AUTH_SERVICE_TL.set(requireNonNull(authService));
+        AUTHENICATORS_TL.set(requireNonNull(authenticators));
+        TOKEN_STORE_TL.set(tokenStore);
+        return () -> {
+            AUTH_SERVICE_TL.remove();
+            AUTHENICATORS_TL.remove();
+            TOKEN_STORE_TL.remove();
+        };
+    }
+
+    private static <T> T verifyLoad(final ThreadLocal<T> threadLocal) {
+        return verifyNotNull(threadLocal.get(), "TokenAuthRealm loading not prepared");
+    }
+
     /**
      * {@inheritDoc}
      *
@@ -101,13 +131,13 @@ public class TokenAuthRealm extends AuthorizingRealm {
             // iterate over <code>TokenAuth</code> implementations and
             // attempt to
             // authentication with each one
-            for (TokenAuth ta : tokenAuthenticators.getTokenAuthCollection()) {
+            for (TokenAuth ta : authenticators.getTokenAuthCollection()) {
                 try {
                     LOG.debug("Authentication attempt using {}", ta.getClass().getName());
                     final Authentication auth = ta.validate(headers);
                     if (auth != null) {
                         LOG.debug("Authentication attempt successful");
-                        authenticationService.set(auth);
+                        authService.set(auth);
                         final ODLPrincipal odlPrincipal = ODLPrincipalImpl.createODLPrincipal(auth);
                         return new SimpleAuthenticationInfo(odlPrincipal, password.toCharArray(), getName());
                     }
@@ -142,7 +172,7 @@ public class TokenAuthRealm extends AuthorizingRealm {
         if (auth == null) {
             throw new AuthenticationException("Could not validate the token " + token);
         }
-        authenticationService.set(auth);
+        authService.set(auth);
         return auth;
     }
 }