masks_ops.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import torch
  4. def masks_to_boxes(masks: torch.Tensor, obj_ids: list[int]):
  5. with torch.autograd.profiler.record_function("perflib: masks_to_boxes"):
  6. # Sanity check based on callsite for replacement
  7. assert masks.shape[0] == len(obj_ids)
  8. assert masks.dim() == 3
  9. # Based on torchvision masks_to_boxes
  10. if masks.numel() == 0:
  11. return torch.zeros((0, 4), device=masks.device, dtype=torch.float)
  12. N, H, W = masks.shape
  13. device = masks.device
  14. y = torch.arange(H, device=device).view(1, H)
  15. x = torch.arange(W, device=device).view(1, W)
  16. masks_with_obj = masks != 0 # N, H, W
  17. masks_with_obj_x = masks_with_obj.amax(
  18. dim=1
  19. ) # N, H (which columns have objects)
  20. masks_with_obj_y = masks_with_obj.amax(dim=2) # N, W (which rows have objects)
  21. masks_without_obj_x = ~masks_with_obj_x
  22. masks_without_obj_y = ~masks_with_obj_y
  23. bounding_boxes_0 = torch.amin(
  24. (masks_without_obj_x * W) + (masks_with_obj_x * x), dim=1
  25. )
  26. bounding_boxes_1 = torch.amin(
  27. (masks_without_obj_y * H) + (masks_with_obj_y * y), dim=1
  28. )
  29. bounding_boxes_2 = torch.amax(masks_with_obj_x * x, dim=1)
  30. bounding_boxes_3 = torch.amax(masks_with_obj_y * y, dim=1)
  31. bounding_boxes = torch.stack(
  32. [bounding_boxes_0, bounding_boxes_1, bounding_boxes_2, bounding_boxes_3],
  33. dim=1,
  34. ).to(dtype=torch.float)
  35. assert bounding_boxes.shape == (N, 4)
  36. assert bounding_boxes.device == masks.device
  37. assert bounding_boxes.dtype == torch.float
  38. return bounding_boxes
  39. def mask_iou(pred_masks: torch.Tensor, gt_masks: torch.Tensor) -> torch.Tensor:
  40. """
  41. Compute the IoU (Intersection over Union) between predicted masks and ground truth masks.
  42. Args:
  43. - pred_masks: (N, H, W) bool Tensor, containing binary predicted segmentation masks
  44. - gt_masks: (M, H, W) bool Tensor, containing binary ground truth segmentation masks
  45. Returns:
  46. - ious: (N, M) float Tensor, containing IoUs for each pair of predicted and ground truth masks
  47. """
  48. assert pred_masks.dtype == gt_masks.dtype == torch.bool
  49. N, H, W = pred_masks.shape
  50. M, _, _ = gt_masks.shape
  51. # Flatten masks: (N, 1, H*W) and (1, M, H*W)
  52. pred_flat = pred_masks.view(N, 1, H * W)
  53. gt_flat = gt_masks.view(1, M, H * W)
  54. # Compute intersection and union: (N, M)
  55. intersection = (pred_flat & gt_flat).sum(dim=2).float()
  56. union = (pred_flat | gt_flat).sum(dim=2).float()
  57. ious = intersection / union.clamp(min=1)
  58. return ious # shape: (N, M)