| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586 |
- # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
- # pyre-unsafe
- import logging
- import torch
- try:
- from cc_torch import get_connected_components
- HAS_CC_TORCH = True
- except ImportError:
- logging.debug(
- "cc_torch not found. Consider installing for better performance. Command line:"
- " pip install git+https://github.com/ronghanghu/cc_torch.git"
- )
- HAS_CC_TORCH = False
- def connected_components_cpu_single(values: torch.Tensor):
- assert values.dim() == 2
- from skimage.measure import label
- labels, num = label(values.cpu().numpy(), return_num=True)
- labels = torch.from_numpy(labels)
- counts = torch.zeros_like(labels)
- for i in range(1, num + 1):
- cur_mask = labels == i
- cur_count = cur_mask.sum()
- counts[cur_mask] = cur_count
- return labels, counts
- def connected_components_cpu(input_tensor: torch.Tensor):
- out_shape = input_tensor.shape
- if input_tensor.dim() == 4 and input_tensor.shape[1] == 1:
- input_tensor = input_tensor.squeeze(1)
- else:
- assert input_tensor.dim() == 3, (
- "Input tensor must be (B, H, W) or (B, 1, H, W)."
- )
- batch_size = input_tensor.shape[0]
- labels_list = []
- counts_list = []
- for b in range(batch_size):
- labels, counts = connected_components_cpu_single(input_tensor[b])
- labels_list.append(labels)
- counts_list.append(counts)
- labels_tensor = torch.stack(labels_list, dim=0).to(input_tensor.device)
- counts_tensor = torch.stack(counts_list, dim=0).to(input_tensor.device)
- return labels_tensor.view(out_shape), counts_tensor.view(out_shape)
- def connected_components(input_tensor: torch.Tensor):
- """
- Computes connected components labeling on a batch of 2D tensors, using the best available backend.
- Args:
- input_tensor (torch.Tensor): A BxHxW integer tensor or Bx1xHxW. Non-zero values are considered foreground. Bool tensor also accepted
- Returns:
- Tuple[torch.Tensor, torch.Tensor]: Both tensors have the same shape as input_tensor.
- - A tensor with dense labels. Background is 0.
- - A tensor with the size of the connected component for each pixel.
- """
- if input_tensor.dim() == 3:
- input_tensor = input_tensor.unsqueeze(1)
- assert input_tensor.dim() == 4 and input_tensor.shape[1] == 1, (
- "Input tensor must be (B, H, W) or (B, 1, H, W)."
- )
- if input_tensor.is_cuda:
- if HAS_CC_TORCH:
- return get_connected_components(input_tensor.to(torch.uint8))
- else:
- # triton fallback
- from sam3.perflib.triton.connected_components import (
- connected_components_triton,
- )
- return connected_components_triton(input_tensor)
- # CPU fallback
- return connected_components_cpu(input_tensor)
|