edt.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. """Triton kernel for euclidean distance transform (EDT)"""
  4. import torch
  5. import triton
  6. import triton.language as tl
  7. """
  8. Disclaimer: This implementation is not meant to be extremely efficient. A CUDA kernel would likely be more efficient.
  9. Even in Triton, there may be more suitable algorithms.
  10. The goal of this kernel is to mimic cv2.distanceTransform(input, cv2.DIST_L2, 0).
  11. Recall that the euclidean distance transform (EDT) calculates the L2 distance to the closest zero pixel for each pixel of the source image.
  12. For images of size NxN, the naive algorithm would be to compute pairwise distances between every pair of points, leading to a O(N^4) algorithm, which is obviously impractical.
  13. One can do better using the following approach:
  14. - First, compute the distance to the closest point in the same row. We can write it as Row_EDT[i,j] = min_k (sqrt((k-j)^2) if input[i,k]==0 else +infinity). With a naive implementation, this step has a O(N^3) complexity
  15. - Then, because of triangular inequality, we notice that the EDT for a given location [i,j] is the min of the row EDTs in the same column. EDT[i,j] = min_k Row_EDT[k, j]. This is also O(N^3)
  16. Overall, this algorithm is quite amenable to parallelization, and has a complexity O(N^3). Can we do better?
  17. It turns out that we can leverage the structure of the L2 distance (nice and convex) to find the minimum in a more efficient way.
  18. We follow the algorithm from "Distance Transforms of Sampled Functions" (https://cs.brown.edu/people/pfelzens/papers/dt-final.pdf), which is also what's implemented in opencv
  19. For a single dimension EDT, we can compute the EDT of an arbitrary function F, that we discretize over the grid. Note that for the binary EDT that we're interested in, we can set F(i,j) = 0 if input[i,j]==0 else +infinity
  20. For now, we'll compute the EDT squared, and will take the sqrt only at the very end.
  21. The basic idea is that each point at location i spawns a parabola around itself, with a bias equal to F(i). So specifically, we're looking at the parabola (x - i)^2 + F(i)
  22. When we're looking for the row EDT at location j, we're effectively looking for min_i (x-i)^2 + F(i). In other word we want to find the lowest parabola at location j.
  23. To do this efficiently, we need to maintain the lower envelope of the union of parabolas. This can be constructed on the fly using a sort of stack approach:
  24. - every time we want to add a new parabola, we check if it may be covering the current right-most parabola. If so, then that parabola was useless, so we can pop it from the stack
  25. - repeat until we can't find any more parabola to pop. Then push the new one.
  26. This algorithm runs in O(N) for a single row, so overall O(N^2) when applied to all rows
  27. Similarly as before, we notice that we can decompose the algorithm for rows and columns, leading to an overall run-time of O(N^2)
  28. This algorithm is less suited for to GPUs, since the one-dimensional EDT computation is quite sequential in nature. However, we can parallelize over batch and row dimensions.
  29. In Triton, things are particularly bad at the moment, since there is no support for reading/writing to the local memory at a specific index (a local gather is coming soon, see https://github.com/triton-lang/triton/issues/974, but no mention of writing, ie scatter)
  30. One could emulate these operations with masking, but in initial tests, it proved to be worst than naively reading and writing to the global memory. My guess is that the cache is compensating somewhat for the repeated single-point accesses.
  31. The timing obtained on a H100 for a random batch of masks of dimension 256 x 1024 x 1024 are as follows:
  32. - OpenCV: 1780ms (including round-trip to cpu, but discounting the fact that it introduces a synchronization point)
  33. - triton, O(N^3) algo: 627ms
  34. - triton, O(N^2) algo: 322ms
  35. Overall, despite being quite naive, this implementation is roughly 5.5x faster than the openCV cpu implem
  36. """
  37. @triton.jit
  38. def edt_kernel(inputs_ptr, outputs_ptr, v, z, height, width, horizontal: tl.constexpr):
  39. # This is a somewhat verbatim implementation of the efficient 1D EDT algorithm described above
  40. # It can be applied horizontally or vertically depending if we're doing the first or second stage.
  41. # It's parallelized across batch+row (or batch+col if horizontal=False)
  42. # TODO: perhaps the implementation can be revisited if/when local gather/scatter become available in triton
  43. batch_id = tl.program_id(axis=0)
  44. if horizontal:
  45. row_id = tl.program_id(axis=1)
  46. block_start = (batch_id * height * width) + row_id * width
  47. length = width
  48. stride = 1
  49. else:
  50. col_id = tl.program_id(axis=1)
  51. block_start = (batch_id * height * width) + col_id
  52. length = height
  53. stride = width
  54. # This will be the index of the right most parabola in the envelope ("the top of the stack")
  55. k = 0
  56. for q in range(1, length):
  57. # Read the function value at the current location. Note that we're doing a singular read, not very efficient
  58. cur_input = tl.load(inputs_ptr + block_start + (q * stride))
  59. # location of the parabola on top of the stack
  60. r = tl.load(v + block_start + (k * stride))
  61. # associated boundary
  62. z_k = tl.load(z + block_start + (k * stride))
  63. # value of the function at the parabola location
  64. previous_input = tl.load(inputs_ptr + block_start + (r * stride))
  65. # intersection between the two parabolas
  66. s = (cur_input - previous_input + q * q - r * r) / (q - r) / 2
  67. # we'll pop as many parabolas as required
  68. while s <= z_k and k - 1 >= 0:
  69. k = k - 1
  70. r = tl.load(v + block_start + (k * stride))
  71. z_k = tl.load(z + block_start + (k * stride))
  72. previous_input = tl.load(inputs_ptr + block_start + (r * stride))
  73. s = (cur_input - previous_input + q * q - r * r) / (q - r) / 2
  74. # Store the new one
  75. k = k + 1
  76. tl.store(v + block_start + (k * stride), q)
  77. tl.store(z + block_start + (k * stride), s)
  78. if k + 1 < length:
  79. tl.store(z + block_start + ((k + 1) * stride), 1e9)
  80. # Last step, we read the envelope to find the min in every location
  81. k = 0
  82. for q in range(length):
  83. while (
  84. k + 1 < length
  85. and tl.load(
  86. z + block_start + ((k + 1) * stride), mask=(k + 1) < length, other=q
  87. )
  88. < q
  89. ):
  90. k += 1
  91. r = tl.load(v + block_start + (k * stride))
  92. d = q - r
  93. old_value = tl.load(inputs_ptr + block_start + (r * stride))
  94. tl.store(outputs_ptr + block_start + (q * stride), old_value + d * d)
  95. def edt_triton(data: torch.Tensor):
  96. """
  97. Computes the Euclidean Distance Transform (EDT) of a batch of binary images.
  98. Args:
  99. data: A tensor of shape (B, H, W) representing a batch of binary images.
  100. Returns:
  101. A tensor of the same shape as data containing the EDT.
  102. It should be equivalent to a batched version of cv2.distanceTransform(input, cv2.DIST_L2, 0)
  103. """
  104. assert data.dim() == 3
  105. assert data.is_cuda
  106. B, H, W = data.shape
  107. data = data.contiguous()
  108. # Allocate the "function" tensor. Implicitly the function is 0 if data[i,j]==0 else +infinity
  109. output = torch.where(data, 1e18, 0.0)
  110. assert output.is_contiguous()
  111. # Scratch tensors for the parabola stacks
  112. parabola_loc = torch.zeros(B, H, W, dtype=torch.uint32, device=data.device)
  113. parabola_inter = torch.empty(B, H, W, dtype=torch.float, device=data.device)
  114. parabola_inter[:, :, 0] = -1e18
  115. parabola_inter[:, :, 1] = 1e18
  116. # Grid size (number of blocks)
  117. grid = (B, H)
  118. # Launch initialization kernel
  119. edt_kernel[grid](
  120. output.clone(),
  121. output,
  122. parabola_loc,
  123. parabola_inter,
  124. H,
  125. W,
  126. horizontal=True,
  127. )
  128. # reset the parabola stacks
  129. parabola_loc.zero_()
  130. parabola_inter[:, :, 0] = -1e18
  131. parabola_inter[:, :, 1] = 1e18
  132. grid = (B, W)
  133. edt_kernel[grid](
  134. output.clone(),
  135. output,
  136. parabola_loc,
  137. parabola_inter,
  138. H,
  139. W,
  140. horizontal=False,
  141. )
  142. # don't forget to take sqrt at the end
  143. return output.sqrt()