learn_schedule.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import tqdm
  2. class LearnScheduleIterator:
  3. def __init__(self, learn_rate, max_steps, cur_step=0):
  4. """
  5. specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000
  6. """
  7. pairs = learn_rate.split(',')
  8. self.rates = []
  9. self.it = 0
  10. self.maxit = 0
  11. try:
  12. for pair in pairs:
  13. if not pair.strip():
  14. continue
  15. tmp = pair.split(':')
  16. if len(tmp) == 2:
  17. step = int(tmp[1])
  18. if step > cur_step:
  19. self.rates.append((float(tmp[0]), min(step, max_steps)))
  20. self.maxit += 1
  21. if step > max_steps:
  22. return
  23. elif step == -1:
  24. self.rates.append((float(tmp[0]), max_steps))
  25. self.maxit += 1
  26. return
  27. else:
  28. self.rates.append((float(tmp[0]), max_steps))
  29. self.maxit += 1
  30. return
  31. assert self.rates
  32. except (ValueError, AssertionError) as e:
  33. raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.') from e
  34. def __iter__(self):
  35. return self
  36. def __next__(self):
  37. if self.it < self.maxit:
  38. self.it += 1
  39. return self.rates[self.it - 1]
  40. else:
  41. raise StopIteration
  42. class LearnRateScheduler:
  43. def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True):
  44. self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step)
  45. (self.learn_rate, self.end_step) = next(self.schedules)
  46. self.verbose = verbose
  47. if self.verbose:
  48. print(f'Training at rate of {self.learn_rate} until step {self.end_step}')
  49. self.finished = False
  50. def step(self, step_number):
  51. if step_number < self.end_step:
  52. return False
  53. try:
  54. (self.learn_rate, self.end_step) = next(self.schedules)
  55. except StopIteration:
  56. self.finished = True
  57. return False
  58. return True
  59. def apply(self, optimizer, step_number):
  60. if not self.step(step_number):
  61. return
  62. if self.verbose:
  63. tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}')
  64. for pg in optimizer.param_groups:
  65. pg['lr'] = self.learn_rate