Workflow: Multiprocess

Still trying.
It's better than multithread though.

Merge branch 'newworkflow' into newworkflow_threaded
This commit is contained in:
Geoffrey Frogeye 2019-12-14 17:27:46 +01:00
commit 189deeb559
Signed by: geoffrey
GPG key ID: D8A7ECA00A8CD3DD
15 changed files with 512 additions and 279 deletions

2
.gitignore vendored
View file

@ -3,5 +3,3 @@
*.db-journal *.db-journal
nameservers nameservers
nameservers.head nameservers.head
*.o
*.so

View file

@ -12,6 +12,7 @@ import logging
import argparse import argparse
import coloredlogs import coloredlogs
import ipaddress import ipaddress
import math
coloredlogs.install( coloredlogs.install(
level='DEBUG', level='DEBUG',
@ -22,40 +23,44 @@ DbValue = typing.Union[None, int, float, str, bytes]
class Database(): class Database():
VERSION = 3 VERSION = 5
PATH = "blocking.db" PATH = "blocking.db"
def open(self) -> None: def open(self) -> None:
mode = 'rwc' if self.write else 'ro' mode = 'rwc' if self.write else 'ro'
uri = f'file:{self.PATH}?mode={mode}' uri = f'file:{self.PATH}?mode={mode}'
self.conn = sqlite3.connect(uri, uri=True) self.conn = sqlite3.connect(uri, uri=True)
self.cursor = self.conn.cursor() cursor = self.conn.cursor()
self.execute("PRAGMA foreign_keys = ON") cursor.execute("PRAGMA foreign_keys = ON")
# self.conn.create_function("prepare_ip4address", 1, self.conn.create_function("unpack_asn", 1,
# Database.prepare_ip4address, self.unpack_asn,
# deterministic=True) deterministic=True)
self.conn.create_function("unpack_ip4address", 1,
self.unpack_ip4address,
deterministic=True)
self.conn.create_function("unpack_ip4network", 2,
self.unpack_ip4network,
deterministic=True)
self.conn.create_function("unpack_domain", 1, self.conn.create_function("unpack_domain", 1,
lambda s: s[:-1][::-1], lambda s: s[:-1][::-1],
deterministic=True) deterministic=True)
self.conn.create_function("format_zone", 1,
def execute(self, cmd: str, args: typing.Union[ lambda s: '*' + s[::-1],
typing.Tuple[DbValue, ...], deterministic=True)
typing.Dict[str, DbValue]] = None) -> None:
# self.log.debug(cmd)
# self.log.debug(args)
self.cursor.execute(cmd, args or tuple())
def get_meta(self, key: str) -> typing.Optional[int]: def get_meta(self, key: str) -> typing.Optional[int]:
cursor = self.conn.cursor()
try: try:
self.execute("SELECT value FROM meta WHERE key=?", (key,)) cursor.execute("SELECT value FROM meta WHERE key=?", (key,))
except sqlite3.OperationalError: except sqlite3.OperationalError:
return None return None
for ver, in self.cursor: for ver, in cursor:
return ver return ver
return None return None
def set_meta(self, key: str, val: int) -> None: def set_meta(self, key: str, val: int) -> None:
self.execute("INSERT INTO meta VALUES (?, ?) " cursor = self.conn.cursor()
cursor.execute("INSERT INTO meta VALUES (?, ?) "
"ON CONFLICT (key) DO " "ON CONFLICT (key) DO "
"UPDATE set value=?", "UPDATE set value=?",
(key, val, val)) (key, val, val))
@ -76,8 +81,9 @@ class Database():
os.unlink(self.PATH) os.unlink(self.PATH)
self.open() self.open()
self.log.info("Creating database version %d.", self.VERSION) self.log.info("Creating database version %d.", self.VERSION)
cursor = self.conn.cursor()
with open("database_schema.sql", 'r') as db_schema: with open("database_schema.sql", 'r') as db_schema:
self.cursor.executescript(db_schema.read()) cursor.executescript(db_schema.read())
self.set_meta('version', self.VERSION) self.set_meta('version', self.VERSION)
self.conn.commit() self.conn.commit()
@ -98,13 +104,6 @@ class Database():
version) version)
self.initialize() self.initialize()
updated = self.get_meta('updated')
if updated is None:
self.execute('SELECT max(updated) FROM rules')
data = self.cursor.fetchone()
updated, = data
self.updated = updated or 1
def enter_step(self, name: str) -> None: def enter_step(self, name: str) -> None:
now = time.perf_counter() now = time.perf_counter()
try: try:
@ -126,24 +125,32 @@ class Database():
self.log.debug(f"{'total':<20}: " self.log.debug(f"{'total':<20}: "
f"{total:9.2f} s ({1:7.2%})") f"{total:9.2f} s ({1:7.2%})")
def prepare_hostname(self, hostname: str) -> str: @staticmethod
def pack_hostname(hostname: str) -> str:
return hostname[::-1] + '.' return hostname[::-1] + '.'
def prepare_zone(self, zone: str) -> str: @staticmethod
return self.prepare_hostname(zone) def pack_zone(zone: str) -> str:
return Database.pack_hostname(zone)
@staticmethod @staticmethod
def prepare_asn(asn: str) -> int: def pack_asn(asn: str) -> int:
asn = asn.upper() asn = asn.upper()
if asn.startswith('AS'): if asn.startswith('AS'):
asn = asn[2:] asn = asn[2:]
return int(asn) return int(asn)
@staticmethod @staticmethod
def prepare_ip4address(address: str) -> int: def unpack_asn(asn: int) -> str:
return f'AS{asn}'
@staticmethod
def pack_ip4address(address: str) -> int:
total = 0 total = 0
for i, octet in enumerate(address.split('.')): for i, octet in enumerate(address.split('.')):
total += int(octet) << (3-i)*8 total += int(octet) << (3-i)*8
if total > 0xFFFFFFFF:
raise ValueError
return total return total
# return '{:02x}{:02x}{:02x}{:02x}'.format( # return '{:02x}{:02x}{:02x}{:02x}'.format(
# *[int(c) for c in address.split('.')]) # *[int(c) for c in address.split('.')])
@ -158,34 +165,78 @@ class Database():
# packed = ipaddress.ip_address(address).packed # packed = ipaddress.ip_address(address).packed
# return packed # return packed
def prepare_ip4network(self, network: str) -> typing.Tuple[int, int]: @staticmethod
# def prepare_ip4network(network: str) -> str: def unpack_ip4address(address: int) -> str:
return '.'.join(str((address >> (i * 8)) & 0xFF)
for i in reversed(range(4)))
@staticmethod
def pack_ip4network(network: str) -> typing.Tuple[int, int]:
# def pack_ip4network(network: str) -> str:
net = ipaddress.ip_network(network) net = ipaddress.ip_network(network)
mini = self.prepare_ip4address(net.network_address.exploded) mini = Database.pack_ip4address(net.network_address.exploded)
maxi = self.prepare_ip4address(net.broadcast_address.exploded) maxi = Database.pack_ip4address(net.broadcast_address.exploded)
# mini = net.network_address.packed # mini = net.network_address.packed
# maxi = net.broadcast_address.packed # maxi = net.broadcast_address.packed
return mini, maxi return mini, maxi
# return Database.prepare_ip4address(net.network_address.exploded)[:net.prefixlen] # return Database.pack_ip4address(net.network_address.exploded)[:net.prefixlen]
def expire(self) -> None: @staticmethod
self.enter_step('expire') def unpack_ip4network(mini: int, maxi: int) -> str:
self.updated += 1 addr = Database.unpack_ip4address(mini)
self.set_meta('updated', self.updated) prefixlen = 32-int(math.log2(maxi-mini+1))
return f'{addr}/{prefixlen}'
def update_references(self) -> None: def update_references(self) -> None:
self.enter_step('update_refs') self.enter_step('update_refs')
self.execute('UPDATE rules AS r SET refs=' cursor = self.conn.cursor()
cursor.execute('UPDATE rules AS r SET refs='
'(SELECT count(*) FROM rules ' '(SELECT count(*) FROM rules '
'WHERE source=r.id)') 'WHERE source=r.id)')
def prune(self) -> None: def prune(self, before: int, base_only: bool = False) -> None:
self.enter_step('prune') self.enter_step('prune')
self.execute('DELETE FROM rules WHERE updated<?', (self.updated,)) cursor = self.conn.cursor()
cmd = 'DELETE FROM rules WHERE updated<?'
if base_only:
cmd += ' AND level=0'
cursor.execute(cmd, (before,))
def export(self, first_party_only: bool = False, def explain(self, entry: int) -> str:
end_chain_only: bool = False) -> typing.Iterable[str]: # Format current
command = 'SELECT unpack_domain(val) FROM rules ' \ string = '???'
cursor = self.conn.cursor()
cursor.execute(
'SELECT unpack_asn(val) FROM asn WHERE entry=:entry '
'UNION '
'SELECT unpack_domain(val) FROM hostname WHERE entry=:entry '
'UNION '
'SELECT format_zone(val) FROM zone WHERE entry=:entry '
'UNION '
'SELECT unpack_ip4address(val) FROM ip4address WHERE entry=:entry '
'UNION '
'SELECT unpack_ip4network(mini, maxi) '
'FROM ip4network WHERE entry=:entry ',
{"entry": entry}
)
for val, in cursor: # only one
string = str(val)
string += f' #{entry}'
# Add source if any
cursor.execute('SELECT source FROM rules WHERE id=?', (entry,))
for source, in cursor:
if source:
string += f'{self.explain(source)}'
return string
def export(self,
first_party_only: bool = False,
end_chain_only: bool = False,
explain: bool = False,
) -> typing.Iterable[str]:
selection = 'entry' if explain else 'unpack_domain(val)'
command = f'SELECT {selection} FROM rules ' \
'INNER JOIN hostname ON rules.id = hostname.entry' 'INNER JOIN hostname ON rules.id = hostname.entry'
restrictions: typing.List[str] = list() restrictions: typing.List[str] = list()
if first_party_only: if first_party_only:
@ -194,16 +245,40 @@ class Database():
restrictions.append('rules.refs = 0') restrictions.append('rules.refs = 0')
if restrictions: if restrictions:
command += ' WHERE ' + ' AND '.join(restrictions) command += ' WHERE ' + ' AND '.join(restrictions)
if not explain:
command += ' ORDER BY unpack_domain(val) ASC' command += ' ORDER BY unpack_domain(val) ASC'
self.execute(command) cursor = self.conn.cursor()
for val, in self.cursor: cursor.execute(command)
for val, in cursor:
if explain:
yield self.explain(val)
else:
yield val yield val
def count_rules(self,
first_party_only: bool = False,
) -> str:
counts: typing.List[str] = list()
cursor = self.conn.cursor()
for table in ['asn', 'ip4network', 'ip4address', 'zone', 'hostname']:
command = f'SELECT count(*) FROM rules ' \
f'INNER JOIN {table} ON rules.id = {table}.entry ' \
'WHERE rules.level = 0'
if first_party_only:
command += ' AND first_party=1'
cursor.execute(command)
count, = cursor.fetchone()
if count > 0:
counts.append(f'{table}: {count}')
return ', '.join(counts)
def get_domain(self, domain: str) -> typing.Iterable[int]: def get_domain(self, domain: str) -> typing.Iterable[int]:
self.enter_step('get_domain_prepare') self.enter_step('get_domain_prepare')
domain_prep = self.prepare_hostname(domain) domain_prep = self.pack_hostname(domain)
cursor = self.conn.cursor()
self.enter_step('get_domain_select') self.enter_step('get_domain_select')
self.execute( cursor.execute(
'SELECT null, entry FROM hostname ' 'SELECT null, entry FROM hostname '
'WHERE val=:d ' 'WHERE val=:d '
'UNION ' 'UNION '
@ -214,22 +289,41 @@ class Database():
')', ')',
{'d': domain_prep} {'d': domain_prep}
) )
for val, entry in self.cursor: for val, entry in cursor:
self.enter_step('get_domain_confirm') self.enter_step('get_domain_confirm')
if not (val is None or domain_prep.startswith(val)): if not (val is None or domain_prep.startswith(val)):
continue continue
self.enter_step('get_domain_yield') self.enter_step('get_domain_yield')
yield entry yield entry
def get_domain_in_zone(self, domain: str) -> typing.Iterable[int]:
self.enter_step('get_domainiz_prepare')
domain_prep = self.pack_hostname(domain)
cursor = self.conn.cursor()
self.enter_step('get_domainiz_select')
cursor.execute(
'SELECT val, entry FROM zone '
'WHERE val<=:d '
'ORDER BY val DESC LIMIT 1',
{'d': domain_prep}
)
for val, entry in cursor:
self.enter_step('get_domainiz_confirm')
if not (val is None or domain_prep.startswith(val)):
continue
self.enter_step('get_domainiz_yield')
yield entry
def get_ip4(self, address: str) -> typing.Iterable[int]: def get_ip4(self, address: str) -> typing.Iterable[int]:
self.enter_step('get_ip4_prepare') self.enter_step('get_ip4_prepare')
try: try:
address_prep = self.prepare_ip4address(address) address_prep = self.pack_ip4address(address)
except (ValueError, IndexError): except (ValueError, IndexError):
self.log.error("Invalid ip4address: %s", address) self.log.error("Invalid ip4address: %s", address)
return return
cursor = self.conn.cursor()
self.enter_step('get_ip4_select') self.enter_step('get_ip4_select')
self.execute( cursor.execute(
'SELECT entry FROM ip4address ' 'SELECT entry FROM ip4address '
# 'SELECT null, entry FROM ip4address ' # 'SELECT null, entry FROM ip4address '
'WHERE val=:a ' 'WHERE val=:a '
@ -244,7 +338,7 @@ class Database():
'WHERE :a BETWEEN mini AND maxi ', 'WHERE :a BETWEEN mini AND maxi ',
{'a': address_prep} {'a': address_prep}
) )
for val, entry in self.cursor: for entry, in cursor:
# self.enter_step('get_ip4_confirm') # self.enter_step('get_ip4_confirm')
# if not (val is None or val.startswith(address_prep)): # if not (val is None or val.startswith(address_prep)):
# # PERF startswith but from the end # # PERF startswith but from the end
@ -252,11 +346,29 @@ class Database():
self.enter_step('get_ip4_yield') self.enter_step('get_ip4_yield')
yield entry yield entry
def get_ip4_in_network(self, address: str) -> typing.Iterable[int]:
self.enter_step('get_ip4in_prepare')
try:
address_prep = self.pack_ip4address(address)
except (ValueError, IndexError):
self.log.error("Invalid ip4address: %s", address)
return
cursor = self.conn.cursor()
self.enter_step('get_ip4in_select')
cursor.execute(
'SELECT entry FROM ip4network '
'WHERE :a BETWEEN mini AND maxi ',
{'a': address_prep}
)
for entry, in cursor:
self.enter_step('get_ip4in_yield')
yield entry
def list_asn(self) -> typing.Iterable[typing.Tuple[str, int]]: def list_asn(self) -> typing.Iterable[typing.Tuple[str, int]]:
cursor = self.conn.cursor()
self.enter_step('list_asn_select') self.enter_step('list_asn_select')
self.enter_step('get_domain_select') cursor.execute('SELECT val, entry FROM asn')
self.execute('SELECT val, entry FROM asn') for val, entry in cursor:
for val, entry in self.cursor:
yield f'AS{val}', entry yield f'AS{val}', entry
def _set_generic(self, def _set_generic(self,
@ -264,6 +376,7 @@ class Database():
select_query: str, select_query: str,
insert_query: str, insert_query: str,
prep: typing.Dict[str, DbValue], prep: typing.Dict[str, DbValue],
updated: int,
is_first_party: bool = False, is_first_party: bool = False,
source: int = None, source: int = None,
) -> None: ) -> None:
@ -271,34 +384,36 @@ class Database():
# here abstraction > performaces # here abstraction > performaces
# Fields based on the source # Fields based on the source
self.enter_step(f'set_{table}_prepare')
cursor = self.conn.cursor()
if source is None: if source is None:
first_party = int(is_first_party) first_party = int(is_first_party)
level = 0 level = 0
else: else:
self.enter_step(f'set_{table}_source') self.enter_step(f'set_{table}_source')
self.execute( cursor.execute(
'SELECT first_party, level FROM rules ' 'SELECT first_party, level FROM rules '
'WHERE id=?', 'WHERE id=?',
(source,) (source,)
) )
first_party, level = self.cursor.fetchone() first_party, level = cursor.fetchone()
level += 1 level += 1
self.enter_step(f'set_{table}_select') self.enter_step(f'set_{table}_select')
self.execute(select_query, prep) cursor.execute(select_query, prep)
rules_prep = { rules_prep: typing.Dict[str, DbValue] = {
"source": source, "source": source,
"updated": self.updated, "updated": updated,
"first_party": first_party, "first_party": first_party,
"level": level, "level": level,
} }
# If the entry already exists # If the entry already exists
for entry, in self.cursor: # only one for entry, in cursor: # only one
self.enter_step(f'set_{table}_update') self.enter_step(f'set_{table}_update')
rules_prep['entry'] = entry rules_prep['entry'] = entry
self.execute( cursor.execute(
'UPDATE rules SET ' 'UPDATE rules SET '
'source=:source, updated=:updated, ' 'source=:source, updated=:updated, '
'first_party=:first_party, level=:level ' 'first_party=:first_party, level=:level '
@ -314,23 +429,18 @@ class Database():
# If it does not exist # If it does not exist
if source is not None:
self.enter_step(f'set_{table}_incsrc')
self.execute('UPDATE rules SET refs = refs + 1 WHERE id=?',
(source,))
self.enter_step(f'set_{table}_insert') self.enter_step(f'set_{table}_insert')
self.execute( cursor.execute(
'INSERT INTO rules ' 'INSERT INTO rules '
'(source, updated, first_party, refs, level) ' '(source, updated, first_party, level) '
'VALUES (:source, :updated, :first_party, 0, :level) ', 'VALUES (:source, :updated, :first_party, :level) ',
rules_prep rules_prep
) )
self.execute('SELECT id FROM rules WHERE rowid=?', cursor.execute('SELECT id FROM rules WHERE rowid=?',
(self.cursor.lastrowid,)) (cursor.lastrowid,))
for entry, in self.cursor: # only one for entry, in cursor: # only one
prep['entry'] = entry prep['entry'] = entry
self.execute(insert_query, prep) cursor.execute(insert_query, prep)
return return
assert False assert False
@ -338,7 +448,7 @@ class Database():
*args: typing.Any, **kwargs: typing.Any) -> None: *args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_hostname_prepare') self.enter_step('set_hostname_prepare')
prep: typing.Dict[str, DbValue] = { prep: typing.Dict[str, DbValue] = {
'val': self.prepare_hostname(hostname), 'val': self.pack_hostname(hostname),
} }
self._set_generic( self._set_generic(
'hostname', 'hostname',
@ -353,7 +463,7 @@ class Database():
*args: typing.Any, **kwargs: typing.Any) -> None: *args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_asn_prepare') self.enter_step('set_asn_prepare')
try: try:
asn_prep = self.prepare_asn(asn) asn_prep = self.pack_asn(asn)
except ValueError: except ValueError:
self.log.error("Invalid asn: %s", asn) self.log.error("Invalid asn: %s", asn)
return return
@ -371,10 +481,9 @@ class Database():
def set_ip4address(self, ip4address: str, def set_ip4address(self, ip4address: str,
*args: typing.Any, **kwargs: typing.Any) -> None: *args: typing.Any, **kwargs: typing.Any) -> None:
# TODO Do not add if already in ip4network
self.enter_step('set_ip4add_prepare') self.enter_step('set_ip4add_prepare')
try: try:
ip4address_prep = self.prepare_ip4address(ip4address) ip4address_prep = self.pack_ip4address(ip4address)
except (ValueError, IndexError): except (ValueError, IndexError):
self.log.error("Invalid ip4address: %s", ip4address) self.log.error("Invalid ip4address: %s", ip4address)
return return
@ -394,7 +503,7 @@ class Database():
*args: typing.Any, **kwargs: typing.Any) -> None: *args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_zone_prepare') self.enter_step('set_zone_prepare')
prep: typing.Dict[str, DbValue] = { prep: typing.Dict[str, DbValue] = {
'val': self.prepare_zone(zone), 'val': self.pack_zone(zone),
} }
self._set_generic( self._set_generic(
'zone', 'zone',
@ -409,7 +518,7 @@ class Database():
*args: typing.Any, **kwargs: typing.Any) -> None: *args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_ip4net_prepare') self.enter_step('set_ip4net_prepare')
try: try:
ip4network_prep = self.prepare_ip4network(ip4network) ip4network_prep = self.pack_ip4network(ip4network)
except (ValueError, IndexError): except (ValueError, IndexError):
self.log.error("Invalid ip4network: %s", ip4network) self.log.error("Invalid ip4network: %s", ip4network)
return return
@ -439,8 +548,12 @@ if __name__ == '__main__':
'-p', '--prune', action='store_true', '-p', '--prune', action='store_true',
help="Remove old entries from database") help="Remove old entries from database")
parser.add_argument( parser.add_argument(
'-e', '--expire', action='store_true', '-b', '--prune-base', action='store_true',
help="Set the whole database as an old source") help="TODO")
parser.add_argument(
'-s', '--prune-before', type=int,
default=(int(time.time()) - 60*60*24*31*6),
help="TODO")
parser.add_argument( parser.add_argument(
'-r', '--references', action='store_true', '-r', '--references', action='store_true',
help="Update the reference count") help="Update the reference count")
@ -451,10 +564,8 @@ if __name__ == '__main__':
if args.initialize: if args.initialize:
DB.initialize() DB.initialize()
if args.prune: if args.prune:
DB.prune() DB.prune(before=args.prune_before, base_only=args.prune_base)
if args.expire: if args.references:
DB.expire()
if args.references and not args.prune:
DB.update_references() DB.update_references()
DB.close() DB.close()

