├── LICENSE ├── README.md ├── model.py ├── predict_by_server.py ├── predict_folder.py ├── predict_one.py ├── run_a_server.py ├── testImg ├── 0.png ├── 1.png ├── 2.png ├── 3.png ├── 4.png ├── 5.png ├── 6.png ├── 7.png ├── 8.png └── 9.png └── trained_weights.h5 /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Brief-rf 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 使用说明 2 | 3 | ## 1. 在本地测试 4 | - 运行```python3 prdict_one.py```即可,默认需要预测的图片路径位于```testImg```文件夹下的```test1.png``` 5 | - 运行```python3 predict_folder.py```预测testImg下的所有图片 6 | ## 2. 部署到服务器 7 | - 运行```python3 run_a_server.py 8888```即在端口8888部署api,也可以在本地运行测试,如果不填写端口,默认端口为7000 8 | ## 3. 通过所部署API进行预测 9 | 运行```python3 predict_by_server.py```即可调用所部署的api进行```testImg```文件夹下所有图片的预测 10 | 11 | ## 4. 打包为可执行文件 12 | ```shell 13 | pyinstaller -F run_a_server.py 14 | ``` 15 | 即可打包为可执行文件,打包结束后会在dist文件夹下生成可执行文件。**直接运行不需要python环境** 16 | > 注意!运行打包后的二进制文件时需要将trained_weights.h5放在同级目录下 17 | ## 其他说明 18 | - 通过测试,调用本地API预测一张图平均所需约0.05s,服务器1C2G占用内存200~400mb之间 19 | - 我所用的tensorflo和Keras版本分别为:1.15.2和2.3.1 20 | - ```model.py``` 存放网络结构 21 | - trained_weights.h5 训练的权重文件 22 | - 简单curl命令测试 23 | ```shell 24 | curl -X POST -F image=@testImg/test1.png 'http://localhost:7000/predict' 25 | ``` 26 | 成功后会返回 27 | ```json 28 | {"predictions":0.553,"success":true} 29 | ``` 30 | 本模型所针对的输入图片尺寸大小必须为140x360,没有适配其他图片大小,有能力的可以根据网络结构进行修改从而适应所需。 31 | 32 | > 免责声明: 本仓库项目中所涉及的脚本,仅用于测试和学习研究,不保证其合法性,准确性,完整性和有效性,请根据情况自行判断。请勿将本项目的任何内容用于商业或非法途径,否则后果由使用者自负。如果您认为该项目的内容可能涉嫌侵犯其权利,请与我联系,我会尽快删除文件。如果您使用或复制了本仓库项目中的任何内容,则视为您已接受此免责声明。 -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Input, Conv2D, Dense, Flatten, AveragePooling2D, Add, Concatenate 2 | from keras.models import Model 3 | 4 | 5 | def brief_net(input_shape=(170, 275, 3), output_shape=1): 6 | input = Input(shape=input_shape) 7 | x = Conv2D(filters=128, kernel_size=1, strides=1, padding='same', activation='relu')(input) 8 | shortcut_ = Conv2D(filters=32, kernel_size=3, strides=4, padding='same', activation='relu')(x) 9 | x = Conv2D(filters=64, kernel_size=3, strides=1, padding='same', activation='relu')(x) 10 | 11 | shortcut = AveragePooling2D(pool_size=2, strides=2, padding='same')(x) 12 | shortcut = Conv2D(filters=16, kernel_size=3, strides=2, padding="same", activation='relu')(shortcut) 13 | x = Conv2D(filters=32, kernel_size=5, strides=2, padding='same', activation='relu')(x) 14 | x = Conv2D(filters=16, kernel_size=3, strides=2, padding='same', activation='relu')(x) 15 | x = Add()([x, shortcut]) 16 | x = Concatenate(axis=-1)([x, shortcut_]) 17 | x = Conv2D(filters=8, kernel_size=3, strides=2, padding='same', activation='relu')(x) 18 | x = Flatten()(x) 19 | x = Dense(units=12, activation='relu')(x) 20 | 21 | x = Dense(units=6, activation='relu')(x) 22 | 23 | output = Dense(units=output_shape, activation="sigmoid")(x) 24 | 25 | model = Model(input, output) 26 | 27 | return model 28 | 29 | if __name__ == '__main__': 30 | model = brief_net() 31 | model.summary() -------------------------------------------------------------------------------- /predict_by_server.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import os 3 | import cv2 4 | import time 5 | 6 | def draw_line(img_folder, img_name, x_val): 7 | img_path = os.path.join(img_folder, "pred" + img_name) 8 | img = cv2.imread(os.path.join(img_folder, img_name)) 9 | cv2.line(img, pt1=(x_val, 0), pt2=(x_val, 360), color=(255, 0, 0), thickness=2) 10 | cv2.imwrite(img_path, img) 11 | def sendImg(img_path): 12 | KERAS_REST_API_URL = "http://127.0.0.1:7000/predict" 13 | image = open(img_path, "rb").read() 14 | payload = {"image": image} 15 | r = requests.post(KERAS_REST_API_URL, files=payload).json() 16 | if r["success"]: 17 | print(r) 18 | return r['predictions'] 19 | img_path = "testImg" 20 | timeList = [] 21 | for testimg in os.listdir(img_path): 22 | img_p = os.path.join(img_path, testimg) 23 | start = time.time() 24 | pre = sendImg(img_p) 25 | cost_time = time.time()-start 26 | draw_line(img_path, testimg, int(pre*360)) 27 | timeList.append(cost_time) 28 | print("{} x坐标预测:{} 耗时:{}s".format(testimg, pre, cost_time)) 29 | 30 | print("平均检测一张图片耗时:{}s".format(sum(timeList)/len(timeList))) -------------------------------------------------------------------------------- /predict_folder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from model import brief_net 3 | import numpy as np 4 | import cv2 5 | 6 | def draw_line(img_folder, img_name, x_val): 7 | img_path = os.path.join(img_folder, "pred"+img_name) 8 | img = cv2.imread(os.path.join(img_folder, img_name)) 9 | print(x_val) 10 | cv2.line(img, pt1=(x_val, 0), pt2=(x_val, 170), color=(0, 255, 255), thickness=2) 11 | cv2.imwrite(img_path, img) 12 | from PIL import Image 13 | 14 | img_path = "testImg" 15 | 16 | if __name__ == '__main__': 17 | # 图片路径 18 | model = brief_net(input_shape=(140, 360, 3), output_shape=1) 19 | model.load_weights("trained_weights.h5") 20 | for testimg in os.listdir(img_path): 21 | img = Image.open(os.path.join(img_path, testimg)) 22 | img = np.array(img) / 255. 23 | img = img[:,:,:3] 24 | print(img.shape) 25 | img = (np.expand_dims(img, 0)) 26 | pre = model.predict(img) 27 | draw_line(img_path, testimg, int(pre*360)) 28 | print("Done") -------------------------------------------------------------------------------- /predict_one.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import os 3 | import cv2 4 | 5 | def draw_line(img_folder, img_name, x_val): 6 | img_path = os.path.join(img_folder, "pred" + img_name) 7 | img = cv2.imread(os.path.join(img_folder, img_name)) 8 | cv2.line(img, pt1=(x_val, 0), pt2=(x_val, 360), color=(255, 0, 0), thickness=2) 9 | print(img_path) 10 | cv2.imwrite(img_path, img) 11 | def sendImg(img_path): 12 | KERAS_REST_API_URL = "http://127.0.0.1:7000/predict" 13 | image = open(img_path, "rb").read() 14 | payload = {"image": image} 15 | r = requests.post(KERAS_REST_API_URL, files=payload).json() 16 | if r["success"]: 17 | print(r) 18 | return r['predictions'] 19 | # 存放测试图片的文件夹 20 | test_img_folder = "testImg" 21 | 22 | img_name = "0.png" 23 | # 生成路径 24 | img_path = os.path.join(test_img_folder, img_name) 25 | # 预测的pre值在0-1之间,即图片的滑块的中心坐标 26 | pre = sendImg(img_path) 27 | 28 | # 乘以360原因是原始图片大小为 140x360 图片宽为360 29 | draw_line(test_img_folder, img_name, int(pre*360)) 30 | print("{} x坐标预测:{}".format(img_name, pre)) 31 | -------------------------------------------------------------------------------- /run_a_server.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 3 | 4 | import numpy as np 5 | from PIL import Image 6 | import sys 7 | import flask 8 | import io 9 | import tensorflow as tf 10 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) 11 | from model import brief_net 12 | from tensorflow.python.keras.backend import set_session 13 | from gevent import pywsgi 14 | 15 | sess = tf.Session() 16 | graph = tf.get_default_graph() 17 | app = flask.Flask(__name__) 18 | def load_model(): 19 | global model 20 | model = brief_net(input_shape=(140, 360, 3), output_shape=1) 21 | set_session(sess) 22 | try: 23 | model.load_weights("trained_weights.h5") 24 | except Exception as e: 25 | print(e) 26 | sys.exit(0) 27 | global graph 28 | graph = tf.get_default_graph() 29 | 30 | def prepare_image(image): 31 | image = np.array(image) / 255. 32 | image = image[:,:,:3] 33 | image = (np.expand_dims(image, axis=0)) 34 | return image 35 | @app.route("/predict", methods=["POST"]) 36 | def predict(): 37 | global sess, graph 38 | data = {"success": False} 39 | if flask.request.method == "POST": 40 | if flask.request.files.get("image"): 41 | image = flask.request.files["image"].read() 42 | image = Image.open(io.BytesIO(image)) 43 | image = prepare_image(image) 44 | with graph.as_default(): 45 | set_session(sess) 46 | preds = model.predict(image) 47 | data["predictions"] = round(preds.tolist()[0][0], 3) 48 | data["success"] = True 49 | return flask.jsonify(data) 50 | 51 | if __name__ == "__main__": 52 | port = 7000 53 | try: 54 | port = int(sys.argv[1]) 55 | except Exception as e: 56 | print(e) 57 | print("* Loading Keras model and Flask starting server...") 58 | load_model() 59 | print("* Model loaded successfully!") 60 | server = pywsgi.WSGIServer(('0.0.0.0', port), app) 61 | print("* Listening on http://0.0.0.0:{}/predict".format(port)) 62 | server.serve_forever() -------------------------------------------------------------------------------- /testImg/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Brief-rf/jd_joy_verification/abebab4e64fb8f7c6d21ec872185bf3ad9d4d6b2/testImg/0.png -------------------------------------------------------------------------------- /testImg/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Brief-rf/jd_joy_verification/abebab4e64fb8f7c6d21ec872185bf3ad9d4d6b2/testImg/1.png -------------------------------------------------------------------------------- /testImg/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Brief-rf/jd_joy_verification/abebab4e64fb8f7c6d21ec872185bf3ad9d4d6b2/testImg/2.png -------------------------------------------------------------------------------- /testImg/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Brief-rf/jd_joy_verification/abebab4e64fb8f7c6d21ec872185bf3ad9d4d6b2/testImg/3.png -------------------------------------------------------------------------------- /testImg/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Brief-rf/jd_joy_verification/abebab4e64fb8f7c6d21ec872185bf3ad9d4d6b2/testImg/4.png -------------------------------------------------------------------------------- /testImg/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Brief-rf/jd_joy_verification/abebab4e64fb8f7c6d21ec872185bf3ad9d4d6b2/testImg/5.png -------------------------------------------------------------------------------- /testImg/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Brief-rf/jd_joy_verification/abebab4e64fb8f7c6d21ec872185bf3ad9d4d6b2/testImg/6.png -------------------------------------------------------------------------------- /testImg/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Brief-rf/jd_joy_verification/abebab4e64fb8f7c6d21ec872185bf3ad9d4d6b2/testImg/7.png -------------------------------------------------------------------------------- /testImg/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Brief-rf/jd_joy_verification/abebab4e64fb8f7c6d21ec872185bf3ad9d4d6b2/testImg/8.png -------------------------------------------------------------------------------- /testImg/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Brief-rf/jd_joy_verification/abebab4e64fb8f7c6d21ec872185bf3ad9d4d6b2/testImg/9.png -------------------------------------------------------------------------------- /trained_weights.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Brief-rf/jd_joy_verification/abebab4e64fb8f7c6d21ec872185bf3ad9d4d6b2/trained_weights.h5 --------------------------------------------------------------------------------