rec_visionlan_head.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/wangyuxin87/VisionLAN
  17. """
  18. from __future__ import absolute_import
  19. from __future__ import division
  20. from __future__ import print_function
  21. import paddle
  22. from paddle import ParamAttr
  23. import paddle.nn as nn
  24. import paddle.nn.functional as F
  25. from paddle.nn.initializer import Normal, XavierNormal
  26. import numpy as np
  27. class PositionalEncoding(nn.Layer):
  28. def __init__(self, d_hid, n_position=200):
  29. super(PositionalEncoding, self).__init__()
  30. self.register_buffer(
  31. 'pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
  32. def _get_sinusoid_encoding_table(self, n_position, d_hid):
  33. ''' Sinusoid position encoding table '''
  34. def get_position_angle_vec(position):
  35. return [
  36. position / np.power(10000, 2 * (hid_j // 2) / d_hid)
  37. for hid_j in range(d_hid)
  38. ]
  39. sinusoid_table = np.array(
  40. [get_position_angle_vec(pos_i) for pos_i in range(n_position)])
  41. sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
  42. sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
  43. sinusoid_table = paddle.to_tensor(sinusoid_table, dtype='float32')
  44. sinusoid_table = paddle.unsqueeze(sinusoid_table, axis=0)
  45. return sinusoid_table
  46. def forward(self, x):
  47. return x + self.pos_table[:, :x.shape[1]].clone().detach()
  48. class ScaledDotProductAttention(nn.Layer):
  49. "Scaled Dot-Product Attention"
  50. def __init__(self, temperature, attn_dropout=0.1):
  51. super(ScaledDotProductAttention, self).__init__()
  52. self.temperature = temperature
  53. self.dropout = nn.Dropout(attn_dropout)
  54. self.softmax = nn.Softmax(axis=2)
  55. def forward(self, q, k, v, mask=None):
  56. k = paddle.transpose(k, perm=[0, 2, 1])
  57. attn = paddle.bmm(q, k)
  58. attn = attn / self.temperature
  59. if mask is not None:
  60. attn = attn.masked_fill(mask, -1e9)
  61. if mask.dim() == 3:
  62. mask = paddle.unsqueeze(mask, axis=1)
  63. elif mask.dim() == 2:
  64. mask = paddle.unsqueeze(mask, axis=1)
  65. mask = paddle.unsqueeze(mask, axis=1)
  66. repeat_times = [
  67. attn.shape[1] // mask.shape[1], attn.shape[2] // mask.shape[2]
  68. ]
  69. mask = paddle.tile(mask, [1, repeat_times[0], repeat_times[1], 1])
  70. attn[mask == 0] = -1e9
  71. attn = self.softmax(attn)
  72. attn = self.dropout(attn)
  73. output = paddle.bmm(attn, v)
  74. return output
  75. class MultiHeadAttention(nn.Layer):
  76. " Multi-Head Attention module"
  77. def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
  78. super(MultiHeadAttention, self).__init__()
  79. self.n_head = n_head
  80. self.d_k = d_k
  81. self.d_v = d_v
  82. self.w_qs = nn.Linear(
  83. d_model,
  84. n_head * d_k,
  85. weight_attr=ParamAttr(initializer=Normal(
  86. mean=0, std=np.sqrt(2.0 / (d_model + d_k)))))
  87. self.w_ks = nn.Linear(
  88. d_model,
  89. n_head * d_k,
  90. weight_attr=ParamAttr(initializer=Normal(
  91. mean=0, std=np.sqrt(2.0 / (d_model + d_k)))))
  92. self.w_vs = nn.Linear(
  93. d_model,
  94. n_head * d_v,
  95. weight_attr=ParamAttr(initializer=Normal(
  96. mean=0, std=np.sqrt(2.0 / (d_model + d_v)))))
  97. self.attention = ScaledDotProductAttention(temperature=np.power(d_k,
  98. 0.5))
  99. self.layer_norm = nn.LayerNorm(d_model)
  100. self.fc = nn.Linear(
  101. n_head * d_v,
  102. d_model,
  103. weight_attr=ParamAttr(initializer=XavierNormal()))
  104. self.dropout = nn.Dropout(dropout)
  105. def forward(self, q, k, v, mask=None):
  106. d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
  107. sz_b, len_q, _ = q.shape
  108. sz_b, len_k, _ = k.shape
  109. sz_b, len_v, _ = v.shape
  110. residual = q
  111. q = self.w_qs(q)
  112. q = paddle.reshape(
  113. q, shape=[-1, len_q, n_head, d_k]) # 4*21*512 ---- 4*21*8*64
  114. k = self.w_ks(k)
  115. k = paddle.reshape(k, shape=[-1, len_k, n_head, d_k])
  116. v = self.w_vs(v)
  117. v = paddle.reshape(v, shape=[-1, len_v, n_head, d_v])
  118. q = paddle.transpose(q, perm=[2, 0, 1, 3])
  119. q = paddle.reshape(q, shape=[-1, len_q, d_k]) # (n*b) x lq x dk
  120. k = paddle.transpose(k, perm=[2, 0, 1, 3])
  121. k = paddle.reshape(k, shape=[-1, len_k, d_k]) # (n*b) x lk x dk
  122. v = paddle.transpose(v, perm=[2, 0, 1, 3])
  123. v = paddle.reshape(v, shape=[-1, len_v, d_v]) # (n*b) x lv x dv
  124. mask = paddle.tile(
  125. mask,
  126. [n_head, 1, 1]) if mask is not None else None # (n*b) x .. x ..
  127. output = self.attention(q, k, v, mask=mask)
  128. output = paddle.reshape(output, shape=[n_head, -1, len_q, d_v])
  129. output = paddle.transpose(output, perm=[1, 2, 0, 3])
  130. output = paddle.reshape(
  131. output, shape=[-1, len_q, n_head * d_v]) # b x lq x (n*dv)
  132. output = self.dropout(self.fc(output))
  133. output = self.layer_norm(output + residual)
  134. return output
  135. class PositionwiseFeedForward(nn.Layer):
  136. def __init__(self, d_in, d_hid, dropout=0.1):
  137. super(PositionwiseFeedForward, self).__init__()
  138. self.w_1 = nn.Conv1D(d_in, d_hid, 1) # position-wise
  139. self.w_2 = nn.Conv1D(d_hid, d_in, 1) # position-wise
  140. self.layer_norm = nn.LayerNorm(d_in)
  141. self.dropout = nn.Dropout(dropout)
  142. def forward(self, x):
  143. residual = x
  144. x = paddle.transpose(x, perm=[0, 2, 1])
  145. x = self.w_2(F.relu(self.w_1(x)))
  146. x = paddle.transpose(x, perm=[0, 2, 1])
  147. x = self.dropout(x)
  148. x = self.layer_norm(x + residual)
  149. return x
  150. class EncoderLayer(nn.Layer):
  151. ''' Compose with two layers '''
  152. def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
  153. super(EncoderLayer, self).__init__()
  154. self.slf_attn = MultiHeadAttention(
  155. n_head, d_model, d_k, d_v, dropout=dropout)
  156. self.pos_ffn = PositionwiseFeedForward(
  157. d_model, d_inner, dropout=dropout)
  158. def forward(self, enc_input, slf_attn_mask=None):
  159. enc_output = self.slf_attn(
  160. enc_input, enc_input, enc_input, mask=slf_attn_mask)
  161. enc_output = self.pos_ffn(enc_output)
  162. return enc_output
  163. class Transformer_Encoder(nn.Layer):
  164. def __init__(self,
  165. n_layers=2,
  166. n_head=8,
  167. d_word_vec=512,
  168. d_k=64,
  169. d_v=64,
  170. d_model=512,
  171. d_inner=2048,
  172. dropout=0.1,
  173. n_position=256):
  174. super(Transformer_Encoder, self).__init__()
  175. self.position_enc = PositionalEncoding(
  176. d_word_vec, n_position=n_position)
  177. self.dropout = nn.Dropout(p=dropout)
  178. self.layer_stack = nn.LayerList([
  179. EncoderLayer(
  180. d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
  181. for _ in range(n_layers)
  182. ])
  183. self.layer_norm = nn.LayerNorm(d_model, epsilon=1e-6)
  184. def forward(self, enc_output, src_mask, return_attns=False):
  185. enc_output = self.dropout(
  186. self.position_enc(enc_output)) # position embeding
  187. for enc_layer in self.layer_stack:
  188. enc_output = enc_layer(enc_output, slf_attn_mask=src_mask)
  189. enc_output = self.layer_norm(enc_output)
  190. return enc_output
  191. class PP_layer(nn.Layer):
  192. def __init__(self, n_dim=512, N_max_character=25, n_position=256):
  193. super(PP_layer, self).__init__()
  194. self.character_len = N_max_character
  195. self.f0_embedding = nn.Embedding(N_max_character, n_dim)
  196. self.w0 = nn.Linear(N_max_character, n_position)
  197. self.wv = nn.Linear(n_dim, n_dim)
  198. self.we = nn.Linear(n_dim, N_max_character)
  199. self.active = nn.Tanh()
  200. self.softmax = nn.Softmax(axis=2)
  201. def forward(self, enc_output):
  202. # enc_output: b,256,512
  203. reading_order = paddle.arange(self.character_len, dtype='int64')
  204. reading_order = reading_order.unsqueeze(0).expand(
  205. [enc_output.shape[0], self.character_len]) # (S,) -> (B, S)
  206. reading_order = self.f0_embedding(reading_order) # b,25,512
  207. # calculate attention
  208. reading_order = paddle.transpose(reading_order, perm=[0, 2, 1])
  209. t = self.w0(reading_order) # b,512,256
  210. t = self.active(
  211. paddle.transpose(
  212. t, perm=[0, 2, 1]) + self.wv(enc_output)) # b,256,512
  213. t = self.we(t) # b,256,25
  214. t = self.softmax(paddle.transpose(t, perm=[0, 2, 1])) # b,25,256
  215. g_output = paddle.bmm(t, enc_output) # b,25,512
  216. return g_output
  217. class Prediction(nn.Layer):
  218. def __init__(self,
  219. n_dim=512,
  220. n_position=256,
  221. N_max_character=25,
  222. n_class=37):
  223. super(Prediction, self).__init__()
  224. self.pp = PP_layer(
  225. n_dim=n_dim, N_max_character=N_max_character, n_position=n_position)
  226. self.pp_share = PP_layer(
  227. n_dim=n_dim, N_max_character=N_max_character, n_position=n_position)
  228. self.w_vrm = nn.Linear(n_dim, n_class) # output layer
  229. self.w_share = nn.Linear(n_dim, n_class) # output layer
  230. self.nclass = n_class
  231. def forward(self, cnn_feature, f_res, f_sub, train_mode=False,
  232. use_mlm=True):
  233. if train_mode:
  234. if not use_mlm:
  235. g_output = self.pp(cnn_feature) # b,25,512
  236. g_output = self.w_vrm(g_output)
  237. f_res = 0
  238. f_sub = 0
  239. return g_output, f_res, f_sub
  240. g_output = self.pp(cnn_feature) # b,25,512
  241. f_res = self.pp_share(f_res)
  242. f_sub = self.pp_share(f_sub)
  243. g_output = self.w_vrm(g_output)
  244. f_res = self.w_share(f_res)
  245. f_sub = self.w_share(f_sub)
  246. return g_output, f_res, f_sub
  247. else:
  248. g_output = self.pp(cnn_feature) # b,25,512
  249. g_output = self.w_vrm(g_output)
  250. return g_output
  251. class MLM(nn.Layer):
  252. "Architecture of MLM"
  253. def __init__(self, n_dim=512, n_position=256, max_text_length=25):
  254. super(MLM, self).__init__()
  255. self.MLM_SequenceModeling_mask = Transformer_Encoder(
  256. n_layers=2, n_position=n_position)
  257. self.MLM_SequenceModeling_WCL = Transformer_Encoder(
  258. n_layers=1, n_position=n_position)
  259. self.pos_embedding = nn.Embedding(max_text_length, n_dim)
  260. self.w0_linear = nn.Linear(1, n_position)
  261. self.wv = nn.Linear(n_dim, n_dim)
  262. self.active = nn.Tanh()
  263. self.we = nn.Linear(n_dim, 1)
  264. self.sigmoid = nn.Sigmoid()
  265. def forward(self, x, label_pos):
  266. # transformer unit for generating mask_c
  267. feature_v_seq = self.MLM_SequenceModeling_mask(x, src_mask=None)
  268. # position embedding layer
  269. label_pos = paddle.to_tensor(label_pos, dtype='int64')
  270. pos_emb = self.pos_embedding(label_pos)
  271. pos_emb = self.w0_linear(paddle.unsqueeze(pos_emb, axis=2))
  272. pos_emb = paddle.transpose(pos_emb, perm=[0, 2, 1])
  273. # fusion position embedding with features V & generate mask_c
  274. att_map_sub = self.active(pos_emb + self.wv(feature_v_seq))
  275. att_map_sub = self.we(att_map_sub) # b,256,1
  276. att_map_sub = paddle.transpose(att_map_sub, perm=[0, 2, 1])
  277. att_map_sub = self.sigmoid(att_map_sub) # b,1,256
  278. # WCL
  279. ## generate inputs for WCL
  280. att_map_sub = paddle.transpose(att_map_sub, perm=[0, 2, 1])
  281. f_res = x * (1 - att_map_sub) # second path with remaining string
  282. f_sub = x * att_map_sub # first path with occluded character
  283. ## transformer units in WCL
  284. f_res = self.MLM_SequenceModeling_WCL(f_res, src_mask=None)
  285. f_sub = self.MLM_SequenceModeling_WCL(f_sub, src_mask=None)
  286. return f_res, f_sub, att_map_sub
  287. def trans_1d_2d(x):
  288. b, w_h, c = x.shape # b, 256, 512
  289. x = paddle.transpose(x, perm=[0, 2, 1])
  290. x = paddle.reshape(x, [-1, c, 32, 8])
  291. x = paddle.transpose(x, perm=[0, 1, 3, 2]) # [b, c, 8, 32]
  292. return x
  293. class MLM_VRM(nn.Layer):
  294. """
  295. MLM+VRM, MLM is only used in training.
  296. ratio controls the occluded number in a batch.
  297. The pipeline of VisionLAN in testing is very concise with only a backbone + sequence modeling(transformer unit) + prediction layer(pp layer).
  298. x: input image
  299. label_pos: character index
  300. training_step: LF or LA process
  301. output
  302. text_pre: prediction of VRM
  303. test_rem: prediction of remaining string in MLM
  304. text_mas: prediction of occluded character in MLM
  305. mask_c_show: visualization of Mask_c
  306. """
  307. def __init__(self,
  308. n_layers=3,
  309. n_position=256,
  310. n_dim=512,
  311. max_text_length=25,
  312. nclass=37):
  313. super(MLM_VRM, self).__init__()
  314. self.MLM = MLM(n_dim=n_dim,
  315. n_position=n_position,
  316. max_text_length=max_text_length)
  317. self.SequenceModeling = Transformer_Encoder(
  318. n_layers=n_layers, n_position=n_position)
  319. self.Prediction = Prediction(
  320. n_dim=n_dim,
  321. n_position=n_position,
  322. N_max_character=max_text_length +
  323. 1, # N_max_character = 1 eos + 25 characters
  324. n_class=nclass)
  325. self.nclass = nclass
  326. self.max_text_length = max_text_length
  327. def forward(self, x, label_pos, training_step, train_mode=False):
  328. b, c, h, w = x.shape
  329. nT = self.max_text_length
  330. x = paddle.transpose(x, perm=[0, 1, 3, 2])
  331. x = paddle.reshape(x, [-1, c, h * w])
  332. x = paddle.transpose(x, perm=[0, 2, 1])
  333. if train_mode:
  334. if training_step == 'LF_1':
  335. f_res = 0
  336. f_sub = 0
  337. x = self.SequenceModeling(x, src_mask=None)
  338. text_pre, test_rem, text_mas = self.Prediction(
  339. x, f_res, f_sub, train_mode=True, use_mlm=False)
  340. return text_pre, text_pre, text_pre, text_pre
  341. elif training_step == 'LF_2':
  342. # MLM
  343. f_res, f_sub, mask_c = self.MLM(x, label_pos)
  344. x = self.SequenceModeling(x, src_mask=None)
  345. text_pre, test_rem, text_mas = self.Prediction(
  346. x, f_res, f_sub, train_mode=True)
  347. mask_c_show = trans_1d_2d(mask_c)
  348. return text_pre, test_rem, text_mas, mask_c_show
  349. elif training_step == 'LA':
  350. # MLM
  351. f_res, f_sub, mask_c = self.MLM(x, label_pos)
  352. ## use the mask_c (1 for occluded character and 0 for remaining characters) to occlude input
  353. ## ratio controls the occluded number in a batch
  354. character_mask = paddle.zeros_like(mask_c)
  355. ratio = b // 2
  356. if ratio >= 1:
  357. with paddle.no_grad():
  358. character_mask[0:ratio, :, :] = mask_c[0:ratio, :, :]
  359. else:
  360. character_mask = mask_c
  361. x = x * (1 - character_mask)
  362. # VRM
  363. ## transformer unit for VRM
  364. x = self.SequenceModeling(x, src_mask=None)
  365. ## prediction layer for MLM and VSR
  366. text_pre, test_rem, text_mas = self.Prediction(
  367. x, f_res, f_sub, train_mode=True)
  368. mask_c_show = trans_1d_2d(mask_c)
  369. return text_pre, test_rem, text_mas, mask_c_show
  370. else:
  371. raise NotImplementedError
  372. else: # VRM is only used in the testing stage
  373. f_res = 0
  374. f_sub = 0
  375. contextual_feature = self.SequenceModeling(x, src_mask=None)
  376. text_pre = self.Prediction(
  377. contextual_feature,
  378. f_res,
  379. f_sub,
  380. train_mode=False,
  381. use_mlm=False)
  382. text_pre = paddle.transpose(
  383. text_pre, perm=[1, 0, 2]) # (26, b, 37))
  384. return text_pre, x
  385. class VLHead(nn.Layer):
  386. """
  387. Architecture of VisionLAN
  388. """
  389. def __init__(self,
  390. in_channels,
  391. out_channels=36,
  392. n_layers=3,
  393. n_position=256,
  394. n_dim=512,
  395. max_text_length=25,
  396. training_step='LA'):
  397. super(VLHead, self).__init__()
  398. self.MLM_VRM = MLM_VRM(
  399. n_layers=n_layers,
  400. n_position=n_position,
  401. n_dim=n_dim,
  402. max_text_length=max_text_length,
  403. nclass=out_channels + 1)
  404. self.training_step = training_step
  405. def forward(self, feat, targets=None):
  406. if self.training:
  407. label_pos = targets[-2]
  408. text_pre, test_rem, text_mas, mask_map = self.MLM_VRM(
  409. feat, label_pos, self.training_step, train_mode=True)
  410. return text_pre, test_rem, text_mas, mask_map
  411. else:
  412. text_pre, x = self.MLM_VRM(
  413. feat, targets, self.training_step, train_mode=False)
  414. return text_pre, x