Refactored for correct retry logic
This commit is contained in:
		
							parent
							
								
									b343893c72
								
							
						
					
					
						commit
						88f0bcc648
					
				
					 1 changed files with 124 additions and 51 deletions
				
			
		|  | @ -1,5 +1,4 @@ | ||||||
| #!/usr/bin/env python3 | #!/usr/bin/env python3 | ||||||
| # pylint: disable=C0103 |  | ||||||
| 
 | 
 | ||||||
| """ | """ | ||||||
| From a list of subdomains, output only | From a list of subdomains, output only | ||||||
|  | @ -7,11 +6,11 @@ the ones resolving to a first-party tracker. | ||||||
| """ | """ | ||||||
| 
 | 
 | ||||||
| import logging | import logging | ||||||
| import threading |  | ||||||
| import queue |  | ||||||
| import os | import os | ||||||
|  | import queue | ||||||
| import re | import re | ||||||
| import sys | import sys | ||||||
|  | import threading | ||||||
| import typing | import typing | ||||||
| 
 | 
 | ||||||
| import coloredlogs | import coloredlogs | ||||||
|  | @ -22,31 +21,43 @@ import progressbar | ||||||
| import regexes | import regexes | ||||||
| 
 | 
 | ||||||
| DNS_TIMEOUT = 60.0 | DNS_TIMEOUT = 60.0 | ||||||
|  | NUMBER_THREADS = 512 | ||||||
|  | NUMBER_TRIES = 10 | ||||||
| 
 | 
 | ||||||
| # TODO Try again does not work because sentinel get through first :/ |  | ||||||
| 
 | 
 | ||||||
| class DnsResolver(threading.Thread): | class Worker(threading.Thread): | ||||||
|     """ |     """ | ||||||
|     Worker process for a DNS resolver. |     Worker process for a DNS resolver. | ||||||
|     Will resolve DNS to match first-party subdomains. |     Will resolve DNS to match first-party subdomains. | ||||||
|     """ |     """ | ||||||
| 
 | 
 | ||||||
