Refactored for correct retry logic
This commit is contained in:
		
							parent
							
								
									b343893c72
								
							
						
					
					
						commit
						88f0bcc648
					
				
					 1 changed files with 124 additions and 51 deletions
				
			
		|  | @ -1,5 +1,4 @@ | |||
| #!/usr/bin/env python3 | ||||
| # pylint: disable=C0103 | ||||
| 
 | ||||
| """ | ||||
| From a list of subdomains, output only | ||||
|  | @ -7,11 +6,11 @@ the ones resolving to a first-party tracker. | |||
| """ | ||||
| 
 | ||||
| import logging | ||||
| import threading | ||||
| import queue | ||||
| import os | ||||
| import queue | ||||
| import re | ||||
| import sys | ||||
| import threading | ||||
| import typing | ||||
| 
 | ||||
| import coloredlogs | ||||
|  | @ -22,31 +21,43 @@ import progressbar | |||
| import regexes | ||||
| 
 | ||||
| DNS_TIMEOUT = 60.0 | ||||
| NUMBER_THREADS = 512 | ||||
| NUMBER_TRIES = 10 | ||||
| 
 | ||||
| # TODO Try again does not work because sentinel get through first :/ | ||||
| 
 | ||||
| class DnsResolver(threading.Thread): | ||||
| class Worker(threading.Thread): | ||||
|     """ | ||||
|     Worker process for a DNS resolver. | ||||
|     Will resolve DNS to match first-party subdomains. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, | ||||
|                  in_queue: queue.Queue, | ||||
|                  out_queue: queue.Queue, | ||||
|                  server: str): | ||||
|         super(DnsResolver, self).__init__() | ||||
|         self.log = logging.getLogger(server) | ||||
|     def change_nameserver(self) -> None: | ||||
|         """ | ||||
|         Assign a this worker another nameserver from the queue. | ||||
|         """ | ||||
|         server = None | ||||
|         while server is None: | ||||
|             try: | ||||
|                 server = self.orchestrator.nameservers_queue.get(block=False) | ||||
|             except queue.Empty: | ||||
|                 self.orchestrator.refill_nameservers_queue() | ||||
|         self.log.debug("Using nameserver: %s", server) | ||||
|         self.resolver.nameservers = [server] | ||||
| 
 | ||||
|         self.in_queue = in_queue | ||||
|         self.out_queue = out_queue | ||||
|     def __init__(self, | ||||
|                  orchestrator: 'Orchestrator', | ||||
|                  index: int = 0): | ||||
|         super(Worker, self).__init__() | ||||
|         self.log = logging.getLogger(f'worker{index:03d}') | ||||
|         self.orchestrator = orchestrator | ||||
| 
 | ||||
|         self.resolver = dns.resolver.Resolver() | ||||
|         self.resolver.nameservers = [server] | ||||
|         self.change_nameserver() | ||||
| 
 | ||||
|     def is_subdomain_matching(self, subdomain: str) -> typing.Optional[bool]: | ||||
|         """ | ||||
|         Indicates if the subdomain redirects to a first-party tracker. | ||||
|         Returns None if the nameserver was unable to satisfy the request. | ||||
|         """ | ||||
|         # TODO Look at the whole chain rather than the last one | ||||
|         # TODO Also match the ASN of the IP (caching the ASN subnetworks will do) | ||||
|  | @ -58,16 +69,20 @@ class DnsResolver(threading.Thread): | |||
|             return False | ||||
|         except dns.resolver.YXDOMAIN: | ||||
|             self.log.warning("Query name too long for %s", subdomain) | ||||
|             return False | ||||
|             return None | ||||
|         except dns.resolver.NoNameservers: | ||||
|             # NOTE Most of the time this error message means that the domain | ||||
|             # does not exists, but sometimes it means the that the server | ||||
|             # itself is broken. So we count on the retry logic. | ||||
|             self.log.warning("All nameservers broken for %s", subdomain) | ||||
|             return None | ||||
|         except dns.exception.Timeout: | ||||
|             # NOTE Same as above | ||||
|             self.log.warning("Timeout for %s", subdomain) | ||||
|             return None | ||||
|         except dns.name.EmptyLabel: | ||||
|             self.log.warning("Empty label for %s", subdomain) | ||||
|             return False | ||||
|             return None | ||||
|         canonical = query.canonical_name.to_text() | ||||
|         for regex in regexes.REGEXES: | ||||
|             if re.match(regex, canonical): | ||||
|  | @ -76,58 +91,112 @@ class DnsResolver(threading.Thread): | |||
| 
 | ||||
|     def run(self) -> None: | ||||
|         self.log.info("Started") | ||||
|         for subdomain in iter(self.in_queue.get, None): | ||||
|             matching = self.is_subdomain_matching(subdomain) | ||||
|         for subdomain in iter(self.orchestrator.subdomains_queue.get, None): | ||||
| 
 | ||||
|             # If issue, retry | ||||
|             for _ in range(NUMBER_TRIES): | ||||
|                 matching = self.is_subdomain_matching(subdomain) | ||||
|                 if matching is not None: | ||||
|                     break | ||||
| 
 | ||||
|             # If it wasn't found after multiple tries | ||||
|             if matching is None: | ||||
|                 # matching = False | ||||
|                 self.in_queue.put(subdomain) | ||||
|                 continue | ||||
|                 self.log.error("Gave up on %s", subdomain) | ||||
|                 matching = False | ||||
| 
 | ||||
|             result = (subdomain, matching) | ||||
|             # self.log.debug("%s", result) | ||||
|             self.out_queue.put(result) | ||||
|         self.out_queue.put(None) | ||||
|             self.orchestrator.results_queue.put(result) | ||||
| 
 | ||||
|         self.orchestrator.results_queue.put(None) | ||||
|         self.log.info("Stopped") | ||||
| 
 | ||||
| 
 | ||||
| def get_matching_subdomains(subdomains: typing.Iterable[str], | ||||
|                             nameservers: typing.List[str] = None, | ||||
|                             ) -> typing.Iterable[typing.Tuple[str, bool]]: | ||||
|     subdomains_queue: queue.Queue = queue.Queue() | ||||
|     results_queue: queue.Queue = queue.Queue() | ||||
| class Orchestrator(): | ||||
|     """ | ||||
|     Orchestrator of the different DnsResolver threads. | ||||
|     Orchestrator of the different Worker threads. | ||||
|     """ | ||||
| 
 | ||||
|     def refill_nameservers_queue(self) -> None: | ||||
|         """ | ||||
|         Re-fill the given nameservers into the nameservers queue. | ||||
|         Done every-time the queue is empty, making it | ||||
|         basically looping and infinite. | ||||
|         """ | ||||
|         # Might be in a race condition but that's probably fine | ||||
|         for nameserver in self.nameservers: | ||||
|             self.nameservers_queue.put(nameserver) | ||||
|         self.log.info("Refilled nameserver queue") | ||||
| 
 | ||||
|     def __init__(self, subdomains: typing.Iterable[str], | ||||
|                  nameservers: typing.List[str] = None): | ||||
|         self.log = logging.getLogger('orchestrator') | ||||
|         self.subdomains = subdomains | ||||
| 
 | ||||
|         # Use interal resolver by default | ||||
|     servers = nameservers or dns.resolver.Resolver().nameservers | ||||
|         self.nameservers = nameservers or dns.resolver.Resolver().nameservers | ||||
| 
 | ||||
|     # Create workers | ||||
|     for server in servers: | ||||
|         DnsResolver(subdomains_queue, results_queue, server).start() | ||||
|         self.subdomains_queue: queue.Queue = queue.Queue( | ||||
|             maxsize=NUMBER_THREADS) | ||||
|         self.results_queue: queue.Queue = queue.Queue() | ||||
|         self.nameservers_queue: queue.Queue = queue.Queue() | ||||
| 
 | ||||
|         self.refill_nameservers_queue() | ||||
| 
 | ||||
|     def fill_subdomain_queue(self) -> None: | ||||
|         """ | ||||
|         Read the subdomains in input and put them into the queue. | ||||
|         Done in a thread so we can both: | ||||
|         - yield the results as they come | ||||
|         - not store all the subdomains at once | ||||
|         """ | ||||
|         self.log.info("Started reading subdomains") | ||||
|         # Send data to workers | ||||
|     for subdomain in subdomains: | ||||
|         subdomains_queue.put(subdomain) | ||||
|         for subdomain in self.subdomains: | ||||
|             self.subdomains_queue.put(subdomain) | ||||
| 
 | ||||
|         self.log.info("Finished reading subdomains") | ||||
|         # Send sentinel to each worker | ||||
|         # sentinel = None ~= EOF | ||||
|     for _ in servers: | ||||
|         subdomains_queue.put(None) | ||||
|         for _ in range(NUMBER_THREADS): | ||||
|             self.subdomains_queue.put(None) | ||||
| 
 | ||||
|     def run(self) -> typing.Iterable[typing.Tuple[str, bool]]: | ||||
|         """ | ||||
|         Yield the results. | ||||
|         """ | ||||
|         # Create workers | ||||
|         self.log.info("Creating workers") | ||||
|         for i in range(NUMBER_THREADS): | ||||
|             Worker(self, i).start() | ||||
| 
 | ||||
|         fill_thread = threading.Thread(target=self.fill_subdomain_queue) | ||||
|         fill_thread.start() | ||||
| 
 | ||||
|         # Wait for one sentinel per worker | ||||
|         # In the meantime output results | ||||
|     for _ in servers: | ||||
|         for result in iter(results_queue.get, None): | ||||
|         for _ in range(NUMBER_THREADS): | ||||
|             for result in iter(self.results_queue.get, None): | ||||
|                 yield result | ||||
| 
 | ||||
|         self.log.info("Waiting for reader thread") | ||||
|         fill_thread.join() | ||||
| 
 | ||||
|         self.log.info("Done!") | ||||
| 
 | ||||
| 
 | ||||
| def main() -> None: | ||||
|     """ | ||||
|     Main function when used directly. | ||||
|     Takes as an input a filename (or nothing, for stdin) | ||||
|     that will be read and the ones that are a tracker | ||||
|     will be outputed on stdout. | ||||
|     Use the file `nameservers` as the list of nameservers | ||||
|     to use, or else it will use the system defaults. | ||||
|     Also shows a nice progressbar. | ||||
|     """ | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     coloredlogs.install( | ||||
|         level='DEBUG', | ||||
|         fmt='%(asctime)s %(name)s[%(process)d] %(levelname)s %(message)s' | ||||
|         fmt='%(asctime)s %(name)s %(levelname)s %(message)s' | ||||
|     ) | ||||
| 
 | ||||
|     # Progress bar | ||||
|  | @ -164,8 +233,12 @@ if __name__ == '__main__': | |||
|         servers = list(filter(None, map(str.strip, servers))) | ||||
| 
 | ||||
|     progress.start() | ||||
|     for subdomain, matching in get_matching_subdomains(iterator, servers): | ||||
|     for subdomain, matching in Orchestrator(iterator, servers).run(): | ||||
|         progress.update(progress.value + 1) | ||||
|         if matching: | ||||
|             print(subdomain) | ||||
|     progress.finish() | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     main() | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue