rec_svtrnet.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from paddle import ParamAttr
  15. from paddle.nn.initializer import KaimingNormal
  16. import numpy as np
  17. import paddle
  18. import paddle.nn as nn
  19. from paddle.nn.initializer import TruncatedNormal, Constant, Normal
  20. trunc_normal_ = TruncatedNormal(std=.02)
  21. normal_ = Normal
  22. zeros_ = Constant(value=0.)
  23. ones_ = Constant(value=1.)
  24. def drop_path(x, drop_prob=0., training=False):
  25. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  26. the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  27. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
  28. """
  29. if drop_prob == 0. or not training:
  30. return x
  31. keep_prob = paddle.to_tensor(1 - drop_prob, dtype=x.dtype)
  32. shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1)
  33. random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
  34. random_tensor = paddle.floor(random_tensor) # binarize
  35. output = x.divide(keep_prob) * random_tensor
  36. return output
  37. class ConvBNLayer(nn.Layer):
  38. def __init__(self,
  39. in_channels,
  40. out_channels,
  41. kernel_size=3,
  42. stride=1,
  43. padding=0,
  44. bias_attr=False,
  45. groups=1,
  46. act=nn.GELU):
  47. super().__init__()
  48. self.conv = nn.Conv2D(
  49. in_channels=in_channels,
  50. out_channels=out_channels,
  51. kernel_size=kernel_size,
  52. stride=stride,
  53. padding=padding,
  54. groups=groups,
  55. weight_attr=paddle.ParamAttr(
  56. initializer=nn.initializer.KaimingUniform()),
  57. bias_attr=bias_attr)
  58. self.norm = nn.BatchNorm2D(out_channels)
  59. self.act = act()
  60. def forward(self, inputs):
  61. out = self.conv(inputs)
  62. out = self.norm(out)
  63. out = self.act(out)
  64. return out
  65. class DropPath(nn.Layer):
  66. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  67. """
  68. def __init__(self, drop_prob=None):
  69. super(DropPath, self).__init__()
  70. self.drop_prob = drop_prob
  71. def forward(self, x):
  72. return drop_path(x, self.drop_prob, self.training)
  73. class Identity(nn.Layer):
  74. def __init__(self):
  75. super(Identity, self).__init__()
  76. def forward(self, input):
  77. return input
  78. class Mlp(nn.Layer):
  79. def __init__(self,
  80. in_features,
  81. hidden_features=None,
  82. out_features=None,
  83. act_layer=nn.GELU,
  84. drop=0.):
  85. super().__init__()
  86. out_features = out_features or in_features
  87. hidden_features = hidden_features or in_features
  88. self.fc1 = nn.Linear(in_features, hidden_features)
  89. self.act = act_layer()
  90. self.fc2 = nn.Linear(hidden_features, out_features)
  91. self.drop = nn.Dropout(drop)
  92. def forward(self, x):
  93. x = self.fc1(x)
  94. x = self.act(x)
  95. x = self.drop(x)
  96. x = self.fc2(x)
  97. x = self.drop(x)
  98. return x
  99. class ConvMixer(nn.Layer):
  100. def __init__(
  101. self,
  102. dim,
  103. num_heads=8,
  104. HW=[8, 25],
  105. local_k=[3, 3], ):
  106. super().__init__()
  107. self.HW = HW
  108. self.dim = dim
  109. self.local_mixer = nn.Conv2D(
  110. dim,
  111. dim,
  112. local_k,
  113. 1, [local_k[0] // 2, local_k[1] // 2],
  114. groups=num_heads,
  115. weight_attr=ParamAttr(initializer=KaimingNormal()))
  116. def forward(self, x):
  117. h = self.HW[0]
  118. w = self.HW[1]
  119. x = x.transpose([0, 2, 1]).reshape([0, self.dim, h, w])
  120. x = self.local_mixer(x)
  121. x = x.flatten(2).transpose([0, 2, 1])
  122. return x
  123. class Attention(nn.Layer):
  124. def __init__(self,
  125. dim,
  126. num_heads=8,
  127. mixer='Global',
  128. HW=None,
  129. local_k=[7, 11],
  130. qkv_bias=False,
  131. qk_scale=None,
  132. attn_drop=0.,
  133. proj_drop=0.):
  134. super().__init__()
  135. self.num_heads = num_heads
  136. head_dim = dim // num_heads
  137. self.scale = qk_scale or head_dim**-0.5
  138. self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
  139. self.attn_drop = nn.Dropout(attn_drop)
  140. self.proj = nn.Linear(dim, dim)
  141. self.proj_drop = nn.Dropout(proj_drop)
  142. self.HW = HW
  143. if HW is not None:
  144. H = HW[0]
  145. W = HW[1]
  146. self.N = H * W
  147. self.C = dim
  148. if mixer == 'Local' and HW is not None:
  149. hk = local_k[0]
  150. wk = local_k[1]
  151. mask = paddle.ones([H * W, H + hk - 1, W + wk - 1], dtype='float32')
  152. for h in range(0, H):
  153. for w in range(0, W):
  154. mask[h * W + w, h:h + hk, w:w + wk] = 0.
  155. mask_paddle = mask[:, hk // 2:H + hk // 2, wk // 2:W + wk //
  156. 2].flatten(1)
  157. mask_inf = paddle.full([H * W, H * W], '-inf', dtype='float32')
  158. mask = paddle.where(mask_paddle < 1, mask_paddle, mask_inf)
  159. self.mask = mask.unsqueeze([0, 1])
  160. self.mixer = mixer
  161. def forward(self, x):
  162. if self.HW is not None:
  163. N = self.N
  164. C = self.C
  165. else:
  166. _, N, C = x.shape
  167. qkv = self.qkv(x).reshape((0, N, 3, self.num_heads, C //
  168. self.num_heads)).transpose((2, 0, 3, 1, 4))
  169. q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
  170. attn = (q.matmul(k.transpose((0, 1, 3, 2))))
  171. if self.mixer == 'Local':
  172. attn += self.mask
  173. attn = nn.functional.softmax(attn, axis=-1)
  174. attn = self.attn_drop(attn)
  175. x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((0, N, C))
  176. x = self.proj(x)
  177. x = self.proj_drop(x)
  178. return x
  179. class Block(nn.Layer):
  180. def __init__(self,
  181. dim,
  182. num_heads,
  183. mixer='Global',
  184. local_mixer=[7, 11],
  185. HW=None,
  186. mlp_ratio=4.,
  187. qkv_bias=False,
  188. qk_scale=None,
  189. drop=0.,
  190. attn_drop=0.,
  191. drop_path=0.,
  192. act_layer=nn.GELU,
  193. norm_layer='nn.LayerNorm',
  194. epsilon=1e-6,
  195. prenorm=True):
  196. super().__init__()
  197. if isinstance(norm_layer, str):
  198. self.norm1 = eval(norm_layer)(dim, epsilon=epsilon)
  199. else:
  200. self.norm1 = norm_layer(dim)
  201. if mixer == 'Global' or mixer == 'Local':
  202. self.mixer = Attention(
  203. dim,
  204. num_heads=num_heads,
  205. mixer=mixer,
  206. HW=HW,
  207. local_k=local_mixer,
  208. qkv_bias=qkv_bias,
  209. qk_scale=qk_scale,
  210. attn_drop=attn_drop,
  211. proj_drop=drop)
  212. elif mixer == 'Conv':
  213. self.mixer = ConvMixer(
  214. dim, num_heads=num_heads, HW=HW, local_k=local_mixer)
  215. else:
  216. raise TypeError("The mixer must be one of [Global, Local, Conv]")
  217. self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
  218. if isinstance(norm_layer, str):
  219. self.norm2 = eval(norm_layer)(dim, epsilon=epsilon)
  220. else:
  221. self.norm2 = norm_layer(dim)
  222. mlp_hidden_dim = int(dim * mlp_ratio)
  223. self.mlp_ratio = mlp_ratio
  224. self.mlp = Mlp(in_features=dim,
  225. hidden_features=mlp_hidden_dim,
  226. act_layer=act_layer,
  227. drop=drop)
  228. self.prenorm = prenorm
  229. def forward(self, x):
  230. if self.prenorm:
  231. x = self.norm1(x + self.drop_path(self.mixer(x)))
  232. x = self.norm2(x + self.drop_path(self.mlp(x)))
  233. else:
  234. x = x + self.drop_path(self.mixer(self.norm1(x)))
  235. x = x + self.drop_path(self.mlp(self.norm2(x)))
  236. return x
  237. class PatchEmbed(nn.Layer):
  238. """ Image to Patch Embedding
  239. """
  240. def __init__(self,
  241. img_size=[32, 100],
  242. in_channels=3,
  243. embed_dim=768,
  244. sub_num=2,
  245. patch_size=[4, 4],
  246. mode='pope'):
  247. super().__init__()
  248. num_patches = (img_size[1] // (2 ** sub_num)) * \
  249. (img_size[0] // (2 ** sub_num))
  250. self.img_size = img_size
  251. self.num_patches = num_patches
  252. self.embed_dim = embed_dim
  253. self.norm = None
  254. if mode == 'pope':
  255. if sub_num == 2:
  256. self.proj = nn.Sequential(
  257. ConvBNLayer(
  258. in_channels=in_channels,
  259. out_channels=embed_dim // 2,
  260. kernel_size=3,
  261. stride=2,
  262. padding=1,
  263. act=nn.GELU,
  264. bias_attr=None),
  265. ConvBNLayer(
  266. in_channels=embed_dim // 2,
  267. out_channels=embed_dim,
  268. kernel_size=3,
  269. stride=2,
  270. padding=1,
  271. act=nn.GELU,
  272. bias_attr=None))
  273. if sub_num == 3:
  274. self.proj = nn.Sequential(
  275. ConvBNLayer(
  276. in_channels=in_channels,
  277. out_channels=embed_dim // 4,
  278. kernel_size=3,
  279. stride=2,
  280. padding=1,
  281. act=nn.GELU,
  282. bias_attr=None),
  283. ConvBNLayer(
  284. in_channels=embed_dim // 4,
  285. out_channels=embed_dim // 2,
  286. kernel_size=3,
  287. stride=2,
  288. padding=1,
  289. act=nn.GELU,
  290. bias_attr=None),
  291. ConvBNLayer(
  292. in_channels=embed_dim // 2,
  293. out_channels=embed_dim,
  294. kernel_size=3,
  295. stride=2,
  296. padding=1,
  297. act=nn.GELU,
  298. bias_attr=None))
  299. elif mode == 'linear':
  300. self.proj = nn.Conv2D(
  301. 1, embed_dim, kernel_size=patch_size, stride=patch_size)
  302. self.num_patches = img_size[0] // patch_size[0] * img_size[
  303. 1] // patch_size[1]
  304. def forward(self, x):
  305. B, C, H, W = x.shape
  306. assert H == self.img_size[0] and W == self.img_size[1], \
  307. f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
  308. x = self.proj(x).flatten(2).transpose((0, 2, 1))
  309. return x
  310. class SubSample(nn.Layer):
  311. def __init__(self,
  312. in_channels,
  313. out_channels,
  314. types='Pool',
  315. stride=[2, 1],
  316. sub_norm='nn.LayerNorm',
  317. act=None):
  318. super().__init__()
  319. self.types = types
  320. if types == 'Pool':
  321. self.avgpool = nn.AvgPool2D(
  322. kernel_size=[3, 5], stride=stride, padding=[1, 2])
  323. self.maxpool = nn.MaxPool2D(
  324. kernel_size=[3, 5], stride=stride, padding=[1, 2])
  325. self.proj = nn.Linear(in_channels, out_channels)
  326. else:
  327. self.conv = nn.Conv2D(
  328. in_channels,
  329. out_channels,
  330. kernel_size=3,
  331. stride=stride,
  332. padding=1,
  333. weight_attr=ParamAttr(initializer=KaimingNormal()))
  334. self.norm = eval(sub_norm)(out_channels)
  335. if act is not None:
  336. self.act = act()
  337. else:
  338. self.act = None
  339. def forward(self, x):
  340. if self.types == 'Pool':
  341. x1 = self.avgpool(x)
  342. x2 = self.maxpool(x)
  343. x = (x1 + x2) * 0.5
  344. out = self.proj(x.flatten(2).transpose((0, 2, 1)))
  345. else:
  346. x = self.conv(x)
  347. out = x.flatten(2).transpose((0, 2, 1))
  348. out = self.norm(out)
  349. if self.act is not None:
  350. out = self.act(out)
  351. return out
  352. class SVTRNet(nn.Layer):
  353. def __init__(
  354. self,
  355. img_size=[32, 100],
  356. in_channels=3,
  357. embed_dim=[64, 128, 256],
  358. depth=[3, 6, 3],
  359. num_heads=[2, 4, 8],
  360. mixer=['Local'] * 6 + ['Global'] *
  361. 6, # Local atten, Global atten, Conv
  362. local_mixer=[[7, 11], [7, 11], [7, 11]],
  363. patch_merging='Conv', # Conv, Pool, None
  364. mlp_ratio=4,
  365. qkv_bias=True,
  366. qk_scale=None,
  367. drop_rate=0.,
  368. last_drop=0.1,
  369. attn_drop_rate=0.,
  370. drop_path_rate=0.1,
  371. norm_layer='nn.LayerNorm',
  372. sub_norm='nn.LayerNorm',
  373. epsilon=1e-6,
  374. out_channels=192,
  375. out_char_num=25,
  376. block_unit='Block',
  377. act='nn.GELU',
  378. last_stage=True,
  379. sub_num=2,
  380. prenorm=True,
  381. use_lenhead=False,
  382. **kwargs):
  383. super().__init__()
  384. self.img_size = img_size
  385. self.embed_dim = embed_dim
  386. self.out_channels = out_channels
  387. self.prenorm = prenorm
  388. patch_merging = None if patch_merging != 'Conv' and patch_merging != 'Pool' else patch_merging
  389. self.patch_embed = PatchEmbed(
  390. img_size=img_size,
  391. in_channels=in_channels,
  392. embed_dim=embed_dim[0],
  393. sub_num=sub_num)
  394. num_patches = self.patch_embed.num_patches
  395. self.HW = [img_size[0] // (2**sub_num), img_size[1] // (2**sub_num)]
  396. self.pos_embed = self.create_parameter(
  397. shape=[1, num_patches, embed_dim[0]], default_initializer=zeros_)
  398. self.add_parameter("pos_embed", self.pos_embed)
  399. self.pos_drop = nn.Dropout(p=drop_rate)
  400. Block_unit = eval(block_unit)
  401. dpr = np.linspace(0, drop_path_rate, sum(depth))
  402. self.blocks1 = nn.LayerList([
  403. Block_unit(
  404. dim=embed_dim[0],
  405. num_heads=num_heads[0],
  406. mixer=mixer[0:depth[0]][i],
  407. HW=self.HW,
  408. local_mixer=local_mixer[0],
  409. mlp_ratio=mlp_ratio,
  410. qkv_bias=qkv_bias,
  411. qk_scale=qk_scale,
  412. drop=drop_rate,
  413. act_layer=eval(act),
  414. attn_drop=attn_drop_rate,
  415. drop_path=dpr[0:depth[0]][i],
  416. norm_layer=norm_layer,
  417. epsilon=epsilon,
  418. prenorm=prenorm) for i in range(depth[0])
  419. ])
  420. if patch_merging is not None:
  421. self.sub_sample1 = SubSample(
  422. embed_dim[0],
  423. embed_dim[1],
  424. sub_norm=sub_norm,
  425. stride=[2, 1],
  426. types=patch_merging)
  427. HW = [self.HW[0] // 2, self.HW[1]]
  428. else:
  429. HW = self.HW
  430. self.patch_merging = patch_merging
  431. self.blocks2 = nn.LayerList([
  432. Block_unit(
  433. dim=embed_dim[1],
  434. num_heads=num_heads[1],
  435. mixer=mixer[depth[0]:depth[0] + depth[1]][i],
  436. HW=HW,
  437. local_mixer=local_mixer[1],
  438. mlp_ratio=mlp_ratio,
  439. qkv_bias=qkv_bias,
  440. qk_scale=qk_scale,
  441. drop=drop_rate,
  442. act_layer=eval(act),
  443. attn_drop=attn_drop_rate,
  444. drop_path=dpr[depth[0]:depth[0] + depth[1]][i],
  445. norm_layer=norm_layer,
  446. epsilon=epsilon,
  447. prenorm=prenorm) for i in range(depth[1])
  448. ])
  449. if patch_merging is not None:
  450. self.sub_sample2 = SubSample(
  451. embed_dim[1],
  452. embed_dim[2],
  453. sub_norm=sub_norm,
  454. stride=[2, 1],
  455. types=patch_merging)
  456. HW = [self.HW[0] // 4, self.HW[1]]
  457. else:
  458. HW = self.HW
  459. self.blocks3 = nn.LayerList([
  460. Block_unit(
  461. dim=embed_dim[2],
  462. num_heads=num_heads[2],
  463. mixer=mixer[depth[0] + depth[1]:][i],
  464. HW=HW,
  465. local_mixer=local_mixer[2],
  466. mlp_ratio=mlp_ratio,
  467. qkv_bias=qkv_bias,
  468. qk_scale=qk_scale,
  469. drop=drop_rate,
  470. act_layer=eval(act),
  471. attn_drop=attn_drop_rate,
  472. drop_path=dpr[depth[0] + depth[1]:][i],
  473. norm_layer=norm_layer,
  474. epsilon=epsilon,
  475. prenorm=prenorm) for i in range(depth[2])
  476. ])
  477. self.last_stage = last_stage
  478. if last_stage:
  479. self.avg_pool = nn.AdaptiveAvgPool2D([1, out_char_num])
  480. self.last_conv = nn.Conv2D(
  481. in_channels=embed_dim[2],
  482. out_channels=self.out_channels,
  483. kernel_size=1,
  484. stride=1,
  485. padding=0,
  486. bias_attr=False)
  487. self.hardswish = nn.Hardswish()
  488. self.dropout = nn.Dropout(p=last_drop, mode="downscale_in_infer")
  489. if not prenorm:
  490. self.norm = eval(norm_layer)(embed_dim[-1], epsilon=epsilon)
  491. self.use_lenhead = use_lenhead
  492. if use_lenhead:
  493. self.len_conv = nn.Linear(embed_dim[2], self.out_channels)
  494. self.hardswish_len = nn.Hardswish()
  495. self.dropout_len = nn.Dropout(
  496. p=last_drop, mode="downscale_in_infer")
  497. trunc_normal_(self.pos_embed)
  498. self.apply(self._init_weights)
  499. def _init_weights(self, m):
  500. if isinstance(m, nn.Linear):
  501. trunc_normal_(m.weight)
  502. if isinstance(m, nn.Linear) and m.bias is not None:
  503. zeros_(m.bias)
  504. elif isinstance(m, nn.LayerNorm):
  505. zeros_(m.bias)
  506. ones_(m.weight)
  507. def forward_features(self, x):
  508. x = self.patch_embed(x)
  509. x = x + self.pos_embed
  510. x = self.pos_drop(x)
  511. for blk in self.blocks1:
  512. x = blk(x)
  513. if self.patch_merging is not None:
  514. x = self.sub_sample1(
  515. x.transpose([0, 2, 1]).reshape(
  516. [0, self.embed_dim[0], self.HW[0], self.HW[1]]))
  517. for blk in self.blocks2:
  518. x = blk(x)
  519. if self.patch_merging is not None:
  520. x = self.sub_sample2(
  521. x.transpose([0, 2, 1]).reshape(
  522. [0, self.embed_dim[1], self.HW[0] // 2, self.HW[1]]))
  523. for blk in self.blocks3:
  524. x = blk(x)
  525. if not self.prenorm:
  526. x = self.norm(x)
  527. return x
  528. def forward(self, x):
  529. x = self.forward_features(x)
  530. if self.use_lenhead:
  531. len_x = self.len_conv(x.mean(1))
  532. len_x = self.dropout_len(self.hardswish_len(len_x))
  533. if self.last_stage:
  534. if self.patch_merging is not None:
  535. h = self.HW[0] // 4
  536. else:
  537. h = self.HW[0]
  538. x = self.avg_pool(
  539. x.transpose([0, 2, 1]).reshape(
  540. [0, self.embed_dim[2], h, self.HW[1]]))
  541. x = self.last_conv(x)
  542. x = self.hardswish(x)
  543. x = self.dropout(x)
  544. if self.use_lenhead:
  545. return x, len_x
  546. return x