server.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. from venv import logger
  2. from flask import request, Flask, request
  3. from shutil import copy
  4. import os
  5. from typing import Dict, List, Union
  6. from urllib.request import urlretrieve
  7. from pathlib import Path
  8. from paddleocr import PaddleOCR
  9. from hashlib import md5
  10. import time
  11. import base64
  12. import io
  13. import requests
  14. import json
  15. ocr = PaddleOCR(use_angle_cls=True, lang="ch", use_gpu=False)
  16. app = Flask(__name__)
  17. img_dir = Path('imgs')
  18. img_dir.mkdir(exist_ok=True)
  19. ocr_cache_dict = dict()
  20. ocr = PaddleOCR(use_gpu=False, use_angle_cls=True, lang="ch", rec=True)
  21. def debugg_log(content: str):
  22. with open("log.txt","a") as file:
  23. file.write(content)
  24. def get_dict_from_request() -> dict:
  25. """
  26. get json data from request as much as possible
  27. Returns
  28. -------
  29. dict
  30. request data in dict format
  31. """
  32. json = {**request.args}
  33. if request.json:
  34. json = {**json, **request.json}
  35. if request.form:
  36. json = {**json, **request.form.to_dict()}
  37. return json
  38. def download_image(img_url: str) -> str:
  39. """
  40. download image or copy image to local from url
  41. Parameters
  42. ----------
  43. img_url : str
  44. url of image to be downloaded
  45. Returns
  46. -------
  47. str
  48. local file path of image
  49. Notes
  50. -----
  51. if download failed, empty string `''` will be returned
  52. """
  53. d = md5(str(img_url).encode()).hexdigest()
  54. file_name = f'{img_dir}/{d}.jpg'
  55. # NOTE: insecurity
  56. # # copy from local file system in the running container
  57. # if Path(img_url).exists():
  58. # copy(img_url, file_name)
  59. if Path(file_name).exists():
  60. return file_name
  61. # download from internet
  62. try:
  63. urlretrieve(img_url, file_name)
  64. return file_name
  65. except:
  66. return ''
  67. def base64_to_file(s: Union[str, bytes]) -> str:
  68. """
  69. decode base64 string or bytes and save to local file system
  70. Parameters
  71. ----------
  72. s : Union[str, bytes]
  73. base64 string or bytes
  74. Returns
  75. -------
  76. str
  77. local file path of base64 data
  78. """
  79. d = md5(str(s).encode()).hexdigest()
  80. file_name = f'{img_dir}/{d}.jpg'
  81. if Path(file_name).exists():
  82. return file_name
  83. if isinstance(s, str):
  84. b = base64.decodebytes(s.encode())
  85. elif isinstance(s, bytes):
  86. b = base64.decodebytes(s)
  87. else:
  88. return ''
  89. with open(file_name, 'wb') as f:
  90. f.write(b)
  91. return file_name
  92. @app.route('/api/ocr_hello', methods=['POST'])
  93. def get_content() -> str:
  94. return 'hello'
  95. @app.route('/api/ocr_extract', methods=['POST'])
  96. def ocr_extract() -> None:
  97. img = request.files.get('file')
  98. print(img)
  99. # 将img保存到当前目录下
  100. img_path = img.filename
  101. content = ''
  102. # json_data = request.get_json()
  103. # path = json_data['path']
  104. if not img_path:
  105. return {
  106. 'success': False,
  107. 'time_cost': 155,
  108. 'results': '解析失败',
  109. 'msg': '解析失败'
  110. }
  111. img.save(img_path)
  112. st = time.perf_counter()
  113. result = ocr.ocr(img_path, cls=True)
  114. app.logger.info(f'leng is {len(result)}')
  115. if (len(result) == 1):
  116. for idx in range(len(result[0])):
  117. res = result[0][idx]
  118. content += res[1][0]
  119. else:
  120. for idx in range(len(result)):
  121. res = result[idx]
  122. content += res[1][0]
  123. return {
  124. 'success': True,
  125. 'time_cost': format(time.perf_counter() - st),
  126. 'results': content,
  127. 'msg': '解析完毕'
  128. }
  129. def timer(func):
  130. """装饰器:打印函数耗时"""
  131. def decorated(*args, **kwargs):
  132. st = time.perf_counter()
  133. ret = func(*args, **kwargs)
  134. print('time cost: {} seconds'.format(time.perf_counter() - st))
  135. return ret
  136. return decorated
  137. @app.route('/api/ocr_dec', methods=['POST'])
  138. def ocr_text_extract() -> None:
  139. """
  140. ocr web api that accept image url, image path and base64 data of image
  141. """
  142. app.logger.info(f'receive request')
  143. print('receive request 2')
  144. st = time.time()
  145. json = get_dict_from_request()
  146. app.logger.info(f'request dict is {json}')
  147. img_url: str = json.get('img_url')
  148. app.logger.info(f'request url is {img_url}')
  149. base64_data: str = json.get('img_base64')
  150. img_path = ''
  151. if img_url:
  152. img_path = download_image(img_url)
  153. elif base64_data:
  154. img_path = base64_to_file(base64_data)
  155. if not img_path:
  156. et = time.time()
  157. return {
  158. 'success': False,
  159. 'time_cost': et-st,
  160. 'results': [],
  161. 'msg': 'maybe img_url or img_base64 is wrong'
  162. }
  163. results = ocr_cache_dict.get(img_path)
  164. if not results:
  165. ocr_result_list = ocr.ocr(img_path)
  166. et = time.time()
  167. if ocr_result_list is None:
  168. ocr_result_list = []
  169. os.remove(img_path)
  170. else:
  171. # make sure float32 can be JSON serializable
  172. ocr_result_list: list = eval(str(ocr_result_list))
  173. results: List[Dict] = []
  174. for each in ocr_result_list:
  175. item = {
  176. 'confidence': each[-1][1],
  177. 'text': each[-1][0],
  178. 'text_region': each[:-1]
  179. }
  180. results.append(item)
  181. ocr_cache_dict[img_path] = results
  182. et = time.time()
  183. return {
  184. 'success': True,
  185. 'time_cost': et-st,
  186. 'results': results,
  187. 'msg': ''
  188. }
  189. def get_file_from_url(url_file: str, target_path: str):
  190. send_headers = {
  191. "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/61.0.3163.100 Safari/537.36",
  192. "Connection": "keep-alive",
  193. "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8",
  194. "Accept-Language": "zh-CN,zh;q=0.8"
  195. }
  196. req = requests.get(url_file, headers=send_headers, timeout=10)
  197. bytes_io = io.BytesIO(req.content)
  198. (_, file_name) = os.path.split(url_file)
  199. if not os.path.exists(target_path):
  200. os.mkdir(target_path)
  201. target_path = target_path + file_name
  202. with open(target_path, 'wb') as file:
  203. file.write(bytes_io.getvalue())
  204. time.sleep(0.1)
  205. return target_path
  206. if __name__ == '__main__':
  207. port = os.environ.get('FLASK_PORT', '')
  208. if port.isalnum() and int(port) > 0:
  209. port = int(port)
  210. else:
  211. port = 5000
  212. app.run(host='0.0.0.0', port=port, debug=True)