├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── app.py ├── download.py ├── requirements.txt ├── server.py └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Must use a Cuda version 11+ 2 | FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime 3 | 4 | WORKDIR / 5 | 6 | # Install git 7 | RUN apt-get update && apt-get install -y git 8 | 9 | # Install python packages 10 | RUN pip3 install --upgrade pip 11 | ADD requirements.txt requirements.txt 12 | RUN pip3 install -r requirements.txt 13 | 14 | # We add the banana boilerplate here 15 | ADD server.py . 16 | EXPOSE 8000 17 | 18 | # Add your huggingface auth key here 19 | ENV HF_AUTH_TOKEN=your_token 20 | 21 | # Add your model weight files 22 | # (in this case we have a python script) 23 | ADD download.py . 24 | RUN python3 download.py 25 | 26 | # Add your custom app code, init() and inference() 27 | ADD app.py . 28 | 29 | CMD python3 -u server.py 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Banana 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # 🍌 Banana Serverless 3 | 4 | This repo gives a basic framework for serving Stable Diffusion in production using simple HTTP servers. 5 | 6 | ## Instant Deploy 7 | Stable Diffusion is now available as a prebult model on Banana! [See how to deploy Stable Diffusion in seconds](https://docs.banana.dev/banana-docs/core-concepts/inference-server/1-click-deploy). 8 | 9 | 10 | # Quickstart 11 | 12 | If you want to customize beyond the prebuilt model: 13 | 14 | **[Follow the quickstart guide in Banana's documentation to use this repo](https://docs.banana.dev/banana-docs/quickstart).** 15 | 16 | *(choose "GitHub Repository" deployment method)* 17 | 18 | ### Additional Steps (outside of quickstart guide) 19 | 20 | 1. Create your own private repo and copy the files from this template repo into it. You'll want a private repo so that your huggingface keys are secure. 21 | 2. Create huggingface account to get permission to download and run [Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion-v1-4) text-to-image model. 22 | - Accept terms and conditions for the use of the v1-4 [Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion-v1-4) 23 | 3. Edit the `dockerfile` in your forked repo with `ENV HF_AUTH_TOKEN=your_auth_token` 24 | 25 | 26 | # Helpful Links 27 | Understand the 🍌 [Serverless framework](https://docs.banana.dev/banana-docs/core-concepts/inference-server/serverless-framework) and functionality of each file within it. 28 | 29 | Generalize this framework to [deploy anything on Banana](https://docs.banana.dev/banana-docs/resources/how-to-serve-anything-on-banana). 30 | 31 | 32 |
33 | 34 | ## Use Banana for scale. 35 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import autocast 3 | from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler 4 | import base64 5 | from io import BytesIO 6 | import os 7 | 8 | # Init is ran on server startup 9 | # Load your model to GPU as a global variable here using the variable name "model" 10 | def init(): 11 | global model 12 | HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN") 13 | 14 | # this will substitute the default PNDM scheduler for K-LMS 15 | lms = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") 16 | 17 | model = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=lms, use_auth_token=HF_AUTH_TOKEN) 18 | 19 | # Inference is ran for every server call 20 | # Reference your preloaded global model variable here. 21 | def inference(model_inputs:dict) -> dict: 22 | global model 23 | 24 | # Parse out your arguments 25 | prompt = model_inputs.get('prompt', None) 26 | height = model_inputs.get('height', 512) 27 | width = model_inputs.get('width', 512) 28 | num_inference_steps = model_inputs.get('num_inference_steps', 50) 29 | guidance_scale = model_inputs.get('guidance_scale', 7.5) 30 | input_seed = model_inputs.get("seed",None) 31 | 32 | #If "seed" is not sent, we won't specify a seed in the call 33 | generator = None 34 | if input_seed != None: 35 | generator = torch.Generator("cuda").manual_seed(input_seed) 36 | 37 | if prompt == None: 38 | return {'message': "No prompt provided"} 39 | 40 | # Run the model 41 | with autocast("cuda"): 42 | image = model(prompt,height=height,width=width,num_inference_steps=num_inference_steps,guidance_scale=guidance_scale,generator=generator)["sample"][0] 43 | 44 | buffered = BytesIO() 45 | image.save(buffered,format="JPEG") 46 | image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') 47 | 48 | # Return the results as a dictionary 49 | return {'image_base64': image_base64} 50 | -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | # In this file, we define download_model 2 | # It runs during container build time to get model weights built into the container 3 | 4 | from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler 5 | import os 6 | 7 | def download_model(): 8 | # do a dry run of loading the huggingface model, which will download weights at build time 9 | #Set auth token which is required to download stable diffusion model weights 10 | HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN") 11 | 12 | lms = LMSDiscreteScheduler( 13 | beta_start=0.00085, 14 | beta_end=0.012, 15 | beta_schedule="scaled_linear" 16 | ) 17 | 18 | model = StableDiffusionPipeline.from_pretrained( 19 | "CompVis/stable-diffusion-v1-4", 20 | scheduler=lms, 21 | use_auth_token=HF_AUTH_TOKEN 22 | ) 23 | 24 | if __name__ == "__main__": 25 | download_model() 26 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | sanic==22.6.2 2 | diffusers 3 | transformers 4 | scipy 5 | websockets==10.0 6 | accelerate -------------------------------------------------------------------------------- /server.py: -------------------------------------------------------------------------------- 1 | # Do not edit if deploying to Banana Serverless 2 | # This file is boilerplate for the http server, and follows a strict interface. 3 | 4 | # Instead, edit the init() and inference() functions in app.py 5 | 6 | from sanic import Sanic, response 7 | import subprocess 8 | import app as user_src 9 | 10 | # We do the model load-to-GPU step on server startup 11 | # so the model object is available globally for reuse 12 | user_src.init() 13 | 14 | # Create the http server app 15 | server = Sanic("my_app") 16 | 17 | # Healthchecks verify that the environment is correct on Banana Serverless 18 | @server.route('/healthcheck', methods=["GET"]) 19 | def healthcheck(request): 20 | # dependency free way to check if GPU is visible 21 | gpu = False 22 | out = subprocess.run("nvidia-smi", shell=True) 23 | if out.returncode == 0: # success state on shell command 24 | gpu = True 25 | 26 | return response.json({"state": "healthy", "gpu": gpu}) 27 | 28 | # Inference POST handler at '/' is called for every http call from Banana 29 | @server.route('/', methods=["POST"]) 30 | def inference(request): 31 | try: 32 | model_inputs = response.json.loads(request.json) 33 | except: 34 | model_inputs = request.json 35 | 36 | output = user_src.inference(model_inputs) 37 | 38 | return response.json(output) 39 | 40 | 41 | if __name__ == '__main__': 42 | server.run(host='0.0.0.0', port="8000", workers=1) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # This file is used to verify your http server acts as expected 2 | # Run it with `python3 test.py`` 3 | 4 | import requests 5 | import base64 6 | from io import BytesIO 7 | from PIL import Image 8 | 9 | model_inputs = {'prompt': 'realistic field of grass'} 10 | 11 | res = requests.post('http://localhost:8000/', json = model_inputs) 12 | 13 | image_byte_string = res.json()["image_base64"] 14 | 15 | image_encoded = image_byte_string.encode('utf-8') 16 | image_bytes = BytesIO(base64.b64decode(image_encoded)) 17 | image = Image.open(image_bytes) 18 | image.save("output.jpg") --------------------------------------------------------------------------------