View file

@ -10,30 +10,37 @@ CREATE TABLE rules (
level INTEGER, -- Level of recursion to the root source rule (used for source priority) level INTEGER, -- Level of recursion to the root source rule (used for source priority)
FOREIGN KEY (source) REFERENCES rules(id) ON DELETE CASCADE FOREIGN KEY (source) REFERENCES rules(id) ON DELETE CASCADE
); );
CREATE INDEX rules_source ON rules (source); -- for references recounting
CREATE INDEX rules_updated ON rules (updated); -- for pruning
CREATE INDEX rules_level_firstparty ON rules (level, first_party); -- for counting rules
CREATE TABLE asn ( CREATE TABLE asn (
val INTEGER PRIMARY KEY, val INTEGER PRIMARY KEY,
entry INTEGER, entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
); );
CREATE INDEX asn_entry ON asn (entry); -- for explainations
CREATE TABLE hostname ( CREATE TABLE hostname (
val TEXT PRIMARY KEY, -- rev'd, ends with a dot (for consistency with zone) val TEXT PRIMARY KEY, -- rev'd, ends with a dot (for consistency with zone)
entry INTEGER, entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
); );
CREATE INDEX hostname_entry ON hostname (entry); -- for explainations
CREATE TABLE zone ( CREATE TABLE zone (
val TEXT PRIMARY KEY, -- rev'd, ends with a dot (for easier matching) val TEXT PRIMARY KEY, -- rev'd, ends with a dot (for easier matching)
entry INTEGER, entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
); );
CREATE INDEX zone_entry ON zone (entry); -- for explainations
CREATE TABLE ip4address ( CREATE TABLE ip4address (
val INTEGER PRIMARY KEY, val INTEGER PRIMARY KEY,
entry INTEGER, entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
); );
CREATE INDEX ip4address_entry ON ip4address (entry); -- for explainations
CREATE TABLE ip4network ( CREATE TABLE ip4network (
-- val TEXT PRIMARY KEY, -- val TEXT PRIMARY KEY,
@ -43,6 +50,7 @@ CREATE TABLE ip4network (
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
); );
CREATE INDEX ip4network_minmax ON ip4network (mini, maxi); CREATE INDEX ip4network_minmax ON ip4network (mini, maxi);
CREATE INDEX ip4network_entry ON ip4network (entry); -- for explainations
-- Store various things -- Store various things
CREATE TABLE meta ( CREATE TABLE meta (

View file

@ -19,12 +19,31 @@ if __name__ == '__main__':
parser.add_argument( parser.add_argument(
'-e', '--end-chain', action='store_true', '-e', '--end-chain', action='store_true',
help="TODO") help="TODO")
parser.add_argument(
'-x', '--explain', action='store_true',
help="TODO")
parser.add_argument(
'-r', '--rules', action='store_true',
help="TODO")
parser.add_argument(
'-c', '--count', action='store_true',
help="TODO")
args = parser.parse_args() args = parser.parse_args()
DB = database.Database() DB = database.Database()
for domain in DB.export(first_party_only=args.first_party, if args.rules:
end_chain_only=args.end_chain): if not args.count:
raise NotImplementedError
print(DB.count_rules(first_party_only=args.first_party))
else:
if args.count:
raise NotImplementedError
for domain in DB.export(
first_party_only=args.first_party,
end_chain_only=args.end_chain,
explain=args.explain,
):
print(domain, file=args.output) print(domain, file=args.output)
DB.close() DB.close()

53
feed_asn.py Executable file
View file

@ -0,0 +1,53 @@
#!/usr/bin/env python3
import database
import argparse
import requests
import typing
import ipaddress
import logging
import time
IPNetwork = typing.Union[ipaddress.IPv4Network, ipaddress.IPv6Network]
def get_ranges(asn: str) -> typing.Iterable[str]:
req = requests.get(
'https://stat.ripe.net/data/as-routing-consistency/data.json',
params={'resource': asn}
)
data = req.json()
for pref in data['data']['prefixes']:
yield pref['prefix']
if __name__ == '__main__':
log = logging.getLogger('feed_asn')
# Parsing arguments
parser = argparse.ArgumentParser(
description="TODO")
args = parser.parse_args()
DB = database.Database()
DBW = database.Database(write=True)
for asn, entry in DB.list_asn():
DB.enter_step('asn_get_ranges')
for prefix in get_ranges(asn):
parsed_prefix: IPNetwork = ipaddress.ip_network(prefix)
if parsed_prefix.version == 4:
DBW.set_ip4network(
prefix,
source=entry,
updated=int(time.time())
)
log.info('Added %s from %s (id=%s)', prefix, asn, entry)
elif parsed_prefix.version == 6:
log.warning('Unimplemented prefix version: %s', prefix)
else:
log.error('Unknown prefix version: %s', prefix)
DB.close()
DBW.close()

176
feed_dns.py Normal file → Executable file
View file

@ -1,23 +1,43 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import database
import argparse import argparse
import sys import database
import json
import logging import logging
import threading import sys
import queue
import typing import typing
import multiprocessing
NUMBER_THREADS = 8 NUMBER_THREADS = 2
BLOCK_SIZE = 100
# select, confirm, write
FUNCTION_MAP: typing.Any = {
'a': (
database.Database.get_ip4,
database.Database.get_domain_in_zone,
database.Database.set_hostname,
),
'cname': (
database.Database.get_domain,
database.Database.get_domain_in_zone,
database.Database.set_hostname,
),
'ptr': (
database.Database.get_domain,
database.Database.get_ip4_in_network,
database.Database.set_ip4address,
),
}
class Worker(threading.Thread): class Reader(multiprocessing.Process):
def __init__(self, def __init__(self,
lines_queue: queue.Queue, lines_queue: multiprocessing.Queue,
write_queue: queue.Queue, write_queue: multiprocessing.Queue,
index: int = 0): index: int = 0):
super(Worker, self).__init__() super(Reader, self).__init__()
self.log = logging.getLogger(f'worker{index:03d}') self.log = logging.getLogger(f'rd{index:03d}')
self.lines_queue = lines_queue self.lines_queue = lines_queue
self.write_queue = write_queue self.write_queue = write_queue
self.index = index self.index = index
@ -25,45 +45,51 @@ class Worker(threading.Thread):
def run(self) -> None: def run(self) -> None:
self.db = database.Database(write=False) self.db = database.Database(write=False)
self.db.log = logging.getLogger(f'db{self.index:03d}') self.db.log = logging.getLogger(f'db{self.index:03d}')
self.db.enter_step('wait_line') self.db.enter_step('line_wait')
line: str block: typing.List[str]
for line in iter(self.lines_queue.get, None):
self.db.enter_step('feed_json_parse')
# split = line.split(b'"')
split = line.split('"')
try: try:
name = split[7] for block in iter(self.lines_queue.get, None):
dtype = split[11] for line in block:
value = split[15] dtype, updated, name, value = line
except IndexError:
log.error("Invalid JSON: %s", line)
continue
# DB.enter_step('feed_json_assert')
# data = json.loads(line)
# assert dtype == data['type']
# assert name == data['name']
# assert value == data['value']
self.db.enter_step('feed_switch') self.db.enter_step('feed_switch')
if dtype == 'a': select, confirm, write = FUNCTION_MAP[dtype]
for rule in self.db.get_ip4(value): for rule in select(self.db, value):
if not any(confirm(self.db, name)):
self.db.enter_step('wait_put') self.db.enter_step('wait_put')
self.write_queue.put( self.write_queue.put((write, name, updated))
(database.Database.set_hostname, name, rule)) self.db.enter_step('line_wait')
elif dtype == 'cname': except KeyboardInterrupt:
for rule in self.db.get_domain(value): self.log.error('Interrupted')
self.db.enter_step('wait_put')
self.write_queue.put( self.db.enter_step('end')
(database.Database.set_hostname, name, rule)) self.db.close()
elif dtype == 'ptr':
for rule in self.db.get_domain(value):
self.db.enter_step('wait_put') class Writer(multiprocessing.Process):
self.write_queue.put( def __init__(self,
(database.Database.set_ip4address, name, rule)) write_queue: multiprocessing.Queue,
self.db.enter_step('wait_line') ):
super(Writer, self).__init__()
self.log = logging.getLogger(f'wr ')
self.write_queue = write_queue
def run(self) -> None:
self.db = database.Database(write=True)
self.db.log = logging.getLogger(f'dbw ')
self.db.enter_step('line_wait')
block: typing.List[str]
try:
fun: typing.Callable
name: str
updated: int
for fun, name, updated in iter(self.write_queue.get, None):
self.db.enter_step('exec')
fun(self.db, name, updated)
self.db.enter_step('line_wait')
except KeyboardInterrupt:
self.log.error('Interrupted')
self.db.enter_step('end') self.db.enter_step('end')
self.write_queue.put(None)
self.db.close() self.db.close()
@ -80,42 +106,52 @@ if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
DB = database.Database(write=False) # Not needed, just for timing DB = database.Database(write=False) # Not needed, just for timing
DB.log = logging.getLogger('dbf') DB.log = logging.getLogger('db ')
DBW = database.Database(write=True)
DBW.log = logging.getLogger('dbw')
lines_queue: queue.Queue = queue.Queue(maxsize=NUMBER_THREADS) lines_queue: multiprocessing.Queue = multiprocessing.Queue(maxsize=100)
write_queue: queue.Queue = queue.Queue(maxsize=NUMBER_THREADS) write_queue: multiprocessing.Queue = multiprocessing.Queue(maxsize=100)
def fill_lines_queue() -> None: DB.enter_step('proc_create')
readers: typing.List[Reader] = list()
for w in range(NUMBER_THREADS):
readers.append(Reader(lines_queue, write_queue, w))
writer = Writer(write_queue)
DB.enter_step('proc_start')
for reader in readers:
reader.start()
writer.start()
try:
block: typing.List[str] = list()
DB.enter_step('iowait') DB.enter_step('iowait')
for line in args.input: for line in args.input:
DB.enter_step('block_append')
DB.enter_step('feed_json_parse')
data = json.loads(line)
line = (data['type'],
int(data['timestamp']),
data['name'],
data['value'])
block.append(line)
if len(block) >= BLOCK_SIZE:
DB.enter_step('wait_put') DB.enter_step('wait_put')
lines_queue.put(line) lines_queue.put(block)
block = list()
DB.enter_step('iowait') DB.enter_step('iowait')
DB.enter_step('wait_put')
lines_queue.put(block)
DB.enter_step('end_put') DB.enter_step('end_put')
for _ in range(NUMBER_THREADS): for _ in range(NUMBER_THREADS):
lines_queue.put(None) lines_queue.put(None)
write_queue.put(None)
for w in range(NUMBER_THREADS): DB.enter_step('proc_join')
Worker(lines_queue, write_queue, w).start() for reader in readers:
reader.join()
writer.join()
except KeyboardInterrupt:
log.error('Interrupted')
threading.Thread(target=fill_lines_queue).start()
for _ in range(NUMBER_THREADS):
fun: typing.Callable
name: str
source: int
DBW.enter_step('wait_fun')
for fun, name, source in iter(write_queue.get, None):
DBW.enter_step('exec_fun')
fun(DBW, name, source=source)
DBW.enter_step('commit')
DBW.conn.commit()
DBW.enter_step('wait_fun')
DBW.enter_step('end')
DBW.close()
DB.close() DB.close()

View file

@ -3,6 +3,7 @@
import database import database
import argparse import argparse
import sys import sys
import time
FUNCTION_MAP = { FUNCTION_MAP = {
'zone': database.Database.set_zone, 'zone': database.Database.set_zone,
@ -32,6 +33,10 @@ if __name__ == '__main__':
fun = FUNCTION_MAP[args.type] fun = FUNCTION_MAP[args.type]
for rule in args.input: for rule in args.input:
fun(DB, rule.strip(), is_first_party=args.first_party) fun(DB,
rule.strip(),
is_first_party=args.first_party,
updated=int(time.time()),
)
DB.close() DB.close()

View file

@ -18,7 +18,7 @@ log "Retrieving rules…"
rm -f rules*/*.cache.* rm -f rules*/*.cache.*
dl https://easylist.to/easylist/easyprivacy.txt rules_adblock/easyprivacy.cache.txt dl https://easylist.to/easylist/easyprivacy.txt rules_adblock/easyprivacy.cache.txt
# From firebog.net Tracking & Telemetry Lists # From firebog.net Tracking & Telemetry Lists
dl https://v.firebog.net/hosts/Prigent-Ads.txt rules/prigent-ads.cache.list # dl https://v.firebog.net/hosts/Prigent-Ads.txt rules/prigent-ads.cache.list
# dl https://gitlab.com/quidsup/notrack-blocklists/raw/master/notrack-blocklist.txt rules/notrack-blocklist.cache.list # dl https://gitlab.com/quidsup/notrack-blocklists/raw/master/notrack-blocklist.txt rules/notrack-blocklist.cache.list
# False positives: https://github.com/WaLLy3K/wally3k.github.io/issues/73 -> 69.media.tumblr.com chicdn.net # False positives: https://github.com/WaLLy3K/wally3k.github.io/issues/73 -> 69.media.tumblr.com chicdn.net
dl https://raw.githubusercontent.com/StevenBlack/hosts/master/data/add.2o7Net/hosts rules_hosts/add2o7.cache.txt dl https://raw.githubusercontent.com/StevenBlack/hosts/master/data/add.2o7Net/hosts rules_hosts/add2o7.cache.txt

View file

@ -4,6 +4,12 @@ function log() {
echo -e "\033[33m$@\033[0m" echo -e "\033[33m$@\033[0m"
} }
log "Pruning old data…"
./database.py --prune
log "Recounting references…"
./database.py --references
log "Exporting lists…" log "Exporting lists…"
./export.py --first-party --output dist/firstparty-trackers.txt ./export.py --first-party --output dist/firstparty-trackers.txt
./export.py --first-party --end-chain --output dist/firstparty-only-trackers.txt ./export.py --first-party --end-chain --output dist/firstparty-only-trackers.txt
@ -11,6 +17,8 @@ log "Exporting lists…"
./export.py --end-chain --output dist/multiparty-only-trackers.txt ./export.py --end-chain --output dist/multiparty-only-trackers.txt
log "Generating hosts lists…" log "Generating hosts lists…"
./export.py --rules --count --first-party > temp/count_rules_firstparty.txt
./export.py --rules --count > temp/count_rules_multiparty.txt
function generate_hosts { function generate_hosts {
basename="$1" basename="$1"
description="$2" description="$2"
@ -36,15 +44,16 @@ function generate_hosts {
echo "#" echo "#"
echo "# Generation date: $(date -Isec)" echo "# Generation date: $(date -Isec)"
echo "# Generation software: eulaurarien $(git describe --tags)" echo "# Generation software: eulaurarien $(git describe --tags)"
echo "# Number of source websites: TODO" echo "# Number of source websites: $(wc -l temp/all_websites.list | cut -d' ' -f1)"
echo "# Number of source subdomains: TODO" echo "# Number of source subdomains: $(wc -l temp/all_subdomains.list | cut -d' ' -f1)"
echo "# Number of source DNS records: ~2M + $(wc -l temp/all_resolved.json | cut -d' ' -f1)"
echo "#" echo "#"
echo "# Number of known first-party trackers: TODO" echo "# Known first-party trackers: $(cat temp/count_rules_firstparty.txt)"
echo "# Number of first-party subdomains: $(wc -l dist/firstparty-trackers.txt | cut -d' ' -f1)" echo "# Number of first-party hostnames: $(wc -l dist/firstparty-trackers.txt | cut -d' ' -f1)"
echo "# … excluding redirected: $(wc -l dist/firstparty-only-trackers.txt | cut -d' ' -f1)" echo "# … excluding redirected: $(wc -l dist/firstparty-only-trackers.txt | cut -d' ' -f1)"
echo "#" echo "#"
echo "# Number of known multi-party trackers: TODO" echo "# Known multi-party trackers: $(cat temp/count_rules_multiparty.txt)"
echo "# Number of multi-party subdomains: $(wc -l dist/multiparty-trackers.txt | cut -d' ' -f1)" echo "# Number of multi-party hostnames: $(wc -l dist/multiparty-trackers.txt | cut -d' ' -f1)"
echo "# … excluding redirected: $(wc -l dist/multiparty-only-trackers.txt | cut -d' ' -f1)" echo "# … excluding redirected: $(wc -l dist/multiparty-only-trackers.txt | cut -d' ' -f1)"
echo echo
sed 's|^|0.0.0.0 |' "dist/$basename.txt" sed 's|^|0.0.0.0 |' "dist/$basename.txt"

View file

@ -5,6 +5,7 @@ function log() {
} }
log "Importing rules…" log "Importing rules…"
BEFORE="$(date +%s)"
cat rules_adblock/*.txt | grep -v '^!' | grep -v '^\[Adblock' | ./adblock_to_domain_list.py | ./feed_rules.py zone cat rules_adblock/*.txt | grep -v '^!' | grep -v '^\[Adblock' | ./adblock_to_domain_list.py | ./feed_rules.py zone
cat rules_hosts/*.txt | grep -v '^#' | grep -v '^$' | cut -d ' ' -f2 | ./feed_rules.py zone cat rules_hosts/*.txt | grep -v '^#' | grep -v '^$' | cut -d ' ' -f2 | ./feed_rules.py zone
cat rules/*.list | grep -v '^#' | grep -v '^$' | ./feed_rules.py zone cat rules/*.list | grep -v '^#' | grep -v '^$' | ./feed_rules.py zone
@ -17,3 +18,5 @@ cat rules_asn/first-party.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py as
./feed_asn.py ./feed_asn.py
log "Pruning old rules…"
./database.py --prune --prune-before "$BEFORE" --prune-base

36
json_to_csv.py Executable file
View file

@ -0,0 +1,36 @@
#!/usr/bin/env python3
import argparse
import sys
import logging
import json
import csv
if __name__ == '__main__':
# Parsing arguments
log = logging.getLogger('json_to_csv')
parser = argparse.ArgumentParser(
description="TODO")
parser.add_argument(
# '-i', '--input', type=argparse.FileType('rb'), default=sys.stdin.buffer,
'-i', '--input', type=argparse.FileType('r'), default=sys.stdin,
help="TODO")
parser.add_argument(
# '-i', '--output', type=argparse.FileType('wb'), default=sys.stdout.buffer,
'-o', '--output', type=argparse.FileType('w'), default=sys.stdout,
help="TODO")
args = parser.parse_args()
writer = csv.writer(args.output)
for line in args.input:
data = json.loads(line)
try:
writer.writerow([
data['type'][0], # First letter, will need to do something special for AAAA
data['timestamp'],
data['name'],
data['value']])
except (KeyError, json.decoder.JSONDecodeError):
log.error('Could not parse line: %s', line)
pass

View file

@ -4,18 +4,16 @@ function log() {
echo -e "\033[33m$@\033[0m" echo -e "\033[33m$@\033[0m"
} }
log "Preparing database…" ./fetch_resources.sh
./database.py --expire
./import_rules.sh ./import_rules.sh
# TODO Fetch 'em # TODO Fetch 'em
log "Reading PTR records…" log "Reading PTR records…"
pv ptr.json.gz | gunzip | ./feed_dns.py pv ptr.json.gz | gunzip | ./json_to_csv.py | ./feed_dns.py
log "Reading A records…" log "Reading A records…"
pv a.json.gz | gunzip | ./feed_dns.py pv a.json.gz | gunzip | ./json_to_csv.py | ./feed_dns.py
log "Reading CNAME records…" log "Reading CNAME records…"
pv cname.json.gz | gunzip | ./feed_dns.py pv cname.json.gz | gunzip | ./json_to_csv.py | ./feed_dns.py
log "Pruning old data…" log "Pruning old data…"
./database.py --prune ./database.py --prune

View file

@ -1,21 +0,0 @@
#!/usr/bin/env python3
"""
List of regex matching first-party trackers.
"""
# Syntax: https://docs.python.org/3/library/re.html#regular-expression-syntax
REGEXES = [
r'^.+\.eulerian\.net\.$', # Eulerian
r'^.+\.criteo\.com\.$', # Criteo
r'^.+\.dnsdelegation\.io\.$', # Criteo
r'^.+\.keyade\.com\.$', # Keyade
r'^.+\.omtrdc\.net\.$', # Adobe Experience Cloud
r'^.+\.bp01\.net\.$', # NP6
r'^.+\.ati-host\.net\.$', # Xiti (AT Internet)
r'^.+\.at-o\.net\.$', # Xiti (AT Internet)
r'^.+\.edgkey\.net\.$', # Edgekey (Akamai)
r'^.+\.akaimaiedge\.net\.$', # Edgekey (Akamai)
r'^.+\.storetail\.io\.$', # Storetail (Criteo)
]

View file

@ -12,22 +12,15 @@ import queue
import sys import sys
import threading import threading
import typing import typing
import csv import time
import coloredlogs import coloredlogs
import dns.exception import dns.exception
import dns.resolver import dns.resolver
import progressbar
DNS_TIMEOUT = 5.0 DNS_TIMEOUT = 5.0
NUMBER_THREADS = 512
NUMBER_TRIES = 5 NUMBER_TRIES = 5
# TODO All the domains don't get treated,
# so it leaves with 4-5 subdomains not resolved
glob = None
class Worker(threading.Thread): class Worker(threading.Thread):
""" """
@ -60,7 +53,7 @@ class Worker(threading.Thread):
def resolve_subdomain(self, subdomain: str) -> typing.Optional[ def resolve_subdomain(self, subdomain: str) -> typing.Optional[
typing.List[ typing.List[
str dns.rrset.RRset
] ]
]: ]:
""" """
@ -93,18 +86,7 @@ class Worker(threading.Thread):
except dns.name.EmptyLabel: except dns.name.EmptyLabel:
self.log.warning("Empty label for %s", subdomain) self.log.warning("Empty label for %s", subdomain)
return None return None
resolved = list() return query.response.answer
last = len(query.response.answer) - 1
for a, answer in enumerate(query.response.answer):
if answer.rdtype == dns.rdatatype.CNAME:
assert a < last
resolved.append(answer.items[0].to_text()[:-1])
elif answer.rdtype == dns.rdatatype.A:
assert a == last
resolved.append(answer.items[0].address)
else:
assert False
return resolved
def run(self) -> None: def run(self) -> None:
self.log.info("Started") self.log.info("Started")
@ -124,7 +106,6 @@ class Worker(threading.Thread):
self.log.error("Gave up on %s", subdomain) self.log.error("Gave up on %s", subdomain)
resolved = [] resolved = []
resolved.insert(0, subdomain)
assert isinstance(resolved, list) assert isinstance(resolved, list)
self.orchestrator.results_queue.put(resolved) self.orchestrator.results_queue.put(resolved)
@ -150,15 +131,17 @@ class Orchestrator():
def __init__(self, subdomains: typing.Iterable[str], def __init__(self, subdomains: typing.Iterable[str],
nameservers: typing.List[str] = None, nameservers: typing.List[str] = None,
nb_workers: int = 1,
): ):
self.log = logging.getLogger('orchestrator') self.log = logging.getLogger('orchestrator')
self.subdomains = subdomains self.subdomains = subdomains
self.nb_workers = nb_workers
# Use interal resolver by default # Use interal resolver by default
self.nameservers = nameservers or dns.resolver.Resolver().nameservers self.nameservers = nameservers or dns.resolver.Resolver().nameservers
self.subdomains_queue: queue.Queue = queue.Queue( self.subdomains_queue: queue.Queue = queue.Queue(
maxsize=NUMBER_THREADS) maxsize=self.nb_workers)
self.results_queue: queue.Queue = queue.Queue() self.results_queue: queue.Queue = queue.Queue()
self.nameservers_queue: queue.Queue = queue.Queue() self.nameservers_queue: queue.Queue = queue.Queue()
@ -179,16 +162,31 @@ class Orchestrator():
self.log.info("Finished reading subdomains") self.log.info("Finished reading subdomains")
# Send sentinel to each worker # Send sentinel to each worker
# sentinel = None ~= EOF # sentinel = None ~= EOF
for _ in range(NUMBER_THREADS): for _ in range(self.nb_workers):
self.subdomains_queue.put(None) self.subdomains_queue.put(None)
def run(self) -> typing.Iterable[typing.List[str]]: @staticmethod
def format_rrset(rrset: dns.rrset.RRset) -> typing.Iterable[str]:
if rrset.rdtype == dns.rdatatype.CNAME:
dtype = 'c'
elif rrset.rdtype == dns.rdatatype.A:
dtype = 'a'
else:
raise NotImplementedError
name = rrset.name.to_text()[:-1]
for item in rrset.items:
value = item.to_text()
if rrset.rdtype == dns.rdatatype.CNAME:
value = value[:-1]
yield f'{dtype},{int(time.time())},{name},{value}\n'
def run(self) -> typing.Iterable[str]:
""" """
Yield the results. Yield the results.
""" """
# Create workers # Create workers
self.log.info("Creating workers") self.log.info("Creating workers")
for i in range(NUMBER_THREADS): for i in range(self.nb_workers):
Worker(self, i).start() Worker(self, i).start()
fill_thread = threading.Thread(target=self.fill_subdomain_queue) fill_thread = threading.Thread(target=self.fill_subdomain_queue)
@ -196,10 +194,11 @@ class Orchestrator():
# Wait for one sentinel per worker # Wait for one sentinel per worker
# In the meantime output results # In the meantime output results
for _ in range(NUMBER_THREADS): for _ in range(self.nb_workers):
result: typing.List[str] resolved: typing.List[dns.rrset.RRset]
for result in iter(self.results_queue.get, None): for resolved in iter(self.results_queue.get, None):
yield result for rrset in resolved:
yield from self.format_rrset(rrset)
self.log.info("Waiting for reader thread") self.log.info("Waiting for reader thread")
fill_thread.join() fill_thread.join()
@ -214,11 +213,9 @@ def main() -> None:
the last CNAME resolved and the IP adress it resolves to. the last CNAME resolved and the IP adress it resolves to.
Takes as an input a filename (or nothing, for stdin), Takes as an input a filename (or nothing, for stdin),
and as an output a filename (or nothing, for stdout). and as an output a filename (or nothing, for stdout).
The input must be a subdomain per line, the output is a comma-sep The input must be a subdomain per line, the output is a TODO
file with the columns source CNAME and A.
Use the file `nameservers` as the list of nameservers Use the file `nameservers` as the list of nameservers
to use, or else it will use the system defaults. to use, or else it will use the system defaults.
Also shows a nice progressbar.
""" """
# Initialization # Initialization
@ -236,28 +233,14 @@ def main() -> None:
parser.add_argument( parser.add_argument(
'-o', '--output', type=argparse.FileType('w'), default=sys.stdout, '-o', '--output', type=argparse.FileType('w'), default=sys.stdout,
help="Outptut file with DNS chains") help="Outptut file with DNS chains")
# parser.add_argument( parser.add_argument(
# '-n', '--nameserver', type=argparse.FileType('r'), '-n', '--nameservers', default='nameservers',
# default='nameservers', help="File with one nameserver per line") help="File with one nameserver per line")
# parser.add_argument( parser.add_argument(
# '-j', '--workers', type=int, default=512, '-j', '--workers', type=int, default=512,
# help="Number of threads to use") help="Number of threads to use")
args = parser.parse_args() args = parser.parse_args()
# Progress bar
widgets = [
progressbar.Percentage(),
' ', progressbar.SimpleProgress(),
' ', progressbar.Bar(),
' ', progressbar.Timer(),
' ', progressbar.AdaptiveTransferSpeed(unit='req'),
' ', progressbar.AdaptiveETA(),
]
progress = progressbar.ProgressBar(widgets=widgets)
if args.input.seekable():
progress.max_value = len(args.input.readlines())
args.input.seek(0)
# Cleaning input # Cleaning input
iterator = iter(args.input) iterator = iter(args.input)
iterator = map(str.strip, iterator) iterator = map(str.strip, iterator)
@ -265,19 +248,16 @@ def main() -> None:
# Reading nameservers # Reading nameservers
servers: typing.List[str] = list() servers: typing.List[str] = list()
if os.path.isfile('nameservers'): if os.path.isfile(args.nameservers):
servers = open('nameservers').readlines() servers = open(args.nameservers).readlines()
servers = list(filter(None, map(str.strip, servers))) servers = list(filter(None, map(str.strip, servers)))
writer = csv.writer(args.output) for resolved in Orchestrator(
iterator,
progress.start() servers,
global glob nb_workers=args.workers
glob = Orchestrator(iterator, servers) ).run():
for resolved in glob.run(): args.output.write(resolved)
progress.update(progress.value + 1)
writer.writerow(resolved)
progress.finish()
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -4,11 +4,9 @@ function log() {
echo -e "\033[33m$@\033[0m" echo -e "\033[33m$@\033[0m"
} }
# Resolve the CNAME chain of all the known subdomains for later analysis log "Compiling locally known subdomain…"
log "Compiling subdomain lists..."
pv subdomains/*.list | sort -u > temp/all_subdomains.list
# Sort by last character to utilize the DNS server caching mechanism # Sort by last character to utilize the DNS server caching mechanism
pv temp/all_subdomains.list | rev | sort | rev > temp/all_subdomains_reversort.list pv subdomains/*.list | sed 's/\r$//' | rev | sort -u | rev > temp/all_subdomains.list
./resolve_subdomains.py --input temp/all_subdomains_reversort.list --output temp/all_resolved.csv log "Resolving locally known subdomain…"
sort -u temp/all_resolved.csv > temp/all_resolved_sorted.csv pv temp/all_subdomains.list | ./resolve_subdomains.py --output temp/all_resolved.csv