| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289 |
- // Copyright (c) Meta Platforms, Inc. and affiliates.
- // All rights reserved.
- // This source code is licensed under the license found in the
- // LICENSE file in the root directory of this source tree.
- // adapted from https://github.com/zsef123/Connected_components_PyTorch
- // with license found in the LICENSE_cctorch file in the root directory.
- #include <ATen/cuda/CUDAContext.h>
- #include <cuda.h>
- #include <cuda_runtime.h>
- #include <torch/extension.h>
- #include <torch/script.h>
- #include <vector>
- // 2d
- #define BLOCK_ROWS 16
- #define BLOCK_COLS 16
- namespace cc2d {
- template <typename T>
- __device__ __forceinline__ unsigned char hasBit(T bitmap, unsigned char pos) {
- return (bitmap >> pos) & 1;
- }
- __device__ int32_t find(const int32_t* s_buf, int32_t n) {
- while (s_buf[n] != n)
- n = s_buf[n];
- return n;
- }
- __device__ int32_t find_n_compress(int32_t* s_buf, int32_t n) {
- const int32_t id = n;
- while (s_buf[n] != n) {
- n = s_buf[n];
- s_buf[id] = n;
- }
- return n;
- }
- __device__ void union_(int32_t* s_buf, int32_t a, int32_t b) {
- bool done;
- do {
- a = find(s_buf, a);
- b = find(s_buf, b);
- if (a < b) {
- int32_t old = atomicMin(s_buf + b, a);
- done = (old == b);
- b = old;
- } else if (b < a) {
- int32_t old = atomicMin(s_buf + a, b);
- done = (old == a);
- a = old;
- } else
- done = true;
- } while (!done);
- }
- __global__ void
- init_labeling(int32_t* label, const uint32_t W, const uint32_t H) {
- const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
- const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
- const uint32_t idx = row * W + col;
- if (row < H && col < W)
- label[idx] = idx;
- }
- __global__ void
- merge(uint8_t* img, int32_t* label, const uint32_t W, const uint32_t H) {
- const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
- const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
- const uint32_t idx = row * W + col;
- if (row >= H || col >= W)
- return;
- uint32_t P = 0;
- if (img[idx])
- P |= 0x777;
- if (row + 1 < H && img[idx + W])
- P |= 0x777 << 4;
- if (col + 1 < W && img[idx + 1])
- P |= 0x777 << 1;
- if (col == 0)
- P &= 0xEEEE;
- if (col + 1 >= W)
- P &= 0x3333;
- else if (col + 2 >= W)
- P &= 0x7777;
- if (row == 0)
- P &= 0xFFF0;
- if (row + 1 >= H)
- P &= 0xFF;
- if (P > 0) {
- // If need check about top-left pixel(if flag the first bit) and hit the
- // top-left pixel
- if (hasBit(P, 0) && img[idx - W - 1]) {
- union_(label, idx, idx - 2 * W - 2); // top left block
- }
- if ((hasBit(P, 1) && img[idx - W]) || (hasBit(P, 2) && img[idx - W + 1]))
- union_(label, idx, idx - 2 * W); // top bottom block
- if (hasBit(P, 3) && img[idx + 2 - W])
- union_(label, idx, idx - 2 * W + 2); // top right block
- if ((hasBit(P, 4) && img[idx - 1]) || (hasBit(P, 8) && img[idx + W - 1]))
- union_(label, idx, idx - 2); // just left block
- }
- }
- __global__ void compression(int32_t* label, const int32_t W, const int32_t H) {
- const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
- const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
- const uint32_t idx = row * W + col;
- if (row < H && col < W)
- find_n_compress(label, idx);
- }
- __global__ void final_labeling(
- const uint8_t* img,
- int32_t* label,
- const int32_t W,
- const int32_t H) {
- const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
- const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
- const uint32_t idx = row * W + col;
- if (row >= H || col >= W)
- return;
- int32_t y = label[idx] + 1;
- if (img[idx])
- label[idx] = y;
- else
- label[idx] = 0;
- if (col + 1 < W) {
- if (img[idx + 1])
- label[idx + 1] = y;
- else
- label[idx + 1] = 0;
- if (row + 1 < H) {
- if (img[idx + W + 1])
- label[idx + W + 1] = y;
- else
- label[idx + W + 1] = 0;
- }
- }
- if (row + 1 < H) {
- if (img[idx + W])
- label[idx + W] = y;
- else
- label[idx + W] = 0;
- }
- }
- __global__ void init_counting(
- const int32_t* label,
- int32_t* count_init,
- const int32_t W,
- const int32_t H) {
- const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y);
- const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x);
- const uint32_t idx = row * W + col;
- if (row >= H || col >= W)
- return;
- int32_t y = label[idx];
- if (y > 0) {
- int32_t count_idx = y - 1;
- atomicAdd(count_init + count_idx, 1);
- }
- }
- __global__ void final_counting(
- const int32_t* label,
- const int32_t* count_init,
- int32_t* count_final,
- const int32_t W,
- const int32_t H) {
- const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y);
- const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x);
- const uint32_t idx = row * W + col;
- if (row >= H || col >= W)
- return;
- int32_t y = label[idx];
- if (y > 0) {
- int32_t count_idx = y - 1;
- count_final[idx] = count_init[count_idx];
- } else {
- count_final[idx] = 0;
- }
- }
- } // namespace cc2d
- std::vector<torch::Tensor> get_connected_componnets(
- const torch::Tensor& inputs) {
- AT_ASSERTM(inputs.is_cuda(), "inputs must be a CUDA tensor");
- AT_ASSERTM(inputs.ndimension() == 4, "inputs must be [N, 1, H, W] shape");
- AT_ASSERTM(
- inputs.scalar_type() == torch::kUInt8, "inputs must be a uint8 type");
- const uint32_t N = inputs.size(0);
- const uint32_t C = inputs.size(1);
- const uint32_t H = inputs.size(2);
- const uint32_t W = inputs.size(3);
- AT_ASSERTM(C == 1, "inputs must be [N, 1, H, W] shape");
- AT_ASSERTM((H % 2) == 0, "height must be an even number");
- AT_ASSERTM((W % 2) == 0, "width must be an even number");
- // label must be uint32_t
- auto label_options =
- torch::TensorOptions().dtype(torch::kInt32).device(inputs.device());
- torch::Tensor labels = torch::zeros({N, C, H, W}, label_options);
- torch::Tensor counts_init = torch::zeros({N, C, H, W}, label_options);
- torch::Tensor counts_final = torch::zeros({N, C, H, W}, label_options);
- dim3 grid = dim3(
- ((W + 1) / 2 + BLOCK_COLS - 1) / BLOCK_COLS,
- ((H + 1) / 2 + BLOCK_ROWS - 1) / BLOCK_ROWS);
- dim3 block = dim3(BLOCK_COLS, BLOCK_ROWS);
- dim3 grid_count =
- dim3((W + BLOCK_COLS) / BLOCK_COLS, (H + BLOCK_ROWS) / BLOCK_ROWS);
- dim3 block_count = dim3(BLOCK_COLS, BLOCK_ROWS);
- cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- for (int n = 0; n < N; n++) {
- uint32_t offset = n * H * W;
- cc2d::init_labeling<<<grid, block, 0, stream>>>(
- labels.data_ptr<int32_t>() + offset, W, H);
- cc2d::merge<<<grid, block, 0, stream>>>(
- inputs.data_ptr<uint8_t>() + offset,
- labels.data_ptr<int32_t>() + offset,
- W,
- H);
- cc2d::compression<<<grid, block, 0, stream>>>(
- labels.data_ptr<int32_t>() + offset, W, H);
- cc2d::final_labeling<<<grid, block, 0, stream>>>(
- inputs.data_ptr<uint8_t>() + offset,
- labels.data_ptr<int32_t>() + offset,
- W,
- H);
- // get the counting of each pixel
- cc2d::init_counting<<<grid_count, block_count, 0, stream>>>(
- labels.data_ptr<int32_t>() + offset,
- counts_init.data_ptr<int32_t>() + offset,
- W,
- H);
- cc2d::final_counting<<<grid_count, block_count, 0, stream>>>(
- labels.data_ptr<int32_t>() + offset,
- counts_init.data_ptr<int32_t>() + offset,
- counts_final.data_ptr<int32_t>() + offset,
- W,
- H);
- }
- // returned values are [labels, counts]
- std::vector<torch::Tensor> outputs;
- outputs.push_back(labels);
- outputs.push_back(counts_final);
- return outputs;
- }
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
- m.def(
- "get_connected_componnets",
- &get_connected_componnets,
- "get_connected_componnets");
- }
|