# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from: 
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/encoders/sar_encoder.py
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/decoders/sar_decoder.py
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F


class SAREncoder(nn.Layer):
    """
    Args:
        enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
        enc_drop_rnn (float): Dropout probability of RNN layer in encoder.
        enc_gru (bool): If True, use GRU, else LSTM in encoder.
        d_model (int): Dim of channels from backbone.
        d_enc (int): Dim of encoder RNN layer.
        mask (bool): If True, mask padding in RNN sequence.
    """

    def __init__(self,
                 enc_bi_rnn=False,
                 enc_drop_rnn=0.1,
                 enc_gru=False,
                 d_model=512,
                 d_enc=512,
                 mask=True,
                 **kwargs):
        super().__init__()
        assert isinstance(enc_bi_rnn, bool)
        assert isinstance(enc_drop_rnn, (int, float))
        assert 0 <= enc_drop_rnn < 1.0
        assert isinstance(enc_gru, bool)
        assert isinstance(d_model, int)
        assert isinstance(d_enc, int)
        assert isinstance(mask, bool)

        self.enc_bi_rnn = enc_bi_rnn
        self.enc_drop_rnn = enc_drop_rnn
        self.mask = mask

        # LSTM Encoder
        if enc_bi_rnn:
            direction = 'bidirectional'
        else:
            direction = 'forward'
        kwargs = dict(
            input_size=d_model,
            hidden_size=d_enc,
            num_layers=2,
            time_major=False,
            dropout=enc_drop_rnn,
            direction=direction)
        if enc_gru:
            self.rnn_encoder = nn.GRU(**kwargs)
        else:
            self.rnn_encoder = nn.LSTM(**kwargs)

        # global feature transformation
        encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
        self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size)

    def forward(self, feat, img_metas=None):
        if img_metas is not None:
            assert len(img_metas[0]) == paddle.shape(feat)[0]

        valid_ratios = None
        if img_metas is not None and self.mask:
            valid_ratios = img_metas[-1]

        h_feat = feat.shape[2]  # bsz c h w
        feat_v = F.max_pool2d(
            feat, kernel_size=(h_feat, 1), stride=1, padding=0)
        feat_v = feat_v.squeeze(2)  # bsz * C * W
        feat_v = paddle.transpose(feat_v, perm=[0, 2, 1])  # bsz * W * C
        holistic_feat = self.rnn_encoder(feat_v)[0]  # bsz * T * C

        if valid_ratios is not None:
            valid_hf = []
            T = paddle.shape(holistic_feat)[1]
            for i in range(paddle.shape(valid_ratios)[0]):
                valid_step = paddle.minimum(
                    T, paddle.ceil(valid_ratios[i] * T).astype('int32')) - 1
                valid_hf.append(holistic_feat[i, valid_step, :])
            valid_hf = paddle.stack(valid_hf, axis=0)
        else:
            valid_hf = holistic_feat[:, -1, :]  # bsz * C
        holistic_feat = self.linear(valid_hf)  # bsz * C

        return holistic_feat


class BaseDecoder(nn.Layer):
    def __init__(self, **kwargs):
        super().__init__()

    def forward_train(self, feat, out_enc, targets, img_metas):
        raise NotImplementedError

    def forward_test(self, feat, out_enc, img_metas):
        raise NotImplementedError

    def forward(self,
                feat,
                out_enc,
                label=None,
                img_metas=None,
                train_mode=True):
        self.train_mode = train_mode

        if train_mode:
            return self.forward_train(feat, out_enc, label, img_metas)
        return self.forward_test(feat, out_enc, img_metas)


