Updated code to match new rules
[integration/test.git] / csit / libraries / ipaddr.py
index 71a58b832664011ac4bf4ccf8f165b0d85a97b9c..8f0260750258c02fcd326c1741c14bde1df9708b 100644 (file)
@@ -22,10 +22,12 @@ and networks.
 
 """
 
-__version__ = '2.1.10'
-
 import struct
 
+
+__version__ = '2.1.11'
+
+
 IPV4LENGTH = 32
 IPV6LENGTH = 128
 
@@ -141,7 +143,7 @@ def v6_int_to_packed(address):
     """The binary representation of this address.
 
     Args:
-        address: An integer representation of an IPv4 IP address.
+        address: An integer representation of an IPv6 IP address.
 
     Returns:
         The binary representation of this address.
@@ -234,7 +236,7 @@ def summarize_address_range(first, last):
         raise TypeError('first and last must be IP addresses, not networks')
     if first.version != last.version:
         raise TypeError("%s and %s are not of the same version" % (
-            str(first), str(last)))
+                        str(first), str(last)))
     if first > last:
         raise ValueError('last IP address must be greater than first')
 
@@ -344,17 +346,17 @@ def collapse_address_list(addresses):
         if isinstance(ip, _BaseIP):
             if ips and ips[-1]._version != ip._version:
                 raise TypeError("%s and %s are not of the same version" % (
-                    str(ip), str(ips[-1])))
+                                str(ip), str(ips[-1])))
             ips.append(ip)
         elif ip._prefixlen == ip._max_prefixlen:
             if ips and ips[-1]._version != ip._version:
                 raise TypeError("%s and %s are not of the same version" % (
-                    str(ip), str(ips[-1])))
+                                str(ip), str(ips[-1])))
             ips.append(ip.ip)
         else:
             if nets and nets[-1]._version != ip._version:
                 raise TypeError("%s and %s are not of the same version" % (
-                    str(ip), str(ips[-1])))
+                                str(ip), str(nets[-1])))
             nets.append(ip)
 
     # sort and dedup
@@ -455,8 +457,8 @@ class _BaseIP(_IPAddrBase):
 
     def __eq__(self, other):
         try:
-            return (self._ip == other._ip
-                    and self._version == other._version)
+            return (self._ip == other._ip and
+                    self._version == other._version)
         except AttributeError:
             return NotImplemented
 
@@ -481,10 +483,10 @@ class _BaseIP(_IPAddrBase):
     def __lt__(self, other):
         if self._version != other._version:
             raise TypeError('%s and %s are not of the same version' % (
-                str(self), str(other)))
+                            str(self), str(other)))
         if not isinstance(other, _BaseIP):
             raise TypeError('%s and %s are not of the same type' % (
-                str(self), str(other)))
+                            str(self), str(other)))
         if self._ip != other._ip:
             return self._ip < other._ip
         return False
@@ -492,10 +494,10 @@ class _BaseIP(_IPAddrBase):
     def __gt__(self, other):
         if self._version != other._version:
             raise TypeError('%s and %s are not of the same version' % (
-                str(self), str(other)))
+                            str(self), str(other)))
         if not isinstance(other, _BaseIP):
             raise TypeError('%s and %s are not of the same type' % (
-                str(self), str(other)))
+                            str(self), str(other)))
         if self._ip != other._ip:
             return self._ip > other._ip
         return False
@@ -580,10 +582,10 @@ class _BaseNet(_IPAddrBase):
     def __lt__(self, other):
         if self._version != other._version:
             raise TypeError('%s and %s are not of the same version' % (
-                str(self), str(other)))
+                            str(self), str(other)))
         if not isinstance(other, _BaseNet):
             raise TypeError('%s and %s are not of the same type' % (
-                str(self), str(other)))
+                            str(self), str(other)))
         if self.network != other.network:
             return self.network < other.network
         if self.netmask != other.netmask:
@@ -593,10 +595,10 @@ class _BaseNet(_IPAddrBase):
     def __gt__(self, other):
         if self._version != other._version:
             raise TypeError('%s and %s are not of the same version' % (
-                str(self), str(other)))
+                            str(self), str(other)))
         if not isinstance(other, _BaseNet):
             raise TypeError('%s and %s are not of the same type' % (
-                str(self), str(other)))
+                            str(self), str(other)))
         if self.network != other.network:
             return self.network > other.network
         if self.netmask != other.netmask:
