ocr_reader.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import cv2
  15. import copy
  16. import numpy as np
  17. import math
  18. import re
  19. import sys
  20. import argparse
  21. import string
  22. from copy import deepcopy
  23. class DetResizeForTest(object):
  24. def __init__(self, **kwargs):
  25. super(DetResizeForTest, self).__init__()
  26. self.resize_type = 0
  27. if 'image_shape' in kwargs:
  28. self.image_shape = kwargs['image_shape']
  29. self.resize_type = 1
  30. elif 'limit_side_len' in kwargs:
  31. self.limit_side_len = kwargs['limit_side_len']
  32. self.limit_type = kwargs.get('limit_type', 'min')
  33. elif 'resize_short' in kwargs:
  34. self.limit_side_len = 736
  35. self.limit_type = 'min'
  36. else:
  37. self.resize_type = 2
  38. self.resize_long = kwargs.get('resize_long', 960)
  39. def __call__(self, data):
  40. img = deepcopy(data)
  41. src_h, src_w, _ = img.shape
  42. if self.resize_type == 0:
  43. img, [ratio_h, ratio_w] = self.resize_image_type0(img)
  44. elif self.resize_type == 2:
  45. img, [ratio_h, ratio_w] = self.resize_image_type2(img)
  46. else:
  47. img, [ratio_h, ratio_w] = self.resize_image_type1(img)
  48. return img
  49. def resize_image_type1(self, img):
  50. resize_h, resize_w = self.image_shape
  51. ori_h, ori_w = img.shape[:2] # (h, w, c)
  52. ratio_h = float(resize_h) / ori_h
  53. ratio_w = float(resize_w) / ori_w
  54. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  55. return img, [ratio_h, ratio_w]
  56. def resize_image_type0(self, img):
  57. """
  58. resize image to a size multiple of 32 which is required by the network
  59. args:
  60. img(array): array with shape [h, w, c]
  61. return(tuple):
  62. img, (ratio_h, ratio_w)
  63. """
  64. limit_side_len = self.limit_side_len
  65. h, w, _ = img.shape
  66. # limit the max side
  67. if self.limit_type == 'max':
  68. if max(h, w) > limit_side_len:
  69. if h > w:
  70. ratio = float(limit_side_len) / h
  71. else:
  72. ratio = float(limit_side_len) / w
  73. else:
  74. ratio = 1.
  75. else:
  76. if min(h, w) < limit_side_len:
  77. if h < w:
  78. ratio = float(limit_side_len) / h
  79. else:
  80. ratio = float(limit_side_len) / w
  81. else:
  82. ratio = 1.
  83. resize_h = int(h * ratio)
  84. resize_w = int(w * ratio)
  85. resize_h = int(round(resize_h / 32) * 32)
  86. resize_w = int(round(resize_w / 32) * 32)
  87. try:
  88. if int(resize_w) <= 0 or int(resize_h) <= 0:
  89. return None, (None, None)
  90. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  91. except:
  92. print(img.shape, resize_w, resize_h)
  93. sys.exit(0)
  94. ratio_h = resize_h / float(h)
  95. ratio_w = resize_w / float(w)
  96. # return img, np.array([h, w])
  97. return img, [ratio_h, ratio_w]
  98. def resize_image_type2(self, img):
  99. h, w, _ = img.shape
  100. resize_w = w
  101. resize_h = h
  102. # Fix the longer side
  103. if resize_h > resize_w:
  104. ratio = float(self.resize_long) / resize_h
  105. else:
  106. ratio = float(self.resize_long) / resize_w
  107. resize_h = int(resize_h * ratio)
  108. resize_w = int(resize_w * ratio)
  109. max_stride = 128
  110. resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
  111. resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
  112. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  113. ratio_h = resize_h / float(h)
  114. ratio_w = resize_w / float(w)
  115. return img, [ratio_h, ratio_w]
  116. class BaseRecLabelDecode(object):
  117. """ Convert between text-label and text-index """
  118. def __init__(self, config):
  119. support_character_type = [
  120. 'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
  121. 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc',
  122. 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr',
  123. 'ne', 'EN'
  124. ]
  125. character_type = config['character_type']
  126. character_dict_path = config['character_dict_path']
  127. use_space_char = True
  128. assert character_type in support_character_type, "Only {} are supported now but get {}".format(
  129. support_character_type, character_type)
  130. self.beg_str = "sos"
  131. self.end_str = "eos"
  132. if character_type == "en":
  133. self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
  134. dict_character = list(self.character_str)
  135. elif character_type == "EN_symbol":
  136. # same with ASTER setting (use 94 char).
  137. self.character_str = string.printable[:-6]
  138. dict_character = list(self.character_str)
  139. elif character_type in support_character_type:
  140. self.character_str = ""
  141. assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
  142. character_type)
  143. with open(character_dict_path, "rb") as fin:
  144. lines = fin.readlines()
  145. for line in lines:
  146. line = line.decode('utf-8').strip("\n").strip("\r\n")
  147. self.character_str += line
  148. if use_space_char:
  149. self.character_str += " "
  150. dict_character = list(self.character_str)
  151. else:
  152. raise NotImplementedError
  153. self.character_type = character_type
  154. dict_character = self.add_special_char(dict_character)
  155. self.dict = {}
  156. for i, char in enumerate(dict_character):
  157. self.dict[char] = i
  158. self.character = dict_character
  159. def add_special_char(self, dict_character):
  160. return dict_character
  161. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  162. """ convert text-index into text-label. """
  163. result_list = []
  164. ignored_tokens = self.get_ignored_tokens()
  165. batch_size = len(text_index)
  166. for batch_idx in range(batch_size):
  167. char_list = []
  168. conf_list = []
  169. for idx in range(len(text_index[batch_idx])):
  170. if text_index[batch_idx][idx] in ignored_tokens:
  171. continue
  172. if is_remove_duplicate:
  173. # only for predict
  174. if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
  175. batch_idx][idx]:
  176. continue
  177. char_list.append(self.character[int(text_index[batch_idx][
  178. idx])])
  179. if text_prob is not None:
  180. conf_list.append(text_prob[batch_idx][idx])
  181. else:
  182. conf_list.append(1)
  183. text = ''.join(char_list)
  184. result_list.append((text, np.mean(conf_list)))
  185. return result_list
  186. def get_ignored_tokens(self):
  187. return [0] # for ctc blank
  188. class CTCLabelDecode(BaseRecLabelDecode):
  189. """ Convert between text-label and text-index """
  190. def __init__(
  191. self,
  192. config,
  193. #character_dict_path=None,
  194. #character_type='ch',
  195. #use_space_char=False,
  196. **kwargs):
  197. super(CTCLabelDecode, self).__init__(config)
  198. def __call__(self, preds, label=None, *args, **kwargs):
  199. preds_idx = preds.argmax(axis=2)
  200. preds_prob = preds.max(axis=2)
  201. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
  202. if label is None:
  203. return text
  204. label = self.decode(label)
  205. return text, label
  206. def add_special_char(self, dict_character):
  207. dict_character = ['blank'] + dict_character
  208. return dict_character
  209. class CharacterOps(object):
  210. """ Convert between text-label and text-index """
  211. def __init__(self, config):
  212. self.character_type = config['character_type']
  213. self.loss_type = config['loss_type']
  214. if self.character_type == "en":
  215. self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
  216. dict_character = list(self.character_str)
  217. elif self.character_type == "ch":
  218. character_dict_path = config['character_dict_path']
  219. self.character_str = ""
  220. with open(character_dict_path, "rb") as fin:
  221. lines = fin.readlines()
  222. for line in lines:
  223. line = line.decode('utf-8').strip("\n").strip("\r\n")
  224. self.character_str += line
  225. dict_character = list(self.character_str)
  226. elif self.character_type == "en_sensitive":
  227. # same with ASTER setting (use 94 char).
  228. self.character_str = string.printable[:-6]
  229. dict_character = list(self.character_str)
  230. else:
  231. self.character_str = None
  232. assert self.character_str is not None, \
  233. "Nonsupport type of the character: {}".format(self.character_str)
  234. self.beg_str = "sos"
  235. self.end_str = "eos"
  236. if self.loss_type == "attention":
  237. dict_character = [self.beg_str, self.end_str] + dict_character
  238. self.dict = {}
  239. for i, char in enumerate(dict_character):
  240. self.dict[char] = i
  241. self.character = dict_character
  242. def encode(self, text):
  243. """convert text-label into text-index.
  244. input:
  245. text: text labels of each image. [batch_size]
  246. output:
  247. text: concatenated text index for CTCLoss.
  248. [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
  249. length: length of each text. [batch_size]
  250. """
  251. if self.character_type == "en":
  252. text = text.lower()
  253. text_list = []
  254. for char in text:
  255. if char not in self.dict:
  256. continue
  257. text_list.append(self.dict[char])
  258. text = np.array(text_list)
  259. return text
  260. def decode(self, text_index, is_remove_duplicate=False):
  261. """ convert text-index into text-label. """
  262. char_list = []
  263. char_num = self.get_char_num()
  264. if self.loss_type == "attention":
  265. beg_idx = self.get_beg_end_flag_idx("beg")
  266. end_idx = self.get_beg_end_flag_idx("end")
  267. ignored_tokens = [beg_idx, end_idx]
  268. else:
  269. ignored_tokens = [char_num]
  270. for idx in range(len(text_index)):
  271. if text_index[idx] in ignored_tokens:
  272. continue
  273. if is_remove_duplicate:
  274. if idx > 0 and text_index[idx - 1] == text_index[idx]:
  275. continue
  276. char_list.append(self.character[text_index[idx]])
  277. text = ''.join(char_list)
  278. return text
  279. def get_char_num(self):
  280. return len(self.character)
  281. def get_beg_end_flag_idx(self, beg_or_end):
  282. if self.loss_type == "attention":
  283. if beg_or_end == "beg":
  284. idx = np.array(self.dict[self.beg_str])
  285. elif beg_or_end == "end":
  286. idx = np.array(self.dict[self.end_str])
  287. else:
  288. assert False, "Unsupport type %s in get_beg_end_flag_idx"\
  289. % beg_or_end
  290. return idx
  291. else:
  292. err = "error in get_beg_end_flag_idx when using the loss %s"\
  293. % (self.loss_type)
  294. assert False, err
  295. class OCRReader(object):
  296. def __init__(self,
  297. algorithm="CRNN",
  298. image_shape=[3, 32, 320],
  299. char_type="ch",
  300. batch_num=1,
  301. char_dict_path="./ppocr_keys_v1.txt"):
  302. self.rec_image_shape = image_shape
  303. self.character_type = char_type
  304. self.rec_batch_num = batch_num
  305. char_ops_params = {}
  306. char_ops_params["character_type"] = char_type
  307. char_ops_params["character_dict_path"] = char_dict_path
  308. char_ops_params['loss_type'] = 'ctc'
  309. self.char_ops = CharacterOps(char_ops_params)
  310. self.label_ops = CTCLabelDecode(char_ops_params)
  311. def resize_norm_img(self, img, max_wh_ratio):
  312. imgC, imgH, imgW = self.rec_image_shape
  313. if self.character_type == "ch":
  314. imgW = int(32 * max_wh_ratio)
  315. h = img.shape[0]
  316. w = img.shape[1]
  317. ratio = w / float(h)
  318. if math.ceil(imgH * ratio) > imgW:
  319. resized_w = imgW
  320. else:
  321. resized_w = int(math.ceil(imgH * ratio))
  322. resized_image = cv2.resize(img, (resized_w, imgH))
  323. resized_image = resized_image.astype('float32')
  324. resized_image = resized_image.transpose((2, 0, 1)) / 255
  325. resized_image -= 0.5
  326. resized_image /= 0.5
  327. padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
  328. padding_im[:, :, 0:resized_w] = resized_image
  329. return padding_im
  330. def preprocess(self, img_list):
  331. img_num = len(img_list)
  332. norm_img_batch = []
  333. max_wh_ratio = 0
  334. for ino in range(img_num):
  335. h, w = img_list[ino].shape[0:2]
  336. wh_ratio = w * 1.0 / h
  337. max_wh_ratio = max(max_wh_ratio, wh_ratio)
  338. for ino in range(img_num):
  339. norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio)
  340. norm_img = norm_img[np.newaxis, :]
  341. norm_img_batch.append(norm_img)
  342. norm_img_batch = np.concatenate(norm_img_batch)
  343. norm_img_batch = norm_img_batch.copy()
  344. return norm_img_batch[0]
  345. def postprocess(self, outputs, with_score=False):
  346. preds = outputs["softmax_5.tmp_0"]
  347. try:
  348. preds = preds.numpy()
  349. except:
  350. pass
  351. preds_idx = preds.argmax(axis=2)
  352. preds_prob = preds.max(axis=2)
  353. text = self.label_ops.decode(
  354. preds_idx, preds_prob, is_remove_duplicate=True)
  355. return text