Made references work

This commit is contained in:
Geoffrey Frogeye 2019-12-16 14:18:03 +01:00
parent 03a4042238
commit c3bf102289
Signed by: geoffrey
GPG key ID: D8A7ECA00A8CD3DD
2 changed files with 103 additions and 63 deletions

View file

@ -70,19 +70,9 @@ class Match():
self.updated: int = 0 self.updated: int = 0
self.level: int = 0 self.level: int = 0
self.source: typing.Optional[Path] = None self.source: typing.Optional[Path] = None
self.references: int = 0
# FP dupplicate args # FP dupplicate args
def set(self,
updated: int,
level: int,
source: Path,
) -> None:
if updated > self.updated or level > self.level:
self.updated = updated
self.level = level
self.source = source
# FP dupplicate function
def active(self) -> bool: def active(self) -> bool:
return self.updated > 0 return self.updated > 0
@ -143,7 +133,7 @@ class Profiler():
class Database(Profiler): class Database(Profiler):
VERSION = 13 VERSION = 14
PATH = "blocking.p" PATH = "blocking.p"
def initialize(self) -> None: def initialize(self) -> None:
@ -268,6 +258,24 @@ class Database(Profiler):
else: else:
raise ValueError raise ValueError
def exec_each_asn(self,
callback: MatchCallable,
arg: typing.Any = None,
) -> typing.Any:
for asn in self.asns:
match = self.asns[asn]
if match.active():
c = callback(
AsnPath(asn),
match,
arg
)
try:
yield from c
except TypeError: # not iterable
pass
def exec_each_domain(self, def exec_each_domain(self,
callback: MatchCallable, callback: MatchCallable,
arg: typing.Any = None, arg: typing.Any = None,
@ -277,17 +285,25 @@ class Database(Profiler):
_dic = _dic or self.domtree _dic = _dic or self.domtree
_par = _par or DomainPath([]) _par = _par or DomainPath([])
if _dic.match_hostname.active(): if _dic.match_hostname.active():
yield from callback( c = callback(
HostnamePath(_par.parts), HostnamePath(_par.parts),
_dic.match_hostname, _dic.match_hostname,
arg arg
) )
try:
yield from c
except TypeError: # not iterable
pass
if _dic.match_zone.active(): if _dic.match_zone.active():
yield from callback( c = callback(
ZonePath(_par.parts), ZonePath(_par.parts),
_dic.match_zone, _dic.match_zone,
arg arg
) )
try:
yield from c
except TypeError: # not iterable
pass
for part in _dic.children: for part in _dic.children:
dic = _dic.children[part] dic = _dic.children[part]
yield from self.exec_each_domain( yield from self.exec_each_domain(
@ -306,11 +322,15 @@ class Database(Profiler):
_dic = _dic or self.ip4tree _dic = _dic or self.ip4tree
_par = _par or Ip4Path(0, 0) _par = _par or Ip4Path(0, 0)
if _dic.active(): if _dic.active():
yield from callback( c = callback(
_par, _par,
_dic, _dic,
arg arg
) )
try:
yield from c
except TypeError: # not iterable
pass
# 0 # 0
pref = _par.prefixlen + 1 pref = _par.prefixlen + 1
@ -341,17 +361,35 @@ class Database(Profiler):
) -> typing.Any: ) -> typing.Any:
yield from self.exec_each_domain(callback) yield from self.exec_each_domain(callback)
yield from self.exec_each_ip4(callback) yield from self.exec_each_ip4(callback)
# TODO ASN yield from self.exec_each_asn(callback)
def update_references(self) -> None: def update_references(self) -> None:
raise NotImplementedError # Should be correctly calculated normally,
# keeping this just in case
def reset_references_cb(path: Path,
match: Match, _: typing.Any
) -> None:
match.references = 0
for _ in self.exec_each(reset_references_cb, None):
pass
def increment_references_cb(path: Path,
match: Match, _: typing.Any
) -> None:
if match.source:
source = self.get_match(match.source)
source.references += 1
for _ in self.exec_each(increment_references_cb, None):
pass
def prune(self, before: int, base_only: bool = False) -> None: def prune(self, before: int, base_only: bool = False) -> None:
raise NotImplementedError raise NotImplementedError
def explain(self, path: Path) -> str: def explain(self, path: Path) -> str:
string = str(path)
match = self.get_match(path) match = self.get_match(path)
string = f'{path}'
if not isinstance(path, RulePath):
string += f' #{match.references}'
if match.source: if match.source:
string += f'{self.explain(match.source)}' string += f'{self.explain(match.source)}'
return string return string
@ -361,17 +399,20 @@ class Database(Profiler):
end_chain_only: bool = False, end_chain_only: bool = False,
explain: bool = False, explain: bool = False,
) -> typing.Iterable[str]: ) -> typing.Iterable[str]:
if first_party_only or end_chain_only: if first_party_only:
raise NotImplementedError raise NotImplementedError
def export_cb(path: Path, match: Match, _: typing.Any def export_cb(path: Path, match: Match, _: typing.Any
) -> typing.Iterable[str]: ) -> typing.Iterable[str]:
assert isinstance(path, DomainPath) assert isinstance(path, DomainPath)
if isinstance(path, HostnamePath): if not isinstance(path, HostnamePath):
if explain: return
yield self.explain(path) if end_chain_only and match.references > 0:
else: return
yield self.unpack_domain(path) if explain:
yield self.explain(path)
else:
yield self.unpack_domain(path)
yield from self.exec_each_domain(export_cb, None) yield from self.exec_each_domain(export_cb, None)
@ -437,9 +478,22 @@ class Database(Profiler):
self.enter_step('get_ip4_yield') self.enter_step('get_ip4_yield')
yield ip4 yield ip4
def list_asn(self) -> typing.Iterable[AsnPath]: def set_match(self,
for asn in self.asns: match: Match,
yield AsnPath(asn) updated: int,
source: Path,
) -> None:
new_source = self.get_match(source)
new_level = new_source.level + 1
if updated > match.updated or new_level > match.level:
if match.source:
old_source = self.get_match(match.source)
old_source.references -= 1
match.updated = updated
match.level = new_level
match.source = source
new_source.references += 1
# FP dupplicate function
def _set_domain(self, def _set_domain(self,
hostname: bool, hostname: bool,
@ -451,30 +505,23 @@ class Database(Profiler):
if is_first_party: if is_first_party:
raise NotImplementedError raise NotImplementedError
domain = self.pack_domain(domain_str) domain = self.pack_domain(domain_str)
self.enter_step('set_domain_src')
if source is None:
level = 0
source = RulePath()
else:
match = self.get_match(source)
level = match.level + 1
self.enter_step('set_domain_brws') self.enter_step('set_domain_brws')
dic = self.domtree dic = self.domtree
for part in domain.parts: for part in domain.parts:
if dic.match_zone.active():
# Refuse to add domain whose zone is already matching
return
if part not in dic.children: if part not in dic.children:
dic.children[part] = DomainTreeNode() dic.children[part] = DomainTreeNode()
dic = dic.children[part] dic = dic.children[part]
if dic.match_zone.active():
# Refuse to add domain whose zone is already matching
return
if hostname: if hostname:
match = dic.match_hostname match = dic.match_hostname
else: else:
match = dic.match_zone match = dic.match_zone
match.set( self.set_match(
match,
updated, updated,
level, source or RulePath(),
source,
) )
def set_hostname(self, def set_hostname(self,
@ -495,22 +542,16 @@ class Database(Profiler):
self.enter_step('set_asn') self.enter_step('set_asn')
if is_first_party: if is_first_party:
raise NotImplementedError raise NotImplementedError
if source is None:
level = 0
source = RulePath()
else:
match = self.get_match(source)
level = match.level + 1
path = self.pack_asn(asn_str) path = self.pack_asn(asn_str)
if path.asn in self.asns: if path.asn in self.asns:
match = self.asns[path.asn] match = self.asns[path.asn]
else: else:
match = AsnNode() match = AsnNode()
self.asns[path.asn] = match self.asns[path.asn] = match
match.set( self.set_match(
match,
updated, updated,
level, source or RulePath(),
source,
) )
def _set_ip4(self, def _set_ip4(self,
@ -520,20 +561,10 @@ class Database(Profiler):
source: Path = None) -> None: source: Path = None) -> None:
if is_first_party: if is_first_party:
raise NotImplementedError raise NotImplementedError
self.enter_step('set_ip4_src')
if source is None:
level = 0
source = RulePath()
else:
match = self.get_match(source)
level = match.level + 1
self.enter_step('set_ip4_brws') self.enter_step('set_ip4_brws')
dic = self.ip4tree dic = self.ip4tree
for i in range(31, 31-ip4.prefixlen, -1): for i in range(31, 31-ip4.prefixlen, -1):
bit = (ip4.value >> i) & 0b1 bit = (ip4.value >> i) & 0b1
if dic.active():
# Refuse to add ip4* whose network is already matching
return
next_dic = dic.one if bit else dic.zero next_dic = dic.one if bit else dic.zero
if next_dic is None: if next_dic is None:
next_dic = IpTreeNode() next_dic = IpTreeNode()
@ -542,10 +573,13 @@ class Database(Profiler):
else: else:
dic.zero = next_dic dic.zero = next_dic
dic = next_dic dic = next_dic
dic.set( if dic.active():
# Refuse to add ip4* whose network is already matching
return
self.set_match(
dic,
updated, updated,
level, source or RulePath(),
source,
) )
def set_ip4address(self, def set_ip4address(self,

View file

@ -32,7 +32,10 @@ if __name__ == '__main__':
DB = database.Database() DB = database.Database()
for path in DB.list_asn(): def add_ranges(path: database.Path,
match: database.Match,
_: typing.Any) -> None:
assert isinstance(path, database.AsnPath)
asn_str = database.Database.unpack_asn(path) asn_str = database.Database.unpack_asn(path)
DB.enter_step('asn_get_ranges') DB.enter_step('asn_get_ranges')
for prefix in get_ranges(asn_str): for prefix in get_ranges(asn_str):
@ -49,4 +52,7 @@ if __name__ == '__main__':
else: else:
log.error('Unknown prefix version: %s', prefix) log.error('Unknown prefix version: %s', prefix)
for _ in DB.exec_each_asn(add_ranges, None):
pass
DB.save() DB.save()