@@ -617,13 +619,13 @@ class _BaseNet(_IPAddrBase):
 
     def __eq__(self, other):
         try:
-            return (self._version == other._version
-                    and self.network == other.network
-                    and int(self.netmask) == int(other.netmask))
+            return (self._version == other._version and
+                    self.network == other.network and
+                    int(self.netmask) == int(other.netmask))
         except AttributeError:
             if isinstance(other, _BaseIP):
-                return (self._version == other._version
-                        and self._ip == other._ip)
+                return (self._version == other._version and
+                        self._ip == other._ip)
 
     def __ne__(self, other):
         eq = self.__eq__(other)
@@ -769,18 +771,18 @@ class _BaseNet(_IPAddrBase):
                 s1, s2 = s2.subnet()
             else:
                 # If we got here, there's a bug somewhere.
-                assert True is False, ('Error performing exclusion: '
-                                       's1: %s s2: %s other: %s' %
-                                       (str(s1), str(s2), str(other)))
+                assert False, ('Error performing exclusion: '
+                               's1: %s s2: %s other: %s' %
+                               (str(s1), str(s2), str(other)))
         if s1 == other:
             ret_addrs.append(s2)
         elif s2 == other:
             ret_addrs.append(s1)
         else:
             # If we got here, there's a bug somewhere.
-            assert True is False, ('Error performing exclusion: '
-                                   's1: %s s2: %s other: %s' %
-                                   (str(s1), str(s2), str(other)))
+            assert False, ('Error performing exclusion: '
+                           's1: %s s2: %s other: %s' %
+                           (str(s1), str(s2), str(other)))
 
         return sorted(ret_addrs, key=_BaseNet._get_networks_key)
 
@@ -845,8 +847,8 @@ class _BaseNet(_IPAddrBase):
         """
         return (self._version, self.network, self.netmask)
 
-    def _ip_int_from_prefix(self, prefixlen=None):
-        """Turn the prefix length netmask into a int for comparison.
+    def _ip_int_from_prefix(self, prefixlen):
+        """Turn the prefix length into a bitwise netmask.
 
         Args:
             prefixlen: An integer, the prefix length.
@@ -855,42 +857,90 @@ class _BaseNet(_IPAddrBase):
             An integer.
 
         """
-        if not prefixlen and prefixlen != 0:
-            prefixlen = self._prefixlen
         return self._ALL_ONES ^ (self._ALL_ONES >> prefixlen)
 
-    def _prefix_from_ip_int(self, ip_int, mask=32):
-        """Return prefix length from the decimal netmask.
+    def _prefix_from_ip_int(self, ip_int):
+        """Return prefix length from a bitwise netmask.
 
         Args:
-            ip_int: An integer, the IP address.
-            mask: The netmask.  Defaults to 32.
+            ip_int: An integer, the netmask in expanded bitwise format.
 
         Returns:
             An integer, the prefix length.
 
+        Raises:
+            NetmaskValueError: If the input is not a valid netmask.
+
         """
-        while mask:
-            if ip_int & 1 == 1:
+        prefixlen = self._max_prefixlen
+        while prefixlen:
+            if ip_int & 1:
                 break
             ip_int >>= 1
-            mask -= 1
+            prefixlen -= 1
 
-        return mask
+        if ip_int == (1 << prefixlen) - 1:
+            return prefixlen
+        else:
+            raise NetmaskValueError('Bit pattern does not match /1*0*/')
 
-    def _ip_string_from_prefix(self, prefixlen=None):
-        """Turn a prefix length into a dotted decimal string.
+    def _prefix_from_prefix_string(self, prefixlen_str):
+        """Turn a prefix length string into an integer.
 
         Args:
-            prefixlen: An integer, the netmask prefix length.
+            prefixlen_str: A decimal string containing the prefix length.
 
         Returns:
-            A string, the dotted decimal netmask string.
+            The prefix length as an integer.
+
+        Raises:
+            NetmaskValueError: If the input is malformed or out of range.
 
         """
