123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369 |
- # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- This code is refer from:
- https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/mmocr/models/textrecog/backbones/table_resnet_extra.py
- """
- import paddle
- import paddle.nn as nn
- import paddle.nn.functional as F
- class BasicBlock(nn.Layer):
- expansion = 1
- def __init__(self,
- inplanes,
- planes,
- stride=1,
- downsample=None,
- gcb_config=None):
- super(BasicBlock, self).__init__()
- self.conv1 = nn.Conv2D(
- inplanes,
- planes,
- kernel_size=3,
- stride=stride,
- padding=1,
- bias_attr=False)
- self.bn1 = nn.BatchNorm2D(planes, momentum=0.9)
- self.relu = nn.ReLU()
- self.conv2 = nn.Conv2D(
- planes, planes, kernel_size=3, stride=1, padding=1, bias_attr=False)
- self.bn2 = nn.BatchNorm2D(planes, momentum=0.9)
- self.downsample = downsample
- self.stride = stride
- self.gcb_config = gcb_config
- if self.gcb_config is not None:
- gcb_ratio = gcb_config['ratio']
- gcb_headers = gcb_config['headers']
- att_scale = gcb_config['att_scale']
- fusion_type = gcb_config['fusion_type']
- self.context_block = MultiAspectGCAttention(
- inplanes=planes,
- ratio=gcb_ratio,
- headers=gcb_headers,
- att_scale=att_scale,
- fusion_type=fusion_type)
- def forward(self, x):
- residual = x
- out = self.conv1(x)
- out = self.bn1(out)
- out = self.relu(out)
- out = self.conv2(out)
- out = self.bn2(out)
- if self.gcb_config is not None:
- out = self.context_block(out)
- if self.downsample is not None:
- residual = self.downsample(x)
- out += residual
- out = self.relu(out)
- return out
- def get_gcb_config(gcb_config, layer):
- if gcb_config is None or not gcb_config['layers'][layer]:
- return None
- else:
- return gcb_config
- class TableResNetExtra(nn.Layer):
- def __init__(self, layers, in_channels=3, gcb_config=None):
- assert len(layers) >= 4
- super(TableResNetExtra, self).__init__()
- self.inplanes = 128
- self.conv1 = nn.Conv2D(
- in_channels,
- 64,
- kernel_size=3,
- stride=1,
- padding=1,
- bias_attr=False)
- self.bn1 = nn.BatchNorm2D(64)
- self.relu1 = nn.ReLU()
- self.conv2 = nn.Conv2D(
- 64, 128, kernel_size=3, stride=1, padding=1, bias_attr=False)
- self.bn2 = nn.BatchNorm2D(128)
- self.relu2 = nn.ReLU()
- self.maxpool1 = nn.MaxPool2D(kernel_size=2, stride=2)
- self.layer1 = self._make_layer(
- BasicBlock,
- 256,
- layers[0],
- stride=1,
- gcb_config=get_gcb_config(gcb_config, 0))
- self.conv3 = nn.Conv2D(
- 256, 256, kernel_size=3, stride=1, padding=1, bias_attr=False)
- self.bn3 = nn.BatchNorm2D(256)
- self.relu3 = nn.ReLU()
- self.maxpool2 = nn.MaxPool2D(kernel_size=2, stride=2)
- self.layer2 = self._make_layer(
- BasicBlock,
- 256,
- layers[1],
- stride=1,
- gcb_config=get_gcb_config(gcb_config, 1))
- self.conv4 = nn.Conv2D(
- 256, 256, kernel_size=3, stride=1, padding=1, bias_attr=False)
- self.bn4 = nn.BatchNorm2D(256)
- self.relu4 = nn.ReLU()
- self.maxpool3 = nn.MaxPool2D(kernel_size=2, stride=2)
- self.layer3 = self._make_layer(
- BasicBlock,
- 512,
- layers[2],
- stride=1,
- gcb_config=get_gcb_config(gcb_config, 2))
- self.conv5 = nn.Conv2D(
- 512, 512, kernel_size=3, stride=1, padding=1, bias_attr=False)
- self.bn5 = nn.BatchNorm2D(512)
- self.relu5 = nn.ReLU()
- self.layer4 = self._make_layer(
- BasicBlock,
- 512,
- layers[3],
- stride=1,
- gcb_config=get_gcb_config(gcb_config, 3))
- self.conv6 = nn.Conv2D(
- 512, 512, kernel_size=3, stride=1, padding=1, bias_attr=False)
- self.bn6 = nn.BatchNorm2D(512)
- self.relu6 = nn.ReLU()
- self.out_channels = [256, 256, 512]
- def _make_layer(self, block, planes, blocks, stride=1, gcb_config=None):
- downsample = None
- if stride != 1 or self.inplanes != planes * block.expansion:
- downsample = nn.Sequential(
- nn.Conv2D(
- self.inplanes,
- planes * block.expansion,
- kernel_size=1,
- stride=stride,
- bias_attr=False),
- nn.BatchNorm2D(planes * block.expansion), )
- layers = []
- layers.append(
- block(
- self.inplanes,
- planes,
- stride,
- downsample,
- gcb_config=gcb_config))
- self.inplanes = planes * block.expansion
- for _ in range(1, blocks):
- layers.append(block(self.inplanes, planes))
- return nn.Sequential(*layers)
- def forward(self, x):
- f = []
- x = self.conv1(x)
- x = self.bn1(x)
- x = self.relu1(x)
- x = self.conv2(x)
- x = self.bn2(x)
- x = self.relu2(x)
- x = self.maxpool1(x)
- x = self.layer1(x)
- x = self.conv3(x)
- x = self.bn3(x)
- x = self.relu3(x)
- f.append(x)
- x = self.maxpool2(x)
- x = self.layer2(x)
- x = self.conv4(x)
- x = self.bn4(x)
- x = self.relu4(x)
- f.append(x)
- x = self.maxpool3(x)
- x = self.layer3(x)
- x = self.conv5(x)
- x = self.bn5(x)
- x = self.relu5(x)
- x = self.layer4(x)
- x = self.conv6(x)
- x = self.bn6(x)
- x = self.relu6(x)
- f.append(x)
- return f
- class MultiAspectGCAttention(nn.Layer):
- def __init__(self,
- inplanes,
- ratio,
- headers,
- pooling_type='att',
- att_scale=False,
- fusion_type='channel_add'):
- super(MultiAspectGCAttention, self).__init__()
- assert pooling_type in ['avg', 'att']
- assert fusion_type in ['channel_add', 'channel_mul', 'channel_concat']
- assert inplanes % headers == 0 and inplanes >= 8 # inplanes must be divided by headers evenly
- self.headers = headers
- self.inplanes = inplanes
- self.ratio = ratio
- self.planes = int(inplanes * ratio)
- self.pooling_type = pooling_type
- self.fusion_type = fusion_type
- self.att_scale = False
- self.single_header_inplanes = int(inplanes / headers)
- if pooling_type == 'att':
- self.conv_mask = nn.Conv2D(
- self.single_header_inplanes, 1, kernel_size=1)
- self.softmax = nn.Softmax(axis=2)
- else:
- self.avg_pool = nn.AdaptiveAvgPool2D(1)
- if fusion_type == 'channel_add':
- self.channel_add_conv = nn.Sequential(
- nn.Conv2D(
- self.inplanes, self.planes, kernel_size=1),
- nn.LayerNorm([self.planes, 1, 1]),
- nn.ReLU(),
- nn.Conv2D(
- self.planes, self.inplanes, kernel_size=1))
- elif fusion_type == 'channel_concat':
- self.channel_concat_conv = nn.Sequential(
- nn.Conv2D(
- self.inplanes, self.planes, kernel_size=1),
- nn.LayerNorm([self.planes, 1, 1]),
- nn.ReLU(),
- nn.Conv2D(
- self.planes, self.inplanes, kernel_size=1))
- # for concat
- self.cat_conv = nn.Conv2D(
- 2 * self.inplanes, self.inplanes, kernel_size=1)
- elif fusion_type == 'channel_mul':
- self.channel_mul_conv = nn.Sequential(
- nn.Conv2D(
- self.inplanes, self.planes, kernel_size=1),
- nn.LayerNorm([self.planes, 1, 1]),
- nn.ReLU(),
- nn.Conv2D(
- self.planes, self.inplanes, kernel_size=1))
- def spatial_pool(self, x):
- batch, channel, height, width = x.shape
- if self.pooling_type == 'att':
- # [N*headers, C', H , W] C = headers * C'
- x = x.reshape([
- batch * self.headers, self.single_header_inplanes, height, width
- ])
- input_x = x
- # [N*headers, C', H * W] C = headers * C'
- # input_x = input_x.view(batch, channel, height * width)
- input_x = input_x.reshape([
- batch * self.headers, self.single_header_inplanes,
- height * width
- ])
- # [N*headers, 1, C', H * W]
- input_x = input_x.unsqueeze(1)
- # [N*headers, 1, H, W]
- context_mask = self.conv_mask(x)
- # [N*headers, 1, H * W]
- context_mask = context_mask.reshape(
- [batch * self.headers, 1, height * width])
- # scale variance
- if self.att_scale and self.headers > 1:
- context_mask = context_mask / paddle.sqrt(
- self.single_header_inplanes)
- # [N*headers, 1, H * W]
- context_mask = self.softmax(context_mask)
- # [N*headers, 1, H * W, 1]
- context_mask = context_mask.unsqueeze(-1)
- # [N*headers, 1, C', 1] = [N*headers, 1, C', H * W] * [N*headers, 1, H * W, 1]
- context = paddle.matmul(input_x, context_mask)
- # [N, headers * C', 1, 1]
- context = context.reshape(
- [batch, self.headers * self.single_header_inplanes, 1, 1])
- else:
- # [N, C, 1, 1]
- context = self.avg_pool(x)
- return context
- def forward(self, x):
- # [N, C, 1, 1]
- context = self.spatial_pool(x)
- out = x
- if self.fusion_type == 'channel_mul':
- # [N, C, 1, 1]
- channel_mul_term = F.sigmoid(self.channel_mul_conv(context))
- out = out * channel_mul_term
- elif self.fusion_type == 'channel_add':
- # [N, C, 1, 1]
- channel_add_term = self.channel_add_conv(context)
- out = out + channel_add_term
- else:
- # [N, C, 1, 1]
- channel_concat_term = self.channel_concat_conv(context)
- # use concat
- _, C1, _, _ = channel_concat_term.shape
- N, C2, H, W = out.shape
- out = paddle.concat(
- [out, channel_concat_term.expand([-1, -1, H, W])], axis=1)
- out = self.cat_conv(out)
- out = F.layer_norm(out, [self.inplanes, H, W])
- out = F.relu(out)
- return out
|