| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156 |
- # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
- # pyre-unsafe
- import numpy as np
- import pycocotools.mask as mask_utils
- import torch
- import torchvision.transforms.functional as F
- from PIL import Image as PILImage
- from sam3.model.box_ops import masks_to_boxes
- from sam3.train.data.sam3_image_dataset import Datapoint
- class InstanceToSemantic(object):
- """Convert instance segmentation to semantic segmentation."""
- def __init__(self, delete_instance=True, use_rle=False):
- self.delete_instance = delete_instance
- self.use_rle = use_rle
- def __call__(self, datapoint: Datapoint, **kwargs):
- for fquery in datapoint.find_queries:
- h, w = datapoint.images[fquery.image_id].size
- if self.use_rle:
- all_segs = [
- datapoint.images[fquery.image_id].objects[obj_id].segment
- for obj_id in fquery.object_ids_output
- ]
- if len(all_segs) > 0:
- # we need to double check that all rles are the correct size
- # Otherwise cocotools will fail silently to an empty [0,0] mask
- for seg in all_segs:
- assert seg["size"] == all_segs[0]["size"], (
- "Instance segments have inconsistent sizes. "
- f"Found sizes {seg['size']} and {all_segs[0]['size']}"
- )
- fquery.semantic_target = mask_utils.merge(all_segs)
- else:
- # There is no good way to create an empty RLE of the correct size
- # We resort to converting an empty box to RLE
- fquery.semantic_target = mask_utils.frPyObjects(
- np.array([[0, 0, 0, 0]], dtype=np.float64), h, w
- )[0]
- else:
- # `semantic_target` is uint8 and remains uint8 throughout the transforms
- # (it contains binary 0 and 1 values just like `segment` for each object)
- fquery.semantic_target = torch.zeros((h, w), dtype=torch.uint8)
- for obj_id in fquery.object_ids_output:
- segment = datapoint.images[fquery.image_id].objects[obj_id].segment
- if segment is not None:
- assert (
- isinstance(segment, torch.Tensor)
- and segment.dtype == torch.uint8
- )
- fquery.semantic_target |= segment
- if self.delete_instance:
- for img in datapoint.images:
- for obj in img.objects:
- del obj.segment
- obj.segment = None
- return datapoint
- class RecomputeBoxesFromMasks:
- """Recompute bounding boxes from masks."""
- def __call__(self, datapoint: Datapoint, **kwargs):
- for img in datapoint.images:
- for obj in img.objects:
- # Note: if the mask is empty, the bounding box will be undefined
- # The empty targets should be subsequently filtered
- obj.bbox = masks_to_boxes(obj.segment)
- obj.area = obj.segment.sum().item()
- return datapoint
- class DecodeRle:
- """This transform decodes RLEs into binary segments.
- Implementing it as a transforms allows lazy loading. Some transforms (eg query filters)
- may be deleting masks, so decoding them from the beginning is wasteful.
- This transforms needs to be called before any kind of geometric manipulation
- """
- def __call__(self, datapoint: Datapoint, **kwargs):
- imgId2size = {}
- warning_shown = False
- for imgId, img in enumerate(datapoint.images):
- if isinstance(img.data, PILImage.Image):
- img_w, img_h = img.data.size
- elif isinstance(img.data, torch.Tensor):
- img_w, img_h = img.data.shape[-2:]
- else:
- raise RuntimeError(f"Unexpected image type {type(img.data)}")
- imgId2size[imgId] = (img_h, img_w)
- for obj in img.objects:
- if obj.segment is not None and not isinstance(
- obj.segment, torch.Tensor
- ):
- if mask_utils.area(obj.segment) == 0:
- print("Warning, empty mask found, approximating from box")
- obj.segment = torch.zeros(img_h, img_w, dtype=torch.uint8)
- x1, y1, x2, y2 = obj.bbox.int().tolist()
- obj.segment[y1 : max(y2, y1 + 1), x1 : max(x1 + 1, x2)] = 1
- else:
- obj.segment = mask_utils.decode(obj.segment)
- # segment is uint8 and remains uint8 throughout the transforms
- obj.segment = torch.tensor(obj.segment).to(torch.uint8)
- if list(obj.segment.shape) != [img_h, img_w]:
- # Should not happen often, but adding for security
- if not warning_shown:
- print(
- f"Warning expected instance segmentation size to be {[img_h, img_w]} but found {list(obj.segment.shape)}"
- )
- # Printing only once per datapoint to avoid spam
- warning_shown = True
- obj.segment = F.resize(
- obj.segment[None], (img_h, img_w)
- ).squeeze(0)
- assert list(obj.segment.shape) == [img_h, img_w]
- warning_shown = False
- for query in datapoint.find_queries:
- if query.semantic_target is not None and not isinstance(
- query.semantic_target, torch.Tensor
- ):
- query.semantic_target = mask_utils.decode(query.semantic_target)
- # segment is uint8 and remains uint8 throughout the transforms
- query.semantic_target = torch.tensor(query.semantic_target).to(
- torch.uint8
- )
- if tuple(query.semantic_target.shape) != imgId2size[query.image_id]:
- if not warning_shown:
- print(
- f"Warning expected semantic segmentation size to be {imgId2size[query.image_id]} but found {tuple(query.semantic_target.shape)}"
- )
- # Printing only once per datapoint to avoid spam
- warning_shown = True
- query.semantic_target = F.resize(
- query.semantic_target[None], imgId2size[query.image_id]
- ).squeeze(0)
- assert tuple(query.semantic_target.shape) == imgId2size[query.image_id]
- return datapoint
|