schedulers.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import math
  4. class InverseSquareRootParamScheduler:
  5. def __init__(
  6. self,
  7. base_lr: float,
  8. warmup_steps: int,
  9. cooldown_steps: int,
  10. timescale: int,
  11. ):
  12. self.base_lr = base_lr
  13. self.warmup_steps = warmup_steps
  14. self.cooldown_steps = cooldown_steps
  15. self.timescale = timescale
  16. def __call__(self, step: int, where: float):
  17. lr = self.base_lr
  18. if where > 0:
  19. total_steps = step / where
  20. progress = (step - self.warmup_steps) / float(
  21. total_steps - self.warmup_steps
  22. )
  23. progress = max(min(progress, 1), 0)
  24. else:
  25. progress = 0
  26. total_steps = 1
  27. shift = self.timescale - self.warmup_steps
  28. if self.warmup_steps < step:
  29. lr = lr / math.sqrt((step + shift) / self.timescale)
  30. if self.warmup_steps:
  31. lr = lr * min(1.0, step / self.warmup_steps)
  32. if self.cooldown_steps:
  33. lr = lr * min(1.0, (total_steps - step) / self.cooldown_steps)
  34. return lr