keypoints.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. from typing import Any, List, Tuple, Union
  4. import numpy as np
  5. import torch
  6. from torch.nn import functional as F
  7. class Keypoints:
  8. """
  9. Stores keypoint **annotation** data. GT Instances have a `gt_keypoints` property
  10. containing the x,y location and visibility flag of each keypoint. This tensor has shape
  11. (N, K, 3) where N is the number of instances and K is the number of keypoints per instance.
  12. The visibility flag follows the COCO format and must be one of three integers:
  13. * v=0: not labeled (in which case x=y=0)
  14. * v=1: labeled but not visible
  15. * v=2: labeled and visible
  16. """
  17. def __init__(self, keypoints: Union[torch.Tensor, np.ndarray, List[List[float]]]):
  18. """
  19. Arguments:
  20. keypoints: A Tensor, numpy array, or list of the x, y, and visibility of each keypoint.
  21. The shape should be (N, K, 3) where N is the number of
  22. instances, and K is the number of keypoints per instance.
  23. """
  24. device = (
  25. keypoints.device
  26. if isinstance(keypoints, torch.Tensor)
  27. else torch.device("cpu")
  28. )
  29. keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=device)
  30. assert keypoints.dim() == 3 and keypoints.shape[2] == 3, keypoints.shape
  31. self.tensor = keypoints
  32. def __len__(self) -> int:
  33. return self.tensor.size(0)
  34. def to(self, *args: Any, **kwargs: Any) -> "Keypoints":
  35. return type(self)(self.tensor.to(*args, **kwargs))
  36. @property
  37. def device(self) -> torch.device:
  38. return self.tensor.device
  39. def to_heatmap(self, boxes: torch.Tensor, heatmap_size: int) -> torch.Tensor:
  40. """
  41. Convert keypoint annotations to a heatmap of one-hot labels for training,
  42. as described in :paper:`Mask R-CNN`.
  43. Arguments:
  44. boxes: Nx4 tensor, the boxes to draw the keypoints to
  45. Returns:
  46. heatmaps:
  47. A tensor of shape (N, K), each element is integer spatial label
  48. in the range [0, heatmap_size**2 - 1] for each keypoint in the input.
  49. valid:
  50. A tensor of shape (N, K) containing whether each keypoint is in the roi or not.
  51. """
  52. return _keypoints_to_heatmap(self.tensor, boxes, heatmap_size)
  53. def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Keypoints":
  54. """
  55. Create a new `Keypoints` by indexing on this `Keypoints`.
  56. The following usage are allowed:
  57. 1. `new_kpts = kpts[3]`: return a `Keypoints` which contains only one instance.
  58. 2. `new_kpts = kpts[2:10]`: return a slice of key points.
  59. 3. `new_kpts = kpts[vector]`, where vector is a torch.ByteTensor
  60. with `length = len(kpts)`. Nonzero elements in the vector will be selected.
  61. Note that the returned Keypoints might share storage with this Keypoints,
  62. subject to Pytorch's indexing semantics.
  63. """
  64. if isinstance(item, int):
  65. return Keypoints([self.tensor[item]])
  66. return Keypoints(self.tensor[item])
  67. def __repr__(self) -> str:
  68. s = self.__class__.__name__ + "("
  69. s += "num_instances={})".format(len(self.tensor))
  70. return s
  71. @staticmethod
  72. def cat(keypoints_list: List["Keypoints"]) -> "Keypoints":
  73. """
  74. Concatenates a list of Keypoints into a single Keypoints
  75. Arguments:
  76. keypoints_list (list[Keypoints])
  77. Returns:
  78. Keypoints: the concatenated Keypoints
  79. """
  80. assert isinstance(keypoints_list, (list, tuple))
  81. assert len(keypoints_list) > 0
  82. assert all(isinstance(keypoints, Keypoints) for keypoints in keypoints_list)
  83. cat_kpts = type(keypoints_list[0])(
  84. torch.cat([kpts.tensor for kpts in keypoints_list], dim=0)
  85. )
  86. return cat_kpts
  87. def _keypoints_to_heatmap(
  88. keypoints: torch.Tensor, rois: torch.Tensor, heatmap_size: int
  89. ) -> Tuple[torch.Tensor, torch.Tensor]:
  90. """
  91. Encode keypoint locations into a target heatmap for use in SoftmaxWithLoss across space.
  92. Maps keypoints from the half-open interval [x1, x2) on continuous image coordinates to the
  93. closed interval [0, heatmap_size - 1] on discrete image coordinates. We use the
  94. continuous-discrete conversion from Heckbert 1990 ("What is the coordinate of a pixel?"):
  95. d = floor(c) and c = d + 0.5, where d is a discrete coordinate and c is a continuous coordinate.
  96. Arguments:
  97. keypoints: tensor of keypoint locations in of shape (N, K, 3).
  98. rois: Nx4 tensor of rois in xyxy format
  99. heatmap_size: integer side length of square heatmap.
  100. Returns:
  101. heatmaps: A tensor of shape (N, K) containing an integer spatial label
  102. in the range [0, heatmap_size**2 - 1] for each keypoint in the input.
  103. valid: A tensor of shape (N, K) containing whether each keypoint is in
  104. the roi or not.
  105. """
  106. if rois.numel() == 0:
  107. return rois.new().long(), rois.new().long()
  108. offset_x = rois[:, 0]
  109. offset_y = rois[:, 1]
  110. scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
  111. scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
  112. offset_x = offset_x[:, None]
  113. offset_y = offset_y[:, None]
  114. scale_x = scale_x[:, None]
  115. scale_y = scale_y[:, None]
  116. x = keypoints[..., 0]
  117. y = keypoints[..., 1]
  118. x_boundary_inds = x == rois[:, 2][:, None]
  119. y_boundary_inds = y == rois[:, 3][:, None]
  120. x = (x - offset_x) * scale_x
  121. x = x.floor().long()
  122. y = (y - offset_y) * scale_y
  123. y = y.floor().long()
  124. x[x_boundary_inds] = heatmap_size - 1
  125. y[y_boundary_inds] = heatmap_size - 1
  126. valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
  127. vis = keypoints[..., 2] > 0
  128. valid = (valid_loc & vis).long()
  129. lin_ind = y * heatmap_size + x
  130. heatmaps = lin_ind * valid
  131. return heatmaps, valid
  132. @torch.jit.script_if_tracing
  133. def heatmaps_to_keypoints(maps: torch.Tensor, rois: torch.Tensor) -> torch.Tensor:
  134. """
  135. Extract predicted keypoint locations from heatmaps.
  136. Args:
  137. maps (Tensor): (#ROIs, #keypoints, POOL_H, POOL_W). The predicted heatmap of logits for
  138. each ROI and each keypoint.
  139. rois (Tensor): (#ROIs, 4). The box of each ROI.
  140. Returns:
  141. Tensor of shape (#ROIs, #keypoints, 4) with the last dimension corresponding to
  142. (x, y, logit, score) for each keypoint.
  143. When converting discrete pixel indices in an NxN image to a continuous keypoint coordinate,
  144. we maintain consistency with :meth:`Keypoints.to_heatmap` by using the conversion from
  145. Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a continuous coordinate.
  146. """
  147. offset_x = rois[:, 0]
  148. offset_y = rois[:, 1]
  149. widths = (rois[:, 2] - rois[:, 0]).clamp(min=1)
  150. heights = (rois[:, 3] - rois[:, 1]).clamp(min=1)
  151. widths_ceil = widths.ceil()
  152. heights_ceil = heights.ceil()
  153. num_rois, num_keypoints = maps.shape[:2]
  154. xy_preds = maps.new_zeros(rois.shape[0], num_keypoints, 4)
  155. width_corrections = widths / widths_ceil
  156. height_corrections = heights / heights_ceil
  157. keypoints_idx = torch.arange(num_keypoints, device=maps.device)
  158. for i in range(num_rois):
  159. outsize = (int(heights_ceil[i]), int(widths_ceil[i]))
  160. roi_map = F.interpolate(
  161. maps[[i]], size=outsize, mode="bicubic", align_corners=False
  162. )
  163. # Although semantically equivalent, `reshape` is used instead of `squeeze` due
  164. # to limitation during ONNX export of `squeeze` in scripting mode
  165. roi_map = roi_map.reshape(roi_map.shape[1:]) # keypoints x H x W
  166. # softmax over the spatial region
  167. max_score, _ = roi_map.view(num_keypoints, -1).max(1)
  168. max_score = max_score.view(num_keypoints, 1, 1)
  169. tmp_full_resolution = (roi_map - max_score).exp_()
  170. tmp_pool_resolution = (maps[i] - max_score).exp_()
  171. # Produce scores over the region H x W, but normalize with POOL_H x POOL_W,
  172. # so that the scores of objects of different absolute sizes will be more comparable
  173. roi_map_scores = tmp_full_resolution / tmp_pool_resolution.sum(
  174. (1, 2), keepdim=True
  175. )
  176. w = roi_map.shape[2]
  177. pos = roi_map.view(num_keypoints, -1).argmax(1)
  178. x_int = pos % w
  179. y_int = (pos - x_int) // w
  180. assert (
  181. roi_map_scores[keypoints_idx, y_int, x_int]
  182. == roi_map_scores.view(num_keypoints, -1).max(1)[0]
  183. ).all()
  184. x = (x_int.float() + 0.5) * width_corrections[i]
  185. y = (y_int.float() + 0.5) * height_corrections[i]
  186. xy_preds[i, :, 0] = x + offset_x[i]
  187. xy_preds[i, :, 1] = y + offset_y[i]
  188. xy_preds[i, :, 2] = roi_map[keypoints_idx, y_int, x_int]
  189. xy_preds[i, :, 3] = roi_map_scores[keypoints_idx, y_int, x_int]
  190. return xy_preds