|     def __init__(self, |     def change_nameserver(self) -> None: | ||||||
|                  in_queue: queue.Queue, |         """ | ||||||
|                  out_queue: queue.Queue, |         Assign a this worker another nameserver from the queue. | ||||||
|                  server: str): |         """ | ||||||
|         super(DnsResolver, self).__init__() |         server = None | ||||||
|         self.log = logging.getLogger(server) |         while server is None: | ||||||
|  |             try: | ||||||
|  |                 server = self.orchestrator.nameservers_queue.get(block=False) | ||||||
|  |             except queue.Empty: | ||||||
|  |                 self.orchestrator.refill_nameservers_queue() | ||||||
|  |         self.log.debug("Using nameserver: %s", server) | ||||||
|  |         self.resolver.nameservers = [server] | ||||||
| 
 | 
 | ||||||
|         self.in_queue = in_queue |     def __init__(self, | ||||||
|         self.out_queue = out_queue |                  orchestrator: 'Orchestrator', | ||||||
|  |                  index: int = 0): | ||||||
|  |         super(Worker, self).__init__() | ||||||
|  |         self.log = logging.getLogger(f'worker{index:03d}') | ||||||
|  |         self.orchestrator = orchestrator | ||||||
| 
 | 
 | ||||||
|         self.resolver = dns.resolver.Resolver() |         self.resolver = dns.resolver.Resolver() | ||||||
|         self.resolver.nameservers = [server] |         self.change_nameserver() | ||||||
| 
 | 
 | ||||||
|     def is_subdomain_matching(self, subdomain: str) -> typing.Optional[bool]: |     def is_subdomain_matching(self, subdomain: str) -> typing.Optional[bool]: | ||||||
|         """ |         """ | ||||||
|         Indicates if the subdomain redirects to a first-party tracker. |         Indicates if the subdomain redirects to a first-party tracker. | ||||||
|  |         Returns None if the nameserver was unable to satisfy the request. | ||||||
|         """ |         """ | ||||||
|         # TODO Look at the whole chain rather than the last one |         # TODO Look at the whole chain rather than the last one | ||||||
|         # TODO Also match the ASN of the IP (caching the ASN subnetworks will do) |         # TODO Also match the ASN of the IP (caching the ASN subnetworks will do) | ||||||
|  | @ -58,16 +69,20 @@ class DnsResolver(threading.Thread): | ||||||
|             return False |             return False | ||||||
|         except dns.resolver.YXDOMAIN: |         except dns.resolver.YXDOMAIN: | ||||||
|             self.log.warning("Query name too long for %s", subdomain) |             self.log.warning("Query name too long for %s", subdomain) | ||||||
|             return False |             return None | ||||||
|         except dns.resolver.NoNameservers: |         except dns.resolver.NoNameservers: | ||||||
|  |             # NOTE Most of the time this error message means that the domain | ||||||
|  |             # does not exists, but sometimes it means the that the server | ||||||
|  |             # itself is broken. So we count on the retry logic. | ||||||
|             self.log.warning("All nameservers broken for %s", subdomain) |             self.log.warning("All nameservers broken for %s", subdomain) | ||||||
|             return None |             return None | ||||||
|         except dns.exception.Timeout: |         except dns.exception.Timeout: | ||||||
|  |             # NOTE Same as above | ||||||
|             self.log.warning("Timeout for %s", subdomain) |             self.log.warning("Timeout for %s", subdomain) | ||||||
|             return None |             return None | ||||||
|         except dns.name.EmptyLabel: |         except dns.name.EmptyLabel: | ||||||
|             self.log.warning("Empty label for %s", subdomain) |             self.log.warning("Empty label for %s", subdomain) | ||||||
|             return False |             return None | ||||||
|         canonical = query.canonical_name.to_text() |         canonical = query.canonical_name.to_text() | ||||||
|         for regex in regexes.REGEXES: |         for regex in regexes.REGEXES: | ||||||
|             if re.match(regex, canonical): |             if re.match(regex, canonical): | ||||||
|  | @ -76,58 +91,112 @@ class DnsResolver(threading.Thread): | ||||||
| 
 | 
 | ||||||
|     def run(self) -> None: |     def run(self) -> None: | ||||||
|         self.log.info("Started") |         self.log.info("Started") | ||||||
|         for subdomain in iter(self.in_queue.get, None): |         for subdomain in iter(self.orchestrator.subdomains_queue.get, None): | ||||||
|             matching = self.is_subdomain_matching(subdomain) |  | ||||||
| 
 | 
 | ||||||
|             # If issue, retry |             for _ in range(NUMBER_TRIES): | ||||||
|  |                 matching = self.is_subdomain_matching(subdomain) | ||||||
|  |                 if matching is not None: | ||||||
|  |                     break | ||||||
|  | 
 | ||||||
|  |             # If it wasn't found after multiple tries | ||||||
|             if matching is None: |             if matching is None: | ||||||
|                 # matching = False |                 self.log.error("Gave up on %s", subdomain) | ||||||
|                 self.in_queue.put(subdomain) |                 matching = False | ||||||
|                 continue |  | ||||||
| 
 | 
 | ||||||
|             result = (subdomain, matching) |             result = (subdomain, matching) | ||||||
|             # self.log.debug("%s", result) |             self.orchestrator.results_queue.put(result) | ||||||
|             self.out_queue.put(result) | 
 | ||||||
|         self.out_queue.put(None) |         self.orchestrator.results_queue.put(None) | ||||||
|         self.log.info("Stopped") |         self.log.info("Stopped") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_matching_subdomains(subdomains: typing.Iterable[str], | class Orchestrator(): | ||||||
|                             nameservers: typing.List[str] = None, |  | ||||||
|                             ) -> typing.Iterable[typing.Tuple[str, bool]]: |  | ||||||
|     subdomains_queue: queue.Queue = queue.Queue() |  | ||||||
|     results_queue: queue.Queue = queue.Queue() |  | ||||||
|     """ |     """ | ||||||
|     Orchestrator of the different DnsResolver threads. |     Orchestrator of the different Worker threads. | ||||||
|     """ |     """ | ||||||
| 
 | 
 | ||||||
|     # Use interal resolver by default |     def refill_nameservers_queue(self) -> None: | ||||||
|     servers = nameservers or dns.resolver.Resolver().nameservers |         """ | ||||||
|  |         Re-fill the given nameservers into the nameservers queue. | ||||||
|  |         Done every-time the queue is empty, making it | ||||||
|  |         basically looping and infinite. | ||||||
|  |         """ | ||||||
|  |         # Might be in a race condition but that's probably fine | ||||||
|  |         for nameserver in self.nameservers: | ||||||
|  |             self.nameservers_queue.put(nameserver) | ||||||
|  |         self.log.info("Refilled nameserver queue") | ||||||
| 
 | 
 | ||||||
|     # Create workers |     def __init__(self, subdomains: typing.Iterable[str], | ||||||
|     for server in servers: |                  nameservers: typing.List[str] = None): | ||||||
|         DnsResolver(subdomains_queue, results_queue, server).start() |         self.log = logging.getLogger('orchestrator') | ||||||
|  |         self.subdomains = subdomains | ||||||
| 
 | 
 | ||||||
