├── .gitignore ├── 4.jpg ├── README.md ├── api ├── .dockerignore ├── Dockerfile ├── docker-entrypoint.sh ├── requirements.txt ├── src │ ├── api_server.py │ ├── config.py │ ├── utils.py │ └── worker.py ├── trained_model │ ├── trained_model.h5 │ └── trained_model.json └── uwsgi.ini ├── docker-compose.yml ├── nginx ├── Dockerfile ├── nginx.conf └── uwsgi_params └── redis └── Dockerfile /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # vscode 107 | .vscode/ -------------------------------------------------------------------------------- /4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ahmed-mez/keras-rest-API/6c3a2c05f14c564f633927002d330051d6899d03/4.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Learning API using Flask, Keras, Redis, nginx and Docker 2 | A scalable Flask API to interact with a pre-trained Keras model. 3 | 4 | ### Overview 5 | The API uses Redis for queuing requests, batch them and feed them to the model to predict the classes then responds the client with a JSON containing the result of his request (classes with top probabilities). 6 | In order to support heavy load and to avoid multiple model syndrome, the model load and prediction, and the receiving/sending requests run independently of each other on different processes. 7 | 8 | 9 | Please note that the model used in this project (which is a simple digit recognition OCR model) is just an example, as the main purpose of the project is the development and deployment of the API. 10 | 11 | ### Deployment 12 | The setup consists of 3 containers: 13 | 1. Flask app with uWSGI 14 | 2. Redis 15 | 3. nginx 16 | 17 | All 3 containers are based respectively on the official docker images of python:2.7.15, Redis and nginx. 18 | 19 | 20 | #### Setup 21 | We just need to build the images and run them using `docker-compose`: 22 | 23 | ``` 24 | $ docker-compose build 25 | $ docker-compose up 26 | ``` 27 | 28 | A head over `http://localhost:8080` should show the following message on the web page: 29 | `Welcome to the digits OCR Keras REST API` 30 | 31 | 32 | #### Example 33 | We can try to submit `POST` requests to the API on the `/predict` entry point: 34 | 35 | `$ curl -X POST -F image=@4.jpg 'http://localhost:8080/predict'` 36 | 37 | The API response : 38 | 39 | ``` 40 | { 41 | "predictions": [ 42 | { 43 | "label": "4", 44 | "probability": 0.9986100196838379 45 | }, 46 | { 47 | "label": "9", 48 | "probability": 0.001174862147308886 49 | }, 50 | { 51 | "label": "7", 52 | "probability": 0.00009050272637978196 53 | } 54 | ], 55 | "success": true 56 | } 57 | ``` 58 | ### TODO 59 | - [ ] Tests 60 | -------------------------------------------------------------------------------- /api/.dockerignore: -------------------------------------------------------------------------------- 1 | venv 2 | **/*.pyc 3 | -------------------------------------------------------------------------------- /api/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:2.7.15 2 | 3 | # update packages 4 | RUN apt-get update -y 5 | 6 | # set workdir 7 | WORKDIR /api 8 | 9 | # pip install 10 | RUN pip install uwsgi 11 | COPY requirements.txt /api 12 | RUN pip install -r requirements.txt 13 | RUN rm requirements.txt 14 | 15 | # deep learning weights 16 | COPY trained_model /api/trained_model 17 | 18 | # uwsgi conf 19 | COPY uwsgi.ini /etc/uwsgi/apps-available/api.ini 20 | 21 | # code 22 | COPY src /api/src 23 | 24 | # logs folder 25 | RUN mkdir logs 26 | 27 | # entrypoint 28 | COPY docker-entrypoint.sh / 29 | RUN chmod +x /docker-entrypoint.sh 30 | ENTRYPOINT ["/docker-entrypoint.sh"] 31 | CMD ["run"] -------------------------------------------------------------------------------- /api/docker-entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | if [ "$1" = 'run' ]; then 5 | python /api/src/worker.py & 6 | uwsgi --ini /etc/uwsgi/apps-available/api.ini 7 | else 8 | exec "$@" 9 | fi -------------------------------------------------------------------------------- /api/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.2.0 2 | astor==0.6.2 3 | backports.weakref==1.0.post1 4 | bleach==1.5.0 5 | click==6.7 6 | enum34==1.1.6 7 | Flask==1.0.2 8 | funcsigs==1.0.2 9 | futures==3.2.0 10 | gast==0.2.0 11 | grpcio==1.11.0 12 | h5py==2.7.1 13 | html5lib==0.9999999 14 | itsdangerous==0.24 15 | Jinja2==2.10 16 | Keras==2.1.6 17 | Markdown==2.6.11 18 | MarkupSafe==1.0 19 | mock==2.0.0 20 | numpy==1.14.3 21 | pbr==4.0.2 22 | Pillow==5.1.0 23 | protobuf==3.5.2.post1 24 | PyYAML==3.12 25 | redis==2.10.6 26 | scipy==1.1.0 27 | six==1.11.0 28 | tensorboard==1.8.0 29 | tensorflow==1.8.0 30 | termcolor==1.1.0 31 | Werkzeug==0.14.1 32 | -------------------------------------------------------------------------------- /api/src/api_server.py: -------------------------------------------------------------------------------- 1 | from config import (REDIS_HOST, REDIS_PORT, REDIS_DB, 2 | IMAGE_WIDTH, IMAGE_HEIGHT, IMAGE_QUEUE, 3 | CONSUMER_SLEEP, LOG_DIR) 4 | from flask import Flask, request, jsonify 5 | from utils import b64_encoding, prepare_image 6 | from redis import StrictRedis 7 | from uuid import uuid4 8 | import logging 9 | import json 10 | import time 11 | 12 | app = Flask(__name__) 13 | db = StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB) 14 | logging.basicConfig(filename=LOG_DIR+"/api.log", level=logging.INFO) 15 | 16 | 17 | @app.route("/") 18 | def index(): 19 | return "Welcome to the digits OCR Keras REST API" 20 | 21 | @app.route("/predict", methods=["POST"]) 22 | def predict(): 23 | data = {"success": False} 24 | if request.method == "POST": 25 | if request.files.get("image"): 26 | image = request.files["image"].read() 27 | try: 28 | image = prepare_image(image) 29 | except Exception: 30 | logging.exception("Error in preparing image") 31 | return jsonify(data), 400 32 | im_id = str(uuid4()) 33 | im_dict = {"im_id": im_id, "image": b64_encoding(image)} 34 | # send image to the redis queue 35 | db.rpush(IMAGE_QUEUE, json.dumps(im_dict)) 36 | logging.info("New image added to queue with id: %s", im_id) 37 | while True: 38 | # start polling 39 | output = db.get(im_id) 40 | if output is not None: 41 | # image processed, try to get predictions 42 | try: 43 | output = output.decode("utf-8") 44 | data["predictions"] = json.loads(output) 45 | logging.info("Got predictions for image with id: %s", im_id) 46 | except Exception: 47 | logging.exception("Error in getting predictions for image with id: %s", im_id) 48 | return jsonify(data), 400 49 | finally: 50 | db.delete(im_id) 51 | logging.info("Deleting image from queue with id: %s", im_id) 52 | break 53 | time.sleep(CONSUMER_SLEEP) 54 | data["success"] = True 55 | logging.info("Send result for image with id: %s", im_id) 56 | return jsonify(data), 200 57 | logging.warning("Invalid request with file %s", request.files) 58 | return jsonify(data), 400 59 | 60 | if __name__ == "__main__": 61 | app.run(host='0.0.0.0', debug=True) 62 | -------------------------------------------------------------------------------- /api/src/config.py: -------------------------------------------------------------------------------- 1 | # initialize image dimensions 2 | IMAGE_WIDTH = 28 3 | IMAGE_HEIGHT = 28 4 | IMAGE_SHAPE = IMAGE_WIDTH*IMAGE_WIDTH 5 | 6 | # initialize constants used for server queuing 7 | IMAGE_QUEUE = "image_queue" 8 | BATCH_SIZE = 32 9 | WORKER_SLEEP = 0.25 10 | CONSUMER_SLEEP = 0.25 11 | 12 | # initialize Redis connection settings 13 | REDIS_HOST = "redis" 14 | REDIS_PORT = 6379 15 | REDIS_DB = 0 16 | 17 | # weights files 18 | WEIGHTS_JSON = "/api/trained_model/trained_model.json" 19 | WEIGHTS_H5 = "/api/trained_model/trained_model.h5" 20 | 21 | # logging location 22 | LOG_DIR = "/api/logs" -------------------------------------------------------------------------------- /api/src/utils.py: -------------------------------------------------------------------------------- 1 | from numpy import frombuffer, newaxis 2 | from base64 import b64encode, b64decode 3 | from PIL import Image, ImageFilter 4 | from io import BytesIO 5 | from numpy import array, newaxis 6 | from config import IMAGE_WIDTH, IMAGE_HEIGHT 7 | 8 | 9 | def b64_encoding(array): 10 | return b64encode(array).decode("utf-8") 11 | 12 | def b64_decoding(enc_array, shape): 13 | return frombuffer(b64decode(enc_array)).reshape(shape)[newaxis] 14 | 15 | def prepare_image(image, target_width=IMAGE_WIDTH, target_height=IMAGE_HEIGHT): 16 | """ Prepare image to be processed 17 | Convert image and adapt its size to be processed by the OCR model. 18 | 19 | Arguments: 20 | image: file -- image to be prepared and processed by the api 21 | 22 | Returns: 23 | Numpy array 24 | """ 25 | im = Image.open(BytesIO(image)).convert('L') 26 | width = float(im.size[0]) 27 | height = float(im.size[1]) 28 | newImage = Image.new('L', (target_width, target_height), (255)) 29 | if width > height: 30 | nheight = int(round((20.0 / width * height), 0)) 31 | if (nheight == 0): 32 | nheight = 1 33 | img = im.resize((20, nheight), Image.ANTIALIAS).filter(ImageFilter.SHARPEN) 34 | wtop = int(round(((target_height - nheight) / 2), 0)) 35 | newImage.paste(img, (4, wtop)) 36 | else: 37 | nwidth = int(round((20.0 / height * width), 0)) 38 | if (nwidth == 0): 39 | nwidth = 1 40 | img = im.resize((nwidth, 20), Image.ANTIALIAS).filter(ImageFilter.SHARPEN) 41 | wleft = int(round(((target_width - nwidth) / 2), 0)) 42 | newImage.paste(img, (wleft, 4)) 43 | tv = list(newImage.getdata()) 44 | tva = [(255 - x) * 1.0 / 255.0 for x in tv] 45 | return array(tva)[newaxis].copy(order="C") 46 | -------------------------------------------------------------------------------- /api/src/worker.py: -------------------------------------------------------------------------------- 1 | from config import (IMAGE_SHAPE, IMAGE_QUEUE, BATCH_SIZE, 2 | WORKER_SLEEP, REDIS_HOST, REDIS_PORT, REDIS_DB, 3 | WEIGHTS_JSON, WEIGHTS_H5, LOG_DIR) 4 | from numpy import argpartition, argsort, vstack 5 | from keras.models import model_from_json 6 | from utils import b64_decoding 7 | from redis import StrictRedis 8 | import logging 9 | import json 10 | import time 11 | 12 | db = StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB) 13 | logging.basicConfig(filename=LOG_DIR+"/worker.log", level=logging.INFO) 14 | 15 | 16 | def load_model(): 17 | """ Load the keras model in memory 18 | """ 19 | try: 20 | json_file = open(WEIGHTS_JSON, 'r') 21 | loaded_model_json = json_file.read() 22 | json_file.close() 23 | except Exception: 24 | raise 25 | model = model_from_json(loaded_model_json) 26 | model.load_weights(WEIGHTS_H5) 27 | return model 28 | 29 | def decode_predictions(predictions, top=3): 30 | """ Interpret the predictions 31 | 32 | Arguments: 33 | predictions [[float]] -- predictions made by the model 34 | Yields: 35 | generator([float]) -- sorted result 36 | """ 37 | for pred in predictions: 38 | indexes = argpartition(pred, -top)[-top:] 39 | indexes = indexes[argsort(-pred[indexes])] 40 | preds = list() 41 | for i in xrange(top): 42 | preds.append((indexes[i], pred[indexes[i]])) 43 | yield preds 44 | 45 | 46 | def predict_process(target_shape=IMAGE_SHAPE): 47 | """ Worker process, load model and poll for images 48 | """ 49 | model = load_model() 50 | assert model is not None 51 | logging.info("Model loaded successfully, start polling for images") 52 | while True: 53 | # start polling 54 | queue = db.lrange(IMAGE_QUEUE, 0, BATCH_SIZE -1) 55 | image_IDs = [] 56 | batch = None 57 | for q in queue: 58 | q = json.loads(q.decode("utf-8")) 59 | image = b64_decoding(q["image"], (target_shape,)) 60 | if batch is None: 61 | batch = image 62 | else: 63 | batch = vstack([batch, image]) 64 | image_IDs.append(q["im_id"]) 65 | if len(image_IDs): 66 | # queue contains images to be processed 67 | try: 68 | preds = model.predict(batch) 69 | results = decode_predictions(preds) 70 | logging.info("Batch predicted successfully with images ids: %s", image_IDs) 71 | except Exception: 72 | logging.exception("Error in prediction, batch with images ids: %s", image_IDs) 73 | raise 74 | for (image_id, result_set) in zip (image_IDs, results): 75 | output = [] 76 | for (label, prob) in result_set: 77 | res = {"label": str(label), "probability": float(prob)} 78 | output.append(res) 79 | logging.info("Setting predictions in queue for image with id: %s", image_id) 80 | db.set(image_id, json.dumps(output)) 81 | db.ltrim(IMAGE_QUEUE, len(image_IDs), -1) 82 | time.sleep(WORKER_SLEEP) 83 | 84 | if __name__ == "__main__": 85 | logging.info("Starting process..") 86 | predict_process() 87 | -------------------------------------------------------------------------------- /api/trained_model/trained_model.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ahmed-mez/keras-rest-API/6c3a2c05f14c564f633927002d330051d6899d03/api/trained_model/trained_model.h5 -------------------------------------------------------------------------------- /api/trained_model/trained_model.json: -------------------------------------------------------------------------------- 1 | {"class_name": "Sequential", "keras_version": "2.1.5", "config": [{"class_name": "Dense", "config": {"kernel_initializer": {"class_name": "VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": null, "mode": "fan_avg"}}, "name": "dense_1", "kernel_constraint": null, "bias_regularizer": null, "bias_constraint": null, "dtype": "float32", "activation": "relu", "trainable": true, "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros", "config": {}}, "units": 512, "batch_input_shape": [null, 784], "use_bias": true, "activity_regularizer": null}}, {"class_name": "Dropout", "config": {"rate": 0.2, "noise_shape": null, "trainable": true, "seed": null, "name": "dropout_1"}}, {"class_name": "Dense", "config": {"kernel_initializer": {"class_name": "VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": null, "mode": "fan_avg"}}, "name": "dense_2", "kernel_constraint": null, "bias_regularizer": null, "bias_constraint": null, "activation": "relu", "trainable": true, "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros", "config": {}}, "units": 512, "use_bias": true, "activity_regularizer": null}}, {"class_name": "Dropout", "config": {"rate": 0.2, "noise_shape": null, "trainable": true, "seed": null, "name": "dropout_2"}}, {"class_name": "Dense", "config": {"kernel_initializer": {"class_name": "VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": null, "mode": "fan_avg"}}, "name": "dense_3", "kernel_constraint": null, "bias_regularizer": null, "bias_constraint": null, "activation": "softmax", "trainable": true, "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros", "config": {}}, "units": 10, "use_bias": true, "activity_regularizer": null}}], "backend": "tensorflow"} -------------------------------------------------------------------------------- /api/uwsgi.ini: -------------------------------------------------------------------------------- 1 | [uwsgi] 2 | chdir = /api/src 3 | module = api_server:app 4 | socket = 0.0.0.0:5000 5 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | api: 2 | build: api 3 | links: 4 | - redis 5 | 6 | nginx: 7 | build: nginx 8 | ports: 9 | - "8080:80" 10 | links: 11 | - api 12 | 13 | redis: 14 | build: redis 15 | -------------------------------------------------------------------------------- /nginx/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nginx 2 | COPY uwsgi_params /etc/nginx/conf.d/uwsgi_params 3 | COPY nginx.conf /etc/nginx/conf.d/default.conf 4 | CMD nginx -c /etc/nginx/conf.d/default.conf 5 | -------------------------------------------------------------------------------- /nginx/nginx.conf: -------------------------------------------------------------------------------- 1 | daemon off; 2 | 3 | events { 4 | worker_connections 1024; 5 | } 6 | 7 | http { 8 | server { 9 | listen 80; 10 | server_name localhost; 11 | 12 | location / { 13 | include uwsgi_params; 14 | uwsgi_pass api:5000; 15 | } 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /nginx/uwsgi_params: -------------------------------------------------------------------------------- 1 | uwsgi_param QUERY_STRING $query_string; 2 | uwsgi_param REQUEST_METHOD $request_method; 3 | uwsgi_param CONTENT_TYPE $content_type; 4 | uwsgi_param CONTENT_LENGTH $content_length; 5 | uwsgi_param REQUEST_URI $request_uri; 6 | uwsgi_param PATH_INFO $document_uri; 7 | uwsgi_param DOCUMENT_ROOT $document_root; 8 | uwsgi_param SERVER_PROTOCOL $server_protocol; 9 | uwsgi_param REMOTE_ADDR $remote_addr; 10 | uwsgi_param REMOTE_PORT $remote_port; 11 | uwsgi_param SERVER_ADDR $server_addr; 12 | uwsgi_param SERVER_PORT $server_port; 13 | uwsgi_param SERVER_NAME $server_name; 14 | -------------------------------------------------------------------------------- /redis/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM redis 2 | CMD redis-server --bind 0.0.0.0 3 | --------------------------------------------------------------------------------