writer.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. """WebSocket protocol versions 13 and 8."""
  2. import asyncio
  3. import random
  4. import sys
  5. from functools import partial
  6. from typing import Final, Optional, Set, Union
  7. from ..base_protocol import BaseProtocol
  8. from ..client_exceptions import ClientConnectionResetError
  9. from ..compression_utils import ZLibBackend, ZLibCompressor
  10. from .helpers import (
  11. MASK_LEN,
  12. MSG_SIZE,
  13. PACK_CLOSE_CODE,
  14. PACK_LEN1,
  15. PACK_LEN2,
  16. PACK_LEN3,
  17. PACK_RANDBITS,
  18. websocket_mask,
  19. )
  20. from .models import WS_DEFLATE_TRAILING, WSMsgType
  21. DEFAULT_LIMIT: Final[int] = 2**16
  22. # WebSocket opcode boundary: opcodes 0-7 are data frames, 8-15 are control frames
  23. # Control frames (ping, pong, close) are never compressed
  24. WS_CONTROL_FRAME_OPCODE: Final[int] = 8
  25. # For websockets, keeping latency low is extremely important as implementations
  26. # generally expect to be able to send and receive messages quickly. We use a
  27. # larger chunk size to reduce the number of executor calls and avoid task
  28. # creation overhead, since both are significant sources of latency when chunks
  29. # are small. A size of 16KiB was chosen as a balance between avoiding task
  30. # overhead and not blocking the event loop too long with synchronous compression.
  31. WEBSOCKET_MAX_SYNC_CHUNK_SIZE = 16 * 1024
  32. class WebSocketWriter:
  33. """WebSocket writer.
  34. The writer is responsible for sending messages to the client. It is
  35. created by the protocol when a connection is established. The writer
  36. should avoid implementing any application logic and should only be
  37. concerned with the low-level details of the WebSocket protocol.
  38. """
  39. def __init__(
  40. self,
  41. protocol: BaseProtocol,
  42. transport: asyncio.Transport,
  43. *,
  44. use_mask: bool = False,
  45. limit: int = DEFAULT_LIMIT,
  46. random: random.Random = random.Random(),
  47. compress: int = 0,
  48. notakeover: bool = False,
  49. ) -> None:
  50. """Initialize a WebSocket writer."""
  51. self.protocol = protocol
  52. self.transport = transport
  53. self.use_mask = use_mask
  54. self.get_random_bits = partial(random.getrandbits, 32)
  55. self.compress = compress
  56. self.notakeover = notakeover
  57. self._closing = False
  58. self._limit = limit
  59. self._output_size = 0
  60. self._compressobj: Optional[ZLibCompressor] = None
  61. self._send_lock = asyncio.Lock()
  62. self._background_tasks: Set[asyncio.Task[None]] = set()
  63. async def send_frame(
  64. self, message: bytes, opcode: int, compress: Optional[int] = None
  65. ) -> None:
  66. """Send a frame over the websocket with message as its payload."""
  67. if self._closing and not (opcode & WSMsgType.CLOSE):
  68. raise ClientConnectionResetError("Cannot write to closing transport")
  69. if not (compress or self.compress) or opcode >= WS_CONTROL_FRAME_OPCODE:
  70. # Non-compressed frames don't need lock or shield
  71. self._write_websocket_frame(message, opcode, 0)
  72. elif len(message) <= WEBSOCKET_MAX_SYNC_CHUNK_SIZE:
  73. # Small compressed payloads - compress synchronously in event loop
  74. # We need the lock even though sync compression has no await points.
  75. # This prevents small frames from interleaving with large frames that
  76. # compress in the executor, avoiding compressor state corruption.
  77. async with self._send_lock:
  78. self._send_compressed_frame_sync(message, opcode, compress)
  79. else:
  80. # Large compressed frames need shield to prevent corruption
  81. # For large compressed frames, the entire compress+send
  82. # operation must be atomic. If cancelled after compression but
  83. # before send, the compressor state would be advanced but data
  84. # not sent, corrupting subsequent frames.
  85. # Create a task to shield from cancellation
  86. # The lock is acquired inside the shielded task so the entire
  87. # operation (lock + compress + send) completes atomically.
  88. # Use eager_start on Python 3.12+ to avoid scheduling overhead
  89. loop = asyncio.get_running_loop()
  90. coro = self._send_compressed_frame_async_locked(message, opcode, compress)
  91. if sys.version_info >= (3, 12):
  92. send_task = asyncio.Task(coro, loop=loop, eager_start=True)
  93. else:
  94. send_task = loop.create_task(coro)
  95. # Keep a strong reference to prevent garbage collection
  96. self._background_tasks.add(send_task)
  97. send_task.add_done_callback(self._background_tasks.discard)
  98. await asyncio.shield(send_task)
  99. # It is safe to return control to the event loop when using compression
  100. # after this point as we have already sent or buffered all the data.
  101. # Once we have written output_size up to the limit, we call the
  102. # drain helper which waits for the transport to be ready to accept
  103. # more data. This is a flow control mechanism to prevent the buffer
  104. # from growing too large. The drain helper will return right away
  105. # if the writer is not paused.
  106. if self._output_size > self._limit:
  107. self._output_size = 0
  108. if self.protocol._paused:
  109. await self.protocol._drain_helper()
  110. def _write_websocket_frame(self, message: bytes, opcode: int, rsv: int) -> None:
  111. """
  112. Write a websocket frame to the transport.
  113. This method handles frame header construction, masking, and writing to transport.
  114. It does not handle compression or flow control - those are the responsibility
  115. of the caller.
  116. """
  117. msg_length = len(message)
  118. use_mask = self.use_mask
  119. mask_bit = 0x80 if use_mask else 0
  120. # Depending on the message length, the header is assembled differently.
  121. # The first byte is reserved for the opcode and the RSV bits.
  122. first_byte = 0x80 | rsv | opcode
  123. if msg_length < 126:
  124. header = PACK_LEN1(first_byte, msg_length | mask_bit)
  125. header_len = 2
  126. elif msg_length < 65536:
  127. header = PACK_LEN2(first_byte, 126 | mask_bit, msg_length)
  128. header_len = 4
  129. else:
  130. header = PACK_LEN3(first_byte, 127 | mask_bit, msg_length)
  131. header_len = 10
  132. if self.transport.is_closing():
  133. raise ClientConnectionResetError("Cannot write to closing transport")
  134. # https://datatracker.ietf.org/doc/html/rfc6455#section-5.3
  135. # If we are using a mask, we need to generate it randomly
  136. # and apply it to the message before sending it. A mask is
  137. # a 32-bit value that is applied to the message using a
  138. # bitwise XOR operation. It is used to prevent certain types
  139. # of attacks on the websocket protocol. The mask is only used
  140. # when aiohttp is acting as a client. Servers do not use a mask.
  141. if use_mask:
  142. mask = PACK_RANDBITS(self.get_random_bits())
  143. message = bytearray(message)
  144. websocket_mask(mask, message)
  145. self.transport.write(header + mask + message)
  146. self._output_size += MASK_LEN
  147. elif msg_length > MSG_SIZE:
  148. self.transport.write(header)
  149. self.transport.write(message)
  150. else:
  151. self.transport.write(header + message)
  152. self._output_size += header_len + msg_length
  153. def _get_compressor(self, compress: Optional[int]) -> ZLibCompressor:
  154. """Get or create a compressor object for the given compression level."""
  155. if compress:
  156. # Do not set self._compress if compressing is for this frame
  157. return ZLibCompressor(
  158. level=ZLibBackend.Z_BEST_SPEED,
  159. wbits=-compress,
  160. max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
  161. )
  162. if not self._compressobj:
  163. self._compressobj = ZLibCompressor(
  164. level=ZLibBackend.Z_BEST_SPEED,
  165. wbits=-self.compress,
  166. max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
  167. )
  168. return self._compressobj
  169. def _send_compressed_frame_sync(
  170. self, message: bytes, opcode: int, compress: Optional[int]
  171. ) -> None:
  172. """
  173. Synchronous send for small compressed frames.
  174. This is used for small compressed payloads that compress synchronously in the event loop.
  175. Since there are no await points, this is inherently cancellation-safe.
  176. """
  177. # RSV are the reserved bits in the frame header. They are used to
  178. # indicate that the frame is using an extension.
  179. # https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
  180. compressobj = self._get_compressor(compress)
  181. # (0x40) RSV1 is set for compressed frames
  182. # https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1
  183. self._write_websocket_frame(
  184. (
  185. compressobj.compress_sync(message)
  186. + compressobj.flush(
  187. ZLibBackend.Z_FULL_FLUSH
  188. if self.notakeover
  189. else ZLibBackend.Z_SYNC_FLUSH
  190. )
  191. ).removesuffix(WS_DEFLATE_TRAILING),
  192. opcode,
  193. 0x40,
  194. )
  195. async def _send_compressed_frame_async_locked(
  196. self, message: bytes, opcode: int, compress: Optional[int]
  197. ) -> None:
  198. """
  199. Async send for large compressed frames with lock.
  200. Acquires the lock and compresses large payloads asynchronously in
  201. the executor. The lock is held for the entire operation to ensure
  202. the compressor state is not corrupted by concurrent sends.
  203. MUST be run shielded from cancellation. If cancelled after
  204. compression but before sending, the compressor state would be
  205. advanced but data not sent, corrupting subsequent frames.
  206. """
  207. async with self._send_lock:
  208. # RSV are the reserved bits in the frame header. They are used to
  209. # indicate that the frame is using an extension.
  210. # https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
  211. compressobj = self._get_compressor(compress)
  212. # (0x40) RSV1 is set for compressed frames
  213. # https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1
  214. self._write_websocket_frame(
  215. (
  216. await compressobj.compress(message)
  217. + compressobj.flush(
  218. ZLibBackend.Z_FULL_FLUSH
  219. if self.notakeover
  220. else ZLibBackend.Z_SYNC_FLUSH
  221. )
  222. ).removesuffix(WS_DEFLATE_TRAILING),
  223. opcode,
  224. 0x40,
  225. )
  226. async def close(self, code: int = 1000, message: Union[bytes, str] = b"") -> None:
  227. """Close the websocket, sending the specified code and message."""
  228. if isinstance(message, str):
  229. message = message.encode("utf-8")
  230. try:
  231. await self.send_frame(
  232. PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE
  233. )
  234. finally:
  235. self._closing = True