resolver.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. import asyncio
  2. import socket
  3. import weakref
  4. from typing import Any, Dict, Final, List, Optional, Tuple, Type, Union
  5. from .abc import AbstractResolver, ResolveResult
  6. __all__ = ("ThreadedResolver", "AsyncResolver", "DefaultResolver")
  7. try:
  8. import aiodns
  9. aiodns_default = hasattr(aiodns.DNSResolver, "getaddrinfo")
  10. except ImportError: # pragma: no cover
  11. aiodns = None # type: ignore[assignment]
  12. aiodns_default = False
  13. _NUMERIC_SOCKET_FLAGS = socket.AI_NUMERICHOST | socket.AI_NUMERICSERV
  14. _NAME_SOCKET_FLAGS = socket.NI_NUMERICHOST | socket.NI_NUMERICSERV
  15. _AI_ADDRCONFIG = socket.AI_ADDRCONFIG
  16. if hasattr(socket, "AI_MASK"):
  17. _AI_ADDRCONFIG &= socket.AI_MASK
  18. class ThreadedResolver(AbstractResolver):
  19. """Threaded resolver.
  20. Uses an Executor for synchronous getaddrinfo() calls.
  21. concurrent.futures.ThreadPoolExecutor is used by default.
  22. """
  23. def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
  24. self._loop = loop or asyncio.get_running_loop()
  25. async def resolve(
  26. self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
  27. ) -> List[ResolveResult]:
  28. infos = await self._loop.getaddrinfo(
  29. host,
  30. port,
  31. type=socket.SOCK_STREAM,
  32. family=family,
  33. flags=_AI_ADDRCONFIG,
  34. )
  35. hosts: List[ResolveResult] = []
  36. for family, _, proto, _, address in infos:
  37. if family == socket.AF_INET6:
  38. if len(address) < 3:
  39. # IPv6 is not supported by Python build,
  40. # or IPv6 is not enabled in the host
  41. continue
  42. if address[3]:
  43. # This is essential for link-local IPv6 addresses.
  44. # LL IPv6 is a VERY rare case. Strictly speaking, we should use
  45. # getnameinfo() unconditionally, but performance makes sense.
  46. resolved_host, _port = await self._loop.getnameinfo(
  47. address, _NAME_SOCKET_FLAGS
  48. )
  49. port = int(_port)
  50. else:
  51. resolved_host, port = address[:2]
  52. else: # IPv4
  53. assert family == socket.AF_INET
  54. resolved_host, port = address # type: ignore[misc]
  55. hosts.append(
  56. ResolveResult(
  57. hostname=host,
  58. host=resolved_host,
  59. port=port,
  60. family=family,
  61. proto=proto,
  62. flags=_NUMERIC_SOCKET_FLAGS,
  63. )
  64. )
  65. return hosts
  66. async def close(self) -> None:
  67. pass
  68. class AsyncResolver(AbstractResolver):
  69. """Use the `aiodns` package to make asynchronous DNS lookups"""
  70. def __init__(
  71. self,
  72. loop: Optional[asyncio.AbstractEventLoop] = None,
  73. *args: Any,
  74. **kwargs: Any,
  75. ) -> None:
  76. if aiodns is None:
  77. raise RuntimeError("Resolver requires aiodns library")
  78. self._loop = loop or asyncio.get_running_loop()
  79. self._manager: Optional[_DNSResolverManager] = None
  80. # If custom args are provided, create a dedicated resolver instance
  81. # This means each AsyncResolver with custom args gets its own
  82. # aiodns.DNSResolver instance
  83. if args or kwargs:
  84. self._resolver = aiodns.DNSResolver(*args, **kwargs)
  85. return
  86. # Use the shared resolver from the manager for default arguments
  87. self._manager = _DNSResolverManager()
  88. self._resolver = self._manager.get_resolver(self, self._loop)
  89. if not hasattr(self._resolver, "gethostbyname"):
  90. # aiodns 1.1 is not available, fallback to DNSResolver.query
  91. self.resolve = self._resolve_with_query # type: ignore
  92. async def resolve(
  93. self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
  94. ) -> List[ResolveResult]:
  95. try:
  96. resp = await self._resolver.getaddrinfo(
  97. host,
  98. port=port,
  99. type=socket.SOCK_STREAM,
  100. family=family,
  101. flags=_AI_ADDRCONFIG,
  102. )
  103. except aiodns.error.DNSError as exc:
  104. msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed"
  105. raise OSError(None, msg) from exc
  106. hosts: List[ResolveResult] = []
  107. for node in resp.nodes:
  108. address: Union[Tuple[bytes, int], Tuple[bytes, int, int, int]] = node.addr
  109. family = node.family
  110. if family == socket.AF_INET6:
  111. if len(address) > 3 and address[3]:
  112. # This is essential for link-local IPv6 addresses.
  113. # LL IPv6 is a VERY rare case. Strictly speaking, we should use
  114. # getnameinfo() unconditionally, but performance makes sense.
  115. result = await self._resolver.getnameinfo(
  116. (address[0].decode("ascii"), *address[1:]),
  117. _NAME_SOCKET_FLAGS,
  118. )
  119. resolved_host = result.node
  120. else:
  121. resolved_host = address[0].decode("ascii")
  122. port = address[1]
  123. else: # IPv4
  124. assert family == socket.AF_INET
  125. resolved_host = address[0].decode("ascii")
  126. port = address[1]
  127. hosts.append(
  128. ResolveResult(
  129. hostname=host,
  130. host=resolved_host,
  131. port=port,
  132. family=family,
  133. proto=0,
  134. flags=_NUMERIC_SOCKET_FLAGS,
  135. )
  136. )
  137. if not hosts:
  138. raise OSError(None, "DNS lookup failed")
  139. return hosts
  140. async def _resolve_with_query(
  141. self, host: str, port: int = 0, family: int = socket.AF_INET
  142. ) -> List[Dict[str, Any]]:
  143. qtype: Final = "AAAA" if family == socket.AF_INET6 else "A"
  144. try:
  145. resp = await self._resolver.query(host, qtype)
  146. except aiodns.error.DNSError as exc:
  147. msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed"
  148. raise OSError(None, msg) from exc
  149. hosts = []
  150. for rr in resp:
  151. hosts.append(
  152. {
  153. "hostname": host,
  154. "host": rr.host,
  155. "port": port,
  156. "family": family,
  157. "proto": 0,
  158. "flags": socket.AI_NUMERICHOST,
  159. }
  160. )
  161. if not hosts:
  162. raise OSError(None, "DNS lookup failed")
  163. return hosts
  164. async def close(self) -> None:
  165. if self._manager:
  166. # Release the resolver from the manager if using the shared resolver
  167. self._manager.release_resolver(self, self._loop)
  168. self._manager = None # Clear reference to manager
  169. self._resolver = None # type: ignore[assignment] # Clear reference to resolver
  170. return
  171. # Otherwise cancel our dedicated resolver
  172. if self._resolver is not None:
  173. self._resolver.cancel()
  174. self._resolver = None # type: ignore[assignment] # Clear reference
  175. class _DNSResolverManager:
  176. """Manager for aiodns.DNSResolver objects.
  177. This class manages shared aiodns.DNSResolver instances
  178. with no custom arguments across different event loops.
  179. """
  180. _instance: Optional["_DNSResolverManager"] = None
  181. def __new__(cls) -> "_DNSResolverManager":
  182. if cls._instance is None:
  183. cls._instance = super().__new__(cls)
  184. cls._instance._init()
  185. return cls._instance
  186. def _init(self) -> None:
  187. # Use WeakKeyDictionary to allow event loops to be garbage collected
  188. self._loop_data: weakref.WeakKeyDictionary[
  189. asyncio.AbstractEventLoop,
  190. tuple["aiodns.DNSResolver", weakref.WeakSet["AsyncResolver"]],
  191. ] = weakref.WeakKeyDictionary()
  192. def get_resolver(
  193. self, client: "AsyncResolver", loop: asyncio.AbstractEventLoop
  194. ) -> "aiodns.DNSResolver":
  195. """Get or create the shared aiodns.DNSResolver instance for a specific event loop.
  196. Args:
  197. client: The AsyncResolver instance requesting the resolver.
  198. This is required to track resolver usage.
  199. loop: The event loop to use for the resolver.
  200. """
  201. # Create a new resolver and client set for this loop if it doesn't exist
  202. if loop not in self._loop_data:
  203. resolver = aiodns.DNSResolver(loop=loop)
  204. client_set: weakref.WeakSet["AsyncResolver"] = weakref.WeakSet()
  205. self._loop_data[loop] = (resolver, client_set)
  206. else:
  207. # Get the existing resolver and client set
  208. resolver, client_set = self._loop_data[loop]
  209. # Register this client with the loop
  210. client_set.add(client)
  211. return resolver
  212. def release_resolver(
  213. self, client: "AsyncResolver", loop: asyncio.AbstractEventLoop
  214. ) -> None:
  215. """Release the resolver for an AsyncResolver client when it's closed.
  216. Args:
  217. client: The AsyncResolver instance to release.
  218. loop: The event loop the resolver was using.
  219. """
  220. # Remove client from its loop's tracking
  221. current_loop_data = self._loop_data.get(loop)
  222. if current_loop_data is None:
  223. return
  224. resolver, client_set = current_loop_data
  225. client_set.discard(client)
  226. # If no more clients for this loop, cancel and remove its resolver
  227. if not client_set:
  228. if resolver is not None:
  229. resolver.cancel()
  230. del self._loop_data[loop]
  231. _DefaultType = Type[Union[AsyncResolver, ThreadedResolver]]
  232. DefaultResolver: _DefaultType = AsyncResolver if aiodns_default else ThreadedResolver