kie_unet_sdmgr.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. from paddle import nn
  19. import numpy as np
  20. import cv2
  21. __all__ = ["Kie_backbone"]
  22. class Encoder(nn.Layer):
  23. def __init__(self, num_channels, num_filters):
  24. super(Encoder, self).__init__()
  25. self.conv1 = nn.Conv2D(
  26. num_channels,
  27. num_filters,
  28. kernel_size=3,
  29. stride=1,
  30. padding=1,
  31. bias_attr=False)
  32. self.bn1 = nn.BatchNorm(num_filters, act='relu')
  33. self.conv2 = nn.Conv2D(
  34. num_filters,
  35. num_filters,
  36. kernel_size=3,
  37. stride=1,
  38. padding=1,
  39. bias_attr=False)
  40. self.bn2 = nn.BatchNorm(num_filters, act='relu')
  41. self.pool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
  42. def forward(self, inputs):
  43. x = self.conv1(inputs)
  44. x = self.bn1(x)
  45. x = self.conv2(x)
  46. x = self.bn2(x)
  47. x_pooled = self.pool(x)
  48. return x, x_pooled
  49. class Decoder(nn.Layer):
  50. def __init__(self, num_channels, num_filters):
  51. super(Decoder, self).__init__()
  52. self.conv1 = nn.Conv2D(
  53. num_channels,
  54. num_filters,
  55. kernel_size=3,
  56. stride=1,
  57. padding=1,
  58. bias_attr=False)
  59. self.bn1 = nn.BatchNorm(num_filters, act='relu')
  60. self.conv2 = nn.Conv2D(
  61. num_filters,
  62. num_filters,
  63. kernel_size=3,
  64. stride=1,
  65. padding=1,
  66. bias_attr=False)
  67. self.bn2 = nn.BatchNorm(num_filters, act='relu')
  68. self.conv0 = nn.Conv2D(
  69. num_channels,
  70. num_filters,
  71. kernel_size=1,
  72. stride=1,
  73. padding=0,
  74. bias_attr=False)
  75. self.bn0 = nn.BatchNorm(num_filters, act='relu')
  76. def forward(self, inputs_prev, inputs):
  77. x = self.conv0(inputs)
  78. x = self.bn0(x)
  79. x = paddle.nn.functional.interpolate(
  80. x, scale_factor=2, mode='bilinear', align_corners=False)
  81. x = paddle.concat([inputs_prev, x], axis=1)
  82. x = self.conv1(x)
  83. x = self.bn1(x)
  84. x = self.conv2(x)
  85. x = self.bn2(x)
  86. return x
  87. class UNet(nn.Layer):
  88. def __init__(self):
  89. super(UNet, self).__init__()
  90. self.down1 = Encoder(num_channels=3, num_filters=16)
  91. self.down2 = Encoder(num_channels=16, num_filters=32)
  92. self.down3 = Encoder(num_channels=32, num_filters=64)
  93. self.down4 = Encoder(num_channels=64, num_filters=128)
  94. self.down5 = Encoder(num_channels=128, num_filters=256)
  95. self.up1 = Decoder(32, 16)
  96. self.up2 = Decoder(64, 32)
  97. self.up3 = Decoder(128, 64)
  98. self.up4 = Decoder(256, 128)
  99. self.out_channels = 16
  100. def forward(self, inputs):
  101. x1, _ = self.down1(inputs)
  102. _, x2 = self.down2(x1)
  103. _, x3 = self.down3(x2)
  104. _, x4 = self.down4(x3)
  105. _, x5 = self.down5(x4)
  106. x = self.up4(x4, x5)
  107. x = self.up3(x3, x)
  108. x = self.up2(x2, x)
  109. x = self.up1(x1, x)
  110. return x
  111. class Kie_backbone(nn.Layer):
  112. def __init__(self, in_channels, **kwargs):
  113. super(Kie_backbone, self).__init__()
  114. self.out_channels = 16
  115. self.img_feat = UNet()
  116. self.maxpool = nn.MaxPool2D(kernel_size=7)
  117. def bbox2roi(self, bbox_list):
  118. rois_list = []
  119. rois_num = []
  120. for img_id, bboxes in enumerate(bbox_list):
  121. rois_num.append(bboxes.shape[0])
  122. rois_list.append(bboxes)
  123. rois = paddle.concat(rois_list, 0)
  124. rois_num = paddle.to_tensor(rois_num, dtype='int32')
  125. return rois, rois_num
  126. def pre_process(self, img, relations, texts, gt_bboxes, tag, img_size):
  127. img, relations, texts, gt_bboxes, tag, img_size = img.numpy(
  128. ), relations.numpy(), texts.numpy(), gt_bboxes.numpy(), tag.numpy(
  129. ).tolist(), img_size.numpy()
  130. temp_relations, temp_texts, temp_gt_bboxes = [], [], []
  131. h, w = int(np.max(img_size[:, 0])), int(np.max(img_size[:, 1]))
  132. img = paddle.to_tensor(img[:, :, :h, :w])
  133. batch = len(tag)
  134. for i in range(batch):
  135. num, recoder_len = tag[i][0], tag[i][1]
  136. temp_relations.append(
  137. paddle.to_tensor(
  138. relations[i, :num, :num, :], dtype='float32'))
  139. temp_texts.append(
  140. paddle.to_tensor(
  141. texts[i, :num, :recoder_len], dtype='float32'))
  142. temp_gt_bboxes.append(
  143. paddle.to_tensor(
  144. gt_bboxes[i, :num, ...], dtype='float32'))
  145. return img, temp_relations, temp_texts, temp_gt_bboxes
  146. def forward(self, inputs):
  147. img = inputs[0]
  148. relations, texts, gt_bboxes, tag, img_size = inputs[1], inputs[
  149. 2], inputs[3], inputs[5], inputs[-1]
  150. img, relations, texts, gt_bboxes = self.pre_process(
  151. img, relations, texts, gt_bboxes, tag, img_size)
  152. x = self.img_feat(img)
  153. boxes, rois_num = self.bbox2roi(gt_bboxes)
  154. feats = paddle.vision.ops.roi_align(
  155. x, boxes, spatial_scale=1.0, output_size=7, boxes_num=rois_num)
  156. feats = self.maxpool(feats).squeeze(-1).squeeze(-1)
  157. return [relations, texts, feats]