basic_for_api.py 53 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. """
  4. Transforms and data augmentation for both image + bbox.
  5. """
  6. import logging
  7. import numbers
  8. import random
  9. from collections.abc import Sequence
  10. from typing import Iterable
  11. import torch
  12. import torchvision.transforms as T
  13. import torchvision.transforms.functional as F
  14. import torchvision.transforms.v2.functional as Fv2
  15. from PIL import Image as PILImage
  16. from sam3.model.box_ops import box_xyxy_to_cxcywh, masks_to_boxes
  17. from sam3.train.data.sam3_image_dataset import Datapoint
  18. from torchvision.transforms import InterpolationMode
  19. def crop(
  20. datapoint,
  21. index,
  22. region,
  23. v2=False,
  24. check_validity=True,
  25. check_input_validity=True,
  26. recompute_box_from_mask=False,
  27. ):
  28. if v2:
  29. rtop, rleft, rheight, rwidth = (int(round(r)) for r in region)
  30. datapoint.images[index].data = Fv2.crop(
  31. datapoint.images[index].data,
  32. top=rtop,
  33. left=rleft,
  34. height=rheight,
  35. width=rwidth,
  36. )
  37. else:
  38. datapoint.images[index].data = F.crop(datapoint.images[index].data, *region)
  39. i, j, h, w = region
  40. # should we do something wrt the original size?
  41. datapoint.images[index].size = (h, w)
  42. for obj in datapoint.images[index].objects:
  43. # crop the mask
  44. if obj.segment is not None:
  45. obj.segment = F.crop(obj.segment, int(i), int(j), int(h), int(w))
  46. # crop the bounding box
  47. if recompute_box_from_mask and obj.segment is not None:
  48. # here the boxes are still in XYXY format with absolute coordinates (they are
  49. # converted to CxCyWH with relative coordinates in basic_for_api.NormalizeAPI)
  50. obj.bbox, obj.area = get_bbox_xyxy_abs_coords_from_mask(obj.segment)
  51. else:
  52. if recompute_box_from_mask and obj.segment is None and obj.area > 0:
  53. logging.warning(
  54. "Cannot recompute bounding box from mask since `obj.segment` is None. "
  55. "Falling back to directly cropping from the input bounding box."
  56. )
  57. boxes = obj.bbox.view(1, 4)
  58. max_size = torch.as_tensor([w, h], dtype=torch.float32)
  59. cropped_boxes = boxes - torch.as_tensor([j, i, j, i], dtype=torch.float32)
  60. cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
  61. cropped_boxes = cropped_boxes.clamp(min=0)
  62. obj.area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
  63. obj.bbox = cropped_boxes.reshape(-1, 4)
  64. for query in datapoint.find_queries:
  65. if query.semantic_target is not None:
  66. query.semantic_target = F.crop(
  67. query.semantic_target, int(i), int(j), int(h), int(w)
  68. )
  69. if query.image_id == index and query.input_bbox is not None:
  70. boxes = query.input_bbox
  71. max_size = torch.as_tensor([w, h], dtype=torch.float32)
  72. cropped_boxes = boxes - torch.as_tensor([j, i, j, i], dtype=torch.float32)
  73. cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
  74. cropped_boxes = cropped_boxes.clamp(min=0)
  75. # cur_area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
  76. # if check_input_validity:
  77. # assert (
  78. # (cur_area > 0).all().item()
  79. # ), "Some input box got cropped out by the crop transform"
  80. query.input_bbox = cropped_boxes.reshape(-1, 4)
  81. if query.image_id == index and query.input_points is not None:
  82. print(
  83. "Warning! Point cropping with this function may lead to unexpected results"
  84. )
  85. points = query.input_points
  86. # Unlike right-lower box edges, which are exclusive, the
  87. # point must be in [0, length-1], hence the -1
  88. max_size = torch.as_tensor([w, h], dtype=torch.float32) - 1
  89. cropped_points = points - torch.as_tensor([j, i, 0], dtype=torch.float32)
  90. cropped_points[:, :, :2] = torch.min(cropped_points[:, :, :2], max_size)
  91. cropped_points[:, :, :2] = cropped_points[:, :, :2].clamp(min=0)
  92. query.input_points = cropped_points
  93. if check_validity:
  94. # Check that all boxes are still valid
  95. for obj in datapoint.images[index].objects:
  96. assert obj.area > 0, "Box {} has no area".format(obj.bbox)
  97. return datapoint
  98. def hflip(datapoint, index):
  99. datapoint.images[index].data = F.hflip(datapoint.images[index].data)
  100. w, h = datapoint.images[index].data.size
  101. for obj in datapoint.images[index].objects:
  102. boxes = obj.bbox.view(1, 4)
  103. boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor(
  104. [-1, 1, -1, 1]
  105. ) + torch.as_tensor([w, 0, w, 0])
  106. obj.bbox = boxes
  107. if obj.segment is not None:
  108. obj.segment = F.hflip(obj.segment)
  109. for query in datapoint.find_queries:
  110. if query.semantic_target is not None:
  111. query.semantic_target = F.hflip(query.semantic_target)
  112. if query.image_id == index and query.input_bbox is not None:
  113. boxes = query.input_bbox
  114. boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor(
  115. [-1, 1, -1, 1]
  116. ) + torch.as_tensor([w, 0, w, 0])
  117. query.input_bbox = boxes
  118. if query.image_id == index and query.input_points is not None:
  119. points = query.input_points
  120. points = points * torch.as_tensor([-1, 1, 1]) + torch.as_tensor([w, 0, 0])
  121. query.input_points = points
  122. return datapoint
  123. def get_size_with_aspect_ratio(image_size, size, max_size=None):
  124. w, h = image_size
  125. if max_size is not None:
  126. min_original_size = float(min((w, h)))
  127. max_original_size = float(max((w, h)))
  128. if max_original_size / min_original_size * size > max_size:
  129. size = max_size * min_original_size / max_original_size
  130. if (w <= h and w == size) or (h <= w and h == size):
  131. return (h, w)
  132. if w < h:
  133. ow = int(round(size))
  134. oh = int(round(size * h / w))
  135. else:
  136. oh = int(round(size))
  137. ow = int(round(size * w / h))
  138. return (oh, ow)
  139. def resize(datapoint, index, size, max_size=None, square=False, v2=False):
  140. # size can be min_size (scalar) or (w, h) tuple
  141. def get_size(image_size, size, max_size=None):
  142. if isinstance(size, (list, tuple)):
  143. return size[::-1]
  144. else:
  145. return get_size_with_aspect_ratio(image_size, size, max_size)
  146. if square:
  147. size = size, size
  148. else:
  149. cur_size = (
  150. datapoint.images[index].data.size()[-2:][::-1]
  151. if v2
  152. else datapoint.images[index].data.size
  153. )
  154. size = get_size(cur_size, size, max_size)
  155. old_size = (
  156. datapoint.images[index].data.size()[-2:][::-1]
  157. if v2
  158. else datapoint.images[index].data.size
  159. )
  160. if v2:
  161. datapoint.images[index].data = Fv2.resize(
  162. datapoint.images[index].data, size, antialias=True
  163. )
  164. else:
  165. datapoint.images[index].data = F.resize(datapoint.images[index].data, size)
  166. new_size = (
  167. datapoint.images[index].data.size()[-2:][::-1]
  168. if v2
  169. else datapoint.images[index].data.size
  170. )
  171. ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(new_size, old_size))
  172. ratio_width, ratio_height = ratios
  173. for obj in datapoint.images[index].objects:
  174. boxes = obj.bbox.view(1, 4)
  175. scaled_boxes = boxes * torch.as_tensor(
  176. [ratio_width, ratio_height, ratio_width, ratio_height], dtype=torch.float32
  177. )
  178. obj.bbox = scaled_boxes
  179. obj.area *= ratio_width * ratio_height
  180. if obj.segment is not None:
  181. obj.segment = F.resize(obj.segment[None, None], size).squeeze()
  182. for query in datapoint.find_queries:
  183. if query.semantic_target is not None:
  184. query.semantic_target = F.resize(
  185. query.semantic_target[None, None], size
  186. ).squeeze()
  187. if query.image_id == index and query.input_bbox is not None:
  188. boxes = query.input_bbox
  189. scaled_boxes = boxes * torch.as_tensor(
  190. [ratio_width, ratio_height, ratio_width, ratio_height],
  191. dtype=torch.float32,
  192. )
  193. query.input_bbox = scaled_boxes
  194. if query.image_id == index and query.input_points is not None:
  195. points = query.input_points
  196. scaled_points = points * torch.as_tensor(
  197. [ratio_width, ratio_height, 1],
  198. dtype=torch.float32,
  199. )
  200. query.input_points = scaled_points
  201. h, w = size
  202. datapoint.images[index].size = (h, w)
  203. return datapoint
  204. def pad(datapoint, index, padding, v2=False):
  205. old_h, old_w = datapoint.images[index].size
  206. h, w = old_h, old_w
  207. if len(padding) == 2:
  208. # assumes that we only pad on the bottom right corners
  209. if v2:
  210. datapoint.images[index].data = Fv2.pad(
  211. datapoint.images[index].data, (0, 0, padding[0], padding[1])
  212. )
  213. else:
  214. datapoint.images[index].data = F.pad(
  215. datapoint.images[index].data, (0, 0, padding[0], padding[1])
  216. )
  217. h += padding[1]
  218. w += padding[0]
  219. else:
  220. if v2:
  221. # left, top, right, bottom
  222. datapoint.images[index].data = Fv2.pad(
  223. datapoint.images[index].data,
  224. (padding[0], padding[1], padding[2], padding[3]),
  225. )
  226. else:
  227. # left, top, right, bottom
  228. datapoint.images[index].data = F.pad(
  229. datapoint.images[index].data,
  230. (padding[0], padding[1], padding[2], padding[3]),
  231. )
  232. h += padding[1] + padding[3]
  233. w += padding[0] + padding[2]
  234. datapoint.images[index].size = (h, w)
  235. for obj in datapoint.images[index].objects:
  236. if len(padding) != 2:
  237. obj.bbox += torch.as_tensor(
  238. [padding[0], padding[1], padding[0], padding[1]], dtype=torch.float32
  239. )
  240. if obj.segment is not None:
  241. if v2:
  242. if len(padding) == 2:
  243. obj.segment = Fv2.pad(
  244. obj.segment[None], (0, 0, padding[0], padding[1])
  245. ).squeeze(0)
  246. else:
  247. obj.segment = Fv2.pad(obj.segment[None], tuple(padding)).squeeze(0)
  248. else:
  249. if len(padding) == 2:
  250. obj.segment = F.pad(obj.segment, (0, 0, padding[0], padding[1]))
  251. else:
  252. obj.segment = F.pad(obj.segment, tuple(padding))
  253. for query in datapoint.find_queries:
  254. if query.semantic_target is not None:
  255. if v2:
  256. if len(padding) == 2:
  257. query.semantic_target = Fv2.pad(
  258. query.semantic_target[None, None],
  259. (0, 0, padding[0], padding[1]),
  260. ).squeeze()
  261. else:
  262. query.semantic_target = Fv2.pad(
  263. query.semantic_target[None, None], tuple(padding)
  264. ).squeeze()
  265. else:
  266. if len(padding) == 2:
  267. query.semantic_target = F.pad(
  268. query.semantic_target[None, None],
  269. (0, 0, padding[0], padding[1]),
  270. ).squeeze()
  271. else:
  272. query.semantic_target = F.pad(
  273. query.semantic_target[None, None], tuple(padding)
  274. ).squeeze()
  275. if query.image_id == index and query.input_bbox is not None:
  276. if len(padding) != 2:
  277. query.input_bbox += torch.as_tensor(
  278. [padding[0], padding[1], padding[0], padding[1]],
  279. dtype=torch.float32,
  280. )
  281. if query.image_id == index and query.input_points is not None:
  282. if len(padding) != 2:
  283. query.input_points += torch.as_tensor(
  284. [padding[0], padding[1], 0], dtype=torch.float32
  285. )
  286. return datapoint
  287. class RandomSizeCropAPI:
  288. def __init__(
  289. self,
  290. min_size: int,
  291. max_size: int,
  292. respect_boxes: bool,
  293. consistent_transform: bool,
  294. respect_input_boxes: bool = True,
  295. v2: bool = False,
  296. recompute_box_from_mask: bool = False,
  297. ):
  298. self.min_size = min_size
  299. self.max_size = max_size
  300. self.respect_boxes = respect_boxes # if True we can't crop a box out
  301. self.respect_input_boxes = respect_input_boxes
  302. self.consistent_transform = consistent_transform
  303. self.v2 = v2
  304. self.recompute_box_from_mask = recompute_box_from_mask
  305. def _sample_no_respect_boxes(self, img):
  306. w = random.randint(self.min_size, min(img.width, self.max_size))
  307. h = random.randint(self.min_size, min(img.height, self.max_size))
  308. return T.RandomCrop.get_params(img, (h, w))
  309. def _sample_respect_boxes(self, img, boxes, points, min_box_size=10.0):
  310. """
  311. Assure that no box or point is dropped via cropping, though portions
  312. of boxes may be removed.
  313. """
  314. if len(boxes) == 0 and len(points) == 0:
  315. return self._sample_no_respect_boxes(img)
  316. if self.v2:
  317. img_height, img_width = img.size()[-2:]
  318. else:
  319. img_width, img_height = img.size
  320. minW, minH, maxW, maxH = (
  321. min(img_width, self.min_size),
  322. min(img_height, self.min_size),
  323. min(img_width, self.max_size),
  324. min(img_height, self.max_size),
  325. )
  326. # The crop box must extend one pixel beyond points to the bottom/right
  327. # to assure the exclusive box contains the points.
  328. minX = (
  329. torch.cat([boxes[:, 0] + min_box_size, points[:, 0] + 1], dim=0)
  330. .max()
  331. .item()
  332. )
  333. minY = (
  334. torch.cat([boxes[:, 1] + min_box_size, points[:, 1] + 1], dim=0)
  335. .max()
  336. .item()
  337. )
  338. minX = min(img_width, minX)
  339. minY = min(img_height, minY)
  340. maxX = torch.cat([boxes[:, 2] - min_box_size, points[:, 0]], dim=0).min().item()
  341. maxY = torch.cat([boxes[:, 3] - min_box_size, points[:, 1]], dim=0).min().item()
  342. maxX = max(0.0, maxX)
  343. maxY = max(0.0, maxY)
  344. minW = max(minW, minX - maxX)
  345. minH = max(minH, minY - maxY)
  346. w = random.uniform(minW, max(minW, maxW))
  347. h = random.uniform(minH, max(minH, maxH))
  348. if minX > maxX:
  349. # i = random.uniform(max(0, minX - w + 1), max(maxX, max(0, minX - w + 1)))
  350. i = random.uniform(max(0, minX - w), max(maxX, max(0, minX - w)))
  351. else:
  352. i = random.uniform(
  353. max(0, minX - w + 1), max(maxX - 1, max(0, minX - w + 1))
  354. )
  355. if minY > maxY:
  356. # j = random.uniform(max(0, minY - h + 1), max(maxY, max(0, minY - h + 1)))
  357. j = random.uniform(max(0, minY - h), max(maxY, max(0, minY - h)))
  358. else:
  359. j = random.uniform(
  360. max(0, minY - h + 1), max(maxY - 1, max(0, minY - h + 1))
  361. )
  362. return [j, i, h, w]
  363. def __call__(self, datapoint, **kwargs):
  364. if self.respect_boxes or self.respect_input_boxes:
  365. if self.consistent_transform:
  366. # Check that all the images are the same size
  367. w, h = datapoint.images[0].data.size
  368. for img in datapoint.images:
  369. assert img.data.size == (w, h)
  370. all_boxes = []
  371. # Getting all boxes in all the images
  372. if self.respect_boxes:
  373. all_boxes += [
  374. obj.bbox.view(-1, 4)
  375. for img in datapoint.images
  376. for obj in img.objects
  377. ]
  378. # Get all the boxes in the find queries
  379. if self.respect_input_boxes:
  380. all_boxes += [
  381. q.input_bbox.view(-1, 4)
  382. for q in datapoint.find_queries
  383. if q.input_bbox is not None
  384. ]
  385. if all_boxes:
  386. all_boxes = torch.cat(all_boxes, 0)
  387. else:
  388. all_boxes = torch.empty(0, 4)
  389. all_points = [
  390. q.input_points.view(-1, 3)[:, :2]
  391. for q in datapoint.find_queries
  392. if q.input_points is not None
  393. ]
  394. if all_points:
  395. all_points = torch.cat(all_points, 0)
  396. else:
  397. all_points = torch.empty(0, 2)
  398. crop_param = self._sample_respect_boxes(
  399. datapoint.images[0].data, all_boxes, all_points
  400. )
  401. for i in range(len(datapoint.images)):
  402. datapoint = crop(
  403. datapoint,
  404. i,
  405. crop_param,
  406. v2=self.v2,
  407. check_validity=self.respect_boxes,
  408. check_input_validity=self.respect_input_boxes,
  409. recompute_box_from_mask=self.recompute_box_from_mask,
  410. )
  411. return datapoint
  412. else:
  413. for i in range(len(datapoint.images)):
  414. all_boxes = []
  415. # Get all boxes in the current image
  416. if self.respect_boxes:
  417. all_boxes += [
  418. obj.bbox.view(-1, 4) for obj in datapoint.images[i].objects
  419. ]
  420. # Get all the boxes in the find queries that correspond to this image
  421. if self.respect_input_boxes:
  422. all_boxes += [
  423. q.input_bbox.view(-1, 4)
  424. for q in datapoint.find_queries
  425. if q.image_id == i and q.input_bbox is not None
  426. ]
  427. if all_boxes:
  428. all_boxes = torch.cat(all_boxes, 0)
  429. else:
  430. all_boxes = torch.empty(0, 4)
  431. all_points = [
  432. q.input_points.view(-1, 3)[:, :2]
  433. for q in datapoint.find_queries
  434. if q.input_points is not None
  435. ]
  436. if all_points:
  437. all_points = torch.cat(all_points, 0)
  438. else:
  439. all_points = torch.empty(0, 2)
  440. crop_param = self._sample_respect_boxes(
  441. datapoint.images[i].data, all_boxes, all_points
  442. )
  443. datapoint = crop(
  444. datapoint,
  445. i,
  446. crop_param,
  447. v2=self.v2,
  448. check_validity=self.respect_boxes,
  449. check_input_validity=self.respect_input_boxes,
  450. recompute_box_from_mask=self.recompute_box_from_mask,
  451. )
  452. return datapoint
  453. else:
  454. if self.consistent_transform:
  455. # Check that all the images are the same size
  456. w, h = datapoint.images[0].data.size
  457. for img in datapoint.images:
  458. assert img.data.size == (w, h)
  459. crop_param = self._sample_no_respect_boxes(datapoint.images[0].data)
  460. for i in range(len(datapoint.images)):
  461. datapoint = crop(
  462. datapoint,
  463. i,
  464. crop_param,
  465. v2=self.v2,
  466. check_validity=self.respect_boxes,
  467. check_input_validity=self.respect_input_boxes,
  468. recompute_box_from_mask=self.recompute_box_from_mask,
  469. )
  470. return datapoint
  471. else:
  472. for i in range(len(datapoint.images)):
  473. crop_param = self._sample_no_respect_boxes(datapoint.images[i].data)
  474. datapoint = crop(
  475. datapoint,
  476. i,
  477. crop_param,
  478. v2=self.v2,
  479. check_validity=self.respect_boxes,
  480. check_input_validity=self.respect_input_boxes,
  481. recompute_box_from_mask=self.recompute_box_from_mask,
  482. )
  483. return datapoint
  484. class CenterCropAPI:
  485. def __init__(self, size, consistent_transform, recompute_box_from_mask=False):
  486. self.size = size
  487. self.consistent_transform = consistent_transform
  488. self.recompute_box_from_mask = recompute_box_from_mask
  489. def _sample_crop(self, image_width, image_height):
  490. crop_height, crop_width = self.size
  491. crop_top = int(round((image_height - crop_height) / 2.0))
  492. crop_left = int(round((image_width - crop_width) / 2.0))
  493. return crop_top, crop_left, crop_height, crop_width
  494. def __call__(self, datapoint, **kwargs):
  495. if self.consistent_transform:
  496. # Check that all the images are the same size
  497. w, h = datapoint.images[0].data.size
  498. for img in datapoint.images:
  499. assert img.size == (w, h)
  500. crop_top, crop_left, crop_height, crop_width = self._sample_crop(w, h)
  501. for i in range(len(datapoint.images)):
  502. datapoint = crop(
  503. datapoint,
  504. i,
  505. (crop_top, crop_left, crop_height, crop_width),
  506. recompute_box_from_mask=self.recompute_box_from_mask,
  507. )
  508. return datapoint
  509. for i in range(len(datapoint.images)):
  510. w, h = datapoint.images[i].data.size
  511. crop_top, crop_left, crop_height, crop_width = self._sample_crop(w, h)
  512. datapoint = crop(
  513. datapoint,
  514. i,
  515. (crop_top, crop_left, crop_height, crop_width),
  516. recompute_box_from_mask=self.recompute_box_from_mask,
  517. )
  518. return datapoint
  519. class RandomHorizontalFlip:
  520. def __init__(self, consistent_transform, p=0.5):
  521. self.p = p
  522. self.consistent_transform = consistent_transform
  523. def __call__(self, datapoint, **kwargs):
  524. if self.consistent_transform:
  525. if random.random() < self.p:
  526. for i in range(len(datapoint.images)):
  527. datapoint = hflip(datapoint, i)
  528. return datapoint
  529. for i in range(len(datapoint.images)):
  530. if random.random() < self.p:
  531. datapoint = hflip(datapoint, i)
  532. return datapoint
  533. class RandomResizeAPI:
  534. def __init__(
  535. self, sizes, consistent_transform, max_size=None, square=False, v2=False
  536. ):
  537. if isinstance(sizes, int):
  538. sizes = (sizes,)
  539. assert isinstance(sizes, Iterable)
  540. self.sizes = list(sizes)
  541. self.max_size = max_size
  542. self.square = square
  543. self.consistent_transform = consistent_transform
  544. self.v2 = v2
  545. def __call__(self, datapoint, **kwargs):
  546. if self.consistent_transform:
  547. size = random.choice(self.sizes)
  548. for i in range(len(datapoint.images)):
  549. datapoint = resize(
  550. datapoint, i, size, self.max_size, square=self.square, v2=self.v2
  551. )
  552. return datapoint
  553. for i in range(len(datapoint.images)):
  554. size = random.choice(self.sizes)
  555. datapoint = resize(
  556. datapoint, i, size, self.max_size, square=self.square, v2=self.v2
  557. )
  558. return datapoint
  559. class ScheduledRandomResizeAPI(RandomResizeAPI):
  560. def __init__(self, size_scheduler, consistent_transform, square=False):
  561. self.size_scheduler = size_scheduler
  562. # Just a meaningful init value for super
  563. params = self.size_scheduler(epoch_num=0)
  564. sizes, max_size = params["sizes"], params["max_size"]
  565. super().__init__(sizes, consistent_transform, max_size=max_size, square=square)
  566. def __call__(self, datapoint, **kwargs):
  567. assert "epoch" in kwargs, "Param scheduler needs to know the current epoch"
  568. params = self.size_scheduler(kwargs["epoch"])
  569. sizes, max_size = params["sizes"], params["max_size"]
  570. self.sizes = sizes
  571. self.max_size = max_size
  572. datapoint = super(ScheduledRandomResizeAPI, self).__call__(datapoint, **kwargs)
  573. return datapoint
  574. class RandomPadAPI:
  575. def __init__(self, max_pad, consistent_transform):
  576. self.max_pad = max_pad
  577. self.consistent_transform = consistent_transform
  578. def _sample_pad(self):
  579. pad_x = random.randint(0, self.max_pad)
  580. pad_y = random.randint(0, self.max_pad)
  581. return pad_x, pad_y
  582. def __call__(self, datapoint, **kwargs):
  583. if self.consistent_transform:
  584. pad_x, pad_y = self._sample_pad()
  585. for i in range(len(datapoint.images)):
  586. datapoint = pad(datapoint, i, (pad_x, pad_y))
  587. return datapoint
  588. for i in range(len(datapoint.images)):
  589. pad_x, pad_y = self._sample_pad()
  590. datapoint = pad(datapoint, i, (pad_x, pad_y))
  591. return datapoint
  592. class PadToSizeAPI:
  593. def __init__(self, size, consistent_transform, bottom_right=False, v2=False):
  594. self.size = size
  595. self.consistent_transform = consistent_transform
  596. self.v2 = v2
  597. self.bottom_right = bottom_right
  598. def _sample_pad(self, w, h):
  599. pad_x = self.size - w
  600. pad_y = self.size - h
  601. assert pad_x >= 0 and pad_y >= 0
  602. pad_left = random.randint(0, pad_x)
  603. pad_right = pad_x - pad_left
  604. pad_top = random.randint(0, pad_y)
  605. pad_bottom = pad_y - pad_top
  606. return pad_left, pad_top, pad_right, pad_bottom
  607. def __call__(self, datapoint, **kwargs):
  608. if self.consistent_transform:
  609. # Check that all the images are the same size
  610. w, h = datapoint.images[0].data.size
  611. for img in datapoint.images:
  612. assert img.size == (w, h)
  613. if self.bottom_right:
  614. pad_right = self.size - w
  615. pad_bottom = self.size - h
  616. padding = (pad_right, pad_bottom)
  617. else:
  618. padding = self._sample_pad(w, h)
  619. for i in range(len(datapoint.images)):
  620. datapoint = pad(datapoint, i, padding, v2=self.v2)
  621. return datapoint
  622. for i, img in enumerate(datapoint.images):
  623. w, h = img.data.size
  624. if self.bottom_right:
  625. pad_right = self.size - w
  626. pad_bottom = self.size - h
  627. padding = (pad_right, pad_bottom)
  628. else:
  629. padding = self._sample_pad(w, h)
  630. datapoint = pad(datapoint, i, padding, v2=self.v2)
  631. return datapoint
  632. class RandomMosaicVideoAPI:
  633. def __init__(self, prob=0.15, grid_h=2, grid_w=2, use_random_hflip=False):
  634. self.prob = prob
  635. self.grid_h = grid_h
  636. self.grid_w = grid_w
  637. self.use_random_hflip = use_random_hflip
  638. def __call__(self, datapoint, **kwargs):
  639. if random.random() > self.prob:
  640. return datapoint
  641. # select a random location to place the target mask in the mosaic
  642. target_grid_y = random.randint(0, self.grid_h - 1)
  643. target_grid_x = random.randint(0, self.grid_w - 1)
  644. # whether to flip each grid in the mosaic horizontally
  645. if self.use_random_hflip:
  646. should_hflip = torch.rand(self.grid_h, self.grid_w) < 0.5
  647. else:
  648. should_hflip = torch.zeros(self.grid_h, self.grid_w, dtype=torch.bool)
  649. for i in range(len(datapoint.images)):
  650. datapoint = random_mosaic_frame(
  651. datapoint,
  652. i,
  653. grid_h=self.grid_h,
  654. grid_w=self.grid_w,
  655. target_grid_y=target_grid_y,
  656. target_grid_x=target_grid_x,
  657. should_hflip=should_hflip,
  658. )
  659. return datapoint
  660. def random_mosaic_frame(
  661. datapoint,
  662. index,
  663. grid_h,
  664. grid_w,
  665. target_grid_y,
  666. target_grid_x,
  667. should_hflip,
  668. ):
  669. # Step 1: downsize the images and paste them into a mosaic
  670. image_data = datapoint.images[index].data
  671. is_pil = isinstance(image_data, PILImage.Image)
  672. if is_pil:
  673. H_im = image_data.height
  674. W_im = image_data.width
  675. image_data_output = PILImage.new("RGB", (W_im, H_im))
  676. else:
  677. H_im = image_data.size(-2)
  678. W_im = image_data.size(-1)
  679. image_data_output = torch.zeros_like(image_data)
  680. downsize_cache = {}
  681. for grid_y in range(grid_h):
  682. for grid_x in range(grid_w):
  683. y_offset_b = grid_y * H_im // grid_h
  684. x_offset_b = grid_x * W_im // grid_w
  685. y_offset_e = (grid_y + 1) * H_im // grid_h
  686. x_offset_e = (grid_x + 1) * W_im // grid_w
  687. H_im_downsize = y_offset_e - y_offset_b
  688. W_im_downsize = x_offset_e - x_offset_b
  689. if (H_im_downsize, W_im_downsize) in downsize_cache:
  690. image_data_downsize = downsize_cache[(H_im_downsize, W_im_downsize)]
  691. else:
  692. image_data_downsize = F.resize(
  693. image_data,
  694. size=(H_im_downsize, W_im_downsize),
  695. interpolation=InterpolationMode.BILINEAR,
  696. antialias=True, # antialiasing for downsizing
  697. )
  698. downsize_cache[(H_im_downsize, W_im_downsize)] = image_data_downsize
  699. if should_hflip[grid_y, grid_x].item():
  700. image_data_downsize = F.hflip(image_data_downsize)
  701. if is_pil:
  702. image_data_output.paste(image_data_downsize, (x_offset_b, y_offset_b))
  703. else:
  704. image_data_output[:, y_offset_b:y_offset_e, x_offset_b:x_offset_e] = (
  705. image_data_downsize
  706. )
  707. datapoint.images[index].data = image_data_output
  708. # Step 2: downsize the masks and paste them into the target grid of the mosaic
  709. # (note that we don't scale input/target boxes since they are not used in TA)
  710. for obj in datapoint.images[index].objects:
  711. if obj.segment is None:
  712. continue
  713. assert obj.segment.shape == (H_im, W_im) and obj.segment.dtype == torch.uint8
  714. segment_output = torch.zeros_like(obj.segment)
  715. target_y_offset_b = target_grid_y * H_im // grid_h
  716. target_x_offset_b = target_grid_x * W_im // grid_w
  717. target_y_offset_e = (target_grid_y + 1) * H_im // grid_h
  718. target_x_offset_e = (target_grid_x + 1) * W_im // grid_w
  719. target_H_im_downsize = target_y_offset_e - target_y_offset_b
  720. target_W_im_downsize = target_x_offset_e - target_x_offset_b
  721. segment_downsize = F.resize(
  722. obj.segment[None, None],
  723. size=(target_H_im_downsize, target_W_im_downsize),
  724. interpolation=InterpolationMode.BILINEAR,
  725. antialias=True, # antialiasing for downsizing
  726. )[0, 0]
  727. if should_hflip[target_grid_y, target_grid_x].item():
  728. segment_downsize = F.hflip(segment_downsize[None, None])[0, 0]
  729. segment_output[
  730. target_y_offset_b:target_y_offset_e, target_x_offset_b:target_x_offset_e
  731. ] = segment_downsize
  732. obj.segment = segment_output
  733. return datapoint
  734. class ScheduledPadToSizeAPI(PadToSizeAPI):
  735. def __init__(self, size_scheduler, consistent_transform):
  736. self.size_scheduler = size_scheduler
  737. size = self.size_scheduler(epoch_num=0)["sizes"]
  738. super().__init__(size, consistent_transform)
  739. def __call__(self, datapoint, **kwargs):
  740. assert "epoch" in kwargs, "Param scheduler needs to know the current epoch"
  741. params = self.size_scheduler(kwargs["epoch"])
  742. self.size = params["resolution"]
  743. return super(ScheduledPadToSizeAPI, self).__call__(datapoint, **kwargs)
  744. class IdentityAPI:
  745. def __call__(self, datapoint, **kwargs):
  746. return datapoint
  747. class RandomSelectAPI:
  748. """
  749. Randomly selects between transforms1 and transforms2,
  750. with probability p for transforms1 and (1 - p) for transforms2
  751. """
  752. def __init__(self, transforms1=None, transforms2=None, p=0.5):
  753. self.transforms1 = transforms1 or IdentityAPI()
  754. self.transforms2 = transforms2 or IdentityAPI()
  755. self.p = p
  756. def __call__(self, datapoint, **kwargs):
  757. if random.random() < self.p:
  758. return self.transforms1(datapoint, **kwargs)
  759. return self.transforms2(datapoint, **kwargs)
  760. class ToTensorAPI:
  761. def __init__(self, v2=False):
  762. self.v2 = v2
  763. def __call__(self, datapoint: Datapoint, **kwargs):
  764. for img in datapoint.images:
  765. if self.v2:
  766. img.data = Fv2.to_image_tensor(img.data)
  767. # img.data = Fv2.to_dtype(img.data, torch.uint8, scale=True)
  768. # img.data = Fv2.convert_image_dtype(img.data, torch.uint8)
  769. else:
  770. img.data = F.to_tensor(img.data)
  771. return datapoint
  772. class NormalizeAPI:
  773. def __init__(self, mean, std, v2=False):
  774. self.mean = mean
  775. self.std = std
  776. self.v2 = v2
  777. def __call__(self, datapoint: Datapoint, **kwargs):
  778. for img in datapoint.images:
  779. if self.v2:
  780. img.data = Fv2.convert_image_dtype(img.data, torch.float32)
  781. img.data = Fv2.normalize(img.data, mean=self.mean, std=self.std)
  782. else:
  783. img.data = F.normalize(img.data, mean=self.mean, std=self.std)
  784. for obj in img.objects:
  785. boxes = obj.bbox
  786. cur_h, cur_w = img.data.shape[-2:]
  787. boxes = box_xyxy_to_cxcywh(boxes)
  788. boxes = boxes / torch.tensor(
  789. [cur_w, cur_h, cur_w, cur_h], dtype=torch.float32
  790. )
  791. obj.bbox = boxes
  792. for query in datapoint.find_queries:
  793. if query.input_bbox is not None:
  794. boxes = query.input_bbox
  795. cur_h, cur_w = datapoint.images[query.image_id].data.shape[-2:]
  796. boxes = box_xyxy_to_cxcywh(boxes)
  797. boxes = boxes / torch.tensor(
  798. [cur_w, cur_h, cur_w, cur_h], dtype=torch.float32
  799. )
  800. query.input_bbox = boxes
  801. if query.input_points is not None:
  802. points = query.input_points
  803. cur_h, cur_w = datapoint.images[query.image_id].data.shape[-2:]
  804. points = points / torch.tensor([cur_w, cur_h, 1.0], dtype=torch.float32)
  805. query.input_points = points
  806. return datapoint
  807. class ComposeAPI:
  808. def __init__(self, transforms):
  809. self.transforms = transforms
  810. def __call__(self, datapoint, **kwargs):
  811. for t in self.transforms:
  812. datapoint = t(datapoint, **kwargs)
  813. return datapoint
  814. def __repr__(self):
  815. format_string = self.__class__.__name__ + "("
  816. for t in self.transforms:
  817. format_string += "\n"
  818. format_string += " {0}".format(t)
  819. format_string += "\n)"
  820. return format_string
  821. class RandomGrayscale:
  822. def __init__(self, consistent_transform, p=0.5):
  823. self.p = p
  824. self.consistent_transform = consistent_transform
  825. self.Grayscale = T.Grayscale(num_output_channels=3)
  826. def __call__(self, datapoint: Datapoint, **kwargs):
  827. if self.consistent_transform:
  828. if random.random() < self.p:
  829. for img in datapoint.images:
  830. img.data = self.Grayscale(img.data)
  831. return datapoint
  832. for img in datapoint.images:
  833. if random.random() < self.p:
  834. img.data = self.Grayscale(img.data)
  835. return datapoint
  836. class ColorJitter:
  837. def __init__(self, consistent_transform, brightness, contrast, saturation, hue):
  838. self.consistent_transform = consistent_transform
  839. self.brightness = (
  840. brightness
  841. if isinstance(brightness, list)
  842. else [max(0, 1 - brightness), 1 + brightness]
  843. )
  844. self.contrast = (
  845. contrast
  846. if isinstance(contrast, list)
  847. else [max(0, 1 - contrast), 1 + contrast]
  848. )
  849. self.saturation = (
  850. saturation
  851. if isinstance(saturation, list)
  852. else [max(0, 1 - saturation), 1 + saturation]
  853. )
  854. self.hue = hue if isinstance(hue, list) or hue is None else ([-hue, hue])
  855. def __call__(self, datapoint: Datapoint, **kwargs):
  856. if self.consistent_transform:
  857. # Create a color jitter transformation params
  858. (
  859. fn_idx,
  860. brightness_factor,
  861. contrast_factor,
  862. saturation_factor,
  863. hue_factor,
  864. ) = T.ColorJitter.get_params(
  865. self.brightness, self.contrast, self.saturation, self.hue
  866. )
  867. for img in datapoint.images:
  868. if not self.consistent_transform:
  869. (
  870. fn_idx,
  871. brightness_factor,
  872. contrast_factor,
  873. saturation_factor,
  874. hue_factor,
  875. ) = T.ColorJitter.get_params(
  876. self.brightness, self.contrast, self.saturation, self.hue
  877. )
  878. for fn_id in fn_idx:
  879. if fn_id == 0 and brightness_factor is not None:
  880. img.data = F.adjust_brightness(img.data, brightness_factor)
  881. elif fn_id == 1 and contrast_factor is not None:
  882. img.data = F.adjust_contrast(img.data, contrast_factor)
  883. elif fn_id == 2 and saturation_factor is not None:
  884. img.data = F.adjust_saturation(img.data, saturation_factor)
  885. elif fn_id == 3 and hue_factor is not None:
  886. img.data = F.adjust_hue(img.data, hue_factor)
  887. return datapoint
  888. class RandomAffine:
  889. def __init__(
  890. self,
  891. degrees,
  892. consistent_transform,
  893. scale=None,
  894. translate=None,
  895. shear=None,
  896. image_mean=(123, 116, 103),
  897. log_warning=True,
  898. num_tentatives=1,
  899. image_interpolation="bicubic",
  900. ):
  901. """
  902. The mask is required for this transform.
  903. if consistent_transform if True, then the same random affine is applied to all frames and masks.
  904. """
  905. self.degrees = degrees if isinstance(degrees, list) else ([-degrees, degrees])
  906. self.scale = scale
  907. self.shear = (
  908. shear if isinstance(shear, list) else ([-shear, shear] if shear else None)
  909. )
  910. self.translate = translate
  911. self.fill_img = image_mean
  912. self.consistent_transform = consistent_transform
  913. self.log_warning = log_warning
  914. self.num_tentatives = num_tentatives
  915. if image_interpolation == "bicubic":
  916. self.image_interpolation = InterpolationMode.BICUBIC
  917. elif image_interpolation == "bilinear":
  918. self.image_interpolation = InterpolationMode.BILINEAR
  919. else:
  920. raise NotImplementedError
  921. def __call__(self, datapoint: Datapoint, **kwargs):
  922. for _tentative in range(self.num_tentatives):
  923. res = self.transform_datapoint(datapoint)
  924. if res is not None:
  925. return res
  926. if self.log_warning:
  927. logging.warning(
  928. f"Skip RandomAffine for zero-area mask in first frame after {self.num_tentatives} tentatives"
  929. )
  930. return datapoint
  931. def transform_datapoint(self, datapoint: Datapoint):
  932. _, height, width = F.get_dimensions(datapoint.images[0].data)
  933. img_size = [width, height]
  934. if self.consistent_transform:
  935. # Create a random affine transformation
  936. affine_params = T.RandomAffine.get_params(
  937. degrees=self.degrees,
  938. translate=self.translate,
  939. scale_ranges=self.scale,
  940. shears=self.shear,
  941. img_size=img_size,
  942. )
  943. for img_idx, img in enumerate(datapoint.images):
  944. this_masks = [
  945. obj.segment.unsqueeze(0) if obj.segment is not None else None
  946. for obj in img.objects
  947. ]
  948. if not self.consistent_transform:
  949. # if not consistent we create a new affine params for every frame&mask pair Create a random affine transformation
  950. affine_params = T.RandomAffine.get_params(
  951. degrees=self.degrees,
  952. translate=self.translate,
  953. scale_ranges=self.scale,
  954. shears=self.shear,
  955. img_size=img_size,
  956. )
  957. transformed_bboxes, transformed_masks = [], []
  958. for i in range(len(img.objects)):
  959. if this_masks[i] is None:
  960. transformed_masks.append(None)
  961. # Dummy bbox for a dummy target
  962. transformed_bboxes.append(torch.tensor([[0, 0, 0, 0]]))
  963. else:
  964. transformed_mask = F.affine(
  965. this_masks[i],
  966. *affine_params,
  967. interpolation=InterpolationMode.NEAREST,
  968. fill=0.0,
  969. )
  970. if img_idx == 0 and transformed_mask.max() == 0:
  971. # We are dealing with a video and the object is not visible in the first frame
  972. # Return the datapoint without transformation
  973. return None
  974. transformed_bbox = masks_to_boxes(transformed_mask)
  975. transformed_bboxes.append(transformed_bbox)
  976. transformed_masks.append(transformed_mask.squeeze())
  977. for i in range(len(img.objects)):
  978. img.objects[i].bbox = transformed_bboxes[i]
  979. img.objects[i].segment = transformed_masks[i]
  980. img.data = F.affine(
  981. img.data,
  982. *affine_params,
  983. interpolation=self.image_interpolation,
  984. fill=self.fill_img,
  985. )
  986. return datapoint
  987. class RandomResizedCrop:
  988. def __init__(
  989. self,
  990. consistent_transform,
  991. size,
  992. scale=None,
  993. ratio=None,
  994. log_warning=True,
  995. num_tentatives=4,
  996. keep_aspect_ratio=False,
  997. ):
  998. """
  999. The mask is required for this transform.
  1000. if consistent_transform if True, then the same random resized crop is applied to all frames and masks.
  1001. """
  1002. if isinstance(size, numbers.Number):
  1003. self.size = (int(size), int(size))
  1004. elif isinstance(size, Sequence) and len(size) == 1:
  1005. self.size = (size[0], size[0])
  1006. elif len(size) != 2:
  1007. raise ValueError("Please provide only two dimensions (h, w) for size.")
  1008. else:
  1009. self.size = size
  1010. self.scale = scale if scale is not None else (0.08, 1.0)
  1011. self.ratio = ratio if ratio is not None else (3.0 / 4.0, 4.0 / 3.0)
  1012. self.consistent_transform = consistent_transform
  1013. self.log_warning = log_warning
  1014. self.num_tentatives = num_tentatives
  1015. self.keep_aspect_ratio = keep_aspect_ratio
  1016. def __call__(self, datapoint: Datapoint, **kwargs):
  1017. for _tentative in range(self.num_tentatives):
  1018. res = self.transform_datapoint(datapoint)
  1019. if res is not None:
  1020. return res
  1021. if self.log_warning:
  1022. logging.warning(
  1023. f"Skip RandomResizeCrop for zero-area mask in first frame after {self.num_tentatives} tentatives"
  1024. )
  1025. return datapoint
  1026. def transform_datapoint(self, datapoint: Datapoint):
  1027. if self.keep_aspect_ratio:
  1028. original_size = datapoint.images[0].size
  1029. original_ratio = original_size[1] / original_size[0]
  1030. ratio = [r * original_ratio for r in self.ratio]
  1031. else:
  1032. ratio = self.ratio
  1033. if self.consistent_transform:
  1034. # Create a random crop transformation
  1035. crop_params = T.RandomResizedCrop.get_params(
  1036. img=datapoint.images[0].data,
  1037. scale=self.scale,
  1038. ratio=ratio,
  1039. )
  1040. for img_idx, img in enumerate(datapoint.images):
  1041. if not self.consistent_transform:
  1042. # Create a random crop transformation
  1043. crop_params = T.RandomResizedCrop.get_params(
  1044. img=img.data,
  1045. scale=self.scale,
  1046. ratio=ratio,
  1047. )
  1048. this_masks = [
  1049. obj.segment.unsqueeze(0) if obj.segment is not None else None
  1050. for obj in img.objects
  1051. ]
  1052. transformed_bboxes, transformed_masks = [], []
  1053. for i in range(len(img.objects)):
  1054. if this_masks[i] is None:
  1055. transformed_masks.append(None)
  1056. # Dummy bbox for a dummy target
  1057. transformed_bboxes.append(torch.tensor([[0, 0, 0, 0]]))
  1058. else:
  1059. transformed_mask = F.resized_crop(
  1060. this_masks[i],
  1061. *crop_params,
  1062. size=self.size,
  1063. interpolation=InterpolationMode.NEAREST,
  1064. )
  1065. if img_idx == 0 and transformed_mask.max() == 0:
  1066. # We are dealing with a video and the object is not visible in the first frame
  1067. # Return the datapoint without transformation
  1068. return None
  1069. transformed_masks.append(transformed_mask.squeeze())
  1070. transformed_bbox = masks_to_boxes(transformed_mask)
  1071. transformed_bboxes.append(transformed_bbox)
  1072. # Set the new boxes and masks if all transformed masks and boxes are good.
  1073. for i in range(len(img.objects)):
  1074. img.objects[i].bbox = transformed_bboxes[i]
  1075. img.objects[i].segment = transformed_masks[i]
  1076. img.data = F.resized_crop(
  1077. img.data,
  1078. *crop_params,
  1079. size=self.size,
  1080. interpolation=InterpolationMode.BILINEAR,
  1081. )
  1082. return datapoint
  1083. class ResizeToMaxIfAbove:
  1084. # Resize datapoint image if one of its sides is larger that max_size
  1085. def __init__(
  1086. self,
  1087. max_size=None,
  1088. ):
  1089. self.max_size = max_size
  1090. def __call__(self, datapoint: Datapoint, **kwargs):
  1091. _, height, width = F.get_dimensions(datapoint.images[0].data)
  1092. if height <= self.max_size and width <= self.max_size:
  1093. # The original frames are small enough
  1094. return datapoint
  1095. elif height >= width:
  1096. new_height = self.max_size
  1097. new_width = int(round(self.max_size * width / height))
  1098. else:
  1099. new_height = int(round(self.max_size * height / width))
  1100. new_width = self.max_size
  1101. size = new_height, new_width
  1102. for index in range(len(datapoint.images)):
  1103. datapoint.images[index].data = F.resize(datapoint.images[index].data, size)
  1104. for obj in datapoint.images[index].objects:
  1105. obj.segment = F.resize(
  1106. obj.segment[None, None],
  1107. size,
  1108. interpolation=InterpolationMode.NEAREST,
  1109. ).squeeze()
  1110. h, w = size
  1111. datapoint.images[index].size = (h, w)
  1112. return datapoint
  1113. def get_bbox_xyxy_abs_coords_from_mask(mask):
  1114. """Get the bounding box (XYXY format w/ absolute coordinates) of a binary mask."""
  1115. assert mask.dim() == 2
  1116. rows = torch.any(mask, dim=1)
  1117. cols = torch.any(mask, dim=0)
  1118. row_inds = rows.nonzero().view(-1)
  1119. col_inds = cols.nonzero().view(-1)
  1120. if row_inds.numel() == 0:
  1121. # mask is empty
  1122. bbox = torch.zeros(1, 4, dtype=torch.float32)
  1123. bbox_area = 0.0
  1124. else:
  1125. ymin, ymax = row_inds.min(), row_inds.max()
  1126. xmin, xmax = col_inds.min(), col_inds.max()
  1127. bbox = torch.tensor([xmin, ymin, xmax, ymax], dtype=torch.float32).view(1, 4)
  1128. bbox_area = float((ymax - ymin) * (xmax - xmin))
  1129. return bbox, bbox_area
  1130. class MotionBlur:
  1131. def __init__(self, kernel_size=5, consistent_transform=True, p=0.5):
  1132. assert kernel_size % 2 == 1, "Kernel size must be odd."
  1133. self.kernel_size = kernel_size
  1134. self.consistent_transform = consistent_transform
  1135. self.p = p
  1136. def __call__(self, datapoint: Datapoint, **kwargs):
  1137. if random.random() >= self.p:
  1138. return datapoint
  1139. if self.consistent_transform:
  1140. # Generate a single motion blur kernel for all images
  1141. kernel = self._generate_motion_blur_kernel()
  1142. for img in datapoint.images:
  1143. if not self.consistent_transform:
  1144. # Generate a new motion blur kernel for each image
  1145. kernel = self._generate_motion_blur_kernel()
  1146. img.data = self._apply_motion_blur(img.data, kernel)
  1147. return datapoint
  1148. def _generate_motion_blur_kernel(self):
  1149. kernel = torch.zeros((self.kernel_size, self.kernel_size))
  1150. direction = random.choice(["horizontal", "vertical", "diagonal"])
  1151. if direction == "horizontal":
  1152. kernel[self.kernel_size // 2, :] = 1.0
  1153. elif direction == "vertical":
  1154. kernel[:, self.kernel_size // 2] = 1.0
  1155. elif direction == "diagonal":
  1156. for i in range(self.kernel_size):
  1157. kernel[i, i] = 1.0
  1158. kernel /= kernel.sum()
  1159. return kernel
  1160. def _apply_motion_blur(self, image, kernel):
  1161. if isinstance(image, PILImage.Image):
  1162. image = F.to_tensor(image)
  1163. channels = image.shape[0]
  1164. kernel = kernel.to(image.device).unsqueeze(0).unsqueeze(0)
  1165. blurred_image = torch.nn.functional.conv2d(
  1166. image.unsqueeze(0),
  1167. kernel.repeat(channels, 1, 1, 1),
  1168. padding=self.kernel_size // 2,
  1169. groups=channels,
  1170. )
  1171. return F.to_pil_image(blurred_image.squeeze(0))
  1172. class LargeScaleJitter:
  1173. def __init__(
  1174. self,
  1175. scale_range=(0.1, 2.0),
  1176. aspect_ratio_range=(0.75, 1.33),
  1177. crop_size=(640, 640),
  1178. consistent_transform=True,
  1179. p=0.5,
  1180. ):
  1181. """
  1182. Args:rack
  1183. scale_range (tuple): Range of scaling factors (min_scale, max_scale).
  1184. aspect_ratio_range (tuple): Range of aspect ratios (min_aspect_ratio, max_aspect_ratio).
  1185. crop_size (tuple): Target size of the cropped region (width, height).
  1186. consistent_transform (bool): Whether to apply the same transformation across all frames.
  1187. p (float): Probability of applying the transformation.
  1188. """
  1189. self.scale_range = scale_range
  1190. self.aspect_ratio_range = aspect_ratio_range
  1191. self.crop_size = crop_size
  1192. self.consistent_transform = consistent_transform
  1193. self.p = p
  1194. def __call__(self, datapoint: Datapoint, **kwargs):
  1195. if random.random() >= self.p:
  1196. return datapoint
  1197. # Sample a single scale factor and aspect ratio for all frames
  1198. log_ratio = torch.log(torch.tensor(self.aspect_ratio_range))
  1199. scale_factor = torch.empty(1).uniform_(*self.scale_range).item()
  1200. aspect_ratio = torch.exp(
  1201. torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
  1202. ).item()
  1203. for idx, img in enumerate(datapoint.images):
  1204. if not self.consistent_transform:
  1205. # Sample a new scale factor and aspect ratio for each frame
  1206. log_ratio = torch.log(torch.tensor(self.aspect_ratio_range))
  1207. scale_factor = torch.empty(1).uniform_(*self.scale_range).item()
  1208. aspect_ratio = torch.exp(
  1209. torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
  1210. ).item()
  1211. # Compute the dimensions of the jittered crop
  1212. original_width, original_height = img.data.size
  1213. target_area = original_width * original_height * scale_factor
  1214. crop_width = int(round((target_area * aspect_ratio) ** 0.5))
  1215. crop_height = int(round((target_area / aspect_ratio) ** 0.5))
  1216. # Randomly select the top-left corner of the crop
  1217. crop_x = random.randint(0, max(0, original_width - crop_width))
  1218. crop_y = random.randint(0, max(0, original_height - crop_height))
  1219. # Extract the cropped region
  1220. datapoint = crop(datapoint, idx, (crop_x, crop_y, crop_width, crop_height))
  1221. # Resize the cropped region to the target crop size
  1222. datapoint = resize(datapoint, idx, self.crop_size)
  1223. return datapoint