diff --git a/filter_subdomains.py b/filter_subdomains.py index de70809..c7a0bfc 100755 --- a/filter_subdomains.py +++ b/filter_subdomains.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# pylint: disable=C0103 """ From a list of subdomains, output only @@ -7,11 +6,11 @@ the ones resolving to a first-party tracker. """ import logging -import threading -import queue import os +import queue import re import sys +import threading import typing import coloredlogs @@ -22,31 +21,43 @@ import progressbar import regexes DNS_TIMEOUT = 60.0 +NUMBER_THREADS = 512 +NUMBER_TRIES = 10 -# TODO Try again does not work because sentinel get through first :/ -class DnsResolver(threading.Thread): +class Worker(threading.Thread): """ Worker process for a DNS resolver. Will resolve DNS to match first-party subdomains. """ - def __init__(self, - in_queue: queue.Queue, - out_queue: queue.Queue, - server: str): - super(DnsResolver, self).__init__() - self.log = logging.getLogger(server) + def change_nameserver(self) -> None: + """ + Assign a this worker another nameserver from the queue. + """ + server = None + while server is None: + try: + server = self.orchestrator.nameservers_queue.get(block=False) + except queue.Empty: + self.orchestrator.refill_nameservers_queue() + self.log.debug("Using nameserver: %s", server) + self.resolver.nameservers = [server] - self.in_queue = in_queue - self.out_queue = out_queue + def __init__(self, + orchestrator: 'Orchestrator', + index: int = 0): + super(Worker, self).__init__() + self.log = logging.getLogger(f'worker{index:03d}') + self.orchestrator = orchestrator self.resolver = dns.resolver.Resolver() - self.resolver.nameservers = [server] + self.change_nameserver() def is_subdomain_matching(self, subdomain: str) -> typing.Optional[bool]: """ Indicates if the subdomain redirects to a first-party tracker. + Returns None if the nameserver was unable to satisfy the request. """ # TODO Look at the whole chain rather than the last one # TODO Also match the ASN of the IP (caching the ASN subnetworks will do) @@ -58,16 +69,20 @@ class DnsResolver(threading.Thread): return False except dns.resolver.YXDOMAIN: self.log.warning("Query name too long for %s", subdomain) - return False + return None except dns.resolver.NoNameservers: + # NOTE Most of the time this error message means that the domain + # does not exists, but sometimes it means the that the server + # itself is broken. So we count on the retry logic. self.log.warning("All nameservers broken for %s", subdomain) return None except dns.exception.Timeout: + # NOTE Same as above self.log.warning("Timeout for %s", subdomain) return None except dns.name.EmptyLabel: self.log.warning("Empty label for %s", subdomain) - return False + return None canonical = query.canonical_name.to_text() for regex in regexes.REGEXES: if re.match(regex, canonical): @@ -76,58 +91,112 @@ class DnsResolver(threading.Thread): def run(self) -> None: self.log.info("Started") - for subdomain in iter(self.in_queue.get, None): - matching = self.is_subdomain_matching(subdomain) + for subdomain in iter(self.orchestrator.subdomains_queue.get, None): - # If issue, retry + for _ in range(NUMBER_TRIES): + matching = self.is_subdomain_matching(subdomain) + if matching is not None: + break + + # If it wasn't found after multiple tries if matching is None: - # matching = False - self.in_queue.put(subdomain) - continue + self.log.error("Gave up on %s", subdomain) + matching = False result = (subdomain, matching) - # self.log.debug("%s", result) - self.out_queue.put(result) - self.out_queue.put(None) + self.orchestrator.results_queue.put(result) + + self.orchestrator.results_queue.put(None) self.log.info("Stopped") -def get_matching_subdomains(subdomains: typing.Iterable[str], - nameservers: typing.List[str] = None, - ) -> typing.Iterable[typing.Tuple[str, bool]]: - subdomains_queue: queue.Queue = queue.Queue() - results_queue: queue.Queue = queue.Queue() +class Orchestrator(): """ - Orchestrator of the different DnsResolver threads. + Orchestrator of the different Worker threads. """ - # Use interal resolver by default - servers = nameservers or dns.resolver.Resolver().nameservers + def refill_nameservers_queue(self) -> None: + """ + Re-fill the given nameservers into the nameservers queue. + Done every-time the queue is empty, making it + basically looping and infinite. + """ + # Might be in a race condition but that's probably fine + for nameserver in self.nameservers: + self.nameservers_queue.put(nameserver) + self.log.info("Refilled nameserver queue") - # Create workers - for server in servers: - DnsResolver(subdomains_queue, results_queue, server).start() + def __init__(self, subdomains: typing.Iterable[str], + nameservers: typing.List[str] = None): + self.log = logging.getLogger('orchestrator') + self.subdomains = subdomains - # Send data to workers - for subdomain in subdomains: - subdomains_queue.put(subdomain) + # Use interal resolver by default + self.nameservers = nameservers or dns.resolver.Resolver().nameservers - # Send sentinel to each worker - # sentinel = None ~= EOF - for _ in servers: - subdomains_queue.put(None) + self.subdomains_queue: queue.Queue = queue.Queue( + maxsize=NUMBER_THREADS) + self.results_queue: queue.Queue = queue.Queue() + self.nameservers_queue: queue.Queue = queue.Queue() - # Wait for one sentinel per worker - # In the meantime output results - for _ in servers: - for result in iter(results_queue.get, None): - yield result + self.refill_nameservers_queue() + + def fill_subdomain_queue(self) -> None: + """ + Read the subdomains in input and put them into the queue. + Done in a thread so we can both: + - yield the results as they come + - not store all the subdomains at once + """ + self.log.info("Started reading subdomains") + # Send data to workers + for subdomain in self.subdomains: + self.subdomains_queue.put(subdomain) + + self.log.info("Finished reading subdomains") + # Send sentinel to each worker + # sentinel = None ~= EOF + for _ in range(NUMBER_THREADS): + self.subdomains_queue.put(None) + + def run(self) -> typing.Iterable[typing.Tuple[str, bool]]: + """ + Yield the results. + """ + # Create workers + self.log.info("Creating workers") + for i in range(NUMBER_THREADS): + Worker(self, i).start() + + fill_thread = threading.Thread(target=self.fill_subdomain_queue) + fill_thread.start() + + # Wait for one sentinel per worker + # In the meantime output results + for _ in range(NUMBER_THREADS): + for result in iter(self.results_queue.get, None): + yield result + + self.log.info("Waiting for reader thread") + fill_thread.join() + + self.log.info("Done!") -if __name__ == '__main__': +def main() -> None: + """ + Main function when used directly. + Takes as an input a filename (or nothing, for stdin) + that will be read and the ones that are a tracker + will be outputed on stdout. + Use the file `nameservers` as the list of nameservers + to use, or else it will use the system defaults. + Also shows a nice progressbar. + """ + coloredlogs.install( level='DEBUG', - fmt='%(asctime)s %(name)s[%(process)d] %(levelname)s %(message)s' + fmt='%(asctime)s %(name)s %(levelname)s %(message)s' ) # Progress bar @@ -164,8 +233,12 @@ if __name__ == '__main__': servers = list(filter(None, map(str.strip, servers))) progress.start() - for subdomain, matching in get_matching_subdomains(iterator, servers): + for subdomain, matching in Orchestrator(iterator, servers).run(): progress.update(progress.value + 1) if matching: print(subdomain) progress.finish() + + +if __name__ == '__main__': + main()