Added level
Also fixed IP logic because this was real messed up
This commit is contained in:
		
							parent
							
								
									3197fa1663
								
							
						
					
					
						commit
						03a4042238
					
				
					 3 changed files with 167 additions and 67 deletions
				
			
		
							
								
								
									
										213
									
								
								database.py
									
										
									
									
									
								
							
							
						
						
									
										213
									
								
								database.py
									
										
									
									
									
								
							|  | @ -26,38 +26,50 @@ class Path(): | |||
| 
 | ||||
| 
 | ||||
| class RulePath(Path): | ||||
|     pass | ||||
|     def __str__(self) -> str: | ||||
|         return '(rules)' | ||||
| 
 | ||||
| 
 | ||||
| class DomainPath(Path): | ||||
|     def __init__(self, path: typing.List[str]): | ||||
|         self.path = path | ||||
|     def __init__(self, parts: typing.List[str]): | ||||
|         self.parts = parts | ||||
| 
 | ||||
|     def __str__(self) -> str: | ||||
|         return '?.' + Database.unpack_domain(self) | ||||
| 
 | ||||
| 
 | ||||
| class HostnamePath(DomainPath): | ||||
|     pass | ||||
|     def __str__(self) -> str: | ||||
|         return Database.unpack_domain(self) | ||||
| 
 | ||||
| 
 | ||||
| class ZonePath(DomainPath): | ||||
|     pass | ||||
|     def __str__(self) -> str: | ||||
|         return '*.' + Database.unpack_domain(self) | ||||
| 
 | ||||
| 
 | ||||
| class AsnPath(Path): | ||||
|     def __init__(self, asn: Asn): | ||||
|         self.asn = asn | ||||
| 
 | ||||
|     def __str__(self) -> str: | ||||
|         return Database.unpack_asn(self) | ||||
| 
 | ||||
| 
 | ||||
| class Ip4Path(Path): | ||||
|     def __init__(self, value: int, prefixlen: int): | ||||
|         self.value = value | ||||
|         self.prefixlen = prefixlen | ||||
| 
 | ||||
|     def __str__(self) -> str: | ||||
|         return Database.unpack_ip4network(self) | ||||
| 
 | ||||
| 
 | ||||
| class Match(): | ||||
|     def __init__(self) -> None: | ||||
|         self.updated: int = 0 | ||||
|         self.level: int = 0 | ||||
|         self.source: Path = RulePath() | ||||
|         self.source: typing.Optional[Path] = None | ||||
|         # FP dupplicate args | ||||
| 
 | ||||
