# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from paddle.optimizer.lr import LRScheduler


class CyclicalCosineDecay(LRScheduler):
    def __init__(self,
                 learning_rate,
                 T_max,
                 cycle=1,
                 last_epoch=-1,
                 eta_min=0.0,
                 verbose=False):
        """
        Cyclical cosine learning rate decay
        A learning rate which can be referred in https://arxiv.org/pdf/2012.12645.pdf
        Args:
            learning rate(float): learning rate
            T_max(int): maximum epoch num
            cycle(int): period of the cosine decay
            last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
            eta_min(float): minimum learning rate during training
            verbose(bool): whether to print learning rate for each epoch
        """
        super(CyclicalCosineDecay, self).__init__(learning_rate, last_epoch,
                                                  verbose)
        self.cycle = cycle
        self.eta_min = eta_min

    def get_lr(self):
        if self.last_epoch == 0:
            return self.base_lr
        reletive_epoch = self.last_epoch % self.cycle
        lr = self.eta_min + 0.5 * (self.base_lr - self.eta_min) * \
                (1 + math.cos(math.pi * reletive_epoch / self.cycle))
        return lr


class OneCycleDecay(LRScheduler):
    """
    One Cycle learning rate decay
    A learning rate which can be referred in https://arxiv.org/abs/1708.07120
    Code refered in https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
    """

    def __init__(self,
                 max_lr,
                 epochs=None,
                 steps_per_epoch=None,
                 pct_start=0.3,
                 anneal_strategy='cos',
                 div_factor=25.,
                 final_div_factor=1e4,
                 three_phase=False,
                 last_epoch=-1,
                 verbose=False):

        # Validate total_steps
        if epochs <= 0 or not isinstance(epochs, int):
            raise ValueError(
                "Expected positive integer epochs, but got {}".format(epochs))
        if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int):
            raise ValueError(
                "Expected positive integer steps_per_epoch, but got {}".format(
                    steps_per_epoch))
        self.total_steps = epochs * steps_per_epoch

        self.max_lr = max_lr
        self.initial_lr = self.max_lr / div_factor
        self.min_lr = self.initial_lr / final_div_factor

        if three_phase:
            self._schedule_phases = [
                {
                    'end_step': float(pct_start * self.total_steps) - 1,
                    'start_lr': self.initial_lr,
                    'end_lr': self.max_lr,
                },
                {
                    'end_step': float(2 * pct_start * self.total_steps) - 2,
                    'start_lr': self.max_lr,
                    'end_lr': self.initial_lr,
                },
                {
                    'end_step': self.total_steps - 1,
                    'start_lr': self.initial_lr,
                    'end_lr': self.min_lr,
                },
            ]
        else:
            self._schedule_phases = [
                {
                    'end_step': float(pct_start * self.total_steps) - 1,
                    'start_lr': self.initial_lr,
                    'end_lr': self.max_lr,
                },
                {
                    'end_step': self.total_steps - 1,
                    'start_lr': self.max_lr,
                    'end_lr': self.min_lr,
                },
            ]

        # Validate pct_start
        if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
            raise ValueError(
                "Expected float between 0 and 1 pct_start, but got {}".format(
                    pct_start))

        # Validate anneal_strategy
        if anneal_strategy not in ['cos', 'linear']:
            raise ValueError(
                "anneal_strategy must by one of 'cos' or 'linear', instead got {}".
                format(anneal_strategy))
        elif anneal_strategy == 'cos':
            self.anneal_func = self._annealing_cos
        elif anneal_strategy == 'linear':
            self.anneal_func = self._annealing_linear

        super(OneCycleDecay, self).__init__(max_lr, last_epoch, verbose)

    def _annealing_cos(self, start, end, pct):
        "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
        cos_out = math.cos(math.pi * pct) + 1
        return end + (start - end) / 2.0 * cos_out

    def _annealing_linear(self, start, end, pct):
        "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
        return (end - start) * pct + start

    def get_lr(self):
        computed_lr = 0.0
        step_num = self.last_epoch

        if step_num > self.total_steps:
            raise ValueError(
                "Tried to step {} times. The specified number of total steps is {}"
                .format(step_num + 1, self.total_steps))
        start_step = 0
        for i, phase in enumerate(self._schedule_phases):
            end_step = phase['end_step']
            if step_num <= end_step or i == len(self._schedule_phases) - 1:
                pct = (step_num - start_step) / (end_step - start_step)
                computed_lr = self.anneal_func(phase['start_lr'],
                                               phase['end_lr'], pct)
                break
            start_step = phase['end_step']

        return computed_lr


class TwoStepCosineDecay(LRScheduler):
    def __init__(self,
                 learning_rate,
                 T_max1,
                 T_max2,
                 eta_min=0,
                 last_epoch=-1,
                 verbose=False):
        if not isinstance(T_max1, int):
            raise TypeError(
                "The type of 'T_max1' in 'CosineAnnealingDecay' must be 'int', but received %s."
                % type(T_max1))
        if not isinstance(T_max2, int):
            raise TypeError(
                "The type of 'T_max2' in 'CosineAnnealingDecay' must be 'int', but received %s."
                % type(T_max2))
        if not isinstance(eta_min, (float, int)):
            raise TypeError(
                "The type of 'eta_min' in 'CosineAnnealingDecay' must be 'float, int', but received %s."
                % type(eta_min))
        assert T_max1 > 0 and isinstance(
            T_max1, int), " 'T_max1' must be a positive integer."
        assert T_max2 > 0 and isinstance(
            T_max2, int), " 'T_max1' must be a positive integer."
        self.T_max1 = T_max1
        self.T_max2 = T_max2
        self.eta_min = float(eta_min)
        super(TwoStepCosineDecay, self).__init__(learning_rate, last_epoch,
                                                 verbose)

    def get_lr(self):

        if self.last_epoch <= self.T_max1:
            if self.last_epoch == 0:
                return self.base_lr
            elif (self.last_epoch - 1 - self.T_max1) % (2 * self.T_max1) == 0:
                return self.last_lr + (self.base_lr - self.eta_min) * (
                    1 - math.cos(math.pi / self.T_max1)) / 2

            return (1 + math.cos(math.pi * self.last_epoch / self.T_max1)) / (
                1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max1)) * (
                    self.last_lr - self.eta_min) + self.eta_min
        else:
            if (self.last_epoch - 1 - self.T_max2) % (2 * self.T_max2) == 0:
                return self.last_lr + (self.base_lr - self.eta_min) * (
                    1 - math.cos(math.pi / self.T_max2)) / 2

            return (1 + math.cos(math.pi * self.last_epoch / self.T_max2)) / (
                1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max2)) * (
                    self.last_lr - self.eta_min) + self.eta_min

    def _get_closed_form_lr(self):
        if self.last_epoch <= self.T_max1:
            return self.eta_min + (self.base_lr - self.eta_min) * (1 + math.cos(
                math.pi * self.last_epoch / self.T_max1)) / 2
        else:
            return self.eta_min + (self.base_lr - self.eta_min) * (1 + math.cos(
                math.pi * self.last_epoch / self.T_max2)) / 2