compression_utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. import asyncio
  2. import sys
  3. import zlib
  4. from abc import ABC, abstractmethod
  5. from concurrent.futures import Executor
  6. from typing import Any, Final, Optional, Protocol, TypedDict, cast
  7. if sys.version_info >= (3, 12):
  8. from collections.abc import Buffer
  9. else:
  10. from typing import Union
  11. Buffer = Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"]
  12. try:
  13. try:
  14. import brotlicffi as brotli
  15. except ImportError:
  16. import brotli
  17. HAS_BROTLI = True
  18. except ImportError: # pragma: no cover
  19. HAS_BROTLI = False
  20. try:
  21. if sys.version_info >= (3, 14):
  22. from compression.zstd import ZstdDecompressor # noqa: I900
  23. else: # TODO(PY314): Remove mentions of backports.zstd across codebase
  24. from backports.zstd import ZstdDecompressor
  25. HAS_ZSTD = True
  26. except ImportError:
  27. HAS_ZSTD = False
  28. MAX_SYNC_CHUNK_SIZE = 4096
  29. DEFAULT_MAX_DECOMPRESS_SIZE = 2**25 # 32MiB
  30. # Unlimited decompression constants - different libraries use different conventions
  31. ZLIB_MAX_LENGTH_UNLIMITED = 0 # zlib uses 0 to mean unlimited
  32. ZSTD_MAX_LENGTH_UNLIMITED = -1 # zstd uses -1 to mean unlimited
  33. class ZLibCompressObjProtocol(Protocol):
  34. def compress(self, data: Buffer) -> bytes: ...
  35. def flush(self, mode: int = ..., /) -> bytes: ...
  36. class ZLibDecompressObjProtocol(Protocol):
  37. def decompress(self, data: Buffer, max_length: int = ...) -> bytes: ...
  38. def flush(self, length: int = ..., /) -> bytes: ...
  39. @property
  40. def eof(self) -> bool: ...
  41. class ZLibBackendProtocol(Protocol):
  42. MAX_WBITS: int
  43. Z_FULL_FLUSH: int
  44. Z_SYNC_FLUSH: int
  45. Z_BEST_SPEED: int
  46. Z_FINISH: int
  47. def compressobj(
  48. self,
  49. level: int = ...,
  50. method: int = ...,
  51. wbits: int = ...,
  52. memLevel: int = ...,
  53. strategy: int = ...,
  54. zdict: Optional[Buffer] = ...,
  55. ) -> ZLibCompressObjProtocol: ...
  56. def decompressobj(
  57. self, wbits: int = ..., zdict: Buffer = ...
  58. ) -> ZLibDecompressObjProtocol: ...
  59. def compress(
  60. self, data: Buffer, /, level: int = ..., wbits: int = ...
  61. ) -> bytes: ...
  62. def decompress(
  63. self, data: Buffer, /, wbits: int = ..., bufsize: int = ...
  64. ) -> bytes: ...
  65. class CompressObjArgs(TypedDict, total=False):
  66. wbits: int
  67. strategy: int
  68. level: int
  69. class ZLibBackendWrapper:
  70. def __init__(self, _zlib_backend: ZLibBackendProtocol):
  71. self._zlib_backend: ZLibBackendProtocol = _zlib_backend
  72. @property
  73. def name(self) -> str:
  74. return getattr(self._zlib_backend, "__name__", "undefined")
  75. @property
  76. def MAX_WBITS(self) -> int:
  77. return self._zlib_backend.MAX_WBITS
  78. @property
  79. def Z_FULL_FLUSH(self) -> int:
  80. return self._zlib_backend.Z_FULL_FLUSH
  81. @property
  82. def Z_SYNC_FLUSH(self) -> int:
  83. return self._zlib_backend.Z_SYNC_FLUSH
  84. @property
  85. def Z_BEST_SPEED(self) -> int:
  86. return self._zlib_backend.Z_BEST_SPEED
  87. @property
  88. def Z_FINISH(self) -> int:
  89. return self._zlib_backend.Z_FINISH
  90. def compressobj(self, *args: Any, **kwargs: Any) -> ZLibCompressObjProtocol:
  91. return self._zlib_backend.compressobj(*args, **kwargs)
  92. def decompressobj(self, *args: Any, **kwargs: Any) -> ZLibDecompressObjProtocol:
  93. return self._zlib_backend.decompressobj(*args, **kwargs)
  94. def compress(self, data: Buffer, *args: Any, **kwargs: Any) -> bytes:
  95. return self._zlib_backend.compress(data, *args, **kwargs)
  96. def decompress(self, data: Buffer, *args: Any, **kwargs: Any) -> bytes:
  97. return self._zlib_backend.decompress(data, *args, **kwargs)
  98. # Everything not explicitly listed in the Protocol we just pass through
  99. def __getattr__(self, attrname: str) -> Any:
  100. return getattr(self._zlib_backend, attrname)
  101. ZLibBackend: ZLibBackendWrapper = ZLibBackendWrapper(zlib)
  102. def set_zlib_backend(new_zlib_backend: ZLibBackendProtocol) -> None:
  103. ZLibBackend._zlib_backend = new_zlib_backend
  104. def encoding_to_mode(
  105. encoding: Optional[str] = None,
  106. suppress_deflate_header: bool = False,
  107. ) -> int:
  108. if encoding == "gzip":
  109. return 16 + ZLibBackend.MAX_WBITS
  110. return -ZLibBackend.MAX_WBITS if suppress_deflate_header else ZLibBackend.MAX_WBITS
  111. class DecompressionBaseHandler(ABC):
  112. def __init__(
  113. self,
  114. executor: Optional[Executor] = None,
  115. max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
  116. ):
  117. """Base class for decompression handlers."""
  118. self._executor = executor
  119. self._max_sync_chunk_size = max_sync_chunk_size
  120. @abstractmethod
  121. def decompress_sync(
  122. self, data: bytes, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED
  123. ) -> bytes:
  124. """Decompress the given data."""
  125. async def decompress(
  126. self, data: bytes, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED
  127. ) -> bytes:
  128. """Decompress the given data."""
  129. if (
  130. self._max_sync_chunk_size is not None
  131. and len(data) > self._max_sync_chunk_size
  132. ):
  133. return await asyncio.get_event_loop().run_in_executor(
  134. self._executor, self.decompress_sync, data, max_length
  135. )
  136. return self.decompress_sync(data, max_length)
  137. class ZLibCompressor:
  138. def __init__(
  139. self,
  140. encoding: Optional[str] = None,
  141. suppress_deflate_header: bool = False,
  142. level: Optional[int] = None,
  143. wbits: Optional[int] = None,
  144. strategy: Optional[int] = None,
  145. executor: Optional[Executor] = None,
  146. max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
  147. ):
  148. self._executor = executor
  149. self._max_sync_chunk_size = max_sync_chunk_size
  150. self._mode = (
  151. encoding_to_mode(encoding, suppress_deflate_header)
  152. if wbits is None
  153. else wbits
  154. )
  155. self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend)
  156. kwargs: CompressObjArgs = {}
  157. kwargs["wbits"] = self._mode
  158. if strategy is not None:
  159. kwargs["strategy"] = strategy
  160. if level is not None:
  161. kwargs["level"] = level
  162. self._compressor = self._zlib_backend.compressobj(**kwargs)
  163. def compress_sync(self, data: bytes) -> bytes:
  164. return self._compressor.compress(data)
  165. async def compress(self, data: bytes) -> bytes:
  166. """Compress the data and returned the compressed bytes.
  167. Note that flush() must be called after the last call to compress()
  168. If the data size is large than the max_sync_chunk_size, the compression
  169. will be done in the executor. Otherwise, the compression will be done
  170. in the event loop.
  171. **WARNING: This method is NOT cancellation-safe when used with flush().**
  172. If this operation is cancelled, the compressor state may be corrupted.
  173. The connection MUST be closed after cancellation to avoid data corruption
  174. in subsequent compress operations.
  175. For cancellation-safe compression (e.g., WebSocket), the caller MUST wrap
  176. compress() + flush() + send operations in a shield and lock to ensure atomicity.
  177. """
  178. # For large payloads, offload compression to executor to avoid blocking event loop
  179. should_use_executor = (
  180. self._max_sync_chunk_size is not None
  181. and len(data) > self._max_sync_chunk_size
  182. )
  183. if should_use_executor:
  184. return await asyncio.get_running_loop().run_in_executor(
  185. self._executor, self._compressor.compress, data
  186. )
  187. return self.compress_sync(data)
  188. def flush(self, mode: Optional[int] = None) -> bytes:
  189. """Flush the compressor synchronously.
  190. **WARNING: This method is NOT cancellation-safe when called after compress().**
  191. The flush() operation accesses shared compressor state. If compress() was
  192. cancelled, calling flush() may result in corrupted data. The connection MUST
  193. be closed after compress() cancellation.
  194. For cancellation-safe compression (e.g., WebSocket), the caller MUST wrap
  195. compress() + flush() + send operations in a shield and lock to ensure atomicity.
  196. """
  197. return self._compressor.flush(
  198. mode if mode is not None else self._zlib_backend.Z_FINISH
  199. )
  200. class ZLibDecompressor(DecompressionBaseHandler):
  201. def __init__(
  202. self,
  203. encoding: Optional[str] = None,
  204. suppress_deflate_header: bool = False,
  205. executor: Optional[Executor] = None,
  206. max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
  207. ):
  208. super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size)
  209. self._mode = encoding_to_mode(encoding, suppress_deflate_header)
  210. self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend)
  211. self._decompressor = self._zlib_backend.decompressobj(wbits=self._mode)
  212. def decompress_sync(
  213. self, data: Buffer, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED
  214. ) -> bytes:
  215. return self._decompressor.decompress(data, max_length)
  216. def flush(self, length: int = 0) -> bytes:
  217. return (
  218. self._decompressor.flush(length)
  219. if length > 0
  220. else self._decompressor.flush()
  221. )
  222. @property
  223. def eof(self) -> bool:
  224. return self._decompressor.eof
  225. class BrotliDecompressor(DecompressionBaseHandler):
  226. # Supports both 'brotlipy' and 'Brotli' packages
  227. # since they share an import name. The top branches
  228. # are for 'brotlipy' and bottom branches for 'Brotli'
  229. def __init__(
  230. self,
  231. executor: Optional[Executor] = None,
  232. max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
  233. ) -> None:
  234. """Decompress data using the Brotli library."""
  235. if not HAS_BROTLI:
  236. raise RuntimeError(
  237. "The brotli decompression is not available. "
  238. "Please install `Brotli` module"
  239. )
  240. self._obj = brotli.Decompressor()
  241. super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size)
  242. def decompress_sync(
  243. self, data: Buffer, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED
  244. ) -> bytes:
  245. """Decompress the given data."""
  246. if hasattr(self._obj, "decompress"):
  247. return cast(bytes, self._obj.decompress(data, max_length))
  248. return cast(bytes, self._obj.process(data, max_length))
  249. def flush(self) -> bytes:
  250. """Flush the decompressor."""
  251. if hasattr(self._obj, "flush"):
  252. return cast(bytes, self._obj.flush())
  253. return b""
  254. class ZSTDDecompressor(DecompressionBaseHandler):
  255. def __init__(
  256. self,
  257. executor: Optional[Executor] = None,
  258. max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
  259. ) -> None:
  260. if not HAS_ZSTD:
  261. raise RuntimeError(
  262. "The zstd decompression is not available. "
  263. "Please install `backports.zstd` module"
  264. )
  265. self._obj = ZstdDecompressor()
  266. super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size)
  267. def decompress_sync(
  268. self, data: bytes, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED
  269. ) -> bytes:
  270. # zstd uses -1 for unlimited, while zlib uses 0 for unlimited
  271. # Convert the zlib convention (0=unlimited) to zstd convention (-1=unlimited)
  272. zstd_max_length = (
  273. ZSTD_MAX_LENGTH_UNLIMITED
  274. if max_length == ZLIB_MAX_LENGTH_UNLIMITED
  275. else max_length
  276. )
  277. return self._obj.decompress(data, zstd_max_length)
  278. def flush(self) -> bytes:
  279. return b""