Compare commits

..

No commits in common. "a0e68f08487e333c39b5056ed24eb925cb3ff3c5" and "5023b85d7ca802f2908526b0e48ac5245aa457df" have entirely different histories.

13 changed files with 897 additions and 724 deletions

5
.gitignore vendored
View file

@ -1,4 +1,7 @@
*.log *.log
*.p *.db
*.db-journal
nameservers nameservers
nameservers.head nameservers.head
*.o
*.so

831
database.py Normal file → Executable file
View file

@ -4,115 +4,111 @@
Utility functions to interact with the database. Utility functions to interact with the database.
""" """
import sqlite3
import typing import typing
import time import time
import os
import logging import logging
import argparse
import coloredlogs import coloredlogs
import pickle import ipaddress
import math
coloredlogs.install( coloredlogs.install(
level='DEBUG', level='DEBUG',
fmt='%(asctime)s %(name)s %(levelname)s %(message)s' fmt='%(asctime)s %(name)s %(levelname)s %(message)s'
) )
Asn = int DbValue = typing.Union[None, int, float, str, bytes]
Timestamp = int
Level = int
class Path(): class Database():
# FP add boolean here VERSION = 5
pass PATH = "blocking.db"
def open(self) -> None:
mode = 'rwc' if self.write else 'ro'
uri = f'file:{self.PATH}?mode={mode}'
self.conn = sqlite3.connect(uri, uri=True)
cursor = self.conn.cursor()
cursor.execute("PRAGMA foreign_keys = ON")
self.conn.create_function("unpack_asn", 1,
self.unpack_asn,
deterministic=True)
self.conn.create_function("unpack_ip4address", 1,
self.unpack_ip4address,
deterministic=True)
self.conn.create_function("unpack_ip4network", 2,
self.unpack_ip4network,
deterministic=True)
self.conn.create_function("unpack_domain", 1,
lambda s: s[:-1][::-1],
deterministic=True)
self.conn.create_function("format_zone", 1,
lambda s: '*' + s[::-1],
deterministic=True)
class RulePath(Path): def get_meta(self, key: str) -> typing.Optional[int]:
pass cursor = self.conn.cursor()
try:
cursor.execute("SELECT value FROM meta WHERE key=?", (key,))
except sqlite3.OperationalError:
return None
for ver, in cursor:
return ver
return None
def set_meta(self, key: str, val: int) -> None:
cursor = self.conn.cursor()
cursor.execute("INSERT INTO meta VALUES (?, ?) "
"ON CONFLICT (key) DO "
"UPDATE set value=?",
(key, val, val))
class DomainPath(Path): def close(self) -> None:
def __init__(self, path: typing.List[str]): self.enter_step('close_commit')
self.path = path self.conn.commit()
self.enter_step('close')
self.conn.close()
self.profile()
def initialize(self) -> None:
self.close()
self.enter_step('initialize')
if not self.write:
self.log.error("Cannot initialize in read-only mode.")
raise
os.unlink(self.PATH)
self.open()
self.log.info("Creating database version %d.", self.VERSION)
cursor = self.conn.cursor()
with open("database_schema.sql", 'r') as db_schema:
cursor.executescript(db_schema.read())
self.set_meta('version', self.VERSION)
self.conn.commit()
class HostnamePath(DomainPath): def __init__(self, write: bool = False) -> None:
pass self.log = logging.getLogger('db')
class ZonePath(DomainPath):
pass
class AsnPath(Path):
def __init__(self, asn: Asn):
self.asn = asn
class Ip4Path(Path):
def __init__(self, value: int, prefixlen: int):
self.value = value
self.prefixlen = prefixlen
class Match():
def __init__(self) -> None:
self.updated: int = 0
self.level: int = 0
self.source: Path = RulePath()
# FP dupplicate args
def set(self,
updated: int,
level: int,
source: Path,
) -> None:
if updated > self.updated or level > self.level:
self.updated = updated
self.level = level
self.source = source
# FP dupplicate function
def active(self) -> bool:
return self.updated > 0
class AsnNode(Match):
pass
class DomainTreeNode():
def __init__(self) -> None:
self.children: typing.Dict[str, DomainTreeNode] = dict()
self.match_zone = Match()
self.match_hostname = Match()
class IpTreeNode():
def __init__(self) -> None:
self.children: typing.List[typing.Optional[IpTreeNode]] = [None, None]
self.match = Match()
Node = typing.Union[DomainTreeNode, IpTreeNode, AsnNode]
NodeCallable = typing.Callable[[Path,
Node,
typing.Optional[typing.Any]],
typing.Any]
class Profiler():
def __init__(self) -> None:
self.log = logging.getLogger('profiler')
self.time_last = time.perf_counter() self.time_last = time.perf_counter()
self.time_step = 'init' self.time_step = 'init'
self.time_dict: typing.Dict[str, float] = dict() self.time_dict: typing.Dict[str, float] = dict()
self.step_dict: typing.Dict[str, int] = dict() self.step_dict: typing.Dict[str, int] = dict()
self.write = write
self.open()
version = self.get_meta('version')
if version != self.VERSION:
if version is not None:
self.log.warning(
"Outdated database version: %d found, will be rebuilt.",
version)
self.initialize()
def enter_step(self, name: str) -> None: def enter_step(self, name: str) -> None:
return
now = time.perf_counter() now = time.perf_counter()
try: try:
self.time_dict[self.time_step] += now - self.time_last self.time_dict[self.time_step] += now - self.time_last
self.step_dict[self.time_step] += int(name != self.time_step) self.step_dict[self.time_step] += 1
except KeyError: except KeyError:
self.time_dict[self.time_step] = now - self.time_last self.time_dict[self.time_step] = now - self.time_last
self.step_dict[self.time_step] = 1 self.step_dict[self.time_step] = 1
@ -129,334 +125,435 @@ class Profiler():
self.log.debug(f"{'total':<20}: " self.log.debug(f"{'total':<20}: "
f"{total:9.2f} s ({1:7.2%})") f"{total:9.2f} s ({1:7.2%})")
@staticmethod
class Database(Profiler): def pack_hostname(hostname: str) -> str:
VERSION = 10 return hostname[::-1] + '.'
PATH = "blocking.p"
def initialize(self) -> None:
self.log.warning(
"Creating database version: %d ",
Database.VERSION)
self.domtree = DomainTreeNode()
self.asns: typing.Dict[Asn, AsnNode] = dict()
self.ip4tree = IpTreeNode()
def load(self) -> None:
self.enter_step('load')
try:
with open(self.PATH, 'rb') as db_fdsec:
version, data = pickle.load(db_fdsec)
if version == Database.VERSION:
self.domtree, self.asns, self.ip4tree = data
return
self.log.warning(
"Outdated database version found: %d, "
"it will be rebuilt.",
version)
except (TypeError, AttributeError, EOFError):
self.log.error(
"Corrupt (or heavily outdated) database found, "
"it will be rebuilt.")
except FileNotFoundError:
pass
self.initialize()
def save(self) -> None:
self.enter_step('save')
with open(self.PATH, 'wb') as db_fdsec:
data = self.domtree, self.asns, self.ip4tree
pickle.dump((self.VERSION, data), db_fdsec)
self.profile()
def __init__(self) -> None:
Profiler.__init__(self)
self.log = logging.getLogger('db')
self.load()
@staticmethod @staticmethod
def pack_domain(domain: str) -> DomainPath: def pack_zone(zone: str) -> str:
return DomainPath(domain.split('.')[::-1]) return Database.pack_hostname(zone)
@staticmethod @staticmethod
def unpack_domain(domain: DomainPath) -> str: def pack_asn(asn: str) -> int:
return '.'.join(domain.path[::-1])
@staticmethod
def pack_asn(asn: str) -> AsnPath:
asn = asn.upper() asn = asn.upper()
if asn.startswith('AS'): if asn.startswith('AS'):
asn = asn[2:] asn = asn[2:]
return AsnPath(int(asn)) return int(asn)
@staticmethod @staticmethod
def unpack_asn(asn: AsnPath) -> str: def unpack_asn(asn: int) -> str:
return f'AS{asn.asn}' return f'AS{asn}'
@staticmethod @staticmethod
def pack_ip4address(address: str) -> Ip4Path: def pack_ip4address(address: str) -> int:
addr = 0 total = 0
for split in address.split('.'): for i, octet in enumerate(address.split('.')):
addr = (addr << 8) + int(split) total += int(octet) << (3-i)*8
return Ip4Path(addr, 32) return total
# return '{:02x}{:02x}{:02x}{:02x}'.format(
# *[int(c) for c in address.split('.')])
# return base64.b16encode(packed).decode()
# return '{:08b}{:08b}{:08b}{:08b}'.format(
# *[int(c) for c in address.split('.')])
# carg = ctypes.c_wchar_p(address)
# ret = ACCEL.ip4_flat(carg, self.accel_ip4_buf)
# if ret != 0:
# raise ValueError
# return self.accel_ip4_buf.value
# packed = ipaddress.ip_address(address).packed
# return packed
@staticmethod @staticmethod
def unpack_ip4address(address: Ip4Path) -> str: def unpack_ip4address(address: int) -> str:
addr = address.value return '.'.join(str((address >> (i * 8)) & 0xFF)
assert address.prefixlen == 32 for i in reversed(range(4)))
octets: typing.List[int] = list()
octets = [0] * 4
for o in reversed(range(4)):
octets[o] = addr & 0xFF
addr >>= 8
return '.'.join(map(str, octets))
@staticmethod @staticmethod
def pack_ip4network(network: str) -> Ip4Path: def pack_ip4network(network: str) -> typing.Tuple[int, int]:
address, prefixlen_str = network.split('/') # def pack_ip4network(network: str) -> str:
prefixlen = int(prefixlen_str) net = ipaddress.ip_network(network)
addr = Database.pack_ip4address(address) mini = Database.pack_ip4address(net.network_address.exploded)
addr.prefixlen = prefixlen maxi = Database.pack_ip4address(net.broadcast_address.exploded)
return addr # mini = net.network_address.packed
# maxi = net.broadcast_address.packed
return mini, maxi
# return Database.pack_ip4address(net.network_address.exploded)[:net.prefixlen]
@staticmethod @staticmethod
def unpack_ip4network(network: Ip4Path) -> str: def unpack_ip4network(mini: int, maxi: int) -> str:
addr = network.value addr = Database.unpack_ip4address(mini)
octets: typing.List[int] = list() prefixlen = 32-int(math.log2(maxi-mini+1))
octets = [0] * 4 return f'{addr}/{prefixlen}'
for o in reversed(range(4)):
octets[o] = addr & 0xFF
addr >>= 8
return '.'.join(map(str, octets)) + '/' + str(network.prefixlen)
def exec_each_domain(self,
callback: NodeCallable,
arg: typing.Any = None,
_dic: DomainTreeNode = None,
_par: DomainPath = None,
) -> typing.Any:
_dic = _dic or self.domtree
_par = _par or DomainPath([])
yield from callback(_par, _dic, arg)
for part in _dic.children:
dic = _dic.children[part]
yield from self.exec_each_domain(
callback,
arg,
_dic=dic,
_par=DomainPath(_par.path + [part])
)
def exec_each_ip4(self,
callback: NodeCallable,
arg: typing.Any = None,
_dic: IpTreeNode = None,
_par: Ip4Path = None,
) -> typing.Any:
_dic = _dic or self.ip4tree
_par = _par or Ip4Path(0, 0)
callback(_par, _dic, arg)
# 0
dic = _dic.children[0]
if dic:
addr0 = _par.value & (0xFFFFFFFF ^ (1 << (32-_par.prefixlen)))
assert addr0 == _par.value
yield from self.exec_each_ip4(
callback,
arg,
_dic=dic,
_par=Ip4Path(addr0, _par.prefixlen+1)
)
# 1
dic = _dic.children[1]
if dic:
addr1 = _par.value | (1 << (32-_par.prefixlen))
yield from self.exec_each_ip4(
callback,
arg,
_dic=dic,
_par=Ip4Path(addr1, _par.prefixlen+1)
)
def exec_each(self,
callback: NodeCallable,
arg: typing.Any = None,
) -> typing.Any:
yield from self.exec_each_domain(callback)
yield from self.exec_each_ip4(callback)
def update_references(self) -> None: def update_references(self) -> None:
raise NotImplementedError self.enter_step('update_refs')
cursor = self.conn.cursor()
cursor.execute('UPDATE rules AS r SET refs='
'(SELECT count(*) FROM rules '
'WHERE source=r.id)')
def prune(self, before: int, base_only: bool = False) -> None: def prune(self, before: int) -> None:
raise NotImplementedError self.enter_step('prune')
cursor = self.conn.cursor()
cursor.execute('DELETE FROM rules WHERE updated<?', (before,))
def explain(self, entry: int) -> str: def explain(self, entry: int) -> str:
raise NotImplementedError # Format current
string = '???'
cursor = self.conn.cursor()
cursor.execute(
'SELECT unpack_asn(val) FROM asn WHERE entry=:entry '
'UNION '
'SELECT unpack_domain(val) FROM hostname WHERE entry=:entry '
'UNION '
'SELECT format_zone(val) FROM zone WHERE entry=:entry '
'UNION '
'SELECT unpack_ip4address(val) FROM ip4address WHERE entry=:entry '
'UNION '
'SELECT unpack_ip4network(mini, maxi) '
'FROM ip4network WHERE entry=:entry ',
{"entry": entry}
)
for val, in cursor: # only one
string = str(val)
string += f' #{entry}'
# Add source if any
cursor.execute('SELECT source FROM rules WHERE id=?', (entry,))
for source, in cursor:
if source:
string += f'{self.explain(source)}'
return string
def export(self, def export(self,
first_party_only: bool = False, first_party_only: bool = False,
end_chain_only: bool = False, end_chain_only: bool = False,
explain: bool = False, explain: bool = False,
) -> typing.Iterable[str]: ) -> typing.Iterable[str]:
if first_party_only or end_chain_only or explain: selection = 'entry' if explain else 'unpack_domain(val)'
raise NotImplementedError command = f'SELECT {selection} FROM rules ' \
'INNER JOIN hostname ON rules.id = hostname.entry'
def export_cb(path: Path, node: Node, _: typing.Any restrictions: typing.List[str] = list()
) -> typing.Iterable[str]: if first_party_only:
assert isinstance(path, DomainPath) restrictions.append('rules.first_party = 1')
assert isinstance(node, DomainTreeNode) if end_chain_only:
if node.match_hostname: restrictions.append('rules.refs = 0')
a = self.unpack_domain(path) if restrictions:
yield a command += ' WHERE ' + ' AND '.join(restrictions)
if not explain:
yield from self.exec_each_domain(export_cb, None) command += ' ORDER BY unpack_domain(val) ASC'
cursor = self.conn.cursor()
cursor.execute(command)
for val, in cursor:
if explain:
yield self.explain(val)
else:
yield val
def count_rules(self, def count_rules(self,
first_party_only: bool = False, first_party_only: bool = False,
) -> str: ) -> str:
raise NotImplementedError counts: typing.List[str] = list()
cursor = self.conn.cursor()
for table in ['asn', 'ip4network', 'ip4address', 'zone', 'hostname']:
command = f'SELECT count(*) FROM rules ' \
f'INNER JOIN {table} ON rules.id = {table}.entry ' \
'WHERE rules.level = 0'
if first_party_only:
command += ' AND first_party=1'
cursor.execute(command)
count, = cursor.fetchone()
if count > 0:
counts.append(f'{table}: {count}')
def get_domain(self, domain_str: str) -> typing.Iterable[DomainPath]: return ', '.join(counts)
self.enter_step('get_domain_pack')
domain = self.pack_domain(domain_str)
self.enter_step('get_domain_brws')
dic = self.domtree
depth = 0
for part in domain.path:
if dic.match_zone.active():
self.enter_step('get_domain_yield')
yield ZonePath(domain.path[:depth])
self.enter_step('get_domain_brws')
if part not in dic.children:
return
dic = dic.children[part]
depth += 1
if dic.match_zone.active():
self.enter_step('get_domain_yield')
yield ZonePath(domain.path)
if dic.match_hostname.active():
self.enter_step('get_domain_yield')
yield HostnamePath(domain.path)
def get_ip4(self, ip4_str: str) -> typing.Iterable[Path]: def get_domain(self, domain: str) -> typing.Iterable[int]:
self.enter_step('get_ip4_pack') self.enter_step('get_domain_prepare')
ip4 = self.pack_ip4address(ip4_str) domain_prep = self.pack_hostname(domain)
self.enter_step('get_ip4_brws') cursor = self.conn.cursor()
dic = self.ip4tree self.enter_step('get_domain_select')
for i in reversed(range(ip4.prefixlen)): cursor.execute(
part = (ip4.value >> i) & 0b1 'SELECT null, entry FROM hostname '
if dic.match.active(): 'WHERE val=:d '
self.enter_step('get_ip4_yield') 'UNION '
yield Ip4Path(ip4.value, 32-i) 'SELECT * FROM ('
self.enter_step('get_ip4_brws') 'SELECT val, entry FROM zone '
next_dic = dic.children[part] 'WHERE val<=:d '
if next_dic is None: 'ORDER BY val DESC LIMIT 1'
return ')',
dic = next_dic {'d': domain_prep}
if dic.match.active(): )
for val, entry in cursor:
self.enter_step('get_domain_confirm')
if not (val is None or domain_prep.startswith(val)):
continue
self.enter_step('get_domain_yield')
yield entry
def get_domain_in_zone(self, domain: str) -> typing.Iterable[int]:
self.enter_step('get_domainiz_prepare')
domain_prep = self.pack_hostname(domain)
cursor = self.conn.cursor()
self.enter_step('get_domainiz_select')
cursor.execute(
'SELECT val, entry FROM zone '
'WHERE val<=:d '
'ORDER BY val DESC LIMIT 1',
{'d': domain_prep}
)
for val, entry in cursor:
self.enter_step('get_domainiz_confirm')
if not (val is None or domain_prep.startswith(val)):
continue
self.enter_step('get_domainiz_yield')
yield entry
def get_ip4(self, address: str) -> typing.Iterable[int]:
self.enter_step('get_ip4_prepare')
try:
address_prep = self.pack_ip4address(address)
except (ValueError, IndexError):
self.log.error("Invalid ip4address: %s", address)
return
cursor = self.conn.cursor()
self.enter_step('get_ip4_select')
cursor.execute(
'SELECT entry FROM ip4address '
# 'SELECT null, entry FROM ip4address '
'WHERE val=:a '
'UNION '
# 'SELECT * FROM ('
# 'SELECT val, entry FROM ip4network '
# 'WHERE val<=:a '
# 'AND instr(:a, val) > 0 '
# 'ORDER BY val DESC'
# ')'
'SELECT entry FROM ip4network '
'WHERE :a BETWEEN mini AND maxi ',
{'a': address_prep}
)
for entry, in cursor:
# self.enter_step('get_ip4_confirm')
# if not (val is None or val.startswith(address_prep)):
# # PERF startswith but from the end
# continue
self.enter_step('get_ip4_yield') self.enter_step('get_ip4_yield')
yield ip4 yield entry
def list_asn(self) -> typing.Iterable[AsnPath]: def get_ip4_in_network(self, address: str) -> typing.Iterable[int]:
for asn in self.asns: self.enter_step('get_ip4in_prepare')
yield AsnPath(asn) try:
address_prep = self.pack_ip4address(address)
def _set_domain(self, except (ValueError, IndexError):
hostname: bool, self.log.error("Invalid ip4address: %s", address)
domain_str: str, return
updated: int, cursor = self.conn.cursor()
is_first_party: bool = None, self.enter_step('get_ip4in_select')
source: Path = None) -> None: cursor.execute(
self.enter_step('set_domain_pack') 'SELECT entry FROM ip4network '
if is_first_party: 'WHERE :a BETWEEN mini AND maxi ',
raise NotImplementedError {'a': address_prep}
domain = self.pack_domain(domain_str)
self.enter_step('set_domain_brws')
dic = self.domtree
for part in domain.path:
if dic.match_zone.active():
# Refuse to add domain whose zone is already matching
return
if part not in dic.children:
dic.children[part] = DomainTreeNode()
dic = dic.children[part]
if hostname:
match = dic.match_hostname
else:
match = dic.match_zone
match.set(
updated,
0, # TODO Level
source or RulePath(),
) )
for entry, in cursor:
self.enter_step('get_ip4in_yield')
yield entry
def set_hostname(self, def list_asn(self) -> typing.Iterable[typing.Tuple[str, int]]:
*args: typing.Any, **kwargs: typing.Any cursor = self.conn.cursor()
self.enter_step('list_asn_select')
cursor.execute('SELECT val, entry FROM asn')
for val, entry in cursor:
yield f'AS{val}', entry
def _set_generic(self,
table: str,
select_query: str,
insert_query: str,
prep: typing.Dict[str, DbValue],
updated: int,
is_first_party: bool = False,
source: int = None,
) -> None: ) -> None:
self._set_domain(True, *args, **kwargs) # Since this isn't the bulk of the processing,
# here abstraction > performaces
def set_zone(self, # Fields based on the source
*args: typing.Any, **kwargs: typing.Any self.enter_step(f'set_{table}_prepare')
) -> None: cursor = self.conn.cursor()
self._set_domain(False, *args, **kwargs) if source is None:
first_party = int(is_first_party)
level = 0
else:
self.enter_step(f'set_{table}_source')
cursor.execute(
'SELECT first_party, level FROM rules '
'WHERE id=?',
(source,)
)
first_party, level = cursor.fetchone()
level += 1
def set_asn(self, self.enter_step(f'set_{table}_select')
asn_str: str, cursor.execute(select_query, prep)
updated: int,
is_first_party: bool = None, rules_prep: typing.Dict[str, DbValue] = {
source: Path = None) -> None: "source": source,
self.enter_step('set_asn') "updated": updated,
if is_first_party: "first_party": first_party,
raise NotImplementedError "level": level,
path = self.pack_asn(asn_str) }
match = AsnNode()
match.set( # If the entry already exists
updated, for entry, in cursor: # only one
0, self.enter_step(f'set_{table}_update')
source or RulePath() rules_prep['entry'] = entry
cursor.execute(
'UPDATE rules SET '
'source=:source, updated=:updated, '
'first_party=:first_party, level=:level '
'WHERE id=:entry AND (updated<:updated OR '
'first_party<:first_party OR level<:level)',
rules_prep
)
# Only update if any of the following:
# - the entry is outdataed
# - the entry was not a first_party but this is
# - this is closer to the original rule
return
# If it does not exist
self.enter_step(f'set_{table}_insert')
cursor.execute(
'INSERT INTO rules '
'(source, updated, first_party, level) '
'VALUES (:source, :updated, :first_party, :level) ',
rules_prep
) )
self.asns[path.asn] = match cursor.execute('SELECT id FROM rules WHERE rowid=?',
(cursor.lastrowid,))
for entry, in cursor: # only one
prep['entry'] = entry
cursor.execute(insert_query, prep)
return
assert False
def _set_ip4(self, def set_hostname(self, hostname: str,
ip4: Ip4Path, *args: typing.Any, **kwargs: typing.Any) -> None:
updated: int, self.enter_step('set_hostname_prepare')
is_first_party: bool = None, prep: typing.Dict[str, DbValue] = {
source: Path = None) -> None: 'val': self.pack_hostname(hostname),
if is_first_party: }
raise NotImplementedError self._set_generic(
dic = self.ip4tree 'hostname',
for i in reversed(range(ip4.prefixlen)): 'SELECT entry FROM hostname WHERE val=:val',
part = (ip4.value >> i) & 0b1 'INSERT INTO hostname (val, entry) '
if dic.match.active(): 'VALUES (:val, :entry)',
# Refuse to add ip4* whose network is already matching prep,
return *args, **kwargs
next_dic = dic.children[part]
if next_dic is None:
next_dic = IpTreeNode()
dic.children[part] = next_dic
dic = next_dic
dic.match.set(
updated,
0, # TODO Level
source or RulePath(),
) )
def set_ip4address(self, def set_asn(self, asn: str,
ip4address_str: str, *args: typing.Any, **kwargs: typing.Any) -> None:
*args: typing.Any, **kwargs: typing.Any self.enter_step('set_asn_prepare')
) -> None: try:
self.enter_step('set_ip4add_pack') asn_prep = self.pack_asn(asn)
ip4 = self.pack_ip4address(ip4address_str) except ValueError:
self.enter_step('set_ip4add_brws') self.log.error("Invalid asn: %s", asn)
self._set_ip4(ip4, *args, **kwargs) return
prep: typing.Dict[str, DbValue] = {
'val': asn_prep,
}
self._set_generic(
'asn',
'SELECT entry FROM asn WHERE val=:val',
'INSERT INTO asn (val, entry) '
'VALUES (:val, :entry)',
prep,
*args, **kwargs
)
def set_ip4network(self, def set_ip4address(self, ip4address: str,
ip4network_str: str, *args: typing.Any, **kwargs: typing.Any) -> None:
*args: typing.Any, **kwargs: typing.Any self.enter_step('set_ip4add_prepare')
) -> None: try:
self.enter_step('set_ip4net_pack') ip4address_prep = self.pack_ip4address(ip4address)
ip4 = self.pack_ip4network(ip4network_str) except (ValueError, IndexError):
self.enter_step('set_ip4net_brws') self.log.error("Invalid ip4address: %s", ip4address)
self._set_ip4(ip4, *args, **kwargs) return
prep: typing.Dict[str, DbValue] = {
'val': ip4address_prep,
}
self._set_generic(
'ip4add',
'SELECT entry FROM ip4address WHERE val=:val',
'INSERT INTO ip4address (val, entry) '
'VALUES (:val, :entry)',
prep,
*args, **kwargs
)
def set_zone(self, zone: str,
*args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_zone_prepare')
prep: typing.Dict[str, DbValue] = {
'val': self.pack_zone(zone),
}
self._set_generic(
'zone',
'SELECT entry FROM zone WHERE val=:val',
'INSERT INTO zone (val, entry) '
'VALUES (:val, :entry)',
prep,
*args, **kwargs
)
def set_ip4network(self, ip4network: str,
*args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_ip4net_prepare')
try:
ip4network_prep = self.pack_ip4network(ip4network)
except (ValueError, IndexError):
self.log.error("Invalid ip4network: %s", ip4network)
return
prep: typing.Dict[str, DbValue] = {
'mini': ip4network_prep[0],
'maxi': ip4network_prep[1],
}
self._set_generic(
'ip4net',
'SELECT entry FROM ip4network WHERE mini=:mini AND maxi=:maxi',
'INSERT INTO ip4network (mini, maxi, entry) '
'VALUES (:mini, :maxi, :entry)',
prep,
*args, **kwargs
)
if __name__ == '__main__':
# Parsing arguments
parser = argparse.ArgumentParser(
description="Database operations")
parser.add_argument(
'-i', '--initialize', action='store_true',
help="Reconstruct the whole database")
parser.add_argument(
'-p', '--prune', action='store_true',
help="Remove old (+6 months) entries from database")
parser.add_argument(
'-r', '--references', action='store_true',
help="Update the reference count")
args = parser.parse_args()
DB = Database(write=True)
if args.initialize:
DB.initialize()
if args.prune:
DB.prune(before=int(time.time()) - 60*60*24*31*6)
if args.references and not args.prune:
DB.update_references()
DB.close()

59
database_schema.sql Normal file
View file

@ -0,0 +1,59 @@
-- Remember to increment DB_VERSION
-- in database.py on changes to this file
CREATE TABLE rules (
id INTEGER PRIMARY KEY AUTOINCREMENT,
source INTEGER, -- The rule this one is based on
updated INTEGER, -- If the row was updated during last data import (0: No, 1: Yes)
first_party INTEGER, -- 1: this blocks a first party for sure, 0: maybe
refs INTEGER, -- Number of entries issued from this one
level INTEGER, -- Level of recursion to the root source rule (used for source priority)
FOREIGN KEY (source) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX rules_source ON rules (source); -- for references recounting
CREATE INDEX rules_updated ON rules (updated); -- for pruning
CREATE INDEX rules_level_firstparty ON rules (level, first_party); -- for counting rules
CREATE TABLE asn (
val INTEGER PRIMARY KEY,
entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX asn_entry ON asn (entry); -- for explainations
CREATE TABLE hostname (
val TEXT PRIMARY KEY, -- rev'd, ends with a dot (for consistency with zone)
entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX hostname_entry ON hostname (entry); -- for explainations
CREATE TABLE zone (
val TEXT PRIMARY KEY, -- rev'd, ends with a dot (for easier matching)
entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX zone_entry ON zone (entry); -- for explainations
CREATE TABLE ip4address (
val INTEGER PRIMARY KEY,
entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX ip4address_entry ON ip4address (entry); -- for explainations
CREATE TABLE ip4network (
-- val TEXT PRIMARY KEY,
mini INTEGER,
maxi INTEGER,
entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX ip4network_minmax ON ip4network (mini, maxi);
CREATE INDEX ip4network_entry ON ip4network (entry); -- for explainations
-- Store various things
CREATE TABLE meta (
key TEXT PRIMARY KEY,
value integer
);

View file

@ -45,3 +45,5 @@ if __name__ == '__main__':
explain=args.explain, explain=args.explain,
): ):
print(domain, file=args.output) print(domain, file=args.output)
DB.close()

View file

@ -31,22 +31,23 @@ if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
DB = database.Database() DB = database.Database()
DBW = database.Database(write=True)
for path in DB.list_asn(): for asn, entry in DB.list_asn():
asn_str = database.Database.unpack_asn(path)
DB.enter_step('asn_get_ranges') DB.enter_step('asn_get_ranges')
for prefix in get_ranges(asn_str): for prefix in get_ranges(asn):
parsed_prefix: IPNetwork = ipaddress.ip_network(prefix) parsed_prefix: IPNetwork = ipaddress.ip_network(prefix)
if parsed_prefix.version == 4: if parsed_prefix.version == 4:
DB.set_ip4network( DBW.set_ip4network(
prefix, prefix,
source=path, source=entry,
updated=int(time.time()) updated=int(time.time())
) )
log.info('Added %s from %s (%s)', prefix, asn_str, path) log.info('Added %s from %s (id=%s)', prefix, asn, entry)
elif parsed_prefix.version == 6: elif parsed_prefix.version == 6:
log.warning('Unimplemented prefix version: %s', prefix) log.warning('Unimplemented prefix version: %s', prefix)
else: else:
log.error('Unknown prefix version: %s', prefix) log.error('Unknown prefix version: %s', prefix)
DB.save() DB.close()
DBW.close()

View file

@ -1,147 +0,0 @@
#!/usr/bin/env python3
import argparse
import database
import logging
import sys
import typing
import enum
RecordType = enum.Enum('RecordType', 'A AAAA CNAME PTR')
Record = typing.Tuple[RecordType, int, str, str]
# select, write
FUNCTION_MAP: typing.Any = {
RecordType.A: (
database.Database.get_ip4,
database.Database.set_hostname,
),
RecordType.CNAME: (
database.Database.get_domain,
database.Database.set_hostname,
),
RecordType.PTR: (
database.Database.get_domain,
database.Database.set_ip4address,
),
}
class Parser():
def __init__(self, buf: typing.Any) -> None:
self.buf = buf
self.log = logging.getLogger('parser')
self.db = database.Database()
def end(self) -> None:
self.db.save()
def register(self,
rtype: RecordType,
updated: int,
name: str,
value: str
) -> None:
self.db.enter_step('register')
select, write = FUNCTION_MAP[rtype]
for source in select(self.db, value):
# write(self.db, name, updated, source=source)
write(self.db, name, updated)
def consume(self) -> None:
raise NotImplementedError
class Rapid7Parser(Parser):
TYPES = {
'a': RecordType.A,
'aaaa': RecordType.AAAA,
'cname': RecordType.CNAME,
'ptr': RecordType.PTR,
}
def consume(self) -> None:
data = dict()
for line in self.buf:
self.db.enter_step('parse_rapid7')
split = line.split('"')
for k in range(1, 14, 4):
key = split[k]
val = split[k+2]
data[key] = val
self.register(
Rapid7Parser.TYPES[data['type']],
int(data['timestamp']),
data['name'],
data['value']
)
class DnsMassParser(Parser):
# dnsmass --output Snrql
# --retry REFUSED,SERVFAIL --resolvers nameservers-ipv4
TYPES = {
'A': (RecordType.A, -1, None),
'AAAA': (RecordType.AAAA, -1, None),
'CNAME': (RecordType.CNAME, -1, -1),
}
def consume(self) -> None:
self.db.enter_step('parse_dnsmass')
timestamp = 0
header = True
for line in self.buf:
line = line[:-1]
if not line:
header = True
continue
split = line.split(' ')
try:
if header:
timestamp = int(split[1])
header = False
else:
dtype, name_offset, value_offset = \
DnsMassParser.TYPES[split[1]]
self.register(
dtype,
timestamp,
split[0][:name_offset],
split[2][:value_offset],
)
self.db.enter_step('parse_dnsmass')
except KeyError:
continue
PARSERS = {
'rapid7': Rapid7Parser,
'dnsmass': DnsMassParser,
}
if __name__ == '__main__':
# Parsing arguments
log = logging.getLogger('feed_dns')
args_parser = argparse.ArgumentParser(
description="TODO")
args_parser.add_argument(
'parser',
choices=PARSERS.keys(),
help="TODO")
args_parser.add_argument(
'-i', '--input', type=argparse.FileType('r'), default=sys.stdin,
help="TODO")
args = args_parser.parse_args()
parser = PARSERS[args.parser](args.input)
try:
parser.consume()
except KeyboardInterrupt:
pass
parser.end()

View file

@ -1,202 +1,64 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse
import database import database
import logging import argparse
import sys import sys
import typing import logging
import multiprocessing import csv
import enum import json
Record = typing.Tuple[typing.Callable, typing.Callable, int, str, str]
# select, write
FUNCTION_MAP: typing.Any = {
'a': (
database.Database.get_ip4,
database.Database.set_hostname,
),
'cname': (
database.Database.get_domain,
database.Database.set_hostname,
),
'ptr': (
database.Database.get_domain,
database.Database.set_ip4address,
),
}
class Writer(multiprocessing.Process):
def __init__(self,
recs_queue: multiprocessing.Queue,
index: int = 0):
super(Writer, self).__init__()
self.log = logging.getLogger(f'wr')
self.recs_queue = recs_queue
def run(self) -> None:
self.db = database.Database()
self.db.log = logging.getLogger(f'wr')
self.db.enter_step('block_wait')
block: typing.List[Record]
for block in iter(self.recs_queue.get, None):
record: Record
for record in block:
select, write, updated, name, value = record
self.db.enter_step('feed_switch')
try:
for source in select(self.db, value):
# write(self.db, name, updated, source=source)
write(self.db, name, updated)
except ValueError:
self.log.exception("Cannot execute: %s", record)
self.db.enter_step('block_wait')
self.db.enter_step('end')
self.db.save()
class Parser():
def __init__(self,
buf: typing.Any,
recs_queue: multiprocessing.Queue,
block_size: int,
):
super(Parser, self).__init__()
self.buf = buf
self.log = logging.getLogger('pr')
self.recs_queue = recs_queue
self.block: typing.List[Record] = list()
self.block_size = block_size
self.prof = database.Profiler()
self.prof.log = logging.getLogger('pr')
def register(self, record: Record) -> None:
self.prof.enter_step('register')
self.block.append(record)
if len(self.block) >= self.block_size:
self.prof.enter_step('put_block')
self.recs_queue.put(self.block)
self.block = list()
def run(self) -> None:
self.consume()
self.recs_queue.put(self.block)
self.prof.profile()
def consume(self) -> None:
raise NotImplementedError
class Rapid7Parser(Parser):
def consume(self) -> None:
data = dict()
for line in self.buf:
self.prof.enter_step('parse_rapid7')
split = line.split('"')
try:
for k in range(1, 14, 4):
key = split[k]
val = split[k+2]
data[key] = val
select, writer = FUNCTION_MAP[data['type']]
record = (
select,
writer,
int(data['timestamp']),
data['name'],
data['value']
)
except IndexError:
self.log.exception("Cannot parse: %s", line)
self.register(record)
class DnsMassParser(Parser):
# dnsmass --output Snrql
# --retry REFUSED,SERVFAIL --resolvers nameservers-ipv4
TYPES = {
'A': (FUNCTION_MAP['a'][0], FUNCTION_MAP['a'][1], -1, None),
# 'AAAA': (FUNCTION_MAP['aaaa'][0], FUNCTION_MAP['aaaa'][1], -1, None),
'CNAME': (FUNCTION_MAP['cname'][0], FUNCTION_MAP['cname'][1], -1, -1),
}
def consume(self) -> None:
self.prof.enter_step('parse_dnsmass')
timestamp = 0
header = True
for line in self.buf:
line = line[:-1]
if not line:
header = True
continue
split = line.split(' ')
try:
if header:
timestamp = int(split[1])
header = False
else:
select, write, name_offset, value_offset = \
DnsMassParser.TYPES[split[1]]
record = (
select,
write,
timestamp,
split[0][:name_offset],
split[2][:value_offset],
)
self.register(record)
self.prof.enter_step('parse_dnsmass')
except KeyError:
continue
PARSERS = {
'rapid7': Rapid7Parser,
'dnsmass': DnsMassParser,
}
if __name__ == '__main__': if __name__ == '__main__':
# Parsing arguments # Parsing arguments
log = logging.getLogger('feed_dns') log = logging.getLogger('feed_dns')
args_parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="TODO") description="TODO")
args_parser.add_argument( parser.add_argument(
'parser', # '-i', '--input', type=argparse.FileType('rb'), default=sys.stdin.buffer,
choices=PARSERS.keys(),
help="TODO")
args_parser.add_argument(
'-i', '--input', type=argparse.FileType('r'), default=sys.stdin, '-i', '--input', type=argparse.FileType('r'), default=sys.stdin,
help="TODO") help="TODO")
args_parser.add_argument( args = parser.parse_args()
'-j', '--workers', type=int, default=4,
help="TODO")
args_parser.add_argument(
'-b', '--block-size', type=int, default=100,
help="TODO")
args_parser.add_argument(
'-q', '--queue-size', type=int, default=10,
help="TODO")
args = args_parser.parse_args()
recs_queue: multiprocessing.Queue = multiprocessing.Queue( DB = database.Database(write=True)
maxsize=args.queue_size)
writer = Writer(recs_queue) try:
writer.start() DB.enter_step('iowait')
for row in csv.reader(args.input):
# for line in args.input:
DB.enter_step('feed_csv_parse')
dtype, timestamp, name, value = row
# DB.enter_step('feed_json_parse')
# data = json.loads(line)
# dtype = data['type'][0]
# # timestamp = data['timestamp']
# name = data['name']
# value = data['value']
parser = PARSERS[args.parser](args.input, recs_queue, args.block_size) DB.enter_step('feed_switch')
parser.run() if dtype == 'a':
for rule in DB.get_ip4(value):
if not list(DB.get_domain_in_zone(name)):
recs_queue.put(None) DB.set_hostname(name, source=rule,
writer.join() updated=int(timestamp))
# updated=int(data['timestamp']))
elif dtype == 'c':
for rule in DB.get_domain(value):
if not list(DB.get_domain_in_zone(name)):
DB.set_hostname(name, source=rule,
updated=int(timestamp))
# updated=int(data['timestamp']))
elif dtype == 'p':
for rule in DB.get_domain(value):
if not list(DB.get_ip4_in_network(name)):
DB.set_ip4address(name, source=rule,
updated=int(timestamp))
# updated=int(data['timestamp']))
else:
raise NotImplementedError(f'Type: {dtype}')
DB.enter_step('iowait')
except KeyboardInterrupt:
log.warning("Interupted.")
pass
DB.close()

View file

@ -28,15 +28,15 @@ if __name__ == '__main__':
help="The input only comes from verified first-party sources") help="The input only comes from verified first-party sources")
args = parser.parse_args() args = parser.parse_args()
DB = database.Database() DB = database.Database(write=True)
fun = FUNCTION_MAP[args.type] fun = FUNCTION_MAP[args.type]
for rule in args.input: for rule in args.input:
fun(DB, fun(DB,
rule.strip(), rule.strip(),
# is_first_party=args.first_party, is_first_party=args.first_party,
updated=int(time.time()), updated=int(time.time()),
) )
DB.save() DB.close()

View file

@ -18,7 +18,7 @@ log "Retrieving rules…"
rm -f rules*/*.cache.* rm -f rules*/*.cache.*
dl https://easylist.to/easylist/easyprivacy.txt rules_adblock/easyprivacy.cache.txt dl https://easylist.to/easylist/easyprivacy.txt rules_adblock/easyprivacy.cache.txt
# From firebog.net Tracking & Telemetry Lists # From firebog.net Tracking & Telemetry Lists
# dl https://v.firebog.net/hosts/Prigent-Ads.txt rules/prigent-ads.cache.list dl https://v.firebog.net/hosts/Prigent-Ads.txt rules/prigent-ads.cache.list
# dl https://gitlab.com/quidsup/notrack-blocklists/raw/master/notrack-blocklist.txt rules/notrack-blocklist.cache.list # dl https://gitlab.com/quidsup/notrack-blocklists/raw/master/notrack-blocklist.txt rules/notrack-blocklist.cache.list
# False positives: https://github.com/WaLLy3K/wally3k.github.io/issues/73 -> 69.media.tumblr.com chicdn.net # False positives: https://github.com/WaLLy3K/wally3k.github.io/issues/73 -> 69.media.tumblr.com chicdn.net
dl https://raw.githubusercontent.com/StevenBlack/hosts/master/data/add.2o7Net/hosts rules_hosts/add2o7.cache.txt dl https://raw.githubusercontent.com/StevenBlack/hosts/master/data/add.2o7Net/hosts rules_hosts/add2o7.cache.txt
@ -51,4 +51,3 @@ then
else else
mv temp/cisco-umbrella_popularity.fresh.list subdomains/cisco-umbrella_popularity.cache.list mv temp/cisco-umbrella_popularity.fresh.list subdomains/cisco-umbrella_popularity.cache.list
fi fi
dl https://www.orwell1984.today/cname/eulerian.net.txt subdomains/orwell-eulerian-cname-list.cache.list

View file

@ -5,12 +5,11 @@ function log() {
} }
log "Importing rules…" log "Importing rules…"
BEFORE="$(date +%s)" cat rules_adblock/*.txt | grep -v '^!' | grep -v '^\[Adblock' | ./adblock_to_domain_list.py | ./feed_rules.py zone
# cat rules_adblock/*.txt | grep -v '^!' | grep -v '^\[Adblock' | ./adblock_to_domain_list.py | ./feed_rules.py zone cat rules_hosts/*.txt | grep -v '^#' | grep -v '^$' | cut -d ' ' -f2 | ./feed_rules.py zone
# cat rules_hosts/*.txt | grep -v '^#' | grep -v '^$' | cut -d ' ' -f2 | ./feed_rules.py zone cat rules/*.list | grep -v '^#' | grep -v '^$' | ./feed_rules.py zone
# cat rules/*.list | grep -v '^#' | grep -v '^$' | ./feed_rules.py zone cat rules_ip/*.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py ip4network
# cat rules_ip/*.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py ip4network cat rules_asn/*.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py asn
# cat rules_asn/*.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py asn
cat rules/first-party.list | grep -v '^#' | grep -v '^$' | ./feed_rules.py zone --first-party cat rules/first-party.list | grep -v '^#' | grep -v '^$' | ./feed_rules.py zone --first-party
cat rules_ip/first-party.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py ip4network --first-party cat rules_ip/first-party.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py ip4network --first-party
@ -18,5 +17,3 @@ cat rules_asn/first-party.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py as
./feed_asn.py ./feed_asn.py
log "Pruning old rules…"
./db.py --prune --prune-before "$BEFORE" --prune-base

36
json_to_csv.py Executable file
View file

@ -0,0 +1,36 @@
#!/usr/bin/env python3
import argparse
import sys
import logging
import json
import csv
if __name__ == '__main__':
# Parsing arguments
log = logging.getLogger('json_to_csv')
parser = argparse.ArgumentParser(
description="TODO")
parser.add_argument(
# '-i', '--input', type=argparse.FileType('rb'), default=sys.stdin.buffer,
'-i', '--input', type=argparse.FileType('r'), default=sys.stdin,
help="TODO")
parser.add_argument(
# '-i', '--output', type=argparse.FileType('wb'), default=sys.stdout.buffer,
'-o', '--output', type=argparse.FileType('w'), default=sys.stdout,
help="TODO")
args = parser.parse_args()
writer = csv.writer(args.output)
for line in args.input:
data = json.loads(line)
try:
writer.writerow([
data['type'][0],
data['timestamp'],
data['name'],
data['value']])
except IndexError:
log.error('Could not parse line: %s', line)
pass

View file

@ -9,11 +9,11 @@ function log() {
# TODO Fetch 'em # TODO Fetch 'em
log "Reading PTR records…" log "Reading PTR records…"
pv ptr.json.gz | gunzip | ./feed_dns.py pv ptr.json.gz | gunzip | ./json_to_csv.py | ./feed_dns.py
log "Reading A records…" log "Reading A records…"
pv a.json.gz | gunzip | ./feed_dns.py pv a.json.gz | gunzip | ./json_to_csv.py | ./feed_dns.py
log "Reading CNAME records…" log "Reading CNAME records…"
pv cname.json.gz | gunzip | ./feed_dns.py pv cname.json.gz | gunzip | ./json_to_csv.py | ./feed_dns.py
log "Pruning old data…" log "Pruning old data…"
./database.py --prune ./database.py --prune

264
resolve_subdomains.py Executable file
View file

@ -0,0 +1,264 @@
#!/usr/bin/env python3
"""
From a list of subdomains, output only
the ones resolving to a first-party tracker.
"""
import argparse
import logging
import os
import queue
import sys
import threading
import typing
import time
import coloredlogs
import dns.exception
import dns.resolver
DNS_TIMEOUT = 5.0
NUMBER_TRIES = 5
class Worker(threading.Thread):
"""
Worker process for a DNS resolver.
Will resolve DNS to match first-party subdomains.
"""
def change_nameserver(self) -> None:
"""
Assign a this worker another nameserver from the queue.
"""
server = None
while server is None:
try:
server = self.orchestrator.nameservers_queue.get(block=False)
except queue.Empty:
self.orchestrator.refill_nameservers_queue()
self.log.info("Using nameserver: %s", server)
self.resolver.nameservers = [server]
def __init__(self,
orchestrator: 'Orchestrator',
index: int = 0):
super(Worker, self).__init__()
self.log = logging.getLogger(f'worker{index:03d}')
self.orchestrator = orchestrator
self.resolver = dns.resolver.Resolver()
self.change_nameserver()
def resolve_subdomain(self, subdomain: str) -> typing.Optional[
typing.List[
dns.rrset.RRset
]
]:
"""
Returns the resolution chain of the subdomain to an A record,
including any intermediary CNAME.
The last element is an IP address.
Returns None if the nameserver was unable to satisfy the request.
Returns [] if the requests points to nothing.
"""
self.log.debug("Querying %s", subdomain)
try:
query = self.resolver.query(subdomain, 'A', lifetime=DNS_TIMEOUT)
except dns.resolver.NXDOMAIN:
return []
except dns.resolver.NoAnswer:
return []
except dns.resolver.YXDOMAIN:
self.log.warning("Query name too long for %s", subdomain)
return None
except dns.resolver.NoNameservers:
# NOTE Most of the time this error message means that the domain
# does not exists, but sometimes it means the that the server
# itself is broken. So we count on the retry logic.
self.log.warning("All nameservers broken for %s", subdomain)
return None
except dns.exception.Timeout:
# NOTE Same as above
self.log.warning("Timeout for %s", subdomain)
return None
except dns.name.EmptyLabel:
self.log.warning("Empty label for %s", subdomain)
return None
return query.response.answer
def run(self) -> None:
self.log.info("Started")
subdomain: str
for subdomain in iter(self.orchestrator.subdomains_queue.get, None):
for _ in range(NUMBER_TRIES):
resolved = self.resolve_subdomain(subdomain)
# Retry with another nameserver if error
if resolved is None:
self.change_nameserver()
else:
break
# If it wasn't found after multiple tries
if resolved is None:
self.log.error("Gave up on %s", subdomain)
resolved = []
assert isinstance(resolved, list)
self.orchestrator.results_queue.put(resolved)
self.orchestrator.results_queue.put(None)
self.log.info("Stopped")
class Orchestrator():
"""
Orchestrator of the different Worker threads.
"""
def refill_nameservers_queue(self) -> None:
"""
Re-fill the given nameservers into the nameservers queue.
Done every-time the queue is empty, making it
basically looping and infinite.
"""
# Might be in a race condition but that's probably fine
for nameserver in self.nameservers:
self.nameservers_queue.put(nameserver)
self.log.info("Refilled nameserver queue")
def __init__(self, subdomains: typing.Iterable[str],
nameservers: typing.List[str] = None,
nb_workers: int = 1,
):
self.log = logging.getLogger('orchestrator')
self.subdomains = subdomains
self.nb_workers = nb_workers
# Use interal resolver by default
self.nameservers = nameservers or dns.resolver.Resolver().nameservers
self.subdomains_queue: queue.Queue = queue.Queue(
maxsize=self.nb_workers)
self.results_queue: queue.Queue = queue.Queue()
self.nameservers_queue: queue.Queue = queue.Queue()
self.refill_nameservers_queue()
def fill_subdomain_queue(self) -> None:
"""
Read the subdomains in input and put them into the queue.
Done in a thread so we can both:
- yield the results as they come
- not store all the subdomains at once
"""
self.log.info("Started reading subdomains")
# Send data to workers
for subdomain in self.subdomains:
self.subdomains_queue.put(subdomain)
self.log.info("Finished reading subdomains")
# Send sentinel to each worker
# sentinel = None ~= EOF
for _ in range(self.nb_workers):
self.subdomains_queue.put(None)
@staticmethod
def format_rrset(rrset: dns.rrset.RRset) -> typing.Iterable[str]:
if rrset.rdtype == dns.rdatatype.CNAME:
dtype = 'c'
elif rrset.rdtype == dns.rdatatype.A:
dtype = 'a'
else:
raise NotImplementedError
name = rrset.name.to_text()[:-1]
for item in rrset.items:
value = item.to_text()
if rrset.rdtype == dns.rdatatype.CNAME:
value = value[:-1]
yield f'{dtype},{int(time.time())},{name},{value}\n'
def run(self) -> typing.Iterable[str]:
"""
Yield the results.
"""
# Create workers
self.log.info("Creating workers")
for i in range(self.nb_workers):
Worker(self, i).start()
fill_thread = threading.Thread(target=self.fill_subdomain_queue)
fill_thread.start()
# Wait for one sentinel per worker
# In the meantime output results
for _ in range(self.nb_workers):
resolved: typing.List[dns.rrset.RRset]
for resolved in iter(self.results_queue.get, None):
for rrset in resolved:
yield from self.format_rrset(rrset)
self.log.info("Waiting for reader thread")
fill_thread.join()
self.log.info("Done!")
def main() -> None:
"""
Main function when used directly.
Read the subdomains provided and output it,
the last CNAME resolved and the IP adress it resolves to.
Takes as an input a filename (or nothing, for stdin),
and as an output a filename (or nothing, for stdout).
The input must be a subdomain per line, the output is a TODO
Use the file `nameservers` as the list of nameservers
to use, or else it will use the system defaults.
"""
# Initialization
coloredlogs.install(
level='DEBUG',
fmt='%(asctime)s %(name)s %(levelname)s %(message)s'
)
# Parsing arguments
parser = argparse.ArgumentParser(
description="Massively resolves subdomains and store them in a file.")
parser.add_argument(
'-i', '--input', type=argparse.FileType('r'), default=sys.stdin,
help="Input file with one subdomain per line")
parser.add_argument(
'-o', '--output', type=argparse.FileType('w'), default=sys.stdout,
help="Outptut file with DNS chains")
parser.add_argument(
'-n', '--nameservers', default='nameservers',
help="File with one nameserver per line")
parser.add_argument(
'-j', '--workers', type=int, default=512,
help="Number of threads to use")
args = parser.parse_args()
# Cleaning input
iterator = iter(args.input)
iterator = map(str.strip, iterator)
iterator = filter(None, iterator)
# Reading nameservers
servers: typing.List[str] = list()
if os.path.isfile(args.nameservers):
servers = open(args.nameservers).readlines()
servers = list(filter(None, map(str.strip, servers)))
for resolved in Orchestrator(
iterator,
servers,
nb_workers=args.workers
).run():
args.output.write(resolved)
if __name__ == '__main__':
main()