copy_paste.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. import cv2
  16. import random
  17. import numpy as np
  18. from PIL import Image
  19. from shapely.geometry import Polygon
  20. from ppocr.data.imaug.iaa_augment import IaaAugment
  21. from ppocr.data.imaug.random_crop_data import is_poly_outside_rect
  22. from tools.infer.utility import get_rotate_crop_image
  23. class CopyPaste(object):
  24. def __init__(self, objects_paste_ratio=0.2, limit_paste=True, **kwargs):
  25. self.ext_data_num = 1
  26. self.objects_paste_ratio = objects_paste_ratio
  27. self.limit_paste = limit_paste
  28. augmenter_args = [{'type': 'Resize', 'args': {'size': [0.5, 3]}}]
  29. self.aug = IaaAugment(augmenter_args)
  30. def __call__(self, data):
  31. point_num = data['polys'].shape[1]
  32. src_img = data['image']
  33. src_polys = data['polys'].tolist()
  34. src_texts = data['texts']
  35. src_ignores = data['ignore_tags'].tolist()
  36. ext_data = data['ext_data'][0]
  37. ext_image = ext_data['image']
  38. ext_polys = ext_data['polys']
  39. ext_texts = ext_data['texts']
  40. ext_ignores = ext_data['ignore_tags']
  41. indexs = [i for i in range(len(ext_ignores)) if not ext_ignores[i]]
  42. select_num = max(
  43. 1, min(int(self.objects_paste_ratio * len(ext_polys)), 30))
  44. random.shuffle(indexs)
  45. select_idxs = indexs[:select_num]
  46. select_polys = ext_polys[select_idxs]
  47. select_ignores = ext_ignores[select_idxs]
  48. src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
  49. ext_image = cv2.cvtColor(ext_image, cv2.COLOR_BGR2RGB)
  50. src_img = Image.fromarray(src_img).convert('RGBA')
  51. for idx, poly, tag in zip(select_idxs, select_polys, select_ignores):
  52. box_img = get_rotate_crop_image(ext_image, poly)
  53. src_img, box = self.paste_img(src_img, box_img, src_polys)
  54. if box is not None:
  55. box = box.tolist()
  56. for _ in range(len(box), point_num):
  57. box.append(box[-1])
  58. src_polys.append(box)
  59. src_texts.append(ext_texts[idx])
  60. src_ignores.append(tag)
  61. src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR)
  62. h, w = src_img.shape[:2]
  63. src_polys = np.array(src_polys)
  64. src_polys[:, :, 0] = np.clip(src_polys[:, :, 0], 0, w)
  65. src_polys[:, :, 1] = np.clip(src_polys[:, :, 1], 0, h)
  66. data['image'] = src_img
  67. data['polys'] = src_polys
  68. data['texts'] = src_texts
  69. data['ignore_tags'] = np.array(src_ignores)
  70. return data
  71. def paste_img(self, src_img, box_img, src_polys):
  72. box_img_pil = Image.fromarray(box_img).convert('RGBA')
  73. src_w, src_h = src_img.size
  74. box_w, box_h = box_img_pil.size
  75. angle = np.random.randint(0, 360)
  76. box = np.array([[[0, 0], [box_w, 0], [box_w, box_h], [0, box_h]]])
  77. box = rotate_bbox(box_img, box, angle)[0]
  78. box_img_pil = box_img_pil.rotate(angle, expand=1)
  79. box_w, box_h = box_img_pil.width, box_img_pil.height
  80. if src_w - box_w < 0 or src_h - box_h < 0:
  81. return src_img, None
  82. paste_x, paste_y = self.select_coord(src_polys, box, src_w - box_w,
  83. src_h - box_h)
  84. if paste_x is None:
  85. return src_img, None
  86. box[:, 0] += paste_x
  87. box[:, 1] += paste_y
  88. r, g, b, A = box_img_pil.split()
  89. src_img.paste(box_img_pil, (paste_x, paste_y), mask=A)
  90. return src_img, box
  91. def select_coord(self, src_polys, box, endx, endy):
  92. if self.limit_paste:
  93. xmin, ymin, xmax, ymax = box[:, 0].min(), box[:, 1].min(
  94. ), box[:, 0].max(), box[:, 1].max()
  95. for _ in range(50):
  96. paste_x = random.randint(0, endx)
  97. paste_y = random.randint(0, endy)
  98. xmin1 = xmin + paste_x
  99. xmax1 = xmax + paste_x
  100. ymin1 = ymin + paste_y
  101. ymax1 = ymax + paste_y
  102. num_poly_in_rect = 0
  103. for poly in src_polys:
  104. if not is_poly_outside_rect(poly, xmin1, ymin1,
  105. xmax1 - xmin1, ymax1 - ymin1):
  106. num_poly_in_rect += 1
  107. break
  108. if num_poly_in_rect == 0:
  109. return paste_x, paste_y
  110. return None, None
  111. else:
  112. paste_x = random.randint(0, endx)
  113. paste_y = random.randint(0, endy)
  114. return paste_x, paste_y
  115. def get_union(pD, pG):
  116. return Polygon(pD).union(Polygon(pG)).area
  117. def get_intersection_over_union(pD, pG):
  118. return get_intersection(pD, pG) / get_union(pD, pG)
  119. def get_intersection(pD, pG):
  120. return Polygon(pD).intersection(Polygon(pG)).area
  121. def rotate_bbox(img, text_polys, angle, scale=1):
  122. """
  123. from https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/augment.py
  124. Args:
  125. img: np.ndarray
  126. text_polys: np.ndarray N*4*2
  127. angle: int
  128. scale: int
  129. Returns:
  130. """
  131. w = img.shape[1]
  132. h = img.shape[0]
  133. rangle = np.deg2rad(angle)
  134. nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w))
  135. nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w))
  136. rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, scale)
  137. rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
  138. rot_mat[0, 2] += rot_move[0]
  139. rot_mat[1, 2] += rot_move[1]
  140. # ---------------------- rotate box ----------------------
  141. rot_text_polys = list()
  142. for bbox in text_polys:
  143. point1 = np.dot(rot_mat, np.array([bbox[0, 0], bbox[0, 1], 1]))
  144. point2 = np.dot(rot_mat, np.array([bbox[1, 0], bbox[1, 1], 1]))
  145. point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1]))
  146. point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1]))
  147. rot_text_polys.append([point1, point2, point3, point4])
  148. return np.array(rot_text_polys, dtype=np.float32)