web_ws.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631
  1. import asyncio
  2. import base64
  3. import binascii
  4. import hashlib
  5. import json
  6. import sys
  7. from typing import Any, Final, Iterable, Optional, Tuple, Union, cast
  8. import attr
  9. from multidict import CIMultiDict
  10. from . import hdrs
  11. from ._websocket.reader import WebSocketDataQueue
  12. from ._websocket.writer import DEFAULT_LIMIT
  13. from .abc import AbstractStreamWriter
  14. from .client_exceptions import WSMessageTypeError
  15. from .helpers import calculate_timeout_when, set_exception, set_result
  16. from .http import (
  17. WS_CLOSED_MESSAGE,
  18. WS_CLOSING_MESSAGE,
  19. WS_KEY,
  20. WebSocketError,
  21. WebSocketReader,
  22. WebSocketWriter,
  23. WSCloseCode,
  24. WSMessage,
  25. WSMsgType as WSMsgType,
  26. ws_ext_gen,
  27. ws_ext_parse,
  28. )
  29. from .http_websocket import _INTERNAL_RECEIVE_TYPES
  30. from .log import ws_logger
  31. from .streams import EofStream
  32. from .typedefs import JSONDecoder, JSONEncoder
  33. from .web_exceptions import HTTPBadRequest, HTTPException
  34. from .web_request import BaseRequest
  35. from .web_response import StreamResponse
  36. if sys.version_info >= (3, 11):
  37. import asyncio as async_timeout
  38. else:
  39. import async_timeout
  40. __all__ = (
  41. "WebSocketResponse",
  42. "WebSocketReady",
  43. "WSMsgType",
  44. )
  45. THRESHOLD_CONNLOST_ACCESS: Final[int] = 5
  46. @attr.s(auto_attribs=True, frozen=True, slots=True)
  47. class WebSocketReady:
  48. ok: bool
  49. protocol: Optional[str]
  50. def __bool__(self) -> bool:
  51. return self.ok
  52. class WebSocketResponse(StreamResponse):
  53. _length_check: bool = False
  54. _ws_protocol: Optional[str] = None
  55. _writer: Optional[WebSocketWriter] = None
  56. _reader: Optional[WebSocketDataQueue] = None
  57. _closed: bool = False
  58. _closing: bool = False
  59. _conn_lost: int = 0
  60. _close_code: Optional[int] = None
  61. _loop: Optional[asyncio.AbstractEventLoop] = None
  62. _waiting: bool = False
  63. _close_wait: Optional[asyncio.Future[None]] = None
  64. _exception: Optional[BaseException] = None
  65. _heartbeat_when: float = 0.0
  66. _heartbeat_cb: Optional[asyncio.TimerHandle] = None
  67. _pong_response_cb: Optional[asyncio.TimerHandle] = None
  68. _ping_task: Optional[asyncio.Task[None]] = None
  69. def __init__(
  70. self,
  71. *,
  72. timeout: float = 10.0,
  73. receive_timeout: Optional[float] = None,
  74. autoclose: bool = True,
  75. autoping: bool = True,
  76. heartbeat: Optional[float] = None,
  77. protocols: Iterable[str] = (),
  78. compress: bool = True,
  79. max_msg_size: int = 4 * 1024 * 1024,
  80. writer_limit: int = DEFAULT_LIMIT,
  81. ) -> None:
  82. super().__init__(status=101)
  83. self._protocols = protocols
  84. self._timeout = timeout
  85. self._receive_timeout = receive_timeout
  86. self._autoclose = autoclose
  87. self._autoping = autoping
  88. self._heartbeat = heartbeat
  89. if heartbeat is not None:
  90. self._pong_heartbeat = heartbeat / 2.0
  91. self._compress: Union[bool, int] = compress
  92. self._max_msg_size = max_msg_size
  93. self._writer_limit = writer_limit
  94. def _cancel_heartbeat(self) -> None:
  95. self._cancel_pong_response_cb()
  96. if self._heartbeat_cb is not None:
  97. self._heartbeat_cb.cancel()
  98. self._heartbeat_cb = None
  99. if self._ping_task is not None:
  100. self._ping_task.cancel()
  101. self._ping_task = None
  102. def _cancel_pong_response_cb(self) -> None:
  103. if self._pong_response_cb is not None:
  104. self._pong_response_cb.cancel()
  105. self._pong_response_cb = None
  106. def _reset_heartbeat(self) -> None:
  107. if self._heartbeat is None:
  108. return
  109. self._cancel_pong_response_cb()
  110. req = self._req
  111. timeout_ceil_threshold = (
  112. req._protocol._timeout_ceil_threshold if req is not None else 5
  113. )
  114. loop = self._loop
  115. assert loop is not None
  116. now = loop.time()
  117. when = calculate_timeout_when(now, self._heartbeat, timeout_ceil_threshold)
  118. self._heartbeat_when = when
  119. if self._heartbeat_cb is None:
  120. # We do not cancel the previous heartbeat_cb here because
  121. # it generates a significant amount of TimerHandle churn
  122. # which causes asyncio to rebuild the heap frequently.
  123. # Instead _send_heartbeat() will reschedule the next
  124. # heartbeat if it fires too early.
  125. self._heartbeat_cb = loop.call_at(when, self._send_heartbeat)
  126. def _send_heartbeat(self) -> None:
  127. self._heartbeat_cb = None
  128. loop = self._loop
  129. assert loop is not None and self._writer is not None
  130. now = loop.time()
  131. if now < self._heartbeat_when:
  132. # Heartbeat fired too early, reschedule
  133. self._heartbeat_cb = loop.call_at(
  134. self._heartbeat_when, self._send_heartbeat
  135. )
  136. return
  137. req = self._req
  138. timeout_ceil_threshold = (
  139. req._protocol._timeout_ceil_threshold if req is not None else 5
  140. )
  141. when = calculate_timeout_when(now, self._pong_heartbeat, timeout_ceil_threshold)
  142. self._cancel_pong_response_cb()
  143. self._pong_response_cb = loop.call_at(when, self._pong_not_received)
  144. coro = self._writer.send_frame(b"", WSMsgType.PING)
  145. if sys.version_info >= (3, 12):
  146. # Optimization for Python 3.12, try to send the ping
  147. # immediately to avoid having to schedule
  148. # the task on the event loop.
  149. ping_task = asyncio.Task(coro, loop=loop, eager_start=True)
  150. else:
  151. ping_task = loop.create_task(coro)
  152. if not ping_task.done():
  153. self._ping_task = ping_task
  154. ping_task.add_done_callback(self._ping_task_done)
  155. else:
  156. self._ping_task_done(ping_task)
  157. def _ping_task_done(self, task: "asyncio.Task[None]") -> None:
  158. """Callback for when the ping task completes."""
  159. if not task.cancelled() and (exc := task.exception()):
  160. self._handle_ping_pong_exception(exc)
  161. self._ping_task = None
  162. def _pong_not_received(self) -> None:
  163. if self._req is not None and self._req.transport is not None:
  164. self._handle_ping_pong_exception(
  165. asyncio.TimeoutError(
  166. f"No PONG received after {self._pong_heartbeat} seconds"
  167. )
  168. )
  169. def _handle_ping_pong_exception(self, exc: BaseException) -> None:
  170. """Handle exceptions raised during ping/pong processing."""
  171. if self._closed:
  172. return
  173. self._set_closed()
  174. self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
  175. self._exception = exc
  176. if self._waiting and not self._closing and self._reader is not None:
  177. self._reader.feed_data(WSMessage(WSMsgType.ERROR, exc, None), 0)
  178. def _set_closed(self) -> None:
  179. """Set the connection to closed.
  180. Cancel any heartbeat timers and set the closed flag.
  181. """
  182. self._closed = True
  183. self._cancel_heartbeat()
  184. async def prepare(self, request: BaseRequest) -> AbstractStreamWriter:
  185. # make pre-check to don't hide it by do_handshake() exceptions
  186. if self._payload_writer is not None:
  187. return self._payload_writer
  188. protocol, writer = self._pre_start(request)
  189. payload_writer = await super().prepare(request)
  190. assert payload_writer is not None
  191. self._post_start(request, protocol, writer)
  192. await payload_writer.drain()
  193. return payload_writer
  194. def _handshake(
  195. self, request: BaseRequest
  196. ) -> Tuple["CIMultiDict[str]", Optional[str], int, bool]:
  197. headers = request.headers
  198. if "websocket" != headers.get(hdrs.UPGRADE, "").lower().strip():
  199. raise HTTPBadRequest(
  200. text=(
  201. "No WebSocket UPGRADE hdr: {}\n Can "
  202. '"Upgrade" only to "WebSocket".'
  203. ).format(headers.get(hdrs.UPGRADE))
  204. )
  205. if "upgrade" not in headers.get(hdrs.CONNECTION, "").lower():
  206. raise HTTPBadRequest(
  207. text="No CONNECTION upgrade hdr: {}".format(
  208. headers.get(hdrs.CONNECTION)
  209. )
  210. )
  211. # find common sub-protocol between client and server
  212. protocol: Optional[str] = None
  213. if hdrs.SEC_WEBSOCKET_PROTOCOL in headers:
  214. req_protocols = [
  215. str(proto.strip())
  216. for proto in headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(",")
  217. ]
  218. for proto in req_protocols:
  219. if proto in self._protocols:
  220. protocol = proto
  221. break
  222. else:
  223. # No overlap found: Return no protocol as per spec
  224. ws_logger.warning(
  225. "%s: Client protocols %r don’t overlap server-known ones %r",
  226. request.remote,
  227. req_protocols,
  228. self._protocols,
  229. )
  230. # check supported version
  231. version = headers.get(hdrs.SEC_WEBSOCKET_VERSION, "")
  232. if version not in ("13", "8", "7"):
  233. raise HTTPBadRequest(text=f"Unsupported version: {version}")
  234. # check client handshake for validity
  235. key = headers.get(hdrs.SEC_WEBSOCKET_KEY)
  236. try:
  237. if not key or len(base64.b64decode(key)) != 16:
  238. raise HTTPBadRequest(text=f"Handshake error: {key!r}")
  239. except binascii.Error:
  240. raise HTTPBadRequest(text=f"Handshake error: {key!r}") from None
  241. accept_val = base64.b64encode(
  242. hashlib.sha1(key.encode() + WS_KEY).digest()
  243. ).decode()
  244. response_headers = CIMultiDict(
  245. {
  246. hdrs.UPGRADE: "websocket",
  247. hdrs.CONNECTION: "upgrade",
  248. hdrs.SEC_WEBSOCKET_ACCEPT: accept_val,
  249. }
  250. )
  251. notakeover = False
  252. compress = 0
  253. if self._compress:
  254. extensions = headers.get(hdrs.SEC_WEBSOCKET_EXTENSIONS)
  255. # Server side always get return with no exception.
  256. # If something happened, just drop compress extension
  257. compress, notakeover = ws_ext_parse(extensions, isserver=True)
  258. if compress:
  259. enabledext = ws_ext_gen(
  260. compress=compress, isserver=True, server_notakeover=notakeover
  261. )
  262. response_headers[hdrs.SEC_WEBSOCKET_EXTENSIONS] = enabledext
  263. if protocol:
  264. response_headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = protocol
  265. return (
  266. response_headers,
  267. protocol,
  268. compress,
  269. notakeover,
  270. )
  271. def _pre_start(self, request: BaseRequest) -> Tuple[Optional[str], WebSocketWriter]:
  272. self._loop = request._loop
  273. headers, protocol, compress, notakeover = self._handshake(request)
  274. self.set_status(101)
  275. self.headers.update(headers)
  276. self.force_close()
  277. self._compress = compress
  278. transport = request._protocol.transport
  279. assert transport is not None
  280. writer = WebSocketWriter(
  281. request._protocol,
  282. transport,
  283. compress=compress,
  284. notakeover=notakeover,
  285. limit=self._writer_limit,
  286. )
  287. return protocol, writer
  288. def _post_start(
  289. self, request: BaseRequest, protocol: Optional[str], writer: WebSocketWriter
  290. ) -> None:
  291. self._ws_protocol = protocol
  292. self._writer = writer
  293. self._reset_heartbeat()
  294. loop = self._loop
  295. assert loop is not None
  296. self._reader = WebSocketDataQueue(request._protocol, 2**16, loop=loop)
  297. request.protocol.set_parser(
  298. WebSocketReader(
  299. self._reader, self._max_msg_size, compress=bool(self._compress)
  300. )
  301. )
  302. # disable HTTP keepalive for WebSocket
  303. request.protocol.keep_alive(False)
  304. def can_prepare(self, request: BaseRequest) -> WebSocketReady:
  305. if self._writer is not None:
  306. raise RuntimeError("Already started")
  307. try:
  308. _, protocol, _, _ = self._handshake(request)
  309. except HTTPException:
  310. return WebSocketReady(False, None)
  311. else:
  312. return WebSocketReady(True, protocol)
  313. @property
  314. def prepared(self) -> bool:
  315. return self._writer is not None
  316. @property
  317. def closed(self) -> bool:
  318. return self._closed
  319. @property
  320. def close_code(self) -> Optional[int]:
  321. return self._close_code
  322. @property
  323. def ws_protocol(self) -> Optional[str]:
  324. return self._ws_protocol
  325. @property
  326. def compress(self) -> Union[int, bool]:
  327. return self._compress
  328. def get_extra_info(self, name: str, default: Any = None) -> Any:
  329. """Get optional transport information.
  330. If no value associated with ``name`` is found, ``default`` is returned.
  331. """
  332. writer = self._writer
  333. if writer is None:
  334. return default
  335. transport = writer.transport
  336. if transport is None:
  337. return default
  338. return transport.get_extra_info(name, default)
  339. def exception(self) -> Optional[BaseException]:
  340. return self._exception
  341. async def ping(self, message: bytes = b"") -> None:
  342. if self._writer is None:
  343. raise RuntimeError("Call .prepare() first")
  344. await self._writer.send_frame(message, WSMsgType.PING)
  345. async def pong(self, message: bytes = b"") -> None:
  346. # unsolicited pong
  347. if self._writer is None:
  348. raise RuntimeError("Call .prepare() first")
  349. await self._writer.send_frame(message, WSMsgType.PONG)
  350. async def send_frame(
  351. self, message: bytes, opcode: WSMsgType, compress: Optional[int] = None
  352. ) -> None:
  353. """Send a frame over the websocket."""
  354. if self._writer is None:
  355. raise RuntimeError("Call .prepare() first")
  356. await self._writer.send_frame(message, opcode, compress)
  357. async def send_str(self, data: str, compress: Optional[int] = None) -> None:
  358. if self._writer is None:
  359. raise RuntimeError("Call .prepare() first")
  360. if not isinstance(data, str):
  361. raise TypeError("data argument must be str (%r)" % type(data))
  362. await self._writer.send_frame(
  363. data.encode("utf-8"), WSMsgType.TEXT, compress=compress
  364. )
  365. async def send_bytes(self, data: bytes, compress: Optional[int] = None) -> None:
  366. if self._writer is None:
  367. raise RuntimeError("Call .prepare() first")
  368. if not isinstance(data, (bytes, bytearray, memoryview)):
  369. raise TypeError("data argument must be byte-ish (%r)" % type(data))
  370. await self._writer.send_frame(data, WSMsgType.BINARY, compress=compress)
  371. async def send_json(
  372. self,
  373. data: Any,
  374. compress: Optional[int] = None,
  375. *,
  376. dumps: JSONEncoder = json.dumps,
  377. ) -> None:
  378. await self.send_str(dumps(data), compress=compress)
  379. async def write_eof(self) -> None: # type: ignore[override]
  380. if self._eof_sent:
  381. return
  382. if self._payload_writer is None:
  383. raise RuntimeError("Response has not been started")
  384. await self.close()
  385. self._eof_sent = True
  386. async def close(
  387. self, *, code: int = WSCloseCode.OK, message: bytes = b"", drain: bool = True
  388. ) -> bool:
  389. """Close websocket connection."""
  390. if self._writer is None:
  391. raise RuntimeError("Call .prepare() first")
  392. if self._closed:
  393. return False
  394. self._set_closed()
  395. try:
  396. await self._writer.close(code, message)
  397. writer = self._payload_writer
  398. assert writer is not None
  399. if drain:
  400. await writer.drain()
  401. except (asyncio.CancelledError, asyncio.TimeoutError):
  402. self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
  403. raise
  404. except Exception as exc:
  405. self._exception = exc
  406. self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
  407. return True
  408. reader = self._reader
  409. assert reader is not None
  410. # we need to break `receive()` cycle before we can call
  411. # `reader.read()` as `close()` may be called from different task
  412. if self._waiting:
  413. assert self._loop is not None
  414. assert self._close_wait is None
  415. self._close_wait = self._loop.create_future()
  416. reader.feed_data(WS_CLOSING_MESSAGE, 0)
  417. await self._close_wait
  418. if self._closing:
  419. self._close_transport()
  420. return True
  421. try:
  422. async with async_timeout.timeout(self._timeout):
  423. while True:
  424. msg = await reader.read()
  425. if msg.type is WSMsgType.CLOSE:
  426. self._set_code_close_transport(msg.data)
  427. return True
  428. except asyncio.CancelledError:
  429. self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
  430. raise
  431. except Exception as exc:
  432. self._exception = exc
  433. self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
  434. return True
  435. def _set_closing(self, code: WSCloseCode) -> None:
  436. """Set the close code and mark the connection as closing."""
  437. self._closing = True
  438. self._close_code = code
  439. self._cancel_heartbeat()
  440. def _set_code_close_transport(self, code: WSCloseCode) -> None:
  441. """Set the close code and close the transport."""
  442. self._close_code = code
  443. self._close_transport()
  444. def _close_transport(self) -> None:
  445. """Close the transport."""
  446. if self._req is not None and self._req.transport is not None:
  447. self._req.transport.close()
  448. async def receive(self, timeout: Optional[float] = None) -> WSMessage:
  449. if self._reader is None:
  450. raise RuntimeError("Call .prepare() first")
  451. receive_timeout = timeout or self._receive_timeout
  452. while True:
  453. if self._waiting:
  454. raise RuntimeError("Concurrent call to receive() is not allowed")
  455. if self._closed:
  456. self._conn_lost += 1
  457. if self._conn_lost >= THRESHOLD_CONNLOST_ACCESS:
  458. raise RuntimeError("WebSocket connection is closed.")
  459. return WS_CLOSED_MESSAGE
  460. elif self._closing:
  461. return WS_CLOSING_MESSAGE
  462. try:
  463. self._waiting = True
  464. try:
  465. if receive_timeout:
  466. # Entering the context manager and creating
  467. # Timeout() object can take almost 50% of the
  468. # run time in this loop so we avoid it if
  469. # there is no read timeout.
  470. async with async_timeout.timeout(receive_timeout):
  471. msg = await self._reader.read()
  472. else:
  473. msg = await self._reader.read()
  474. self._reset_heartbeat()
  475. finally:
  476. self._waiting = False
  477. if self._close_wait:
  478. set_result(self._close_wait, None)
  479. except asyncio.TimeoutError:
  480. raise
  481. except EofStream:
  482. self._close_code = WSCloseCode.OK
  483. await self.close()
  484. return WSMessage(WSMsgType.CLOSED, None, None)
  485. except WebSocketError as exc:
  486. self._close_code = exc.code
  487. await self.close(code=exc.code)
  488. return WSMessage(WSMsgType.ERROR, exc, None)
  489. except Exception as exc:
  490. self._exception = exc
  491. self._set_closing(WSCloseCode.ABNORMAL_CLOSURE)
  492. await self.close()
  493. return WSMessage(WSMsgType.ERROR, exc, None)
  494. if msg.type not in _INTERNAL_RECEIVE_TYPES:
  495. # If its not a close/closing/ping/pong message
  496. # we can return it immediately
  497. return msg
  498. if msg.type is WSMsgType.CLOSE:
  499. self._set_closing(msg.data)
  500. # Could be closed while awaiting reader.
  501. if not self._closed and self._autoclose:
  502. # The client is likely going to close the
  503. # connection out from under us so we do not
  504. # want to drain any pending writes as it will
  505. # likely result writing to a broken pipe.
  506. await self.close(drain=False)
  507. elif msg.type is WSMsgType.CLOSING:
  508. self._set_closing(WSCloseCode.OK)
  509. elif msg.type is WSMsgType.PING and self._autoping:
  510. await self.pong(msg.data)
  511. continue
  512. elif msg.type is WSMsgType.PONG and self._autoping:
  513. continue
  514. return msg
  515. async def receive_str(self, *, timeout: Optional[float] = None) -> str:
  516. msg = await self.receive(timeout)
  517. if msg.type is not WSMsgType.TEXT:
  518. raise WSMessageTypeError(
  519. f"Received message {msg.type}:{msg.data!r} is not WSMsgType.TEXT"
  520. )
  521. return cast(str, msg.data)
  522. async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes:
  523. msg = await self.receive(timeout)
  524. if msg.type is not WSMsgType.BINARY:
  525. raise WSMessageTypeError(
  526. f"Received message {msg.type}:{msg.data!r} is not WSMsgType.BINARY"
  527. )
  528. return cast(bytes, msg.data)
  529. async def receive_json(
  530. self, *, loads: JSONDecoder = json.loads, timeout: Optional[float] = None
  531. ) -> Any:
  532. data = await self.receive_str(timeout=timeout)
  533. return loads(data)
  534. async def write(self, data: bytes) -> None:
  535. raise RuntimeError("Cannot call .write() for websocket")
  536. def __aiter__(self) -> "WebSocketResponse":
  537. return self
  538. async def __anext__(self) -> WSMessage:
  539. msg = await self.receive()
  540. if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
  541. raise StopAsyncIteration
  542. return msg
  543. def _cancel(self, exc: BaseException) -> None:
  544. # web_protocol calls this from connection_lost
  545. # or when the server is shutting down.
  546. self._closing = True
  547. self._cancel_heartbeat()
  548. if self._reader is not None:
  549. set_exception(self._reader, exc)