├── .env.sample ├── Dockerfile ├── README.md ├── app.py ├── celery_worker.py ├── docker-compose.yml ├── model_loader.py ├── redis_server.py ├── requirements.txt ├── run.py └── utils.py /.env.sample: -------------------------------------------------------------------------------- 1 | HUGGINGFACE_TOKEN=... 2 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.2.0-base-ubuntu20.04 2 | 3 | # get python 3.9 4 | RUN apt-get update && apt-get install -y software-properties-common 5 | RUN add-apt-repository ppa:deadsnakes/ppa 6 | RUN apt-get update && apt-get install -y python3.9 python3.9-dev python3.9-distutils python3-pip 7 | 8 | WORKDIR /app 9 | ADD . /app 10 | 11 | RUN pip install --no-cache-dir -r requirements.txt 12 | EXPOSE 80 13 | 14 | CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "80"] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Serving a Scalable Fast API application leveraging Celery and Redis 2 | 3 | ### Pre-requisites: 4 | 5 | - You must deploy it on a GPU with ~16GB of memory 6 | - You will need `docker compose` (HINT: Do not use `docker-compose`) 7 | - [Documentation](https://docs.docker.com/compose/install/linux/#install-the-plugin-manually) 8 | - You will need to ensure `nvidia-ctk --version` provides a valid output 9 | - [Documentation](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) 10 | 11 | ### Tasks: 12 | 13 | 1. Provide a simple system diagram (created in whatever format you feel best communicates the flow) for the application 14 | 2. Provide an example output from the model. 15 | - NOTE: Getting an output is a multi-step process. 16 | 17 | This is meant to be a challenging task - so you might need to spend some time troubleshooting and tinkering! 18 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | from pydantic import BaseModel, Field 3 | from celery.result import AsyncResult 4 | from typing import Any 5 | from celery_worker import generate_text_task 6 | from dotenv import load_dotenv 7 | 8 | load_dotenv() 9 | 10 | app = FastAPI() 11 | 12 | 13 | class Prompt(BaseModel): 14 | prompt: str 15 | 16 | 17 | @app.post("/generateText") 18 | async def generate_text(prompt: Prompt) -> Any: 19 | task = generate_text_task.delay(prompt.prompt) 20 | return {"task_id": task.id} 21 | 22 | 23 | @app.get("/task/{task_id}") 24 | async def get_generate_text(task_id: str): 25 | task = AsyncResult(task_id) 26 | if task.ready(): 27 | task_result = task.get() 28 | return { 29 | "result": task_result[0], 30 | "time": task_result[1], 31 | "memory": task_result[2], 32 | } 33 | else: 34 | return {"status": "Task Pending"} 35 | -------------------------------------------------------------------------------- /celery_worker.py: -------------------------------------------------------------------------------- 1 | from celery import Celery, signals 2 | from utils import generate_output 3 | from model_loader import ModelLoader 4 | 5 | 6 | def make_celery(app_name=__name__): 7 | backend = broker = "redis://redis:6379/0" 8 | return Celery(app_name, backend=backend, broker=broker) 9 | 10 | 11 | celery = make_celery() 12 | 13 | model_loader = None 14 | model_path = "meta-llama/Llama-2-7b-chat-hf" 15 | 16 | 17 | @signals.worker_process_init.connect 18 | def setup_model(signal, sender, **kwargs): 19 | global model_loader 20 | model_loader = ModelLoader(model_path) 21 | 22 | 23 | @celery.task 24 | def generate_text_task(prompt): 25 | time, memory, outputs = generate_output( 26 | prompt, model_loader.model, model_loader.tokenizer 27 | ) 28 | return model_loader.tokenizer.decode(outputs[0]), time, memory 29 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3' 2 | services: 3 | web: 4 | build: . 5 | command: ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "80"] 6 | volumes: 7 | - .:/app 8 | ports: 9 | - 8000:80 10 | depends_on: 11 | - redis 12 | deploy: 13 | resources: 14 | reservations: 15 | devices: 16 | - driver: nvidia 17 | count: 1 18 | capabilities: [gpu] 19 | worker: 20 | build: . 21 | command: celery -A celery_worker worker -P solo --loglevel=info 22 | volumes: 23 | - .:/app 24 | depends_on: 25 | - redis 26 | deploy: 27 | resources: 28 | reservations: 29 | devices: 30 | - driver: nvidia 31 | count: 1 32 | capabilities: [gpu] 33 | redis: 34 | image: "redis:alpine" 35 | ports: 36 | - 6379:6379 37 | -------------------------------------------------------------------------------- /model_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from transformers import ( 3 | AutoModelForCausalLM, 4 | AutoConfig, 5 | AutoTokenizer, 6 | BitsAndBytesConfig, 7 | ) 8 | import torch 9 | from dotenv import load_dotenv 10 | 11 | load_dotenv() 12 | 13 | 14 | class ModelLoader: 15 | def __init__(self, model_path: str): 16 | self.model_path = model_path 17 | self.config = AutoConfig.from_pretrained( 18 | self.model_path, 19 | trust_remote_code=True, 20 | use_auth_token=os.getenv("HUGGINGFACE_TOKEN"), 21 | ) 22 | self.model = self._load_model() 23 | self.tokenizer = AutoTokenizer.from_pretrained( 24 | self.model_path, use_auth_token=os.getenv("HUGGINGFACE_TOKEN") 25 | ) 26 | 27 | def _load_model(self): 28 | nf4_config = BitsAndBytesConfig( 29 | load_in_4bit=True, 30 | bnb_4bit_quant_type="nf4", 31 | bnb_4bit_use_double_quant=True, 32 | bnb_4bit_compute_dtype=torch.bfloat16, 33 | ) 34 | model = AutoModelForCausalLM.from_pretrained( 35 | self.model_path, 36 | quantization_config=nf4_config, 37 | trust_remote_code=True, 38 | device_map="auto", 39 | use_auth_token=os.getenv("HUGGINGFACE_TOKEN"), 40 | ) 41 | return model 42 | -------------------------------------------------------------------------------- /redis_server.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import redis_server 3 | 4 | 5 | def install_redis_server(redis_version): 6 | try: 7 | subprocess.check_call(["pip", "install", f"redis-server=={redis_version}"]) 8 | print(f"Redis server version {redis_version} installed successfully.") 9 | except subprocess.CalledProcessError: 10 | print("Failed to install Redis server.") 11 | exit(1) 12 | 13 | 14 | def start_redis_server(): 15 | try: 16 | redis_server_path = redis_server.REDIS_SERVER_PATH 17 | subprocess.Popen([redis_server_path]) 18 | print("Redis server started successfully.") 19 | except Exception as e: 20 | print("Failed to start Redis server:", str(e)) 21 | exit(1) 22 | 23 | 24 | def main(): 25 | redis_version = "6.0.9" 26 | install_redis_server(redis_version) 27 | start_redis_server() 28 | 29 | 30 | if __name__ == "__main__": 31 | main() 32 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi==0.99.1 2 | uvicorn==0.22.0 3 | pydantic==1.10.10 4 | celery==5.3.1 5 | redis==4.6.0 6 | python-dotenv==1.0.0 7 | transformers==4.30.2 8 | torch==2.0.1 9 | accelerate==0.21.0 10 | bitsandbytes==0.41.0 11 | scipy==1.10.1 12 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import http.client 2 | import json 3 | import time 4 | 5 | API_HOST = "localhost" 6 | API_PORT = 8000 7 | 8 | 9 | def generate_text(prompt): 10 | conn = http.client.HTTPConnection(API_HOST, API_PORT) 11 | headers = {"Content-type": "application/json"} 12 | data = {"prompt": prompt} 13 | json_data = json.dumps(data) 14 | conn.request("POST", "/generateText/", json_data, headers) 15 | response = conn.getresponse() 16 | result = json.loads(response.read().decode()) 17 | conn.close() 18 | return result["task_id"] 19 | 20 | 21 | def get_task_status(task_id): 22 | conn = http.client.HTTPConnection(API_HOST, API_PORT) 23 | conn.request("GET", f"/generateTextTask/{task_id}") 24 | response = conn.getresponse() 25 | status = response.read().decode() 26 | conn.close() 27 | return status 28 | 29 | 30 | def main(): 31 | prompt = input("Enter the prompt: ") 32 | 33 | task_id = generate_text(prompt) 34 | while True: 35 | status = get_task_status(task_id) 36 | if "Task Pending" not in status: 37 | print(status) 38 | break 39 | time.sleep(2) 40 | 41 | 42 | if __name__ == "__main__": 43 | main() 44 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import functools 4 | from transformers import AutoModelForCausalLM, AutoTokenizer 5 | 6 | 7 | def time_decorator(func): 8 | @functools.wraps(func) 9 | def wrapper(*args, **kwargs): 10 | start_time = time.time() 11 | result = func(*args, **kwargs) 12 | end_time = time.time() 13 | exec_time = end_time - start_time 14 | return (result, exec_time) 15 | 16 | return wrapper 17 | 18 | 19 | def memory_decorator(func): 20 | @functools.wraps(func) 21 | def wrapper(*args, **kwargs): 22 | torch.cuda.empty_cache() 23 | torch.cuda.reset_peak_memory_stats() 24 | result, exec_time = func(*args, **kwargs) 25 | peak_mem = torch.cuda.max_memory_allocated() 26 | peak_mem_consumption = peak_mem / 1e9 27 | return peak_mem_consumption, exec_time, result 28 | 29 | return wrapper 30 | 31 | 32 | @memory_decorator 33 | @time_decorator 34 | def generate_output( 35 | prompt: str, model: AutoModelForCausalLM, tokenizer: AutoTokenizer 36 | ) -> torch.Tensor: 37 | input_ids = tokenizer(prompt, return_tensors="pt").input_ids 38 | input_ids = input_ids.to("cuda") 39 | outputs = model.generate(input_ids, max_length=500) 40 | return outputs 41 | --------------------------------------------------------------------------------