operators.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524
  1. """
  2. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. from __future__ import unicode_literals
  20. import sys
  21. import six
  22. import cv2
  23. import numpy as np
  24. import math
  25. from PIL import Image
  26. class DecodeImage(object):
  27. """ decode image """
  28. def __init__(self,
  29. img_mode='RGB',
  30. channel_first=False,
  31. ignore_orientation=False,
  32. **kwargs):
  33. self.img_mode = img_mode
  34. self.channel_first = channel_first
  35. self.ignore_orientation = ignore_orientation
  36. def __call__(self, data):
  37. img = data['image']
  38. if six.PY2:
  39. assert type(img) is str and len(
  40. img) > 0, "invalid input 'img' in DecodeImage"
  41. else:
  42. assert type(img) is bytes and len(
  43. img) > 0, "invalid input 'img' in DecodeImage"
  44. img = np.frombuffer(img, dtype='uint8')
  45. if self.ignore_orientation:
  46. img = cv2.imdecode(img, cv2.IMREAD_IGNORE_ORIENTATION |
  47. cv2.IMREAD_COLOR)
  48. else:
  49. img = cv2.imdecode(img, 1)
  50. if img is None:
  51. return None
  52. if self.img_mode == 'GRAY':
  53. img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
  54. elif self.img_mode == 'RGB':
  55. assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
  56. img = img[:, :, ::-1]
  57. if self.channel_first:
  58. img = img.transpose((2, 0, 1))
  59. data['image'] = img
  60. return data
  61. class NormalizeImage(object):
  62. """ normalize image such as substract mean, divide std
  63. """
  64. def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
  65. if isinstance(scale, str):
  66. scale = eval(scale)
  67. self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
  68. mean = mean if mean is not None else [0.485, 0.456, 0.406]
  69. std = std if std is not None else [0.229, 0.224, 0.225]
  70. shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
  71. self.mean = np.array(mean).reshape(shape).astype('float32')
  72. self.std = np.array(std).reshape(shape).astype('float32')
  73. def __call__(self, data):
  74. img = data['image']
  75. from PIL import Image
  76. if isinstance(img, Image.Image):
  77. img = np.array(img)
  78. assert isinstance(img,
  79. np.ndarray), "invalid input 'img' in NormalizeImage"
  80. data['image'] = (
  81. img.astype('float32') * self.scale - self.mean) / self.std
  82. return data
  83. class ToCHWImage(object):
  84. """ convert hwc image to chw image
  85. """
  86. def __init__(self, **kwargs):
  87. pass
  88. def __call__(self, data):
  89. img = data['image']
  90. from PIL import Image
  91. if isinstance(img, Image.Image):
  92. img = np.array(img)
  93. data['image'] = img.transpose((2, 0, 1))
  94. return data
  95. class Fasttext(object):
  96. def __init__(self, path="None", **kwargs):
  97. import fasttext
  98. self.fast_model = fasttext.load_model(path)
  99. def __call__(self, data):
  100. label = data['label']
  101. fast_label = self.fast_model[label]
  102. data['fast_label'] = fast_label
  103. return data
  104. class KeepKeys(object):
  105. def __init__(self, keep_keys, **kwargs):
  106. self.keep_keys = keep_keys
  107. def __call__(self, data):
  108. data_list = []
  109. for key in self.keep_keys:
  110. data_list.append(data[key])
  111. return data_list
  112. class Pad(object):
  113. def __init__(self, size=None, size_div=32, **kwargs):
  114. if size is not None and not isinstance(size, (int, list, tuple)):
  115. raise TypeError("Type of target_size is invalid. Now is {}".format(
  116. type(size)))
  117. if isinstance(size, int):
  118. size = [size, size]
  119. self.size = size
  120. self.size_div = size_div
  121. def __call__(self, data):
  122. img = data['image']
  123. img_h, img_w = img.shape[0], img.shape[1]
  124. if self.size:
  125. resize_h2, resize_w2 = self.size
  126. assert (
  127. img_h < resize_h2 and img_w < resize_w2
  128. ), '(h, w) of target size should be greater than (img_h, img_w)'
  129. else:
  130. resize_h2 = max(
  131. int(math.ceil(img.shape[0] / self.size_div) * self.size_div),
  132. self.size_div)
  133. resize_w2 = max(
  134. int(math.ceil(img.shape[1] / self.size_div) * self.size_div),
  135. self.size_div)
  136. img = cv2.copyMakeBorder(
  137. img,
  138. 0,
  139. resize_h2 - img_h,
  140. 0,
  141. resize_w2 - img_w,
  142. cv2.BORDER_CONSTANT,
  143. value=0)
  144. data['image'] = img
  145. return data
  146. class Resize(object):
  147. def __init__(self, size=(640, 640), **kwargs):
  148. self.size = size
  149. def resize_image(self, img):
  150. resize_h, resize_w = self.size
  151. ori_h, ori_w = img.shape[:2] # (h, w, c)
  152. ratio_h = float(resize_h) / ori_h
  153. ratio_w = float(resize_w) / ori_w
  154. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  155. return img, [ratio_h, ratio_w]
  156. def __call__(self, data):
  157. img = data['image']
  158. if 'polys' in data:
  159. text_polys = data['polys']
  160. img_resize, [ratio_h, ratio_w] = self.resize_image(img)
  161. if 'polys' in data:
  162. new_boxes = []
  163. for box in text_polys:
  164. new_box = []
  165. for cord in box:
  166. new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
  167. new_boxes.append(new_box)
  168. data['polys'] = np.array(new_boxes, dtype=np.float32)
  169. data['image'] = img_resize
  170. return data
  171. class DetResizeForTest(object):
  172. def __init__(self, **kwargs):
  173. super(DetResizeForTest, self).__init__()
  174. self.resize_type = 0
  175. self.keep_ratio = False
  176. if 'image_shape' in kwargs:
  177. self.image_shape = kwargs['image_shape']
  178. self.resize_type = 1
  179. if 'keep_ratio' in kwargs:
  180. self.keep_ratio = kwargs['keep_ratio']
  181. elif 'limit_side_len' in kwargs:
  182. self.limit_side_len = kwargs['limit_side_len']
  183. self.limit_type = kwargs.get('limit_type', 'min')
  184. elif 'resize_long' in kwargs:
  185. self.resize_type = 2
  186. self.resize_long = kwargs.get('resize_long', 960)
  187. else:
  188. self.limit_side_len = 736
  189. self.limit_type = 'min'
  190. def __call__(self, data):
  191. img = data['image']
  192. src_h, src_w, _ = img.shape
  193. if sum([src_h, src_w]) < 64:
  194. img = self.image_padding(img)
  195. if self.resize_type == 0:
  196. # img, shape = self.resize_image_type0(img)
  197. img, [ratio_h, ratio_w] = self.resize_image_type0(img)
  198. elif self.resize_type == 2:
  199. img, [ratio_h, ratio_w] = self.resize_image_type2(img)
  200. else:
  201. # img, shape = self.resize_image_type1(img)
  202. img, [ratio_h, ratio_w] = self.resize_image_type1(img)
  203. data['image'] = img
  204. data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
  205. return data
  206. def image_padding(self, im, value=0):
  207. h, w, c = im.shape
  208. im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value
  209. im_pad[:h, :w, :] = im
  210. return im_pad
  211. def resize_image_type1(self, img):
  212. resize_h, resize_w = self.image_shape
  213. ori_h, ori_w = img.shape[:2] # (h, w, c)
  214. if self.keep_ratio is True:
  215. resize_w = ori_w * resize_h / ori_h
  216. N = math.ceil(resize_w / 32)
  217. resize_w = N * 32
  218. ratio_h = float(resize_h) / ori_h
  219. ratio_w = float(resize_w) / ori_w
  220. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  221. # return img, np.array([ori_h, ori_w])
  222. return img, [ratio_h, ratio_w]
  223. def resize_image_type0(self, img):
  224. """
  225. resize image to a size multiple of 32 which is required by the network
  226. args:
  227. img(array): array with shape [h, w, c]
  228. return(tuple):
  229. img, (ratio_h, ratio_w)
  230. """
  231. limit_side_len = self.limit_side_len
  232. h, w, c = img.shape
  233. # limit the max side
  234. if self.limit_type == 'max':
  235. if max(h, w) > limit_side_len:
  236. if h > w:
  237. ratio = float(limit_side_len) / h
  238. else:
  239. ratio = float(limit_side_len) / w
  240. else:
  241. ratio = 1.
  242. elif self.limit_type == 'min':
  243. if min(h, w) < limit_side_len:
  244. if h < w:
  245. ratio = float(limit_side_len) / h
  246. else:
  247. ratio = float(limit_side_len) / w
  248. else:
  249. ratio = 1.
  250. elif self.limit_type == 'resize_long':
  251. ratio = float(limit_side_len) / max(h, w)
  252. else:
  253. raise Exception('not support limit type, image ')
  254. resize_h = int(h * ratio)
  255. resize_w = int(w * ratio)
  256. resize_h = max(int(round(resize_h / 32) * 32), 32)
  257. resize_w = max(int(round(resize_w / 32) * 32), 32)
  258. try:
  259. if int(resize_w) <= 0 or int(resize_h) <= 0:
  260. return None, (None, None)
  261. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  262. except:
  263. print(img.shape, resize_w, resize_h)
  264. sys.exit(0)
  265. ratio_h = resize_h / float(h)
  266. ratio_w = resize_w / float(w)
  267. return img, [ratio_h, ratio_w]
  268. def resize_image_type2(self, img):
  269. h, w, _ = img.shape
  270. resize_w = w
  271. resize_h = h
  272. if resize_h > resize_w:
  273. ratio = float(self.resize_long) / resize_h
  274. else:
  275. ratio = float(self.resize_long) / resize_w
  276. resize_h = int(resize_h * ratio)
  277. resize_w = int(resize_w * ratio)
  278. max_stride = 128
  279. resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
  280. resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
  281. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  282. ratio_h = resize_h / float(h)
  283. ratio_w = resize_w / float(w)
  284. return img, [ratio_h, ratio_w]
  285. class E2EResizeForTest(object):
  286. def __init__(self, **kwargs):
  287. super(E2EResizeForTest, self).__init__()
  288. self.max_side_len = kwargs['max_side_len']
  289. self.valid_set = kwargs['valid_set']
  290. def __call__(self, data):
  291. img = data['image']
  292. src_h, src_w, _ = img.shape
  293. if self.valid_set == 'totaltext':
  294. im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext(
  295. img, max_side_len=self.max_side_len)
  296. else:
  297. im_resized, (ratio_h, ratio_w) = self.resize_image(
  298. img, max_side_len=self.max_side_len)
  299. data['image'] = im_resized
  300. data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
  301. return data
  302. def resize_image_for_totaltext(self, im, max_side_len=512):
  303. h, w, _ = im.shape
  304. resize_w = w
  305. resize_h = h
  306. ratio = 1.25
  307. if h * ratio > max_side_len:
  308. ratio = float(max_side_len) / resize_h
  309. resize_h = int(resize_h * ratio)
  310. resize_w = int(resize_w * ratio)
  311. max_stride = 128
  312. resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
  313. resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
  314. im = cv2.resize(im, (int(resize_w), int(resize_h)))
  315. ratio_h = resize_h / float(h)
  316. ratio_w = resize_w / float(w)
  317. return im, (ratio_h, ratio_w)
  318. def resize_image(self, im, max_side_len=512):
  319. """
  320. resize image to a size multiple of max_stride which is required by the network
  321. :param im: the resized image
  322. :param max_side_len: limit of max image size to avoid out of memory in gpu
  323. :return: the resized image and the resize ratio
  324. """
  325. h, w, _ = im.shape
  326. resize_w = w
  327. resize_h = h
  328. # Fix the longer side
  329. if resize_h > resize_w:
  330. ratio = float(max_side_len) / resize_h
  331. else:
  332. ratio = float(max_side_len) / resize_w
  333. resize_h = int(resize_h * ratio)
  334. resize_w = int(resize_w * ratio)
  335. max_stride = 128
  336. resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
  337. resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
  338. im = cv2.resize(im, (int(resize_w), int(resize_h)))
  339. ratio_h = resize_h / float(h)
  340. ratio_w = resize_w / float(w)
  341. return im, (ratio_h, ratio_w)
  342. class KieResize(object):
  343. def __init__(self, **kwargs):
  344. super(KieResize, self).__init__()
  345. self.max_side, self.min_side = kwargs['img_scale'][0], kwargs[
  346. 'img_scale'][1]
  347. def __call__(self, data):
  348. img = data['image']
  349. points = data['points']
  350. src_h, src_w, _ = img.shape
  351. im_resized, scale_factor, [ratio_h, ratio_w
  352. ], [new_h, new_w] = self.resize_image(img)
  353. resize_points = self.resize_boxes(img, points, scale_factor)
  354. data['ori_image'] = img
  355. data['ori_boxes'] = points
  356. data['points'] = resize_points
  357. data['image'] = im_resized
  358. data['shape'] = np.array([new_h, new_w])
  359. return data
  360. def resize_image(self, img):
  361. norm_img = np.zeros([1024, 1024, 3], dtype='float32')
  362. scale = [512, 1024]
  363. h, w = img.shape[:2]
  364. max_long_edge = max(scale)
  365. max_short_edge = min(scale)
  366. scale_factor = min(max_long_edge / max(h, w),
  367. max_short_edge / min(h, w))
  368. resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(h * float(
  369. scale_factor) + 0.5)
  370. max_stride = 32
  371. resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
  372. resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
  373. im = cv2.resize(img, (resize_w, resize_h))
  374. new_h, new_w = im.shape[:2]
  375. w_scale = new_w / w
  376. h_scale = new_h / h
  377. scale_factor = np.array(
  378. [w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
  379. norm_img[:new_h, :new_w, :] = im
  380. return norm_img, scale_factor, [h_scale, w_scale], [new_h, new_w]
  381. def resize_boxes(self, im, points, scale_factor):
  382. points = points * scale_factor
  383. img_shape = im.shape[:2]
  384. points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1])
  385. points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0])
  386. return points
  387. class SRResize(object):
  388. def __init__(self,
  389. imgH=32,
  390. imgW=128,
  391. down_sample_scale=4,
  392. keep_ratio=False,
  393. min_ratio=1,
  394. mask=False,
  395. infer_mode=False,
  396. **kwargs):
  397. self.imgH = imgH
  398. self.imgW = imgW
  399. self.keep_ratio = keep_ratio
  400. self.min_ratio = min_ratio
  401. self.down_sample_scale = down_sample_scale
  402. self.mask = mask
  403. self.infer_mode = infer_mode
  404. def __call__(self, data):
  405. imgH = self.imgH
  406. imgW = self.imgW
  407. images_lr = data["image_lr"]
  408. transform2 = ResizeNormalize(
  409. (imgW // self.down_sample_scale, imgH // self.down_sample_scale))
  410. images_lr = transform2(images_lr)
  411. data["img_lr"] = images_lr
  412. if self.infer_mode:
  413. return data
  414. images_HR = data["image_hr"]
  415. label_strs = data["label"]
  416. transform = ResizeNormalize((imgW, imgH))
  417. images_HR = transform(images_HR)
  418. data["img_hr"] = images_HR
  419. return data
  420. class ResizeNormalize(object):
  421. def __init__(self, size, interpolation=Image.BICUBIC):
  422. self.size = size
  423. self.interpolation = interpolation
  424. def __call__(self, img):
  425. img = img.resize(self.size, self.interpolation)
  426. img_numpy = np.array(img).astype("float32")
  427. img_numpy = img_numpy.transpose((2, 0, 1)) / 255
  428. return img_numpy
  429. class GrayImageChannelFormat(object):
  430. """
  431. format gray scale image's channel: (3,h,w) -> (1,h,w)
  432. Args:
  433. inverse: inverse gray image
  434. """
  435. def __init__(self, inverse=False, **kwargs):
  436. self.inverse = inverse
  437. def __call__(self, data):
  438. img = data['image']
  439. img_single_channel = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  440. img_expanded = np.expand_dims(img_single_channel, 0)
  441. if self.inverse:
  442. data['image'] = np.abs(img_expanded - 1)
  443. else:
  444. data['image'] = img_expanded
  445. data['src_image'] = img
  446. return data