├── .dockerignore ├── .gitignore ├── .pip.conf ├── .travis.yml ├── Dockerfile ├── LICENSE ├── README.md ├── category.py ├── compat ├── flask_server.py ├── flask_server.spec ├── grpc.proto ├── grpc_pb2.py ├── grpc_pb2_grpc.py ├── grpc_server.py ├── grpc_server.spec └── sanic_server.py ├── config.py ├── constants.py ├── demo.py ├── deploy.conf.py ├── event_handler.py ├── event_loop.py ├── graph_session.py ├── interface.py ├── middleware ├── __init__.py ├── constructor ├── impl │ ├── color_extractor.py │ ├── color_filter.py │ ├── corp_to_multi.py │ ├── gif_frames.py │ └── rgb_filter.py └── resource │ ├── __init__.py │ └── color_filter.py ├── package.py ├── predict.py ├── pretreatment.py ├── requirements.txt ├── resource ├── VERSION ├── favorite.ico └── icon.ico ├── sdk ├── __init__.py ├── onnx │ ├── __init__.py │ ├── requirements.txt │ └── sdk.py ├── pb │ ├── __init__.py │ ├── requirements.txt │ └── sdk.py └── tflite │ ├── __init__.py │ ├── requirements.txt │ └── sdk.py ├── signature.py ├── test.py ├── tornado_server.py ├── tornado_server.spec ├── tornado_server_gpu.spec └── utils.py /.dockerignore: -------------------------------------------------------------------------------- 1 | # .dockerignore 文件中指定在传递给 docker引擎 时需要忽略掉的文件或文件夹。 2 | 3 | .git/ 4 | .idea/ 5 | patchca_ubuntu/ 6 | __pycache__/ 7 | .DS_Store 8 | venv/ -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | # C extensions 8 | *.so 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | # Installer logs 32 | pip-log.txt 33 | pip-delete-this-directory.txt 34 | # Unit test / coverage reports 35 | htmlcov/ 36 | .tox/ 37 | .coverage 38 | .coverage.* 39 | .cache 40 | nosetests.xml 41 | coverage.xml 42 | *,cover 43 | .hypothesis/ 44 | # Translations 45 | *.mo 46 | *.pot 47 | # Django stuff: 48 | *.log 49 | local_settings.py 50 | # Flask stuff: 51 | instance/ 52 | .webassets-cache 53 | # Scrapy stuff: 54 | .scrapy 55 | # Sphinx documentation 56 | docs/_build/ 57 | # PyBuilder 58 | target/ 59 | # Jupyter Notebook 60 | .ipynb_checkpoints 61 | # pyenv 62 | .python-version 63 | # celery beat schedule file 64 | celerybeat-schedule 65 | # SageMath parsed files 66 | *.sage.py 67 | # dotenv 68 | .env 69 | # virtualenv 70 | .venv 71 | venv/ 72 | ENV/ 73 | # Spyder project settings 74 | .spyderproject 75 | # Rope project settings 76 | .ropeproject 77 | # jetbrains IDE项目配置文件 78 | .idea/ 79 | # image file 80 | *.jpg 81 | *.png 82 | *.gif 83 | # linux系统生成的文件 84 | *.pid 85 | nohup.out 86 | graph/* 87 | model/* 88 | config.yaml -------------------------------------------------------------------------------- /.pip.conf: -------------------------------------------------------------------------------- 1 | [global] 2 | index-url = https://pypi.doubanio.com/simple/ 3 | trusted-host = pypi.doubanio.com -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | 3 | python: 4 | - "3.6" 5 | 6 | install: 7 | - pip install -r requirements.txt 8 | 9 | script: 10 | python test.py -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.6.8-stretch as builder 2 | 3 | ADD . /app/ 4 | 5 | WORKDIR /app/ 6 | 7 | COPY requirements.txt /app/ 8 | 9 | # timezone 10 | ENV TZ=Asia/Shanghai 11 | 12 | RUN pip install --no-cache-dir --upgrade pip \ 13 | && pip install --no-cache-dir -r requirements.txt 14 | 15 | 16 | ENTRYPOINT ["python3", "tornado_server.py"] 17 | EXPOSE 19952 18 | # run command: 19 | # docker run -d -p 19952:19952 [image:tag] 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The Star And Thank Author License (SATA) 2 | 3 | Copyright (c) 2018 kelromz(kerlomz@gmail.com) 4 | 5 | Project Url: https://github.com/kerlomz/captcha_platform 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in 15 | all copies or substantial portions of the Software. 16 | 17 | And wait, the most important, you shall star/+1/like the project(s) in project url 18 | section above first, and then thank the author(s) in Copyright section. 19 | 20 | Here are some suggested ways: 21 | 22 | - Email the authors a thank-you letter, and make friends with him/her/them. 23 | - Report bugs or issues. 24 | - Tell friends what a wonderful project this is. 25 | - And, sure, you can just express thanks in your mind without telling the world. 26 | 27 | Contributors of this project by forking have the option to add his/her name and 28 | forked project url at copyright and project url sections, but shall not delete 29 | or modify anything else in these two sections. 30 | 31 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 32 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 33 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 34 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 35 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 36 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 37 | THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | [![Build Status](https://travis-ci.org/kerlomz/captcha_platform.svg?branch=master)](https://travis-ci.org/kerlomz/captcha_platform) 3 | 4 | # Project Introduction 5 | This project is based on CNN+BLSTM+CTC to realize verification code identification. 6 | This project is only for deployment models, If you need to train the model, please move to https://github.com/kerlomz/captcha_trainer 7 | 8 | # Informed 9 | 1. The default requirements.txt will install CPU version, Change "requirements.txt" from "TensorFlow" to "TensorFlow-GPU" to Switch to GPU version, Use the GPU version to install the corresponding CUDA and cuDNN. 10 | 2. demo.py: An example of how to call a prediction method. 11 | 3. The model folder folder is used to store model configuration files such as model.yaml. 12 | 4. The graph folder is used to store compiled models such as model.pb 13 | 5. The deployment service will automatically load all the models in the model configuration. When a new model configuration is added, the corresponding compilation model in the graph folder will be automatically loaded, so if you need to add it, please copy the corresponding compilation model to the graph path first, then add the model configuration. 14 | 15 | 16 | # Start 17 | 1. Install the python 3.9 environment (with pip) 18 | 2. Install virtualenv ```pip3 install virtualenv``` 19 | 3. Create a separate virtual environment for the project: 20 | ```bash 21 | virtualenv -p /usr/bin/python3 venv # venv is the name of the virtual environment. 22 | cd venv/ # venv is the name of the virtual environment. 23 | source bin/activate # to activate the current virtual environment. 24 | cd captcha_platform # captcha_platform is the project path. 25 | ``` 26 | 4. ```pip install -r requirements.txt``` 27 | 5. Place your trained model.yaml in model folder, and your model.pb in graph folder (create if not exist) 28 | 6. Deploy as follows. 29 | 30 | ## 1. Http Version 31 | 1. Linux 32 | Deploy (Linux/Mac): 33 | 34 | Port: 19952 35 | ``` 36 | python tornado_server.py 37 | ``` 38 | 39 | 2. Windows 40 | Deploy (Windows): 41 | ``` 42 | python xxx_server.py 43 | ``` 44 | 45 | 3. Request 46 | 47 | |Request URI | Content-Type | Payload Type | Method | 48 | | ----------- | ---------------- | -------- | -------- | 49 | | http://localhost:[Bind-port]/captcha/v1 | application/json | JSON | POST | 50 | 51 | | Parameter | Required | Type | Description | 52 | | ---------- | ---- | ------ | ------------------------ | 53 | | image | Yes | String | Base64 encoding binary stream | 54 | | model_name | No | String | ModelName, bindable in yaml configuration | 55 | 56 | 57 | The request is in JSON format, like: {"image": "base64 encoded image binary stream"} 58 | 59 | 4. Response 60 | 61 | | Parameter Name | Type | Description | 62 | | ------- | ------ | ------------------ | 63 | | message | String | Identify results or error messages | 64 | | code | String | Status Code | 65 | | success | String | Whether to request success | 66 | 67 | The return is in JSON format, like: {"message": "xxxx", "code": 0, "success": true} 68 | 69 | 70 | ## 2. G-RPC Version 71 | Deploy: 72 | ``` 73 | python3 grpc_server.py 74 | ``` 75 | Port: 50054 76 | 77 | 78 | # Update G-RPC-CODE 79 | python -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. ./grpc.proto 80 | 81 | 82 | # Directory Structure 83 | 84 | - captcha_platform 85 | - grpc_server.py 86 | - flask_server.py 87 | - tornado_server.py 88 | - sanic_server.py 89 | - demo.py 90 | - config.yaml 91 | - model 92 | - model-1.yaml 93 | - model-2.yaml 94 | - ... 95 | - graph 96 | - Model-1.pb 97 | - ... 98 | 99 | # Management Model 100 | 1. **Load a model** 101 | - Put the trained pb model in the graph folder. 102 | - Put the trained yaml model configuration file in the model folder. 103 | 2. **Unload a model** 104 | - Delete the corresponding yaml configuration file in the model folder. 105 | - Delete the corresponding pb model file in the graph folder. 106 | 3. **Update a model** 107 | - Put the trained pb model in the graph folder. 108 | - Put the yaml configuration file with "Version" greater than the current version in the model folder. 109 | - Delete old models and configurations. 110 | 111 | # License 112 | This project use SATA License (Star And Thank Author License), so you have to star this project before using. Read the license carefully. 113 | 114 | # Introduction 115 | https://www.jianshu.com/p/80ef04b16efc 116 | 117 | # Donate 118 | Thank you very much for your support of my project. -------------------------------------------------------------------------------- /compat/flask_server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import time 5 | import optparse 6 | import threading 7 | from flask import * 8 | from flask_caching import Cache 9 | from gevent.pywsgi import WSGIServer 10 | from geventwebsocket.handler import WebSocketHandler 11 | from config import Config 12 | from utils import ImageUtils 13 | from constants import Response 14 | 15 | from interface import InterfaceManager 16 | from signature import Signature, ServerType 17 | from watchdog.observers import Observer 18 | from event_handler import FileEventHandler 19 | from middleware import * 20 | # The order cannot be changed, it must be before the flask. 21 | 22 | app = Flask(__name__) 23 | cache = Cache(app, config={'CACHE_TYPE': 'simple'}) 24 | 25 | conf_path = '../config.yaml' 26 | model_path = '../model' 27 | graph_path = '../graph' 28 | 29 | 30 | system_config = Config(conf_path=conf_path, model_path=model_path, graph_path=graph_path) 31 | sign = Signature(ServerType.FLASK, system_config) 32 | _except = Response(system_config.response_def_map) 33 | route_map = {i['Class']: i['Route'] for i in system_config.route_map} 34 | sign.set_auth([{'accessKey': system_config.access_key, 'secretKey': system_config.secret_key}]) 35 | logger = system_config.logger 36 | interface_manager = InterfaceManager() 37 | image_utils = ImageUtils(system_config) 38 | 39 | 40 | @app.after_request 41 | def after_request(response): 42 | response.headers['Access-Control-Allow-Origin'] = '*' 43 | return response 44 | 45 | 46 | @app.errorhandler(400) 47 | def server_error(error=None): 48 | message = "Bad Request" 49 | return jsonify(message=message, code=error.code, success=False) 50 | 51 | 52 | @app.errorhandler(500) 53 | def server_error(error=None): 54 | message = 'Internal Server Error' 55 | return jsonify(message=message, code=500, success=False) 56 | 57 | 58 | @app.errorhandler(404) 59 | def not_found(error=None): 60 | message = '404 Not Found' 61 | return jsonify(message=message, code=error.code, success=False) 62 | 63 | 64 | @app.errorhandler(403) 65 | def permission_denied(error=None): 66 | message = 'Forbidden' 67 | return jsonify(message=message, code=error.code, success=False) 68 | 69 | 70 | @app.route(route_map['AuthHandler'], methods=['POST']) 71 | @sign.signature_required # This decorator is required for certification. 72 | def auth_request(): 73 | return common_request() 74 | 75 | 76 | @app.route(route_map['NoAuthHandler'], methods=['POST']) 77 | def no_auth_request(): 78 | return common_request() 79 | 80 | 81 | def common_request(): 82 | """ 83 | This api is used for captcha prediction without authentication 84 | :return: 85 | """ 86 | start_time = time.time() 87 | if not request.json or 'image' not in request.json: 88 | abort(400) 89 | 90 | if interface_manager.total == 0: 91 | logger.info('There is currently no model deployment and services are not available.') 92 | return json.dumps({"message": "", "success": False, "code": -999}) 93 | 94 | bytes_batch, response = image_utils.get_bytes_batch(request.json['image']) 95 | 96 | if not bytes_batch: 97 | logger.error('Name[{}] - Response[{}] - {} ms'.format( 98 | request.json.get('model_site'), response, 99 | (time.time() - start_time) * 1000) 100 | ) 101 | return json.dumps(response), 200 102 | 103 | image_sample = bytes_batch[0] 104 | image_size = ImageUtils.size_of_image(image_sample) 105 | size_string = "{}x{}".format(image_size[0], image_size[1]) 106 | 107 | if 'model_name' in request.json: 108 | interface = interface_manager.get_by_name(request.json['model_name']) 109 | else: 110 | interface = interface_manager.get_by_size(size_string) 111 | 112 | split_char = request.json['output_split'] if 'output_split' in request.json else interface.model_conf.output_split 113 | 114 | if 'need_color' in request.json and request.json['need_color']: 115 | bytes_batch = [color_extract.separate_color(_, color_map[request.json['need_color']]) for _ in bytes_batch] 116 | 117 | image_batch, response = ImageUtils.get_image_batch(interface.model_conf, bytes_batch) 118 | 119 | if not image_batch: 120 | logger.error('[{}] - Size[{}] - Name[{}] - Response[{}] - {} ms'.format( 121 | interface.name, size_string, request.json.get('model_name'), response, 122 | (time.time() - start_time) * 1000) 123 | ) 124 | return json.dumps(response), 200 125 | 126 | result = interface.predict_batch(image_batch, split_char) 127 | logger.info('[{}] - Size[{}] - Name[{}] - Predict Result[{}] - {} ms'.format( 128 | interface.name, 129 | size_string, 130 | request.json.get('model_name'), 131 | result, 132 | (time.time() - start_time) * 1000 133 | )) 134 | response['message'] = result 135 | return json.dumps(response), 200 136 | 137 | 138 | def event_loop(): 139 | event = threading.Event() 140 | observer = Observer() 141 | event_handler = FileEventHandler(system_config, model_path, interface_manager) 142 | observer.schedule(event_handler, event_handler.model_conf_path, True) 143 | observer.start() 144 | try: 145 | while True: 146 | event.wait(1) 147 | except KeyboardInterrupt: 148 | observer.stop() 149 | observer.join() 150 | 151 | 152 | threading.Thread(target=event_loop, daemon=True).start() 153 | 154 | if __name__ == "__main__": 155 | 156 | parser = optparse.OptionParser() 157 | parser.add_option('-p', '--port', type="int", default=19951, dest="port") 158 | 159 | opt, args = parser.parse_args() 160 | server_port = opt.port 161 | 162 | server_host = "0.0.0.0" 163 | 164 | logger.info('Running on http://{}:{}/ '.format(server_host, server_port)) 165 | server = WSGIServer((server_host, server_port), app, handler_class=WebSocketHandler) 166 | try: 167 | server.serve_forever() 168 | except KeyboardInterrupt: 169 | server.stop() 170 | -------------------------------------------------------------------------------- /compat/flask_server.spec: -------------------------------------------------------------------------------- 1 | # -*- mode: python -*- 2 | # Used to package as a single executable 3 | # This is a configuration file 4 | 5 | block_cipher = pyi_crypto.PyiBlockCipher(key='kerlomz&coriander') 6 | 7 | added_files = [('resource/icon.ico', 'resource')] 8 | 9 | a = Analysis(['flask_server.py'], 10 | pathex=['.'], 11 | binaries=[], 12 | datas=added_files, 13 | hiddenimports=[], 14 | hookspath=[], 15 | runtime_hooks=[], 16 | excludes=[], 17 | win_no_prefer_redirects=False, 18 | win_private_assemblies=False, 19 | cipher=block_cipher) 20 | pyz = PYZ(a.pure, a.zipped_data, 21 | cipher=block_cipher) 22 | exe = EXE(pyz, 23 | a.scripts, 24 | a.binaries, 25 | a.zipfiles, 26 | a.datas, 27 | name='captcha_platform_flask', 28 | debug=False, 29 | strip=False, 30 | upx=True, 31 | runtime_tmpdir=None, 32 | console=True, 33 | icon='resource/icon.ico') 34 | -------------------------------------------------------------------------------- /compat/grpc.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | service Predict { 4 | rpc predict (PredictRequest) returns (PredictResult) {} 5 | } 6 | 7 | message PredictRequest { 8 | string image = 1; 9 | string split_char = 2; 10 | string model_name = 3; 11 | string model_type = 4; 12 | string model_site = 5; 13 | string need_color = 6; 14 | } 15 | 16 | message PredictResult { 17 | string result = 1; 18 | int32 code = 2; 19 | bool success = 3; 20 | } -------------------------------------------------------------------------------- /compat/grpc_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: grpc.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | # @@protoc_insertion_point(imports) 11 | 12 | _sym_db = _symbol_database.Default() 13 | 14 | 15 | 16 | 17 | DESCRIPTOR = _descriptor.FileDescriptor( 18 | name='grpc.proto', 19 | package='', 20 | syntax='proto3', 21 | serialized_options=None, 22 | serialized_pb=_b('\n\ngrpc.proto\"\x83\x01\n\x0ePredictRequest\x12\r\n\x05image\x18\x01 \x01(\t\x12\x12\n\nsplit_char\x18\x02 \x01(\t\x12\x12\n\nmodel_name\x18\x03 \x01(\t\x12\x12\n\nmodel_type\x18\x04 \x01(\t\x12\x12\n\nmodel_site\x18\x05 \x01(\t\x12\x12\n\nneed_color\x18\x06 \x01(\t\">\n\rPredictResult\x12\x0e\n\x06result\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\x05\x12\x0f\n\x07success\x18\x03 \x01(\x08\x32\x37\n\x07Predict\x12,\n\x07predict\x12\x0f.PredictRequest\x1a\x0e.PredictResult\"\x00\x62\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _PREDICTREQUEST = _descriptor.Descriptor( 29 | name='PredictRequest', 30 | full_name='PredictRequest', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='image', full_name='PredictRequest.image', index=0, 37 | number=1, type=9, cpp_type=9, label=1, 38 | has_default_value=False, default_value=_b("").decode('utf-8'), 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | serialized_options=None, file=DESCRIPTOR), 42 | _descriptor.FieldDescriptor( 43 | name='split_char', full_name='PredictRequest.split_char', index=1, 44 | number=2, type=9, cpp_type=9, label=1, 45 | has_default_value=False, default_value=_b("").decode('utf-8'), 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | serialized_options=None, file=DESCRIPTOR), 49 | _descriptor.FieldDescriptor( 50 | name='model_name', full_name='PredictRequest.model_name', index=2, 51 | number=3, type=9, cpp_type=9, label=1, 52 | has_default_value=False, default_value=_b("").decode('utf-8'), 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | serialized_options=None, file=DESCRIPTOR), 56 | _descriptor.FieldDescriptor( 57 | name='model_type', full_name='PredictRequest.model_type', index=3, 58 | number=4, type=9, cpp_type=9, label=1, 59 | has_default_value=False, default_value=_b("").decode('utf-8'), 60 | message_type=None, enum_type=None, containing_type=None, 61 | is_extension=False, extension_scope=None, 62 | serialized_options=None, file=DESCRIPTOR), 63 | _descriptor.FieldDescriptor( 64 | name='model_site', full_name='PredictRequest.model_site', index=4, 65 | number=5, type=9, cpp_type=9, label=1, 66 | has_default_value=False, default_value=_b("").decode('utf-8'), 67 | message_type=None, enum_type=None, containing_type=None, 68 | is_extension=False, extension_scope=None, 69 | serialized_options=None, file=DESCRIPTOR), 70 | _descriptor.FieldDescriptor( 71 | name='need_color', full_name='PredictRequest.need_color', index=5, 72 | number=6, type=9, cpp_type=9, label=1, 73 | has_default_value=False, default_value=_b("").decode('utf-8'), 74 | message_type=None, enum_type=None, containing_type=None, 75 | is_extension=False, extension_scope=None, 76 | serialized_options=None, file=DESCRIPTOR), 77 | ], 78 | extensions=[ 79 | ], 80 | nested_types=[], 81 | enum_types=[ 82 | ], 83 | serialized_options=None, 84 | is_extendable=False, 85 | syntax='proto3', 86 | extension_ranges=[], 87 | oneofs=[ 88 | ], 89 | serialized_start=15, 90 | serialized_end=146, 91 | ) 92 | 93 | 94 | _PREDICTRESULT = _descriptor.Descriptor( 95 | name='PredictResult', 96 | full_name='PredictResult', 97 | filename=None, 98 | file=DESCRIPTOR, 99 | containing_type=None, 100 | fields=[ 101 | _descriptor.FieldDescriptor( 102 | name='result', full_name='PredictResult.result', index=0, 103 | number=1, type=9, cpp_type=9, label=1, 104 | has_default_value=False, default_value=_b("").decode('utf-8'), 105 | message_type=None, enum_type=None, containing_type=None, 106 | is_extension=False, extension_scope=None, 107 | serialized_options=None, file=DESCRIPTOR), 108 | _descriptor.FieldDescriptor( 109 | name='code', full_name='PredictResult.code', index=1, 110 | number=2, type=5, cpp_type=1, label=1, 111 | has_default_value=False, default_value=0, 112 | message_type=None, enum_type=None, containing_type=None, 113 | is_extension=False, extension_scope=None, 114 | serialized_options=None, file=DESCRIPTOR), 115 | _descriptor.FieldDescriptor( 116 | name='success', full_name='PredictResult.success', index=2, 117 | number=3, type=8, cpp_type=7, label=1, 118 | has_default_value=False, default_value=False, 119 | message_type=None, enum_type=None, containing_type=None, 120 | is_extension=False, extension_scope=None, 121 | serialized_options=None, file=DESCRIPTOR), 122 | ], 123 | extensions=[ 124 | ], 125 | nested_types=[], 126 | enum_types=[ 127 | ], 128 | serialized_options=None, 129 | is_extendable=False, 130 | syntax='proto3', 131 | extension_ranges=[], 132 | oneofs=[ 133 | ], 134 | serialized_start=148, 135 | serialized_end=210, 136 | ) 137 | 138 | DESCRIPTOR.message_types_by_name['PredictRequest'] = _PREDICTREQUEST 139 | DESCRIPTOR.message_types_by_name['PredictResult'] = _PREDICTRESULT 140 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 141 | 142 | PredictRequest = _reflection.GeneratedProtocolMessageType('PredictRequest', (_message.Message,), dict( 143 | DESCRIPTOR = _PREDICTREQUEST, 144 | __module__ = 'grpc_pb2' 145 | # @@protoc_insertion_point(class_scope:PredictRequest) 146 | )) 147 | _sym_db.RegisterMessage(PredictRequest) 148 | 149 | PredictResult = _reflection.GeneratedProtocolMessageType('PredictResult', (_message.Message,), dict( 150 | DESCRIPTOR = _PREDICTRESULT, 151 | __module__ = 'grpc_pb2' 152 | # @@protoc_insertion_point(class_scope:PredictResult) 153 | )) 154 | _sym_db.RegisterMessage(PredictResult) 155 | 156 | 157 | 158 | _PREDICT = _descriptor.ServiceDescriptor( 159 | name='Predict', 160 | full_name='Predict', 161 | file=DESCRIPTOR, 162 | index=0, 163 | serialized_options=None, 164 | serialized_start=212, 165 | serialized_end=267, 166 | methods=[ 167 | _descriptor.MethodDescriptor( 168 | name='predict', 169 | full_name='Predict.predict', 170 | index=0, 171 | containing_service=None, 172 | input_type=_PREDICTREQUEST, 173 | output_type=_PREDICTRESULT, 174 | serialized_options=None, 175 | ), 176 | ]) 177 | _sym_db.RegisterServiceDescriptor(_PREDICT) 178 | 179 | DESCRIPTOR.services_by_name['Predict'] = _PREDICT 180 | 181 | # @@protoc_insertion_point(module_scope) 182 | -------------------------------------------------------------------------------- /compat/grpc_pb2_grpc.py: -------------------------------------------------------------------------------- 1 | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! 2 | import grpc 3 | 4 | from compat import grpc_pb2 as grpc__pb2 5 | 6 | 7 | class PredictStub(object): 8 | # missing associated documentation comment in .proto file 9 | pass 10 | 11 | def __init__(self, channel): 12 | """Constructor. 13 | 14 | Args: 15 | channel: A grpc.Channel. 16 | """ 17 | self.predict = channel.unary_unary( 18 | '/Predict/predict', 19 | request_serializer=grpc__pb2.PredictRequest.SerializeToString, 20 | response_deserializer=grpc__pb2.PredictResult.FromString, 21 | ) 22 | 23 | 24 | class PredictServicer(object): 25 | # missing associated documentation comment in .proto file 26 | pass 27 | 28 | def predict(self, request, context): 29 | # missing associated documentation comment in .proto file 30 | pass 31 | context.set_code(grpc.StatusCode.UNIMPLEMENTED) 32 | context.set_details('Method not implemented!') 33 | raise NotImplementedError('Method not implemented!') 34 | 35 | 36 | def add_PredictServicer_to_server(servicer, server): 37 | rpc_method_handlers = { 38 | 'predict': grpc.unary_unary_rpc_method_handler( 39 | servicer.predict, 40 | request_deserializer=grpc__pb2.PredictRequest.FromString, 41 | response_serializer=grpc__pb2.PredictResult.SerializeToString, 42 | ), 43 | } 44 | generic_handler = grpc.method_handlers_generic_handler( 45 | 'Predict', rpc_method_handlers) 46 | server.add_generic_rpc_handlers((generic_handler,)) 47 | -------------------------------------------------------------------------------- /compat/grpc_server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | 5 | import time 6 | import threading 7 | from concurrent import futures 8 | 9 | import grpc 10 | from compat import grpc_pb2_grpc, grpc_pb2 11 | import optparse 12 | from utils import ImageUtils 13 | from interface import InterfaceManager 14 | from config import Config 15 | from middleware import * 16 | from event_loop import event_loop 17 | 18 | _ONE_DAY_IN_SECONDS = 60 * 60 * 24 19 | 20 | 21 | class Predict(grpc_pb2_grpc.PredictServicer): 22 | 23 | def __init__(self, **kwargs): 24 | super().__init__(**kwargs) 25 | self.image_utils = ImageUtils(system_config) 26 | 27 | def predict(self, request, context): 28 | start_time = time.time() 29 | bytes_batch, status = self.image_utils.get_bytes_batch(request.image) 30 | 31 | if interface_manager.total == 0: 32 | logger.info('There is currently no model deployment and services are not available.') 33 | return {"result": "", "success": False, "code": -999} 34 | 35 | if not bytes_batch: 36 | return grpc_pb2.PredictResult(result="", success=status['success'], code=status['code']) 37 | 38 | image_sample = bytes_batch[0] 39 | image_size = ImageUtils.size_of_image(image_sample) 40 | size_string = "{}x{}".format(image_size[0], image_size[1]) 41 | if request.model_name: 42 | interface = interface_manager.get_by_name(request.model_name) 43 | else: 44 | interface = interface_manager.get_by_size(size_string) 45 | if not interface: 46 | logger.info('Service is not ready!') 47 | return {"result": "", "success": False, "code": 999} 48 | 49 | if request.need_color: 50 | bytes_batch = [color_extract.separate_color(_, color_map[request.need_color]) for _ in bytes_batch] 51 | 52 | image_batch, status = ImageUtils.get_image_batch(interface.model_conf, bytes_batch) 53 | 54 | if not image_batch: 55 | return grpc_pb2.PredictResult(result="", success=status['success'], code=status['code']) 56 | 57 | result = interface.predict_batch(image_batch, request.split_char) 58 | logger.info('[{}] - Size[{}] - Type[{}] - Site[{}] - Predict Result[{}] - {} ms'.format( 59 | interface.name, 60 | size_string, 61 | request.model_type, 62 | request.model_site, 63 | result, 64 | (time.time() - start_time) * 1000 65 | )) 66 | return grpc_pb2.PredictResult(result=result, success=status['success'], code=status['code']) 67 | 68 | 69 | def serve(): 70 | server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) 71 | grpc_pb2_grpc.add_PredictServicer_to_server(Predict(), server) 72 | server.add_insecure_port('[::]:50054') 73 | server.start() 74 | try: 75 | while True: 76 | time.sleep(_ONE_DAY_IN_SECONDS) 77 | except KeyboardInterrupt: 78 | server.stop(0) 79 | 80 | 81 | if __name__ == '__main__': 82 | parser = optparse.OptionParser() 83 | parser.add_option('-p', '--port', type="int", default=50054, dest="port") 84 | parser.add_option('-c', '--config', type="str", default='./config.yaml', dest="config") 85 | parser.add_option('-m', '--model_path', type="str", default='model', dest="model_path") 86 | parser.add_option('-g', '--graph_path', type="str", default='graph', dest="graph_path") 87 | opt, args = parser.parse_args() 88 | server_port = opt.port 89 | conf_path = opt.config 90 | model_path = opt.model_path 91 | graph_path = opt.graph_path 92 | system_config = Config(conf_path=conf_path, model_path=model_path, graph_path=graph_path) 93 | interface_manager = InterfaceManager() 94 | threading.Thread(target=lambda: event_loop(system_config, model_path, interface_manager)).start() 95 | 96 | logger = system_config.logger 97 | server_host = "0.0.0.0" 98 | 99 | logger.info('Running on http://{}:{}/ '.format(server_host, server_port)) 100 | serve() 101 | -------------------------------------------------------------------------------- /compat/grpc_server.spec: -------------------------------------------------------------------------------- 1 | # -*- mode: python -*- 2 | # Used to package as a single executable 3 | # This is a configuration file 4 | 5 | block_cipher = pyi_crypto.PyiBlockCipher(key='kerlomz&coriander') 6 | 7 | added_files = [('resource/icon.ico', 'resource')] 8 | 9 | a = Analysis(['grpc_server.py'], 10 | pathex=['.'], 11 | binaries=[], 12 | datas=added_files, 13 | hiddenimports=[], 14 | hookspath=[], 15 | runtime_hooks=[], 16 | excludes=[], 17 | win_no_prefer_redirects=False, 18 | win_private_assemblies=False, 19 | cipher=block_cipher) 20 | pyz = PYZ(a.pure, a.zipped_data, 21 | cipher=block_cipher) 22 | exe = EXE(pyz, 23 | a.scripts, 24 | a.binaries, 25 | a.zipfiles, 26 | a.datas, 27 | name='captcha_platform_grpc', 28 | debug=False, 29 | strip=False, 30 | upx=True, 31 | runtime_tmpdir=None, 32 | console=True, 33 | icon='resource/icon.ico') 34 | -------------------------------------------------------------------------------- /compat/sanic_server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import time 5 | import optparse 6 | import threading 7 | from config import Config 8 | from utils import ImageUtils 9 | from interface import InterfaceManager 10 | from watchdog.observers import Observer 11 | from event_handler import FileEventHandler 12 | from sanic import Sanic 13 | from sanic.response import json 14 | from signature import Signature, ServerType 15 | from middleware import * 16 | from event_loop import event_loop 17 | 18 | app = Sanic() 19 | sign = Signature(ServerType.SANIC) 20 | parser = optparse.OptionParser() 21 | 22 | conf_path = '../config.yaml' 23 | model_path = '../model' 24 | graph_path = '../graph' 25 | 26 | system_config = Config(conf_path=conf_path, model_path=model_path, graph_path=graph_path) 27 | sign.set_auth([{'accessKey': system_config.access_key, 'secretKey': system_config.secret_key}]) 28 | logger = system_config.logger 29 | interface_manager = InterfaceManager() 30 | threading.Thread(target=lambda: event_loop(system_config, model_path, interface_manager)).start() 31 | 32 | image_utils = ImageUtils(system_config) 33 | 34 | 35 | @app.route('/captcha/auth/v2', methods=['POST']) 36 | @sign.signature_required # This decorator is required for certification. 37 | def auth_request(request): 38 | return common_request(request) 39 | 40 | 41 | @app.route('/captcha/v1', methods=['POST']) 42 | def no_auth_request(request): 43 | return common_request(request) 44 | 45 | 46 | def common_request(request): 47 | """ 48 | This api is used for captcha prediction without authentication 49 | :return: 50 | """ 51 | start_time = time.time() 52 | if not request.json or 'image' not in request.json: 53 | print(request.json) 54 | return 55 | 56 | if interface_manager.total == 0: 57 | logger.info('There is currently no model deployment and services are not available.') 58 | return json({"message": "", "success": False, "code": -999}) 59 | 60 | bytes_batch, response = image_utils.get_bytes_batch(request.json['image']) 61 | 62 | if not bytes_batch: 63 | logger.error('Type[{}] - Site[{}] - Response[{}] - {} ms'.format( 64 | request.json['model_type'], request.json['model_site'], response, 65 | (time.time() - start_time) * 1000) 66 | ) 67 | return json(response) 68 | 69 | image_sample = bytes_batch[0] 70 | image_size = ImageUtils.size_of_image(image_sample) 71 | size_string = "{}x{}".format(image_size[0], image_size[1]) 72 | 73 | if 'model_name' in request.json: 74 | interface = interface_manager.get_by_name(request.json['model_name']) 75 | else: 76 | interface = interface_manager.get_by_size(size_string) 77 | 78 | split_char = request.json['split_char'] if 'split_char' in request.json else interface.model_conf.split_char 79 | 80 | if 'need_color' in request.json and request.json['need_color']: 81 | bytes_batch = [color_extract.separate_color(_, color_map[request.json['need_color']]) for _ in bytes_batch] 82 | 83 | image_batch, response = ImageUtils.get_image_batch(interface.model_conf, bytes_batch) 84 | 85 | if not image_batch: 86 | logger.error('[{}] - Size[{}] - Name[{}] - Response[{}] - {} ms'.format( 87 | interface.name, size_string, request.json.get('model_name'), response, 88 | (time.time() - start_time) * 1000) 89 | ) 90 | return json(response) 91 | 92 | result = interface.predict_batch(image_batch, split_char) 93 | logger.info('[{}] - Size[{}] - Predict Result[{}] - {} ms'.format( 94 | interface.name, 95 | size_string, 96 | result, 97 | (time.time() - start_time) * 1000 98 | )) 99 | response['message'] = result 100 | return json(response) 101 | 102 | 103 | if __name__ == "__main__": 104 | 105 | parser.add_option('-p', '--port', type="int", default=19953, dest="port") 106 | 107 | opt, args = parser.parse_args() 108 | server_port = opt.port 109 | 110 | 111 | 112 | server_host = "0.0.0.0" 113 | 114 | logger.info('Running on http://{}:{}/ '.format(server_host, server_port)) 115 | app.run(host=server_host, port=server_port) 116 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import os 5 | import sys 6 | import uuid 7 | import json 8 | import yaml 9 | import hashlib 10 | import logging 11 | import logging.handlers 12 | from category import * 13 | from constants import SystemConfig, ModelField, ModelScene 14 | 15 | MODEL_SCENE_MAP = { 16 | 'Classification': ModelScene.Classification 17 | } 18 | 19 | MODEL_FIELD_MAP = { 20 | 'Image': ModelField.Image, 21 | 'Text': ModelField.Text 22 | } 23 | 24 | BLACKLIST_PATH = "blacklist.json" 25 | WHITELIST_PATH = "whitelist.json" 26 | 27 | 28 | def resource_path(relative_path): 29 | try: 30 | # PyInstaller creates a temp folder and stores path in _MEIPASS 31 | base_path = sys._MEIPASS 32 | except AttributeError: 33 | base_path = os.path.abspath(".") 34 | return os.path.join(base_path, relative_path) 35 | 36 | 37 | def get_version(): 38 | version_file_path = resource_path("VERSION") 39 | 40 | if not os.path.exists(version_file_path): 41 | return "NULL" 42 | 43 | with open(version_file_path, "r", encoding="utf8") as f: 44 | return "".join(f.readlines()).strip() 45 | 46 | 47 | def get_default(src, default): 48 | return src if src else default 49 | 50 | 51 | def get_dict_fill(src: dict, default: dict): 52 | if not src: 53 | return default 54 | new_dict = default 55 | new_dict.update(src) 56 | return new_dict 57 | 58 | 59 | def blacklist() -> list: 60 | if not os.path.exists(BLACKLIST_PATH): 61 | return [] 62 | try: 63 | with open(BLACKLIST_PATH, "r", encoding="utf8") as f_blacklist: 64 | result = json.loads("".join(f_blacklist.readlines())) 65 | return result 66 | except Exception as e: 67 | print(e) 68 | return [] 69 | 70 | 71 | def whitelist() -> list: 72 | if not os.path.exists(WHITELIST_PATH): 73 | return ["127.0.0.1", "localhost"] 74 | try: 75 | with open(WHITELIST_PATH, "r", encoding="utf8") as f_whitelist: 76 | result = json.loads("".join(f_whitelist.readlines())) 77 | return result 78 | except Exception as e: 79 | print(e) 80 | return ["127.0.0.1", "localhost"] 81 | 82 | 83 | def set_blacklist(ip): 84 | try: 85 | old_blacklist = blacklist() 86 | old_blacklist.append(ip) 87 | with open(BLACKLIST_PATH, "w+", encoding="utf8") as f_blacklist: 88 | f_blacklist.write(json.dumps(old_blacklist, ensure_ascii=False, indent=2)) 89 | except Exception as e: 90 | print(e) 91 | 92 | 93 | class Config(object): 94 | def __init__(self, conf_path: str, graph_path: str = None, model_path: str = None): 95 | self.model_path = model_path 96 | self.conf_path = conf_path 97 | self.graph_path = graph_path 98 | self.sys_cf = self.read_conf 99 | self.access_key = None 100 | self.secret_key = None 101 | self.default_model = self.sys_cf['System']['DefaultModel'] 102 | self.default_port = self.sys_cf['System'].get('DefaultPort') 103 | if not self.default_port: 104 | self.default_port = 19952 105 | self.split_flag = self.sys_cf['System']['SplitFlag'] 106 | self.split_flag = self.split_flag if isinstance(self.split_flag, bytes) else SystemConfig.split_flag 107 | self.route_map = get_default(self.sys_cf.get('RouteMap'), SystemConfig.default_route) 108 | self.log_path = "logs" 109 | self.request_def_map = get_default(self.sys_cf.get('RequestDef'), SystemConfig.default_config['RequestDef']) 110 | self.response_def_map = get_default(self.sys_cf.get('ResponseDef'), SystemConfig.default_config['ResponseDef']) 111 | self.save_path = self.sys_cf['System'].get("SavePath") 112 | self.request_count_interval = get_default( 113 | src=self.sys_cf['System'].get("RequestCountInterval"), 114 | default=60 * 60 * 24 115 | ) 116 | self.g_request_count_interval = get_default( 117 | src=self.sys_cf['System'].get("GlobalRequestCountInterval"), 118 | default=60 * 60 * 24 119 | ) 120 | self.request_limit = get_default(self.sys_cf['System'].get("RequestLimit"), -1) 121 | self.global_request_limit = get_default(self.sys_cf['System'].get("GlobalRequestLimit"), -1) 122 | self.exceeded_msg = get_default( 123 | src=self.sys_cf['System'].get("ExceededMessage"), 124 | default=SystemConfig.default_config['System'].get('ExceededMessage') 125 | ) 126 | self.illegal_time_msg = get_default( 127 | src=self.sys_cf['System'].get("IllegalTimeMessage"), 128 | default=SystemConfig.default_config['System'].get('IllegalTimeMessage') 129 | ) 130 | 131 | self.request_size_limit: dict = get_default( 132 | src=self.sys_cf['System'].get('RequestSizeLimit'), 133 | default={} 134 | ) 135 | self.blacklist_trigger_times = get_default(self.sys_cf['System'].get("BlacklistTriggerTimes"), -1) 136 | 137 | self.use_whitelist: dict = get_default( 138 | src=self.sys_cf['System'].get('Whitelist'), 139 | default=False 140 | ) 141 | 142 | self.error_message = get_dict_fill( 143 | self.sys_cf['System'].get('ErrorMessage'), SystemConfig.default_config['System']['ErrorMessage'] 144 | ) 145 | self.logger_tag = get_default(self.sys_cf['System'].get('LoggerTag'), "coriander") 146 | self.without_logger = self.sys_cf['System'].get('WithoutLogger') 147 | self.without_logger = self.without_logger if self.without_logger is not None else False 148 | self.logger = logging.getLogger(self.logger_tag) 149 | self.use_default_authorization = False 150 | self.authorization = None 151 | self.init_logger() 152 | self.assignment() 153 | 154 | def init_logger(self): 155 | self.logger.setLevel(logging.INFO) 156 | 157 | if not os.path.exists(self.model_path): 158 | os.makedirs(self.model_path) 159 | if not os.path.exists(self.graph_path): 160 | os.makedirs(self.graph_path) 161 | 162 | self.logger.propagate = False 163 | 164 | if not self.without_logger: 165 | if not os.path.exists(self.log_path): 166 | os.makedirs(self.log_path) 167 | file_handler = logging.handlers.TimedRotatingFileHandler( 168 | '{}/{}.log'.format(self.log_path, "captcha_platform"), 169 | when="MIDNIGHT", 170 | interval=1, 171 | backupCount=180, 172 | encoding='utf-8' 173 | ) 174 | stream_handler = logging.StreamHandler() 175 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 176 | file_handler.setFormatter(formatter) 177 | stream_handler.setFormatter(formatter) 178 | self.logger.addHandler(file_handler) 179 | self.logger.addHandler(stream_handler) 180 | 181 | def assignment(self): 182 | # ---AUTHORIZATION START--- 183 | mac_address = hex(uuid.getnode())[2:] 184 | self.use_default_authorization = False 185 | self.authorization = self.sys_cf.get('Security') 186 | if not self.authorization or not self.authorization.get('AccessKey') or not self.authorization.get('SecretKey'): 187 | self.use_default_authorization = True 188 | model_name_md5 = hashlib.md5( 189 | "{}".format(self.default_model).encode('utf8')).hexdigest() 190 | self.authorization = { 191 | 'AccessKey': model_name_md5[0: 16], 192 | 'SecretKey': hashlib.md5("{}{}".format(model_name_md5, mac_address).encode('utf8')).hexdigest() 193 | } 194 | self.access_key = self.authorization['AccessKey'] 195 | self.secret_key = self.authorization['SecretKey'] 196 | # ---AUTHORIZATION END--- 197 | 198 | @property 199 | def read_conf(self): 200 | if not os.path.exists(self.conf_path): 201 | with open(self.conf_path, 'w', encoding="utf-8") as sys_fp: 202 | sys_fp.write(yaml.safe_dump(SystemConfig.default_config)) 203 | return SystemConfig.default_config 204 | with open(self.conf_path, 'r', encoding="utf-8") as sys_fp: 205 | sys_stream = sys_fp.read() 206 | return yaml.load(sys_stream, Loader=yaml.SafeLoader) 207 | 208 | 209 | class Model(object): 210 | 211 | def __init__(self, conf: Config, model_conf_path: str): 212 | self.conf = conf 213 | self.logger = self.conf.logger 214 | self.graph_path = conf.graph_path 215 | self.model_path = conf.model_path 216 | self.model_conf_path = model_conf_path 217 | self.model_conf_demo = 'model_demo.yaml' 218 | self.verify() 219 | 220 | def verify(self): 221 | if not os.path.exists(self.model_conf_path): 222 | raise Exception( 223 | 'Configuration File "{}" No Found. ' 224 | 'If it is used for the first time, please copy one from {} as {}'.format( 225 | self.model_conf_path, 226 | self.model_conf_demo, 227 | self.model_path 228 | ) 229 | ) 230 | 231 | if not os.path.exists(self.model_path): 232 | os.makedirs(self.model_path) 233 | raise Exception( 234 | 'For the first time, please put the trained model in the model directory.' 235 | ) 236 | 237 | def category_extract(self, param): 238 | if isinstance(param, list): 239 | return param 240 | if isinstance(param, str): 241 | if param in SIMPLE_CATEGORY_MODEL.keys(): 242 | return SIMPLE_CATEGORY_MODEL.get(param) 243 | self.logger.error( 244 | "Category set configuration error, customized category set should be list type" 245 | ) 246 | return None 247 | 248 | @property 249 | def model_conf(self) -> dict: 250 | with open(self.model_conf_path, 'r', encoding="utf-8") as sys_fp: 251 | sys_stream = sys_fp.read() 252 | return yaml.load(sys_stream, Loader=yaml.SafeLoader) 253 | 254 | 255 | class ModelConfig(Model): 256 | model_exists: bool = False 257 | 258 | def __init__(self, conf: Config, model_conf_path: str): 259 | super().__init__(conf=conf, model_conf_path=model_conf_path) 260 | 261 | self.conf = conf 262 | 263 | """MODEL""" 264 | self.model_root: dict = self.model_conf['Model'] 265 | self.model_name: str = self.model_root.get('ModelName') 266 | self.model_version: float = self.model_root.get('Version') 267 | self.model_version = self.model_version if self.model_version else 1.0 268 | self.model_field_param: str = self.model_root.get('ModelField') 269 | self.model_field: ModelField = ModelConfig.param_convert( 270 | source=self.model_field_param, 271 | param_map=MODEL_FIELD_MAP, 272 | text="Current model field ({model_field}) is not supported".format(model_field=self.model_field_param), 273 | code=50002 274 | ) 275 | 276 | self.model_scene_param: str = self.model_root.get('ModelScene') 277 | 278 | self.model_scene: ModelScene = ModelConfig.param_convert( 279 | source=self.model_scene_param, 280 | param_map=MODEL_SCENE_MAP, 281 | text="Current model scene ({model_scene}) is not supported".format(model_scene=self.model_scene_param), 282 | code=50001 283 | ) 284 | 285 | """SYSTEM""" 286 | self.checkpoint_tag = 'checkpoint' 287 | self.system_root: dict = self.model_conf['System'] 288 | self.memory_usage: float = self.system_root.get('MemoryUsage') 289 | 290 | """FIELD PARAM - IMAGE""" 291 | self.field_root: dict = self.model_conf['FieldParam'] 292 | self.category_param = self.field_root.get('Category') 293 | self.category_value = self.category_extract(self.category_param) 294 | if self.category_value is None: 295 | raise Exception( 296 | "The category set type does not exist, there is no category set named {}".format(self.category_param), 297 | ) 298 | self.category: list = SPACE_TOKEN + self.category_value 299 | self.category_num: int = len(self.category) 300 | self.image_channel: int = self.field_root.get('ImageChannel') 301 | self.image_width: int = self.field_root.get('ImageWidth') 302 | self.image_height: int = self.field_root.get('ImageHeight') 303 | self.max_label_num: int = self.field_root.get('MaxLabelNum') 304 | self.min_label_num: int = self.get_var(self.field_root, 'MinLabelNum', self.max_label_num) 305 | self.resize: list = self.field_root.get('Resize') 306 | self.output_split = self.field_root.get('OutputSplit') 307 | self.output_split = self.output_split if self.output_split else "" 308 | self.corp_params = self.field_root.get('CorpParams') 309 | self.output_coord = self.field_root.get('OutputCoord') 310 | self.batch_model = self.field_root.get('BatchModel') 311 | self.external_model = self.field_root.get('ExternalModelForCorp') 312 | self.category_split = self.field_root.get('CategorySplit') 313 | 314 | """PRETREATMENT""" 315 | self.pretreatment_root = self.model_conf.get('Pretreatment') 316 | self.pre_binaryzation = self.get_var(self.pretreatment_root, 'Binaryzation', -1) 317 | self.pre_replace_transparent = self.get_var(self.pretreatment_root, 'ReplaceTransparent', True) 318 | self.pre_horizontal_stitching = self.get_var(self.pretreatment_root, 'HorizontalStitching', False) 319 | self.pre_concat_frames = self.get_var(self.pretreatment_root, 'ConcatFrames', -1) 320 | self.pre_blend_frames = self.get_var(self.pretreatment_root, 'BlendFrames', -1) 321 | self.pre_freq_frames = self.get_var(self.pretreatment_root, 'FreqFrames', -1) 322 | self.exec_map = self.get_var(self.pretreatment_root, 'ExecuteMap', None) 323 | 324 | """COMPILE_MODEL""" 325 | self.compile_model_path = os.path.join(self.graph_path, '{}.pb'.format(self.model_name)) 326 | if not os.path.exists(self.compile_model_path): 327 | if not os.path.exists(self.graph_path): 328 | os.makedirs(self.graph_path) 329 | self.logger.error( 330 | '{} not found, please put the trained model in the graph directory.'.format(self.compile_model_path) 331 | ) 332 | else: 333 | self.model_exists = True 334 | 335 | @staticmethod 336 | def param_convert(source, param_map: dict, text, code, default=None): 337 | if source is None: 338 | return default 339 | if source not in param_map.keys(): 340 | raise Exception(text) 341 | return param_map[source] 342 | 343 | def size_match(self, size_str): 344 | return size_str == self.size_string 345 | 346 | @staticmethod 347 | def get_var(src: dict, name: str, default=None): 348 | if not src or name not in src: 349 | return default 350 | return src.get(name) 351 | 352 | @property 353 | def size_string(self): 354 | return "{}x{}".format(self.image_width, self.image_height) 355 | -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | from enum import Enum, unique 5 | 6 | 7 | @unique 8 | class ModelScene(Enum): 9 | """模型场景枚举""" 10 | Classification = 'Classification' 11 | 12 | 13 | @unique 14 | class ModelField(Enum): 15 | """模型类别枚举""" 16 | Image = 'Image' 17 | Text = 'Text' 18 | 19 | 20 | class SystemConfig: 21 | split_flag = b'\x99\x99\x99\x00\xff\xff\xff\x00\x99\x99\x99' 22 | default_route = [ 23 | { 24 | "Class": "AuthHandler", 25 | "Route": "/captcha/auth/v2" 26 | }, 27 | { 28 | "Class": "NoAuthHandler", 29 | "Route": "/captcha/v1" 30 | }, 31 | { 32 | "Class": "SimpleHandler", 33 | "Route": "/captcha/v3" 34 | }, 35 | { 36 | "Class": "HeartBeatHandler", 37 | "Route": "/check_backend_active.html" 38 | }, 39 | { 40 | "Class": "HeartBeatHandler", 41 | "Route": "/verification" 42 | }, 43 | { 44 | "Class": "HeartBeatHandler", 45 | "Route": "/" 46 | }, 47 | { 48 | "Class": "ServiceHandler", 49 | "Route": "/service/info" 50 | }, 51 | { 52 | "Class": "FileHandler", 53 | "Route": "/service/logs/(.*)", 54 | "Param": {"path": "logs"} 55 | }, 56 | { 57 | "Class": "BaseHandler", 58 | "Route": ".*" 59 | } 60 | ] 61 | default_config = { 62 | "System": { 63 | "DefaultModel": "default", 64 | "SplitFlag": b'\x99\x99\x99\x00\xff\xff999999.........99999\xff\x00\x99\x99\x99', 65 | "SavePath": "", 66 | "RequestCountInterval": 86400, 67 | "GlobalRequestCountInterval": 86400, 68 | "RequestLimit": -1, 69 | "GlobalRequestLimit": -1, 70 | "WithoutLogger": False, 71 | "RequestSizeLimit": {}, 72 | "DefaultPort": 19952, 73 | "IllegalTimeMessage": "The maximum number of requests has been exceeded.", 74 | "ExceededMessage": "Illegal access time, please request in open hours.", 75 | "BlacklistTriggerTimes": -1, 76 | "Whitelist": False, 77 | "ErrorMessage": { 78 | 400: "Bad Request", 79 | 401: "Unicode Decode Error", 80 | 403: "Forbidden", 81 | 404: "404 Not Found", 82 | 405: "Method Not Allowed", 83 | 500: "Internal Server Error" 84 | } 85 | }, 86 | "RouteMap": default_route, 87 | "Security": { 88 | "AccessKey": "", 89 | "SecretKey": "" 90 | }, 91 | "RequestDef": { 92 | "InputData": "image", 93 | "ModelName": "model_name", 94 | }, 95 | "ResponseDef": { 96 | "Message": "message", 97 | "StatusCode": "code", 98 | "StatusBool": "success", 99 | "Uid": "uid", 100 | } 101 | } 102 | 103 | 104 | class ServerType(str): 105 | FLASK = 19951 106 | TORNADO = 19952 107 | SANIC = 19953 108 | 109 | 110 | class Response: 111 | 112 | def __init__(self, def_map: dict): 113 | # SIGN 114 | self.INVALID_PUBLIC_PARAMS = dict(Message='Invalid Public Params', StatusCode=400001, StatusBool=False) 115 | self.UNKNOWN_SERVER_ERROR = dict(Message='Unknown Server Error', StatusCode=400002, StatusBool=False) 116 | self.INVALID_TIMESTAMP = dict(Message='Invalid Timestamp', StatusCode=400004, StatusBool=False) 117 | self.INVALID_ACCESS_KEY = dict(Message='Invalid Access Key', StatusCode=400005, StatusBool=False) 118 | self.INVALID_QUERY_STRING = dict(Message='Invalid Query String', StatusCode=400006, StatusBool=False) 119 | 120 | # SERVER 121 | self.SUCCESS = dict(Message=None, StatusCode=000000, StatusBool=True) 122 | self.INVALID_IMAGE_FORMAT = dict(Message='Invalid Image Format', StatusCode=500001, StatusBool=False) 123 | self.INVALID_BASE64_STRING = dict(Message='Invalid Base64 String', StatusCode=500002, StatusBool=False) 124 | self.IMAGE_DAMAGE = dict(Message='Image Damage', StatusCode=500003, StatusBool=False) 125 | self.IMAGE_SIZE_NOT_MATCH_GRAPH = dict(Message='Image Size Not Match Graph Value', StatusCode=500004, StatusBool=False) 126 | 127 | self.INVALID_PUBLIC_PARAMS = self.parse(self.INVALID_PUBLIC_PARAMS, def_map) 128 | self.UNKNOWN_SERVER_ERROR = self.parse(self.UNKNOWN_SERVER_ERROR, def_map) 129 | self.INVALID_TIMESTAMP = self.parse(self.INVALID_TIMESTAMP, def_map) 130 | self.INVALID_ACCESS_KEY = self.parse(self.INVALID_ACCESS_KEY, def_map) 131 | self.INVALID_QUERY_STRING = self.parse(self.INVALID_QUERY_STRING, def_map) 132 | 133 | self.SUCCESS = self.parse(self.SUCCESS, def_map) 134 | self.INVALID_IMAGE_FORMAT = self.parse(self.INVALID_IMAGE_FORMAT, def_map) 135 | self.INVALID_BASE64_STRING = self.parse(self.INVALID_BASE64_STRING, def_map) 136 | self.IMAGE_DAMAGE = self.parse(self.IMAGE_DAMAGE, def_map) 137 | self.IMAGE_SIZE_NOT_MATCH_GRAPH = self.parse(self.IMAGE_SIZE_NOT_MATCH_GRAPH, def_map) 138 | 139 | def find_message(self, _code): 140 | e = [value for value in vars(self).values()] 141 | _t = [i['message'] for i in e if i['code'] == _code] 142 | return _t[0] if _t else None 143 | 144 | def find(self, _code): 145 | e = [value for value in vars(self).values()] 146 | _t = [i for i in e if i['code'] == _code] 147 | return _t[0] if _t else None 148 | 149 | def all_code(self): 150 | return [i['message'] for i in [value for value in vars(self).values()]] 151 | 152 | @staticmethod 153 | def parse(src: dict, target_map: dict): 154 | return {target_map[k]: v for k, v in src.items()} 155 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import io 5 | import os 6 | import base64 7 | import datetime 8 | import hashlib 9 | import time 10 | import numpy as np 11 | import cv2 12 | from config import Config 13 | from requests import Session, post, get 14 | from PIL import Image as PilImage 15 | from constants import ServerType 16 | 17 | # DEFAULT_HOST = "63.211.111.82" 18 | # DEFAULT_HOST = "39.100.71.103" 19 | # DEFAULT_HOST = "120.79.233.49" 20 | # DEFAULT_HOST = "47.52.203.228" 21 | DEFAULT_HOST = "192.168.50.152" 22 | # DEFAULT_HOST = "127.0.0.1" 23 | 24 | 25 | def _image(_path, model_type=None, model_site=None, need_color=None, fpath=None): 26 | with open(_path, "rb") as f: 27 | img_bytes = f.read() 28 | # data_stream = io.BytesIO(img_bytes) 29 | # pil_image = PilImage.open(data_stream) 30 | # size = pil_image.size 31 | # im = np.array(pil_image) 32 | # im = im[3:size[1] - 3, 3:size[0] - 3] 33 | # img_bytes = bytearray(cv2.imencode('.png', im)[1]) 34 | 35 | b64 = base64.b64encode(img_bytes).decode() 36 | return { 37 | 'image': b64, 38 | 'model_type': model_type, 39 | 'model_site': model_site, 40 | 'need_color': need_color, 41 | 'path': fpath 42 | } 43 | 44 | 45 | class Auth(object): 46 | 47 | def __init__(self, host: str, server_type: ServerType, access_key=None, secret_key=None, port=None): 48 | self._conf = Config(conf_path="config.yaml") 49 | self._url = 'http://{}:{}/captcha/auth/v2'.format(host, port if port else server_type) 50 | self._access_key = access_key if access_key else self._conf.access_key 51 | self._secret_key = secret_key if secret_key else self._conf.secret_key 52 | self.true_count = 0 53 | self.total_count = 0 54 | 55 | def sign(self, args): 56 | """ MD5 signature 57 | @param args: All query parameters (public and private) requested in addition to signature 58 | { 59 | 'image': 'base64 encoded text', 60 | 'accessKey': 'C180130204197838', 61 | 'timestamp': 1536682949, 62 | 'sign': 'F641778AE4F93DAF5CCE3E43A674C34E' 63 | } 64 | The sign is the md5 encrypted of "accessKey=your_assess_key&image=base64_encoded_text×tamp=current_timestamp" 65 | """ 66 | if "sign" in args: 67 | args.pop("sign") 68 | query_string = '&'.join(['{}={}'.format(k, v) for (k, v) in sorted(args.items())]) 69 | query_string = '&'.join([query_string, self._secret_key]) 70 | return hashlib.md5(query_string.encode('utf-8')).hexdigest().upper() 71 | 72 | def make_json(self, params): 73 | if not isinstance(params, dict): 74 | raise TypeError("params is not a dict") 75 | # Get the current timestamp 76 | timestamp = int(time.mktime(datetime.datetime.now().timetuple())) 77 | # Set public parameters 78 | params.update(accessKey=self._access_key, timestamp=timestamp) 79 | params.update(sign=self.sign(params)) 80 | return params 81 | 82 | def request(self, params): 83 | params = dict(params, **self.make_json(params)) 84 | return post(self._url, json=params).json() 85 | 86 | def local_iter(self, image_list: dict): 87 | for k, v in image_list.items(): 88 | code = self.request(v).get('message') 89 | _true = str(code).lower() == str(k).lower() 90 | if _true: 91 | self.true_count += 1 92 | self.total_count += 1 93 | print('result: {}, label: {}, flag: n{}, acc_rate: {}'.format(code, k, _true, 94 | self.true_count / self.total_count)) 95 | 96 | 97 | class NoAuth(object): 98 | def __init__(self, host: str, server_type: ServerType, port=None, url=None): 99 | self._url = 'http://{}:{}/captcha/v1'.format(host, port if port else server_type) 100 | self._url = self._url if not url else url 101 | self.true_count = 0 102 | self.total_count = 0 103 | 104 | def request(self, params): 105 | import json 106 | # print(params) 107 | # print(params['fpath']) 108 | # print(json.dumps(params)) 109 | # return post(self._url, data=base64.b64decode(params.get("image").encode())).json() 110 | return post(self._url, json=params).json() 111 | 112 | def local_iter(self, image_list: dict): 113 | for k, v in image_list.items(): 114 | try: 115 | code = self.request(v).get('message') 116 | _true = str(code).lower() == str(k).lower() 117 | if _true: 118 | self.true_count += 1 119 | self.total_count += 1 120 | print('result: {}, label: {}, flag: {}, acc_rate: {}, {}'.format( 121 | code, k, _true, self.true_count / self.total_count, v.get('path') 122 | )) 123 | except Exception as e: 124 | print(e) 125 | 126 | def press_testing(self, image_list: dict, model_type=None, model_site=None): 127 | from multiprocessing.pool import ThreadPool 128 | pool = ThreadPool(500) 129 | for k, v in image_list.items(): 130 | pool.apply_async( 131 | self.request({"image": v.get('image'), "model_type": model_type, "model_site": model_site})) 132 | pool.close() 133 | pool.join() 134 | print(self.true_count / len(image_list)) 135 | 136 | 137 | class GoogleRPC(object): 138 | 139 | def __init__(self, host: str): 140 | self._url = '{}:50054'.format(host) 141 | self.true_count = 0 142 | self.total_count = 0 143 | 144 | def request(self, image, println=False, value=None, model_type=None, model_site=None, need_color=None): 145 | 146 | import grpc 147 | import grpc_pb2 148 | import grpc_pb2_grpc 149 | channel = grpc.insecure_channel(self._url) 150 | stub = grpc_pb2_grpc.PredictStub(channel) 151 | response = stub.predict(grpc_pb2.PredictRequest( 152 | image=image, split_char=',', model_type=model_type, model_site=model_site, need_color=need_color 153 | )) 154 | if println and value: 155 | _true = str(response.result).lower() == str(value).lower() 156 | if _true: 157 | self.true_count += 1 158 | print("result: {}, label: {}, flag: {}".format(response.result, value, _true)) 159 | return {"message": response.result, "code": response.code, "success": response.success} 160 | 161 | def local_iter(self, image_list: dict, model_type=None, model_site=None): 162 | for k, v in image_list.items(): 163 | code = self.request(v.get('image'), model_type=model_type, model_site=model_site, 164 | need_color=v.get('need_color')).get('message') 165 | _true = str(code).lower() == str(k).lower() 166 | if _true: 167 | self.true_count += 1 168 | self.total_count += 1 169 | print('result: {}, label: {}, flag: {}, acc_rate: {}'.format( 170 | code, k, _true, self.true_count / self.total_count 171 | )) 172 | 173 | def remote_iter(self, url: str, save_path: str = None, num=100, model_type=None, model_site=None): 174 | if not os.path.exists(save_path): 175 | os.makedirs(save_path) 176 | sess = Session() 177 | sess.verify = False 178 | for i in range(num): 179 | img_bytes = sess.get(url).content 180 | img_b64 = base64.b64encode(img_bytes).decode() 181 | code = self.request(img_b64, model_type=model_type, model_site=model_site).get('message') 182 | with open("{}/{}_{}.jpg".format(save_path, code, hashlib.md5(img_bytes).hexdigest()), "wb") as f: 183 | f.write(img_bytes) 184 | 185 | print('result: {}'.format( 186 | code, 187 | )) 188 | 189 | def press_testing(self, image_list: dict, model_type=None, model_site=None): 190 | from multiprocessing.pool import ThreadPool 191 | pool = ThreadPool(500) 192 | for k, v in image_list.items(): 193 | pool.apply_async(self.request(v.get('image'), True, k, model_type=model_type, model_site=model_site)) 194 | pool.close() 195 | pool.join() 196 | print(self.true_count / len(image_list)) 197 | 198 | 199 | if __name__ == '__main__': 200 | # # Here you can replace it with a web request to get images in real time. 201 | # with open(r"D:\***.jpg", "rb") as f: 202 | # img_bytes = f.read() 203 | 204 | # # Here is the code for the network request. 205 | # # Replace your own captcha url for testing. 206 | # # sess = Session() 207 | # # sess.headers = { 208 | # # 'user-agent': 'Chrome' 209 | # # } 210 | # # img_bytes = sess.get("http://***.com/captcha").content 211 | # 212 | # # Open the image for human eye comparison, 213 | # # preview whether the recognition result is consistent. 214 | # data_stream = io.BytesIO(img_bytes) 215 | # pil_image = PilImage.open(data_stream) 216 | # pil_image.show() 217 | # api_params = { 218 | # 'image': base64.b64encode(img_bytes).decode(), 219 | # } 220 | # print(api_params) 221 | # for i in range(1): 222 | # Tornado API with authentication 223 | # resp = Auth(DEFAULT_HOST, ServerType.TORNADO).request(api_params) 224 | # print(resp) 225 | 226 | # Flask API with authentication 227 | # resp = Auth(DEFAULT_HOST, ServerType.FLASK).request(api_params) 228 | # print(resp) 229 | 230 | # Tornado API without authentication 231 | # resp = NoAuth(DEFAULT_HOST, ServerType.TORNADO).request(api_params) 232 | # print(resp) 233 | 234 | # Flask API without authentication 235 | # resp = NoAuth(DEFAULT_HOST, ServerType.FLASK).request(api_params) 236 | # print(resp) 237 | 238 | # API by gRPC - The fastest way. 239 | # If you want to identify multiple verification codes continuously, please do like this: 240 | # resp = GoogleRPC(DEFAULT_HOST).request(base64.b64encode(img_bytes+b'\x00\xff\xff\xff\x00'+img_bytes).decode()) 241 | # b'\x00\xff\xff\xff\x00' is the split_flag defined in config.py 242 | # resp = GoogleRPC(DEFAULT_HOST).request(base64.b64encode(img_bytes).decode()) 243 | # print(resp) 244 | # pass 245 | 246 | # API by gRPC - The fastest way, Local batch version, only for self testing. 247 | path = r"C:\Users\kerlomz\Desktop\New folder (6)" 248 | path_list = os.listdir(path) 249 | import random 250 | 251 | # random.shuffle(path_list) 252 | print(path_list) 253 | batch = { 254 | _path.split('_')[0].lower(): _image( 255 | os.path.join(path, _path), 256 | model_type=None, 257 | model_site=None, 258 | need_color=None, 259 | fpath=_path 260 | ) 261 | for i, _path in enumerate(path_list) 262 | if i < 10000 263 | } 264 | print(batch) 265 | NoAuth(DEFAULT_HOST, ServerType.TORNADO, port=19952).local_iter(batch) 266 | # NoAuth(DEFAULT_HOST, ServerType.FLASK).local_iter(batch) 267 | # NoAuth(DEFAULT_HOST, ServerType.SANIC).local_iter(batch) 268 | # GoogleRPC(DEFAULT_HOST).local_iter(batch, model_site=None, model_type=None) 269 | # GoogleRPC(DEFAULT_HOST).press_testing(batch, model_site=None, model_type=None) 270 | # GoogleRPC(DEFAULT_HOST).remote_iter("https://pbank.cqrcb.com:9080/perbank/VerifyImage?update=0.8746844661116633", r"D:\test12", 100, model_site='80x24', model_type=None) 271 | -------------------------------------------------------------------------------- /deploy.conf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | # Gunicorn deploy file. 5 | 6 | import multiprocessing 7 | 8 | bind = '0.0.0.0:19951' 9 | workers = multiprocessing.cpu_count() * 2 + 1 10 | backlog = 2048 11 | # worker_class = "gevent" 12 | debug = False 13 | daemon = True 14 | proc_name = 'gunicorn.pid' 15 | pidfile = 'debug.log' 16 | errorlog = 'error.log' 17 | accesslog = 'access.log' 18 | loglevel = 'info' 19 | timeout = 10 20 | -------------------------------------------------------------------------------- /event_handler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import os 5 | import time 6 | from watchdog.events import * 7 | from config import ModelConfig, Config 8 | from graph_session import GraphSession 9 | from interface import InterfaceManager, Interface 10 | from utils import PathUtils 11 | 12 | 13 | class FileEventHandler(FileSystemEventHandler): 14 | def __init__(self, conf: Config, model_conf_path: str, interface_manager: InterfaceManager): 15 | FileSystemEventHandler.__init__(self) 16 | self.conf = conf 17 | self.logger = self.conf.logger 18 | self.name_map = {} 19 | self.model_conf_path = model_conf_path 20 | self.interface_manager = interface_manager 21 | self.init() 22 | 23 | def init(self): 24 | model_list = os.listdir(self.model_conf_path) 25 | model_list = [os.path.join(self.model_conf_path, i) for i in model_list if i.endswith("yaml")] 26 | for model in model_list: 27 | self._add(model, is_first=True) 28 | if self.interface_manager.total == 0: 29 | self.logger.info( 30 | "\n - Number of interfaces: {}" 31 | "\n - There is currently no model deployment" 32 | "\n - Services are not available" 33 | "\n[ Please check the graph and model path whether the pb file and yaml file are placed. ]".format( 34 | self.interface_manager.total, 35 | )) 36 | else: 37 | self.logger.info( 38 | "\n - Number of interfaces: {}" 39 | "\n - Current online interface: \n\t - {}" 40 | "\n - The default Interface is: {}".format( 41 | self.interface_manager.total, 42 | "\n\t - ".join(["[{}]".format(v) for k, v in self.name_map.items()]), 43 | self.interface_manager.default_name 44 | )) 45 | 46 | def _add(self, src_path, is_first=False, count=0): 47 | try: 48 | model_path = str(src_path) 49 | path_exists = os.path.exists(model_path) 50 | if not path_exists and count > 0: 51 | self.logger.error("{} not found, retry attempt is terminated.".format(model_path)) 52 | return 53 | if 'model_demo.yaml' in model_path: 54 | self.logger.warning( 55 | "\n-------------------------------------------------------------------\n" 56 | "- Found that the model_demo.yaml file exists, \n" 57 | "- the loading is automatically ignored. \n" 58 | "- If it is used for the first time, \n" 59 | "- please copy it as a template. \n" 60 | "- and do not use the reserved character \"model_demo.yaml\" as the file name." 61 | "\n-------------------------------------------------------------------" 62 | ) 63 | return 64 | if model_path.endswith("yaml"): 65 | model_conf = ModelConfig(self.conf, model_path) 66 | inner_name = model_conf.model_name 67 | inner_size = model_conf.size_string 68 | inner_key = PathUtils.get_file_name(model_path) 69 | for k, v in self.name_map.items(): 70 | if inner_size in v: 71 | self.logger.warning( 72 | "\n-------------------------------------------------------------------\n" 73 | "- The current model {} is the same size [{}] as the loaded model {}. \n" 74 | "- Only one of the smart calls can be called. \n" 75 | "- If you want to refer to one of them, \n" 76 | "- please use the model key or model type to find it." 77 | "\n-------------------------------------------------------------------".format( 78 | inner_key, inner_size, k 79 | ) 80 | ) 81 | break 82 | 83 | inner_value = model_conf.model_name 84 | graph_session = GraphSession(model_conf) 85 | if graph_session.loaded: 86 | interface = Interface(graph_session) 87 | if inner_name == self.conf.default_model: 88 | self.interface_manager.set_default(interface) 89 | else: 90 | self.interface_manager.add(interface) 91 | self.logger.info("{} a new model: {} ({})".format( 92 | "Inited" if is_first else "Added", inner_value, inner_key 93 | )) 94 | self.name_map[inner_key] = inner_value 95 | if src_path in self.interface_manager.invalid_group: 96 | self.interface_manager.invalid_group.pop(src_path) 97 | else: 98 | self.interface_manager.report(src_path) 99 | if count < 12 and not is_first: 100 | time.sleep(5) 101 | return self._add(src_path, is_first=is_first, count=count+1) 102 | 103 | except Exception as e: 104 | self.interface_manager.report(src_path) 105 | self.logger.error(e) 106 | 107 | def delete(self, src_path): 108 | try: 109 | model_path = str(src_path) 110 | if model_path.endswith("yaml"): 111 | inner_key = PathUtils.get_file_name(model_path) 112 | graph_name = self.name_map.get(inner_key) 113 | self.interface_manager.remove_by_name(graph_name) 114 | self.name_map.pop(inner_key) 115 | self.logger.info("Unload the model: {} ({})".format(graph_name, inner_key)) 116 | except Exception as e: 117 | self.logger.error("Config File [{}] does not exist.".format(str(e).replace("'", ""))) 118 | 119 | def on_created(self, event): 120 | if event.is_directory: 121 | self.logger.info("directory created:{0}".format(event.src_path)) 122 | else: 123 | model_path = str(event.src_path) 124 | self._add(model_path) 125 | self.logger.info( 126 | "\n - Number of interfaces: {}" 127 | "\n - Current online interface: \n\t - {}" 128 | "\n - The default Interface is: {}".format( 129 | len(self.interface_manager.group), 130 | "\n\t - ".join(["[{}]".format(v) for k, v in self.name_map.items()]), 131 | self.interface_manager.default_name 132 | )) 133 | 134 | def on_deleted(self, event): 135 | if event.is_directory: 136 | self.logger.info("directory deleted:{0}".format(event.src_path)) 137 | else: 138 | model_path = str(event.src_path) 139 | if model_path in self.interface_manager.invalid_group: 140 | self.interface_manager.invalid_group.pop(model_path) 141 | inner_key = PathUtils.get_file_name(model_path) 142 | if inner_key in self.name_map: 143 | self.delete(model_path) 144 | self.logger.info( 145 | "\n - Number of interfaces: {}" 146 | "\n - Current online interface: \n\t - {}" 147 | "\n - The default Interface is: {}".format( 148 | len(self.interface_manager.group), 149 | "\n\t - ".join(["[{}]".format(v) for k, v in self.name_map.items()]), 150 | self.interface_manager.default_name 151 | )) 152 | 153 | 154 | if __name__ == "__main__": 155 | pass 156 | # import time 157 | # from watchdog.observers import Observer 158 | # observer = Observer() 159 | # interface_manager = InterfaceManager() 160 | # event_handler = FileEventHandler("", interface_manager) 161 | # observer.schedule(event_handler, event_handler.model_conf_path, True) 162 | # observer.start() 163 | # try: 164 | # while True: 165 | # time.sleep(1) 166 | # except KeyboardInterrupt: 167 | # observer.stop() 168 | # observer.join() 169 | -------------------------------------------------------------------------------- /event_loop.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import time 5 | from watchdog.observers import Observer 6 | from event_handler import FileEventHandler 7 | 8 | 9 | def event_loop(system_config, model_path, interface_manager): 10 | observer = Observer() 11 | event_handler = FileEventHandler(system_config, model_path, interface_manager) 12 | observer.schedule(event_handler, event_handler.model_conf_path, True) 13 | observer.start() 14 | try: 15 | while True: 16 | time.sleep(1) 17 | except KeyboardInterrupt: 18 | observer.stop() 19 | observer.join() -------------------------------------------------------------------------------- /graph_session.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import os 5 | import tensorflow as tf 6 | tf.compat.v1.disable_v2_behavior() 7 | from tensorflow.python.framework.errors_impl import NotFoundError 8 | from config import ModelConfig 9 | 10 | os.environ['TF_XLA_FLAGS'] = '--tf_xla_cpu_global_jit' 11 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 12 | 13 | 14 | class GraphSession(object): 15 | def __init__(self, model_conf: ModelConfig): 16 | self.model_conf = model_conf 17 | self.logger = self.model_conf.logger 18 | self.size_str = self.model_conf.size_string 19 | self.model_name = self.model_conf.model_name 20 | self.graph_name = self.model_conf.model_name 21 | self.version = self.model_conf.model_version 22 | self.graph = tf.compat.v1.Graph() 23 | self.sess = tf.compat.v1.Session( 24 | graph=self.graph, 25 | config=tf.compat.v1.ConfigProto( 26 | 27 | # allow_soft_placement=True, 28 | # log_device_placement=True, 29 | gpu_options=tf.compat.v1.GPUOptions( 30 | # allocator_type='BFC', 31 | allow_growth=True, # it will cause fragmentation. 32 | # per_process_gpu_memory_fraction=self.model_conf.device_usage 33 | per_process_gpu_memory_fraction=0.1 34 | ) 35 | ) 36 | ) 37 | self.graph_def = self.graph.as_graph_def() 38 | self.loaded = self.load_model() 39 | 40 | def load_model(self): 41 | # Here is for debugging, positioning error source use. 42 | # with self.graph.as_default(): 43 | # saver = tf.train.import_meta_graph('graph/***.meta') 44 | # saver.restore(self.sess, tf.train.latest_checkpoint('graph')) 45 | if not self.model_conf.model_exists: 46 | self.destroy() 47 | return False 48 | try: 49 | with tf.io.gfile.GFile(self.model_conf.compile_model_path, "rb") as f: 50 | graph_def_file = f.read() 51 | self.graph_def.ParseFromString(graph_def_file) 52 | with self.graph.as_default(): 53 | self.sess.run(tf.compat.v1.global_variables_initializer()) 54 | _ = tf.import_graph_def(self.graph_def, name="") 55 | 56 | self.logger.info('TensorFlow Session {} Loaded.'.format(self.model_conf.model_name)) 57 | return True 58 | except NotFoundError: 59 | self.logger.error('The system cannot find the model specified.') 60 | self.destroy() 61 | return False 62 | 63 | @property 64 | def session(self): 65 | return self.sess 66 | 67 | def destroy(self): 68 | self.sess.close() 69 | del self.sess 70 | -------------------------------------------------------------------------------- /interface.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import os 5 | import time 6 | from graph_session import GraphSession 7 | from predict import predict_func 8 | 9 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 10 | 11 | 12 | class Interface(object): 13 | 14 | def __init__(self, graph_session: GraphSession): 15 | self.graph_sess = graph_session 16 | self.model_conf = graph_session.model_conf 17 | self.size_str = self.model_conf.size_string 18 | self.graph_name = self.graph_sess.graph_name 19 | self.version = self.graph_sess.version 20 | self.model_category = self.model_conf.category_param 21 | if self.graph_sess.loaded: 22 | self.sess = self.graph_sess.session 23 | self.dense_decoded = self.sess.graph.get_tensor_by_name("dense_decoded:0") 24 | self.x = self.sess.graph.get_tensor_by_name('input:0') 25 | self.sess.graph.finalize() 26 | 27 | @property 28 | def name(self): 29 | return self.graph_name 30 | 31 | @property 32 | def size(self): 33 | return self.size_str 34 | 35 | def destroy(self): 36 | self.graph_sess.destroy() 37 | 38 | def predict_batch(self, image_batch, output_split=None): 39 | predict_text = predict_func( 40 | image_batch, 41 | self.sess, 42 | self.dense_decoded, 43 | self.x, 44 | self.model_conf, 45 | output_split 46 | ) 47 | return predict_text 48 | 49 | 50 | class InterfaceManager(object): 51 | 52 | def __init__(self, interface: Interface = None): 53 | self.group = [] 54 | self.invalid_group = {} 55 | self.set_default(interface) 56 | 57 | def add(self, interface: Interface): 58 | if interface in self.group: 59 | return 60 | self.group.append(interface) 61 | 62 | def remove(self, interface: Interface): 63 | if interface in self.group: 64 | interface.destroy() 65 | self.group.remove(interface) 66 | 67 | def report(self, model): 68 | self.invalid_group[model] = {"create_time": time.asctime(time.localtime(time.time()))} 69 | 70 | def remove_by_name(self, graph_name): 71 | interface = self.get_by_name(graph_name, False) 72 | self.remove(interface) 73 | 74 | def get_by_size(self, size: str, return_default=True): 75 | 76 | match_ids = [i for i in range(len(self.group)) if self.group[i].size_str == size] 77 | if not match_ids: 78 | return self.default if return_default else None 79 | else: 80 | ver = [self.group[i].version for i in match_ids] 81 | return self.group[match_ids[ver.index(max(ver))]] 82 | 83 | def get_by_name(self, key: str, return_default=True): 84 | for interface in self.group: 85 | if interface.name == key: 86 | 87 | return interface 88 | return self.default if return_default else None 89 | 90 | @property 91 | def default(self): 92 | return self.group[0] if len(self.group) > 0 else None 93 | 94 | @property 95 | def default_name(self): 96 | _default = self.default 97 | if not _default: 98 | return 99 | return _default.graph_name 100 | 101 | @property 102 | def total(self): 103 | return len(self.group) 104 | 105 | @property 106 | def online_names(self): 107 | return [i.name for i in self.group] 108 | 109 | def set_default(self, interface: Interface): 110 | if not interface: 111 | return 112 | self.group.insert(0, interface) 113 | -------------------------------------------------------------------------------- /middleware/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | 5 | # from middleware.impl import color_filter 6 | from middleware.impl import corp_to_multi 7 | from middleware.impl import gif_frames 8 | # color_extract = color_filter.ColorFilter() 9 | # color_map = color_filter.color_map 10 | 11 | -------------------------------------------------------------------------------- /middleware/constructor/color_extractor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import tensorflow as tf 5 | from tensorflow.python.framework.graph_util import convert_variables_to_constants 6 | 7 | black = tf.constant([[0, 0, 0]], dtype=tf.int32) 8 | red = tf.constant([[0, 0, 255]], dtype=tf.int32) 9 | yellow = tf.constant([[0, 255, 255]], dtype=tf.int32) 10 | blue = tf.constant([[255, 0, 0]], dtype=tf.int32) 11 | green = tf.constant([[0, 255, 0]], dtype=tf.int32) 12 | white = tf.constant([[255, 255, 255]], dtype=tf.int32) 13 | 14 | 15 | def k_means(data, target_color, bg_color1, bg_color2, alpha=1.0): 16 | def get_distance(point): 17 | sum_squares = tf.cast(tf.reduce_sum(tf.abs(tf.subtract(data, point)), axis=2, keep_dims=True), tf.float32) 18 | return sum_squares 19 | 20 | alpha_value = tf.constant(alpha, dtype=tf.float32) 21 | # target_color black:0, red:1, blue:2, yellow:3, green:4, color_1:5, color_2:6 22 | black_distance = get_distance(black) 23 | red_distance = get_distance(red) 24 | if target_color == 1: 25 | red_distance = tf.multiply(red_distance, alpha_value) 26 | blue_distance = get_distance(blue) 27 | if target_color == 2: 28 | blue_distance = tf.multiply(blue_distance, alpha_value) 29 | yellow_distance = get_distance(yellow) 30 | if target_color == 3: 31 | yellow_distance = tf.multiply(yellow_distance, alpha_value) 32 | white_distance = get_distance(yellow) 33 | if target_color == 7: 34 | white_distance = tf.multiply(white_distance, alpha_value) 35 | 36 | green_distance = get_distance(green) 37 | c_1_distance = get_distance(bg_color1) 38 | c_2_distance = get_distance(bg_color2) 39 | 40 | distances = tf.concat([ 41 | black_distance, 42 | red_distance, 43 | blue_distance, 44 | yellow_distance, 45 | green_distance, 46 | c_1_distance, 47 | c_2_distance, 48 | white_distance 49 | ], axis=-1) 50 | 51 | clusters = tf.cast(tf.argmin(distances, axis=-1), tf.int32) 52 | 53 | mask = tf.equal(clusters, target_color) 54 | mask = tf.cast(mask, tf.int32) 55 | 56 | return mask * 255 57 | 58 | 59 | def filter_img(img, target_color, alpha=0.9): 60 | # background color1 61 | color_1 = img[0, 0, :] 62 | color_1 = tf.reshape(color_1, [1, 3]) 63 | color_1 = tf.cast(color_1, dtype=tf.int32) 64 | 65 | # background color2 66 | color_2 = img[6, 6, :] 67 | color_2 = tf.reshape(color_2, [1, 3]) 68 | color_2 = tf.cast(color_2, dtype=tf.int32) 69 | 70 | filtered_img = k_means(img_holder, target_color, color_1, color_2, alpha) 71 | filtered_img = tf.expand_dims(filtered_img, axis=0) 72 | filtered_img = tf.expand_dims(filtered_img, axis=-1) 73 | filtered_img = tf.squeeze(filtered_img, name="filtered") 74 | return filtered_img 75 | 76 | 77 | def compile_graph(): 78 | 79 | with sess.graph.as_default(): 80 | input_graph_def = sess.graph.as_graph_def() 81 | 82 | output_graph_def = convert_variables_to_constants( 83 | sess, 84 | input_graph_def, 85 | output_node_names=['filtered'] 86 | ) 87 | 88 | last_compile_model_path = "color_extractor.pb" 89 | with tf.gfile.FastGFile(last_compile_model_path, mode='wb') as gf: 90 | # gf.write(output_graph_def.SerializeToString()) 91 | print(output_graph_def.SerializeToString()) 92 | 93 | 94 | if __name__ == "__main__": 95 | 96 | sess = tf.Session() 97 | img_holder = tf.placeholder(dtype=tf.int32, name="img_holder") 98 | color = tf.placeholder(dtype=tf.int32, name="target_color") 99 | filtered = filter_img(img_holder, color, alpha=0.8) 100 | 101 | compile_graph() 102 | 103 | -------------------------------------------------------------------------------- /middleware/impl/color_extractor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import io 5 | import cv2 6 | import time 7 | import PIL.Image as PilImage 8 | import numpy as np 9 | import tensorflow as tf 10 | from enum import Enum, unique 11 | from distutils.version import StrictVersion 12 | 13 | 14 | @unique 15 | class TargetColor(Enum): 16 | Black = 0 17 | Red = 1 18 | Blue = 2 19 | Yellow = 3 20 | Green = 4 21 | White = 7 22 | 23 | 24 | color_map = { 25 | 'black': TargetColor.Black, 26 | 'red': TargetColor.Red, 27 | 'blue': TargetColor.Blue, 28 | 'yellow': TargetColor.Yellow, 29 | 'green': TargetColor.Green, 30 | 'white': TargetColor.White 31 | } 32 | 33 | 34 | class ColorExtract: 35 | 36 | def __init__(self): 37 | self.model_raw_v1_14 = b'\nB\n\x05Const\x12\x05Const*%\n\x05value\x12\x1cB\x1a\x08\x03\x12\x08\x12\x02\x08\x01\x12\x02\x08\x03"\x0c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00*\x0b\n\x05dtype\x12\x020\x03\nD\n\x07Const_1\x12\x05Const*%\n\x05value\x12\x1cB\x1a\x08\x03\x12\x08\x12\x02\x08\x01\x12\x02\x08\x03"\x0c\x00\x00\x00\x00\x00\x00\x00\x00\xff\x00\x00\x00*\x0b\n\x05dtype\x12\x020\x03\nD\n\x07Const_2\x12\x05Const*%\n\x05value\x12\x1cB\x1a\x08\x03\x12\x08\x12\x02\x08\x01\x12\x02\x08\x03"\x0c\x00\x00\x00\x00\xff\x00\x00\x00\xff\x00\x00\x00*\x0b\n\x05dtype\x12\x020\x03\nD\n\x07Const_3\x12\x05Const*%\n\x05value\x12\x1cB\x1a\x08\x03\x12\x08\x12\x02\x08\x01\x12\x02\x08\x03"\x0c\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00*\x0b\n\x05dtype\x12\x020\x03\nD\n\x07Const_4\x12\x05Const*%\n\x05value\x12\x1cB\x1a\x08\x03\x12\x08\x12\x02\x08\x01\x12\x02\x08\x03"\x0c\x00\x00\x00\x00\xff\x00\x00\x00\x00\x00\x00\x00*\x0b\n\x05dtype\x12\x020\x03\n5\n\nimg_holder\x12\x0bPlaceholder*\r\n\x05shape\x12\x04:\x02\x18\x01*\x0b\n\x05dtype\x12\x020\x03\n7\n\x0ctarget_color\x12\x0bPlaceholder*\r\n\x05shape\x12\x04:\x02\x18\x01*\x0b\n\x05dtype\x12\x020\x03\nL\n\x13strided_slice/stack\x12\x05Const*!\n\x05value\x12\x18B\x16\x08\x03\x12\x04\x12\x02\x08\x03"\x0c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00*\x0b\n\x05dtype\x12\x020\x03\nN\n\x15strided_slice/stack_1\x12\x05Const*!\n\x05value\x12\x18B\x16\x08\x03\x12\x04\x12\x02\x08\x03"\x0c\x01\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00*\x0b\n\x05dtype\x12\x020\x03\nN\n\x15strided_slice/stack_2\x12\x05Const*!\n\x05value\x12\x18B\x16\x08\x03\x12\x04\x12\x02\x08\x03"\x0c\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00*\x0b\n\x05dtype\x12\x020\x03\n\xe6\x01\n\rstrided_slice\x12\x0cStridedSlice\x1a\nimg_holder\x1a\x13strided_slice/stack\x1a\x15strided_slice/stack_1\x1a\x15strided_slice/stack_2*\x07\n\x01T\x12\x020\x03*\x0b\n\x05Index\x12\x020\x03*\x16\n\x10shrink_axis_mask\x12\x02\x18\x03*\x10\n\nbegin_mask\x12\x02\x18\x04*\x13\n\rellipsis_mask\x12\x02\x18\x00*\x13\n\rnew_axis_mask\x12\x02\x18\x00*\x0e\n\x08end_mask\x12\x02\x18\x04\nB\n\rReshape/shape\x12\x05Const*\x1d\n\x05value\x12\x14B\x12\x08\x03\x12\x04\x12\x02\x08\x02"\x08\x01\x00\x00\x00\x03\x00\x00\x00*\x0b\n\x05dtype\x12\x020\x03\nG\n\x07Reshape\x12\x07Reshape\x1a\rstrided_slice\x1a\rReshape/shape*\x07\n\x01T\x12\x020\x03*\x0c\n\x06Tshape\x12\x020\x03\nN\n\x15strided_slice_1/stack\x12\x05Const*!\n\x05value\x12\x18B\x16\x08\x03\x12\x04\x12\x02\x08\x03"\x0c\x06\x00\x00\x00\x06\x00\x00\x00\x00\x00\x00\x00*\x0b\n\x05dtype\x12\x020\x03\nP\n\x17strided_slice_1/stack_1\x12\x05Const*!\n\x05value\x12\x18B\x16\x08\x03\x12\x04\x12\x02\x08\x03"\x0c\x07\x00\x00\x00\x07\x00\x00\x00\x00\x00\x00\x00*\x0b\n\x05dtype\x12\x020\x03\nP\n\x17strided_slice_1/stack_2\x12\x05Const*!\n\x05value\x12\x18B\x16\x08\x03\x12\x04\x12\x02\x08\x03"\x0c\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00*\x0b\n\x05dtype\x12\x020\x03\n\xee\x01\n\x0fstrided_slice_1\x12\x0cStridedSlice\x1a\nimg_holder\x1a\x15strided_slice_1/stack\x1a\x17strided_slice_1/stack_1\x1a\x17strided_slice_1/stack_2*\x07\n\x01T\x12\x020\x03*\x0b\n\x05Index\x12\x020\x03*\x16\n\x10shrink_axis_mask\x12\x02\x18\x03*\x10\n\nbegin_mask\x12\x02\x18\x04*\x13\n\rellipsis_mask\x12\x02\x18\x00*\x13\n\rnew_axis_mask\x12\x02\x18\x00*\x0e\n\x08end_mask\x12\x02\x18\x04\nD\n\x0fReshape_1/shape\x12\x05Const*\x1d\n\x05value\x12\x14B\x12\x08\x03\x12\x04\x12\x02\x08\x02"\x08\x01\x00\x00\x00\x03\x00\x00\x00*\x0b\n\x05dtype\x12\x020\x03\nM\n\tReshape_1\x12\x07Reshape\x1a\x0fstrided_slice_1\x1a\x0fReshape_1/shape*\x07\n\x01T\x12\x020\x03*\x0c\n\x06Tshape\x12\x020\x03\n&\n\x03Sub\x12\x03Sub\x1a\nimg_holder\x1a\x05Const*\x07\n\x01T\x12\x020\x03\n\x18\n\x03Abs\x12\x03Abs\x1a\x03Sub*\x07\n\x01T\x12\x020\x03\n?\n\x15Sum/reduction_indices\x12\x05Const*\x12\n\x05value\x12\tB\x07\x08\x03\x12\x00:\x01\x02*\x0b\n\x05dtype\x12\x020\x03\nL\n\x03Sum\x12\x03Sum\x1a\x03Abs\x1a\x15Sum/reduction_indices*\n\n\x04Tidx\x12\x020\x03*\x0f\n\tkeep_dims\x12\x02(\x01*\x07\n\x01T\x12\x020\x03\n9\n\x04Cast\x12\x04Cast\x1a\x03Sum*\n\n\x04SrcT\x12\x020\x03*\x0e\n\x08Truncate\x12\x02(\x00*\n\n\x04DstT\x12\x020\x01\n*\n\x05Sub_1\x12\x03Sub\x1a\nimg_holder\x1a\x07Const_1*\x07\n\x01T\x12\x020\x03\n\x1c\n\x05Abs_1\x12\x03Abs\x1a\x05Sub_1*\x07\n\x01T\x12\x020\x03\nA\n\x17Sum_1/reduction_indices\x12\x05Const*\x12\n\x05value\x12\tB\x07\x08\x03\x12\x00:\x01\x02*\x0b\n\x05dtype\x12\x020\x03\nR\n\x05Sum_1\x12\x03Sum\x1a\x05Abs_1\x1a\x17Sum_1/reduction_indices*\n\n\x04Tidx\x12\x020\x03*\x0f\n\tkeep_dims\x12\x02(\x01*\x07\n\x01T\x12\x020\x03\n=\n\x06Cast_1\x12\x04Cast\x1a\x05Sum_1*\n\n\x04SrcT\x12\x020\x03*\x0e\n\x08Truncate\x12\x02(\x00*\n\n\x04DstT\x12\x020\x01\n*\n\x05Sub_2\x12\x03Sub\x1a\nimg_holder\x1a\x07Const_3*\x07\n\x01T\x12\x020\x03\n\x1c\n\x05Abs_2\x12\x03Abs\x1a\x05Sub_2*\x07\n\x01T\x12\x020\x03\nA\n\x17Sum_2/reduction_indices\x12\x05Const*\x12\n\x05value\x12\tB\x07\x08\x03\x12\x00:\x01\x02*\x0b\n\x05dtype\x12\x020\x03\nR\n\x05Sum_2\x12\x03Sum\x1a\x05Abs_2\x1a\x17Sum_2/reduction_indices*\n\n\x04Tidx\x12\x020\x03*\x0f\n\tkeep_dims\x12\x02(\x01*\x07\n\x01T\x12\x020\x03\n=\n\x06Cast_2\x12\x04Cast\x1a\x05Sum_2*\n\n\x04SrcT\x12\x020\x03*\x0e\n\x08Truncate\x12\x02(\x00*\n\n\x04DstT\x12\x020\x01\n*\n\x05Sub_3\x12\x03Sub\x1a\nimg_holder\x1a\x07Const_2*\x07\n\x01T\x12\x020\x03\n\x1c\n\x05Abs_3\x12\x03Abs\x1a\x05Sub_3*\x07\n\x01T\x12\x020\x03\nA\n\x17Sum_3/reduction_indices\x12\x05Const*\x12\n\x05value\x12\tB\x07\x08\x03\x12\x00:\x01\x02*\x0b\n\x05dtype\x12\x020\x03\nR\n\x05Sum_3\x12\x03Sum\x1a\x05Abs_3\x1a\x17Sum_3/reduction_indices*\n\n\x04Tidx\x12\x020\x03*\x0f\n\tkeep_dims\x12\x02(\x01*\x07\n\x01T\x12\x020\x03\n=\n\x06Cast_3\x12\x04Cast\x1a\x05Sum_3*\n\n\x04SrcT\x12\x020\x03*\x0e\n\x08Truncate\x12\x02(\x00*\n\n\x04DstT\x12\x020\x01\n*\n\x05Sub_4\x12\x03Sub\x1a\nimg_holder\x1a\x07Const_2*\x07\n\x01T\x12\x020\x03\n\x1c\n\x05Abs_4\x12\x03Abs\x1a\x05Sub_4*\x07\n\x01T\x12\x020\x03\nA\n\x17Sum_4/reduction_indices\x12\x05Const*\x12\n\x05value\x12\tB\x07\x08\x03\x12\x00:\x01\x02*\x0b\n\x05dtype\x12\x020\x03\nR\n\x05Sum_4\x12\x03Sum\x1a\x05Abs_4\x1a\x17Sum_4/reduction_indices*\n\n\x04Tidx\x12\x020\x03*\x0f\n\tkeep_dims\x12\x02(\x01*\x07\n\x01T\x12\x020\x03\n=\n\x06Cast_4\x12\x04Cast\x1a\x05Sum_4*\n\n\x04SrcT\x12\x020\x03*\x0e\n\x08Truncate\x12\x02(\x00*\n\n\x04DstT\x12\x020\x01\n*\n\x05Sub_5\x12\x03Sub\x1a\nimg_holder\x1a\x07Const_4*\x07\n\x01T\x12\x020\x03\n\x1c\n\x05Abs_5\x12\x03Abs\x1a\x05Sub_5*\x07\n\x01T\x12\x020\x03\nA\n\x17Sum_5/reduction_indices\x12\x05Const*\x12\n\x05value\x12\tB\x07\x08\x03\x12\x00:\x01\x02*\x0b\n\x05dtype\x12\x020\x03\nR\n\x05Sum_5\x12\x03Sum\x1a\x05Abs_5\x1a\x17Sum_5/reduction_indices*\n\n\x04Tidx\x12\x020\x03*\x0f\n\tkeep_dims\x12\x02(\x01*\x07\n\x01T\x12\x020\x03\n=\n\x06Cast_5\x12\x04Cast\x1a\x05Sum_5*\n\n\x04SrcT\x12\x020\x03*\x0e\n\x08Truncate\x12\x02(\x00*\n\n\x04DstT\x12\x020\x01\n*\n\x05Sub_6\x12\x03Sub\x1a\nimg_holder\x1a\x07Reshape*\x07\n\x01T\x12\x020\x03\n\x1c\n\x05Abs_6\x12\x03Abs\x1a\x05Sub_6*\x07\n\x01T\x12\x020\x03\nA\n\x17Sum_6/reduction_indices\x12\x05Const*\x12\n\x05value\x12\tB\x07\x08\x03\x12\x00:\x01\x02*\x0b\n\x05dtype\x12\x020\x03\nR\n\x05Sum_6\x12\x03Sum\x1a\x05Abs_6\x1a\x17Sum_6/reduction_indices*\n\n\x04Tidx\x12\x020\x03*\x0f\n\tkeep_dims\x12\x02(\x01*\x07\n\x01T\x12\x020\x03\n=\n\x06Cast_6\x12\x04Cast\x1a\x05Sum_6*\n\n\x04SrcT\x12\x020\x03*\x0e\n\x08Truncate\x12\x02(\x00*\n\n\x04DstT\x12\x020\x01\n,\n\x05Sub_7\x12\x03Sub\x1a\nimg_holder\x1a\tReshape_1*\x07\n\x01T\x12\x020\x03\n\x1c\n\x05Abs_7\x12\x03Abs\x1a\x05Sub_7*\x07\n\x01T\x12\x020\x03\nA\n\x17Sum_7/reduction_indices\x12\x05Const*\x12\n\x05value\x12\tB\x07\x08\x03\x12\x00:\x01\x02*\x0b\n\x05dtype\x12\x020\x03\nR\n\x05Sum_7\x12\x03Sum\x1a\x05Abs_7\x1a\x17Sum_7/reduction_indices*\n\n\x04Tidx\x12\x020\x03*\x0f\n\tkeep_dims\x12\x02(\x01*\x07\n\x01T\x12\x020\x03\n=\n\x06Cast_7\x12\x04Cast\x1a\x05Sum_7*\n\n\x04SrcT\x12\x020\x03*\x0e\n\x08Truncate\x12\x02(\x00*\n\n\x04DstT\x12\x020\x01\n>\n\x0bconcat/axis\x12\x05Const*\x1b\n\x05value\x12\x12B\x10\x08\x03\x12\x00:\n\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01*\x0b\n\x05dtype\x12\x020\x03\n{\n\x06concat\x12\x08ConcatV2\x1a\x04Cast\x1a\x06Cast_1\x1a\x06Cast_2\x1a\x06Cast_3\x1a\x06Cast_5\x1a\x06Cast_6\x1a\x06Cast_7\x1a\x06Cast_4\x1a\x0bconcat/axis*\n\n\x04Tidx\x12\x020\x03*\x07\n\x01T\x12\x020\x01*\x07\n\x01N\x12\x02\x18\x08\nC\n\x10ArgMin/dimension\x12\x05Const*\x1b\n\x05value\x12\x12B\x10\x08\x03\x12\x00:\n\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01*\x0b\n\x05dtype\x12\x020\x03\nR\n\x06ArgMin\x12\x06ArgMin\x1a\x06concat\x1a\x10ArgMin/dimension*\n\n\x04Tidx\x12\x020\x03*\x07\n\x01T\x12\x020\x01*\x11\n\x0boutput_type\x12\x020\t\n>\n\x06Cast_8\x12\x04Cast\x1a\x06ArgMin*\n\n\x04SrcT\x12\x020\t*\x0e\n\x08Truncate\x12\x02(\x00*\n\n\x04DstT\x12\x020\x03\n-\n\x05Equal\x12\x05Equal\x1a\x06Cast_8\x1a\x0ctarget_color*\x07\n\x01T\x12\x020\x03\n=\n\x06Cast_9\x12\x04Cast\x1a\x05Equal*\n\n\x04SrcT\x12\x020\n*\x0e\n\x08Truncate\x12\x02(\x00*\n\n\x04DstT\x12\x020\x03\n0\n\x05mul/y\x12\x05Const*\x13\n\x05value\x12\nB\x08\x08\x03\x12\x00:\x02\xff\x01*\x0b\n\x05dtype\x12\x020\x03\n"\n\x03mul\x12\x03Mul\x1a\x06Cast_9\x1a\x05mul/y*\x07\n\x01T\x12\x020\x03\n8\n\x0eExpandDims/dim\x12\x05Const*\x12\n\x05value\x12\tB\x07\x08\x03\x12\x00:\x01\x00*\x0b\n\x05dtype\x12\x020\x03\nB\n\nExpandDims\x12\nExpandDims\x1a\x03mul\x1a\x0eExpandDims/dim*\n\n\x04Tdim\x12\x020\x03*\x07\n\x01T\x12\x020\x03\nC\n\x10ExpandDims_1/dim\x12\x05Const*\x1b\n\x05value\x12\x12B\x10\x08\x03\x12\x00:\n\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01*\x0b\n\x05dtype\x12\x020\x03\nM\n\x0cExpandDims_1\x12\nExpandDims\x1a\nExpandDims\x1a\x10ExpandDims_1/dim*\n\n\x04Tdim\x12\x020\x03*\x07\n\x01T\x12\x020\x03\n>\n\x08filtered\x12\x07Squeeze\x1a\x0cExpandDims_1*\x12\n\x0csqueeze_dims\x12\x02\n\x00*\x07\n\x01T\x12\x020\x03\x12\x00' 38 | self.color_graph = tf.Graph() 39 | self.color_sess = tf.compat.v1.Session( 40 | graph=self.color_graph, 41 | config=tf.compat.v1.ConfigProto( 42 | allow_soft_placement=True, 43 | # log_device_placement=True, 44 | gpu_options=tf.compat.v1.GPUOptions( 45 | # allow_growth=True, # it will cause fragmentation. 46 | per_process_gpu_memory_fraction=0.1 47 | )) 48 | ) 49 | self.color_graph_def = self.color_graph.as_graph_def() 50 | self.load_model() 51 | self.img_holder = self.color_sess.graph.get_tensor_by_name("img_holder:0") 52 | self.target_color = self.color_sess.graph.get_tensor_by_name("target_color:0") 53 | self.filtered = self.color_sess.graph.get_tensor_by_name("filtered:0") 54 | self.color_graph.finalize() 55 | 56 | def load_model(self): 57 | raw = self.model_raw_v1_14 58 | self.color_graph_def.ParseFromString(raw) 59 | with self.color_graph.as_default(): 60 | self.color_sess.run(tf.compat.v1.global_variables_initializer()) 61 | _ = tf.import_graph_def(self.color_graph_def, name="") 62 | 63 | def separate_color(self, image_bytes, color: TargetColor): 64 | # image = np.asarray(bytearray(image_bytes), dtype="uint8") 65 | # image = cv2.imdecode(image, -1) 66 | image = np.array(PilImage.open(io.BytesIO(image_bytes))) 67 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 68 | mask = self.color_sess.run(self.filtered, {self.img_holder: image, self.target_color: color.value}) 69 | mask = bytearray(cv2.imencode('.png', mask)[1]) 70 | return mask 71 | 72 | 73 | if __name__ == '__main__': 74 | 75 | pass 76 | # import os 77 | # source_dir = r'E:\***' 78 | # target_dir = r'E:\***' 79 | # if not os.path.exists(target_dir): 80 | # os.makedirs(target_dir) 81 | # 82 | # source_names = os.listdir(source_dir) 83 | # color_extract = ColorExtract() 84 | # st = time.time() 85 | # for i, name in enumerate(source_names): 86 | # img_path = os.path.join(source_dir, name) 87 | # if i % 100 == 0: 88 | # print(i) 89 | # with open(img_path, "rb") as f: 90 | # b = f.read() 91 | # result = color_extract.separate_color(b, color_map['red']) 92 | # target_path = os.path.join(target_dir, name) 93 | # with open(target_path, "wb") as f: 94 | # f.write(result) 95 | # 96 | # print('completed {}'.format(time.time() - st)) 97 | -------------------------------------------------------------------------------- /middleware/impl/color_filter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import io 5 | import cv2 6 | import time 7 | import PIL.Image as PilImage 8 | import numpy as np 9 | import onnxruntime as ort 10 | from enum import Enum, unique 11 | from distutils.version import StrictVersion 12 | from middleware.resource import color_model 13 | 14 | 15 | @unique 16 | class TargetColor(Enum): 17 | Red = 1 18 | Blue = 2 19 | Yellow = 3 20 | Black = 4 21 | 22 | 23 | color_map = { 24 | 'black': TargetColor.Black, 25 | 'red': TargetColor.Red, 26 | 'blue': TargetColor.Blue, 27 | 'yellow': TargetColor.Yellow, 28 | } 29 | 30 | 31 | class ColorFilter: 32 | 33 | def __init__(self): 34 | self.model_onnx = color_model 35 | self.sess = ort.InferenceSession(self.model_onnx) 36 | 37 | def predict_color(self, image_batch, color: TargetColor): 38 | dense_decoded_code = self.sess.run(["dense_decoded:0"], input_feed={ 39 | "input:0": image_batch, 40 | }) 41 | result = dense_decoded_code[0][0].tolist() 42 | return [i for i, c in enumerate(result) if c == color.value] 43 | 44 | 45 | if __name__ == '__main__': 46 | 47 | pass 48 | # import os 49 | # source_dir = r'E:\***' 50 | # target_dir = r'E:\***' 51 | # if not os.path.exists(target_dir): 52 | # os.makedirs(target_dir) 53 | # 54 | # source_names = os.listdir(source_dir) 55 | # color_extract = ColorExtract() 56 | # st = time.time() 57 | # for i, name in enumerate(source_names): 58 | # img_path = os.path.join(source_dir, name) 59 | # if i % 100 == 0: 60 | # print(i) 61 | # with open(img_path, "rb") as f: 62 | # b = f.read() 63 | # result = color_extract.separate_color(b, color_map['red']) 64 | # target_path = os.path.join(target_dir, name) 65 | # with open(target_path, "wb") as f: 66 | # f.write(result) 67 | # 68 | # print('completed {}'.format(time.time() - st)) 69 | -------------------------------------------------------------------------------- /middleware/impl/corp_to_multi.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import io 5 | import cv2 6 | import PIL.Image as Pil_Image 7 | import numpy as np 8 | 9 | 10 | def coord_calc(param, is_range=True, is_integer=True): 11 | 12 | result_group = [] 13 | start_h = param['start_pos'][1] 14 | end_h = start_h + param['corp_size'][1] 15 | for row in range(param['corp_num'][1]): 16 | start_w = param['start_pos'][0] 17 | end_w = start_w + param['corp_size'][0] 18 | for col in range(param['corp_num'][0]): 19 | pos_range = [[start_w, end_w], [start_h, end_h]] 20 | t = lambda x: int(x) if is_integer else x 21 | pos_center = [t((start_w + end_w)/2), t((start_h + end_h)/2)] 22 | result_group.append(pos_range if is_range else pos_center) 23 | start_w = end_w + param['interval_size'][0] 24 | end_w = start_w + param['corp_size'][0] 25 | start_h = end_h + param['interval_size'][1] 26 | end_h = start_h + param['corp_size'][1] 27 | return result_group 28 | 29 | 30 | def parse_multi_img(image_bytes, param_group): 31 | img_bytes = image_bytes[0] 32 | image_arr = np.array(Pil_Image.open(io.BytesIO(img_bytes)).convert('RGB')) 33 | if len(image_arr.shape) == 3: 34 | image_arr = cv2.cvtColor(image_arr, cv2.COLOR_BGR2RGB) 35 | # image_arr = np.fromstring(img_bytes, np.uint8) 36 | # print(image_arr.shape) 37 | image_arr = image_arr.swapaxes(0, 1) 38 | group = [] 39 | for p in param_group: 40 | pos_ranges = coord_calc(p, True, True) 41 | for pos_range in pos_ranges: 42 | corp_arr = image_arr[pos_range[0][0]: pos_range[0][1], pos_range[1][0]: pos_range[1][1]] 43 | corp_arr = cv2.imencode('.png', np.swapaxes(corp_arr, 0, 1))[1] 44 | corp_bytes = bytes(bytearray(corp_arr)) 45 | group.append(corp_bytes) 46 | return group 47 | 48 | 49 | def get_coordinate(label: str, param_group, title_index=None): 50 | if title_index is None: 51 | title_index = [0] 52 | param = param_group[-1] 53 | coord_map = coord_calc(param, is_range=False, is_integer=True) 54 | index_group = get_pair_index(label=label, title_index=title_index) 55 | return [coord_map[i] for i in index_group] 56 | 57 | 58 | def get_pair_index(label: str, title_index=None): 59 | if title_index is None: 60 | title_index = [0] 61 | max_index = max(title_index) 62 | 63 | label_group = label.split(',') 64 | titles = [label_group[i] for i in title_index] 65 | 66 | index_group = [] 67 | for title in titles: 68 | for i, item in enumerate(label_group[max_index+1:]): 69 | if item == title: 70 | index_group.append(i) 71 | index_group = [i for i in index_group] 72 | return index_group 73 | 74 | 75 | if __name__ == '__main__': 76 | import os 77 | import hashlib 78 | root_dir = r"H:\Task\Trains\d111_Trains" 79 | target_dir = r"F:\1q2" 80 | if not os.path.exists(target_dir): 81 | os.makedirs(target_dir) 82 | _param_group = [ 83 | { 84 | "start_pos": [20, 50], 85 | "interval_size": [20, 20], 86 | "corp_num": [4, 2], 87 | "corp_size": [60, 60] 88 | } 89 | ] 90 | for name in os.listdir(root_dir): 91 | path = os.path.join(root_dir, name) 92 | with open(path, "rb") as f: 93 | file_bytes = [f.read()] 94 | group = parse_multi_img(file_bytes, _param_group) 95 | for b in group: 96 | tag = hashlib.md5(b).hexdigest() 97 | p = os.path.join(target_dir, "{}.png".format(tag)) 98 | with open(p, "wb") as f: 99 | f.write(b) -------------------------------------------------------------------------------- /middleware/impl/gif_frames.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import io 5 | import cv2 6 | import numpy as np 7 | from itertools import groupby 8 | from PIL import ImageSequence, Image 9 | 10 | 11 | def split_frames(image_obj, need_frame=None): 12 | image_seq = ImageSequence.all_frames(image_obj) 13 | image_arr_last = [np.asarray(image_seq[-1])] if -1 in need_frame and len(need_frame) > 1 else [] 14 | image_arr = [np.asarray(item) for i, item in enumerate(image_seq) if (i in need_frame or need_frame == [-1])] 15 | image_arr += image_arr_last 16 | return image_arr 17 | 18 | 19 | def concat_arr(img_arr): 20 | if len(img_arr) < 2: 21 | return img_arr[0] 22 | all_slice = img_arr[0] 23 | for im_slice in img_arr[1:]: 24 | all_slice = np.concatenate((all_slice, im_slice), axis=1) 25 | return all_slice 26 | 27 | 28 | def numpy_to_bytes(numpy_arr): 29 | cv_img = cv2.imencode('.png', numpy_arr)[1] 30 | img_bytes = bytes(bytearray(cv_img)) 31 | return img_bytes 32 | 33 | 34 | def concat_frames(image_obj, need_frame=None): 35 | if not need_frame: 36 | need_frame = [0] 37 | img_arr = split_frames(image_obj, need_frame) 38 | img_arr = concat_arr(img_arr) 39 | return img_arr 40 | 41 | 42 | def blend_arr(img_arr): 43 | if len(img_arr) < 2: 44 | return img_arr[0] 45 | all_slice = img_arr[0] 46 | for im_slice in img_arr[1:]: 47 | all_slice = cv2.addWeighted(all_slice, 0.5, im_slice, 0.5, 0) 48 | # all_slice = cv2.equalizeHist(all_slice) 49 | return all_slice 50 | 51 | 52 | def blend_frame(image_obj, need_frame=None): 53 | if not need_frame: 54 | need_frame = [-1] 55 | img_arr = split_frames(image_obj, need_frame) 56 | img_arr = blend_arr(img_arr) 57 | if len(img_arr.shape) > 2 and img_arr.shape[2] == 3: 58 | img_arr = cv2.cvtColor(img_arr, cv2.COLOR_RGB2GRAY) 59 | img_arr = cv2.equalizeHist(img_arr) 60 | return img_arr 61 | 62 | 63 | def all_frames(image_obj): 64 | if isinstance(image_obj, list): 65 | image_obj = image_obj[0] 66 | stream = io.BytesIO(image_obj) 67 | pil_image = Image.open(stream) 68 | image_seq = ImageSequence.all_frames(pil_image) 69 | array_seq = [np.asarray(im.convert("RGB")) for im in image_seq] 70 | # [1::2] 71 | bytes_arr = [cv2.imencode('.png', img_arr)[1] for img_arr in array_seq] 72 | split_flag = b'\x99\x99\x99\x00\xff\xff999999.........99999\xff\x00\x99\x99\x99' 73 | return split_flag.join(bytes_arr).split(split_flag) 74 | 75 | 76 | def get_continuity_max(src: list): 77 | if not src: 78 | return "" 79 | elem_cont_len = lambda x: max(len(list(g)) for k, g in groupby(src) if k == x) 80 | target_list = [elem_cont_len(i) for i in src] 81 | target_index = target_list.index(max(target_list)) 82 | return src[target_index] 83 | 84 | 85 | if __name__ == "__main__": 86 | pass 87 | -------------------------------------------------------------------------------- /middleware/impl/rgb_filter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import numpy as np 5 | import cv2 6 | 7 | 8 | def rgb_filter(image_obj, need_rgb): 9 | low_rgb = np.array([i-15 for i in need_rgb]) 10 | high_rgb = np.array([i+15 for i in need_rgb]) 11 | mask = cv2.inRange(image_obj, lowerb=low_rgb, upperb=high_rgb) 12 | mask = cv2.bitwise_not(mask) 13 | mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB) 14 | # img_bytes = cv2.imencode('.png', mask)[1] 15 | return mask 16 | 17 | 18 | if __name__ == '__main__': 19 | pass -------------------------------------------------------------------------------- /middleware/resource/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | from middleware.resource.color_filter import model_onnx 5 | 6 | color_model = model_onnx 7 | if __name__ == '__main__': 8 | pass -------------------------------------------------------------------------------- /package.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import os 5 | import cv2 6 | import time 7 | import stat 8 | import socket 9 | import paramiko 10 | import platform 11 | import distutils 12 | import tensorflow as tf 13 | tf.compat.v1.disable_v2_behavior() 14 | from enum import Enum, unique 15 | from utils import SystemUtils 16 | from config import resource_path 17 | 18 | from PyInstaller.__main__ import run, logger 19 | """ Used to package as a single executable """ 20 | 21 | if platform.system() == 'Linux': 22 | if distutils.distutils_path.endswith('__init__.py'): 23 | distutils.distutils_path = os.path.dirname(distutils.distutils_path) 24 | 25 | with open("./resource/VERSION", "w", encoding="utf8") as f: 26 | today = time.strftime("%Y%m%d", time.localtime(time.time())) 27 | f.write(today) 28 | 29 | 30 | @unique 31 | class Version(Enum): 32 | CPU = 'CPU' 33 | GPU = 'GPU' 34 | 35 | 36 | if __name__ == '__main__': 37 | 38 | ver = Version.CPU 39 | 40 | upload = False 41 | server_ip = "" 42 | username = "" 43 | password = "" 44 | model_dir = "model" 45 | graph_dir = "graph" 46 | 47 | if ver == Version.GPU: 48 | opts = ['tornado_server_gpu.spec', '--distpath=dist'] 49 | else: 50 | opts = ['tornado_server.spec', '--distpath=dist'] 51 | run(opts) -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | from config import ModelConfig 5 | 6 | 7 | def decode_maps(categories): 8 | return {index: category for index, category in enumerate(categories, 0)} 9 | 10 | 11 | def predict_func(image_batch, _sess, dense_decoded, op_input, model: ModelConfig, output_split=None): 12 | 13 | output_split = model.output_split if output_split is None else output_split 14 | 15 | category_split = model.category_split if model.category_split else "" 16 | 17 | dense_decoded_code = _sess.run(dense_decoded, feed_dict={ 18 | op_input: image_batch, 19 | }) 20 | decoded_expression = [] 21 | for item in dense_decoded_code: 22 | expression = [] 23 | 24 | for i in item: 25 | if i == -1 or i == model.category_num: 26 | expression.append("") 27 | else: 28 | expression.append(decode_maps(model.category)[i]) 29 | decoded_expression.append(category_split.join(expression)) 30 | return output_split.join(decoded_expression) if len(decoded_expression) > 1 else decoded_expression[0] 31 | -------------------------------------------------------------------------------- /pretreatment.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import cv2 5 | 6 | 7 | class Pretreatment(object): 8 | 9 | def __init__(self, origin): 10 | self.origin = origin 11 | 12 | def get(self): 13 | return self.origin 14 | 15 | def binarization(self, value, modify=False): 16 | ret, _binarization = cv2.threshold(self.origin, value, 255, cv2.THRESH_BINARY) 17 | if modify: 18 | self.origin = _binarization 19 | return _binarization 20 | 21 | 22 | def preprocessing(image, binaryzation=-1): 23 | pretreatment = Pretreatment(image) 24 | if binaryzation > 0: 25 | pretreatment.binarization(binaryzation, True) 26 | return pretreatment.get() 27 | 28 | 29 | def preprocessing_by_func(exec_map, key, src_arr): 30 | if not exec_map: 31 | return src_arr 32 | target_arr = cv2.cvtColor(src_arr, cv2.COLOR_RGB2BGR) 33 | for sentence in exec_map.get(key): 34 | if sentence.startswith("@@"): 35 | target_arr = eval(sentence[2:]) 36 | elif sentence.startswith("$$"): 37 | exec(sentence[2:]) 38 | return cv2.cvtColor(target_arr, cv2.COLOR_BGR2RGB) 39 | 40 | 41 | if __name__ == '__main__': 42 | pass 43 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | flask 2 | gevent 3 | Flask-Caching 4 | gevent-websocket 5 | tf-nightly 6 | pillow 7 | opencv-python-headless 8 | numpy 9 | grpcio 10 | grpcio_tools 11 | requests 12 | pyyaml 13 | tornado 14 | watchdog 15 | pyinstaller 16 | sanic 17 | paramiko 18 | APScheduler 19 | requests -------------------------------------------------------------------------------- /resource/VERSION: -------------------------------------------------------------------------------- 1 | 20211203 -------------------------------------------------------------------------------- /resource/favorite.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kerlomz/captcha_platform/f7d719bd1239a987996e266bd7fe35c96003b378/resource/favorite.ico -------------------------------------------------------------------------------- /resource/icon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kerlomz/captcha_platform/f7d719bd1239a987996e266bd7fe35c96003b378/resource/icon.ico -------------------------------------------------------------------------------- /sdk/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz -------------------------------------------------------------------------------- /sdk/onnx/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz -------------------------------------------------------------------------------- /sdk/onnx/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pillow 3 | opencv-python==3.4.5.20 4 | pyyaml>=3.13 5 | onnxruntime -------------------------------------------------------------------------------- /sdk/onnx/sdk.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import io 5 | import os 6 | import pickle 7 | import cv2 8 | import time 9 | import yaml 10 | import binascii 11 | import numpy as np 12 | import PIL.Image as PIL_Image 13 | from enum import Enum, unique 14 | import onnxruntime as ort 15 | 16 | SPACE_TOKEN = [''] 17 | NUMBER = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] 18 | ALPHA_UPPER = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 19 | 'V', 'W', 'X', 'Y', 'Z'] 20 | ALPHA_LOWER = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 21 | 'v', 'w', 'x', 'y', 'z'] 22 | ARITHMETIC = ['(', ')', '+', '-', '×', '÷', '='] 23 | CHINESE_3500 = [ 24 | '一', '乙', '二', '十', '丁', '厂', '七', '卜', '人', '入', '八', '九', '几', '儿', '了', '力', '乃', '刀', '又', '三', 25 | '于', '干', '亏', '士', '工', '土', '才', '寸', '下', '大', '丈', '与', '万', '上', '小', '口', '巾', '山', '千', '乞', 26 | '川', '亿', '个', '勺', '久', '凡', '及', '夕', '丸', '么', '广', '亡', '门', '义', '之', '尸', '弓', '己', '已', '子', 27 | '卫', '也', '女', '飞', '刃', '习', '叉', '马', '乡', '丰', '王', '井', '开', '夫', '天', '无', '元', '专', '云', '扎', 28 | '艺', '木', '五', '支', '厅', '不', '太', '犬', '区', '历', '尤', '友', '匹', '车', '巨', '牙', '屯', '比', '互', '切', 29 | '瓦', '止', '少', '日', '中', '冈', '贝', '内', '水', '见', '午', '牛', '手', '毛', '气', '升', '长', '仁', '什', '片', 30 | '仆', '化', '仇', '币', '仍', '仅', '斤', '爪', '反', '介', '父', '从', '今', '凶', '分', '乏', '公', '仓', '月', '氏', 31 | '勿', '欠', '风', '丹', '匀', '乌', '凤', '勾', '文', '六', '方', '火', '为', '斗', '忆', '订', '计', '户', '认', '心', 32 | '尺', '引', '丑', '巴', '孔', '队', '办', '以', '允', '予', '劝', '双', '书', '幻', '玉', '刊', '示', '末', '未', '击', 33 | '打', '巧', '正', '扑', '扒', '功', '扔', '去', '甘', '世', '古', '节', '本', '术', '可', '丙', '左', '厉', '右', '石', 34 | '布', '龙', '平', '灭', '轧', '东', '卡', '北', '占', '业', '旧', '帅', '归', '且', '旦', '目', '叶', '甲', '申', '叮', 35 | '电', '号', '田', '由', '史', '只', '央', '兄', '叼', '叫', '另', '叨', '叹', '四', '生', '失', '禾', '丘', '付', '仗', 36 | '代', '仙', '们', '仪', '白', '仔', '他', '斥', '瓜', '乎', '丛', '令', '用', '甩', '印', '乐', '句', '匆', '册', '犯', 37 | '外', '处', '冬', '鸟', '务', '包', '饥', '主', '市', '立', '闪', '兰', '半', '汁', '汇', '头', '汉', '宁', '穴', '它', 38 | '讨', '写', '让', '礼', '训', '必', '议', '讯', '记', '永', '司', '尼', '民', '出', '辽', '奶', '奴', '加', '召', '皮', 39 | '边', '发', '孕', '圣', '对', '台', '矛', '纠', '母', '幼', '丝', '式', '刑', '动', '扛', '寺', '吉', '扣', '考', '托', 40 | '老', '执', '巩', '圾', '扩', '扫', '地', '扬', '场', '耳', '共', '芒', '亚', '芝', '朽', '朴', '机', '权', '过', '臣', 41 | '再', '协', '西', '压', '厌', '在', '有', '百', '存', '而', '页', '匠', '夸', '夺', '灰', '达', '列', '死', '成', '夹', 42 | '轨', '邪', '划', '迈', '毕', '至', '此', '贞', '师', '尘', '尖', '劣', '光', '当', '早', '吐', '吓', '虫', '曲', '团', 43 | '同', '吊', '吃', '因', '吸', '吗', '屿', '帆', '岁', '回', '岂', '刚', '则', '肉', '网', '年', '朱', '先', '丢', '舌', 44 | '竹', '迁', '乔', '伟', '传', '乒', '乓', '休', '伍', '伏', '优', '伐', '延', '件', '任', '伤', '价', '份', '华', '仰', 45 | '仿', '伙', '伪', '自', '血', '向', '似', '后', '行', '舟', '全', '会', '杀', '合', '兆', '企', '众', '爷', '伞', '创', 46 | '肌', '朵', '杂', '危', '旬', '旨', '负', '各', '名', '多', '争', '色', '壮', '冲', '冰', '庄', '庆', '亦', '刘', '齐', 47 | '交', '次', '衣', '产', '决', '充', '妄', '闭', '问', '闯', '羊', '并', '关', '米', '灯', '州', '汗', '污', '江', '池', 48 | '汤', '忙', '兴', '宇', '守', '宅', '字', '安', '讲', '军', '许', '论', '农', '讽', '设', '访', '寻', '那', '迅', '尽', 49 | '导', '异', '孙', '阵', '阳', '收', '阶', '阴', '防', '奸', '如', '妇', '好', '她', '妈', '戏', '羽', '观', '欢', '买', 50 | '红', '纤', '级', '约', '纪', '驰', '巡', '寿', '弄', '麦', '形', '进', '戒', '吞', '远', '违', '运', '扶', '抚', '坛', 51 | '技', '坏', '扰', '拒', '找', '批', '扯', '址', '走', '抄', '坝', '贡', '攻', '赤', '折', '抓', '扮', '抢', '孝', '均', 52 | '抛', '投', '坟', '抗', '坑', '坊', '抖', '护', '壳', '志', '扭', '块', '声', '把', '报', '却', '劫', '芽', '花', '芹', 53 | '芬', '苍', '芳', '严', '芦', '劳', '克', '苏', '杆', '杠', '杜', '材', '村', '杏', '极', '李', '杨', '求', '更', '束', 54 | '豆', '两', '丽', '医', '辰', '励', '否', '还', '歼', '来', '连', '步', '坚', '旱', '盯', '呈', '时', '吴', '助', '县', 55 | '里', '呆', '园', '旷', '围', '呀', '吨', '足', '邮', '男', '困', '吵', '串', '员', '听', '吩', '吹', '呜', '吧', '吼', 56 | '别', '岗', '帐', '财', '针', '钉', '告', '我', '乱', '利', '秃', '秀', '私', '每', '兵', '估', '体', '何', '但', '伸', 57 | '作', '伯', '伶', '佣', '低', '你', '住', '位', '伴', '身', '皂', '佛', '近', '彻', '役', '返', '余', '希', '坐', '谷', 58 | '妥', '含', '邻', '岔', '肝', '肚', '肠', '龟', '免', '狂', '犹', '角', '删', '条', '卵', '岛', '迎', '饭', '饮', '系', 59 | '言', '冻', '状', '亩', '况', '床', '库', '疗', '应', '冷', '这', '序', '辛', '弃', '冶', '忘', '闲', '间', '闷', '判', 60 | '灶', '灿', '弟', '汪', '沙', '汽', '沃', '泛', '沟', '没', '沈', '沉', '怀', '忧', '快', '完', '宋', '宏', '牢', '究', 61 | '穷', '灾', '良', '证', '启', '评', '补', '初', '社', '识', '诉', '诊', '词', '译', '君', '灵', '即', '层', '尿', '尾', 62 | '迟', '局', '改', '张', '忌', '际', '陆', '阿', '陈', '阻', '附', '妙', '妖', '妨', '努', '忍', '劲', '鸡', '驱', '纯', 63 | '纱', '纳', '纲', '驳', '纵', '纷', '纸', '纹', '纺', '驴', '纽', '奉', '玩', '环', '武', '青', '责', '现', '表', '规', 64 | '抹', '拢', '拔', '拣', '担', '坦', '押', '抽', '拐', '拖', '拍', '者', '顶', '拆', '拥', '抵', '拘', '势', '抱', '垃', 65 | '拉', '拦', '拌', '幸', '招', '坡', '披', '拨', '择', '抬', '其', '取', '苦', '若', '茂', '苹', '苗', '英', '范', '直', 66 | '茄', '茎', '茅', '林', '枝', '杯', '柜', '析', '板', '松', '枪', '构', '杰', '述', '枕', '丧', '或', '画', '卧', '事', 67 | '刺', '枣', '雨', '卖', '矿', '码', '厕', '奔', '奇', '奋', '态', '欧', '垄', '妻', '轰', '顷', '转', '斩', '轮', '软', 68 | '到', '非', '叔', '肯', '齿', '些', '虎', '虏', '肾', '贤', '尚', '旺', '具', '果', '味', '昆', '国', '昌', '畅', '明', 69 | '易', '昂', '典', '固', '忠', '咐', '呼', '鸣', '咏', '呢', '岸', '岩', '帖', '罗', '帜', '岭', '凯', '败', '贩', '购', 70 | '图', '钓', '制', '知', '垂', '牧', '物', '乖', '刮', '秆', '和', '季', '委', '佳', '侍', '供', '使', '例', '版', '侄', 71 | '侦', '侧', '凭', '侨', '佩', '货', '依', '的', '迫', '质', '欣', '征', '往', '爬', '彼', '径', '所', '舍', '金', '命', 72 | '斧', '爸', '采', '受', '乳', '贪', '念', '贫', '肤', '肺', '肢', '肿', '胀', '朋', '股', '肥', '服', '胁', '周', '昏', 73 | '鱼', '兔', '狐', '忽', '狗', '备', '饰', '饱', '饲', '变', '京', '享', '店', '夜', '庙', '府', '底', '剂', '郊', '废', 74 | '净', '盲', '放', '刻', '育', '闸', '闹', '郑', '券', '卷', '单', '炒', '炊', '炕', '炎', '炉', '沫', '浅', '法', '泄', 75 | '河', '沾', '泪', '油', '泊', '沿', '泡', '注', '泻', '泳', '泥', '沸', '波', '泼', '泽', '治', '怖', '性', '怕', '怜', 76 | '怪', '学', '宝', '宗', '定', '宜', '审', '宙', '官', '空', '帘', '实', '试', '郎', '诗', '肩', '房', '诚', '衬', '衫', 77 | '视', '话', '诞', '询', '该', '详', '建', '肃', '录', '隶', '居', '届', '刷', '屈', '弦', '承', '孟', '孤', '陕', '降', 78 | '限', '妹', '姑', '姐', '姓', '始', '驾', '参', '艰', '线', '练', '组', '细', '驶', '织', '终', '驻', '驼', '绍', '经', 79 | '贯', '奏', '春', '帮', '珍', '玻', '毒', '型', '挂', '封', '持', '项', '垮', '挎', '城', '挠', '政', '赴', '赵', '挡', 80 | '挺', '括', '拴', '拾', '挑', '指', '垫', '挣', '挤', '拼', '挖', '按', '挥', '挪', '某', '甚', '革', '荐', '巷', '带', 81 | '草', '茧', '茶', '荒', '茫', '荡', '荣', '故', '胡', '南', '药', '标', '枯', '柄', '栋', '相', '查', '柏', '柳', '柱', 82 | '柿', '栏', '树', '要', '咸', '威', '歪', '研', '砖', '厘', '厚', '砌', '砍', '面', '耐', '耍', '牵', '残', '殃', '轻', 83 | '鸦', '皆', '背', '战', '点', '临', '览', '竖', '省', '削', '尝', '是', '盼', '眨', '哄', '显', '哑', '冒', '映', '星', 84 | '昨', '畏', '趴', '胃', '贵', '界', '虹', '虾', '蚁', '思', '蚂', '虽', '品', '咽', '骂', '哗', '咱', '响', '哈', '咬', 85 | '咳', '哪', '炭', '峡', '罚', '贱', '贴', '骨', '钞', '钟', '钢', '钥', '钩', '卸', '缸', '拜', '看', '矩', '怎', '牲', 86 | '选', '适', '秒', '香', '种', '秋', '科', '重', '复', '竿', '段', '便', '俩', '贷', '顺', '修', '保', '促', '侮', '俭', 87 | '俗', '俘', '信', '皇', '泉', '鬼', '侵', '追', '俊', '盾', '待', '律', '很', '须', '叙', '剑', '逃', '食', '盆', '胆', 88 | '胜', '胞', '胖', '脉', '勉', '狭', '狮', '独', '狡', '狱', '狠', '贸', '怨', '急', '饶', '蚀', '饺', '饼', '弯', '将', 89 | '奖', '哀', '亭', '亮', '度', '迹', '庭', '疮', '疯', '疫', '疤', '姿', '亲', '音', '帝', '施', '闻', '阀', '阁', '差', 90 | '养', '美', '姜', '叛', '送', '类', '迷', '前', '首', '逆', '总', '炼', '炸', '炮', '烂', '剃', '洁', '洪', '洒', '浇', 91 | '浊', '洞', '测', '洗', '活', '派', '洽', '染', '济', '洋', '洲', '浑', '浓', '津', '恒', '恢', '恰', '恼', '恨', '举', 92 | '觉', '宣', '室', '宫', '宪', '突', '穿', '窃', '客', '冠', '语', '扁', '袄', '祖', '神', '祝', '误', '诱', '说', '诵', 93 | '垦', '退', '既', '屋', '昼', '费', '陡', '眉', '孩', '除', '险', '院', '娃', '姥', '姨', '姻', '娇', '怒', '架', '贺', 94 | '盈', '勇', '怠', '柔', '垒', '绑', '绒', '结', '绕', '骄', '绘', '给', '络', '骆', '绝', '绞', '统', '耕', '耗', '艳', 95 | '泰', '珠', '班', '素', '蚕', '顽', '盏', '匪', '捞', '栽', '捕', '振', '载', '赶', '起', '盐', '捎', '捏', '埋', '捉', 96 | '捆', '捐', '损', '都', '哲', '逝', '捡', '换', '挽', '热', '恐', '壶', '挨', '耻', '耽', '恭', '莲', '莫', '荷', '获', 97 | '晋', '恶', '真', '框', '桂', '档', '桐', '株', '桥', '桃', '格', '校', '核', '样', '根', '索', '哥', '速', '逗', '栗', 98 | '配', '翅', '辱', '唇', '夏', '础', '破', '原', '套', '逐', '烈', '殊', '顾', '轿', '较', '顿', '毙', '致', '柴', '桌', 99 | '虑', '监', '紧', '党', '晒', '眠', '晓', '鸭', '晃', '晌', '晕', '蚊', '哨', '哭', '恩', '唤', '啊', '唉', '罢', '峰', 100 | '圆', '贼', '贿', '钱', '钳', '钻', '铁', '铃', '铅', '缺', '氧', '特', '牺', '造', '乘', '敌', '秤', '租', '积', '秧', 101 | '秩', '称', '秘', '透', '笔', '笑', '笋', '债', '借', '值', '倚', '倾', '倒', '倘', '俱', '倡', '候', '俯', '倍', '倦', 102 | '健', '臭', '射', '躬', '息', '徒', '徐', '舰', '舱', '般', '航', '途', '拿', '爹', '爱', '颂', '翁', '脆', '脂', '胸', 103 | '胳', '脏', '胶', '脑', '狸', '狼', '逢', '留', '皱', '饿', '恋', '桨', '浆', '衰', '高', '席', '准', '座', '脊', '症', 104 | '病', '疾', '疼', '疲', '效', '离', '唐', '资', '凉', '站', '剖', '竞', '部', '旁', '旅', '畜', '阅', '羞', '瓶', '拳', 105 | '粉', '料', '益', '兼', '烤', '烘', '烦', '烧', '烛', '烟', '递', '涛', '浙', '涝', '酒', '涉', '消', '浩', '海', '涂', 106 | '浴', '浮', '流', '润', '浪', '浸', '涨', '烫', '涌', '悟', '悄', '悔', '悦', '害', '宽', '家', '宵', '宴', '宾', '窄', 107 | '容', '宰', '案', '请', '朗', '诸', '读', '扇', '袜', '袖', '袍', '被', '祥', '课', '谁', '调', '冤', '谅', '谈', '谊', 108 | '剥', '恳', '展', '剧', '屑', '弱', '陵', '陶', '陷', '陪', '娱', '娘', '通', '能', '难', '预', '桑', '绢', '绣', '验', 109 | '继', '球', '理', '捧', '堵', '描', '域', '掩', '捷', '排', '掉', '堆', '推', '掀', '授', '教', '掏', '掠', '培', '接', 110 | '控', '探', '据', '掘', '职', '基', '著', '勒', '黄', '萌', '萝', '菌', '菜', '萄', '菊', '萍', '菠', '营', '械', '梦', 111 | '梢', '梅', '检', '梳', '梯', '桶', '救', '副', '票', '戚', '爽', '聋', '袭', '盛', '雪', '辅', '辆', '虚', '雀', '堂', 112 | '常', '匙', '晨', '睁', '眯', '眼', '悬', '野', '啦', '晚', '啄', '距', '跃', '略', '蛇', '累', '唱', '患', '唯', '崖', 113 | '崭', '崇', '圈', '铜', '铲', '银', '甜', '梨', '犁', '移', '笨', '笼', '笛', '符', '第', '敏', '做', '袋', '悠', '偿', 114 | '偶', '偷', '您', '售', '停', '偏', '假', '得', '衔', '盘', '船', '斜', '盒', '鸽', '悉', '欲', '彩', '领', '脚', '脖', 115 | '脸', '脱', '象', '够', '猜', '猪', '猎', '猫', '猛', '馅', '馆', '凑', '减', '毫', '麻', '痒', '痕', '廊', '康', '庸', 116 | '鹿', '盗', '章', '竟', '商', '族', '旋', '望', '率', '着', '盖', '粘', '粗', '粒', '断', '剪', '兽', '清', '添', '淋', 117 | '淹', '渠', '渐', '混', '渔', '淘', '液', '淡', '深', '婆', '梁', '渗', '情', '惜', '惭', '悼', '惧', '惕', '惊', '惨', 118 | '惯', '寇', '寄', '宿', '窑', '密', '谋', '谎', '祸', '谜', '逮', '敢', '屠', '弹', '随', '蛋', '隆', '隐', '婚', '婶', 119 | '颈', '绩', '绪', '续', '骑', '绳', '维', '绵', '绸', '绿', '琴', '斑', '替', '款', '堪', '搭', '塔', '越', '趁', '趋', 120 | '超', '提', '堤', '博', '揭', '喜', '插', '揪', '搜', '煮', '援', '裁', '搁', '搂', '搅', '握', '揉', '斯', '期', '欺', 121 | '联', '散', '惹', '葬', '葛', '董', '葡', '敬', '葱', '落', '朝', '辜', '葵', '棒', '棋', '植', '森', '椅', '椒', '棵', 122 | '棍', '棉', '棚', '棕', '惠', '惑', '逼', '厨', '厦', '硬', '确', '雁', '殖', '裂', '雄', '暂', '雅', '辈', '悲', '紫', 123 | '辉', '敞', '赏', '掌', '晴', '暑', '最', '量', '喷', '晶', '喇', '遇', '喊', '景', '践', '跌', '跑', '遗', '蛙', '蛛', 124 | '蜓', '喝', '喂', '喘', '喉', '幅', '帽', '赌', '赔', '黑', '铸', '铺', '链', '销', '锁', '锄', '锅', '锈', '锋', '锐', 125 | '短', '智', '毯', '鹅', '剩', '稍', '程', '稀', '税', '筐', '等', '筑', '策', '筛', '筒', '答', '筋', '筝', '傲', '傅', 126 | '牌', '堡', '集', '焦', '傍', '储', '奥', '街', '惩', '御', '循', '艇', '舒', '番', '释', '禽', '腊', '脾', '腔', '鲁', 127 | '猾', '猴', '然', '馋', '装', '蛮', '就', '痛', '童', '阔', '善', '羡', '普', '粪', '尊', '道', '曾', '焰', '港', '湖', 128 | '渣', '湿', '温', '渴', '滑', '湾', '渡', '游', '滋', '溉', '愤', '慌', '惰', '愧', '愉', '慨', '割', '寒', '富', '窜', 129 | '窝', '窗', '遍', '裕', '裤', '裙', '谢', '谣', '谦', '属', '屡', '强', '粥', '疏', '隔', '隙', '絮', '嫂', '登', '缎', 130 | '缓', '编', '骗', '缘', '瑞', '魂', '肆', '摄', '摸', '填', '搏', '塌', '鼓', '摆', '携', '搬', '摇', '搞', '塘', '摊', 131 | '蒜', '勤', '鹊', '蓝', '墓', '幕', '蓬', '蓄', '蒙', '蒸', '献', '禁', '楚', '想', '槐', '榆', '楼', '概', '赖', '酬', 132 | '感', '碍', '碑', '碎', '碰', '碗', '碌', '雷', '零', '雾', '雹', '输', '督', '龄', '鉴', '睛', '睡', '睬', '鄙', '愚', 133 | '暖', '盟', '歇', '暗', '照', '跨', '跳', '跪', '路', '跟', '遣', '蛾', '蜂', '嗓', '置', '罪', '罩', '错', '锡', '锣', 134 | '锤', '锦', '键', '锯', '矮', '辞', '稠', '愁', '筹', '签', '简', '毁', '舅', '鼠', '催', '傻', '像', '躲', '微', '愈', 135 | '遥', '腰', '腥', '腹', '腾', '腿', '触', '解', '酱', '痰', '廉', '新', '韵', '意', '粮', '数', '煎', '塑', '慈', '煤', 136 | '煌', '满', '漠', '源', '滤', '滥', '滔', '溪', '溜', '滚', '滨', '粱', '滩', '慎', '誉', '塞', '谨', '福', '群', '殿', 137 | '辟', '障', '嫌', '嫁', '叠', '缝', '缠', '静', '碧', '璃', '墙', '撇', '嘉', '摧', '截', '誓', '境', '摘', '摔', '聚', 138 | '蔽', '慕', '暮', '蔑', '模', '榴', '榜', '榨', '歌', '遭', '酷', '酿', '酸', '磁', '愿', '需', '弊', '裳', '颗', '嗽', 139 | '蜻', '蜡', '蝇', '蜘', '赚', '锹', '锻', '舞', '稳', '算', '箩', '管', '僚', '鼻', '魄', '貌', '膜', '膊', '膀', '鲜', 140 | '疑', '馒', '裹', '敲', '豪', '膏', '遮', '腐', '瘦', '辣', '竭', '端', '旗', '精', '歉', '熄', '熔', '漆', '漂', '漫', 141 | '滴', '演', '漏', '慢', '寨', '赛', '察', '蜜', '谱', '嫩', '翠', '熊', '凳', '骡', '缩', '慧', '撕', '撒', '趣', '趟', 142 | '撑', '播', '撞', '撤', '增', '聪', '鞋', '蕉', '蔬', '横', '槽', '樱', '橡', '飘', '醋', '醉', '震', '霉', '瞒', '题', 143 | '暴', '瞎', '影', '踢', '踏', '踩', '踪', '蝶', '蝴', '嘱', '墨', '镇', '靠', '稻', '黎', '稿', '稼', '箱', '箭', '篇', 144 | '僵', '躺', '僻', '德', '艘', '膝', '膛', '熟', '摩', '颜', '毅', '糊', '遵', '潜', '潮', '懂', '额', '慰', '劈', '操', 145 | '燕', '薯', '薪', '薄', '颠', '橘', '整', '融', '醒', '餐', '嘴', '蹄', '器', '赠', '默', '镜', '赞', '篮', '邀', '衡', 146 | '膨', '雕', '磨', '凝', '辨', '辩', '糖', '糕', '燃', '澡', '激', '懒', '壁', '避', '缴', '戴', '擦', '鞠', '藏', '霜', 147 | '霞', '瞧', '蹈', '螺', '穗', '繁', '辫', '赢', '糟', '糠', '燥', '臂', '翼', '骤', '鞭', '覆', '蹦', '镰', '翻', '鹰', 148 | '警', '攀', '蹲', '颤', '瓣', '爆', '疆', '壤', '耀', '躁', '嚼', '嚷', '籍', '魔', '灌', '蠢', '霸', '露', '囊', '罐', 149 | '匕', '刁', '丐', '歹', '戈', '夭', '仑', '讥', '冗', '邓', '艾', '夯', '凸', '卢', '叭', '叽', '皿', '凹', '囚', '矢', 150 | '乍', '尔', '冯', '玄', '邦', '迂', '邢', '芋', '芍', '吏', '夷', '吁', '吕', '吆', '屹', '廷', '迄', '臼', '仲', '伦', 151 | '伊', '肋', '旭', '匈', '凫', '妆', '亥', '汛', '讳', '讶', '讹', '讼', '诀', '弛', '阱', '驮', '驯', '纫', '玖', '玛', 152 | '韧', '抠', '扼', '汞', '扳', '抡', '坎', '坞', '抑', '拟', '抒', '芙', '芜', '苇', '芥', '芯', '芭', '杖', '杉', '巫', 153 | '杈', '甫', '匣', '轩', '卤', '肖', '吱', '吠', '呕', '呐', '吟', '呛', '吻', '吭', '邑', '囤', '吮', '岖', '牡', '佑', 154 | '佃', '伺', '囱', '肛', '肘', '甸', '狈', '鸠', '彤', '灸', '刨', '庇', '吝', '庐', '闰', '兑', '灼', '沐', '沛', '汰', 155 | '沥', '沦', '汹', '沧', '沪', '忱', '诅', '诈', '罕', '屁', '坠', '妓', '姊', '妒', '纬', '玫', '卦', '坷', '坯', '拓', 156 | '坪', '坤', '拄', '拧', '拂', '拙', '拇', '拗', '茉', '昔', '苛', '苫', '苟', '苞', '茁', '苔', '枉', '枢', '枚', '枫', 157 | '杭', '郁', '矾', '奈', '奄', '殴', '歧', '卓', '昙', '哎', '咕', '呵', '咙', '呻', '咒', '咆', '咖', '帕', '账', '贬', 158 | '贮', '氛', '秉', '岳', '侠', '侥', '侣', '侈', '卑', '刽', '刹', '肴', '觅', '忿', '瓮', '肮', '肪', '狞', '庞', '疟', 159 | '疙', '疚', '卒', '氓', '炬', '沽', '沮', '泣', '泞', '泌', '沼', '怔', '怯', '宠', '宛', '衩', '祈', '诡', '帚', '屉', 160 | '弧', '弥', '陋', '陌', '函', '姆', '虱', '叁', '绅', '驹', '绊', '绎', '契', '贰', '玷', '玲', '珊', '拭', '拷', '拱', 161 | '挟', '垢', '垛', '拯', '荆', '茸', '茬', '荚', '茵', '茴', '荞', '荠', '荤', '荧', '荔', '栈', '柑', '栅', '柠', '枷', 162 | '勃', '柬', '砂', '泵', '砚', '鸥', '轴', '韭', '虐', '昧', '盹', '咧', '昵', '昭', '盅', '勋', '哆', '咪', '哟', '幽', 163 | '钙', '钝', '钠', '钦', '钧', '钮', '毡', '氢', '秕', '俏', '俄', '俐', '侯', '徊', '衍', '胚', '胧', '胎', '狰', '饵', 164 | '峦', '奕', '咨', '飒', '闺', '闽', '籽', '娄', '烁', '炫', '洼', '柒', '涎', '洛', '恃', '恍', '恬', '恤', '宦', '诫', 165 | '诬', '祠', '诲', '屏', '屎', '逊', '陨', '姚', '娜', '蚤', '骇', '耘', '耙', '秦', '匿', '埂', '捂', '捍', '袁', '捌', 166 | '挫', '挚', '捣', '捅', '埃', '耿', '聂', '荸', '莽', '莱', '莉', '莹', '莺', '梆', '栖', '桦', '栓', '桅', '桩', '贾', 167 | '酌', '砸', '砰', '砾', '殉', '逞', '哮', '唠', '哺', '剔', '蚌', '蚜', '畔', '蚣', '蚪', '蚓', '哩', '圃', '鸯', '唁', 168 | '哼', '唆', '峭', '唧', '峻', '赂', '赃', '钾', '铆', '氨', '秫', '笆', '俺', '赁', '倔', '殷', '耸', '舀', '豺', '豹', 169 | '颁', '胯', '胰', '脐', '脓', '逛', '卿', '鸵', '鸳', '馁', '凌', '凄', '衷', '郭', '斋', '疹', '紊', '瓷', '羔', '烙', 170 | '浦', '涡', '涣', '涤', '涧', '涕', '涩', '悍', '悯', '窍', '诺', '诽', '袒', '谆', '祟', '恕', '娩', '骏', '琐', '麸', 171 | '琉', '琅', '措', '捺', '捶', '赦', '埠', '捻', '掐', '掂', '掖', '掷', '掸', '掺', '勘', '聊', '娶', '菱', '菲', '萎', 172 | '菩', '萤', '乾', '萧', '萨', '菇', '彬', '梗', '梧', '梭', '曹', '酝', '酗', '厢', '硅', '硕', '奢', '盔', '匾', '颅', 173 | '彪', '眶', '晤', '曼', '晦', '冕', '啡', '畦', '趾', '啃', '蛆', '蚯', '蛉', '蛀', '唬', '唾', '啤', '啥', '啸', '崎', 174 | '逻', '崔', '崩', '婴', '赊', '铐', '铛', '铝', '铡', '铣', '铭', '矫', '秸', '秽', '笙', '笤', '偎', '傀', '躯', '兜', 175 | '衅', '徘', '徙', '舶', '舷', '舵', '敛', '翎', '脯', '逸', '凰', '猖', '祭', '烹', '庶', '庵', '痊', '阎', '阐', '眷', 176 | '焊', '焕', '鸿', '涯', '淑', '淌', '淮', '淆', '渊', '淫', '淳', '淤', '淀', '涮', '涵', '惦', '悴', '惋', '寂', '窒', 177 | '谍', '谐', '裆', '袱', '祷', '谒', '谓', '谚', '尉', '堕', '隅', '婉', '颇', '绰', '绷', '综', '绽', '缀', '巢', '琳', 178 | '琢', '琼', '揍', '堰', '揩', '揽', '揖', '彭', '揣', '搀', '搓', '壹', '搔', '葫', '募', '蒋', '蒂', '韩', '棱', '椰', 179 | '焚', '椎', '棺', '榔', '椭', '粟', '棘', '酣', '酥', '硝', '硫', '颊', '雳', '翘', '凿', '棠', '晰', '鼎', '喳', '遏', 180 | '晾', '畴', '跋', '跛', '蛔', '蜒', '蛤', '鹃', '喻', '啼', '喧', '嵌', '赋', '赎', '赐', '锉', '锌', '甥', '掰', '氮', 181 | '氯', '黍', '筏', '牍', '粤', '逾', '腌', '腋', '腕', '猩', '猬', '惫', '敦', '痘', '痢', '痪', '竣', '翔', '奠', '遂', 182 | '焙', '滞', '湘', '渤', '渺', '溃', '溅', '湃', '愕', '惶', '寓', '窖', '窘', '雇', '谤', '犀', '隘', '媒', '媚', '婿', 183 | '缅', '缆', '缔', '缕', '骚', '瑟', '鹉', '瑰', '搪', '聘', '斟', '靴', '靶', '蓖', '蒿', '蒲', '蓉', '楔', '椿', '楷', 184 | '榄', '楞', '楣', '酪', '碘', '硼', '碉', '辐', '辑', '频', '睹', '睦', '瞄', '嗜', '嗦', '暇', '畸', '跷', '跺', '蜈', 185 | '蜗', '蜕', '蛹', '嗅', '嗡', '嗤', '署', '蜀', '幌', '锚', '锥', '锨', '锭', '锰', '稚', '颓', '筷', '魁', '衙', '腻', 186 | '腮', '腺', '鹏', '肄', '猿', '颖', '煞', '雏', '馍', '馏', '禀', '痹', '廓', '痴', '靖', '誊', '漓', '溢', '溯', '溶', 187 | '滓', '溺', '寞', '窥', '窟', '寝', '褂', '裸', '谬', '媳', '嫉', '缚', '缤', '剿', '赘', '熬', '赫', '蔫', '摹', '蔓', 188 | '蔗', '蔼', '熙', '蔚', '兢', '榛', '榕', '酵', '碟', '碴', '碱', '碳', '辕', '辖', '雌', '墅', '嘁', '踊', '蝉', '嘀', 189 | '幔', '镀', '舔', '熏', '箍', '箕', '箫', '舆', '僧', '孵', '瘩', '瘟', '彰', '粹', '漱', '漩', '漾', '慷', '寡', '寥', 190 | '谭', '褐', '褪', '隧', '嫡', '缨', '撵', '撩', '撮', '撬', '擒', '墩', '撰', '鞍', '蕊', '蕴', '樊', '樟', '橄', '敷', 191 | '豌', '醇', '磕', '磅', '碾', '憋', '嘶', '嘲', '嘹', '蝠', '蝎', '蝌', '蝗', '蝙', '嘿', '幢', '镊', '镐', '稽', '篓', 192 | '膘', '鲤', '鲫', '褒', '瘪', '瘤', '瘫', '凛', '澎', '潭', '潦', '澳', '潘', '澈', '澜', '澄', '憔', '懊', '憎', '翩', 193 | '褥', '谴', '鹤', '憨', '履', '嬉', '豫', '缭', '撼', '擂', '擅', '蕾', '薛', '薇', '擎', '翰', '噩', '橱', '橙', '瓢', 194 | '蟥', '霍', '霎', '辙', '冀', '踱', '蹂', '蟆', '螃', '螟', '噪', '鹦', '黔', '穆', '篡', '篷', '篙', '篱', '儒', '膳', 195 | '鲸', '瘾', '瘸', '糙', '燎', '濒', '憾', '懈', '窿', '缰', '壕', '藐', '檬', '檐', '檩', '檀', '礁', '磷', '了', '瞬', 196 | '瞳', '瞪', '曙', '蹋', '蟋', '蟀', '嚎', '赡', '镣', '魏', '簇', '儡', '徽', '爵', '朦', '臊', '鳄', '糜', '癌', '懦', 197 | '豁', '臀', '藕', '藤', '瞻', '嚣', '鳍', '癞', '瀑', '襟', '璧', '戳', '攒', '孽', '蘑', '藻', '鳖', '蹭', '蹬', '簸', 198 | '簿', '蟹', '靡', '癣', '羹', '鬓', '攘', '蠕', '巍', '鳞', '糯', '譬', '霹', '躏', '髓', '蘸', '镶', '瓤', '矗', '圳', 199 | '珏', '蕙', '旻', '涅', '攸', '嘛', '醪', '缪', '噗', '瞨', '靳', '帷', '徨', 200 | ] 201 | 202 | FLOAT = ['.'] 203 | 204 | SIMPLE_CATEGORY_MODEL = dict( 205 | NUMERIC=NUMBER, 206 | ALPHANUMERIC=NUMBER + ALPHA_LOWER + ALPHA_UPPER, 207 | ALPHANUMERIC_LOWER=NUMBER + ALPHA_LOWER, 208 | ALPHANUMERIC_UPPER=NUMBER + ALPHA_UPPER, 209 | ALPHABET_LOWER=ALPHA_LOWER, 210 | ALPHABET_UPPER=ALPHA_UPPER, 211 | ALPHABET=ALPHA_LOWER + ALPHA_UPPER, 212 | ARITHMETIC=NUMBER + ARITHMETIC, 213 | FLOAT=NUMBER + FLOAT, 214 | CHS_3500=CHINESE_3500, 215 | ALPHANUMERIC_MIX_CHS_3500_LOWER=NUMBER + ALPHA_LOWER + CHINESE_3500 216 | ) 217 | 218 | 219 | def encode_maps(source): 220 | return {category: i for i, category in enumerate(source, 0)} 221 | 222 | 223 | @unique 224 | class ModelScene(Enum): 225 | """模型场景枚举""" 226 | Classification = 'Classification' 227 | 228 | 229 | @unique 230 | class ModelField(Enum): 231 | """模型类别枚举""" 232 | Image = 'Image' 233 | Text = 'Text' 234 | 235 | 236 | MODEL_SCENE_MAP = { 237 | 'Classification': ModelScene.Classification 238 | } 239 | 240 | MODEL_FIELD_MAP = { 241 | 'Image': ModelField.Image, 242 | 'Text': ModelField.Text 243 | } 244 | 245 | 246 | class ModelConfig(object): 247 | 248 | @staticmethod 249 | def category_extract(param): 250 | if isinstance(param, list): 251 | return param 252 | if isinstance(param, str): 253 | if param in SIMPLE_CATEGORY_MODEL.keys(): 254 | return SIMPLE_CATEGORY_MODEL.get(param) 255 | raise ValueError( 256 | "Category set configuration error, customized category set should be list type" 257 | ) 258 | 259 | @property 260 | def model_conf(self) -> dict: 261 | if self.model_content: 262 | return self.model_content 263 | with open(self.model_conf_path, 'r', encoding="utf-8") as sys_fp: 264 | sys_stream = sys_fp.read() 265 | return yaml.load(sys_stream, Loader=yaml.SafeLoader) 266 | 267 | def __init__(self, model_conf_path=None, model_content=None): 268 | self.model_content = model_content 269 | self.model_path = model_conf_path 270 | self.graph_path = os.path.dirname(self.model_path) if model_conf_path else "" 271 | self.model_conf_path = model_conf_path 272 | self.model_conf_demo = 'model_demo.yaml' 273 | 274 | """MODEL""" 275 | self.model_root: dict = self.model_conf['Model'] 276 | self.model_name: str = self.model_root.get('ModelName') 277 | self.model_version: float = self.model_root.get('Version') 278 | self.model_version = self.model_version if self.model_version else 1.0 279 | self.model_field_param: str = self.model_root.get('ModelField') 280 | self.model_field: ModelField = self.param_convert( 281 | source=self.model_field_param, 282 | param_map=MODEL_FIELD_MAP, 283 | text="Current model field ({model_field}) is not supported".format(model_field=self.model_field_param), 284 | code=50002 285 | ) 286 | 287 | self.model_scene_param: str = self.model_root.get('ModelScene') 288 | 289 | self.model_scene: ModelScene = self.param_convert( 290 | source=self.model_scene_param, 291 | param_map=MODEL_SCENE_MAP, 292 | text="Current model scene ({model_scene}) is not supported".format(model_scene=self.model_scene_param), 293 | code=50001 294 | ) 295 | 296 | """SYSTEM""" 297 | self.checkpoint_tag = 'checkpoint' 298 | self.system_root: dict = self.model_conf['System'] 299 | self.memory_usage: float = self.system_root.get('MemoryUsage') 300 | 301 | """FIELD PARAM - IMAGE""" 302 | self.field_root: dict = self.model_conf['FieldParam'] 303 | self.category_param = self.field_root.get('Category') 304 | self.category_value = self.category_extract(self.category_param) 305 | if self.category_value is None: 306 | raise Exception( 307 | "The category set type does not exist, there is no category set named {}".format(self.category_param), 308 | ) 309 | self.category: list = SPACE_TOKEN + self.category_value 310 | self.category_num: int = len(self.category) 311 | self.image_channel: int = self.field_root.get('ImageChannel') 312 | self.image_width: int = self.field_root.get('ImageWidth') 313 | self.image_height: int = self.field_root.get('ImageHeight') 314 | self.resize: list = self.field_root.get('Resize') 315 | self.output_split = self.field_root.get('OutputSplit') 316 | self.output_split = self.output_split if self.output_split else "" 317 | self.category_split = self.field_root.get('CategorySplit') 318 | self.corp_params = self.field_root.get('CorpParams') 319 | self.output_coord = self.field_root.get('OutputCoord') 320 | self.batch_model = self.field_root.get('BatchModel') 321 | 322 | """PRETREATMENT""" 323 | self.pretreatment_root = self.model_conf.get('Pretreatment') 324 | self.pre_binaryzation = self.get_var(self.pretreatment_root, 'Binaryzation', -1) 325 | self.pre_replace_transparent = self.get_var(self.pretreatment_root, 'ReplaceTransparent', True) 326 | self.pre_horizontal_stitching = self.get_var(self.pretreatment_root, 'HorizontalStitching', False) 327 | self.pre_concat_frames = self.get_var(self.pretreatment_root, 'ConcatFrames', -1) 328 | self.pre_blend_frames = self.get_var(self.pretreatment_root, 'BlendFrames', -1) 329 | self.exec_map = self.pretreatment_root.get('ExecuteMap') 330 | """COMPILE_MODEL""" 331 | if self.graph_path: 332 | self.compile_model_path = os.path.join(self.graph_path, '{}.onnx'.format(self.model_name)) 333 | if not os.path.exists(self.compile_model_path): 334 | if not os.path.exists(self.graph_path): 335 | os.makedirs(self.graph_path) 336 | raise ValueError( 337 | '{} not found, please put the trained model in the current directory.'.format(self.compile_model_path) 338 | ) 339 | else: 340 | self.model_exists = True 341 | else: 342 | self.model_exists = True if self.model_content else False 343 | self.compile_model_path = "" 344 | 345 | @staticmethod 346 | def param_convert(source, param_map: dict, text, code, default=None): 347 | if source is None: 348 | return default 349 | if source not in param_map.keys(): 350 | raise Exception(text) 351 | return param_map[source] 352 | 353 | def size_match(self, size_str): 354 | return size_str == self.size_string 355 | 356 | @staticmethod 357 | def get_var(src: dict, name: str, default=None): 358 | if not src: 359 | return default 360 | return src.get(name) 361 | 362 | @property 363 | def size_string(self): 364 | return "{}x{}".format(self.image_width, self.image_height) 365 | 366 | 367 | class Model(object): 368 | model_conf: ModelConfig 369 | graph_bytes: object = None 370 | 371 | def __init__(self, conf_path: str, source_bytes: bytes = None, key=None): 372 | if conf_path: 373 | self.model_conf = ModelConfig(model_conf_path=conf_path) 374 | # self.graph_bytes = self.model_conf.compile_model_path 375 | if source_bytes: 376 | model_conf, self.graph_bytes = self.parse_model(source_bytes, key) 377 | self.model_conf = ModelConfig(model_content=model_conf) 378 | 379 | @staticmethod 380 | def parse_model(source_bytes: bytes, key=None): 381 | split_tag = b'-#||#-' 382 | 383 | if not key: 384 | key = [b"_____" + i.encode("utf8") + b"_____" for i in "&coriander"] 385 | if isinstance(key, str): 386 | key = [b"_____" + i.encode("utf8") + b"_____" for i in key] 387 | key_len_int = len(key) 388 | model_bytes_list = [] 389 | graph_bytes_list = [] 390 | slice_index = source_bytes.index(key[0]) 391 | split_tag_len = len(split_tag) 392 | slice_0 = source_bytes[0: slice_index].split(split_tag) 393 | model_slice_len = len(slice_0[1]) 394 | graph_slice_len = len(slice_0[0]) 395 | slice_len = split_tag_len + model_slice_len + graph_slice_len 396 | 397 | for i in range(key_len_int - 1): 398 | slice_index = source_bytes.index(key[i]) 399 | print(slice_index, slice_index - slice_len) 400 | slices = source_bytes[slice_index - slice_len: slice_index].split(split_tag) 401 | model_bytes_list.append(slices[1]) 402 | graph_bytes_list.append(slices[0]) 403 | slices = source_bytes.split(key[-2])[1][:-len(key[-1])].split(split_tag) 404 | 405 | model_bytes_list.append(slices[1]) 406 | graph_bytes_list.append(slices[0]) 407 | model_bytes = b"".join(model_bytes_list) 408 | model_conf: dict = pickle.loads(model_bytes) 409 | graph_bytes: bytes = b"".join(graph_bytes_list) 410 | return model_conf, graph_bytes 411 | 412 | 413 | class GraphSession(object): 414 | def __init__(self, model: Model): 415 | self.model_conf = model.model_conf 416 | self.size_str = self.model_conf.size_string 417 | self.model_name = self.model_conf.model_name 418 | self.graph_name = self.model_conf.model_name 419 | self.version = self.model_conf.model_version 420 | self.graph_bytes = model.graph_bytes 421 | self.sess = ort.InferenceSession( 422 | self.model_conf.compile_model_path if not model.graph_bytes else model.graph_bytes 423 | ) 424 | 425 | 426 | class Interface(object): 427 | 428 | def __init__(self, graph_session: GraphSession): 429 | self.graph_sess = graph_session 430 | self.model_conf = graph_session.model_conf 431 | self.size_str = self.model_conf.size_string 432 | self.graph_name = self.graph_sess.graph_name 433 | self.version = self.graph_sess.version 434 | self.model_category = self.model_conf.category 435 | if self.graph_sess.sess: 436 | self.sess = self.graph_sess.sess 437 | 438 | @property 439 | def name(self): 440 | return self.graph_name 441 | 442 | @property 443 | def size(self): 444 | return self.size_str 445 | 446 | def predict_batch(self, image_batch, output_split=None): 447 | predict_text = self.predict_func( 448 | image_batch, 449 | self.sess, 450 | self.model_conf, 451 | output_split 452 | ) 453 | return predict_text 454 | 455 | @staticmethod 456 | def decode_maps(categories): 457 | return {index: category for index, category in enumerate(categories, 0)} 458 | 459 | def predict_func(self, image_batch, _sess, model: ModelConfig, output_split=None): 460 | if isinstance(image_batch, list): 461 | image_batch = np.asarray(image_batch) 462 | 463 | output_split = model.output_split if output_split is None else output_split 464 | 465 | category_split = model.category_split if model.category_split else "" 466 | 467 | dense_decoded_code = _sess.run(["dense_decoded:0"], input_feed={ 468 | "input:0": image_batch, 469 | }) 470 | decoded_expression = [] 471 | 472 | for item in dense_decoded_code[0]: 473 | expression = [] 474 | for i in item: 475 | if i == -1 or i == model.category_num: 476 | expression.append("") 477 | else: 478 | expression.append(self.decode_maps(model.category)[i]) 479 | 480 | decoded_expression.append(category_split.join(expression)) 481 | return output_split.join(decoded_expression) if len(decoded_expression) > 1 else decoded_expression[0] 482 | 483 | 484 | class Pretreatment(object): 485 | 486 | def __init__(self, origin): 487 | self.origin = origin 488 | 489 | def get(self): 490 | return self.origin 491 | 492 | def binarization(self, value, modify=False): 493 | ret, _binarization = cv2.threshold(self.origin, value, 255, cv2.THRESH_BINARY) 494 | if modify: 495 | self.origin = _binarization 496 | return _binarization 497 | 498 | @staticmethod 499 | def preprocessing(image, binaryzation=-1): 500 | pretreatment = Pretreatment(image) 501 | if binaryzation > 0: 502 | pretreatment.binarization(binaryzation, True) 503 | return pretreatment.get() 504 | 505 | @staticmethod 506 | def preprocessing_by_func(exec_map, key, src_arr): 507 | if not exec_map: 508 | return src_arr 509 | target_arr = cv2.cvtColor(src_arr, cv2.COLOR_RGB2BGR) 510 | for sentence in exec_map.get(key): 511 | if sentence.startswith("@@"): 512 | target_arr = eval(sentence[2:]) 513 | elif sentence.startswith("$$"): 514 | exec(sentence[2:]) 515 | return cv2.cvtColor(target_arr, cv2.COLOR_BGR2RGB) 516 | 517 | 518 | class ImageUtils(object): 519 | 520 | @staticmethod 521 | def get_bytes_batch(image_bytes): 522 | try: 523 | bytes_batch = [image_bytes] 524 | except binascii.Error: 525 | return None, "INVALID_BASE64_STRING" 526 | what_img = [ImageUtils.test_image(i) for i in bytes_batch] 527 | if None in what_img: 528 | return None, "INVALID_IMAGE_FORMAT" 529 | return bytes_batch, "SUCCESS" 530 | 531 | @staticmethod 532 | def get_image_batch(model: ModelConfig, bytes_batch, param_key=None): 533 | # Note that there are two return objects here. 534 | # 1.image_batch, 2.response 535 | 536 | def load_image(image_bytes): 537 | data_stream = io.BytesIO(image_bytes) 538 | pil_image = PIL_Image.open(data_stream) 539 | 540 | gif_handle = model.pre_concat_frames != -1 or model.pre_blend_frames != -1 541 | 542 | if pil_image.mode == 'P' and not gif_handle: 543 | pil_image = pil_image.convert('RGB') 544 | 545 | rgb = pil_image.split() 546 | size = pil_image.size 547 | 548 | if (len(rgb) > 3 and model.pre_replace_transparent) and not gif_handle: 549 | background = PIL_Image.new('RGB', pil_image.size, (255, 255, 255)) 550 | background.paste(pil_image, (0, 0, size[0], size[1]), pil_image) 551 | pil_image = background 552 | 553 | im = np.asarray(pil_image) 554 | 555 | im = Pretreatment.preprocessing_by_func( 556 | exec_map=model.exec_map, 557 | key=param_key, 558 | src_arr=im 559 | ) 560 | 561 | if model.image_channel == 1 and len(im.shape) == 3: 562 | im = cv2.cvtColor(im, cv2.COLOR_RGB2GRAY) 563 | 564 | im = Pretreatment.preprocessing( 565 | image=im, 566 | binaryzation=model.pre_binaryzation, 567 | ) 568 | 569 | if model.pre_horizontal_stitching: 570 | up_slice = im[0: int(size[1] / 2), 0: size[0]] 571 | down_slice = im[int(size[1] / 2): size[1], 0: size[0]] 572 | im = np.concatenate((up_slice, down_slice), axis=1) 573 | 574 | image = im.astype(np.float32) 575 | if model.resize[0] == -1: 576 | ratio = model.resize[1] / size[1] 577 | resize_width = int(ratio * size[0]) 578 | image = cv2.resize(image, (resize_width, model.resize[1])) 579 | else: 580 | image = cv2.resize(image, (model.resize[0], model.resize[1])) 581 | image = image.swapaxes(0, 1) 582 | return (image[:, :, np.newaxis] if model.image_channel == 1 else image[:, :]) / 255. 583 | 584 | try: 585 | image_batch = [load_image(i) for i in bytes_batch] 586 | return image_batch, "SUCCESS" 587 | except OSError: 588 | return None, "IMAGE_DAMAGE" 589 | except ValueError as _e: 590 | print(_e) 591 | return None, "IMAGE_SIZE_NOT_MATCH_GRAPH" 592 | 593 | @staticmethod 594 | def size_of_image(image_bytes: bytes): 595 | _null_size = tuple((-1, -1)) 596 | try: 597 | data_stream = io.BytesIO(image_bytes) 598 | size = PIL_Image.open(data_stream).size 599 | return size 600 | except OSError: 601 | return _null_size 602 | except ValueError: 603 | return _null_size 604 | 605 | @staticmethod 606 | def test_image(h): 607 | """JPEG""" 608 | if h[:3] == b"\xff\xd8\xff": 609 | return 'jpeg' 610 | """PNG""" 611 | if h[:8] == b"\211PNG\r\n\032\n": 612 | return 'png' 613 | """GIF ('87 and '89 variants)""" 614 | if h[:6] in (b'GIF87a', b'GIF89a'): 615 | return 'gif' 616 | """TIFF (can be in Motorola or Intel byte order)""" 617 | if h[:2] in (b'MM', b'II'): 618 | return 'tiff' 619 | if h[:2] == b'BM': 620 | return 'bmp' 621 | """SGI image library""" 622 | if h[:2] == b'\001\332': 623 | return 'rgb' 624 | """PBM (portable bitmap)""" 625 | if len(h) >= 3 and \ 626 | h[0] == b'P' and h[1] in b'14' and h[2] in b' \t\n\r': 627 | return 'pbm' 628 | """PGM (portable graymap)""" 629 | if len(h) >= 3 and \ 630 | h[0] == b'P' and h[1] in b'25' and h[2] in b' \t\n\r': 631 | return 'pgm' 632 | """PPM (portable pixmap)""" 633 | if len(h) >= 3 and h[0] == b'P' and h[1] in b'36' and h[2] in b' \t\n\r': 634 | return 'ppm' 635 | """Sun raster file""" 636 | if h[:4] == b'\x59\xA6\x6A\x95': 637 | return 'rast' 638 | """X bitmap (X10 or X11)""" 639 | s = b'#define ' 640 | if h[:len(s)] == s: 641 | return 'xbm' 642 | return None 643 | 644 | 645 | class SDK(object): 646 | 647 | def __init__(self, conf_path=None, model_entity: bytes = None): 648 | if not conf_path and not model_entity: 649 | raise ValueError('One of parameters conf_path and model_entity must be filled') 650 | model = Model(conf_path=conf_path, source_bytes=model_entity) 651 | self.model_conf = model.model_conf 652 | self.graph_session = GraphSession(model) 653 | self.interface = Interface(self.graph_session) 654 | 655 | def predict(self, image_bytes, param_key=None): 656 | bytes_batch, message = ImageUtils.get_bytes_batch(image_bytes) 657 | if not bytes_batch: 658 | raise ValueError(message) 659 | image_batch, message = ImageUtils.get_image_batch(self.model_conf, bytes_batch, param_key=param_key) 660 | if not image_batch: 661 | raise ValueError(message) 662 | result = self.interface.predict_batch(image_batch, None) 663 | return result 664 | 665 | 666 | if __name__ == '__main__': 667 | # FROM PATH 668 | # sdk = SDK(conf_path=r"model.yaml") 669 | # with open(r"H:\TrainSet\1541187040676.jpg", "rb") as f: 670 | # b = f.read() 671 | # for i in [b] * 1000: 672 | # t1 = time.time() 673 | # print(sdk.predict(b), (time.time() - t1) * 1000) 674 | 675 | # FROM BYTES 676 | with open(r"model.pl", "rb") as f: 677 | b = f.read() 678 | sdk = SDK(model_entity=b) 679 | with open(r"1540868881850.jpg", "rb") as f: 680 | b = f.read() 681 | for i in [b] * 1000: 682 | t1 = time.time() 683 | print(sdk.predict(b), (time.time() - t1) * 1000) 684 | -------------------------------------------------------------------------------- /sdk/pb/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz -------------------------------------------------------------------------------- /sdk/pb/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==1.14 2 | numpy 3 | pillow 4 | opencv-python==3.4.5.20 5 | pyyaml>=3.13 -------------------------------------------------------------------------------- /sdk/tflite/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz -------------------------------------------------------------------------------- /sdk/tflite/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==1.14 2 | numpy 3 | pillow 4 | opencv-python==3.4.5.20 5 | pyyaml>=3.13 -------------------------------------------------------------------------------- /signature.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | from functools import wraps 5 | from constants import ServerType 6 | from utils import * 7 | 8 | 9 | class InvalidUsage(Exception): 10 | 11 | def __init__(self, message, code=None): 12 | Exception.__init__(self) 13 | self.message = message 14 | self.success = False 15 | self.code = code 16 | 17 | def to_dict(self): 18 | rv = {'code': self.code, 'message': self.message, 'success': self.success} 19 | return rv 20 | 21 | 22 | class Signature(object): 23 | """ api signature authentication """ 24 | 25 | def __init__(self, server_type: ServerType, conf: Config): 26 | self.conf = conf 27 | self._except = Response(self.conf.response_def_map) 28 | self._auth = [] 29 | self._timestamp_expiration = 120 30 | self.request = None 31 | self.type = server_type 32 | 33 | def set_auth(self, auth): 34 | self._auth = auth 35 | 36 | def _check_req_timestamp(self, req_timestamp): 37 | """ Check the timestamp 38 | @pram req_timestamp str,int: Timestamp in the request parameter (10 digits) 39 | """ 40 | if len(str(req_timestamp)) == 10: 41 | req_timestamp = int(req_timestamp) 42 | now_timestamp = SignUtils.timestamp() 43 | if now_timestamp - self._timestamp_expiration <= req_timestamp <= now_timestamp + self._timestamp_expiration: 44 | return True 45 | return False 46 | 47 | def _check_req_access_key(self, req_access_key): 48 | """ Check the access_key in the request parameter 49 | @pram req_access_key str: access key in the request parameter 50 | """ 51 | if req_access_key in [i['accessKey'] for i in self._auth if "accessKey" in i]: 52 | return True 53 | return False 54 | 55 | def _get_secret_key(self, access_key): 56 | """ Obtain the corresponding secret_key according to access_key 57 | @pram access_key str: access key in the request parameter 58 | """ 59 | secret_keys = [i['secretKey'] for i in self._auth if i.get('accessKey') == access_key] 60 | return "" if not secret_keys else secret_keys[0] 61 | 62 | def _sign(self, args): 63 | """ MD5 signature 64 | @param args: All query parameters (public and private) requested in addition to signature 65 | """ 66 | if "sign" in args: 67 | args.pop("sign") 68 | access_key = args["accessKey"] 69 | query_string = '&'.join(['{}={}'.format(k, v) for (k, v) in sorted(args.items())]) 70 | query_string = '&'.join([query_string, self._get_secret_key(access_key)]) 71 | return SignUtils.md5(query_string).upper() 72 | 73 | def _verification(self, req_params, tornado_handler=None): 74 | """ Verify that the request is valid 75 | @param req_params: All query parameters requested (public and private) 76 | """ 77 | try: 78 | req_signature = req_params["sign"] 79 | req_timestamp = req_params["timestamp"] 80 | req_access_key = req_params["accessKey"] 81 | except KeyError: 82 | raise InvalidUsage(**self._except.INVALID_PUBLIC_PARAMS) 83 | except Exception: 84 | raise InvalidUsage(**self._except.UNKNOWN_SERVER_ERROR) 85 | else: 86 | if self.type == ServerType.FLASK or self.type == ServerType.SANIC: 87 | from flask.app import HTTPException, json 88 | # NO.1 Check the timestamp 89 | if not self._check_req_timestamp(req_timestamp): 90 | raise HTTPException(response=json.jsonify(self._except.INVALID_TIMESTAMP)) 91 | # NO.2 Check the access_id 92 | if not self._check_req_access_key(req_access_key): 93 | raise HTTPException(response=json.jsonify(self._except.INVALID_ACCESS_KEY)) 94 | # NO.3 Check the sign 95 | if req_signature == self._sign(req_params): 96 | return True 97 | else: 98 | raise HTTPException(response=json.jsonify(self._except.INVALID_QUERY_STRING)) 99 | elif self.type == ServerType.TORNADO: 100 | from tornado.web import HTTPError 101 | # NO.1 Check the timestamp 102 | if not self._check_req_timestamp(req_timestamp): 103 | return tornado_handler.write_error(self._except.INVALID_TIMESTAMP['code']) 104 | # NO.2 Check the access_id 105 | if not self._check_req_access_key(req_access_key): 106 | return tornado_handler.write_error(self._except.INVALID_ACCESS_KEY['code']) 107 | # NO.3 Check the sign 108 | if req_signature == self._sign(req_params): 109 | return True 110 | else: 111 | return tornado_handler.write_error(self._except.INVALID_QUERY_STRING['code']) 112 | raise Exception('Unknown Server Type') 113 | 114 | def signature_required(self, f): 115 | @wraps(f) 116 | def decorated_function(*args, **kwargs): 117 | if self.type == ServerType.FLASK: 118 | from flask import request 119 | params = request.json 120 | elif self.type == ServerType.TORNADO: 121 | from tornado.escape import json_decode 122 | params = json_decode(args[0].request.body) 123 | elif self.type == ServerType.SANIC: 124 | params = args[0].json 125 | else: 126 | raise UserWarning('Illegal type, the current version is not supported at this time.') 127 | result = self._verification(params, args[0] if self.type == ServerType.TORNADO else None) 128 | if result is True: 129 | return f(*args, **kwargs) 130 | return decorated_function 131 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class SquareTest(tf.test.TestCase): 5 | 6 | def testSquare(self): 7 | with self.test_session(): 8 | x = tf.square([2, 3]) 9 | self.assertAllEqual(x.eval(), [4, 9]) 10 | 11 | 12 | if __name__ == '__main__': 13 | tf.test.main() 14 | -------------------------------------------------------------------------------- /tornado_server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import os 5 | import uuid 6 | import time 7 | import json 8 | import platform 9 | import numpy as np 10 | import asyncio 11 | import hashlib 12 | import optparse 13 | import threading 14 | import tornado.ioloop 15 | import tornado.log 16 | import tornado.gen 17 | import tornado.httpserver 18 | import tornado.options 19 | from pytz import utc 20 | from apscheduler.triggers.interval import IntervalTrigger 21 | from apscheduler.schedulers.background import BackgroundScheduler 22 | from tornado.web import RequestHandler 23 | from constants import Response 24 | from json.decoder import JSONDecodeError 25 | from tornado.escape import json_decode, json_encode 26 | from interface import InterfaceManager, Interface 27 | from config import Config, blacklist, set_blacklist, whitelist, get_version 28 | from utils import ImageUtils, ParamUtils, Arithmetic 29 | from signature import Signature, ServerType 30 | from tornado.concurrent import run_on_executor 31 | from concurrent.futures import ThreadPoolExecutor 32 | from middleware import * 33 | from event_loop import event_loop 34 | 35 | tornado.options.define('ip_blacklist', default=list(), type=list) 36 | tornado.options.define('ip_whitelist', default=list(), type=list) 37 | tornado.options.define('ip_risk_times', default=dict(), type=dict) 38 | tornado.options.define('request_count', default=dict(), type=dict) 39 | tornado.options.define('global_request_count', default=0, type=int) 40 | model_path = "model" 41 | system_config = Config(conf_path="config.yaml", model_path=model_path, graph_path="graph") 42 | sign = Signature(ServerType.TORNADO, system_config) 43 | arithmetic = Arithmetic() 44 | semaphore = asyncio.Semaphore(500) 45 | 46 | scheduler = BackgroundScheduler(timezone='Asia/Shanghai') 47 | 48 | 49 | class BaseHandler(RequestHandler): 50 | 51 | def __init__(self, application, request, **kwargs): 52 | super().__init__(application, request, **kwargs) 53 | self.exception = Response(system_config.response_def_map) 54 | self.executor = ThreadPoolExecutor(workers) 55 | self.image_utils = ImageUtils(system_config) 56 | 57 | @property 58 | def request_incr(self): 59 | if self.request.remote_ip not in tornado.options.options.request_count: 60 | tornado.options.options.request_count[self.request.remote_ip] = 1 61 | else: 62 | tornado.options.options.request_count[self.request.remote_ip] += 1 63 | return tornado.options.options.request_count[self.request.remote_ip] 64 | 65 | def request_desc(self): 66 | if self.request.remote_ip not in tornado.options.options.request_count: 67 | return 68 | else: 69 | tornado.options.options.request_count[self.request.remote_ip] -= 1 70 | 71 | @property 72 | def global_request_incr(self): 73 | tornado.options.options.global_request_count += 1 74 | return tornado.options.options.global_request_count 75 | 76 | @staticmethod 77 | def global_request_desc(): 78 | tornado.options.options.global_request_count -= 1 79 | 80 | @staticmethod 81 | def risk_ip_count(ip): 82 | if ip not in tornado.options.options.ip_risk_times: 83 | tornado.options.options.ip_risk_times[ip] = 1 84 | else: 85 | tornado.options.options.ip_risk_times[ip] += 1 86 | 87 | @staticmethod 88 | def risk_ip(ip): 89 | return tornado.options.options.ip_risk_times[ip] 90 | 91 | def data_received(self, chunk): 92 | pass 93 | 94 | def parse_param(self): 95 | try: 96 | data = json_decode(self.request.body) 97 | except JSONDecodeError: 98 | data = self.request.body_arguments 99 | except UnicodeDecodeError: 100 | raise tornado.web.HTTPError(401) 101 | if not data: 102 | raise tornado.web.HTTPError(400) 103 | return data 104 | 105 | def write_error(self, code, **kw): 106 | err_resp = dict(StatusCode=code, Message=system_config.error_message[code], StatusBool=False) 107 | if code in system_config.error_message: 108 | code_dict = Response.parse(err_resp, system_config.response_def_map) 109 | else: 110 | code_dict = self.exception.find(code) 111 | return self.finish(json.dumps(code_dict, ensure_ascii=False)) 112 | 113 | 114 | class NoAuthHandler(BaseHandler): 115 | uid_key: str = system_config.response_def_map['Uid'] 116 | message_key: str = system_config.response_def_map['Message'] 117 | status_bool_key = system_config.response_def_map['StatusBool'] 118 | status_code_key = system_config.response_def_map['StatusCode'] 119 | 120 | @staticmethod 121 | def save_image(uid, label, image_bytes): 122 | if system_config.save_path: 123 | if not os.path.exists(system_config.save_path): 124 | os.makedirs(system_config.save_path) 125 | save_name = "{}_{}.png".format(label, uid) 126 | with open(os.path.join(system_config.save_path, save_name), "wb") as f: 127 | f.write(image_bytes) 128 | 129 | @run_on_executor 130 | def predict(self, interface: Interface, image_batch, split_char): 131 | result = interface.predict_batch(image_batch, split_char) 132 | if 'ARITHMETIC' in interface.model_category: 133 | if '=' in result or '+' in result or '-' in result or '×' in result or '÷' in result: 134 | result = result.replace("×", "*").replace("÷", "/") 135 | result = str(int(arithmetic.calc(result))) 136 | return result 137 | 138 | @staticmethod 139 | def match_blacklist(ip: str): 140 | for black_ip in tornado.options.options.ip_blacklist: 141 | if ip.startswith(black_ip): 142 | return True 143 | return False 144 | 145 | @staticmethod 146 | def match_whitelist(ip: str): 147 | for white_ip in tornado.options.options.ip_whitelist: 148 | if ip.startswith(white_ip): 149 | return True 150 | return False 151 | 152 | async def options(self): 153 | self.set_status(204) 154 | return self.finish() 155 | 156 | @tornado.gen.coroutine 157 | def post(self): 158 | uid = str(uuid.uuid1()) 159 | start_time = time.time() 160 | data = self.parse_param() 161 | request_def_map = system_config.request_def_map 162 | input_data_key = request_def_map['InputData'] 163 | model_name_key = request_def_map['ModelName'] 164 | if input_data_key not in data.keys(): 165 | raise tornado.web.HTTPError(400) 166 | 167 | model_name = ParamUtils.filter(data.get(model_name_key)) 168 | output_split = ParamUtils.filter(data.get('output_split')) 169 | need_color = ParamUtils.filter(data.get('need_color')) 170 | param_key = ParamUtils.filter(data.get('param_key')) 171 | extract_rgb = ParamUtils.filter(data.get('extract_rgb')) 172 | 173 | request_incr = self.request_incr 174 | global_count = self.global_request_incr 175 | request_count = " - Count[{}]".format(request_incr) 176 | log_params = " - ParamKey[{}]".format(param_key) if param_key else "" 177 | log_params += " - NeedColor[{}]".format(need_color) if need_color else "" 178 | 179 | if interface_manager.total == 0: 180 | self.request_desc() 181 | self.global_request_desc() 182 | logger.info('There is currently no model deployment and services are not available.') 183 | return self.finish(json_encode( 184 | {self.uid_key: uid, self.message_key: "", self.status_bool_key: False, self.status_code_key: -999} 185 | )) 186 | bytes_batch, response = self.image_utils.get_bytes_batch(data[input_data_key]) 187 | 188 | if not (opt.low_hour == -1 or opt.up_hour == -1) and not (opt.low_hour <= time.localtime().tm_hour <= opt.up_hour): 189 | logger.info("[{}] - [{} {}] | - Response[{}] - {} ms".format( 190 | uid, self.request.remote_ip, self.request.uri, "Not in open time.", 191 | (time.time() - start_time) * 1000) 192 | ) 193 | return self.finish(json.dumps({ 194 | self.uid_key: uid, 195 | self.message_key: system_config.illegal_time_msg.format(opt.low_hour, opt.up_hour), 196 | self.status_bool_key: False, 197 | self.status_code_key: -250 198 | }, ensure_ascii=False)) 199 | 200 | if not bytes_batch: 201 | logger.error('[{}] - [{} {}] | - Response[{}] - {} ms'.format( 202 | uid, self.request.remote_ip, self.request.uri, response, 203 | (time.time() - start_time) * 1000) 204 | ) 205 | return self.finish(json_encode(response)) 206 | 207 | image_sample = bytes_batch[0] 208 | image_size = ImageUtils.size_of_image(image_sample) 209 | size_string = "{}x{}".format(image_size[0], image_size[1]) 210 | if system_config.request_size_limit and size_string not in system_config.request_size_limit: 211 | self.request_desc() 212 | self.global_request_desc() 213 | logger.info('[{}] - [{} {}] | Size[{}] - [{}][{}] - Error[{}] - {} ms'.format( 214 | uid, self.request.remote_ip, self.request.uri, size_string, global_count, log_params, 215 | "Image size is invalid.", 216 | round((time.time() - start_time) * 1000)) 217 | ) 218 | msg = system_config.request_size_limit.get("msg") 219 | msg = msg if msg else "The size of the picture is wrong. " \ 220 | "Only the original image is supported. " \ 221 | "Please do not take a screenshot!" 222 | return self.finish(json.dumps({ 223 | self.uid_key: uid, 224 | self.message_key: msg, 225 | self.status_bool_key: False, 226 | self.status_code_key: -250 227 | }, ensure_ascii=False)) 228 | 229 | if system_config.use_whitelist: 230 | assert_whitelist = self.match_whitelist(self.request.remote_ip) 231 | if not assert_whitelist: 232 | logger.info('[{}] - [{} {}] | Size[{}]{}{} - Error[{}] - {} ms'.format( 233 | uid, self.request.remote_ip, self.request.uri, size_string, request_count, log_params, 234 | "Whitelist limit", 235 | round((time.time() - start_time) * 1000)) 236 | ) 237 | return self.finish(json.dumps({ 238 | self.uid_key: uid, 239 | self.message_key: "Only allow IP access in the whitelist", 240 | self.status_bool_key: False, 241 | self.status_code_key: -111 242 | }, ensure_ascii=False)) 243 | 244 | if global_request_limit != -1 and global_count > global_request_limit: 245 | logger.info('[{}] - [{} {}] | Size[{}]{}{} - Error[{}] - {} ms'.format( 246 | uid, self.request.remote_ip, self.request.uri, size_string, global_count, log_params, 247 | "Maximum number of requests exceeded (G)", 248 | round((time.time() - start_time) * 1000)) 249 | ) 250 | return self.finish(json.dumps({ 251 | self.uid_key: uid, 252 | self.message_key: system_config.exceeded_msg, 253 | self.status_bool_key: False, 254 | self.status_code_key: -555 255 | }, ensure_ascii=False)) 256 | 257 | assert_blacklist = self.match_blacklist(self.request.remote_ip) 258 | if assert_blacklist: 259 | logger.info('[{}] - [{} {}] | Size[{}]{}{} - Error[{}] - {} ms'.format( 260 | uid, self.request.remote_ip, self.request.uri, size_string, request_count, log_params, 261 | "The ip is on the risk blacklist (IP)", 262 | round((time.time() - start_time) * 1000)) 263 | ) 264 | return self.finish(json.dumps({ 265 | self.uid_key: uid, 266 | self.message_key: system_config.exceeded_msg, 267 | self.status_bool_key: False, 268 | self.status_code_key: -110 269 | }, ensure_ascii=False)) 270 | if request_limit != -1 and request_incr > request_limit: 271 | self.risk_ip_count(self.request.remote_ip) 272 | assert_blacklist_trigger = system_config.blacklist_trigger_times != -1 273 | if self.risk_ip( 274 | self.request.remote_ip) > system_config.blacklist_trigger_times and assert_blacklist_trigger: 275 | if self.request.remote_ip not in blacklist(): 276 | set_blacklist(self.request.remote_ip) 277 | update_blacklist() 278 | logger.info('[{}] - [{} {}] | Size[{}]{}{} - Error[{}] - {} ms'.format( 279 | uid, self.request.remote_ip, self.request.uri, size_string, request_count, log_params, 280 | "Maximum number of requests exceeded (IP)", 281 | round((time.time() - start_time) * 1000)) 282 | ) 283 | return self.finish(json.dumps({ 284 | self.uid_key: uid, 285 | self.message_key: system_config.exceeded_msg, 286 | self.status_bool_key: False, 287 | self.status_code_key: -444 288 | }, ensure_ascii=False)) 289 | if model_name_key in data and data[model_name_key]: 290 | interface: Interface = interface_manager.get_by_name(model_name) 291 | else: 292 | interface: Interface = interface_manager.get_by_size(size_string) 293 | if not interface: 294 | self.request_desc() 295 | self.global_request_desc() 296 | logger.info('Service is not ready!') 297 | return self.finish(json_encode( 298 | {self.uid_key: uid, self.message_key: "", self.status_bool_key: False, self.status_code_key: 999} 299 | )) 300 | 301 | output_split = output_split if 'output_split' in data else interface.model_conf.output_split 302 | 303 | if interface.model_conf.corp_params: 304 | bytes_batch = corp_to_multi.parse_multi_img(bytes_batch, interface.model_conf.corp_params) 305 | if interface.model_conf.pre_freq_frames != -1: 306 | bytes_batch = gif_frames.all_frames(bytes_batch) 307 | exec_map = interface.model_conf.exec_map 308 | if exec_map and len(exec_map.keys()) > 1 and not param_key: 309 | self.request_desc() 310 | self.global_request_desc() 311 | logger.info('[{}] - [{} {}] | [{}] - Size[{}]{}{} - Error[{}] - {} ms'.format( 312 | uid, self.request.remote_ip, self.request.uri, interface.name, size_string, request_count, log_params, 313 | "The model is missing the param_key parameter because the model is configured with ExecuteMap.", 314 | round((time.time() - start_time) * 1000)) 315 | ) 316 | return self.finish(json_encode( 317 | { 318 | self.uid_key: uid, 319 | self.message_key: "Missing the parameter [param_key].", 320 | self.status_bool_key: False, 321 | self.status_code_key: 474 322 | } 323 | )) 324 | elif exec_map and param_key and param_key not in exec_map: 325 | self.request_desc() 326 | self.global_request_desc() 327 | logger.info('[{}] - [{} {}] | [{}] - Size[{}]{}{} - Error[{}] - {} ms'.format( 328 | uid, self.request.remote_ip, self.request.uri, interface.name, size_string, request_count, log_params, 329 | "The param_key parameter is not support in the model.", 330 | round((time.time() - start_time) * 1000)) 331 | ) 332 | return self.finish(json_encode( 333 | { 334 | self.uid_key: uid, 335 | self.message_key: "Not support the parameter [param_key].", 336 | self.status_bool_key: False, 337 | self.status_code_key: 474 338 | } 339 | )) 340 | elif exec_map and len(exec_map.keys()) == 1: 341 | param_key = list(interface.model_conf.exec_map.keys())[0] 342 | 343 | if interface.model_conf.external_model and interface.model_conf.corp_params: 344 | result = [] 345 | len_of_result = [] 346 | pre_corp_num = 0 347 | for corp_param in interface.model_conf.corp_params: 348 | corp_size = corp_param['corp_size'] 349 | corp_num_list = corp_param['corp_num'] 350 | corp_num = corp_num_list[0] * corp_num_list[1] 351 | sub_bytes_batch = bytes_batch[pre_corp_num: pre_corp_num + corp_num] 352 | pre_corp_num = corp_num 353 | size_string = "{}x{}".format(corp_size[0], corp_size[1]) 354 | 355 | sub_interface = interface_manager.get_by_size(size_string) 356 | 357 | image_batch, response = ImageUtils.get_image_batch( 358 | sub_interface.model_conf, sub_bytes_batch, param_key=param_key 359 | ) 360 | 361 | text = yield self.predict( 362 | sub_interface, image_batch, output_split, size_string, start_time, log_params, request_count, 363 | uid=uid 364 | ) 365 | result.append(text) 366 | len_of_result.append(len(result[0].split(sub_interface.model_conf.category_split))) 367 | 368 | response[self.message_key] = interface.model_conf.output_split.join(result) 369 | if interface.model_conf.corp_params and interface.model_conf.output_coord: 370 | # final_result = auxiliary_result + "," + response[self.message_key] 371 | # if auxiliary_result else response[self.message_key] 372 | final_result = response[self.message_key] 373 | response[self.message_key] = corp_to_multi.get_coordinate( 374 | label=final_result, 375 | param_group=interface.model_conf.corp_params, 376 | title_index=[i for i in range(len_of_result[0])] 377 | ) 378 | return self.finish(json.dumps(response, ensure_ascii=False).replace("= len(i) >= interface.model_conf.min_label_num 403 | ] 404 | predict_result = gif_frames.get_continuity_max(predict_result) 405 | # if need_color: 406 | # # only support six label and size [90x35]. 407 | # color_batch = np.resize(image_batch[0], (90, 35, 3)) 408 | # need_index = color_extract.predict_color(image_batch=[color_batch], color=color_map[need_color]) 409 | # predict_result = "".join([v for i, v in enumerate(predict_result) if i in need_index]) 410 | 411 | uid_str = "[{}] - ".format(uid) 412 | logger.info('{}[{} {}] | [{}] - Size[{}]{}{} - Predict[{}] - {} ms'.format( 413 | uid_str, self.request.remote_ip, self.request.uri, interface.name, size_string, request_count, log_params, 414 | predict_result, 415 | round((time.time() - start_time) * 1000)) 416 | ) 417 | response[self.message_key] = predict_result 418 | response[self.uid_key] = uid 419 | self.executor.submit(self.save_image, uid, response[self.message_key], bytes_batch[0]) 420 | if interface.model_conf.corp_params and interface.model_conf.output_coord: 421 | # final_result = auxiliary_result + "," + response[self.message_key] 422 | # if auxiliary_result else response[self.message_key] 423 | final_result = response[self.message_key] 424 | response[self.message_key] = corp_to_multi.get_coordinate( 425 | label=final_result, 426 | param_group=interface.model_conf.corp_params, 427 | title_index=[0] 428 | ) 429 | return self.finish(json.dumps(response, ensure_ascii=False).replace(" 1: 477 | logger.info('[{}] - [{} {}] | [{}] - Size[{}] - Error[{}] - {} ms'.format( 478 | uid, self.request.remote_ip, self.request.uri, interface.name, size_string, 479 | "The model is configured with ExecuteMap, but the api do not support this param.", 480 | round((time.time() - start_time) * 1000)) 481 | ) 482 | return self.finish(json_encode( 483 | { 484 | self.message_key: "the api do not support [ExecuteMap].", 485 | self.status_bool_key: False, 486 | self.status_code_key: 474 487 | } 488 | )) 489 | elif exec_map and len(exec_map.keys()) == 1: 490 | param_key = list(interface.model_conf.exec_map.keys())[0] 491 | 492 | image_batch, response = ImageUtils.get_image_batch(interface.model_conf, bytes_batch, param_key=param_key) 493 | 494 | if not image_batch: 495 | logger.error('[{}] - [{}] | [{}] - Size[{}] - Response[{}] - {} ms'.format( 496 | uid, self.request.remote_ip, interface.name, size_string, response, 497 | (time.time() - start_time) * 1000) 498 | ) 499 | return self.finish(json_encode(response)) 500 | 501 | result = interface.predict_batch(image_batch, None) 502 | logger.info('[{}] - [{}] | [{}] - Size[{}] - Predict[{}] - {} ms'.format( 503 | uid, self.request.remote_ip, interface.name, size_string, result, (time.time() - start_time) * 1000) 504 | ) 505 | response[self.uid_key] = uid 506 | response[self.message_key] = result 507 | return self.write(json.dumps(response, ensure_ascii=False).replace("'.format(server_host, server_port)) 598 | app = make_app(system_config.route_map) 599 | http_server = tornado.httpserver.HTTPServer(app) 600 | http_server.bind(server_port, server_host) 601 | http_server.start(1) 602 | tornado.ioloop.IOLoop.instance().start() 603 | -------------------------------------------------------------------------------- /tornado_server.spec: -------------------------------------------------------------------------------- 1 | # -*- mode: python -*- 2 | # Used to package as a single executable 3 | # This is a configuration file 4 | import cv2 5 | import os 6 | from PyInstaller.utils.hooks import collect_all 7 | 8 | 9 | block_cipher = None 10 | 11 | binaries, hiddenimports = [], ['numpy.core._dtype_ctypes'] 12 | tmp_ret = collect_all('tzdata') 13 | added_files = [('resource/icon.ico', 'resource'), ('resource/favorite.ico', '.'), ('resource/VERSION', 'astor'), ('resource/VERSION', '.')] 14 | added_files += tmp_ret[0]; binaries += tmp_ret[1] 15 | hiddenimports += tmp_ret[2] 16 | 17 | a = Analysis(['tornado_server.py'], 18 | pathex=['.', os.path.join(os.path.dirname(cv2.__file__), 'config-3.9')], 19 | binaries=binaries, 20 | datas=added_files, 21 | hiddenimports=hiddenimports, 22 | hookspath=[], 23 | runtime_hooks=[], 24 | excludes=[], 25 | win_no_prefer_redirects=False, 26 | win_private_assemblies=False, 27 | cipher=block_cipher) 28 | pyz = PYZ(a.pure, a.zipped_data, 29 | cipher=block_cipher) 30 | exe = EXE(pyz, 31 | a.scripts, 32 | a.binaries, 33 | a.zipfiles, 34 | a.datas, 35 | [], 36 | name='captcha_platform_tornado', 37 | debug=False, 38 | strip=False, 39 | upx=True, 40 | runtime_tmpdir=None, 41 | console=True, 42 | icon='resource/icon.ico') 43 | -------------------------------------------------------------------------------- /tornado_server_gpu.spec: -------------------------------------------------------------------------------- 1 | # -*- mode: python -*- 2 | # Used to package as a single executable 3 | # This is a configuration file 4 | 5 | block_cipher = None 6 | 7 | added_files = [('resource/icon.ico', 'resource'), ('resource/VERSION', 'astor')] 8 | 9 | a = Analysis(['tornado_server.py'], 10 | pathex=['.'], 11 | binaries=[], 12 | datas=added_files, 13 | hiddenimports=['numpy.core._dtype_ctypes'], 14 | hookspath=[], 15 | runtime_hooks=[], 16 | excludes=[], 17 | win_no_prefer_redirects=False, 18 | win_private_assemblies=False, 19 | cipher=block_cipher) 20 | pyz = PYZ(a.pure, a.zipped_data, 21 | cipher=block_cipher) 22 | exe = EXE(pyz, 23 | a.scripts, 24 | [], 25 | exclude_binaries=True, 26 | name='captcha_platform_tornado_gpu', 27 | debug=False, 28 | bootloader_ignore_signals=False, 29 | strip=False, 30 | upx=True, 31 | console=True, 32 | icon='resource/icon.ico') 33 | coll = COLLECT(exe, 34 | a.binaries, 35 | a.zipfiles, 36 | a.datas, 37 | strip=False, 38 | upx=True, 39 | upx_exclude=[], 40 | name='captcha_platform_tornado_gpu') 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import io 5 | import re 6 | import os 7 | import cv2 8 | import time 9 | import base64 10 | import functools 11 | import binascii 12 | import datetime 13 | import hashlib 14 | import numpy as np 15 | import tensorflow as tf 16 | from PIL import Image as PIL_Image 17 | from constants import Response, SystemConfig 18 | from pretreatment import preprocessing, preprocessing_by_func 19 | from config import ModelConfig, Config 20 | from middleware.impl.gif_frames import concat_frames, blend_frame 21 | from middleware.impl.rgb_filter import rgb_filter 22 | 23 | 24 | class Arithmetic(object): 25 | 26 | def calc(self, formula): 27 | formula = re.sub(' ', '', formula) 28 | formula_ret = 0 29 | match_brackets = re.search(r'\([^()]+\)', formula) 30 | if match_brackets: 31 | calc_result = self.calc(match_brackets.group().strip("(,)")) 32 | formula = formula.replace(match_brackets.group(), str(calc_result)) 33 | return self.calc(formula) 34 | else: 35 | formula = formula.replace('--', '+').replace('++', '+').replace('-+', '-').replace('+-', '-') 36 | while re.findall(r"[*/]", formula): 37 | get_formula = re.search(r"[.\d]+[*/]+[-]?[.\d]+", formula) 38 | if get_formula: 39 | get_formula_str = get_formula.group() 40 | if get_formula_str.count("*"): 41 | formula_list = get_formula_str.split("*") 42 | ret = float(formula_list[0]) * float(formula_list[1]) 43 | else: 44 | formula_list = get_formula_str.split("/") 45 | ret = float(formula_list[0]) / float(formula_list[1]) 46 | formula = formula.replace(get_formula_str, str(ret)).replace('--', '+').replace('++', '+') 47 | formula = re.findall(r'[-]?[.\d]+', formula) 48 | for num in formula: 49 | formula_ret += float(num) 50 | return formula_ret 51 | 52 | 53 | class ParamUtils(object): 54 | 55 | @staticmethod 56 | def filter(param): 57 | if isinstance(param, list) and len(param) > 0 and isinstance(param[0], bytes): 58 | return param[0].decode() 59 | return param 60 | 61 | 62 | class SignUtils(object): 63 | 64 | @staticmethod 65 | def md5(text): 66 | return hashlib.md5(text.encode('utf-8')).hexdigest() 67 | 68 | @staticmethod 69 | def timestamp(): 70 | return int(time.mktime(datetime.datetime.now().timetuple())) 71 | 72 | 73 | class PathUtils(object): 74 | 75 | @staticmethod 76 | def get_file_name(path: str): 77 | if '/' in path: 78 | return path.split('/')[-1] 79 | elif '\\' in path: 80 | return path.split('\\')[-1] 81 | else: 82 | return path 83 | 84 | 85 | class ImageUtils(object): 86 | 87 | def __init__(self, conf: Config): 88 | self.conf = conf 89 | 90 | def get_bytes_batch(self, base64_or_bytes): 91 | response = Response(self.conf.response_def_map) 92 | b64_filter_s = lambda s: re.sub("data:image/.+?base64,", "", s, 1) if ',' in s else s 93 | b64_filter_b = lambda s: re.sub(b"data:image/.+?base64,", b"", s, 1) if b',' in s else s 94 | try: 95 | if isinstance(base64_or_bytes, bytes): 96 | if self.conf.split_flag in base64_or_bytes: 97 | bytes_batch = base64_or_bytes.split(self.conf.split_flag) 98 | else: 99 | bytes_batch = [base64_or_bytes] 100 | elif isinstance(base64_or_bytes, list): 101 | bytes_batch = [base64.b64decode(b64_filter_s(i).encode('utf-8')) for i in base64_or_bytes if 102 | isinstance(i, str)] 103 | if not bytes_batch: 104 | bytes_batch = [base64.b64decode(b64_filter_b(i)) for i in base64_or_bytes if isinstance(i, bytes)] 105 | else: 106 | base64_or_bytes = b64_filter_s(base64_or_bytes) 107 | bytes_batch = base64.b64decode(base64_or_bytes.encode('utf-8')).split(self.conf.split_flag) 108 | 109 | except binascii.Error: 110 | return None, response.INVALID_BASE64_STRING 111 | what_img = [ImageUtils.test_image(i) for i in bytes_batch] 112 | 113 | if None in what_img: 114 | return None, response.INVALID_IMAGE_FORMAT 115 | return bytes_batch, response.SUCCESS 116 | 117 | @staticmethod 118 | def get_image_batch(model: ModelConfig, bytes_batch, param_key=None, extract_rgb: list = None): 119 | # Note that there are two return objects here. 120 | # 1.image_batch, 2.response 121 | 122 | response = Response(model.conf.response_def_map) 123 | 124 | def load_image(image_bytes: bytes): 125 | data_stream = io.BytesIO(image_bytes) 126 | pil_image = PIL_Image.open(data_stream) 127 | 128 | gif_handle = model.pre_concat_frames != -1 or model.pre_blend_frames != -1 129 | 130 | if pil_image.mode == 'P' and not gif_handle: 131 | pil_image = pil_image.convert('RGB') 132 | 133 | rgb = pil_image.split() 134 | size = pil_image.size 135 | 136 | if (len(rgb) > 3 and model.pre_replace_transparent) and not gif_handle: 137 | background = PIL_Image.new('RGB', pil_image.size, (255, 255, 255)) 138 | try: 139 | background.paste(pil_image, (0, 0, size[0], size[1]), pil_image) 140 | pil_image = background 141 | except: 142 | pil_image = pil_image.convert('RGB') 143 | 144 | if len(pil_image.split()) > 3 and model.image_channel == 3: 145 | pil_image = pil_image.convert('RGB') 146 | 147 | if model.pre_concat_frames != -1: 148 | im = concat_frames(pil_image, model.pre_concat_frames) 149 | elif model.pre_blend_frames != -1: 150 | im = blend_frame(pil_image, model.pre_blend_frames) 151 | else: 152 | im = np.asarray(pil_image) 153 | 154 | if extract_rgb: 155 | im = rgb_filter(im, extract_rgb) 156 | 157 | im = preprocessing_by_func( 158 | exec_map=model.exec_map, 159 | key=param_key, 160 | src_arr=im 161 | ) 162 | 163 | if model.image_channel == 1 and len(im.shape) == 3: 164 | im = cv2.cvtColor(im, cv2.COLOR_RGB2GRAY) 165 | 166 | im = preprocessing( 167 | image=im, 168 | binaryzation=model.pre_binaryzation, 169 | ) 170 | 171 | if model.pre_horizontal_stitching: 172 | up_slice = im[0: int(size[1] / 2), 0: size[0]] 173 | down_slice = im[int(size[1] / 2): size[1], 0: size[0]] 174 | im = np.concatenate((up_slice, down_slice), axis=1) 175 | 176 | image = im.astype(np.float32) 177 | if model.resize[0] == -1: 178 | ratio = model.resize[1] / size[1] 179 | resize_width = int(ratio * size[0]) 180 | image = cv2.resize(image, (resize_width, model.resize[1])) 181 | else: 182 | image = cv2.resize(image, (model.resize[0], model.resize[1])) 183 | image = image.swapaxes(0, 1) 184 | return (image[:, :, np.newaxis] if model.image_channel == 1 else image[:, :]) / 255. 185 | 186 | try: 187 | image_batch = [load_image(i) for i in bytes_batch] 188 | return image_batch, response.SUCCESS 189 | except OSError: 190 | return None, response.IMAGE_DAMAGE 191 | except ValueError as _e: 192 | print(_e) 193 | return None, response.IMAGE_SIZE_NOT_MATCH_GRAPH 194 | 195 | @staticmethod 196 | def size_of_image(image_bytes: bytes): 197 | _null_size = tuple((-1, -1)) 198 | try: 199 | data_stream = io.BytesIO(image_bytes) 200 | size = PIL_Image.open(data_stream).size 201 | return size 202 | except OSError: 203 | return _null_size 204 | except ValueError: 205 | return _null_size 206 | 207 | @staticmethod 208 | def test_image(h): 209 | """JPEG""" 210 | if h[:3] == b"\xff\xd8\xff": 211 | return 'jpeg' 212 | """PNG""" 213 | if h[:8] == b"\211PNG\r\n\032\n": 214 | return 'png' 215 | """GIF ('87 and '89 variants)""" 216 | if h[:6] in (b'GIF87a', b'GIF89a'): 217 | return 'gif' 218 | """TIFF (can be in Motorola or Intel byte order)""" 219 | if h[:2] in (b'MM', b'II'): 220 | return 'tiff' 221 | if h[:2] == b'BM': 222 | return 'bmp' 223 | """SGI image library""" 224 | if h[:2] == b'\001\332': 225 | return 'rgb' 226 | """PBM (portable bitmap)""" 227 | if len(h) >= 3 and \ 228 | h[0] == b'P' and h[1] in b'14' and h[2] in b' \t\n\r': 229 | return 'pbm' 230 | """PGM (portable graymap)""" 231 | if len(h) >= 3 and \ 232 | h[0] == b'P' and h[1] in b'25' and h[2] in b' \t\n\r': 233 | return 'pgm' 234 | """PPM (portable pixmap)""" 235 | if len(h) >= 3 and h[0] == b'P' and h[1] in b'36' and h[2] in b' \t\n\r': 236 | return 'ppm' 237 | """Sun raster file""" 238 | if h[:4] == b'\x59\xA6\x6A\x95': 239 | return 'rast' 240 | """X bitmap (X10 or X11)""" 241 | s = b'#define ' 242 | if h[:len(s)] == s: 243 | return 'xbm' 244 | return None 245 | 246 | 247 | class SystemUtils(object): 248 | 249 | @staticmethod 250 | def datetime(origin=None, microseconds=None): 251 | now = origin if origin else time.time() 252 | if microseconds: 253 | return ( 254 | datetime.datetime.fromtimestamp(now) + datetime.timedelta(microseconds=microseconds) 255 | ).strftime('%Y-%m-%d %H:%M:%S.%f') 256 | return datetime.datetime.fromtimestamp(now).strftime('%Y-%m-%d %H:%M:%S.%f') 257 | 258 | @staticmethod 259 | def isdir(sftp, path): 260 | from stat import S_ISDIR 261 | try: 262 | return S_ISDIR(sftp.stat(path).st_mode) 263 | except IOError: 264 | return False 265 | 266 | @staticmethod 267 | def empty(sftp, path): 268 | from paramiko import SFTPClient 269 | if not SystemUtils.isdir(sftp, path): 270 | sftp.mkdir(path) 271 | 272 | files = sftp.listdir(path=path) 273 | 274 | for f in files: 275 | file_path = os.path.join(path, f) 276 | if SystemUtils.isdir(sftp, file_path): 277 | SystemUtils.empty(sftp, file_path) 278 | else: 279 | sftp.remove(file_path) 280 | --------------------------------------------------------------------------------