connected_components.cu 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. // Copyright (c) Meta Platforms, Inc. and affiliates.
  2. // All rights reserved.
  3. // This source code is licensed under the license found in the
  4. // LICENSE file in the root directory of this source tree.
  5. // adapted from https://github.com/zsef123/Connected_components_PyTorch
  6. // with license found in the LICENSE_cctorch file in the root directory.
  7. #include <ATen/cuda/CUDAContext.h>
  8. #include <cuda.h>
  9. #include <cuda_runtime.h>
  10. #include <torch/extension.h>
  11. #include <torch/script.h>
  12. #include <vector>
  13. // 2d
  14. #define BLOCK_ROWS 16
  15. #define BLOCK_COLS 16
  16. namespace cc2d {
  17. template <typename T>
  18. __device__ __forceinline__ unsigned char hasBit(T bitmap, unsigned char pos) {
  19. return (bitmap >> pos) & 1;
  20. }
  21. __device__ int32_t find(const int32_t* s_buf, int32_t n) {
  22. while (s_buf[n] != n)
  23. n = s_buf[n];
  24. return n;
  25. }
  26. __device__ int32_t find_n_compress(int32_t* s_buf, int32_t n) {
  27. const int32_t id = n;
  28. while (s_buf[n] != n) {
  29. n = s_buf[n];
  30. s_buf[id] = n;
  31. }
  32. return n;
  33. }
  34. __device__ void union_(int32_t* s_buf, int32_t a, int32_t b) {
  35. bool done;
  36. do {
  37. a = find(s_buf, a);
  38. b = find(s_buf, b);
  39. if (a < b) {
  40. int32_t old = atomicMin(s_buf + b, a);
  41. done = (old == b);
  42. b = old;
  43. } else if (b < a) {
  44. int32_t old = atomicMin(s_buf + a, b);
  45. done = (old == a);
  46. a = old;
  47. } else
  48. done = true;
  49. } while (!done);
  50. }
  51. __global__ void
  52. init_labeling(int32_t* label, const uint32_t W, const uint32_t H) {
  53. const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
  54. const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
  55. const uint32_t idx = row * W + col;
  56. if (row < H && col < W)
  57. label[idx] = idx;
  58. }
  59. __global__ void
  60. merge(uint8_t* img, int32_t* label, const uint32_t W, const uint32_t H) {
  61. const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
  62. const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
  63. const uint32_t idx = row * W + col;
  64. if (row >= H || col >= W)
  65. return;
  66. uint32_t P = 0;
  67. if (img[idx])
  68. P |= 0x777;
  69. if (row + 1 < H && img[idx + W])
  70. P |= 0x777 << 4;
  71. if (col + 1 < W && img[idx + 1])
  72. P |= 0x777 << 1;
  73. if (col == 0)
  74. P &= 0xEEEE;
  75. if (col + 1 >= W)
  76. P &= 0x3333;
  77. else if (col + 2 >= W)
  78. P &= 0x7777;
  79. if (row == 0)
  80. P &= 0xFFF0;
  81. if (row + 1 >= H)
  82. P &= 0xFF;
  83. if (P > 0) {
  84. // If need check about top-left pixel(if flag the first bit) and hit the
  85. // top-left pixel
  86. if (hasBit(P, 0) && img[idx - W - 1]) {
  87. union_(label, idx, idx - 2 * W - 2); // top left block
  88. }
  89. if ((hasBit(P, 1) && img[idx - W]) || (hasBit(P, 2) && img[idx - W + 1]))
  90. union_(label, idx, idx - 2 * W); // top bottom block
  91. if (hasBit(P, 3) && img[idx + 2 - W])
  92. union_(label, idx, idx - 2 * W + 2); // top right block
  93. if ((hasBit(P, 4) && img[idx - 1]) || (hasBit(P, 8) && img[idx + W - 1]))
  94. union_(label, idx, idx - 2); // just left block
  95. }
  96. }
  97. __global__ void compression(int32_t* label, const int32_t W, const int32_t H) {
  98. const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
  99. const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
  100. const uint32_t idx = row * W + col;
  101. if (row < H && col < W)
  102. find_n_compress(label, idx);
  103. }
  104. __global__ void final_labeling(
  105. const uint8_t* img,
  106. int32_t* label,
  107. const int32_t W,
  108. const int32_t H) {
  109. const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
  110. const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
  111. const uint32_t idx = row * W + col;
  112. if (row >= H || col >= W)
  113. return;
  114. int32_t y = label[idx] + 1;
  115. if (img[idx])
  116. label[idx] = y;
  117. else
  118. label[idx] = 0;
  119. if (col + 1 < W) {
  120. if (img[idx + 1])
  121. label[idx + 1] = y;
  122. else
  123. label[idx + 1] = 0;
  124. if (row + 1 < H) {
  125. if (img[idx + W + 1])
  126. label[idx + W + 1] = y;
  127. else
  128. label[idx + W + 1] = 0;
  129. }
  130. }
  131. if (row + 1 < H) {
  132. if (img[idx + W])
  133. label[idx + W] = y;
  134. else
  135. label[idx + W] = 0;
  136. }
  137. }
  138. __global__ void init_counting(
  139. const int32_t* label,
  140. int32_t* count_init,
  141. const int32_t W,
  142. const int32_t H) {
  143. const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y);
  144. const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x);
  145. const uint32_t idx = row * W + col;
  146. if (row >= H || col >= W)
  147. return;
  148. int32_t y = label[idx];
  149. if (y > 0) {
  150. int32_t count_idx = y - 1;
  151. atomicAdd(count_init + count_idx, 1);
  152. }
  153. }
  154. __global__ void final_counting(
  155. const int32_t* label,
  156. const int32_t* count_init,
  157. int32_t* count_final,
  158. const int32_t W,
  159. const int32_t H) {
  160. const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y);
  161. const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x);
  162. const uint32_t idx = row * W + col;
  163. if (row >= H || col >= W)
  164. return;
  165. int32_t y = label[idx];
  166. if (y > 0) {
  167. int32_t count_idx = y - 1;
  168. count_final[idx] = count_init[count_idx];
  169. } else {
  170. count_final[idx] = 0;
  171. }
  172. }
  173. } // namespace cc2d
  174. std::vector<torch::Tensor> get_connected_componnets(
  175. const torch::Tensor& inputs) {
  176. AT_ASSERTM(inputs.is_cuda(), "inputs must be a CUDA tensor");
  177. AT_ASSERTM(inputs.ndimension() == 4, "inputs must be [N, 1, H, W] shape");
  178. AT_ASSERTM(
  179. inputs.scalar_type() == torch::kUInt8, "inputs must be a uint8 type");
  180. const uint32_t N = inputs.size(0);
  181. const uint32_t C = inputs.size(1);
  182. const uint32_t H = inputs.size(2);
  183. const uint32_t W = inputs.size(3);
  184. AT_ASSERTM(C == 1, "inputs must be [N, 1, H, W] shape");
  185. AT_ASSERTM((H % 2) == 0, "height must be an even number");
  186. AT_ASSERTM((W % 2) == 0, "width must be an even number");
  187. // label must be uint32_t
  188. auto label_options =
  189. torch::TensorOptions().dtype(torch::kInt32).device(inputs.device());
  190. torch::Tensor labels = torch::zeros({N, C, H, W}, label_options);
  191. torch::Tensor counts_init = torch::zeros({N, C, H, W}, label_options);
  192. torch::Tensor counts_final = torch::zeros({N, C, H, W}, label_options);
  193. dim3 grid = dim3(
  194. ((W + 1) / 2 + BLOCK_COLS - 1) / BLOCK_COLS,
  195. ((H + 1) / 2 + BLOCK_ROWS - 1) / BLOCK_ROWS);
  196. dim3 block = dim3(BLOCK_COLS, BLOCK_ROWS);
  197. dim3 grid_count =
  198. dim3((W + BLOCK_COLS) / BLOCK_COLS, (H + BLOCK_ROWS) / BLOCK_ROWS);
  199. dim3 block_count = dim3(BLOCK_COLS, BLOCK_ROWS);
  200. cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  201. for (int n = 0; n < N; n++) {
  202. uint32_t offset = n * H * W;
  203. cc2d::init_labeling<<<grid, block, 0, stream>>>(
  204. labels.data_ptr<int32_t>() + offset, W, H);
  205. cc2d::merge<<<grid, block, 0, stream>>>(
  206. inputs.data_ptr<uint8_t>() + offset,
  207. labels.data_ptr<int32_t>() + offset,
  208. W,
  209. H);
  210. cc2d::compression<<<grid, block, 0, stream>>>(
  211. labels.data_ptr<int32_t>() + offset, W, H);
  212. cc2d::final_labeling<<<grid, block, 0, stream>>>(
  213. inputs.data_ptr<uint8_t>() + offset,
  214. labels.data_ptr<int32_t>() + offset,
  215. W,
  216. H);
  217. // get the counting of each pixel
  218. cc2d::init_counting<<<grid_count, block_count, 0, stream>>>(
  219. labels.data_ptr<int32_t>() + offset,
  220. counts_init.data_ptr<int32_t>() + offset,
  221. W,
  222. H);
  223. cc2d::final_counting<<<grid_count, block_count, 0, stream>>>(
  224. labels.data_ptr<int32_t>() + offset,
  225. counts_init.data_ptr<int32_t>() + offset,
  226. counts_final.data_ptr<int32_t>() + offset,
  227. W,
  228. H);
  229. }
  230. // returned values are [labels, counts]
  231. std::vector<torch::Tensor> outputs;
  232. outputs.push_back(labels);
  233. outputs.push_back(counts_final);
  234. return outputs;
  235. }
  236. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  237. m.def(
  238. "get_connected_componnets",
  239. &get_connected_componnets,
  240. "get_connected_componnets");
  241. }