|     # Send data to workers |         # Use interal resolver by default | ||||||
|     for subdomain in subdomains: |         self.nameservers = nameservers or dns.resolver.Resolver().nameservers | ||||||
|         subdomains_queue.put(subdomain) |  | ||||||
| 
 | 
 | ||||||
|     # Send sentinel to each worker |         self.subdomains_queue: queue.Queue = queue.Queue( | ||||||
|     # sentinel = None ~= EOF |             maxsize=NUMBER_THREADS) | ||||||
|     for _ in servers: |         self.results_queue: queue.Queue = queue.Queue() | ||||||
|         subdomains_queue.put(None) |         self.nameservers_queue: queue.Queue = queue.Queue() | ||||||
| 
 | 
 | ||||||
|     # Wait for one sentinel per worker |         self.refill_nameservers_queue() | ||||||
|     # In the meantime output results | 
 | ||||||
|     for _ in servers: |     def fill_subdomain_queue(self) -> None: | ||||||
|         for result in iter(results_queue.get, None): |         """ | ||||||
|             yield result |         Read the subdomains in input and put them into the queue. | ||||||
|  |         Done in a thread so we can both: | ||||||
|  |         - yield the results as they come | ||||||
|  |         - not store all the subdomains at once | ||||||
|  |         """ | ||||||
|  |         self.log.info("Started reading subdomains") | ||||||
|  |         # Send data to workers | ||||||
|  |         for subdomain in self.subdomains: | ||||||
|  |             self.subdomains_queue.put(subdomain) | ||||||
|  | 
 | ||||||
|  |         self.log.info("Finished reading subdomains") | ||||||
|  |         # Send sentinel to each worker | ||||||
|  |         # sentinel = None ~= EOF | ||||||
|  |         for _ in range(NUMBER_THREADS): | ||||||
|  |             self.subdomains_queue.put(None) | ||||||
|  | 
 | ||||||
|  |     def run(self) -> typing.Iterable[typing.Tuple[str, bool]]: | ||||||
|  |         """ | ||||||
|  |         Yield the results. | ||||||
|  |         """ | ||||||
|  |         # Create workers | ||||||
|  |         self.log.info("Creating workers") | ||||||
|  |         for i in range(NUMBER_THREADS): | ||||||
|  |             Worker(self, i).start() | ||||||
|  | 
 | ||||||
|  |         fill_thread = threading.Thread(target=self.fill_subdomain_queue) | ||||||
|  |         fill_thread.start() | ||||||
|  | 
 | ||||||
|  |         # Wait for one sentinel per worker | ||||||
|  |         # In the meantime output results | ||||||
|  |         for _ in range(NUMBER_THREADS): | ||||||
|  |             for result in iter(self.results_queue.get, None): | ||||||
|  |                 yield result | ||||||
|  | 
 | ||||||
|  |         self.log.info("Waiting for reader thread") | ||||||
|  |         fill_thread.join() | ||||||
|  | 
 | ||||||
|  |         self.log.info("Done!") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| if __name__ == '__main__': | def main() -> None: | ||||||
|  |     """ | ||||||
|  |     Main function when used directly. | ||||||
|  |     Takes as an input a filename (or nothing, for stdin) | ||||||
|  |     that will be read and the ones that are a tracker | ||||||
|  |     will be outputed on stdout. | ||||||
|  |     Use the file `nameservers` as the list of nameservers | ||||||
|  |     to use, or else it will use the system defaults. | ||||||
|  |     Also shows a nice progressbar. | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|     coloredlogs.install( |     coloredlogs.install( | ||||||
|         level='DEBUG', |         level='DEBUG', | ||||||
|         fmt='%(asctime)s %(name)s[%(process)d] %(levelname)s %(message)s' |         fmt='%(asctime)s %(name)s %(levelname)s %(message)s' | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
|     # Progress bar |     # Progress bar | ||||||
|  | @ -164,8 +233,12 @@ if __name__ == '__main__': | ||||||
|         servers = list(filter(None, map(str.strip, servers))) |         servers = list(filter(None, map(str.strip, servers))) | ||||||
| 
 | 
 | ||||||
|     progress.start() |     progress.start() | ||||||
|     for subdomain, matching in get_matching_subdomains(iterator, servers): |     for subdomain, matching in Orchestrator(iterator, servers).run(): | ||||||
|         progress.update(progress.value + 1) |         progress.update(progress.value + 1) | ||||||
|         if matching: |         if matching: | ||||||
|             print(subdomain) |             print(subdomain) | ||||||
|     progress.finish() |     progress.finish() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     main() | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue