cookiejar.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522
  1. import asyncio
  2. import calendar
  3. import contextlib
  4. import datetime
  5. import heapq
  6. import itertools
  7. import os # noqa
  8. import pathlib
  9. import pickle
  10. import re
  11. import time
  12. import warnings
  13. from collections import defaultdict
  14. from collections.abc import Mapping
  15. from http.cookies import BaseCookie, Morsel, SimpleCookie
  16. from typing import (
  17. DefaultDict,
  18. Dict,
  19. Iterable,
  20. Iterator,
  21. List,
  22. Optional,
  23. Set,
  24. Tuple,
  25. Union,
  26. )
  27. from yarl import URL
  28. from ._cookie_helpers import preserve_morsel_with_coded_value
  29. from .abc import AbstractCookieJar, ClearCookiePredicate
  30. from .helpers import is_ip_address
  31. from .typedefs import LooseCookies, PathLike, StrOrURL
  32. __all__ = ("CookieJar", "DummyCookieJar")
  33. CookieItem = Union[str, "Morsel[str]"]
  34. # We cache these string methods here as their use is in performance critical code.
  35. _FORMAT_PATH = "{}/{}".format
  36. _FORMAT_DOMAIN_REVERSED = "{1}.{0}".format
  37. # The minimum number of scheduled cookie expirations before we start cleaning up
  38. # the expiration heap. This is a performance optimization to avoid cleaning up the
  39. # heap too often when there are only a few scheduled expirations.
  40. _MIN_SCHEDULED_COOKIE_EXPIRATION = 100
  41. _SIMPLE_COOKIE = SimpleCookie()
  42. class CookieJar(AbstractCookieJar):
  43. """Implements cookie storage adhering to RFC 6265."""
  44. DATE_TOKENS_RE = re.compile(
  45. r"[\x09\x20-\x2F\x3B-\x40\x5B-\x60\x7B-\x7E]*"
  46. r"(?P<token>[\x00-\x08\x0A-\x1F\d:a-zA-Z\x7F-\xFF]+)"
  47. )
  48. DATE_HMS_TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})")
  49. DATE_DAY_OF_MONTH_RE = re.compile(r"(\d{1,2})")
  50. DATE_MONTH_RE = re.compile(
  51. "(jan)|(feb)|(mar)|(apr)|(may)|(jun)|(jul)|(aug)|(sep)|(oct)|(nov)|(dec)",
  52. re.I,
  53. )
  54. DATE_YEAR_RE = re.compile(r"(\d{2,4})")
  55. # calendar.timegm() fails for timestamps after datetime.datetime.max
  56. # Minus one as a loss of precision occurs when timestamp() is called.
  57. MAX_TIME = (
  58. int(datetime.datetime.max.replace(tzinfo=datetime.timezone.utc).timestamp()) - 1
  59. )
  60. try:
  61. calendar.timegm(time.gmtime(MAX_TIME))
  62. except (OSError, ValueError):
  63. # Hit the maximum representable time on Windows
  64. # https://learn.microsoft.com/en-us/cpp/c-runtime-library/reference/localtime-localtime32-localtime64
  65. # Throws ValueError on PyPy 3.9, OSError elsewhere
  66. MAX_TIME = calendar.timegm((3000, 12, 31, 23, 59, 59, -1, -1, -1))
  67. except OverflowError:
  68. # #4515: datetime.max may not be representable on 32-bit platforms
  69. MAX_TIME = 2**31 - 1
  70. # Avoid minuses in the future, 3x faster
  71. SUB_MAX_TIME = MAX_TIME - 1
  72. def __init__(
  73. self,
  74. *,
  75. unsafe: bool = False,
  76. quote_cookie: bool = True,
  77. treat_as_secure_origin: Union[StrOrURL, List[StrOrURL], None] = None,
  78. loop: Optional[asyncio.AbstractEventLoop] = None,
  79. ) -> None:
  80. super().__init__(loop=loop)
  81. self._cookies: DefaultDict[Tuple[str, str], SimpleCookie] = defaultdict(
  82. SimpleCookie
  83. )
  84. self._morsel_cache: DefaultDict[Tuple[str, str], Dict[str, Morsel[str]]] = (
  85. defaultdict(dict)
  86. )
  87. self._host_only_cookies: Set[Tuple[str, str]] = set()
  88. self._unsafe = unsafe
  89. self._quote_cookie = quote_cookie
  90. if treat_as_secure_origin is None:
  91. treat_as_secure_origin = []
  92. elif isinstance(treat_as_secure_origin, URL):
  93. treat_as_secure_origin = [treat_as_secure_origin.origin()]
  94. elif isinstance(treat_as_secure_origin, str):
  95. treat_as_secure_origin = [URL(treat_as_secure_origin).origin()]
  96. else:
  97. treat_as_secure_origin = [
  98. URL(url).origin() if isinstance(url, str) else url.origin()
  99. for url in treat_as_secure_origin
  100. ]
  101. self._treat_as_secure_origin = treat_as_secure_origin
  102. self._expire_heap: List[Tuple[float, Tuple[str, str, str]]] = []
  103. self._expirations: Dict[Tuple[str, str, str], float] = {}
  104. @property
  105. def quote_cookie(self) -> bool:
  106. return self._quote_cookie
  107. def save(self, file_path: PathLike) -> None:
  108. file_path = pathlib.Path(file_path)
  109. with file_path.open(mode="wb") as f:
  110. pickle.dump(self._cookies, f, pickle.HIGHEST_PROTOCOL)
  111. def load(self, file_path: PathLike) -> None:
  112. file_path = pathlib.Path(file_path)
  113. with file_path.open(mode="rb") as f:
  114. self._cookies = pickle.load(f)
  115. def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None:
  116. if predicate is None:
  117. self._expire_heap.clear()
  118. self._cookies.clear()
  119. self._morsel_cache.clear()
  120. self._host_only_cookies.clear()
  121. self._expirations.clear()
  122. return
  123. now = time.time()
  124. to_del = [
  125. key
  126. for (domain, path), cookie in self._cookies.items()
  127. for name, morsel in cookie.items()
  128. if (
  129. (key := (domain, path, name)) in self._expirations
  130. and self._expirations[key] <= now
  131. )
  132. or predicate(morsel)
  133. ]
  134. if to_del:
  135. self._delete_cookies(to_del)
  136. def clear_domain(self, domain: str) -> None:
  137. self.clear(lambda x: self._is_domain_match(domain, x["domain"]))
  138. def __iter__(self) -> "Iterator[Morsel[str]]":
  139. self._do_expiration()
  140. for val in self._cookies.values():
  141. yield from val.values()
  142. def __len__(self) -> int:
  143. """Return number of cookies.
  144. This function does not iterate self to avoid unnecessary expiration
  145. checks.
  146. """
  147. return sum(len(cookie.values()) for cookie in self._cookies.values())
  148. def _do_expiration(self) -> None:
  149. """Remove expired cookies."""
  150. if not (expire_heap_len := len(self._expire_heap)):
  151. return
  152. # If the expiration heap grows larger than the number expirations
  153. # times two, we clean it up to avoid keeping expired entries in
  154. # the heap and consuming memory. We guard this with a minimum
  155. # threshold to avoid cleaning up the heap too often when there are
  156. # only a few scheduled expirations.
  157. if (
  158. expire_heap_len > _MIN_SCHEDULED_COOKIE_EXPIRATION
  159. and expire_heap_len > len(self._expirations) * 2
  160. ):
  161. # Remove any expired entries from the expiration heap
  162. # that do not match the expiration time in the expirations
  163. # as it means the cookie has been re-added to the heap
  164. # with a different expiration time.
  165. self._expire_heap = [
  166. entry
  167. for entry in self._expire_heap
  168. if self._expirations.get(entry[1]) == entry[0]
  169. ]
  170. heapq.heapify(self._expire_heap)
  171. now = time.time()
  172. to_del: List[Tuple[str, str, str]] = []
  173. # Find any expired cookies and add them to the to-delete list
  174. while self._expire_heap:
  175. when, cookie_key = self._expire_heap[0]
  176. if when > now:
  177. break
  178. heapq.heappop(self._expire_heap)
  179. # Check if the cookie hasn't been re-added to the heap
  180. # with a different expiration time as it will be removed
  181. # later when it reaches the top of the heap and its
  182. # expiration time is met.
  183. if self._expirations.get(cookie_key) == when:
  184. to_del.append(cookie_key)
  185. if to_del:
  186. self._delete_cookies(to_del)
  187. def _delete_cookies(self, to_del: List[Tuple[str, str, str]]) -> None:
  188. for domain, path, name in to_del:
  189. self._host_only_cookies.discard((domain, name))
  190. self._cookies[(domain, path)].pop(name, None)
  191. self._morsel_cache[(domain, path)].pop(name, None)
  192. self._expirations.pop((domain, path, name), None)
  193. def _expire_cookie(self, when: float, domain: str, path: str, name: str) -> None:
  194. cookie_key = (domain, path, name)
  195. if self._expirations.get(cookie_key) == when:
  196. # Avoid adding duplicates to the heap
  197. return
  198. heapq.heappush(self._expire_heap, (when, cookie_key))
  199. self._expirations[cookie_key] = when
  200. def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None:
  201. """Update cookies."""
  202. hostname = response_url.raw_host
  203. if not self._unsafe and is_ip_address(hostname):
  204. # Don't accept cookies from IPs
  205. return
  206. if isinstance(cookies, Mapping):
  207. cookies = cookies.items()
  208. for name, cookie in cookies:
  209. if not isinstance(cookie, Morsel):
  210. tmp = SimpleCookie()
  211. tmp[name] = cookie # type: ignore[assignment]
  212. cookie = tmp[name]
  213. domain = cookie["domain"]
  214. # ignore domains with trailing dots
  215. if domain and domain[-1] == ".":
  216. domain = ""
  217. del cookie["domain"]
  218. if not domain and hostname is not None:
  219. # Set the cookie's domain to the response hostname
  220. # and set its host-only-flag
  221. self._host_only_cookies.add((hostname, name))
  222. domain = cookie["domain"] = hostname
  223. if domain and domain[0] == ".":
  224. # Remove leading dot
  225. domain = domain[1:]
  226. cookie["domain"] = domain
  227. if hostname and not self._is_domain_match(domain, hostname):
  228. # Setting cookies for different domains is not allowed
  229. continue
  230. path = cookie["path"]
  231. if not path or path[0] != "/":
  232. # Set the cookie's path to the response path
  233. path = response_url.path
  234. if not path.startswith("/"):
  235. path = "/"
  236. else:
  237. # Cut everything from the last slash to the end
  238. path = "/" + path[1 : path.rfind("/")]
  239. cookie["path"] = path
  240. path = path.rstrip("/")
  241. if max_age := cookie["max-age"]:
  242. try:
  243. delta_seconds = int(max_age)
  244. max_age_expiration = min(time.time() + delta_seconds, self.MAX_TIME)
  245. self._expire_cookie(max_age_expiration, domain, path, name)
  246. except ValueError:
  247. cookie["max-age"] = ""
  248. elif expires := cookie["expires"]:
  249. if expire_time := self._parse_date(expires):
  250. self._expire_cookie(expire_time, domain, path, name)
  251. else:
  252. cookie["expires"] = ""
  253. key = (domain, path)
  254. if self._cookies[key].get(name) != cookie:
  255. # Don't blow away the cache if the same
  256. # cookie gets set again
  257. self._cookies[key][name] = cookie
  258. self._morsel_cache[key].pop(name, None)
  259. self._do_expiration()
  260. def filter_cookies(self, request_url: URL = URL()) -> "BaseCookie[str]":
  261. """Returns this jar's cookies filtered by their attributes."""
  262. # We always use BaseCookie now since all
  263. # cookies set on on filtered are fully constructed
  264. # Morsels, not just names and values.
  265. filtered: BaseCookie[str] = BaseCookie()
  266. if not self._cookies:
  267. # Skip do_expiration() if there are no cookies.
  268. return filtered
  269. self._do_expiration()
  270. if not self._cookies:
  271. # Skip rest of function if no non-expired cookies.
  272. return filtered
  273. if type(request_url) is not URL:
  274. warnings.warn(
  275. "filter_cookies expects yarl.URL instances only,"
  276. f"and will stop working in 4.x, got {type(request_url)}",
  277. DeprecationWarning,
  278. stacklevel=2,
  279. )
  280. request_url = URL(request_url)
  281. hostname = request_url.raw_host or ""
  282. is_not_secure = request_url.scheme not in ("https", "wss")
  283. if is_not_secure and self._treat_as_secure_origin:
  284. request_origin = URL()
  285. with contextlib.suppress(ValueError):
  286. request_origin = request_url.origin()
  287. is_not_secure = request_origin not in self._treat_as_secure_origin
  288. # Send shared cookie
  289. key = ("", "")
  290. for c in self._cookies[key].values():
  291. # Check cache first
  292. if c.key in self._morsel_cache[key]:
  293. filtered[c.key] = self._morsel_cache[key][c.key]
  294. continue
  295. # Build and cache the morsel
  296. mrsl_val = self._build_morsel(c)
  297. self._morsel_cache[key][c.key] = mrsl_val
  298. filtered[c.key] = mrsl_val
  299. if is_ip_address(hostname):
  300. if not self._unsafe:
  301. return filtered
  302. domains: Iterable[str] = (hostname,)
  303. else:
  304. # Get all the subdomains that might match a cookie (e.g. "foo.bar.com", "bar.com", "com")
  305. domains = itertools.accumulate(
  306. reversed(hostname.split(".")), _FORMAT_DOMAIN_REVERSED
  307. )
  308. # Get all the path prefixes that might match a cookie (e.g. "", "/foo", "/foo/bar")
  309. paths = itertools.accumulate(request_url.path.split("/"), _FORMAT_PATH)
  310. # Create every combination of (domain, path) pairs.
  311. pairs = itertools.product(domains, paths)
  312. path_len = len(request_url.path)
  313. # Point 2: https://www.rfc-editor.org/rfc/rfc6265.html#section-5.4
  314. for p in pairs:
  315. if p not in self._cookies:
  316. continue
  317. for name, cookie in self._cookies[p].items():
  318. domain = cookie["domain"]
  319. if (domain, name) in self._host_only_cookies and domain != hostname:
  320. continue
  321. # Skip edge case when the cookie has a trailing slash but request doesn't.
  322. if len(cookie["path"]) > path_len:
  323. continue
  324. if is_not_secure and cookie["secure"]:
  325. continue
  326. # We already built the Morsel so reuse it here
  327. if name in self._morsel_cache[p]:
  328. filtered[name] = self._morsel_cache[p][name]
  329. continue
  330. # Build and cache the morsel
  331. mrsl_val = self._build_morsel(cookie)
  332. self._morsel_cache[p][name] = mrsl_val
  333. filtered[name] = mrsl_val
  334. return filtered
  335. def _build_morsel(self, cookie: Morsel[str]) -> Morsel[str]:
  336. """Build a morsel for sending, respecting quote_cookie setting."""
  337. if self._quote_cookie and cookie.coded_value and cookie.coded_value[0] == '"':
  338. return preserve_morsel_with_coded_value(cookie)
  339. morsel: Morsel[str] = Morsel()
  340. if self._quote_cookie:
  341. value, coded_value = _SIMPLE_COOKIE.value_encode(cookie.value)
  342. else:
  343. coded_value = value = cookie.value
  344. # We use __setstate__ instead of the public set() API because it allows us to
  345. # bypass validation and set already validated state. This is more stable than
  346. # setting protected attributes directly and unlikely to change since it would
  347. # break pickling.
  348. morsel.__setstate__({"key": cookie.key, "value": value, "coded_value": coded_value}) # type: ignore[attr-defined]
  349. return morsel
  350. @staticmethod
  351. def _is_domain_match(domain: str, hostname: str) -> bool:
  352. """Implements domain matching adhering to RFC 6265."""
  353. if hostname == domain:
  354. return True
  355. if not hostname.endswith(domain):
  356. return False
  357. non_matching = hostname[: -len(domain)]
  358. if not non_matching.endswith("."):
  359. return False
  360. return not is_ip_address(hostname)
  361. @classmethod
  362. def _parse_date(cls, date_str: str) -> Optional[int]:
  363. """Implements date string parsing adhering to RFC 6265."""
  364. if not date_str:
  365. return None
  366. found_time = False
  367. found_day = False
  368. found_month = False
  369. found_year = False
  370. hour = minute = second = 0
  371. day = 0
  372. month = 0
  373. year = 0
  374. for token_match in cls.DATE_TOKENS_RE.finditer(date_str):
  375. token = token_match.group("token")
  376. if not found_time:
  377. time_match = cls.DATE_HMS_TIME_RE.match(token)
  378. if time_match:
  379. found_time = True
  380. hour, minute, second = (int(s) for s in time_match.groups())
  381. continue
  382. if not found_day:
  383. day_match = cls.DATE_DAY_OF_MONTH_RE.match(token)
  384. if day_match:
  385. found_day = True
  386. day = int(day_match.group())
  387. continue
  388. if not found_month:
  389. month_match = cls.DATE_MONTH_RE.match(token)
  390. if month_match:
  391. found_month = True
  392. assert month_match.lastindex is not None
  393. month = month_match.lastindex
  394. continue
  395. if not found_year:
  396. year_match = cls.DATE_YEAR_RE.match(token)
  397. if year_match:
  398. found_year = True
  399. year = int(year_match.group())
  400. if 70 <= year <= 99:
  401. year += 1900
  402. elif 0 <= year <= 69:
  403. year += 2000
  404. if False in (found_day, found_month, found_year, found_time):
  405. return None
  406. if not 1 <= day <= 31:
  407. return None
  408. if year < 1601 or hour > 23 or minute > 59 or second > 59:
  409. return None
  410. return calendar.timegm((year, month, day, hour, minute, second, -1, -1, -1))
  411. class DummyCookieJar(AbstractCookieJar):
  412. """Implements a dummy cookie storage.
  413. It can be used with the ClientSession when no cookie processing is needed.
  414. """
  415. def __init__(self, *, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
  416. super().__init__(loop=loop)
  417. def __iter__(self) -> "Iterator[Morsel[str]]":
  418. while False:
  419. yield None
  420. def __len__(self) -> int:
  421. return 0
  422. @property
  423. def quote_cookie(self) -> bool:
  424. return True
  425. def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None:
  426. pass
  427. def clear_domain(self, domain: str) -> None:
  428. pass
  429. def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None:
  430. pass
  431. def filter_cookies(self, request_url: URL) -> "BaseCookie[str]":
  432. return SimpleCookie()