segmentation.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import numpy as np
  4. import pycocotools.mask as mask_utils
  5. import torch
  6. import torchvision.transforms.functional as F
  7. from PIL import Image as PILImage
  8. from sam3.model.box_ops import masks_to_boxes
  9. from sam3.train.data.sam3_image_dataset import Datapoint
  10. class InstanceToSemantic(object):
  11. """Convert instance segmentation to semantic segmentation."""
  12. def __init__(self, delete_instance=True, use_rle=False):
  13. self.delete_instance = delete_instance
  14. self.use_rle = use_rle
  15. def __call__(self, datapoint: Datapoint, **kwargs):
  16. for fquery in datapoint.find_queries:
  17. h, w = datapoint.images[fquery.image_id].size
  18. if self.use_rle:
  19. all_segs = [
  20. datapoint.images[fquery.image_id].objects[obj_id].segment
  21. for obj_id in fquery.object_ids_output
  22. ]
  23. if len(all_segs) > 0:
  24. # we need to double check that all rles are the correct size
  25. # Otherwise cocotools will fail silently to an empty [0,0] mask
  26. for seg in all_segs:
  27. assert seg["size"] == all_segs[0]["size"], (
  28. "Instance segments have inconsistent sizes. "
  29. f"Found sizes {seg['size']} and {all_segs[0]['size']}"
  30. )
  31. fquery.semantic_target = mask_utils.merge(all_segs)
  32. else:
  33. # There is no good way to create an empty RLE of the correct size
  34. # We resort to converting an empty box to RLE
  35. fquery.semantic_target = mask_utils.frPyObjects(
  36. np.array([[0, 0, 0, 0]], dtype=np.float64), h, w
  37. )[0]
  38. else:
  39. # `semantic_target` is uint8 and remains uint8 throughout the transforms
  40. # (it contains binary 0 and 1 values just like `segment` for each object)
  41. fquery.semantic_target = torch.zeros((h, w), dtype=torch.uint8)
  42. for obj_id in fquery.object_ids_output:
  43. segment = datapoint.images[fquery.image_id].objects[obj_id].segment
  44. if segment is not None:
  45. assert (
  46. isinstance(segment, torch.Tensor)
  47. and segment.dtype == torch.uint8
  48. )
  49. fquery.semantic_target |= segment
  50. if self.delete_instance:
  51. for img in datapoint.images:
  52. for obj in img.objects:
  53. del obj.segment
  54. obj.segment = None
  55. return datapoint
  56. class RecomputeBoxesFromMasks:
  57. """Recompute bounding boxes from masks."""
  58. def __call__(self, datapoint: Datapoint, **kwargs):
  59. for img in datapoint.images:
  60. for obj in img.objects:
  61. # Note: if the mask is empty, the bounding box will be undefined
  62. # The empty targets should be subsequently filtered
  63. obj.bbox = masks_to_boxes(obj.segment)
  64. obj.area = obj.segment.sum().item()
  65. return datapoint
  66. class DecodeRle:
  67. """This transform decodes RLEs into binary segments.
  68. Implementing it as a transforms allows lazy loading. Some transforms (eg query filters)
  69. may be deleting masks, so decoding them from the beginning is wasteful.
  70. This transforms needs to be called before any kind of geometric manipulation
  71. """
  72. def __call__(self, datapoint: Datapoint, **kwargs):
  73. imgId2size = {}
  74. warning_shown = False
  75. for imgId, img in enumerate(datapoint.images):
  76. if isinstance(img.data, PILImage.Image):
  77. img_w, img_h = img.data.size
  78. elif isinstance(img.data, torch.Tensor):
  79. img_w, img_h = img.data.shape[-2:]
  80. else:
  81. raise RuntimeError(f"Unexpected image type {type(img.data)}")
  82. imgId2size[imgId] = (img_h, img_w)
  83. for obj in img.objects:
  84. if obj.segment is not None and not isinstance(
  85. obj.segment, torch.Tensor
  86. ):
  87. if mask_utils.area(obj.segment) == 0:
  88. print("Warning, empty mask found, approximating from box")
  89. obj.segment = torch.zeros(img_h, img_w, dtype=torch.uint8)
  90. x1, y1, x2, y2 = obj.bbox.int().tolist()
  91. obj.segment[y1 : max(y2, y1 + 1), x1 : max(x1 + 1, x2)] = 1
  92. else:
  93. obj.segment = mask_utils.decode(obj.segment)
  94. # segment is uint8 and remains uint8 throughout the transforms
  95. obj.segment = torch.tensor(obj.segment).to(torch.uint8)
  96. if list(obj.segment.shape) != [img_h, img_w]:
  97. # Should not happen often, but adding for security
  98. if not warning_shown:
  99. print(
  100. f"Warning expected instance segmentation size to be {[img_h, img_w]} but found {list(obj.segment.shape)}"
  101. )
  102. # Printing only once per datapoint to avoid spam
  103. warning_shown = True
  104. obj.segment = F.resize(
  105. obj.segment[None], (img_h, img_w)
  106. ).squeeze(0)
  107. assert list(obj.segment.shape) == [img_h, img_w]
  108. warning_shown = False
  109. for query in datapoint.find_queries:
  110. if query.semantic_target is not None and not isinstance(
  111. query.semantic_target, torch.Tensor
  112. ):
  113. query.semantic_target = mask_utils.decode(query.semantic_target)
  114. # segment is uint8 and remains uint8 throughout the transforms
  115. query.semantic_target = torch.tensor(query.semantic_target).to(
  116. torch.uint8
  117. )
  118. if tuple(query.semantic_target.shape) != imgId2size[query.image_id]:
  119. if not warning_shown:
  120. print(
  121. f"Warning expected semantic segmentation size to be {imgId2size[query.image_id]} but found {tuple(query.semantic_target.shape)}"
  122. )
  123. # Printing only once per datapoint to avoid spam
  124. warning_shown = True
  125. query.semantic_target = F.resize(
  126. query.semantic_target[None], imgId2size[query.image_id]
  127. ).squeeze(0)
  128. assert tuple(query.semantic_target.shape) == imgId2size[query.image_id]
  129. return datapoint