Map identities to proper objects
[mdsal.git] / binding / mdsal-binding-runtime-api / src / main / java / org / opendaylight / mdsal / binding / runtime / api / AbstractBindingRuntimeContext.java
index dbed269c4af0d895127351403becc4af7433da88..3fed0a6a7677d9d5fd5840252153555fbbb1676a 100644 (file)
@@ -8,14 +8,19 @@
 package org.opendaylight.mdsal.binding.runtime.api;
 
 import static com.google.common.base.Preconditions.checkArgument;
+import static java.util.Objects.requireNonNull;
 
 import com.google.common.annotations.Beta;
+import com.google.common.base.Throwables;
 import com.google.common.cache.CacheBuilder;
 import com.google.common.cache.CacheLoader;
 import com.google.common.cache.LoadingCache;
+import java.util.concurrent.ExecutionException;
+import org.eclipse.jdt.annotation.NonNull;
 import org.opendaylight.mdsal.binding.model.api.JavaTypeName;
 import org.opendaylight.yangtools.yang.binding.Action;
 import org.opendaylight.yangtools.yang.binding.Augmentation;
+import org.opendaylight.yangtools.yang.binding.BaseIdentity;
 import org.opendaylight.yangtools.yang.binding.Notification;
 import org.opendaylight.yangtools.yang.binding.RpcInput;
 import org.opendaylight.yangtools.yang.binding.RpcOutput;
@@ -28,16 +33,18 @@ import org.opendaylight.yangtools.yang.model.api.stmt.SchemaNodeIdentifier.Absol
  */
 @Beta
 public abstract class AbstractBindingRuntimeContext implements BindingRuntimeContext {
-    private final LoadingCache<QName, Class<?>> identityClasses = CacheBuilder.newBuilder().weakValues().build(
-        new CacheLoader<QName, Class<?>>() {
+    private final LoadingCache<@NonNull QName, @NonNull Class<? extends BaseIdentity>> identityClasses =
+        CacheBuilder.newBuilder().weakValues().build(new CacheLoader<>() {
             @Override
-            public Class<?> load(final QName key) {
+            public Class<? extends BaseIdentity> load(final QName key) {
                 final var type = getTypes().findIdentity(key).orElseThrow(
                     () -> new IllegalArgumentException("Supplied QName " + key + " is not a valid identity"));
                 try {
-                    return loadClass(type.getIdentifier());
-                } catch (final ClassNotFoundException e) {
+                    return loadClass(type.getIdentifier()).asSubclass(BaseIdentity.class);
+                } catch (ClassNotFoundException e) {
                     throw new IllegalArgumentException("Required class " + type + " was not found.", e);
+                } catch (ClassCastException e) {
+                    throw new IllegalArgumentException(key + " resolves to a non-identity class", e);
                 }
             }
         });
@@ -79,8 +86,13 @@ public abstract class AbstractBindingRuntimeContext implements BindingRuntimeCon
     }
 
     @Override
-    public final Class<?> getIdentityClass(final QName input) {
-        return identityClasses.getUnchecked(input);
+    public final Class<? extends BaseIdentity> getIdentityClass(final QName input) {
+        try {
+            return identityClasses.get(requireNonNull(input));
+        } catch (ExecutionException e) {
+            Throwables.throwIfUnchecked(e.getCause());
+            throw new IllegalStateException("Unexpected error looking up " + input, e);
+        }
     }
 
     @Override