class ParallelSARDecoder(BaseDecoder):
    """
    Args:
        out_channels (int): Output class number.
        enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
        dec_bi_rnn (bool): If True, use bidirectional RNN in decoder.
        dec_drop_rnn (float): Dropout of RNN layer in decoder.
        dec_gru (bool): If True, use GRU, else LSTM in decoder.
        d_model (int): Dim of channels from backbone.
        d_enc (int): Dim of encoder RNN layer.
        d_k (int): Dim of channels of attention module.
        pred_dropout (float): Dropout probability of prediction layer.
        max_seq_len (int): Maximum sequence length for decoding.
        mask (bool): If True, mask padding in feature map.
        start_idx (int): Index of start token.
        padding_idx (int): Index of padding token.
        pred_concat (bool): If True, concat glimpse feature from
            attention with holistic feature and hidden state.
    """

    def __init__(
            self,
            out_channels,  # 90 + unknown + start + padding
            enc_bi_rnn=False,
            dec_bi_rnn=False,
            dec_drop_rnn=0.0,
            dec_gru=False,
            d_model=512,
            d_enc=512,
            d_k=64,
            pred_dropout=0.1,
            max_text_length=30,
            mask=True,
            pred_concat=True,
            **kwargs):
        super().__init__()

        self.num_classes = out_channels
        self.enc_bi_rnn = enc_bi_rnn
        self.d_k = d_k
        self.start_idx = out_channels - 2
        self.padding_idx = out_channels - 1
        self.max_seq_len = max_text_length
        self.mask = mask
        self.pred_concat = pred_concat

        encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
        decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1)

        # 2D attention layer
        self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k)
        self.conv3x3_1 = nn.Conv2D(
            d_model, d_k, kernel_size=3, stride=1, padding=1)
        self.conv1x1_2 = nn.Linear(d_k, 1)

        # Decoder RNN layer
        if dec_bi_rnn:
            direction = 'bidirectional'
        else:
            direction = 'forward'

        kwargs = dict(
            input_size=encoder_rnn_out_size,
            hidden_size=encoder_rnn_out_size,
            num_layers=2,
            time_major=False,
            dropout=dec_drop_rnn,
            direction=direction)
        if dec_gru:
            self.rnn_decoder = nn.GRU(**kwargs)
        else:
            self.rnn_decoder = nn.LSTM(**kwargs)

        # Decoder input embedding
        self.embedding = nn.Embedding(
            self.num_classes,
            encoder_rnn_out_size,
            padding_idx=self.padding_idx)

        # Prediction layer
        self.pred_dropout = nn.Dropout(pred_dropout)
        pred_num_classes = self.num_classes - 1
        if pred_concat:
            fc_in_channel = decoder_rnn_out_size + d_model + encoder_rnn_out_size
        else:
            fc_in_channel = d_model
        self.prediction = nn.Linear(fc_in_channel, pred_num_classes)

    def _2d_attention(self,
                      decoder_input,
                      feat,
                      holistic_feat,
                      valid_ratios=None):

        y = self.rnn_decoder(decoder_input)[0]
        # y: bsz * (seq_len + 1) * hidden_size

        attn_query = self.conv1x1_1(y)  # bsz * (seq_len + 1) * attn_size
        bsz, seq_len, attn_size = attn_query.shape
        attn_query = paddle.unsqueeze(attn_query, axis=[3, 4])
        # (bsz, seq_len + 1, attn_size, 1, 1)

        attn_key = self.conv3x3_1(feat)
        # bsz * attn_size * h * w
        attn_key = attn_key.unsqueeze(1)
        # bsz * 1 * attn_size * h * w

        attn_weight = paddle.tanh(paddle.add(attn_key, attn_query))

        # bsz * (seq_len + 1) * attn_size * h * w
        attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 3, 4, 2])
        # bsz * (seq_len + 1) * h * w * attn_size
        attn_weight = self.conv1x1_2(attn_weight)
        # bsz * (seq_len + 1) * h * w * 1
        bsz, T, h, w, c = paddle.shape(attn_weight)
        assert c == 1

        if valid_ratios is not None:
            # cal mask of attention weight
            for i in range(paddle.shape(valid_ratios)[0]):
                valid_width = paddle.minimum(
                    w, paddle.ceil(valid_ratios[i] * w).astype("int32"))
                if valid_width < w:
                    attn_weight[i, :, :, valid_width:, :] = float('-inf')

        attn_weight = paddle.reshape(attn_weight, [bsz, T, -1])
        attn_weight = F.softmax(attn_weight, axis=-1)

        attn_weight = paddle.reshape(attn_weight, [bsz, T, h, w, c])
        attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 4, 2, 3])
        # attn_weight: bsz * T * c * h * w
        # feat: bsz * c * h * w
        attn_feat = paddle.sum(paddle.multiply(feat.unsqueeze(1), attn_weight),
                               (3, 4),
                               keepdim=False)
        # bsz * (seq_len + 1) * C

        # Linear transformation
        if self.pred_concat:
            hf_c = holistic_feat.shape[-1]
            holistic_feat = paddle.expand(
                holistic_feat, shape=[bsz, seq_len, hf_c])
            y = self.prediction(paddle.concat((y, attn_feat, holistic_feat), 2))
        else:
            y = self.prediction(attn_feat)
        # bsz * (seq_len + 1) * num_classes
        if self.train_mode:
            y = self.pred_dropout(y)

        return y

    def forward_train(self, feat, out_enc, label, img_metas):
        '''
        img_metas: [label, valid_ratio]
        '''
        if img_metas is not None:
            assert paddle.shape(img_metas[0])[0] == paddle.shape(feat)[0]

        valid_ratios = None
        if img_metas is not None and self.mask:
            valid_ratios = img_metas[-1]

        lab_embedding = self.embedding(label)
        # bsz * seq_len * emb_dim
        out_enc = out_enc.unsqueeze(1)
        # bsz * 1 * emb_dim
        in_dec = paddle.concat((out_enc, lab_embedding), axis=1)
        # bsz * (seq_len + 1) * C
        out_dec = self._2d_attention(
            in_dec, feat, out_enc, valid_ratios=valid_ratios)

        return out_dec[:, 1:, :]  # bsz * seq_len * num_classes

    def forward_test(self, feat, out_enc, img_metas):
        if img_metas is not None:
            assert len(img_metas[0]) == feat.shape[0]

        valid_ratios = None
        if img_metas is not None and self.mask:
            valid_ratios = img_metas[-1]

        seq_len = self.max_seq_len
        bsz = feat.shape[0]
        start_token = paddle.full(
            (bsz, ), fill_value=self.start_idx, dtype='int64')
        # bsz
        start_token = self.embedding(start_token)
        # bsz * emb_dim
        emb_dim = start_token.shape[1]
        start_token = start_token.unsqueeze(1)
        start_token = paddle.expand(start_token, shape=[bsz, seq_len, emb_dim])
        # bsz * seq_len * emb_dim
        out_enc = out_enc.unsqueeze(1)
        # bsz * 1 * emb_dim
        decoder_input = paddle.concat((out_enc, start_token), axis=1)
        # bsz * (seq_len + 1) * emb_dim

        outputs = []
        for i in range(1, seq_len + 1):
            decoder_output = self._2d_attention(
                decoder_input, feat, out_enc, valid_ratios=valid_ratios)
            char_output = decoder_output[:, i, :]  # bsz * num_classes
            char_output = F.softmax(char_output, -1)
            outputs.append(char_output)
            max_idx = paddle.argmax(char_output, axis=1, keepdim=False)
            char_embedding = self.embedding(max_idx)  # bsz * emb_dim
            if i < seq_len:
                decoder_input[:, i + 1, :] = char_embedding

        outputs = paddle.stack(outputs, 1)  # bsz * seq_len * num_classes

        return outputs


