123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081 |
- import tqdm
- class LearnScheduleIterator:
- def __init__(self, learn_rate, max_steps, cur_step=0):
- """
- specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000
- """
- pairs = learn_rate.split(',')
- self.rates = []
- self.it = 0
- self.maxit = 0
- try:
- for pair in pairs:
- if not pair.strip():
- continue
- tmp = pair.split(':')
- if len(tmp) == 2:
- step = int(tmp[1])
- if step > cur_step:
- self.rates.append((float(tmp[0]), min(step, max_steps)))
- self.maxit += 1
- if step > max_steps:
- return
- elif step == -1:
- self.rates.append((float(tmp[0]), max_steps))
- self.maxit += 1
- return
- else:
- self.rates.append((float(tmp[0]), max_steps))
- self.maxit += 1
- return
- assert self.rates
- except (ValueError, AssertionError) as e:
- raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.') from e
- def __iter__(self):
- return self
- def __next__(self):
- if self.it < self.maxit:
- self.it += 1
- return self.rates[self.it - 1]
- else:
- raise StopIteration
- class LearnRateScheduler:
- def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True):
- self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step)
- (self.learn_rate, self.end_step) = next(self.schedules)
- self.verbose = verbose
- if self.verbose:
- print(f'Training at rate of {self.learn_rate} until step {self.end_step}')
- self.finished = False
- def step(self, step_number):
- if step_number < self.end_step:
- return False
- try:
- (self.learn_rate, self.end_step) = next(self.schedules)
- except StopIteration:
- self.finished = True
- return False
- return True
- def apply(self, optimizer, step_number):
- if not self.step(step_number):
- return
- if self.verbose:
- tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}')
- for pg in optimizer.param_groups:
- pg['lr'] = self.learn_rate
|