masks_ops.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. """Utilities for masks manipulation"""
  4. import numpy as np
  5. import pycocotools.mask as maskUtils
  6. import torch
  7. from pycocotools import mask as mask_util
  8. def instance_masks_to_semantic_masks(
  9. instance_masks: torch.Tensor, num_instances: torch.Tensor
  10. ) -> torch.Tensor:
  11. """This function converts instance masks to semantic masks.
  12. It accepts a collapsed batch of instances masks (ie all instance masks are concatenated in a single tensor) and
  13. the number of instances in each image of the batch.
  14. It returns a mask with the same spatial dimensions as the input instance masks, where for each batch element the
  15. semantic mask is the union of all the instance masks in the batch element.
  16. If for a given batch element there are no instances (ie num_instances[i]==0), the corresponding semantic mask will be a tensor of zeros.
  17. Args:
  18. instance_masks (torch.Tensor): A tensor of shape (N, H, W) where N is the number of instances in the batch.
  19. num_instances (torch.Tensor): A tensor of shape (B,) where B is the batch size. It contains the number of instances
  20. in each image of the batch.
  21. Returns:
  22. torch.Tensor: A tensor of shape (B, H, W) where B is the batch size and H, W are the spatial dimensions of the
  23. input instance masks.
  24. """
  25. masks_per_query = torch.split(instance_masks, num_instances.tolist())
  26. return torch.stack([torch.any(masks, dim=0) for masks in masks_per_query], dim=0)
  27. def mask_intersection(masks1, masks2, block_size=16):
  28. """Compute the intersection of two sets of masks, without blowing the memory"""
  29. assert masks1.shape[1:] == masks2.shape[1:]
  30. assert masks1.dtype == torch.bool and masks2.dtype == torch.bool
  31. result = torch.zeros(
  32. masks1.shape[0], masks2.shape[0], device=masks1.device, dtype=torch.long
  33. )
  34. for i in range(0, masks1.shape[0], block_size):
  35. for j in range(0, masks2.shape[0], block_size):
  36. intersection = (
  37. (masks1[i : i + block_size, None] * masks2[None, j : j + block_size])
  38. .flatten(-2)
  39. .sum(-1)
  40. )
  41. result[i : i + block_size, j : j + block_size] = intersection
  42. return result
  43. def mask_iom(masks1, masks2):
  44. """
  45. Similar to IoU, except the denominator is the area of the smallest mask
  46. """
  47. assert masks1.shape[1:] == masks2.shape[1:]
  48. assert masks1.dtype == torch.bool and masks2.dtype == torch.bool
  49. # intersection = (masks1[:, None] * masks2[None]).flatten(-2).sum(-1)
  50. intersection = mask_intersection(masks1, masks2)
  51. area1 = masks1.flatten(-2).sum(-1)
  52. area2 = masks2.flatten(-2).sum(-1)
  53. min_area = torch.min(area1[:, None], area2[None, :])
  54. return intersection / (min_area + 1e-8)
  55. def compute_boundary(seg):
  56. """
  57. Adapted from https://github.com/JonathonLuiten/TrackEval/blob/master/trackeval/metrics/j_and_f.py#L148
  58. Return a 1pix wide boundary of the given mask
  59. """
  60. assert seg.ndim >= 2
  61. e = torch.zeros_like(seg)
  62. s = torch.zeros_like(seg)
  63. se = torch.zeros_like(seg)
  64. e[..., :, :-1] = seg[..., :, 1:]
  65. s[..., :-1, :] = seg[..., 1:, :]
  66. se[..., :-1, :-1] = seg[..., 1:, 1:]
  67. b = seg ^ e | seg ^ s | seg ^ se
  68. b[..., -1, :] = seg[..., -1, :] ^ e[..., -1, :]
  69. b[..., :, -1] = seg[..., :, -1] ^ s[..., :, -1]
  70. b[..., -1, -1] = 0
  71. return b
  72. def dilation(mask, kernel_size):
  73. """
  74. Implements the dilation operation. If the input is on cpu, we call the cv2 version.
  75. Otherwise, we implement it using a convolution
  76. The kernel is assumed to be a square kernel
  77. """
  78. assert mask.ndim == 3
  79. kernel_size = int(kernel_size)
  80. assert kernel_size % 2 == 1, (
  81. f"Dilation expects a odd kernel size, got {kernel_size}"
  82. )
  83. if mask.is_cuda:
  84. m = mask.unsqueeze(1).to(torch.float16)
  85. k = torch.ones(1, 1, kernel_size, 1, dtype=m.dtype, device=m.device)
  86. result = torch.nn.functional.conv2d(m, k, padding="same")
  87. result = torch.nn.functional.conv2d(result, k.transpose(-1, -2), padding="same")
  88. return result.view_as(mask) > 0
  89. all_masks = mask.view(-1, mask.size(-2), mask.size(-1)).numpy().astype(np.uint8)
  90. kernel = np.ones((kernel_size, kernel_size), dtype=np.uint8)
  91. import cv2
  92. processed = [torch.from_numpy(cv2.dilate(m, kernel)) for m in all_masks]
  93. return torch.stack(processed).view_as(mask).to(mask)
  94. def compute_F_measure(
  95. gt_boundary_rle, gt_dilated_boundary_rle, dt_boundary_rle, dt_dilated_boundary_rle
  96. ):
  97. """Adapted from https://github.com/JonathonLuiten/TrackEval/blob/master/trackeval/metrics/j_and_f.py#L207
  98. Assumes the boundary and dilated boundaries have already been computed and converted to RLE
  99. """
  100. gt_match = maskUtils.merge([gt_boundary_rle, dt_dilated_boundary_rle], True)
  101. dt_match = maskUtils.merge([dt_boundary_rle, gt_dilated_boundary_rle], True)
  102. n_dt = maskUtils.area(dt_boundary_rle)
  103. n_gt = maskUtils.area(gt_boundary_rle)
  104. # % Compute precision and recall
  105. if n_dt == 0 and n_gt > 0:
  106. precision = 1
  107. recall = 0
  108. elif n_dt > 0 and n_gt == 0:
  109. precision = 0
  110. recall = 1
  111. elif n_dt == 0 and n_gt == 0:
  112. precision = 1
  113. recall = 1
  114. else:
  115. precision = maskUtils.area(dt_match) / float(n_dt)
  116. recall = maskUtils.area(gt_match) / float(n_gt)
  117. # Compute F measure
  118. if precision + recall == 0:
  119. f_val = 0
  120. else:
  121. f_val = 2 * precision * recall / (precision + recall)
  122. return f_val
  123. @torch.no_grad()
  124. def rle_encode(orig_mask, return_areas=False):
  125. """Encodes a collection of masks in RLE format
  126. This function emulates the behavior of the COCO API's encode function, but
  127. is executed partially on the GPU for faster execution.
  128. Args:
  129. mask (torch.Tensor): A mask of shape (N, H, W) with dtype=torch.bool
  130. return_areas (bool): If True, add the areas of the masks as a part of
  131. the RLE output dict under the "area" key. Default is False.
  132. Returns:
  133. str: The RLE encoded masks
  134. """
  135. assert orig_mask.ndim == 3, "Mask must be of shape (N, H, W)"
  136. assert orig_mask.dtype == torch.bool, "Mask must have dtype=torch.bool"
  137. if orig_mask.numel() == 0:
  138. return []
  139. # First, transpose the spatial dimensions.
  140. # This is necessary because the COCO API uses Fortran order
  141. mask = orig_mask.transpose(1, 2)
  142. # Flatten the mask
  143. flat_mask = mask.reshape(mask.shape[0], -1)
  144. if return_areas:
  145. mask_areas = flat_mask.sum(-1).tolist()
  146. # Find the indices where the mask changes
  147. differences = torch.ones(
  148. mask.shape[0], flat_mask.shape[1] + 1, device=mask.device, dtype=torch.bool
  149. )
  150. differences[:, 1:-1] = flat_mask[:, :-1] != flat_mask[:, 1:]
  151. differences[:, 0] = flat_mask[:, 0]
  152. _, change_indices = torch.where(differences)
  153. try:
  154. boundaries = torch.cumsum(differences.sum(-1), 0).cpu()
  155. except RuntimeError as _:
  156. boundaries = torch.cumsum(differences.cpu().sum(-1), 0)
  157. change_indices_clone = change_indices.clone()
  158. # First pass computes the RLEs on GPU, in a flatten format
  159. for i in range(mask.shape[0]):
  160. # Get the change indices for this batch item
  161. beg = 0 if i == 0 else boundaries[i - 1].item()
  162. end = boundaries[i].item()
  163. change_indices[beg + 1 : end] -= change_indices_clone[beg : end - 1]
  164. # Now we can split the RLES of each batch item, and convert them to strings
  165. # No more gpu at this point
  166. change_indices = change_indices.tolist()
  167. batch_rles = []
  168. # Process each mask in the batch separately
  169. for i in range(mask.shape[0]):
  170. beg = 0 if i == 0 else boundaries[i - 1].item()
  171. end = boundaries[i].item()
  172. run_lengths = change_indices[beg:end]
  173. uncompressed_rle = {"counts": run_lengths, "size": list(orig_mask.shape[1:])}
  174. h, w = uncompressed_rle["size"]
  175. rle = mask_util.frPyObjects(uncompressed_rle, h, w)
  176. rle["counts"] = rle["counts"].decode("utf-8")
  177. if return_areas:
  178. rle["area"] = mask_areas[i]
  179. batch_rles.append(rle)
  180. return batch_rles
  181. def robust_rle_encode(masks):
  182. """Encodes a collection of masks in RLE format. Uses the gpu version fist, falls back to the cpu version if it fails"""
  183. assert masks.ndim == 3, "Mask must be of shape (N, H, W)"
  184. assert masks.dtype == torch.bool, "Mask must have dtype=torch.bool"
  185. try:
  186. return rle_encode(masks)
  187. except RuntimeError as _:
  188. masks = masks.cpu().numpy()
  189. rles = [
  190. mask_util.encode(
  191. np.array(mask[:, :, np.newaxis], dtype=np.uint8, order="F")
  192. )[0]
  193. for mask in masks
  194. ]
  195. for rle in rles:
  196. rle["counts"] = rle["counts"].decode("utf-8")
  197. return rles
  198. def ann_to_rle(segm, im_info):
  199. """Convert annotation which can be polygons, uncompressed RLE to RLE.
  200. Args:
  201. ann (dict) : annotation object
  202. Returns:
  203. ann (rle)
  204. """
  205. h, w = im_info["height"], im_info["width"]
  206. if isinstance(segm, list):
  207. # polygon -- a single object might consist of multiple parts
  208. # we merge all parts into one mask rle code
  209. rles = mask_util.frPyObjects(segm, h, w)
  210. rle = mask_util.merge(rles)
  211. elif isinstance(segm["counts"], list):
  212. # uncompressed RLE
  213. rle = mask_util.frPyObjects(segm, h, w)
  214. else:
  215. # rle
  216. rle = segm
  217. return rle