diff --git a/database.py b/database.py index 2d970e3..12569d3 100644 --- a/database.py +++ b/database.py @@ -20,7 +20,7 @@ PathType = enum.Enum('PathType', 'Rule Hostname Zone Asn Ip4 Ip6') RulePath = typing.Union[None] Asn = int DomainPath = typing.List[str] -Ip4Path = typing.List[int] +Ip4Path = typing.Tuple[int, int] # value, prefixlen Ip6Path = typing.List[int] Path = typing.Union[RulePath, DomainPath, Asn, Ip4Path, Ip6Path] TypedPath = typing.Tuple[PathType, Path] @@ -139,33 +139,33 @@ class Database(Profiler): @staticmethod def pack_ip4address(address: str) -> Ip4Path: - addr: Ip4Path = [0] * 32 - octets = [int(octet) for octet in address.split('.')] - for b in range(32): - if (octets[b//8] >> b % 8) & 0b1: - addr[b] = 1 - return addr + addr = 0 + for split in address.split('.'): + addr = addr << 4 + int(split) + return (addr, 32) @staticmethod def unpack_ip4address(address: Ip4Path) -> str: + addr, prefixlen = address + assert prefixlen == 32 + octets: typing.List[int] = list() octets = [0] * 4 - for b, bit in enumerate(address): - octets[b//8] = (octets[b//8] << 1) + bit + for o in reversed(range(4)): + octets[o] = addr & 0xFF + addr >>= 8 return '.'.join(map(str, octets)) @staticmethod def pack_ip4network(network: str) -> Ip4Path: address, prefixlen_str = network.split('/') prefixlen = int(prefixlen_str) - return Database.pack_ip4address(address)[:prefixlen] + addr, _ = Database.pack_ip4address(address) + return (addr, prefixlen) @staticmethod def unpack_ip4network(network: Ip4Path) -> str: - address = network.copy() - prefixlen = len(network) - for _ in range(32-prefixlen): - address.append(0) - addr = Database.unpack_ip4address(address) + address, prefixlen = network + addr = Database.unpack_ip4address((address, 32)) return f'{addr}/{prefixlen}' def update_references(self) -> None: @@ -224,20 +224,19 @@ class Database(Profiler): def get_ip4(self, ip4_str: str) -> typing.Iterable[TypedPath]: self.enter_step('get_ip4_pack') - ip4 = self.pack_ip4address(ip4_str) + ip4, prefixlen = self.pack_ip4address(ip4_str) self.enter_step('get_ip4_brws') dic = self.ip4tree - depth = 0 - for part in ip4: + for i in reversed(range(prefixlen)): + part = (ip4 >> i) & 0b1 if dic.match: self.enter_step('get_ip4_yield') - yield (PathType.Ip4, ip4[:depth]) + yield (PathType.Ip4, (ip4, 32-i)) self.enter_step('get_ip4_brws') next_dic = dic.children[part] if next_dic is None: return dic = next_dic - depth += 1 if dic.match: self.enter_step('get_ip4_yield') yield (PathType.Ip4, ip4) @@ -307,10 +306,11 @@ class Database(Profiler): self.enter_step('set_ip4add_pack') if is_first_party or source: raise NotImplementedError + ip4, prefixlen = self.pack_ip4address(ip4address_str) self.enter_step('set_ip4add_brws') - ip4address = self.pack_ip4address(ip4address_str) dic = self.ip4tree - for part in ip4address: + for i in reversed(range(prefixlen)): + part = (ip4 >> i) & 0b1 if dic.match: # Refuse to add ip4address whose network is already matching return @@ -330,9 +330,10 @@ class Database(Profiler): if is_first_party or source: raise NotImplementedError self.enter_step('set_ip4net_brws') - ip4network = self.pack_ip4network(ip4network_str) + ip4, prefixlen = self.pack_ip4network(ip4network_str) dic = self.ip4tree - for part in ip4network: + for i in reversed(range(prefixlen)): + part = (ip4 >> i) & 0b1 if dic.match: # Refuse to add ip4network whose parent network # is already matching