├── api
├── __init__.py
├── models.py
├── Dockerfile
└── app.py
├── .gitignore
├── img
├── doc.gif
├── schema.jpg
└── webapp.gif
├── webapp
├── Dockerfile
├── app.py
└── templates
│ ├── static
│ ├── styles.css
│ └── app.js
│ └── index.html
├── requirements.txt
├── celery_tasks
├── app_worker.py
├── tasks.py
└── yolo.py
├── docker-compose.yml
└── README.md
/api/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pt
2 | __pycache__
3 | .DS_Store
--------------------------------------------------------------------------------
/img/doc.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GoldVelvet/YOLOv5-fastapi-celery-redis-rabbitmq/HEAD/img/doc.gif
--------------------------------------------------------------------------------
/img/schema.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GoldVelvet/YOLOv5-fastapi-celery-redis-rabbitmq/HEAD/img/schema.jpg
--------------------------------------------------------------------------------
/img/webapp.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GoldVelvet/YOLOv5-fastapi-celery-redis-rabbitmq/HEAD/img/webapp.gif
--------------------------------------------------------------------------------
/webapp/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM tiangolo/uvicorn-gunicorn-fastapi:python3.7
2 | RUN pip install -U pip aiofiles python-multipart jinja2
3 | COPY /webapp /app
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | celery==4.4.7
2 | fastapi
3 | matplotlib
4 | opencv-python
5 | aiofiles
6 | python-multipart
7 | pyyaml
8 | seaborn
9 | uvicorn
10 | redis
--------------------------------------------------------------------------------
/api/models.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 |
3 |
4 | class Task(BaseModel):
5 | task_id: str
6 | status: str
7 |
8 |
9 | class Prediction(Task):
10 | task_id: str
11 | status: str
12 | result: str
13 |
--------------------------------------------------------------------------------
/celery_tasks/app_worker.py:
--------------------------------------------------------------------------------
1 | import os
2 | from celery import Celery
3 |
4 | BROKER_URI = 'amqp://rabbitmq'
5 | BACKEND_URI = 'redis://redis'
6 |
7 | app = Celery(
8 | 'celery_tasks',
9 | broker=BROKER_URI,
10 | backend=BACKEND_URI,
11 | include=['celery_tasks.tasks']
12 | )
13 |
--------------------------------------------------------------------------------
/api/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM pytorch/pytorch:latest
2 | WORKDIR /app
3 | RUN apt-get update
4 | RUN apt-get install ffmpeg libsm6 libxext6 -y
5 | COPY requirements.txt requirements.txt
6 | RUN pip install -U pip
7 | RUN pip install -r requirements.txt
8 | COPY . .
9 | RUN rm -r /app/api/uploads/*; exit 0
10 | RUN rm -r /app/api/static/*; exit 0
--------------------------------------------------------------------------------
/webapp/app.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI, Request
2 | from fastapi.responses import HTMLResponse
3 | from fastapi.staticfiles import StaticFiles
4 | from fastapi.templating import Jinja2Templates
5 |
6 | app = FastAPI()
7 | app.mount("/static", StaticFiles(directory="templates/static"), name="static")
8 |
9 |
10 | templates = Jinja2Templates(directory="templates")
11 |
12 |
13 | @app.get("/", response_class=HTMLResponse)
14 | async def index(request: Request):
15 | return templates.TemplateResponse("index.html", context={'request': request})
16 |
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | services:
2 | rabbitmq:
3 | container_name: rabbitmq
4 | image: rabbitmq:3-management
5 | ports:
6 | - 5672:5672
7 | - 15672:15672
8 |
9 | redis:
10 | container_name: redis
11 | image: redis
12 | ports:
13 | - "6379:6379"
14 |
15 | worker:
16 | container_name: worker
17 | build:
18 | dockerfile: api/Dockerfile
19 | context: .
20 | command: sh -c "cd api && uvicorn app:app --host 0.0.0.0 --port 8000 & cd ../ & celery -A celery_tasks.app_worker worker -l INFO --pool=solo"
21 | volumes:
22 | - .:/app
23 | ports:
24 | - "8000:8000"
25 |
26 | webapp:
27 | container_name: webapp
28 | build:
29 | dockerfile: webapp/Dockerfile
30 | context: .
31 | command: sh -c "cd /app && uvicorn app:app --host 0.0.0.0 --port 80 --reload"
32 | ports:
33 | - "80:80"
--------------------------------------------------------------------------------
/celery_tasks/tasks.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from celery import Task
3 | from celery.exceptions import MaxRetriesExceededError
4 | from .app_worker import app
5 | from .yolo import YoloModel
6 |
7 |
8 | class PredictTask(Task):
9 | def __init__(self):
10 | super().__init__()
11 | self.model = None
12 |
13 | def __call__(self, *args, **kwargs):
14 | if not self.model:
15 | logging.info('Loading Model...')
16 | self.model = YoloModel()
17 | logging.info('Model loaded')
18 | return self.run(*args, **kwargs)
19 |
20 |
21 | @app.task(ignore_result=False, bind=True, base=PredictTask)
22 | def predict_image(self, data):
23 | try:
24 | data_pred = self.model.predict(data)
25 | return {'status': 'SUCCESS', 'result': data_pred}
26 | except Exception as ex:
27 | try:
28 | self.retry(countdown=2)
29 | except MaxRetriesExceededError as ex:
30 | return {'status': 'FAIL', 'result': 'max retried achieved'}
31 |
--------------------------------------------------------------------------------
/webapp/templates/static/styles.css:
--------------------------------------------------------------------------------
1 | h2 {
2 | font-size: 30px;
3 | text-align: center;
4 | width: 100%;
5 | background: #1462dc;
6 | padding: 13px;
7 | line-height: 30px;
8 | color: white;
9 | }
10 |
11 | h4 {
12 | margin-top: 20px;
13 | text-align: center;
14 | }
15 |
16 | div.preview,
17 | .centermask {
18 | float: left;
19 | width: 33.3%;
20 | height: auto;
21 | text-align: center;
22 | margin-bottom: 20px;
23 | }
24 |
25 | .centermask {
26 | width: 100%;
27 | }
28 |
29 | .img,
30 | .img-centermask {
31 | height: auto;
32 | max-width: 300px;
33 | background-size: contain;
34 | }
35 |
36 | .img-centermask {
37 | max-width: 800px;
38 | }
39 |
40 | span.class-name {
41 | display: block;
42 | width: 100%;
43 | text-align: center;
44 | }
45 |
46 | .header_label {
47 | margin-top: 10px;
48 | display: none;
49 | text-align: center;
50 | width: 100%;
51 | background: #96c0ff;
52 | }
53 |
54 | .col-sm-10 {
55 | text-align: center !important;
56 | }
57 |
58 | th,
59 | td {
60 | padding: 10px;
61 | }
62 |
63 | textarea {
64 | border: 0px;
65 | }
66 |
67 | #row_detail {
68 | margin-top: 20px;
69 | }
--------------------------------------------------------------------------------
/celery_tasks/yolo.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch
3 | import matplotlib.pyplot as plt
4 |
5 |
6 | class YoloModel:
7 | def __init__(self):
8 | self.model = torch.hub.load('ultralytics/yolov5', 'yolov5x', pretrained=True)
9 | self.model.eval()
10 |
11 | def predict(self, img):
12 | try:
13 | with torch.no_grad():
14 | result = self.model(img)
15 | result.save('api/static/results/')
16 | final_result = {}
17 | data = []
18 | file_name = f'static/{result.files[0]}'
19 |
20 | for i in range(len(result.xywhn[0])):
21 | x, y, w, h, prob, cls = result.xywhn[0][i].numpy()
22 | preds = {}
23 | preds['x'] = str(x)
24 | preds['y'] = str(y)
25 | preds['w'] = str(w)
26 | preds['h'] = str(h)
27 | preds['prob'] = str(prob)
28 | preds['class'] = result.names[int(cls)]
29 | data.append(preds)
30 |
31 | return {'file_name': file_name, 'bbox': data}
32 | except Exception as ex:
33 | logging.error(str(ex))
34 | return None
35 |
--------------------------------------------------------------------------------
/webapp/templates/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | Yolo v5
5 |
6 |
7 |
8 |
9 |
10 |
12 |
13 |
14 |
15 |
16 | Yolo v5 - Object Detection
17 | Demo application showing how to consume the API which performs object detection using Yolo v5.
18 |
19 |
52 |
53 |
54 |
--------------------------------------------------------------------------------
/api/app.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | sys.path.insert(0, os.path.realpath(os.path.pardir))
4 | from fastapi import FastAPI, File, UploadFile
5 | from fastapi.staticfiles import StaticFiles
6 | from fastapi.middleware.cors import CORSMiddleware
7 | from fastapi.responses import JSONResponse
8 | from celery_tasks.tasks import predict_image
9 | from celery.result import AsyncResult
10 | from models import Task, Prediction
11 | import uuid
12 | import logging
13 | from pydantic.typing import List
14 | import numpy as np
15 |
16 | UPLOAD_FOLDER = 'uploads'
17 | STATIC_FOLDER = 'static/results'
18 |
19 | isdir = os.path.isdir(UPLOAD_FOLDER)
20 | if not isdir:
21 | os.makedirs(UPLOAD_FOLDER)
22 |
23 | isdir = os.path.isdir(STATIC_FOLDER)
24 | if not isdir:
25 | os.makedirs(STATIC_FOLDER)
26 |
27 | origins = [
28 | "http://localhost",
29 | "http://localhost:8080",
30 | ]
31 |
32 | app = FastAPI()
33 | app.mount("/static", StaticFiles(directory=STATIC_FOLDER), name="static")
34 | app.add_middleware(
35 | CORSMiddleware,
36 | allow_origins=origins,
37 | allow_credentials=True,
38 | allow_methods=["*"],
39 | allow_headers=["*"],
40 | )
41 |
42 |
43 | @app.post('/api/process')
44 | async def process(files: List[UploadFile] = File(...)):
45 | tasks = []
46 | try:
47 | for file in files:
48 | d = {}
49 | try:
50 | name = str(uuid.uuid4()).split('-')[0]
51 | ext = file.filename.split('.')[-1]
52 | file_name = f'{UPLOAD_FOLDER}/{name}.{ext}'
53 | with open(file_name, 'wb+') as f:
54 | f.write(file.file.read())
55 | f.close()
56 |
57 | # start task prediction
58 | task_id = predict_image.delay(os.path.join('api', file_name))
59 | d['task_id'] = str(task_id)
60 | d['status'] = 'PROCESSING'
61 | d['url_result'] = f'/api/result/{task_id}'
62 | except Exception as ex:
63 | logging.info(ex)
64 | d['task_id'] = str(task_id)
65 | d['status'] = 'ERROR'
66 | d['url_result'] = ''
67 | tasks.append(d)
68 | return JSONResponse(status_code=202, content=tasks)
69 | except Exception as ex:
70 | logging.info(ex)
71 | return JSONResponse(status_code=400, content=[])
72 |
73 |
74 | @app.get('/api/result/{task_id}', response_model=Prediction)
75 | async def result(task_id: str):
76 | task = AsyncResult(task_id)
77 |
78 | # Task Not Ready
79 | if not task.ready():
80 | return JSONResponse(status_code=202, content={'task_id': str(task_id), 'status': task.status, 'result': ''})
81 |
82 | # Task done: return the value
83 | task_result = task.get()
84 | result = task_result.get('result')
85 | return JSONResponse(status_code=200, content={'task_id': str(task_id), 'status': task_result.get('status'), 'result': result})
86 |
87 |
88 | @app.get('/api/status/{task_id}', response_model=Prediction)
89 | async def status(task_id: str):
90 | task = AsyncResult(task_id)
91 | return JSONResponse(status_code=200, content={'task_id': str(task_id), 'status': task.status, 'result': ''})
92 |
--------------------------------------------------------------------------------
/webapp/templates/static/app.js:
--------------------------------------------------------------------------------
1 | var URL = 'http://localhost:8000'
2 | var URL_STATUS = 'http://localhost:8000/api/status/'
3 | var results = []
4 | var status_list = []
5 | var res = ''
6 | jQuery(document).ready(function () {
7 | $('#row_detail').hide()
8 | $("#row_results").hide();
9 | $('#btn-process').on('click', function () {
10 | var form_data = new FormData();
11 | files = $('#input_file').prop('files')
12 | for (i = 0; i < files.length; i++)
13 | form_data.append('files', $('#input_file').prop('files')[i]);
14 |
15 | $.ajax({
16 | url: URL + '/api/process',
17 | type: "post",
18 | data: form_data,
19 | enctype: 'multipart/form-data',
20 | contentType: false,
21 | processData: false,
22 | cache: false,
23 | beforeSend: function () {
24 | results = []
25 | status_list = []
26 | $("#table_result > tbody").html('');
27 | $('#row_detail').hide();
28 | $("#row_results").hide();
29 | },
30 | }).done(function (jsondata, textStatus, jqXHR) {
31 | for (i = 0; i < jsondata.length; i++) {
32 | task_id = jsondata[i]['task_id']
33 | status = jsondata[i]['status']
34 | results.push(URL + jsondata[i]['url_result'])
35 | status_list.push(task_id)
36 | result_button = `