Made references work
This commit is contained in:
parent
03a4042238
commit
c3bf102289
150
database.py
150
database.py
|
@ -70,19 +70,9 @@ class Match():
|
|||
self.updated: int = 0
|
||||
self.level: int = 0
|
||||
self.source: typing.Optional[Path] = None
|
||||
self.references: int = 0
|
||||
# FP dupplicate args
|
||||
|
||||
def set(self,
|
||||
updated: int,
|
||||
level: int,
|
||||
source: Path,
|
||||
) -> None:
|
||||
if updated > self.updated or level > self.level:
|
||||
self.updated = updated
|
||||
self.level = level
|
||||
self.source = source
|
||||
# FP dupplicate function
|
||||
|
||||
def active(self) -> bool:
|
||||
return self.updated > 0
|
||||
|
||||
|
@ -143,7 +133,7 @@ class Profiler():
|
|||
|
||||
|
||||
class Database(Profiler):
|
||||
VERSION = 13
|
||||
VERSION = 14
|
||||
PATH = "blocking.p"
|
||||
|
||||
def initialize(self) -> None:
|
||||
|
@ -268,6 +258,24 @@ class Database(Profiler):
|
|||
else:
|
||||
raise ValueError
|
||||
|
||||
def exec_each_asn(self,
|
||||
callback: MatchCallable,
|
||||
arg: typing.Any = None,
|
||||
) -> typing.Any:
|
||||
for asn in self.asns:
|
||||
match = self.asns[asn]
|
||||
if match.active():
|
||||
c = callback(
|
||||
AsnPath(asn),
|
||||
match,
|
||||
arg
|
||||
)
|
||||
try:
|
||||
yield from c
|
||||
except TypeError: # not iterable
|
||||
pass
|
||||
|
||||
|
||||
def exec_each_domain(self,
|
||||
callback: MatchCallable,
|
||||
arg: typing.Any = None,
|
||||
|
@ -277,17 +285,25 @@ class Database(Profiler):
|
|||
_dic = _dic or self.domtree
|
||||
_par = _par or DomainPath([])
|
||||
if _dic.match_hostname.active():
|
||||
yield from callback(
|
||||
c = callback(
|
||||
HostnamePath(_par.parts),
|
||||
_dic.match_hostname,
|
||||
arg
|
||||
)
|
||||
try:
|
||||
yield from c
|
||||
except TypeError: # not iterable
|
||||
pass
|
||||
if _dic.match_zone.active():
|
||||
yield from callback(
|
||||
c = callback(
|
||||
ZonePath(_par.parts),
|
||||
_dic.match_zone,
|
||||
arg
|
||||
)
|
||||
try:
|
||||
yield from c
|
||||
except TypeError: # not iterable
|
||||
pass
|
||||
for part in _dic.children:
|
||||
dic = _dic.children[part]
|
||||
yield from self.exec_each_domain(
|
||||
|
@ -306,11 +322,15 @@ class Database(Profiler):
|
|||
_dic = _dic or self.ip4tree
|
||||
_par = _par or Ip4Path(0, 0)
|
||||
if _dic.active():
|
||||
yield from callback(
|
||||
c = callback(
|
||||
_par,
|
||||
_dic,
|
||||
arg
|
||||
)
|
||||
try:
|
||||
yield from c
|
||||
except TypeError: # not iterable
|
||||
pass
|
||||
|
||||
# 0
|
||||
pref = _par.prefixlen + 1
|
||||
|
@ -341,17 +361,35 @@ class Database(Profiler):
|
|||
) -> typing.Any:
|
||||
yield from self.exec_each_domain(callback)
|
||||
yield from self.exec_each_ip4(callback)
|
||||
# TODO ASN
|
||||
yield from self.exec_each_asn(callback)
|
||||
|
||||
def update_references(self) -> None:
|
||||
raise NotImplementedError
|
||||
# Should be correctly calculated normally,
|
||||
# keeping this just in case
|
||||
def reset_references_cb(path: Path,
|
||||
match: Match, _: typing.Any
|
||||
) -> None:
|
||||
match.references = 0
|
||||
for _ in self.exec_each(reset_references_cb, None):
|
||||
pass
|
||||
|
||||
def increment_references_cb(path: Path,
|
||||
match: Match, _: typing.Any
|
||||
) -> None:
|
||||
if match.source:
|
||||
source = self.get_match(match.source)
|
||||
source.references += 1
|
||||
for _ in self.exec_each(increment_references_cb, None):
|
||||
pass
|
||||
|
||||
def prune(self, before: int, base_only: bool = False) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def explain(self, path: Path) -> str:
|
||||
string = str(path)
|
||||
match = self.get_match(path)
|
||||
string = f'{path}'
|
||||
if not isinstance(path, RulePath):
|
||||
string += f' #{match.references}'
|
||||
if match.source:
|
||||
string += f' ← {self.explain(match.source)}'
|
||||
return string
|
||||
|
@ -361,13 +399,16 @@ class Database(Profiler):
|
|||
end_chain_only: bool = False,
|
||||
explain: bool = False,
|
||||
) -> typing.Iterable[str]:
|
||||
if first_party_only or end_chain_only:
|
||||
if first_party_only:
|
||||
raise NotImplementedError
|
||||
|
||||
def export_cb(path: Path, match: Match, _: typing.Any
|
||||
) -> typing.Iterable[str]:
|
||||
assert isinstance(path, DomainPath)
|
||||
if isinstance(path, HostnamePath):
|
||||
if not isinstance(path, HostnamePath):
|
||||
return
|
||||
if end_chain_only and match.references > 0:
|
||||
return
|
||||
if explain:
|
||||
yield self.explain(path)
|
||||
else:
|
||||
|
@ -437,9 +478,22 @@ class Database(Profiler):
|
|||
self.enter_step('get_ip4_yield')
|
||||
yield ip4
|
||||
|
||||
def list_asn(self) -> typing.Iterable[AsnPath]:
|
||||
for asn in self.asns:
|
||||
yield AsnPath(asn)
|
||||
def set_match(self,
|
||||
match: Match,
|
||||
updated: int,
|
||||
source: Path,
|
||||
) -> None:
|
||||
new_source = self.get_match(source)
|
||||
new_level = new_source.level + 1
|
||||
if updated > match.updated or new_level > match.level:
|
||||
if match.source:
|
||||
old_source = self.get_match(match.source)
|
||||
old_source.references -= 1
|
||||
match.updated = updated
|
||||
match.level = new_level
|
||||
match.source = source
|
||||
new_source.references += 1
|
||||
# FP dupplicate function
|
||||
|
||||
def _set_domain(self,
|
||||
hostname: bool,
|
||||
|
@ -451,30 +505,23 @@ class Database(Profiler):
|
|||
if is_first_party:
|
||||
raise NotImplementedError
|
||||
domain = self.pack_domain(domain_str)
|
||||
self.enter_step('set_domain_src')
|
||||
if source is None:
|
||||
level = 0
|
||||
source = RulePath()
|
||||
else:
|
||||
match = self.get_match(source)
|
||||
level = match.level + 1
|
||||
self.enter_step('set_domain_brws')
|
||||
dic = self.domtree
|
||||
for part in domain.parts:
|
||||
if dic.match_zone.active():
|
||||
# Refuse to add domain whose zone is already matching
|
||||
return
|
||||
if part not in dic.children:
|
||||
dic.children[part] = DomainTreeNode()
|
||||
dic = dic.children[part]
|
||||
if dic.match_zone.active():
|
||||
# Refuse to add domain whose zone is already matching
|
||||
return
|
||||
if hostname:
|
||||
match = dic.match_hostname
|
||||
else:
|
||||
match = dic.match_zone
|
||||
match.set(
|
||||
self.set_match(
|
||||
match,
|
||||
updated,
|
||||
level,
|
||||
source,
|
||||
source or RulePath(),
|
||||
)
|
||||
|
||||
def set_hostname(self,
|
||||
|
@ -495,22 +542,16 @@ class Database(Profiler):
|
|||
self.enter_step('set_asn')
|
||||
if is_first_party:
|
||||
raise NotImplementedError
|
||||
if source is None:
|
||||
level = 0
|
||||
source = RulePath()
|
||||
else:
|
||||
match = self.get_match(source)
|
||||
level = match.level + 1
|
||||
path = self.pack_asn(asn_str)
|
||||
if path.asn in self.asns:
|
||||
match = self.asns[path.asn]
|
||||
else:
|
||||
match = AsnNode()
|
||||
self.asns[path.asn] = match
|
||||
match.set(
|
||||
self.set_match(
|
||||
match,
|
||||
updated,
|
||||
level,
|
||||
source,
|
||||
source or RulePath(),
|
||||
)
|
||||
|
||||
def _set_ip4(self,
|
||||
|
@ -520,20 +561,10 @@ class Database(Profiler):
|
|||
source: Path = None) -> None:
|
||||
if is_first_party:
|
||||
raise NotImplementedError
|
||||
self.enter_step('set_ip4_src')
|
||||
if source is None:
|
||||
level = 0
|
||||
source = RulePath()
|
||||
else:
|
||||
match = self.get_match(source)
|
||||
level = match.level + 1
|
||||
self.enter_step('set_ip4_brws')
|
||||
dic = self.ip4tree
|
||||
for i in range(31, 31-ip4.prefixlen, -1):
|
||||
bit = (ip4.value >> i) & 0b1
|
||||
if dic.active():
|
||||
# Refuse to add ip4* whose network is already matching
|
||||
return
|
||||
next_dic = dic.one if bit else dic.zero
|
||||
if next_dic is None:
|
||||
next_dic = IpTreeNode()
|
||||
|
@ -542,10 +573,13 @@ class Database(Profiler):
|
|||
else:
|
||||
dic.zero = next_dic
|
||||
dic = next_dic
|
||||
dic.set(
|
||||
if dic.active():
|
||||
# Refuse to add ip4* whose network is already matching
|
||||
return
|
||||
self.set_match(
|
||||
dic,
|
||||
updated,
|
||||
level,
|
||||
source,
|
||||
source or RulePath(),
|
||||
)
|
||||
|
||||
def set_ip4address(self,
|
||||
|
|
|
@ -32,7 +32,10 @@ if __name__ == '__main__':
|
|||
|
||||
DB = database.Database()
|
||||
|
||||
for path in DB.list_asn():
|
||||
def add_ranges(path: database.Path,
|
||||
match: database.Match,
|
||||
_: typing.Any) -> None:
|
||||
assert isinstance(path, database.AsnPath)
|
||||
asn_str = database.Database.unpack_asn(path)
|
||||
DB.enter_step('asn_get_ranges')
|
||||
for prefix in get_ranges(asn_str):
|
||||
|
@ -49,4 +52,7 @@ if __name__ == '__main__':
|
|||
else:
|
||||
log.error('Unknown prefix version: %s', prefix)
|
||||
|
||||
for _ in DB.exec_each_asn(add_ranges, None):
|
||||
pass
|
||||
|
||||
DB.save()
|
||||
|
|
Loading…
Reference in a new issue