masks.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import copy
  4. import itertools
  5. from typing import Any, Iterator, List, Union
  6. import numpy as np
  7. import pycocotools.mask as mask_util
  8. import torch
  9. from torch import device
  10. from .boxes import Boxes
  11. from .memory import retry_if_cuda_oom
  12. from .roi_align import ROIAlign
  13. def polygon_area(x, y):
  14. # Using the shoelace formula
  15. # https://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates
  16. return 0.5 * np.abs(np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1)))
  17. def polygons_to_bitmask(
  18. polygons: List[np.ndarray], height: int, width: int
  19. ) -> np.ndarray:
  20. """
  21. Args:
  22. polygons (list[ndarray]): each array has shape (Nx2,)
  23. height, width (int)
  24. Returns:
  25. ndarray: a bool mask of shape (height, width)
  26. """
  27. if len(polygons) == 0:
  28. # COCOAPI does not support empty polygons
  29. return np.zeros((height, width)).astype(bool)
  30. rles = mask_util.frPyObjects(polygons, height, width)
  31. rle = mask_util.merge(rles)
  32. return mask_util.decode(rle).astype(bool)
  33. def rasterize_polygons_within_box(
  34. polygons: List[np.ndarray], box: np.ndarray, mask_size: int
  35. ) -> torch.Tensor:
  36. """
  37. Rasterize the polygons into a mask image and
  38. crop the mask content in the given box.
  39. The cropped mask is resized to (mask_size, mask_size).
  40. This function is used when generating training targets for mask head in Mask R-CNN.
  41. Given original ground-truth masks for an image, new ground-truth mask
  42. training targets in the size of `mask_size x mask_size`
  43. must be provided for each predicted box. This function will be called to
  44. produce such targets.
  45. Args:
  46. polygons (list[ndarray[float]]): a list of polygons, which represents an instance.
  47. box: 4-element numpy array
  48. mask_size (int):
  49. Returns:
  50. Tensor: BoolTensor of shape (mask_size, mask_size)
  51. """
  52. # 1. Shift the polygons w.r.t the boxes
  53. w, h = box[2] - box[0], box[3] - box[1]
  54. polygons = copy.deepcopy(polygons)
  55. for p in polygons:
  56. p[0::2] = p[0::2] - box[0]
  57. p[1::2] = p[1::2] - box[1]
  58. # 2. Rescale the polygons to the new box size
  59. # max() to avoid division by small number
  60. ratio_h = mask_size / max(h, 0.1)
  61. ratio_w = mask_size / max(w, 0.1)
  62. if ratio_h == ratio_w:
  63. for p in polygons:
  64. p *= ratio_h
  65. else:
  66. for p in polygons:
  67. p[0::2] *= ratio_w
  68. p[1::2] *= ratio_h
  69. # 3. Rasterize the polygons with coco api
  70. mask = polygons_to_bitmask(polygons, mask_size, mask_size)
  71. mask = torch.from_numpy(mask)
  72. return mask
  73. class BitMasks:
  74. """
  75. This class stores the segmentation masks for all objects in one image, in
  76. the form of bitmaps.
  77. Attributes:
  78. tensor: bool Tensor of N,H,W, representing N instances in the image.
  79. """
  80. def __init__(self, tensor: Union[torch.Tensor, np.ndarray]):
  81. """
  82. Args:
  83. tensor: bool Tensor of N,H,W, representing N instances in the image.
  84. """
  85. if isinstance(tensor, torch.Tensor):
  86. tensor = tensor.to(torch.bool)
  87. else:
  88. tensor = torch.as_tensor(
  89. tensor, dtype=torch.bool, device=torch.device("cpu")
  90. )
  91. assert tensor.dim() == 3, tensor.size()
  92. self.image_size = tensor.shape[1:]
  93. self.tensor = tensor
  94. @torch.jit.unused
  95. def to(self, *args: Any, **kwargs: Any) -> "BitMasks":
  96. return BitMasks(self.tensor.to(*args, **kwargs))
  97. @property
  98. def device(self) -> torch.device:
  99. return self.tensor.device
  100. @torch.jit.unused
  101. def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "BitMasks":
  102. """
  103. Returns:
  104. BitMasks: Create a new :class:`BitMasks` by indexing.
  105. The following usage are allowed:
  106. 1. `new_masks = masks[3]`: return a `BitMasks` which contains only one mask.
  107. 2. `new_masks = masks[2:10]`: return a slice of masks.
  108. 3. `new_masks = masks[vector]`, where vector is a torch.BoolTensor
  109. with `length = len(masks)`. Nonzero elements in the vector will be selected.
  110. Note that the returned object might share storage with this object,
  111. subject to Pytorch's indexing semantics.
  112. """
  113. if isinstance(item, int):
  114. return BitMasks(self.tensor[item].unsqueeze(0))
  115. m = self.tensor[item]
  116. assert m.dim() == 3, (
  117. "Indexing on BitMasks with {} returns a tensor with shape {}!".format(
  118. item, m.shape
  119. )
  120. )
  121. return BitMasks(m)
  122. @torch.jit.unused
  123. def __iter__(self) -> torch.Tensor:
  124. yield from self.tensor
  125. @torch.jit.unused
  126. def __repr__(self) -> str:
  127. s = self.__class__.__name__ + "("
  128. s += "num_instances={})".format(len(self.tensor))
  129. return s
  130. def __len__(self) -> int:
  131. return self.tensor.shape[0]
  132. def nonempty(self) -> torch.Tensor:
  133. """
  134. Find masks that are non-empty.
  135. Returns:
  136. Tensor: a BoolTensor which represents
  137. whether each mask is empty (False) or non-empty (True).
  138. """
  139. return self.tensor.flatten(1).any(dim=1)
  140. @staticmethod
  141. def from_polygon_masks(
  142. polygon_masks: Union["PolygonMasks", List[List[np.ndarray]]],
  143. height: int,
  144. width: int,
  145. ) -> "BitMasks":
  146. """
  147. Args:
  148. polygon_masks (list[list[ndarray]] or PolygonMasks)
  149. height, width (int)
  150. """
  151. if isinstance(polygon_masks, PolygonMasks):
  152. polygon_masks = polygon_masks.polygons
  153. masks = [polygons_to_bitmask(p, height, width) for p in polygon_masks]
  154. if len(masks):
  155. return BitMasks(torch.stack([torch.from_numpy(x) for x in masks]))
  156. else:
  157. return BitMasks(torch.empty(0, height, width, dtype=torch.bool))
  158. @staticmethod
  159. def from_roi_masks(roi_masks: "ROIMasks", height: int, width: int) -> "BitMasks":
  160. """
  161. Args:
  162. roi_masks:
  163. height, width (int):
  164. """
  165. return roi_masks.to_bitmasks(height, width)
  166. def crop_and_resize(self, boxes: torch.Tensor, mask_size: int) -> torch.Tensor:
  167. """
  168. Crop each bitmask by the given box, and resize results to (mask_size, mask_size).
  169. This can be used to prepare training targets for Mask R-CNN.
  170. It has less reconstruction error compared to rasterization with polygons.
  171. However we observe no difference in accuracy,
  172. but BitMasks requires more memory to store all the masks.
  173. Args:
  174. boxes (Tensor): Nx4 tensor storing the boxes for each mask
  175. mask_size (int): the size of the rasterized mask.
  176. Returns:
  177. Tensor:
  178. A bool tensor of shape (N, mask_size, mask_size), where
  179. N is the number of predicted boxes for this image.
  180. """
  181. assert len(boxes) == len(self), "{} != {}".format(len(boxes), len(self))
  182. device = self.tensor.device
  183. batch_inds = torch.arange(len(boxes), device=device).to(dtype=boxes.dtype)[
  184. :, None
  185. ]
  186. rois = torch.cat([batch_inds, boxes], dim=1) # Nx5
  187. bit_masks = self.tensor.to(dtype=torch.float32)
  188. rois = rois.to(device=device)
  189. output = (
  190. ROIAlign((mask_size, mask_size), 1.0, 0, aligned=True)
  191. .forward(bit_masks[:, None, :, :], rois)
  192. .squeeze(1)
  193. )
  194. output = output >= 0.5
  195. return output
  196. def get_bounding_boxes(self) -> Boxes:
  197. """
  198. Returns:
  199. Boxes: tight bounding boxes around bitmasks.
  200. If a mask is empty, it's bounding box will be all zero.
  201. """
  202. boxes = torch.zeros(self.tensor.shape[0], 4, dtype=torch.float32)
  203. x_any = torch.any(self.tensor, dim=1)
  204. y_any = torch.any(self.tensor, dim=2)
  205. for idx in range(self.tensor.shape[0]):
  206. x = torch.where(x_any[idx, :])[0]
  207. y = torch.where(y_any[idx, :])[0]
  208. if len(x) > 0 and len(y) > 0:
  209. boxes[idx, :] = torch.as_tensor(
  210. [x[0], y[0], x[-1] + 1, y[-1] + 1], dtype=torch.float32
  211. )
  212. return Boxes(boxes)
  213. @staticmethod
  214. def cat(bitmasks_list: List["BitMasks"]) -> "BitMasks":
  215. """
  216. Concatenates a list of BitMasks into a single BitMasks
  217. Arguments:
  218. bitmasks_list (list[BitMasks])
  219. Returns:
  220. BitMasks: the concatenated BitMasks
  221. """
  222. assert isinstance(bitmasks_list, (list, tuple))
  223. assert len(bitmasks_list) > 0
  224. assert all(isinstance(bitmask, BitMasks) for bitmask in bitmasks_list)
  225. cat_bitmasks = type(bitmasks_list[0])(
  226. torch.cat([bm.tensor for bm in bitmasks_list], dim=0)
  227. )
  228. return cat_bitmasks
  229. class PolygonMasks:
  230. """
  231. This class stores the segmentation masks for all objects in one image, in the form of polygons.
  232. Attributes:
  233. polygons: list[list[ndarray]]. Each ndarray is a float64 vector representing a polygon.
  234. """
  235. def __init__(self, polygons: List[List[Union[torch.Tensor, np.ndarray]]]):
  236. """
  237. Arguments:
  238. polygons (list[list[np.ndarray]]): The first
  239. level of the list correspond to individual instances,
  240. the second level to all the polygons that compose the
  241. instance, and the third level to the polygon coordinates.
  242. The third level array should have the format of
  243. [x0, y0, x1, y1, ..., xn, yn] (n >= 3).
  244. """
  245. if not isinstance(polygons, list):
  246. raise ValueError(
  247. "Cannot create PolygonMasks: Expect a list of list of polygons per image. "
  248. "Got '{}' instead.".format(type(polygons))
  249. )
  250. def _make_array(t: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
  251. # Use float64 for higher precision, because why not?
  252. # Always put polygons on CPU (self.to is a no-op) since they
  253. # are supposed to be small tensors.
  254. # May need to change this assumption if GPU placement becomes useful
  255. if isinstance(t, torch.Tensor):
  256. t = t.cpu().numpy()
  257. return np.asarray(t).astype("float64")
  258. def process_polygons(
  259. polygons_per_instance: List[Union[torch.Tensor, np.ndarray]],
  260. ) -> List[np.ndarray]:
  261. if not isinstance(polygons_per_instance, list):
  262. raise ValueError(
  263. "Cannot create polygons: Expect a list of polygons per instance. "
  264. "Got '{}' instead.".format(type(polygons_per_instance))
  265. )
  266. # transform each polygon to a numpy array
  267. polygons_per_instance = [_make_array(p) for p in polygons_per_instance]
  268. for polygon in polygons_per_instance:
  269. if len(polygon) % 2 != 0 or len(polygon) < 6:
  270. raise ValueError(
  271. f"Cannot create a polygon from {len(polygon)} coordinates."
  272. )
  273. return polygons_per_instance
  274. self.polygons: List[List[np.ndarray]] = [
  275. process_polygons(polygons_per_instance)
  276. for polygons_per_instance in polygons
  277. ]
  278. def to(self, *args: Any, **kwargs: Any) -> "PolygonMasks":
  279. return self
  280. @property
  281. def device(self) -> torch.device:
  282. return torch.device("cpu")
  283. def get_bounding_boxes(self) -> Boxes:
  284. """
  285. Returns:
  286. Boxes: tight bounding boxes around polygon masks.
  287. """
  288. boxes = torch.zeros(len(self.polygons), 4, dtype=torch.float32)
  289. for idx, polygons_per_instance in enumerate(self.polygons):
  290. minxy = torch.as_tensor([float("inf"), float("inf")], dtype=torch.float32)
  291. maxxy = torch.zeros(2, dtype=torch.float32)
  292. for polygon in polygons_per_instance:
  293. coords = torch.from_numpy(polygon).view(-1, 2).to(dtype=torch.float32)
  294. minxy = torch.min(minxy, torch.min(coords, dim=0).values)
  295. maxxy = torch.max(maxxy, torch.max(coords, dim=0).values)
  296. boxes[idx, :2] = minxy
  297. boxes[idx, 2:] = maxxy
  298. return Boxes(boxes)
  299. def nonempty(self) -> torch.Tensor:
  300. """
  301. Find masks that are non-empty.
  302. Returns:
  303. Tensor:
  304. a BoolTensor which represents whether each mask is empty (False) or not (True).
  305. """
  306. keep = [1 if len(polygon) > 0 else 0 for polygon in self.polygons]
  307. return torch.from_numpy(np.asarray(keep, dtype=bool))
  308. def __getitem__(
  309. self, item: Union[int, slice, List[int], torch.BoolTensor]
  310. ) -> "PolygonMasks":
  311. """
  312. Support indexing over the instances and return a `PolygonMasks` object.
  313. `item` can be:
  314. 1. An integer. It will return an object with only one instance.
  315. 2. A slice. It will return an object with the selected instances.
  316. 3. A list[int]. It will return an object with the selected instances,
  317. correpsonding to the indices in the list.
  318. 4. A vector mask of type BoolTensor, whose length is num_instances.
  319. It will return an object with the instances whose mask is nonzero.
  320. """
  321. if isinstance(item, int):
  322. selected_polygons = [self.polygons[item]]
  323. elif isinstance(item, slice):
  324. selected_polygons = self.polygons[item]
  325. elif isinstance(item, list):
  326. selected_polygons = [self.polygons[i] for i in item]
  327. elif isinstance(item, torch.Tensor):
  328. # Polygons is a list, so we have to move the indices back to CPU.
  329. if item.dtype == torch.bool:
  330. assert item.dim() == 1, item.shape
  331. item = item.nonzero().squeeze(1).cpu().numpy().tolist()
  332. elif item.dtype in [torch.int32, torch.int64]:
  333. item = item.cpu().numpy().tolist()
  334. else:
  335. raise ValueError(
  336. "Unsupported tensor dtype={} for indexing!".format(item.dtype)
  337. )
  338. selected_polygons = [self.polygons[i] for i in item]
  339. return PolygonMasks(selected_polygons)
  340. def __iter__(self) -> Iterator[List[np.ndarray]]:
  341. """
  342. Yields:
  343. list[ndarray]: the polygons for one instance.
  344. Each Tensor is a float64 vector representing a polygon.
  345. """
  346. return iter(self.polygons)
  347. def __repr__(self) -> str:
  348. s = self.__class__.__name__ + "("
  349. s += "num_instances={})".format(len(self.polygons))
  350. return s
  351. def __len__(self) -> int:
  352. return len(self.polygons)
  353. def crop_and_resize(self, boxes: torch.Tensor, mask_size: int) -> torch.Tensor:
  354. """
  355. Crop each mask by the given box, and resize results to (mask_size, mask_size).
  356. This can be used to prepare training targets for Mask R-CNN.
  357. Args:
  358. boxes (Tensor): Nx4 tensor storing the boxes for each mask
  359. mask_size (int): the size of the rasterized mask.
  360. Returns:
  361. Tensor: A bool tensor of shape (N, mask_size, mask_size), where
  362. N is the number of predicted boxes for this image.
  363. """
  364. assert len(boxes) == len(self), "{} != {}".format(len(boxes), len(self))
  365. device = boxes.device
  366. # Put boxes on the CPU, as the polygon representation is not efficient GPU-wise
  367. # (several small tensors for representing a single instance mask)
  368. boxes = boxes.to(torch.device("cpu"))
  369. results = [
  370. rasterize_polygons_within_box(poly, box.numpy(), mask_size)
  371. for poly, box in zip(self.polygons, boxes)
  372. ]
  373. """
  374. poly: list[list[float]], the polygons for one instance
  375. box: a tensor of shape (4,)
  376. """
  377. if len(results) == 0:
  378. return torch.empty(0, mask_size, mask_size, dtype=torch.bool, device=device)
  379. return torch.stack(results, dim=0).to(device=device)
  380. def area(self):
  381. """
  382. Computes area of the mask.
  383. Only works with Polygons, using the shoelace formula:
  384. https://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates
  385. Returns:
  386. Tensor: a vector, area for each instance
  387. """
  388. area = []
  389. for polygons_per_instance in self.polygons:
  390. area_per_instance = 0
  391. for p in polygons_per_instance:
  392. area_per_instance += polygon_area(p[0::2], p[1::2])
  393. area.append(area_per_instance)
  394. return torch.tensor(area)
  395. @staticmethod
  396. def cat(polymasks_list: List["PolygonMasks"]) -> "PolygonMasks":
  397. """
  398. Concatenates a list of PolygonMasks into a single PolygonMasks
  399. Arguments:
  400. polymasks_list (list[PolygonMasks])
  401. Returns:
  402. PolygonMasks: the concatenated PolygonMasks
  403. """
  404. assert isinstance(polymasks_list, (list, tuple))
  405. assert len(polymasks_list) > 0
  406. assert all(isinstance(polymask, PolygonMasks) for polymask in polymasks_list)
  407. cat_polymasks = type(polymasks_list[0])(
  408. list(itertools.chain.from_iterable(pm.polygons for pm in polymasks_list))
  409. )
  410. return cat_polymasks
  411. class ROIMasks:
  412. """
  413. Represent masks by N smaller masks defined in some ROIs. Once ROI boxes are given,
  414. full-image bitmask can be obtained by "pasting" the mask on the region defined
  415. by the corresponding ROI box.
  416. """
  417. def __init__(self, tensor: torch.Tensor):
  418. """
  419. Args:
  420. tensor: (N, M, M) mask tensor that defines the mask within each ROI.
  421. """
  422. if tensor.dim() != 3:
  423. raise ValueError("ROIMasks must take a masks of 3 dimension.")
  424. self.tensor = tensor
  425. def to(self, device: torch.device) -> "ROIMasks":
  426. return ROIMasks(self.tensor.to(device))
  427. @property
  428. def device(self) -> device:
  429. return self.tensor.device
  430. def __len__(self):
  431. return self.tensor.shape[0]
  432. def __getitem__(self, item) -> "ROIMasks":
  433. """
  434. Returns:
  435. ROIMasks: Create a new :class:`ROIMasks` by indexing.
  436. The following usage are allowed:
  437. 1. `new_masks = masks[2:10]`: return a slice of masks.
  438. 2. `new_masks = masks[vector]`, where vector is a torch.BoolTensor
  439. with `length = len(masks)`. Nonzero elements in the vector will be selected.
  440. Note that the returned object might share storage with this object,
  441. subject to Pytorch's indexing semantics.
  442. """
  443. t = self.tensor[item]
  444. if t.dim() != 3:
  445. raise ValueError(
  446. f"Indexing on ROIMasks with {item} returns a tensor with shape {t.shape}!"
  447. )
  448. return ROIMasks(t)
  449. @torch.jit.unused
  450. def __repr__(self) -> str:
  451. s = self.__class__.__name__ + "("
  452. s += "num_instances={})".format(len(self.tensor))
  453. return s
  454. @torch.jit.unused
  455. def to_bitmasks(self, boxes: torch.Tensor, height, width, threshold=0.5):
  456. """
  457. Args: see documentation of :func:`paste_masks_in_image`.
  458. """
  459. from detectron2.layers.mask_ops import (
  460. _paste_masks_tensor_shape,
  461. paste_masks_in_image,
  462. )
  463. if torch.jit.is_tracing():
  464. if isinstance(height, torch.Tensor):
  465. paste_func = _paste_masks_tensor_shape
  466. else:
  467. paste_func = paste_masks_in_image
  468. else:
  469. paste_func = retry_if_cuda_oom(paste_masks_in_image)
  470. bitmasks = paste_func(
  471. self.tensor, boxes.tensor, (height, width), threshold=threshold
  472. )
  473. return BitMasks(bitmasks)