collator.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. from dataclasses import dataclass, field as field_ptr_behaviour, fields, is_dataclass
  4. from typing import Any, get_args, get_origin, List, Union
  5. import torch
  6. from sam3.model.data_misc import (
  7. BatchedDatapoint,
  8. BatchedFindTarget,
  9. BatchedInferenceMetadata,
  10. FindStage,
  11. )
  12. from .sam3_image_dataset import Datapoint
  13. MyTensor = Union[torch.Tensor, List[Any]]
  14. def convert_my_tensors(obj):
  15. def is_optional_field(field) -> bool:
  16. return get_origin(field) is Union and type(None) in get_args(field)
  17. for field in fields(obj):
  18. if is_dataclass(getattr(obj, field.name)):
  19. convert_my_tensors(getattr(obj, field.name))
  20. continue
  21. field_type = field.type
  22. if is_optional_field(field.type):
  23. field_type = Union[get_args(field.type)[:-1]] # Get the Optional field type
  24. if field_type != MyTensor or getattr(obj, field.name) is None:
  25. continue
  26. elif len(getattr(obj, field.name)) and isinstance(
  27. getattr(obj, field.name)[0], torch.Tensor
  28. ):
  29. stack_dim = 0
  30. if field.name in [
  31. "input_boxes",
  32. "input_boxes_label",
  33. ]:
  34. stack_dim = 1
  35. setattr(
  36. obj,
  37. field.name,
  38. torch.stack(getattr(obj, field.name), dim=stack_dim).to(
  39. getattr(obj, field.name + "__type")
  40. ),
  41. )
  42. else:
  43. setattr(
  44. obj,
  45. field.name,
  46. torch.as_tensor(
  47. getattr(obj, field.name), dtype=getattr(obj, field.name + "__type")
  48. ),
  49. )
  50. return obj
  51. def packed_to_padded_naive(boxes_packed, num_boxes, fill_value=0):
  52. """
  53. Convert a packed tensor of bounding boxes to a padded tensor of bounding
  54. boxes. Naive implementation using a loop.
  55. Inputs:
  56. - boxes_packed: Tensor of shape (N_1 + ... + N_B, 4)
  57. - num_boxes: Tensor of shape (B,) where num_boxes[i] = N_i
  58. Returns:
  59. - boxes_padded: Tensor of shape (B, N_max, 4) where N_max = max_i N_i
  60. """
  61. B = num_boxes.shape[0]
  62. Ns = num_boxes.tolist()
  63. boxes_padded = boxes_packed.new_zeros(B, max(Ns), *boxes_packed.shape[1:])
  64. if fill_value != 0:
  65. boxes_padded[...] = fill_value
  66. prev_idx = 0
  67. for i in range(B):
  68. next_idx = prev_idx + Ns[i]
  69. boxes_padded[i, : Ns[i]] = boxes_packed[prev_idx:next_idx]
  70. prev_idx = next_idx
  71. return boxes_padded
  72. def pad_tensor_list_to_longest(
  73. tensors: List[torch.Tensor], dim=0, pad_val=0
  74. ) -> List[torch.Tensor]:
  75. # Edits the list in-place
  76. if not tensors:
  77. return tensors
  78. pad_len = max(t.shape[dim] for t in tensors)
  79. for i in range(len(tensors)):
  80. n_dims = len(tensors[i].shape)
  81. n_right_dims = (n_dims - 1) - (n_dims + dim) % n_dims
  82. n_pad = pad_len - tensors[i].shape[dim]
  83. pad_tuple = tuple([0] * 2 * n_right_dims + [0, n_pad])
  84. tensors[i] = torch.nn.functional.pad(tensors[i], pad_tuple, value=pad_val)
  85. return tensors
  86. def collate_fn_api_with_chunking(
  87. batch,
  88. num_chunks,
  89. dict_key,
  90. with_seg_masks=False,
  91. input_points_embedding_dim=257,
  92. repeats: int = 0,
  93. load_image_in_fp16: bool = False,
  94. ):
  95. assert num_chunks >= 1, "num_chunks must be >= 1"
  96. # split the batch into num_chunks chunks
  97. batch_chunks = [batch[i::num_chunks] for i in range(num_chunks)]
  98. # collate each chunk
  99. collated_chunks = [
  100. collate_fn_api(
  101. chunk,
  102. dict_key,
  103. with_seg_masks,
  104. input_points_embedding_dim,
  105. repeats,
  106. # ptr_behaviour,
  107. load_image_in_fp16,
  108. )
  109. for chunk in batch_chunks
  110. ]
  111. return collated_chunks
  112. def collate_fn_api(
  113. batch: List[Datapoint],
  114. dict_key,
  115. with_seg_masks=False,
  116. input_points_embedding_dim=257,
  117. repeats: int = 0,
  118. load_image_in_fp16: bool = False,
  119. ):
  120. # img_batch = torch.stack(sum([[img.data for img in v.images] for v in batch], []))
  121. img_batch = []
  122. text_batch = []
  123. raw_images = None
  124. num_stages = (
  125. max(q.query_processing_order for data in batch for q in data.find_queries) + 1
  126. )
  127. stages = [
  128. FindStage(
  129. img_ids=[],
  130. text_ids=[],
  131. input_boxes=[],
  132. input_boxes_label=[],
  133. input_boxes_mask=[],
  134. input_points=[],
  135. input_points_mask=[],
  136. object_ids=[],
  137. )
  138. for _ in range(num_stages)
  139. ]
  140. find_targets = [
  141. BatchedFindTarget(
  142. num_boxes=[],
  143. boxes=[],
  144. boxes_padded=[],
  145. is_exhaustive=[],
  146. segments=[],
  147. semantic_segments=[],
  148. is_valid_segment=[],
  149. repeated_boxes=[],
  150. object_ids=[],
  151. object_ids_padded=[],
  152. )
  153. for _ in range(num_stages)
  154. ]
  155. find_metadatas = [
  156. BatchedInferenceMetadata(
  157. coco_image_id=[],
  158. original_size=[],
  159. object_id=[],
  160. frame_index=[],
  161. original_image_id=[],
  162. original_category_id=[],
  163. is_conditioning_only=[],
  164. )
  165. for _ in range(num_stages)
  166. ]
  167. offset_img_id = 0
  168. offset_query_id = [0 for _ in range(num_stages)]
  169. for i, data in enumerate(batch):
  170. img_batch.extend([img.data for img in data.images])
  171. if data.raw_images is not None:
  172. if raw_images is None:
  173. raw_images = []
  174. raw_images.extend(data.raw_images)
  175. # Conversion of query_ids indexing in a datapoint to query_ids indexing in a stage
  176. datapoint_query_id_2_stage_query_id = []
  177. for q in data.find_queries:
  178. stage_id = q.query_processing_order
  179. datapoint_query_id_2_stage_query_id.append(offset_query_id[stage_id])
  180. offset_query_id[stage_id] += 1
  181. for j, q in enumerate(data.find_queries):
  182. stage_id = q.query_processing_order
  183. stages[stage_id].img_ids.append(q.image_id + offset_img_id)
  184. if q.query_text not in text_batch:
  185. text_batch.append(q.query_text)
  186. stages[stage_id].text_ids.append(text_batch.index(q.query_text))
  187. assert q.inference_metadata is not None, (
  188. "inference_metadata must be provided when FindQueryLoaded is created."
  189. )
  190. for f in fields(q.inference_metadata):
  191. getattr(find_metadatas[stage_id], f.name).append(
  192. getattr(q.inference_metadata, f.name)
  193. )
  194. if q.input_bbox is not None:
  195. assert q.input_bbox.numel() % 4 == 0
  196. assert q.input_bbox_label is not None
  197. nb_boxes = q.input_bbox.numel() // 4
  198. assert len(q.input_bbox_label) == nb_boxes
  199. stages[stage_id].input_boxes.append(q.input_bbox.view(nb_boxes, 4))
  200. stages[stage_id].input_boxes_label.append(
  201. q.input_bbox_label.view(nb_boxes)
  202. )
  203. stages[stage_id].input_boxes_mask.append(
  204. torch.zeros(nb_boxes, dtype=torch.bool)
  205. )
  206. else:
  207. stages[stage_id].input_boxes.append(torch.zeros(0, 4))
  208. stages[stage_id].input_boxes_label.append(
  209. torch.zeros(0, dtype=torch.bool)
  210. )
  211. stages[stage_id].input_boxes_mask.append(
  212. torch.ones(0, dtype=torch.bool)
  213. )
  214. if q.input_points is not None:
  215. stages[stage_id].input_points.append(
  216. q.input_points.squeeze(0) # Strip a trivial batch index
  217. )
  218. # All masks will be padded up to the longest length
  219. # with 1s before final conversion to batchd tensors
  220. stages[stage_id].input_points_mask.append(
  221. torch.zeros(q.input_points.shape[1])
  222. )
  223. else:
  224. stages[stage_id].input_points.append(
  225. torch.empty(0, input_points_embedding_dim)
  226. )
  227. stages[stage_id].input_points_mask.append(torch.empty(0))
  228. current_out_boxes = []
  229. current_out_object_ids = []
  230. # Set the object ids referred to by this query
  231. stages[stage_id].object_ids.append(q.object_ids_output)
  232. for object_id in q.object_ids_output:
  233. current_out_boxes.append(
  234. data.images[q.image_id].objects[object_id].bbox
  235. )
  236. current_out_object_ids.append(object_id)
  237. find_targets[stage_id].boxes.extend(current_out_boxes)
  238. find_targets[stage_id].object_ids.extend(current_out_object_ids)
  239. if repeats > 0:
  240. for _ in range(repeats):
  241. find_targets[stage_id].repeated_boxes.extend(current_out_boxes)
  242. find_targets[stage_id].num_boxes.append(len(current_out_boxes))
  243. find_targets[stage_id].is_exhaustive.append(q.is_exhaustive)
  244. if with_seg_masks:
  245. current_seg_mask = []
  246. current_is_valid_segment = []
  247. for object_id in q.object_ids_output:
  248. seg_mask = data.images[q.image_id].objects[object_id].segment
  249. if seg_mask is not None:
  250. current_seg_mask.append(seg_mask)
  251. current_is_valid_segment.append(1)
  252. else:
  253. dummy_mask = torch.zeros(
  254. data.images[q.image_id].data.shape[-2:], dtype=torch.bool
  255. )
  256. current_seg_mask.append(dummy_mask)
  257. current_is_valid_segment.append(0)
  258. find_targets[stage_id].segments.extend(current_seg_mask)
  259. find_targets[stage_id].is_valid_segment.extend(current_is_valid_segment)
  260. else:
  261. # We are not loading segmentation masks
  262. find_targets[stage_id].segments = None
  263. find_targets[stage_id].is_valid_segment = None
  264. if q.semantic_target is not None:
  265. find_targets[stage_id].semantic_segments.append(q.semantic_target)
  266. offset_img_id += len(data.images)
  267. # Pad input points to equal sequence lengths
  268. for i in range(len(stages)):
  269. stages[i].input_points = pad_tensor_list_to_longest(
  270. stages[i].input_points, dim=0, pad_val=0
  271. )
  272. # Masked-out regions indicated by 1s.
  273. stages[i].input_points_mask = pad_tensor_list_to_longest(
  274. stages[i].input_points_mask, dim=0, pad_val=1
  275. )
  276. # Pad input boxes to equal sequence lengths
  277. for i in range(len(stages)):
  278. stages[i].input_boxes = pad_tensor_list_to_longest(
  279. stages[i].input_boxes, dim=0, pad_val=0
  280. )
  281. stages[i].input_boxes_label = pad_tensor_list_to_longest(
  282. stages[i].input_boxes_label, dim=0, pad_val=0
  283. )
  284. # Masked-out regions indicated by 1s.
  285. stages[i].input_boxes_mask = pad_tensor_list_to_longest(
  286. stages[i].input_boxes_mask, dim=0, pad_val=1
  287. )
  288. # Convert to tensors
  289. for i in range(len(stages)):
  290. stages[i] = convert_my_tensors(stages[i])
  291. find_targets[i] = convert_my_tensors(find_targets[i])
  292. find_metadatas[i] = convert_my_tensors(find_metadatas[i])
  293. # get padded representation for the boxes
  294. find_targets[i].boxes_padded = packed_to_padded_naive(
  295. find_targets[i].boxes.view(-1, 4), find_targets[i].num_boxes
  296. )
  297. find_targets[i].object_ids_padded = packed_to_padded_naive(
  298. find_targets[i].object_ids, find_targets[i].num_boxes, fill_value=-1
  299. )
  300. # Finalize the image batch
  301. # check sizes
  302. for img in img_batch[1:]:
  303. assert img.shape == img_batch[0].shape, "All images must have the same size"
  304. image_batch = torch.stack(img_batch)
  305. if load_image_in_fp16:
  306. # Optionally, cast the image tensors to fp16, which helps save GPU memory on
  307. # long videos with thousands of frames (where image tensors could be several GBs)
  308. image_batch = image_batch.half()
  309. return {
  310. dict_key: BatchedDatapoint(
  311. img_batch=image_batch,
  312. find_text_batch=text_batch,
  313. find_inputs=stages,
  314. find_targets=find_targets,
  315. find_metadatas=find_metadatas,
  316. raw_images=raw_images,
  317. )
  318. }