rle.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. """Some utilities for RLE encoding that doesn't require downloading the masks to the cpu"""
  4. import numpy as np
  5. import torch
  6. from pycocotools import mask as mask_util
  7. @torch.no_grad()
  8. def rle_encode(orig_mask, return_areas=False):
  9. """Encodes a collection of masks in RLE format
  10. This function emulates the behavior of the COCO API's encode function, but
  11. is executed partially on the GPU for faster execution.
  12. Args:
  13. mask (torch.Tensor): A mask of shape (N, H, W) with dtype=torch.bool
  14. return_areas (bool): If True, add the areas of the masks as a part of
  15. the RLE output dict under the "area" key. Default is False.
  16. Returns:
  17. str: The RLE encoded masks
  18. """
  19. assert orig_mask.ndim == 3, "Mask must be of shape (N, H, W)"
  20. assert orig_mask.dtype == torch.bool, "Mask must have dtype=torch.bool"
  21. if orig_mask.numel() == 0:
  22. return []
  23. # First, transpose the spatial dimensions.
  24. # This is necessary because the COCO API uses Fortran order
  25. mask = orig_mask.transpose(1, 2)
  26. # Flatten the mask
  27. flat_mask = mask.reshape(mask.shape[0], -1)
  28. if return_areas:
  29. mask_areas = flat_mask.sum(-1).tolist()
  30. # Find the indices where the mask changes
  31. differences = torch.ones(
  32. mask.shape[0], flat_mask.shape[1] + 1, device=mask.device, dtype=torch.bool
  33. )
  34. differences[:, 1:-1] = flat_mask[:, :-1] != flat_mask[:, 1:]
  35. differences[:, 0] = flat_mask[:, 0]
  36. _, change_indices = torch.where(differences)
  37. try:
  38. boundaries = torch.cumsum(differences.sum(-1), 0).cpu()
  39. except RuntimeError as _:
  40. boundaries = torch.cumsum(differences.cpu().sum(-1), 0)
  41. change_indices_clone = change_indices.clone()
  42. # First pass computes the RLEs on GPU, in a flatten format
  43. for i in range(mask.shape[0]):
  44. # Get the change indices for this batch item
  45. beg = 0 if i == 0 else boundaries[i - 1].item()
  46. end = boundaries[i].item()
  47. change_indices[beg + 1 : end] -= change_indices_clone[beg : end - 1]
  48. # Now we can split the RLES of each batch item, and convert them to strings
  49. # No more gpu at this point
  50. change_indices = change_indices.tolist()
  51. batch_rles = []
  52. # Process each mask in the batch separately
  53. for i in range(mask.shape[0]):
  54. beg = 0 if i == 0 else boundaries[i - 1].item()
  55. end = boundaries[i].item()
  56. run_lengths = change_indices[beg:end]
  57. uncompressed_rle = {"counts": run_lengths, "size": list(orig_mask.shape[1:])}
  58. h, w = uncompressed_rle["size"]
  59. rle = mask_util.frPyObjects(uncompressed_rle, h, w)
  60. rle["counts"] = rle["counts"].decode("utf-8")
  61. if return_areas:
  62. rle["area"] = mask_areas[i]
  63. batch_rles.append(rle)
  64. return batch_rles
  65. def robust_rle_encode(masks):
  66. """Encodes a collection of masks in RLE format. Uses the gpu version fist, falls back to the cpu version if it fails"""
  67. assert masks.ndim == 3, "Mask must be of shape (N, H, W)"
  68. assert masks.dtype == torch.bool, "Mask must have dtype=torch.bool"
  69. try:
  70. return rle_encode(masks)
  71. except RuntimeError as _:
  72. masks = masks.cpu().numpy()
  73. rles = [
  74. mask_util.encode(
  75. np.array(mask[:, :, np.newaxis], dtype=np.uint8, order="F")
  76. )[0]
  77. for mask in masks
  78. ]
  79. for rle in rles:
  80. rle["counts"] = rle["counts"].decode("utf-8")
  81. return rles
  82. def ann_to_rle(segm, im_info):
  83. """Convert annotation which can be polygons, uncompressed RLE to RLE.
  84. Args:
  85. ann (dict) : annotation object
  86. Returns:
  87. ann (rle)
  88. """
  89. h, w = im_info["height"], im_info["width"]
  90. if isinstance(segm, list):
  91. # polygon -- a single object might consist of multiple parts
  92. # we merge all parts into one mask rle code
  93. rles = mask_util.frPyObjects(segm, h, w)
  94. rle = mask_util.merge(rles)
  95. elif isinstance(segm["counts"], list):
  96. # uncompressed RLE
  97. rle = mask_util.frPyObjects(segm, h, w)
  98. else:
  99. # rle
  100. rle = segm
  101. return rle