| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323 |
- # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
- # pyre-unsafe
- """Triton kernel for faster and memory efficient sigmoid focal loss"""
- import torch
- import triton
- import triton.language as tl
- from torch._inductor.runtime.triton_helpers import libdevice
- """
- The sigmoid focal loss is defined as:
- prob = inputs.sigmoid()
- ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
- p_t = prob * targets + (1 - prob) * (1 - targets)
- alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
- loss = alpha_t * ce_loss * ((1 - p_t) ** gamma)
- Where alpha and gamma are scalar parameters, inputs are the logits, targets the float targets.
- We implement two versions of the sigmoid focal loss: with and without sum reduction.
- The latter is implemented with built-in reduction to avoid materializing wrt the output of the loss.
- This can help save a bit of peak memory.
- The reduction version is implemented using somewhat of a hack. Pytorch's generated kernels usually do the point-wise operation in a first kernel, and implement the reduction another kernel launched on a grid of size 1, where the reduction happens as a for loop in the triton kernel.
- Since we want to fuse those two kernels, that is not a good idea: we'd have to launch the overall kernel on a grid of size 1, which is obviously inefficient.
- On the other hand, typical CUDA algorithms for reduction (eg reduction tree) are hard to implement in triton due to the lack of thread sync primitives.
- We settle for a version that abuses triton's atomic_add: we can have all threads simply add to the same location.
- In practice, this is not good, since it creates a massive bottleneck on the semaphore for that single memory location. So instead, we create M reduction locations. Each thread will simply write to thread_id%M. The python code can finally sum over the M reductions.
- M = 32 works fine in benchmarking tests. The forward is a tiny bit slower compared to the non-reduced kernel, but the backward breaks even due to one less memory allocation.
- """
- @triton.jit
- def _inner_focal_loss_fwd(inputs, targets, alpha, gamma):
- inv_targets = 1 - targets
- # Sigmoid
- sig = tl.sigmoid(inputs)
- # Binary cross entropy with logits
- # In practice, we want the following:
- # bce_loss = -targets * tl.log(sig) - (1 - targets) * tl.log(1 - sig)
- # However, the above is not numerically stable.
- # We're also not directly taking the sum here, so the usual log-sum-exp trick doesn't apply
- # The bce can be reformulated, after algebraic manipulation, to
- # bce_loss = log(1 + exp(-x)) + x * (1-y)
- # This is still not stable, because for large (-x) the exponential will blow up.
- # We'll use the following alternate formulation:
- # bce_loss = max(x, 0) - x * y + log(1 + exp(-abs(x)))
- # Let's show that it's equivalent:
- # Case x>=0: abs(x) = x , max(x, 0) = x
- # so we get x - x * y + log(1 + exp(-x)) which is equivalent
- # Case x<0: abs(x) = -x, max(x, 0) = 0
- # we have log(1 + exp(-abs(x))) = log(1 + exp(x)) = log(exp(x)(1 + exp(-x))) = x+log(1 + exp(-x))
- # plugging it in, we get
- # 0 - x * y + x + log(1 + exp(-x)), which is also equivalent
- # Note that this is stable because now the exponent are guaranteed to be below 0.
- max_val = tl.clamp(inputs, min=0, max=1e9)
- bce_loss = max_val - inputs * targets + tl.log(1 + tl.exp(-tl.abs(inputs)))
- # Modulating factor
- p_t = sig * targets + (1 - sig) * inv_targets
- mod_factor = libdevice.pow(1 - p_t, gamma)
- # Alpha factor
- alpha_t = alpha * targets + (1 - alpha) * inv_targets
- # Final loss calculation
- return alpha_t * mod_factor * bce_loss
- # Non-reduced version
- @triton.jit
- def sigmoid_focal_loss_fwd_kernel(
- inputs_ptr,
- targets_ptr,
- loss_ptr,
- alpha: float,
- gamma: float,
- n_elements: int,
- BLOCK_SIZE: tl.constexpr,
- ):
- pid = tl.program_id(axis=0)
- block_start = pid * BLOCK_SIZE
- offset = block_start + tl.arange(0, BLOCK_SIZE)
- mask = offset < n_elements
- # Load data
- inputs = tl.load(inputs_ptr + offset, mask=mask).to(tl.float32)
- targets = tl.load(targets_ptr + offset, mask=mask)
- final_loss = _inner_focal_loss_fwd(inputs, targets, alpha, gamma)
- # Store result
- tl.store(loss_ptr + offset, final_loss, mask=mask)
- # version with reduction
- @triton.jit
- def sigmoid_focal_loss_fwd_kernel_reduce(
- inputs_ptr,
- targets_ptr,
- loss_ptr,
- alpha: float,
- gamma: float,
- n_elements: int,
- BLOCK_SIZE: tl.constexpr,
- REDUCE_SIZE: tl.constexpr,
- ):
- pid = tl.program_id(axis=0)
- block_start = pid * BLOCK_SIZE
- reduce_loc = pid % REDUCE_SIZE
- offset = block_start + tl.arange(0, BLOCK_SIZE)
- mask = offset < n_elements
- # Load data
- inputs = tl.load(inputs_ptr + offset, mask=mask).to(tl.float32)
- targets = tl.load(targets_ptr + offset, mask=mask)
- final_loss = _inner_focal_loss_fwd(inputs, targets, alpha, gamma) * mask
- fl = tl.sum(final_loss)
- # Store result
- tl.atomic_add(loss_ptr + reduce_loc, fl)
- @triton.jit
- def _inner_focal_loss_bwd(inputs, targets, alpha, gamma):
- inv_targets = 1 - targets
- # Recompute forward
- max_val = tl.clamp(inputs, min=0, max=1e9)
- bce_loss = max_val - inputs * targets + tl.log(1 + tl.exp(-tl.abs(inputs)))
- # Sigmoid
- sig = tl.sigmoid(inputs)
- inv_sig = 1 - sig
- # Modulating factor
- p_t = sig * targets + inv_sig * inv_targets
- tmp = libdevice.pow(1 - p_t, gamma - 1)
- mod_factor = tmp * (1 - p_t)
- # Alpha factor
- alpha_t = alpha * targets + (1 - alpha) * inv_targets
- # Now computing the derivatives
- d_pt = (2 * targets - 1) * sig * inv_sig
- d_mod_factor = -gamma * d_pt * tmp
- d_bce_loss = sig - targets
- return alpha_t * (d_bce_loss * mod_factor + d_mod_factor * bce_loss)
- @triton.jit
- def sigmoid_focal_loss_bwd_kernel(
- inputs_ptr,
- targets_ptr,
- grad_inputs_ptr,
- grad_out_ptr,
- alpha: float,
- gamma: float,
- n_elements: int,
- BLOCK_SIZE: tl.constexpr,
- ):
- pid = tl.program_id(axis=0)
- block_start = pid * BLOCK_SIZE
- offset = block_start + tl.arange(0, BLOCK_SIZE)
- mask = offset < n_elements
- input_ptrs = inputs_ptr + offset
- target_ptrs = targets_ptr + offset
- grad_input_ptrs = grad_inputs_ptr + offset
- grad_out_ptrs = grad_out_ptr + offset
- # Load data
- inputs = tl.load(input_ptrs, mask=mask).to(tl.float32)
- targets = tl.load(target_ptrs, mask=mask)
- grad_out = tl.load(grad_out_ptrs, mask=mask)
- d_loss = grad_out * _inner_focal_loss_bwd(inputs, targets, alpha, gamma)
- tl.store(grad_input_ptrs, d_loss, mask=mask)
- @triton.jit
- def sigmoid_focal_loss_bwd_kernel_reduce(
- inputs_ptr,
- targets_ptr,
- grad_inputs_ptr,
- grad_out_ptr,
- alpha: float,
- gamma: float,
- n_elements: int,
- BLOCK_SIZE: tl.constexpr,
- ):
- # The only difference is that the gradient is now a single scalar
- pid = tl.program_id(axis=0)
- block_start = pid * BLOCK_SIZE
- offset = block_start + tl.arange(0, BLOCK_SIZE)
- mask = offset < n_elements
- input_ptrs = inputs_ptr + offset
- target_ptrs = targets_ptr + offset
- grad_input_ptrs = grad_inputs_ptr + offset
- # Load data
- inputs = tl.load(input_ptrs, mask=mask).to(tl.float32)
- targets = tl.load(target_ptrs, mask=mask)
- grad_out = tl.load(grad_out_ptr)
- d_loss = grad_out * _inner_focal_loss_bwd(inputs, targets, alpha, gamma)
- tl.store(grad_input_ptrs, d_loss, mask=mask)
- class SigmoidFocalLoss(torch.autograd.Function):
- BLOCK_SIZE = 256
- @staticmethod
- def forward(ctx, inputs, targets, alpha=0.25, gamma=2):
- n_elements = inputs.numel()
- assert targets.numel() == n_elements
- input_shape = inputs.shape
- inputs = inputs.view(-1).contiguous()
- targets = targets.view(-1).contiguous()
- loss = torch.empty(inputs.shape, dtype=torch.float32, device=inputs.device)
- grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
- sigmoid_focal_loss_fwd_kernel[grid](
- inputs, targets, loss, alpha, gamma, n_elements, SigmoidFocalLoss.BLOCK_SIZE
- )
- ctx.save_for_backward(inputs.view(input_shape), targets.view(input_shape))
- ctx.alpha = alpha
- ctx.gamma = gamma
- return loss.view(input_shape)
- @staticmethod
- def backward(ctx, grad_output):
- inputs, targets = ctx.saved_tensors
- alpha = ctx.alpha
- gamma = ctx.gamma
- n_elements = inputs.numel()
- input_shape = inputs.shape
- grad_inputs = torch.empty(
- inputs.shape, dtype=grad_output.dtype, device=grad_output.device
- )
- inputs_ptr = inputs.view(-1).contiguous()
- targets_ptr = targets.view(-1).contiguous()
- grad_output_ptr = grad_output.view(-1).contiguous()
- grad_inputs_ptr = grad_inputs
- assert grad_output.numel() == n_elements
- grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
- sigmoid_focal_loss_bwd_kernel[grid](
- inputs_ptr,
- targets_ptr,
- grad_inputs_ptr,
- grad_output_ptr,
- alpha,
- gamma,
- n_elements,
- SigmoidFocalLoss.BLOCK_SIZE,
- )
- return grad_inputs.view(input_shape), None, None, None
- triton_sigmoid_focal_loss = SigmoidFocalLoss.apply
- class SigmoidFocalLossReduced(torch.autograd.Function):
- BLOCK_SIZE = 256
- REDUCE_SIZE = 32
- @staticmethod
- def forward(ctx, inputs, targets, alpha=0.25, gamma=2):
- n_elements = inputs.numel()
- input_shape = inputs.shape
- inputs = inputs.view(-1).contiguous()
- targets = targets.view(-1).contiguous()
- loss = torch.zeros(
- SigmoidFocalLossReduced.REDUCE_SIZE,
- device=inputs.device,
- dtype=torch.float32,
- )
- grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
- sigmoid_focal_loss_fwd_kernel_reduce[grid](
- inputs,
- targets,
- loss,
- alpha,
- gamma,
- n_elements,
- SigmoidFocalLossReduced.BLOCK_SIZE,
- SigmoidFocalLossReduced.REDUCE_SIZE,
- )
- ctx.save_for_backward(inputs.view(input_shape), targets.view(input_shape))
- ctx.alpha = alpha
- ctx.gamma = gamma
- return loss.sum()
- @staticmethod
- def backward(ctx, grad_output):
- inputs, targets = ctx.saved_tensors
- alpha = ctx.alpha
- gamma = ctx.gamma
- n_elements = inputs.numel()
- input_shape = inputs.shape
- grad_inputs = torch.empty(
- inputs.shape, dtype=grad_output.dtype, device=grad_output.device
- )
- inputs_ptr = inputs.view(-1).contiguous()
- targets_ptr = targets.reshape(-1).contiguous()
- assert grad_output.numel() == 1
- grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
- sigmoid_focal_loss_bwd_kernel_reduce[grid](
- inputs_ptr,
- targets_ptr,
- grad_inputs,
- grad_output,
- alpha,
- gamma,
- n_elements,
- SigmoidFocalLossReduced.BLOCK_SIZE,
- )
- return grad_inputs.view(input_shape), None, None, None
- triton_sigmoid_focal_loss_reduce = SigmoidFocalLossReduced.apply
|