common.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. from typing import Type
  4. import torch
  5. import torch.nn as nn
  6. class MLPBlock(nn.Module):
  7. def __init__(
  8. self,
  9. embedding_dim: int,
  10. mlp_dim: int,
  11. act: Type[nn.Module] = nn.GELU,
  12. ) -> None:
  13. super().__init__()
  14. self.lin1 = nn.Linear(embedding_dim, mlp_dim)
  15. self.lin2 = nn.Linear(mlp_dim, embedding_dim)
  16. self.act = act()
  17. def forward(self, x: torch.Tensor) -> torch.Tensor:
  18. return self.lin2(self.act(self.lin1(x)))
  19. # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
  20. # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
  21. class LayerNorm2d(nn.Module):
  22. def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
  23. super().__init__()
  24. self.weight = nn.Parameter(torch.ones(num_channels))
  25. self.bias = nn.Parameter(torch.zeros(num_channels))
  26. self.eps = eps
  27. def forward(self, x: torch.Tensor) -> torch.Tensor:
  28. u = x.mean(1, keepdim=True)
  29. s = (x - u).pow(2).mean(1, keepdim=True)
  30. x = (x - u) / torch.sqrt(s + self.eps)
  31. x = self.weight[:, None, None] * x + self.bias[:, None, None]
  32. return x