sr_rensnet_transformer.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/FudanVI/FudanOCR/blob/main/text-gestalt/loss/transformer_english_decomposition.py
  17. """
  18. import copy
  19. import math
  20. import paddle
  21. import paddle.nn as nn
  22. import paddle.nn.functional as F
  23. def subsequent_mask(size):
  24. """Generate a square mask for the sequence. The masked positions are filled with float('-inf').
  25. Unmasked positions are filled with float(0.0).
  26. """
  27. mask = paddle.ones([1, size, size], dtype='float32')
  28. mask_inf = paddle.triu(
  29. paddle.full(
  30. shape=[1, size, size], dtype='float32', fill_value='-inf'),
  31. diagonal=1)
  32. mask = mask + mask_inf
  33. padding_mask = paddle.equal(mask, paddle.to_tensor(1, dtype=mask.dtype))
  34. return padding_mask
  35. def clones(module, N):
  36. return nn.LayerList([copy.deepcopy(module) for _ in range(N)])
  37. def masked_fill(x, mask, value):
  38. y = paddle.full(x.shape, value, x.dtype)
  39. return paddle.where(mask, y, x)
  40. def attention(query, key, value, mask=None, dropout=None, attention_map=None):
  41. d_k = query.shape[-1]
  42. scores = paddle.matmul(query,
  43. paddle.transpose(key, [0, 1, 3, 2])) / math.sqrt(d_k)
  44. if mask is not None:
  45. scores = masked_fill(scores, mask == 0, float('-inf'))
  46. else:
  47. pass
  48. p_attn = F.softmax(scores, axis=-1)
  49. if dropout is not None:
  50. p_attn = dropout(p_attn)
  51. return paddle.matmul(p_attn, value), p_attn
  52. class MultiHeadedAttention(nn.Layer):
  53. def __init__(self, h, d_model, dropout=0.1, compress_attention=False):
  54. super(MultiHeadedAttention, self).__init__()
  55. assert d_model % h == 0
  56. self.d_k = d_model // h
  57. self.h = h
  58. self.linears = clones(nn.Linear(d_model, d_model), 4)
  59. self.attn = None
  60. self.dropout = nn.Dropout(p=dropout, mode="downscale_in_infer")
  61. self.compress_attention = compress_attention
  62. self.compress_attention_linear = nn.Linear(h, 1)
  63. def forward(self, query, key, value, mask=None, attention_map=None):
  64. if mask is not None:
  65. mask = mask.unsqueeze(1)
  66. nbatches = query.shape[0]
  67. query, key, value = \
  68. [paddle.transpose(l(x).reshape([nbatches, -1, self.h, self.d_k]), [0,2,1,3])
  69. for l, x in zip(self.linears, (query, key, value))]
  70. x, attention_map = attention(
  71. query,
  72. key,
  73. value,
  74. mask=mask,
  75. dropout=self.dropout,
  76. attention_map=attention_map)
  77. x = paddle.reshape(
  78. paddle.transpose(x, [0, 2, 1, 3]),
  79. [nbatches, -1, self.h * self.d_k])
  80. return self.linears[-1](x), attention_map
  81. class ResNet(nn.Layer):
  82. def __init__(self, num_in, block, layers):
  83. super(ResNet, self).__init__()
  84. self.conv1 = nn.Conv2D(num_in, 64, kernel_size=3, stride=1, padding=1)
  85. self.bn1 = nn.BatchNorm2D(64, use_global_stats=True)
  86. self.relu1 = nn.ReLU()
  87. self.pool = nn.MaxPool2D((2, 2), (2, 2))
  88. self.conv2 = nn.Conv2D(64, 128, kernel_size=3, stride=1, padding=1)
  89. self.bn2 = nn.BatchNorm2D(128, use_global_stats=True)
  90. self.relu2 = nn.ReLU()
  91. self.layer1_pool = nn.MaxPool2D((2, 2), (2, 2))
  92. self.layer1 = self._make_layer(block, 128, 256, layers[0])
  93. self.layer1_conv = nn.Conv2D(256, 256, 3, 1, 1)
  94. self.layer1_bn = nn.BatchNorm2D(256, use_global_stats=True)
  95. self.layer1_relu = nn.ReLU()
  96. self.layer2_pool = nn.MaxPool2D((2, 2), (2, 2))
  97. self.layer2 = self._make_layer(block, 256, 256, layers[1])
  98. self.layer2_conv = nn.Conv2D(256, 256, 3, 1, 1)
  99. self.layer2_bn = nn.BatchNorm2D(256, use_global_stats=True)
  100. self.layer2_relu = nn.ReLU()
  101. self.layer3_pool = nn.MaxPool2D((2, 2), (2, 2))
  102. self.layer3 = self._make_layer(block, 256, 512, layers[2])
  103. self.layer3_conv = nn.Conv2D(512, 512, 3, 1, 1)
  104. self.layer3_bn = nn.BatchNorm2D(512, use_global_stats=True)
  105. self.layer3_relu = nn.ReLU()
  106. self.layer4_pool = nn.MaxPool2D((2, 2), (2, 2))
  107. self.layer4 = self._make_layer(block, 512, 512, layers[3])
  108. self.layer4_conv2 = nn.Conv2D(512, 1024, 3, 1, 1)
  109. self.layer4_conv2_bn = nn.BatchNorm2D(1024, use_global_stats=True)
  110. self.layer4_conv2_relu = nn.ReLU()
  111. def _make_layer(self, block, inplanes, planes, blocks):
  112. if inplanes != planes:
  113. downsample = nn.Sequential(
  114. nn.Conv2D(inplanes, planes, 3, 1, 1),
  115. nn.BatchNorm2D(
  116. planes, use_global_stats=True), )
  117. else:
  118. downsample = None
  119. layers = []
  120. layers.append(block(inplanes, planes, downsample))
  121. for i in range(1, blocks):
  122. layers.append(block(planes, planes, downsample=None))
  123. return nn.Sequential(*layers)
  124. def forward(self, x):
  125. x = self.conv1(x)
  126. x = self.bn1(x)
  127. x = self.relu1(x)
  128. x = self.pool(x)
  129. x = self.conv2(x)
  130. x = self.bn2(x)
  131. x = self.relu2(x)
  132. x = self.layer1_pool(x)
  133. x = self.layer1(x)
  134. x = self.layer1_conv(x)
  135. x = self.layer1_bn(x)
  136. x = self.layer1_relu(x)
  137. x = self.layer2(x)
  138. x = self.layer2_conv(x)
  139. x = self.layer2_bn(x)
  140. x = self.layer2_relu(x)
  141. x = self.layer3(x)
  142. x = self.layer3_conv(x)
  143. x = self.layer3_bn(x)
  144. x = self.layer3_relu(x)
  145. x = self.layer4(x)
  146. x = self.layer4_conv2(x)
  147. x = self.layer4_conv2_bn(x)
  148. x = self.layer4_conv2_relu(x)
  149. return x
  150. class Bottleneck(nn.Layer):
  151. def __init__(self, input_dim):
  152. super(Bottleneck, self).__init__()
  153. self.conv1 = nn.Conv2D(input_dim, input_dim, 1)
  154. self.bn1 = nn.BatchNorm2D(input_dim, use_global_stats=True)
  155. self.relu = nn.ReLU()
  156. self.conv2 = nn.Conv2D(input_dim, input_dim, 3, 1, 1)
  157. self.bn2 = nn.BatchNorm2D(input_dim, use_global_stats=True)
  158. def forward(self, x):
  159. residual = x
  160. out = self.conv1(x)
  161. out = self.bn1(out)
  162. out = self.relu(out)
  163. out = self.conv2(out)
  164. out = self.bn2(out)
  165. out += residual
  166. out = self.relu(out)
  167. return out
  168. class PositionalEncoding(nn.Layer):
  169. "Implement the PE function."
  170. def __init__(self, dropout, dim, max_len=5000):
  171. super(PositionalEncoding, self).__init__()
  172. self.dropout = nn.Dropout(p=dropout, mode="downscale_in_infer")
  173. pe = paddle.zeros([max_len, dim])
  174. position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1)
  175. div_term = paddle.exp(
  176. paddle.arange(0, dim, 2).astype('float32') *
  177. (-math.log(10000.0) / dim))
  178. pe[:, 0::2] = paddle.sin(position * div_term)
  179. pe[:, 1::2] = paddle.cos(position * div_term)
  180. pe = paddle.unsqueeze(pe, 0)
  181. self.register_buffer('pe', pe)
  182. def forward(self, x):
  183. x = x + self.pe[:, :paddle.shape(x)[1]]
  184. return self.dropout(x)
  185. class PositionwiseFeedForward(nn.Layer):
  186. "Implements FFN equation."
  187. def __init__(self, d_model, d_ff, dropout=0.1):
  188. super(PositionwiseFeedForward, self).__init__()
  189. self.w_1 = nn.Linear(d_model, d_ff)
  190. self.w_2 = nn.Linear(d_ff, d_model)
  191. self.dropout = nn.Dropout(dropout, mode="downscale_in_infer")
  192. def forward(self, x):
  193. return self.w_2(self.dropout(F.relu(self.w_1(x))))
  194. class Generator(nn.Layer):
  195. "Define standard linear + softmax generation step."
  196. def __init__(self, d_model, vocab):
  197. super(Generator, self).__init__()
  198. self.proj = nn.Linear(d_model, vocab)
  199. self.relu = nn.ReLU()
  200. def forward(self, x):
  201. out = self.proj(x)
  202. return out
  203. class Embeddings(nn.Layer):
  204. def __init__(self, d_model, vocab):
  205. super(Embeddings, self).__init__()
  206. self.lut = nn.Embedding(vocab, d_model)
  207. self.d_model = d_model
  208. def forward(self, x):
  209. embed = self.lut(x) * math.sqrt(self.d_model)
  210. return embed
  211. class LayerNorm(nn.Layer):
  212. "Construct a layernorm module (See citation for details)."
  213. def __init__(self, features, eps=1e-6):
  214. super(LayerNorm, self).__init__()
  215. self.a_2 = self.create_parameter(
  216. shape=[features],
  217. default_initializer=paddle.nn.initializer.Constant(1.0))
  218. self.b_2 = self.create_parameter(
  219. shape=[features],
  220. default_initializer=paddle.nn.initializer.Constant(0.0))
  221. self.eps = eps
  222. def forward(self, x):
  223. mean = x.mean(-1, keepdim=True)
  224. std = x.std(-1, keepdim=True)
  225. return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
  226. class Decoder(nn.Layer):
  227. def __init__(self):
  228. super(Decoder, self).__init__()
  229. self.mask_multihead = MultiHeadedAttention(
  230. h=16, d_model=1024, dropout=0.1)
  231. self.mul_layernorm1 = LayerNorm(1024)
  232. self.multihead = MultiHeadedAttention(h=16, d_model=1024, dropout=0.1)
  233. self.mul_layernorm2 = LayerNorm(1024)
  234. self.pff = PositionwiseFeedForward(1024, 2048)
  235. self.mul_layernorm3 = LayerNorm(1024)
  236. def forward(self, text, conv_feature, attention_map=None):
  237. text_max_length = text.shape[1]
  238. mask = subsequent_mask(text_max_length)
  239. result = text
  240. result = self.mul_layernorm1(result + self.mask_multihead(
  241. text, text, text, mask=mask)[0])
  242. b, c, h, w = conv_feature.shape
  243. conv_feature = paddle.transpose(
  244. conv_feature.reshape([b, c, h * w]), [0, 2, 1])
  245. word_image_align, attention_map = self.multihead(
  246. result,
  247. conv_feature,
  248. conv_feature,
  249. mask=None,
  250. attention_map=attention_map)
  251. result = self.mul_layernorm2(result + word_image_align)
  252. result = self.mul_layernorm3(result + self.pff(result))
  253. return result, attention_map
  254. class BasicBlock(nn.Layer):
  255. def __init__(self, inplanes, planes, downsample):
  256. super(BasicBlock, self).__init__()
  257. self.conv1 = nn.Conv2D(
  258. inplanes, planes, kernel_size=3, stride=1, padding=1)
  259. self.bn1 = nn.BatchNorm2D(planes, use_global_stats=True)
  260. self.relu = nn.ReLU()
  261. self.conv2 = nn.Conv2D(
  262. planes, planes, kernel_size=3, stride=1, padding=1)
  263. self.bn2 = nn.BatchNorm2D(planes, use_global_stats=True)
  264. self.downsample = downsample
  265. def forward(self, x):
  266. residual = x
  267. out = self.conv1(x)
  268. out = self.bn1(out)
  269. out = self.relu(out)
  270. out = self.conv2(out)
  271. out = self.bn2(out)
  272. if self.downsample != None:
  273. residual = self.downsample(residual)
  274. out += residual
  275. out = self.relu(out)
  276. return out
  277. class Encoder(nn.Layer):
  278. def __init__(self):
  279. super(Encoder, self).__init__()
  280. self.cnn = ResNet(num_in=1, block=BasicBlock, layers=[1, 2, 5, 3])
  281. def forward(self, input):
  282. conv_result = self.cnn(input)
  283. return conv_result
  284. class Transformer(nn.Layer):
  285. def __init__(self, in_channels=1, alphabet='0123456789'):
  286. super(Transformer, self).__init__()
  287. self.alphabet = alphabet
  288. word_n_class = self.get_alphabet_len()
  289. self.embedding_word_with_upperword = Embeddings(512, word_n_class)
  290. self.pe = PositionalEncoding(dim=512, dropout=0.1, max_len=5000)
  291. self.encoder = Encoder()
  292. self.decoder = Decoder()
  293. self.generator_word_with_upperword = Generator(1024, word_n_class)
  294. for p in self.parameters():
  295. if p.dim() > 1:
  296. nn.initializer.XavierNormal(p)
  297. def get_alphabet_len(self):
  298. return len(self.alphabet)
  299. def forward(self, image, text_length, text_input, attention_map=None):
  300. if image.shape[1] == 3:
  301. R = image[:, 0:1, :, :]
  302. G = image[:, 1:2, :, :]
  303. B = image[:, 2:3, :, :]
  304. image = 0.299 * R + 0.587 * G + 0.114 * B
  305. conv_feature = self.encoder(image) # batch, 1024, 8, 32
  306. max_length = max(text_length)
  307. text_input = text_input[:, :max_length]
  308. text_embedding = self.embedding_word_with_upperword(
  309. text_input) # batch, text_max_length, 512
  310. postion_embedding = self.pe(
  311. paddle.zeros(text_embedding.shape)) # batch, text_max_length, 512
  312. text_input_with_pe = paddle.concat([text_embedding, postion_embedding],
  313. 2) # batch, text_max_length, 1024
  314. batch, seq_len, _ = text_input_with_pe.shape
  315. text_input_with_pe, word_attention_map = self.decoder(
  316. text_input_with_pe, conv_feature)
  317. word_decoder_result = self.generator_word_with_upperword(
  318. text_input_with_pe)
  319. if self.training:
  320. total_length = paddle.sum(text_length)
  321. probs_res = paddle.zeros([total_length, self.get_alphabet_len()])
  322. start = 0
  323. for index, length in enumerate(text_length):
  324. length = int(length.numpy())
  325. probs_res[start:start + length, :] = word_decoder_result[
  326. index, 0:0 + length, :]
  327. start = start + length
  328. return probs_res, word_attention_map, None
  329. else:
  330. return word_decoder_result