123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249 |
- from venv import logger
- from flask import request, Flask, request
- from shutil import copy
- import os
- from typing import Dict, List, Union
- from urllib.request import urlretrieve
- from pathlib import Path
- from paddleocr import PaddleOCR
- from hashlib import md5
- import time
- import base64
- import io
- import requests
- import json
- ocr = PaddleOCR(use_angle_cls=True, lang="ch", use_gpu=False)
- app = Flask(__name__)
- img_dir = Path('imgs')
- img_dir.mkdir(exist_ok=True)
- ocr_cache_dict = dict()
- ocr = PaddleOCR(use_gpu=False, use_angle_cls=True, lang="ch", rec=True)
- def debugg_log(content: str):
- with open("log.txt","a") as file:
- file.write(content)
- def get_dict_from_request() -> dict:
- """
- get json data from request as much as possible
- Returns
- -------
- dict
- request data in dict format
- """
- json = {**request.args}
- if request.json:
- json = {**json, **request.json}
- if request.form:
- json = {**json, **request.form.to_dict()}
- return json
- def download_image(img_url: str) -> str:
- """
- download image or copy image to local from url
- Parameters
- ----------
- img_url : str
- url of image to be downloaded
- Returns
- -------
- str
- local file path of image
- Notes
- -----
- if download failed, empty string `''` will be returned
- """
- d = md5(str(img_url).encode()).hexdigest()
- file_name = f'{img_dir}/{d}.jpg'
- # NOTE: insecurity
- # # copy from local file system in the running container
- # if Path(img_url).exists():
- # copy(img_url, file_name)
- if Path(file_name).exists():
- return file_name
- # download from internet
- try:
- urlretrieve(img_url, file_name)
- return file_name
- except:
- return ''
- def base64_to_file(s: Union[str, bytes]) -> str:
- """
- decode base64 string or bytes and save to local file system
- Parameters
- ----------
- s : Union[str, bytes]
- base64 string or bytes
- Returns
- -------
- str
- local file path of base64 data
- """
- d = md5(str(s).encode()).hexdigest()
- file_name = f'{img_dir}/{d}.jpg'
- if Path(file_name).exists():
- return file_name
- if isinstance(s, str):
- b = base64.decodebytes(s.encode())
- elif isinstance(s, bytes):
- b = base64.decodebytes(s)
- else:
- return ''
- with open(file_name, 'wb') as f:
- f.write(b)
- return file_name
- @app.route('/api/ocr_hello', methods=['POST'])
- def get_content() -> str:
- return 'hello'
- @app.route('/api/ocr_extract', methods=['POST'])
- def ocr_extract() -> None:
- img = request.files.get('file')
- print(img)
- # 将img保存到当前目录下
- img_path = img.filename
- content = ''
- # json_data = request.get_json()
- # path = json_data['path']
- if not img_path:
- return {
- 'success': False,
- 'time_cost': 155,
- 'results': '解析失败',
- 'msg': '解析失败'
- }
- img.save(img_path)
- st = time.perf_counter()
- result = ocr.ocr(img_path, cls=True)
- app.logger.info(f'leng is {len(result)}')
- if (len(result) == 1):
- for idx in range(len(result[0])):
- res = result[0][idx]
- content += res[1][0]
- else:
- for idx in range(len(result)):
- res = result[idx]
- content += res[1][0]
- return {
- 'success': True,
- 'time_cost': format(time.perf_counter() - st),
- 'results': content,
- 'msg': '解析完毕'
- }
- def timer(func):
- """装饰器:打印函数耗时"""
- def decorated(*args, **kwargs):
- st = time.perf_counter()
- ret = func(*args, **kwargs)
- print('time cost: {} seconds'.format(time.perf_counter() - st))
- return ret
- return decorated
- @app.route('/api/ocr_dec', methods=['POST'])
- def ocr_text_extract() -> None:
- """
- ocr web api that accept image url, image path and base64 data of image
- """
- app.logger.info(f'receive request')
- print('receive request 2')
- st = time.time()
- json = get_dict_from_request()
- app.logger.info(f'request dict is {json}')
- img_url: str = json.get('img_url')
- app.logger.info(f'request url is {img_url}')
- base64_data: str = json.get('img_base64')
- img_path = ''
- if img_url:
- img_path = download_image(img_url)
- elif base64_data:
- img_path = base64_to_file(base64_data)
- if not img_path:
- et = time.time()
- return {
- 'success': False,
- 'time_cost': et-st,
- 'results': [],
- 'msg': 'maybe img_url or img_base64 is wrong'
- }
- results = ocr_cache_dict.get(img_path)
- if not results:
- ocr_result_list = ocr.ocr(img_path)
- et = time.time()
- if ocr_result_list is None:
- ocr_result_list = []
- os.remove(img_path)
- else:
- # make sure float32 can be JSON serializable
- ocr_result_list: list = eval(str(ocr_result_list))
- results: List[Dict] = []
- for each in ocr_result_list:
- item = {
- 'confidence': each[-1][1],
- 'text': each[-1][0],
- 'text_region': each[:-1]
- }
- results.append(item)
- ocr_cache_dict[img_path] = results
- et = time.time()
- return {
- 'success': True,
- 'time_cost': et-st,
- 'results': results,
- 'msg': ''
- }
- def get_file_from_url(url_file: str, target_path: str):
- send_headers = {
- "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/61.0.3163.100 Safari/537.36",
- "Connection": "keep-alive",
- "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8",
- "Accept-Language": "zh-CN,zh;q=0.8"
- }
- req = requests.get(url_file, headers=send_headers, timeout=10)
- bytes_io = io.BytesIO(req.content)
- (_, file_name) = os.path.split(url_file)
- if not os.path.exists(target_path):
- os.mkdir(target_path)
- target_path = target_path + file_name
- with open(target_path, 'wb') as file:
- file.write(bytes_io.getvalue())
- time.sleep(0.1)
- return target_path
- if __name__ == '__main__':
- port = os.environ.get('FLASK_PORT', '')
- if port.isalnum() and int(port) > 0:
- port = int(port)
- else:
- port = 5000
- app.run(host='0.0.0.0', port=port, debug=True)
|