io_utils.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import contextlib
  4. import os
  5. import queue
  6. import re
  7. import time
  8. from threading import Condition, get_ident, Lock, Thread
  9. import numpy as np
  10. import torch
  11. import torch.nn.functional as F
  12. import torchvision.transforms.functional as TF
  13. from PIL import Image
  14. from sam3.logger import get_logger
  15. from tqdm import tqdm
  16. logger = get_logger(__name__)
  17. IS_MAIN_PROCESS = os.getenv("IS_MAIN_PROCESS", "1") == "1"
  18. RANK = int(os.getenv("RANK", "0"))
  19. IMAGE_EXTS = [".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"]
  20. VIDEO_EXTS = [".mp4", ".mov", ".avi", ".mkv", ".webm"]
  21. def load_resource_as_video_frames(
  22. resource_path,
  23. image_size,
  24. offload_video_to_cpu,
  25. img_mean=(0.5, 0.5, 0.5),
  26. img_std=(0.5, 0.5, 0.5),
  27. async_loading_frames=False,
  28. video_loader_type="cv2",
  29. ):
  30. """
  31. Load video frames from either a video or an image (as a single-frame video).
  32. Alternatively, if input is a list of PIL images, convert its format
  33. """
  34. if isinstance(resource_path, list):
  35. img_mean = torch.tensor(img_mean, dtype=torch.float16)[:, None, None]
  36. img_std = torch.tensor(img_std, dtype=torch.float16)[:, None, None]
  37. assert all(isinstance(img_pil, Image.Image) for img_pil in resource_path)
  38. assert len(resource_path) is not None
  39. orig_height, orig_width = resource_path[0].size
  40. orig_height, orig_width = (
  41. orig_width,
  42. orig_height,
  43. ) # For some reason, this method returns these swapped
  44. images = []
  45. for img_pil in resource_path:
  46. img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size)))
  47. assert img_np.dtype == np.uint8, "np.uint8 is expected for JPEG images"
  48. img_np = img_np / 255.0
  49. img = torch.from_numpy(img_np).permute(2, 0, 1)
  50. # float16 precision should be sufficient for image tensor storage
  51. img = img.to(dtype=torch.float16)
  52. # normalize by mean and std
  53. img -= img_mean
  54. img /= img_std
  55. images.append(img)
  56. images = torch.stack(images)
  57. if not offload_video_to_cpu:
  58. images = images.cuda()
  59. return images, orig_height, orig_width
  60. is_image = (
  61. isinstance(resource_path, str)
  62. and os.path.splitext(resource_path)[-1].lower() in IMAGE_EXTS
  63. )
  64. if is_image:
  65. return load_image_as_single_frame_video(
  66. image_path=resource_path,
  67. image_size=image_size,
  68. offload_video_to_cpu=offload_video_to_cpu,
  69. img_mean=img_mean,
  70. img_std=img_std,
  71. )
  72. else:
  73. return load_video_frames(
  74. video_path=resource_path,
  75. image_size=image_size,
  76. offload_video_to_cpu=offload_video_to_cpu,
  77. img_mean=img_mean,
  78. img_std=img_std,
  79. async_loading_frames=async_loading_frames,
  80. video_loader_type=video_loader_type,
  81. )
  82. def load_image_as_single_frame_video(
  83. image_path,
  84. image_size,
  85. offload_video_to_cpu,
  86. img_mean=(0.5, 0.5, 0.5),
  87. img_std=(0.5, 0.5, 0.5),
  88. ):
  89. """Load an image as a single-frame video."""
  90. images, image_height, image_width = _load_img_as_tensor(image_path, image_size)
  91. images = images.unsqueeze(0).half()
  92. img_mean = torch.tensor(img_mean, dtype=torch.float16)[:, None, None]
  93. img_std = torch.tensor(img_std, dtype=torch.float16)[:, None, None]
  94. if not offload_video_to_cpu:
  95. images = images.cuda()
  96. img_mean = img_mean.cuda()
  97. img_std = img_std.cuda()
  98. # normalize by mean and std
  99. images -= img_mean
  100. images /= img_std
  101. return images, image_height, image_width
  102. def load_video_frames(
  103. video_path,
  104. image_size,
  105. offload_video_to_cpu,
  106. img_mean=(0.5, 0.5, 0.5),
  107. img_std=(0.5, 0.5, 0.5),
  108. async_loading_frames=False,
  109. video_loader_type="cv2",
  110. ):
  111. """
  112. Load the video frames from video_path. The frames are resized to image_size as in
  113. the model and are loaded to GPU if offload_video_to_cpu=False. This is used by the demo.
  114. """
  115. assert isinstance(video_path, str)
  116. if video_path.startswith("<load-dummy-video"):
  117. # Check for pattern <load-dummy-video-N> where N is an integer
  118. match = re.match(r"<load-dummy-video-(\d+)>", video_path)
  119. num_frames = int(match.group(1)) if match else 60
  120. return load_dummy_video(image_size, offload_video_to_cpu, num_frames=num_frames)
  121. elif os.path.isdir(video_path):
  122. return load_video_frames_from_image_folder(
  123. image_folder=video_path,
  124. image_size=image_size,
  125. offload_video_to_cpu=offload_video_to_cpu,
  126. img_mean=img_mean,
  127. img_std=img_std,
  128. async_loading_frames=async_loading_frames,
  129. )
  130. elif os.path.splitext(video_path)[-1].lower() in VIDEO_EXTS:
  131. return load_video_frames_from_video_file(
  132. video_path=video_path,
  133. image_size=image_size,
  134. offload_video_to_cpu=offload_video_to_cpu,
  135. img_mean=img_mean,
  136. img_std=img_std,
  137. async_loading_frames=async_loading_frames,
  138. video_loader_type=video_loader_type,
  139. )
  140. else:
  141. raise NotImplementedError("Only video files and image folders are supported")
  142. def load_video_frames_from_image_folder(
  143. image_folder,
  144. image_size,
  145. offload_video_to_cpu,
  146. img_mean,
  147. img_std,
  148. async_loading_frames,
  149. ):
  150. """
  151. Load the video frames from a directory of image files ("<frame_index>.<img_ext>" format)
  152. """
  153. frame_names = [
  154. p
  155. for p in os.listdir(image_folder)
  156. if os.path.splitext(p)[-1].lower() in IMAGE_EXTS
  157. ]
  158. try:
  159. frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
  160. except ValueError:
  161. # fallback to lexicographic sort if the format is not "<frame_index>.<img_ext>"
  162. logger.warning(
  163. f'frame names are not in "<frame_index>.<img_ext>" format: {frame_names[:5]=}, '
  164. f"falling back to lexicographic sort."
  165. )
  166. frame_names.sort()
  167. num_frames = len(frame_names)
  168. if num_frames == 0:
  169. raise RuntimeError(f"no images found in {image_folder}")
  170. img_paths = [os.path.join(image_folder, frame_name) for frame_name in frame_names]
  171. img_mean = torch.tensor(img_mean, dtype=torch.float16)[:, None, None]
  172. img_std = torch.tensor(img_std, dtype=torch.float16)[:, None, None]
  173. if async_loading_frames:
  174. lazy_images = AsyncImageFrameLoader(
  175. img_paths, image_size, offload_video_to_cpu, img_mean, img_std
  176. )
  177. return lazy_images, lazy_images.video_height, lazy_images.video_width
  178. # float16 precision should be sufficient for image tensor storage
  179. images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float16)
  180. video_height, video_width = None, None
  181. for n, img_path in enumerate(
  182. tqdm(img_paths, desc=f"frame loading (image folder) [rank={RANK}]")
  183. ):
  184. images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
  185. if not offload_video_to_cpu:
  186. images = images.cuda()
  187. img_mean = img_mean.cuda()
  188. img_std = img_std.cuda()
  189. # normalize by mean and std
  190. images -= img_mean
  191. images /= img_std
  192. return images, video_height, video_width
  193. def load_video_frames_from_video_file(
  194. video_path,
  195. image_size,
  196. offload_video_to_cpu,
  197. img_mean,
  198. img_std,
  199. async_loading_frames,
  200. gpu_acceleration=False,
  201. gpu_device=None,
  202. video_loader_type="cv2",
  203. ):
  204. """Load the video frames from a video file."""
  205. if video_loader_type == "cv2":
  206. return load_video_frames_from_video_file_using_cv2(
  207. video_path=video_path,
  208. image_size=image_size,
  209. img_mean=img_mean,
  210. img_std=img_std,
  211. offload_video_to_cpu=offload_video_to_cpu,
  212. )
  213. elif video_loader_type == "torchcodec":
  214. logger.info("Using torchcodec to load video file")
  215. lazy_images = AsyncVideoFileLoaderWithTorchCodec(
  216. video_path=video_path,
  217. image_size=image_size,
  218. offload_video_to_cpu=offload_video_to_cpu,
  219. img_mean=img_mean,
  220. img_std=img_std,
  221. gpu_acceleration=gpu_acceleration,
  222. gpu_device=gpu_device,
  223. )
  224. # The `AsyncVideoFileLoaderWithTorchCodec` class always loads the videos asynchronously,
  225. # so we just wait for its loading thread to finish if async_loading_frames=False.
  226. if not async_loading_frames:
  227. async_thread = lazy_images.thread
  228. if async_thread is not None:
  229. async_thread.join()
  230. return lazy_images, lazy_images.video_height, lazy_images.video_width
  231. else:
  232. raise RuntimeError("video_loader_type must be either 'cv2' or 'torchcodec'")
  233. def load_video_frames_from_video_file_using_cv2(
  234. video_path: str,
  235. image_size: int,
  236. img_mean: tuple = (0.5, 0.5, 0.5),
  237. img_std: tuple = (0.5, 0.5, 0.5),
  238. offload_video_to_cpu: bool = False,
  239. ) -> torch.Tensor:
  240. """
  241. Load video from path, convert to normalized tensor with specified preprocessing
  242. Args:
  243. video_path: Path to video file
  244. image_size: Target size for square frames (height and width)
  245. img_mean: Normalization mean (RGB)
  246. img_std: Normalization standard deviation (RGB)
  247. Returns:
  248. torch.Tensor: Preprocessed video tensor in shape (T, C, H, W) with float16 dtype
  249. """
  250. import cv2 # delay OpenCV import to avoid unnecessary dependency
  251. # Initialize video capture
  252. cap = cv2.VideoCapture(video_path)
  253. if not cap.isOpened():
  254. raise ValueError(f"Could not open video: {video_path}")
  255. original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  256. original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  257. num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
  258. num_frames = num_frames if num_frames > 0 else None
  259. frames = []
  260. pbar = tqdm(desc=f"frame loading (OpenCV) [rank={RANK}]", total=num_frames)
  261. while True:
  262. ret, frame = cap.read()
  263. if not ret:
  264. break
  265. # Convert BGR to RGB and resize
  266. frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
  267. frame_resized = cv2.resize(
  268. frame_rgb, (image_size, image_size), interpolation=cv2.INTER_CUBIC
  269. )
  270. frames.append(frame_resized)
  271. pbar.update(1)
  272. cap.release()
  273. pbar.close()
  274. # Convert to tensor
  275. frames_np = np.stack(frames, axis=0).astype(np.float32) # (T, H, W, C)
  276. video_tensor = torch.from_numpy(frames_np).permute(0, 3, 1, 2) # (T, C, H, W)
  277. img_mean = torch.tensor(img_mean, dtype=torch.float16).view(1, 3, 1, 1)
  278. img_std = torch.tensor(img_std, dtype=torch.float16).view(1, 3, 1, 1)
  279. if not offload_video_to_cpu:
  280. video_tensor = video_tensor.cuda()
  281. img_mean = img_mean.cuda()
  282. img_std = img_std.cuda()
  283. # normalize by mean and std
  284. video_tensor -= img_mean
  285. video_tensor /= img_std
  286. return video_tensor, original_height, original_width
  287. def load_dummy_video(image_size, offload_video_to_cpu, num_frames=60):
  288. """
  289. Load a dummy video with random frames for testing and compilation warmup purposes.
  290. """
  291. video_height, video_width = 480, 640 # dummy original video sizes
  292. images = torch.randn(num_frames, 3, image_size, image_size, dtype=torch.float16)
  293. if not offload_video_to_cpu:
  294. images = images.cuda()
  295. return images, video_height, video_width
  296. def _load_img_as_tensor(img_path, image_size):
  297. """Load and resize an image and convert it into a PyTorch tensor."""
  298. img = Image.open(img_path).convert("RGB")
  299. orig_width, orig_height = img.width, img.height
  300. img = TF.resize(img, size=(image_size, image_size))
  301. img = TF.to_tensor(img)
  302. return img, orig_height, orig_width
  303. class AsyncImageFrameLoader:
  304. """
  305. A list of video frames to be load asynchronously without blocking session start.
  306. """
  307. def __init__(self, img_paths, image_size, offload_video_to_cpu, img_mean, img_std):
  308. self.img_paths = img_paths
  309. self.image_size = image_size
  310. self.offload_video_to_cpu = offload_video_to_cpu
  311. self.img_mean = img_mean
  312. self.img_std = img_std
  313. # items in `self._images` will be loaded asynchronously
  314. self.images = [None] * len(img_paths)
  315. # catch and raise any exceptions in the async loading thread
  316. self.exception = None
  317. # video_height and video_width be filled when loading the first image
  318. self.video_height = None
  319. self.video_width = None
  320. # load the first frame to fill video_height and video_width and also
  321. # to cache it (since it's most likely where the user will click)
  322. self.__getitem__(0)
  323. # load the rest of frames asynchronously without blocking the session start
  324. def _load_frames():
  325. try:
  326. for n in tqdm(
  327. range(len(self.images)),
  328. desc=f"frame loading (image folder) [rank={RANK}]",
  329. ):
  330. self.__getitem__(n)
  331. except Exception as e:
  332. self.exception = e
  333. self.thread = Thread(target=_load_frames, daemon=True)
  334. self.thread.start()
  335. def __getitem__(self, index):
  336. if self.exception is not None:
  337. raise RuntimeError("Failure in frame loading thread") from self.exception
  338. img = self.images[index]
  339. if img is not None:
  340. return img
  341. img, video_height, video_width = _load_img_as_tensor(
  342. self.img_paths[index], self.image_size
  343. )
  344. self.video_height = video_height
  345. self.video_width = video_width
  346. # float16 precision should be sufficient for image tensor storage
  347. img = img.to(dtype=torch.float16)
  348. # normalize by mean and std
  349. img -= self.img_mean
  350. img /= self.img_std
  351. if not self.offload_video_to_cpu:
  352. img = img.cuda()
  353. self.images[index] = img
  354. return img
  355. def __len__(self):
  356. return len(self.images)
  357. class TorchCodecDecoder:
  358. """
  359. A wrapper to support GPU device and num_threads in TorchCodec decoder,
  360. which are not supported by `torchcodec.decoders.SimpleVideoDecoder` yet.
  361. """
  362. def __init__(self, source, dimension_order="NCHW", device="cpu", num_threads=1):
  363. from torchcodec import _core as core
  364. self._source = source # hold a reference to the source to prevent it from GC
  365. if isinstance(source, str):
  366. self._decoder = core.create_from_file(source, "exact")
  367. elif isinstance(source, bytes):
  368. self._decoder = core.create_from_bytes(source, "exact")
  369. else:
  370. raise TypeError(f"Unknown source type: {type(source)}.")
  371. assert dimension_order in ("NCHW", "NHWC")
  372. device_string = str(device)
  373. core.scan_all_streams_to_update_metadata(self._decoder)
  374. core.add_video_stream(
  375. self._decoder,
  376. dimension_order=dimension_order,
  377. device=device_string,
  378. num_threads=(1 if "cuda" in device_string else num_threads),
  379. )
  380. video_metadata = core.get_container_metadata(self._decoder)
  381. best_stream_index = video_metadata.best_video_stream_index
  382. assert best_stream_index is not None
  383. self.metadata = video_metadata.streams[best_stream_index]
  384. assert self.metadata.num_frames_from_content is not None
  385. self._num_frames = self.metadata.num_frames_from_content
  386. def __len__(self) -> int:
  387. return self._num_frames
  388. def __getitem__(self, key: int):
  389. from torchcodec import _core as core
  390. if key < 0:
  391. key += self._num_frames
  392. if key >= self._num_frames or key < 0:
  393. raise IndexError(
  394. f"Index {key} is out of bounds; length is {self._num_frames}"
  395. )
  396. frame_data, *_ = core.get_frame_at_index(
  397. self._decoder,
  398. frame_index=key,
  399. )
  400. return frame_data
  401. class FIFOLock:
  402. """A lock that ensures FIFO ordering of lock acquisitions."""
  403. def __init__(self):
  404. self._lock = Lock()
  405. self._waiters = queue.Queue()
  406. self._condition = Condition()
  407. def acquire(self):
  408. ident = get_ident()
  409. with self._condition:
  410. self._waiters.put(ident)
  411. while self._waiters.queue[0] != ident or not self._lock.acquire(
  412. blocking=False
  413. ):
  414. self._condition.wait()
  415. # got the lock and it's our turn
  416. def release(self):
  417. with self._condition:
  418. self._lock.release()
  419. self._waiters.get()
  420. self._condition.notify_all()
  421. def __enter__(self):
  422. self.acquire()
  423. def __exit__(self, t, v, tb):
  424. self.release()
  425. class AsyncVideoFileLoaderWithTorchCodec:
  426. """
  427. Loading frames from video files asynchronously without blocking session start.
  428. Unlike `AsyncVideoFileLoader`, this class uses PyTorch's offical TorchCodec library
  429. for video decoding, which is more efficient and supports more video formats.
  430. """
  431. def __init__(
  432. self,
  433. video_path,
  434. image_size,
  435. offload_video_to_cpu,
  436. img_mean,
  437. img_std,
  438. gpu_acceleration=True,
  439. gpu_device=None,
  440. use_rand_seek_in_loading=False,
  441. ):
  442. # Check and possibly infer the output device (and also get its GPU id when applicable)
  443. assert gpu_device is None or gpu_device.type == "cuda"
  444. gpu_id = (
  445. gpu_device.index
  446. if gpu_device is not None and gpu_device.index is not None
  447. else torch.cuda.current_device()
  448. )
  449. if offload_video_to_cpu:
  450. out_device = torch.device("cpu")
  451. else:
  452. out_device = torch.device("cuda") if gpu_device is None else gpu_device
  453. self.out_device = out_device
  454. self.gpu_acceleration = gpu_acceleration
  455. self.gpu_id = gpu_id
  456. self.image_size = image_size
  457. self.offload_video_to_cpu = offload_video_to_cpu
  458. if not isinstance(img_mean, torch.Tensor):
  459. img_mean = torch.tensor(img_mean, dtype=torch.float16)[:, None, None]
  460. self.img_mean = img_mean
  461. if not isinstance(img_std, torch.Tensor):
  462. img_std = torch.tensor(img_std, dtype=torch.float16)[:, None, None]
  463. self.img_std = img_std
  464. if gpu_acceleration:
  465. self.img_mean = self.img_mean.to(f"cuda:{self.gpu_id}")
  466. self.img_std = self.img_std.to(f"cuda:{self.gpu_id}")
  467. decoder_option = {"device": f"cuda:{self.gpu_id}"}
  468. else:
  469. self.img_mean = self.img_mean.cpu()
  470. self.img_std = self.img_std.cpu()
  471. decoder_option = {"num_threads": 1} # use a single thread to save memory
  472. self.rank = int(os.environ.get("RANK", "0"))
  473. self.world_size = int(os.environ.get("WORLD_SIZE", "1"))
  474. self.async_reader = TorchCodecDecoder(video_path, **decoder_option)
  475. # `num_frames_from_content` is the true number of frames in the video content
  476. # from the scan operation (rather than from the metadata, which could be wrong)
  477. self.num_frames = self.async_reader.metadata.num_frames_from_content
  478. self.video_height = self.async_reader.metadata.height
  479. self.video_width = self.async_reader.metadata.width
  480. # items in `self._images` will be loaded asynchronously
  481. self.images_loaded = [False] * self.num_frames
  482. self.images = torch.zeros(
  483. self.num_frames,
  484. 3,
  485. self.image_size,
  486. self.image_size,
  487. dtype=torch.float16,
  488. device=self.out_device,
  489. )
  490. # catch and raise any exceptions in the async loading thread
  491. self.exception = None
  492. self.use_rand_seek_in_loading = use_rand_seek_in_loading
  493. self.rand_seek_idx_queue = queue.Queue()
  494. # use a lock to avoid race condition between concurrent access to torchcodec
  495. # libs (which are not thread-safe); the lock is replaced with a nullcontext
  496. # when the video is fully loaded
  497. self.torchcodec_access_lock = FIFOLock()
  498. self._start_video_loading()
  499. def _load_one_frame(self, idx):
  500. frame_resized = self._transform_frame(self.async_reader[idx])
  501. return frame_resized
  502. @torch.inference_mode()
  503. def _start_video_loading(self):
  504. desc = f"frame loading (TorchCodec w/ {'GPU' if self.gpu_acceleration else 'CPU'}) [rank={RANK}]"
  505. pbar = tqdm(desc=desc, total=self.num_frames)
  506. self.num_loaded_frames = 0
  507. # load the first frame synchronously to cache it before the session is opened
  508. idx = self.num_loaded_frames
  509. self.images[idx] = self._load_one_frame(idx)
  510. self.images_loaded[idx] = True
  511. self.num_loaded_frames += 1
  512. pbar.update(n=1)
  513. self.all_frames_loaded = self.num_loaded_frames == self.num_frames
  514. # load the frames asynchronously without blocking the session start
  515. def _load_frames():
  516. finished = self.all_frames_loaded
  517. chunk_size = 16
  518. while not finished:
  519. # asynchronously load `chunk_size` frames each time we acquire the lock
  520. with self.torchcodec_access_lock, torch.inference_mode():
  521. for _ in range(chunk_size):
  522. try:
  523. idx = self.num_loaded_frames
  524. self.images[idx] = self._load_one_frame(idx)
  525. self.images_loaded[idx] = True
  526. self.num_loaded_frames += 1
  527. pbar.update(n=1)
  528. if self.num_loaded_frames >= self.num_frames:
  529. finished = True
  530. break
  531. except Exception as e:
  532. self.exception = e
  533. raise
  534. # also read the frame that is being randomly seeked to
  535. while True:
  536. try:
  537. idx = self.rand_seek_idx_queue.get_nowait()
  538. if not self.images_loaded[idx]:
  539. self.images[idx] = self._load_one_frame(idx)
  540. self.images_loaded[idx] = True
  541. except queue.Empty:
  542. break
  543. except Exception as e:
  544. self.exception = e
  545. raise
  546. # finished -- check whether we have loaded the total number of frames
  547. if self.num_loaded_frames != self.num_frames:
  548. raise RuntimeError(
  549. f"There are {self.num_frames} frames in the video, but only "
  550. f"{self.num_loaded_frames} frames can be loaded successfully."
  551. )
  552. else:
  553. self.all_frames_loaded = True
  554. pbar.close()
  555. with self.torchcodec_access_lock:
  556. import gc
  557. # all frames have been loaded, so we can release the readers and free their memory
  558. # also remove pbar and thread (which shouldn't be a part of session saving)
  559. reader = self.async_reader
  560. if reader is not None:
  561. reader._source = None
  562. self.async_reader = None
  563. self.pbar = None
  564. self.thread = None
  565. self.rand_seek_idx_queue = None
  566. gc.collect()
  567. # remove the lock (replace it with nullcontext) when the video is fully loaded
  568. self.torchcodec_access_lock = contextlib.nullcontext()
  569. self.thread = Thread(target=_load_frames, daemon=True)
  570. self.thread.start()
  571. def _transform_frame(self, frame):
  572. frame = frame.clone() # make a copy to avoid modifying the original frame bytes
  573. frame = frame.float() # convert to float32 before interpolation
  574. frame_resized = F.interpolate(
  575. frame[None, :],
  576. size=(self.image_size, self.image_size),
  577. mode="bicubic",
  578. align_corners=False,
  579. )[0]
  580. # float16 precision should be sufficient for image tensor storage
  581. frame_resized = frame_resized.half() # uint8 -> float16
  582. frame_resized /= 255
  583. frame_resized -= self.img_mean
  584. frame_resized /= self.img_std
  585. if self.offload_video_to_cpu:
  586. frame_resized = frame_resized.cpu()
  587. elif frame_resized.device != self.out_device:
  588. frame_resized = frame_resized.to(device=self.out_device, non_blocking=True)
  589. return frame_resized
  590. def __getitem__(self, index):
  591. if self.exception is not None:
  592. raise RuntimeError("Failure in frame loading thread") from self.exception
  593. max_tries = 1200
  594. for _ in range(max_tries):
  595. # use a lock to avoid race condition between concurrent access to torchcodec
  596. # libs (which are not thread-safe); the lock is replaced with a nullcontext
  597. # when the video is fully loaded
  598. with self.torchcodec_access_lock:
  599. if self.images_loaded[index]:
  600. return self.images[index]
  601. if self.use_rand_seek_in_loading:
  602. # async loading hasn't reached this frame yet, so we load this frame individually
  603. # (it will be loaded by in _load_frames thread and added to self.images[index])
  604. self.rand_seek_idx_queue.put(index)
  605. time.sleep(0.1)
  606. raise RuntimeError(f"Failed to load frame {index} after {max_tries} tries")
  607. def __len__(self):
  608. return len(self.images)
  609. def __getstate__(self):
  610. """
  611. Remove a few attributes during pickling, so that this async video loader can be
  612. saved and loaded as a part of the model session.
  613. """
  614. # wait for async video loading to finish before pickling
  615. async_thread = self.thread
  616. if async_thread is not None:
  617. async_thread.join()
  618. # release a few objects that cannot be pickled
  619. reader = self.async_reader
  620. if reader is not None:
  621. reader._source = None
  622. self.async_reader = None
  623. self.pbar = None
  624. self.thread = None
  625. self.rand_seek_idx_queue = None
  626. self.torchcodec_access_lock = contextlib.nullcontext()
  627. return self.__dict__.copy()