tsrn.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/FudanVI/FudanOCR/blob/main/text-gestalt/model/tsrn.py
  17. """
  18. import math
  19. import paddle
  20. import paddle.nn.functional as F
  21. from paddle import nn
  22. from collections import OrderedDict
  23. import sys
  24. import numpy as np
  25. import warnings
  26. import math, copy
  27. import cv2
  28. warnings.filterwarnings("ignore")
  29. from .tps_spatial_transformer import TPSSpatialTransformer
  30. from .stn import STN as STN_model
  31. from ppocr.modeling.heads.sr_rensnet_transformer import Transformer
  32. class TSRN(nn.Layer):
  33. def __init__(self,
  34. in_channels,
  35. scale_factor=2,
  36. width=128,
  37. height=32,
  38. STN=False,
  39. srb_nums=5,
  40. mask=False,
  41. hidden_units=32,
  42. infer_mode=False,
  43. **kwargs):
  44. super(TSRN, self).__init__()
  45. in_planes = 3
  46. if mask:
  47. in_planes = 4
  48. assert math.log(scale_factor, 2) % 1 == 0
  49. upsample_block_num = int(math.log(scale_factor, 2))
  50. self.block1 = nn.Sequential(
  51. nn.Conv2D(
  52. in_planes, 2 * hidden_units, kernel_size=9, padding=4),
  53. nn.PReLU())
  54. self.srb_nums = srb_nums
  55. for i in range(srb_nums):
  56. setattr(self, 'block%d' % (i + 2),
  57. RecurrentResidualBlock(2 * hidden_units))
  58. setattr(
  59. self,
  60. 'block%d' % (srb_nums + 2),
  61. nn.Sequential(
  62. nn.Conv2D(
  63. 2 * hidden_units,
  64. 2 * hidden_units,
  65. kernel_size=3,
  66. padding=1),
  67. nn.BatchNorm2D(2 * hidden_units)))
  68. block_ = [
  69. UpsampleBLock(2 * hidden_units, 2)
  70. for _ in range(upsample_block_num)
  71. ]
  72. block_.append(
  73. nn.Conv2D(
  74. 2 * hidden_units, in_planes, kernel_size=9, padding=4))
  75. setattr(self, 'block%d' % (srb_nums + 3), nn.Sequential(*block_))
  76. self.tps_inputsize = [height // scale_factor, width // scale_factor]
  77. tps_outputsize = [height // scale_factor, width // scale_factor]
  78. num_control_points = 20
  79. tps_margins = [0.05, 0.05]
  80. self.stn = STN
  81. if self.stn:
  82. self.tps = TPSSpatialTransformer(
  83. output_image_size=tuple(tps_outputsize),
  84. num_control_points=num_control_points,
  85. margins=tuple(tps_margins))
  86. self.stn_head = STN_model(
  87. in_channels=in_planes,
  88. num_ctrlpoints=num_control_points,
  89. activation='none')
  90. self.out_channels = in_channels
  91. self.r34_transformer = Transformer()
  92. for param in self.r34_transformer.parameters():
  93. param.trainable = False
  94. self.infer_mode = infer_mode
  95. def forward(self, x):
  96. output = {}
  97. if self.infer_mode:
  98. output["lr_img"] = x
  99. y = x
  100. else:
  101. output["lr_img"] = x[0]
  102. output["hr_img"] = x[1]
  103. y = x[0]
  104. if self.stn and self.training:
  105. _, ctrl_points_x = self.stn_head(y)
  106. y, _ = self.tps(y, ctrl_points_x)
  107. block = {'1': self.block1(y)}
  108. for i in range(self.srb_nums + 1):
  109. block[str(i + 2)] = getattr(self,
  110. 'block%d' % (i + 2))(block[str(i + 1)])
  111. block[str(self.srb_nums + 3)] = getattr(self, 'block%d' % (self.srb_nums + 3)) \
  112. ((block['1'] + block[str(self.srb_nums + 2)]))
  113. sr_img = paddle.tanh(block[str(self.srb_nums + 3)])
  114. output["sr_img"] = sr_img
  115. if self.training:
  116. hr_img = x[1]
  117. length = x[2]
  118. input_tensor = x[3]
  119. # add transformer
  120. sr_pred, word_attention_map_pred, _ = self.r34_transformer(
  121. sr_img, length, input_tensor)
  122. hr_pred, word_attention_map_gt, _ = self.r34_transformer(
  123. hr_img, length, input_tensor)
  124. output["hr_img"] = hr_img
  125. output["hr_pred"] = hr_pred
  126. output["word_attention_map_gt"] = word_attention_map_gt
  127. output["sr_pred"] = sr_pred
  128. output["word_attention_map_pred"] = word_attention_map_pred
  129. return output
  130. class RecurrentResidualBlock(nn.Layer):
  131. def __init__(self, channels):
  132. super(RecurrentResidualBlock, self).__init__()
  133. self.conv1 = nn.Conv2D(channels, channels, kernel_size=3, padding=1)
  134. self.bn1 = nn.BatchNorm2D(channels)
  135. self.gru1 = GruBlock(channels, channels)
  136. self.prelu = mish()
  137. self.conv2 = nn.Conv2D(channels, channels, kernel_size=3, padding=1)
  138. self.bn2 = nn.BatchNorm2D(channels)
  139. self.gru2 = GruBlock(channels, channels)
  140. def forward(self, x):
  141. residual = self.conv1(x)
  142. residual = self.bn1(residual)
  143. residual = self.prelu(residual)
  144. residual = self.conv2(residual)
  145. residual = self.bn2(residual)
  146. residual = self.gru1(residual.transpose([0, 1, 3, 2])).transpose(
  147. [0, 1, 3, 2])
  148. return self.gru2(x + residual)
  149. class UpsampleBLock(nn.Layer):
  150. def __init__(self, in_channels, up_scale):
  151. super(UpsampleBLock, self).__init__()
  152. self.conv = nn.Conv2D(
  153. in_channels, in_channels * up_scale**2, kernel_size=3, padding=1)
  154. self.pixel_shuffle = nn.PixelShuffle(up_scale)
  155. self.prelu = mish()
  156. def forward(self, x):
  157. x = self.conv(x)
  158. x = self.pixel_shuffle(x)
  159. x = self.prelu(x)
  160. return x
  161. class mish(nn.Layer):
  162. def __init__(self, ):
  163. super(mish, self).__init__()
  164. self.activated = True
  165. def forward(self, x):
  166. if self.activated:
  167. x = x * (paddle.tanh(F.softplus(x)))
  168. return x
  169. class GruBlock(nn.Layer):
  170. def __init__(self, in_channels, out_channels):
  171. super(GruBlock, self).__init__()
  172. assert out_channels % 2 == 0
  173. self.conv1 = nn.Conv2D(
  174. in_channels, out_channels, kernel_size=1, padding=0)
  175. self.gru = nn.GRU(out_channels,
  176. out_channels // 2,
  177. direction='bidirectional')
  178. def forward(self, x):
  179. # x: b, c, w, h
  180. x = self.conv1(x)
  181. x = x.transpose([0, 2, 3, 1]) # b, w, h, c
  182. batch_size, w, h, c = x.shape
  183. x = x.reshape([-1, h, c]) # b*w, h, c
  184. x, _ = self.gru(x)
  185. x = x.reshape([-1, w, h, c])
  186. x = x.transpose([0, 3, 1, 2])
  187. return x