roi_align.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. from torch import nn
  4. from torchvision.ops import roi_align
  5. # NOTE: torchvision's RoIAlign has a different default aligned=False
  6. class ROIAlign(nn.Module):
  7. def __init__(self, output_size, spatial_scale, sampling_ratio, aligned=True):
  8. """
  9. Args:
  10. output_size (tuple): h, w
  11. spatial_scale (float): scale the input boxes by this number
  12. sampling_ratio (int): number of inputs samples to take for each output
  13. sample. 0 to take samples densely.
  14. aligned (bool): if False, use the legacy implementation in
  15. Detectron. If True, align the results more perfectly.
  16. Note:
  17. The meaning of aligned=True:
  18. Given a continuous coordinate c, its two neighboring pixel indices (in our
  19. pixel model) are computed by floor(c - 0.5) and ceil(c - 0.5). For example,
  20. c=1.3 has pixel neighbors with discrete indices [0] and [1] (which are sampled
  21. from the underlying signal at continuous coordinates 0.5 and 1.5). But the original
  22. roi_align (aligned=False) does not subtract the 0.5 when computing neighboring
  23. pixel indices and therefore it uses pixels with a slightly incorrect alignment
  24. (relative to our pixel model) when performing bilinear interpolation.
  25. With `aligned=True`,
  26. we first appropriately scale the ROI and then shift it by -0.5
  27. prior to calling roi_align. This produces the correct neighbors; see
  28. detectron2/tests/test_roi_align.py for verification.
  29. The difference does not make a difference to the model's performance if
  30. ROIAlign is used together with conv layers.
  31. """
  32. super().__init__()
  33. self.output_size = output_size
  34. self.spatial_scale = spatial_scale
  35. self.sampling_ratio = sampling_ratio
  36. self.aligned = aligned
  37. from torchvision import __version__
  38. version = tuple(int(x) for x in __version__.split(".")[:2])
  39. # https://github.com/pytorch/vision/pull/2438
  40. assert version >= (0, 7), "Require torchvision >= 0.7"
  41. def forward(self, input, rois):
  42. """
  43. Args:
  44. input: NCHW images
  45. rois: Bx5 boxes. First column is the index into N. The other 4 columns are xyxy.
  46. """
  47. assert rois.dim() == 2 and rois.size(1) == 5
  48. if input.is_quantized:
  49. input = input.dequantize()
  50. return roi_align(
  51. input,
  52. rois.to(dtype=input.dtype),
  53. self.output_size,
  54. self.spatial_scale,
  55. self.sampling_ratio,
  56. self.aligned,
  57. )
  58. def __repr__(self):
  59. tmpstr = self.__class__.__name__ + "("
  60. tmpstr += "output_size=" + str(self.output_size)
  61. tmpstr += ", spatial_scale=" + str(self.spatial_scale)
  62. tmpstr += ", sampling_ratio=" + str(self.sampling_ratio)
  63. tmpstr += ", aligned=" + str(self.aligned)
  64. tmpstr += ")"
  65. return tmpstr