boxes.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import math
  4. from enum import IntEnum, unique
  5. from typing import List, Tuple, Union
  6. import numpy as np
  7. import torch
  8. from torch import device
  9. _RawBoxType = Union[List[float], Tuple[float, ...], torch.Tensor, np.ndarray]
  10. @unique
  11. class BoxMode(IntEnum):
  12. """
  13. Enum of different ways to represent a box.
  14. """
  15. XYXY_ABS = 0
  16. """
  17. (x0, y0, x1, y1) in absolute floating points coordinates.
  18. The coordinates in range [0, width or height].
  19. """
  20. XYWH_ABS = 1
  21. """
  22. (x0, y0, w, h) in absolute floating points coordinates.
  23. """
  24. XYXY_REL = 2
  25. """
  26. Not yet supported!
  27. (x0, y0, x1, y1) in range [0, 1]. They are relative to the size of the image.
  28. """
  29. XYWH_REL = 3
  30. """
  31. Not yet supported!
  32. (x0, y0, w, h) in range [0, 1]. They are relative to the size of the image.
  33. """
  34. XYWHA_ABS = 4
  35. """
  36. (xc, yc, w, h, a) in absolute floating points coordinates.
  37. (xc, yc) is the center of the rotated box, and the angle a is in degrees ccw.
  38. """
  39. @staticmethod
  40. def convert(
  41. box: _RawBoxType, from_mode: "BoxMode", to_mode: "BoxMode"
  42. ) -> _RawBoxType:
  43. """
  44. Args:
  45. box: can be a k-tuple, k-list or an Nxk array/tensor, where k = 4 or 5
  46. from_mode, to_mode (BoxMode)
  47. Returns:
  48. The converted box of the same type.
  49. """
  50. if from_mode == to_mode:
  51. return box
  52. original_type = type(box)
  53. is_numpy = isinstance(box, np.ndarray)
  54. single_box = isinstance(box, (list, tuple))
  55. if single_box:
  56. assert len(box) == 4 or len(box) == 5, (
  57. "BoxMode.convert takes either a k-tuple/list or an Nxk array/tensor,"
  58. " where k == 4 or 5"
  59. )
  60. arr = torch.tensor(box)[None, :]
  61. else:
  62. # avoid modifying the input box
  63. if is_numpy:
  64. arr = torch.from_numpy(np.asarray(box)).clone()
  65. else:
  66. arr = box.clone()
  67. assert to_mode not in [
  68. BoxMode.XYXY_REL,
  69. BoxMode.XYWH_REL,
  70. ] and from_mode not in [
  71. BoxMode.XYXY_REL,
  72. BoxMode.XYWH_REL,
  73. ], "Relative mode not yet supported!"
  74. if from_mode == BoxMode.XYWHA_ABS and to_mode == BoxMode.XYXY_ABS:
  75. assert arr.shape[-1] == 5, (
  76. "The last dimension of input shape must be 5 for XYWHA format"
  77. )
  78. original_dtype = arr.dtype
  79. arr = arr.double()
  80. w = arr[:, 2]
  81. h = arr[:, 3]
  82. a = arr[:, 4]
  83. c = torch.abs(torch.cos(a * math.pi / 180.0))
  84. s = torch.abs(torch.sin(a * math.pi / 180.0))
  85. # This basically computes the horizontal bounding rectangle of the rotated box
  86. new_w = c * w + s * h
  87. new_h = c * h + s * w
  88. # convert center to top-left corner
  89. arr[:, 0] -= new_w / 2.0
  90. arr[:, 1] -= new_h / 2.0
  91. # bottom-right corner
  92. arr[:, 2] = arr[:, 0] + new_w
  93. arr[:, 3] = arr[:, 1] + new_h
  94. arr = arr[:, :4].to(dtype=original_dtype)
  95. elif from_mode == BoxMode.XYWH_ABS and to_mode == BoxMode.XYWHA_ABS:
  96. original_dtype = arr.dtype
  97. arr = arr.double()
  98. arr[:, 0] += arr[:, 2] / 2.0
  99. arr[:, 1] += arr[:, 3] / 2.0
  100. angles = torch.zeros((arr.shape[0], 1), dtype=arr.dtype)
  101. arr = torch.cat((arr, angles), axis=1).to(dtype=original_dtype)
  102. else:
  103. if to_mode == BoxMode.XYXY_ABS and from_mode == BoxMode.XYWH_ABS:
  104. arr[:, 2] += arr[:, 0]
  105. arr[:, 3] += arr[:, 1]
  106. elif from_mode == BoxMode.XYXY_ABS and to_mode == BoxMode.XYWH_ABS:
  107. arr[:, 2] -= arr[:, 0]
  108. arr[:, 3] -= arr[:, 1]
  109. else:
  110. raise NotImplementedError(
  111. "Conversion from BoxMode {} to {} is not supported yet".format(
  112. from_mode, to_mode
  113. )
  114. )
  115. if single_box:
  116. return original_type(arr.flatten().tolist())
  117. if is_numpy:
  118. return arr.numpy()
  119. else:
  120. return arr
  121. class Boxes:
  122. """
  123. This structure stores a list of boxes as a Nx4 torch.Tensor.
  124. It supports some common methods about boxes
  125. (`area`, `clip`, `nonempty`, etc),
  126. and also behaves like a Tensor
  127. (support indexing, `to(device)`, `.device`, and iteration over all boxes)
  128. Attributes:
  129. tensor (torch.Tensor): float matrix of Nx4. Each row is (x1, y1, x2, y2).
  130. """
  131. def __init__(self, tensor: torch.Tensor):
  132. """
  133. Args:
  134. tensor (Tensor[float]): a Nx4 matrix. Each row is (x1, y1, x2, y2).
  135. """
  136. if not isinstance(tensor, torch.Tensor):
  137. tensor = torch.as_tensor(
  138. tensor, dtype=torch.float32, device=torch.device("cpu")
  139. )
  140. else:
  141. tensor = tensor.to(torch.float32)
  142. if tensor.numel() == 0:
  143. # Use reshape, so we don't end up creating a new tensor that does not depend on
  144. # the inputs (and consequently confuses jit)
  145. tensor = tensor.reshape((-1, 4)).to(dtype=torch.float32)
  146. assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
  147. self.tensor = tensor
  148. def clone(self) -> "Boxes":
  149. """
  150. Clone the Boxes.
  151. Returns:
  152. Boxes
  153. """
  154. return Boxes(self.tensor.clone())
  155. def to(self, device: torch.device):
  156. # Boxes are assumed float32 and does not support to(dtype)
  157. return Boxes(self.tensor.to(device=device))
  158. def area(self) -> torch.Tensor:
  159. """
  160. Computes the area of all the boxes.
  161. Returns:
  162. torch.Tensor: a vector with areas of each box.
  163. """
  164. box = self.tensor
  165. area = (box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1])
  166. return area
  167. def clip(self, box_size: Tuple[int, int]) -> None:
  168. """
  169. Clip (in place) the boxes by limiting x coordinates to the range [0, width]
  170. and y coordinates to the range [0, height].
  171. Args:
  172. box_size (height, width): The clipping box's size.
  173. """
  174. assert torch.isfinite(self.tensor).all(), "Box tensor contains infinite or NaN!"
  175. h, w = box_size
  176. x1 = self.tensor[:, 0].clamp(min=0, max=w)
  177. y1 = self.tensor[:, 1].clamp(min=0, max=h)
  178. x2 = self.tensor[:, 2].clamp(min=0, max=w)
  179. y2 = self.tensor[:, 3].clamp(min=0, max=h)
  180. self.tensor = torch.stack((x1, y1, x2, y2), dim=-1)
  181. def nonempty(self, threshold: float = 0.0) -> torch.Tensor:
  182. """
  183. Find boxes that are non-empty.
  184. A box is considered empty, if either of its side is no larger than threshold.
  185. Returns:
  186. Tensor:
  187. a binary vector which represents whether each box is empty
  188. (False) or non-empty (True).
  189. """
  190. box = self.tensor
  191. widths = box[:, 2] - box[:, 0]
  192. heights = box[:, 3] - box[:, 1]
  193. keep = (widths > threshold) & (heights > threshold)
  194. return keep
  195. def __getitem__(self, item) -> "Boxes":
  196. """
  197. Args:
  198. item: int, slice, or a BoolTensor
  199. Returns:
  200. Boxes: Create a new :class:`Boxes` by indexing.
  201. The following usage are allowed:
  202. 1. `new_boxes = boxes[3]`: return a `Boxes` which contains only one box.
  203. 2. `new_boxes = boxes[2:10]`: return a slice of boxes.
  204. 3. `new_boxes = boxes[vector]`, where vector is a torch.BoolTensor
  205. with `length = len(boxes)`. Nonzero elements in the vector will be selected.
  206. Note that the returned Boxes might share storage with this Boxes,
  207. subject to Pytorch's indexing semantics.
  208. """
  209. if isinstance(item, int):
  210. return Boxes(self.tensor[item].view(1, -1))
  211. b = self.tensor[item]
  212. assert b.dim() == 2, (
  213. "Indexing on Boxes with {} failed to return a matrix!".format(item)
  214. )
  215. return Boxes(b)
  216. def __len__(self) -> int:
  217. return self.tensor.shape[0]
  218. def __repr__(self) -> str:
  219. return "Boxes(" + str(self.tensor) + ")"
  220. def inside_box(
  221. self, box_size: Tuple[int, int], boundary_threshold: int = 0
  222. ) -> torch.Tensor:
  223. """
  224. Args:
  225. box_size (height, width): Size of the reference box.
  226. boundary_threshold (int): Boxes that extend beyond the reference box
  227. boundary by more than boundary_threshold are considered "outside".
  228. Returns:
  229. a binary vector, indicating whether each box is inside the reference box.
  230. """
  231. height, width = box_size
  232. inds_inside = (
  233. (self.tensor[..., 0] >= -boundary_threshold)
  234. & (self.tensor[..., 1] >= -boundary_threshold)
  235. & (self.tensor[..., 2] < width + boundary_threshold)
  236. & (self.tensor[..., 3] < height + boundary_threshold)
  237. )
  238. return inds_inside
  239. def get_centers(self) -> torch.Tensor:
  240. """
  241. Returns:
  242. The box centers in a Nx2 array of (x, y).
  243. """
  244. return (self.tensor[:, :2] + self.tensor[:, 2:]) / 2
  245. def scale(self, scale_x: float, scale_y: float) -> None:
  246. """
  247. Scale the box with horizontal and vertical scaling factors
  248. """
  249. self.tensor[:, 0::2] *= scale_x
  250. self.tensor[:, 1::2] *= scale_y
  251. @classmethod
  252. def cat(cls, boxes_list: List["Boxes"]) -> "Boxes":
  253. """
  254. Concatenates a list of Boxes into a single Boxes
  255. Arguments:
  256. boxes_list (list[Boxes])
  257. Returns:
  258. Boxes: the concatenated Boxes
  259. """
  260. assert isinstance(boxes_list, (list, tuple))
  261. if len(boxes_list) == 0:
  262. return cls(torch.empty(0))
  263. assert all([isinstance(box, Boxes) for box in boxes_list])
  264. # use torch.cat (v.s. layers.cat) so the returned boxes never share storage with input
  265. cat_boxes = cls(torch.cat([b.tensor for b in boxes_list], dim=0))
  266. return cat_boxes
  267. @property
  268. def device(self) -> device:
  269. return self.tensor.device
  270. # type "Iterator[torch.Tensor]", yield, and iter() not supported by torchscript
  271. # https://github.com/pytorch/pytorch/issues/18627
  272. @torch.jit.unused
  273. def __iter__(self):
  274. """
  275. Yield a box as a Tensor of shape (4,) at a time.
  276. """
  277. yield from self.tensor
  278. def pairwise_intersection(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
  279. """
  280. Given two lists of boxes of size N and M,
  281. compute the intersection area between __all__ N x M pairs of boxes.
  282. The box order must be (xmin, ymin, xmax, ymax)
  283. Args:
  284. boxes1,boxes2 (Boxes): two `Boxes`. Contains N & M boxes, respectively.
  285. Returns:
  286. Tensor: intersection, sized [N,M].
  287. """
  288. boxes1, boxes2 = boxes1.tensor, boxes2.tensor
  289. width_height = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) - torch.max(
  290. boxes1[:, None, :2], boxes2[:, :2]
  291. ) # [N,M,2]
  292. width_height.clamp_(min=0) # [N,M,2]
  293. intersection = width_height.prod(dim=2) # [N,M]
  294. return intersection
  295. # implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
  296. # with slight modifications
  297. def pairwise_iou(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
  298. """
  299. Given two lists of boxes of size N and M, compute the IoU
  300. (intersection over union) between **all** N x M pairs of boxes.
  301. The box order must be (xmin, ymin, xmax, ymax).
  302. Args:
  303. boxes1,boxes2 (Boxes): two `Boxes`. Contains N & M boxes, respectively.
  304. Returns:
  305. Tensor: IoU, sized [N,M].
  306. """
  307. area1 = boxes1.area() # [N]
  308. area2 = boxes2.area() # [M]
  309. inter = pairwise_intersection(boxes1, boxes2)
  310. # handle empty boxes
  311. iou = torch.where(
  312. inter > 0,
  313. inter / (area1[:, None] + area2 - inter),
  314. torch.zeros(1, dtype=inter.dtype, device=inter.device),
  315. )
  316. return iou
  317. def pairwise_ioa(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
  318. """
  319. Similar to :func:`pariwise_iou` but compute the IoA (intersection over boxes2 area).
  320. Args:
  321. boxes1,boxes2 (Boxes): two `Boxes`. Contains N & M boxes, respectively.
  322. Returns:
  323. Tensor: IoA, sized [N,M].
  324. """
  325. area2 = boxes2.area() # [M]
  326. inter = pairwise_intersection(boxes1, boxes2)
  327. # handle empty boxes
  328. ioa = torch.where(
  329. inter > 0, inter / area2, torch.zeros(1, dtype=inter.dtype, device=inter.device)
  330. )
  331. return ioa
  332. def pairwise_point_box_distance(points: torch.Tensor, boxes: Boxes):
  333. """
  334. Pairwise distance between N points and M boxes. The distance between a
  335. point and a box is represented by the distance from the point to 4 edges
  336. of the box. Distances are all positive when the point is inside the box.
  337. Args:
  338. points: Nx2 coordinates. Each row is (x, y)
  339. boxes: M boxes
  340. Returns:
  341. Tensor: distances of size (N, M, 4). The 4 values are distances from
  342. the point to the left, top, right, bottom of the box.
  343. """
  344. x, y = points.unsqueeze(dim=2).unbind(dim=1) # (N, 1)
  345. x0, y0, x1, y1 = boxes.tensor.unsqueeze(dim=0).unbind(dim=2) # (1, M)
  346. return torch.stack([x - x0, y - y0, x1 - x, y1 - y], dim=2)
  347. def matched_pairwise_iou(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
  348. """
  349. Compute pairwise intersection over union (IOU) of two sets of matched
  350. boxes that have the same number of boxes.
  351. Similar to :func:`pairwise_iou`, but computes only diagonal elements of the matrix.
  352. Args:
  353. boxes1 (Boxes): bounding boxes, sized [N,4].
  354. boxes2 (Boxes): same length as boxes1
  355. Returns:
  356. Tensor: iou, sized [N].
  357. """
  358. assert len(boxes1) == len(boxes2), (
  359. "boxlists should have the samenumber of entries, got {}, {}".format(
  360. len(boxes1), len(boxes2)
  361. )
  362. )
  363. area1 = boxes1.area() # [N]
  364. area2 = boxes2.area() # [N]
  365. box1, box2 = boxes1.tensor, boxes2.tensor
  366. lt = torch.max(box1[:, :2], box2[:, :2]) # [N,2]
  367. rb = torch.min(box1[:, 2:], box2[:, 2:]) # [N,2]
  368. wh = (rb - lt).clamp(min=0) # [N,2]
  369. inter = wh[:, 0] * wh[:, 1] # [N]
  370. iou = inter / (area1 + area2 - inter) # [N]
  371. return iou