-        if not prefixlen:
-            prefixlen = self._prefixlen
-        return self._string_from_ip_int(self._ip_int_from_prefix(prefixlen))
+        try:
+            if not _BaseV4._DECIMAL_DIGITS.issuperset(prefixlen_str):
+                raise ValueError
+            prefixlen = int(prefixlen_str)
+            if not (0 <= prefixlen <= self._max_prefixlen):
+                raise ValueError
+        except ValueError:
+            raise NetmaskValueError('%s is not a valid prefix length' %
+                                    prefixlen_str)
+        return prefixlen
+
+    def _prefix_from_ip_string(self, ip_str):
+        """Turn a netmask/hostmask string into a prefix length.
+
+        Args:
+            ip_str: A netmask or hostmask, formatted as an IP address.
+
+        Returns:
+            The prefix length as an integer.
+
+        Raises:
+            NetmaskValueError: If the input is not a netmask or hostmask.
+
+        """
+        # Parse the netmask/hostmask like an IP address.
+        try:
+            ip_int = self._ip_int_from_string(ip_str)
+        except AddressValueError:
+            raise NetmaskValueError('%s is not a valid netmask' % ip_str)
+
+        # Try matching a netmask (this would be /1*0*/ as a bitwise regexp).
+        # Note that the two ambiguous cases (all-ones and all-zeroes) are
+        # treated as netmasks.
+        try:
+            return self._prefix_from_ip_int(ip_int)
+        except NetmaskValueError:
+            pass
+
+        # Invert the bits, and try matching a /0+1+/ hostmask instead.
+        ip_int ^= self._ALL_ONES
+        try:
+            return self._prefix_from_ip_int(ip_int)
+        except NetmaskValueError:
+            raise NetmaskValueError('%s is not a valid netmask' % ip_str)
 
     def iter_subnets(self, prefixlen_diff=1, new_prefix=None):
         """The subnets which join to make the current subnet.
@@ -933,7 +983,7 @@ class _BaseNet(_IPAddrBase):
             raise ValueError('prefix length diff must be > 0')
         new_prefixlen = self._prefixlen + prefixlen_diff
 
-        if not self._is_valid_netmask(str(new_prefixlen)):
+        if new_prefixlen > self._max_prefixlen:
             raise ValueError(
                 'prefix length diff %d is invalid for netblock %s' % (
                     new_prefixlen, str(self)))
@@ -1232,9 +1282,6 @@ class IPv4Network(_BaseV4, _BaseNet):
 
     """
 
-    # the valid octets for host and netmasks. only useful for IPv4.
-    _valid_mask_octets = set((255, 254, 252, 248, 240, 224, 192, 128, 0))
-
     def __init__(self, address, strict=False):
         """Instantiate a new IPv4 network object.
 
@@ -1297,30 +1344,18 @@ class IPv4Network(_BaseV4, _BaseNet):
         self.ip = IPv4Address(self._ip)
 
         if len(addr) == 2:
-            mask = addr[1].split('.')
-            if len(mask) == 4:
-                # We have dotted decimal netmask.
-                if self._is_valid_netmask(addr[1]):
-                    self.netmask = IPv4Address(self._ip_int_from_string(
-                        addr[1]))
-                elif self._is_hostmask(addr[1]):
-                    self.netmask = IPv4Address(
-                        self._ip_int_from_string(addr[1]) ^ self._ALL_ONES)
-                else:
-                    raise NetmaskValueError('%s is not a valid netmask' % addr[1])
-
-                self._prefixlen = self._prefix_from_ip_int(int(self.netmask))
-            else:
-                # We have a netmask in prefix length form.
-                if not self._is_valid_netmask(addr[1]):
-                    raise NetmaskValueError(addr[1])
-                self._prefixlen = int(addr[1])
-                self.netmask = IPv4Address(self._ip_int_from_prefix(
-                    self._prefixlen))
+            try:
+                # Check for a netmask in prefix length form.
+                self._prefixlen = self._prefix_from_prefix_string(addr[1])
+            except NetmaskValueError:
+                # Check for a netmask or hostmask in dotted-quad form.
+                # This may raise NetmaskValueError.
+                self._prefixlen = self._prefix_from_ip_string(addr[1])
         else:
             self._prefixlen = self._max_prefixlen
-            self.netmask = IPv4Address(self._ip_int_from_prefix(
-                self._prefixlen))
+
+        self.netmask = IPv4Address(self._ip_int_from_prefix(self._prefixlen))
+
         if strict:
             if self.ip != self.network:
                 raise ValueError('%s has host bits set' %
@@ -1328,57 +1363,18 @@ class IPv4Network(_BaseV4, _BaseNet):
         if self._prefixlen == (self._max_prefixlen - 1):
             self.iterhosts = self.__iter__
 
-    def _is_hostmask(self, ip_str):
-        """Test if the IP string is a hostmask (rather than a netmask).
+    # backwards compatibility
+    def IsRFC1918(self):
+        return self.is_private
 
