Workflow: Various optimisations and fixes

I forgot to close this one earlier, so:
Closes #7
This commit is contained in:
Geoffrey Frogeye 2019-12-13 18:00:00 +01:00
parent f3eedcba22
commit ab7ef609dd
Signed by: geoffrey
GPG key ID: D8A7ECA00A8CD3DD
8 changed files with 214 additions and 117 deletions

View file

@ -12,6 +12,7 @@ import logging
import argparse import argparse
import coloredlogs import coloredlogs
import ipaddress import ipaddress
import math
coloredlogs.install( coloredlogs.install(
level='DEBUG', level='DEBUG',
@ -22,40 +23,44 @@ DbValue = typing.Union[None, int, float, str, bytes]
class Database(): class Database():
VERSION = 3 VERSION = 4
PATH = "blocking.db" PATH = "blocking.db"
def open(self) -> None: def open(self) -> None:
mode = 'rwc' if self.write else 'ro' mode = 'rwc' if self.write else 'ro'
uri = f'file:{self.PATH}?mode={mode}' uri = f'file:{self.PATH}?mode={mode}'
self.conn = sqlite3.connect(uri, uri=True) self.conn = sqlite3.connect(uri, uri=True)
self.cursor = self.conn.cursor() cursor = self.conn.cursor()
self.execute("PRAGMA foreign_keys = ON") cursor.execute("PRAGMA foreign_keys = ON")
# self.conn.create_function("prepare_ip4address", 1, self.conn.create_function("unpack_asn", 1,
# Database.prepare_ip4address, self.unpack_asn,
# deterministic=True) deterministic=True)
self.conn.create_function("unpack_ip4address", 1,
self.unpack_ip4address,
deterministic=True)
self.conn.create_function("unpack_ip4network", 2,
self.unpack_ip4network,
deterministic=True)
self.conn.create_function("unpack_domain", 1, self.conn.create_function("unpack_domain", 1,
lambda s: s[:-1][::-1], lambda s: s[:-1][::-1],
deterministic=True) deterministic=True)
self.conn.create_function("format_zone", 1,
def execute(self, cmd: str, args: typing.Union[ lambda s: '*' + s[::-1],
typing.Tuple[DbValue, ...], deterministic=True)
typing.Dict[str, DbValue]] = None) -> None:
# self.log.debug(cmd)
# self.log.debug(args)
self.cursor.execute(cmd, args or tuple())
def get_meta(self, key: str) -> typing.Optional[int]: def get_meta(self, key: str) -> typing.Optional[int]:
cursor = self.conn.cursor()
try: try:
self.execute("SELECT value FROM meta WHERE key=?", (key,)) cursor.execute("SELECT value FROM meta WHERE key=?", (key,))
except sqlite3.OperationalError: except sqlite3.OperationalError:
return None return None
for ver, in self.cursor: for ver, in cursor:
return ver return ver
return None return None
def set_meta(self, key: str, val: int) -> None: def set_meta(self, key: str, val: int) -> None:
self.execute("INSERT INTO meta VALUES (?, ?) " cursor = self.conn.cursor()
cursor.execute("INSERT INTO meta VALUES (?, ?) "
"ON CONFLICT (key) DO " "ON CONFLICT (key) DO "
"UPDATE set value=?", "UPDATE set value=?",
(key, val, val)) (key, val, val))
@ -76,8 +81,9 @@ class Database():
os.unlink(self.PATH) os.unlink(self.PATH)
self.open() self.open()
self.log.info("Creating database version %d.", self.VERSION) self.log.info("Creating database version %d.", self.VERSION)
cursor = self.conn.cursor()
with open("database_schema.sql", 'r') as db_schema: with open("database_schema.sql", 'r') as db_schema:
self.cursor.executescript(db_schema.read()) cursor.executescript(db_schema.read())
self.set_meta('version', self.VERSION) self.set_meta('version', self.VERSION)
self.conn.commit() self.conn.commit()
@ -119,21 +125,27 @@ class Database():
self.log.debug(f"{'total':<20}: " self.log.debug(f"{'total':<20}: "
f"{total:9.2f} s ({1:7.2%})") f"{total:9.2f} s ({1:7.2%})")
def prepare_hostname(self, hostname: str) -> str: @staticmethod
def pack_hostname(hostname: str) -> str:
return hostname[::-1] + '.' return hostname[::-1] + '.'
def prepare_zone(self, zone: str) -> str: @staticmethod
return self.prepare_hostname(zone) def pack_zone(zone: str) -> str:
return Database.pack_hostname(zone)
@staticmethod @staticmethod
def prepare_asn(asn: str) -> int: def pack_asn(asn: str) -> int:
asn = asn.upper() asn = asn.upper()
if asn.startswith('AS'): if asn.startswith('AS'):
asn = asn[2:] asn = asn[2:]
return int(asn) return int(asn)
@staticmethod @staticmethod
def prepare_ip4address(address: str) -> int: def unpack_asn(asn: int) -> str:
return f'AS{asn}'
@staticmethod
def pack_ip4address(address: str) -> int:
total = 0 total = 0
for i, octet in enumerate(address.split('.')): for i, octet in enumerate(address.split('.')):
total += int(octet) << (3-i)*8 total += int(octet) << (3-i)*8
@ -151,29 +163,75 @@ class Database():
# packed = ipaddress.ip_address(address).packed # packed = ipaddress.ip_address(address).packed
# return packed # return packed
def prepare_ip4network(self, network: str) -> typing.Tuple[int, int]: @staticmethod
# def prepare_ip4network(network: str) -> str: def unpack_ip4address(address: int) -> str:
return '.'.join(str((address >> (i * 8)) & 0xFF)
for i in reversed(range(4)))
@staticmethod
def pack_ip4network(network: str) -> typing.Tuple[int, int]:
# def pack_ip4network(network: str) -> str:
net = ipaddress.ip_network(network) net = ipaddress.ip_network(network)
mini = self.prepare_ip4address(net.network_address.exploded) mini = Database.pack_ip4address(net.network_address.exploded)
maxi = self.prepare_ip4address(net.broadcast_address.exploded) maxi = Database.pack_ip4address(net.broadcast_address.exploded)
# mini = net.network_address.packed # mini = net.network_address.packed
# maxi = net.broadcast_address.packed # maxi = net.broadcast_address.packed
return mini, maxi return mini, maxi
# return Database.prepare_ip4address(net.network_address.exploded)[:net.prefixlen] # return Database.pack_ip4address(net.network_address.exploded)[:net.prefixlen]
@staticmethod
def unpack_ip4network(mini: int, maxi: int) -> str:
addr = Database.unpack_ip4address(mini)
prefixlen = 32-int(math.log2(maxi-mini+1))
return f'{addr}/{prefixlen}'
def update_references(self) -> None: def update_references(self) -> None:
self.enter_step('update_refs') self.enter_step('update_refs')
self.execute('UPDATE rules AS r SET refs=' cursor = self.conn.cursor()
cursor.execute('UPDATE rules AS r SET refs='
'(SELECT count(*) FROM rules ' '(SELECT count(*) FROM rules '
'WHERE source=r.id)') 'WHERE source=r.id)')
def prune(self, before: int) -> None: def prune(self, before: int) -> None:
self.enter_step('prune') self.enter_step('prune')
self.execute('DELETE FROM rules WHERE updated<?', (before,)) cursor = self.conn.cursor()
cursor.execute('DELETE FROM rules WHERE updated<?', (before,))
def export(self, first_party_only: bool = False, def explain(self, entry: int) -> str:
end_chain_only: bool = False) -> typing.Iterable[str]: # Format current
command = 'SELECT unpack_domain(val) FROM rules ' \ string = '???'
cursor = self.conn.cursor()
cursor.execute(
'SELECT unpack_asn(val) FROM asn WHERE entry=:entry '
'UNION '
'SELECT unpack_domain(val) FROM hostname WHERE entry=:entry '
'UNION '
'SELECT format_zone(val) FROM zone WHERE entry=:entry '
'UNION '
'SELECT unpack_ip4address(val) FROM ip4address WHERE entry=:entry '
'UNION '
'SELECT unpack_ip4network(mini, maxi) '
'FROM ip4network WHERE entry=:entry ',
{"entry": entry}
)
for val, in cursor: # only one
string = str(val)
string += f' #{entry}'
# Add source if any
cursor.execute('SELECT source FROM rules WHERE id=?', (entry,))
for source, in cursor:
if source:
string += f'{self.explain(source)}'
return string
def export(self,
first_party_only: bool = False,
end_chain_only: bool = False,
explain: bool = False,
) -> typing.Iterable[str]:
selection = 'entry' if explain else 'unpack_domain(val)'
command = f'SELECT {selection} FROM rules ' \
'INNER JOIN hostname ON rules.id = hostname.entry' 'INNER JOIN hostname ON rules.id = hostname.entry'
restrictions: typing.List[str] = list() restrictions: typing.List[str] = list()
if first_party_only: if first_party_only:
@ -182,16 +240,22 @@ class Database():
restrictions.append('rules.refs = 0') restrictions.append('rules.refs = 0')
if restrictions: if restrictions:
command += ' WHERE ' + ' AND '.join(restrictions) command += ' WHERE ' + ' AND '.join(restrictions)
if not explain:
command += ' ORDER BY unpack_domain(val) ASC' command += ' ORDER BY unpack_domain(val) ASC'
self.execute(command) cursor = self.conn.cursor()
for val, in self.cursor: cursor.execute(command)
for val, in cursor:
if explain:
yield self.explain(val)
else:
yield val yield val
def get_domain(self, domain: str) -> typing.Iterable[int]: def get_domain(self, domain: str) -> typing.Iterable[int]:
self.enter_step('get_domain_prepare') self.enter_step('get_domain_prepare')
domain_prep = self.prepare_hostname(domain) domain_prep = self.pack_hostname(domain)
cursor = self.conn.cursor()
self.enter_step('get_domain_select') self.enter_step('get_domain_select')
self.execute( cursor.execute(
'SELECT null, entry FROM hostname ' 'SELECT null, entry FROM hostname '
'WHERE val=:d ' 'WHERE val=:d '
'UNION ' 'UNION '
@ -202,22 +266,41 @@ class Database():
')', ')',
{'d': domain_prep} {'d': domain_prep}
) )
for val, entry in self.cursor: for val, entry in cursor:
self.enter_step('get_domain_confirm') self.enter_step('get_domain_confirm')
if not (val is None or domain_prep.startswith(val)): if not (val is None or domain_prep.startswith(val)):
continue continue
self.enter_step('get_domain_yield') self.enter_step('get_domain_yield')
yield entry yield entry
def get_domain_in_zone(self, domain: str) -> typing.Iterable[int]:
self.enter_step('get_domainiz_prepare')
domain_prep = self.pack_hostname(domain)
cursor = self.conn.cursor()
self.enter_step('get_domainiz_select')
cursor.execute(
'SELECT val, entry FROM zone '
'WHERE val<=:d '
'ORDER BY val DESC LIMIT 1',
{'d': domain_prep}
)
for val, entry in cursor:
self.enter_step('get_domainiz_confirm')
if not (val is None or domain_prep.startswith(val)):
continue
self.enter_step('get_domainiz_yield')
yield entry
def get_ip4(self, address: str) -> typing.Iterable[int]: def get_ip4(self, address: str) -> typing.Iterable[int]:
self.enter_step('get_ip4_prepare') self.enter_step('get_ip4_prepare')
try: try:
address_prep = self.prepare_ip4address(address) address_prep = self.pack_ip4address(address)
except (ValueError, IndexError): except (ValueError, IndexError):
self.log.error("Invalid ip4address: %s", address) self.log.error("Invalid ip4address: %s", address)
return return
cursor = self.conn.cursor()
self.enter_step('get_ip4_select') self.enter_step('get_ip4_select')
self.execute( cursor.execute(
'SELECT entry FROM ip4address ' 'SELECT entry FROM ip4address '
# 'SELECT null, entry FROM ip4address ' # 'SELECT null, entry FROM ip4address '
'WHERE val=:a ' 'WHERE val=:a '
@ -232,7 +315,7 @@ class Database():
'WHERE :a BETWEEN mini AND maxi ', 'WHERE :a BETWEEN mini AND maxi ',
{'a': address_prep} {'a': address_prep}
) )
for val, entry in self.cursor: for entry, in cursor:
# self.enter_step('get_ip4_confirm') # self.enter_step('get_ip4_confirm')
# if not (val is None or val.startswith(address_prep)): # if not (val is None or val.startswith(address_prep)):
# # PERF startswith but from the end # # PERF startswith but from the end
@ -240,11 +323,29 @@ class Database():
self.enter_step('get_ip4_yield') self.enter_step('get_ip4_yield')
yield entry yield entry
def get_ip4_in_network(self, address: str) -> typing.Iterable[int]:
self.enter_step('get_ip4in_prepare')
try:
address_prep = self.pack_ip4address(address)
except (ValueError, IndexError):
self.log.error("Invalid ip4address: %s", address)
return
cursor = self.conn.cursor()
self.enter_step('get_ip4in_select')
cursor.execute(
'SELECT entry FROM ip4network '
'WHERE :a BETWEEN mini AND maxi ',
{'a': address_prep}
)
for entry, in cursor:
self.enter_step('get_ip4in_yield')
yield entry
def list_asn(self) -> typing.Iterable[typing.Tuple[str, int]]: def list_asn(self) -> typing.Iterable[typing.Tuple[str, int]]:
cursor = self.conn.cursor()
self.enter_step('list_asn_select') self.enter_step('list_asn_select')
self.enter_step('get_domain_select') cursor.execute('SELECT val, entry FROM asn')
self.execute('SELECT val, entry FROM asn') for val, entry in cursor:
for val, entry in self.cursor:
yield f'AS{val}', entry yield f'AS{val}', entry
def _set_generic(self, def _set_generic(self,
@ -260,21 +361,23 @@ class Database():
# here abstraction > performaces # here abstraction > performaces
# Fields based on the source # Fields based on the source
self.enter_step(f'set_{table}_prepare')
cursor = self.conn.cursor()
if source is None: if source is None:
first_party = int(is_first_party) first_party = int(is_first_party)
level = 0 level = 0
else: else:
self.enter_step(f'set_{table}_source') self.enter_step(f'set_{table}_source')
self.execute( cursor.execute(
'SELECT first_party, level FROM rules ' 'SELECT first_party, level FROM rules '
'WHERE id=?', 'WHERE id=?',
(source,) (source,)
) )
first_party, level = self.cursor.fetchone() first_party, level = cursor.fetchone()
level += 1 level += 1
self.enter_step(f'set_{table}_select') self.enter_step(f'set_{table}_select')
self.execute(select_query, prep) cursor.execute(select_query, prep)
rules_prep: typing.Dict[str, DbValue] = { rules_prep: typing.Dict[str, DbValue] = {
"source": source, "source": source,
@ -284,10 +387,10 @@ class Database():
} }
# If the entry already exists # If the entry already exists
for entry, in self.cursor: # only one for entry, in cursor: # only one
self.enter_step(f'set_{table}_update') self.enter_step(f'set_{table}_update')
rules_prep['entry'] = entry rules_prep['entry'] = entry
self.execute( cursor.execute(
'UPDATE rules SET ' 'UPDATE rules SET '
'source=:source, updated=:updated, ' 'source=:source, updated=:updated, '
'first_party=:first_party, level=:level ' 'first_party=:first_party, level=:level '
@ -303,23 +406,18 @@ class Database():
# If it does not exist # If it does not exist
if source is not None:
self.enter_step(f'set_{table}_incsrc')
self.execute('UPDATE rules SET refs = refs + 1 WHERE id=?',
(source,))
self.enter_step(f'set_{table}_insert') self.enter_step(f'set_{table}_insert')
self.execute( cursor.execute(
'INSERT INTO rules ' 'INSERT INTO rules '
'(source, updated, first_party, refs, level) ' '(source, updated, first_party, level) '
'VALUES (:source, :updated, :first_party, 0, :level) ', 'VALUES (:source, :updated, :first_party, :level) ',
rules_prep rules_prep
) )
self.execute('SELECT id FROM rules WHERE rowid=?', cursor.execute('SELECT id FROM rules WHERE rowid=?',
(self.cursor.lastrowid,)) (cursor.lastrowid,))
for entry, in self.cursor: # only one for entry, in cursor: # only one
prep['entry'] = entry prep['entry'] = entry
self.execute(insert_query, prep) cursor.execute(insert_query, prep)
return return
assert False assert False
@ -327,7 +425,7 @@ class Database():
*args: typing.Any, **kwargs: typing.Any) -> None: *args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_hostname_prepare') self.enter_step('set_hostname_prepare')
prep: typing.Dict[str, DbValue] = { prep: typing.Dict[str, DbValue] = {
'val': self.prepare_hostname(hostname), 'val': self.pack_hostname(hostname),
} }
self._set_generic( self._set_generic(
'hostname', 'hostname',
@ -342,7 +440,7 @@ class Database():
*args: typing.Any, **kwargs: typing.Any) -> None: *args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_asn_prepare') self.enter_step('set_asn_prepare')
try: try:
asn_prep = self.prepare_asn(asn) asn_prep = self.pack_asn(asn)
except ValueError: except ValueError:
self.log.error("Invalid asn: %s", asn) self.log.error("Invalid asn: %s", asn)
return return
@ -360,10 +458,9 @@ class Database():
def set_ip4address(self, ip4address: str, def set_ip4address(self, ip4address: str,
*args: typing.Any, **kwargs: typing.Any) -> None: *args: typing.Any, **kwargs: typing.Any) -> None:
# TODO Do not add if already in ip4network
self.enter_step('set_ip4add_prepare') self.enter_step('set_ip4add_prepare')
try: try:
ip4address_prep = self.prepare_ip4address(ip4address) ip4address_prep = self.pack_ip4address(ip4address)
except (ValueError, IndexError): except (ValueError, IndexError):
self.log.error("Invalid ip4address: %s", ip4address) self.log.error("Invalid ip4address: %s", ip4address)
return return
@ -383,7 +480,7 @@ class Database():
*args: typing.Any, **kwargs: typing.Any) -> None: *args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_zone_prepare') self.enter_step('set_zone_prepare')
prep: typing.Dict[str, DbValue] = { prep: typing.Dict[str, DbValue] = {
'val': self.prepare_zone(zone), 'val': self.pack_zone(zone),
} }
self._set_generic( self._set_generic(
'zone', 'zone',
@ -398,7 +495,7 @@ class Database():
*args: typing.Any, **kwargs: typing.Any) -> None: *args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_ip4net_prepare') self.enter_step('set_ip4net_prepare')
try: try:
ip4network_prep = self.prepare_ip4network(ip4network) ip4network_prep = self.pack_ip4network(ip4network)
except (ValueError, IndexError): except (ValueError, IndexError):
self.log.error("Invalid ip4network: %s", ip4network) self.log.error("Invalid ip4network: %s", ip4network)
return return

View file

@ -10,30 +10,35 @@ CREATE TABLE rules (
level INTEGER, -- Level of recursion to the root source rule (used for source priority) level INTEGER, -- Level of recursion to the root source rule (used for source priority)
FOREIGN KEY (source) REFERENCES rules(id) ON DELETE CASCADE FOREIGN KEY (source) REFERENCES rules(id) ON DELETE CASCADE
); );
CREATE INDEX rules_source ON rules (source); -- for references recounting
CREATE TABLE asn ( CREATE TABLE asn (
val INTEGER PRIMARY KEY, val INTEGER PRIMARY KEY,
entry INTEGER, entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
); );
CREATE INDEX asn_entry ON asn (entry); -- for explainations
CREATE TABLE hostname ( CREATE TABLE hostname (
val TEXT PRIMARY KEY, -- rev'd, ends with a dot (for consistency with zone) val TEXT PRIMARY KEY, -- rev'd, ends with a dot (for consistency with zone)
entry INTEGER, entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
); );
CREATE INDEX hostname_entry ON hostname (entry); -- for explainations
CREATE TABLE zone ( CREATE TABLE zone (
val TEXT PRIMARY KEY, -- rev'd, ends with a dot (for easier matching) val TEXT PRIMARY KEY, -- rev'd, ends with a dot (for easier matching)
entry INTEGER, entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
); );
CREATE INDEX zone_entry ON zone (entry); -- for explainations
CREATE TABLE ip4address ( CREATE TABLE ip4address (
val INTEGER PRIMARY KEY, val INTEGER PRIMARY KEY,
entry INTEGER, entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
); );
CREATE INDEX ip4address_entry ON ip4address (entry); -- for explainations
CREATE TABLE ip4network ( CREATE TABLE ip4network (
-- val TEXT PRIMARY KEY, -- val TEXT PRIMARY KEY,
@ -43,6 +48,7 @@ CREATE TABLE ip4network (
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
); );
CREATE INDEX ip4network_minmax ON ip4network (mini, maxi); CREATE INDEX ip4network_minmax ON ip4network (mini, maxi);
CREATE INDEX ip4network_entry ON ip4network (entry); -- for explainations
-- Store various things -- Store various things
CREATE TABLE meta ( CREATE TABLE meta (

View file

@ -19,12 +19,18 @@ if __name__ == '__main__':
parser.add_argument( parser.add_argument(
'-e', '--end-chain', action='store_true', '-e', '--end-chain', action='store_true',
help="TODO") help="TODO")
parser.add_argument(
'-x', '--explain', action='store_true',
help="TODO")
args = parser.parse_args() args = parser.parse_args()
DB = database.Database() DB = database.Database()
for domain in DB.export(first_party_only=args.first_party, for domain in DB.export(
end_chain_only=args.end_chain): first_party_only=args.first_party,
end_chain_only=args.end_chain,
explain=args.explain,
):
print(domain, file=args.output) print(domain, file=args.output)
DB.close() DB.close()

View file

@ -50,3 +50,4 @@ if __name__ == '__main__':
log.error('Unknown prefix version: %s', prefix) log.error('Unknown prefix version: %s', prefix)
DB.close() DB.close()
DBW.close()

View file

@ -44,12 +44,15 @@ if __name__ == '__main__':
DB.enter_step('feed_switch') DB.enter_step('feed_switch')
if dtype == 'a': if dtype == 'a':
for rule in DB.get_ip4(value): for rule in DB.get_ip4(value):
if not list(DB.get_domain_in_zone(name)):
DB.set_hostname(name, source=rule, updated=updated) DB.set_hostname(name, source=rule, updated=updated)
elif dtype == 'cname': elif dtype == 'cname':
for rule in DB.get_domain(value): for rule in DB.get_domain(value):
if not list(DB.get_domain_in_zone(name)):
DB.set_hostname(name, source=rule, updated=updated) DB.set_hostname(name, source=rule, updated=updated)
elif dtype == 'ptr': elif dtype == 'ptr':
for rule in DB.get_domain(value): for rule in DB.get_domain(value):
if not list(DB.get_ip4_in_network(name)):
DB.set_ip4address(name, source=rule, updated=updated) DB.set_ip4address(name, source=rule, updated=updated)
DB.enter_step('iowait') DB.enter_step('iowait')
except KeyboardInterrupt: except KeyboardInterrupt:

View file

@ -4,6 +4,9 @@ function log() {
echo -e "\033[33m$@\033[0m" echo -e "\033[33m$@\033[0m"
} }
log "Recounting references…"
./database.py --references
log "Exporting lists…" log "Exporting lists…"
./export.py --first-party --output dist/firstparty-trackers.txt ./export.py --first-party --output dist/firstparty-trackers.txt
./export.py --first-party --end-chain --output dist/firstparty-only-trackers.txt ./export.py --first-party --end-chain --output dist/firstparty-only-trackers.txt

View file

@ -1,21 +0,0 @@
#!/usr/bin/env python3
"""
List of regex matching first-party trackers.
"""
# Syntax: https://docs.python.org/3/library/re.html#regular-expression-syntax
REGEXES = [
r'^.+\.eulerian\.net\.$', # Eulerian
r'^.+\.criteo\.com\.$', # Criteo
r'^.+\.dnsdelegation\.io\.$', # Criteo
r'^.+\.keyade\.com\.$', # Keyade
r'^.+\.omtrdc\.net\.$', # Adobe Experience Cloud
r'^.+\.bp01\.net\.$', # NP6
r'^.+\.ati-host\.net\.$', # Xiti (AT Internet)
r'^.+\.at-o\.net\.$', # Xiti (AT Internet)
r'^.+\.edgkey\.net\.$', # Edgekey (Akamai)
r'^.+\.akaimaiedge\.net\.$', # Edgekey (Akamai)
r'^.+\.storetail\.io\.$', # Storetail (Criteo)
]

View file

@ -19,12 +19,8 @@ import dns.exception
import dns.resolver import dns.resolver
DNS_TIMEOUT = 5.0 DNS_TIMEOUT = 5.0
NUMBER_THREADS = 512
NUMBER_TRIES = 5 NUMBER_TRIES = 5
# TODO All the domains don't get treated,
# so it leaves with 4-5 subdomains not resolved
class Worker(threading.Thread): class Worker(threading.Thread):
""" """
@ -135,15 +131,17 @@ class Orchestrator():
def __init__(self, subdomains: typing.Iterable[str], def __init__(self, subdomains: typing.Iterable[str],
nameservers: typing.List[str] = None, nameservers: typing.List[str] = None,
nb_workers: int = 1,
): ):
self.log = logging.getLogger('orchestrator') self.log = logging.getLogger('orchestrator')
self.subdomains = subdomains self.subdomains = subdomains
self.nb_workers = nb_workers
# Use interal resolver by default # Use interal resolver by default
self.nameservers = nameservers or dns.resolver.Resolver().nameservers self.nameservers = nameservers or dns.resolver.Resolver().nameservers
self.subdomains_queue: queue.Queue = queue.Queue( self.subdomains_queue: queue.Queue = queue.Queue(
maxsize=NUMBER_THREADS) maxsize=self.nb_workers)
self.results_queue: queue.Queue = queue.Queue() self.results_queue: queue.Queue = queue.Queue()
self.nameservers_queue: queue.Queue = queue.Queue() self.nameservers_queue: queue.Queue = queue.Queue()
@ -164,7 +162,7 @@ class Orchestrator():
self.log.info("Finished reading subdomains") self.log.info("Finished reading subdomains")
# Send sentinel to each worker # Send sentinel to each worker
# sentinel = None ~= EOF # sentinel = None ~= EOF
for _ in range(NUMBER_THREADS): for _ in range(self.nb_workers):
self.subdomains_queue.put(None) self.subdomains_queue.put(None)
@staticmethod @staticmethod
@ -189,7 +187,7 @@ class Orchestrator():
""" """
# Create workers # Create workers
self.log.info("Creating workers") self.log.info("Creating workers")
for i in range(NUMBER_THREADS): for i in range(self.nb_workers):
Worker(self, i).start() Worker(self, i).start()
fill_thread = threading.Thread(target=self.fill_subdomain_queue) fill_thread = threading.Thread(target=self.fill_subdomain_queue)
@ -197,7 +195,7 @@ class Orchestrator():
# Wait for one sentinel per worker # Wait for one sentinel per worker
# In the meantime output results # In the meantime output results
for _ in range(NUMBER_THREADS): for _ in range(self.nb_workers):
resolved: typing.List[dns.rrset.RRset] resolved: typing.List[dns.rrset.RRset]
for resolved in iter(self.results_queue.get, None): for resolved in iter(self.results_queue.get, None):
for rrset in resolved: for rrset in resolved:
@ -223,7 +221,7 @@ def main() -> None:
# Initialization # Initialization
coloredlogs.install( coloredlogs.install(
# level='DEBUG', level='DEBUG',
fmt='%(asctime)s %(name)s %(levelname)s %(message)s' fmt='%(asctime)s %(name)s %(levelname)s %(message)s'
) )
@ -236,12 +234,12 @@ def main() -> None:
parser.add_argument( parser.add_argument(
'-o', '--output', type=argparse.FileType('w'), default=sys.stdout, '-o', '--output', type=argparse.FileType('w'), default=sys.stdout,
help="Outptut file with DNS chains") help="Outptut file with DNS chains")
# parser.add_argument( parser.add_argument(
# '-n', '--nameserver', type=argparse.FileType('r'), '-n', '--nameservers', default='nameservers',
# default='nameservers', help="File with one nameserver per line") help="File with one nameserver per line")
# parser.add_argument( parser.add_argument(
# '-j', '--workers', type=int, default=512, '-j', '--workers', type=int, default=512,
# help="Number of threads to use") help="Number of threads to use")
args = parser.parse_args() args = parser.parse_args()
# Cleaning input # Cleaning input
@ -251,11 +249,15 @@ def main() -> None:
# Reading nameservers # Reading nameservers
servers: typing.List[str] = list() servers: typing.List[str] = list()
if os.path.isfile('nameservers'): if os.path.isfile(args.nameservers):
servers = open('nameservers').readlines() servers = open(args.nameservers).readlines()
servers = list(filter(None, map(str.strip, servers))) servers = list(filter(None, map(str.strip, servers)))
for resolved in Orchestrator(iterator, servers).run(): for resolved in Orchestrator(
iterator,
servers,
nb_workers=args.workers
).run():
args.output.write(resolved) args.output.write(resolved)