|     def set(self, | ||||
|  | @ -86,16 +98,16 @@ class DomainTreeNode(): | |||
|         self.match_hostname = Match() | ||||
| 
 | ||||
| 
 | ||||
| class IpTreeNode(): | ||||
| class IpTreeNode(Match): | ||||
|     def __init__(self) -> None: | ||||
|         Match.__init__(self) | ||||
|         self.zero: typing.Optional[IpTreeNode] = None | ||||
|         self.one: typing.Optional[IpTreeNode] = None | ||||
|         self.match = Match() | ||||
| 
 | ||||
| 
 | ||||
| Node = typing.Union[DomainTreeNode, IpTreeNode, AsnNode] | ||||
| NodeCallable = typing.Callable[[Path, | ||||
|                                 Node, | ||||
| MatchCallable = typing.Callable[[Path, | ||||
|                                  Match, | ||||
|                                  typing.Optional[typing.Any]], | ||||
|                                 typing.Any] | ||||
| 
 | ||||
|  | @ -109,7 +121,6 @@ class Profiler(): | |||
|         self.step_dict: typing.Dict[str, int] = dict() | ||||
| 
 | ||||
|     def enter_step(self, name: str) -> None: | ||||
|         return | ||||
|         now = time.perf_counter() | ||||
|         try: | ||||
|             self.time_dict[self.time_step] += now - self.time_last | ||||
|  | @ -132,7 +143,7 @@ class Profiler(): | |||
| 
 | ||||
| 
 | ||||
| class Database(Profiler): | ||||
|     VERSION = 11 | ||||
|     VERSION = 13 | ||||
|     PATH = "blocking.p" | ||||
| 
 | ||||
|     def initialize(self) -> None: | ||||
|  | @ -181,7 +192,7 @@ class Database(Profiler): | |||
| 
 | ||||
|     @staticmethod | ||||
|     def unpack_domain(domain: DomainPath) -> str: | ||||
|         return '.'.join(domain.path[::-1]) | ||||
|         return '.'.join(domain.parts[::-1]) | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def pack_asn(asn: str) -> AsnPath: | ||||
|  | @ -230,62 +241,107 @@ class Database(Profiler): | |||
|             addr >>= 8 | ||||
|         return '.'.join(map(str, octets)) + '/' + str(network.prefixlen) | ||||
| 
 | ||||
|     def get_match(self, path: Path) -> Match: | ||||
|         if isinstance(path, RulePath): | ||||
|             return Match() | ||||
|         elif isinstance(path, AsnPath): | ||||
|             return self.asns[path.asn] | ||||
|         elif isinstance(path, DomainPath): | ||||
|             dicd = self.domtree | ||||
|             for part in path.parts: | ||||
|                 dicd = dicd.children[part] | ||||
|             if isinstance(path, HostnamePath): | ||||
|                 return dicd.match_hostname | ||||
|             elif isinstance(path, ZonePath): | ||||
|                 return dicd.match_zone | ||||
|             else: | ||||
|                 raise ValueError | ||||
|         elif isinstance(path, Ip4Path): | ||||
|             dici = self.ip4tree | ||||
|             for i in range(31, 31-path.prefixlen, -1): | ||||
|                 bit = (path.value >> i) & 0b1 | ||||
|                 dici_next = dici.one if bit else dici.zero | ||||
|                 if not dici_next: | ||||
|                     raise IndexError | ||||
|                 dici = dici_next | ||||
|             return dici | ||||
|         else: | ||||
|             raise ValueError | ||||
| 
 | ||||
|     def exec_each_domain(self, | ||||
|                          callback: NodeCallable, | ||||
|                          callback: MatchCallable, | ||||
|                          arg: typing.Any = None, | ||||
|                          _dic: DomainTreeNode = None, | ||||
|                          _par: DomainPath = None, | ||||
|                          ) -> typing.Any: | ||||
|         _dic = _dic or self.domtree | ||||
|         _par = _par or DomainPath([]) | ||||
|         yield from callback(_par, _dic, arg) | ||||
|         if _dic.match_hostname.active(): | ||||
|             yield from callback( | ||||
|                 HostnamePath(_par.parts), | ||||
|                 _dic.match_hostname, | ||||
|                 arg | ||||
|             ) | ||||
|         if _dic.match_zone.active(): | ||||
|             yield from callback( | ||||
|                 ZonePath(_par.parts), | ||||
|                 _dic.match_zone, | ||||
|                 arg | ||||
|             ) | ||||
|         for part in _dic.children: | ||||
|             dic = _dic.children[part] | ||||
|             yield from self.exec_each_domain( | ||||
|                 callback, | ||||
|                 arg, | ||||
|                 _dic=dic, | ||||
|                 _par=DomainPath(_par.path + [part]) | ||||
|                 _par=DomainPath(_par.parts + [part]) | ||||
|             ) | ||||
| 
 | ||||
|     def exec_each_ip4(self, | ||||
|                       callback: NodeCallable, | ||||
|                       callback: MatchCallable, | ||||
|                       arg: typing.Any = None, | ||||
|                       _dic: IpTreeNode = None, | ||||
|                       _par: Ip4Path = None, | ||||
|                       ) -> typing.Any: | ||||
|         _dic = _dic or self.ip4tree | ||||
|         _par = _par or Ip4Path(0, 0) | ||||
|         callback(_par, _dic, arg) | ||||
|         if _dic.active(): | ||||
|             yield from callback( | ||||
|                 _par, | ||||
|                 _dic, | ||||
|                 arg | ||||
|             ) | ||||
| 
 | ||||
|         # 0 | ||||
|         pref = _par.prefixlen + 1 | ||||
|         dic = _dic.zero | ||||
|         if dic: | ||||
|             addr0 = _par.value & (0xFFFFFFFF ^ (1 << (32-_par.prefixlen))) | ||||
|             addr0 = _par.value & (0xFFFFFFFF ^ (1 << (32-pref))) | ||||
|             assert addr0 == _par.value | ||||
|             yield from self.exec_each_ip4( | ||||
|                 callback, | ||||
|                 arg, | ||||
|                 _dic=dic, | ||||
|                 _par=Ip4Path(addr0, _par.prefixlen+1) | ||||
|                 _par=Ip4Path(addr0, pref) | ||||
|             ) | ||||
|         # 1 | ||||
|         dic = _dic.one | ||||
|         if dic: | ||||
|             addr1 = _par.value | (1 << (32-_par.prefixlen)) | ||||
|             addr1 = _par.value | (1 << (32-pref)) | ||||
|             yield from self.exec_each_ip4( | ||||
|                 callback, | ||||
|                 arg, | ||||
|                 _dic=dic, | ||||
|                 _par=Ip4Path(addr1, _par.prefixlen+1) | ||||
|                 _par=Ip4Path(addr1, pref) | ||||
|             ) | ||||
| 
 | ||||
|     def exec_each(self, | ||||
|                   callback: NodeCallable, | ||||
|                   callback: MatchCallable, | ||||
|                   arg: typing.Any = None, | ||||
|                   ) -> typing.Any: | ||||
|         yield from self.exec_each_domain(callback) | ||||
|         yield from self.exec_each_ip4(callback) | ||||
|         # TODO ASN | ||||
| 
 | ||||
|     def update_references(self) -> None: | ||||
|         raise NotImplementedError | ||||
|  | @ -293,27 +349,47 @@ class Database(Profiler): | |||
|     def prune(self, before: int, base_only: bool = False) -> None: | ||||
|         raise NotImplementedError | ||||
| 
 | ||||
|     def explain(self, entry: int) -> str: | ||||
|         raise NotImplementedError | ||||
|     def explain(self, path: Path) -> str: | ||||
|         string = str(path) | ||||
|         match = self.get_match(path) | ||||
|         if match.source: | ||||
|             string += f' ← {self.explain(match.source)}' | ||||
|         return string | ||||
| 
 | ||||
|     def export(self, | ||||
|                first_party_only: bool = False, | ||||
|                end_chain_only: bool = False, | ||||
|                explain: bool = False, | ||||
|                ) -> typing.Iterable[str]: | ||||
|         if first_party_only or end_chain_only or explain: | ||||
|         if first_party_only or end_chain_only: | ||||
|             raise NotImplementedError | ||||
| 
 | ||||
|         def export_cb(path: Path, node: Node, _: typing.Any | ||||
|         def export_cb(path: Path, match: Match, _: typing.Any | ||||
|                       ) -> typing.Iterable[str]: | ||||
|             assert isinstance(path, DomainPath) | ||||
|             assert isinstance(node, DomainTreeNode) | ||||
|             if node.match_hostname: | ||||
|                 a = self.unpack_domain(path) | ||||
|                 yield a | ||||
|             if isinstance(path, HostnamePath): | ||||
|                 if explain: | ||||
|                     yield self.explain(path) | ||||
|                 else: | ||||
|                     yield self.unpack_domain(path) | ||||
| 
 | ||||
|         yield from self.exec_each_domain(export_cb, None) | ||||
| 
 | ||||
|     def list_rules(self, | ||||
|                    first_party_only: bool = False, | ||||
|                    ) -> typing.Iterable[str]: | ||||
|         if first_party_only: | ||||
|             raise NotImplementedError | ||||
| 
 | ||||
|         def list_rules_cb(path: Path, match: Match, _: typing.Any | ||||
|                           ) -> typing.Iterable[str]: | ||||
|             if isinstance(path, ZonePath) \ | ||||
|                     or (isinstance(path, Ip4Path) and path.prefixlen < 32): | ||||
|                 # if match.level == 0: | ||||
|                 yield self.explain(path) | ||||
| 
 | ||||
|         yield from self.exec_each(list_rules_cb, None) | ||||
| 
 | ||||
|     def count_rules(self, | ||||
|                     first_party_only: bool = False, | ||||
|                     ) -> str: | ||||
|  | @ -325,10 +401,10 @@ class Database(Profiler): | |||
|         self.enter_step('get_domain_brws') | ||||
|         dic = self.domtree | ||||
|         depth = 0 | ||||
|         for part in domain.path: | ||||
|         for part in domain.parts: | ||||
|             if dic.match_zone.active(): | ||||
|                 self.enter_step('get_domain_yield') | ||||
|                 yield ZonePath(domain.path[:depth]) | ||||
|                 yield ZonePath(domain.parts[:depth]) | ||||
|             self.enter_step('get_domain_brws') | ||||
|             if part not in dic.children: | ||||
|                 return | ||||
|  | @ -336,27 +412,28 @@ class Database(Profiler): | |||
|             depth += 1 | ||||
|         if dic.match_zone.active(): | ||||
|             self.enter_step('get_domain_yield') | ||||
|             yield ZonePath(domain.path) | ||||
|             yield ZonePath(domain.parts) | ||||
|         if dic.match_hostname.active(): | ||||
|             self.enter_step('get_domain_yield') | ||||
|             yield HostnamePath(domain.path) | ||||
|             yield HostnamePath(domain.parts) | ||||
| 
 | ||||
|     def get_ip4(self, ip4_str: str) -> typing.Iterable[Path]: | ||||
|         self.enter_step('get_ip4_pack') | ||||
|         ip4 = self.pack_ip4address(ip4_str) | ||||
|         self.enter_step('get_ip4_brws') | ||||
|         dic = self.ip4tree | ||||
|         for i in reversed(range(ip4.prefixlen)): | ||||
|             part = (ip4.value >> i) & 0b1 | ||||
|             if dic.match.active(): | ||||
|         for i in range(31, 31-ip4.prefixlen, -1): | ||||
|             bit = (ip4.value >> i) & 0b1 | ||||
|             if dic.active(): | ||||
|                 self.enter_step('get_ip4_yield') | ||||
|                 yield Ip4Path(ip4.value, 32-i) | ||||
|                 a = Ip4Path(ip4.value >> (i+1) << (i+1), 31-i) | ||||
|                 yield a | ||||
|                 self.enter_step('get_ip4_brws') | ||||
|             next_dic = dic.one if part else dic.zero | ||||
|             next_dic = dic.one if bit else dic.zero | ||||
|             if next_dic is None: | ||||
|                 return | ||||
|             dic = next_dic | ||||
|         if dic.match.active(): | ||||
|         if dic.active(): | ||||
|             self.enter_step('get_ip4_yield') | ||||
|             yield ip4 | ||||
| 
 | ||||
|  | @ -374,9 +451,16 @@ class Database(Profiler): | |||
|         if is_first_party: | ||||
|             raise NotImplementedError | ||||
|         domain = self.pack_domain(domain_str) | ||||
|         self.enter_step('set_domain_src') | ||||
|         if source is None: | ||||
|             level = 0 | ||||
|             source = RulePath() | ||||
|         else: | ||||
|             match = self.get_match(source) | ||||
|             level = match.level + 1 | ||||
|         self.enter_step('set_domain_brws') | ||||
|         dic = self.domtree | ||||
|         for part in domain.path: | ||||
|         for part in domain.parts: | ||||
|             if dic.match_zone.active(): | ||||
|                 # Refuse to add domain whose zone is already matching | ||||
|                 return | ||||
|  | @ -389,8 +473,8 @@ class Database(Profiler): | |||
|             match = dic.match_zone | ||||
|         match.set( | ||||
|             updated, | ||||
|             0,  # TODO Level | ||||
|             source or RulePath(), | ||||
|             level, | ||||
|             source, | ||||
|         ) | ||||
| 
 | ||||
|     def set_hostname(self, | ||||
|  | @ -411,14 +495,23 @@ class Database(Profiler): | |||
|         self.enter_step('set_asn') | ||||
|         if is_first_party: | ||||
|             raise NotImplementedError | ||||
|         if source is None: | ||||
|             level = 0 | ||||
|             source = RulePath() | ||||
|         else: | ||||
|             match = self.get_match(source) | ||||
|             level = match.level + 1 | ||||
|         path = self.pack_asn(asn_str) | ||||
|         if path.asn in self.asns: | ||||
|             match = self.asns[path.asn] | ||||
|         else: | ||||
|             match = AsnNode() | ||||
|             self.asns[path.asn] = match | ||||
|         match.set( | ||||
|             updated, | ||||
|                 0, | ||||
|                 source or RulePath() | ||||
|             level, | ||||
|             source, | ||||
|         ) | ||||
|         self.asns[path.asn] = match | ||||
| 
 | ||||
|     def _set_ip4(self, | ||||
|                  ip4: Ip4Path, | ||||
|  | @ -427,24 +520,32 @@ class Database(Profiler): | |||
|                  source: Path = None) -> None: | ||||
|         if is_first_party: | ||||
|             raise NotImplementedError | ||||
|         self.enter_step('set_ip4_src') | ||||
|         if source is None: | ||||
|             level = 0 | ||||
|             source = RulePath() | ||||
|         else: | ||||
|             match = self.get_match(source) | ||||
|             level = match.level + 1 | ||||
|         self.enter_step('set_ip4_brws') | ||||
|         dic = self.ip4tree | ||||
|         for i in reversed(range(ip4.prefixlen)): | ||||
|             part = (ip4.value >> i) & 0b1 | ||||
|             if dic.match.active(): | ||||
|         for i in range(31, 31-ip4.prefixlen, -1): | ||||
|             bit = (ip4.value >> i) & 0b1 | ||||
|             if dic.active(): | ||||
|                 # Refuse to add ip4* whose network is already matching | ||||
|                 return | ||||
|             next_dic = dic.one if part else dic.zero | ||||
|             next_dic = dic.one if bit else dic.zero | ||||
|             if next_dic is None: | ||||
|                 next_dic = IpTreeNode() | ||||
|                 if part: | ||||
|                 if bit: | ||||
|                     dic.one = next_dic | ||||
|                 else: | ||||
|                     dic.zero = next_dic | ||||
|             dic = next_dic | ||||
|         dic.match.set( | ||||
|         dic.set( | ||||
|             updated, | ||||
|             0,  # TODO Level | ||||
|             source or RulePath(), | ||||
|             level, | ||||
|             source, | ||||
|         ) | ||||
| 
 | ||||
|     def set_ip4address(self, | ||||
|  | @ -453,7 +554,6 @@ class Database(Profiler): | |||
|                        ) -> None: | ||||
|         self.enter_step('set_ip4add_pack') | ||||
|         ip4 = self.pack_ip4address(ip4address_str) | ||||
|         self.enter_step('set_ip4add_brws') | ||||
|         self._set_ip4(ip4, *args, **kwargs) | ||||
| 
 | ||||
|     def set_ip4network(self, | ||||
|  | @ -462,5 +562,4 @@ class Database(Profiler): | |||
|                        ) -> None: | ||||
|         self.enter_step('set_ip4net_pack') | ||||
|         ip4 = self.pack_ip4network(ip4network_str) | ||||
|         self.enter_step('set_ip4net_brws') | ||||
|         self._set_ip4(ip4, *args, **kwargs) | ||||
|  |  | |||
|  | @ -33,9 +33,11 @@ if __name__ == '__main__': | |||
|     DB = database.Database() | ||||
| 
 | ||||
|     if args.rules: | ||||
|         if not args.count: | ||||
|             raise NotImplementedError | ||||
|         if args.count: | ||||
|             print(DB.count_rules(first_party_only=args.first_party)) | ||||
|         else: | ||||
|             for line in DB.list_rules(): | ||||
|                 print(line) | ||||
|     else: | ||||
|         if args.count: | ||||
|             raise NotImplementedError | ||||
|  |  | |||
|  | @ -51,8 +51,7 @@ class Writer(multiprocessing.Process): | |||
| 
 | ||||
|                 try: | ||||
|                     for source in select(self.db, value): | ||||
|                         # write(self.db, name, updated, source=source) | ||||
|                         write(self.db, name, updated) | ||||
|                         write(self.db, name, updated, source=source) | ||||
|                 except ValueError: | ||||
|                     self.log.exception("Cannot execute: %s", record) | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue