lr_scheduler.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from paddle.optimizer.lr import LRScheduler
  16. class CyclicalCosineDecay(LRScheduler):
  17. def __init__(self,
  18. learning_rate,
  19. T_max,
  20. cycle=1,
  21. last_epoch=-1,
  22. eta_min=0.0,
  23. verbose=False):
  24. """
  25. Cyclical cosine learning rate decay
  26. A learning rate which can be referred in https://arxiv.org/pdf/2012.12645.pdf
  27. Args:
  28. learning rate(float): learning rate
  29. T_max(int): maximum epoch num
  30. cycle(int): period of the cosine decay
  31. last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
  32. eta_min(float): minimum learning rate during training
  33. verbose(bool): whether to print learning rate for each epoch
  34. """
  35. super(CyclicalCosineDecay, self).__init__(learning_rate, last_epoch,
  36. verbose)
  37. self.cycle = cycle
  38. self.eta_min = eta_min
  39. def get_lr(self):
  40. if self.last_epoch == 0:
  41. return self.base_lr
  42. reletive_epoch = self.last_epoch % self.cycle
  43. lr = self.eta_min + 0.5 * (self.base_lr - self.eta_min) * \
  44. (1 + math.cos(math.pi * reletive_epoch / self.cycle))
  45. return lr
  46. class OneCycleDecay(LRScheduler):
  47. """
  48. One Cycle learning rate decay
  49. A learning rate which can be referred in https://arxiv.org/abs/1708.07120
  50. Code refered in https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
  51. """
  52. def __init__(self,
  53. max_lr,
  54. epochs=None,
  55. steps_per_epoch=None,
  56. pct_start=0.3,
  57. anneal_strategy='cos',
  58. div_factor=25.,
  59. final_div_factor=1e4,
  60. three_phase=False,
  61. last_epoch=-1,
  62. verbose=False):
  63. # Validate total_steps
  64. if epochs <= 0 or not isinstance(epochs, int):
  65. raise ValueError(
  66. "Expected positive integer epochs, but got {}".format(epochs))
  67. if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int):
  68. raise ValueError(
  69. "Expected positive integer steps_per_epoch, but got {}".format(
  70. steps_per_epoch))
  71. self.total_steps = epochs * steps_per_epoch
  72. self.max_lr = max_lr
  73. self.initial_lr = self.max_lr / div_factor
  74. self.min_lr = self.initial_lr / final_div_factor
  75. if three_phase:
  76. self._schedule_phases = [
  77. {
  78. 'end_step': float(pct_start * self.total_steps) - 1,
  79. 'start_lr': self.initial_lr,
  80. 'end_lr': self.max_lr,
  81. },
  82. {
  83. 'end_step': float(2 * pct_start * self.total_steps) - 2,
  84. 'start_lr': self.max_lr,
  85. 'end_lr': self.initial_lr,
  86. },
  87. {
  88. 'end_step': self.total_steps - 1,
  89. 'start_lr': self.initial_lr,
  90. 'end_lr': self.min_lr,
  91. },
  92. ]
  93. else:
  94. self._schedule_phases = [
  95. {
  96. 'end_step': float(pct_start * self.total_steps) - 1,
  97. 'start_lr': self.initial_lr,
  98. 'end_lr': self.max_lr,
  99. },
  100. {
  101. 'end_step': self.total_steps - 1,
  102. 'start_lr': self.max_lr,
  103. 'end_lr': self.min_lr,
  104. },
  105. ]
  106. # Validate pct_start
  107. if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
  108. raise ValueError(
  109. "Expected float between 0 and 1 pct_start, but got {}".format(
  110. pct_start))
  111. # Validate anneal_strategy
  112. if anneal_strategy not in ['cos', 'linear']:
  113. raise ValueError(
  114. "anneal_strategy must by one of 'cos' or 'linear', instead got {}".
  115. format(anneal_strategy))
  116. elif anneal_strategy == 'cos':
  117. self.anneal_func = self._annealing_cos
  118. elif anneal_strategy == 'linear':
  119. self.anneal_func = self._annealing_linear
  120. super(OneCycleDecay, self).__init__(max_lr, last_epoch, verbose)
  121. def _annealing_cos(self, start, end, pct):
  122. "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
  123. cos_out = math.cos(math.pi * pct) + 1
  124. return end + (start - end) / 2.0 * cos_out
  125. def _annealing_linear(self, start, end, pct):
  126. "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
  127. return (end - start) * pct + start
  128. def get_lr(self):
  129. computed_lr = 0.0
  130. step_num = self.last_epoch
  131. if step_num > self.total_steps:
  132. raise ValueError(
  133. "Tried to step {} times. The specified number of total steps is {}"
  134. .format(step_num + 1, self.total_steps))
  135. start_step = 0
  136. for i, phase in enumerate(self._schedule_phases):
  137. end_step = phase['end_step']
  138. if step_num <= end_step or i == len(self._schedule_phases) - 1:
  139. pct = (step_num - start_step) / (end_step - start_step)
  140. computed_lr = self.anneal_func(phase['start_lr'],
  141. phase['end_lr'], pct)
  142. break
  143. start_step = phase['end_step']
  144. return computed_lr
  145. class TwoStepCosineDecay(LRScheduler):
  146. def __init__(self,
  147. learning_rate,
  148. T_max1,
  149. T_max2,
  150. eta_min=0,
  151. last_epoch=-1,
  152. verbose=False):
  153. if not isinstance(T_max1, int):
  154. raise TypeError(
  155. "The type of 'T_max1' in 'CosineAnnealingDecay' must be 'int', but received %s."
  156. % type(T_max1))
  157. if not isinstance(T_max2, int):
  158. raise TypeError(
  159. "The type of 'T_max2' in 'CosineAnnealingDecay' must be 'int', but received %s."
  160. % type(T_max2))
  161. if not isinstance(eta_min, (float, int)):
  162. raise TypeError(
  163. "The type of 'eta_min' in 'CosineAnnealingDecay' must be 'float, int', but received %s."
  164. % type(eta_min))
  165. assert T_max1 > 0 and isinstance(
  166. T_max1, int), " 'T_max1' must be a positive integer."
  167. assert T_max2 > 0 and isinstance(
  168. T_max2, int), " 'T_max1' must be a positive integer."
  169. self.T_max1 = T_max1
  170. self.T_max2 = T_max2
  171. self.eta_min = float(eta_min)
  172. super(TwoStepCosineDecay, self).__init__(learning_rate, last_epoch,
  173. verbose)
  174. def get_lr(self):
  175. if self.last_epoch <= self.T_max1:
  176. if self.last_epoch == 0:
  177. return self.base_lr
  178. elif (self.last_epoch - 1 - self.T_max1) % (2 * self.T_max1) == 0:
  179. return self.last_lr + (self.base_lr - self.eta_min) * (
  180. 1 - math.cos(math.pi / self.T_max1)) / 2
  181. return (1 + math.cos(math.pi * self.last_epoch / self.T_max1)) / (
  182. 1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max1)) * (
  183. self.last_lr - self.eta_min) + self.eta_min
  184. else:
  185. if (self.last_epoch - 1 - self.T_max2) % (2 * self.T_max2) == 0:
  186. return self.last_lr + (self.base_lr - self.eta_min) * (
  187. 1 - math.cos(math.pi / self.T_max2)) / 2
  188. return (1 + math.cos(math.pi * self.last_epoch / self.T_max2)) / (
  189. 1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max2)) * (
  190. self.last_lr - self.eta_min) + self.eta_min
  191. def _get_closed_form_lr(self):
  192. if self.last_epoch <= self.T_max1:
  193. return self.eta_min + (self.base_lr - self.eta_min) * (1 + math.cos(
  194. math.pi * self.last_epoch / self.T_max1)) / 2
  195. else:
  196. return self.eta_min + (self.base_lr - self.eta_min) * (1 + math.cos(
  197. math.pi * self.last_epoch / self.T_max2)) / 2