| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210 |
- # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
- # pyre-unsafe
- """
- Misc functions, including distributed helpers.
- """
- import collections
- import re
- from dataclasses import dataclass, field as field_ptr_behaviour, fields, is_dataclass
- from typing import Any, get_args, get_origin, List, Mapping, Optional, Sequence, Union
- import torch
- MyTensor = Union[torch.Tensor, List[Any]]
- def interpolate(
- input, size=None, scale_factor=None, mode="nearest", align_corners=None
- ):
- # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
- """
- Equivalent to nn.functional.interpolate, but with support for empty channel sizes.
- """
- if input.numel() > 0:
- return torch.nn.functional.interpolate(
- input, size, scale_factor, mode, align_corners
- )
- assert input.shape[0] != 0 or input.shape[1] != 0, (
- "At least one of the two first dimensions must be non zero"
- )
- if input.shape[1] == 0:
- # Pytorch doesn't support null dimension on the channel dimension, so we transpose to fake a null batch dim
- return torch.nn.functional.interpolate(
- input.transpose(0, 1), size, scale_factor, mode, align_corners
- ).transpose(0, 1)
- # empty batch dimension is now supported in pytorch
- return torch.nn.functional.interpolate(
- input, size, scale_factor, mode, align_corners
- )
- @dataclass
- class BatchedPointer:
- stage_ids: MyTensor
- stage_ids__type = torch.long
- query_ids: MyTensor
- query_ids__type = torch.long
- object_ids: MyTensor
- object_ids__type = torch.long
- ptr_mask: MyTensor
- ptr_mask__type = torch.bool
- ptr_types: MyTensor
- ptr_types__type = torch.long
- @dataclass
- class FindStage:
- img_ids: MyTensor
- img_ids__type = torch.long
- text_ids: MyTensor
- text_ids__type = torch.long
- input_boxes: MyTensor
- input_boxes__type = torch.float
- input_boxes_mask: MyTensor
- input_boxes_mask__type = torch.bool
- input_boxes_label: MyTensor
- input_boxes_label__type = torch.long
- input_points: MyTensor
- input_points__type = torch.float
- input_points_mask: MyTensor
- input_points_mask__type = torch.bool
- # We track the object ids referred to by this query.
- # This is beneficial for tracking in videos without the need for pointers.
- object_ids: Optional[List[List]] = None # List of objects per query
- @dataclass
- class BatchedFindTarget:
- # The number of boxes in each find query
- num_boxes: MyTensor
- num_boxes__type = torch.long
- # Target boxes in normalized CxCywh format
- boxes: MyTensor
- boxes__type = torch.float
- # Target boxes in normalized CxCywh format but in padded representation
- # as used in BinaryHungarianMatcherV2 (unlike the packed ones in `boxes`)
- boxes_padded: MyTensor
- boxes_padded__type = torch.float
- # For hybrid matching, we repeat the boxes
- repeated_boxes: MyTensor
- repeated_boxes__type = torch.float
- # Target Segmentation masks
- segments: Optional[MyTensor]
- segments__type = torch.bool
- # Target Semantic Segmentation masks
- semantic_segments: Optional[MyTensor]
- semantic_segments__type = torch.bool
- is_valid_segment: Optional[MyTensor]
- is_valid_segment__type = torch.bool
- # Whether annotations are exhaustive for each query
- is_exhaustive: MyTensor
- is_exhaustive__type = torch.bool
- # The object id for each ground-truth box, in both packed and padded representations
- object_ids: MyTensor
- object_ids__type = torch.long
- object_ids_padded: MyTensor
- object_ids_padded__type = torch.long
- @dataclass
- class BatchedInferenceMetadata:
- """All metadata required to post-process a find stage"""
- # Coco id that corresponds to the "image" for evaluation by the coco evaluator
- coco_image_id: MyTensor
- coco_image_id__type = torch.long
- # id in the original dataset, such that we can use the original evaluator
- original_image_id: MyTensor
- original_image_id__type = torch.long
- # Original category id (if we want to use the original evaluator)
- original_category_id: MyTensor
- original_category_id__type = torch.int
- # Size of the raw image (height, width)
- original_size: MyTensor
- original_size__type = torch.long
- # id of the object in the media (track_id for a video)
- object_id: MyTensor
- object_id__type = torch.long
- # index of the frame in the media (0 in the case of a single-frame media)
- frame_index: MyTensor
- frame_index__type = torch.long
- # Adding for relations inference
- # get_text_input: List[Optional[str]]
- # Adding for TA conditional inference
- is_conditioning_only: List[Optional[bool]]
- @dataclass
- class BatchedDatapoint:
- img_batch: torch.Tensor
- find_text_batch: List[str]
- find_inputs: List[FindStage]
- find_targets: List[BatchedFindTarget]
- find_metadatas: List[BatchedInferenceMetadata]
- raw_images: Optional[List[Any]] = None
- def convert_my_tensors(obj):
- def is_optional_field(field) -> bool:
- return get_origin(field) is Union and type(None) in get_args(field)
- for field in fields(obj):
- if is_dataclass(getattr(obj, field.name)):
- convert_my_tensors(getattr(obj, field.name))
- continue
- field_type = field.type
- if is_optional_field(field.type):
- field_type = Union[get_args(field.type)[:-1]] # Get the Optional field type
- if field_type != MyTensor or getattr(obj, field.name) is None:
- continue
- elif len(getattr(obj, field.name)) and isinstance(
- getattr(obj, field.name)[0], torch.Tensor
- ):
- stack_dim = 0
- if field.name in [
- "input_boxes",
- "input_boxes_label",
- ]:
- stack_dim = 1
- setattr(
- obj,
- field.name,
- torch.stack(getattr(obj, field.name), dim=stack_dim).to(
- getattr(obj, field.name + "__type")
- ),
- )
- else:
- setattr(
- obj,
- field.name,
- torch.as_tensor(
- getattr(obj, field.name), dtype=getattr(obj, field.name + "__type")
- ),
- )
- return obj
|