Generates a host list of first-party trackers for ad-blocking.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

640 lines
20 KiB

3 years ago
  1. #!/usr/bin/env python3
  2. """
  3. Utility functions to interact with the database.
  4. """
  5. import typing
  6. import time
  7. import logging
  8. import coloredlogs
  9. import pickle
  10. coloredlogs.install(
  11. level='DEBUG',
  12. fmt='%(asctime)s %(name)s %(levelname)s %(message)s'
  13. )
  14. Asn = int
  15. Timestamp = int
  16. Level = int
  17. class Path():
  18. # FP add boolean here
  19. pass
  20. class RulePath(Path):
  21. def __str__(self) -> str:
  22. return '(rule)'
  23. class RuleFirstPath(RulePath):
  24. def __str__(self) -> str:
  25. return '(first-party rule)'
  26. class RuleMultiPath(RulePath):
  27. def __str__(self) -> str:
  28. return '(multi-party rule)'
  29. class DomainPath(Path):
  30. def __init__(self, parts: typing.List[str]):
  31. self.parts = parts
  32. def __str__(self) -> str:
  33. return '?.' + Database.unpack_domain(self)
  34. class HostnamePath(DomainPath):
  35. def __str__(self) -> str:
  36. return Database.unpack_domain(self)
  37. class ZonePath(DomainPath):
  38. def __str__(self) -> str:
  39. return '*.' + Database.unpack_domain(self)
  40. class AsnPath(Path):
  41. def __init__(self, asn: Asn):
  42. self.asn = asn
  43. def __str__(self) -> str:
  44. return Database.unpack_asn(self)
  45. class Ip4Path(Path):
  46. def __init__(self, value: int, prefixlen: int):
  47. self.value = value
  48. self.prefixlen = prefixlen
  49. def __str__(self) -> str:
  50. return Database.unpack_ip4network(self)
  51. class Match():
  52. def __init__(self) -> None:
  53. self.source: typing.Optional[Path] = None
  54. self.updated: int = 0
  55. # Cache
  56. self.level: int = 0
  57. self.first_party: bool = False
  58. self.references: int = 0
  59. def active(self, first_party: bool = None) -> bool:
  60. if self.updated == 0 or (first_party and not self.first_party):
  61. return False
  62. return True
  63. class AsnNode(Match):
  64. def __init__(self) -> None:
  65. Match.__init__(self)
  66. self.name = ''
  67. class DomainTreeNode():
  68. def __init__(self) -> None:
  69. self.children: typing.Dict[str, DomainTreeNode] = dict()
  70. self.match_zone = Match()
  71. self.match_hostname = Match()
  72. class IpTreeNode(Match):
  73. def __init__(self) -> None:
  74. Match.__init__(self)
  75. self.zero: typing.Optional[IpTreeNode] = None
  76. self.one: typing.Optional[IpTreeNode] = None
  77. Node = typing.Union[DomainTreeNode, IpTreeNode, AsnNode]
  78. MatchCallable = typing.Callable[[Path,
  79. Match],
  80. typing.Any]
  81. class Profiler():
  82. def __init__(self) -> None:
  83. self.log = logging.getLogger('profiler')
  84. self.time_last = time.perf_counter()
  85. self.time_step = 'init'
  86. self.time_dict: typing.Dict[str, float] = dict()
  87. self.step_dict: typing.Dict[str, int] = dict()
  88. def enter_step(self, name: str) -> None:
  89. now = time.perf_counter()
  90. try:
  91. self.time_dict[self.time_step] += now - self.time_last
  92. self.step_dict[self.time_step] += int(name != self.time_step)
  93. except KeyError:
  94. self.time_dict[self.time_step] = now - self.time_last
  95. self.step_dict[self.time_step] = 1
  96. self.time_step = name
  97. self.time_last = time.perf_counter()
  98. def profile(self) -> None:
  99. self.enter_step('profile')
  100. total = sum(self.time_dict.values())
  101. for key, secs in sorted(self.time_dict.items(), key=lambda t: t[1]):
  102. times = self.step_dict[key]
  103. self.log.debug(f"{key:<20}: {times:9d} × {secs/times:5.3e} "
  104. f"= {secs:9.2f} s ({secs/total:7.2%}) ")
  105. self.log.debug(f"{'total':<20}: "
  106. f"{total:9.2f} s ({1:7.2%})")
  107. class Database(Profiler):
  108. VERSION = 17
  109. PATH = "blocking.p"
  110. def initialize(self) -> None:
  111. self.log.warning(
  112. "Creating database version: %d ",
  113. Database.VERSION)
  114. # Dummy match objects that everything refer to
  115. self.rules: typing.List[Match] = list()
  116. for first_party in (False, True):
  117. m = Match()
  118. m.updated = 1
  119. m.level = 0
  120. m.first_party = first_party
  121. self.rules.append(m)
  122. self.domtree = DomainTreeNode()
  123. self.asns: typing.Dict[Asn, AsnNode] = dict()
  124. self.ip4tree = IpTreeNode()
  125. def load(self) -> None:
  126. self.enter_step('load')
  127. try:
  128. with open(self.PATH, 'rb') as db_fdsec:
  129. version, data = pickle.load(db_fdsec)
  130. if version == Database.VERSION:
  131. self.rules, self.domtree, self.asns, self.ip4tree = data
  132. return
  133. self.log.warning(
  134. "Outdated database version found: %d, "
  135. "it will be rebuilt.",
  136. version)
  137. except (TypeError, AttributeError, EOFError):
  138. self.log.error(
  139. "Corrupt (or heavily outdated) database found, "
  140. "it will be rebuilt.")
  141. except FileNotFoundError:
  142. pass
  143. self.initialize()
  144. def save(self) -> None:
  145. self.enter_step('save')
  146. with open(self.PATH, 'wb') as db_fdsec:
  147. data = self.rules, self.domtree, self.asns, self.ip4tree
  148. pickle.dump((self.VERSION, data), db_fdsec)
  149. self.profile()
  150. def __init__(self) -> None:
  151. Profiler.__init__(self)
  152. self.log = logging.getLogger('db')
  153. self.load()
  154. @staticmethod
  155. def pack_domain(domain: str) -> DomainPath:
  156. return DomainPath(domain.split('.')[::-1])
  157. @staticmethod
  158. def unpack_domain(domain: DomainPath) -> str:
  159. return '.'.join(domain.parts[::-1])
  160. @staticmethod
  161. def pack_asn(asn: str) -> AsnPath:
  162. asn = asn.upper()
  163. if asn.startswith('AS'):
  164. asn = asn[2:]
  165. return AsnPath(int(asn))
  166. @staticmethod
  167. def unpack_asn(asn: AsnPath) -> str:
  168. return f'AS{asn.asn}'
  169. @staticmethod
  170. def pack_ip4address(address: str) -> Ip4Path:
  171. addr = 0
  172. for split in address.split('.'):
  173. addr = (addr << 8) + int(split)
  174. return Ip4Path(addr, 32)
  175. @staticmethod
  176. def unpack_ip4address(address: Ip4Path) -> str:
  177. addr = address.value
  178. assert address.prefixlen == 32
  179. octets: typing.List[int] = list()
  180. octets = [0] * 4
  181. for o in reversed(range(4)):
  182. octets[o] = addr & 0xFF
  183. addr >>= 8
  184. return '.'.join(map(str, octets))
  185. @staticmethod
  186. def pack_ip4network(network: str) -> Ip4Path:
  187. address, prefixlen_str = network.split('/')
  188. prefixlen = int(prefixlen_str)
  189. addr = Database.pack_ip4address(address)
  190. addr.prefixlen = prefixlen
  191. return addr
  192. @staticmethod
  193. def unpack_ip4network(network: Ip4Path) -> str:
  194. addr = network.value
  195. octets: typing.List[int] = list()
  196. octets = [0] * 4
  197. for o in reversed(range(4)):
  198. octets[o] = addr & 0xFF
  199. addr >>= 8
  200. return '.'.join(map(str, octets)) + '/' + str(network.prefixlen)
  201. def get_match(self, path: Path) -> Match:
  202. if isinstance(path, RuleMultiPath):
  203. return self.rules[0]
  204. elif isinstance(path, RuleFirstPath):
  205. return self.rules[1]
  206. elif isinstance(path, AsnPath):
  207. return self.asns[path.asn]
  208. elif isinstance(path, DomainPath):
  209. dicd = self.domtree
  210. for part in path.parts:
  211. dicd = dicd.children[part]
  212. if isinstance(path, HostnamePath):
  213. return dicd.match_hostname
  214. elif isinstance(path, ZonePath):
  215. return dicd.match_zone
  216. else:
  217. raise ValueError
  218. elif isinstance(path, Ip4Path):
  219. dici = self.ip4tree
  220. for i in range(31, 31-path.prefixlen, -1):
  221. bit = (path.value >> i) & 0b1
  222. dici_next = dici.one if bit else dici.zero
  223. if not dici_next:
  224. raise IndexError
  225. dici = dici_next
  226. return dici
  227. else:
  228. raise ValueError
  229. def exec_each_asn(self,
  230. callback: MatchCallable,
  231. ) -> typing.Any:
  232. for asn in self.asns:
  233. match = self.asns[asn]
  234. if match.active():
  235. c = callback(
  236. AsnPath(asn),
  237. match,
  238. )
  239. try:
  240. yield from c
  241. except TypeError: # not iterable
  242. pass
  243. def exec_each_domain(self,
  244. callback: MatchCallable,
  245. _dic: DomainTreeNode = None,
  246. _par: DomainPath = None,
  247. ) -> typing.Any:
  248. _dic = _dic or self.domtree
  249. _par = _par or DomainPath([])
  250. if _dic.match_hostname.active():
  251. c = callback(
  252. HostnamePath(_par.parts),
  253. _dic.match_hostname,
  254. )
  255. try:
  256. yield from c
  257. except TypeError: # not iterable
  258. pass
  259. if _dic.match_zone.active():
  260. c = callback(
  261. ZonePath(_par.parts),
  262. _dic.match_zone,
  263. )
  264. try:
  265. yield from c
  266. except TypeError: # not iterable
  267. pass
  268. for part in _dic.children:
  269. dic = _dic.children[part]
  270. yield from self.exec_each_domain(
  271. callback,
  272. _dic=dic,
  273. _par=DomainPath(_par.parts + [part])
  274. )
  275. def exec_each_ip4(self,
  276. callback: MatchCallable,
  277. _dic: IpTreeNode = None,
  278. _par: Ip4Path = None,
  279. ) -> typing.Any:
  280. _dic = _dic or self.ip4tree
  281. _par = _par or Ip4Path(0, 0)
  282. if _dic.active():
  283. c = callback(
  284. _par,
  285. _dic,
  286. )
  287. try:
  288. yield from c
  289. except TypeError: # not iterable
  290. pass
  291. # 0
  292. pref = _par.prefixlen + 1
  293. dic = _dic.zero
  294. if dic:
  295. addr0 = _par.value & (0xFFFFFFFF ^ (1 << (32-pref)))
  296. assert addr0 == _par.value
  297. yield from self.exec_each_ip4(
  298. callback,
  299. _dic=dic,
  300. _par=Ip4Path(addr0, pref)
  301. )
  302. # 1
  303. dic = _dic.one
  304. if dic:
  305. addr1 = _par.value | (1 << (32-pref))
  306. yield from self.exec_each_ip4(
  307. callback,
  308. _dic=dic,
  309. _par=Ip4Path(addr1, pref)
  310. )
  311. def exec_each(self,
  312. callback: MatchCallable,
  313. ) -> typing.Any:
  314. yield from self.exec_each_domain(callback)
  315. yield from self.exec_each_ip4(callback)
  316. yield from self.exec_each_asn(callback)
  317. def update_references(self) -> None:
  318. # Should be correctly calculated normally,
  319. # keeping this just in case
  320. def reset_references_cb(path: Path,
  321. match: Match
  322. ) -> None:
  323. match.references = 0
  324. for _ in self.exec_each(reset_references_cb):
  325. pass
  326. def increment_references_cb(path: Path,
  327. match: Match
  328. ) -> None:
  329. if match.source:
  330. source = self.get_match(match.source)
  331. source.references += 1
  332. for _ in self.exec_each(increment_references_cb):
  333. pass
  334. def prune(self, before: int, base_only: bool = False) -> None:
  335. raise NotImplementedError
  336. def explain(self, path: Path) -> str:
  337. match = self.get_match(path)
  338. if isinstance(match, AsnNode):
  339. string = f'{path} ({match.name}) #{match.references}'
  340. else:
  341. string = f'{path} #{match.references}'
  342. if match.source:
  343. string += f' ← {self.explain(match.source)}'
  344. return string
  345. def export(self,
  346. first_party_only: bool = False,
  347. end_chain_only: bool = False,
  348. explain: bool = False,
  349. ) -> typing.Iterable[str]:
  350. def export_cb(path: Path, match: Match
  351. ) -> typing.Iterable[str]:
  352. assert isinstance(path, DomainPath)
  353. if not isinstance(path, HostnamePath):
  354. return
  355. if first_party_only and not match.first_party:
  356. return
  357. if end_chain_only and match.references > 0:
  358. return
  359. if explain:
  360. yield self.explain(path)
  361. else:
  362. yield self.unpack_domain(path)
  363. yield from self.exec_each_domain(export_cb)
  364. def list_rules(self,
  365. first_party_only: bool = False,
  366. ) -> typing.Iterable[str]:
  367. def list_rules_cb(path: Path, match: Match
  368. ) -> typing.Iterable[str]:
  369. if first_party_only and not match.first_party:
  370. return
  371. if isinstance(path, ZonePath) \
  372. or (isinstance(path, Ip4Path) and path.prefixlen < 32):
  373. # if match.level == 1:
  374. # It should be the latter condition but it is more
  375. # useful when using the former
  376. yield self.explain(path)
  377. yield from self.exec_each(list_rules_cb)
  378. def count_records(self,
  379. first_party_only: bool = False,
  380. rules_only: bool = False,
  381. ) -> str:
  382. memo: typing.Dict[str, int] = dict()
  383. def count_records_cb(path: Path, match: Match) -> None:
  384. if first_party_only and not match.first_party:
  385. return
  386. # if isinstance(path, ZonePath) \
  387. # or (isinstance(path, Ip4Path) and path.prefixlen < 32):
  388. if rules_only and match.level > 1:
  389. return
  390. try:
  391. memo[path.__class__.__name__] += 1
  392. except KeyError:
  393. memo[path.__class__.__name__] = 1
  394. for _ in self.exec_each(count_records_cb):
  395. pass
  396. split: typing.List[str] = list()
  397. for key, value in sorted(memo.items(), key=lambda s: s[0]):
  398. split.append(f'{key[:-4]}: {value}')
  399. return ', '.join(split)
  400. def get_domain(self, domain_str: str) -> typing.Iterable[DomainPath]:
  401. self.enter_step('get_domain_pack')
  402. domain = self.pack_domain(domain_str)
  403. self.enter_step('get_domain_brws')
  404. dic = self.domtree
  405. depth = 0
  406. for part in domain.parts:
  407. if dic.match_zone.active():
  408. self.enter_step('get_domain_yield')
  409. yield ZonePath(domain.parts[:depth])
  410. self.enter_step('get_domain_brws')
  411. if part not in dic.children:
  412. return
  413. dic = dic.children[part]
  414. depth += 1
  415. if dic.match_zone.active():
  416. self.enter_step('get_domain_yield')
  417. yield ZonePath(domain.parts)
  418. if dic.match_hostname.active():
  419. self.enter_step('get_domain_yield')
  420. yield HostnamePath(domain.parts)
  421. def get_ip4(self, ip4_str: str) -> typing.Iterable[Path]:
  422. self.enter_step('get_ip4_pack')
  423. ip4 = self.pack_ip4address(ip4_str)
  424. self.enter_step('get_ip4_brws')
  425. dic = self.ip4tree
  426. for i in range(31, 31-ip4.prefixlen, -1):
  427. bit = (ip4.value >> i) & 0b1
  428. if dic.active():
  429. self.enter_step('get_ip4_yield')
  430. yield Ip4Path(ip4.value >> (i+1) << (i+1), 31-i)
  431. self.enter_step('get_ip4_brws')
  432. next_dic = dic.one if bit else dic.zero
  433. if next_dic is None:
  434. return
  435. dic = next_dic
  436. if dic.active():
  437. self.enter_step('get_ip4_yield')
  438. yield ip4
  439. def _set_match(self,
  440. match: Match,
  441. updated: int,
  442. source: Path,
  443. source_match: Match = None,
  444. ) -> None:
  445. # source_match is in parameters because most of the time
  446. # its parent function needs it too,
  447. # so it can pass it to save a traversal
  448. source_match = source_match or self.get_match(source)
  449. new_level = source_match.level + 1
  450. if updated > match.updated or new_level < match.level \
  451. or source_match.first_party > match.first_party:
  452. # NOTE FP and level of matches referencing this one
  453. # won't be updated until run or prune
  454. if match.source:
  455. old_source = self.get_match(match.source)
  456. old_source.references -= 1
  457. match.updated = updated
  458. match.level = new_level
  459. match.first_party = source_match.first_party
  460. match.source = source
  461. source_match.references += 1
  462. def _set_domain(self,
  463. hostname: bool,
  464. domain_str: str,
  465. updated: int,
  466. source: Path) -> None:
  467. self.enter_step('set_domain_pack')
  468. domain = self.pack_domain(domain_str)
  469. self.enter_step('set_domain_fp')
  470. source_match = self.get_match(source)
  471. is_first_party = source_match.first_party
  472. self.enter_step('set_domain_brws')
  473. dic = self.domtree
  474. for part in domain.parts:
  475. if part not in dic.children:
  476. dic.children[part] = DomainTreeNode()
  477. dic = dic.children[part]
  478. if dic.match_zone.active(is_first_party):
  479. # Refuse to add domain whose zone is already matching
  480. return
  481. if hostname:
  482. match = dic.match_hostname
  483. else:
  484. match = dic.match_zone
  485. self._set_match(
  486. match,
  487. updated,
  488. source,
  489. source_match=source_match,
  490. )
  491. def set_hostname(self,
  492. *args: typing.Any, **kwargs: typing.Any
  493. ) -> None:
  494. self._set_domain(True, *args, **kwargs)
  495. def set_zone(self,
  496. *args: typing.Any, **kwargs: typing.Any
  497. ) -> None:
  498. self._set_domain(False, *args, **kwargs)
  499. def set_asn(self,
  500. asn_str: str,
  501. updated: int,
  502. source: Path) -> None:
  503. self.enter_step('set_asn')
  504. path = self.pack_asn(asn_str)
  505. if path.asn in self.asns:
  506. match = self.asns[path.asn]
  507. else:
  508. match = AsnNode()
  509. self.asns[path.asn] = match
  510. self._set_match(
  511. match,
  512. updated,
  513. source,
  514. )
  515. def _set_ip4(self,
  516. ip4: Ip4Path,
  517. updated: int,
  518. source: Path) -> None:
  519. self.enter_step('set_ip4_fp')
  520. source_match = self.get_match(source)
  521. is_first_party = source_match.first_party
  522. self.enter_step('set_ip4_brws')
  523. dic = self.ip4tree
  524. for i in range(31, 31-ip4.prefixlen, -1):
  525. bit = (ip4.value >> i) & 0b1
  526. next_dic = dic.one if bit else dic.zero
  527. if next_dic is None:
  528. next_dic = IpTreeNode()
  529. if bit:
  530. dic.one = next_dic
  531. else:
  532. dic.zero = next_dic
  533. dic = next_dic
  534. if dic.active(is_first_party):
  535. # Refuse to add ip4* whose network is already matching
  536. return
  537. self._set_match(
  538. dic,
  539. updated,
  540. source,
  541. source_match=source_match,
  542. )
  543. def set_ip4address(self,
  544. ip4address_str: str,
  545. *args: typing.Any, **kwargs: typing.Any
  546. ) -> None:
  547. self.enter_step('set_ip4add_pack')
  548. ip4 = self.pack_ip4address(ip4address_str)
  549. self._set_ip4(ip4, *args, **kwargs)
  550. def set_ip4network(self,
  551. ip4network_str: str,
  552. *args: typing.Any, **kwargs: typing.Any
  553. ) -> None:
  554. self.enter_step('set_ip4net_pack')
  555. ip4 = self.pack_ip4network(ip4network_str)
  556. self._set_ip4(ip4, *args, **kwargs)