multipart.py 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152
  1. import base64
  2. import binascii
  3. import json
  4. import re
  5. import sys
  6. import uuid
  7. import warnings
  8. from collections import deque
  9. from collections.abc import Mapping, Sequence
  10. from types import TracebackType
  11. from typing import (
  12. TYPE_CHECKING,
  13. Any,
  14. Deque,
  15. Dict,
  16. Iterator,
  17. List,
  18. Optional,
  19. Tuple,
  20. Type,
  21. Union,
  22. cast,
  23. )
  24. from urllib.parse import parse_qsl, unquote, urlencode
  25. from multidict import CIMultiDict, CIMultiDictProxy
  26. from .abc import AbstractStreamWriter
  27. from .compression_utils import (
  28. DEFAULT_MAX_DECOMPRESS_SIZE,
  29. ZLibCompressor,
  30. ZLibDecompressor,
  31. )
  32. from .hdrs import (
  33. CONTENT_DISPOSITION,
  34. CONTENT_ENCODING,
  35. CONTENT_LENGTH,
  36. CONTENT_TRANSFER_ENCODING,
  37. CONTENT_TYPE,
  38. )
  39. from .helpers import CHAR, TOKEN, parse_mimetype, reify
  40. from .http import HeadersParser
  41. from .log import internal_logger
  42. from .payload import (
  43. JsonPayload,
  44. LookupError,
  45. Order,
  46. Payload,
  47. StringPayload,
  48. get_payload,
  49. payload_type,
  50. )
  51. from .streams import StreamReader
  52. if sys.version_info >= (3, 11):
  53. from typing import Self
  54. else:
  55. from typing import TypeVar
  56. Self = TypeVar("Self", bound="BodyPartReader")
  57. __all__ = (
  58. "MultipartReader",
  59. "MultipartWriter",
  60. "BodyPartReader",
  61. "BadContentDispositionHeader",
  62. "BadContentDispositionParam",
  63. "parse_content_disposition",
  64. "content_disposition_filename",
  65. )
  66. if TYPE_CHECKING:
  67. from .client_reqrep import ClientResponse
  68. class BadContentDispositionHeader(RuntimeWarning):
  69. pass
  70. class BadContentDispositionParam(RuntimeWarning):
  71. pass
  72. def parse_content_disposition(
  73. header: Optional[str],
  74. ) -> Tuple[Optional[str], Dict[str, str]]:
  75. def is_token(string: str) -> bool:
  76. return bool(string) and TOKEN >= set(string)
  77. def is_quoted(string: str) -> bool:
  78. return string[0] == string[-1] == '"'
  79. def is_rfc5987(string: str) -> bool:
  80. return is_token(string) and string.count("'") == 2
  81. def is_extended_param(string: str) -> bool:
  82. return string.endswith("*")
  83. def is_continuous_param(string: str) -> bool:
  84. pos = string.find("*") + 1
  85. if not pos:
  86. return False
  87. substring = string[pos:-1] if string.endswith("*") else string[pos:]
  88. return substring.isdigit()
  89. def unescape(text: str, *, chars: str = "".join(map(re.escape, CHAR))) -> str:
  90. return re.sub(f"\\\\([{chars}])", "\\1", text)
  91. if not header:
  92. return None, {}
  93. disptype, *parts = header.split(";")
  94. if not is_token(disptype):
  95. warnings.warn(BadContentDispositionHeader(header))
  96. return None, {}
  97. params: Dict[str, str] = {}
  98. while parts:
  99. item = parts.pop(0)
  100. if not item: # To handle trailing semicolons
  101. warnings.warn(BadContentDispositionHeader(header))
  102. continue
  103. if "=" not in item:
  104. warnings.warn(BadContentDispositionHeader(header))
  105. return None, {}
  106. key, value = item.split("=", 1)
  107. key = key.lower().strip()
  108. value = value.lstrip()
  109. if key in params:
  110. warnings.warn(BadContentDispositionHeader(header))
  111. return None, {}
  112. if not is_token(key):
  113. warnings.warn(BadContentDispositionParam(item))
  114. continue
  115. elif is_continuous_param(key):
  116. if is_quoted(value):
  117. value = unescape(value[1:-1])
  118. elif not is_token(value):
  119. warnings.warn(BadContentDispositionParam(item))
  120. continue
  121. elif is_extended_param(key):
  122. if is_rfc5987(value):
  123. encoding, _, value = value.split("'", 2)
  124. encoding = encoding or "utf-8"
  125. else:
  126. warnings.warn(BadContentDispositionParam(item))
  127. continue
  128. try:
  129. value = unquote(value, encoding, "strict")
  130. except UnicodeDecodeError: # pragma: nocover
  131. warnings.warn(BadContentDispositionParam(item))
  132. continue
  133. else:
  134. failed = True
  135. if is_quoted(value):
  136. failed = False
  137. value = unescape(value[1:-1].lstrip("\\/"))
  138. elif is_token(value):
  139. failed = False
  140. elif parts:
  141. # maybe just ; in filename, in any case this is just
  142. # one case fix, for proper fix we need to redesign parser
  143. _value = f"{value};{parts[0]}"
  144. if is_quoted(_value):
  145. parts.pop(0)
  146. value = unescape(_value[1:-1].lstrip("\\/"))
  147. failed = False
  148. if failed:
  149. warnings.warn(BadContentDispositionHeader(header))
  150. return None, {}
  151. params[key] = value
  152. return disptype.lower(), params
  153. def content_disposition_filename(
  154. params: Mapping[str, str], name: str = "filename"
  155. ) -> Optional[str]:
  156. name_suf = "%s*" % name
  157. if not params:
  158. return None
  159. elif name_suf in params:
  160. return params[name_suf]
  161. elif name in params:
  162. return params[name]
  163. else:
  164. parts = []
  165. fnparams = sorted(
  166. (key, value) for key, value in params.items() if key.startswith(name_suf)
  167. )
  168. for num, (key, value) in enumerate(fnparams):
  169. _, tail = key.split("*", 1)
  170. if tail.endswith("*"):
  171. tail = tail[:-1]
  172. if tail == str(num):
  173. parts.append(value)
  174. else:
  175. break
  176. if not parts:
  177. return None
  178. value = "".join(parts)
  179. if "'" in value:
  180. encoding, _, value = value.split("'", 2)
  181. encoding = encoding or "utf-8"
  182. return unquote(value, encoding, "strict")
  183. return value
  184. class MultipartResponseWrapper:
  185. """Wrapper around the MultipartReader.
  186. It takes care about
  187. underlying connection and close it when it needs in.
  188. """
  189. def __init__(
  190. self,
  191. resp: "ClientResponse",
  192. stream: "MultipartReader",
  193. ) -> None:
  194. self.resp = resp
  195. self.stream = stream
  196. def __aiter__(self) -> "MultipartResponseWrapper":
  197. return self
  198. async def __anext__(
  199. self,
  200. ) -> Union["MultipartReader", "BodyPartReader"]:
  201. part = await self.next()
  202. if part is None:
  203. raise StopAsyncIteration
  204. return part
  205. def at_eof(self) -> bool:
  206. """Returns True when all response data had been read."""
  207. return self.resp.content.at_eof()
  208. async def next(
  209. self,
  210. ) -> Optional[Union["MultipartReader", "BodyPartReader"]]:
  211. """Emits next multipart reader object."""
  212. item = await self.stream.next()
  213. if self.stream.at_eof():
  214. await self.release()
  215. return item
  216. async def release(self) -> None:
  217. """Release the connection gracefully.
  218. All remaining content is read to the void.
  219. """
  220. await self.resp.release()
  221. class BodyPartReader:
  222. """Multipart reader for single body part."""
  223. chunk_size = 8192
  224. def __init__(
  225. self,
  226. boundary: bytes,
  227. headers: "CIMultiDictProxy[str]",
  228. content: StreamReader,
  229. *,
  230. subtype: str = "mixed",
  231. default_charset: Optional[str] = None,
  232. max_decompress_size: int = DEFAULT_MAX_DECOMPRESS_SIZE,
  233. ) -> None:
  234. self.headers = headers
  235. self._boundary = boundary
  236. self._boundary_len = len(boundary) + 2 # Boundary + \r\n
  237. self._content = content
  238. self._default_charset = default_charset
  239. self._at_eof = False
  240. self._is_form_data = subtype == "form-data"
  241. # https://datatracker.ietf.org/doc/html/rfc7578#section-4.8
  242. length = None if self._is_form_data else self.headers.get(CONTENT_LENGTH, None)
  243. self._length = int(length) if length is not None else None
  244. self._read_bytes = 0
  245. self._unread: Deque[bytes] = deque()
  246. self._prev_chunk: Optional[bytes] = None
  247. self._content_eof = 0
  248. self._cache: Dict[str, Any] = {}
  249. self._max_decompress_size = max_decompress_size
  250. def __aiter__(self: Self) -> Self:
  251. return self
  252. async def __anext__(self) -> bytes:
  253. part = await self.next()
  254. if part is None:
  255. raise StopAsyncIteration
  256. return part
  257. async def next(self) -> Optional[bytes]:
  258. item = await self.read()
  259. if not item:
  260. return None
  261. return item
  262. async def read(self, *, decode: bool = False) -> bytes:
  263. """Reads body part data.
  264. decode: Decodes data following by encoding
  265. method from Content-Encoding header. If it missed
  266. data remains untouched
  267. """
  268. if self._at_eof:
  269. return b""
  270. data = bytearray()
  271. while not self._at_eof:
  272. data.extend(await self.read_chunk(self.chunk_size))
  273. if decode:
  274. return await self.decode(data)
  275. return data
  276. async def read_chunk(self, size: int = chunk_size) -> bytes:
  277. """Reads body part content chunk of the specified size.
  278. size: chunk size
  279. """
  280. if self._at_eof:
  281. return b""
  282. if self._length:
  283. chunk = await self._read_chunk_from_length(size)
  284. else:
  285. chunk = await self._read_chunk_from_stream(size)
  286. # For the case of base64 data, we must read a fragment of size with a
  287. # remainder of 0 by dividing by 4 for string without symbols \n or \r
  288. encoding = self.headers.get(CONTENT_TRANSFER_ENCODING)
  289. if encoding and encoding.lower() == "base64":
  290. stripped_chunk = b"".join(chunk.split())
  291. remainder = len(stripped_chunk) % 4
  292. while remainder != 0 and not self.at_eof():
  293. over_chunk_size = 4 - remainder
  294. over_chunk = b""
  295. if self._prev_chunk:
  296. over_chunk = self._prev_chunk[:over_chunk_size]
  297. self._prev_chunk = self._prev_chunk[len(over_chunk) :]
  298. if len(over_chunk) != over_chunk_size:
  299. over_chunk += await self._content.read(4 - len(over_chunk))
  300. if not over_chunk:
  301. self._at_eof = True
  302. stripped_chunk += b"".join(over_chunk.split())
  303. chunk += over_chunk
  304. remainder = len(stripped_chunk) % 4
  305. self._read_bytes += len(chunk)
  306. if self._read_bytes == self._length:
  307. self._at_eof = True
  308. if self._at_eof and await self._content.readline() != b"\r\n":
  309. raise ValueError("Reader did not read all the data or it is malformed")
  310. return chunk
  311. async def _read_chunk_from_length(self, size: int) -> bytes:
  312. # Reads body part content chunk of the specified size.
  313. # The body part must has Content-Length header with proper value.
  314. assert self._length is not None, "Content-Length required for chunked read"
  315. chunk_size = min(size, self._length - self._read_bytes)
  316. chunk = await self._content.read(chunk_size)
  317. if self._content.at_eof():
  318. self._at_eof = True
  319. return chunk
  320. async def _read_chunk_from_stream(self, size: int) -> bytes:
  321. # Reads content chunk of body part with unknown length.
  322. # The Content-Length header for body part is not necessary.
  323. assert (
  324. size >= self._boundary_len
  325. ), "Chunk size must be greater or equal than boundary length + 2"
  326. first_chunk = self._prev_chunk is None
  327. if first_chunk:
  328. # We need to re-add the CRLF that got removed from headers parsing.
  329. self._prev_chunk = b"\r\n" + await self._content.read(size)
  330. chunk = b""
  331. # content.read() may return less than size, so we need to loop to ensure
  332. # we have enough data to detect the boundary.
  333. while len(chunk) < self._boundary_len:
  334. chunk += await self._content.read(size)
  335. self._content_eof += int(self._content.at_eof())
  336. if self._content_eof > 2:
  337. raise ValueError("Reading after EOF")
  338. if self._content_eof:
  339. break
  340. if len(chunk) > size:
  341. self._content.unread_data(chunk[size:])
  342. chunk = chunk[:size]
  343. assert self._prev_chunk is not None
  344. window = self._prev_chunk + chunk
  345. sub = b"\r\n" + self._boundary
  346. if first_chunk:
  347. idx = window.find(sub)
  348. else:
  349. idx = window.find(sub, max(0, len(self._prev_chunk) - len(sub)))
  350. if idx >= 0:
  351. # pushing boundary back to content
  352. with warnings.catch_warnings():
  353. warnings.filterwarnings("ignore", category=DeprecationWarning)
  354. self._content.unread_data(window[idx:])
  355. self._prev_chunk = self._prev_chunk[:idx]
  356. chunk = window[len(self._prev_chunk) : idx]
  357. if not chunk:
  358. self._at_eof = True
  359. result = self._prev_chunk[2 if first_chunk else 0 :] # Strip initial CRLF
  360. self._prev_chunk = chunk
  361. return result
  362. async def readline(self) -> bytes:
  363. """Reads body part by line by line."""
  364. if self._at_eof:
  365. return b""
  366. if self._unread:
  367. line = self._unread.popleft()
  368. else:
  369. line = await self._content.readline()
  370. if line.startswith(self._boundary):
  371. # the very last boundary may not come with \r\n,
  372. # so set single rules for everyone
  373. sline = line.rstrip(b"\r\n")
  374. boundary = self._boundary
  375. last_boundary = self._boundary + b"--"
  376. # ensure that we read exactly the boundary, not something alike
  377. if sline == boundary or sline == last_boundary:
  378. self._at_eof = True
  379. self._unread.append(line)
  380. return b""
  381. else:
  382. next_line = await self._content.readline()
  383. if next_line.startswith(self._boundary):
  384. line = line[:-2] # strip CRLF but only once
  385. self._unread.append(next_line)
  386. return line
  387. async def release(self) -> None:
  388. """Like read(), but reads all the data to the void."""
  389. if self._at_eof:
  390. return
  391. while not self._at_eof:
  392. await self.read_chunk(self.chunk_size)
  393. async def text(self, *, encoding: Optional[str] = None) -> str:
  394. """Like read(), but assumes that body part contains text data."""
  395. data = await self.read(decode=True)
  396. # see https://www.w3.org/TR/html5/forms.html#multipart/form-data-encoding-algorithm
  397. # and https://dvcs.w3.org/hg/xhr/raw-file/tip/Overview.html#dom-xmlhttprequest-send
  398. encoding = encoding or self.get_charset(default="utf-8")
  399. return data.decode(encoding)
  400. async def json(self, *, encoding: Optional[str] = None) -> Optional[Dict[str, Any]]:
  401. """Like read(), but assumes that body parts contains JSON data."""
  402. data = await self.read(decode=True)
  403. if not data:
  404. return None
  405. encoding = encoding or self.get_charset(default="utf-8")
  406. return cast(Dict[str, Any], json.loads(data.decode(encoding)))
  407. async def form(self, *, encoding: Optional[str] = None) -> List[Tuple[str, str]]:
  408. """Like read(), but assumes that body parts contain form urlencoded data."""
  409. data = await self.read(decode=True)
  410. if not data:
  411. return []
  412. if encoding is not None:
  413. real_encoding = encoding
  414. else:
  415. real_encoding = self.get_charset(default="utf-8")
  416. try:
  417. decoded_data = data.rstrip().decode(real_encoding)
  418. except UnicodeDecodeError:
  419. raise ValueError("data cannot be decoded with %s encoding" % real_encoding)
  420. return parse_qsl(
  421. decoded_data,
  422. keep_blank_values=True,
  423. encoding=real_encoding,
  424. )
  425. def at_eof(self) -> bool:
  426. """Returns True if the boundary was reached or False otherwise."""
  427. return self._at_eof
  428. async def decode(self, data: bytes) -> bytes:
  429. """Decodes data.
  430. Decoding is done according the specified Content-Encoding
  431. or Content-Transfer-Encoding headers value.
  432. """
  433. if CONTENT_TRANSFER_ENCODING in self.headers:
  434. data = self._decode_content_transfer(data)
  435. # https://datatracker.ietf.org/doc/html/rfc7578#section-4.8
  436. if not self._is_form_data and CONTENT_ENCODING in self.headers:
  437. return await self._decode_content(data)
  438. return data
  439. async def _decode_content(self, data: bytes) -> bytes:
  440. encoding = self.headers.get(CONTENT_ENCODING, "").lower()
  441. if encoding == "identity":
  442. return data
  443. if encoding in {"deflate", "gzip"}:
  444. return await ZLibDecompressor(
  445. encoding=encoding,
  446. suppress_deflate_header=True,
  447. ).decompress(data, max_length=self._max_decompress_size)
  448. raise RuntimeError(f"unknown content encoding: {encoding}")
  449. def _decode_content_transfer(self, data: bytes) -> bytes:
  450. encoding = self.headers.get(CONTENT_TRANSFER_ENCODING, "").lower()
  451. if encoding == "base64":
  452. return base64.b64decode(data)
  453. elif encoding == "quoted-printable":
  454. return binascii.a2b_qp(data)
  455. elif encoding in ("binary", "8bit", "7bit"):
  456. return data
  457. else:
  458. raise RuntimeError(f"unknown content transfer encoding: {encoding}")
  459. def get_charset(self, default: str) -> str:
  460. """Returns charset parameter from Content-Type header or default."""
  461. ctype = self.headers.get(CONTENT_TYPE, "")
  462. mimetype = parse_mimetype(ctype)
  463. return mimetype.parameters.get("charset", self._default_charset or default)
  464. @reify
  465. def name(self) -> Optional[str]:
  466. """Returns name specified in Content-Disposition header.
  467. If the header is missing or malformed, returns None.
  468. """
  469. _, params = parse_content_disposition(self.headers.get(CONTENT_DISPOSITION))
  470. return content_disposition_filename(params, "name")
  471. @reify
  472. def filename(self) -> Optional[str]:
  473. """Returns filename specified in Content-Disposition header.
  474. Returns None if the header is missing or malformed.
  475. """
  476. _, params = parse_content_disposition(self.headers.get(CONTENT_DISPOSITION))
  477. return content_disposition_filename(params, "filename")
  478. @payload_type(BodyPartReader, order=Order.try_first)
  479. class BodyPartReaderPayload(Payload):
  480. _value: BodyPartReader
  481. # _autoclose = False (inherited) - Streaming reader that may have resources
  482. def __init__(self, value: BodyPartReader, *args: Any, **kwargs: Any) -> None:
  483. super().__init__(value, *args, **kwargs)
  484. params: Dict[str, str] = {}
  485. if value.name is not None:
  486. params["name"] = value.name
  487. if value.filename is not None:
  488. params["filename"] = value.filename
  489. if params:
  490. self.set_content_disposition("attachment", True, **params)
  491. def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
  492. raise TypeError("Unable to decode.")
  493. async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes:
  494. """Raises TypeError as body parts should be consumed via write().
  495. This is intentional: BodyPartReader payloads are designed for streaming
  496. large data (potentially gigabytes) and must be consumed only once via
  497. the write() method to avoid memory exhaustion. They cannot be buffered
  498. in memory for reuse.
  499. """
  500. raise TypeError("Unable to read body part as bytes. Use write() to consume.")
  501. async def write(self, writer: AbstractStreamWriter) -> None:
  502. field = self._value
  503. chunk = await field.read_chunk(size=2**16)
  504. while chunk:
  505. await writer.write(await field.decode(chunk))
  506. chunk = await field.read_chunk(size=2**16)
  507. class MultipartReader:
  508. """Multipart body reader."""
  509. #: Response wrapper, used when multipart readers constructs from response.
  510. response_wrapper_cls = MultipartResponseWrapper
  511. #: Multipart reader class, used to handle multipart/* body parts.
  512. #: None points to type(self)
  513. multipart_reader_cls: Optional[Type["MultipartReader"]] = None
  514. #: Body part reader class for non multipart/* content types.
  515. part_reader_cls = BodyPartReader
  516. def __init__(self, headers: Mapping[str, str], content: StreamReader) -> None:
  517. self._mimetype = parse_mimetype(headers[CONTENT_TYPE])
  518. assert self._mimetype.type == "multipart", "multipart/* content type expected"
  519. if "boundary" not in self._mimetype.parameters:
  520. raise ValueError(
  521. "boundary missed for Content-Type: %s" % headers[CONTENT_TYPE]
  522. )
  523. self.headers = headers
  524. self._boundary = ("--" + self._get_boundary()).encode()
  525. self._content = content
  526. self._default_charset: Optional[str] = None
  527. self._last_part: Optional[Union["MultipartReader", BodyPartReader]] = None
  528. self._at_eof = False
  529. self._at_bof = True
  530. self._unread: List[bytes] = []
  531. def __aiter__(self: Self) -> Self:
  532. return self
  533. async def __anext__(
  534. self,
  535. ) -> Optional[Union["MultipartReader", BodyPartReader]]:
  536. part = await self.next()
  537. if part is None:
  538. raise StopAsyncIteration
  539. return part
  540. @classmethod
  541. def from_response(
  542. cls,
  543. response: "ClientResponse",
  544. ) -> MultipartResponseWrapper:
  545. """Constructs reader instance from HTTP response.
  546. :param response: :class:`~aiohttp.client.ClientResponse` instance
  547. """
  548. obj = cls.response_wrapper_cls(
  549. response, cls(response.headers, response.content)
  550. )
  551. return obj
  552. def at_eof(self) -> bool:
  553. """Returns True if the final boundary was reached, false otherwise."""
  554. return self._at_eof
  555. async def next(
  556. self,
  557. ) -> Optional[Union["MultipartReader", BodyPartReader]]:
  558. """Emits the next multipart body part."""
  559. # So, if we're at BOF, we need to skip till the boundary.
  560. if self._at_eof:
  561. return None
  562. await self._maybe_release_last_part()
  563. if self._at_bof:
  564. await self._read_until_first_boundary()
  565. self._at_bof = False
  566. else:
  567. await self._read_boundary()
  568. if self._at_eof: # we just read the last boundary, nothing to do there
  569. return None
  570. part = await self.fetch_next_part()
  571. # https://datatracker.ietf.org/doc/html/rfc7578#section-4.6
  572. if (
  573. self._last_part is None
  574. and self._mimetype.subtype == "form-data"
  575. and isinstance(part, BodyPartReader)
  576. ):
  577. _, params = parse_content_disposition(part.headers.get(CONTENT_DISPOSITION))
  578. if params.get("name") == "_charset_":
  579. # Longest encoding in https://encoding.spec.whatwg.org/encodings.json
  580. # is 19 characters, so 32 should be more than enough for any valid encoding.
  581. charset = await part.read_chunk(32)
  582. if len(charset) > 31:
  583. raise RuntimeError("Invalid default charset")
  584. self._default_charset = charset.strip().decode()
  585. part = await self.fetch_next_part()
  586. self._last_part = part
  587. return self._last_part
  588. async def release(self) -> None:
  589. """Reads all the body parts to the void till the final boundary."""
  590. while not self._at_eof:
  591. item = await self.next()
  592. if item is None:
  593. break
  594. await item.release()
  595. async def fetch_next_part(
  596. self,
  597. ) -> Union["MultipartReader", BodyPartReader]:
  598. """Returns the next body part reader."""
  599. headers = await self._read_headers()
  600. return self._get_part_reader(headers)
  601. def _get_part_reader(
  602. self,
  603. headers: "CIMultiDictProxy[str]",
  604. ) -> Union["MultipartReader", BodyPartReader]:
  605. """Dispatches the response by the `Content-Type` header.
  606. Returns a suitable reader instance.
  607. :param dict headers: Response headers
  608. """
  609. ctype = headers.get(CONTENT_TYPE, "")
  610. mimetype = parse_mimetype(ctype)
  611. if mimetype.type == "multipart":
  612. if self.multipart_reader_cls is None:
  613. return type(self)(headers, self._content)
  614. return self.multipart_reader_cls(headers, self._content)
  615. else:
  616. return self.part_reader_cls(
  617. self._boundary,
  618. headers,
  619. self._content,
  620. subtype=self._mimetype.subtype,
  621. default_charset=self._default_charset,
  622. )
  623. def _get_boundary(self) -> str:
  624. boundary = self._mimetype.parameters["boundary"]
  625. if len(boundary) > 70:
  626. raise ValueError("boundary %r is too long (70 chars max)" % boundary)
  627. return boundary
  628. async def _readline(self) -> bytes:
  629. if self._unread:
  630. return self._unread.pop()
  631. return await self._content.readline()
  632. async def _read_until_first_boundary(self) -> None:
  633. while True:
  634. chunk = await self._readline()
  635. if chunk == b"":
  636. raise ValueError(
  637. "Could not find starting boundary %r" % (self._boundary)
  638. )
  639. chunk = chunk.rstrip()
  640. if chunk == self._boundary:
  641. return
  642. elif chunk == self._boundary + b"--":
  643. self._at_eof = True
  644. return
  645. async def _read_boundary(self) -> None:
  646. chunk = (await self._readline()).rstrip()
  647. if chunk == self._boundary:
  648. pass
  649. elif chunk == self._boundary + b"--":
  650. self._at_eof = True
  651. epilogue = await self._readline()
  652. next_line = await self._readline()
  653. # the epilogue is expected and then either the end of input or the
  654. # parent multipart boundary, if the parent boundary is found then
  655. # it should be marked as unread and handed to the parent for
  656. # processing
  657. if next_line[:2] == b"--":
  658. self._unread.append(next_line)
  659. # otherwise the request is likely missing an epilogue and both
  660. # lines should be passed to the parent for processing
  661. # (this handles the old behavior gracefully)
  662. else:
  663. self._unread.extend([next_line, epilogue])
  664. else:
  665. raise ValueError(f"Invalid boundary {chunk!r}, expected {self._boundary!r}")
  666. async def _read_headers(self) -> "CIMultiDictProxy[str]":
  667. lines = []
  668. while True:
  669. chunk = await self._content.readline()
  670. chunk = chunk.rstrip(b"\r\n")
  671. lines.append(chunk)
  672. if not chunk:
  673. break
  674. parser = HeadersParser()
  675. headers, raw_headers = parser.parse_headers(lines)
  676. return headers
  677. async def _maybe_release_last_part(self) -> None:
  678. """Ensures that the last read body part is read completely."""
  679. if self._last_part is not None:
  680. if not self._last_part.at_eof():
  681. await self._last_part.release()
  682. self._unread.extend(self._last_part._unread)
  683. self._last_part = None
  684. _Part = Tuple[Payload, str, str]
  685. class MultipartWriter(Payload):
  686. """Multipart body writer."""
  687. _value: None
  688. # _consumed = False (inherited) - Can be encoded multiple times
  689. _autoclose = True # No file handles, just collects parts in memory
  690. def __init__(self, subtype: str = "mixed", boundary: Optional[str] = None) -> None:
  691. boundary = boundary if boundary is not None else uuid.uuid4().hex
  692. # The underlying Payload API demands a str (utf-8), not bytes,
  693. # so we need to ensure we don't lose anything during conversion.
  694. # As a result, require the boundary to be ASCII only.
  695. # In both situations.
  696. try:
  697. self._boundary = boundary.encode("ascii")
  698. except UnicodeEncodeError:
  699. raise ValueError("boundary should contain ASCII only chars") from None
  700. ctype = f"multipart/{subtype}; boundary={self._boundary_value}"
  701. super().__init__(None, content_type=ctype)
  702. self._parts: List[_Part] = []
  703. self._is_form_data = subtype == "form-data"
  704. def __enter__(self) -> "MultipartWriter":
  705. return self
  706. def __exit__(
  707. self,
  708. exc_type: Optional[Type[BaseException]],
  709. exc_val: Optional[BaseException],
  710. exc_tb: Optional[TracebackType],
  711. ) -> None:
  712. pass
  713. def __iter__(self) -> Iterator[_Part]:
  714. return iter(self._parts)
  715. def __len__(self) -> int:
  716. return len(self._parts)
  717. def __bool__(self) -> bool:
  718. return True
  719. _valid_tchar_regex = re.compile(rb"\A[!#$%&'*+\-.^_`|~\w]+\Z")
  720. _invalid_qdtext_char_regex = re.compile(rb"[\x00-\x08\x0A-\x1F\x7F]")
  721. @property
  722. def _boundary_value(self) -> str:
  723. """Wrap boundary parameter value in quotes, if necessary.
  724. Reads self.boundary and returns a unicode string.
  725. """
  726. # Refer to RFCs 7231, 7230, 5234.
  727. #
  728. # parameter = token "=" ( token / quoted-string )
  729. # token = 1*tchar
  730. # quoted-string = DQUOTE *( qdtext / quoted-pair ) DQUOTE
  731. # qdtext = HTAB / SP / %x21 / %x23-5B / %x5D-7E / obs-text
  732. # obs-text = %x80-FF
  733. # quoted-pair = "\" ( HTAB / SP / VCHAR / obs-text )
  734. # tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*"
  735. # / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~"
  736. # / DIGIT / ALPHA
  737. # ; any VCHAR, except delimiters
  738. # VCHAR = %x21-7E
  739. value = self._boundary
  740. if re.match(self._valid_tchar_regex, value):
  741. return value.decode("ascii") # cannot fail
  742. if re.search(self._invalid_qdtext_char_regex, value):
  743. raise ValueError("boundary value contains invalid characters")
  744. # escape %x5C and %x22
  745. quoted_value_content = value.replace(b"\\", b"\\\\")
  746. quoted_value_content = quoted_value_content.replace(b'"', b'\\"')
  747. return '"' + quoted_value_content.decode("ascii") + '"'
  748. @property
  749. def boundary(self) -> str:
  750. return self._boundary.decode("ascii")
  751. def append(self, obj: Any, headers: Optional[Mapping[str, str]] = None) -> Payload:
  752. if headers is None:
  753. headers = CIMultiDict()
  754. if isinstance(obj, Payload):
  755. obj.headers.update(headers)
  756. return self.append_payload(obj)
  757. else:
  758. try:
  759. payload = get_payload(obj, headers=headers)
  760. except LookupError:
  761. raise TypeError("Cannot create payload from %r" % obj)
  762. else:
  763. return self.append_payload(payload)
  764. def append_payload(self, payload: Payload) -> Payload:
  765. """Adds a new body part to multipart writer."""
  766. encoding: Optional[str] = None
  767. te_encoding: Optional[str] = None
  768. if self._is_form_data:
  769. # https://datatracker.ietf.org/doc/html/rfc7578#section-4.7
  770. # https://datatracker.ietf.org/doc/html/rfc7578#section-4.8
  771. assert (
  772. not {CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TRANSFER_ENCODING}
  773. & payload.headers.keys()
  774. )
  775. # Set default Content-Disposition in case user doesn't create one
  776. if CONTENT_DISPOSITION not in payload.headers:
  777. name = f"section-{len(self._parts)}"
  778. payload.set_content_disposition("form-data", name=name)
  779. else:
  780. # compression
  781. encoding = payload.headers.get(CONTENT_ENCODING, "").lower()
  782. if encoding and encoding not in ("deflate", "gzip", "identity"):
  783. raise RuntimeError(f"unknown content encoding: {encoding}")
  784. if encoding == "identity":
  785. encoding = None
  786. # te encoding
  787. te_encoding = payload.headers.get(CONTENT_TRANSFER_ENCODING, "").lower()
  788. if te_encoding not in ("", "base64", "quoted-printable", "binary"):
  789. raise RuntimeError(f"unknown content transfer encoding: {te_encoding}")
  790. if te_encoding == "binary":
  791. te_encoding = None
  792. # size
  793. size = payload.size
  794. if size is not None and not (encoding or te_encoding):
  795. payload.headers[CONTENT_LENGTH] = str(size)
  796. self._parts.append((payload, encoding, te_encoding)) # type: ignore[arg-type]
  797. return payload
  798. def append_json(
  799. self, obj: Any, headers: Optional[Mapping[str, str]] = None
  800. ) -> Payload:
  801. """Helper to append JSON part."""
  802. if headers is None:
  803. headers = CIMultiDict()
  804. return self.append_payload(JsonPayload(obj, headers=headers))
  805. def append_form(
  806. self,
  807. obj: Union[Sequence[Tuple[str, str]], Mapping[str, str]],
  808. headers: Optional[Mapping[str, str]] = None,
  809. ) -> Payload:
  810. """Helper to append form urlencoded part."""
  811. assert isinstance(obj, (Sequence, Mapping))
  812. if headers is None:
  813. headers = CIMultiDict()
  814. if isinstance(obj, Mapping):
  815. obj = list(obj.items())
  816. data = urlencode(obj, doseq=True)
  817. return self.append_payload(
  818. StringPayload(
  819. data, headers=headers, content_type="application/x-www-form-urlencoded"
  820. )
  821. )
  822. @property
  823. def size(self) -> Optional[int]:
  824. """Size of the payload."""
  825. total = 0
  826. for part, encoding, te_encoding in self._parts:
  827. part_size = part.size
  828. if encoding or te_encoding or part_size is None:
  829. return None
  830. total += int(
  831. 2
  832. + len(self._boundary)
  833. + 2
  834. + part_size # b'--'+self._boundary+b'\r\n'
  835. + len(part._binary_headers)
  836. + 2 # b'\r\n'
  837. )
  838. total += 2 + len(self._boundary) + 4 # b'--'+self._boundary+b'--\r\n'
  839. return total
  840. def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
  841. """Return string representation of the multipart data.
  842. WARNING: This method may do blocking I/O if parts contain file payloads.
  843. It should not be called in the event loop. Use as_bytes().decode() instead.
  844. """
  845. return "".join(
  846. "--"
  847. + self.boundary
  848. + "\r\n"
  849. + part._binary_headers.decode(encoding, errors)
  850. + part.decode()
  851. for part, _e, _te in self._parts
  852. )
  853. async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes:
  854. """Return bytes representation of the multipart data.
  855. This method is async-safe and calls as_bytes on underlying payloads.
  856. """
  857. parts: List[bytes] = []
  858. # Process each part
  859. for part, _e, _te in self._parts:
  860. # Add boundary
  861. parts.append(b"--" + self._boundary + b"\r\n")
  862. # Add headers
  863. parts.append(part._binary_headers)
  864. # Add payload content using as_bytes for async safety
  865. part_bytes = await part.as_bytes(encoding, errors)
  866. parts.append(part_bytes)
  867. # Add trailing CRLF
  868. parts.append(b"\r\n")
  869. # Add closing boundary
  870. parts.append(b"--" + self._boundary + b"--\r\n")
  871. return b"".join(parts)
  872. async def write(
  873. self, writer: AbstractStreamWriter, close_boundary: bool = True
  874. ) -> None:
  875. """Write body."""
  876. for part, encoding, te_encoding in self._parts:
  877. if self._is_form_data:
  878. # https://datatracker.ietf.org/doc/html/rfc7578#section-4.2
  879. assert CONTENT_DISPOSITION in part.headers
  880. assert "name=" in part.headers[CONTENT_DISPOSITION]
  881. await writer.write(b"--" + self._boundary + b"\r\n")
  882. await writer.write(part._binary_headers)
  883. if encoding or te_encoding:
  884. w = MultipartPayloadWriter(writer)
  885. if encoding:
  886. w.enable_compression(encoding)
  887. if te_encoding:
  888. w.enable_encoding(te_encoding)
  889. await part.write(w) # type: ignore[arg-type]
  890. await w.write_eof()
  891. else:
  892. await part.write(writer)
  893. await writer.write(b"\r\n")
  894. if close_boundary:
  895. await writer.write(b"--" + self._boundary + b"--\r\n")
  896. async def close(self) -> None:
  897. """
  898. Close all part payloads that need explicit closing.
  899. IMPORTANT: This method must not await anything that might not finish
  900. immediately, as it may be called during cleanup/cancellation. Schedule
  901. any long-running operations without awaiting them.
  902. """
  903. if self._consumed:
  904. return
  905. self._consumed = True
  906. # Close all parts that need explicit closing
  907. # We catch and log exceptions to ensure all parts get a chance to close
  908. # we do not use asyncio.gather() here because we are not allowed
  909. # to suspend given we may be called during cleanup
  910. for idx, (part, _, _) in enumerate(self._parts):
  911. if not part.autoclose and not part.consumed:
  912. try:
  913. await part.close()
  914. except Exception as exc:
  915. internal_logger.error(
  916. "Failed to close multipart part %d: %s", idx, exc, exc_info=True
  917. )
  918. class MultipartPayloadWriter:
  919. def __init__(self, writer: AbstractStreamWriter) -> None:
  920. self._writer = writer
  921. self._encoding: Optional[str] = None
  922. self._compress: Optional[ZLibCompressor] = None
  923. self._encoding_buffer: Optional[bytearray] = None
  924. def enable_encoding(self, encoding: str) -> None:
  925. if encoding == "base64":
  926. self._encoding = encoding
  927. self._encoding_buffer = bytearray()
  928. elif encoding == "quoted-printable":
  929. self._encoding = "quoted-printable"
  930. def enable_compression(
  931. self, encoding: str = "deflate", strategy: Optional[int] = None
  932. ) -> None:
  933. self._compress = ZLibCompressor(
  934. encoding=encoding,
  935. suppress_deflate_header=True,
  936. strategy=strategy,
  937. )
  938. async def write_eof(self) -> None:
  939. if self._compress is not None:
  940. chunk = self._compress.flush()
  941. if chunk:
  942. self._compress = None
  943. await self.write(chunk)
  944. if self._encoding == "base64":
  945. if self._encoding_buffer:
  946. await self._writer.write(base64.b64encode(self._encoding_buffer))
  947. async def write(self, chunk: bytes) -> None:
  948. if self._compress is not None:
  949. if chunk:
  950. chunk = await self._compress.compress(chunk)
  951. if not chunk:
  952. return
  953. if self._encoding == "base64":
  954. buf = self._encoding_buffer
  955. assert buf is not None
  956. buf.extend(chunk)
  957. if buf:
  958. div, mod = divmod(len(buf), 3)
  959. enc_chunk, self._encoding_buffer = (buf[: div * 3], buf[div * 3 :])
  960. if enc_chunk:
  961. b64chunk = base64.b64encode(enc_chunk)
  962. await self._writer.write(b64chunk)
  963. elif self._encoding == "quoted-printable":
  964. await self._writer.write(binascii.b2a_qp(chunk))
  965. else:
  966. await self._writer.write(chunk)