Added level

Also fixed IP logic because this was real messed up
This commit is contained in:
Geoffrey Frogeye 2019-12-16 09:31:29 +01:00
parent 3197fa1663
commit 03a4042238
Signed by: geoffrey
GPG key ID: D8A7ECA00A8CD3DD
3 changed files with 167 additions and 67 deletions

View file

@ -26,38 +26,50 @@ class Path():
class RulePath(Path): class RulePath(Path):
pass def __str__(self) -> str:
return '(rules)'
class DomainPath(Path): class DomainPath(Path):
def __init__(self, path: typing.List[str]): def __init__(self, parts: typing.List[str]):
self.path = path self.parts = parts
def __str__(self) -> str:
return '?.' + Database.unpack_domain(self)
class HostnamePath(DomainPath): class HostnamePath(DomainPath):
pass def __str__(self) -> str:
return Database.unpack_domain(self)
class ZonePath(DomainPath): class ZonePath(DomainPath):
pass def __str__(self) -> str:
return '*.' + Database.unpack_domain(self)
class AsnPath(Path): class AsnPath(Path):
def __init__(self, asn: Asn): def __init__(self, asn: Asn):
self.asn = asn self.asn = asn
def __str__(self) -> str:
return Database.unpack_asn(self)
class Ip4Path(Path): class Ip4Path(Path):
def __init__(self, value: int, prefixlen: int): def __init__(self, value: int, prefixlen: int):
self.value = value self.value = value
self.prefixlen = prefixlen self.prefixlen = prefixlen
def __str__(self) -> str:
return Database.unpack_ip4network(self)
class Match(): class Match():
def __init__(self) -> None: def __init__(self) -> None:
self.updated: int = 0 self.updated: int = 0
self.level: int = 0 self.level: int = 0
self.source: Path = RulePath() self.source: typing.Optional[Path] = None
# FP dupplicate args # FP dupplicate args
def set(self, def set(self,
@ -86,18 +98,18 @@ class DomainTreeNode():
self.match_hostname = Match() self.match_hostname = Match()
class IpTreeNode(): class IpTreeNode(Match):
def __init__(self) -> None: def __init__(self) -> None:
Match.__init__(self)
self.zero: typing.Optional[IpTreeNode] = None self.zero: typing.Optional[IpTreeNode] = None
self.one: typing.Optional[IpTreeNode] = None self.one: typing.Optional[IpTreeNode] = None
self.match = Match()
Node = typing.Union[DomainTreeNode, IpTreeNode, AsnNode] Node = typing.Union[DomainTreeNode, IpTreeNode, AsnNode]
NodeCallable = typing.Callable[[Path, MatchCallable = typing.Callable[[Path,
Node, Match,
typing.Optional[typing.Any]], typing.Optional[typing.Any]],
typing.Any] typing.Any]
class Profiler(): class Profiler():
@ -109,7 +121,6 @@ class Profiler():
self.step_dict: typing.Dict[str, int] = dict() self.step_dict: typing.Dict[str, int] = dict()
def enter_step(self, name: str) -> None: def enter_step(self, name: str) -> None:
return
now = time.perf_counter() now = time.perf_counter()
try: try:
self.time_dict[self.time_step] += now - self.time_last self.time_dict[self.time_step] += now - self.time_last
@ -132,7 +143,7 @@ class Profiler():
class Database(Profiler): class Database(Profiler):
VERSION = 11 VERSION = 13
PATH = "blocking.p" PATH = "blocking.p"
def initialize(self) -> None: def initialize(self) -> None:
@ -181,7 +192,7 @@ class Database(Profiler):
@staticmethod @staticmethod
def unpack_domain(domain: DomainPath) -> str: def unpack_domain(domain: DomainPath) -> str:
return '.'.join(domain.path[::-1]) return '.'.join(domain.parts[::-1])
@staticmethod @staticmethod
def pack_asn(asn: str) -> AsnPath: def pack_asn(asn: str) -> AsnPath:
@ -230,62 +241,107 @@ class Database(Profiler):
addr >>= 8 addr >>= 8
return '.'.join(map(str, octets)) + '/' + str(network.prefixlen) return '.'.join(map(str, octets)) + '/' + str(network.prefixlen)
def get_match(self, path: Path) -> Match:
if isinstance(path, RulePath):
return Match()
elif isinstance(path, AsnPath):
return self.asns[path.asn]
elif isinstance(path, DomainPath):
dicd = self.domtree
for part in path.parts:
dicd = dicd.children[part]
if isinstance(path, HostnamePath):
return dicd.match_hostname
elif isinstance(path, ZonePath):
return dicd.match_zone
else:
raise ValueError
elif isinstance(path, Ip4Path):
dici = self.ip4tree
for i in range(31, 31-path.prefixlen, -1):
bit = (path.value >> i) & 0b1
dici_next = dici.one if bit else dici.zero
if not dici_next:
raise IndexError
dici = dici_next
return dici
else:
raise ValueError
def exec_each_domain(self, def exec_each_domain(self,
callback: NodeCallable, callback: MatchCallable,
arg: typing.Any = None, arg: typing.Any = None,
_dic: DomainTreeNode = None, _dic: DomainTreeNode = None,
_par: DomainPath = None, _par: DomainPath = None,
) -> typing.Any: ) -> typing.Any:
_dic = _dic or self.domtree _dic = _dic or self.domtree
_par = _par or DomainPath([]) _par = _par or DomainPath([])
yield from callback(_par, _dic, arg) if _dic.match_hostname.active():
yield from callback(
HostnamePath(_par.parts),
_dic.match_hostname,
arg
)
if _dic.match_zone.active():
yield from callback(
ZonePath(_par.parts),
_dic.match_zone,
arg
)
for part in _dic.children: for part in _dic.children:
dic = _dic.children[part] dic = _dic.children[part]
yield from self.exec_each_domain( yield from self.exec_each_domain(
callback, callback,
arg, arg,
_dic=dic, _dic=dic,
_par=DomainPath(_par.path + [part]) _par=DomainPath(_par.parts + [part])
) )
def exec_each_ip4(self, def exec_each_ip4(self,
callback: NodeCallable, callback: MatchCallable,
arg: typing.Any = None, arg: typing.Any = None,
_dic: IpTreeNode = None, _dic: IpTreeNode = None,
_par: Ip4Path = None, _par: Ip4Path = None,
) -> typing.Any: ) -> typing.Any:
_dic = _dic or self.ip4tree _dic = _dic or self.ip4tree
_par = _par or Ip4Path(0, 0) _par = _par or Ip4Path(0, 0)
callback(_par, _dic, arg) if _dic.active():
yield from callback(
_par,
_dic,
arg
)
# 0 # 0
pref = _par.prefixlen + 1
dic = _dic.zero dic = _dic.zero
if dic: if dic:
addr0 = _par.value & (0xFFFFFFFF ^ (1 << (32-_par.prefixlen))) addr0 = _par.value & (0xFFFFFFFF ^ (1 << (32-pref)))
assert addr0 == _par.value assert addr0 == _par.value
yield from self.exec_each_ip4( yield from self.exec_each_ip4(
callback, callback,
arg, arg,
_dic=dic, _dic=dic,
_par=Ip4Path(addr0, _par.prefixlen+1) _par=Ip4Path(addr0, pref)
) )
# 1 # 1
dic = _dic.one dic = _dic.one
if dic: if dic:
addr1 = _par.value | (1 << (32-_par.prefixlen)) addr1 = _par.value | (1 << (32-pref))
yield from self.exec_each_ip4( yield from self.exec_each_ip4(
callback, callback,
arg, arg,
_dic=dic, _dic=dic,
_par=Ip4Path(addr1, _par.prefixlen+1) _par=Ip4Path(addr1, pref)
) )
def exec_each(self, def exec_each(self,
callback: NodeCallable, callback: MatchCallable,
arg: typing.Any = None, arg: typing.Any = None,
) -> typing.Any: ) -> typing.Any:
yield from self.exec_each_domain(callback) yield from self.exec_each_domain(callback)
yield from self.exec_each_ip4(callback) yield from self.exec_each_ip4(callback)
# TODO ASN
def update_references(self) -> None: def update_references(self) -> None:
raise NotImplementedError raise NotImplementedError
@ -293,27 +349,47 @@ class Database(Profiler):
def prune(self, before: int, base_only: bool = False) -> None: def prune(self, before: int, base_only: bool = False) -> None:
raise NotImplementedError raise NotImplementedError
def explain(self, entry: int) -> str: def explain(self, path: Path) -> str:
raise NotImplementedError string = str(path)
match = self.get_match(path)
if match.source:
string += f'{self.explain(match.source)}'
return string
def export(self, def export(self,
first_party_only: bool = False, first_party_only: bool = False,
end_chain_only: bool = False, end_chain_only: bool = False,
explain: bool = False, explain: bool = False,
) -> typing.Iterable[str]: ) -> typing.Iterable[str]:
if first_party_only or end_chain_only or explain: if first_party_only or end_chain_only:
raise NotImplementedError raise NotImplementedError
def export_cb(path: Path, node: Node, _: typing.Any def export_cb(path: Path, match: Match, _: typing.Any
) -> typing.Iterable[str]: ) -> typing.Iterable[str]:
assert isinstance(path, DomainPath) assert isinstance(path, DomainPath)
assert isinstance(node, DomainTreeNode) if isinstance(path, HostnamePath):
if node.match_hostname: if explain:
a = self.unpack_domain(path) yield self.explain(path)
yield a else:
yield self.unpack_domain(path)
yield from self.exec_each_domain(export_cb, None) yield from self.exec_each_domain(export_cb, None)
def list_rules(self,
first_party_only: bool = False,
) -> typing.Iterable[str]:
if first_party_only:
raise NotImplementedError
def list_rules_cb(path: Path, match: Match, _: typing.Any
) -> typing.Iterable[str]:
if isinstance(path, ZonePath) \
or (isinstance(path, Ip4Path) and path.prefixlen < 32):
# if match.level == 0:
yield self.explain(path)
yield from self.exec_each(list_rules_cb, None)
def count_rules(self, def count_rules(self,
first_party_only: bool = False, first_party_only: bool = False,
) -> str: ) -> str:
@ -325,10 +401,10 @@ class Database(Profiler):
self.enter_step('get_domain_brws') self.enter_step('get_domain_brws')
dic = self.domtree dic = self.domtree
depth = 0 depth = 0
for part in domain.path: for part in domain.parts:
if dic.match_zone.active(): if dic.match_zone.active():
self.enter_step('get_domain_yield') self.enter_step('get_domain_yield')
yield ZonePath(domain.path[:depth]) yield ZonePath(domain.parts[:depth])
self.enter_step('get_domain_brws') self.enter_step('get_domain_brws')
if part not in dic.children: if part not in dic.children:
return return
@ -336,27 +412,28 @@ class Database(Profiler):
depth += 1 depth += 1
if dic.match_zone.active(): if dic.match_zone.active():
self.enter_step('get_domain_yield') self.enter_step('get_domain_yield')
yield ZonePath(domain.path) yield ZonePath(domain.parts)
if dic.match_hostname.active(): if dic.match_hostname.active():
self.enter_step('get_domain_yield') self.enter_step('get_domain_yield')
yield HostnamePath(domain.path) yield HostnamePath(domain.parts)
def get_ip4(self, ip4_str: str) -> typing.Iterable[Path]: def get_ip4(self, ip4_str: str) -> typing.Iterable[Path]:
self.enter_step('get_ip4_pack') self.enter_step('get_ip4_pack')
ip4 = self.pack_ip4address(ip4_str) ip4 = self.pack_ip4address(ip4_str)
self.enter_step('get_ip4_brws') self.enter_step('get_ip4_brws')
dic = self.ip4tree dic = self.ip4tree
for i in reversed(range(ip4.prefixlen)): for i in range(31, 31-ip4.prefixlen, -1):
part = (ip4.value >> i) & 0b1 bit = (ip4.value >> i) & 0b1
if dic.match.active(): if dic.active():
self.enter_step('get_ip4_yield') self.enter_step('get_ip4_yield')
yield Ip4Path(ip4.value, 32-i) a = Ip4Path(ip4.value >> (i+1) << (i+1), 31-i)
self.enter_step('get_ip4_brws') yield a
next_dic = dic.one if part else dic.zero self.enter_step('get_ip4_brws')
next_dic = dic.one if bit else dic.zero
if next_dic is None: if next_dic is None:
return return
dic = next_dic dic = next_dic
if dic.match.active(): if dic.active():
self.enter_step('get_ip4_yield') self.enter_step('get_ip4_yield')
yield ip4 yield ip4
@ -374,9 +451,16 @@ class Database(Profiler):
if is_first_party: if is_first_party:
raise NotImplementedError raise NotImplementedError
domain = self.pack_domain(domain_str) domain = self.pack_domain(domain_str)
self.enter_step('set_domain_src')
if source is None:
level = 0
source = RulePath()
else:
match = self.get_match(source)
level = match.level + 1
self.enter_step('set_domain_brws') self.enter_step('set_domain_brws')
dic = self.domtree dic = self.domtree
for part in domain.path: for part in domain.parts:
if dic.match_zone.active(): if dic.match_zone.active():
# Refuse to add domain whose zone is already matching # Refuse to add domain whose zone is already matching
return return
@ -389,8 +473,8 @@ class Database(Profiler):
match = dic.match_zone match = dic.match_zone
match.set( match.set(
updated, updated,
0, # TODO Level level,
source or RulePath(), source,
) )
def set_hostname(self, def set_hostname(self,
@ -411,14 +495,23 @@ class Database(Profiler):
self.enter_step('set_asn') self.enter_step('set_asn')
if is_first_party: if is_first_party:
raise NotImplementedError raise NotImplementedError
if source is None:
level = 0
source = RulePath()
else:
match = self.get_match(source)
level = match.level + 1
path = self.pack_asn(asn_str) path = self.pack_asn(asn_str)
match = AsnNode() if path.asn in self.asns:
match = self.asns[path.asn]
else:
match = AsnNode()
self.asns[path.asn] = match
match.set( match.set(
updated, updated,
0, level,
source or RulePath() source,
) )
self.asns[path.asn] = match
def _set_ip4(self, def _set_ip4(self,
ip4: Ip4Path, ip4: Ip4Path,
@ -427,24 +520,32 @@ class Database(Profiler):
source: Path = None) -> None: source: Path = None) -> None:
if is_first_party: if is_first_party:
raise NotImplementedError raise NotImplementedError
self.enter_step('set_ip4_src')
if source is None:
level = 0
source = RulePath()
else:
match = self.get_match(source)
level = match.level + 1
self.enter_step('set_ip4_brws')
dic = self.ip4tree dic = self.ip4tree
for i in reversed(range(ip4.prefixlen)): for i in range(31, 31-ip4.prefixlen, -1):
part = (ip4.value >> i) & 0b1 bit = (ip4.value >> i) & 0b1
if dic.match.active(): if dic.active():
# Refuse to add ip4* whose network is already matching # Refuse to add ip4* whose network is already matching
return return
next_dic = dic.one if part else dic.zero next_dic = dic.one if bit else dic.zero
if next_dic is None: if next_dic is None:
next_dic = IpTreeNode() next_dic = IpTreeNode()
if part: if bit:
dic.one = next_dic dic.one = next_dic
else: else:
dic.zero = next_dic dic.zero = next_dic
dic = next_dic dic = next_dic
dic.match.set( dic.set(
updated, updated,
0, # TODO Level level,
source or RulePath(), source,
) )
def set_ip4address(self, def set_ip4address(self,
@ -453,7 +554,6 @@ class Database(Profiler):
) -> None: ) -> None:
self.enter_step('set_ip4add_pack') self.enter_step('set_ip4add_pack')
ip4 = self.pack_ip4address(ip4address_str) ip4 = self.pack_ip4address(ip4address_str)
self.enter_step('set_ip4add_brws')
self._set_ip4(ip4, *args, **kwargs) self._set_ip4(ip4, *args, **kwargs)
def set_ip4network(self, def set_ip4network(self,
@ -462,5 +562,4 @@ class Database(Profiler):
) -> None: ) -> None:
self.enter_step('set_ip4net_pack') self.enter_step('set_ip4net_pack')
ip4 = self.pack_ip4network(ip4network_str) ip4 = self.pack_ip4network(ip4network_str)
self.enter_step('set_ip4net_brws')
self._set_ip4(ip4, *args, **kwargs) self._set_ip4(ip4, *args, **kwargs)

View file

@ -33,9 +33,11 @@ if __name__ == '__main__':
DB = database.Database() DB = database.Database()
if args.rules: if args.rules:
if not args.count: if args.count:
raise NotImplementedError print(DB.count_rules(first_party_only=args.first_party))
print(DB.count_rules(first_party_only=args.first_party)) else:
for line in DB.list_rules():
print(line)
else: else:
if args.count: if args.count:
raise NotImplementedError raise NotImplementedError

View file

@ -51,8 +51,7 @@ class Writer(multiprocessing.Process):
try: try:
for source in select(self.db, value): for source in select(self.db, value):
# write(self.db, name, updated, source=source) write(self.db, name, updated, source=source)
write(self.db, name, updated)
except ValueError: except ValueError:
self.log.exception("Cannot execute: %s", record) self.log.exception("Cannot execute: %s", record)