Refactored for correct retry logic
This commit is contained in:
parent
b343893c72
commit
88f0bcc648
|
@ -1,5 +1,4 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# pylint: disable=C0103
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
From a list of subdomains, output only
|
From a list of subdomains, output only
|
||||||
|
@ -7,11 +6,11 @@ the ones resolving to a first-party tracker.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import threading
|
|
||||||
import queue
|
|
||||||
import os
|
import os
|
||||||
|
import queue
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
import threading
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
import coloredlogs
|
import coloredlogs
|
||||||
|
@ -22,31 +21,43 @@ import progressbar
|
||||||
import regexes
|
import regexes
|
||||||
|
|
||||||
DNS_TIMEOUT = 60.0
|
DNS_TIMEOUT = 60.0
|
||||||
|
NUMBER_THREADS = 512
|
||||||
|
NUMBER_TRIES = 10
|
||||||
|
|
||||||
# TODO Try again does not work because sentinel get through first :/
|
|
||||||
|
|
||||||
class DnsResolver(threading.Thread):
|
class Worker(threading.Thread):
|
||||||
"""
|
"""
|
||||||
Worker process for a DNS resolver.
|
Worker process for a DNS resolver.
|
||||||
Will resolve DNS to match first-party subdomains.
|
Will resolve DNS to match first-party subdomains.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def change_nameserver(self) -> None:
|
||||||
in_queue: queue.Queue,
|
"""
|
||||||
out_queue: queue.Queue,
|
Assign a this worker another nameserver from the queue.
|
||||||
server: str):
|
"""
|
||||||
super(DnsResolver, self).__init__()
|
server = None
|
||||||
self.log = logging.getLogger(server)
|
while server is None:
|
||||||
|
try:
|
||||||
|
server = self.orchestrator.nameservers_queue.get(block=False)
|
||||||
|
except queue.Empty:
|
||||||
|
self.orchestrator.refill_nameservers_queue()
|
||||||
|
self.log.debug("Using nameserver: %s", server)
|
||||||
|
self.resolver.nameservers = [server]
|
||||||
|
|
||||||
self.in_queue = in_queue
|
def __init__(self,
|
||||||
self.out_queue = out_queue
|
orchestrator: 'Orchestrator',
|
||||||
|
index: int = 0):
|
||||||
|
super(Worker, self).__init__()
|
||||||
|
self.log = logging.getLogger(f'worker{index:03d}')
|
||||||
|
self.orchestrator = orchestrator
|
||||||
|
|
||||||
self.resolver = dns.resolver.Resolver()
|
self.resolver = dns.resolver.Resolver()
|
||||||
self.resolver.nameservers = [server]
|
self.change_nameserver()
|
||||||
|
|
||||||
def is_subdomain_matching(self, subdomain: str) -> typing.Optional[bool]:
|
def is_subdomain_matching(self, subdomain: str) -> typing.Optional[bool]:
|
||||||
"""
|
"""
|
||||||
Indicates if the subdomain redirects to a first-party tracker.
|
Indicates if the subdomain redirects to a first-party tracker.
|
||||||
|
Returns None if the nameserver was unable to satisfy the request.
|
||||||
"""
|
"""
|
||||||
# TODO Look at the whole chain rather than the last one
|
# TODO Look at the whole chain rather than the last one
|
||||||
# TODO Also match the ASN of the IP (caching the ASN subnetworks will do)
|
# TODO Also match the ASN of the IP (caching the ASN subnetworks will do)
|
||||||
|
@ -58,16 +69,20 @@ class DnsResolver(threading.Thread):
|
||||||
return False
|
return False
|
||||||
except dns.resolver.YXDOMAIN:
|
except dns.resolver.YXDOMAIN:
|
||||||
self.log.warning("Query name too long for %s", subdomain)
|
self.log.warning("Query name too long for %s", subdomain)
|
||||||
return False
|
return None
|
||||||
except dns.resolver.NoNameservers:
|
except dns.resolver.NoNameservers:
|
||||||
|
# NOTE Most of the time this error message means that the domain
|
||||||
|
# does not exists, but sometimes it means the that the server
|
||||||
|
# itself is broken. So we count on the retry logic.
|
||||||
self.log.warning("All nameservers broken for %s", subdomain)
|
self.log.warning("All nameservers broken for %s", subdomain)
|
||||||
return None
|
return None
|
||||||
except dns.exception.Timeout:
|
except dns.exception.Timeout:
|
||||||
|
# NOTE Same as above
|
||||||
self.log.warning("Timeout for %s", subdomain)
|
self.log.warning("Timeout for %s", subdomain)
|
||||||
return None
|
return None
|
||||||
except dns.name.EmptyLabel:
|
except dns.name.EmptyLabel:
|
||||||
self.log.warning("Empty label for %s", subdomain)
|
self.log.warning("Empty label for %s", subdomain)
|
||||||
return False
|
return None
|
||||||
canonical = query.canonical_name.to_text()
|
canonical = query.canonical_name.to_text()
|
||||||
for regex in regexes.REGEXES:
|
for regex in regexes.REGEXES:
|
||||||
if re.match(regex, canonical):
|
if re.match(regex, canonical):
|
||||||
|
@ -76,58 +91,112 @@ class DnsResolver(threading.Thread):
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
self.log.info("Started")
|
self.log.info("Started")
|
||||||
for subdomain in iter(self.in_queue.get, None):
|
for subdomain in iter(self.orchestrator.subdomains_queue.get, None):
|
||||||
matching = self.is_subdomain_matching(subdomain)
|
|
||||||
|
|
||||||
# If issue, retry
|
for _ in range(NUMBER_TRIES):
|
||||||
|
matching = self.is_subdomain_matching(subdomain)
|
||||||
|
if matching is not None:
|
||||||
|
break
|
||||||
|
|
||||||
|
# If it wasn't found after multiple tries
|
||||||
if matching is None:
|
if matching is None:
|
||||||
# matching = False
|
self.log.error("Gave up on %s", subdomain)
|
||||||
self.in_queue.put(subdomain)
|
matching = False
|
||||||
continue
|
|
||||||
|
|
||||||
result = (subdomain, matching)
|
result = (subdomain, matching)
|
||||||
# self.log.debug("%s", result)
|
self.orchestrator.results_queue.put(result)
|
||||||
self.out_queue.put(result)
|
|
||||||
self.out_queue.put(None)
|
self.orchestrator.results_queue.put(None)
|
||||||
self.log.info("Stopped")
|
self.log.info("Stopped")
|
||||||
|
|
||||||
|
|
||||||
def get_matching_subdomains(subdomains: typing.Iterable[str],
|
class Orchestrator():
|
||||||
nameservers: typing.List[str] = None,
|
|
||||||
) -> typing.Iterable[typing.Tuple[str, bool]]:
|
|
||||||
subdomains_queue: queue.Queue = queue.Queue()
|
|
||||||
results_queue: queue.Queue = queue.Queue()
|
|
||||||
"""
|
"""
|
||||||
Orchestrator of the different DnsResolver threads.
|
Orchestrator of the different Worker threads.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def refill_nameservers_queue(self) -> None:
|
||||||
|
"""
|
||||||
|
Re-fill the given nameservers into the nameservers queue.
|
||||||
|
Done every-time the queue is empty, making it
|
||||||
|
basically looping and infinite.
|
||||||
|
"""
|
||||||
|
# Might be in a race condition but that's probably fine
|
||||||
|
for nameserver in self.nameservers:
|
||||||
|
self.nameservers_queue.put(nameserver)
|
||||||
|
self.log.info("Refilled nameserver queue")
|
||||||
|
|
||||||
|
def __init__(self, subdomains: typing.Iterable[str],
|
||||||
|
nameservers: typing.List[str] = None):
|
||||||
|
self.log = logging.getLogger('orchestrator')
|
||||||
|
self.subdomains = subdomains
|
||||||
|
|
||||||
# Use interal resolver by default
|
# Use interal resolver by default
|
||||||
servers = nameservers or dns.resolver.Resolver().nameservers
|
self.nameservers = nameservers or dns.resolver.Resolver().nameservers
|
||||||
|
|
||||||
# Create workers
|
self.subdomains_queue: queue.Queue = queue.Queue(
|
||||||
for server in servers:
|
maxsize=NUMBER_THREADS)
|
||||||
DnsResolver(subdomains_queue, results_queue, server).start()
|
self.results_queue: queue.Queue = queue.Queue()
|
||||||
|
self.nameservers_queue: queue.Queue = queue.Queue()
|
||||||
|
|
||||||
|
self.refill_nameservers_queue()
|
||||||
|
|
||||||
|
def fill_subdomain_queue(self) -> None:
|
||||||
|
"""
|
||||||
|
Read the subdomains in input and put them into the queue.
|
||||||
|
Done in a thread so we can both:
|
||||||
|
- yield the results as they come
|
||||||
|
- not store all the subdomains at once
|
||||||
|
"""
|
||||||
|
self.log.info("Started reading subdomains")
|
||||||
# Send data to workers
|
# Send data to workers
|
||||||
for subdomain in subdomains:
|
for subdomain in self.subdomains:
|
||||||
subdomains_queue.put(subdomain)
|
self.subdomains_queue.put(subdomain)
|
||||||
|
|
||||||
|
self.log.info("Finished reading subdomains")
|
||||||
# Send sentinel to each worker
|
# Send sentinel to each worker
|
||||||
# sentinel = None ~= EOF
|
# sentinel = None ~= EOF
|
||||||
for _ in servers:
|
for _ in range(NUMBER_THREADS):
|
||||||
subdomains_queue.put(None)
|
self.subdomains_queue.put(None)
|
||||||
|
|
||||||
|
def run(self) -> typing.Iterable[typing.Tuple[str, bool]]:
|
||||||
|
"""
|
||||||
|
Yield the results.
|
||||||
|
"""
|
||||||
|
# Create workers
|
||||||
|
self.log.info("Creating workers")
|
||||||
|
for i in range(NUMBER_THREADS):
|
||||||
|
Worker(self, i).start()
|
||||||
|
|
||||||
|
fill_thread = threading.Thread(target=self.fill_subdomain_queue)
|
||||||
|
fill_thread.start()
|
||||||
|
|
||||||
# Wait for one sentinel per worker
|
# Wait for one sentinel per worker
|
||||||
# In the meantime output results
|
# In the meantime output results
|
||||||
for _ in servers:
|
for _ in range(NUMBER_THREADS):
|
||||||
for result in iter(results_queue.get, None):
|
for result in iter(self.results_queue.get, None):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
|
self.log.info("Waiting for reader thread")
|
||||||
|
fill_thread.join()
|
||||||
|
|
||||||
|
self.log.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""
|
||||||
|
Main function when used directly.
|
||||||
|
Takes as an input a filename (or nothing, for stdin)
|
||||||
|
that will be read and the ones that are a tracker
|
||||||
|
will be outputed on stdout.
|
||||||
|
Use the file `nameservers` as the list of nameservers
|
||||||
|
to use, or else it will use the system defaults.
|
||||||
|
Also shows a nice progressbar.
|
||||||
|
"""
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
coloredlogs.install(
|
coloredlogs.install(
|
||||||
level='DEBUG',
|
level='DEBUG',
|
||||||
fmt='%(asctime)s %(name)s[%(process)d] %(levelname)s %(message)s'
|
fmt='%(asctime)s %(name)s %(levelname)s %(message)s'
|
||||||
)
|
)
|
||||||
|
|
||||||
# Progress bar
|
# Progress bar
|
||||||
|
@ -164,8 +233,12 @@ if __name__ == '__main__':
|
||||||
servers = list(filter(None, map(str.strip, servers)))
|
servers = list(filter(None, map(str.strip, servers)))
|
||||||
|
|
||||||
progress.start()
|
progress.start()
|
||||||
for subdomain, matching in get_matching_subdomains(iterator, servers):
|
for subdomain, matching in Orchestrator(iterator, servers).run():
|
||||||
progress.update(progress.value + 1)
|
progress.update(progress.value + 1)
|
||||||
if matching:
|
if matching:
|
||||||
print(subdomain)
|
print(subdomain)
|
||||||
progress.finish()
|
progress.finish()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
|
Loading…
Reference in a new issue