server.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. from venv import logger
  2. from flask import request, Flask, request
  3. from shutil import copy
  4. import os
  5. from typing import Dict, List, Union
  6. from urllib.request import urlretrieve
  7. from pathlib import Path
  8. from paddleocr import PaddleOCR
  9. from hashlib import md5
  10. import time
  11. import base64
  12. import io
  13. import requests
  14. import json
  15. from pathlib import Path
  16. # ocr = PaddleOCR(use_angle_cls=True, lang="ch", use_gpu=False)
  17. app = Flask(__name__)
  18. img_dir = Path('imgs')
  19. img_dir.mkdir(exist_ok=True)
  20. ocr_cache_dict = dict()
  21. ocr = PaddleOCR(use_angle_cls=True, lang="ch", rec=True, use_gpu=True)
  22. def debugg_log(content: str):
  23. with open("log.txt","a") as file:
  24. file.write(content)
  25. def get_dict_from_request() -> dict:
  26. """
  27. get json data from request as much as possible
  28. Returns
  29. -------
  30. dict
  31. request data in dict format
  32. """
  33. json = {**request.args}
  34. if request.json:
  35. json = {**json, **request.json}
  36. if request.form:
  37. json = {**json, **request.form.to_dict()}
  38. return json
  39. def download_image(img_url: str) -> str:
  40. """
  41. download image or copy image to local from url
  42. Parameters
  43. ----------
  44. img_url : str
  45. url of image to be downloaded
  46. Returns
  47. -------
  48. str
  49. local file path of image
  50. Notes
  51. -----
  52. if download failed, empty string `''` will be returned
  53. """
  54. d = md5(str(img_url).encode()).hexdigest()
  55. file_name = f'{img_dir}/{d}.jpg'
  56. # NOTE: insecurity
  57. # # copy from local file system in the running container
  58. # if Path(img_url).exists():
  59. # copy(img_url, file_name)
  60. if Path(file_name).exists():
  61. return file_name
  62. # download from internet
  63. try:
  64. urlretrieve(img_url, file_name)
  65. return file_name
  66. except:
  67. return ''
  68. def base64_to_file(s: Union[str, bytes]) -> str:
  69. """
  70. decode base64 string or bytes and save to local file system
  71. Parameters
  72. ----------
  73. s : Union[str, bytes]
  74. base64 string or bytes
  75. Returns
  76. -------
  77. str
  78. local file path of base64 data
  79. """
  80. d = md5(str(s).encode()).hexdigest()
  81. file_name = f'{img_dir}/{d}.jpg'
  82. if Path(file_name).exists():
  83. return file_name
  84. if isinstance(s, str):
  85. b = base64.decodebytes(s.encode())
  86. elif isinstance(s, bytes):
  87. b = base64.decodebytes(s)
  88. else:
  89. return ''
  90. with open(file_name, 'wb') as f:
  91. f.write(b)
  92. return file_name
  93. @app.route('/api/ocr_hello', methods=['POST'])
  94. def get_content() -> str:
  95. return 'hello'
  96. @app.route('/api/ocr_extract', methods=['POST'])
  97. def ocr_extract() -> None:
  98. img = request.files.get('file')
  99. print(img)
  100. # 将img保存到当前目录下
  101. img_path = img.filename
  102. content = ''
  103. # json_data = request.get_json()
  104. # path = json_data['path']
  105. if not img_path:
  106. return {
  107. 'success': False,
  108. 'time_cost': 155,
  109. 'results': '解析失败',
  110. 'msg': '解析失败'
  111. }
  112. img.save(img_path)
  113. st = time.perf_counter()
  114. result = ocr.ocr(img_path, cls=True)
  115. app.logger.info(f'leng is {len(result)}')
  116. if (len(result) == 1):
  117. for idx in range(len(result[0])):
  118. res = result[0][idx]
  119. content += res[1][0]
  120. else:
  121. for idx in range(len(result)):
  122. res = result[idx]
  123. content += res[1][0]
  124. return {
  125. 'success': True,
  126. 'time_cost': format(time.perf_counter() - st),
  127. 'results': content,
  128. 'msg': '解析完毕'
  129. }
  130. def timer(func):
  131. """装饰器:打印函数耗时"""
  132. def decorated(*args, **kwargs):
  133. st = time.perf_counter()
  134. ret = func(*args, **kwargs)
  135. print('time cost: {} seconds'.format(time.perf_counter() - st))
  136. return ret
  137. return decorated
  138. @app.route('/api/ocr_dec', methods=['POST'])
  139. def ocr_text_extract() -> None:
  140. """
  141. ocr web api that accept image url, image path and base64 data of image
  142. """
  143. app.logger.info(f'receive request')
  144. print('receive request 2')
  145. st = time.time()
  146. json = get_dict_from_request()
  147. app.logger.info(f'request dict is {json}')
  148. img_url: str = json.get('img_url')
  149. app.logger.info(f'request url is {img_url}')
  150. base64_data: str = json.get('img_base64')
  151. img_path = ''
  152. if img_url:
  153. img_path = download_image(img_url)
  154. elif base64_data:
  155. img_path = base64_to_file(base64_data)
  156. if not img_path:
  157. et = time.time()
  158. return {
  159. 'success': False,
  160. 'time_cost': et-st,
  161. 'results': [],
  162. 'msg': 'maybe img_url or img_base64 is wrong'
  163. }
  164. results = ocr_cache_dict.get(img_path)
  165. if not results:
  166. ocr_result_list = ocr.ocr(img_path)
  167. et = time.time()
  168. if ocr_result_list is None:
  169. ocr_result_list = []
  170. os.remove(img_path)
  171. else:
  172. # make sure float32 can be JSON serializable
  173. ocr_result_list: list = eval(str(ocr_result_list))
  174. results: List[Dict] = []
  175. for each in ocr_result_list:
  176. item = {
  177. 'confidence': each[-1][1],
  178. 'text': each[-1][0],
  179. 'text_region': each[:-1]
  180. }
  181. results.append(item)
  182. ocr_cache_dict[img_path] = results
  183. et = time.time()
  184. return {
  185. 'success': True,
  186. 'time_cost': et-st,
  187. 'results': results,
  188. 'msg': ''
  189. }
  190. # def get_file_from_url(url_file: str, target_path: str):
  191. # send_headers = {
  192. # "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/61.0.3163.100 Safari/537.36",
  193. # "Connection": "keep-alive",
  194. # "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8",
  195. # "Accept-Language": "zh-CN,zh;q=0.8"
  196. # }
  197. # req = requests.get(url_file, headers=send_headers, timeout=10)
  198. # bytes_io = io.BytesIO(req.content)
  199. # (_, file_name) = os.path.split(url_file)
  200. # if not os.path.exists(target_path):
  201. # os.mkdir(target_path)
  202. # target_path = target_path + file_name
  203. # with open(target_path, 'wb') as file:
  204. # file.write(bytes_io.getvalue())
  205. # time.sleep(0.1)
  206. # return target_path
  207. if __name__ == '__main__':
  208. # port = os.environ.get('FLASK_PORT', '')
  209. # if port.isalnum() and int(port) > 0:
  210. # port = int(port)
  211. # else:
  212. # port = 5000
  213. # app.run(host='0.0.0.0', port=port, debug=True)
  214. test_dir = Path('img_test')
  215. count = 1
  216. start = time.perf_counter()
  217. for file in test_dir.iterdir():
  218. st = time.perf_counter()
  219. result = ocr.ocr(str(file), cls=True)
  220. print(f'第{count}张图片耗时:{format(time.perf_counter() - st)}')
  221. count += 1
  222. print(f'20张图片总耗时:{format(time.perf_counter() - start)}')