class SARHead(nn.Layer):
    def __init__(self,
                 in_channels,
                 out_channels,
                 enc_dim=512,
                 max_text_length=30,
                 enc_bi_rnn=False,
                 enc_drop_rnn=0.1,
                 enc_gru=False,
                 dec_bi_rnn=False,
                 dec_drop_rnn=0.0,
                 dec_gru=False,
                 d_k=512,
                 pred_dropout=0.1,
                 pred_concat=True,
                 **kwargs):
        super(SARHead, self).__init__()

        # encoder module
        self.encoder = SAREncoder(
            enc_bi_rnn=enc_bi_rnn,
            enc_drop_rnn=enc_drop_rnn,
            enc_gru=enc_gru,
            d_model=in_channels,
            d_enc=enc_dim)

        # decoder module
        self.decoder = ParallelSARDecoder(
            out_channels=out_channels,
            enc_bi_rnn=enc_bi_rnn,
            dec_bi_rnn=dec_bi_rnn,
            dec_drop_rnn=dec_drop_rnn,
            dec_gru=dec_gru,
            d_model=in_channels,
            d_enc=enc_dim,
            d_k=d_k,
            pred_dropout=pred_dropout,
            max_text_length=max_text_length,
            pred_concat=pred_concat)

    def forward(self, feat, targets=None):
        '''
        img_metas: [label, valid_ratio]
        '''
        holistic_feat = self.encoder(feat, targets)  # bsz c

        if self.training:
            label = targets[0]  # label
            final_out = self.decoder(
                feat, holistic_feat, label, img_metas=targets)
        else:
            final_out = self.decoder(
                feat,
                holistic_feat,
                label=None,
                img_metas=targets,
                train_mode=False)
            # (bsz, seq_len, num_classes)

        return final_out