├── .dockerignore ├── .env.example ├── .gitignore ├── Dockerfile ├── MANIFEST ├── Makefile ├── README.md ├── docker-compose.yml ├── hf.png ├── huggingfastapi ├── __init__.py ├── api │ ├── __init__.py │ └── routes │ │ ├── __init__.py │ │ ├── heartbeat.py │ │ ├── prediction.py │ │ └── router.py ├── core │ ├── __init__.py │ ├── config.py │ ├── event_handlers.py │ ├── messages.py │ └── security.py ├── main.py ├── models │ ├── __init__.py │ ├── heartbeat.py │ ├── payload.py │ └── prediction.py └── services │ ├── __init__.py │ ├── nlp.py │ └── utils.py ├── ml_model └── model_description.md ├── post_swagger.png ├── requirements.txt ├── response_swagger.png ├── setup.cfg ├── setup.py ├── tests ├── __init__.py ├── conftest.py ├── test_api │ ├── __init__.py │ ├── test_api_auth.py │ ├── test_heartbeat.py │ └── test_prediction.py └── test_service │ ├── __init__.py │ └── test_models.py └── tox.ini /.dockerignore: -------------------------------------------------------------------------------- 1 | /images 2 | /tests 3 | .ascii-art 4 | 5 | *.yml 6 | *.ini 7 | 8 | # Hidden files 9 | .DS_store 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # data is not logged 20 | data/ 21 | 22 | # Distribution / packaging 23 | .Python 24 | env/ 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | wheels/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | .vscode/ 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | htmlcov-py36/ 55 | .tox/ 56 | .coverage 57 | .coverage.* 58 | .cache 59 | nosetests.xml 60 | coverage.xml 61 | *.cover 62 | .hypothesis/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # pyenv 89 | .python-version 90 | 91 | .vscode/ 92 | .pytest-cache/ 93 | .pytest_cache/ 94 | .empty/ 95 | 96 | # celery beat schedule file 97 | celerybeat-schedule 98 | 99 | # SageMath parsed files 100 | *.sage.py 101 | 102 | # virtualenv 103 | .venv 104 | venv/ 105 | ENV/ 106 | 107 | # Spyder project settings 108 | .spyderproject 109 | .spyproject 110 | 111 | # Rope project settings 112 | .ropeproject 113 | 114 | # mkdocs documentation 115 | /site 116 | 117 | # mypy 118 | .mypy_cache/ 119 | .idea/ 120 | 121 | 122 | 123 | Child_Watch_API/ 124 | htmlcov-*/ 125 | 126 | # remove large model 127 | ml_model/*.joblib 128 | 129 | 130 | #venv 131 | venv/ -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | # FastAPI 2 | IS_DEBUG=False 3 | API_KEY=example_key 4 | DEFAULT_MODEL_PATH=./ml_model/models 5 | # Hugging FaceModel 6 | QUESTION_ANSWER_MODEL=deepset/roberta-base-squad2 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /ml_model/models 2 | 3 | # Hidden files 4 | .DS_store 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # data is not logged 15 | data/ 16 | 17 | # Distribution / packaging 18 | .Python 19 | env/ 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | .vscode/ 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | htmlcov-py36/ 50 | .tox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | .hypothesis/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # pyenv 84 | .python-version 85 | 86 | .vscode/ 87 | .pytest-cache/ 88 | .pytest_cache/ 89 | .empty/ 90 | 91 | # celery beat schedule file 92 | celerybeat-schedule 93 | 94 | # SageMath parsed files 95 | *.sage.py 96 | 97 | # dotenv 98 | .env 99 | 100 | # virtualenv 101 | .venv 102 | venv/ 103 | ENV/ 104 | 105 | # Spyder project settings 106 | .spyderproject 107 | .spyproject 108 | 109 | # Rope project settings 110 | .ropeproject 111 | 112 | # mkdocs documentation 113 | /site 114 | 115 | # mypy 116 | .mypy_cache/ 117 | .idea/ 118 | 119 | 120 | 121 | Child_Watch_API/ 122 | htmlcov-*/ 123 | 124 | # remove large model 125 | ml_model/*.joblib 126 | 127 | 128 | #venv 129 | venv/ 130 | 131 | 132 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.8.2 2 | 3 | ENV PYTHONUNBUFFERED 1 4 | 5 | EXPOSE 8000 6 | WORKDIR /app 7 | 8 | COPY requirements.txt ./ 9 | RUN pip install --upgrade pip && \ 10 | pip install -r requirements.txt 11 | 12 | COPY . ./ 13 | 14 | ENV PYTHONPATH huggingfastapi 15 | CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] -------------------------------------------------------------------------------- /MANIFEST: -------------------------------------------------------------------------------- 1 | # file GENERATED by distutils, do NOT edit 2 | setup.cfg 3 | setup.py 4 | huggingfastapi/__init__.py 5 | huggingfastapi/main.py 6 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | SHELL := /bin/bash 2 | 3 | # Variables definitions 4 | # ----------------------------------------------------------------------------- 5 | 6 | DEFAULT_MODEL_PATH := ./ml_model/ 7 | 8 | ifeq ($(API_KEY),) 9 | API_KEY := $(python -c "import uuid;print(str(uuid.uuid4()))") 10 | endif 11 | 12 | 13 | # Target section and Global definitions 14 | # ----------------------------------------------------------------------------- 15 | .PHONY: all clean test install run deploy down 16 | 17 | all: clean test install run deploy down 18 | 19 | test: 20 | python -m pip install --upgrade pip && pip install setuptools tox 21 | tox 22 | 23 | install: generate_dot_env 24 | pip install --upgrade pip && pip install -r requirements.txt 25 | 26 | run: 27 | - python -c "import subprocess;print(subprocess.run(['hostname','-I'],capture_output=True,text=True).stdout.strip())" 28 | PYTHONPATH=huggingfastapi/ uvicorn huggingfastapi.main:app --reload --host 0.0.0.0 29 | 30 | build: generate_dot_env 31 | docker compose build 32 | 33 | deploy: generate_dot_env 34 | docker compose build 35 | docker compose up -d 36 | 37 | down: 38 | docker compose down 39 | 40 | generate_dot_env: 41 | @if [[ ! -e .env ]]; then \ 42 | cp .env.example .env; \ 43 | fi 44 | 45 | clean: 46 | -find . -name '*.pyc' -exec rm -rf {} \; 47 | -find . -name '__pycache__' -exec rm -rf {} \; 48 | -find . -name 'Thumbs.db' -exec rm -rf {} \; 49 | -find . -name '*~' -exec rm -rf {} \; 50 | -rm -rf .cache 51 | -rm -rf build 52 | -rm -rf dist 53 | -rm -rf *.egg-info 54 | -rm -rf htmlcov* 55 | -rm -rf .tox/ 56 | -rm -rf docs/_build 57 | -rm -r .coverage 58 | -rm -rf .pytest_cache -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Question Answering API 2 | > 🤗 Huggingface + ⚡ FastAPI = ❤️ Awesomeness. How to structure Deep Learning model serving REST API with FastAPI 3 | 4 | ![huggingfastapi](hf.png) 5 | How to server Hugging face models with FastAPI, the Python's fastest REST API framework. 6 | 7 | The minimalistic project structure for development and production. 8 | 9 | Installation and setup instructions to 10 | run the development mode model and serve a local RESTful API endpoint. 11 | 12 | ## Project structure 13 | 14 | Files related to application are in the `huggingfastapi` or `tests` directories. 15 | Application parts are: 16 | 17 | huggingfastapi 18 | ├── api - Main API. 19 | │   └── routes - Web routes. 20 | ├── core - Application configuration, startup events, logging. 21 | ├── models - Pydantic models for api. 22 | ├── services - NLP logics. 23 | └── main.py - FastAPI application creation and configuration. 24 | │ 25 | tests - Codes without tests is an illusion 26 | 27 | ## Swagger Example 28 | ![post_swagger](post_swagger.png) 29 | ![response_swagger](response_swagger.png) 30 | 31 | ## Requirements 32 | 33 | Python 3.6+, [Make and Docker] 34 | 35 | ## Installation 36 | Install the required packages in your local environment (ideally virtualenv, conda, etc.). 37 | 40 | 41 | ```sh 42 | python -m venv .venv 43 | source .venv/bin/activate 44 | make install 45 | ``` 46 | 47 | #### Running Localhost 48 | 49 | ```sh 50 | make run 51 | ``` 52 | 53 | #### Running Via Docker 54 | 55 | ```sh 56 | make deploy 57 | ``` 58 | 59 | #### Running Tests 60 | 61 | ```sh 62 | make test 63 | ``` 64 | 65 | ## Setup 66 | 1. Duplicate the `.env.example` file and rename it to `.env` 67 | 68 | 2. In the `.env` file configure the `API_KEY` entry. The key is used for authenticating our API.
69 | Execute script to generate .env, and replace `example_key`with the UUID generated: 70 | 71 | ```bash 72 | make generate_dot_env 73 | python -c "import uuid;print(str(uuid.uuid4()))" 74 | 75 | ``` 76 | 77 | ## Run without `make` for development 78 | 79 | 1. Start your app with: 80 | ```bash 81 | PYTHONPATH=./huggingfastapi uvicorn main:app --reload 82 | ``` 83 | 84 | 2. Go to [http://localhost:8000/docs](http://localhost:8000/docs) or [http://localhost:8000/redoc](http://localhost:8000/redoc) for alternative swagger 85 | 86 | 3. Click `Authorize` and enter the API key as created in the Setup step. 87 | 88 | 89 | ## Run Tests without using `make` 90 | 91 | Intall testing libraries and run `tox`: 92 | ```bash 93 | pip install tox pytest flake8 coverage bandit 94 | tox 95 | ``` 96 | This runs tests and coverage for Python 3.8 and Flake8, Bandit. 97 | 98 | 99 | # TODO 100 | - [X] Change make to invoke: invoke is overrated 101 | - [ ] Add endpoint for uploading text file and questions 102 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3" 2 | 3 | services: 4 | app: 5 | build: . 6 | image: huggingfastapi:latest 7 | ports: 8 | - 8000:8000 9 | env_file: 10 | - .env 11 | container_name: huggingfastapi 12 | -------------------------------------------------------------------------------- /hf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Proteusiq/huggingfastapi/29744d40015f49b51d9dafd6178a5363bc12962b/hf.png -------------------------------------------------------------------------------- /huggingfastapi/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Proteusiq/huggingfastapi/29744d40015f49b51d9dafd6178a5363bc12962b/huggingfastapi/__init__.py -------------------------------------------------------------------------------- /huggingfastapi/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Proteusiq/huggingfastapi/29744d40015f49b51d9dafd6178a5363bc12962b/huggingfastapi/api/__init__.py -------------------------------------------------------------------------------- /huggingfastapi/api/routes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Proteusiq/huggingfastapi/29744d40015f49b51d9dafd6178a5363bc12962b/huggingfastapi/api/routes/__init__.py -------------------------------------------------------------------------------- /huggingfastapi/api/routes/heartbeat.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter 2 | 3 | from huggingfastapi.models.heartbeat import HearbeatResult 4 | 5 | router = APIRouter() 6 | 7 | 8 | @router.get("/heartbeat", response_model=HearbeatResult, name="heartbeat") 9 | def get_hearbeat() -> HearbeatResult: 10 | heartbeat = HearbeatResult(is_alive=True) 11 | return heartbeat 12 | -------------------------------------------------------------------------------- /huggingfastapi/api/routes/prediction.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, Depends 2 | from starlette.requests import Request 3 | 4 | from huggingfastapi.core import security 5 | from huggingfastapi.models.payload import QAPredictionPayload 6 | from huggingfastapi.models.prediction import QAPredictionResult 7 | from huggingfastapi.services.nlp import QAModel 8 | 9 | router = APIRouter() 10 | 11 | 12 | @router.post("/question", response_model=QAPredictionResult, name="question") 13 | def post_question( 14 | request: Request, 15 | authenticated: bool = Depends(security.validate_request), 16 | block_data: QAPredictionPayload = None, 17 | ) -> QAPredictionResult: 18 | """ 19 | #### Retrieves an answer from context given a question 20 | 21 | """ 22 | 23 | model: QAModel = request.app.state.model 24 | prediction: QAPredictionResult = model.predict(block_data) 25 | 26 | return prediction 27 | -------------------------------------------------------------------------------- /huggingfastapi/api/routes/router.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter 2 | 3 | from huggingfastapi.api.routes import heartbeat, prediction 4 | 5 | api_router = APIRouter() 6 | api_router.include_router(heartbeat.router, tags=["health"], prefix="/health") 7 | api_router.include_router(prediction.router, tags=["prediction"], prefix="/v1") 8 | -------------------------------------------------------------------------------- /huggingfastapi/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Proteusiq/huggingfastapi/29744d40015f49b51d9dafd6178a5363bc12962b/huggingfastapi/core/__init__.py -------------------------------------------------------------------------------- /huggingfastapi/core/config.py: -------------------------------------------------------------------------------- 1 | from starlette.config import Config 2 | from starlette.datastructures import Secret 3 | 4 | APP_VERSION = "0.1.0" 5 | APP_NAME = "Hugging FastAPI" 6 | API_PREFIX = "/api" 7 | 8 | config = Config(".env") 9 | 10 | API_KEY: Secret = config("API_KEY", cast=Secret) 11 | IS_DEBUG: bool = config("IS_DEBUG", cast=bool, default=False) 12 | 13 | DEFAULT_MODEL_PATH: str = config("DEFAULT_MODEL_PATH") 14 | QUESTION_ANSWER_MODEL: str = config("QUESTION_ANSWER_MODEL") 15 | -------------------------------------------------------------------------------- /huggingfastapi/core/event_handlers.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | from fastapi import FastAPI 4 | from loguru import logger 5 | 6 | from huggingfastapi.core.config import DEFAULT_MODEL_PATH 7 | from huggingfastapi.services.nlp import QAModel 8 | 9 | 10 | def _startup_model(app: FastAPI) -> None: 11 | model_path = DEFAULT_MODEL_PATH 12 | model_instance = QAModel(model_path) 13 | app.state.model = model_instance 14 | 15 | 16 | def _shutdown_model(app: FastAPI) -> None: 17 | app.state.model = None 18 | 19 | 20 | def start_app_handler(app: FastAPI) -> Callable: 21 | def startup() -> None: 22 | logger.info("Running app start handler.") 23 | _startup_model(app) 24 | 25 | return startup 26 | 27 | 28 | def stop_app_handler(app: FastAPI) -> Callable: 29 | def shutdown() -> None: 30 | logger.info("Running app shutdown handler.") 31 | _shutdown_model(app) 32 | 33 | return shutdown 34 | -------------------------------------------------------------------------------- /huggingfastapi/core/messages.py: -------------------------------------------------------------------------------- 1 | NO_API_KEY = "No API key provided." 2 | AUTH_REQ = "Authentication required." 3 | HTTP_500_DETAIL = "Internal server error." 4 | 5 | # templates 6 | NO_VALID_PAYLOAD = "{} is not a valid payload." 7 | -------------------------------------------------------------------------------- /huggingfastapi/core/security.py: -------------------------------------------------------------------------------- 1 | import secrets 2 | from typing import Optional 3 | 4 | from fastapi import HTTPException, Security 5 | from fastapi.security.api_key import APIKeyHeader 6 | from starlette.status import HTTP_400_BAD_REQUEST, HTTP_401_UNAUTHORIZED 7 | 8 | from huggingfastapi.core import config 9 | from huggingfastapi.core.messages import AUTH_REQ, NO_API_KEY 10 | 11 | api_key = APIKeyHeader(name="token", auto_error=False) 12 | 13 | 14 | def validate_request(header: Optional[str] = Security(api_key)) -> bool: 15 | if header is None: 16 | raise HTTPException( 17 | status_code=HTTP_400_BAD_REQUEST, detail=NO_API_KEY, headers={} 18 | ) 19 | if not secrets.compare_digest(header, str(config.API_KEY)): 20 | raise HTTPException( 21 | status_code=HTTP_401_UNAUTHORIZED, detail=AUTH_REQ, headers={} 22 | ) 23 | return True 24 | -------------------------------------------------------------------------------- /huggingfastapi/main.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | from huggingfastapi.api.routes.router import api_router 3 | from huggingfastapi.core.config import API_PREFIX, APP_NAME, APP_VERSION, IS_DEBUG 4 | from huggingfastapi.core.event_handlers import start_app_handler, stop_app_handler 5 | 6 | 7 | def get_app() -> FastAPI: 8 | fast_app = FastAPI(title=APP_NAME, version=APP_VERSION, debug=IS_DEBUG) 9 | fast_app.include_router(api_router, prefix=API_PREFIX) 10 | 11 | fast_app.add_event_handler("startup", start_app_handler(fast_app)) 12 | fast_app.add_event_handler("shutdown", stop_app_handler(fast_app)) 13 | 14 | return fast_app 15 | 16 | 17 | app = get_app() 18 | -------------------------------------------------------------------------------- /huggingfastapi/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Proteusiq/huggingfastapi/29744d40015f49b51d9dafd6178a5363bc12962b/huggingfastapi/models/__init__.py -------------------------------------------------------------------------------- /huggingfastapi/models/heartbeat.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | class HearbeatResult(BaseModel): 5 | is_alive: bool 6 | -------------------------------------------------------------------------------- /huggingfastapi/models/payload.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | class QAPredictionPayload(BaseModel): 5 | 6 | context: str = "42 is the answer to life, the universe and everything." 7 | question: str = "What is the answer to life?" 8 | -------------------------------------------------------------------------------- /huggingfastapi/models/prediction.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | from huggingfastapi.core.config import QUESTION_ANSWER_MODEL 3 | 4 | 5 | class QAPredictionResult(BaseModel): 6 | 7 | score: float 8 | start: int 9 | end: int 10 | answer: str 11 | model: str = QUESTION_ANSWER_MODEL 12 | -------------------------------------------------------------------------------- /huggingfastapi/services/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Proteusiq/huggingfastapi/29744d40015f49b51d9dafd6178a5363bc12962b/huggingfastapi/services/__init__.py -------------------------------------------------------------------------------- /huggingfastapi/services/nlp.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | from loguru import logger 3 | 4 | from transformers import AutoTokenizer 5 | from transformers import AutoModelForQuestionAnswering 6 | from transformers import pipeline 7 | 8 | from huggingfastapi.models.payload import QAPredictionPayload 9 | from huggingfastapi.models.prediction import QAPredictionResult 10 | from huggingfastapi.services.utils import ModelLoader 11 | from huggingfastapi.core.messages import NO_VALID_PAYLOAD 12 | from huggingfastapi.core.config import ( 13 | DEFAULT_MODEL_PATH, 14 | QUESTION_ANSWER_MODEL, 15 | ) 16 | 17 | 18 | class QAModel(object): 19 | def __init__(self, path=DEFAULT_MODEL_PATH): 20 | self.path = path 21 | self._load_local_model() 22 | 23 | def _load_local_model(self): 24 | 25 | tokenizer, model = ModelLoader( 26 | model_name=QUESTION_ANSWER_MODEL, 27 | model_directory=DEFAULT_MODEL_PATH, 28 | tokenizer_loader=AutoTokenizer, 29 | model_loader=AutoModelForQuestionAnswering, 30 | ).retrieve() 31 | 32 | self.nlp = pipeline("question-answering", model=model, tokenizer=tokenizer) 33 | 34 | def _pre_process(self, payload: QAPredictionPayload) -> List: 35 | logger.debug("Pre-processing payload.") 36 | result = [payload.context, payload.question] 37 | return result 38 | 39 | def _post_process(self, prediction: Dict) -> QAPredictionResult: 40 | logger.debug("Post-processing prediction.") 41 | 42 | qa = QAPredictionResult(**prediction) 43 | 44 | return qa 45 | 46 | def _predict(self, features: List) -> tuple: 47 | logger.debug("Predicting.") 48 | 49 | context, question = features 50 | 51 | QA_input = { 52 | "question": question, 53 | "context": context, 54 | } 55 | 56 | prediction_result = self.nlp(QA_input) 57 | 58 | return prediction_result 59 | 60 | def predict(self, payload: QAPredictionPayload): 61 | if payload is None: 62 | raise ValueError(NO_VALID_PAYLOAD.format(payload)) 63 | 64 | pre_processed_payload = self._pre_process(payload) 65 | prediction = self._predict(pre_processed_payload) 66 | logger.info(prediction) 67 | post_processed_result = self._post_process(prediction) 68 | 69 | return post_processed_result 70 | -------------------------------------------------------------------------------- /huggingfastapi/services/utils.py: -------------------------------------------------------------------------------- 1 | import typing as t 2 | from loguru import logger 3 | from pathlib import Path 4 | import torch 5 | from transformers import PreTrainedModel 6 | from transformers import PreTrainedTokenizer 7 | 8 | 9 | class ModelLoader: 10 | """ModelLoader 11 | Downloading and Loading Hugging FaceModels 12 | Download occurs only when model is not located in the local model directory 13 | If model exists in local directory, load. 14 | 15 | """ 16 | 17 | def __init__( 18 | self, 19 | model_name: str, 20 | model_directory: str, 21 | tokenizer_loader: PreTrainedTokenizer, 22 | model_loader: PreTrainedModel, 23 | ): 24 | 25 | self.model_name = Path(model_name) 26 | self.model_directory = Path(model_directory) 27 | self.model_loader = model_loader 28 | self.tokenizer_loader = tokenizer_loader 29 | 30 | self.save_path = self.model_directory / self.model_name 31 | 32 | if not self.save_path.exists(): 33 | logger.debug(f"[+] {self.save_path} does not exit!") 34 | self.save_path.mkdir(parents=True, exist_ok=True) 35 | self.__download_model() 36 | 37 | self.tokenizer, self.model = self.__load_model() 38 | 39 | def __repr__(self): 40 | return f"{self.__class__.__name__}(model={self.save_path})" 41 | 42 | # Download model from HuggingFace 43 | def __download_model(self) -> None: 44 | 45 | logger.debug(f"[+] Downloading {self.model_name}") 46 | tokenizer = self.tokenizer_loader.from_pretrained(f"{self.model_name}") 47 | model = self.model_loader.from_pretrained(f"{self.model_name}") 48 | 49 | logger.debug(f"[+] Saving {self.model_name} to {self.save_path}") 50 | tokenizer.save_pretrained(f"{self.save_path}") 51 | model.save_pretrained(f"{self.save_path}") 52 | 53 | logger.debug("[+] Process completed") 54 | 55 | # Load model 56 | def __load_model(self) -> t.Tuple: 57 | 58 | logger.debug(f"[+] Loading model from {self.save_path}") 59 | tokenizer = self.tokenizer_loader.from_pretrained(f"{self.save_path}") 60 | # Check if GPU is available 61 | device = "cuda" if torch.cuda.is_available() else "cpu" 62 | logger.info(f"[+] Model loaded in {device} complete") 63 | model = self.model_loader.from_pretrained(f"{self.save_path}").to(device) 64 | 65 | logger.debug("[+] Loading completed") 66 | return tokenizer, model 67 | 68 | def retrieve(self) -> t.Tuple: 69 | 70 | """Retriver 71 | 72 | Returns: 73 | Tuple: tokenizer, model 74 | """ 75 | return self.tokenizer, self.model 76 | -------------------------------------------------------------------------------- /ml_model/model_description.md: -------------------------------------------------------------------------------- 1 | ## NLP Model Description 2 | 3 | Model: https://huggingface.co/deepset/roberta-base-squad2
4 | 5 | ## Authors 6 | Branden Chan: branden.chan [at] deepset.ai Timo Möller: timo.moeller [at] deepset.ai
7 | Malte Pietsch: malte.pietsch [at] deepset.ai Tanay Soni: tanay.soni [at] deepset.ai
8 | 9 | ## Settings 10 | Language model: roberta-base
11 | Language: English
12 | Downstream-task: Extractive QA
13 | Training data: SQuAD 2.0
14 | Eval data: SQuAD 2.0 15 | 16 | 17 | -------------------------------------------------------------------------------- /post_swagger.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Proteusiq/huggingfastapi/29744d40015f49b51d9dafd6178a5363bc12962b/post_swagger.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | uvicorn==0.18.3 2 | fastapi==0.85.0 3 | loguru==0.6.0 4 | 5 | transformers[torch]==4.22.1 6 | sentencepiece==0.1.97 7 | -------------------------------------------------------------------------------- /response_swagger.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Proteusiq/huggingfastapi/29744d40015f49b51d9dafd6178a5363bc12962b/response_swagger.png -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [tool:pytest] 2 | testpaths = tests 3 | 4 | [coverage:run] 5 | source = app 6 | branch = True 7 | 8 | [coverage:report] 9 | precision = 2 10 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from distutils.core import setup 4 | 5 | setup( 6 | name="HUGGING_FASTAPI", 7 | version="0.1.0", 8 | description="QA Transformer NLP Model", 9 | author="Prayson W. Daniel", 10 | author_email="praysonwilfred@gmail.com", 11 | packages=["huggingfastapi"], 12 | ) 13 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Proteusiq/huggingfastapi/29744d40015f49b51d9dafd6178a5363bc12962b/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from starlette.config import environ 3 | from starlette.testclient import TestClient 4 | 5 | 6 | environ["API_KEY"] = "example_key" 7 | 8 | 9 | from huggingfastapi.main import get_app # noqa: E402 10 | 11 | 12 | @pytest.fixture() 13 | def test_client(): 14 | app = get_app() 15 | with TestClient(app) as test_client: 16 | yield test_client 17 | -------------------------------------------------------------------------------- /tests/test_api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Proteusiq/huggingfastapi/29744d40015f49b51d9dafd6178a5363bc12962b/tests/test_api/__init__.py -------------------------------------------------------------------------------- /tests/test_api/test_api_auth.py: -------------------------------------------------------------------------------- 1 | from huggingfastapi.core import messages 2 | 3 | 4 | def test_auth_using_prediction_api_no_apikey_header(test_client) -> None: 5 | response = test_client.post("/api/v1/question") 6 | assert response.status_code == 400 7 | assert response.json() == {"detail": messages.NO_API_KEY} 8 | 9 | 10 | def test_auth_using_prediction_api_wrong_apikey_header(test_client) -> None: 11 | response = test_client.post( 12 | "/api/v1/question", 13 | json={"context": "test", "question": "test"}, 14 | headers={"token": "WRONG_TOKEN"}, 15 | ) 16 | assert response.status_code == 401 17 | assert response.json() == {"detail": messages.AUTH_REQ} 18 | -------------------------------------------------------------------------------- /tests/test_api/test_heartbeat.py: -------------------------------------------------------------------------------- 1 | from huggingfastapi.main import get_app 2 | 3 | app = get_app() 4 | 5 | 6 | def test_heartbeat(test_client) -> None: 7 | response = test_client.get("/api/health/heartbeat") 8 | assert response.status_code == 200 9 | assert response.json() == {"is_alive": True} 10 | 11 | 12 | def test_default_route(test_client) -> None: 13 | response = test_client.get("/") 14 | assert response.status_code == 404 15 | -------------------------------------------------------------------------------- /tests/test_api/test_prediction.py: -------------------------------------------------------------------------------- 1 | def test_prediction(test_client) -> None: 2 | response = test_client.post( 3 | "/api/v1/question", 4 | json={"context": "two plus two equal four", "question": "What is four?"}, 5 | headers={"token": "example_key"}, 6 | ) 7 | assert response.status_code == 200 8 | assert "score" in response.json() 9 | 10 | 11 | def test_prediction_nopayload(test_client) -> None: 12 | response = test_client.post( 13 | "/api/v1/question", json={}, headers={"token": "example_key"} 14 | ) 15 | """ 16 | ## if nopayload, default Hitchhiker's Guide to the Galaxy example is sent 17 | # context:"42 is the answer to life, the universe and everything." 18 | # question:"What is the answer to life?" 19 | # if no default, assert response.status_code == 422 20 | """ 21 | 22 | data = response.json() 23 | assert data["answer"] == "42" 24 | -------------------------------------------------------------------------------- /tests/test_service/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Proteusiq/huggingfastapi/29744d40015f49b51d9dafd6178a5363bc12962b/tests/test_service/__init__.py -------------------------------------------------------------------------------- /tests/test_service/test_models.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from huggingfastapi.core import config 4 | from huggingfastapi.models.payload import QAPredictionPayload 5 | from huggingfastapi.models.prediction import QAPredictionResult 6 | from huggingfastapi.services.nlp import QAModel 7 | 8 | 9 | def test_prediction(test_client) -> None: 10 | model_path = config.DEFAULT_MODEL_PATH 11 | qa = QAPredictionPayload.parse_obj( 12 | {"context": "two plus two equal four", "question": "What is four?"} 13 | ) 14 | 15 | tm = QAModel(model_path) 16 | result = tm.predict(qa) 17 | assert isinstance(result, QAPredictionResult) 18 | assert result.answer == "two plus two" 19 | 20 | 21 | def test_prediction_no_payload(test_client) -> None: 22 | model_path = config.DEFAULT_MODEL_PATH 23 | 24 | tm = QAModel(model_path) 25 | with pytest.raises(ValueError): 26 | result = tm.predict(None) 27 | assert isinstance(result, QAPredictionResult) 28 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = flake8,py38,bandit 3 | 4 | [testenv] 5 | deps = 6 | coverage 7 | pytest 8 | pytest-cov 9 | pycodestyle>=2.0 10 | commands = 11 | coverage erase 12 | pip install -r {toxinidir}/requirements.txt 13 | pytest --cov=tests --cov=huggingfastapi --cov-report=term-missing --cov-config=setup.cfg 14 | coverage report 15 | coverage html -d htmlcov-{envname} 16 | 17 | [testenv:autopep8] 18 | basepython = python3 19 | skip_install = true 20 | deps = 21 | autopep8>=1.5 22 | commands = 23 | autopep8 --in-place --aggressive -r huggingfastapi/ 24 | 25 | 26 | [testenv:flake8] 27 | basepython = python3 28 | skip_install = true 29 | deps = 30 | flake8 31 | flake8-bugbear 32 | flake8-colors 33 | flake8-typing-imports>=1.1 34 | pep8-naming 35 | commands = 36 | flake8 huggingfastapi/ tests/ setup.py 37 | 38 | [testenv:bandit] 39 | basepython = python3 40 | deps = 41 | bandit 42 | commands = 43 | bandit -r huggingfastapi/ 44 | 45 | [flake8] 46 | exclude = 47 | .tox, 48 | .git, 49 | __pycache__, 50 | docs/source/conf.py, 51 | build, 52 | dist, 53 | tests/fixtures/*, 54 | *.pyc, 55 | *.egg-info, 56 | .cache, 57 | .eggs 58 | max-complexity = 10 59 | import-order-style = google 60 | application-import-names = flake8 61 | max-line-length = 120 62 | ignore = 63 | B008, 64 | N803, 65 | N806, 66 | E126 --------------------------------------------------------------------------------