transforms.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. """
  6. Transforms and data augmentation for both image + bbox.
  7. """
  8. import logging
  9. import random
  10. from typing import Iterable
  11. import torch
  12. import torchvision.transforms as T
  13. import torchvision.transforms.functional as F
  14. import torchvision.transforms.v2.functional as Fv2
  15. from PIL import Image as PILImage
  16. from torchvision.transforms import InterpolationMode
  17. from training.utils.data_utils import VideoDatapoint
  18. def hflip(datapoint, index):
  19. datapoint.frames[index].data = F.hflip(datapoint.frames[index].data)
  20. for obj in datapoint.frames[index].objects:
  21. if obj.segment is not None:
  22. obj.segment = F.hflip(obj.segment)
  23. return datapoint
  24. def get_size_with_aspect_ratio(image_size, size, max_size=None):
  25. w, h = image_size
  26. if max_size is not None:
  27. min_original_size = float(min((w, h)))
  28. max_original_size = float(max((w, h)))
  29. if max_original_size / min_original_size * size > max_size:
  30. size = max_size * min_original_size / max_original_size
  31. if (w <= h and w == size) or (h <= w and h == size):
  32. return (h, w)
  33. if w < h:
  34. ow = int(round(size))
  35. oh = int(round(size * h / w))
  36. else:
  37. oh = int(round(size))
  38. ow = int(round(size * w / h))
  39. return (oh, ow)
  40. def resize(datapoint, index, size, max_size=None, square=False, v2=False):
  41. # size can be min_size (scalar) or (w, h) tuple
  42. def get_size(image_size, size, max_size=None):
  43. if isinstance(size, (list, tuple)):
  44. return size[::-1]
  45. else:
  46. return get_size_with_aspect_ratio(image_size, size, max_size)
  47. if square:
  48. size = size, size
  49. else:
  50. cur_size = (
  51. datapoint.frames[index].data.size()[-2:][::-1]
  52. if v2
  53. else datapoint.frames[index].data.size
  54. )
  55. size = get_size(cur_size, size, max_size)
  56. old_size = (
  57. datapoint.frames[index].data.size()[-2:][::-1]
  58. if v2
  59. else datapoint.frames[index].data.size
  60. )
  61. if v2:
  62. datapoint.frames[index].data = Fv2.resize(
  63. datapoint.frames[index].data, size, antialias=True
  64. )
  65. else:
  66. datapoint.frames[index].data = F.resize(datapoint.frames[index].data, size)
  67. new_size = (
  68. datapoint.frames[index].data.size()[-2:][::-1]
  69. if v2
  70. else datapoint.frames[index].data.size
  71. )
  72. for obj in datapoint.frames[index].objects:
  73. if obj.segment is not None:
  74. obj.segment = F.resize(obj.segment[None, None], size).squeeze()
  75. h, w = size
  76. datapoint.frames[index].size = (h, w)
  77. return datapoint
  78. def pad(datapoint, index, padding, v2=False):
  79. old_h, old_w = datapoint.frames[index].size
  80. h, w = old_h, old_w
  81. if len(padding) == 2:
  82. # assumes that we only pad on the bottom right corners
  83. datapoint.frames[index].data = F.pad(
  84. datapoint.frames[index].data, (0, 0, padding[0], padding[1])
  85. )
  86. h += padding[1]
  87. w += padding[0]
  88. else:
  89. # left, top, right, bottom
  90. datapoint.frames[index].data = F.pad(
  91. datapoint.frames[index].data,
  92. (padding[0], padding[1], padding[2], padding[3]),
  93. )
  94. h += padding[1] + padding[3]
  95. w += padding[0] + padding[2]
  96. datapoint.frames[index].size = (h, w)
  97. for obj in datapoint.frames[index].objects:
  98. if obj.segment is not None:
  99. if v2:
  100. if len(padding) == 2:
  101. obj.segment = Fv2.pad(obj.segment, (0, 0, padding[0], padding[1]))
  102. else:
  103. obj.segment = Fv2.pad(obj.segment, tuple(padding))
  104. else:
  105. if len(padding) == 2:
  106. obj.segment = F.pad(obj.segment, (0, 0, padding[0], padding[1]))
  107. else:
  108. obj.segment = F.pad(obj.segment, tuple(padding))
  109. return datapoint
  110. class RandomHorizontalFlip:
  111. def __init__(self, consistent_transform, p=0.5):
  112. self.p = p
  113. self.consistent_transform = consistent_transform
  114. def __call__(self, datapoint, **kwargs):
  115. if self.consistent_transform:
  116. if random.random() < self.p:
  117. for i in range(len(datapoint.frames)):
  118. datapoint = hflip(datapoint, i)
  119. return datapoint
  120. for i in range(len(datapoint.frames)):
  121. if random.random() < self.p:
  122. datapoint = hflip(datapoint, i)
  123. return datapoint
  124. class RandomResizeAPI:
  125. def __init__(
  126. self, sizes, consistent_transform, max_size=None, square=False, v2=False
  127. ):
  128. if isinstance(sizes, int):
  129. sizes = (sizes,)
  130. assert isinstance(sizes, Iterable)
  131. self.sizes = list(sizes)
  132. self.max_size = max_size
  133. self.square = square
  134. self.consistent_transform = consistent_transform
  135. self.v2 = v2
  136. def __call__(self, datapoint, **kwargs):
  137. if self.consistent_transform:
  138. size = random.choice(self.sizes)
  139. for i in range(len(datapoint.frames)):
  140. datapoint = resize(
  141. datapoint, i, size, self.max_size, square=self.square, v2=self.v2
  142. )
  143. return datapoint
  144. for i in range(len(datapoint.frames)):
  145. size = random.choice(self.sizes)
  146. datapoint = resize(
  147. datapoint, i, size, self.max_size, square=self.square, v2=self.v2
  148. )
  149. return datapoint
  150. class ToTensorAPI:
  151. def __init__(self, v2=False):
  152. self.v2 = v2
  153. def __call__(self, datapoint: VideoDatapoint, **kwargs):
  154. for img in datapoint.frames:
  155. if self.v2:
  156. img.data = Fv2.to_image_tensor(img.data)
  157. else:
  158. img.data = F.to_tensor(img.data)
  159. return datapoint
  160. class NormalizeAPI:
  161. def __init__(self, mean, std, v2=False):
  162. self.mean = mean
  163. self.std = std
  164. self.v2 = v2
  165. def __call__(self, datapoint: VideoDatapoint, **kwargs):
  166. for img in datapoint.frames:
  167. if self.v2:
  168. img.data = Fv2.convert_image_dtype(img.data, torch.float32)
  169. img.data = Fv2.normalize(img.data, mean=self.mean, std=self.std)
  170. else:
  171. img.data = F.normalize(img.data, mean=self.mean, std=self.std)
  172. return datapoint
  173. class ComposeAPI:
  174. def __init__(self, transforms):
  175. self.transforms = transforms
  176. def __call__(self, datapoint, **kwargs):
  177. for t in self.transforms:
  178. datapoint = t(datapoint, **kwargs)
  179. return datapoint
  180. def __repr__(self):
  181. format_string = self.__class__.__name__ + "("
  182. for t in self.transforms:
  183. format_string += "\n"
  184. format_string += " {0}".format(t)
  185. format_string += "\n)"
  186. return format_string
  187. class RandomGrayscale:
  188. def __init__(self, consistent_transform, p=0.5):
  189. self.p = p
  190. self.consistent_transform = consistent_transform
  191. self.Grayscale = T.Grayscale(num_output_channels=3)
  192. def __call__(self, datapoint: VideoDatapoint, **kwargs):
  193. if self.consistent_transform:
  194. if random.random() < self.p:
  195. for img in datapoint.frames:
  196. img.data = self.Grayscale(img.data)
  197. return datapoint
  198. for img in datapoint.frames:
  199. if random.random() < self.p:
  200. img.data = self.Grayscale(img.data)
  201. return datapoint
  202. class ColorJitter:
  203. def __init__(self, consistent_transform, brightness, contrast, saturation, hue):
  204. self.consistent_transform = consistent_transform
  205. self.brightness = (
  206. brightness
  207. if isinstance(brightness, list)
  208. else [max(0, 1 - brightness), 1 + brightness]
  209. )
  210. self.contrast = (
  211. contrast
  212. if isinstance(contrast, list)
  213. else [max(0, 1 - contrast), 1 + contrast]
  214. )
  215. self.saturation = (
  216. saturation
  217. if isinstance(saturation, list)
  218. else [max(0, 1 - saturation), 1 + saturation]
  219. )
  220. self.hue = hue if isinstance(hue, list) or hue is None else ([-hue, hue])
  221. def __call__(self, datapoint: VideoDatapoint, **kwargs):
  222. if self.consistent_transform:
  223. # Create a color jitter transformation params
  224. (
  225. fn_idx,
  226. brightness_factor,
  227. contrast_factor,
  228. saturation_factor,
  229. hue_factor,
  230. ) = T.ColorJitter.get_params(
  231. self.brightness, self.contrast, self.saturation, self.hue
  232. )
  233. for img in datapoint.frames:
  234. if not self.consistent_transform:
  235. (
  236. fn_idx,
  237. brightness_factor,
  238. contrast_factor,
  239. saturation_factor,
  240. hue_factor,
  241. ) = T.ColorJitter.get_params(
  242. self.brightness, self.contrast, self.saturation, self.hue
  243. )
  244. for fn_id in fn_idx:
  245. if fn_id == 0 and brightness_factor is not None:
  246. img.data = F.adjust_brightness(img.data, brightness_factor)
  247. elif fn_id == 1 and contrast_factor is not None:
  248. img.data = F.adjust_contrast(img.data, contrast_factor)
  249. elif fn_id == 2 and saturation_factor is not None:
  250. img.data = F.adjust_saturation(img.data, saturation_factor)
  251. elif fn_id == 3 and hue_factor is not None:
  252. img.data = F.adjust_hue(img.data, hue_factor)
  253. return datapoint
  254. class RandomAffine:
  255. def __init__(
  256. self,
  257. degrees,
  258. consistent_transform,
  259. scale=None,
  260. translate=None,
  261. shear=None,
  262. image_mean=(123, 116, 103),
  263. log_warning=True,
  264. num_tentatives=1,
  265. image_interpolation="bicubic",
  266. ):
  267. """
  268. The mask is required for this transform.
  269. if consistent_transform if True, then the same random affine is applied to all frames and masks.
  270. """
  271. self.degrees = degrees if isinstance(degrees, list) else ([-degrees, degrees])
  272. self.scale = scale
  273. self.shear = (
  274. shear if isinstance(shear, list) else ([-shear, shear] if shear else None)
  275. )
  276. self.translate = translate
  277. self.fill_img = image_mean
  278. self.consistent_transform = consistent_transform
  279. self.log_warning = log_warning
  280. self.num_tentatives = num_tentatives
  281. if image_interpolation == "bicubic":
  282. self.image_interpolation = InterpolationMode.BICUBIC
  283. elif image_interpolation == "bilinear":
  284. self.image_interpolation = InterpolationMode.BILINEAR
  285. else:
  286. raise NotImplementedError
  287. def __call__(self, datapoint: VideoDatapoint, **kwargs):
  288. for _tentative in range(self.num_tentatives):
  289. res = self.transform_datapoint(datapoint)
  290. if res is not None:
  291. return res
  292. if self.log_warning:
  293. logging.warning(
  294. f"Skip RandomAffine for zero-area mask in first frame after {self.num_tentatives} tentatives"
  295. )
  296. return datapoint
  297. def transform_datapoint(self, datapoint: VideoDatapoint):
  298. _, height, width = F.get_dimensions(datapoint.frames[0].data)
  299. img_size = [width, height]
  300. if self.consistent_transform:
  301. # Create a random affine transformation
  302. affine_params = T.RandomAffine.get_params(
  303. degrees=self.degrees,
  304. translate=self.translate,
  305. scale_ranges=self.scale,
  306. shears=self.shear,
  307. img_size=img_size,
  308. )
  309. for img_idx, img in enumerate(datapoint.frames):
  310. this_masks = [
  311. obj.segment.unsqueeze(0) if obj.segment is not None else None
  312. for obj in img.objects
  313. ]
  314. if not self.consistent_transform:
  315. # if not consistent we create a new affine params for every frame&mask pair Create a random affine transformation
  316. affine_params = T.RandomAffine.get_params(
  317. degrees=self.degrees,
  318. translate=self.translate,
  319. scale_ranges=self.scale,
  320. shears=self.shear,
  321. img_size=img_size,
  322. )
  323. transformed_bboxes, transformed_masks = [], []
  324. for i in range(len(img.objects)):
  325. if this_masks[i] is None:
  326. transformed_masks.append(None)
  327. # Dummy bbox for a dummy target
  328. transformed_bboxes.append(torch.tensor([[0, 0, 1, 1]]))
  329. else:
  330. transformed_mask = F.affine(
  331. this_masks[i],
  332. *affine_params,
  333. interpolation=InterpolationMode.NEAREST,
  334. fill=0.0,
  335. )
  336. if img_idx == 0 and transformed_mask.max() == 0:
  337. # We are dealing with a video and the object is not visible in the first frame
  338. # Return the datapoint without transformation
  339. return None
  340. transformed_masks.append(transformed_mask.squeeze())
  341. for i in range(len(img.objects)):
  342. img.objects[i].segment = transformed_masks[i]
  343. img.data = F.affine(
  344. img.data,
  345. *affine_params,
  346. interpolation=self.image_interpolation,
  347. fill=self.fill_img,
  348. )
  349. return datapoint
  350. def random_mosaic_frame(
  351. datapoint,
  352. index,
  353. grid_h,
  354. grid_w,
  355. target_grid_y,
  356. target_grid_x,
  357. should_hflip,
  358. ):
  359. # Step 1: downsize the images and paste them into a mosaic
  360. image_data = datapoint.frames[index].data
  361. is_pil = isinstance(image_data, PILImage.Image)
  362. if is_pil:
  363. H_im = image_data.height
  364. W_im = image_data.width
  365. image_data_output = PILImage.new("RGB", (W_im, H_im))
  366. else:
  367. H_im = image_data.size(-2)
  368. W_im = image_data.size(-1)
  369. image_data_output = torch.zeros_like(image_data)
  370. downsize_cache = {}
  371. for grid_y in range(grid_h):
  372. for grid_x in range(grid_w):
  373. y_offset_b = grid_y * H_im // grid_h
  374. x_offset_b = grid_x * W_im // grid_w
  375. y_offset_e = (grid_y + 1) * H_im // grid_h
  376. x_offset_e = (grid_x + 1) * W_im // grid_w
  377. H_im_downsize = y_offset_e - y_offset_b
  378. W_im_downsize = x_offset_e - x_offset_b
  379. if (H_im_downsize, W_im_downsize) in downsize_cache:
  380. image_data_downsize = downsize_cache[(H_im_downsize, W_im_downsize)]
  381. else:
  382. image_data_downsize = F.resize(
  383. image_data,
  384. size=(H_im_downsize, W_im_downsize),
  385. interpolation=InterpolationMode.BILINEAR,
  386. antialias=True, # antialiasing for downsizing
  387. )
  388. downsize_cache[(H_im_downsize, W_im_downsize)] = image_data_downsize
  389. if should_hflip[grid_y, grid_x].item():
  390. image_data_downsize = F.hflip(image_data_downsize)
  391. if is_pil:
  392. image_data_output.paste(image_data_downsize, (x_offset_b, y_offset_b))
  393. else:
  394. image_data_output[:, y_offset_b:y_offset_e, x_offset_b:x_offset_e] = (
  395. image_data_downsize
  396. )
  397. datapoint.frames[index].data = image_data_output
  398. # Step 2: downsize the masks and paste them into the target grid of the mosaic
  399. for obj in datapoint.frames[index].objects:
  400. if obj.segment is None:
  401. continue
  402. assert obj.segment.shape == (H_im, W_im) and obj.segment.dtype == torch.uint8
  403. segment_output = torch.zeros_like(obj.segment)
  404. target_y_offset_b = target_grid_y * H_im // grid_h
  405. target_x_offset_b = target_grid_x * W_im // grid_w
  406. target_y_offset_e = (target_grid_y + 1) * H_im // grid_h
  407. target_x_offset_e = (target_grid_x + 1) * W_im // grid_w
  408. target_H_im_downsize = target_y_offset_e - target_y_offset_b
  409. target_W_im_downsize = target_x_offset_e - target_x_offset_b
  410. segment_downsize = F.resize(
  411. obj.segment[None, None],
  412. size=(target_H_im_downsize, target_W_im_downsize),
  413. interpolation=InterpolationMode.BILINEAR,
  414. antialias=True, # antialiasing for downsizing
  415. )[0, 0]
  416. if should_hflip[target_grid_y, target_grid_x].item():
  417. segment_downsize = F.hflip(segment_downsize[None, None])[0, 0]
  418. segment_output[
  419. target_y_offset_b:target_y_offset_e, target_x_offset_b:target_x_offset_e
  420. ] = segment_downsize
  421. obj.segment = segment_output
  422. return datapoint
  423. class RandomMosaicVideoAPI:
  424. def __init__(self, prob=0.15, grid_h=2, grid_w=2, use_random_hflip=False):
  425. self.prob = prob
  426. self.grid_h = grid_h
  427. self.grid_w = grid_w
  428. self.use_random_hflip = use_random_hflip
  429. def __call__(self, datapoint, **kwargs):
  430. if random.random() > self.prob:
  431. return datapoint
  432. # select a random location to place the target mask in the mosaic
  433. target_grid_y = random.randint(0, self.grid_h - 1)
  434. target_grid_x = random.randint(0, self.grid_w - 1)
  435. # whether to flip each grid in the mosaic horizontally
  436. if self.use_random_hflip:
  437. should_hflip = torch.rand(self.grid_h, self.grid_w) < 0.5
  438. else:
  439. should_hflip = torch.zeros(self.grid_h, self.grid_w, dtype=torch.bool)
  440. for i in range(len(datapoint.frames)):
  441. datapoint = random_mosaic_frame(
  442. datapoint,
  443. i,
  444. grid_h=self.grid_h,
  445. grid_w=self.grid_w,
  446. target_grid_y=target_grid_y,
  447. target_grid_x=target_grid_x,
  448. should_hflip=should_hflip,
  449. )
  450. return datapoint