Added level
Also fixed IP logic because this was real messed up
This commit is contained in:
parent
3197fa1663
commit
03a4042238
213
database.py
213
database.py
|
@ -26,38 +26,50 @@ class Path():
|
|||
|
||||
|
||||
class RulePath(Path):
|
||||
pass
|
||||
def __str__(self) -> str:
|
||||
return '(rules)'
|
||||
|
||||
|
||||
class DomainPath(Path):
|
||||
def __init__(self, path: typing.List[str]):
|
||||
self.path = path
|
||||
def __init__(self, parts: typing.List[str]):
|
||||
self.parts = parts
|
||||
|
||||
def __str__(self) -> str:
|
||||
return '?.' + Database.unpack_domain(self)
|
||||
|
||||
|
||||
class HostnamePath(DomainPath):
|
||||
pass
|
||||
def __str__(self) -> str:
|
||||
return Database.unpack_domain(self)
|
||||
|
||||
|
||||
class ZonePath(DomainPath):
|
||||
pass
|
||||
def __str__(self) -> str:
|
||||
return '*.' + Database.unpack_domain(self)
|
||||
|
||||
|
||||
class AsnPath(Path):
|
||||
def __init__(self, asn: Asn):
|
||||
self.asn = asn
|
||||
|
||||
def __str__(self) -> str:
|
||||
return Database.unpack_asn(self)
|
||||
|
||||
|
||||
class Ip4Path(Path):
|
||||
def __init__(self, value: int, prefixlen: int):
|
||||
self.value = value
|
||||
self.prefixlen = prefixlen
|
||||
|
||||
def __str__(self) -> str:
|
||||
return Database.unpack_ip4network(self)
|
||||
|
||||
|
||||
class Match():
|
||||
def __init__(self) -> None:
|
||||
self.updated: int = 0
|
||||
self.level: int = 0
|
||||
self.source: Path = RulePath()
|
||||
self.source: typing.Optional[Path] = None
|
||||
# FP dupplicate args
|
||||
|
||||
def set(self,
|
||||
|
@ -86,16 +98,16 @@ class DomainTreeNode():
|
|||
self.match_hostname = Match()
|
||||
|
||||
|
||||
class IpTreeNode():
|
||||
class IpTreeNode(Match):
|
||||
def __init__(self) -> None:
|
||||
Match.__init__(self)
|
||||
self.zero: typing.Optional[IpTreeNode] = None
|
||||
self.one: typing.Optional[IpTreeNode] = None
|
||||
self.match = Match()
|
||||
|
||||
|
||||
Node = typing.Union[DomainTreeNode, IpTreeNode, AsnNode]
|
||||
NodeCallable = typing.Callable[[Path,
|
||||
Node,
|
||||
MatchCallable = typing.Callable[[Path,
|
||||
Match,
|
||||
typing.Optional[typing.Any]],
|
||||
typing.Any]
|
||||
|
||||
|
@ -109,7 +121,6 @@ class Profiler():
|
|||
self.step_dict: typing.Dict[str, int] = dict()
|
||||
|
||||
def enter_step(self, name: str) -> None:
|
||||
return
|
||||
now = time.perf_counter()
|
||||
try:
|
||||
self.time_dict[self.time_step] += now - self.time_last
|
||||
|
@ -132,7 +143,7 @@ class Profiler():
|
|||
|
||||
|
||||
class Database(Profiler):
|
||||
VERSION = 11
|
||||
VERSION = 13
|
||||
PATH = "blocking.p"
|
||||
|
||||
def initialize(self) -> None:
|
||||
|
@ -181,7 +192,7 @@ class Database(Profiler):
|
|||
|
||||
@staticmethod
|
||||
def unpack_domain(domain: DomainPath) -> str:
|
||||
return '.'.join(domain.path[::-1])
|
||||
return '.'.join(domain.parts[::-1])
|
||||
|
||||
@staticmethod
|
||||
def pack_asn(asn: str) -> AsnPath:
|
||||
|
@ -230,62 +241,107 @@ class Database(Profiler):
|
|||
addr >>= 8
|
||||
return '.'.join(map(str, octets)) + '/' + str(network.prefixlen)
|
||||
|
||||
def get_match(self, path: Path) -> Match:
|
||||
if isinstance(path, RulePath):
|
||||
return Match()
|
||||
elif isinstance(path, AsnPath):
|
||||
return self.asns[path.asn]
|
||||
elif isinstance(path, DomainPath):
|
||||
dicd = self.domtree
|
||||
for part in path.parts:
|
||||
dicd = dicd.children[part]
|
||||
if isinstance(path, HostnamePath):
|
||||
return dicd.match_hostname
|
||||
elif isinstance(path, ZonePath):
|
||||
return dicd.match_zone
|
||||
else:
|
||||
raise ValueError
|
||||
elif isinstance(path, Ip4Path):
|
||||
dici = self.ip4tree
|
||||
for i in range(31, 31-path.prefixlen, -1):
|
||||
bit = (path.value >> i) & 0b1
|
||||
dici_next = dici.one if bit else dici.zero
|
||||
if not dici_next:
|
||||
raise IndexError
|
||||
dici = dici_next
|
||||
return dici
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
def exec_each_domain(self,
|
||||
callback: NodeCallable,
|
||||
callback: MatchCallable,
|
||||
arg: typing.Any = None,
|
||||
_dic: DomainTreeNode = None,
|
||||
_par: DomainPath = None,
|
||||
) -> typing.Any:
|
||||
_dic = _dic or self.domtree
|
||||
_par = _par or DomainPath([])
|
||||
yield from callback(_par, _dic, arg)
|
||||
if _dic.match_hostname.active():
|
||||
yield from callback(
|
||||
HostnamePath(_par.parts),
|
||||
_dic.match_hostname,
|
||||
arg
|
||||
)
|
||||
if _dic.match_zone.active():
|
||||
yield from callback(
|
||||
ZonePath(_par.parts),
|
||||
_dic.match_zone,
|
||||
arg
|
||||
)
|
||||
for part in _dic.children:
|
||||
dic = _dic.children[part]
|
||||
yield from self.exec_each_domain(
|
||||
callback,
|
||||
arg,
|
||||
_dic=dic,
|
||||
_par=DomainPath(_par.path + [part])
|
||||
_par=DomainPath(_par.parts + [part])
|
||||
)
|
||||
|
||||
def exec_each_ip4(self,
|
||||
callback: NodeCallable,
|
||||
callback: MatchCallable,
|
||||
arg: typing.Any = None,
|
||||
_dic: IpTreeNode = None,
|
||||
_par: Ip4Path = None,
|
||||
) -> typing.Any:
|
||||
_dic = _dic or self.ip4tree
|
||||
_par = _par or Ip4Path(0, 0)
|
||||
callback(_par, _dic, arg)
|
||||
if _dic.active():
|
||||
yield from callback(
|
||||
_par,
|
||||
_dic,
|
||||
arg
|
||||
)
|
||||
|
||||
# 0
|
||||
pref = _par.prefixlen + 1
|
||||
dic = _dic.zero
|
||||
if dic:
|
||||
addr0 = _par.value & (0xFFFFFFFF ^ (1 << (32-_par.prefixlen)))
|
||||
addr0 = _par.value & (0xFFFFFFFF ^ (1 << (32-pref)))
|
||||
assert addr0 == _par.value
|
||||
yield from self.exec_each_ip4(
|
||||
callback,
|
||||
arg,
|
||||
_dic=dic,
|
||||
_par=Ip4Path(addr0, _par.prefixlen+1)
|
||||
_par=Ip4Path(addr0, pref)
|
||||
)
|
||||
# 1
|
||||
dic = _dic.one
|
||||
if dic:
|
||||
addr1 = _par.value | (1 << (32-_par.prefixlen))
|
||||
addr1 = _par.value | (1 << (32-pref))
|
||||
yield from self.exec_each_ip4(
|
||||
callback,
|
||||
arg,
|
||||
_dic=dic,
|
||||
_par=Ip4Path(addr1, _par.prefixlen+1)
|
||||
_par=Ip4Path(addr1, pref)
|
||||
)
|
||||
|
||||
def exec_each(self,
|
||||
callback: NodeCallable,
|
||||
callback: MatchCallable,
|
||||
arg: typing.Any = None,
|
||||
) -> typing.Any:
|
||||
yield from self.exec_each_domain(callback)
|
||||
yield from self.exec_each_ip4(callback)
|
||||
# TODO ASN
|
||||
|
||||
def update_references(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
@ -293,27 +349,47 @@ class Database(Profiler):
|
|||
def prune(self, before: int, base_only: bool = False) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def explain(self, entry: int) -> str:
|
||||
raise NotImplementedError
|
||||
def explain(self, path: Path) -> str:
|
||||
string = str(path)
|
||||
match = self.get_match(path)
|
||||
if match.source:
|
||||
string += f' ← {self.explain(match.source)}'
|
||||
return string
|
||||
|
||||
def export(self,
|
||||
first_party_only: bool = False,
|
||||
end_chain_only: bool = False,
|
||||
explain: bool = False,
|
||||
) -> typing.Iterable[str]:
|
||||
if first_party_only or end_chain_only or explain:
|
||||
if first_party_only or end_chain_only:
|
||||
raise NotImplementedError
|
||||
|
||||
def export_cb(path: Path, node: Node, _: typing.Any
|
||||
def export_cb(path: Path, match: Match, _: typing.Any
|
||||
) -> typing.Iterable[str]:
|
||||
assert isinstance(path, DomainPath)
|
||||
assert isinstance(node, DomainTreeNode)
|
||||
if node.match_hostname:
|
||||
a = self.unpack_domain(path)
|
||||
yield a
|
||||
if isinstance(path, HostnamePath):
|
||||
if explain:
|
||||
yield self.explain(path)
|
||||
else:
|
||||
yield self.unpack_domain(path)
|
||||
|
||||
yield from self.exec_each_domain(export_cb, None)
|
||||
|
||||
def list_rules(self,
|
||||
first_party_only: bool = False,
|
||||
) -> typing.Iterable[str]:
|
||||
if first_party_only:
|
||||
raise NotImplementedError
|
||||
|
||||
def list_rules_cb(path: Path, match: Match, _: typing.Any
|
||||
) -> typing.Iterable[str]:
|
||||
if isinstance(path, ZonePath) \
|
||||
or (isinstance(path, Ip4Path) and path.prefixlen < 32):
|
||||
# if match.level == 0:
|
||||
yield self.explain(path)
|
||||
|
||||
yield from self.exec_each(list_rules_cb, None)
|
||||
|
||||
def count_rules(self,
|
||||
first_party_only: bool = False,
|
||||
) -> str:
|
||||
|
@ -325,10 +401,10 @@ class Database(Profiler):
|
|||
self.enter_step('get_domain_brws')
|
||||
dic = self.domtree
|
||||
depth = 0
|
||||
for part in domain.path:
|
||||
for part in domain.parts:
|
||||
if dic.match_zone.active():
|
||||
self.enter_step('get_domain_yield')
|
||||
yield ZonePath(domain.path[:depth])
|
||||
yield ZonePath(domain.parts[:depth])
|
||||
self.enter_step('get_domain_brws')
|
||||
if part not in dic.children:
|
||||
return
|
||||
|
@ -336,27 +412,28 @@ class Database(Profiler):
|
|||
depth += 1
|
||||
if dic.match_zone.active():
|
||||
self.enter_step('get_domain_yield')
|
||||
yield ZonePath(domain.path)
|
||||
yield ZonePath(domain.parts)
|
||||
if dic.match_hostname.active():
|
||||
self.enter_step('get_domain_yield')
|
||||
yield HostnamePath(domain.path)
|
||||
yield HostnamePath(domain.parts)
|
||||
|
||||
def get_ip4(self, ip4_str: str) -> typing.Iterable[Path]:
|
||||
self.enter_step('get_ip4_pack')
|
||||
ip4 = self.pack_ip4address(ip4_str)
|
||||
self.enter_step('get_ip4_brws')
|
||||
dic = self.ip4tree
|
||||
for i in reversed(range(ip4.prefixlen)):
|
||||
part = (ip4.value >> i) & 0b1
|
||||
if dic.match.active():
|
||||
for i in range(31, 31-ip4.prefixlen, -1):
|
||||
bit = (ip4.value >> i) & 0b1
|
||||
if dic.active():
|
||||
self.enter_step('get_ip4_yield')
|
||||
yield Ip4Path(ip4.value, 32-i)
|
||||
a = Ip4Path(ip4.value >> (i+1) << (i+1), 31-i)
|
||||
yield a
|
||||
self.enter_step('get_ip4_brws')
|
||||
next_dic = dic.one if part else dic.zero
|
||||
next_dic = dic.one if bit else dic.zero
|
||||
if next_dic is None:
|
||||
return
|
||||
dic = next_dic
|
||||
if dic.match.active():
|
||||
if dic.active():
|
||||
self.enter_step('get_ip4_yield')
|
||||
yield ip4
|
||||
|
||||
|
@ -374,9 +451,16 @@ class Database(Profiler):
|
|||
if is_first_party:
|
||||
raise NotImplementedError
|
||||
domain = self.pack_domain(domain_str)
|
||||
self.enter_step('set_domain_src')
|
||||
if source is None:
|
||||
level = 0
|
||||
source = RulePath()
|
||||
else:
|
||||
match = self.get_match(source)
|
||||
level = match.level + 1
|
||||
self.enter_step('set_domain_brws')
|
||||
dic = self.domtree
|
||||
for part in domain.path:
|
||||
for part in domain.parts:
|
||||
if dic.match_zone.active():
|
||||
# Refuse to add domain whose zone is already matching
|
||||
return
|
||||
|
@ -389,8 +473,8 @@ class Database(Profiler):
|
|||
match = dic.match_zone
|
||||
match.set(
|
||||
updated,
|
||||
0, # TODO Level
|
||||
source or RulePath(),
|
||||
level,
|
||||
source,
|
||||
)
|
||||
|
||||
def set_hostname(self,
|
||||
|
@ -411,14 +495,23 @@ class Database(Profiler):
|
|||
self.enter_step('set_asn')
|
||||
if is_first_party:
|
||||
raise NotImplementedError
|
||||
if source is None:
|
||||
level = 0
|
||||
source = RulePath()
|
||||
else:
|
||||
match = self.get_match(source)
|
||||
level = match.level + 1
|
||||
path = self.pack_asn(asn_str)
|
||||
if path.asn in self.asns:
|
||||
match = self.asns[path.asn]
|
||||
else:
|
||||
match = AsnNode()
|
||||
self.asns[path.asn] = match
|
||||
match.set(
|
||||
updated,
|
||||
0,
|
||||
source or RulePath()
|
||||
level,
|
||||
source,
|
||||
)
|
||||
self.asns[path.asn] = match
|
||||
|
||||
def _set_ip4(self,
|
||||
ip4: Ip4Path,
|
||||
|
@ -427,24 +520,32 @@ class Database(Profiler):
|
|||
source: Path = None) -> None:
|
||||
if is_first_party:
|
||||
raise NotImplementedError
|
||||
self.enter_step('set_ip4_src')
|
||||
if source is None:
|
||||
level = 0
|
||||
source = RulePath()
|
||||
else:
|
||||
match = self.get_match(source)
|
||||
level = match.level + 1
|
||||
self.enter_step('set_ip4_brws')
|
||||
dic = self.ip4tree
|
||||
for i in reversed(range(ip4.prefixlen)):
|
||||
part = (ip4.value >> i) & 0b1
|
||||
if dic.match.active():
|
||||
for i in range(31, 31-ip4.prefixlen, -1):
|
||||
bit = (ip4.value >> i) & 0b1
|
||||
if dic.active():
|
||||
# Refuse to add ip4* whose network is already matching
|
||||
return
|
||||
next_dic = dic.one if part else dic.zero
|
||||
next_dic = dic.one if bit else dic.zero
|
||||
if next_dic is None:
|
||||
next_dic = IpTreeNode()
|
||||
if part:
|
||||
if bit:
|
||||
dic.one = next_dic
|
||||
else:
|
||||
dic.zero = next_dic
|
||||
dic = next_dic
|
||||
dic.match.set(
|
||||
dic.set(
|
||||
updated,
|
||||
0, # TODO Level
|
||||
source or RulePath(),
|
||||
level,
|
||||
source,
|
||||
)
|
||||
|
||||
def set_ip4address(self,
|
||||
|
@ -453,7 +554,6 @@ class Database(Profiler):
|
|||
) -> None:
|
||||
self.enter_step('set_ip4add_pack')
|
||||
ip4 = self.pack_ip4address(ip4address_str)
|
||||
self.enter_step('set_ip4add_brws')
|
||||
self._set_ip4(ip4, *args, **kwargs)
|
||||
|
||||
def set_ip4network(self,
|
||||
|
@ -462,5 +562,4 @@ class Database(Profiler):
|
|||
) -> None:
|
||||
self.enter_step('set_ip4net_pack')
|
||||
ip4 = self.pack_ip4network(ip4network_str)
|
||||
self.enter_step('set_ip4net_brws')
|
||||
self._set_ip4(ip4, *args, **kwargs)
|
||||
|
|
|
@ -33,9 +33,11 @@ if __name__ == '__main__':
|
|||
DB = database.Database()
|
||||
|
||||
if args.rules:
|
||||
if not args.count:
|
||||
raise NotImplementedError
|
||||
if args.count:
|
||||
print(DB.count_rules(first_party_only=args.first_party))
|
||||
else:
|
||||
for line in DB.list_rules():
|
||||
print(line)
|
||||
else:
|
||||
if args.count:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -51,8 +51,7 @@ class Writer(multiprocessing.Process):
|
|||
|
||||
try:
|
||||
for source in select(self.db, value):
|
||||
# write(self.db, name, updated, source=source)
|
||||
write(self.db, name, updated)
|
||||
write(self.db, name, updated, source=source)
|
||||
except ValueError:
|
||||
self.log.exception("Cannot execute: %s", record)
|
||||
|
||||
|
|
Loading…
Reference in a new issue