| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274 |
- import asyncio
- import socket
- import weakref
- from typing import Any, Dict, Final, List, Optional, Tuple, Type, Union
- from .abc import AbstractResolver, ResolveResult
- __all__ = ("ThreadedResolver", "AsyncResolver", "DefaultResolver")
- try:
- import aiodns
- aiodns_default = hasattr(aiodns.DNSResolver, "getaddrinfo")
- except ImportError: # pragma: no cover
- aiodns = None # type: ignore[assignment]
- aiodns_default = False
- _NUMERIC_SOCKET_FLAGS = socket.AI_NUMERICHOST | socket.AI_NUMERICSERV
- _NAME_SOCKET_FLAGS = socket.NI_NUMERICHOST | socket.NI_NUMERICSERV
- _AI_ADDRCONFIG = socket.AI_ADDRCONFIG
- if hasattr(socket, "AI_MASK"):
- _AI_ADDRCONFIG &= socket.AI_MASK
- class ThreadedResolver(AbstractResolver):
- """Threaded resolver.
- Uses an Executor for synchronous getaddrinfo() calls.
- concurrent.futures.ThreadPoolExecutor is used by default.
- """
- def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
- self._loop = loop or asyncio.get_running_loop()
- async def resolve(
- self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
- ) -> List[ResolveResult]:
- infos = await self._loop.getaddrinfo(
- host,
- port,
- type=socket.SOCK_STREAM,
- family=family,
- flags=_AI_ADDRCONFIG,
- )
- hosts: List[ResolveResult] = []
- for family, _, proto, _, address in infos:
- if family == socket.AF_INET6:
- if len(address) < 3:
- # IPv6 is not supported by Python build,
- # or IPv6 is not enabled in the host
- continue
- if address[3]:
- # This is essential for link-local IPv6 addresses.
- # LL IPv6 is a VERY rare case. Strictly speaking, we should use
- # getnameinfo() unconditionally, but performance makes sense.
- resolved_host, _port = await self._loop.getnameinfo(
- address, _NAME_SOCKET_FLAGS
- )
- port = int(_port)
- else:
- resolved_host, port = address[:2]
- else: # IPv4
- assert family == socket.AF_INET
- resolved_host, port = address # type: ignore[misc]
- hosts.append(
- ResolveResult(
- hostname=host,
- host=resolved_host,
- port=port,
- family=family,
- proto=proto,
- flags=_NUMERIC_SOCKET_FLAGS,
- )
- )
- return hosts
- async def close(self) -> None:
- pass
- class AsyncResolver(AbstractResolver):
- """Use the `aiodns` package to make asynchronous DNS lookups"""
- def __init__(
- self,
- loop: Optional[asyncio.AbstractEventLoop] = None,
- *args: Any,
- **kwargs: Any,
- ) -> None:
- if aiodns is None:
- raise RuntimeError("Resolver requires aiodns library")
- self._loop = loop or asyncio.get_running_loop()
- self._manager: Optional[_DNSResolverManager] = None
- # If custom args are provided, create a dedicated resolver instance
- # This means each AsyncResolver with custom args gets its own
- # aiodns.DNSResolver instance
- if args or kwargs:
- self._resolver = aiodns.DNSResolver(*args, **kwargs)
- return
- # Use the shared resolver from the manager for default arguments
- self._manager = _DNSResolverManager()
- self._resolver = self._manager.get_resolver(self, self._loop)
- if not hasattr(self._resolver, "gethostbyname"):
- # aiodns 1.1 is not available, fallback to DNSResolver.query
- self.resolve = self._resolve_with_query # type: ignore
- async def resolve(
- self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
- ) -> List[ResolveResult]:
- try:
- resp = await self._resolver.getaddrinfo(
- host,
- port=port,
- type=socket.SOCK_STREAM,
- family=family,
- flags=_AI_ADDRCONFIG,
- )
- except aiodns.error.DNSError as exc:
- msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed"
- raise OSError(None, msg) from exc
- hosts: List[ResolveResult] = []
- for node in resp.nodes:
- address: Union[Tuple[bytes, int], Tuple[bytes, int, int, int]] = node.addr
- family = node.family
- if family == socket.AF_INET6:
- if len(address) > 3 and address[3]:
- # This is essential for link-local IPv6 addresses.
- # LL IPv6 is a VERY rare case. Strictly speaking, we should use
- # getnameinfo() unconditionally, but performance makes sense.
- result = await self._resolver.getnameinfo(
- (address[0].decode("ascii"), *address[1:]),
- _NAME_SOCKET_FLAGS,
- )
- resolved_host = result.node
- else:
- resolved_host = address[0].decode("ascii")
- port = address[1]
- else: # IPv4
- assert family == socket.AF_INET
- resolved_host = address[0].decode("ascii")
- port = address[1]
- hosts.append(
- ResolveResult(
- hostname=host,
- host=resolved_host,
- port=port,
- family=family,
- proto=0,
- flags=_NUMERIC_SOCKET_FLAGS,
- )
- )
- if not hosts:
- raise OSError(None, "DNS lookup failed")
- return hosts
- async def _resolve_with_query(
- self, host: str, port: int = 0, family: int = socket.AF_INET
- ) -> List[Dict[str, Any]]:
- qtype: Final = "AAAA" if family == socket.AF_INET6 else "A"
- try:
- resp = await self._resolver.query(host, qtype)
- except aiodns.error.DNSError as exc:
- msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed"
- raise OSError(None, msg) from exc
- hosts = []
- for rr in resp:
- hosts.append(
- {
- "hostname": host,
- "host": rr.host,
- "port": port,
- "family": family,
- "proto": 0,
- "flags": socket.AI_NUMERICHOST,
- }
- )
- if not hosts:
- raise OSError(None, "DNS lookup failed")
- return hosts
- async def close(self) -> None:
- if self._manager:
- # Release the resolver from the manager if using the shared resolver
- self._manager.release_resolver(self, self._loop)
- self._manager = None # Clear reference to manager
- self._resolver = None # type: ignore[assignment] # Clear reference to resolver
- return
- # Otherwise cancel our dedicated resolver
- if self._resolver is not None:
- self._resolver.cancel()
- self._resolver = None # type: ignore[assignment] # Clear reference
- class _DNSResolverManager:
- """Manager for aiodns.DNSResolver objects.
- This class manages shared aiodns.DNSResolver instances
- with no custom arguments across different event loops.
- """
- _instance: Optional["_DNSResolverManager"] = None
- def __new__(cls) -> "_DNSResolverManager":
- if cls._instance is None:
- cls._instance = super().__new__(cls)
- cls._instance._init()
- return cls._instance
- def _init(self) -> None:
- # Use WeakKeyDictionary to allow event loops to be garbage collected
- self._loop_data: weakref.WeakKeyDictionary[
- asyncio.AbstractEventLoop,
- tuple["aiodns.DNSResolver", weakref.WeakSet["AsyncResolver"]],
- ] = weakref.WeakKeyDictionary()
- def get_resolver(
- self, client: "AsyncResolver", loop: asyncio.AbstractEventLoop
- ) -> "aiodns.DNSResolver":
- """Get or create the shared aiodns.DNSResolver instance for a specific event loop.
- Args:
- client: The AsyncResolver instance requesting the resolver.
- This is required to track resolver usage.
- loop: The event loop to use for the resolver.
- """
- # Create a new resolver and client set for this loop if it doesn't exist
- if loop not in self._loop_data:
- resolver = aiodns.DNSResolver(loop=loop)
- client_set: weakref.WeakSet["AsyncResolver"] = weakref.WeakSet()
- self._loop_data[loop] = (resolver, client_set)
- else:
- # Get the existing resolver and client set
- resolver, client_set = self._loop_data[loop]
- # Register this client with the loop
- client_set.add(client)
- return resolver
- def release_resolver(
- self, client: "AsyncResolver", loop: asyncio.AbstractEventLoop
- ) -> None:
- """Release the resolver for an AsyncResolver client when it's closed.
- Args:
- client: The AsyncResolver instance to release.
- loop: The event loop the resolver was using.
- """
- # Remove client from its loop's tracking
- current_loop_data = self._loop_data.get(loop)
- if current_loop_data is None:
- return
- resolver, client_set = current_loop_data
- client_set.discard(client)
- # If no more clients for this loop, cancel and remove its resolver
- if not client_set:
- if resolver is not None:
- resolver.cancel()
- del self._loop_data[loop]
- _DefaultType = Type[Union[AsyncResolver, ThreadedResolver]]
- DefaultResolver: _DefaultType = AsyncResolver if aiodns_default else ThreadedResolver
|