connected_components.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import logging
  4. import torch
  5. try:
  6. from cc_torch import get_connected_components
  7. HAS_CC_TORCH = True
  8. except ImportError:
  9. logging.debug(
  10. "cc_torch not found. Consider installing for better performance. Command line:"
  11. " pip install git+https://github.com/ronghanghu/cc_torch.git"
  12. )
  13. HAS_CC_TORCH = False
  14. def connected_components_cpu_single(values: torch.Tensor):
  15. assert values.dim() == 2
  16. from skimage.measure import label
  17. labels, num = label(values.cpu().numpy(), return_num=True)
  18. labels = torch.from_numpy(labels)
  19. counts = torch.zeros_like(labels)
  20. for i in range(1, num + 1):
  21. cur_mask = labels == i
  22. cur_count = cur_mask.sum()
  23. counts[cur_mask] = cur_count
  24. return labels, counts
  25. def connected_components_cpu(input_tensor: torch.Tensor):
  26. out_shape = input_tensor.shape
  27. if input_tensor.dim() == 4 and input_tensor.shape[1] == 1:
  28. input_tensor = input_tensor.squeeze(1)
  29. else:
  30. assert input_tensor.dim() == 3, (
  31. "Input tensor must be (B, H, W) or (B, 1, H, W)."
  32. )
  33. batch_size = input_tensor.shape[0]
  34. labels_list = []
  35. counts_list = []
  36. for b in range(batch_size):
  37. labels, counts = connected_components_cpu_single(input_tensor[b])
  38. labels_list.append(labels)
  39. counts_list.append(counts)
  40. labels_tensor = torch.stack(labels_list, dim=0).to(input_tensor.device)
  41. counts_tensor = torch.stack(counts_list, dim=0).to(input_tensor.device)
  42. return labels_tensor.view(out_shape), counts_tensor.view(out_shape)
  43. def connected_components(input_tensor: torch.Tensor):
  44. """
  45. Computes connected components labeling on a batch of 2D tensors, using the best available backend.
  46. Args:
  47. input_tensor (torch.Tensor): A BxHxW integer tensor or Bx1xHxW. Non-zero values are considered foreground. Bool tensor also accepted
  48. Returns:
  49. Tuple[torch.Tensor, torch.Tensor]: Both tensors have the same shape as input_tensor.
  50. - A tensor with dense labels. Background is 0.
  51. - A tensor with the size of the connected component for each pixel.
  52. """
  53. if input_tensor.dim() == 3:
  54. input_tensor = input_tensor.unsqueeze(1)
  55. assert input_tensor.dim() == 4 and input_tensor.shape[1] == 1, (
  56. "Input tensor must be (B, H, W) or (B, 1, H, W)."
  57. )
  58. if input_tensor.is_cuda:
  59. if HAS_CC_TORCH:
  60. return get_connected_components(input_tensor.to(torch.uint8))
  61. else:
  62. # triton fallback
  63. from sam3.perflib.triton.connected_components import (
  64. connected_components_triton,
  65. )
  66. return connected_components_triton(input_tensor)
  67. # CPU fallback
  68. return connected_components_cpu(input_tensor)