BUG-1119: optimize length and range checks in generated sources.
[yangtools.git] / code-generator / binding-java-api-generator / src / main / java / org / opendaylight / yangtools / sal / java / api / generator / BaseTemplate.xtend
index ababc45775c1109d26845bd472cd22adb951b411..672d4b648e94eea593dbd1d04a8d6415bcefa751 100644 (file)
@@ -15,7 +15,6 @@ import org.opendaylight.yangtools.binding.generator.util.Types
 import com.google.common.base.Splitter
 import org.opendaylight.yangtools.sal.binding.model.api.MethodSignature
 import com.google.common.collect.Range
-import java.util.ArrayList
 import java.util.List
 import org.opendaylight.yangtools.sal.binding.model.api.ConcreteType
 import org.opendaylight.yangtools.sal.binding.model.api.Restrictions
@@ -23,9 +22,11 @@ import org.opendaylight.yangtools.sal.binding.model.api.GeneratedTransferObject
 import java.util.Collection
 import java.util.Arrays
 import java.util.HashMap
+import com.google.common.collect.ImmutableList
+import java.math.BigInteger
+import java.math.BigDecimal
 
 abstract class BaseTemplate {
-
     protected val GeneratedType type;
     protected val Map<String, String> importMap;
     static val paragraphSplitter = Splitter.on("\n\n").omitEmptyStrings();
@@ -170,93 +171,79 @@ abstract class BaseTemplate {
     }
 
     def generateRestrictions(Type type, String paramName, Type returnType) '''
-        «val boolean isArray = returnType.name.contains("[")»
-        «processRestrictions(type, paramName, returnType, isArray)»
-    '''
-
-    def generateRestrictions(GeneratedProperty field, String paramName) '''
-        «val Type type = field.returnType»
-        «IF type instanceof ConcreteType»
-            «processRestrictions(type, paramName, field.returnType, type.name.contains("["))»
-        «ELSEIF type instanceof GeneratedTransferObject»
-            «processRestrictions(type, paramName, field.returnType, isArrayType(type as GeneratedTransferObject))»
-        «ENDIF»
-    '''
-
-
-    private def processRestrictions(Type type, String paramName, Type returnType, boolean isArray) '''
         «val restrictions = type.getRestrictions»
         «IF restrictions !== null»
+            «val boolean isNestedType = !(returnType instanceof ConcreteType)»
             «IF !restrictions.lengthConstraints.empty»
-                «generateLengthRestriction(type, restrictions, paramName, isArray,
-            !(returnType instanceof ConcreteType))»
+                «generateLengthRestriction(returnType, restrictions, paramName, isNestedType)»
             «ENDIF»
-            «IF !restrictions.rangeConstraints.empty &&
-            ("java.lang".equals(returnType.packageName) || "java.math".equals(returnType.packageName))»
-                «generateRangeRestriction(type, returnType, restrictions, paramName,
-            !(returnType instanceof ConcreteType))»
+            «IF !restrictions.rangeConstraints.empty»
+                «generateRangeRestriction(returnType, restrictions, paramName, isNestedType)»
             «ENDIF»
         «ENDIF»
     '''
 
-    def generateLengthRestriction(Type type, Restrictions restrictions, String paramName, boolean isArray,
-        boolean isNestedType) '''
+    def private generateLengthRestriction(Type returnType, Restrictions restrictions, String paramName, boolean isNestedType) '''
+        «val clazz = restrictions.lengthConstraints.iterator.next.min.class»
         if («paramName» != null) {
+            «printLengthConstraint(returnType, clazz, paramName, isNestedType, returnType.name.contains("["))»
             boolean isValidLength = false;
-            «List.importedName»<«Range.importedName»<«Integer.importedName»>> lengthConstraints = new «ArrayList.
-            importedName»<>(); 
-            «FOR r : restrictions.lengthConstraints»
-                lengthConstraints.add(«Range.importedName».closed(«r.min», «r.max»));
-            «ENDFOR»
-            for («Range.importedName»<«Integer.importedName»> r : lengthConstraints) {
-                «IF isArray»
-                    «IF isNestedType»
-                        if (r.contains(«paramName».getValue().length)) {
-                    «ELSE»
-                        if (r.contains(«paramName».length)) {
-                    «ENDIF»
-                «ELSE»
-                    «IF isNestedType»
-                        if (r.contains(«paramName».getValue().length())) {
-                    «ELSE»
-                        if (r.contains(«paramName».length())) {
-                    «ENDIF»
-                «ENDIF»
-                isValidLength = true;
+            for («Range.importedName»<«clazz.importedNumber»> r : «IF isNestedType»«returnType.importedName».«ENDIF»length()) {
+                if (r.contains(_constraint)) {
+                    isValidLength = true;
                 }
             }
             if (!isValidLength) {
-                throw new IllegalArgumentException(String.format("Invalid length: %s, expected: %s.", «paramName», lengthConstraints));
+                throw new IllegalArgumentException(String.format("Invalid length: %s, expected: %s.", «paramName», «IF isNestedType»«returnType.importedName».«ENDIF»length()));
             }
         }
     '''
 
-    def generateRangeRestriction(Type type, Type returnType, Restrictions restrictions, String paramName,
-        boolean isNestedType) '''
-        «val javaType = Class.forName(returnType.fullyQualifiedName)»
+    def private generateRangeRestriction(Type returnType, Restrictions restrictions, String paramName, boolean isNestedType) '''
+        «val clazz = restrictions.rangeConstraints.iterator.next.min.class»
         if («paramName» != null) {
+            «printRangeConstraint(returnType, clazz, paramName, isNestedType)»
             boolean isValidRange = false;
-            «List.importedName»<«Range.importedName»<«javaType.importedName»>> rangeConstraints = new «ArrayList.
-            importedName»<>(); 
-            «FOR r : restrictions.rangeConstraints»
-                rangeConstraints.add(«Range.importedName».closed(new «javaType.importedName»(«r.min.toQuote»), new «javaType.
-            importedName»(«r.max.toQuote»)));
-            «ENDFOR»
-            for («Range.importedName»<«javaType.importedName»> r : rangeConstraints) {
-                «IF isNestedType»
-                    if (r.contains(«paramName».getValue())) {
-                «ELSE»
-                    if (r.contains(«paramName»)) {
-                «ENDIF»
-                isValidRange = true;
+            for («Range.importedName»<«clazz.importedNumber»> r : «IF isNestedType»«returnType.importedName».«ENDIF»range()) {
+                if (r.contains(_constraint)) {
+                    isValidRange = true;
                 }
             }
             if (!isValidRange) {
-                throw new IllegalArgumentException(String.format("Invalid range: %s, expected: %s.", «paramName», rangeConstraints));
+                throw new IllegalArgumentException(String.format("Invalid range: %s, expected: %s.", «paramName», «IF isNestedType»«returnType.importedName».«ENDIF»range()));
             }
         }
     '''
 
+    /**
+     * Print length constraint.
+     * This should always be a BigInteger (only string and binary can have length restriction)
+     */
+    def printLengthConstraint(Type returnType, Class<? extends Number> clazz, String paramName, boolean isNestedType, boolean isArray) '''
+        «clazz.importedNumber» _constraint = «clazz.importedNumber».valueOf(«paramName»«IF isNestedType».getValue()«ENDIF».length«IF !isArray»()«ENDIF»);
+    '''
+
+    def printRangeConstraint(Type returnType, Class<? extends Number> clazz, String paramName, boolean isNestedType) '''
+        «IF clazz.canonicalName.equals(BigDecimal.canonicalName)»
+            «clazz.importedNumber» _constraint = new «clazz.importedNumber»(«paramName»«IF isNestedType».getValue()«ENDIF».toString());
+        «ELSE»
+            «IF isNestedType»
+                «val propReturnType = findProperty(returnType as GeneratedTransferObject, "value").returnType»
+                «IF propReturnType.fullyQualifiedName.equals(BigInteger.canonicalName)»
+                    «clazz.importedNumber» _constraint = «paramName».getValue();
+                «ELSE»
+                    «clazz.importedNumber» _constraint = «clazz.importedNumber».valueOf(«paramName».getValue());
+                «ENDIF»
+            «ELSE»
+                «IF returnType.fullyQualifiedName.equals(BigInteger.canonicalName)»
+                    «clazz.importedNumber» _constraint = «paramName»;
+                «ELSE»
+                    «clazz.importedNumber» _constraint = «clazz.importedNumber».valueOf(«paramName»);
+                «ENDIF»
+            «ENDIF»
+        «ENDIF»
+    '''
+
     def protected generateToString(Collection<GeneratedProperty> properties) '''
         «IF !properties.empty»
             @Override
@@ -305,24 +292,13 @@ abstract class BaseTemplate {
 
     def boolean isArrayType(GeneratedTransferObject type) {
         var isArray = false
-        val GeneratedTransferObject superType = type.findSuperType
-        val GeneratedProperty value = superType.getPropByName("value")
+        val GeneratedProperty value = findProperty(type, "value")
         if (value != null && value.returnType.name.contains("[")) {
             isArray = true
         }
         return isArray
     }
 
-    def GeneratedTransferObject findSuperType(GeneratedTransferObject gto) {
-        var GeneratedTransferObject base = gto
-        var GeneratedTransferObject superType = base.superType
-        while (superType !== null) {
-            base = superType
-            superType = base.superType
-        }
-        return base;
-    }
-
     def String toQuote(Object obj) {
         return "\"" + obj.toString + "\"";
     }
@@ -342,4 +318,111 @@ abstract class BaseTemplate {
         ENDIF
     »'''
 
+    def protected generateLengthMethod(String methodName, Type type, String className, String varName) '''
+        «val Restrictions restrictions = type.restrictions»
+        «IF restrictions != null && !(restrictions.lengthConstraints.empty)»
+            «val numberClass = restrictions.lengthConstraints.iterator.next.min.class»
+            public static «List.importedName»<«Range.importedName»<«numberClass.importedNumber»>> «methodName»() {
+                «IF numberClass.equals(typeof(BigDecimal))»
+                    «lengthMethodBody(restrictions, numberClass, className, varName)»
+                «ELSE»
+                    «lengthMethodBody(restrictions, typeof(BigInteger), className, varName)»
+                «ENDIF»
+            }
+        «ENDIF»
+    '''
+
+    def private lengthMethodBody(Restrictions restrictions, Class<? extends Number> numberClass, String className, String varName) '''
+        if («varName» == null) {
+            synchronized («className».class) {
+                if («varName» == null) {
+                    «ImmutableList.importedName».Builder<«Range.importedName»<«numberClass.importedName»>> builder = «ImmutableList.importedName».builder();
+                    «FOR r : restrictions.lengthConstraints»
+                        builder.add(«Range.importedName».closed(«numericValue(numberClass, r.min)», «numericValue(numberClass, r.max)»));
+                    «ENDFOR»
+                    «varName» = builder.build();
+                }
+            }
+        }
+        return «varName»;
+    '''
+
+    def protected generateRangeMethod(String methodName, Type type, String className, String varName) '''
+        «val Restrictions restrictions = type.restrictions»
+        «IF restrictions != null && !(restrictions.rangeConstraints.empty)»
+            «val numberClass = restrictions.rangeConstraints.iterator.next.min.class»
+            public static «List.importedName»<«Range.importedName»<«numberClass.importedNumber»>> «methodName»() {
+                «IF numberClass.equals(typeof(BigDecimal))»
+                    «rangeMethodBody(restrictions, numberClass, className, varName)»
+                «ELSE»
+                    «rangeMethodBody(restrictions, typeof(BigInteger), className, varName)»
+                «ENDIF»
+            }
+        «ENDIF»
+    '''
+
+    def private rangeMethodBody(Restrictions restrictions, Class<? extends Number> numberClass, String className, String varName) '''
+        if («varName» == null) {
+            synchronized («className».class) {
+                if («varName» == null) {
+                    «ImmutableList.importedName».Builder<«Range.importedName»<«numberClass.importedName»>> builder = «ImmutableList.importedName».builder();
+                    «FOR r : restrictions.rangeConstraints»
+                        builder.add(«Range.importedName».closed(«numericValue(numberClass, r.min)», «numericValue(numberClass, r.max)»));
+                    «ENDFOR»
+                    «varName» = builder.build();
+                }
+            }
+        }
+        return «varName»;
+    '''
+
+    def protected String importedNumber(Class<? extends Number> clazz) {
+        if (clazz.equals(typeof(BigDecimal))) {
+            return BigDecimal.importedName
+        }
+        return BigInteger.importedName
+    }
+
+    def private String numericValue(Class<? extends Number> clazz, Object numberValue) {
+        val number = clazz.importedName;
+        val value = numberValue.toString
+        if (clazz.equals(typeof(BigInteger)) || clazz.equals(typeof(BigDecimal))) {
+            if (value.equals("0")) {
+                return number + ".ZERO"
+            } else if (value.equals("1")) {
+                return number + ".ONE"
+            } else if (value.equals("10")) {
+                return number + ".TEN"
+            } else {
+                try {
+                    val Long longVal = Long.valueOf(value)
+                    return number + ".valueOf(" + longVal + "L)"
+                } catch (NumberFormatException e) {
+                    if (clazz.equals(typeof(BigDecimal))) {
+                        try {
+                            val Double doubleVal = Double.valueOf(value);
+                            return number + ".valueOf(" + doubleVal + ")"
+                        } catch (NumberFormatException e2) {
+                        }
+                    }
+                }
+            }
+        }
+        return "new " + number + "(\"" + value + "\")"
+    }
+
+    def private GeneratedProperty findProperty(GeneratedTransferObject gto, String name) {
+        val props = gto.properties
+        for (prop : props) {
+            if (prop.name.equals(name)) {
+                return prop
+            }
+        }
+        val GeneratedTransferObject parent = gto.superType
+        if (parent != null) {
+            return findProperty(parent, name)
+        }
+        return null
+    }
+
 }