Improve ClassLoaderUtils.loadClassWithTCCL()
[yangtools.git] / common / util / src / main / java / org / opendaylight / yangtools / util / ClassLoaderUtils.java
index 1a8161455c367b7ca08b7c1b1342b9a6b844cfe3..b5064080ad93757a026209fdabcb17005af71da9 100644 (file)
@@ -8,6 +8,7 @@
 package org.opendaylight.yangtools.util;
 
 import static com.google.common.base.Preconditions.checkNotNull;
+
 import com.google.common.base.Joiner;
 import com.google.common.base.Preconditions;
 import com.google.common.base.Splitter;
@@ -30,32 +31,30 @@ public final class ClassLoaderUtils {
     }
 
     /**
-     *
      * Runs {@link Supplier} with provided {@link ClassLoader}.
      *
-     * Invokes supplies function and makes sure that original {@link ClassLoader}
+     * <p>Invokes supplies function and makes sure that original {@link ClassLoader}
      * is context {@link ClassLoader} after execution.
      *
      * @param cls {@link ClassLoader} to be used.
      * @param function Function to be executed.
      * @return Result of supplier invocation.
-     *
      */
     public static <V> V withClassLoader(final ClassLoader cls, final Supplier<V> function) {
         checkNotNull(cls, "Classloader should not be null");
         checkNotNull(function, "Function should not be null");
 
-        final ClassLoader oldCls = Thread.currentThread().getContextClassLoader();
+        final Thread currentThread = Thread.currentThread();
+        final ClassLoader oldCls = currentThread.getContextClassLoader();
         try {
-            Thread.currentThread().setContextClassLoader(cls);
+            currentThread.setContextClassLoader(cls);
             return function.get();
         } finally {
-            Thread.currentThread().setContextClassLoader(oldCls);
+            currentThread.setContextClassLoader(oldCls);
         }
     }
 
     /**
-     *
      * Runs {@link Callable} with provided {@link ClassLoader}.
      *
      * Invokes supplies function and makes sure that original {@link ClassLoader}
@@ -64,36 +63,31 @@ public final class ClassLoaderUtils {
      * @param cls {@link ClassLoader} to be used.
      * @param function Function to be executed.
      * @return Result of callable invocation.
-     *
      */
     public static <V> V withClassLoader(final ClassLoader cls, final Callable<V> function) throws Exception {
         checkNotNull(cls, "Classloader should not be null");
         checkNotNull(function, "Function should not be null");
 
-        final ClassLoader oldCls = Thread.currentThread().getContextClassLoader();
+        final Thread currentThread = Thread.currentThread();
+        final ClassLoader oldCls = currentThread.getContextClassLoader();
         try {
-            Thread.currentThread().setContextClassLoader(cls);
+            currentThread.setContextClassLoader(cls);
             return function.call();
         } finally {
-            Thread.currentThread().setContextClassLoader(oldCls);
+            currentThread.setContextClassLoader(oldCls);
         }
     }
 
-    public static Object construct(final Constructor<? extends Object> constructor, final List<Object> objects)
+    public static Object construct(final Constructor<?> constructor, final List<Object> objects)
             throws InstantiationException, IllegalAccessException, IllegalArgumentException, InvocationTargetException {
         final Object[] initargs = objects.toArray();
         return constructor.newInstance(initargs);
     }
 
     /**
-     *
      * Loads class using this supplied classloader.
      *
-     *
-     * @param cls
      * @param name String name of class.
-     * @return
-     * @throws ClassNotFoundException
      */
     public static Class<?> loadClass(final ClassLoader cls, final String name) throws ClassNotFoundException {
         if ("byte[]".equals(name)) {
@@ -116,9 +110,9 @@ public final class ClassLoaderUtils {
                 final String outerName = Joiner.on(".").join(components.subList(0, length));
                 final String innerName = outerName + "$" + components.get(length);
                 return cls.loadClass(innerName);
-            } else {
-                throw e;
             }
+
+            throw e;
         }
     }
 
@@ -136,33 +130,45 @@ public final class ClassLoaderUtils {
     }
 
     public static Class<?> loadClassWithTCCL(final String name) throws ClassNotFoundException {
-        return loadClass(Thread.currentThread().getContextClassLoader(), name);
+        final Thread thread = Thread.currentThread();
+        final ClassLoader tccl = thread.getContextClassLoader();
+        if (tccl == null) {
+            throw new ClassNotFoundException("Thread " + thread + " does not have a Context Class Loader, cannot load "
+                    + name);
+        }
+        return loadClass(tccl, name);
     }
 
-    public static Class<?> tryToLoadClassWithTCCL(final String fullyQualifiedName) {
+    public static Class<?> tryToLoadClassWithTCCL(final String fullyQualifiedClassName) {
+        final Thread thread = Thread.currentThread();
+        final ClassLoader tccl = thread.getContextClassLoader();
+        if (tccl == null) {
+            LOG.debug("Thread {} does not have a Context Class Loader, not loading class {}", thread,
+                fullyQualifiedClassName);
+            return null;
+        }
+
         try {
-            return loadClassWithTCCL(fullyQualifiedName);
+            return loadClass(tccl, fullyQualifiedClassName);
         } catch (final ClassNotFoundException e) {
-            LOG.debug("Failed to load class {}", fullyQualifiedName, e);
+            LOG.debug("Failed to load class {}", fullyQualifiedClassName, e);
             return null;
         }
     }
 
     public static <S,G,P> Class<P> findFirstGenericArgument(final Class<S> scannedClass, final Class<G> genericType) {
-        return withClassLoader(scannedClass.getClassLoader(), ClassLoaderUtils.<S,G,P>findFirstGenericArgumentTask(scannedClass, genericType));
+        return withClassLoader(scannedClass.getClassLoader(), findFirstGenericArgumentTask(scannedClass, genericType));
     }
 
-    private static <S,G,P> Supplier<Class<P>> findFirstGenericArgumentTask(final Class<S> scannedClass, final Class<G> genericType) {
-        return new Supplier<Class<P>>() {
-            @Override
-            @SuppressWarnings("unchecked")
-            public Class<P> get() {
-                final ParameterizedType augmentationGeneric = findParameterizedType(scannedClass, genericType);
-                if (augmentationGeneric != null) {
-                    return (Class<P>) augmentationGeneric.getActualTypeArguments()[0];
-                }
-                return null;
+    @SuppressWarnings("unchecked")
+    private static <S, G, P> Supplier<Class<P>> findFirstGenericArgumentTask(final Class<S> scannedClass,
+            final Class<G> genericType) {
+        return () -> {
+            final ParameterizedType augmentationGeneric = findParameterizedType(scannedClass, genericType);
+            if (augmentationGeneric != null) {
+                return (Class<P>) augmentationGeneric.getActualTypeArguments()[0];
             }
+            return null;
         };
     }