| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677 |
- # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
- # pyre-unsafe
- from torch import nn
- from torchvision.ops import roi_align
- # NOTE: torchvision's RoIAlign has a different default aligned=False
- class ROIAlign(nn.Module):
- def __init__(self, output_size, spatial_scale, sampling_ratio, aligned=True):
- """
- Args:
- output_size (tuple): h, w
- spatial_scale (float): scale the input boxes by this number
- sampling_ratio (int): number of inputs samples to take for each output
- sample. 0 to take samples densely.
- aligned (bool): if False, use the legacy implementation in
- Detectron. If True, align the results more perfectly.
- Note:
- The meaning of aligned=True:
- Given a continuous coordinate c, its two neighboring pixel indices (in our
- pixel model) are computed by floor(c - 0.5) and ceil(c - 0.5). For example,
- c=1.3 has pixel neighbors with discrete indices [0] and [1] (which are sampled
- from the underlying signal at continuous coordinates 0.5 and 1.5). But the original
- roi_align (aligned=False) does not subtract the 0.5 when computing neighboring
- pixel indices and therefore it uses pixels with a slightly incorrect alignment
- (relative to our pixel model) when performing bilinear interpolation.
- With `aligned=True`,
- we first appropriately scale the ROI and then shift it by -0.5
- prior to calling roi_align. This produces the correct neighbors; see
- detectron2/tests/test_roi_align.py for verification.
- The difference does not make a difference to the model's performance if
- ROIAlign is used together with conv layers.
- """
- super().__init__()
- self.output_size = output_size
- self.spatial_scale = spatial_scale
- self.sampling_ratio = sampling_ratio
- self.aligned = aligned
- from torchvision import __version__
- version = tuple(int(x) for x in __version__.split(".")[:2])
- # https://github.com/pytorch/vision/pull/2438
- assert version >= (0, 7), "Require torchvision >= 0.7"
- def forward(self, input, rois):
- """
- Args:
- input: NCHW images
- rois: Bx5 boxes. First column is the index into N. The other 4 columns are xyxy.
- """
- assert rois.dim() == 2 and rois.size(1) == 5
- if input.is_quantized:
- input = input.dequantize()
- return roi_align(
- input,
- rois.to(dtype=input.dtype),
- self.output_size,
- self.spatial_scale,
- self.sampling_ratio,
- self.aligned,
- )
- def __repr__(self):
- tmpstr = self.__class__.__name__ + "("
- tmpstr += "output_size=" + str(self.output_size)
- tmpstr += ", spatial_scale=" + str(self.spatial_scale)
- tmpstr += ", sampling_ratio=" + str(self.sampling_ratio)
- tmpstr += ", aligned=" + str(self.aligned)
- tmpstr += ")"
- return tmpstr
|