Refactored for correct retry logic

This commit is contained in:
Geoffrey Frogeye 2019-11-14 15:03:20 +01:00
parent b343893c72
commit 88f0bcc648

View file

@ -1,5 +1,4 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# pylint: disable=C0103
""" """
From a list of subdomains, output only From a list of subdomains, output only
@ -7,11 +6,11 @@ the ones resolving to a first-party tracker.
""" """
import logging import logging
import threading
import queue
import os import os
import queue
import re import re
import sys import sys
import threading
import typing import typing
import coloredlogs import coloredlogs
@ -22,31 +21,43 @@ import progressbar
import regexes import regexes
DNS_TIMEOUT = 60.0 DNS_TIMEOUT = 60.0
NUMBER_THREADS = 512
NUMBER_TRIES = 10
# TODO Try again does not work because sentinel get through first :/
class DnsResolver(threading.Thread): class Worker(threading.Thread):
""" """
Worker process for a DNS resolver. Worker process for a DNS resolver.
Will resolve DNS to match first-party subdomains. Will resolve DNS to match first-party subdomains.
""" """
def __init__(self, def change_nameserver(self) -> None:
in_queue: queue.Queue, """
out_queue: queue.Queue, Assign a this worker another nameserver from the queue.
server: str): """
super(DnsResolver, self).__init__() server = None
self.log = logging.getLogger(server) while server is None:
try:
server = self.orchestrator.nameservers_queue.get(block=False)
except queue.Empty:
self.orchestrator.refill_nameservers_queue()
self.log.debug("Using nameserver: %s", server)
self.resolver.nameservers = [server]
self.in_queue = in_queue def __init__(self,
self.out_queue = out_queue orchestrator: 'Orchestrator',
index: int = 0):
super(Worker, self).__init__()
self.log = logging.getLogger(f'worker{index:03d}')
self.orchestrator = orchestrator
self.resolver = dns.resolver.Resolver() self.resolver = dns.resolver.Resolver()
self.resolver.nameservers = [server] self.change_nameserver()
def is_subdomain_matching(self, subdomain: str) -> typing.Optional[bool]: def is_subdomain_matching(self, subdomain: str) -> typing.Optional[bool]:
""" """
Indicates if the subdomain redirects to a first-party tracker. Indicates if the subdomain redirects to a first-party tracker.
Returns None if the nameserver was unable to satisfy the request.
""" """
# TODO Look at the whole chain rather than the last one # TODO Look at the whole chain rather than the last one
# TODO Also match the ASN of the IP (caching the ASN subnetworks will do) # TODO Also match the ASN of the IP (caching the ASN subnetworks will do)
@ -58,16 +69,20 @@ class DnsResolver(threading.Thread):
return False return False
except dns.resolver.YXDOMAIN: except dns.resolver.YXDOMAIN:
self.log.warning("Query name too long for %s", subdomain) self.log.warning("Query name too long for %s", subdomain)
return False return None
except dns.resolver.NoNameservers: except dns.resolver.NoNameservers:
# NOTE Most of the time this error message means that the domain
# does not exists, but sometimes it means the that the server
# itself is broken. So we count on the retry logic.
self.log.warning("All nameservers broken for %s", subdomain) self.log.warning("All nameservers broken for %s", subdomain)
return None return None
except dns.exception.Timeout: except dns.exception.Timeout:
# NOTE Same as above
self.log.warning("Timeout for %s", subdomain) self.log.warning("Timeout for %s", subdomain)
return None return None
except dns.name.EmptyLabel: except dns.name.EmptyLabel:
self.log.warning("Empty label for %s", subdomain) self.log.warning("Empty label for %s", subdomain)
return False return None
canonical = query.canonical_name.to_text() canonical = query.canonical_name.to_text()
for regex in regexes.REGEXES: for regex in regexes.REGEXES:
if re.match(regex, canonical): if re.match(regex, canonical):
@ -76,58 +91,112 @@ class DnsResolver(threading.Thread):
def run(self) -> None: def run(self) -> None:
self.log.info("Started") self.log.info("Started")
for subdomain in iter(self.in_queue.get, None): for subdomain in iter(self.orchestrator.subdomains_queue.get, None):
matching = self.is_subdomain_matching(subdomain)
# If issue, retry for _ in range(NUMBER_TRIES):
matching = self.is_subdomain_matching(subdomain)
if matching is not None:
break
# If it wasn't found after multiple tries
if matching is None: if matching is None:
# matching = False self.log.error("Gave up on %s", subdomain)
self.in_queue.put(subdomain) matching = False
continue
result = (subdomain, matching) result = (subdomain, matching)
# self.log.debug("%s", result) self.orchestrator.results_queue.put(result)
self.out_queue.put(result)
self.out_queue.put(None) self.orchestrator.results_queue.put(None)
self.log.info("Stopped") self.log.info("Stopped")
def get_matching_subdomains(subdomains: typing.Iterable[str], class Orchestrator():
nameservers: typing.List[str] = None,
) -> typing.Iterable[typing.Tuple[str, bool]]:
subdomains_queue: queue.Queue = queue.Queue()
results_queue: queue.Queue = queue.Queue()
""" """
Orchestrator of the different DnsResolver threads. Orchestrator of the different Worker threads.
""" """
def refill_nameservers_queue(self) -> None:
"""
Re-fill the given nameservers into the nameservers queue.
Done every-time the queue is empty, making it
basically looping and infinite.
"""
# Might be in a race condition but that's probably fine
for nameserver in self.nameservers:
self.nameservers_queue.put(nameserver)
self.log.info("Refilled nameserver queue")
def __init__(self, subdomains: typing.Iterable[str],
nameservers: typing.List[str] = None):
self.log = logging.getLogger('orchestrator')
self.subdomains = subdomains
# Use interal resolver by default # Use interal resolver by default
servers = nameservers or dns.resolver.Resolver().nameservers self.nameservers = nameservers or dns.resolver.Resolver().nameservers
# Create workers self.subdomains_queue: queue.Queue = queue.Queue(
for server in servers: maxsize=NUMBER_THREADS)
DnsResolver(subdomains_queue, results_queue, server).start() self.results_queue: queue.Queue = queue.Queue()
self.nameservers_queue: queue.Queue = queue.Queue()
self.refill_nameservers_queue()
def fill_subdomain_queue(self) -> None:
"""
Read the subdomains in input and put them into the queue.
Done in a thread so we can both:
- yield the results as they come
- not store all the subdomains at once
"""
self.log.info("Started reading subdomains")
# Send data to workers # Send data to workers
for subdomain in subdomains: for subdomain in self.subdomains:
subdomains_queue.put(subdomain) self.subdomains_queue.put(subdomain)
self.log.info("Finished reading subdomains")
# Send sentinel to each worker # Send sentinel to each worker
# sentinel = None ~= EOF # sentinel = None ~= EOF
for _ in servers: for _ in range(NUMBER_THREADS):
subdomains_queue.put(None) self.subdomains_queue.put(None)
def run(self) -> typing.Iterable[typing.Tuple[str, bool]]:
"""
Yield the results.
"""
# Create workers
self.log.info("Creating workers")
for i in range(NUMBER_THREADS):
Worker(self, i).start()
fill_thread = threading.Thread(target=self.fill_subdomain_queue)
fill_thread.start()
# Wait for one sentinel per worker # Wait for one sentinel per worker
# In the meantime output results # In the meantime output results
for _ in servers: for _ in range(NUMBER_THREADS):
for result in iter(results_queue.get, None): for result in iter(self.results_queue.get, None):
yield result yield result
self.log.info("Waiting for reader thread")
fill_thread.join()
self.log.info("Done!")
def main() -> None:
"""
Main function when used directly.
Takes as an input a filename (or nothing, for stdin)
that will be read and the ones that are a tracker
will be outputed on stdout.
Use the file `nameservers` as the list of nameservers
to use, or else it will use the system defaults.
Also shows a nice progressbar.
"""
if __name__ == '__main__':
coloredlogs.install( coloredlogs.install(
level='DEBUG', level='DEBUG',
fmt='%(asctime)s %(name)s[%(process)d] %(levelname)s %(message)s' fmt='%(asctime)s %(name)s %(levelname)s %(message)s'
) )
# Progress bar # Progress bar
@ -164,8 +233,12 @@ if __name__ == '__main__':
servers = list(filter(None, map(str.strip, servers))) servers = list(filter(None, map(str.strip, servers)))
progress.start() progress.start()
for subdomain, matching in get_matching_subdomains(iterator, servers): for subdomain, matching in Orchestrator(iterator, servers).run():
progress.update(progress.value + 1) progress.update(progress.value + 1)
if matching: if matching:
print(subdomain) print(subdomain)
progress.finish() progress.finish()
if __name__ == '__main__':
main()