sigmoid_focal_loss.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. """Triton kernel for faster and memory efficient sigmoid focal loss"""
  4. import torch
  5. import triton
  6. import triton.language as tl
  7. from torch._inductor.runtime.triton_helpers import libdevice
  8. """
  9. The sigmoid focal loss is defined as:
  10. prob = inputs.sigmoid()
  11. ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
  12. p_t = prob * targets + (1 - prob) * (1 - targets)
  13. alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
  14. loss = alpha_t * ce_loss * ((1 - p_t) ** gamma)
  15. Where alpha and gamma are scalar parameters, inputs are the logits, targets the float targets.
  16. We implement two versions of the sigmoid focal loss: with and without sum reduction.
  17. The latter is implemented with built-in reduction to avoid materializing wrt the output of the loss.
  18. This can help save a bit of peak memory.
  19. The reduction version is implemented using somewhat of a hack. Pytorch's generated kernels usually do the point-wise operation in a first kernel, and implement the reduction another kernel launched on a grid of size 1, where the reduction happens as a for loop in the triton kernel.
  20. Since we want to fuse those two kernels, that is not a good idea: we'd have to launch the overall kernel on a grid of size 1, which is obviously inefficient.
  21. On the other hand, typical CUDA algorithms for reduction (eg reduction tree) are hard to implement in triton due to the lack of thread sync primitives.
  22. We settle for a version that abuses triton's atomic_add: we can have all threads simply add to the same location.
  23. In practice, this is not good, since it creates a massive bottleneck on the semaphore for that single memory location. So instead, we create M reduction locations. Each thread will simply write to thread_id%M. The python code can finally sum over the M reductions.
  24. M = 32 works fine in benchmarking tests. The forward is a tiny bit slower compared to the non-reduced kernel, but the backward breaks even due to one less memory allocation.
  25. """
  26. @triton.jit
  27. def _inner_focal_loss_fwd(inputs, targets, alpha, gamma):
  28. inv_targets = 1 - targets
  29. # Sigmoid
  30. sig = tl.sigmoid(inputs)
  31. # Binary cross entropy with logits
  32. # In practice, we want the following:
  33. # bce_loss = -targets * tl.log(sig) - (1 - targets) * tl.log(1 - sig)
  34. # However, the above is not numerically stable.
  35. # We're also not directly taking the sum here, so the usual log-sum-exp trick doesn't apply
  36. # The bce can be reformulated, after algebraic manipulation, to
  37. # bce_loss = log(1 + exp(-x)) + x * (1-y)
  38. # This is still not stable, because for large (-x) the exponential will blow up.
  39. # We'll use the following alternate formulation:
  40. # bce_loss = max(x, 0) - x * y + log(1 + exp(-abs(x)))
  41. # Let's show that it's equivalent:
  42. # Case x>=0: abs(x) = x , max(x, 0) = x
  43. # so we get x - x * y + log(1 + exp(-x)) which is equivalent
  44. # Case x<0: abs(x) = -x, max(x, 0) = 0
  45. # we have log(1 + exp(-abs(x))) = log(1 + exp(x)) = log(exp(x)(1 + exp(-x))) = x+log(1 + exp(-x))
  46. # plugging it in, we get
  47. # 0 - x * y + x + log(1 + exp(-x)), which is also equivalent
  48. # Note that this is stable because now the exponent are guaranteed to be below 0.
  49. max_val = tl.clamp(inputs, min=0, max=1e9)
  50. bce_loss = max_val - inputs * targets + tl.log(1 + tl.exp(-tl.abs(inputs)))
  51. # Modulating factor
  52. p_t = sig * targets + (1 - sig) * inv_targets
  53. mod_factor = libdevice.pow(1 - p_t, gamma)
  54. # Alpha factor
  55. alpha_t = alpha * targets + (1 - alpha) * inv_targets
  56. # Final loss calculation
  57. return alpha_t * mod_factor * bce_loss
  58. # Non-reduced version
  59. @triton.jit
  60. def sigmoid_focal_loss_fwd_kernel(
  61. inputs_ptr,
  62. targets_ptr,
  63. loss_ptr,
  64. alpha: float,
  65. gamma: float,
  66. n_elements: int,
  67. BLOCK_SIZE: tl.constexpr,
  68. ):
  69. pid = tl.program_id(axis=0)
  70. block_start = pid * BLOCK_SIZE
  71. offset = block_start + tl.arange(0, BLOCK_SIZE)
  72. mask = offset < n_elements
  73. # Load data
  74. inputs = tl.load(inputs_ptr + offset, mask=mask).to(tl.float32)
  75. targets = tl.load(targets_ptr + offset, mask=mask)
  76. final_loss = _inner_focal_loss_fwd(inputs, targets, alpha, gamma)
  77. # Store result
  78. tl.store(loss_ptr + offset, final_loss, mask=mask)
  79. # version with reduction
  80. @triton.jit
  81. def sigmoid_focal_loss_fwd_kernel_reduce(
  82. inputs_ptr,
  83. targets_ptr,
  84. loss_ptr,
  85. alpha: float,
  86. gamma: float,
  87. n_elements: int,
  88. BLOCK_SIZE: tl.constexpr,
  89. REDUCE_SIZE: tl.constexpr,
  90. ):
  91. pid = tl.program_id(axis=0)
  92. block_start = pid * BLOCK_SIZE
  93. reduce_loc = pid % REDUCE_SIZE
  94. offset = block_start + tl.arange(0, BLOCK_SIZE)
  95. mask = offset < n_elements
  96. # Load data
  97. inputs = tl.load(inputs_ptr + offset, mask=mask).to(tl.float32)
  98. targets = tl.load(targets_ptr + offset, mask=mask)
  99. final_loss = _inner_focal_loss_fwd(inputs, targets, alpha, gamma) * mask
  100. fl = tl.sum(final_loss)
  101. # Store result
  102. tl.atomic_add(loss_ptr + reduce_loc, fl)
  103. @triton.jit
  104. def _inner_focal_loss_bwd(inputs, targets, alpha, gamma):
  105. inv_targets = 1 - targets
  106. # Recompute forward
  107. max_val = tl.clamp(inputs, min=0, max=1e9)
  108. bce_loss = max_val - inputs * targets + tl.log(1 + tl.exp(-tl.abs(inputs)))
  109. # Sigmoid
  110. sig = tl.sigmoid(inputs)
  111. inv_sig = 1 - sig
  112. # Modulating factor
  113. p_t = sig * targets + inv_sig * inv_targets
  114. tmp = libdevice.pow(1 - p_t, gamma - 1)
  115. mod_factor = tmp * (1 - p_t)
  116. # Alpha factor
  117. alpha_t = alpha * targets + (1 - alpha) * inv_targets
  118. # Now computing the derivatives
  119. d_pt = (2 * targets - 1) * sig * inv_sig
  120. d_mod_factor = -gamma * d_pt * tmp
  121. d_bce_loss = sig - targets
  122. return alpha_t * (d_bce_loss * mod_factor + d_mod_factor * bce_loss)
  123. @triton.jit
  124. def sigmoid_focal_loss_bwd_kernel(
  125. inputs_ptr,
  126. targets_ptr,
  127. grad_inputs_ptr,
  128. grad_out_ptr,
  129. alpha: float,
  130. gamma: float,
  131. n_elements: int,
  132. BLOCK_SIZE: tl.constexpr,
  133. ):
  134. pid = tl.program_id(axis=0)
  135. block_start = pid * BLOCK_SIZE
  136. offset = block_start + tl.arange(0, BLOCK_SIZE)
  137. mask = offset < n_elements
  138. input_ptrs = inputs_ptr + offset
  139. target_ptrs = targets_ptr + offset
  140. grad_input_ptrs = grad_inputs_ptr + offset
  141. grad_out_ptrs = grad_out_ptr + offset
  142. # Load data
  143. inputs = tl.load(input_ptrs, mask=mask).to(tl.float32)
  144. targets = tl.load(target_ptrs, mask=mask)
  145. grad_out = tl.load(grad_out_ptrs, mask=mask)
  146. d_loss = grad_out * _inner_focal_loss_bwd(inputs, targets, alpha, gamma)
  147. tl.store(grad_input_ptrs, d_loss, mask=mask)
  148. @triton.jit
  149. def sigmoid_focal_loss_bwd_kernel_reduce(
  150. inputs_ptr,
  151. targets_ptr,
  152. grad_inputs_ptr,
  153. grad_out_ptr,
  154. alpha: float,
  155. gamma: float,
  156. n_elements: int,
  157. BLOCK_SIZE: tl.constexpr,
  158. ):
  159. # The only difference is that the gradient is now a single scalar
  160. pid = tl.program_id(axis=0)
  161. block_start = pid * BLOCK_SIZE
  162. offset = block_start + tl.arange(0, BLOCK_SIZE)
  163. mask = offset < n_elements
  164. input_ptrs = inputs_ptr + offset
  165. target_ptrs = targets_ptr + offset
  166. grad_input_ptrs = grad_inputs_ptr + offset
  167. # Load data
  168. inputs = tl.load(input_ptrs, mask=mask).to(tl.float32)
  169. targets = tl.load(target_ptrs, mask=mask)
  170. grad_out = tl.load(grad_out_ptr)
  171. d_loss = grad_out * _inner_focal_loss_bwd(inputs, targets, alpha, gamma)
  172. tl.store(grad_input_ptrs, d_loss, mask=mask)
  173. class SigmoidFocalLoss(torch.autograd.Function):
  174. BLOCK_SIZE = 256
  175. @staticmethod
  176. def forward(ctx, inputs, targets, alpha=0.25, gamma=2):
  177. n_elements = inputs.numel()
  178. assert targets.numel() == n_elements
  179. input_shape = inputs.shape
  180. inputs = inputs.view(-1).contiguous()
  181. targets = targets.view(-1).contiguous()
  182. loss = torch.empty(inputs.shape, dtype=torch.float32, device=inputs.device)
  183. grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
  184. sigmoid_focal_loss_fwd_kernel[grid](
  185. inputs, targets, loss, alpha, gamma, n_elements, SigmoidFocalLoss.BLOCK_SIZE
  186. )
  187. ctx.save_for_backward(inputs.view(input_shape), targets.view(input_shape))
  188. ctx.alpha = alpha
  189. ctx.gamma = gamma
  190. return loss.view(input_shape)
  191. @staticmethod
  192. def backward(ctx, grad_output):
  193. inputs, targets = ctx.saved_tensors
  194. alpha = ctx.alpha
  195. gamma = ctx.gamma
  196. n_elements = inputs.numel()
  197. input_shape = inputs.shape
  198. grad_inputs = torch.empty(
  199. inputs.shape, dtype=grad_output.dtype, device=grad_output.device
  200. )
  201. inputs_ptr = inputs.view(-1).contiguous()
  202. targets_ptr = targets.view(-1).contiguous()
  203. grad_output_ptr = grad_output.view(-1).contiguous()
  204. grad_inputs_ptr = grad_inputs
  205. assert grad_output.numel() == n_elements
  206. grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
  207. sigmoid_focal_loss_bwd_kernel[grid](
  208. inputs_ptr,
  209. targets_ptr,
  210. grad_inputs_ptr,
  211. grad_output_ptr,
  212. alpha,
  213. gamma,
  214. n_elements,
  215. SigmoidFocalLoss.BLOCK_SIZE,
  216. )
  217. return grad_inputs.view(input_shape), None, None, None
  218. triton_sigmoid_focal_loss = SigmoidFocalLoss.apply
  219. class SigmoidFocalLossReduced(torch.autograd.Function):
  220. BLOCK_SIZE = 256
  221. REDUCE_SIZE = 32
  222. @staticmethod
  223. def forward(ctx, inputs, targets, alpha=0.25, gamma=2):
  224. n_elements = inputs.numel()
  225. input_shape = inputs.shape
  226. inputs = inputs.view(-1).contiguous()
  227. targets = targets.view(-1).contiguous()
  228. loss = torch.zeros(
  229. SigmoidFocalLossReduced.REDUCE_SIZE,
  230. device=inputs.device,
  231. dtype=torch.float32,
  232. )
  233. grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
  234. sigmoid_focal_loss_fwd_kernel_reduce[grid](
  235. inputs,
  236. targets,
  237. loss,
  238. alpha,
  239. gamma,
  240. n_elements,
  241. SigmoidFocalLossReduced.BLOCK_SIZE,
  242. SigmoidFocalLossReduced.REDUCE_SIZE,
  243. )
  244. ctx.save_for_backward(inputs.view(input_shape), targets.view(input_shape))
  245. ctx.alpha = alpha
  246. ctx.gamma = gamma
  247. return loss.sum()
  248. @staticmethod
  249. def backward(ctx, grad_output):
  250. inputs, targets = ctx.saved_tensors
  251. alpha = ctx.alpha
  252. gamma = ctx.gamma
  253. n_elements = inputs.numel()
  254. input_shape = inputs.shape
  255. grad_inputs = torch.empty(
  256. inputs.shape, dtype=grad_output.dtype, device=grad_output.device
  257. )
  258. inputs_ptr = inputs.view(-1).contiguous()
  259. targets_ptr = targets.reshape(-1).contiguous()
  260. assert grad_output.numel() == 1
  261. grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
  262. sigmoid_focal_loss_bwd_kernel_reduce[grid](
  263. inputs_ptr,
  264. targets_ptr,
  265. grad_inputs,
  266. grad_output,
  267. alpha,
  268. gamma,
  269. n_elements,
  270. SigmoidFocalLossReduced.BLOCK_SIZE,
  271. )
  272. return grad_inputs.view(input_shape), None, None, None
  273. triton_sigmoid_focal_loss_reduce = SigmoidFocalLossReduced.apply