connected_components.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import math
  4. import torch
  5. import triton
  6. import triton.language as tl
  7. @triton.jit
  8. def _any_combine(a, b):
  9. return a | b
  10. @triton.jit
  11. def tl_any(a, dim=0):
  12. return tl.reduce(a, dim, _any_combine)
  13. # ==============================================================================
  14. # ## Phase 1: Initialization Kernel
  15. # ==============================================================================
  16. # Each foreground pixel (value > 0) gets a unique label equal to its
  17. # linear index. Background pixels (value == 0) get a sentinel label of -1.
  18. # Note that the indexing is done across batch boundaries for simplicity
  19. # (i.e., the first pixel of image 1 gets label H*W, etc.)
  20. @triton.jit
  21. def _init_labels_kernel(
  22. input_ptr, labels_ptr, numel: tl.constexpr, BLOCK_SIZE: tl.constexpr
  23. ):
  24. pid = tl.program_id(0)
  25. offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
  26. mask = offsets < numel
  27. input_values = tl.load(input_ptr + offsets, mask=mask, other=0)
  28. indices = tl.where((input_values != 0), offsets, -1)
  29. tl.store(labels_ptr + offsets, indices, mask=mask)
  30. # ==============================================================================
  31. # ## Phase 2: Local merging
  32. # ==============================================================================
  33. # Each pixel tries to merge with its 8-connected neighbors (up, down, left, right)
  34. # if they have the same value. This is done using a disjoint-set union operation.
  35. @triton.jit
  36. def find(labels_ptr, indices, mask):
  37. current_pids = indices
  38. # 'is_done' tracks lanes that have finished their work.
  39. # A lane is initially "done" if it's not active (mask is False).
  40. is_done = ~mask
  41. # Loop as long as there is at least one lane that is NOT done.
  42. while tl_any(~is_done):
  43. # The work_mask is for lanes that are still active and seeking their root.
  44. work_mask = ~is_done
  45. parents = tl.load(labels_ptr + current_pids, mask=work_mask, other=-1)
  46. # A lane is now done if its parent is itself (it's a root)
  47. # or if it hits a -1 sentinel (a safe exit condition).
  48. is_root = parents == current_pids
  49. is_sentinel = parents == -1
  50. is_done |= is_root | is_sentinel
  51. # For lanes that are not yet done, update their pid to their parent to continue traversal.
  52. current_pids = tl.where(is_done, current_pids, parents)
  53. # We could add the following line to do path compression, but experimentally it's slower
  54. # tl.atomic_min(labels_ptr + indices, current_pids, mask=mask)
  55. return current_pids
  56. @triton.jit
  57. def union(labels_ptr, a, b, process_mask):
  58. # This function implements a disjoint-set union
  59. # As an invariant, we use the fact that the roots have the lower id. That helps parallelization
  60. # However, that is not sufficient by itself. Suppose two threads want to do union(0,2) and union(1,2) at the same time
  61. # Then if we do a naive atomic_min, 0 and 1 will compete to be the new parent of 2 and min(0, 1) will win.
  62. # However, 1 still needs to be merged with the new {0, 2} component.
  63. # To ensure that merge is also done, we need to detect whether the merge was successful, and if not retry until it is
  64. current_a = a
  65. current_b = b
  66. final_root = a
  67. # A mask to track which lanes have successfully completed their union.
  68. done_mask = ~process_mask # tl.zeros_like(a) == 1 # Init with all False
  69. while tl_any(~done_mask):
  70. # Define the mask for lanes that still need work in this iteration
  71. work_mask = process_mask & ~done_mask
  72. # Find the roots for the current a and b values in the active lanes
  73. root_a = find(labels_ptr, current_a, work_mask)
  74. tl.debug_barrier()
  75. root_b = find(labels_ptr, current_b, work_mask)
  76. # 7. Merge logic
  77. # If roots are already the same, the sets are already merged. Mark as done.
  78. are_equal = root_a == root_b
  79. final_root = tl.where(are_equal & work_mask & ~done_mask, root_a, final_root)
  80. done_mask |= are_equal & work_mask
  81. # Define masks for the two merge cases (a < b or b < a)
  82. a_is_smaller = root_a < root_b
  83. # Case 1: root_a < root_b. Attempt to set parent[root_b] = root_a
  84. merge_mask_a_smaller = work_mask & a_is_smaller & ~are_equal
  85. ptr_b = labels_ptr + root_b
  86. old_val_b = tl.atomic_min(ptr_b, root_a, mask=merge_mask_a_smaller)
  87. # A lane is done if its atomic op was successful (old value was what we expected)
  88. success_b = old_val_b == root_b
  89. final_root = tl.where(success_b & work_mask & ~done_mask, root_a, final_root)
  90. done_mask |= success_b & merge_mask_a_smaller
  91. # *** Crucial Retry Logic ***
  92. # If the update failed (old_val_b != root_b), another thread interfered.
  93. # We update `current_b` to this new root (`old_val_b`) and will retry in the next loop iteration.
  94. current_b = tl.where(success_b | ~merge_mask_a_smaller, current_b, old_val_b)
  95. # Case 2: root_b < root_a. Attempt to set parent[root_a] = root_b
  96. merge_mask_b_smaller = work_mask & ~a_is_smaller & ~are_equal
  97. ptr_a = labels_ptr + root_a
  98. old_val_a = tl.atomic_min(ptr_a, root_b, mask=merge_mask_b_smaller)
  99. success_a = old_val_a == root_a
  100. final_root = tl.where(success_a & work_mask & ~done_mask, root_b, final_root)
  101. done_mask |= success_a & merge_mask_b_smaller
  102. # *** Crucial Retry Logic ***
  103. # Similarly, update `current_a` if the atomic operation failed.
  104. current_a = tl.where(success_a | ~merge_mask_b_smaller, current_a, old_val_a)
  105. return final_root
  106. @triton.jit
  107. def _merge_helper(
  108. input_ptr,
  109. labels_ptr,
  110. base_offset,
  111. offsets_h,
  112. offsets_w,
  113. mask_2d,
  114. valid_current,
  115. current_values,
  116. current_labels,
  117. H,
  118. W,
  119. dx: tl.constexpr,
  120. dy: tl.constexpr,
  121. ):
  122. # Helper functions to compute merge with a specific neighbor offset (dx, dy)
  123. neighbor_h = offsets_h + dy
  124. neighbor_w = offsets_w + dx
  125. # Proper bounds checking: all four bounds must be satisfied
  126. mask_n = (
  127. mask_2d
  128. & (neighbor_h[:, None] >= 0)
  129. & (neighbor_h[:, None] < H)
  130. & (neighbor_w[None, :] >= 0)
  131. & (neighbor_w[None, :] < W)
  132. )
  133. offsets_neighbor = neighbor_h[:, None] * W + neighbor_w[None, :]
  134. neighbor_values = tl.load(
  135. input_ptr + base_offset + offsets_neighbor, mask=mask_n, other=-1
  136. )
  137. mask_n = tl.ravel(mask_n)
  138. neighbor_labels = tl.load(
  139. labels_ptr + tl.ravel(base_offset + offsets_neighbor), mask=mask_n, other=-1
  140. )
  141. to_merge = (
  142. mask_n & (neighbor_labels != -1) & tl.ravel(current_values == neighbor_values)
  143. )
  144. valid_write = valid_current & to_merge
  145. # returns new parents for the pixels that were merged (otherwise keeps current labels)
  146. return tl.where(
  147. valid_write,
  148. union(labels_ptr, current_labels, neighbor_labels, valid_write),
  149. current_labels,
  150. )
  151. @triton.autotune(
  152. configs=[
  153. triton.Config(
  154. {"BLOCK_SIZE_H": 4, "BLOCK_SIZE_W": 16}, num_stages=1, num_warps=2
  155. ),
  156. triton.Config(
  157. {"BLOCK_SIZE_H": 4, "BLOCK_SIZE_W": 32}, num_stages=2, num_warps=4
  158. ),
  159. ],
  160. key=["H", "W"],
  161. restore_value=["labels_ptr"],
  162. )
  163. @triton.jit
  164. def _local_prop_kernel(
  165. labels_ptr,
  166. input_ptr,
  167. H: tl.constexpr,
  168. W: tl.constexpr,
  169. BLOCK_SIZE_H: tl.constexpr,
  170. BLOCK_SIZE_W: tl.constexpr,
  171. ):
  172. # This is the meat of the Phase 2 to do local merging
  173. # It will be launched with a 2D grid:
  174. # - dim 0: batch index
  175. # - dim 1: block index over HxW image (2D tiling)
  176. pid_b = tl.program_id(0)
  177. pid_hw = tl.program_id(1)
  178. # Calculate offsets for the core block
  179. offsets_h = (pid_hw // tl.cdiv(W, BLOCK_SIZE_W)) * BLOCK_SIZE_H + tl.arange(
  180. 0, BLOCK_SIZE_H
  181. )
  182. offsets_w = (pid_hw % tl.cdiv(W, BLOCK_SIZE_W)) * BLOCK_SIZE_W + tl.arange(
  183. 0, BLOCK_SIZE_W
  184. )
  185. base_offset = pid_b * H * W
  186. offsets_2d = offsets_h[:, None] * W + offsets_w[None, :]
  187. mask_2d = (offsets_h[:, None] < H) & (offsets_w[None, :] < W)
  188. mask_1d = tl.ravel(mask_2d)
  189. # Load the current labels for the block - these are parent pointers
  190. current_labels = tl.load(
  191. labels_ptr + tl.ravel(base_offset + offsets_2d), mask=mask_1d, other=-1
  192. )
  193. current_values = tl.load(
  194. input_ptr + base_offset + offsets_2d, mask=mask_2d, other=-1
  195. )
  196. valid_current = mask_1d & (current_labels != -1)
  197. # Horizontal merge
  198. current_labels = _merge_helper(
  199. input_ptr,
  200. labels_ptr,
  201. base_offset,
  202. offsets_h,
  203. offsets_w,
  204. mask_2d,
  205. valid_current,
  206. current_values,
  207. current_labels,
  208. H,
  209. W,
  210. -1,
  211. 0,
  212. )
  213. # Vertical merge
  214. current_labels = _merge_helper(
  215. input_ptr,
  216. labels_ptr,
  217. base_offset,
  218. offsets_h,
  219. offsets_w,
  220. mask_2d,
  221. valid_current,
  222. current_values,
  223. current_labels,
  224. H,
  225. W,
  226. 0,
  227. -1,
  228. )
  229. # Diagonal merges
  230. current_labels = _merge_helper(
  231. input_ptr,
  232. labels_ptr,
  233. base_offset,
  234. offsets_h,
  235. offsets_w,
  236. mask_2d,
  237. valid_current,
  238. current_values,
  239. current_labels,
  240. H,
  241. W,
  242. -1,
  243. -1,
  244. )
  245. current_labels = _merge_helper(
  246. input_ptr,
  247. labels_ptr,
  248. base_offset,
  249. offsets_h,
  250. offsets_w,
  251. mask_2d,
  252. valid_current,
  253. current_values,
  254. current_labels,
  255. H,
  256. W,
  257. -1,
  258. 1,
  259. )
  260. # This actually does some path compression, in a lightweight but beneficial way
  261. tl.atomic_min(
  262. labels_ptr + tl.ravel(base_offset + offsets_2d), current_labels, mask=mask_1d
  263. )
  264. # ==============================================================================
  265. # ## Phase 3: Pointer Jumping Kernel
  266. # ==============================================================================
  267. # This kernel performs pointer jumping to ensure that all pixels point directly to their root labels.
  268. # This is done in a loop until convergence.
  269. @triton.jit
  270. def _pointer_jump_kernel(
  271. labels_in_ptr, labels_out_ptr, numel: tl.constexpr, BLOCK_SIZE: tl.constexpr
  272. ):
  273. """
  274. Pointer jumping kernel with double buffering to avoid race conditions.
  275. Reads from labels_in_ptr and writes to labels_out_ptr.
  276. """
  277. # This kernel is launched with a 1D grid, and does not care about batching explicitly.
  278. # By construction, the labels are global indices across the batch, and we never perform
  279. # cross-batch merges, so this is safe.
  280. pid = tl.program_id(0)
  281. offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
  282. mask = offsets < numel
  283. # Load current labels from input buffer
  284. current_labels = tl.load(labels_in_ptr + offsets, mask=mask, other=-1)
  285. valid_mask = mask & (current_labels != -1)
  286. # A mask to track which lanes have successfully completed their union.
  287. done_mask = ~valid_mask
  288. while tl_any(~(done_mask | ~valid_mask)):
  289. parent_labels = tl.load(
  290. labels_in_ptr + current_labels, mask=valid_mask, other=-1
  291. )
  292. are_equal = current_labels == parent_labels
  293. done_mask |= are_equal & valid_mask
  294. current_labels = tl.where(
  295. ~done_mask, tl.minimum(current_labels, parent_labels), current_labels
  296. )
  297. # Write to output buffer (safe because we're not reading from it)
  298. tl.store(labels_out_ptr + offsets, current_labels, mask=mask)
  299. # ==============================================================================
  300. # ## Phase 4: Kernels for Computing Component Sizes
  301. # ==============================================================================
  302. # Step 4.1: Count occurrences of each root label using atomic adds.
  303. @triton.jit
  304. def _count_labels_kernel(labels_ptr, sizes_ptr, numel, BLOCK_SIZE: tl.constexpr):
  305. pid = tl.program_id(0)
  306. offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
  307. mask = offsets < numel
  308. # Load the final, converged labels
  309. labels = tl.load(labels_ptr + offsets, mask=mask, other=-1)
  310. valid_mask = mask & (labels != -1)
  311. # Atomically increment the counter for each label. This builds a histogram.
  312. tl.atomic_add(sizes_ptr + labels, 1, mask=valid_mask)
  313. # Step 4.2: Broadcast the computed sizes back to the output tensor.
  314. @triton.jit
  315. def _broadcast_sizes_kernel(
  316. labels_ptr, sizes_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr
  317. ):
  318. pid = tl.program_id(0)
  319. offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
  320. mask = offsets < numel
  321. # Load the final labels
  322. labels = tl.load(labels_ptr + offsets, mask=mask, other=-1)
  323. valid_mask = mask & (labels != -1)
  324. # Look up the size for each label from the histogram
  325. component_sizes = tl.load(sizes_ptr + labels, mask=valid_mask, other=0)
  326. # Write the size to the final output tensor. Background pixels get size 0.
  327. tl.store(out_ptr + offsets, component_sizes, mask=mask)
  328. def connected_components_triton(input_tensor: torch.Tensor):
  329. """
  330. Computes connected components labeling on a batch of 2D integer tensors using Triton.
  331. Args:
  332. input_tensor (torch.Tensor): A BxHxW integer tensor or Bx1xHxW. Non-zero values are considered foreground. Bool tensor also accepted
  333. Returns:
  334. Tuple[torch.Tensor, int]: A tuple containing:
  335. - A BxHxW output tensor with dense labels. Background is 0.
  336. - A BxHxW tensor with the size of the connected component for each pixel.
  337. """
  338. assert input_tensor.is_cuda and input_tensor.is_contiguous(), (
  339. "Input tensor must be a contiguous CUDA tensor."
  340. )
  341. out_shape = input_tensor.shape
  342. if input_tensor.dim() == 4 and input_tensor.shape[1] == 1:
  343. input_tensor = input_tensor.squeeze(1)
  344. else:
  345. assert input_tensor.dim() == 3, (
  346. "Input tensor must be (B, H, W) or (B, 1, H, W)."
  347. )
  348. B, H, W = input_tensor.shape
  349. numel = B * H * W
  350. device = input_tensor.device
  351. # --- Allocate Tensors ---
  352. labels = torch.empty_like(input_tensor, dtype=torch.int32)
  353. output = torch.empty_like(input_tensor, dtype=torch.int32)
  354. # --- Phase 1 ---
  355. BLOCK_SIZE = 256
  356. grid_init = (triton.cdiv(numel, BLOCK_SIZE),)
  357. _init_labels_kernel[grid_init](
  358. input_tensor,
  359. labels,
  360. numel,
  361. BLOCK_SIZE=BLOCK_SIZE,
  362. )
  363. # --- Phase 2 ---
  364. grid_local_prop = lambda meta: (
  365. B,
  366. triton.cdiv(H, meta["BLOCK_SIZE_H"]) * triton.cdiv(W, meta["BLOCK_SIZE_W"]),
  367. )
  368. _local_prop_kernel[grid_local_prop](labels, input_tensor, H, W)
  369. # --- Phase 3 ---
  370. BLOCK_SIZE = 256
  371. grid_jump = lambda meta: (triton.cdiv(numel, meta["BLOCK_SIZE"]),)
  372. _pointer_jump_kernel[grid_jump](labels, output, numel, BLOCK_SIZE=BLOCK_SIZE)
  373. # --- Phase 4 ---
  374. # Allocate tensor to store the final output sizes
  375. component_sizes_out = torch.empty_like(input_tensor, dtype=torch.int32)
  376. # Allocate a temporary 1D tensor to act as the histogram
  377. # Size is numel because labels can be up to numel-1
  378. sizes_histogram = torch.zeros(numel, dtype=torch.int32, device=device)
  379. # 4.1: Count the occurrences of each label
  380. grid_count = (triton.cdiv(numel, BLOCK_SIZE),)
  381. _count_labels_kernel[grid_count](
  382. output, sizes_histogram, numel, BLOCK_SIZE=BLOCK_SIZE
  383. )
  384. # 2.2: Broadcast the counts to the final output tensor
  385. grid_broadcast = (triton.cdiv(numel, BLOCK_SIZE),)
  386. _broadcast_sizes_kernel[grid_broadcast](
  387. output, sizes_histogram, component_sizes_out, numel, BLOCK_SIZE=BLOCK_SIZE
  388. )
  389. return output.view(out_shape) + 1, component_sizes_out.view(out_shape)