data_misc.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. """
  4. Misc functions, including distributed helpers.
  5. """
  6. import collections
  7. import re
  8. from dataclasses import dataclass, field as field_ptr_behaviour, fields, is_dataclass
  9. from typing import Any, get_args, get_origin, List, Mapping, Optional, Sequence, Union
  10. import torch
  11. MyTensor = Union[torch.Tensor, List[Any]]
  12. def interpolate(
  13. input, size=None, scale_factor=None, mode="nearest", align_corners=None
  14. ):
  15. # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
  16. """
  17. Equivalent to nn.functional.interpolate, but with support for empty channel sizes.
  18. """
  19. if input.numel() > 0:
  20. return torch.nn.functional.interpolate(
  21. input, size, scale_factor, mode, align_corners
  22. )
  23. assert input.shape[0] != 0 or input.shape[1] != 0, (
  24. "At least one of the two first dimensions must be non zero"
  25. )
  26. if input.shape[1] == 0:
  27. # Pytorch doesn't support null dimension on the channel dimension, so we transpose to fake a null batch dim
  28. return torch.nn.functional.interpolate(
  29. input.transpose(0, 1), size, scale_factor, mode, align_corners
  30. ).transpose(0, 1)
  31. # empty batch dimension is now supported in pytorch
  32. return torch.nn.functional.interpolate(
  33. input, size, scale_factor, mode, align_corners
  34. )
  35. @dataclass
  36. class BatchedPointer:
  37. stage_ids: MyTensor
  38. stage_ids__type = torch.long
  39. query_ids: MyTensor
  40. query_ids__type = torch.long
  41. object_ids: MyTensor
  42. object_ids__type = torch.long
  43. ptr_mask: MyTensor
  44. ptr_mask__type = torch.bool
  45. ptr_types: MyTensor
  46. ptr_types__type = torch.long
  47. @dataclass
  48. class FindStage:
  49. img_ids: MyTensor
  50. img_ids__type = torch.long
  51. text_ids: MyTensor
  52. text_ids__type = torch.long
  53. input_boxes: MyTensor
  54. input_boxes__type = torch.float
  55. input_boxes_mask: MyTensor
  56. input_boxes_mask__type = torch.bool
  57. input_boxes_label: MyTensor
  58. input_boxes_label__type = torch.long
  59. input_points: MyTensor
  60. input_points__type = torch.float
  61. input_points_mask: MyTensor
  62. input_points_mask__type = torch.bool
  63. # We track the object ids referred to by this query.
  64. # This is beneficial for tracking in videos without the need for pointers.
  65. object_ids: Optional[List[List]] = None # List of objects per query
  66. @dataclass
  67. class BatchedFindTarget:
  68. # The number of boxes in each find query
  69. num_boxes: MyTensor
  70. num_boxes__type = torch.long
  71. # Target boxes in normalized CxCywh format
  72. boxes: MyTensor
  73. boxes__type = torch.float
  74. # Target boxes in normalized CxCywh format but in padded representation
  75. # as used in BinaryHungarianMatcherV2 (unlike the packed ones in `boxes`)
  76. boxes_padded: MyTensor
  77. boxes_padded__type = torch.float
  78. # For hybrid matching, we repeat the boxes
  79. repeated_boxes: MyTensor
  80. repeated_boxes__type = torch.float
  81. # Target Segmentation masks
  82. segments: Optional[MyTensor]
  83. segments__type = torch.bool
  84. # Target Semantic Segmentation masks
  85. semantic_segments: Optional[MyTensor]
  86. semantic_segments__type = torch.bool
  87. is_valid_segment: Optional[MyTensor]
  88. is_valid_segment__type = torch.bool
  89. # Whether annotations are exhaustive for each query
  90. is_exhaustive: MyTensor
  91. is_exhaustive__type = torch.bool
  92. # The object id for each ground-truth box, in both packed and padded representations
  93. object_ids: MyTensor
  94. object_ids__type = torch.long
  95. object_ids_padded: MyTensor
  96. object_ids_padded__type = torch.long
  97. @dataclass
  98. class BatchedInferenceMetadata:
  99. """All metadata required to post-process a find stage"""
  100. # Coco id that corresponds to the "image" for evaluation by the coco evaluator
  101. coco_image_id: MyTensor
  102. coco_image_id__type = torch.long
  103. # id in the original dataset, such that we can use the original evaluator
  104. original_image_id: MyTensor
  105. original_image_id__type = torch.long
  106. # Original category id (if we want to use the original evaluator)
  107. original_category_id: MyTensor
  108. original_category_id__type = torch.int
  109. # Size of the raw image (height, width)
  110. original_size: MyTensor
  111. original_size__type = torch.long
  112. # id of the object in the media (track_id for a video)
  113. object_id: MyTensor
  114. object_id__type = torch.long
  115. # index of the frame in the media (0 in the case of a single-frame media)
  116. frame_index: MyTensor
  117. frame_index__type = torch.long
  118. # Adding for relations inference
  119. # get_text_input: List[Optional[str]]
  120. # Adding for TA conditional inference
  121. is_conditioning_only: List[Optional[bool]]
  122. @dataclass
  123. class BatchedDatapoint:
  124. img_batch: torch.Tensor
  125. find_text_batch: List[str]
  126. find_inputs: List[FindStage]
  127. find_targets: List[BatchedFindTarget]
  128. find_metadatas: List[BatchedInferenceMetadata]
  129. raw_images: Optional[List[Any]] = None
  130. def convert_my_tensors(obj):
  131. def is_optional_field(field) -> bool:
  132. return get_origin(field) is Union and type(None) in get_args(field)
  133. for field in fields(obj):
  134. if is_dataclass(getattr(obj, field.name)):
  135. convert_my_tensors(getattr(obj, field.name))
  136. continue
  137. field_type = field.type
  138. if is_optional_field(field.type):
  139. field_type = Union[get_args(field.type)[:-1]] # Get the Optional field type
  140. if field_type != MyTensor or getattr(obj, field.name) is None:
  141. continue
  142. elif len(getattr(obj, field.name)) and isinstance(
  143. getattr(obj, field.name)[0], torch.Tensor
  144. ):
  145. stack_dim = 0
  146. if field.name in [
  147. "input_boxes",
  148. "input_boxes_label",
  149. ]:
  150. stack_dim = 1
  151. setattr(
  152. obj,
  153. field.name,
  154. torch.stack(getattr(obj, field.name), dim=stack_dim).to(
  155. getattr(obj, field.name + "__type")
  156. ),
  157. )
  158. else:
  159. setattr(
  160. obj,
  161. field.name,
  162. torch.as_tensor(
  163. getattr(obj, field.name), dtype=getattr(obj, field.name + "__type")
  164. ),
  165. )
  166. return obj