rec_nrtr_head.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import paddle
  16. from paddle import nn
  17. import paddle.nn.functional as F
  18. from paddle.nn import LayerList
  19. from paddle.nn import Dropout, Linear, LayerNorm
  20. import numpy as np
  21. from ppocr.modeling.backbones.rec_svtrnet import Mlp, zeros_, ones_
  22. from paddle.nn.initializer import XavierNormal as xavier_normal_
  23. class Transformer(nn.Layer):
  24. """A transformer model. User is able to modify the attributes as needed. The architechture
  25. is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
  26. Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
  27. Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
  28. Processing Systems, pages 6000-6010.
  29. Args:
  30. d_model: the number of expected features in the encoder/decoder inputs (default=512).
  31. nhead: the number of heads in the multiheadattention models (default=8).
  32. num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
  33. num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
  34. dim_feedforward: the dimension of the feedforward network model (default=2048).
  35. dropout: the dropout value (default=0.1).
  36. custom_encoder: custom encoder (default=None).
  37. custom_decoder: custom decoder (default=None).
  38. """
  39. def __init__(self,
  40. d_model=512,
  41. nhead=8,
  42. num_encoder_layers=6,
  43. beam_size=0,
  44. num_decoder_layers=6,
  45. max_len=25,
  46. dim_feedforward=1024,
  47. attention_dropout_rate=0.0,
  48. residual_dropout_rate=0.1,
  49. in_channels=0,
  50. out_channels=0,
  51. scale_embedding=True):
  52. super(Transformer, self).__init__()
  53. self.out_channels = out_channels + 1
  54. self.max_len = max_len
  55. self.embedding = Embeddings(
  56. d_model=d_model,
  57. vocab=self.out_channels,
  58. padding_idx=0,
  59. scale_embedding=scale_embedding)
  60. self.positional_encoding = PositionalEncoding(
  61. dropout=residual_dropout_rate, dim=d_model)
  62. if num_encoder_layers > 0:
  63. self.encoder = nn.LayerList([
  64. TransformerBlock(
  65. d_model,
  66. nhead,
  67. dim_feedforward,
  68. attention_dropout_rate,
  69. residual_dropout_rate,
  70. with_self_attn=True,
  71. with_cross_attn=False) for i in range(num_encoder_layers)
  72. ])
  73. else:
  74. self.encoder = None
  75. self.decoder = nn.LayerList([
  76. TransformerBlock(
  77. d_model,
  78. nhead,
  79. dim_feedforward,
  80. attention_dropout_rate,
  81. residual_dropout_rate,
  82. with_self_attn=True,
  83. with_cross_attn=True) for i in range(num_decoder_layers)
  84. ])
  85. self.beam_size = beam_size
  86. self.d_model = d_model
  87. self.nhead = nhead
  88. self.tgt_word_prj = nn.Linear(
  89. d_model, self.out_channels, bias_attr=False)
  90. w0 = np.random.normal(0.0, d_model**-0.5,
  91. (d_model, self.out_channels)).astype(np.float32)
  92. self.tgt_word_prj.weight.set_value(w0)
  93. self.apply(self._init_weights)
  94. def _init_weights(self, m):
  95. if isinstance(m, nn.Linear):
  96. xavier_normal_(m.weight)
  97. if m.bias is not None:
  98. zeros_(m.bias)
  99. def forward_train(self, src, tgt):
  100. tgt = tgt[:, :-1]
  101. tgt = self.embedding(tgt)
  102. tgt = self.positional_encoding(tgt)
  103. tgt_mask = self.generate_square_subsequent_mask(tgt.shape[1])
  104. if self.encoder is not None:
  105. src = self.positional_encoding(src)
  106. for encoder_layer in self.encoder:
  107. src = encoder_layer(src)
  108. memory = src # B N C
  109. else:
  110. memory = src # B N C
  111. for decoder_layer in self.decoder:
  112. tgt = decoder_layer(tgt, memory, self_mask=tgt_mask)
  113. output = tgt
  114. logit = self.tgt_word_prj(output)
  115. return logit
  116. def forward(self, src, targets=None):
  117. """Take in and process masked source/target sequences.
  118. Args:
  119. src: the sequence to the encoder (required).
  120. tgt: the sequence to the decoder (required).
  121. Shape:
  122. - src: :math:`(B, sN, C)`.
  123. - tgt: :math:`(B, tN, C)`.
  124. Examples:
  125. >>> output = transformer_model(src, tgt)
  126. """
  127. if self.training:
  128. max_len = targets[1].max()
  129. tgt = targets[0][:, :2 + max_len]
  130. return self.forward_train(src, tgt)
  131. else:
  132. if self.beam_size > 0:
  133. return self.forward_beam(src)
  134. else:
  135. return self.forward_test(src)
  136. def forward_test(self, src):
  137. bs = paddle.shape(src)[0]
  138. if self.encoder is not None:
  139. src = self.positional_encoding(src)
  140. for encoder_layer in self.encoder:
  141. src = encoder_layer(src)
  142. memory = src # B N C
  143. else:
  144. memory = src
  145. dec_seq = paddle.full((bs, 1), 2, dtype=paddle.int64)
  146. dec_prob = paddle.full((bs, 1), 1., dtype=paddle.float32)
  147. for len_dec_seq in range(1, paddle.to_tensor(self.max_len)):
  148. dec_seq_embed = self.embedding(dec_seq)
  149. dec_seq_embed = self.positional_encoding(dec_seq_embed)
  150. tgt_mask = self.generate_square_subsequent_mask(
  151. paddle.shape(dec_seq_embed)[1])
  152. tgt = dec_seq_embed
  153. for decoder_layer in self.decoder:
  154. tgt = decoder_layer(tgt, memory, self_mask=tgt_mask)
  155. dec_output = tgt
  156. dec_output = dec_output[:, -1, :]
  157. word_prob = F.softmax(self.tgt_word_prj(dec_output), axis=-1)
  158. preds_idx = paddle.argmax(word_prob, axis=-1)
  159. if paddle.equal_all(
  160. preds_idx,
  161. paddle.full(
  162. paddle.shape(preds_idx), 3, dtype='int64')):
  163. break
  164. preds_prob = paddle.max(word_prob, axis=-1)
  165. dec_seq = paddle.concat(
  166. [dec_seq, paddle.reshape(preds_idx, [-1, 1])], axis=1)
  167. dec_prob = paddle.concat(
  168. [dec_prob, paddle.reshape(preds_prob, [-1, 1])], axis=1)
  169. return [dec_seq, dec_prob]
  170. def forward_beam(self, images):
  171. """ Translation work in one batch """
  172. def get_inst_idx_to_tensor_position_map(inst_idx_list):
  173. """ Indicate the position of an instance in a tensor. """
  174. return {
  175. inst_idx: tensor_position
  176. for tensor_position, inst_idx in enumerate(inst_idx_list)
  177. }
  178. def collect_active_part(beamed_tensor, curr_active_inst_idx,
  179. n_prev_active_inst, n_bm):
  180. """ Collect tensor parts associated to active instances. """
  181. beamed_tensor_shape = paddle.shape(beamed_tensor)
  182. n_curr_active_inst = len(curr_active_inst_idx)
  183. new_shape = (n_curr_active_inst * n_bm, beamed_tensor_shape[1],
  184. beamed_tensor_shape[2])
  185. beamed_tensor = beamed_tensor.reshape([n_prev_active_inst, -1])
  186. beamed_tensor = beamed_tensor.index_select(
  187. curr_active_inst_idx, axis=0)
  188. beamed_tensor = beamed_tensor.reshape(new_shape)
  189. return beamed_tensor
  190. def collate_active_info(src_enc, inst_idx_to_position_map,
  191. active_inst_idx_list):
  192. # Sentences which are still active are collected,
  193. # so the decoder will not run on completed sentences.
  194. n_prev_active_inst = len(inst_idx_to_position_map)
  195. active_inst_idx = [
  196. inst_idx_to_position_map[k] for k in active_inst_idx_list
  197. ]
  198. active_inst_idx = paddle.to_tensor(active_inst_idx, dtype='int64')
  199. active_src_enc = collect_active_part(
  200. src_enc.transpose([1, 0, 2]), active_inst_idx,
  201. n_prev_active_inst, n_bm).transpose([1, 0, 2])
  202. active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
  203. active_inst_idx_list)
  204. return active_src_enc, active_inst_idx_to_position_map
  205. def beam_decode_step(inst_dec_beams, len_dec_seq, enc_output,
  206. inst_idx_to_position_map, n_bm):
  207. """ Decode and update beam status, and then return active beam idx """
  208. def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
  209. dec_partial_seq = [
  210. b.get_current_state() for b in inst_dec_beams if not b.done
  211. ]
  212. dec_partial_seq = paddle.stack(dec_partial_seq)
  213. dec_partial_seq = dec_partial_seq.reshape([-1, len_dec_seq])
  214. return dec_partial_seq
  215. def predict_word(dec_seq, enc_output, n_active_inst, n_bm):
  216. dec_seq = self.embedding(dec_seq)
  217. dec_seq = self.positional_encoding(dec_seq)
  218. tgt_mask = self.generate_square_subsequent_mask(
  219. paddle.shape(dec_seq)[1])
  220. tgt = dec_seq
  221. for decoder_layer in self.decoder:
  222. tgt = decoder_layer(tgt, enc_output, self_mask=tgt_mask)
  223. dec_output = tgt
  224. dec_output = dec_output[:,
  225. -1, :] # Pick the last step: (bh * bm) * d_h
  226. word_prob = F.softmax(self.tgt_word_prj(dec_output), axis=1)
  227. word_prob = paddle.reshape(word_prob, [n_active_inst, n_bm, -1])
  228. return word_prob
  229. def collect_active_inst_idx_list(inst_beams, word_prob,
  230. inst_idx_to_position_map):
  231. active_inst_idx_list = []
  232. for inst_idx, inst_position in inst_idx_to_position_map.items():
  233. is_inst_complete = inst_beams[inst_idx].advance(word_prob[
  234. inst_position])
  235. if not is_inst_complete:
  236. active_inst_idx_list += [inst_idx]
  237. return active_inst_idx_list
  238. n_active_inst = len(inst_idx_to_position_map)
  239. dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
  240. word_prob = predict_word(dec_seq, enc_output, n_active_inst, n_bm)
  241. # Update the beam with predicted word prob information and collect incomplete instances
  242. active_inst_idx_list = collect_active_inst_idx_list(
  243. inst_dec_beams, word_prob, inst_idx_to_position_map)
  244. return active_inst_idx_list
  245. def collect_hypothesis_and_scores(inst_dec_beams, n_best):
  246. all_hyp, all_scores = [], []
  247. for inst_idx in range(len(inst_dec_beams)):
  248. scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores()
  249. all_scores += [scores[:n_best]]
  250. hyps = [
  251. inst_dec_beams[inst_idx].get_hypothesis(i)
  252. for i in tail_idxs[:n_best]
  253. ]
  254. all_hyp += [hyps]
  255. return all_hyp, all_scores
  256. with paddle.no_grad():
  257. #-- Encode
  258. if self.encoder is not None:
  259. src = self.positional_encoding(images)
  260. src_enc = self.encoder(src)
  261. else:
  262. src_enc = images
  263. n_bm = self.beam_size
  264. src_shape = paddle.shape(src_enc)
  265. inst_dec_beams = [Beam(n_bm) for _ in range(1)]
  266. active_inst_idx_list = list(range(1))
  267. # Repeat data for beam search
  268. src_enc = paddle.tile(src_enc, [1, n_bm, 1])
  269. inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
  270. active_inst_idx_list)
  271. # Decode
  272. for len_dec_seq in range(1, paddle.to_tensor(self.max_len)):
  273. src_enc_copy = src_enc.clone()
  274. active_inst_idx_list = beam_decode_step(
  275. inst_dec_beams, len_dec_seq, src_enc_copy,
  276. inst_idx_to_position_map, n_bm)
  277. if not active_inst_idx_list:
  278. break # all instances have finished their path to <EOS>
  279. src_enc, inst_idx_to_position_map = collate_active_info(
  280. src_enc_copy, inst_idx_to_position_map,
  281. active_inst_idx_list)
  282. batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams,
  283. 1)
  284. result_hyp = []
  285. hyp_scores = []
  286. for bs_hyp, score in zip(batch_hyp, batch_scores):
  287. l = len(bs_hyp[0])
  288. bs_hyp_pad = bs_hyp[0] + [3] * (25 - l)
  289. result_hyp.append(bs_hyp_pad)
  290. score = float(score) / l
  291. hyp_score = [score for _ in range(25)]
  292. hyp_scores.append(hyp_score)
  293. return [
  294. paddle.to_tensor(
  295. np.array(result_hyp), dtype=paddle.int64),
  296. paddle.to_tensor(hyp_scores)
  297. ]
  298. def generate_square_subsequent_mask(self, sz):
  299. """Generate a square mask for the sequence. The masked positions are filled with float('-inf').
  300. Unmasked positions are filled with float(0.0).
  301. """
  302. mask = paddle.zeros([sz, sz], dtype='float32')
  303. mask_inf = paddle.triu(
  304. paddle.full(
  305. shape=[sz, sz], dtype='float32', fill_value='-inf'),
  306. diagonal=1)
  307. mask = mask + mask_inf
  308. return mask.unsqueeze([0, 1])
  309. class MultiheadAttention(nn.Layer):
  310. """Allows the model to jointly attend to information
  311. from different representation subspaces.
  312. See reference: Attention Is All You Need
  313. .. math::
  314. \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
  315. \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
  316. Args:
  317. embed_dim: total dimension of the model
  318. num_heads: parallel attention layers, or heads
  319. """
  320. def __init__(self, embed_dim, num_heads, dropout=0., self_attn=False):
  321. super(MultiheadAttention, self).__init__()
  322. self.embed_dim = embed_dim
  323. self.num_heads = num_heads
  324. # self.dropout = dropout
  325. self.head_dim = embed_dim // num_heads
  326. assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
  327. self.scale = self.head_dim**-0.5
  328. self.self_attn = self_attn
  329. if self_attn:
  330. self.qkv = nn.Linear(embed_dim, embed_dim * 3)
  331. else:
  332. self.q = nn.Linear(embed_dim, embed_dim)
  333. self.kv = nn.Linear(embed_dim, embed_dim * 2)
  334. self.attn_drop = nn.Dropout(dropout)
  335. self.out_proj = nn.Linear(embed_dim, embed_dim)
  336. def forward(self, query, key=None, attn_mask=None):
  337. qN = query.shape[1]
  338. if self.self_attn:
  339. qkv = self.qkv(query).reshape(
  340. (0, qN, 3, self.num_heads, self.head_dim)).transpose(
  341. (2, 0, 3, 1, 4))
  342. q, k, v = qkv[0], qkv[1], qkv[2]
  343. else:
  344. kN = key.shape[1]
  345. q = self.q(query).reshape(
  346. [0, qN, self.num_heads, self.head_dim]).transpose([0, 2, 1, 3])
  347. kv = self.kv(key).reshape(
  348. (0, kN, 2, self.num_heads, self.head_dim)).transpose(
  349. (2, 0, 3, 1, 4))
  350. k, v = kv[0], kv[1]
  351. attn = (q.matmul(k.transpose((0, 1, 3, 2)))) * self.scale
  352. if attn_mask is not None:
  353. attn += attn_mask
  354. attn = F.softmax(attn, axis=-1)
  355. attn = self.attn_drop(attn)
  356. x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape(
  357. (0, qN, self.embed_dim))
  358. x = self.out_proj(x)
  359. return x
  360. class TransformerBlock(nn.Layer):
  361. def __init__(self,
  362. d_model,
  363. nhead,
  364. dim_feedforward=2048,
  365. attention_dropout_rate=0.0,
  366. residual_dropout_rate=0.1,
  367. with_self_attn=True,
  368. with_cross_attn=False,
  369. epsilon=1e-5):
  370. super(TransformerBlock, self).__init__()
  371. self.with_self_attn = with_self_attn
  372. if with_self_attn:
  373. self.self_attn = MultiheadAttention(
  374. d_model,
  375. nhead,
  376. dropout=attention_dropout_rate,
  377. self_attn=with_self_attn)
  378. self.norm1 = LayerNorm(d_model, epsilon=epsilon)
  379. self.dropout1 = Dropout(residual_dropout_rate)
  380. self.with_cross_attn = with_cross_attn
  381. if with_cross_attn:
  382. self.cross_attn = MultiheadAttention( #for self_attn of encoder or cross_attn of decoder
  383. d_model,
  384. nhead,
  385. dropout=attention_dropout_rate)
  386. self.norm2 = LayerNorm(d_model, epsilon=epsilon)
  387. self.dropout2 = Dropout(residual_dropout_rate)
  388. self.mlp = Mlp(in_features=d_model,
  389. hidden_features=dim_feedforward,
  390. act_layer=nn.ReLU,
  391. drop=residual_dropout_rate)
  392. self.norm3 = LayerNorm(d_model, epsilon=epsilon)
  393. self.dropout3 = Dropout(residual_dropout_rate)
  394. def forward(self, tgt, memory=None, self_mask=None, cross_mask=None):
  395. if self.with_self_attn:
  396. tgt1 = self.self_attn(tgt, attn_mask=self_mask)
  397. tgt = self.norm1(tgt + self.dropout1(tgt1))
  398. if self.with_cross_attn:
  399. tgt2 = self.cross_attn(tgt, key=memory, attn_mask=cross_mask)
  400. tgt = self.norm2(tgt + self.dropout2(tgt2))
  401. tgt = self.norm3(tgt + self.dropout3(self.mlp(tgt)))
  402. return tgt
  403. class PositionalEncoding(nn.Layer):
  404. """Inject some information about the relative or absolute position of the tokens
  405. in the sequence. The positional encodings have the same dimension as
  406. the embeddings, so that the two can be summed. Here, we use sine and cosine
  407. functions of different frequencies.
  408. .. math::
  409. \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
  410. \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
  411. \text{where pos is the word position and i is the embed idx)
  412. Args:
  413. d_model: the embed dim (required).
  414. dropout: the dropout value (default=0.1).
  415. max_len: the max. length of the incoming sequence (default=5000).
  416. Examples:
  417. >>> pos_encoder = PositionalEncoding(d_model)
  418. """
  419. def __init__(self, dropout, dim, max_len=5000):
  420. super(PositionalEncoding, self).__init__()
  421. self.dropout = nn.Dropout(p=dropout)
  422. pe = paddle.zeros([max_len, dim])
  423. position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1)
  424. div_term = paddle.exp(
  425. paddle.arange(0, dim, 2).astype('float32') *
  426. (-math.log(10000.0) / dim))
  427. pe[:, 0::2] = paddle.sin(position * div_term)
  428. pe[:, 1::2] = paddle.cos(position * div_term)
  429. pe = paddle.unsqueeze(pe, 0)
  430. pe = paddle.transpose(pe, [1, 0, 2])
  431. self.register_buffer('pe', pe)
  432. def forward(self, x):
  433. """Inputs of forward function
  434. Args:
  435. x: the sequence fed to the positional encoder model (required).
  436. Shape:
  437. x: [sequence length, batch size, embed dim]
  438. output: [sequence length, batch size, embed dim]
  439. Examples:
  440. >>> output = pos_encoder(x)
  441. """
  442. x = x.transpose([1, 0, 2])
  443. x = x + self.pe[:paddle.shape(x)[0], :]
  444. return self.dropout(x).transpose([1, 0, 2])
  445. class PositionalEncoding_2d(nn.Layer):
  446. """Inject some information about the relative or absolute position of the tokens
  447. in the sequence. The positional encodings have the same dimension as
  448. the embeddings, so that the two can be summed. Here, we use sine and cosine
  449. functions of different frequencies.
  450. .. math::
  451. \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
  452. \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
  453. \text{where pos is the word position and i is the embed idx)
  454. Args:
  455. d_model: the embed dim (required).
  456. dropout: the dropout value (default=0.1).
  457. max_len: the max. length of the incoming sequence (default=5000).
  458. Examples:
  459. >>> pos_encoder = PositionalEncoding(d_model)
  460. """
  461. def __init__(self, dropout, dim, max_len=5000):
  462. super(PositionalEncoding_2d, self).__init__()
  463. self.dropout = nn.Dropout(p=dropout)
  464. pe = paddle.zeros([max_len, dim])
  465. position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1)
  466. div_term = paddle.exp(
  467. paddle.arange(0, dim, 2).astype('float32') *
  468. (-math.log(10000.0) / dim))
  469. pe[:, 0::2] = paddle.sin(position * div_term)
  470. pe[:, 1::2] = paddle.cos(position * div_term)
  471. pe = paddle.transpose(paddle.unsqueeze(pe, 0), [1, 0, 2])
  472. self.register_buffer('pe', pe)
  473. self.avg_pool_1 = nn.AdaptiveAvgPool2D((1, 1))
  474. self.linear1 = nn.Linear(dim, dim)
  475. self.linear1.weight.data.fill_(1.)
  476. self.avg_pool_2 = nn.AdaptiveAvgPool2D((1, 1))
  477. self.linear2 = nn.Linear(dim, dim)
  478. self.linear2.weight.data.fill_(1.)
  479. def forward(self, x):
  480. """Inputs of forward function
  481. Args:
  482. x: the sequence fed to the positional encoder model (required).
  483. Shape:
  484. x: [sequence length, batch size, embed dim]
  485. output: [sequence length, batch size, embed dim]
  486. Examples:
  487. >>> output = pos_encoder(x)
  488. """
  489. w_pe = self.pe[:paddle.shape(x)[-1], :]
  490. w1 = self.linear1(self.avg_pool_1(x).squeeze()).unsqueeze(0)
  491. w_pe = w_pe * w1
  492. w_pe = paddle.transpose(w_pe, [1, 2, 0])
  493. w_pe = paddle.unsqueeze(w_pe, 2)
  494. h_pe = self.pe[:paddle.shape(x).shape[-2], :]
  495. w2 = self.linear2(self.avg_pool_2(x).squeeze()).unsqueeze(0)
  496. h_pe = h_pe * w2
  497. h_pe = paddle.transpose(h_pe, [1, 2, 0])
  498. h_pe = paddle.unsqueeze(h_pe, 3)
  499. x = x + w_pe + h_pe
  500. x = paddle.transpose(
  501. paddle.reshape(x,
  502. [x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]),
  503. [2, 0, 1])
  504. return self.dropout(x)
  505. class Embeddings(nn.Layer):
  506. def __init__(self, d_model, vocab, padding_idx=None, scale_embedding=True):
  507. super(Embeddings, self).__init__()
  508. self.embedding = nn.Embedding(vocab, d_model, padding_idx=padding_idx)
  509. w0 = np.random.normal(0.0, d_model**-0.5,
  510. (vocab, d_model)).astype(np.float32)
  511. self.embedding.weight.set_value(w0)
  512. self.d_model = d_model
  513. self.scale_embedding = scale_embedding
  514. def forward(self, x):
  515. if self.scale_embedding:
  516. x = self.embedding(x)
  517. return x * math.sqrt(self.d_model)
  518. return self.embedding(x)
  519. class Beam():
  520. """ Beam search """
  521. def __init__(self, size, device=False):
  522. self.size = size
  523. self._done = False
  524. # The score for each translation on the beam.
  525. self.scores = paddle.zeros((size, ), dtype=paddle.float32)
  526. self.all_scores = []
  527. # The backpointers at each time-step.
  528. self.prev_ks = []
  529. # The outputs at each time-step.
  530. self.next_ys = [paddle.full((size, ), 0, dtype=paddle.int64)]
  531. self.next_ys[0][0] = 2
  532. def get_current_state(self):
  533. "Get the outputs for the current timestep."
  534. return self.get_tentative_hypothesis()
  535. def get_current_origin(self):
  536. "Get the backpointers for the current timestep."
  537. return self.prev_ks[-1]
  538. @property
  539. def done(self):
  540. return self._done
  541. def advance(self, word_prob):
  542. "Update beam status and check if finished or not."
  543. num_words = word_prob.shape[1]
  544. # Sum the previous scores.
  545. if len(self.prev_ks) > 0:
  546. beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob)
  547. else:
  548. beam_lk = word_prob[0]
  549. flat_beam_lk = beam_lk.reshape([-1])
  550. best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True,
  551. True) # 1st sort
  552. self.all_scores.append(self.scores)
  553. self.scores = best_scores
  554. # bestScoresId is flattened as a (beam x word) array,
  555. # so we need to calculate which word and beam each score came from
  556. prev_k = best_scores_id // num_words
  557. self.prev_ks.append(prev_k)
  558. self.next_ys.append(best_scores_id - prev_k * num_words)
  559. # End condition is when top-of-beam is EOS.
  560. if self.next_ys[-1][0] == 3:
  561. self._done = True
  562. self.all_scores.append(self.scores)
  563. return self._done
  564. def sort_scores(self):
  565. "Sort the scores."
  566. return self.scores, paddle.to_tensor(
  567. [i for i in range(int(self.scores.shape[0]))], dtype='int32')
  568. def get_the_best_score_and_idx(self):
  569. "Get the score of the best in the beam."
  570. scores, ids = self.sort_scores()
  571. return scores[1], ids[1]
  572. def get_tentative_hypothesis(self):
  573. "Get the decoded sequence for the current timestep."
  574. if len(self.next_ys) == 1:
  575. dec_seq = self.next_ys[0].unsqueeze(1)
  576. else:
  577. _, keys = self.sort_scores()
  578. hyps = [self.get_hypothesis(k) for k in keys]
  579. hyps = [[2] + h for h in hyps]
  580. dec_seq = paddle.to_tensor(hyps, dtype='int64')
  581. return dec_seq
  582. def get_hypothesis(self, k):
  583. """ Walk back to construct the full hypothesis. """
  584. hyp = []
  585. for j in range(len(self.prev_ks) - 1, -1, -1):
  586. hyp.append(self.next_ys[j + 1][k])
  587. k = self.prev_ks[j][k]
  588. return list(map(lambda x: x.item(), hyp[::-1]))