e2e_pg_head.py 7.9 KB


  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import math
  18. import paddle
  19. from paddle import nn
  20. import paddle.nn.functional as F
  21. from paddle import ParamAttr
  22. class ConvBNLayer(nn.Layer):
  23. def __init__(self,
  24. in_channels,
  25. out_channels,
  26. kernel_size,
  27. stride,
  28. padding,
  29. groups=1,
  30. if_act=True,
  31. act=None,
  32. name=None):
  33. super(ConvBNLayer, self).__init__()
  34. self.if_act = if_act
  35. self.act = act
  36. self.conv = nn.Conv2D(
  37. in_channels=in_channels,
  38. out_channels=out_channels,
  39. kernel_size=kernel_size,
  40. stride=stride,
  41. padding=padding,
  42. groups=groups,
  43. weight_attr=ParamAttr(name=name + '_weights'),
  44. bias_attr=False)
  45. self.bn = nn.BatchNorm(
  46. num_channels=out_channels,
  47. act=act,
  48. param_attr=ParamAttr(name="bn_" + name + "_scale"),
  49. bias_attr=ParamAttr(name="bn_" + name + "_offset"),
  50. moving_mean_name="bn_" + name + "_mean",
  51. moving_variance_name="bn_" + name + "_variance",
  52. use_global_stats=False)
  53. def forward(self, x):
  54. x = self.conv(x)
  55. x = self.bn(x)
  56. return x
  57. class PGHead(nn.Layer):
  58. """
  59. """
  60. def __init__(self,
  61. in_channels,
  62. character_dict_path='ppocr/utils/ic15_dict.txt',
  63. **kwargs):
  64. super(PGHead, self).__init__()
  65. # get character_length
  66. with open(character_dict_path, "rb") as fin:
  67. lines = fin.readlines()
  68. character_length = len(lines) + 1
  69. self.conv_f_score1 = ConvBNLayer(
  70. in_channels=in_channels,
  71. out_channels=64,
  72. kernel_size=1,
  73. stride=1,
  74. padding=0,
  75. act='relu',
  76. name="conv_f_score{}".format(1))
  77. self.conv_f_score2 = ConvBNLayer(
  78. in_channels=64,
  79. out_channels=64,
  80. kernel_size=3,
  81. stride=1,
  82. padding=1,
  83. act='relu',
  84. name="conv_f_score{}".format(2))
  85. self.conv_f_score3 = ConvBNLayer(
  86. in_channels=64,
  87. out_channels=128,
  88. kernel_size=1,
  89. stride=1,
  90. padding=0,
  91. act='relu',
  92. name="conv_f_score{}".format(3))
  93. self.conv1 = nn.Conv2D(
  94. in_channels=128,
  95. out_channels=1,
  96. kernel_size=3,
  97. stride=1,
  98. padding=1,
  99. groups=1,
  100. weight_attr=ParamAttr(name="conv_f_score{}".format(4)),
  101. bias_attr=False)
  102. self.conv_f_boder1 = ConvBNLayer(
  103. in_channels=in_channels,
  104. out_channels=64,
  105. kernel_size=1,
  106. stride=1,
  107. padding=0,
  108. act='relu',
  109. name="conv_f_boder{}".format(1))
  110. self.conv_f_boder2 = ConvBNLayer(
  111. in_channels=64,
  112. out_channels=64,
  113. kernel_size=3,
  114. stride=1,
  115. padding=1,
  116. act='relu',
  117. name="conv_f_boder{}".format(2))
  118. self.conv_f_boder3 = ConvBNLayer(
  119. in_channels=64,
  120. out_channels=128,
  121. kernel_size=1,
  122. stride=1,
  123. padding=0,
  124. act='relu',
  125. name="conv_f_boder{}".format(3))
  126. self.conv2 = nn.Conv2D(
  127. in_channels=128,
  128. out_channels=4,
  129. kernel_size=3,
  130. stride=1,
  131. padding=1,
  132. groups=1,
  133. weight_attr=ParamAttr(name="conv_f_boder{}".format(4)),
  134. bias_attr=False)
  135. self.conv_f_char1 = ConvBNLayer(
  136. in_channels=in_channels,
  137. out_channels=128,
  138. kernel_size=1,
  139. stride=1,
  140. padding=0,
  141. act='relu',
  142. name="conv_f_char{}".format(1))
  143. self.conv_f_char2 = ConvBNLayer(
  144. in_channels=128,
  145. out_channels=128,
  146. kernel_size=3,
  147. stride=1,
  148. padding=1,
  149. act='relu',
  150. name="conv_f_char{}".format(2))
  151. self.conv_f_char3 = ConvBNLayer(
  152. in_channels=128,
  153. out_channels=256,
  154. kernel_size=1,
  155. stride=1,
  156. padding=0,
  157. act='relu',
  158. name="conv_f_char{}".format(3))
  159. self.conv_f_char4 = ConvBNLayer(
  160. in_channels=256,
  161. out_channels=256,
  162. kernel_size=3,
  163. stride=1,
  164. padding=1,
  165. act='relu',
  166. name="conv_f_char{}".format(4))
  167. self.conv_f_char5 = ConvBNLayer(
  168. in_channels=256,
  169. out_channels=256,
  170. kernel_size=1,
  171. stride=1,
  172. padding=0,
  173. act='relu',
  174. name="conv_f_char{}".format(5))
  175. self.conv3 = nn.Conv2D(
  176. in_channels=256,
  177. out_channels=character_length,
  178. kernel_size=3,
  179. stride=1,
  180. padding=1,
  181. groups=1,
  182. weight_attr=ParamAttr(name="conv_f_char{}".format(6)),
  183. bias_attr=False)
  184. self.conv_f_direc1 = ConvBNLayer(
  185. in_channels=in_channels,
  186. out_channels=64,
  187. kernel_size=1,
  188. stride=1,
  189. padding=0,
  190. act='relu',
  191. name="conv_f_direc{}".format(1))
  192. self.conv_f_direc2 = ConvBNLayer(
  193. in_channels=64,
  194. out_channels=64,
  195. kernel_size=3,
  196. stride=1,
  197. padding=1,
  198. act='relu',
  199. name="conv_f_direc{}".format(2))
  200. self.conv_f_direc3 = ConvBNLayer(
  201. in_channels=64,
  202. out_channels=128,
  203. kernel_size=1,
  204. stride=1,
  205. padding=0,
  206. act='relu',
  207. name="conv_f_direc{}".format(3))
  208. self.conv4 = nn.Conv2D(
  209. in_channels=128,
  210. out_channels=2,
  211. kernel_size=3,
  212. stride=1,
  213. padding=1,
  214. groups=1,
  215. weight_attr=ParamAttr(name="conv_f_direc{}".format(4)),
  216. bias_attr=False)
  217. def forward(self, x, targets=None):
  218. f_score = self.conv_f_score1(x)
  219. f_score = self.conv_f_score2(f_score)
  220. f_score = self.conv_f_score3(f_score)
  221. f_score = self.conv1(f_score)
  222. f_score = F.sigmoid(f_score)
  223. # f_border
  224. f_border = self.conv_f_boder1(x)
  225. f_border = self.conv_f_boder2(f_border)
  226. f_border = self.conv_f_boder3(f_border)
  227. f_border = self.conv2(f_border)
  228. f_char = self.conv_f_char1(x)
  229. f_char = self.conv_f_char2(f_char)
  230. f_char = self.conv_f_char3(f_char)
  231. f_char = self.conv_f_char4(f_char)
  232. f_char = self.conv_f_char5(f_char)
  233. f_char = self.conv3(f_char)
  234. f_direction = self.conv_f_direc1(x)
  235. f_direction = self.conv_f_direc2(f_direction)
  236. f_direction = self.conv_f_direc3(f_direction)
  237. f_direction = self.conv4(f_direction)
  238. predicts = {}
  239. predicts['f_score'] = f_score
  240. predicts['f_border'] = f_border
  241. predicts['f_char'] = f_char
  242. predicts['f_direction'] = f_direction
  243. return predicts