|
@@ -0,0 +1,249 @@
|
|
|
+from venv import logger
|
|
|
+from flask import request, Flask, request
|
|
|
+from shutil import copy
|
|
|
+import os
|
|
|
+from typing import Dict, List, Union
|
|
|
+from urllib.request import urlretrieve
|
|
|
+from pathlib import Path
|
|
|
+from paddleocr import PaddleOCR
|
|
|
+from hashlib import md5
|
|
|
+import time
|
|
|
+import base64
|
|
|
+import io
|
|
|
+import requests
|
|
|
+import json
|
|
|
+
|
|
|
+ocr = PaddleOCR(use_angle_cls=True, lang="ch", use_gpu=False)
|
|
|
+app = Flask(__name__)
|
|
|
+img_dir = Path('imgs')
|
|
|
+img_dir.mkdir(exist_ok=True)
|
|
|
+ocr_cache_dict = dict()
|
|
|
+
|
|
|
+ocr = PaddleOCR(use_gpu=False, use_angle_cls=True, lang="ch", rec=True)
|
|
|
+
|
|
|
+def debugg_log(content: str):
|
|
|
+ with open("log.txt","a") as file:
|
|
|
+ file.write(content)
|
|
|
+
|
|
|
+
|
|
|
+def get_dict_from_request() -> dict:
|
|
|
+ """
|
|
|
+ get json data from request as much as possible
|
|
|
+ Returns
|
|
|
+ -------
|
|
|
+ dict
|
|
|
+ request data in dict format
|
|
|
+ """
|
|
|
+ json = {**request.args}
|
|
|
+ if request.json:
|
|
|
+ json = {**json, **request.json}
|
|
|
+ if request.form:
|
|
|
+ json = {**json, **request.form.to_dict()}
|
|
|
+ return json
|
|
|
+
|
|
|
+
|
|
|
+def download_image(img_url: str) -> str:
|
|
|
+ """
|
|
|
+ download image or copy image to local from url
|
|
|
+ Parameters
|
|
|
+ ----------
|
|
|
+ img_url : str
|
|
|
+ url of image to be downloaded
|
|
|
+ Returns
|
|
|
+ -------
|
|
|
+ str
|
|
|
+ local file path of image
|
|
|
+ Notes
|
|
|
+ -----
|
|
|
+ if download failed, empty string `''` will be returned
|
|
|
+ """
|
|
|
+ d = md5(str(img_url).encode()).hexdigest()
|
|
|
+ file_name = f'{img_dir}/{d}.jpg'
|
|
|
+ # NOTE: insecurity
|
|
|
+ # # copy from local file system in the running container
|
|
|
+ # if Path(img_url).exists():
|
|
|
+ # copy(img_url, file_name)
|
|
|
+
|
|
|
+ if Path(file_name).exists():
|
|
|
+ return file_name
|
|
|
+ # download from internet
|
|
|
+ try:
|
|
|
+ urlretrieve(img_url, file_name)
|
|
|
+ return file_name
|
|
|
+ except:
|
|
|
+ return ''
|
|
|
+
|
|
|
+
|
|
|
+def base64_to_file(s: Union[str, bytes]) -> str:
|
|
|
+ """
|
|
|
+ decode base64 string or bytes and save to local file system
|
|
|
+ Parameters
|
|
|
+ ----------
|
|
|
+ s : Union[str, bytes]
|
|
|
+ base64 string or bytes
|
|
|
+ Returns
|
|
|
+ -------
|
|
|
+ str
|
|
|
+ local file path of base64 data
|
|
|
+ """
|
|
|
+ d = md5(str(s).encode()).hexdigest()
|
|
|
+ file_name = f'{img_dir}/{d}.jpg'
|
|
|
+ if Path(file_name).exists():
|
|
|
+ return file_name
|
|
|
+ if isinstance(s, str):
|
|
|
+ b = base64.decodebytes(s.encode())
|
|
|
+ elif isinstance(s, bytes):
|
|
|
+ b = base64.decodebytes(s)
|
|
|
+ else:
|
|
|
+ return ''
|
|
|
+ with open(file_name, 'wb') as f:
|
|
|
+ f.write(b)
|
|
|
+ return file_name
|
|
|
+
|
|
|
+
|
|
|
+@app.route('/api/ocr_hello', methods=['POST'])
|
|
|
+def get_content() -> str:
|
|
|
+ return 'hello'
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+@app.route('/api/ocr_extract', methods=['POST'])
|
|
|
+def ocr_extract() -> None:
|
|
|
+
|
|
|
+ img = request.files.get('file')
|
|
|
+ print(img)
|
|
|
+ # 将img保存到当前目录下
|
|
|
+ img_path = img.filename
|
|
|
+ content = ''
|
|
|
+
|
|
|
+ # json_data = request.get_json()
|
|
|
+ # path = json_data['path']
|
|
|
+ if not img_path:
|
|
|
+ return {
|
|
|
+ 'success': False,
|
|
|
+ 'time_cost': 155,
|
|
|
+ 'results': '解析失败',
|
|
|
+ 'msg': '解析失败'
|
|
|
+ }
|
|
|
+
|
|
|
+ img.save(img_path)
|
|
|
+
|
|
|
+ st = time.perf_counter()
|
|
|
+ result = ocr.ocr(img_path, cls=True)
|
|
|
+ app.logger.info(f'leng is {len(result)}')
|
|
|
+ if (len(result) == 1):
|
|
|
+ for idx in range(len(result[0])):
|
|
|
+ res = result[0][idx]
|
|
|
+ content += res[1][0]
|
|
|
+ else:
|
|
|
+ for idx in range(len(result)):
|
|
|
+ res = result[idx]
|
|
|
+ content += res[1][0]
|
|
|
+
|
|
|
+ return {
|
|
|
+ 'success': True,
|
|
|
+ 'time_cost': format(time.perf_counter() - st),
|
|
|
+ 'results': content,
|
|
|
+ 'msg': '解析完毕'
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+def timer(func):
|
|
|
+ """装饰器:打印函数耗时"""
|
|
|
+
|
|
|
+ def decorated(*args, **kwargs):
|
|
|
+ st = time.perf_counter()
|
|
|
+ ret = func(*args, **kwargs)
|
|
|
+ print('time cost: {} seconds'.format(time.perf_counter() - st))
|
|
|
+ return ret
|
|
|
+
|
|
|
+ return decorated
|
|
|
+
|
|
|
+
|
|
|
+@app.route('/api/ocr_dec', methods=['POST'])
|
|
|
+def ocr_text_extract() -> None:
|
|
|
+ """
|
|
|
+ ocr web api that accept image url, image path and base64 data of image
|
|
|
+ """
|
|
|
+ app.logger.info(f'receive request')
|
|
|
+ print('receive request 2')
|
|
|
+ st = time.time()
|
|
|
+ json = get_dict_from_request()
|
|
|
+ app.logger.info(f'request dict is {json}')
|
|
|
+ img_url: str = json.get('img_url')
|
|
|
+ app.logger.info(f'request url is {img_url}')
|
|
|
+ base64_data: str = json.get('img_base64')
|
|
|
+ img_path = ''
|
|
|
+ if img_url:
|
|
|
+ img_path = download_image(img_url)
|
|
|
+ elif base64_data:
|
|
|
+ img_path = base64_to_file(base64_data)
|
|
|
+ if not img_path:
|
|
|
+ et = time.time()
|
|
|
+ return {
|
|
|
+ 'success': False,
|
|
|
+ 'time_cost': et-st,
|
|
|
+ 'results': [],
|
|
|
+ 'msg': 'maybe img_url or img_base64 is wrong'
|
|
|
+ }
|
|
|
+ results = ocr_cache_dict.get(img_path)
|
|
|
+ if not results:
|
|
|
+ ocr_result_list = ocr.ocr(img_path)
|
|
|
+
|
|
|
+ et = time.time()
|
|
|
+ if ocr_result_list is None:
|
|
|
+ ocr_result_list = []
|
|
|
+ os.remove(img_path)
|
|
|
+ else:
|
|
|
+ # make sure float32 can be JSON serializable
|
|
|
+ ocr_result_list: list = eval(str(ocr_result_list))
|
|
|
+ results: List[Dict] = []
|
|
|
+ for each in ocr_result_list:
|
|
|
+ item = {
|
|
|
+ 'confidence': each[-1][1],
|
|
|
+ 'text': each[-1][0],
|
|
|
+ 'text_region': each[:-1]
|
|
|
+ }
|
|
|
+ results.append(item)
|
|
|
+ ocr_cache_dict[img_path] = results
|
|
|
+
|
|
|
+ et = time.time()
|
|
|
+ return {
|
|
|
+ 'success': True,
|
|
|
+ 'time_cost': et-st,
|
|
|
+ 'results': results,
|
|
|
+ 'msg': ''
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+def get_file_from_url(url_file: str, target_path: str):
|
|
|
+ send_headers = {
|
|
|
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/61.0.3163.100 Safari/537.36",
|
|
|
+ "Connection": "keep-alive",
|
|
|
+ "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8",
|
|
|
+ "Accept-Language": "zh-CN,zh;q=0.8"
|
|
|
+ }
|
|
|
+
|
|
|
+ req = requests.get(url_file, headers=send_headers, timeout=10)
|
|
|
+ bytes_io = io.BytesIO(req.content)
|
|
|
+
|
|
|
+ (_, file_name) = os.path.split(url_file)
|
|
|
+ if not os.path.exists(target_path):
|
|
|
+ os.mkdir(target_path)
|
|
|
+ target_path = target_path + file_name
|
|
|
+ with open(target_path, 'wb') as file:
|
|
|
+ file.write(bytes_io.getvalue())
|
|
|
+
|
|
|
+ time.sleep(0.1)
|
|
|
+ return target_path
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ port = os.environ.get('FLASK_PORT', '')
|
|
|
+ if port.isalnum() and int(port) > 0:
|
|
|
+ port = int(port)
|
|
|
+ else:
|
|
|
+ port = 5000
|
|
|
+ app.run(host='0.0.0.0', port=port, debug=True)
|