-        Args:
-            ip_str: A string, the potential hostmask.
+    def IsMulticast(self):
+        return self.is_multicast
 
-        Returns:
-            A boolean, True if the IP string is a hostmask.
+    def IsLoopback(self):
+        return self.is_loopback
 
-        """
-        bits = ip_str.split('.')
-        try:
-            parts = [int(x) for x in bits if int(x) in self._valid_mask_octets]
-        except ValueError:
-            return False
-        if len(parts) != len(bits):
-            return False
-        if parts[0] < parts[-1]:
-            return True
-        return False
-
-    def _is_valid_netmask(self, netmask):
-        """Verify that the netmask is valid.
-
-        Args:
-            netmask: A string, either a prefix or dotted decimal
-              netmask.
-
-        Returns:
-            A boolean, True if the prefix represents a valid IPv4
-            netmask.
-
-        """
-        mask = netmask.split('.')
-        if len(mask) == 4:
-            if [x for x in mask if int(x) not in self._valid_mask_octets]:
-                return False
-            if [y for idx, y in enumerate(mask) if idx > 0 and y > mask[idx - 1]]:
-                return False
-            return True
-        try:
-            netmask = int(netmask)
-        except ValueError:
-            return False
-        return 0 <= netmask <= self._max_prefixlen
-
-    # backwards compatibility
-    IsRFC1918 = lambda self: self.is_private
-    IsMulticast = lambda self: self.is_multicast
-    IsLoopback = lambda self: self.is_loopback
-    IsLinkLocal = lambda self: self.is_link_local
+    def IsLinkLocal(self):
+        return self.is_link_local
 
 
 class _BaseV6(object):
@@ -1493,6 +1489,8 @@ class _BaseV6(object):
         # Whitelist the characters, since int() allows a lot of bizarre stuff.
         if not self._HEX_DIGITS.issuperset(hextet_str):
             raise ValueError
+        if len(hextet_str) > 4:
+            raise ValueError
         hextet_int = int(hextet_str, 16)
         if hextet_int > 0xFFFF:
             raise ValueError
@@ -1794,6 +1792,7 @@ class IPv6Address(_BaseV6, _BaseIP):
 
 
 class IPv6Network(_BaseV6, _BaseNet):
+
     """This class represents and manipulates 128-bit IPv6 networks.
 
     Attributes: [examples for IPv6('2001:658:22A:CAFE:200::1/64')]
@@ -1860,10 +1859,8 @@ class IPv6Network(_BaseV6, _BaseNet):
         self.ip = IPv6Address(self._ip)
 
         if len(addr) == 2:
-            if self._is_valid_netmask(addr[1]):
-                self._prefixlen = int(addr[1])
-            else:
-                raise NetmaskValueError(addr[1])
+            # This may raise NetmaskValueError
+            self._prefixlen = self._prefix_from_prefix_string(addr[1])
         else:
             self._prefixlen = self._max_prefixlen
 
@@ -1876,23 +1873,6 @@ class IPv6Network(_BaseV6, _BaseNet):
         if self._prefixlen == (self._max_prefixlen - 1):
             self.iterhosts = self.__iter__
 
-    def _is_valid_netmask(self, prefixlen):
-        """Verify that the netmask/prefixlen is valid.
-
-        Args:
-            prefixlen: A string, the netmask in prefix length format.
-
-        Returns:
-            A boolean, True if the prefix represents a valid IPv6
-            netmask.
-
-        """
-        try:
-            prefixlen = int(prefixlen)
-        except ValueError:
-            return False
-        return 0 <= prefixlen <= self._max_prefixlen
-
     @property
     def with_netmask(self):
         return self.with_prefixlen