client_proto.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. import asyncio
  2. from contextlib import suppress
  3. from typing import Any, Optional, Tuple, Union
  4. from .base_protocol import BaseProtocol
  5. from .client_exceptions import (
  6. ClientConnectionError,
  7. ClientOSError,
  8. ClientPayloadError,
  9. ServerDisconnectedError,
  10. SocketTimeoutError,
  11. )
  12. from .helpers import (
  13. _EXC_SENTINEL,
  14. EMPTY_BODY_STATUS_CODES,
  15. BaseTimerContext,
  16. set_exception,
  17. set_result,
  18. )
  19. from .http import HttpResponseParser, RawResponseMessage
  20. from .http_exceptions import HttpProcessingError
  21. from .streams import EMPTY_PAYLOAD, DataQueue, StreamReader
  22. class ResponseHandler(BaseProtocol, DataQueue[Tuple[RawResponseMessage, StreamReader]]):
  23. """Helper class to adapt between Protocol and StreamReader."""
  24. def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
  25. BaseProtocol.__init__(self, loop=loop)
  26. DataQueue.__init__(self, loop)
  27. self._should_close = False
  28. self._payload: Optional[StreamReader] = None
  29. self._skip_payload = False
  30. self._payload_parser = None
  31. self._timer = None
  32. self._tail = b""
  33. self._upgraded = False
  34. self._parser: Optional[HttpResponseParser] = None
  35. self._read_timeout: Optional[float] = None
  36. self._read_timeout_handle: Optional[asyncio.TimerHandle] = None
  37. self._timeout_ceil_threshold: Optional[float] = 5
  38. self._closed: Union[None, asyncio.Future[None]] = None
  39. self._connection_lost_called = False
  40. @property
  41. def closed(self) -> Union[None, asyncio.Future[None]]:
  42. """Future that is set when the connection is closed.
  43. This property returns a Future that will be completed when the connection
  44. is closed. The Future is created lazily on first access to avoid creating
  45. futures that will never be awaited.
  46. Returns:
  47. - A Future[None] if the connection is still open or was closed after
  48. this property was accessed
  49. - None if connection_lost() was already called before this property
  50. was ever accessed (indicating no one is waiting for the closure)
  51. """
  52. if self._closed is None and not self._connection_lost_called:
  53. self._closed = self._loop.create_future()
  54. return self._closed
  55. @property
  56. def upgraded(self) -> bool:
  57. return self._upgraded
  58. @property
  59. def should_close(self) -> bool:
  60. return bool(
  61. self._should_close
  62. or (self._payload is not None and not self._payload.is_eof())
  63. or self._upgraded
  64. or self._exception is not None
  65. or self._payload_parser is not None
  66. or self._buffer
  67. or self._tail
  68. )
  69. def force_close(self) -> None:
  70. self._should_close = True
  71. def close(self) -> None:
  72. self._exception = None # Break cyclic references
  73. transport = self.transport
  74. if transport is not None:
  75. transport.close()
  76. self.transport = None
  77. self._payload = None
  78. self._drop_timeout()
  79. def abort(self) -> None:
  80. self._exception = None # Break cyclic references
  81. transport = self.transport
  82. if transport is not None:
  83. transport.abort()
  84. self.transport = None
  85. self._payload = None
  86. self._drop_timeout()
  87. def is_connected(self) -> bool:
  88. return self.transport is not None and not self.transport.is_closing()
  89. def connection_lost(self, exc: Optional[BaseException]) -> None:
  90. self._connection_lost_called = True
  91. self._drop_timeout()
  92. original_connection_error = exc
  93. reraised_exc = original_connection_error
  94. connection_closed_cleanly = original_connection_error is None
  95. if self._closed is not None:
  96. # If someone is waiting for the closed future,
  97. # we should set it to None or an exception. If
  98. # self._closed is None, it means that
  99. # connection_lost() was called already
  100. # or nobody is waiting for it.
  101. if connection_closed_cleanly:
  102. set_result(self._closed, None)
  103. else:
  104. assert original_connection_error is not None
  105. set_exception(
  106. self._closed,
  107. ClientConnectionError(
  108. f"Connection lost: {original_connection_error !s}",
  109. ),
  110. original_connection_error,
  111. )
  112. if self._payload_parser is not None:
  113. with suppress(Exception): # FIXME: log this somehow?
  114. self._payload_parser.feed_eof()
  115. uncompleted = None
  116. if self._parser is not None:
  117. try:
  118. uncompleted = self._parser.feed_eof()
  119. except Exception as underlying_exc:
  120. if self._payload is not None:
  121. client_payload_exc_msg = (
  122. f"Response payload is not completed: {underlying_exc !r}"
  123. )
  124. if not connection_closed_cleanly:
  125. client_payload_exc_msg = (
  126. f"{client_payload_exc_msg !s}. "
  127. f"{original_connection_error !r}"
  128. )
  129. set_exception(
  130. self._payload,
  131. ClientPayloadError(client_payload_exc_msg),
  132. underlying_exc,
  133. )
  134. if not self.is_eof():
  135. if isinstance(original_connection_error, OSError):
  136. reraised_exc = ClientOSError(*original_connection_error.args)
  137. if connection_closed_cleanly:
  138. reraised_exc = ServerDisconnectedError(uncompleted)
  139. # assigns self._should_close to True as side effect,
  140. # we do it anyway below
  141. underlying_non_eof_exc = (
  142. _EXC_SENTINEL
  143. if connection_closed_cleanly
  144. else original_connection_error
  145. )
  146. assert underlying_non_eof_exc is not None
  147. assert reraised_exc is not None
  148. self.set_exception(reraised_exc, underlying_non_eof_exc)
  149. self._should_close = True
  150. self._parser = None
  151. self._payload = None
  152. self._payload_parser = None
  153. self._reading_paused = False
  154. super().connection_lost(reraised_exc)
  155. def eof_received(self) -> None:
  156. # should call parser.feed_eof() most likely
  157. self._drop_timeout()
  158. def pause_reading(self) -> None:
  159. super().pause_reading()
  160. self._drop_timeout()
  161. def resume_reading(self) -> None:
  162. super().resume_reading()
  163. self._reschedule_timeout()
  164. def set_exception(
  165. self,
  166. exc: BaseException,
  167. exc_cause: BaseException = _EXC_SENTINEL,
  168. ) -> None:
  169. self._should_close = True
  170. self._drop_timeout()
  171. super().set_exception(exc, exc_cause)
  172. def set_parser(self, parser: Any, payload: Any) -> None:
  173. # TODO: actual types are:
  174. # parser: WebSocketReader
  175. # payload: WebSocketDataQueue
  176. # but they are not generi enough
  177. # Need an ABC for both types
  178. self._payload = payload
  179. self._payload_parser = parser
  180. self._drop_timeout()
  181. if self._tail:
  182. data, self._tail = self._tail, b""
  183. self.data_received(data)
  184. def set_response_params(
  185. self,
  186. *,
  187. timer: Optional[BaseTimerContext] = None,
  188. skip_payload: bool = False,
  189. read_until_eof: bool = False,
  190. auto_decompress: bool = True,
  191. read_timeout: Optional[float] = None,
  192. read_bufsize: int = 2**16,
  193. timeout_ceil_threshold: float = 5,
  194. max_line_size: int = 8190,
  195. max_field_size: int = 8190,
  196. ) -> None:
  197. self._skip_payload = skip_payload
  198. self._read_timeout = read_timeout
  199. self._timeout_ceil_threshold = timeout_ceil_threshold
  200. self._parser = HttpResponseParser(
  201. self,
  202. self._loop,
  203. read_bufsize,
  204. timer=timer,
  205. payload_exception=ClientPayloadError,
  206. response_with_body=not skip_payload,
  207. read_until_eof=read_until_eof,
  208. auto_decompress=auto_decompress,
  209. max_line_size=max_line_size,
  210. max_field_size=max_field_size,
  211. )
  212. if self._tail:
  213. data, self._tail = self._tail, b""
  214. self.data_received(data)
  215. def _drop_timeout(self) -> None:
  216. if self._read_timeout_handle is not None:
  217. self._read_timeout_handle.cancel()
  218. self._read_timeout_handle = None
  219. def _reschedule_timeout(self) -> None:
  220. timeout = self._read_timeout
  221. if self._read_timeout_handle is not None:
  222. self._read_timeout_handle.cancel()
  223. if timeout:
  224. self._read_timeout_handle = self._loop.call_later(
  225. timeout, self._on_read_timeout
  226. )
  227. else:
  228. self._read_timeout_handle = None
  229. def start_timeout(self) -> None:
  230. self._reschedule_timeout()
  231. @property
  232. def read_timeout(self) -> Optional[float]:
  233. return self._read_timeout
  234. @read_timeout.setter
  235. def read_timeout(self, read_timeout: Optional[float]) -> None:
  236. self._read_timeout = read_timeout
  237. def _on_read_timeout(self) -> None:
  238. exc = SocketTimeoutError("Timeout on reading data from socket")
  239. self.set_exception(exc)
  240. if self._payload is not None:
  241. set_exception(self._payload, exc)
  242. def data_received(self, data: bytes) -> None:
  243. self._reschedule_timeout()
  244. if not data:
  245. return
  246. # custom payload parser - currently always WebSocketReader
  247. if self._payload_parser is not None:
  248. eof, tail = self._payload_parser.feed_data(data)
  249. if eof:
  250. self._payload = None
  251. self._payload_parser = None
  252. if tail:
  253. self.data_received(tail)
  254. return
  255. if self._upgraded or self._parser is None:
  256. # i.e. websocket connection, websocket parser is not set yet
  257. self._tail += data
  258. return
  259. # parse http messages
  260. try:
  261. messages, upgraded, tail = self._parser.feed_data(data)
  262. except BaseException as underlying_exc:
  263. if self.transport is not None:
  264. # connection.release() could be called BEFORE
  265. # data_received(), the transport is already
  266. # closed in this case
  267. self.transport.close()
  268. # should_close is True after the call
  269. if isinstance(underlying_exc, HttpProcessingError):
  270. exc = HttpProcessingError(
  271. code=underlying_exc.code,
  272. message=underlying_exc.message,
  273. headers=underlying_exc.headers,
  274. )
  275. else:
  276. exc = HttpProcessingError()
  277. self.set_exception(exc, underlying_exc)
  278. return
  279. self._upgraded = upgraded
  280. payload: Optional[StreamReader] = None
  281. for message, payload in messages:
  282. if message.should_close:
  283. self._should_close = True
  284. self._payload = payload
  285. if self._skip_payload or message.code in EMPTY_BODY_STATUS_CODES:
  286. self.feed_data((message, EMPTY_PAYLOAD), 0)
  287. else:
  288. self.feed_data((message, payload), 0)
  289. if payload is not None:
  290. # new message(s) was processed
  291. # register timeout handler unsubscribing
  292. # either on end-of-stream or immediately for
  293. # EMPTY_PAYLOAD
  294. if payload is not EMPTY_PAYLOAD:
  295. payload.on_eof(self._drop_timeout)
  296. else:
  297. self._drop_timeout()
  298. if upgraded and tail:
  299. self.data_received(tail)