├── .dockerignore
├── .gitignore
├── .pip.conf
├── .travis.yml
├── Dockerfile
├── LICENSE
├── README.md
├── category.py
├── compat
├── flask_server.py
├── flask_server.spec
├── grpc.proto
├── grpc_pb2.py
├── grpc_pb2_grpc.py
├── grpc_server.py
├── grpc_server.spec
└── sanic_server.py
├── config.py
├── constants.py
├── demo.py
├── deploy.conf.py
├── event_handler.py
├── event_loop.py
├── graph_session.py
├── interface.py
├── middleware
├── __init__.py
├── constructor
├── impl
│ ├── color_extractor.py
│ ├── color_filter.py
│ ├── corp_to_multi.py
│ ├── gif_frames.py
│ └── rgb_filter.py
└── resource
│ ├── __init__.py
│ └── color_filter.py
├── package.py
├── predict.py
├── pretreatment.py
├── requirements.txt
├── resource
├── VERSION
├── favorite.ico
└── icon.ico
├── sdk
├── __init__.py
├── onnx
│ ├── __init__.py
│ ├── requirements.txt
│ └── sdk.py
├── pb
│ ├── __init__.py
│ ├── requirements.txt
│ └── sdk.py
└── tflite
│ ├── __init__.py
│ ├── requirements.txt
│ └── sdk.py
├── signature.py
├── test.py
├── tornado_server.py
├── tornado_server.spec
├── tornado_server_gpu.spec
└── utils.py
/.dockerignore:
--------------------------------------------------------------------------------
1 | # .dockerignore 文件中指定在传递给 docker引擎 时需要忽略掉的文件或文件夹。
2 |
3 | .git/
4 | .idea/
5 | patchca_ubuntu/
6 | __pycache__/
7 | .DS_Store
8 | venv/
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by .ignore support plugin (hsz.mobi)
2 | ### Python template
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 | # C extensions
8 | *.so
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | # PyInstaller
28 | # Usually these files are written by a python script from a template
29 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
30 | *.manifest
31 | # Installer logs
32 | pip-log.txt
33 | pip-delete-this-directory.txt
34 | # Unit test / coverage reports
35 | htmlcov/
36 | .tox/
37 | .coverage
38 | .coverage.*
39 | .cache
40 | nosetests.xml
41 | coverage.xml
42 | *,cover
43 | .hypothesis/
44 | # Translations
45 | *.mo
46 | *.pot
47 | # Django stuff:
48 | *.log
49 | local_settings.py
50 | # Flask stuff:
51 | instance/
52 | .webassets-cache
53 | # Scrapy stuff:
54 | .scrapy
55 | # Sphinx documentation
56 | docs/_build/
57 | # PyBuilder
58 | target/
59 | # Jupyter Notebook
60 | .ipynb_checkpoints
61 | # pyenv
62 | .python-version
63 | # celery beat schedule file
64 | celerybeat-schedule
65 | # SageMath parsed files
66 | *.sage.py
67 | # dotenv
68 | .env
69 | # virtualenv
70 | .venv
71 | venv/
72 | ENV/
73 | # Spyder project settings
74 | .spyderproject
75 | # Rope project settings
76 | .ropeproject
77 | # jetbrains IDE项目配置文件
78 | .idea/
79 | # image file
80 | *.jpg
81 | *.png
82 | *.gif
83 | # linux系统生成的文件
84 | *.pid
85 | nohup.out
86 | graph/*
87 | model/*
88 | config.yaml
--------------------------------------------------------------------------------
/.pip.conf:
--------------------------------------------------------------------------------
1 | [global]
2 | index-url = https://pypi.doubanio.com/simple/
3 | trusted-host = pypi.doubanio.com
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 |
3 | python:
4 | - "3.6"
5 |
6 | install:
7 | - pip install -r requirements.txt
8 |
9 | script:
10 | python test.py
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.6.8-stretch as builder
2 |
3 | ADD . /app/
4 |
5 | WORKDIR /app/
6 |
7 | COPY requirements.txt /app/
8 |
9 | # timezone
10 | ENV TZ=Asia/Shanghai
11 |
12 | RUN pip install --no-cache-dir --upgrade pip \
13 | && pip install --no-cache-dir -r requirements.txt
14 |
15 |
16 | ENTRYPOINT ["python3", "tornado_server.py"]
17 | EXPOSE 19952
18 | # run command:
19 | # docker run -d -p 19952:19952 [image:tag]
20 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The Star And Thank Author License (SATA)
2 |
3 | Copyright (c) 2018 kelromz(kerlomz@gmail.com)
4 |
5 | Project Url: https://github.com/kerlomz/captcha_platform
6 |
7 | Permission is hereby granted, free of charge, to any person obtaining a copy
8 | of this software and associated documentation files (the "Software"), to deal
9 | in the Software without restriction, including without limitation the rights
10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | copies of the Software, and to permit persons to whom the Software is
12 | furnished to do so, subject to the following conditions:
13 |
14 | The above copyright notice and this permission notice shall be included in
15 | all copies or substantial portions of the Software.
16 |
17 | And wait, the most important, you shall star/+1/like the project(s) in project url
18 | section above first, and then thank the author(s) in Copyright section.
19 |
20 | Here are some suggested ways:
21 |
22 | - Email the authors a thank-you letter, and make friends with him/her/them.
23 | - Report bugs or issues.
24 | - Tell friends what a wonderful project this is.
25 | - And, sure, you can just express thanks in your mind without telling the world.
26 |
27 | Contributors of this project by forking have the option to add his/her name and
28 | forked project url at copyright and project url sections, but shall not delete
29 | or modify anything else in these two sections.
30 |
31 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
34 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
36 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
37 | THE SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | [](https://travis-ci.org/kerlomz/captcha_platform)
3 |
4 | # Project Introduction
5 | This project is based on CNN+BLSTM+CTC to realize verification code identification.
6 | This project is only for deployment models, If you need to train the model, please move to https://github.com/kerlomz/captcha_trainer
7 |
8 | # Informed
9 | 1. The default requirements.txt will install CPU version, Change "requirements.txt" from "TensorFlow" to "TensorFlow-GPU" to Switch to GPU version, Use the GPU version to install the corresponding CUDA and cuDNN.
10 | 2. demo.py: An example of how to call a prediction method.
11 | 3. The model folder folder is used to store model configuration files such as model.yaml.
12 | 4. The graph folder is used to store compiled models such as model.pb
13 | 5. The deployment service will automatically load all the models in the model configuration. When a new model configuration is added, the corresponding compilation model in the graph folder will be automatically loaded, so if you need to add it, please copy the corresponding compilation model to the graph path first, then add the model configuration.
14 |
15 |
16 | # Start
17 | 1. Install the python 3.9 environment (with pip)
18 | 2. Install virtualenv ```pip3 install virtualenv```
19 | 3. Create a separate virtual environment for the project:
20 | ```bash
21 | virtualenv -p /usr/bin/python3 venv # venv is the name of the virtual environment.
22 | cd venv/ # venv is the name of the virtual environment.
23 | source bin/activate # to activate the current virtual environment.
24 | cd captcha_platform # captcha_platform is the project path.
25 | ```
26 | 4. ```pip install -r requirements.txt```
27 | 5. Place your trained model.yaml in model folder, and your model.pb in graph folder (create if not exist)
28 | 6. Deploy as follows.
29 |
30 | ## 1. Http Version
31 | 1. Linux
32 | Deploy (Linux/Mac):
33 |
34 | Port: 19952
35 | ```
36 | python tornado_server.py
37 | ```
38 |
39 | 2. Windows
40 | Deploy (Windows):
41 | ```
42 | python xxx_server.py
43 | ```
44 |
45 | 3. Request
46 |
47 | |Request URI | Content-Type | Payload Type | Method |
48 | | ----------- | ---------------- | -------- | -------- |
49 | | http://localhost:[Bind-port]/captcha/v1 | application/json | JSON | POST |
50 |
51 | | Parameter | Required | Type | Description |
52 | | ---------- | ---- | ------ | ------------------------ |
53 | | image | Yes | String | Base64 encoding binary stream |
54 | | model_name | No | String | ModelName, bindable in yaml configuration |
55 |
56 |
57 | The request is in JSON format, like: {"image": "base64 encoded image binary stream"}
58 |
59 | 4. Response
60 |
61 | | Parameter Name | Type | Description |
62 | | ------- | ------ | ------------------ |
63 | | message | String | Identify results or error messages |
64 | | code | String | Status Code |
65 | | success | String | Whether to request success |
66 |
67 | The return is in JSON format, like: {"message": "xxxx", "code": 0, "success": true}
68 |
69 |
70 | ## 2. G-RPC Version
71 | Deploy:
72 | ```
73 | python3 grpc_server.py
74 | ```
75 | Port: 50054
76 |
77 |
78 | # Update G-RPC-CODE
79 | python -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. ./grpc.proto
80 |
81 |
82 | # Directory Structure
83 |
84 | - captcha_platform
85 | - grpc_server.py
86 | - flask_server.py
87 | - tornado_server.py
88 | - sanic_server.py
89 | - demo.py
90 | - config.yaml
91 | - model
92 | - model-1.yaml
93 | - model-2.yaml
94 | - ...
95 | - graph
96 | - Model-1.pb
97 | - ...
98 |
99 | # Management Model
100 | 1. **Load a model**
101 | - Put the trained pb model in the graph folder.
102 | - Put the trained yaml model configuration file in the model folder.
103 | 2. **Unload a model**
104 | - Delete the corresponding yaml configuration file in the model folder.
105 | - Delete the corresponding pb model file in the graph folder.
106 | 3. **Update a model**
107 | - Put the trained pb model in the graph folder.
108 | - Put the yaml configuration file with "Version" greater than the current version in the model folder.
109 | - Delete old models and configurations.
110 |
111 | # License
112 | This project use SATA License (Star And Thank Author License), so you have to star this project before using. Read the license carefully.
113 |
114 | # Introduction
115 | https://www.jianshu.com/p/80ef04b16efc
116 |
117 | # Donate
118 | Thank you very much for your support of my project.
--------------------------------------------------------------------------------
/compat/flask_server.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # Author: kerlomz
4 | import time
5 | import optparse
6 | import threading
7 | from flask import *
8 | from flask_caching import Cache
9 | from gevent.pywsgi import WSGIServer
10 | from geventwebsocket.handler import WebSocketHandler
11 | from config import Config
12 | from utils import ImageUtils
13 | from constants import Response
14 |
15 | from interface import InterfaceManager
16 | from signature import Signature, ServerType
17 | from watchdog.observers import Observer
18 | from event_handler import FileEventHandler
19 | from middleware import *
20 | # The order cannot be changed, it must be before the flask.
21 |
22 | app = Flask(__name__)
23 | cache = Cache(app, config={'CACHE_TYPE': 'simple'})
24 |
25 | conf_path = '../config.yaml'
26 | model_path = '../model'
27 | graph_path = '../graph'
28 |
29 |
30 | system_config = Config(conf_path=conf_path, model_path=model_path, graph_path=graph_path)
31 | sign = Signature(ServerType.FLASK, system_config)
32 | _except = Response(system_config.response_def_map)
33 | route_map = {i['Class']: i['Route'] for i in system_config.route_map}
34 | sign.set_auth([{'accessKey': system_config.access_key, 'secretKey': system_config.secret_key}])
35 | logger = system_config.logger
36 | interface_manager = InterfaceManager()
37 | image_utils = ImageUtils(system_config)
38 |
39 |
40 | @app.after_request
41 | def after_request(response):
42 | response.headers['Access-Control-Allow-Origin'] = '*'
43 | return response
44 |
45 |
46 | @app.errorhandler(400)
47 | def server_error(error=None):
48 | message = "Bad Request"
49 | return jsonify(message=message, code=error.code, success=False)
50 |
51 |
52 | @app.errorhandler(500)
53 | def server_error(error=None):
54 | message = 'Internal Server Error'
55 | return jsonify(message=message, code=500, success=False)
56 |
57 |
58 | @app.errorhandler(404)
59 | def not_found(error=None):
60 | message = '404 Not Found'
61 | return jsonify(message=message, code=error.code, success=False)
62 |
63 |
64 | @app.errorhandler(403)
65 | def permission_denied(error=None):
66 | message = 'Forbidden'
67 | return jsonify(message=message, code=error.code, success=False)
68 |
69 |
70 | @app.route(route_map['AuthHandler'], methods=['POST'])
71 | @sign.signature_required # This decorator is required for certification.
72 | def auth_request():
73 | return common_request()
74 |
75 |
76 | @app.route(route_map['NoAuthHandler'], methods=['POST'])
77 | def no_auth_request():
78 | return common_request()
79 |
80 |
81 | def common_request():
82 | """
83 | This api is used for captcha prediction without authentication
84 | :return:
85 | """
86 | start_time = time.time()
87 | if not request.json or 'image' not in request.json:
88 | abort(400)
89 |
90 | if interface_manager.total == 0:
91 | logger.info('There is currently no model deployment and services are not available.')
92 | return json.dumps({"message": "", "success": False, "code": -999})
93 |
94 | bytes_batch, response = image_utils.get_bytes_batch(request.json['image'])
95 |
96 | if not bytes_batch:
97 | logger.error('Name[{}] - Response[{}] - {} ms'.format(
98 | request.json.get('model_site'), response,
99 | (time.time() - start_time) * 1000)
100 | )
101 | return json.dumps(response), 200
102 |
103 | image_sample = bytes_batch[0]
104 | image_size = ImageUtils.size_of_image(image_sample)
105 | size_string = "{}x{}".format(image_size[0], image_size[1])
106 |
107 | if 'model_name' in request.json:
108 | interface = interface_manager.get_by_name(request.json['model_name'])
109 | else:
110 | interface = interface_manager.get_by_size(size_string)
111 |
112 | split_char = request.json['output_split'] if 'output_split' in request.json else interface.model_conf.output_split
113 |
114 | if 'need_color' in request.json and request.json['need_color']:
115 | bytes_batch = [color_extract.separate_color(_, color_map[request.json['need_color']]) for _ in bytes_batch]
116 |
117 | image_batch, response = ImageUtils.get_image_batch(interface.model_conf, bytes_batch)
118 |
119 | if not image_batch:
120 | logger.error('[{}] - Size[{}] - Name[{}] - Response[{}] - {} ms'.format(
121 | interface.name, size_string, request.json.get('model_name'), response,
122 | (time.time() - start_time) * 1000)
123 | )
124 | return json.dumps(response), 200
125 |
126 | result = interface.predict_batch(image_batch, split_char)
127 | logger.info('[{}] - Size[{}] - Name[{}] - Predict Result[{}] - {} ms'.format(
128 | interface.name,
129 | size_string,
130 | request.json.get('model_name'),
131 | result,
132 | (time.time() - start_time) * 1000
133 | ))
134 | response['message'] = result
135 | return json.dumps(response), 200
136 |
137 |
138 | def event_loop():
139 | event = threading.Event()
140 | observer = Observer()
141 | event_handler = FileEventHandler(system_config, model_path, interface_manager)
142 | observer.schedule(event_handler, event_handler.model_conf_path, True)
143 | observer.start()
144 | try:
145 | while True:
146 | event.wait(1)
147 | except KeyboardInterrupt:
148 | observer.stop()
149 | observer.join()
150 |
151 |
152 | threading.Thread(target=event_loop, daemon=True).start()
153 |
154 | if __name__ == "__main__":
155 |
156 | parser = optparse.OptionParser()
157 | parser.add_option('-p', '--port', type="int", default=19951, dest="port")
158 |
159 | opt, args = parser.parse_args()
160 | server_port = opt.port
161 |
162 | server_host = "0.0.0.0"
163 |
164 | logger.info('Running on http://{}:{}/ '.format(server_host, server_port))
165 | server = WSGIServer((server_host, server_port), app, handler_class=WebSocketHandler)
166 | try:
167 | server.serve_forever()
168 | except KeyboardInterrupt:
169 | server.stop()
170 |
--------------------------------------------------------------------------------
/compat/flask_server.spec:
--------------------------------------------------------------------------------
1 | # -*- mode: python -*-
2 | # Used to package as a single executable
3 | # This is a configuration file
4 |
5 | block_cipher = pyi_crypto.PyiBlockCipher(key='kerlomz&coriander')
6 |
7 | added_files = [('resource/icon.ico', 'resource')]
8 |
9 | a = Analysis(['flask_server.py'],
10 | pathex=['.'],
11 | binaries=[],
12 | datas=added_files,
13 | hiddenimports=[],
14 | hookspath=[],
15 | runtime_hooks=[],
16 | excludes=[],
17 | win_no_prefer_redirects=False,
18 | win_private_assemblies=False,
19 | cipher=block_cipher)
20 | pyz = PYZ(a.pure, a.zipped_data,
21 | cipher=block_cipher)
22 | exe = EXE(pyz,
23 | a.scripts,
24 | a.binaries,
25 | a.zipfiles,
26 | a.datas,
27 | name='captcha_platform_flask',
28 | debug=False,
29 | strip=False,
30 | upx=True,
31 | runtime_tmpdir=None,
32 | console=True,
33 | icon='resource/icon.ico')
34 |
--------------------------------------------------------------------------------
/compat/grpc.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | service Predict {
4 | rpc predict (PredictRequest) returns (PredictResult) {}
5 | }
6 |
7 | message PredictRequest {
8 | string image = 1;
9 | string split_char = 2;
10 | string model_name = 3;
11 | string model_type = 4;
12 | string model_site = 5;
13 | string need_color = 6;
14 | }
15 |
16 | message PredictResult {
17 | string result = 1;
18 | int32 code = 2;
19 | bool success = 3;
20 | }
--------------------------------------------------------------------------------
/compat/grpc_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: grpc.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | # @@protoc_insertion_point(imports)
11 |
12 | _sym_db = _symbol_database.Default()
13 |
14 |
15 |
16 |
17 | DESCRIPTOR = _descriptor.FileDescriptor(
18 | name='grpc.proto',
19 | package='',
20 | syntax='proto3',
21 | serialized_options=None,
22 | serialized_pb=_b('\n\ngrpc.proto\"\x83\x01\n\x0ePredictRequest\x12\r\n\x05image\x18\x01 \x01(\t\x12\x12\n\nsplit_char\x18\x02 \x01(\t\x12\x12\n\nmodel_name\x18\x03 \x01(\t\x12\x12\n\nmodel_type\x18\x04 \x01(\t\x12\x12\n\nmodel_site\x18\x05 \x01(\t\x12\x12\n\nneed_color\x18\x06 \x01(\t\">\n\rPredictResult\x12\x0e\n\x06result\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\x05\x12\x0f\n\x07success\x18\x03 \x01(\x08\x32\x37\n\x07Predict\x12,\n\x07predict\x12\x0f.PredictRequest\x1a\x0e.PredictResult\"\x00\x62\x06proto3')
23 | )
24 |
25 |
26 |
27 |
28 | _PREDICTREQUEST = _descriptor.Descriptor(
29 | name='PredictRequest',
30 | full_name='PredictRequest',
31 | filename=None,
32 | file=DESCRIPTOR,
33 | containing_type=None,
34 | fields=[
35 | _descriptor.FieldDescriptor(
36 | name='image', full_name='PredictRequest.image', index=0,
37 | number=1, type=9, cpp_type=9, label=1,
38 | has_default_value=False, default_value=_b("").decode('utf-8'),
39 | message_type=None, enum_type=None, containing_type=None,
40 | is_extension=False, extension_scope=None,
41 | serialized_options=None, file=DESCRIPTOR),
42 | _descriptor.FieldDescriptor(
43 | name='split_char', full_name='PredictRequest.split_char', index=1,
44 | number=2, type=9, cpp_type=9, label=1,
45 | has_default_value=False, default_value=_b("").decode('utf-8'),
46 | message_type=None, enum_type=None, containing_type=None,
47 | is_extension=False, extension_scope=None,
48 | serialized_options=None, file=DESCRIPTOR),
49 | _descriptor.FieldDescriptor(
50 | name='model_name', full_name='PredictRequest.model_name', index=2,
51 | number=3, type=9, cpp_type=9, label=1,
52 | has_default_value=False, default_value=_b("").decode('utf-8'),
53 | message_type=None, enum_type=None, containing_type=None,
54 | is_extension=False, extension_scope=None,
55 | serialized_options=None, file=DESCRIPTOR),
56 | _descriptor.FieldDescriptor(
57 | name='model_type', full_name='PredictRequest.model_type', index=3,
58 | number=4, type=9, cpp_type=9, label=1,
59 | has_default_value=False, default_value=_b("").decode('utf-8'),
60 | message_type=None, enum_type=None, containing_type=None,
61 | is_extension=False, extension_scope=None,
62 | serialized_options=None, file=DESCRIPTOR),
63 | _descriptor.FieldDescriptor(
64 | name='model_site', full_name='PredictRequest.model_site', index=4,
65 | number=5, type=9, cpp_type=9, label=1,
66 | has_default_value=False, default_value=_b("").decode('utf-8'),
67 | message_type=None, enum_type=None, containing_type=None,
68 | is_extension=False, extension_scope=None,
69 | serialized_options=None, file=DESCRIPTOR),
70 | _descriptor.FieldDescriptor(
71 | name='need_color', full_name='PredictRequest.need_color', index=5,
72 | number=6, type=9, cpp_type=9, label=1,
73 | has_default_value=False, default_value=_b("").decode('utf-8'),
74 | message_type=None, enum_type=None, containing_type=None,
75 | is_extension=False, extension_scope=None,
76 | serialized_options=None, file=DESCRIPTOR),
77 | ],
78 | extensions=[
79 | ],
80 | nested_types=[],
81 | enum_types=[
82 | ],
83 | serialized_options=None,
84 | is_extendable=False,
85 | syntax='proto3',
86 | extension_ranges=[],
87 | oneofs=[
88 | ],
89 | serialized_start=15,
90 | serialized_end=146,
91 | )
92 |
93 |
94 | _PREDICTRESULT = _descriptor.Descriptor(
95 | name='PredictResult',
96 | full_name='PredictResult',
97 | filename=None,
98 | file=DESCRIPTOR,
99 | containing_type=None,
100 | fields=[
101 | _descriptor.FieldDescriptor(
102 | name='result', full_name='PredictResult.result', index=0,
103 | number=1, type=9, cpp_type=9, label=1,
104 | has_default_value=False, default_value=_b("").decode('utf-8'),
105 | message_type=None, enum_type=None, containing_type=None,
106 | is_extension=False, extension_scope=None,
107 | serialized_options=None, file=DESCRIPTOR),
108 | _descriptor.FieldDescriptor(
109 | name='code', full_name='PredictResult.code', index=1,
110 | number=2, type=5, cpp_type=1, label=1,
111 | has_default_value=False, default_value=0,
112 | message_type=None, enum_type=None, containing_type=None,
113 | is_extension=False, extension_scope=None,
114 | serialized_options=None, file=DESCRIPTOR),
115 | _descriptor.FieldDescriptor(
116 | name='success', full_name='PredictResult.success', index=2,
117 | number=3, type=8, cpp_type=7, label=1,
118 | has_default_value=False, default_value=False,
119 | message_type=None, enum_type=None, containing_type=None,
120 | is_extension=False, extension_scope=None,
121 | serialized_options=None, file=DESCRIPTOR),
122 | ],
123 | extensions=[
124 | ],
125 | nested_types=[],
126 | enum_types=[
127 | ],
128 | serialized_options=None,
129 | is_extendable=False,
130 | syntax='proto3',
131 | extension_ranges=[],
132 | oneofs=[
133 | ],
134 | serialized_start=148,
135 | serialized_end=210,
136 | )
137 |
138 | DESCRIPTOR.message_types_by_name['PredictRequest'] = _PREDICTREQUEST
139 | DESCRIPTOR.message_types_by_name['PredictResult'] = _PREDICTRESULT
140 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
141 |
142 | PredictRequest = _reflection.GeneratedProtocolMessageType('PredictRequest', (_message.Message,), dict(
143 | DESCRIPTOR = _PREDICTREQUEST,
144 | __module__ = 'grpc_pb2'
145 | # @@protoc_insertion_point(class_scope:PredictRequest)
146 | ))
147 | _sym_db.RegisterMessage(PredictRequest)
148 |
149 | PredictResult = _reflection.GeneratedProtocolMessageType('PredictResult', (_message.Message,), dict(
150 | DESCRIPTOR = _PREDICTRESULT,
151 | __module__ = 'grpc_pb2'
152 | # @@protoc_insertion_point(class_scope:PredictResult)
153 | ))
154 | _sym_db.RegisterMessage(PredictResult)
155 |
156 |
157 |
158 | _PREDICT = _descriptor.ServiceDescriptor(
159 | name='Predict',
160 | full_name='Predict',
161 | file=DESCRIPTOR,
162 | index=0,
163 | serialized_options=None,
164 | serialized_start=212,
165 | serialized_end=267,
166 | methods=[
167 | _descriptor.MethodDescriptor(
168 | name='predict',
169 | full_name='Predict.predict',
170 | index=0,
171 | containing_service=None,
172 | input_type=_PREDICTREQUEST,
173 | output_type=_PREDICTRESULT,
174 | serialized_options=None,
175 | ),
176 | ])
177 | _sym_db.RegisterServiceDescriptor(_PREDICT)
178 |
179 | DESCRIPTOR.services_by_name['Predict'] = _PREDICT
180 |
181 | # @@protoc_insertion_point(module_scope)
182 |
--------------------------------------------------------------------------------
/compat/grpc_pb2_grpc.py:
--------------------------------------------------------------------------------
1 | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
2 | import grpc
3 |
4 | from compat import grpc_pb2 as grpc__pb2
5 |
6 |
7 | class PredictStub(object):
8 | # missing associated documentation comment in .proto file
9 | pass
10 |
11 | def __init__(self, channel):
12 | """Constructor.
13 |
14 | Args:
15 | channel: A grpc.Channel.
16 | """
17 | self.predict = channel.unary_unary(
18 | '/Predict/predict',
19 | request_serializer=grpc__pb2.PredictRequest.SerializeToString,
20 | response_deserializer=grpc__pb2.PredictResult.FromString,
21 | )
22 |
23 |
24 | class PredictServicer(object):
25 | # missing associated documentation comment in .proto file
26 | pass
27 |
28 | def predict(self, request, context):
29 | # missing associated documentation comment in .proto file
30 | pass
31 | context.set_code(grpc.StatusCode.UNIMPLEMENTED)
32 | context.set_details('Method not implemented!')
33 | raise NotImplementedError('Method not implemented!')
34 |
35 |
36 | def add_PredictServicer_to_server(servicer, server):
37 | rpc_method_handlers = {
38 | 'predict': grpc.unary_unary_rpc_method_handler(
39 | servicer.predict,
40 | request_deserializer=grpc__pb2.PredictRequest.FromString,
41 | response_serializer=grpc__pb2.PredictResult.SerializeToString,
42 | ),
43 | }
44 | generic_handler = grpc.method_handlers_generic_handler(
45 | 'Predict', rpc_method_handlers)
46 | server.add_generic_rpc_handlers((generic_handler,))
47 |
--------------------------------------------------------------------------------
/compat/grpc_server.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # Author: kerlomz
4 |
5 | import time
6 | import threading
7 | from concurrent import futures
8 |
9 | import grpc
10 | from compat import grpc_pb2_grpc, grpc_pb2
11 | import optparse
12 | from utils import ImageUtils
13 | from interface import InterfaceManager
14 | from config import Config
15 | from middleware import *
16 | from event_loop import event_loop
17 |
18 | _ONE_DAY_IN_SECONDS = 60 * 60 * 24
19 |
20 |
21 | class Predict(grpc_pb2_grpc.PredictServicer):
22 |
23 | def __init__(self, **kwargs):
24 | super().__init__(**kwargs)
25 | self.image_utils = ImageUtils(system_config)
26 |
27 | def predict(self, request, context):
28 | start_time = time.time()
29 | bytes_batch, status = self.image_utils.get_bytes_batch(request.image)
30 |
31 | if interface_manager.total == 0:
32 | logger.info('There is currently no model deployment and services are not available.')
33 | return {"result": "", "success": False, "code": -999}
34 |
35 | if not bytes_batch:
36 | return grpc_pb2.PredictResult(result="", success=status['success'], code=status['code'])
37 |
38 | image_sample = bytes_batch[0]
39 | image_size = ImageUtils.size_of_image(image_sample)
40 | size_string = "{}x{}".format(image_size[0], image_size[1])
41 | if request.model_name:
42 | interface = interface_manager.get_by_name(request.model_name)
43 | else:
44 | interface = interface_manager.get_by_size(size_string)
45 | if not interface:
46 | logger.info('Service is not ready!')
47 | return {"result": "", "success": False, "code": 999}
48 |
49 | if request.need_color:
50 | bytes_batch = [color_extract.separate_color(_, color_map[request.need_color]) for _ in bytes_batch]
51 |
52 | image_batch, status = ImageUtils.get_image_batch(interface.model_conf, bytes_batch)
53 |
54 | if not image_batch:
55 | return grpc_pb2.PredictResult(result="", success=status['success'], code=status['code'])
56 |
57 | result = interface.predict_batch(image_batch, request.split_char)
58 | logger.info('[{}] - Size[{}] - Type[{}] - Site[{}] - Predict Result[{}] - {} ms'.format(
59 | interface.name,
60 | size_string,
61 | request.model_type,
62 | request.model_site,
63 | result,
64 | (time.time() - start_time) * 1000
65 | ))
66 | return grpc_pb2.PredictResult(result=result, success=status['success'], code=status['code'])
67 |
68 |
69 | def serve():
70 | server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
71 | grpc_pb2_grpc.add_PredictServicer_to_server(Predict(), server)
72 | server.add_insecure_port('[::]:50054')
73 | server.start()
74 | try:
75 | while True:
76 | time.sleep(_ONE_DAY_IN_SECONDS)
77 | except KeyboardInterrupt:
78 | server.stop(0)
79 |
80 |
81 | if __name__ == '__main__':
82 | parser = optparse.OptionParser()
83 | parser.add_option('-p', '--port', type="int", default=50054, dest="port")
84 | parser.add_option('-c', '--config', type="str", default='./config.yaml', dest="config")
85 | parser.add_option('-m', '--model_path', type="str", default='model', dest="model_path")
86 | parser.add_option('-g', '--graph_path', type="str", default='graph', dest="graph_path")
87 | opt, args = parser.parse_args()
88 | server_port = opt.port
89 | conf_path = opt.config
90 | model_path = opt.model_path
91 | graph_path = opt.graph_path
92 | system_config = Config(conf_path=conf_path, model_path=model_path, graph_path=graph_path)
93 | interface_manager = InterfaceManager()
94 | threading.Thread(target=lambda: event_loop(system_config, model_path, interface_manager)).start()
95 |
96 | logger = system_config.logger
97 | server_host = "0.0.0.0"
98 |
99 | logger.info('Running on http://{}:{}/ '.format(server_host, server_port))
100 | serve()
101 |
--------------------------------------------------------------------------------
/compat/grpc_server.spec:
--------------------------------------------------------------------------------
1 | # -*- mode: python -*-
2 | # Used to package as a single executable
3 | # This is a configuration file
4 |
5 | block_cipher = pyi_crypto.PyiBlockCipher(key='kerlomz&coriander')
6 |
7 | added_files = [('resource/icon.ico', 'resource')]
8 |
9 | a = Analysis(['grpc_server.py'],
10 | pathex=['.'],
11 | binaries=[],
12 | datas=added_files,
13 | hiddenimports=[],
14 | hookspath=[],
15 | runtime_hooks=[],
16 | excludes=[],
17 | win_no_prefer_redirects=False,
18 | win_private_assemblies=False,
19 | cipher=block_cipher)
20 | pyz = PYZ(a.pure, a.zipped_data,
21 | cipher=block_cipher)
22 | exe = EXE(pyz,
23 | a.scripts,
24 | a.binaries,
25 | a.zipfiles,
26 | a.datas,
27 | name='captcha_platform_grpc',
28 | debug=False,
29 | strip=False,
30 | upx=True,
31 | runtime_tmpdir=None,
32 | console=True,
33 | icon='resource/icon.ico')
34 |
--------------------------------------------------------------------------------
/compat/sanic_server.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # Author: kerlomz
4 | import time
5 | import optparse
6 | import threading
7 | from config import Config
8 | from utils import ImageUtils
9 | from interface import InterfaceManager
10 | from watchdog.observers import Observer
11 | from event_handler import FileEventHandler
12 | from sanic import Sanic
13 | from sanic.response import json
14 | from signature import Signature, ServerType
15 | from middleware import *
16 | from event_loop import event_loop
17 |
18 | app = Sanic()
19 | sign = Signature(ServerType.SANIC)
20 | parser = optparse.OptionParser()
21 |
22 | conf_path = '../config.yaml'
23 | model_path = '../model'
24 | graph_path = '../graph'
25 |
26 | system_config = Config(conf_path=conf_path, model_path=model_path, graph_path=graph_path)
27 | sign.set_auth([{'accessKey': system_config.access_key, 'secretKey': system_config.secret_key}])
28 | logger = system_config.logger
29 | interface_manager = InterfaceManager()
30 | threading.Thread(target=lambda: event_loop(system_config, model_path, interface_manager)).start()
31 |
32 | image_utils = ImageUtils(system_config)
33 |
34 |
35 | @app.route('/captcha/auth/v2', methods=['POST'])
36 | @sign.signature_required # This decorator is required for certification.
37 | def auth_request(request):
38 | return common_request(request)
39 |
40 |
41 | @app.route('/captcha/v1', methods=['POST'])
42 | def no_auth_request(request):
43 | return common_request(request)
44 |
45 |
46 | def common_request(request):
47 | """
48 | This api is used for captcha prediction without authentication
49 | :return:
50 | """
51 | start_time = time.time()
52 | if not request.json or 'image' not in request.json:
53 | print(request.json)
54 | return
55 |
56 | if interface_manager.total == 0:
57 | logger.info('There is currently no model deployment and services are not available.')
58 | return json({"message": "", "success": False, "code": -999})
59 |
60 | bytes_batch, response = image_utils.get_bytes_batch(request.json['image'])
61 |
62 | if not bytes_batch:
63 | logger.error('Type[{}] - Site[{}] - Response[{}] - {} ms'.format(
64 | request.json['model_type'], request.json['model_site'], response,
65 | (time.time() - start_time) * 1000)
66 | )
67 | return json(response)
68 |
69 | image_sample = bytes_batch[0]
70 | image_size = ImageUtils.size_of_image(image_sample)
71 | size_string = "{}x{}".format(image_size[0], image_size[1])
72 |
73 | if 'model_name' in request.json:
74 | interface = interface_manager.get_by_name(request.json['model_name'])
75 | else:
76 | interface = interface_manager.get_by_size(size_string)
77 |
78 | split_char = request.json['split_char'] if 'split_char' in request.json else interface.model_conf.split_char
79 |
80 | if 'need_color' in request.json and request.json['need_color']:
81 | bytes_batch = [color_extract.separate_color(_, color_map[request.json['need_color']]) for _ in bytes_batch]
82 |
83 | image_batch, response = ImageUtils.get_image_batch(interface.model_conf, bytes_batch)
84 |
85 | if not image_batch:
86 | logger.error('[{}] - Size[{}] - Name[{}] - Response[{}] - {} ms'.format(
87 | interface.name, size_string, request.json.get('model_name'), response,
88 | (time.time() - start_time) * 1000)
89 | )
90 | return json(response)
91 |
92 | result = interface.predict_batch(image_batch, split_char)
93 | logger.info('[{}] - Size[{}] - Predict Result[{}] - {} ms'.format(
94 | interface.name,
95 | size_string,
96 | result,
97 | (time.time() - start_time) * 1000
98 | ))
99 | response['message'] = result
100 | return json(response)
101 |
102 |
103 | if __name__ == "__main__":
104 |
105 | parser.add_option('-p', '--port', type="int", default=19953, dest="port")
106 |
107 | opt, args = parser.parse_args()
108 | server_port = opt.port
109 |
110 |
111 |
112 | server_host = "0.0.0.0"
113 |
114 | logger.info('Running on http://{}:{}/ '.format(server_host, server_port))
115 | app.run(host=server_host, port=server_port)
116 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # Author: kerlomz
4 | import os
5 | import sys
6 | import uuid
7 | import json
8 | import yaml
9 | import hashlib
10 | import logging
11 | import logging.handlers
12 | from category import *
13 | from constants import SystemConfig, ModelField, ModelScene
14 |
15 | MODEL_SCENE_MAP = {
16 | 'Classification': ModelScene.Classification
17 | }
18 |
19 | MODEL_FIELD_MAP = {
20 | 'Image': ModelField.Image,
21 | 'Text': ModelField.Text
22 | }
23 |
24 | BLACKLIST_PATH = "blacklist.json"
25 | WHITELIST_PATH = "whitelist.json"
26 |
27 |
28 | def resource_path(relative_path):
29 | try:
30 | # PyInstaller creates a temp folder and stores path in _MEIPASS
31 | base_path = sys._MEIPASS
32 | except AttributeError:
33 | base_path = os.path.abspath(".")
34 | return os.path.join(base_path, relative_path)
35 |
36 |
37 | def get_version():
38 | version_file_path = resource_path("VERSION")
39 |
40 | if not os.path.exists(version_file_path):
41 | return "NULL"
42 |
43 | with open(version_file_path, "r", encoding="utf8") as f:
44 | return "".join(f.readlines()).strip()
45 |
46 |
47 | def get_default(src, default):
48 | return src if src else default
49 |
50 |
51 | def get_dict_fill(src: dict, default: dict):
52 | if not src:
53 | return default
54 | new_dict = default
55 | new_dict.update(src)
56 | return new_dict
57 |
58 |
59 | def blacklist() -> list:
60 | if not os.path.exists(BLACKLIST_PATH):
61 | return []
62 | try:
63 | with open(BLACKLIST_PATH, "r", encoding="utf8") as f_blacklist:
64 | result = json.loads("".join(f_blacklist.readlines()))
65 | return result
66 | except Exception as e:
67 | print(e)
68 | return []
69 |
70 |
71 | def whitelist() -> list:
72 | if not os.path.exists(WHITELIST_PATH):
73 | return ["127.0.0.1", "localhost"]
74 | try:
75 | with open(WHITELIST_PATH, "r", encoding="utf8") as f_whitelist:
76 | result = json.loads("".join(f_whitelist.readlines()))
77 | return result
78 | except Exception as e:
79 | print(e)
80 | return ["127.0.0.1", "localhost"]
81 |
82 |
83 | def set_blacklist(ip):
84 | try:
85 | old_blacklist = blacklist()
86 | old_blacklist.append(ip)
87 | with open(BLACKLIST_PATH, "w+", encoding="utf8") as f_blacklist:
88 | f_blacklist.write(json.dumps(old_blacklist, ensure_ascii=False, indent=2))
89 | except Exception as e:
90 | print(e)
91 |
92 |
93 | class Config(object):
94 | def __init__(self, conf_path: str, graph_path: str = None, model_path: str = None):
95 | self.model_path = model_path
96 | self.conf_path = conf_path
97 | self.graph_path = graph_path
98 | self.sys_cf = self.read_conf
99 | self.access_key = None
100 | self.secret_key = None
101 | self.default_model = self.sys_cf['System']['DefaultModel']
102 | self.default_port = self.sys_cf['System'].get('DefaultPort')
103 | if not self.default_port:
104 | self.default_port = 19952
105 | self.split_flag = self.sys_cf['System']['SplitFlag']
106 | self.split_flag = self.split_flag if isinstance(self.split_flag, bytes) else SystemConfig.split_flag
107 | self.route_map = get_default(self.sys_cf.get('RouteMap'), SystemConfig.default_route)
108 | self.log_path = "logs"
109 | self.request_def_map = get_default(self.sys_cf.get('RequestDef'), SystemConfig.default_config['RequestDef'])
110 | self.response_def_map = get_default(self.sys_cf.get('ResponseDef'), SystemConfig.default_config['ResponseDef'])
111 | self.save_path = self.sys_cf['System'].get("SavePath")
112 | self.request_count_interval = get_default(
113 | src=self.sys_cf['System'].get("RequestCountInterval"),
114 | default=60 * 60 * 24
115 | )
116 | self.g_request_count_interval = get_default(
117 | src=self.sys_cf['System'].get("GlobalRequestCountInterval"),
118 | default=60 * 60 * 24
119 | )
120 | self.request_limit = get_default(self.sys_cf['System'].get("RequestLimit"), -1)
121 | self.global_request_limit = get_default(self.sys_cf['System'].get("GlobalRequestLimit"), -1)
122 | self.exceeded_msg = get_default(
123 | src=self.sys_cf['System'].get("ExceededMessage"),
124 | default=SystemConfig.default_config['System'].get('ExceededMessage')
125 | )
126 | self.illegal_time_msg = get_default(
127 | src=self.sys_cf['System'].get("IllegalTimeMessage"),
128 | default=SystemConfig.default_config['System'].get('IllegalTimeMessage')
129 | )
130 |
131 | self.request_size_limit: dict = get_default(
132 | src=self.sys_cf['System'].get('RequestSizeLimit'),
133 | default={}
134 | )
135 | self.blacklist_trigger_times = get_default(self.sys_cf['System'].get("BlacklistTriggerTimes"), -1)
136 |
137 | self.use_whitelist: dict = get_default(
138 | src=self.sys_cf['System'].get('Whitelist'),
139 | default=False
140 | )
141 |
142 | self.error_message = get_dict_fill(
143 | self.sys_cf['System'].get('ErrorMessage'), SystemConfig.default_config['System']['ErrorMessage']
144 | )
145 | self.logger_tag = get_default(self.sys_cf['System'].get('LoggerTag'), "coriander")
146 | self.without_logger = self.sys_cf['System'].get('WithoutLogger')
147 | self.without_logger = self.without_logger if self.without_logger is not None else False
148 | self.logger = logging.getLogger(self.logger_tag)
149 | self.use_default_authorization = False
150 | self.authorization = None
151 | self.init_logger()
152 | self.assignment()
153 |
154 | def init_logger(self):
155 | self.logger.setLevel(logging.INFO)
156 |
157 | if not os.path.exists(self.model_path):
158 | os.makedirs(self.model_path)
159 | if not os.path.exists(self.graph_path):
160 | os.makedirs(self.graph_path)
161 |
162 | self.logger.propagate = False
163 |
164 | if not self.without_logger:
165 | if not os.path.exists(self.log_path):
166 | os.makedirs(self.log_path)
167 | file_handler = logging.handlers.TimedRotatingFileHandler(
168 | '{}/{}.log'.format(self.log_path, "captcha_platform"),
169 | when="MIDNIGHT",
170 | interval=1,
171 | backupCount=180,
172 | encoding='utf-8'
173 | )
174 | stream_handler = logging.StreamHandler()
175 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
176 | file_handler.setFormatter(formatter)
177 | stream_handler.setFormatter(formatter)
178 | self.logger.addHandler(file_handler)
179 | self.logger.addHandler(stream_handler)
180 |
181 | def assignment(self):
182 | # ---AUTHORIZATION START---
183 | mac_address = hex(uuid.getnode())[2:]
184 | self.use_default_authorization = False
185 | self.authorization = self.sys_cf.get('Security')
186 | if not self.authorization or not self.authorization.get('AccessKey') or not self.authorization.get('SecretKey'):
187 | self.use_default_authorization = True
188 | model_name_md5 = hashlib.md5(
189 | "{}".format(self.default_model).encode('utf8')).hexdigest()
190 | self.authorization = {
191 | 'AccessKey': model_name_md5[0: 16],
192 | 'SecretKey': hashlib.md5("{}{}".format(model_name_md5, mac_address).encode('utf8')).hexdigest()
193 | }
194 | self.access_key = self.authorization['AccessKey']
195 | self.secret_key = self.authorization['SecretKey']
196 | # ---AUTHORIZATION END---
197 |
198 | @property
199 | def read_conf(self):
200 | if not os.path.exists(self.conf_path):
201 | with open(self.conf_path, 'w', encoding="utf-8") as sys_fp:
202 | sys_fp.write(yaml.safe_dump(SystemConfig.default_config))
203 | return SystemConfig.default_config
204 | with open(self.conf_path, 'r', encoding="utf-8") as sys_fp:
205 | sys_stream = sys_fp.read()
206 | return yaml.load(sys_stream, Loader=yaml.SafeLoader)
207 |
208 |
209 | class Model(object):
210 |
211 | def __init__(self, conf: Config, model_conf_path: str):
212 | self.conf = conf
213 | self.logger = self.conf.logger
214 | self.graph_path = conf.graph_path
215 | self.model_path = conf.model_path
216 | self.model_conf_path = model_conf_path
217 | self.model_conf_demo = 'model_demo.yaml'
218 | self.verify()
219 |
220 | def verify(self):
221 | if not os.path.exists(self.model_conf_path):
222 | raise Exception(
223 | 'Configuration File "{}" No Found. '
224 | 'If it is used for the first time, please copy one from {} as {}'.format(
225 | self.model_conf_path,
226 | self.model_conf_demo,
227 | self.model_path
228 | )
229 | )
230 |
231 | if not os.path.exists(self.model_path):
232 | os.makedirs(self.model_path)
233 | raise Exception(
234 | 'For the first time, please put the trained model in the model directory.'
235 | )
236 |
237 | def category_extract(self, param):
238 | if isinstance(param, list):
239 | return param
240 | if isinstance(param, str):
241 | if param in SIMPLE_CATEGORY_MODEL.keys():
242 | return SIMPLE_CATEGORY_MODEL.get(param)
243 | self.logger.error(
244 | "Category set configuration error, customized category set should be list type"
245 | )
246 | return None
247 |
248 | @property
249 | def model_conf(self) -> dict:
250 | with open(self.model_conf_path, 'r', encoding="utf-8") as sys_fp:
251 | sys_stream = sys_fp.read()
252 | return yaml.load(sys_stream, Loader=yaml.SafeLoader)
253 |
254 |
255 | class ModelConfig(Model):
256 | model_exists: bool = False
257 |
258 | def __init__(self, conf: Config, model_conf_path: str):
259 | super().__init__(conf=conf, model_conf_path=model_conf_path)
260 |
261 | self.conf = conf
262 |
263 | """MODEL"""
264 | self.model_root: dict = self.model_conf['Model']
265 | self.model_name: str = self.model_root.get('ModelName')
266 | self.model_version: float = self.model_root.get('Version')
267 | self.model_version = self.model_version if self.model_version else 1.0
268 | self.model_field_param: str = self.model_root.get('ModelField')
269 | self.model_field: ModelField = ModelConfig.param_convert(
270 | source=self.model_field_param,
271 | param_map=MODEL_FIELD_MAP,
272 | text="Current model field ({model_field}) is not supported".format(model_field=self.model_field_param),
273 | code=50002
274 | )
275 |
276 | self.model_scene_param: str = self.model_root.get('ModelScene')
277 |
278 | self.model_scene: ModelScene = ModelConfig.param_convert(
279 | source=self.model_scene_param,
280 | param_map=MODEL_SCENE_MAP,
281 | text="Current model scene ({model_scene}) is not supported".format(model_scene=self.model_scene_param),
282 | code=50001
283 | )
284 |
285 | """SYSTEM"""
286 | self.checkpoint_tag = 'checkpoint'
287 | self.system_root: dict = self.model_conf['System']
288 | self.memory_usage: float = self.system_root.get('MemoryUsage')
289 |
290 | """FIELD PARAM - IMAGE"""
291 | self.field_root: dict = self.model_conf['FieldParam']
292 | self.category_param = self.field_root.get('Category')
293 | self.category_value = self.category_extract(self.category_param)
294 | if self.category_value is None:
295 | raise Exception(
296 | "The category set type does not exist, there is no category set named {}".format(self.category_param),
297 | )
298 | self.category: list = SPACE_TOKEN + self.category_value
299 | self.category_num: int = len(self.category)
300 | self.image_channel: int = self.field_root.get('ImageChannel')
301 | self.image_width: int = self.field_root.get('ImageWidth')
302 | self.image_height: int = self.field_root.get('ImageHeight')
303 | self.max_label_num: int = self.field_root.get('MaxLabelNum')
304 | self.min_label_num: int = self.get_var(self.field_root, 'MinLabelNum', self.max_label_num)
305 | self.resize: list = self.field_root.get('Resize')
306 | self.output_split = self.field_root.get('OutputSplit')
307 | self.output_split = self.output_split if self.output_split else ""
308 | self.corp_params = self.field_root.get('CorpParams')
309 | self.output_coord = self.field_root.get('OutputCoord')
310 | self.batch_model = self.field_root.get('BatchModel')
311 | self.external_model = self.field_root.get('ExternalModelForCorp')
312 | self.category_split = self.field_root.get('CategorySplit')
313 |
314 | """PRETREATMENT"""
315 | self.pretreatment_root = self.model_conf.get('Pretreatment')
316 | self.pre_binaryzation = self.get_var(self.pretreatment_root, 'Binaryzation', -1)
317 | self.pre_replace_transparent = self.get_var(self.pretreatment_root, 'ReplaceTransparent', True)
318 | self.pre_horizontal_stitching = self.get_var(self.pretreatment_root, 'HorizontalStitching', False)
319 | self.pre_concat_frames = self.get_var(self.pretreatment_root, 'ConcatFrames', -1)
320 | self.pre_blend_frames = self.get_var(self.pretreatment_root, 'BlendFrames', -1)
321 | self.pre_freq_frames = self.get_var(self.pretreatment_root, 'FreqFrames', -1)
322 | self.exec_map = self.get_var(self.pretreatment_root, 'ExecuteMap', None)
323 |
324 | """COMPILE_MODEL"""
325 | self.compile_model_path = os.path.join(self.graph_path, '{}.pb'.format(self.model_name))
326 | if not os.path.exists(self.compile_model_path):
327 | if not os.path.exists(self.graph_path):
328 | os.makedirs(self.graph_path)
329 | self.logger.error(
330 | '{} not found, please put the trained model in the graph directory.'.format(self.compile_model_path)
331 | )
332 | else:
333 | self.model_exists = True
334 |
335 | @staticmethod
336 | def param_convert(source, param_map: dict, text, code, default=None):
337 | if source is None:
338 | return default
339 | if source not in param_map.keys():
340 | raise Exception(text)
341 | return param_map[source]
342 |
343 | def size_match(self, size_str):
344 | return size_str == self.size_string
345 |
346 | @staticmethod
347 | def get_var(src: dict, name: str, default=None):
348 | if not src or name not in src:
349 | return default
350 | return src.get(name)
351 |
352 | @property
353 | def size_string(self):
354 | return "{}x{}".format(self.image_width, self.image_height)
355 |
--------------------------------------------------------------------------------
/constants.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # Author: kerlomz
4 | from enum import Enum, unique
5 |
6 |
7 | @unique
8 | class ModelScene(Enum):
9 | """模型场景枚举"""
10 | Classification = 'Classification'
11 |
12 |
13 | @unique
14 | class ModelField(Enum):
15 | """模型类别枚举"""
16 | Image = 'Image'
17 | Text = 'Text'
18 |
19 |
20 | class SystemConfig:
21 | split_flag = b'\x99\x99\x99\x00\xff\xff\xff\x00\x99\x99\x99'
22 | default_route = [
23 | {
24 | "Class": "AuthHandler",
25 | "Route": "/captcha/auth/v2"
26 | },
27 | {
28 | "Class": "NoAuthHandler",
29 | "Route": "/captcha/v1"
30 | },
31 | {
32 | "Class": "SimpleHandler",
33 | "Route": "/captcha/v3"
34 | },
35 | {
36 | "Class": "HeartBeatHandler",
37 | "Route": "/check_backend_active.html"
38 | },
39 | {
40 | "Class": "HeartBeatHandler",
41 | "Route": "/verification"
42 | },
43 | {
44 | "Class": "HeartBeatHandler",
45 | "Route": "/"
46 | },
47 | {
48 | "Class": "ServiceHandler",
49 | "Route": "/service/info"
50 | },
51 | {
52 | "Class": "FileHandler",
53 | "Route": "/service/logs/(.*)",
54 | "Param": {"path": "logs"}
55 | },
56 | {
57 | "Class": "BaseHandler",
58 | "Route": ".*"
59 | }
60 | ]
61 | default_config = {
62 | "System": {
63 | "DefaultModel": "default",
64 | "SplitFlag": b'\x99\x99\x99\x00\xff\xff999999.........99999\xff\x00\x99\x99\x99',
65 | "SavePath": "",
66 | "RequestCountInterval": 86400,
67 | "GlobalRequestCountInterval": 86400,
68 | "RequestLimit": -1,
69 | "GlobalRequestLimit": -1,
70 | "WithoutLogger": False,
71 | "RequestSizeLimit": {},
72 | "DefaultPort": 19952,
73 | "IllegalTimeMessage": "The maximum number of requests has been exceeded.",
74 | "ExceededMessage": "Illegal access time, please request in open hours.",
75 | "BlacklistTriggerTimes": -1,
76 | "Whitelist": False,
77 | "ErrorMessage": {
78 | 400: "Bad Request",
79 | 401: "Unicode Decode Error",
80 | 403: "Forbidden",
81 | 404: "404 Not Found",
82 | 405: "Method Not Allowed",
83 | 500: "Internal Server Error"
84 | }
85 | },
86 | "RouteMap": default_route,
87 | "Security": {
88 | "AccessKey": "",
89 | "SecretKey": ""
90 | },
91 | "RequestDef": {
92 | "InputData": "image",
93 | "ModelName": "model_name",
94 | },
95 | "ResponseDef": {
96 | "Message": "message",
97 | "StatusCode": "code",
98 | "StatusBool": "success",
99 | "Uid": "uid",
100 | }
101 | }
102 |
103 |
104 | class ServerType(str):
105 | FLASK = 19951
106 | TORNADO = 19952
107 | SANIC = 19953
108 |
109 |
110 | class Response:
111 |
112 | def __init__(self, def_map: dict):
113 | # SIGN
114 | self.INVALID_PUBLIC_PARAMS = dict(Message='Invalid Public Params', StatusCode=400001, StatusBool=False)
115 | self.UNKNOWN_SERVER_ERROR = dict(Message='Unknown Server Error', StatusCode=400002, StatusBool=False)
116 | self.INVALID_TIMESTAMP = dict(Message='Invalid Timestamp', StatusCode=400004, StatusBool=False)
117 | self.INVALID_ACCESS_KEY = dict(Message='Invalid Access Key', StatusCode=400005, StatusBool=False)
118 | self.INVALID_QUERY_STRING = dict(Message='Invalid Query String', StatusCode=400006, StatusBool=False)
119 |
120 | # SERVER
121 | self.SUCCESS = dict(Message=None, StatusCode=000000, StatusBool=True)
122 | self.INVALID_IMAGE_FORMAT = dict(Message='Invalid Image Format', StatusCode=500001, StatusBool=False)
123 | self.INVALID_BASE64_STRING = dict(Message='Invalid Base64 String', StatusCode=500002, StatusBool=False)
124 | self.IMAGE_DAMAGE = dict(Message='Image Damage', StatusCode=500003, StatusBool=False)
125 | self.IMAGE_SIZE_NOT_MATCH_GRAPH = dict(Message='Image Size Not Match Graph Value', StatusCode=500004, StatusBool=False)
126 |
127 | self.INVALID_PUBLIC_PARAMS = self.parse(self.INVALID_PUBLIC_PARAMS, def_map)
128 | self.UNKNOWN_SERVER_ERROR = self.parse(self.UNKNOWN_SERVER_ERROR, def_map)
129 | self.INVALID_TIMESTAMP = self.parse(self.INVALID_TIMESTAMP, def_map)
130 | self.INVALID_ACCESS_KEY = self.parse(self.INVALID_ACCESS_KEY, def_map)
131 | self.INVALID_QUERY_STRING = self.parse(self.INVALID_QUERY_STRING, def_map)
132 |
133 | self.SUCCESS = self.parse(self.SUCCESS, def_map)
134 | self.INVALID_IMAGE_FORMAT = self.parse(self.INVALID_IMAGE_FORMAT, def_map)
135 | self.INVALID_BASE64_STRING = self.parse(self.INVALID_BASE64_STRING, def_map)
136 | self.IMAGE_DAMAGE = self.parse(self.IMAGE_DAMAGE, def_map)
137 | self.IMAGE_SIZE_NOT_MATCH_GRAPH = self.parse(self.IMAGE_SIZE_NOT_MATCH_GRAPH, def_map)
138 |
139 | def find_message(self, _code):
140 | e = [value for value in vars(self).values()]
141 | _t = [i['message'] for i in e if i['code'] == _code]
142 | return _t[0] if _t else None
143 |
144 | def find(self, _code):
145 | e = [value for value in vars(self).values()]
146 | _t = [i for i in e if i['code'] == _code]
147 | return _t[0] if _t else None
148 |
149 | def all_code(self):
150 | return [i['message'] for i in [value for value in vars(self).values()]]
151 |
152 | @staticmethod
153 | def parse(src: dict, target_map: dict):
154 | return {target_map[k]: v for k, v in src.items()}
155 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # Author: kerlomz
4 | import io
5 | import os
6 | import base64
7 | import datetime
8 | import hashlib
9 | import time
10 | import numpy as np
11 | import cv2
12 | from config import Config
13 | from requests import Session, post, get
14 | from PIL import Image as PilImage
15 | from constants import ServerType
16 |
17 | # DEFAULT_HOST = "63.211.111.82"
18 | # DEFAULT_HOST = "39.100.71.103"
19 | # DEFAULT_HOST = "120.79.233.49"
20 | # DEFAULT_HOST = "47.52.203.228"
21 | DEFAULT_HOST = "192.168.50.152"
22 | # DEFAULT_HOST = "127.0.0.1"
23 |
24 |
25 | def _image(_path, model_type=None, model_site=None, need_color=None, fpath=None):
26 | with open(_path, "rb") as f:
27 | img_bytes = f.read()
28 | # data_stream = io.BytesIO(img_bytes)
29 | # pil_image = PilImage.open(data_stream)
30 | # size = pil_image.size
31 | # im = np.array(pil_image)
32 | # im = im[3:size[1] - 3, 3:size[0] - 3]
33 | # img_bytes = bytearray(cv2.imencode('.png', im)[1])
34 |
35 | b64 = base64.b64encode(img_bytes).decode()
36 | return {
37 | 'image': b64,
38 | 'model_type': model_type,
39 | 'model_site': model_site,
40 | 'need_color': need_color,
41 | 'path': fpath
42 | }
43 |
44 |
45 | class Auth(object):
46 |
47 | def __init__(self, host: str, server_type: ServerType, access_key=None, secret_key=None, port=None):
48 | self._conf = Config(conf_path="config.yaml")
49 | self._url = 'http://{}:{}/captcha/auth/v2'.format(host, port if port else server_type)
50 | self._access_key = access_key if access_key else self._conf.access_key
51 | self._secret_key = secret_key if secret_key else self._conf.secret_key
52 | self.true_count = 0
53 | self.total_count = 0
54 |
55 | def sign(self, args):
56 | """ MD5 signature
57 | @param args: All query parameters (public and private) requested in addition to signature
58 | {
59 | 'image': 'base64 encoded text',
60 | 'accessKey': 'C180130204197838',
61 | 'timestamp': 1536682949,
62 | 'sign': 'F641778AE4F93DAF5CCE3E43A674C34E'
63 | }
64 | The sign is the md5 encrypted of "accessKey=your_assess_key&image=base64_encoded_text×tamp=current_timestamp"
65 | """
66 | if "sign" in args:
67 | args.pop("sign")
68 | query_string = '&'.join(['{}={}'.format(k, v) for (k, v) in sorted(args.items())])
69 | query_string = '&'.join([query_string, self._secret_key])
70 | return hashlib.md5(query_string.encode('utf-8')).hexdigest().upper()
71 |
72 | def make_json(self, params):
73 | if not isinstance(params, dict):
74 | raise TypeError("params is not a dict")
75 | # Get the current timestamp
76 | timestamp = int(time.mktime(datetime.datetime.now().timetuple()))
77 | # Set public parameters
78 | params.update(accessKey=self._access_key, timestamp=timestamp)
79 | params.update(sign=self.sign(params))
80 | return params
81 |
82 | def request(self, params):
83 | params = dict(params, **self.make_json(params))
84 | return post(self._url, json=params).json()
85 |
86 | def local_iter(self, image_list: dict):
87 | for k, v in image_list.items():
88 | code = self.request(v).get('message')
89 | _true = str(code).lower() == str(k).lower()
90 | if _true:
91 | self.true_count += 1
92 | self.total_count += 1
93 | print('result: {}, label: {}, flag: n{}, acc_rate: {}'.format(code, k, _true,
94 | self.true_count / self.total_count))
95 |
96 |
97 | class NoAuth(object):
98 | def __init__(self, host: str, server_type: ServerType, port=None, url=None):
99 | self._url = 'http://{}:{}/captcha/v1'.format(host, port if port else server_type)
100 | self._url = self._url if not url else url
101 | self.true_count = 0
102 | self.total_count = 0
103 |
104 | def request(self, params):
105 | import json
106 | # print(params)
107 | # print(params['fpath'])
108 | # print(json.dumps(params))
109 | # return post(self._url, data=base64.b64decode(params.get("image").encode())).json()
110 | return post(self._url, json=params).json()
111 |
112 | def local_iter(self, image_list: dict):
113 | for k, v in image_list.items():
114 | try:
115 | code = self.request(v).get('message')
116 | _true = str(code).lower() == str(k).lower()
117 | if _true:
118 | self.true_count += 1
119 | self.total_count += 1
120 | print('result: {}, label: {}, flag: {}, acc_rate: {}, {}'.format(
121 | code, k, _true, self.true_count / self.total_count, v.get('path')
122 | ))
123 | except Exception as e:
124 | print(e)
125 |
126 | def press_testing(self, image_list: dict, model_type=None, model_site=None):
127 | from multiprocessing.pool import ThreadPool
128 | pool = ThreadPool(500)
129 | for k, v in image_list.items():
130 | pool.apply_async(
131 | self.request({"image": v.get('image'), "model_type": model_type, "model_site": model_site}))
132 | pool.close()
133 | pool.join()
134 | print(self.true_count / len(image_list))
135 |
136 |
137 | class GoogleRPC(object):
138 |
139 | def __init__(self, host: str):
140 | self._url = '{}:50054'.format(host)
141 | self.true_count = 0
142 | self.total_count = 0
143 |
144 | def request(self, image, println=False, value=None, model_type=None, model_site=None, need_color=None):
145 |
146 | import grpc
147 | import grpc_pb2
148 | import grpc_pb2_grpc
149 | channel = grpc.insecure_channel(self._url)
150 | stub = grpc_pb2_grpc.PredictStub(channel)
151 | response = stub.predict(grpc_pb2.PredictRequest(
152 | image=image, split_char=',', model_type=model_type, model_site=model_site, need_color=need_color
153 | ))
154 | if println and value:
155 | _true = str(response.result).lower() == str(value).lower()
156 | if _true:
157 | self.true_count += 1
158 | print("result: {}, label: {}, flag: {}".format(response.result, value, _true))
159 | return {"message": response.result, "code": response.code, "success": response.success}
160 |
161 | def local_iter(self, image_list: dict, model_type=None, model_site=None):
162 | for k, v in image_list.items():
163 | code = self.request(v.get('image'), model_type=model_type, model_site=model_site,
164 | need_color=v.get('need_color')).get('message')
165 | _true = str(code).lower() == str(k).lower()
166 | if _true:
167 | self.true_count += 1
168 | self.total_count += 1
169 | print('result: {}, label: {}, flag: {}, acc_rate: {}'.format(
170 | code, k, _true, self.true_count / self.total_count
171 | ))
172 |
173 | def remote_iter(self, url: str, save_path: str = None, num=100, model_type=None, model_site=None):
174 | if not os.path.exists(save_path):
175 | os.makedirs(save_path)
176 | sess = Session()
177 | sess.verify = False
178 | for i in range(num):
179 | img_bytes = sess.get(url).content
180 | img_b64 = base64.b64encode(img_bytes).decode()
181 | code = self.request(img_b64, model_type=model_type, model_site=model_site).get('message')
182 | with open("{}/{}_{}.jpg".format(save_path, code, hashlib.md5(img_bytes).hexdigest()), "wb") as f:
183 | f.write(img_bytes)
184 |
185 | print('result: {}'.format(
186 | code,
187 | ))
188 |
189 | def press_testing(self, image_list: dict, model_type=None, model_site=None):
190 | from multiprocessing.pool import ThreadPool
191 | pool = ThreadPool(500)
192 | for k, v in image_list.items():
193 | pool.apply_async(self.request(v.get('image'), True, k, model_type=model_type, model_site=model_site))
194 | pool.close()
195 | pool.join()
196 | print(self.true_count / len(image_list))
197 |
198 |
199 | if __name__ == '__main__':
200 | # # Here you can replace it with a web request to get images in real time.
201 | # with open(r"D:\***.jpg", "rb") as f:
202 | # img_bytes = f.read()
203 |
204 | # # Here is the code for the network request.
205 | # # Replace your own captcha url for testing.
206 | # # sess = Session()
207 | # # sess.headers = {
208 | # # 'user-agent': 'Chrome'
209 | # # }
210 | # # img_bytes = sess.get("http://***.com/captcha").content
211 | #
212 | # # Open the image for human eye comparison,
213 | # # preview whether the recognition result is consistent.
214 | # data_stream = io.BytesIO(img_bytes)
215 | # pil_image = PilImage.open(data_stream)
216 | # pil_image.show()
217 | # api_params = {
218 | # 'image': base64.b64encode(img_bytes).decode(),
219 | # }
220 | # print(api_params)
221 | # for i in range(1):
222 | # Tornado API with authentication
223 | # resp = Auth(DEFAULT_HOST, ServerType.TORNADO).request(api_params)
224 | # print(resp)
225 |
226 | # Flask API with authentication
227 | # resp = Auth(DEFAULT_HOST, ServerType.FLASK).request(api_params)
228 | # print(resp)
229 |
230 | # Tornado API without authentication
231 | # resp = NoAuth(DEFAULT_HOST, ServerType.TORNADO).request(api_params)
232 | # print(resp)
233 |
234 | # Flask API without authentication
235 | # resp = NoAuth(DEFAULT_HOST, ServerType.FLASK).request(api_params)
236 | # print(resp)
237 |
238 | # API by gRPC - The fastest way.
239 | # If you want to identify multiple verification codes continuously, please do like this:
240 | # resp = GoogleRPC(DEFAULT_HOST).request(base64.b64encode(img_bytes+b'\x00\xff\xff\xff\x00'+img_bytes).decode())
241 | # b'\x00\xff\xff\xff\x00' is the split_flag defined in config.py
242 | # resp = GoogleRPC(DEFAULT_HOST).request(base64.b64encode(img_bytes).decode())
243 | # print(resp)
244 | # pass
245 |
246 | # API by gRPC - The fastest way, Local batch version, only for self testing.
247 | path = r"C:\Users\kerlomz\Desktop\New folder (6)"
248 | path_list = os.listdir(path)
249 | import random
250 |
251 | # random.shuffle(path_list)
252 | print(path_list)
253 | batch = {
254 | _path.split('_')[0].lower(): _image(
255 | os.path.join(path, _path),
256 | model_type=None,
257 | model_site=None,
258 | need_color=None,
259 | fpath=_path
260 | )
261 | for i, _path in enumerate(path_list)
262 | if i < 10000
263 | }
264 | print(batch)
265 | NoAuth(DEFAULT_HOST, ServerType.TORNADO, port=19952).local_iter(batch)
266 | # NoAuth(DEFAULT_HOST, ServerType.FLASK).local_iter(batch)
267 | # NoAuth(DEFAULT_HOST, ServerType.SANIC).local_iter(batch)
268 | # GoogleRPC(DEFAULT_HOST).local_iter(batch, model_site=None, model_type=None)
269 | # GoogleRPC(DEFAULT_HOST).press_testing(batch, model_site=None, model_type=None)
270 | # GoogleRPC(DEFAULT_HOST).remote_iter("https://pbank.cqrcb.com:9080/perbank/VerifyImage?update=0.8746844661116633", r"D:\test12", 100, model_site='80x24', model_type=None)
271 |
--------------------------------------------------------------------------------
/deploy.conf.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # Author: kerlomz
4 | # Gunicorn deploy file.
5 |
6 | import multiprocessing
7 |
8 | bind = '0.0.0.0:19951'
9 | workers = multiprocessing.cpu_count() * 2 + 1
10 | backlog = 2048
11 | # worker_class = "gevent"
12 | debug = False
13 | daemon = True
14 | proc_name = 'gunicorn.pid'
15 | pidfile = 'debug.log'
16 | errorlog = 'error.log'
17 | accesslog = 'access.log'
18 | loglevel = 'info'
19 | timeout = 10
20 |
--------------------------------------------------------------------------------
/event_handler.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # Author: kerlomz
4 | import os
5 | import time
6 | from watchdog.events import *
7 | from config import ModelConfig, Config
8 | from graph_session import GraphSession
9 | from interface import InterfaceManager, Interface
10 | from utils import PathUtils
11 |
12 |
13 | class FileEventHandler(FileSystemEventHandler):
14 | def __init__(self, conf: Config, model_conf_path: str, interface_manager: InterfaceManager):
15 | FileSystemEventHandler.__init__(self)
16 | self.conf = conf
17 | self.logger = self.conf.logger
18 | self.name_map = {}
19 | self.model_conf_path = model_conf_path
20 | self.interface_manager = interface_manager
21 | self.init()
22 |
23 | def init(self):
24 | model_list = os.listdir(self.model_conf_path)
25 | model_list = [os.path.join(self.model_conf_path, i) for i in model_list if i.endswith("yaml")]
26 | for model in model_list:
27 | self._add(model, is_first=True)
28 | if self.interface_manager.total == 0:
29 | self.logger.info(
30 | "\n - Number of interfaces: {}"
31 | "\n - There is currently no model deployment"
32 | "\n - Services are not available"
33 | "\n[ Please check the graph and model path whether the pb file and yaml file are placed. ]".format(
34 | self.interface_manager.total,
35 | ))
36 | else:
37 | self.logger.info(
38 | "\n - Number of interfaces: {}"
39 | "\n - Current online interface: \n\t - {}"
40 | "\n - The default Interface is: {}".format(
41 | self.interface_manager.total,
42 | "\n\t - ".join(["[{}]".format(v) for k, v in self.name_map.items()]),
43 | self.interface_manager.default_name
44 | ))
45 |
46 | def _add(self, src_path, is_first=False, count=0):
47 | try:
48 | model_path = str(src_path)
49 | path_exists = os.path.exists(model_path)
50 | if not path_exists and count > 0:
51 | self.logger.error("{} not found, retry attempt is terminated.".format(model_path))
52 | return
53 | if 'model_demo.yaml' in model_path:
54 | self.logger.warning(
55 | "\n-------------------------------------------------------------------\n"
56 | "- Found that the model_demo.yaml file exists, \n"
57 | "- the loading is automatically ignored. \n"
58 | "- If it is used for the first time, \n"
59 | "- please copy it as a template. \n"
60 | "- and do not use the reserved character \"model_demo.yaml\" as the file name."
61 | "\n-------------------------------------------------------------------"
62 | )
63 | return
64 | if model_path.endswith("yaml"):
65 | model_conf = ModelConfig(self.conf, model_path)
66 | inner_name = model_conf.model_name
67 | inner_size = model_conf.size_string
68 | inner_key = PathUtils.get_file_name(model_path)
69 | for k, v in self.name_map.items():
70 | if inner_size in v:
71 | self.logger.warning(
72 | "\n-------------------------------------------------------------------\n"
73 | "- The current model {} is the same size [{}] as the loaded model {}. \n"
74 | "- Only one of the smart calls can be called. \n"
75 | "- If you want to refer to one of them, \n"
76 | "- please use the model key or model type to find it."
77 | "\n-------------------------------------------------------------------".format(
78 | inner_key, inner_size, k
79 | )
80 | )
81 | break
82 |
83 | inner_value = model_conf.model_name
84 | graph_session = GraphSession(model_conf)
85 | if graph_session.loaded:
86 | interface = Interface(graph_session)
87 | if inner_name == self.conf.default_model:
88 | self.interface_manager.set_default(interface)
89 | else:
90 | self.interface_manager.add(interface)
91 | self.logger.info("{} a new model: {} ({})".format(
92 | "Inited" if is_first else "Added", inner_value, inner_key
93 | ))
94 | self.name_map[inner_key] = inner_value
95 | if src_path in self.interface_manager.invalid_group:
96 | self.interface_manager.invalid_group.pop(src_path)
97 | else:
98 | self.interface_manager.report(src_path)
99 | if count < 12 and not is_first:
100 | time.sleep(5)
101 | return self._add(src_path, is_first=is_first, count=count+1)
102 |
103 | except Exception as e:
104 | self.interface_manager.report(src_path)
105 | self.logger.error(e)
106 |
107 | def delete(self, src_path):
108 | try:
109 | model_path = str(src_path)
110 | if model_path.endswith("yaml"):
111 | inner_key = PathUtils.get_file_name(model_path)
112 | graph_name = self.name_map.get(inner_key)
113 | self.interface_manager.remove_by_name(graph_name)
114 | self.name_map.pop(inner_key)
115 | self.logger.info("Unload the model: {} ({})".format(graph_name, inner_key))
116 | except Exception as e:
117 | self.logger.error("Config File [{}] does not exist.".format(str(e).replace("'", "")))
118 |
119 | def on_created(self, event):
120 | if event.is_directory:
121 | self.logger.info("directory created:{0}".format(event.src_path))
122 | else:
123 | model_path = str(event.src_path)
124 | self._add(model_path)
125 | self.logger.info(
126 | "\n - Number of interfaces: {}"
127 | "\n - Current online interface: \n\t - {}"
128 | "\n - The default Interface is: {}".format(
129 | len(self.interface_manager.group),
130 | "\n\t - ".join(["[{}]".format(v) for k, v in self.name_map.items()]),
131 | self.interface_manager.default_name
132 | ))
133 |
134 | def on_deleted(self, event):
135 | if event.is_directory:
136 | self.logger.info("directory deleted:{0}".format(event.src_path))
137 | else:
138 | model_path = str(event.src_path)
139 | if model_path in self.interface_manager.invalid_group:
140 | self.interface_manager.invalid_group.pop(model_path)
141 | inner_key = PathUtils.get_file_name(model_path)
142 | if inner_key in self.name_map:
143 | self.delete(model_path)
144 | self.logger.info(
145 | "\n - Number of interfaces: {}"
146 | "\n - Current online interface: \n\t - {}"
147 | "\n - The default Interface is: {}".format(
148 | len(self.interface_manager.group),
149 | "\n\t - ".join(["[{}]".format(v) for k, v in self.name_map.items()]),
150 | self.interface_manager.default_name
151 | ))
152 |
153 |
154 | if __name__ == "__main__":
155 | pass
156 | # import time
157 | # from watchdog.observers import Observer
158 | # observer = Observer()
159 | # interface_manager = InterfaceManager()
160 | # event_handler = FileEventHandler("", interface_manager)
161 | # observer.schedule(event_handler, event_handler.model_conf_path, True)
162 | # observer.start()
163 | # try:
164 | # while True:
165 | # time.sleep(1)
166 | # except KeyboardInterrupt:
167 | # observer.stop()
168 | # observer.join()
169 |
--------------------------------------------------------------------------------
/event_loop.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # Author: kerlomz
4 | import time
5 | from watchdog.observers import Observer
6 | from event_handler import FileEventHandler
7 |
8 |
9 | def event_loop(system_config, model_path, interface_manager):
10 | observer = Observer()
11 | event_handler = FileEventHandler(system_config, model_path, interface_manager)
12 | observer.schedule(event_handler, event_handler.model_conf_path, True)
13 | observer.start()
14 | try:
15 | while True:
16 | time.sleep(1)
17 | except KeyboardInterrupt:
18 | observer.stop()
19 | observer.join()
--------------------------------------------------------------------------------
/graph_session.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # Author: kerlomz
4 | import os
5 | import tensorflow as tf
6 | tf.compat.v1.disable_v2_behavior()
7 | from tensorflow.python.framework.errors_impl import NotFoundError
8 | from config import ModelConfig
9 |
10 | os.environ['TF_XLA_FLAGS'] = '--tf_xla_cpu_global_jit'
11 | os.environ['CUDA_VISIBLE_DEVICES'] = '1'
12 |
13 |
14 | class GraphSession(object):
15 | def __init__(self, model_conf: ModelConfig):
16 | self.model_conf = model_conf
17 | self.logger = self.model_conf.logger
18 | self.size_str = self.model_conf.size_string
19 | self.model_name = self.model_conf.model_name
20 | self.graph_name = self.model_conf.model_name
21 | self.version = self.model_conf.model_version
22 | self.graph = tf.compat.v1.Graph()
23 | self.sess = tf.compat.v1.Session(
24 | graph=self.graph,
25 | config=tf.compat.v1.ConfigProto(
26 |
27 | # allow_soft_placement=True,
28 | # log_device_placement=True,
29 | gpu_options=tf.compat.v1.GPUOptions(
30 | # allocator_type='BFC',
31 | allow_growth=True, # it will cause fragmentation.
32 | # per_process_gpu_memory_fraction=self.model_conf.device_usage
33 | per_process_gpu_memory_fraction=0.1
34 | )
35 | )
36 | )
37 | self.graph_def = self.graph.as_graph_def()
38 | self.loaded = self.load_model()
39 |
40 | def load_model(self):
41 | # Here is for debugging, positioning error source use.
42 | # with self.graph.as_default():
43 | # saver = tf.train.import_meta_graph('graph/***.meta')
44 | # saver.restore(self.sess, tf.train.latest_checkpoint('graph'))
45 | if not self.model_conf.model_exists:
46 | self.destroy()
47 | return False
48 | try:
49 | with tf.io.gfile.GFile(self.model_conf.compile_model_path, "rb") as f:
50 | graph_def_file = f.read()
51 | self.graph_def.ParseFromString(graph_def_file)
52 | with self.graph.as_default():
53 | self.sess.run(tf.compat.v1.global_variables_initializer())
54 | _ = tf.import_graph_def(self.graph_def, name="")
55 |
56 | self.logger.info('TensorFlow Session {} Loaded.'.format(self.model_conf.model_name))
57 | return True
58 | except NotFoundError:
59 | self.logger.error('The system cannot find the model specified.')
60 | self.destroy()
61 | return False
62 |
63 | @property
64 | def session(self):
65 | return self.sess
66 |
67 | def destroy(self):
68 | self.sess.close()
69 | del self.sess
70 |
--------------------------------------------------------------------------------
/interface.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # Author: kerlomz
4 | import os
5 | import time
6 | from graph_session import GraphSession
7 | from predict import predict_func
8 |
9 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
10 |
11 |
12 | class Interface(object):
13 |
14 | def __init__(self, graph_session: GraphSession):
15 | self.graph_sess = graph_session
16 | self.model_conf = graph_session.model_conf
17 | self.size_str = self.model_conf.size_string
18 | self.graph_name = self.graph_sess.graph_name
19 | self.version = self.graph_sess.version
20 | self.model_category = self.model_conf.category_param
21 | if self.graph_sess.loaded:
22 | self.sess = self.graph_sess.session
23 | self.dense_decoded = self.sess.graph.get_tensor_by_name("dense_decoded:0")
24 | self.x = self.sess.graph.get_tensor_by_name('input:0')
25 | self.sess.graph.finalize()
26 |
27 | @property
28 | def name(self):
29 | return self.graph_name
30 |
31 | @property
32 | def size(self):
33 | return self.size_str
34 |
35 | def destroy(self):
36 | self.graph_sess.destroy()
37 |
38 | def predict_batch(self, image_batch, output_split=None):
39 | predict_text = predict_func(
40 | image_batch,
41 | self.sess,
42 | self.dense_decoded,
43 | self.x,
44 | self.model_conf,
45 | output_split
46 | )
47 | return predict_text
48 |
49 |
50 | class InterfaceManager(object):
51 |
52 | def __init__(self, interface: Interface = None):
53 | self.group = []
54 | self.invalid_group = {}
55 | self.set_default(interface)
56 |
57 | def add(self, interface: Interface):
58 | if interface in self.group:
59 | return
60 | self.group.append(interface)
61 |
62 | def remove(self, interface: Interface):
63 | if interface in self.group:
64 | interface.destroy()
65 | self.group.remove(interface)
66 |
67 | def report(self, model):
68 | self.invalid_group[model] = {"create_time": time.asctime(time.localtime(time.time()))}
69 |
70 | def remove_by_name(self, graph_name):
71 | interface = self.get_by_name(graph_name, False)
72 | self.remove(interface)
73 |
74 | def get_by_size(self, size: str, return_default=True):
75 |
76 | match_ids = [i for i in range(len(self.group)) if self.group[i].size_str == size]
77 | if not match_ids:
78 | return self.default if return_default else None
79 | else:
80 | ver = [self.group[i].version for i in match_ids]
81 | return self.group[match_ids[ver.index(max(ver))]]
82 |
83 | def get_by_name(self, key: str, return_default=True):
84 | for interface in self.group:
85 | if interface.name == key:
86 |
87 | return interface
88 | return self.default if return_default else None
89 |
90 | @property
91 | def default(self):
92 | return self.group[0] if len(self.group) > 0 else None
93 |
94 | @property
95 | def default_name(self):
96 | _default = self.default
97 | if not _default:
98 | return
99 | return _default.graph_name
100 |
101 | @property
102 | def total(self):
103 | return len(self.group)
104 |
105 | @property
106 | def online_names(self):
107 | return [i.name for i in self.group]
108 |
109 | def set_default(self, interface: Interface):
110 | if not interface:
111 | return
112 | self.group.insert(0, interface)
113 |
--------------------------------------------------------------------------------
/middleware/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # Author: kerlomz
4 |
5 | # from middleware.impl import color_filter
6 | from middleware.impl import corp_to_multi
7 | from middleware.impl import gif_frames
8 | # color_extract = color_filter.ColorFilter()
9 | # color_map = color_filter.color_map
10 |
11 |
--------------------------------------------------------------------------------
/middleware/constructor/color_extractor.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # Author: kerlomz
4 | import tensorflow as tf
5 | from tensorflow.python.framework.graph_util import convert_variables_to_constants
6 |
7 | black = tf.constant([[0, 0, 0]], dtype=tf.int32)
8 | red = tf.constant([[0, 0, 255]], dtype=tf.int32)
9 | yellow = tf.constant([[0, 255, 255]], dtype=tf.int32)
10 | blue = tf.constant([[255, 0, 0]], dtype=tf.int32)
11 | green = tf.constant([[0, 255, 0]], dtype=tf.int32)
12 | white = tf.constant([[255, 255, 255]], dtype=tf.int32)
13 |
14 |
15 | def k_means(data, target_color, bg_color1, bg_color2, alpha=1.0):
16 | def get_distance(point):
17 | sum_squares = tf.cast(tf.reduce_sum(tf.abs(tf.subtract(data, point)), axis=2, keep_dims=True), tf.float32)
18 | return sum_squares
19 |
20 | alpha_value = tf.constant(alpha, dtype=tf.float32)
21 | # target_color black:0, red:1, blue:2, yellow:3, green:4, color_1:5, color_2:6
22 | black_distance = get_distance(black)
23 | red_distance = get_distance(red)
24 | if target_color == 1:
25 | red_distance = tf.multiply(red_distance, alpha_value)
26 | blue_distance = get_distance(blue)
27 | if target_color == 2:
28 | blue_distance = tf.multiply(blue_distance, alpha_value)
29 | yellow_distance = get_distance(yellow)
30 | if target_color == 3:
31 | yellow_distance = tf.multiply(yellow_distance, alpha_value)
32 | white_distance = get_distance(yellow)
33 | if target_color == 7:
34 | white_distance = tf.multiply(white_distance, alpha_value)
35 |
36 | green_distance = get_distance(green)
37 | c_1_distance = get_distance(bg_color1)
38 | c_2_distance = get_distance(bg_color2)
39 |
40 | distances = tf.concat([
41 | black_distance,
42 | red_distance,
43 | blue_distance,
44 | yellow_distance,
45 | green_distance,
46 | c_1_distance,
47 | c_2_distance,
48 | white_distance
49 | ], axis=-1)
50 |
51 | clusters = tf.cast(tf.argmin(distances, axis=-1), tf.int32)
52 |
53 | mask = tf.equal(clusters, target_color)
54 | mask = tf.cast(mask, tf.int32)
55 |
56 | return mask * 255
57 |
58 |
59 | def filter_img(img, target_color, alpha=0.9):
60 | # background color1
61 | color_1 = img[0, 0, :]
62 | color_1 = tf.reshape(color_1, [1, 3])
63 | color_1 = tf.cast(color_1, dtype=tf.int32)
64 |
65 | # background color2
66 | color_2 = img[6, 6, :]
67 | color_2 = tf.reshape(color_2, [1, 3])
68 | color_2 = tf.cast(color_2, dtype=tf.int32)
69 |
70 | filtered_img = k_means(img_holder, target_color, color_1, color_2, alpha)
71 | filtered_img = tf.expand_dims(filtered_img, axis=0)
72 | filtered_img = tf.expand_dims(filtered_img, axis=-1)
73 | filtered_img = tf.squeeze(filtered_img, name="filtered")
74 | return filtered_img
75 |
76 |
77 | def compile_graph():
78 |
79 | with sess.graph.as_default():
80 | input_graph_def = sess.graph.as_graph_def()
81 |
82 | output_graph_def = convert_variables_to_constants(
83 | sess,
84 | input_graph_def,
85 | output_node_names=['filtered']
86 | )
87 |
88 | last_compile_model_path = "color_extractor.pb"
89 | with tf.gfile.FastGFile(last_compile_model_path, mode='wb') as gf:
90 | # gf.write(output_graph_def.SerializeToString())
91 | print(output_graph_def.SerializeToString())
92 |
93 |
94 | if __name__ == "__main__":
95 |
96 | sess = tf.Session()
97 | img_holder = tf.placeholder(dtype=tf.int32, name="img_holder")
98 | color = tf.placeholder(dtype=tf.int32, name="target_color")
99 | filtered = filter_img(img_holder, color, alpha=0.8)
100 |
101 | compile_graph()
102 |
103 |
--------------------------------------------------------------------------------
/middleware/impl/color_extractor.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # Author: kerlomz
4 | import io
5 | import cv2
6 | import time
7 | import PIL.Image as PilImage
8 | import numpy as np
9 | import tensorflow as tf
10 | from enum import Enum, unique
11 | from distutils.version import StrictVersion
12 |
13 |
14 | @unique
15 | class TargetColor(Enum):
16 | Black = 0
17 | Red = 1
18 | Blue = 2
19 | Yellow = 3
20 | Green = 4
21 | White = 7
22 |
23 |
24 | color_map = {
25 | 'black': TargetColor.Black,
26 | 'red': TargetColor.Red,
27 | 'blue': TargetColor.Blue,
28 | 'yellow': TargetColor.Yellow,
29 | 'green': TargetColor.Green,
30 | 'white': TargetColor.White
31 | }
32 |
33 |
34 | class ColorExtract:
35 |
36 | def __init__(self):
37 | self.model_raw_v1_14 = b'\nB\n\x05Const\x12\x05Const*%\n\x05value\x12\x1cB\x1a\x08\x03\x12\x08\x12\x02\x08\x01\x12\x02\x08\x03"\x0c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00*\x0b\n\x05dtype\x12\x020\x03\nD\n\x07Const_1\x12\x05Const*%\n\x05value\x12\x1cB\x1a\x08\x03\x12\x08\x12\x02\x08\x01\x12\x02\x08\x03"\x0c\x00\x00\x00\x00\x00\x00\x00\x00\xff\x00\x00\x00*\x0b\n\x05dtype\x12\x020\x03\nD\n\x07Const_2\x12\x05Const*%\n\x05value\x12\x1cB\x1a\x08\x03\x12\x08\x12\x02\x08\x01\x12\x02\x08\x03"\x0c\x00\x00\x00\x00\xff\x00\x00\x00\xff\x00\x00\x00*\x0b\n\x05dtype\x12\x020\x03\nD\n\x07Const_3\x12\x05Const*%\n\x05value\x12\x1cB\x1a\x08\x03\x12\x08\x12\x02\x08\x01\x12\x02\x08\x03"\x0c\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00*\x0b\n\x05dtype\x12\x020\x03\nD\n\x07Const_4\x12\x05Const*%\n\x05value\x12\x1cB\x1a\x08\x03\x12\x08\x12\x02\x08\x01\x12\x02\x08\x03"\x0c\x00\x00\x00\x00\xff\x00\x00\x00\x00\x00\x00\x00*\x0b\n\x05dtype\x12\x020\x03\n5\n\nimg_holder\x12\x0bPlaceholder*\r\n\x05shape\x12\x04:\x02\x18\x01*\x0b\n\x05dtype\x12\x020\x03\n7\n\x0ctarget_color\x12\x0bPlaceholder*\r\n\x05shape\x12\x04:\x02\x18\x01*\x0b\n\x05dtype\x12\x020\x03\nL\n\x13strided_slice/stack\x12\x05Const*!\n\x05value\x12\x18B\x16\x08\x03\x12\x04\x12\x02\x08\x03"\x0c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00*\x0b\n\x05dtype\x12\x020\x03\nN\n\x15strided_slice/stack_1\x12\x05Const*!\n\x05value\x12\x18B\x16\x08\x03\x12\x04\x12\x02\x08\x03"\x0c\x01\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00*\x0b\n\x05dtype\x12\x020\x03\nN\n\x15strided_slice/stack_2\x12\x05Const*!\n\x05value\x12\x18B\x16\x08\x03\x12\x04\x12\x02\x08\x03"\x0c\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00*\x0b\n\x05dtype\x12\x020\x03\n\xe6\x01\n\rstrided_slice\x12\x0cStridedSlice\x1a\nimg_holder\x1a\x13strided_slice/stack\x1a\x15strided_slice/stack_1\x1a\x15strided_slice/stack_2*\x07\n\x01T\x12\x020\x03*\x0b\n\x05Index\x12\x020\x03*\x16\n\x10shrink_axis_mask\x12\x02\x18\x03*\x10\n\nbegin_mask\x12\x02\x18\x04*\x13\n\rellipsis_mask\x12\x02\x18\x00*\x13\n\rnew_axis_mask\x12\x02\x18\x00*\x0e\n\x08end_mask\x12\x02\x18\x04\nB\n\rReshape/shape\x12\x05Const*\x1d\n\x05value\x12\x14B\x12\x08\x03\x12\x04\x12\x02\x08\x02"\x08\x01\x00\x00\x00\x03\x00\x00\x00*\x0b\n\x05dtype\x12\x020\x03\nG\n\x07Reshape\x12\x07Reshape\x1a\rstrided_slice\x1a\rReshape/shape*\x07\n\x01T\x12\x020\x03*\x0c\n\x06Tshape\x12\x020\x03\nN\n\x15strided_slice_1/stack\x12\x05Const*!\n\x05value\x12\x18B\x16\x08\x03\x12\x04\x12\x02\x08\x03"\x0c\x06\x00\x00\x00\x06\x00\x00\x00\x00\x00\x00\x00*\x0b\n\x05dtype\x12\x020\x03\nP\n\x17strided_slice_1/stack_1\x12\x05Const*!\n\x05value\x12\x18B\x16\x08\x03\x12\x04\x12\x02\x08\x03"\x0c\x07\x00\x00\x00\x07\x00\x00\x00\x00\x00\x00\x00*\x0b\n\x05dtype\x12\x020\x03\nP\n\x17strided_slice_1/stack_2\x12\x05Const*!\n\x05value\x12\x18B\x16\x08\x03\x12\x04\x12\x02\x08\x03"\x0c\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00*\x0b\n\x05dtype\x12\x020\x03\n\xee\x01\n\x0fstrided_slice_1\x12\x0cStridedSlice\x1a\nimg_holder\x1a\x15strided_slice_1/stack\x1a\x17strided_slice_1/stack_1\x1a\x17strided_slice_1/stack_2*\x07\n\x01T\x12\x020\x03*\x0b\n\x05Index\x12\x020\x03*\x16\n\x10shrink_axis_mask\x12\x02\x18\x03*\x10\n\nbegin_mask\x12\x02\x18\x04*\x13\n\rellipsis_mask\x12\x02\x18\x00*\x13\n\rnew_axis_mask\x12\x02\x18\x00*\x0e\n\x08end_mask\x12\x02\x18\x04\nD\n\x0fReshape_1/shape\x12\x05Const*\x1d\n\x05value\x12\x14B\x12\x08\x03\x12\x04\x12\x02\x08\x02"\x08\x01\x00\x00\x00\x03\x00\x00\x00*\x0b\n\x05dtype\x12\x020\x03\nM\n\tReshape_1\x12\x07Reshape\x1a\x0fstrided_slice_1\x1a\x0fReshape_1/shape*\x07\n\x01T\x12\x020\x03*\x0c\n\x06Tshape\x12\x020\x03\n&\n\x03Sub\x12\x03Sub\x1a\nimg_holder\x1a\x05Const*\x07\n\x01T\x12\x020\x03\n\x18\n\x03Abs\x12\x03Abs\x1a\x03Sub*\x07\n\x01T\x12\x020\x03\n?\n\x15Sum/reduction_indices\x12\x05Const*\x12\n\x05value\x12\tB\x07\x08\x03\x12\x00:\x01\x02*\x0b\n\x05dtype\x12\x020\x03\nL\n\x03Sum\x12\x03Sum\x1a\x03Abs\x1a\x15Sum/reduction_indices*\n\n\x04Tidx\x12\x020\x03*\x0f\n\tkeep_dims\x12\x02(\x01*\x07\n\x01T\x12\x020\x03\n9\n\x04Cast\x12\x04Cast\x1a\x03Sum*\n\n\x04SrcT\x12\x020\x03*\x0e\n\x08Truncate\x12\x02(\x00*\n\n\x04DstT\x12\x020\x01\n*\n\x05Sub_1\x12\x03Sub\x1a\nimg_holder\x1a\x07Const_1*\x07\n\x01T\x12\x020\x03\n\x1c\n\x05Abs_1\x12\x03Abs\x1a\x05Sub_1*\x07\n\x01T\x12\x020\x03\nA\n\x17Sum_1/reduction_indices\x12\x05Const*\x12\n\x05value\x12\tB\x07\x08\x03\x12\x00:\x01\x02*\x0b\n\x05dtype\x12\x020\x03\nR\n\x05Sum_1\x12\x03Sum\x1a\x05Abs_1\x1a\x17Sum_1/reduction_indices*\n\n\x04Tidx\x12\x020\x03*\x0f\n\tkeep_dims\x12\x02(\x01*\x07\n\x01T\x12\x020\x03\n=\n\x06Cast_1\x12\x04Cast\x1a\x05Sum_1*\n\n\x04SrcT\x12\x020\x03*\x0e\n\x08Truncate\x12\x02(\x00*\n\n\x04DstT\x12\x020\x01\n*\n\x05Sub_2\x12\x03Sub\x1a\nimg_holder\x1a\x07Const_3*\x07\n\x01T\x12\x020\x03\n\x1c\n\x05Abs_2\x12\x03Abs\x1a\x05Sub_2*\x07\n\x01T\x12\x020\x03\nA\n\x17Sum_2/reduction_indices\x12\x05Const*\x12\n\x05value\x12\tB\x07\x08\x03\x12\x00:\x01\x02*\x0b\n\x05dtype\x12\x020\x03\nR\n\x05Sum_2\x12\x03Sum\x1a\x05Abs_2\x1a\x17Sum_2/reduction_indices*\n\n\x04Tidx\x12\x020\x03*\x0f\n\tkeep_dims\x12\x02(\x01*\x07\n\x01T\x12\x020\x03\n=\n\x06Cast_2\x12\x04Cast\x1a\x05Sum_2*\n\n\x04SrcT\x12\x020\x03*\x0e\n\x08Truncate\x12\x02(\x00*\n\n\x04DstT\x12\x020\x01\n*\n\x05Sub_3\x12\x03Sub\x1a\nimg_holder\x1a\x07Const_2*\x07\n\x01T\x12\x020\x03\n\x1c\n\x05Abs_3\x12\x03Abs\x1a\x05Sub_3*\x07\n\x01T\x12\x020\x03\nA\n\x17Sum_3/reduction_indices\x12\x05Const*\x12\n\x05value\x12\tB\x07\x08\x03\x12\x00:\x01\x02*\x0b\n\x05dtype\x12\x020\x03\nR\n\x05Sum_3\x12\x03Sum\x1a\x05Abs_3\x1a\x17Sum_3/reduction_indices*\n\n\x04Tidx\x12\x020\x03*\x0f\n\tkeep_dims\x12\x02(\x01*\x07\n\x01T\x12\x020\x03\n=\n\x06Cast_3\x12\x04Cast\x1a\x05Sum_3*\n\n\x04SrcT\x12\x020\x03*\x0e\n\x08Truncate\x12\x02(\x00*\n\n\x04DstT\x12\x020\x01\n*\n\x05Sub_4\x12\x03Sub\x1a\nimg_holder\x1a\x07Const_2*\x07\n\x01T\x12\x020\x03\n\x1c\n\x05Abs_4\x12\x03Abs\x1a\x05Sub_4*\x07\n\x01T\x12\x020\x03\nA\n\x17Sum_4/reduction_indices\x12\x05Const*\x12\n\x05value\x12\tB\x07\x08\x03\x12\x00:\x01\x02*\x0b\n\x05dtype\x12\x020\x03\nR\n\x05Sum_4\x12\x03Sum\x1a\x05Abs_4\x1a\x17Sum_4/reduction_indices*\n\n\x04Tidx\x12\x020\x03*\x0f\n\tkeep_dims\x12\x02(\x01*\x07\n\x01T\x12\x020\x03\n=\n\x06Cast_4\x12\x04Cast\x1a\x05Sum_4*\n\n\x04SrcT\x12\x020\x03*\x0e\n\x08Truncate\x12\x02(\x00*\n\n\x04DstT\x12\x020\x01\n*\n\x05Sub_5\x12\x03Sub\x1a\nimg_holder\x1a\x07Const_4*\x07\n\x01T\x12\x020\x03\n\x1c\n\x05Abs_5\x12\x03Abs\x1a\x05Sub_5*\x07\n\x01T\x12\x020\x03\nA\n\x17Sum_5/reduction_indices\x12\x05Const*\x12\n\x05value\x12\tB\x07\x08\x03\x12\x00:\x01\x02*\x0b\n\x05dtype\x12\x020\x03\nR\n\x05Sum_5\x12\x03Sum\x1a\x05Abs_5\x1a\x17Sum_5/reduction_indices*\n\n\x04Tidx\x12\x020\x03*\x0f\n\tkeep_dims\x12\x02(\x01*\x07\n\x01T\x12\x020\x03\n=\n\x06Cast_5\x12\x04Cast\x1a\x05Sum_5*\n\n\x04SrcT\x12\x020\x03*\x0e\n\x08Truncate\x12\x02(\x00*\n\n\x04DstT\x12\x020\x01\n*\n\x05Sub_6\x12\x03Sub\x1a\nimg_holder\x1a\x07Reshape*\x07\n\x01T\x12\x020\x03\n\x1c\n\x05Abs_6\x12\x03Abs\x1a\x05Sub_6*\x07\n\x01T\x12\x020\x03\nA\n\x17Sum_6/reduction_indices\x12\x05Const*\x12\n\x05value\x12\tB\x07\x08\x03\x12\x00:\x01\x02*\x0b\n\x05dtype\x12\x020\x03\nR\n\x05Sum_6\x12\x03Sum\x1a\x05Abs_6\x1a\x17Sum_6/reduction_indices*\n\n\x04Tidx\x12\x020\x03*\x0f\n\tkeep_dims\x12\x02(\x01*\x07\n\x01T\x12\x020\x03\n=\n\x06Cast_6\x12\x04Cast\x1a\x05Sum_6*\n\n\x04SrcT\x12\x020\x03*\x0e\n\x08Truncate\x12\x02(\x00*\n\n\x04DstT\x12\x020\x01\n,\n\x05Sub_7\x12\x03Sub\x1a\nimg_holder\x1a\tReshape_1*\x07\n\x01T\x12\x020\x03\n\x1c\n\x05Abs_7\x12\x03Abs\x1a\x05Sub_7*\x07\n\x01T\x12\x020\x03\nA\n\x17Sum_7/reduction_indices\x12\x05Const*\x12\n\x05value\x12\tB\x07\x08\x03\x12\x00:\x01\x02*\x0b\n\x05dtype\x12\x020\x03\nR\n\x05Sum_7\x12\x03Sum\x1a\x05Abs_7\x1a\x17Sum_7/reduction_indices*\n\n\x04Tidx\x12\x020\x03*\x0f\n\tkeep_dims\x12\x02(\x01*\x07\n\x01T\x12\x020\x03\n=\n\x06Cast_7\x12\x04Cast\x1a\x05Sum_7*\n\n\x04SrcT\x12\x020\x03*\x0e\n\x08Truncate\x12\x02(\x00*\n\n\x04DstT\x12\x020\x01\n>\n\x0bconcat/axis\x12\x05Const*\x1b\n\x05value\x12\x12B\x10\x08\x03\x12\x00:\n\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01*\x0b\n\x05dtype\x12\x020\x03\n{\n\x06concat\x12\x08ConcatV2\x1a\x04Cast\x1a\x06Cast_1\x1a\x06Cast_2\x1a\x06Cast_3\x1a\x06Cast_5\x1a\x06Cast_6\x1a\x06Cast_7\x1a\x06Cast_4\x1a\x0bconcat/axis*\n\n\x04Tidx\x12\x020\x03*\x07\n\x01T\x12\x020\x01*\x07\n\x01N\x12\x02\x18\x08\nC\n\x10ArgMin/dimension\x12\x05Const*\x1b\n\x05value\x12\x12B\x10\x08\x03\x12\x00:\n\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01*\x0b\n\x05dtype\x12\x020\x03\nR\n\x06ArgMin\x12\x06ArgMin\x1a\x06concat\x1a\x10ArgMin/dimension*\n\n\x04Tidx\x12\x020\x03*\x07\n\x01T\x12\x020\x01*\x11\n\x0boutput_type\x12\x020\t\n>\n\x06Cast_8\x12\x04Cast\x1a\x06ArgMin*\n\n\x04SrcT\x12\x020\t*\x0e\n\x08Truncate\x12\x02(\x00*\n\n\x04DstT\x12\x020\x03\n-\n\x05Equal\x12\x05Equal\x1a\x06Cast_8\x1a\x0ctarget_color*\x07\n\x01T\x12\x020\x03\n=\n\x06Cast_9\x12\x04Cast\x1a\x05Equal*\n\n\x04SrcT\x12\x020\n*\x0e\n\x08Truncate\x12\x02(\x00*\n\n\x04DstT\x12\x020\x03\n0\n\x05mul/y\x12\x05Const*\x13\n\x05value\x12\nB\x08\x08\x03\x12\x00:\x02\xff\x01*\x0b\n\x05dtype\x12\x020\x03\n"\n\x03mul\x12\x03Mul\x1a\x06Cast_9\x1a\x05mul/y*\x07\n\x01T\x12\x020\x03\n8\n\x0eExpandDims/dim\x12\x05Const*\x12\n\x05value\x12\tB\x07\x08\x03\x12\x00:\x01\x00*\x0b\n\x05dtype\x12\x020\x03\nB\n\nExpandDims\x12\nExpandDims\x1a\x03mul\x1a\x0eExpandDims/dim*\n\n\x04Tdim\x12\x020\x03*\x07\n\x01T\x12\x020\x03\nC\n\x10ExpandDims_1/dim\x12\x05Const*\x1b\n\x05value\x12\x12B\x10\x08\x03\x12\x00:\n\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01*\x0b\n\x05dtype\x12\x020\x03\nM\n\x0cExpandDims_1\x12\nExpandDims\x1a\nExpandDims\x1a\x10ExpandDims_1/dim*\n\n\x04Tdim\x12\x020\x03*\x07\n\x01T\x12\x020\x03\n>\n\x08filtered\x12\x07Squeeze\x1a\x0cExpandDims_1*\x12\n\x0csqueeze_dims\x12\x02\n\x00*\x07\n\x01T\x12\x020\x03\x12\x00'
38 | self.color_graph = tf.Graph()
39 | self.color_sess = tf.compat.v1.Session(
40 | graph=self.color_graph,
41 | config=tf.compat.v1.ConfigProto(
42 | allow_soft_placement=True,
43 | # log_device_placement=True,
44 | gpu_options=tf.compat.v1.GPUOptions(
45 | # allow_growth=True, # it will cause fragmentation.
46 | per_process_gpu_memory_fraction=0.1
47 | ))
48 | )
49 | self.color_graph_def = self.color_graph.as_graph_def()
50 | self.load_model()
51 | self.img_holder = self.color_sess.graph.get_tensor_by_name("img_holder:0")
52 | self.target_color = self.color_sess.graph.get_tensor_by_name("target_color:0")
53 | self.filtered = self.color_sess.graph.get_tensor_by_name("filtered:0")
54 | self.color_graph.finalize()
55 |
56 | def load_model(self):
57 | raw = self.model_raw_v1_14
58 | self.color_graph_def.ParseFromString(raw)
59 | with self.color_graph.as_default():
60 | self.color_sess.run(tf.compat.v1.global_variables_initializer())
61 | _ = tf.import_graph_def(self.color_graph_def, name="")
62 |
63 | def separate_color(self, image_bytes, color: TargetColor):
64 | # image = np.asarray(bytearray(image_bytes), dtype="uint8")
65 | # image = cv2.imdecode(image, -1)
66 | image = np.array(PilImage.open(io.BytesIO(image_bytes)))
67 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
68 | mask = self.color_sess.run(self.filtered, {self.img_holder: image, self.target_color: color.value})
69 | mask = bytearray(cv2.imencode('.png', mask)[1])
70 | return mask
71 |
72 |
73 | if __name__ == '__main__':
74 |
75 | pass
76 | # import os
77 | # source_dir = r'E:\***'
78 | # target_dir = r'E:\***'
79 | # if not os.path.exists(target_dir):
80 | # os.makedirs(target_dir)
81 | #
82 | # source_names = os.listdir(source_dir)
83 | # color_extract = ColorExtract()
84 | # st = time.time()
85 | # for i, name in enumerate(source_names):
86 | # img_path = os.path.join(source_dir, name)
87 | # if i % 100 == 0:
88 | # print(i)
89 | # with open(img_path, "rb") as f:
90 | # b = f.read()
91 | # result = color_extract.separate_color(b, color_map['red'])
92 | # target_path = os.path.join(target_dir, name)
93 | # with open(target_path, "wb") as f:
94 | # f.write(result)
95 | #
96 | # print('completed {}'.format(time.time() - st))
97 |
--------------------------------------------------------------------------------
/middleware/impl/color_filter.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # Author: kerlomz
4 | import io
5 | import cv2
6 | import time
7 | import PIL.Image as PilImage
8 | import numpy as np
9 | import onnxruntime as ort
10 | from enum import Enum, unique
11 | from distutils.version import StrictVersion
12 | from middleware.resource import color_model
13 |
14 |
15 | @unique
16 | class TargetColor(Enum):
17 | Red = 1
18 | Blue = 2
19 | Yellow = 3
20 | Black = 4
21 |
22 |
23 | color_map = {
24 | 'black': TargetColor.Black,
25 | 'red': TargetColor.Red,
26 | 'blue': TargetColor.Blue,
27 | 'yellow': TargetColor.Yellow,
28 | }
29 |
30 |
31 | class ColorFilter:
32 |
33 | def __init__(self):
34 | self.model_onnx = color_model
35 | self.sess = ort.InferenceSession(self.model_onnx)
36 |
37 | def predict_color(self, image_batch, color: TargetColor):
38 | dense_decoded_code = self.sess.run(["dense_decoded:0"], input_feed={
39 | "input:0": image_batch,
40 | })
41 | result = dense_decoded_code[0][0].tolist()
42 | return [i for i, c in enumerate(result) if c == color.value]
43 |
44 |
45 | if __name__ == '__main__':
46 |
47 | pass
48 | # import os
49 | # source_dir = r'E:\***'
50 | # target_dir = r'E:\***'
51 | # if not os.path.exists(target_dir):
52 | # os.makedirs(target_dir)
53 | #
54 | # source_names = os.listdir(source_dir)
55 | # color_extract = ColorExtract()
56 | # st = time.time()
57 | # for i, name in enumerate(source_names):
58 | # img_path = os.path.join(source_dir, name)
59 | # if i % 100 == 0:
60 | # print(i)
61 | # with open(img_path, "rb") as f:
62 | # b = f.read()
63 | # result = color_extract.separate_color(b, color_map['red'])
64 | # target_path = os.path.join(target_dir, name)
65 | # with open(target_path, "wb") as f:
66 | # f.write(result)
67 | #
68 | # print('completed {}'.format(time.time() - st))
69 |
--------------------------------------------------------------------------------
/middleware/impl/corp_to_multi.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # Author: kerlomz
4 | import io
5 | import cv2
6 | import PIL.Image as Pil_Image
7 | import numpy as np
8 |
9 |
10 | def coord_calc(param, is_range=True, is_integer=True):
11 |
12 | result_group = []
13 | start_h = param['start_pos'][1]
14 | end_h = start_h + param['corp_size'][1]
15 | for row in range(param['corp_num'][1]):
16 | start_w = param['start_pos'][0]
17 | end_w = start_w + param['corp_size'][0]
18 | for col in range(param['corp_num'][0]):
19 | pos_range = [[start_w, end_w], [start_h, end_h]]
20 | t = lambda x: int(x) if is_integer else x
21 | pos_center = [t((start_w + end_w)/2), t((start_h + end_h)/2)]
22 | result_group.append(pos_range if is_range else pos_center)
23 | start_w = end_w + param['interval_size'][0]
24 | end_w = start_w + param['corp_size'][0]
25 | start_h = end_h + param['interval_size'][1]
26 | end_h = start_h + param['corp_size'][1]
27 | return result_group
28 |
29 |
30 | def parse_multi_img(image_bytes, param_group):
31 | img_bytes = image_bytes[0]
32 | image_arr = np.array(Pil_Image.open(io.BytesIO(img_bytes)).convert('RGB'))
33 | if len(image_arr.shape) == 3:
34 | image_arr = cv2.cvtColor(image_arr, cv2.COLOR_BGR2RGB)
35 | # image_arr = np.fromstring(img_bytes, np.uint8)
36 | # print(image_arr.shape)
37 | image_arr = image_arr.swapaxes(0, 1)
38 | group = []
39 | for p in param_group:
40 | pos_ranges = coord_calc(p, True, True)
41 | for pos_range in pos_ranges:
42 | corp_arr = image_arr[pos_range[0][0]: pos_range[0][1], pos_range[1][0]: pos_range[1][1]]
43 | corp_arr = cv2.imencode('.png', np.swapaxes(corp_arr, 0, 1))[1]
44 | corp_bytes = bytes(bytearray(corp_arr))
45 | group.append(corp_bytes)
46 | return group
47 |
48 |
49 | def get_coordinate(label: str, param_group, title_index=None):
50 | if title_index is None:
51 | title_index = [0]
52 | param = param_group[-1]
53 | coord_map = coord_calc(param, is_range=False, is_integer=True)
54 | index_group = get_pair_index(label=label, title_index=title_index)
55 | return [coord_map[i] for i in index_group]
56 |
57 |
58 | def get_pair_index(label: str, title_index=None):
59 | if title_index is None:
60 | title_index = [0]
61 | max_index = max(title_index)
62 |
63 | label_group = label.split(',')
64 | titles = [label_group[i] for i in title_index]
65 |
66 | index_group = []
67 | for title in titles:
68 | for i, item in enumerate(label_group[max_index+1:]):
69 | if item == title:
70 | index_group.append(i)
71 | index_group = [i for i in index_group]
72 | return index_group
73 |
74 |
75 | if __name__ == '__main__':
76 | import os
77 | import hashlib
78 | root_dir = r"H:\Task\Trains\d111_Trains"
79 | target_dir = r"F:\1q2"
80 | if not os.path.exists(target_dir):
81 | os.makedirs(target_dir)
82 | _param_group = [
83 | {
84 | "start_pos": [20, 50],
85 | "interval_size": [20, 20],
86 | "corp_num": [4, 2],
87 | "corp_size": [60, 60]
88 | }
89 | ]
90 | for name in os.listdir(root_dir):
91 | path = os.path.join(root_dir, name)
92 | with open(path, "rb") as f:
93 | file_bytes = [f.read()]
94 | group = parse_multi_img(file_bytes, _param_group)
95 | for b in group:
96 | tag = hashlib.md5(b).hexdigest()
97 | p = os.path.join(target_dir, "{}.png".format(tag))
98 | with open(p, "wb") as f:
99 | f.write(b)
--------------------------------------------------------------------------------
/middleware/impl/gif_frames.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # Author: kerlomz
4 | import io
5 | import cv2
6 | import numpy as np
7 | from itertools import groupby
8 | from PIL import ImageSequence, Image
9 |
10 |
11 | def split_frames(image_obj, need_frame=None):
12 | image_seq = ImageSequence.all_frames(image_obj)
13 | image_arr_last = [np.asarray(image_seq[-1])] if -1 in need_frame and len(need_frame) > 1 else []
14 | image_arr = [np.asarray(item) for i, item in enumerate(image_seq) if (i in need_frame or need_frame == [-1])]
15 | image_arr += image_arr_last
16 | return image_arr
17 |
18 |
19 | def concat_arr(img_arr):
20 | if len(img_arr) < 2:
21 | return img_arr[0]
22 | all_slice = img_arr[0]
23 | for im_slice in img_arr[1:]:
24 | all_slice = np.concatenate((all_slice, im_slice), axis=1)
25 | return all_slice
26 |
27 |
28 | def numpy_to_bytes(numpy_arr):
29 | cv_img = cv2.imencode('.png', numpy_arr)[1]
30 | img_bytes = bytes(bytearray(cv_img))
31 | return img_bytes
32 |
33 |
34 | def concat_frames(image_obj, need_frame=None):
35 | if not need_frame:
36 | need_frame = [0]
37 | img_arr = split_frames(image_obj, need_frame)
38 | img_arr = concat_arr(img_arr)
39 | return img_arr
40 |
41 |
42 | def blend_arr(img_arr):
43 | if len(img_arr) < 2:
44 | return img_arr[0]
45 | all_slice = img_arr[0]
46 | for im_slice in img_arr[1:]:
47 | all_slice = cv2.addWeighted(all_slice, 0.5, im_slice, 0.5, 0)
48 | # all_slice = cv2.equalizeHist(all_slice)
49 | return all_slice
50 |
51 |
52 | def blend_frame(image_obj, need_frame=None):
53 | if not need_frame:
54 | need_frame = [-1]
55 | img_arr = split_frames(image_obj, need_frame)
56 | img_arr = blend_arr(img_arr)
57 | if len(img_arr.shape) > 2 and img_arr.shape[2] == 3:
58 | img_arr = cv2.cvtColor(img_arr, cv2.COLOR_RGB2GRAY)
59 | img_arr = cv2.equalizeHist(img_arr)
60 | return img_arr
61 |
62 |
63 | def all_frames(image_obj):
64 | if isinstance(image_obj, list):
65 | image_obj = image_obj[0]
66 | stream = io.BytesIO(image_obj)
67 | pil_image = Image.open(stream)
68 | image_seq = ImageSequence.all_frames(pil_image)
69 | array_seq = [np.asarray(im.convert("RGB")) for im in image_seq]
70 | # [1::2]
71 | bytes_arr = [cv2.imencode('.png', img_arr)[1] for img_arr in array_seq]
72 | split_flag = b'\x99\x99\x99\x00\xff\xff999999.........99999\xff\x00\x99\x99\x99'
73 | return split_flag.join(bytes_arr).split(split_flag)
74 |
75 |
76 | def get_continuity_max(src: list):
77 | if not src:
78 | return ""
79 | elem_cont_len = lambda x: max(len(list(g)) for k, g in groupby(src) if k == x)
80 | target_list = [elem_cont_len(i) for i in src]
81 | target_index = target_list.index(max(target_list))
82 | return src[target_index]
83 |
84 |
85 | if __name__ == "__main__":
86 | pass
87 |
--------------------------------------------------------------------------------
/middleware/impl/rgb_filter.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # Author: kerlomz
4 | import numpy as np
5 | import cv2
6 |
7 |
8 | def rgb_filter(image_obj, need_rgb):
9 | low_rgb = np.array([i-15 for i in need_rgb])
10 | high_rgb = np.array([i+15 for i in need_rgb])
11 | mask = cv2.inRange(image_obj, lowerb=low_rgb, upperb=high_rgb)
12 | mask = cv2.bitwise_not(mask)
13 | mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB)
14 | # img_bytes = cv2.imencode('.png', mask)[1]
15 | return mask
16 |
17 |
18 | if __name__ == '__main__':
19 | pass
--------------------------------------------------------------------------------
/middleware/resource/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # Author: kerlomz
4 | from middleware.resource.color_filter import model_onnx
5 |
6 | color_model = model_onnx
7 | if __name__ == '__main__':
8 | pass
--------------------------------------------------------------------------------
/package.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # Author: kerlomz
4 | import os
5 | import cv2
6 | import time
7 | import stat
8 | import socket
9 | import paramiko
10 | import platform
11 | import distutils
12 | import tensorflow as tf
13 | tf.compat.v1.disable_v2_behavior()
14 | from enum import Enum, unique
15 | from utils import SystemUtils
16 | from config import resource_path
17 |
18 | from PyInstaller.__main__ import run, logger
19 | """ Used to package as a single executable """
20 |
21 | if platform.system() == 'Linux':
22 | if distutils.distutils_path.endswith('__init__.py'):
23 | distutils.distutils_path = os.path.dirname(distutils.distutils_path)
24 |
25 | with open("./resource/VERSION", "w", encoding="utf8") as f:
26 | today = time.strftime("%Y%m%d", time.localtime(time.time()))
27 | f.write(today)
28 |
29 |
30 | @unique
31 | class Version(Enum):
32 | CPU = 'CPU'
33 | GPU = 'GPU'
34 |
35 |
36 | if __name__ == '__main__':
37 |
38 | ver = Version.CPU
39 |
40 | upload = False
41 | server_ip = ""
42 | username = ""
43 | password = ""
44 | model_dir = "model"
45 | graph_dir = "graph"
46 |
47 | if ver == Version.GPU:
48 | opts = ['tornado_server_gpu.spec', '--distpath=dist']
49 | else:
50 | opts = ['tornado_server.spec', '--distpath=dist']
51 | run(opts)
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # Author: kerlomz
4 | from config import ModelConfig
5 |
6 |
7 | def decode_maps(categories):
8 | return {index: category for index, category in enumerate(categories, 0)}
9 |
10 |
11 | def predict_func(image_batch, _sess, dense_decoded, op_input, model: ModelConfig, output_split=None):
12 |
13 | output_split = model.output_split if output_split is None else output_split
14 |
15 | category_split = model.category_split if model.category_split else ""
16 |
17 | dense_decoded_code = _sess.run(dense_decoded, feed_dict={
18 | op_input: image_batch,
19 | })
20 | decoded_expression = []
21 | for item in dense_decoded_code:
22 | expression = []
23 |
24 | for i in item:
25 | if i == -1 or i == model.category_num:
26 | expression.append("")
27 | else:
28 | expression.append(decode_maps(model.category)[i])
29 | decoded_expression.append(category_split.join(expression))
30 | return output_split.join(decoded_expression) if len(decoded_expression) > 1 else decoded_expression[0]
31 |
--------------------------------------------------------------------------------
/pretreatment.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # Author: kerlomz
4 | import cv2
5 |
6 |
7 | class Pretreatment(object):
8 |
9 | def __init__(self, origin):
10 | self.origin = origin
11 |
12 | def get(self):
13 | return self.origin
14 |
15 | def binarization(self, value, modify=False):
16 | ret, _binarization = cv2.threshold(self.origin, value, 255, cv2.THRESH_BINARY)
17 | if modify:
18 | self.origin = _binarization
19 | return _binarization
20 |
21 |
22 | def preprocessing(image, binaryzation=-1):
23 | pretreatment = Pretreatment(image)
24 | if binaryzation > 0:
25 | pretreatment.binarization(binaryzation, True)
26 | return pretreatment.get()
27 |
28 |
29 | def preprocessing_by_func(exec_map, key, src_arr):
30 | if not exec_map:
31 | return src_arr
32 | target_arr = cv2.cvtColor(src_arr, cv2.COLOR_RGB2BGR)
33 | for sentence in exec_map.get(key):
34 | if sentence.startswith("@@"):
35 | target_arr = eval(sentence[2:])
36 | elif sentence.startswith("$$"):
37 | exec(sentence[2:])
38 | return cv2.cvtColor(target_arr, cv2.COLOR_BGR2RGB)
39 |
40 |
41 | if __name__ == '__main__':
42 | pass
43 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | flask
2 | gevent
3 | Flask-Caching
4 | gevent-websocket
5 | tf-nightly
6 | pillow
7 | opencv-python-headless
8 | numpy
9 | grpcio
10 | grpcio_tools
11 | requests
12 | pyyaml
13 | tornado
14 | watchdog
15 | pyinstaller
16 | sanic
17 | paramiko
18 | APScheduler
19 | requests
--------------------------------------------------------------------------------
/resource/VERSION:
--------------------------------------------------------------------------------
1 | 20211203
--------------------------------------------------------------------------------
/resource/favorite.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kerlomz/captcha_platform/f7d719bd1239a987996e266bd7fe35c96003b378/resource/favorite.ico
--------------------------------------------------------------------------------
/resource/icon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kerlomz/captcha_platform/f7d719bd1239a987996e266bd7fe35c96003b378/resource/icon.ico
--------------------------------------------------------------------------------
/sdk/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # Author: kerlomz
--------------------------------------------------------------------------------
/sdk/onnx/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # Author: kerlomz
--------------------------------------------------------------------------------
/sdk/onnx/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | pillow
3 | opencv-python==3.4.5.20
4 | pyyaml>=3.13
5 | onnxruntime
--------------------------------------------------------------------------------
/sdk/onnx/sdk.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # Author: kerlomz
4 | import io
5 | import os
6 | import pickle
7 | import cv2
8 | import time
9 | import yaml
10 | import binascii
11 | import numpy as np
12 | import PIL.Image as PIL_Image
13 | from enum import Enum, unique
14 | import onnxruntime as ort
15 |
16 | SPACE_TOKEN = ['']
17 | NUMBER = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
18 | ALPHA_UPPER = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U',
19 | 'V', 'W', 'X', 'Y', 'Z']
20 | ALPHA_LOWER = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u',
21 | 'v', 'w', 'x', 'y', 'z']
22 | ARITHMETIC = ['(', ')', '+', '-', '×', '÷', '=']
23 | CHINESE_3500 = [
24 | '一', '乙', '二', '十', '丁', '厂', '七', '卜', '人', '入', '八', '九', '几', '儿', '了', '力', '乃', '刀', '又', '三',
25 | '于', '干', '亏', '士', '工', '土', '才', '寸', '下', '大', '丈', '与', '万', '上', '小', '口', '巾', '山', '千', '乞',
26 | '川', '亿', '个', '勺', '久', '凡', '及', '夕', '丸', '么', '广', '亡', '门', '义', '之', '尸', '弓', '己', '已', '子',
27 | '卫', '也', '女', '飞', '刃', '习', '叉', '马', '乡', '丰', '王', '井', '开', '夫', '天', '无', '元', '专', '云', '扎',
28 | '艺', '木', '五', '支', '厅', '不', '太', '犬', '区', '历', '尤', '友', '匹', '车', '巨', '牙', '屯', '比', '互', '切',
29 | '瓦', '止', '少', '日', '中', '冈', '贝', '内', '水', '见', '午', '牛', '手', '毛', '气', '升', '长', '仁', '什', '片',
30 | '仆', '化', '仇', '币', '仍', '仅', '斤', '爪', '反', '介', '父', '从', '今', '凶', '分', '乏', '公', '仓', '月', '氏',
31 | '勿', '欠', '风', '丹', '匀', '乌', '凤', '勾', '文', '六', '方', '火', '为', '斗', '忆', '订', '计', '户', '认', '心',
32 | '尺', '引', '丑', '巴', '孔', '队', '办', '以', '允', '予', '劝', '双', '书', '幻', '玉', '刊', '示', '末', '未', '击',
33 | '打', '巧', '正', '扑', '扒', '功', '扔', '去', '甘', '世', '古', '节', '本', '术', '可', '丙', '左', '厉', '右', '石',
34 | '布', '龙', '平', '灭', '轧', '东', '卡', '北', '占', '业', '旧', '帅', '归', '且', '旦', '目', '叶', '甲', '申', '叮',
35 | '电', '号', '田', '由', '史', '只', '央', '兄', '叼', '叫', '另', '叨', '叹', '四', '生', '失', '禾', '丘', '付', '仗',
36 | '代', '仙', '们', '仪', '白', '仔', '他', '斥', '瓜', '乎', '丛', '令', '用', '甩', '印', '乐', '句', '匆', '册', '犯',
37 | '外', '处', '冬', '鸟', '务', '包', '饥', '主', '市', '立', '闪', '兰', '半', '汁', '汇', '头', '汉', '宁', '穴', '它',
38 | '讨', '写', '让', '礼', '训', '必', '议', '讯', '记', '永', '司', '尼', '民', '出', '辽', '奶', '奴', '加', '召', '皮',
39 | '边', '发', '孕', '圣', '对', '台', '矛', '纠', '母', '幼', '丝', '式', '刑', '动', '扛', '寺', '吉', '扣', '考', '托',
40 | '老', '执', '巩', '圾', '扩', '扫', '地', '扬', '场', '耳', '共', '芒', '亚', '芝', '朽', '朴', '机', '权', '过', '臣',
41 | '再', '协', '西', '压', '厌', '在', '有', '百', '存', '而', '页', '匠', '夸', '夺', '灰', '达', '列', '死', '成', '夹',
42 | '轨', '邪', '划', '迈', '毕', '至', '此', '贞', '师', '尘', '尖', '劣', '光', '当', '早', '吐', '吓', '虫', '曲', '团',
43 | '同', '吊', '吃', '因', '吸', '吗', '屿', '帆', '岁', '回', '岂', '刚', '则', '肉', '网', '年', '朱', '先', '丢', '舌',
44 | '竹', '迁', '乔', '伟', '传', '乒', '乓', '休', '伍', '伏', '优', '伐', '延', '件', '任', '伤', '价', '份', '华', '仰',
45 | '仿', '伙', '伪', '自', '血', '向', '似', '后', '行', '舟', '全', '会', '杀', '合', '兆', '企', '众', '爷', '伞', '创',
46 | '肌', '朵', '杂', '危', '旬', '旨', '负', '各', '名', '多', '争', '色', '壮', '冲', '冰', '庄', '庆', '亦', '刘', '齐',
47 | '交', '次', '衣', '产', '决', '充', '妄', '闭', '问', '闯', '羊', '并', '关', '米', '灯', '州', '汗', '污', '江', '池',
48 | '汤', '忙', '兴', '宇', '守', '宅', '字', '安', '讲', '军', '许', '论', '农', '讽', '设', '访', '寻', '那', '迅', '尽',
49 | '导', '异', '孙', '阵', '阳', '收', '阶', '阴', '防', '奸', '如', '妇', '好', '她', '妈', '戏', '羽', '观', '欢', '买',
50 | '红', '纤', '级', '约', '纪', '驰', '巡', '寿', '弄', '麦', '形', '进', '戒', '吞', '远', '违', '运', '扶', '抚', '坛',
51 | '技', '坏', '扰', '拒', '找', '批', '扯', '址', '走', '抄', '坝', '贡', '攻', '赤', '折', '抓', '扮', '抢', '孝', '均',
52 | '抛', '投', '坟', '抗', '坑', '坊', '抖', '护', '壳', '志', '扭', '块', '声', '把', '报', '却', '劫', '芽', '花', '芹',
53 | '芬', '苍', '芳', '严', '芦', '劳', '克', '苏', '杆', '杠', '杜', '材', '村', '杏', '极', '李', '杨', '求', '更', '束',
54 | '豆', '两', '丽', '医', '辰', '励', '否', '还', '歼', '来', '连', '步', '坚', '旱', '盯', '呈', '时', '吴', '助', '县',
55 | '里', '呆', '园', '旷', '围', '呀', '吨', '足', '邮', '男', '困', '吵', '串', '员', '听', '吩', '吹', '呜', '吧', '吼',
56 | '别', '岗', '帐', '财', '针', '钉', '告', '我', '乱', '利', '秃', '秀', '私', '每', '兵', '估', '体', '何', '但', '伸',
57 | '作', '伯', '伶', '佣', '低', '你', '住', '位', '伴', '身', '皂', '佛', '近', '彻', '役', '返', '余', '希', '坐', '谷',
58 | '妥', '含', '邻', '岔', '肝', '肚', '肠', '龟', '免', '狂', '犹', '角', '删', '条', '卵', '岛', '迎', '饭', '饮', '系',
59 | '言', '冻', '状', '亩', '况', '床', '库', '疗', '应', '冷', '这', '序', '辛', '弃', '冶', '忘', '闲', '间', '闷', '判',
60 | '灶', '灿', '弟', '汪', '沙', '汽', '沃', '泛', '沟', '没', '沈', '沉', '怀', '忧', '快', '完', '宋', '宏', '牢', '究',
61 | '穷', '灾', '良', '证', '启', '评', '补', '初', '社', '识', '诉', '诊', '词', '译', '君', '灵', '即', '层', '尿', '尾',
62 | '迟', '局', '改', '张', '忌', '际', '陆', '阿', '陈', '阻', '附', '妙', '妖', '妨', '努', '忍', '劲', '鸡', '驱', '纯',
63 | '纱', '纳', '纲', '驳', '纵', '纷', '纸', '纹', '纺', '驴', '纽', '奉', '玩', '环', '武', '青', '责', '现', '表', '规',
64 | '抹', '拢', '拔', '拣', '担', '坦', '押', '抽', '拐', '拖', '拍', '者', '顶', '拆', '拥', '抵', '拘', '势', '抱', '垃',
65 | '拉', '拦', '拌', '幸', '招', '坡', '披', '拨', '择', '抬', '其', '取', '苦', '若', '茂', '苹', '苗', '英', '范', '直',
66 | '茄', '茎', '茅', '林', '枝', '杯', '柜', '析', '板', '松', '枪', '构', '杰', '述', '枕', '丧', '或', '画', '卧', '事',
67 | '刺', '枣', '雨', '卖', '矿', '码', '厕', '奔', '奇', '奋', '态', '欧', '垄', '妻', '轰', '顷', '转', '斩', '轮', '软',
68 | '到', '非', '叔', '肯', '齿', '些', '虎', '虏', '肾', '贤', '尚', '旺', '具', '果', '味', '昆', '国', '昌', '畅', '明',
69 | '易', '昂', '典', '固', '忠', '咐', '呼', '鸣', '咏', '呢', '岸', '岩', '帖', '罗', '帜', '岭', '凯', '败', '贩', '购',
70 | '图', '钓', '制', '知', '垂', '牧', '物', '乖', '刮', '秆', '和', '季', '委', '佳', '侍', '供', '使', '例', '版', '侄',
71 | '侦', '侧', '凭', '侨', '佩', '货', '依', '的', '迫', '质', '欣', '征', '往', '爬', '彼', '径', '所', '舍', '金', '命',
72 | '斧', '爸', '采', '受', '乳', '贪', '念', '贫', '肤', '肺', '肢', '肿', '胀', '朋', '股', '肥', '服', '胁', '周', '昏',
73 | '鱼', '兔', '狐', '忽', '狗', '备', '饰', '饱', '饲', '变', '京', '享', '店', '夜', '庙', '府', '底', '剂', '郊', '废',
74 | '净', '盲', '放', '刻', '育', '闸', '闹', '郑', '券', '卷', '单', '炒', '炊', '炕', '炎', '炉', '沫', '浅', '法', '泄',
75 | '河', '沾', '泪', '油', '泊', '沿', '泡', '注', '泻', '泳', '泥', '沸', '波', '泼', '泽', '治', '怖', '性', '怕', '怜',
76 | '怪', '学', '宝', '宗', '定', '宜', '审', '宙', '官', '空', '帘', '实', '试', '郎', '诗', '肩', '房', '诚', '衬', '衫',
77 | '视', '话', '诞', '询', '该', '详', '建', '肃', '录', '隶', '居', '届', '刷', '屈', '弦', '承', '孟', '孤', '陕', '降',
78 | '限', '妹', '姑', '姐', '姓', '始', '驾', '参', '艰', '线', '练', '组', '细', '驶', '织', '终', '驻', '驼', '绍', '经',
79 | '贯', '奏', '春', '帮', '珍', '玻', '毒', '型', '挂', '封', '持', '项', '垮', '挎', '城', '挠', '政', '赴', '赵', '挡',
80 | '挺', '括', '拴', '拾', '挑', '指', '垫', '挣', '挤', '拼', '挖', '按', '挥', '挪', '某', '甚', '革', '荐', '巷', '带',
81 | '草', '茧', '茶', '荒', '茫', '荡', '荣', '故', '胡', '南', '药', '标', '枯', '柄', '栋', '相', '查', '柏', '柳', '柱',
82 | '柿', '栏', '树', '要', '咸', '威', '歪', '研', '砖', '厘', '厚', '砌', '砍', '面', '耐', '耍', '牵', '残', '殃', '轻',
83 | '鸦', '皆', '背', '战', '点', '临', '览', '竖', '省', '削', '尝', '是', '盼', '眨', '哄', '显', '哑', '冒', '映', '星',
84 | '昨', '畏', '趴', '胃', '贵', '界', '虹', '虾', '蚁', '思', '蚂', '虽', '品', '咽', '骂', '哗', '咱', '响', '哈', '咬',
85 | '咳', '哪', '炭', '峡', '罚', '贱', '贴', '骨', '钞', '钟', '钢', '钥', '钩', '卸', '缸', '拜', '看', '矩', '怎', '牲',
86 | '选', '适', '秒', '香', '种', '秋', '科', '重', '复', '竿', '段', '便', '俩', '贷', '顺', '修', '保', '促', '侮', '俭',
87 | '俗', '俘', '信', '皇', '泉', '鬼', '侵', '追', '俊', '盾', '待', '律', '很', '须', '叙', '剑', '逃', '食', '盆', '胆',
88 | '胜', '胞', '胖', '脉', '勉', '狭', '狮', '独', '狡', '狱', '狠', '贸', '怨', '急', '饶', '蚀', '饺', '饼', '弯', '将',
89 | '奖', '哀', '亭', '亮', '度', '迹', '庭', '疮', '疯', '疫', '疤', '姿', '亲', '音', '帝', '施', '闻', '阀', '阁', '差',
90 | '养', '美', '姜', '叛', '送', '类', '迷', '前', '首', '逆', '总', '炼', '炸', '炮', '烂', '剃', '洁', '洪', '洒', '浇',
91 | '浊', '洞', '测', '洗', '活', '派', '洽', '染', '济', '洋', '洲', '浑', '浓', '津', '恒', '恢', '恰', '恼', '恨', '举',
92 | '觉', '宣', '室', '宫', '宪', '突', '穿', '窃', '客', '冠', '语', '扁', '袄', '祖', '神', '祝', '误', '诱', '说', '诵',
93 | '垦', '退', '既', '屋', '昼', '费', '陡', '眉', '孩', '除', '险', '院', '娃', '姥', '姨', '姻', '娇', '怒', '架', '贺',
94 | '盈', '勇', '怠', '柔', '垒', '绑', '绒', '结', '绕', '骄', '绘', '给', '络', '骆', '绝', '绞', '统', '耕', '耗', '艳',
95 | '泰', '珠', '班', '素', '蚕', '顽', '盏', '匪', '捞', '栽', '捕', '振', '载', '赶', '起', '盐', '捎', '捏', '埋', '捉',
96 | '捆', '捐', '损', '都', '哲', '逝', '捡', '换', '挽', '热', '恐', '壶', '挨', '耻', '耽', '恭', '莲', '莫', '荷', '获',
97 | '晋', '恶', '真', '框', '桂', '档', '桐', '株', '桥', '桃', '格', '校', '核', '样', '根', '索', '哥', '速', '逗', '栗',
98 | '配', '翅', '辱', '唇', '夏', '础', '破', '原', '套', '逐', '烈', '殊', '顾', '轿', '较', '顿', '毙', '致', '柴', '桌',
99 | '虑', '监', '紧', '党', '晒', '眠', '晓', '鸭', '晃', '晌', '晕', '蚊', '哨', '哭', '恩', '唤', '啊', '唉', '罢', '峰',
100 | '圆', '贼', '贿', '钱', '钳', '钻', '铁', '铃', '铅', '缺', '氧', '特', '牺', '造', '乘', '敌', '秤', '租', '积', '秧',
101 | '秩', '称', '秘', '透', '笔', '笑', '笋', '债', '借', '值', '倚', '倾', '倒', '倘', '俱', '倡', '候', '俯', '倍', '倦',
102 | '健', '臭', '射', '躬', '息', '徒', '徐', '舰', '舱', '般', '航', '途', '拿', '爹', '爱', '颂', '翁', '脆', '脂', '胸',
103 | '胳', '脏', '胶', '脑', '狸', '狼', '逢', '留', '皱', '饿', '恋', '桨', '浆', '衰', '高', '席', '准', '座', '脊', '症',
104 | '病', '疾', '疼', '疲', '效', '离', '唐', '资', '凉', '站', '剖', '竞', '部', '旁', '旅', '畜', '阅', '羞', '瓶', '拳',
105 | '粉', '料', '益', '兼', '烤', '烘', '烦', '烧', '烛', '烟', '递', '涛', '浙', '涝', '酒', '涉', '消', '浩', '海', '涂',
106 | '浴', '浮', '流', '润', '浪', '浸', '涨', '烫', '涌', '悟', '悄', '悔', '悦', '害', '宽', '家', '宵', '宴', '宾', '窄',
107 | '容', '宰', '案', '请', '朗', '诸', '读', '扇', '袜', '袖', '袍', '被', '祥', '课', '谁', '调', '冤', '谅', '谈', '谊',
108 | '剥', '恳', '展', '剧', '屑', '弱', '陵', '陶', '陷', '陪', '娱', '娘', '通', '能', '难', '预', '桑', '绢', '绣', '验',
109 | '继', '球', '理', '捧', '堵', '描', '域', '掩', '捷', '排', '掉', '堆', '推', '掀', '授', '教', '掏', '掠', '培', '接',
110 | '控', '探', '据', '掘', '职', '基', '著', '勒', '黄', '萌', '萝', '菌', '菜', '萄', '菊', '萍', '菠', '营', '械', '梦',
111 | '梢', '梅', '检', '梳', '梯', '桶', '救', '副', '票', '戚', '爽', '聋', '袭', '盛', '雪', '辅', '辆', '虚', '雀', '堂',
112 | '常', '匙', '晨', '睁', '眯', '眼', '悬', '野', '啦', '晚', '啄', '距', '跃', '略', '蛇', '累', '唱', '患', '唯', '崖',
113 | '崭', '崇', '圈', '铜', '铲', '银', '甜', '梨', '犁', '移', '笨', '笼', '笛', '符', '第', '敏', '做', '袋', '悠', '偿',
114 | '偶', '偷', '您', '售', '停', '偏', '假', '得', '衔', '盘', '船', '斜', '盒', '鸽', '悉', '欲', '彩', '领', '脚', '脖',
115 | '脸', '脱', '象', '够', '猜', '猪', '猎', '猫', '猛', '馅', '馆', '凑', '减', '毫', '麻', '痒', '痕', '廊', '康', '庸',
116 | '鹿', '盗', '章', '竟', '商', '族', '旋', '望', '率', '着', '盖', '粘', '粗', '粒', '断', '剪', '兽', '清', '添', '淋',
117 | '淹', '渠', '渐', '混', '渔', '淘', '液', '淡', '深', '婆', '梁', '渗', '情', '惜', '惭', '悼', '惧', '惕', '惊', '惨',
118 | '惯', '寇', '寄', '宿', '窑', '密', '谋', '谎', '祸', '谜', '逮', '敢', '屠', '弹', '随', '蛋', '隆', '隐', '婚', '婶',
119 | '颈', '绩', '绪', '续', '骑', '绳', '维', '绵', '绸', '绿', '琴', '斑', '替', '款', '堪', '搭', '塔', '越', '趁', '趋',
120 | '超', '提', '堤', '博', '揭', '喜', '插', '揪', '搜', '煮', '援', '裁', '搁', '搂', '搅', '握', '揉', '斯', '期', '欺',
121 | '联', '散', '惹', '葬', '葛', '董', '葡', '敬', '葱', '落', '朝', '辜', '葵', '棒', '棋', '植', '森', '椅', '椒', '棵',
122 | '棍', '棉', '棚', '棕', '惠', '惑', '逼', '厨', '厦', '硬', '确', '雁', '殖', '裂', '雄', '暂', '雅', '辈', '悲', '紫',
123 | '辉', '敞', '赏', '掌', '晴', '暑', '最', '量', '喷', '晶', '喇', '遇', '喊', '景', '践', '跌', '跑', '遗', '蛙', '蛛',
124 | '蜓', '喝', '喂', '喘', '喉', '幅', '帽', '赌', '赔', '黑', '铸', '铺', '链', '销', '锁', '锄', '锅', '锈', '锋', '锐',
125 | '短', '智', '毯', '鹅', '剩', '稍', '程', '稀', '税', '筐', '等', '筑', '策', '筛', '筒', '答', '筋', '筝', '傲', '傅',
126 | '牌', '堡', '集', '焦', '傍', '储', '奥', '街', '惩', '御', '循', '艇', '舒', '番', '释', '禽', '腊', '脾', '腔', '鲁',
127 | '猾', '猴', '然', '馋', '装', '蛮', '就', '痛', '童', '阔', '善', '羡', '普', '粪', '尊', '道', '曾', '焰', '港', '湖',
128 | '渣', '湿', '温', '渴', '滑', '湾', '渡', '游', '滋', '溉', '愤', '慌', '惰', '愧', '愉', '慨', '割', '寒', '富', '窜',
129 | '窝', '窗', '遍', '裕', '裤', '裙', '谢', '谣', '谦', '属', '屡', '强', '粥', '疏', '隔', '隙', '絮', '嫂', '登', '缎',
130 | '缓', '编', '骗', '缘', '瑞', '魂', '肆', '摄', '摸', '填', '搏', '塌', '鼓', '摆', '携', '搬', '摇', '搞', '塘', '摊',
131 | '蒜', '勤', '鹊', '蓝', '墓', '幕', '蓬', '蓄', '蒙', '蒸', '献', '禁', '楚', '想', '槐', '榆', '楼', '概', '赖', '酬',
132 | '感', '碍', '碑', '碎', '碰', '碗', '碌', '雷', '零', '雾', '雹', '输', '督', '龄', '鉴', '睛', '睡', '睬', '鄙', '愚',
133 | '暖', '盟', '歇', '暗', '照', '跨', '跳', '跪', '路', '跟', '遣', '蛾', '蜂', '嗓', '置', '罪', '罩', '错', '锡', '锣',
134 | '锤', '锦', '键', '锯', '矮', '辞', '稠', '愁', '筹', '签', '简', '毁', '舅', '鼠', '催', '傻', '像', '躲', '微', '愈',
135 | '遥', '腰', '腥', '腹', '腾', '腿', '触', '解', '酱', '痰', '廉', '新', '韵', '意', '粮', '数', '煎', '塑', '慈', '煤',
136 | '煌', '满', '漠', '源', '滤', '滥', '滔', '溪', '溜', '滚', '滨', '粱', '滩', '慎', '誉', '塞', '谨', '福', '群', '殿',
137 | '辟', '障', '嫌', '嫁', '叠', '缝', '缠', '静', '碧', '璃', '墙', '撇', '嘉', '摧', '截', '誓', '境', '摘', '摔', '聚',
138 | '蔽', '慕', '暮', '蔑', '模', '榴', '榜', '榨', '歌', '遭', '酷', '酿', '酸', '磁', '愿', '需', '弊', '裳', '颗', '嗽',
139 | '蜻', '蜡', '蝇', '蜘', '赚', '锹', '锻', '舞', '稳', '算', '箩', '管', '僚', '鼻', '魄', '貌', '膜', '膊', '膀', '鲜',
140 | '疑', '馒', '裹', '敲', '豪', '膏', '遮', '腐', '瘦', '辣', '竭', '端', '旗', '精', '歉', '熄', '熔', '漆', '漂', '漫',
141 | '滴', '演', '漏', '慢', '寨', '赛', '察', '蜜', '谱', '嫩', '翠', '熊', '凳', '骡', '缩', '慧', '撕', '撒', '趣', '趟',
142 | '撑', '播', '撞', '撤', '增', '聪', '鞋', '蕉', '蔬', '横', '槽', '樱', '橡', '飘', '醋', '醉', '震', '霉', '瞒', '题',
143 | '暴', '瞎', '影', '踢', '踏', '踩', '踪', '蝶', '蝴', '嘱', '墨', '镇', '靠', '稻', '黎', '稿', '稼', '箱', '箭', '篇',
144 | '僵', '躺', '僻', '德', '艘', '膝', '膛', '熟', '摩', '颜', '毅', '糊', '遵', '潜', '潮', '懂', '额', '慰', '劈', '操',
145 | '燕', '薯', '薪', '薄', '颠', '橘', '整', '融', '醒', '餐', '嘴', '蹄', '器', '赠', '默', '镜', '赞', '篮', '邀', '衡',
146 | '膨', '雕', '磨', '凝', '辨', '辩', '糖', '糕', '燃', '澡', '激', '懒', '壁', '避', '缴', '戴', '擦', '鞠', '藏', '霜',
147 | '霞', '瞧', '蹈', '螺', '穗', '繁', '辫', '赢', '糟', '糠', '燥', '臂', '翼', '骤', '鞭', '覆', '蹦', '镰', '翻', '鹰',
148 | '警', '攀', '蹲', '颤', '瓣', '爆', '疆', '壤', '耀', '躁', '嚼', '嚷', '籍', '魔', '灌', '蠢', '霸', '露', '囊', '罐',
149 | '匕', '刁', '丐', '歹', '戈', '夭', '仑', '讥', '冗', '邓', '艾', '夯', '凸', '卢', '叭', '叽', '皿', '凹', '囚', '矢',
150 | '乍', '尔', '冯', '玄', '邦', '迂', '邢', '芋', '芍', '吏', '夷', '吁', '吕', '吆', '屹', '廷', '迄', '臼', '仲', '伦',
151 | '伊', '肋', '旭', '匈', '凫', '妆', '亥', '汛', '讳', '讶', '讹', '讼', '诀', '弛', '阱', '驮', '驯', '纫', '玖', '玛',
152 | '韧', '抠', '扼', '汞', '扳', '抡', '坎', '坞', '抑', '拟', '抒', '芙', '芜', '苇', '芥', '芯', '芭', '杖', '杉', '巫',
153 | '杈', '甫', '匣', '轩', '卤', '肖', '吱', '吠', '呕', '呐', '吟', '呛', '吻', '吭', '邑', '囤', '吮', '岖', '牡', '佑',
154 | '佃', '伺', '囱', '肛', '肘', '甸', '狈', '鸠', '彤', '灸', '刨', '庇', '吝', '庐', '闰', '兑', '灼', '沐', '沛', '汰',
155 | '沥', '沦', '汹', '沧', '沪', '忱', '诅', '诈', '罕', '屁', '坠', '妓', '姊', '妒', '纬', '玫', '卦', '坷', '坯', '拓',
156 | '坪', '坤', '拄', '拧', '拂', '拙', '拇', '拗', '茉', '昔', '苛', '苫', '苟', '苞', '茁', '苔', '枉', '枢', '枚', '枫',
157 | '杭', '郁', '矾', '奈', '奄', '殴', '歧', '卓', '昙', '哎', '咕', '呵', '咙', '呻', '咒', '咆', '咖', '帕', '账', '贬',
158 | '贮', '氛', '秉', '岳', '侠', '侥', '侣', '侈', '卑', '刽', '刹', '肴', '觅', '忿', '瓮', '肮', '肪', '狞', '庞', '疟',
159 | '疙', '疚', '卒', '氓', '炬', '沽', '沮', '泣', '泞', '泌', '沼', '怔', '怯', '宠', '宛', '衩', '祈', '诡', '帚', '屉',
160 | '弧', '弥', '陋', '陌', '函', '姆', '虱', '叁', '绅', '驹', '绊', '绎', '契', '贰', '玷', '玲', '珊', '拭', '拷', '拱',
161 | '挟', '垢', '垛', '拯', '荆', '茸', '茬', '荚', '茵', '茴', '荞', '荠', '荤', '荧', '荔', '栈', '柑', '栅', '柠', '枷',
162 | '勃', '柬', '砂', '泵', '砚', '鸥', '轴', '韭', '虐', '昧', '盹', '咧', '昵', '昭', '盅', '勋', '哆', '咪', '哟', '幽',
163 | '钙', '钝', '钠', '钦', '钧', '钮', '毡', '氢', '秕', '俏', '俄', '俐', '侯', '徊', '衍', '胚', '胧', '胎', '狰', '饵',
164 | '峦', '奕', '咨', '飒', '闺', '闽', '籽', '娄', '烁', '炫', '洼', '柒', '涎', '洛', '恃', '恍', '恬', '恤', '宦', '诫',
165 | '诬', '祠', '诲', '屏', '屎', '逊', '陨', '姚', '娜', '蚤', '骇', '耘', '耙', '秦', '匿', '埂', '捂', '捍', '袁', '捌',
166 | '挫', '挚', '捣', '捅', '埃', '耿', '聂', '荸', '莽', '莱', '莉', '莹', '莺', '梆', '栖', '桦', '栓', '桅', '桩', '贾',
167 | '酌', '砸', '砰', '砾', '殉', '逞', '哮', '唠', '哺', '剔', '蚌', '蚜', '畔', '蚣', '蚪', '蚓', '哩', '圃', '鸯', '唁',
168 | '哼', '唆', '峭', '唧', '峻', '赂', '赃', '钾', '铆', '氨', '秫', '笆', '俺', '赁', '倔', '殷', '耸', '舀', '豺', '豹',
169 | '颁', '胯', '胰', '脐', '脓', '逛', '卿', '鸵', '鸳', '馁', '凌', '凄', '衷', '郭', '斋', '疹', '紊', '瓷', '羔', '烙',
170 | '浦', '涡', '涣', '涤', '涧', '涕', '涩', '悍', '悯', '窍', '诺', '诽', '袒', '谆', '祟', '恕', '娩', '骏', '琐', '麸',
171 | '琉', '琅', '措', '捺', '捶', '赦', '埠', '捻', '掐', '掂', '掖', '掷', '掸', '掺', '勘', '聊', '娶', '菱', '菲', '萎',
172 | '菩', '萤', '乾', '萧', '萨', '菇', '彬', '梗', '梧', '梭', '曹', '酝', '酗', '厢', '硅', '硕', '奢', '盔', '匾', '颅',
173 | '彪', '眶', '晤', '曼', '晦', '冕', '啡', '畦', '趾', '啃', '蛆', '蚯', '蛉', '蛀', '唬', '唾', '啤', '啥', '啸', '崎',
174 | '逻', '崔', '崩', '婴', '赊', '铐', '铛', '铝', '铡', '铣', '铭', '矫', '秸', '秽', '笙', '笤', '偎', '傀', '躯', '兜',
175 | '衅', '徘', '徙', '舶', '舷', '舵', '敛', '翎', '脯', '逸', '凰', '猖', '祭', '烹', '庶', '庵', '痊', '阎', '阐', '眷',
176 | '焊', '焕', '鸿', '涯', '淑', '淌', '淮', '淆', '渊', '淫', '淳', '淤', '淀', '涮', '涵', '惦', '悴', '惋', '寂', '窒',
177 | '谍', '谐', '裆', '袱', '祷', '谒', '谓', '谚', '尉', '堕', '隅', '婉', '颇', '绰', '绷', '综', '绽', '缀', '巢', '琳',
178 | '琢', '琼', '揍', '堰', '揩', '揽', '揖', '彭', '揣', '搀', '搓', '壹', '搔', '葫', '募', '蒋', '蒂', '韩', '棱', '椰',
179 | '焚', '椎', '棺', '榔', '椭', '粟', '棘', '酣', '酥', '硝', '硫', '颊', '雳', '翘', '凿', '棠', '晰', '鼎', '喳', '遏',
180 | '晾', '畴', '跋', '跛', '蛔', '蜒', '蛤', '鹃', '喻', '啼', '喧', '嵌', '赋', '赎', '赐', '锉', '锌', '甥', '掰', '氮',
181 | '氯', '黍', '筏', '牍', '粤', '逾', '腌', '腋', '腕', '猩', '猬', '惫', '敦', '痘', '痢', '痪', '竣', '翔', '奠', '遂',
182 | '焙', '滞', '湘', '渤', '渺', '溃', '溅', '湃', '愕', '惶', '寓', '窖', '窘', '雇', '谤', '犀', '隘', '媒', '媚', '婿',
183 | '缅', '缆', '缔', '缕', '骚', '瑟', '鹉', '瑰', '搪', '聘', '斟', '靴', '靶', '蓖', '蒿', '蒲', '蓉', '楔', '椿', '楷',
184 | '榄', '楞', '楣', '酪', '碘', '硼', '碉', '辐', '辑', '频', '睹', '睦', '瞄', '嗜', '嗦', '暇', '畸', '跷', '跺', '蜈',
185 | '蜗', '蜕', '蛹', '嗅', '嗡', '嗤', '署', '蜀', '幌', '锚', '锥', '锨', '锭', '锰', '稚', '颓', '筷', '魁', '衙', '腻',
186 | '腮', '腺', '鹏', '肄', '猿', '颖', '煞', '雏', '馍', '馏', '禀', '痹', '廓', '痴', '靖', '誊', '漓', '溢', '溯', '溶',
187 | '滓', '溺', '寞', '窥', '窟', '寝', '褂', '裸', '谬', '媳', '嫉', '缚', '缤', '剿', '赘', '熬', '赫', '蔫', '摹', '蔓',
188 | '蔗', '蔼', '熙', '蔚', '兢', '榛', '榕', '酵', '碟', '碴', '碱', '碳', '辕', '辖', '雌', '墅', '嘁', '踊', '蝉', '嘀',
189 | '幔', '镀', '舔', '熏', '箍', '箕', '箫', '舆', '僧', '孵', '瘩', '瘟', '彰', '粹', '漱', '漩', '漾', '慷', '寡', '寥',
190 | '谭', '褐', '褪', '隧', '嫡', '缨', '撵', '撩', '撮', '撬', '擒', '墩', '撰', '鞍', '蕊', '蕴', '樊', '樟', '橄', '敷',
191 | '豌', '醇', '磕', '磅', '碾', '憋', '嘶', '嘲', '嘹', '蝠', '蝎', '蝌', '蝗', '蝙', '嘿', '幢', '镊', '镐', '稽', '篓',
192 | '膘', '鲤', '鲫', '褒', '瘪', '瘤', '瘫', '凛', '澎', '潭', '潦', '澳', '潘', '澈', '澜', '澄', '憔', '懊', '憎', '翩',
193 | '褥', '谴', '鹤', '憨', '履', '嬉', '豫', '缭', '撼', '擂', '擅', '蕾', '薛', '薇', '擎', '翰', '噩', '橱', '橙', '瓢',
194 | '蟥', '霍', '霎', '辙', '冀', '踱', '蹂', '蟆', '螃', '螟', '噪', '鹦', '黔', '穆', '篡', '篷', '篙', '篱', '儒', '膳',
195 | '鲸', '瘾', '瘸', '糙', '燎', '濒', '憾', '懈', '窿', '缰', '壕', '藐', '檬', '檐', '檩', '檀', '礁', '磷', '了', '瞬',
196 | '瞳', '瞪', '曙', '蹋', '蟋', '蟀', '嚎', '赡', '镣', '魏', '簇', '儡', '徽', '爵', '朦', '臊', '鳄', '糜', '癌', '懦',
197 | '豁', '臀', '藕', '藤', '瞻', '嚣', '鳍', '癞', '瀑', '襟', '璧', '戳', '攒', '孽', '蘑', '藻', '鳖', '蹭', '蹬', '簸',
198 | '簿', '蟹', '靡', '癣', '羹', '鬓', '攘', '蠕', '巍', '鳞', '糯', '譬', '霹', '躏', '髓', '蘸', '镶', '瓤', '矗', '圳',
199 | '珏', '蕙', '旻', '涅', '攸', '嘛', '醪', '缪', '噗', '瞨', '靳', '帷', '徨',
200 | ]
201 |
202 | FLOAT = ['.']
203 |
204 | SIMPLE_CATEGORY_MODEL = dict(
205 | NUMERIC=NUMBER,
206 | ALPHANUMERIC=NUMBER + ALPHA_LOWER + ALPHA_UPPER,
207 | ALPHANUMERIC_LOWER=NUMBER + ALPHA_LOWER,
208 | ALPHANUMERIC_UPPER=NUMBER + ALPHA_UPPER,
209 | ALPHABET_LOWER=ALPHA_LOWER,
210 | ALPHABET_UPPER=ALPHA_UPPER,
211 | ALPHABET=ALPHA_LOWER + ALPHA_UPPER,
212 | ARITHMETIC=NUMBER + ARITHMETIC,
213 | FLOAT=NUMBER + FLOAT,
214 | CHS_3500=CHINESE_3500,
215 | ALPHANUMERIC_MIX_CHS_3500_LOWER=NUMBER + ALPHA_LOWER + CHINESE_3500
216 | )
217 |
218 |
219 | def encode_maps(source):
220 | return {category: i for i, category in enumerate(source, 0)}
221 |
222 |
223 | @unique
224 | class ModelScene(Enum):
225 | """模型场景枚举"""
226 | Classification = 'Classification'
227 |
228 |
229 | @unique
230 | class ModelField(Enum):
231 | """模型类别枚举"""
232 | Image = 'Image'
233 | Text = 'Text'
234 |
235 |
236 | MODEL_SCENE_MAP = {
237 | 'Classification': ModelScene.Classification
238 | }
239 |
240 | MODEL_FIELD_MAP = {
241 | 'Image': ModelField.Image,
242 | 'Text': ModelField.Text
243 | }
244 |
245 |
246 | class ModelConfig(object):
247 |
248 | @staticmethod
249 | def category_extract(param):
250 | if isinstance(param, list):
251 | return param
252 | if isinstance(param, str):
253 | if param in SIMPLE_CATEGORY_MODEL.keys():
254 | return SIMPLE_CATEGORY_MODEL.get(param)
255 | raise ValueError(
256 | "Category set configuration error, customized category set should be list type"
257 | )
258 |
259 | @property
260 | def model_conf(self) -> dict:
261 | if self.model_content:
262 | return self.model_content
263 | with open(self.model_conf_path, 'r', encoding="utf-8") as sys_fp:
264 | sys_stream = sys_fp.read()
265 | return yaml.load(sys_stream, Loader=yaml.SafeLoader)
266 |
267 | def __init__(self, model_conf_path=None, model_content=None):
268 | self.model_content = model_content
269 | self.model_path = model_conf_path
270 | self.graph_path = os.path.dirname(self.model_path) if model_conf_path else ""
271 | self.model_conf_path = model_conf_path
272 | self.model_conf_demo = 'model_demo.yaml'
273 |
274 | """MODEL"""
275 | self.model_root: dict = self.model_conf['Model']
276 | self.model_name: str = self.model_root.get('ModelName')
277 | self.model_version: float = self.model_root.get('Version')
278 | self.model_version = self.model_version if self.model_version else 1.0
279 | self.model_field_param: str = self.model_root.get('ModelField')
280 | self.model_field: ModelField = self.param_convert(
281 | source=self.model_field_param,
282 | param_map=MODEL_FIELD_MAP,
283 | text="Current model field ({model_field}) is not supported".format(model_field=self.model_field_param),
284 | code=50002
285 | )
286 |
287 | self.model_scene_param: str = self.model_root.get('ModelScene')
288 |
289 | self.model_scene: ModelScene = self.param_convert(
290 | source=self.model_scene_param,
291 | param_map=MODEL_SCENE_MAP,
292 | text="Current model scene ({model_scene}) is not supported".format(model_scene=self.model_scene_param),
293 | code=50001
294 | )
295 |
296 | """SYSTEM"""
297 | self.checkpoint_tag = 'checkpoint'
298 | self.system_root: dict = self.model_conf['System']
299 | self.memory_usage: float = self.system_root.get('MemoryUsage')
300 |
301 | """FIELD PARAM - IMAGE"""
302 | self.field_root: dict = self.model_conf['FieldParam']
303 | self.category_param = self.field_root.get('Category')
304 | self.category_value = self.category_extract(self.category_param)
305 | if self.category_value is None:
306 | raise Exception(
307 | "The category set type does not exist, there is no category set named {}".format(self.category_param),
308 | )
309 | self.category: list = SPACE_TOKEN + self.category_value
310 | self.category_num: int = len(self.category)
311 | self.image_channel: int = self.field_root.get('ImageChannel')
312 | self.image_width: int = self.field_root.get('ImageWidth')
313 | self.image_height: int = self.field_root.get('ImageHeight')
314 | self.resize: list = self.field_root.get('Resize')
315 | self.output_split = self.field_root.get('OutputSplit')
316 | self.output_split = self.output_split if self.output_split else ""
317 | self.category_split = self.field_root.get('CategorySplit')
318 | self.corp_params = self.field_root.get('CorpParams')
319 | self.output_coord = self.field_root.get('OutputCoord')
320 | self.batch_model = self.field_root.get('BatchModel')
321 |
322 | """PRETREATMENT"""
323 | self.pretreatment_root = self.model_conf.get('Pretreatment')
324 | self.pre_binaryzation = self.get_var(self.pretreatment_root, 'Binaryzation', -1)
325 | self.pre_replace_transparent = self.get_var(self.pretreatment_root, 'ReplaceTransparent', True)
326 | self.pre_horizontal_stitching = self.get_var(self.pretreatment_root, 'HorizontalStitching', False)
327 | self.pre_concat_frames = self.get_var(self.pretreatment_root, 'ConcatFrames', -1)
328 | self.pre_blend_frames = self.get_var(self.pretreatment_root, 'BlendFrames', -1)
329 | self.exec_map = self.pretreatment_root.get('ExecuteMap')
330 | """COMPILE_MODEL"""
331 | if self.graph_path:
332 | self.compile_model_path = os.path.join(self.graph_path, '{}.onnx'.format(self.model_name))
333 | if not os.path.exists(self.compile_model_path):
334 | if not os.path.exists(self.graph_path):
335 | os.makedirs(self.graph_path)
336 | raise ValueError(
337 | '{} not found, please put the trained model in the current directory.'.format(self.compile_model_path)
338 | )
339 | else:
340 | self.model_exists = True
341 | else:
342 | self.model_exists = True if self.model_content else False
343 | self.compile_model_path = ""
344 |
345 | @staticmethod
346 | def param_convert(source, param_map: dict, text, code, default=None):
347 | if source is None:
348 | return default
349 | if source not in param_map.keys():
350 | raise Exception(text)
351 | return param_map[source]
352 |
353 | def size_match(self, size_str):
354 | return size_str == self.size_string
355 |
356 | @staticmethod
357 | def get_var(src: dict, name: str, default=None):
358 | if not src:
359 | return default
360 | return src.get(name)
361 |
362 | @property
363 | def size_string(self):
364 | return "{}x{}".format(self.image_width, self.image_height)
365 |
366 |
367 | class Model(object):
368 | model_conf: ModelConfig
369 | graph_bytes: object = None
370 |
371 | def __init__(self, conf_path: str, source_bytes: bytes = None, key=None):
372 | if conf_path:
373 | self.model_conf = ModelConfig(model_conf_path=conf_path)
374 | # self.graph_bytes = self.model_conf.compile_model_path
375 | if source_bytes:
376 | model_conf, self.graph_bytes = self.parse_model(source_bytes, key)
377 | self.model_conf = ModelConfig(model_content=model_conf)
378 |
379 | @staticmethod
380 | def parse_model(source_bytes: bytes, key=None):
381 | split_tag = b'-#||#-'
382 |
383 | if not key:
384 | key = [b"_____" + i.encode("utf8") + b"_____" for i in "&coriander"]
385 | if isinstance(key, str):
386 | key = [b"_____" + i.encode("utf8") + b"_____" for i in key]
387 | key_len_int = len(key)
388 | model_bytes_list = []
389 | graph_bytes_list = []
390 | slice_index = source_bytes.index(key[0])
391 | split_tag_len = len(split_tag)
392 | slice_0 = source_bytes[0: slice_index].split(split_tag)
393 | model_slice_len = len(slice_0[1])
394 | graph_slice_len = len(slice_0[0])
395 | slice_len = split_tag_len + model_slice_len + graph_slice_len
396 |
397 | for i in range(key_len_int - 1):
398 | slice_index = source_bytes.index(key[i])
399 | print(slice_index, slice_index - slice_len)
400 | slices = source_bytes[slice_index - slice_len: slice_index].split(split_tag)
401 | model_bytes_list.append(slices[1])
402 | graph_bytes_list.append(slices[0])
403 | slices = source_bytes.split(key[-2])[1][:-len(key[-1])].split(split_tag)
404 |
405 | model_bytes_list.append(slices[1])
406 | graph_bytes_list.append(slices[0])
407 | model_bytes = b"".join(model_bytes_list)
408 | model_conf: dict = pickle.loads(model_bytes)
409 | graph_bytes: bytes = b"".join(graph_bytes_list)
410 | return model_conf, graph_bytes
411 |
412 |
413 | class GraphSession(object):
414 | def __init__(self, model: Model):
415 | self.model_conf = model.model_conf
416 | self.size_str = self.model_conf.size_string
417 | self.model_name = self.model_conf.model_name
418 | self.graph_name = self.model_conf.model_name
419 | self.version = self.model_conf.model_version
420 | self.graph_bytes = model.graph_bytes
421 | self.sess = ort.InferenceSession(
422 | self.model_conf.compile_model_path if not model.graph_bytes else model.graph_bytes
423 | )
424 |
425 |
426 | class Interface(object):
427 |
428 | def __init__(self, graph_session: GraphSession):
429 | self.graph_sess = graph_session
430 | self.model_conf = graph_session.model_conf
431 | self.size_str = self.model_conf.size_string
432 | self.graph_name = self.graph_sess.graph_name
433 | self.version = self.graph_sess.version
434 | self.model_category = self.model_conf.category
435 | if self.graph_sess.sess:
436 | self.sess = self.graph_sess.sess
437 |
438 | @property
439 | def name(self):
440 | return self.graph_name
441 |
442 | @property
443 | def size(self):
444 | return self.size_str
445 |
446 | def predict_batch(self, image_batch, output_split=None):
447 | predict_text = self.predict_func(
448 | image_batch,
449 | self.sess,
450 | self.model_conf,
451 | output_split
452 | )
453 | return predict_text
454 |
455 | @staticmethod
456 | def decode_maps(categories):
457 | return {index: category for index, category in enumerate(categories, 0)}
458 |
459 | def predict_func(self, image_batch, _sess, model: ModelConfig, output_split=None):
460 | if isinstance(image_batch, list):
461 | image_batch = np.asarray(image_batch)
462 |
463 | output_split = model.output_split if output_split is None else output_split
464 |
465 | category_split = model.category_split if model.category_split else ""
466 |
467 | dense_decoded_code = _sess.run(["dense_decoded:0"], input_feed={
468 | "input:0": image_batch,
469 | })
470 | decoded_expression = []
471 |
472 | for item in dense_decoded_code[0]:
473 | expression = []
474 | for i in item:
475 | if i == -1 or i == model.category_num:
476 | expression.append("")
477 | else:
478 | expression.append(self.decode_maps(model.category)[i])
479 |
480 | decoded_expression.append(category_split.join(expression))
481 | return output_split.join(decoded_expression) if len(decoded_expression) > 1 else decoded_expression[0]
482 |
483 |
484 | class Pretreatment(object):
485 |
486 | def __init__(self, origin):
487 | self.origin = origin
488 |
489 | def get(self):
490 | return self.origin
491 |
492 | def binarization(self, value, modify=False):
493 | ret, _binarization = cv2.threshold(self.origin, value, 255, cv2.THRESH_BINARY)
494 | if modify:
495 | self.origin = _binarization
496 | return _binarization
497 |
498 | @staticmethod
499 | def preprocessing(image, binaryzation=-1):
500 | pretreatment = Pretreatment(image)
501 | if binaryzation > 0:
502 | pretreatment.binarization(binaryzation, True)
503 | return pretreatment.get()
504 |
505 | @staticmethod
506 | def preprocessing_by_func(exec_map, key, src_arr):
507 | if not exec_map:
508 | return src_arr
509 | target_arr = cv2.cvtColor(src_arr, cv2.COLOR_RGB2BGR)
510 | for sentence in exec_map.get(key):
511 | if sentence.startswith("@@"):
512 | target_arr = eval(sentence[2:])
513 | elif sentence.startswith("$$"):
514 | exec(sentence[2:])
515 | return cv2.cvtColor(target_arr, cv2.COLOR_BGR2RGB)
516 |
517 |
518 | class ImageUtils(object):
519 |
520 | @staticmethod
521 | def get_bytes_batch(image_bytes):
522 | try:
523 | bytes_batch = [image_bytes]
524 | except binascii.Error:
525 | return None, "INVALID_BASE64_STRING"
526 | what_img = [ImageUtils.test_image(i) for i in bytes_batch]
527 | if None in what_img:
528 | return None, "INVALID_IMAGE_FORMAT"
529 | return bytes_batch, "SUCCESS"
530 |
531 | @staticmethod
532 | def get_image_batch(model: ModelConfig, bytes_batch, param_key=None):
533 | # Note that there are two return objects here.
534 | # 1.image_batch, 2.response
535 |
536 | def load_image(image_bytes):
537 | data_stream = io.BytesIO(image_bytes)
538 | pil_image = PIL_Image.open(data_stream)
539 |
540 | gif_handle = model.pre_concat_frames != -1 or model.pre_blend_frames != -1
541 |
542 | if pil_image.mode == 'P' and not gif_handle:
543 | pil_image = pil_image.convert('RGB')
544 |
545 | rgb = pil_image.split()
546 | size = pil_image.size
547 |
548 | if (len(rgb) > 3 and model.pre_replace_transparent) and not gif_handle:
549 | background = PIL_Image.new('RGB', pil_image.size, (255, 255, 255))
550 | background.paste(pil_image, (0, 0, size[0], size[1]), pil_image)
551 | pil_image = background
552 |
553 | im = np.asarray(pil_image)
554 |
555 | im = Pretreatment.preprocessing_by_func(
556 | exec_map=model.exec_map,
557 | key=param_key,
558 | src_arr=im
559 | )
560 |
561 | if model.image_channel == 1 and len(im.shape) == 3:
562 | im = cv2.cvtColor(im, cv2.COLOR_RGB2GRAY)
563 |
564 | im = Pretreatment.preprocessing(
565 | image=im,
566 | binaryzation=model.pre_binaryzation,
567 | )
568 |
569 | if model.pre_horizontal_stitching:
570 | up_slice = im[0: int(size[1] / 2), 0: size[0]]
571 | down_slice = im[int(size[1] / 2): size[1], 0: size[0]]
572 | im = np.concatenate((up_slice, down_slice), axis=1)
573 |
574 | image = im.astype(np.float32)
575 | if model.resize[0] == -1:
576 | ratio = model.resize[1] / size[1]
577 | resize_width = int(ratio * size[0])
578 | image = cv2.resize(image, (resize_width, model.resize[1]))
579 | else:
580 | image = cv2.resize(image, (model.resize[0], model.resize[1]))
581 | image = image.swapaxes(0, 1)
582 | return (image[:, :, np.newaxis] if model.image_channel == 1 else image[:, :]) / 255.
583 |
584 | try:
585 | image_batch = [load_image(i) for i in bytes_batch]
586 | return image_batch, "SUCCESS"
587 | except OSError:
588 | return None, "IMAGE_DAMAGE"
589 | except ValueError as _e:
590 | print(_e)
591 | return None, "IMAGE_SIZE_NOT_MATCH_GRAPH"
592 |
593 | @staticmethod
594 | def size_of_image(image_bytes: bytes):
595 | _null_size = tuple((-1, -1))
596 | try:
597 | data_stream = io.BytesIO(image_bytes)
598 | size = PIL_Image.open(data_stream).size
599 | return size
600 | except OSError:
601 | return _null_size
602 | except ValueError:
603 | return _null_size
604 |
605 | @staticmethod
606 | def test_image(h):
607 | """JPEG"""
608 | if h[:3] == b"\xff\xd8\xff":
609 | return 'jpeg'
610 | """PNG"""
611 | if h[:8] == b"\211PNG\r\n\032\n":
612 | return 'png'
613 | """GIF ('87 and '89 variants)"""
614 | if h[:6] in (b'GIF87a', b'GIF89a'):
615 | return 'gif'
616 | """TIFF (can be in Motorola or Intel byte order)"""
617 | if h[:2] in (b'MM', b'II'):
618 | return 'tiff'
619 | if h[:2] == b'BM':
620 | return 'bmp'
621 | """SGI image library"""
622 | if h[:2] == b'\001\332':
623 | return 'rgb'
624 | """PBM (portable bitmap)"""
625 | if len(h) >= 3 and \
626 | h[0] == b'P' and h[1] in b'14' and h[2] in b' \t\n\r':
627 | return 'pbm'
628 | """PGM (portable graymap)"""
629 | if len(h) >= 3 and \
630 | h[0] == b'P' and h[1] in b'25' and h[2] in b' \t\n\r':
631 | return 'pgm'
632 | """PPM (portable pixmap)"""
633 | if len(h) >= 3 and h[0] == b'P' and h[1] in b'36' and h[2] in b' \t\n\r':
634 | return 'ppm'
635 | """Sun raster file"""
636 | if h[:4] == b'\x59\xA6\x6A\x95':
637 | return 'rast'
638 | """X bitmap (X10 or X11)"""
639 | s = b'#define '
640 | if h[:len(s)] == s:
641 | return 'xbm'
642 | return None
643 |
644 |
645 | class SDK(object):
646 |
647 | def __init__(self, conf_path=None, model_entity: bytes = None):
648 | if not conf_path and not model_entity:
649 | raise ValueError('One of parameters conf_path and model_entity must be filled')
650 | model = Model(conf_path=conf_path, source_bytes=model_entity)
651 | self.model_conf = model.model_conf
652 | self.graph_session = GraphSession(model)
653 | self.interface = Interface(self.graph_session)
654 |
655 | def predict(self, image_bytes, param_key=None):
656 | bytes_batch, message = ImageUtils.get_bytes_batch(image_bytes)
657 | if not bytes_batch:
658 | raise ValueError(message)
659 | image_batch, message = ImageUtils.get_image_batch(self.model_conf, bytes_batch, param_key=param_key)
660 | if not image_batch:
661 | raise ValueError(message)
662 | result = self.interface.predict_batch(image_batch, None)
663 | return result
664 |
665 |
666 | if __name__ == '__main__':
667 | # FROM PATH
668 | # sdk = SDK(conf_path=r"model.yaml")
669 | # with open(r"H:\TrainSet\1541187040676.jpg", "rb") as f:
670 | # b = f.read()
671 | # for i in [b] * 1000:
672 | # t1 = time.time()
673 | # print(sdk.predict(b), (time.time() - t1) * 1000)
674 |
675 | # FROM BYTES
676 | with open(r"model.pl", "rb") as f:
677 | b = f.read()
678 | sdk = SDK(model_entity=b)
679 | with open(r"1540868881850.jpg", "rb") as f:
680 | b = f.read()
681 | for i in [b] * 1000:
682 | t1 = time.time()
683 | print(sdk.predict(b), (time.time() - t1) * 1000)
684 |
--------------------------------------------------------------------------------
/sdk/pb/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # Author: kerlomz
--------------------------------------------------------------------------------
/sdk/pb/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow==1.14
2 | numpy
3 | pillow
4 | opencv-python==3.4.5.20
5 | pyyaml>=3.13
--------------------------------------------------------------------------------
/sdk/tflite/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # Author: kerlomz
--------------------------------------------------------------------------------
/sdk/tflite/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow==1.14
2 | numpy
3 | pillow
4 | opencv-python==3.4.5.20
5 | pyyaml>=3.13
--------------------------------------------------------------------------------
/signature.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # Author: kerlomz
4 | from functools import wraps
5 | from constants import ServerType
6 | from utils import *
7 |
8 |
9 | class InvalidUsage(Exception):
10 |
11 | def __init__(self, message, code=None):
12 | Exception.__init__(self)
13 | self.message = message
14 | self.success = False
15 | self.code = code
16 |
17 | def to_dict(self):
18 | rv = {'code': self.code, 'message': self.message, 'success': self.success}
19 | return rv
20 |
21 |
22 | class Signature(object):
23 | """ api signature authentication """
24 |
25 | def __init__(self, server_type: ServerType, conf: Config):
26 | self.conf = conf
27 | self._except = Response(self.conf.response_def_map)
28 | self._auth = []
29 | self._timestamp_expiration = 120
30 | self.request = None
31 | self.type = server_type
32 |
33 | def set_auth(self, auth):
34 | self._auth = auth
35 |
36 | def _check_req_timestamp(self, req_timestamp):
37 | """ Check the timestamp
38 | @pram req_timestamp str,int: Timestamp in the request parameter (10 digits)
39 | """
40 | if len(str(req_timestamp)) == 10:
41 | req_timestamp = int(req_timestamp)
42 | now_timestamp = SignUtils.timestamp()
43 | if now_timestamp - self._timestamp_expiration <= req_timestamp <= now_timestamp + self._timestamp_expiration:
44 | return True
45 | return False
46 |
47 | def _check_req_access_key(self, req_access_key):
48 | """ Check the access_key in the request parameter
49 | @pram req_access_key str: access key in the request parameter
50 | """
51 | if req_access_key in [i['accessKey'] for i in self._auth if "accessKey" in i]:
52 | return True
53 | return False
54 |
55 | def _get_secret_key(self, access_key):
56 | """ Obtain the corresponding secret_key according to access_key
57 | @pram access_key str: access key in the request parameter
58 | """
59 | secret_keys = [i['secretKey'] for i in self._auth if i.get('accessKey') == access_key]
60 | return "" if not secret_keys else secret_keys[0]
61 |
62 | def _sign(self, args):
63 | """ MD5 signature
64 | @param args: All query parameters (public and private) requested in addition to signature
65 | """
66 | if "sign" in args:
67 | args.pop("sign")
68 | access_key = args["accessKey"]
69 | query_string = '&'.join(['{}={}'.format(k, v) for (k, v) in sorted(args.items())])
70 | query_string = '&'.join([query_string, self._get_secret_key(access_key)])
71 | return SignUtils.md5(query_string).upper()
72 |
73 | def _verification(self, req_params, tornado_handler=None):
74 | """ Verify that the request is valid
75 | @param req_params: All query parameters requested (public and private)
76 | """
77 | try:
78 | req_signature = req_params["sign"]
79 | req_timestamp = req_params["timestamp"]
80 | req_access_key = req_params["accessKey"]
81 | except KeyError:
82 | raise InvalidUsage(**self._except.INVALID_PUBLIC_PARAMS)
83 | except Exception:
84 | raise InvalidUsage(**self._except.UNKNOWN_SERVER_ERROR)
85 | else:
86 | if self.type == ServerType.FLASK or self.type == ServerType.SANIC:
87 | from flask.app import HTTPException, json
88 | # NO.1 Check the timestamp
89 | if not self._check_req_timestamp(req_timestamp):
90 | raise HTTPException(response=json.jsonify(self._except.INVALID_TIMESTAMP))
91 | # NO.2 Check the access_id
92 | if not self._check_req_access_key(req_access_key):
93 | raise HTTPException(response=json.jsonify(self._except.INVALID_ACCESS_KEY))
94 | # NO.3 Check the sign
95 | if req_signature == self._sign(req_params):
96 | return True
97 | else:
98 | raise HTTPException(response=json.jsonify(self._except.INVALID_QUERY_STRING))
99 | elif self.type == ServerType.TORNADO:
100 | from tornado.web import HTTPError
101 | # NO.1 Check the timestamp
102 | if not self._check_req_timestamp(req_timestamp):
103 | return tornado_handler.write_error(self._except.INVALID_TIMESTAMP['code'])
104 | # NO.2 Check the access_id
105 | if not self._check_req_access_key(req_access_key):
106 | return tornado_handler.write_error(self._except.INVALID_ACCESS_KEY['code'])
107 | # NO.3 Check the sign
108 | if req_signature == self._sign(req_params):
109 | return True
110 | else:
111 | return tornado_handler.write_error(self._except.INVALID_QUERY_STRING['code'])
112 | raise Exception('Unknown Server Type')
113 |
114 | def signature_required(self, f):
115 | @wraps(f)
116 | def decorated_function(*args, **kwargs):
117 | if self.type == ServerType.FLASK:
118 | from flask import request
119 | params = request.json
120 | elif self.type == ServerType.TORNADO:
121 | from tornado.escape import json_decode
122 | params = json_decode(args[0].request.body)
123 | elif self.type == ServerType.SANIC:
124 | params = args[0].json
125 | else:
126 | raise UserWarning('Illegal type, the current version is not supported at this time.')
127 | result = self._verification(params, args[0] if self.type == ServerType.TORNADO else None)
128 | if result is True:
129 | return f(*args, **kwargs)
130 | return decorated_function
131 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | class SquareTest(tf.test.TestCase):
5 |
6 | def testSquare(self):
7 | with self.test_session():
8 | x = tf.square([2, 3])
9 | self.assertAllEqual(x.eval(), [4, 9])
10 |
11 |
12 | if __name__ == '__main__':
13 | tf.test.main()
14 |
--------------------------------------------------------------------------------
/tornado_server.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # Author: kerlomz
4 | import os
5 | import uuid
6 | import time
7 | import json
8 | import platform
9 | import numpy as np
10 | import asyncio
11 | import hashlib
12 | import optparse
13 | import threading
14 | import tornado.ioloop
15 | import tornado.log
16 | import tornado.gen
17 | import tornado.httpserver
18 | import tornado.options
19 | from pytz import utc
20 | from apscheduler.triggers.interval import IntervalTrigger
21 | from apscheduler.schedulers.background import BackgroundScheduler
22 | from tornado.web import RequestHandler
23 | from constants import Response
24 | from json.decoder import JSONDecodeError
25 | from tornado.escape import json_decode, json_encode
26 | from interface import InterfaceManager, Interface
27 | from config import Config, blacklist, set_blacklist, whitelist, get_version
28 | from utils import ImageUtils, ParamUtils, Arithmetic
29 | from signature import Signature, ServerType
30 | from tornado.concurrent import run_on_executor
31 | from concurrent.futures import ThreadPoolExecutor
32 | from middleware import *
33 | from event_loop import event_loop
34 |
35 | tornado.options.define('ip_blacklist', default=list(), type=list)
36 | tornado.options.define('ip_whitelist', default=list(), type=list)
37 | tornado.options.define('ip_risk_times', default=dict(), type=dict)
38 | tornado.options.define('request_count', default=dict(), type=dict)
39 | tornado.options.define('global_request_count', default=0, type=int)
40 | model_path = "model"
41 | system_config = Config(conf_path="config.yaml", model_path=model_path, graph_path="graph")
42 | sign = Signature(ServerType.TORNADO, system_config)
43 | arithmetic = Arithmetic()
44 | semaphore = asyncio.Semaphore(500)
45 |
46 | scheduler = BackgroundScheduler(timezone='Asia/Shanghai')
47 |
48 |
49 | class BaseHandler(RequestHandler):
50 |
51 | def __init__(self, application, request, **kwargs):
52 | super().__init__(application, request, **kwargs)
53 | self.exception = Response(system_config.response_def_map)
54 | self.executor = ThreadPoolExecutor(workers)
55 | self.image_utils = ImageUtils(system_config)
56 |
57 | @property
58 | def request_incr(self):
59 | if self.request.remote_ip not in tornado.options.options.request_count:
60 | tornado.options.options.request_count[self.request.remote_ip] = 1
61 | else:
62 | tornado.options.options.request_count[self.request.remote_ip] += 1
63 | return tornado.options.options.request_count[self.request.remote_ip]
64 |
65 | def request_desc(self):
66 | if self.request.remote_ip not in tornado.options.options.request_count:
67 | return
68 | else:
69 | tornado.options.options.request_count[self.request.remote_ip] -= 1
70 |
71 | @property
72 | def global_request_incr(self):
73 | tornado.options.options.global_request_count += 1
74 | return tornado.options.options.global_request_count
75 |
76 | @staticmethod
77 | def global_request_desc():
78 | tornado.options.options.global_request_count -= 1
79 |
80 | @staticmethod
81 | def risk_ip_count(ip):
82 | if ip not in tornado.options.options.ip_risk_times:
83 | tornado.options.options.ip_risk_times[ip] = 1
84 | else:
85 | tornado.options.options.ip_risk_times[ip] += 1
86 |
87 | @staticmethod
88 | def risk_ip(ip):
89 | return tornado.options.options.ip_risk_times[ip]
90 |
91 | def data_received(self, chunk):
92 | pass
93 |
94 | def parse_param(self):
95 | try:
96 | data = json_decode(self.request.body)
97 | except JSONDecodeError:
98 | data = self.request.body_arguments
99 | except UnicodeDecodeError:
100 | raise tornado.web.HTTPError(401)
101 | if not data:
102 | raise tornado.web.HTTPError(400)
103 | return data
104 |
105 | def write_error(self, code, **kw):
106 | err_resp = dict(StatusCode=code, Message=system_config.error_message[code], StatusBool=False)
107 | if code in system_config.error_message:
108 | code_dict = Response.parse(err_resp, system_config.response_def_map)
109 | else:
110 | code_dict = self.exception.find(code)
111 | return self.finish(json.dumps(code_dict, ensure_ascii=False))
112 |
113 |
114 | class NoAuthHandler(BaseHandler):
115 | uid_key: str = system_config.response_def_map['Uid']
116 | message_key: str = system_config.response_def_map['Message']
117 | status_bool_key = system_config.response_def_map['StatusBool']
118 | status_code_key = system_config.response_def_map['StatusCode']
119 |
120 | @staticmethod
121 | def save_image(uid, label, image_bytes):
122 | if system_config.save_path:
123 | if not os.path.exists(system_config.save_path):
124 | os.makedirs(system_config.save_path)
125 | save_name = "{}_{}.png".format(label, uid)
126 | with open(os.path.join(system_config.save_path, save_name), "wb") as f:
127 | f.write(image_bytes)
128 |
129 | @run_on_executor
130 | def predict(self, interface: Interface, image_batch, split_char):
131 | result = interface.predict_batch(image_batch, split_char)
132 | if 'ARITHMETIC' in interface.model_category:
133 | if '=' in result or '+' in result or '-' in result or '×' in result or '÷' in result:
134 | result = result.replace("×", "*").replace("÷", "/")
135 | result = str(int(arithmetic.calc(result)))
136 | return result
137 |
138 | @staticmethod
139 | def match_blacklist(ip: str):
140 | for black_ip in tornado.options.options.ip_blacklist:
141 | if ip.startswith(black_ip):
142 | return True
143 | return False
144 |
145 | @staticmethod
146 | def match_whitelist(ip: str):
147 | for white_ip in tornado.options.options.ip_whitelist:
148 | if ip.startswith(white_ip):
149 | return True
150 | return False
151 |
152 | async def options(self):
153 | self.set_status(204)
154 | return self.finish()
155 |
156 | @tornado.gen.coroutine
157 | def post(self):
158 | uid = str(uuid.uuid1())
159 | start_time = time.time()
160 | data = self.parse_param()
161 | request_def_map = system_config.request_def_map
162 | input_data_key = request_def_map['InputData']
163 | model_name_key = request_def_map['ModelName']
164 | if input_data_key not in data.keys():
165 | raise tornado.web.HTTPError(400)
166 |
167 | model_name = ParamUtils.filter(data.get(model_name_key))
168 | output_split = ParamUtils.filter(data.get('output_split'))
169 | need_color = ParamUtils.filter(data.get('need_color'))
170 | param_key = ParamUtils.filter(data.get('param_key'))
171 | extract_rgb = ParamUtils.filter(data.get('extract_rgb'))
172 |
173 | request_incr = self.request_incr
174 | global_count = self.global_request_incr
175 | request_count = " - Count[{}]".format(request_incr)
176 | log_params = " - ParamKey[{}]".format(param_key) if param_key else ""
177 | log_params += " - NeedColor[{}]".format(need_color) if need_color else ""
178 |
179 | if interface_manager.total == 0:
180 | self.request_desc()
181 | self.global_request_desc()
182 | logger.info('There is currently no model deployment and services are not available.')
183 | return self.finish(json_encode(
184 | {self.uid_key: uid, self.message_key: "", self.status_bool_key: False, self.status_code_key: -999}
185 | ))
186 | bytes_batch, response = self.image_utils.get_bytes_batch(data[input_data_key])
187 |
188 | if not (opt.low_hour == -1 or opt.up_hour == -1) and not (opt.low_hour <= time.localtime().tm_hour <= opt.up_hour):
189 | logger.info("[{}] - [{} {}] | - Response[{}] - {} ms".format(
190 | uid, self.request.remote_ip, self.request.uri, "Not in open time.",
191 | (time.time() - start_time) * 1000)
192 | )
193 | return self.finish(json.dumps({
194 | self.uid_key: uid,
195 | self.message_key: system_config.illegal_time_msg.format(opt.low_hour, opt.up_hour),
196 | self.status_bool_key: False,
197 | self.status_code_key: -250
198 | }, ensure_ascii=False))
199 |
200 | if not bytes_batch:
201 | logger.error('[{}] - [{} {}] | - Response[{}] - {} ms'.format(
202 | uid, self.request.remote_ip, self.request.uri, response,
203 | (time.time() - start_time) * 1000)
204 | )
205 | return self.finish(json_encode(response))
206 |
207 | image_sample = bytes_batch[0]
208 | image_size = ImageUtils.size_of_image(image_sample)
209 | size_string = "{}x{}".format(image_size[0], image_size[1])
210 | if system_config.request_size_limit and size_string not in system_config.request_size_limit:
211 | self.request_desc()
212 | self.global_request_desc()
213 | logger.info('[{}] - [{} {}] | Size[{}] - [{}][{}] - Error[{}] - {} ms'.format(
214 | uid, self.request.remote_ip, self.request.uri, size_string, global_count, log_params,
215 | "Image size is invalid.",
216 | round((time.time() - start_time) * 1000))
217 | )
218 | msg = system_config.request_size_limit.get("msg")
219 | msg = msg if msg else "The size of the picture is wrong. " \
220 | "Only the original image is supported. " \
221 | "Please do not take a screenshot!"
222 | return self.finish(json.dumps({
223 | self.uid_key: uid,
224 | self.message_key: msg,
225 | self.status_bool_key: False,
226 | self.status_code_key: -250
227 | }, ensure_ascii=False))
228 |
229 | if system_config.use_whitelist:
230 | assert_whitelist = self.match_whitelist(self.request.remote_ip)
231 | if not assert_whitelist:
232 | logger.info('[{}] - [{} {}] | Size[{}]{}{} - Error[{}] - {} ms'.format(
233 | uid, self.request.remote_ip, self.request.uri, size_string, request_count, log_params,
234 | "Whitelist limit",
235 | round((time.time() - start_time) * 1000))
236 | )
237 | return self.finish(json.dumps({
238 | self.uid_key: uid,
239 | self.message_key: "Only allow IP access in the whitelist",
240 | self.status_bool_key: False,
241 | self.status_code_key: -111
242 | }, ensure_ascii=False))
243 |
244 | if global_request_limit != -1 and global_count > global_request_limit:
245 | logger.info('[{}] - [{} {}] | Size[{}]{}{} - Error[{}] - {} ms'.format(
246 | uid, self.request.remote_ip, self.request.uri, size_string, global_count, log_params,
247 | "Maximum number of requests exceeded (G)",
248 | round((time.time() - start_time) * 1000))
249 | )
250 | return self.finish(json.dumps({
251 | self.uid_key: uid,
252 | self.message_key: system_config.exceeded_msg,
253 | self.status_bool_key: False,
254 | self.status_code_key: -555
255 | }, ensure_ascii=False))
256 |
257 | assert_blacklist = self.match_blacklist(self.request.remote_ip)
258 | if assert_blacklist:
259 | logger.info('[{}] - [{} {}] | Size[{}]{}{} - Error[{}] - {} ms'.format(
260 | uid, self.request.remote_ip, self.request.uri, size_string, request_count, log_params,
261 | "The ip is on the risk blacklist (IP)",
262 | round((time.time() - start_time) * 1000))
263 | )
264 | return self.finish(json.dumps({
265 | self.uid_key: uid,
266 | self.message_key: system_config.exceeded_msg,
267 | self.status_bool_key: False,
268 | self.status_code_key: -110
269 | }, ensure_ascii=False))
270 | if request_limit != -1 and request_incr > request_limit:
271 | self.risk_ip_count(self.request.remote_ip)
272 | assert_blacklist_trigger = system_config.blacklist_trigger_times != -1
273 | if self.risk_ip(
274 | self.request.remote_ip) > system_config.blacklist_trigger_times and assert_blacklist_trigger:
275 | if self.request.remote_ip not in blacklist():
276 | set_blacklist(self.request.remote_ip)
277 | update_blacklist()
278 | logger.info('[{}] - [{} {}] | Size[{}]{}{} - Error[{}] - {} ms'.format(
279 | uid, self.request.remote_ip, self.request.uri, size_string, request_count, log_params,
280 | "Maximum number of requests exceeded (IP)",
281 | round((time.time() - start_time) * 1000))
282 | )
283 | return self.finish(json.dumps({
284 | self.uid_key: uid,
285 | self.message_key: system_config.exceeded_msg,
286 | self.status_bool_key: False,
287 | self.status_code_key: -444
288 | }, ensure_ascii=False))
289 | if model_name_key in data and data[model_name_key]:
290 | interface: Interface = interface_manager.get_by_name(model_name)
291 | else:
292 | interface: Interface = interface_manager.get_by_size(size_string)
293 | if not interface:
294 | self.request_desc()
295 | self.global_request_desc()
296 | logger.info('Service is not ready!')
297 | return self.finish(json_encode(
298 | {self.uid_key: uid, self.message_key: "", self.status_bool_key: False, self.status_code_key: 999}
299 | ))
300 |
301 | output_split = output_split if 'output_split' in data else interface.model_conf.output_split
302 |
303 | if interface.model_conf.corp_params:
304 | bytes_batch = corp_to_multi.parse_multi_img(bytes_batch, interface.model_conf.corp_params)
305 | if interface.model_conf.pre_freq_frames != -1:
306 | bytes_batch = gif_frames.all_frames(bytes_batch)
307 | exec_map = interface.model_conf.exec_map
308 | if exec_map and len(exec_map.keys()) > 1 and not param_key:
309 | self.request_desc()
310 | self.global_request_desc()
311 | logger.info('[{}] - [{} {}] | [{}] - Size[{}]{}{} - Error[{}] - {} ms'.format(
312 | uid, self.request.remote_ip, self.request.uri, interface.name, size_string, request_count, log_params,
313 | "The model is missing the param_key parameter because the model is configured with ExecuteMap.",
314 | round((time.time() - start_time) * 1000))
315 | )
316 | return self.finish(json_encode(
317 | {
318 | self.uid_key: uid,
319 | self.message_key: "Missing the parameter [param_key].",
320 | self.status_bool_key: False,
321 | self.status_code_key: 474
322 | }
323 | ))
324 | elif exec_map and param_key and param_key not in exec_map:
325 | self.request_desc()
326 | self.global_request_desc()
327 | logger.info('[{}] - [{} {}] | [{}] - Size[{}]{}{} - Error[{}] - {} ms'.format(
328 | uid, self.request.remote_ip, self.request.uri, interface.name, size_string, request_count, log_params,
329 | "The param_key parameter is not support in the model.",
330 | round((time.time() - start_time) * 1000))
331 | )
332 | return self.finish(json_encode(
333 | {
334 | self.uid_key: uid,
335 | self.message_key: "Not support the parameter [param_key].",
336 | self.status_bool_key: False,
337 | self.status_code_key: 474
338 | }
339 | ))
340 | elif exec_map and len(exec_map.keys()) == 1:
341 | param_key = list(interface.model_conf.exec_map.keys())[0]
342 |
343 | if interface.model_conf.external_model and interface.model_conf.corp_params:
344 | result = []
345 | len_of_result = []
346 | pre_corp_num = 0
347 | for corp_param in interface.model_conf.corp_params:
348 | corp_size = corp_param['corp_size']
349 | corp_num_list = corp_param['corp_num']
350 | corp_num = corp_num_list[0] * corp_num_list[1]
351 | sub_bytes_batch = bytes_batch[pre_corp_num: pre_corp_num + corp_num]
352 | pre_corp_num = corp_num
353 | size_string = "{}x{}".format(corp_size[0], corp_size[1])
354 |
355 | sub_interface = interface_manager.get_by_size(size_string)
356 |
357 | image_batch, response = ImageUtils.get_image_batch(
358 | sub_interface.model_conf, sub_bytes_batch, param_key=param_key
359 | )
360 |
361 | text = yield self.predict(
362 | sub_interface, image_batch, output_split, size_string, start_time, log_params, request_count,
363 | uid=uid
364 | )
365 | result.append(text)
366 | len_of_result.append(len(result[0].split(sub_interface.model_conf.category_split)))
367 |
368 | response[self.message_key] = interface.model_conf.output_split.join(result)
369 | if interface.model_conf.corp_params and interface.model_conf.output_coord:
370 | # final_result = auxiliary_result + "," + response[self.message_key]
371 | # if auxiliary_result else response[self.message_key]
372 | final_result = response[self.message_key]
373 | response[self.message_key] = corp_to_multi.get_coordinate(
374 | label=final_result,
375 | param_group=interface.model_conf.corp_params,
376 | title_index=[i for i in range(len_of_result[0])]
377 | )
378 | return self.finish(json.dumps(response, ensure_ascii=False).replace("", "<\\/"))
379 | else:
380 | image_batch, response = ImageUtils.get_image_batch(
381 | interface.model_conf,
382 | bytes_batch,
383 | param_key=param_key,
384 | extract_rgb=extract_rgb
385 | )
386 |
387 | if not image_batch:
388 | self.request_desc()
389 | self.global_request_desc()
390 | logger.error('[{}] - [{} {}] | [{}] - Size[{}] - Response[{}] - {} ms'.format(
391 | uid, self.request.remote_ip, self.request.uri, interface.name, size_string, response,
392 | round((time.time() - start_time) * 1000))
393 | )
394 | response[self.uid_key] = uid
395 | return self.finish(json_encode(response))
396 |
397 | predict_result = yield self.predict(interface, image_batch, output_split)
398 | if interface.model_conf.pre_freq_frames != -1:
399 | predict_result = predict_result.split(interface.model_conf.output_split)
400 | predict_result = [
401 | i for i in predict_result
402 | if interface.model_conf.max_label_num >= len(i) >= interface.model_conf.min_label_num
403 | ]
404 | predict_result = gif_frames.get_continuity_max(predict_result)
405 | # if need_color:
406 | # # only support six label and size [90x35].
407 | # color_batch = np.resize(image_batch[0], (90, 35, 3))
408 | # need_index = color_extract.predict_color(image_batch=[color_batch], color=color_map[need_color])
409 | # predict_result = "".join([v for i, v in enumerate(predict_result) if i in need_index])
410 |
411 | uid_str = "[{}] - ".format(uid)
412 | logger.info('{}[{} {}] | [{}] - Size[{}]{}{} - Predict[{}] - {} ms'.format(
413 | uid_str, self.request.remote_ip, self.request.uri, interface.name, size_string, request_count, log_params,
414 | predict_result,
415 | round((time.time() - start_time) * 1000))
416 | )
417 | response[self.message_key] = predict_result
418 | response[self.uid_key] = uid
419 | self.executor.submit(self.save_image, uid, response[self.message_key], bytes_batch[0])
420 | if interface.model_conf.corp_params and interface.model_conf.output_coord:
421 | # final_result = auxiliary_result + "," + response[self.message_key]
422 | # if auxiliary_result else response[self.message_key]
423 | final_result = response[self.message_key]
424 | response[self.message_key] = corp_to_multi.get_coordinate(
425 | label=final_result,
426 | param_group=interface.model_conf.corp_params,
427 | title_index=[0]
428 | )
429 | return self.finish(json.dumps(response, ensure_ascii=False).replace("", "<\\/"))
430 |
431 |
432 | class AuthHandler(NoAuthHandler):
433 |
434 | @sign.signature_required
435 | def post(self):
436 | return super().post()
437 |
438 |
439 | class SimpleHandler(BaseHandler):
440 | uid_key: str = system_config.response_def_map['Uid']
441 | message_key: str = system_config.response_def_map['Message']
442 | status_bool_key = system_config.response_def_map['StatusBool']
443 | status_code_key = system_config.response_def_map['StatusCode']
444 |
445 | def post(self):
446 | uid = str(uuid.uuid1())
447 | param_key = None
448 | start_time = time.time()
449 | if interface_manager.total == 0:
450 | logger.info('There is currently no model deployment and services are not available.')
451 | return self.finish(json_encode(
452 | {self.uid_key: uid, self.message_key: "", self.status_bool_key: False, self.status_code_key: -999}
453 | ))
454 |
455 | bytes_batch, response = self.image_utils.get_bytes_batch(self.request.body)
456 |
457 | if not bytes_batch:
458 | logger.error('Response[{}] - {} ms'.format(
459 | response,
460 | (time.time() - start_time) * 1000)
461 | )
462 | return self.finish(json_encode(response))
463 |
464 | image_sample = bytes_batch[0]
465 | image_size = ImageUtils.size_of_image(image_sample)
466 | size_string = "{}x{}".format(image_size[0], image_size[1])
467 |
468 | interface = interface_manager.get_by_size(size_string)
469 | if not interface:
470 | logger.info('Service is not ready!')
471 | return self.finish(json_encode(
472 | {self.message_key: "", self.status_bool_key: False, self.status_code_key: 999}
473 | ))
474 |
475 | exec_map = interface.model_conf.exec_map
476 | if exec_map and len(exec_map.keys()) > 1:
477 | logger.info('[{}] - [{} {}] | [{}] - Size[{}] - Error[{}] - {} ms'.format(
478 | uid, self.request.remote_ip, self.request.uri, interface.name, size_string,
479 | "The model is configured with ExecuteMap, but the api do not support this param.",
480 | round((time.time() - start_time) * 1000))
481 | )
482 | return self.finish(json_encode(
483 | {
484 | self.message_key: "the api do not support [ExecuteMap].",
485 | self.status_bool_key: False,
486 | self.status_code_key: 474
487 | }
488 | ))
489 | elif exec_map and len(exec_map.keys()) == 1:
490 | param_key = list(interface.model_conf.exec_map.keys())[0]
491 |
492 | image_batch, response = ImageUtils.get_image_batch(interface.model_conf, bytes_batch, param_key=param_key)
493 |
494 | if not image_batch:
495 | logger.error('[{}] - [{}] | [{}] - Size[{}] - Response[{}] - {} ms'.format(
496 | uid, self.request.remote_ip, interface.name, size_string, response,
497 | (time.time() - start_time) * 1000)
498 | )
499 | return self.finish(json_encode(response))
500 |
501 | result = interface.predict_batch(image_batch, None)
502 | logger.info('[{}] - [{}] | [{}] - Size[{}] - Predict[{}] - {} ms'.format(
503 | uid, self.request.remote_ip, interface.name, size_string, result, (time.time() - start_time) * 1000)
504 | )
505 | response[self.uid_key] = uid
506 | response[self.message_key] = result
507 | return self.write(json.dumps(response, ensure_ascii=False).replace("", "<\\/"))
508 |
509 |
510 | class ServiceHandler(BaseHandler):
511 |
512 | def get(self):
513 | response = {
514 | "total": interface_manager.total,
515 | "online": interface_manager.online_names,
516 | "invalid": interface_manager.invalid_group,
517 | "blacklist": tornado.options.options.ip_blacklist
518 | }
519 | return self.finish(json.dumps(response, ensure_ascii=False, indent=2))
520 |
521 |
522 | class FileHandler(tornado.web.StaticFileHandler):
523 | def data_received(self, chunk):
524 | pass
525 |
526 | def set_extra_headers(self, path):
527 | self.set_header("Cache-control", "no-cache")
528 |
529 |
530 | class HeartBeatHandler(BaseHandler):
531 |
532 | def get(self):
533 | self.finish("")
534 |
535 |
536 | def clear_specific_job():
537 | tornado.options.options.request_count = {}
538 |
539 |
540 | def clear_global_job():
541 | tornado.options.options.global_request_count = 0
542 |
543 |
544 | def update_blacklist():
545 | tornado.options.options.ip_blacklist = blacklist()
546 |
547 |
548 | def make_app(route: list):
549 | return tornado.web.Application(
550 | [
551 | (i['Route'], globals()[i['Class']], i.get("Param"))
552 | if "Param" in i else
553 | (i['Route'], globals()[i['Class']]) for i in route
554 | ],
555 | static_path=os.path.join(os.path.dirname(__file__), "resource"),
556 | )
557 |
558 |
559 | trigger_specific = IntervalTrigger(seconds=system_config.request_count_interval)
560 | trigger_global = IntervalTrigger(seconds=system_config.g_request_count_interval)
561 | trigger_blacklist = IntervalTrigger(seconds=10)
562 | scheduler.add_job(update_blacklist, trigger_blacklist)
563 | scheduler.add_job(clear_specific_job, trigger_specific)
564 | scheduler.add_job(clear_global_job, trigger_global)
565 | scheduler.start()
566 |
567 | if __name__ == "__main__":
568 | parser = optparse.OptionParser()
569 |
570 | request_limit = system_config.request_limit
571 | global_request_limit = system_config.global_request_limit
572 |
573 | parser.add_option('-p', '--port', type="int", default=system_config.default_port, dest="port")
574 | parser.add_option('-w', '--workers', type="int", default=50, dest="workers")
575 | parser.add_option('--up_hour', type="int", default=-1, dest="up_hour")
576 | parser.add_option('--low_hour', type="int", default=-1, dest="low_hour")
577 |
578 | opt, args = parser.parse_args()
579 | server_port = opt.port
580 |
581 | if platform.system() == 'Windows':
582 | # os.system("chcp 65001")
583 | os.system("title=Eve-DL Platform v0.1({}) ^| [{}]".format(get_version(), server_port))
584 |
585 | workers = opt.workers
586 | logger = system_config.logger
587 | # print('=============WITHOUT_LOGGER=============', system_config.without_logger)
588 | tornado.log.enable_pretty_logging(logger=logger)
589 | interface_manager = InterfaceManager()
590 | threading.Thread(target=lambda: event_loop(system_config, model_path, interface_manager)).start()
591 |
592 | sign.set_auth([{'accessKey': system_config.access_key, 'secretKey': system_config.secret_key}])
593 |
594 | tornado.options.options.ip_whitelist = whitelist()
595 |
596 | server_host = "0.0.0.0"
597 | logger.info('Running on http://{}:{}/ '.format(server_host, server_port))
598 | app = make_app(system_config.route_map)
599 | http_server = tornado.httpserver.HTTPServer(app)
600 | http_server.bind(server_port, server_host)
601 | http_server.start(1)
602 | tornado.ioloop.IOLoop.instance().start()
603 |
--------------------------------------------------------------------------------
/tornado_server.spec:
--------------------------------------------------------------------------------
1 | # -*- mode: python -*-
2 | # Used to package as a single executable
3 | # This is a configuration file
4 | import cv2
5 | import os
6 | from PyInstaller.utils.hooks import collect_all
7 |
8 |
9 | block_cipher = None
10 |
11 | binaries, hiddenimports = [], ['numpy.core._dtype_ctypes']
12 | tmp_ret = collect_all('tzdata')
13 | added_files = [('resource/icon.ico', 'resource'), ('resource/favorite.ico', '.'), ('resource/VERSION', 'astor'), ('resource/VERSION', '.')]
14 | added_files += tmp_ret[0]; binaries += tmp_ret[1]
15 | hiddenimports += tmp_ret[2]
16 |
17 | a = Analysis(['tornado_server.py'],
18 | pathex=['.', os.path.join(os.path.dirname(cv2.__file__), 'config-3.9')],
19 | binaries=binaries,
20 | datas=added_files,
21 | hiddenimports=hiddenimports,
22 | hookspath=[],
23 | runtime_hooks=[],
24 | excludes=[],
25 | win_no_prefer_redirects=False,
26 | win_private_assemblies=False,
27 | cipher=block_cipher)
28 | pyz = PYZ(a.pure, a.zipped_data,
29 | cipher=block_cipher)
30 | exe = EXE(pyz,
31 | a.scripts,
32 | a.binaries,
33 | a.zipfiles,
34 | a.datas,
35 | [],
36 | name='captcha_platform_tornado',
37 | debug=False,
38 | strip=False,
39 | upx=True,
40 | runtime_tmpdir=None,
41 | console=True,
42 | icon='resource/icon.ico')
43 |
--------------------------------------------------------------------------------
/tornado_server_gpu.spec:
--------------------------------------------------------------------------------
1 | # -*- mode: python -*-
2 | # Used to package as a single executable
3 | # This is a configuration file
4 |
5 | block_cipher = None
6 |
7 | added_files = [('resource/icon.ico', 'resource'), ('resource/VERSION', 'astor')]
8 |
9 | a = Analysis(['tornado_server.py'],
10 | pathex=['.'],
11 | binaries=[],
12 | datas=added_files,
13 | hiddenimports=['numpy.core._dtype_ctypes'],
14 | hookspath=[],
15 | runtime_hooks=[],
16 | excludes=[],
17 | win_no_prefer_redirects=False,
18 | win_private_assemblies=False,
19 | cipher=block_cipher)
20 | pyz = PYZ(a.pure, a.zipped_data,
21 | cipher=block_cipher)
22 | exe = EXE(pyz,
23 | a.scripts,
24 | [],
25 | exclude_binaries=True,
26 | name='captcha_platform_tornado_gpu',
27 | debug=False,
28 | bootloader_ignore_signals=False,
29 | strip=False,
30 | upx=True,
31 | console=True,
32 | icon='resource/icon.ico')
33 | coll = COLLECT(exe,
34 | a.binaries,
35 | a.zipfiles,
36 | a.datas,
37 | strip=False,
38 | upx=True,
39 | upx_exclude=[],
40 | name='captcha_platform_tornado_gpu')
41 |
42 |
43 |
44 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # Author: kerlomz
4 | import io
5 | import re
6 | import os
7 | import cv2
8 | import time
9 | import base64
10 | import functools
11 | import binascii
12 | import datetime
13 | import hashlib
14 | import numpy as np
15 | import tensorflow as tf
16 | from PIL import Image as PIL_Image
17 | from constants import Response, SystemConfig
18 | from pretreatment import preprocessing, preprocessing_by_func
19 | from config import ModelConfig, Config
20 | from middleware.impl.gif_frames import concat_frames, blend_frame
21 | from middleware.impl.rgb_filter import rgb_filter
22 |
23 |
24 | class Arithmetic(object):
25 |
26 | def calc(self, formula):
27 | formula = re.sub(' ', '', formula)
28 | formula_ret = 0
29 | match_brackets = re.search(r'\([^()]+\)', formula)
30 | if match_brackets:
31 | calc_result = self.calc(match_brackets.group().strip("(,)"))
32 | formula = formula.replace(match_brackets.group(), str(calc_result))
33 | return self.calc(formula)
34 | else:
35 | formula = formula.replace('--', '+').replace('++', '+').replace('-+', '-').replace('+-', '-')
36 | while re.findall(r"[*/]", formula):
37 | get_formula = re.search(r"[.\d]+[*/]+[-]?[.\d]+", formula)
38 | if get_formula:
39 | get_formula_str = get_formula.group()
40 | if get_formula_str.count("*"):
41 | formula_list = get_formula_str.split("*")
42 | ret = float(formula_list[0]) * float(formula_list[1])
43 | else:
44 | formula_list = get_formula_str.split("/")
45 | ret = float(formula_list[0]) / float(formula_list[1])
46 | formula = formula.replace(get_formula_str, str(ret)).replace('--', '+').replace('++', '+')
47 | formula = re.findall(r'[-]?[.\d]+', formula)
48 | for num in formula:
49 | formula_ret += float(num)
50 | return formula_ret
51 |
52 |
53 | class ParamUtils(object):
54 |
55 | @staticmethod
56 | def filter(param):
57 | if isinstance(param, list) and len(param) > 0 and isinstance(param[0], bytes):
58 | return param[0].decode()
59 | return param
60 |
61 |
62 | class SignUtils(object):
63 |
64 | @staticmethod
65 | def md5(text):
66 | return hashlib.md5(text.encode('utf-8')).hexdigest()
67 |
68 | @staticmethod
69 | def timestamp():
70 | return int(time.mktime(datetime.datetime.now().timetuple()))
71 |
72 |
73 | class PathUtils(object):
74 |
75 | @staticmethod
76 | def get_file_name(path: str):
77 | if '/' in path:
78 | return path.split('/')[-1]
79 | elif '\\' in path:
80 | return path.split('\\')[-1]
81 | else:
82 | return path
83 |
84 |
85 | class ImageUtils(object):
86 |
87 | def __init__(self, conf: Config):
88 | self.conf = conf
89 |
90 | def get_bytes_batch(self, base64_or_bytes):
91 | response = Response(self.conf.response_def_map)
92 | b64_filter_s = lambda s: re.sub("data:image/.+?base64,", "", s, 1) if ',' in s else s
93 | b64_filter_b = lambda s: re.sub(b"data:image/.+?base64,", b"", s, 1) if b',' in s else s
94 | try:
95 | if isinstance(base64_or_bytes, bytes):
96 | if self.conf.split_flag in base64_or_bytes:
97 | bytes_batch = base64_or_bytes.split(self.conf.split_flag)
98 | else:
99 | bytes_batch = [base64_or_bytes]
100 | elif isinstance(base64_or_bytes, list):
101 | bytes_batch = [base64.b64decode(b64_filter_s(i).encode('utf-8')) for i in base64_or_bytes if
102 | isinstance(i, str)]
103 | if not bytes_batch:
104 | bytes_batch = [base64.b64decode(b64_filter_b(i)) for i in base64_or_bytes if isinstance(i, bytes)]
105 | else:
106 | base64_or_bytes = b64_filter_s(base64_or_bytes)
107 | bytes_batch = base64.b64decode(base64_or_bytes.encode('utf-8')).split(self.conf.split_flag)
108 |
109 | except binascii.Error:
110 | return None, response.INVALID_BASE64_STRING
111 | what_img = [ImageUtils.test_image(i) for i in bytes_batch]
112 |
113 | if None in what_img:
114 | return None, response.INVALID_IMAGE_FORMAT
115 | return bytes_batch, response.SUCCESS
116 |
117 | @staticmethod
118 | def get_image_batch(model: ModelConfig, bytes_batch, param_key=None, extract_rgb: list = None):
119 | # Note that there are two return objects here.
120 | # 1.image_batch, 2.response
121 |
122 | response = Response(model.conf.response_def_map)
123 |
124 | def load_image(image_bytes: bytes):
125 | data_stream = io.BytesIO(image_bytes)
126 | pil_image = PIL_Image.open(data_stream)
127 |
128 | gif_handle = model.pre_concat_frames != -1 or model.pre_blend_frames != -1
129 |
130 | if pil_image.mode == 'P' and not gif_handle:
131 | pil_image = pil_image.convert('RGB')
132 |
133 | rgb = pil_image.split()
134 | size = pil_image.size
135 |
136 | if (len(rgb) > 3 and model.pre_replace_transparent) and not gif_handle:
137 | background = PIL_Image.new('RGB', pil_image.size, (255, 255, 255))
138 | try:
139 | background.paste(pil_image, (0, 0, size[0], size[1]), pil_image)
140 | pil_image = background
141 | except:
142 | pil_image = pil_image.convert('RGB')
143 |
144 | if len(pil_image.split()) > 3 and model.image_channel == 3:
145 | pil_image = pil_image.convert('RGB')
146 |
147 | if model.pre_concat_frames != -1:
148 | im = concat_frames(pil_image, model.pre_concat_frames)
149 | elif model.pre_blend_frames != -1:
150 | im = blend_frame(pil_image, model.pre_blend_frames)
151 | else:
152 | im = np.asarray(pil_image)
153 |
154 | if extract_rgb:
155 | im = rgb_filter(im, extract_rgb)
156 |
157 | im = preprocessing_by_func(
158 | exec_map=model.exec_map,
159 | key=param_key,
160 | src_arr=im
161 | )
162 |
163 | if model.image_channel == 1 and len(im.shape) == 3:
164 | im = cv2.cvtColor(im, cv2.COLOR_RGB2GRAY)
165 |
166 | im = preprocessing(
167 | image=im,
168 | binaryzation=model.pre_binaryzation,
169 | )
170 |
171 | if model.pre_horizontal_stitching:
172 | up_slice = im[0: int(size[1] / 2), 0: size[0]]
173 | down_slice = im[int(size[1] / 2): size[1], 0: size[0]]
174 | im = np.concatenate((up_slice, down_slice), axis=1)
175 |
176 | image = im.astype(np.float32)
177 | if model.resize[0] == -1:
178 | ratio = model.resize[1] / size[1]
179 | resize_width = int(ratio * size[0])
180 | image = cv2.resize(image, (resize_width, model.resize[1]))
181 | else:
182 | image = cv2.resize(image, (model.resize[0], model.resize[1]))
183 | image = image.swapaxes(0, 1)
184 | return (image[:, :, np.newaxis] if model.image_channel == 1 else image[:, :]) / 255.
185 |
186 | try:
187 | image_batch = [load_image(i) for i in bytes_batch]
188 | return image_batch, response.SUCCESS
189 | except OSError:
190 | return None, response.IMAGE_DAMAGE
191 | except ValueError as _e:
192 | print(_e)
193 | return None, response.IMAGE_SIZE_NOT_MATCH_GRAPH
194 |
195 | @staticmethod
196 | def size_of_image(image_bytes: bytes):
197 | _null_size = tuple((-1, -1))
198 | try:
199 | data_stream = io.BytesIO(image_bytes)
200 | size = PIL_Image.open(data_stream).size
201 | return size
202 | except OSError:
203 | return _null_size
204 | except ValueError:
205 | return _null_size
206 |
207 | @staticmethod
208 | def test_image(h):
209 | """JPEG"""
210 | if h[:3] == b"\xff\xd8\xff":
211 | return 'jpeg'
212 | """PNG"""
213 | if h[:8] == b"\211PNG\r\n\032\n":
214 | return 'png'
215 | """GIF ('87 and '89 variants)"""
216 | if h[:6] in (b'GIF87a', b'GIF89a'):
217 | return 'gif'
218 | """TIFF (can be in Motorola or Intel byte order)"""
219 | if h[:2] in (b'MM', b'II'):
220 | return 'tiff'
221 | if h[:2] == b'BM':
222 | return 'bmp'
223 | """SGI image library"""
224 | if h[:2] == b'\001\332':
225 | return 'rgb'
226 | """PBM (portable bitmap)"""
227 | if len(h) >= 3 and \
228 | h[0] == b'P' and h[1] in b'14' and h[2] in b' \t\n\r':
229 | return 'pbm'
230 | """PGM (portable graymap)"""
231 | if len(h) >= 3 and \
232 | h[0] == b'P' and h[1] in b'25' and h[2] in b' \t\n\r':
233 | return 'pgm'
234 | """PPM (portable pixmap)"""
235 | if len(h) >= 3 and h[0] == b'P' and h[1] in b'36' and h[2] in b' \t\n\r':
236 | return 'ppm'
237 | """Sun raster file"""
238 | if h[:4] == b'\x59\xA6\x6A\x95':
239 | return 'rast'
240 | """X bitmap (X10 or X11)"""
241 | s = b'#define '
242 | if h[:len(s)] == s:
243 | return 'xbm'
244 | return None
245 |
246 |
247 | class SystemUtils(object):
248 |
249 | @staticmethod
250 | def datetime(origin=None, microseconds=None):
251 | now = origin if origin else time.time()
252 | if microseconds:
253 | return (
254 | datetime.datetime.fromtimestamp(now) + datetime.timedelta(microseconds=microseconds)
255 | ).strftime('%Y-%m-%d %H:%M:%S.%f')
256 | return datetime.datetime.fromtimestamp(now).strftime('%Y-%m-%d %H:%M:%S.%f')
257 |
258 | @staticmethod
259 | def isdir(sftp, path):
260 | from stat import S_ISDIR
261 | try:
262 | return S_ISDIR(sftp.stat(path).st_mode)
263 | except IOError:
264 | return False
265 |
266 | @staticmethod
267 | def empty(sftp, path):
268 | from paramiko import SFTPClient
269 | if not SystemUtils.isdir(sftp, path):
270 | sftp.mkdir(path)
271 |
272 | files = sftp.listdir(path=path)
273 |
274 | for f in files:
275 | file_path = os.path.join(path, f)
276 | if SystemUtils.isdir(sftp, file_path):
277 | SystemUtils.empty(sftp, file_path)
278 | else:
279 | sftp.remove(file_path)
280 |
--------------------------------------------------------------------------------