# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/FudanVI/FudanOCR/blob/main/text-gestalt/loss/transformer_english_decomposition.py
"""
import copy
import math

import paddle
import paddle.nn as nn
import paddle.nn.functional as F


def subsequent_mask(size):
    """Generate a square mask for the sequence. The masked positions are filled with float('-inf').
        Unmasked positions are filled with float(0.0).
    """
    mask = paddle.ones([1, size, size], dtype='float32')
    mask_inf = paddle.triu(
        paddle.full(
            shape=[1, size, size], dtype='float32', fill_value='-inf'),
        diagonal=1)
    mask = mask + mask_inf
    padding_mask = paddle.equal(mask, paddle.to_tensor(1, dtype=mask.dtype))
    return padding_mask


def clones(module, N):
    return nn.LayerList([copy.deepcopy(module) for _ in range(N)])


def masked_fill(x, mask, value):
    y = paddle.full(x.shape, value, x.dtype)
    return paddle.where(mask, y, x)


def attention(query, key, value, mask=None, dropout=None, attention_map=None):
    d_k = query.shape[-1]
    scores = paddle.matmul(query,
                           paddle.transpose(key, [0, 1, 3, 2])) / math.sqrt(d_k)

    if mask is not None:
        scores = masked_fill(scores, mask == 0, float('-inf'))
    else:
        pass

    p_attn = F.softmax(scores, axis=-1)

    if dropout is not None:
        p_attn = dropout(p_attn)
    return paddle.matmul(p_attn, value), p_attn


class MultiHeadedAttention(nn.Layer):
    def __init__(self, h, d_model, dropout=0.1, compress_attention=False):
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout, mode="downscale_in_infer")
        self.compress_attention = compress_attention
        self.compress_attention_linear = nn.Linear(h, 1)

    def forward(self, query, key, value, mask=None, attention_map=None):
        if mask is not None:
            mask = mask.unsqueeze(1)
        nbatches = query.shape[0]

        query, key, value = \
            [paddle.transpose(l(x).reshape([nbatches, -1, self.h, self.d_k]), [0,2,1,3])
             for l, x in zip(self.linears, (query, key, value))]

        x, attention_map = attention(
            query,
            key,
            value,
            mask=mask,
            dropout=self.dropout,
            attention_map=attention_map)

        x = paddle.reshape(
            paddle.transpose(x, [0, 2, 1, 3]),
            [nbatches, -1, self.h * self.d_k])

        return self.linears[-1](x), attention_map


class ResNet(nn.Layer):
    def __init__(self, num_in, block, layers):
        super(ResNet, self).__init__()

        self.conv1 = nn.Conv2D(num_in, 64, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2D(64, use_global_stats=True)
        self.relu1 = nn.ReLU()
        self.pool = nn.MaxPool2D((2, 2), (2, 2))

        self.conv2 = nn.Conv2D(64, 128, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2D(128, use_global_stats=True)
        self.relu2 = nn.ReLU()

        self.layer1_pool = nn.MaxPool2D((2, 2), (2, 2))
        self.layer1 = self._make_layer(block, 128, 256, layers[0])
        self.layer1_conv = nn.Conv2D(256, 256, 3, 1, 1)
        self.layer1_bn = nn.BatchNorm2D(256, use_global_stats=True)
        self.layer1_relu = nn.ReLU()

        self.layer2_pool = nn.MaxPool2D((2, 2), (2, 2))
        self.layer2 = self._make_layer(block, 256, 256, layers[1])
        self.layer2_conv = nn.Conv2D(256, 256, 3, 1, 1)
        self.layer2_bn = nn.BatchNorm2D(256, use_global_stats=True)
        self.layer2_relu = nn.ReLU()

        self.layer3_pool = nn.MaxPool2D((2, 2), (2, 2))
        self.layer3 = self._make_layer(block, 256, 512, layers[2])
        self.layer3_conv = nn.Conv2D(512, 512, 3, 1, 1)
        self.layer3_bn = nn.BatchNorm2D(512, use_global_stats=True)
        self.layer3_relu = nn.ReLU()

        self.layer4_pool = nn.MaxPool2D((2, 2), (2, 2))
        self.layer4 = self._make_layer(block, 512, 512, layers[3])
        self.layer4_conv2 = nn.Conv2D(512, 1024, 3, 1, 1)
        self.layer4_conv2_bn = nn.BatchNorm2D(1024, use_global_stats=True)
        self.layer4_conv2_relu = nn.ReLU()

    def _make_layer(self, block, inplanes, planes, blocks):

        if inplanes != planes:
            downsample = nn.Sequential(
                nn.Conv2D(inplanes, planes, 3, 1, 1),
                nn.BatchNorm2D(
                    planes, use_global_stats=True), )
        else:
            downsample = None
        layers = []
        layers.append(block(inplanes, planes, downsample))
        for i in range(1, blocks):
            layers.append(block(planes, planes, downsample=None))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.pool(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)

        x = self.layer1_pool(x)
        x = self.layer1(x)
        x = self.layer1_conv(x)
        x = self.layer1_bn(x)
        x = self.layer1_relu(x)

        x = self.layer2(x)
        x = self.layer2_conv(x)
        x = self.layer2_bn(x)
        x = self.layer2_relu(x)

        x = self.layer3(x)
        x = self.layer3_conv(x)
        x = self.layer3_bn(x)
        x = self.layer3_relu(x)

        x = self.layer4(x)
        x = self.layer4_conv2(x)
        x = self.layer4_conv2_bn(x)
        x = self.layer4_conv2_relu(x)

        return x


class Bottleneck(nn.Layer):
    def __init__(self, input_dim):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2D(input_dim, input_dim, 1)
        self.bn1 = nn.BatchNorm2D(input_dim, use_global_stats=True)
        self.relu = nn.ReLU()

        self.conv2 = nn.Conv2D(input_dim, input_dim, 3, 1, 1)
        self.bn2 = nn.BatchNorm2D(input_dim, use_global_stats=True)

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += residual
        out = self.relu(out)

        return out


class PositionalEncoding(nn.Layer):
    "Implement the PE function."

    def __init__(self, dropout, dim, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout, mode="downscale_in_infer")

        pe = paddle.zeros([max_len, dim])
        position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1)
        div_term = paddle.exp(
            paddle.arange(0, dim, 2).astype('float32') *
            (-math.log(10000.0) / dim))
        pe[:, 0::2] = paddle.sin(position * div_term)
        pe[:, 1::2] = paddle.cos(position * div_term)
        pe = paddle.unsqueeze(pe, 0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :paddle.shape(x)[1]]
        return self.dropout(x)


class PositionwiseFeedForward(nn.Layer):
    "Implements FFN equation."

    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout, mode="downscale_in_infer")

    def forward(self, x):
        return self.w_2(self.dropout(F.relu(self.w_1(x))))


class Generator(nn.Layer):
    "Define standard linear + softmax generation step."

    def __init__(self, d_model, vocab):
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, vocab)
        self.relu = nn.ReLU()

    def forward(self, x):
        out = self.proj(x)
        return out


class Embeddings(nn.Layer):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        embed = self.lut(x) * math.sqrt(self.d_model)
        return embed


class LayerNorm(nn.Layer):
    "Construct a layernorm module (See citation for details)."

    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = self.create_parameter(
            shape=[features],
            default_initializer=paddle.nn.initializer.Constant(1.0))
        self.b_2 = self.create_parameter(
            shape=[features],
            default_initializer=paddle.nn.initializer.Constant(0.0))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2


class Decoder(nn.Layer):
    def __init__(self):
        super(Decoder, self).__init__()

        self.mask_multihead = MultiHeadedAttention(
            h=16, d_model=1024, dropout=0.1)
        self.mul_layernorm1 = LayerNorm(1024)

        self.multihead = MultiHeadedAttention(h=16, d_model=1024, dropout=0.1)
        self.mul_layernorm2 = LayerNorm(1024)

        self.pff = PositionwiseFeedForward(1024, 2048)
        self.mul_layernorm3 = LayerNorm(1024)

    def forward(self, text, conv_feature, attention_map=None):
        text_max_length = text.shape[1]
        mask = subsequent_mask(text_max_length)
        result = text
        result = self.mul_layernorm1(result + self.mask_multihead(
            text, text, text, mask=mask)[0])
        b, c, h, w = conv_feature.shape
        conv_feature = paddle.transpose(
            conv_feature.reshape([b, c, h * w]), [0, 2, 1])
        word_image_align, attention_map = self.multihead(
            result,
            conv_feature,
            conv_feature,
            mask=None,
            attention_map=attention_map)
        result = self.mul_layernorm2(result + word_image_align)
        result = self.mul_layernorm3(result + self.pff(result))

        return result, attention_map


class BasicBlock(nn.Layer):
    def __init__(self, inplanes, planes, downsample):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2D(
            inplanes, planes, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2D(planes, use_global_stats=True)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2D(
            planes, planes, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2D(planes, use_global_stats=True)
        self.downsample = downsample

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample != None:
            residual = self.downsample(residual)

        out += residual
        out = self.relu(out)

        return out


class Encoder(nn.Layer):
    def __init__(self):
        super(Encoder, self).__init__()
        self.cnn = ResNet(num_in=1, block=BasicBlock, layers=[1, 2, 5, 3])

    def forward(self, input):
        conv_result = self.cnn(input)
        return conv_result


class Transformer(nn.Layer):
    def __init__(self, in_channels=1, alphabet='0123456789'):
        super(Transformer, self).__init__()
        self.alphabet = alphabet
        word_n_class = self.get_alphabet_len()
        self.embedding_word_with_upperword = Embeddings(512, word_n_class)
        self.pe = PositionalEncoding(dim=512, dropout=0.1, max_len=5000)

        self.encoder = Encoder()
        self.decoder = Decoder()
        self.generator_word_with_upperword = Generator(1024, word_n_class)

        for p in self.parameters():
            if p.dim() > 1:
                nn.initializer.XavierNormal(p)

    def get_alphabet_len(self):
        return len(self.alphabet)

    def forward(self, image, text_length, text_input, attention_map=None):
        if image.shape[1] == 3:
            R = image[:, 0:1, :, :]
            G = image[:, 1:2, :, :]
            B = image[:, 2:3, :, :]
            image = 0.299 * R + 0.587 * G + 0.114 * B

        conv_feature = self.encoder(image)  # batch, 1024, 8, 32
        max_length = max(text_length)
        text_input = text_input[:, :max_length]

        text_embedding = self.embedding_word_with_upperword(
            text_input)  # batch, text_max_length, 512
        postion_embedding = self.pe(
            paddle.zeros(text_embedding.shape))  # batch, text_max_length, 512
        text_input_with_pe = paddle.concat([text_embedding, postion_embedding],
                                           2)  # batch, text_max_length, 1024
        batch, seq_len, _ = text_input_with_pe.shape

        text_input_with_pe, word_attention_map = self.decoder(
            text_input_with_pe, conv_feature)

        word_decoder_result = self.generator_word_with_upperword(
            text_input_with_pe)

        if self.training:
            total_length = paddle.sum(text_length)
            probs_res = paddle.zeros([total_length, self.get_alphabet_len()])
            start = 0

            for index, length in enumerate(text_length):
                length = int(length.numpy())
                probs_res[start:start + length, :] = word_decoder_result[
                    index, 0:0 + length, :]

                start = start + length

            return probs_res, word_attention_map, None
        else:
            return word_decoder_result