From dc44dea50598256d8dce70d783849418bf32f9be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20=E2=80=9CFrogeye=E2=80=9D=20Preud=27homme?= Date: Sun, 8 Dec 2019 01:23:36 +0100 Subject: [PATCH] Optimized IP matching --- filter_subdomains.py | 40 +++++++++++++++++++++++++++++----------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/filter_subdomains.py b/filter_subdomains.py index 309bb6c..601a031 100755 --- a/filter_subdomains.py +++ b/filter_subdomains.py @@ -15,14 +15,20 @@ import ipaddress # DomainRule = typing.Union[bool, typing.Dict[str, 'DomainRule']] DomainRule = typing.Union[bool, typing.Dict] +# IpRule = typing.Union[bool, typing.Dict[int, 'DomainRule']] +IpRule = typing.Union[bool, typing.Dict] RULES_DICT: DomainRule = dict() -RULES_IP: typing.Set[ipaddress.IPv4Network] = set() +RULES_IP_DICT: IpRule = dict() + + +def get_bits(address: ipaddress.IPv4Address) -> typing.Iterator[int]: + for char in address.packed: + for i in range(7, -1, -1): + yield (char >> i) & 0b1 def subdomain_matching(subdomain: str) -> bool: - if not RULES_DICT: - return False parts = subdomain.split('.') parts.reverse() dic = RULES_DICT @@ -36,12 +42,16 @@ def subdomain_matching(subdomain: str) -> bool: def ip_matching(ip_str: str) -> bool: - if not RULES_IP: - return False ip = ipaddress.ip_address(ip_str) - for net in RULES_IP: - if ip in net: - return True + dic = RULES_IP_DICT + i = 0 + for bit in get_bits(ip): + i += 1 + if isinstance(dic, bool) or bit not in dic: + break + dic = dic[bit] + if isinstance(dic, bool): + return dic return False @@ -78,9 +88,17 @@ def register_rule(subdomain: str) -> None: def register_rule_ip(network: str) -> None: net = ipaddress.ip_network(network) - RULES_IP.add(net) - # If RULES_IP start becoming bigger, - # we might implement a binary tree for performance + ip = net.network_address + dic = RULES_IP_DICT + last_bit = net.prefixlen - 1 + for b, bit in enumerate(get_bits(ip)): + if isinstance(dic, bool): + return + if b == last_bit: + dic[bit] = True + else: + dic.setdefault(bit, dict()) + dic = dic[bit] if __name__ == '__main__':