pytest_plugin.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. import asyncio
  2. import contextlib
  3. import inspect
  4. import warnings
  5. from typing import (
  6. Any,
  7. Awaitable,
  8. Callable,
  9. Dict,
  10. Iterator,
  11. Optional,
  12. Protocol,
  13. Union,
  14. overload,
  15. )
  16. import pytest
  17. from .test_utils import (
  18. BaseTestServer,
  19. RawTestServer,
  20. TestClient,
  21. TestServer,
  22. loop_context,
  23. setup_test_loop,
  24. teardown_test_loop,
  25. unused_port as _unused_port,
  26. )
  27. from .web import Application, BaseRequest, Request
  28. from .web_protocol import _RequestHandler
  29. try:
  30. import uvloop
  31. except ImportError: # pragma: no cover
  32. uvloop = None # type: ignore[assignment]
  33. class AiohttpClient(Protocol):
  34. @overload
  35. async def __call__(
  36. self,
  37. __param: Application,
  38. *,
  39. server_kwargs: Optional[Dict[str, Any]] = None,
  40. **kwargs: Any,
  41. ) -> TestClient[Request, Application]: ...
  42. @overload
  43. async def __call__(
  44. self,
  45. __param: BaseTestServer,
  46. *,
  47. server_kwargs: Optional[Dict[str, Any]] = None,
  48. **kwargs: Any,
  49. ) -> TestClient[BaseRequest, None]: ...
  50. class AiohttpServer(Protocol):
  51. def __call__(
  52. self, app: Application, *, port: Optional[int] = None, **kwargs: Any
  53. ) -> Awaitable[TestServer]: ...
  54. class AiohttpRawServer(Protocol):
  55. def __call__(
  56. self, handler: _RequestHandler, *, port: Optional[int] = None, **kwargs: Any
  57. ) -> Awaitable[RawTestServer]: ...
  58. def pytest_addoption(parser): # type: ignore[no-untyped-def]
  59. parser.addoption(
  60. "--aiohttp-fast",
  61. action="store_true",
  62. default=False,
  63. help="run tests faster by disabling extra checks",
  64. )
  65. parser.addoption(
  66. "--aiohttp-loop",
  67. action="store",
  68. default="pyloop",
  69. help="run tests with specific loop: pyloop, uvloop or all",
  70. )
  71. parser.addoption(
  72. "--aiohttp-enable-loop-debug",
  73. action="store_true",
  74. default=False,
  75. help="enable event loop debug mode",
  76. )
  77. def pytest_fixture_setup(fixturedef): # type: ignore[no-untyped-def]
  78. """Set up pytest fixture.
  79. Allow fixtures to be coroutines. Run coroutine fixtures in an event loop.
  80. """
  81. func = fixturedef.func
  82. if inspect.isasyncgenfunction(func):
  83. # async generator fixture
  84. is_async_gen = True
  85. elif inspect.iscoroutinefunction(func):
  86. # regular async fixture
  87. is_async_gen = False
  88. else:
  89. # not an async fixture, nothing to do
  90. return
  91. strip_request = False
  92. if "request" not in fixturedef.argnames:
  93. fixturedef.argnames += ("request",)
  94. strip_request = True
  95. def wrapper(*args, **kwargs): # type: ignore[no-untyped-def]
  96. request = kwargs["request"]
  97. if strip_request:
  98. del kwargs["request"]
  99. # if neither the fixture nor the test use the 'loop' fixture,
  100. # 'getfixturevalue' will fail because the test is not parameterized
  101. # (this can be removed someday if 'loop' is no longer parameterized)
  102. if "loop" not in request.fixturenames:
  103. raise Exception(
  104. "Asynchronous fixtures must depend on the 'loop' fixture or "
  105. "be used in tests depending from it."
  106. )
  107. _loop = request.getfixturevalue("loop")
  108. if is_async_gen:
  109. # for async generators, we need to advance the generator once,
  110. # then advance it again in a finalizer
  111. gen = func(*args, **kwargs)
  112. def finalizer(): # type: ignore[no-untyped-def]
  113. try:
  114. return _loop.run_until_complete(gen.__anext__())
  115. except StopAsyncIteration:
  116. pass
  117. request.addfinalizer(finalizer)
  118. return _loop.run_until_complete(gen.__anext__())
  119. else:
  120. return _loop.run_until_complete(func(*args, **kwargs))
  121. fixturedef.func = wrapper
  122. @pytest.fixture
  123. def fast(request): # type: ignore[no-untyped-def]
  124. """--fast config option"""
  125. return request.config.getoption("--aiohttp-fast")
  126. @pytest.fixture
  127. def loop_debug(request): # type: ignore[no-untyped-def]
  128. """--enable-loop-debug config option"""
  129. return request.config.getoption("--aiohttp-enable-loop-debug")
  130. @contextlib.contextmanager
  131. def _runtime_warning_context(): # type: ignore[no-untyped-def]
  132. """Context manager which checks for RuntimeWarnings.
  133. This exists specifically to
  134. avoid "coroutine 'X' was never awaited" warnings being missed.
  135. If RuntimeWarnings occur in the context a RuntimeError is raised.
  136. """
  137. with warnings.catch_warnings(record=True) as _warnings:
  138. yield
  139. rw = [
  140. "{w.filename}:{w.lineno}:{w.message}".format(w=w)
  141. for w in _warnings
  142. if w.category == RuntimeWarning
  143. ]
  144. if rw:
  145. raise RuntimeError(
  146. "{} Runtime Warning{},\n{}".format(
  147. len(rw), "" if len(rw) == 1 else "s", "\n".join(rw)
  148. )
  149. )
  150. @contextlib.contextmanager
  151. def _passthrough_loop_context(loop, fast=False): # type: ignore[no-untyped-def]
  152. """Passthrough loop context.
  153. Sets up and tears down a loop unless one is passed in via the loop
  154. argument when it's passed straight through.
  155. """
  156. if loop:
  157. # loop already exists, pass it straight through
  158. yield loop
  159. else:
  160. # this shadows loop_context's standard behavior
  161. loop = setup_test_loop()
  162. yield loop
  163. teardown_test_loop(loop, fast=fast)
  164. def pytest_pycollect_makeitem(collector, name, obj): # type: ignore[no-untyped-def]
  165. """Fix pytest collecting for coroutines."""
  166. if collector.funcnamefilter(name) and inspect.iscoroutinefunction(obj):
  167. return list(collector._genfunctions(name, obj))
  168. def pytest_pyfunc_call(pyfuncitem): # type: ignore[no-untyped-def]
  169. """Run coroutines in an event loop instead of a normal function call."""
  170. fast = pyfuncitem.config.getoption("--aiohttp-fast")
  171. if inspect.iscoroutinefunction(pyfuncitem.function):
  172. existing_loop = (
  173. pyfuncitem.funcargs.get("proactor_loop")
  174. or pyfuncitem.funcargs.get("selector_loop")
  175. or pyfuncitem.funcargs.get("uvloop_loop")
  176. or pyfuncitem.funcargs.get("loop", None)
  177. )
  178. with _runtime_warning_context():
  179. with _passthrough_loop_context(existing_loop, fast=fast) as _loop:
  180. testargs = {
  181. arg: pyfuncitem.funcargs[arg]
  182. for arg in pyfuncitem._fixtureinfo.argnames
  183. }
  184. _loop.run_until_complete(pyfuncitem.obj(**testargs))
  185. return True
  186. def pytest_generate_tests(metafunc): # type: ignore[no-untyped-def]
  187. if "loop_factory" not in metafunc.fixturenames:
  188. return
  189. loops = metafunc.config.option.aiohttp_loop
  190. avail_factories: dict[str, Callable[[], asyncio.AbstractEventLoop]]
  191. avail_factories = {"pyloop": asyncio.new_event_loop}
  192. if uvloop is not None: # pragma: no cover
  193. avail_factories["uvloop"] = uvloop.new_event_loop
  194. if loops == "all":
  195. loops = "pyloop,uvloop?"
  196. factories = {} # type: ignore[var-annotated]
  197. for name in loops.split(","):
  198. required = not name.endswith("?")
  199. name = name.strip(" ?")
  200. if name not in avail_factories: # pragma: no cover
  201. if required:
  202. raise ValueError(
  203. "Unknown loop '%s', available loops: %s"
  204. % (name, list(factories.keys()))
  205. )
  206. else:
  207. continue
  208. factories[name] = avail_factories[name]
  209. metafunc.parametrize(
  210. "loop_factory", list(factories.values()), ids=list(factories.keys())
  211. )
  212. @pytest.fixture
  213. def loop(
  214. loop_factory: Callable[[], asyncio.AbstractEventLoop],
  215. fast: bool,
  216. loop_debug: bool,
  217. ) -> Iterator[asyncio.AbstractEventLoop]:
  218. """Return an instance of the event loop."""
  219. with loop_context(loop_factory, fast=fast) as _loop:
  220. if loop_debug:
  221. _loop.set_debug(True) # pragma: no cover
  222. asyncio.set_event_loop(_loop)
  223. yield _loop
  224. @pytest.fixture
  225. def proactor_loop() -> Iterator[asyncio.AbstractEventLoop]:
  226. factory = asyncio.ProactorEventLoop # type: ignore[attr-defined]
  227. with loop_context(factory) as _loop:
  228. asyncio.set_event_loop(_loop)
  229. yield _loop
  230. @pytest.fixture
  231. def unused_port(aiohttp_unused_port: Callable[[], int]) -> Callable[[], int]:
  232. warnings.warn(
  233. "Deprecated, use aiohttp_unused_port fixture instead",
  234. DeprecationWarning,
  235. stacklevel=2,
  236. )
  237. return aiohttp_unused_port
  238. @pytest.fixture
  239. def aiohttp_unused_port() -> Callable[[], int]:
  240. """Return a port that is unused on the current host."""
  241. return _unused_port
  242. @pytest.fixture
  243. def aiohttp_server(loop: asyncio.AbstractEventLoop) -> Iterator[AiohttpServer]:
  244. """Factory to create a TestServer instance, given an app.
  245. aiohttp_server(app, **kwargs)
  246. """
  247. servers = []
  248. async def go(
  249. app: Application,
  250. *,
  251. host: str = "127.0.0.1",
  252. port: Optional[int] = None,
  253. **kwargs: Any,
  254. ) -> TestServer:
  255. server = TestServer(app, host=host, port=port)
  256. await server.start_server(loop=loop, **kwargs)
  257. servers.append(server)
  258. return server
  259. yield go
  260. async def finalize() -> None:
  261. while servers:
  262. await servers.pop().close()
  263. loop.run_until_complete(finalize())
  264. @pytest.fixture
  265. def test_server(aiohttp_server): # type: ignore[no-untyped-def] # pragma: no cover
  266. warnings.warn(
  267. "Deprecated, use aiohttp_server fixture instead",
  268. DeprecationWarning,
  269. stacklevel=2,
  270. )
  271. return aiohttp_server
  272. @pytest.fixture
  273. def aiohttp_raw_server(loop: asyncio.AbstractEventLoop) -> Iterator[AiohttpRawServer]:
  274. """Factory to create a RawTestServer instance, given a web handler.
  275. aiohttp_raw_server(handler, **kwargs)
  276. """
  277. servers = []
  278. async def go(
  279. handler: _RequestHandler, *, port: Optional[int] = None, **kwargs: Any
  280. ) -> RawTestServer:
  281. server = RawTestServer(handler, port=port)
  282. await server.start_server(loop=loop, **kwargs)
  283. servers.append(server)
  284. return server
  285. yield go
  286. async def finalize() -> None:
  287. while servers:
  288. await servers.pop().close()
  289. loop.run_until_complete(finalize())
  290. @pytest.fixture
  291. def raw_test_server( # type: ignore[no-untyped-def] # pragma: no cover
  292. aiohttp_raw_server,
  293. ):
  294. warnings.warn(
  295. "Deprecated, use aiohttp_raw_server fixture instead",
  296. DeprecationWarning,
  297. stacklevel=2,
  298. )
  299. return aiohttp_raw_server
  300. @pytest.fixture
  301. def aiohttp_client(loop: asyncio.AbstractEventLoop) -> Iterator[AiohttpClient]:
  302. """Factory to create a TestClient instance.
  303. aiohttp_client(app, **kwargs)
  304. aiohttp_client(server, **kwargs)
  305. aiohttp_client(raw_server, **kwargs)
  306. """
  307. clients = []
  308. @overload
  309. async def go(
  310. __param: Application,
  311. *,
  312. server_kwargs: Optional[Dict[str, Any]] = None,
  313. **kwargs: Any,
  314. ) -> TestClient[Request, Application]: ...
  315. @overload
  316. async def go(
  317. __param: BaseTestServer,
  318. *,
  319. server_kwargs: Optional[Dict[str, Any]] = None,
  320. **kwargs: Any,
  321. ) -> TestClient[BaseRequest, None]: ...
  322. async def go(
  323. __param: Union[Application, BaseTestServer],
  324. *args: Any,
  325. server_kwargs: Optional[Dict[str, Any]] = None,
  326. **kwargs: Any,
  327. ) -> TestClient[Any, Any]:
  328. if isinstance(__param, Callable) and not isinstance( # type: ignore[arg-type]
  329. __param, (Application, BaseTestServer)
  330. ):
  331. __param = __param(loop, *args, **kwargs)
  332. kwargs = {}
  333. else:
  334. assert not args, "args should be empty"
  335. if isinstance(__param, Application):
  336. server_kwargs = server_kwargs or {}
  337. server = TestServer(__param, loop=loop, **server_kwargs)
  338. client = TestClient(server, loop=loop, **kwargs)
  339. elif isinstance(__param, BaseTestServer):
  340. client = TestClient(__param, loop=loop, **kwargs)
  341. else:
  342. raise ValueError("Unknown argument type: %r" % type(__param))
  343. await client.start_server()
  344. clients.append(client)
  345. return client
  346. yield go
  347. async def finalize() -> None:
  348. while clients:
  349. await clients.pop().close()
  350. loop.run_until_complete(finalize())
  351. @pytest.fixture
  352. def test_client(aiohttp_client): # type: ignore[no-untyped-def] # pragma: no cover
  353. warnings.warn(
  354. "Deprecated, use aiohttp_client fixture instead",
  355. DeprecationWarning,
  356. stacklevel=2,
  357. )
  358. return aiohttp_client