Guard against null augmentations
[mdsal.git] / binding / mdsal-binding-dom-codec / src / main / java / org / opendaylight / mdsal / binding / dom / codec / impl / UnionTypeCodec.java
index 6d35f58986f5a68e5e5418cc8cd88bd77a9897a8..2e27ad294514e18cbe19c5eab9da9b5c9e310c1d 100644 (file)
@@ -8,54 +8,63 @@
 package org.opendaylight.mdsal.binding.dom.codec.impl;
 
 import static com.google.common.base.Verify.verify;
+import static com.google.common.base.Verify.verifyNotNull;
+import static java.util.Objects.requireNonNull;
 
 import com.google.common.collect.ImmutableSet;
 import java.lang.reflect.Method;
 import java.util.ArrayList;
 import java.util.Iterator;
 import java.util.List;
-import java.util.concurrent.Callable;
-import org.opendaylight.mdsal.binding.model.api.GeneratedType;
+import org.opendaylight.mdsal.binding.model.api.GeneratedTransferObject;
+import org.opendaylight.mdsal.binding.model.api.Type;
 import org.opendaylight.mdsal.binding.runtime.api.RuntimeGeneratedUnion;
-import org.opendaylight.mdsal.binding.spec.naming.BindingMapping;
-import org.opendaylight.yangtools.concepts.IllegalArgumentCodec;
+import org.opendaylight.yangtools.yang.binding.contract.Naming;
 import org.opendaylight.yangtools.yang.model.api.TypeDefinition;
 import org.opendaylight.yangtools.yang.model.api.type.UnionTypeDefinition;
 
-final class UnionTypeCodec extends ReflectionBasedCodec {
+final class UnionTypeCodec implements ValueCodec<Object, Object> {
     private final ImmutableSet<UnionValueOptionContext> typeCodecs;
+    private final Class<?> unionClass;
 
-    private UnionTypeCodec(final Class<?> unionCls,final List<UnionValueOptionContext> codecs) {
-        super(unionCls);
-        typeCodecs = ImmutableSet.copyOf(codecs);
+    private UnionTypeCodec(final Class<?> unionClass, final List<UnionValueOptionContext> typeCodecs) {
+        this.unionClass = requireNonNull(unionClass);
+        // Squashes duplicates
+        this.typeCodecs = ImmutableSet.copyOf(typeCodecs);
     }
 
-    static Callable<UnionTypeCodec> loader(final Class<?> unionCls, final UnionTypeDefinition unionType,
-            final BindingCodecContext bindingCodecContext) {
-        return () -> {
-            final GeneratedType contextType = bindingCodecContext.getRuntimeContext().getTypeWithSchema(unionCls)
-                .getKey();
-            verify(contextType instanceof RuntimeGeneratedUnion, "Unexpected runtime type %s", contextType);
-            final RuntimeGeneratedUnion contextUnion = (RuntimeGeneratedUnion) contextType;
+    static UnionTypeCodec of(final Class<?> unionCls, final UnionTypeDefinition unionType,
+            final BindingCodecContext codecContext) throws Exception {
+        final List<String> unionProperties = extractUnionProperties(codecContext.getRuntimeContext()
+            .getTypeWithSchema(unionCls).javaType());
+        final List<TypeDefinition<?>> unionTypes = unionType.getTypes();
+        verify(unionTypes.size() == unionProperties.size(), "Mismatched union types %s and properties %s",
+            unionTypes, unionProperties);
 
-            final List<TypeDefinition<?>> unionTypes = unionType.getTypes();
-            final List<String> unionProperties = contextUnion.typePropertyNames();
-            verify(unionTypes.size() == unionProperties.size(), "Mismatched union types %s and properties %s",
-                unionTypes, unionProperties);
+        final List<UnionValueOptionContext> values = new ArrayList<>(unionTypes.size());
+        final Iterator<String> it = unionProperties.iterator();
+        for (final TypeDefinition<?> subtype : unionTypes) {
+            final String getterName = Naming.GETTER_PREFIX + Naming.toFirstUpper(it.next());
+            final Method valueGetter = unionCls.getMethod(getterName);
+            final Class<?> valueType = valueGetter.getReturnType();
+            final ValueCodec<Object, Object> codec = codecContext.getCodec(valueType, subtype);
 
-            final List<UnionValueOptionContext> values = new ArrayList<>(unionTypes.size());
-            final Iterator<String> it = unionProperties.iterator();
-            for (final TypeDefinition<?> subtype : unionType.getTypes()) {
-                final String getterName = BindingMapping.GETTER_PREFIX + BindingMapping.toFirstUpper(it.next());
-                final Method valueGetter = unionCls.getMethod(getterName);
-                final Class<?> valueType = valueGetter.getReturnType();
-                final IllegalArgumentCodec<Object, Object> codec = bindingCodecContext.getCodec(valueType, subtype);
+            values.add(new UnionValueOptionContext(unionCls, valueType, valueGetter, codec));
+        }
 
-                values.add(new UnionValueOptionContext(unionCls, valueType, valueGetter, codec));
-            }
+        return new UnionTypeCodec(unionCls, values);
+    }
 
-            return new UnionTypeCodec(unionCls, values);
-        };
+    private static List<String> extractUnionProperties(final Type type) {
+        verify(type instanceof GeneratedTransferObject, "Unexpected runtime type %s", type);
+
+        GeneratedTransferObject gto = (GeneratedTransferObject) type;
+        while (true) {
+            if (gto instanceof RuntimeGeneratedUnion) {
+                return ((RuntimeGeneratedUnion) gto).typePropertyNames();
+            }
+            gto = verifyNotNull(gto.getSuperType(), "Cannot find union type information for %s", type);
+        }
     }
 
     @Override
@@ -68,19 +77,17 @@ final class UnionTypeCodec extends ReflectionBasedCodec {
         }
 
         throw new IllegalArgumentException(String.format("Failed to construct instance of %s for input %s",
-            getTypeClass(), input));
+            unionClass, input));
     }
 
     @Override
     public Object serialize(final Object input) {
-        if (input != null) {
-            for (final UnionValueOptionContext valCtx : typeCodecs) {
-                final Object domValue = valCtx.serialize(input);
-                if (domValue != null) {
-                    return domValue;
-                }
+        for (final UnionValueOptionContext valCtx : typeCodecs) {
+            final Object domValue = valCtx.serialize(input);
+            if (domValue != null) {
+                return domValue;
             }
         }
-        return null;
+        throw new IllegalStateException("No codec matched value " + input);
     }
 }