http_writer.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. """Http related parsers and protocol."""
  2. import asyncio
  3. import sys
  4. from typing import ( # noqa
  5. TYPE_CHECKING,
  6. Any,
  7. Awaitable,
  8. Callable,
  9. Iterable,
  10. List,
  11. NamedTuple,
  12. Optional,
  13. Union,
  14. )
  15. from multidict import CIMultiDict
  16. from .abc import AbstractStreamWriter
  17. from .base_protocol import BaseProtocol
  18. from .client_exceptions import ClientConnectionResetError
  19. from .compression_utils import ZLibCompressor
  20. from .helpers import NO_EXTENSIONS
  21. __all__ = ("StreamWriter", "HttpVersion", "HttpVersion10", "HttpVersion11")
  22. MIN_PAYLOAD_FOR_WRITELINES = 2048
  23. IS_PY313_BEFORE_313_2 = (3, 13, 0) <= sys.version_info < (3, 13, 2)
  24. IS_PY_BEFORE_312_9 = sys.version_info < (3, 12, 9)
  25. SKIP_WRITELINES = IS_PY313_BEFORE_313_2 or IS_PY_BEFORE_312_9
  26. # writelines is not safe for use
  27. # on Python 3.12+ until 3.12.9
  28. # on Python 3.13+ until 3.13.2
  29. # and on older versions it not any faster than write
  30. # CVE-2024-12254: https://github.com/python/cpython/pull/127656
  31. class HttpVersion(NamedTuple):
  32. major: int
  33. minor: int
  34. HttpVersion10 = HttpVersion(1, 0)
  35. HttpVersion11 = HttpVersion(1, 1)
  36. _T_OnChunkSent = Optional[Callable[[bytes], Awaitable[None]]]
  37. _T_OnHeadersSent = Optional[Callable[["CIMultiDict[str]"], Awaitable[None]]]
  38. class StreamWriter(AbstractStreamWriter):
  39. length: Optional[int] = None
  40. chunked: bool = False
  41. _eof: bool = False
  42. _compress: Optional[ZLibCompressor] = None
  43. def __init__(
  44. self,
  45. protocol: BaseProtocol,
  46. loop: asyncio.AbstractEventLoop,
  47. on_chunk_sent: _T_OnChunkSent = None,
  48. on_headers_sent: _T_OnHeadersSent = None,
  49. ) -> None:
  50. self._protocol = protocol
  51. self.loop = loop
  52. self._on_chunk_sent: _T_OnChunkSent = on_chunk_sent
  53. self._on_headers_sent: _T_OnHeadersSent = on_headers_sent
  54. self._headers_buf: Optional[bytes] = None
  55. self._headers_written: bool = False
  56. @property
  57. def transport(self) -> Optional[asyncio.Transport]:
  58. return self._protocol.transport
  59. @property
  60. def protocol(self) -> BaseProtocol:
  61. return self._protocol
  62. def enable_chunking(self) -> None:
  63. self.chunked = True
  64. def enable_compression(
  65. self, encoding: str = "deflate", strategy: Optional[int] = None
  66. ) -> None:
  67. self._compress = ZLibCompressor(encoding=encoding, strategy=strategy)
  68. def _write(self, chunk: Union[bytes, bytearray, memoryview]) -> None:
  69. size = len(chunk)
  70. self.buffer_size += size
  71. self.output_size += size
  72. transport = self._protocol.transport
  73. if transport is None or transport.is_closing():
  74. raise ClientConnectionResetError("Cannot write to closing transport")
  75. transport.write(chunk)
  76. def _writelines(self, chunks: Iterable[bytes]) -> None:
  77. size = 0
  78. for chunk in chunks:
  79. size += len(chunk)
  80. self.buffer_size += size
  81. self.output_size += size
  82. transport = self._protocol.transport
  83. if transport is None or transport.is_closing():
  84. raise ClientConnectionResetError("Cannot write to closing transport")
  85. if SKIP_WRITELINES or size < MIN_PAYLOAD_FOR_WRITELINES:
  86. transport.write(b"".join(chunks))
  87. else:
  88. transport.writelines(chunks)
  89. def _write_chunked_payload(
  90. self, chunk: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"]
  91. ) -> None:
  92. """Write a chunk with proper chunked encoding."""
  93. chunk_len_pre = f"{len(chunk):x}\r\n".encode("ascii")
  94. self._writelines((chunk_len_pre, chunk, b"\r\n"))
  95. def _send_headers_with_payload(
  96. self,
  97. chunk: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"],
  98. is_eof: bool,
  99. ) -> None:
  100. """Send buffered headers with payload, coalescing into single write."""
  101. # Mark headers as written
  102. self._headers_written = True
  103. headers_buf = self._headers_buf
  104. self._headers_buf = None
  105. if TYPE_CHECKING:
  106. # Safe because callers (write() and write_eof()) only invoke this method
  107. # after checking that self._headers_buf is truthy
  108. assert headers_buf is not None
  109. if not self.chunked:
  110. # Non-chunked: coalesce headers with body
  111. if chunk:
  112. self._writelines((headers_buf, chunk))
  113. else:
  114. self._write(headers_buf)
  115. return
  116. # Coalesce headers with chunked data
  117. if chunk:
  118. chunk_len_pre = f"{len(chunk):x}\r\n".encode("ascii")
  119. if is_eof:
  120. self._writelines((headers_buf, chunk_len_pre, chunk, b"\r\n0\r\n\r\n"))
  121. else:
  122. self._writelines((headers_buf, chunk_len_pre, chunk, b"\r\n"))
  123. elif is_eof:
  124. self._writelines((headers_buf, b"0\r\n\r\n"))
  125. else:
  126. self._write(headers_buf)
  127. async def write(
  128. self,
  129. chunk: Union[bytes, bytearray, memoryview],
  130. *,
  131. drain: bool = True,
  132. LIMIT: int = 0x10000,
  133. ) -> None:
  134. """
  135. Writes chunk of data to a stream.
  136. write_eof() indicates end of stream.
  137. writer can't be used after write_eof() method being called.
  138. write() return drain future.
  139. """
  140. if self._on_chunk_sent is not None:
  141. await self._on_chunk_sent(chunk)
  142. if isinstance(chunk, memoryview):
  143. if chunk.nbytes != len(chunk):
  144. # just reshape it
  145. chunk = chunk.cast("c")
  146. if self._compress is not None:
  147. chunk = await self._compress.compress(chunk)
  148. if not chunk:
  149. return
  150. if self.length is not None:
  151. chunk_len = len(chunk)
  152. if self.length >= chunk_len:
  153. self.length = self.length - chunk_len
  154. else:
  155. chunk = chunk[: self.length]
  156. self.length = 0
  157. if not chunk:
  158. return
  159. # Handle buffered headers for small payload optimization
  160. if self._headers_buf and not self._headers_written:
  161. self._send_headers_with_payload(chunk, False)
  162. if drain and self.buffer_size > LIMIT:
  163. self.buffer_size = 0
  164. await self.drain()
  165. return
  166. if chunk:
  167. if self.chunked:
  168. self._write_chunked_payload(chunk)
  169. else:
  170. self._write(chunk)
  171. if drain and self.buffer_size > LIMIT:
  172. self.buffer_size = 0
  173. await self.drain()
  174. async def write_headers(
  175. self, status_line: str, headers: "CIMultiDict[str]"
  176. ) -> None:
  177. """Write headers to the stream."""
  178. if self._on_headers_sent is not None:
  179. await self._on_headers_sent(headers)
  180. # status + headers
  181. buf = _serialize_headers(status_line, headers)
  182. self._headers_written = False
  183. self._headers_buf = buf
  184. def send_headers(self) -> None:
  185. """Force sending buffered headers if not already sent."""
  186. if not self._headers_buf or self._headers_written:
  187. return
  188. self._headers_written = True
  189. headers_buf = self._headers_buf
  190. self._headers_buf = None
  191. if TYPE_CHECKING:
  192. # Safe because we only enter this block when self._headers_buf is truthy
  193. assert headers_buf is not None
  194. self._write(headers_buf)
  195. def set_eof(self) -> None:
  196. """Indicate that the message is complete."""
  197. if self._eof:
  198. return
  199. # If headers haven't been sent yet, send them now
  200. # This handles the case where there's no body at all
  201. if self._headers_buf and not self._headers_written:
  202. self._headers_written = True
  203. headers_buf = self._headers_buf
  204. self._headers_buf = None
  205. if TYPE_CHECKING:
  206. # Safe because we only enter this block when self._headers_buf is truthy
  207. assert headers_buf is not None
  208. # Combine headers and chunked EOF marker in a single write
  209. if self.chunked:
  210. self._writelines((headers_buf, b"0\r\n\r\n"))
  211. else:
  212. self._write(headers_buf)
  213. elif self.chunked and self._headers_written:
  214. # Headers already sent, just send the final chunk marker
  215. self._write(b"0\r\n\r\n")
  216. self._eof = True
  217. async def write_eof(self, chunk: bytes = b"") -> None:
  218. if self._eof:
  219. return
  220. if chunk and self._on_chunk_sent is not None:
  221. await self._on_chunk_sent(chunk)
  222. # Handle body/compression
  223. if self._compress:
  224. chunks: List[bytes] = []
  225. chunks_len = 0
  226. if chunk and (compressed_chunk := await self._compress.compress(chunk)):
  227. chunks_len = len(compressed_chunk)
  228. chunks.append(compressed_chunk)
  229. flush_chunk = self._compress.flush()
  230. chunks_len += len(flush_chunk)
  231. chunks.append(flush_chunk)
  232. assert chunks_len
  233. # Send buffered headers with compressed data if not yet sent
  234. if self._headers_buf and not self._headers_written:
  235. self._headers_written = True
  236. headers_buf = self._headers_buf
  237. self._headers_buf = None
  238. if self.chunked:
  239. # Coalesce headers with compressed chunked data
  240. chunk_len_pre = f"{chunks_len:x}\r\n".encode("ascii")
  241. self._writelines(
  242. (headers_buf, chunk_len_pre, *chunks, b"\r\n0\r\n\r\n")
  243. )
  244. else:
  245. # Coalesce headers with compressed data
  246. self._writelines((headers_buf, *chunks))
  247. await self.drain()
  248. self._eof = True
  249. return
  250. # Headers already sent, just write compressed data
  251. if self.chunked:
  252. chunk_len_pre = f"{chunks_len:x}\r\n".encode("ascii")
  253. self._writelines((chunk_len_pre, *chunks, b"\r\n0\r\n\r\n"))
  254. elif len(chunks) > 1:
  255. self._writelines(chunks)
  256. else:
  257. self._write(chunks[0])
  258. await self.drain()
  259. self._eof = True
  260. return
  261. # No compression - send buffered headers if not yet sent
  262. if self._headers_buf and not self._headers_written:
  263. # Use helper to send headers with payload
  264. self._send_headers_with_payload(chunk, True)
  265. await self.drain()
  266. self._eof = True
  267. return
  268. # Handle remaining body
  269. if self.chunked:
  270. if chunk:
  271. # Write final chunk with EOF marker
  272. self._writelines(
  273. (f"{len(chunk):x}\r\n".encode("ascii"), chunk, b"\r\n0\r\n\r\n")
  274. )
  275. else:
  276. self._write(b"0\r\n\r\n")
  277. await self.drain()
  278. self._eof = True
  279. return
  280. if chunk:
  281. self._write(chunk)
  282. await self.drain()
  283. self._eof = True
  284. async def drain(self) -> None:
  285. """Flush the write buffer.
  286. The intended use is to write
  287. await w.write(data)
  288. await w.drain()
  289. """
  290. protocol = self._protocol
  291. if protocol.transport is not None and protocol._paused:
  292. await protocol._drain_helper()
  293. def _safe_header(string: str) -> str:
  294. if "\r" in string or "\n" in string:
  295. raise ValueError(
  296. "Newline or carriage return detected in headers. "
  297. "Potential header injection attack."
  298. )
  299. return string
  300. def _py_serialize_headers(status_line: str, headers: "CIMultiDict[str]") -> bytes:
  301. headers_gen = (_safe_header(k) + ": " + _safe_header(v) for k, v in headers.items())
  302. line = status_line + "\r\n" + "\r\n".join(headers_gen) + "\r\n\r\n"
  303. return line.encode("utf-8")
  304. _serialize_headers = _py_serialize_headers
  305. try:
  306. import aiohttp._http_writer as _http_writer # type: ignore[import-not-found]
  307. _c_serialize_headers = _http_writer._serialize_headers
  308. if not NO_EXTENSIONS:
  309. _serialize_headers = _c_serialize_headers
  310. except ImportError:
  311. pass