mask_sampling.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. from typing import Callable
  4. import torch
  5. from torch.nn import functional as F
  6. # Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py
  7. def point_sample(input, point_coords, **kwargs):
  8. """
  9. A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors.
  10. Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside
  11. [0, 1] x [0, 1] square.
  12. Args:
  13. input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid.
  14. point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains
  15. [0, 1] x [0, 1] normalized point coordinates.
  16. Returns:
  17. output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains
  18. features for points in `point_coords`. The features are obtained via bilinear
  19. interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`.
  20. """
  21. add_dim = False
  22. if point_coords.dim() == 3:
  23. add_dim = True
  24. point_coords = point_coords.unsqueeze(2)
  25. normalized_point_coords = 2.0 * point_coords - 1.0 # Normalize to [-1,1]
  26. output = F.grid_sample(input, normalized_point_coords, **kwargs)
  27. if add_dim:
  28. output = output.squeeze(3)
  29. return output
  30. # Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py
  31. def get_uncertain_point_coords_with_randomness(
  32. logits: torch.Tensor,
  33. uncertainty_func: Callable,
  34. num_points: int,
  35. oversample_ratio: int,
  36. importance_sample_ratio: float,
  37. ) -> torch.Tensor:
  38. """
  39. Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties
  40. are calculated for each point using 'uncertainty_func' function that takes point's logit
  41. prediction as input.
  42. See PointRend paper for details.
  43. Args:
  44. logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for
  45. class-specific or class-agnostic prediction.
  46. uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that
  47. contains logit predictions for P points and returns their uncertainties as a Tensor of
  48. shape (N, 1, P).
  49. num_points (int): The number of points P to sample.
  50. oversample_ratio (int): Oversampling parameter.
  51. importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling.
  52. Returns:
  53. point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P
  54. sampled points.
  55. """
  56. assert oversample_ratio >= 1
  57. assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0
  58. num_boxes = logits.shape[0]
  59. num_sampled = int(num_points * oversample_ratio)
  60. point_coords = torch.rand(num_boxes, num_sampled, 2, device=logits.device)
  61. point_logits = point_sample(logits, point_coords, align_corners=False)
  62. # It is crucial to calculate uncertainty based on the sampled prediction value for the points.
  63. # Calculating uncertainties of the predictions first and sampling them for points leads
  64. # to incorrect results.
  65. # To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between
  66. # two predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value.
  67. # However, if we calculate uncertainties for the predictions first,
  68. # both will have -1 uncertainty, and the sampled point will get -1 uncertainty.
  69. point_uncertainties = uncertainty_func(point_logits)
  70. num_uncertain_points = int(importance_sample_ratio * num_points)
  71. num_random_points = num_points - num_uncertain_points
  72. idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
  73. # Flatten the indices
  74. shift = num_sampled * torch.arange(
  75. num_boxes, dtype=torch.long, device=logits.device
  76. )
  77. idx += shift[:, None]
  78. point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
  79. num_boxes, num_uncertain_points, 2
  80. )
  81. if num_random_points > 0:
  82. point_coords = torch.cat(
  83. [
  84. point_coords,
  85. torch.rand(num_boxes, num_random_points, 2, device=logits.device),
  86. ],
  87. dim=1,
  88. )
  89. return point_coords
  90. # Adapted from https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/criterion.py
  91. def calculate_uncertainty(logits: torch.Tensor) -> torch.Tensor:
  92. """
  93. Estimates uncerainty as L1 distance between 0.0 and the logit prediction.
  94. Args:
  95. logits (Tensor): A tensor of shape (R, 1, ...) for class-agnostic
  96. predicted masks
  97. Returns:
  98. scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
  99. the most uncertain locations having the highest uncertainty score.
  100. """
  101. assert logits.shape[1] == 1
  102. return -(torch.abs(logits))