├── api ├── __init__.py ├── models.py ├── Dockerfile └── app.py ├── .gitignore ├── img ├── doc.gif ├── schema.jpg └── webapp.gif ├── webapp ├── Dockerfile ├── app.py └── templates │ ├── static │ ├── styles.css │ └── app.js │ └── index.html ├── requirements.txt ├── celery_tasks ├── app_worker.py ├── tasks.py └── yolo.py ├── docker-compose.yml └── README.md /api/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pt 2 | __pycache__ 3 | .DS_Store -------------------------------------------------------------------------------- /img/doc.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GoldVelvet/YOLOv5-fastapi-celery-redis-rabbitmq/HEAD/img/doc.gif -------------------------------------------------------------------------------- /img/schema.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GoldVelvet/YOLOv5-fastapi-celery-redis-rabbitmq/HEAD/img/schema.jpg -------------------------------------------------------------------------------- /img/webapp.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GoldVelvet/YOLOv5-fastapi-celery-redis-rabbitmq/HEAD/img/webapp.gif -------------------------------------------------------------------------------- /webapp/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM tiangolo/uvicorn-gunicorn-fastapi:python3.7 2 | RUN pip install -U pip aiofiles python-multipart jinja2 3 | COPY /webapp /app -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | celery==4.4.7 2 | fastapi 3 | matplotlib 4 | opencv-python 5 | aiofiles 6 | python-multipart 7 | pyyaml 8 | seaborn 9 | uvicorn 10 | redis -------------------------------------------------------------------------------- /api/models.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | class Task(BaseModel): 5 | task_id: str 6 | status: str 7 | 8 | 9 | class Prediction(Task): 10 | task_id: str 11 | status: str 12 | result: str 13 | -------------------------------------------------------------------------------- /celery_tasks/app_worker.py: -------------------------------------------------------------------------------- 1 | import os 2 | from celery import Celery 3 | 4 | BROKER_URI = 'amqp://rabbitmq' 5 | BACKEND_URI = 'redis://redis' 6 | 7 | app = Celery( 8 | 'celery_tasks', 9 | broker=BROKER_URI, 10 | backend=BACKEND_URI, 11 | include=['celery_tasks.tasks'] 12 | ) 13 | -------------------------------------------------------------------------------- /api/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:latest 2 | WORKDIR /app 3 | RUN apt-get update 4 | RUN apt-get install ffmpeg libsm6 libxext6 -y 5 | COPY requirements.txt requirements.txt 6 | RUN pip install -U pip 7 | RUN pip install -r requirements.txt 8 | COPY . . 9 | RUN rm -r /app/api/uploads/*; exit 0 10 | RUN rm -r /app/api/static/*; exit 0 -------------------------------------------------------------------------------- /webapp/app.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Request 2 | from fastapi.responses import HTMLResponse 3 | from fastapi.staticfiles import StaticFiles 4 | from fastapi.templating import Jinja2Templates 5 | 6 | app = FastAPI() 7 | app.mount("/static", StaticFiles(directory="templates/static"), name="static") 8 | 9 | 10 | templates = Jinja2Templates(directory="templates") 11 | 12 | 13 | @app.get("/", response_class=HTMLResponse) 14 | async def index(request: Request): 15 | return templates.TemplateResponse("index.html", context={'request': request}) 16 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | services: 2 | rabbitmq: 3 | container_name: rabbitmq 4 | image: rabbitmq:3-management 5 | ports: 6 | - 5672:5672 7 | - 15672:15672 8 | 9 | redis: 10 | container_name: redis 11 | image: redis 12 | ports: 13 | - "6379:6379" 14 | 15 | worker: 16 | container_name: worker 17 | build: 18 | dockerfile: api/Dockerfile 19 | context: . 20 | command: sh -c "cd api && uvicorn app:app --host 0.0.0.0 --port 8000 & cd ../ & celery -A celery_tasks.app_worker worker -l INFO --pool=solo" 21 | volumes: 22 | - .:/app 23 | ports: 24 | - "8000:8000" 25 | 26 | webapp: 27 | container_name: webapp 28 | build: 29 | dockerfile: webapp/Dockerfile 30 | context: . 31 | command: sh -c "cd /app && uvicorn app:app --host 0.0.0.0 --port 80 --reload" 32 | ports: 33 | - "80:80" -------------------------------------------------------------------------------- /celery_tasks/tasks.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from celery import Task 3 | from celery.exceptions import MaxRetriesExceededError 4 | from .app_worker import app 5 | from .yolo import YoloModel 6 | 7 | 8 | class PredictTask(Task): 9 | def __init__(self): 10 | super().__init__() 11 | self.model = None 12 | 13 | def __call__(self, *args, **kwargs): 14 | if not self.model: 15 | logging.info('Loading Model...') 16 | self.model = YoloModel() 17 | logging.info('Model loaded') 18 | return self.run(*args, **kwargs) 19 | 20 | 21 | @app.task(ignore_result=False, bind=True, base=PredictTask) 22 | def predict_image(self, data): 23 | try: 24 | data_pred = self.model.predict(data) 25 | return {'status': 'SUCCESS', 'result': data_pred} 26 | except Exception as ex: 27 | try: 28 | self.retry(countdown=2) 29 | except MaxRetriesExceededError as ex: 30 | return {'status': 'FAIL', 'result': 'max retried achieved'} 31 | -------------------------------------------------------------------------------- /webapp/templates/static/styles.css: -------------------------------------------------------------------------------- 1 | h2 { 2 | font-size: 30px; 3 | text-align: center; 4 | width: 100%; 5 | background: #1462dc; 6 | padding: 13px; 7 | line-height: 30px; 8 | color: white; 9 | } 10 | 11 | h4 { 12 | margin-top: 20px; 13 | text-align: center; 14 | } 15 | 16 | div.preview, 17 | .centermask { 18 | float: left; 19 | width: 33.3%; 20 | height: auto; 21 | text-align: center; 22 | margin-bottom: 20px; 23 | } 24 | 25 | .centermask { 26 | width: 100%; 27 | } 28 | 29 | .img, 30 | .img-centermask { 31 | height: auto; 32 | max-width: 300px; 33 | background-size: contain; 34 | } 35 | 36 | .img-centermask { 37 | max-width: 800px; 38 | } 39 | 40 | span.class-name { 41 | display: block; 42 | width: 100%; 43 | text-align: center; 44 | } 45 | 46 | .header_label { 47 | margin-top: 10px; 48 | display: none; 49 | text-align: center; 50 | width: 100%; 51 | background: #96c0ff; 52 | } 53 | 54 | .col-sm-10 { 55 | text-align: center !important; 56 | } 57 | 58 | th, 59 | td { 60 | padding: 10px; 61 | } 62 | 63 | textarea { 64 | border: 0px; 65 | } 66 | 67 | #row_detail { 68 | margin-top: 20px; 69 | } -------------------------------------------------------------------------------- /celery_tasks/yolo.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | class YoloModel: 7 | def __init__(self): 8 | self.model = torch.hub.load('ultralytics/yolov5', 'yolov5x', pretrained=True) 9 | self.model.eval() 10 | 11 | def predict(self, img): 12 | try: 13 | with torch.no_grad(): 14 | result = self.model(img) 15 | result.save('api/static/results/') 16 | final_result = {} 17 | data = [] 18 | file_name = f'static/{result.files[0]}' 19 | 20 | for i in range(len(result.xywhn[0])): 21 | x, y, w, h, prob, cls = result.xywhn[0][i].numpy() 22 | preds = {} 23 | preds['x'] = str(x) 24 | preds['y'] = str(y) 25 | preds['w'] = str(w) 26 | preds['h'] = str(h) 27 | preds['prob'] = str(prob) 28 | preds['class'] = result.names[int(cls)] 29 | data.append(preds) 30 | 31 | return {'file_name': file_name, 'bbox': data} 32 | except Exception as ex: 33 | logging.error(str(ex)) 34 | return None 35 | -------------------------------------------------------------------------------- /webapp/templates/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Yolo v5 5 | 6 | 7 | 8 | 9 | 10 | 12 | 13 | 14 | 15 | 16 |

Yolo v5 - Object Detection

17 |

Demo application showing how to consume the API which performs object detection using Yolo v5.

18 |
19 |
20 |
21 |
22 | 23 | 24 |
25 |
26 | 27 | 28 |
29 |
30 |
31 |
32 |
33 |
34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 |
Task IdStatusAction
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 | 50 |
51 |
52 | 53 | 54 | -------------------------------------------------------------------------------- /api/app.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.insert(0, os.path.realpath(os.path.pardir)) 4 | from fastapi import FastAPI, File, UploadFile 5 | from fastapi.staticfiles import StaticFiles 6 | from fastapi.middleware.cors import CORSMiddleware 7 | from fastapi.responses import JSONResponse 8 | from celery_tasks.tasks import predict_image 9 | from celery.result import AsyncResult 10 | from models import Task, Prediction 11 | import uuid 12 | import logging 13 | from pydantic.typing import List 14 | import numpy as np 15 | 16 | UPLOAD_FOLDER = 'uploads' 17 | STATIC_FOLDER = 'static/results' 18 | 19 | isdir = os.path.isdir(UPLOAD_FOLDER) 20 | if not isdir: 21 | os.makedirs(UPLOAD_FOLDER) 22 | 23 | isdir = os.path.isdir(STATIC_FOLDER) 24 | if not isdir: 25 | os.makedirs(STATIC_FOLDER) 26 | 27 | origins = [ 28 | "http://localhost", 29 | "http://localhost:8080", 30 | ] 31 | 32 | app = FastAPI() 33 | app.mount("/static", StaticFiles(directory=STATIC_FOLDER), name="static") 34 | app.add_middleware( 35 | CORSMiddleware, 36 | allow_origins=origins, 37 | allow_credentials=True, 38 | allow_methods=["*"], 39 | allow_headers=["*"], 40 | ) 41 | 42 | 43 | @app.post('/api/process') 44 | async def process(files: List[UploadFile] = File(...)): 45 | tasks = [] 46 | try: 47 | for file in files: 48 | d = {} 49 | try: 50 | name = str(uuid.uuid4()).split('-')[0] 51 | ext = file.filename.split('.')[-1] 52 | file_name = f'{UPLOAD_FOLDER}/{name}.{ext}' 53 | with open(file_name, 'wb+') as f: 54 | f.write(file.file.read()) 55 | f.close() 56 | 57 | # start task prediction 58 | task_id = predict_image.delay(os.path.join('api', file_name)) 59 | d['task_id'] = str(task_id) 60 | d['status'] = 'PROCESSING' 61 | d['url_result'] = f'/api/result/{task_id}' 62 | except Exception as ex: 63 | logging.info(ex) 64 | d['task_id'] = str(task_id) 65 | d['status'] = 'ERROR' 66 | d['url_result'] = '' 67 | tasks.append(d) 68 | return JSONResponse(status_code=202, content=tasks) 69 | except Exception as ex: 70 | logging.info(ex) 71 | return JSONResponse(status_code=400, content=[]) 72 | 73 | 74 | @app.get('/api/result/{task_id}', response_model=Prediction) 75 | async def result(task_id: str): 76 | task = AsyncResult(task_id) 77 | 78 | # Task Not Ready 79 | if not task.ready(): 80 | return JSONResponse(status_code=202, content={'task_id': str(task_id), 'status': task.status, 'result': ''}) 81 | 82 | # Task done: return the value 83 | task_result = task.get() 84 | result = task_result.get('result') 85 | return JSONResponse(status_code=200, content={'task_id': str(task_id), 'status': task_result.get('status'), 'result': result}) 86 | 87 | 88 | @app.get('/api/status/{task_id}', response_model=Prediction) 89 | async def status(task_id: str): 90 | task = AsyncResult(task_id) 91 | return JSONResponse(status_code=200, content={'task_id': str(task_id), 'status': task.status, 'result': ''}) 92 | -------------------------------------------------------------------------------- /webapp/templates/static/app.js: -------------------------------------------------------------------------------- 1 | var URL = 'http://localhost:8000' 2 | var URL_STATUS = 'http://localhost:8000/api/status/' 3 | var results = [] 4 | var status_list = [] 5 | var res = '' 6 | jQuery(document).ready(function () { 7 | $('#row_detail').hide() 8 | $("#row_results").hide(); 9 | $('#btn-process').on('click', function () { 10 | var form_data = new FormData(); 11 | files = $('#input_file').prop('files') 12 | for (i = 0; i < files.length; i++) 13 | form_data.append('files', $('#input_file').prop('files')[i]); 14 | 15 | $.ajax({ 16 | url: URL + '/api/process', 17 | type: "post", 18 | data: form_data, 19 | enctype: 'multipart/form-data', 20 | contentType: false, 21 | processData: false, 22 | cache: false, 23 | beforeSend: function () { 24 | results = [] 25 | status_list = [] 26 | $("#table_result > tbody").html(''); 27 | $('#row_detail').hide(); 28 | $("#row_results").hide(); 29 | }, 30 | }).done(function (jsondata, textStatus, jqXHR) { 31 | for (i = 0; i < jsondata.length; i++) { 32 | task_id = jsondata[i]['task_id'] 33 | status = jsondata[i]['status'] 34 | results.push(URL + jsondata[i]['url_result']) 35 | status_list.push(task_id) 36 | result_button = `