├── .gitignore ├── README.md ├── bentofile.yaml ├── configuration.yaml ├── images └── result.png ├── import_models.py ├── requirements.txt ├── service.py ├── start-server.py └── txt2img_test.sh /.gitignore: -------------------------------------------------------------------------------- 1 | output.jpg 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | *.dist-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | poetry.lock 31 | 32 | # IDE 33 | .idea 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # pip 46 | pip-wheel-metadata 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | .test_bundle/ 60 | *.xml 61 | 62 | # Sphinx documentation 63 | docs/build/ 64 | docs/source/_build 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # Jupyter Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # SageMath parsed files 79 | *.sage.py 80 | 81 | # Environments 82 | .env 83 | .venv 84 | env/ 85 | venv/ 86 | ENV/ 87 | env.bak/ 88 | venv.bak/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mypy 98 | .mypy_cache/ 99 | 100 | # MacOS X 101 | .DS_Store 102 | 103 | # vscode config 104 | .vscode/ 105 | 106 | # Integration test data 107 | tests/integration/data 108 | 109 | # Emacs buffers 110 | *~ 111 | 112 | # setuptools_scm generated version file 113 | typings 114 | **/*/_version.py 115 | 116 | # test files 117 | catboost_info 118 | 119 | # ignore pyvenv 120 | pyvenv.cfg 121 | 122 | # bazel generated files 123 | bazel-* 124 | grpc-client/cpp/bentoml 125 | grpc-client/go/bentoml 126 | grpc-client/node/bentoml 127 | node_modules 128 | thirdparty 129 | 130 | # Swift-related 131 | **/*/Sources/bentoml 132 | .DS_Store 133 | **/*/.build 134 | /Packages 135 | /*.xcodeproj 136 | xcuserdata/ 137 | DerivedData/ 138 | .swiftpm/config/registries.json 139 | .swiftpm/xcode/package.xcworkspace/contents.xcworkspacedata 140 | .netrc 141 | Package.resolved 142 | 143 | # Java-Kotlin 144 | **/*/java/**/*/bentoml 145 | **/*/kotlin/**/*/bentoml 146 | .gradle 147 | 148 | # PHP generated stubs 149 | **/*/php/Bentoml 150 | **/*/php/GPBMetadata 151 | composer.lock 152 | vendor 153 | *.pkl 154 | *.ckpt 155 | 156 | # Triton-related 157 | *.pt 158 | *.savedmodel 159 | *.onnx 160 | *.engine 161 | *.plan 162 | *.bin 163 | yolov5* 164 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

Deep Floyd IF with Multiple GPU

3 |

Serve state-of-the-art stable diffusion model with multiple GPU with ease. 4 |
5 | Powered by BentoML 🍱 + StabilityAI 🎨 + HuggingFace 🤗 6 |

7 |
8 | 9 | ## 📖 Introduction 📖 10 | - 🧪 **Stable Diffusion**: Stable Diffusion is a deep learning, text-to-image model primarily used to generate detailed images conditioned on text descriptions. 11 | 12 | - 🔮 **[IF by DeepFloyd Lab](https://github.com/deep-floyd/IF)**: IF is a novel state-of-the-art open-source text-to-image model with a high degree of photorealism and language understanding. 13 | 14 | - 🚀 **BentoML with IF and GPUs**: In this project, BentoML demonstrate how to serve IF models easily across multiple GPU 15 | 16 | - 🎛️ **Interactive Experience with Gradio UI**: You can play with the hosted IF model with an interactive Gradio UI. 17 | 18 | ## 🏃‍♂️ Running the Service 🏃‍♂️ 19 | ### Prerequisite 20 | To be able to run this project locally, you will need to have the following: 21 | - Python 3.8+ 22 | - `pip` installed 23 | - At least 2x16GB VRAM GPU or 1x40 VRAM GPU 24 | 25 | ### Installing Dependencies 26 | It is recommended to use a Virtual Environment in your python projects for dependencies isolation. Run the following to install dependencies: 27 | ```bash 28 | pip install -r requirements.txt 29 | ``` 30 | 31 | ### Import the IF Models 32 | To download the IF Models to your local Bento Store: 33 | ```bash 34 | python import_models.py 35 | ``` 36 | 37 | However, if you have never download models from huggingface via command line, you may need to authorize first using following commands: 38 | 39 | ``` 40 | pip install -U huggingface_hub 41 | huggingface-cli login 42 | ``` 43 | 44 | ### Run the Web Server with Gradio 45 | Run the server locally with web UI powered by gradio: 46 | 47 | ```bash 48 | # For a GPU with more than 40GB VRAM, run all models on the same GPU 49 | python start-server.py 50 | 51 | # For two Tesla T4 with 15GB VRAM each, 52 | # assign stage1 model to the first GPU, 53 | # and stage2 and stage3 models to the second GPU 54 | python start-server.py --stage1-gpu=0 --stage2-gpu=1 --stage3-gpu=1 55 | 56 | # For one Tesla T4 with 15GB VRAM and two additional GPUs with smaller VRAM size, 57 | # assign stage1 model to T4, 58 | # and stage2 and stage3 models to the second and third GPUs respectively 59 | python start-server.py --stage1-gpu=0 --stage2-gpu=1 --stage3-gpu=2 60 | ``` 61 | 62 | Then you can visit the web UI at . BentoML's api endpoint is also accessible at . To show all options that you can change (like server's port), just run `python start-server --help` 63 | 64 | ### Example Prompt 65 | **Prompt** 66 | ``` 67 | orange and black, head shot of a woman standing under street lights, dark theme, 68 | Frank Miller, cinema, ultra realistic, ambiance, insanely detailed and intricate, 69 | hyper realistic, 8k resolution, photorealistic, highly textured, intricate details 70 | ``` 71 | **Negative Prompt** 72 | ``` 73 | tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, 74 | mutation, mutated, extra limbs, extra legs, extra arms, disfigured, deformed, 75 | cross-eye, body out of frame, blurry, bad art, bad anatomy, blurred, text, 76 | watermark, grainy 77 | ``` 78 | **Results** 79 | ![Generated Image](images/result.png) 80 | 81 | ## 🚀 Bringing it to Production 🚀 82 | For this project, as it requires huge GPU devices which we typically do not have locally, it is particularly beneficial to deploy via [☁️ BentoCloud](https://www.bentoml.com/bento-cloud/) -- A managed distributed compute platform for Machine Learning Serving. 83 | 84 | Otherwise, BentoML offers a number of options for deploying and hosting online ML services into production, learn more at [Deploying Bento](https://docs.bentoml.org/en/latest/concepts/deploy.html). 85 | 86 | 87 | ## 👥 Community 👥 88 | BentoML has a thriving open source community where thousands of ML/AI practitioners are 89 | contributing to the project, helping other users and discussing the future of AI. 👉 [Pop into our Slack community!](https://l.bentoml.com/join-slack) 90 | -------------------------------------------------------------------------------- /bentofile.yaml: -------------------------------------------------------------------------------- 1 | service: "service.py:svc" 2 | include: 3 | - "service.py" 4 | - "configuration.yaml" 5 | python: 6 | packages: 7 | - torch>=2.0 8 | - transformers 9 | - accelerate 10 | - diffusers 11 | - triton 12 | - sentencepiece 13 | docker: 14 | distro: debian 15 | cuda_version: "11.6" 16 | env: 17 | BENTOML_CONFIG: "src/configuration.yaml" 18 | -------------------------------------------------------------------------------- /configuration.yaml: -------------------------------------------------------------------------------- 1 | api_server: 2 | timeout: 1800 3 | runners: 4 | timeout: 1500 5 | stage1_runner: 6 | resources: 7 | nvidia.com/gpu: [0] 8 | stage2_runner: 9 | resources: 10 | nvidia.com/gpu: [1] 11 | stage3_runner: 12 | resources: 13 | nvidia.com/gpu: [1] 14 | -------------------------------------------------------------------------------- /images/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bentoml/IF-multi-GPUs-demo/bb9289dba841017ae25f5ac98b7b6b83f6ec38e3/images/result.png -------------------------------------------------------------------------------- /import_models.py: -------------------------------------------------------------------------------- 1 | import bentoml 2 | from bentoml.exceptions import NotFound 3 | import diffusers 4 | 5 | if __name__ == "__main__": 6 | stage1_signatures = { 7 | "__call__": {"batchable": False}, 8 | "encode_prompt": {"batchable": False}, 9 | } 10 | 11 | stage1_model_tag = "IF-stage1:v1.0" 12 | try: 13 | bentoml.diffusers.get(stage1_model_tag) 14 | except NotFound: 15 | bentoml.diffusers.import_model( 16 | stage1_model_tag, "DeepFloyd/IF-I-XL-v1.0", 17 | signatures=stage1_signatures, 18 | variant="fp16", 19 | pipeline_class=diffusers.DiffusionPipeline,) 20 | 21 | stage2_model_tag = "IF-stage2:v1.0" 22 | try: 23 | bentoml.diffusers.get(stage2_model_tag) 24 | except NotFound: 25 | bentoml.diffusers.import_model( 26 | stage2_model_tag, "DeepFloyd/IF-II-L-v1.0", 27 | variant="fp16", 28 | pipeline_class=diffusers.DiffusionPipeline 29 | ) 30 | 31 | upscaler_model_name = "sd-upscaler" 32 | try: 33 | bentoml.diffusers.get(upscaler_model_name) 34 | except NotFound: 35 | bentoml.diffusers.import_model( 36 | upscaler_model_name, "stabilityai/stable-diffusion-x4-upscaler" 37 | ) 38 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | bentoml[io-image]>=1.0.22 2 | transformers 3 | diffusers 4 | torch>=2.0 5 | accelerate 6 | sentencepiece 7 | triton; sys_platform == "linux" 8 | gradio==3.20.1 9 | -------------------------------------------------------------------------------- /service.py: -------------------------------------------------------------------------------- 1 | from diffusers import DiffusionPipeline 2 | 3 | import bentoml 4 | from bentoml.io import JSON 5 | from bentoml.io import Image 6 | 7 | stage1_model = bentoml.diffusers.get("IF-stage1:v1.0") 8 | stage1_runner = stage1_model.with_options( 9 | enable_model_cpu_offload=True, 10 | ).to_runner(name="stage1_runner") 11 | 12 | stage2_model = bentoml.diffusers.get("IF-stage2:v1.0") 13 | stage2_runner = stage2_model.with_options( 14 | enable_model_cpu_offload=True, 15 | load_pretrained_extra_kwargs={"text_encoder": None}, 16 | ).to_runner(name="stage2_runner") 17 | 18 | stage3_model = bentoml.diffusers.get("sd-upscaler:latest") 19 | stage3_runner = stage3_model.with_options( 20 | pipeline_class=DiffusionPipeline, 21 | enable_model_cpu_offload=True, 22 | ).to_runner(name="stage3_runner") 23 | 24 | 25 | svc = bentoml.Service( 26 | "DeepFloyd-IF", runners=[stage1_runner, stage2_runner, stage3_runner] 27 | ) 28 | 29 | 30 | @svc.api(input=JSON(), output=Image()) 31 | def txt2img(input_data): 32 | prompt = input_data.get("prompt") 33 | negative_prompt = input_data.get("negative_prompt") 34 | prompt_embeds, negative_embeds = stage1_runner.encode_prompt.run( 35 | prompt=prompt, negative_prompt=negative_prompt 36 | ) 37 | res_t = stage1_runner.run( 38 | prompt_embeds=prompt_embeds, 39 | negative_prompt_embeds=negative_embeds, 40 | output_type="pt", 41 | ) 42 | images = res_t[0] 43 | res_t = stage2_runner.run( 44 | image=images, 45 | prompt_embeds=prompt_embeds, 46 | negative_prompt_embeds=negative_embeds, 47 | output_type="pt", 48 | ) 49 | images = res_t[0] 50 | res_t = stage3_runner.run( 51 | prompt=prompt, negative_prompt=negative_prompt, image=images, noise_level=100 52 | ) 53 | images = res_t[0] 54 | return images[0] 55 | -------------------------------------------------------------------------------- /start-server.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import bentoml 4 | from service import svc 5 | 6 | if __name__ == "__main__": 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--stage1-gpu", type=int, dest="gpu1", default=0, 10 | help="stage1 gpu index (default to 0)") 11 | parser.add_argument("--stage2-gpu", type=int, dest="gpu2", default=0, 12 | help="stage2 gpu index (default to 0)") 13 | parser.add_argument("--stage3-gpu", type=int, dest="gpu3", default=0, 14 | help="stage3 gpu index (default to 0)") 15 | parser.add_argument("--bentoml-api-workers", type=int, dest="api_workers", default=1, 16 | help="BentoML api server worker number (default to 1)") 17 | parser.add_argument("--bentoml-server-port", type=int, dest="bentoml_port", default=3000, 18 | help="BentoML server port (default to 3000)") 19 | parser.add_argument("--bentoml-server-host", type=str, dest="bentoml_host", default="0.0.0.0", 20 | help="BentoML server host (default to 0.0.0.0)") 21 | parser.add_argument("--gradio-server-port", type=int, dest="gradio_port", default=7860, 22 | help="gradio web UI server port (default to 7860)") 23 | parser.add_argument("--gradio-server-name", type=str, dest="gradio_host", default="0.0.0.0", 24 | help="gradio web UI server host (default to 0.0.0.0)") 25 | parser.add_argument("--gradio-share", type=bool, default=False, 26 | help="share your gradio app (default to False)") 27 | parser.add_argument("--debug", type=bool, default=False, 28 | help="debug mode with BentoML log output (default to False)") 29 | 30 | args = parser.parse_args() 31 | 32 | env_vars = { 33 | "BENTOML_CONFIG_OPTIONS": f""" 34 | runners.stage1_runner.resources."nvidia.com/gpu"[0]={args.gpu1} 35 | runners.stage2_runner.resources."nvidia.com/gpu"[0]={args.gpu2} 36 | runners.stage3_runner.resources."nvidia.com/gpu"[0]={args.gpu3} 37 | api_server.workers={args.api_workers} 38 | api_server.timeout=1800 39 | runners.timeout=1500 40 | """ 41 | } 42 | 43 | server = bentoml.HTTPServer(svc, port=args.bentoml_port, host=args.bentoml_host) 44 | server.timeout = 300 45 | 46 | if args.debug: 47 | stdout = sys.stdout 48 | stderr = sys.stderr 49 | else: 50 | stdout = None 51 | stderr = sys.stderr 52 | with server.start(env=env_vars, blocking=False, 53 | stdout=stdout, stderr=stderr): 54 | 55 | client = server.get_client() 56 | 57 | # gradio interface 58 | 59 | def inference(prompt, negative_prompt=""): 60 | img = client.txt2img(dict(prompt=prompt, negative_prompt=negative_prompt)) 61 | return img 62 | 63 | import gradio as gr 64 | with gr.Blocks() as demo: 65 | 66 | with gr.Row(): 67 | 68 | with gr.Column(scale=55): 69 | gallery = gr.Image( 70 | label="Generated images", show_label=False, 71 | ).style(grid=[2], height="auto", container=True) 72 | 73 | with gr.Column(scale=45): 74 | generate = gr.Button( 75 | value="Generate", 76 | variant="secondary" 77 | ).style(container=True) 78 | 79 | with gr.Group(): 80 | prompt = gr.Textbox( 81 | label="Prompt", 82 | show_label=False, 83 | max_lines=3, 84 | placeholder="Enter prompt", 85 | lines=3 86 | ).style(container=False) 87 | 88 | neg_prompt = gr.Textbox( 89 | label="Negative prompt", 90 | placeholder="What to exclude from the image" 91 | ) 92 | 93 | inputs = [prompt, neg_prompt] 94 | outputs = [gallery] 95 | prompt.submit(inference, inputs=inputs, outputs=outputs, show_progress=True) 96 | generate.click(inference, inputs=inputs, outputs=outputs, show_progress=True) 97 | 98 | demo.queue().launch( 99 | server_name=args.gradio_host, 100 | server_port=args.gradio_port, 101 | share=args.gradio_share 102 | ) 103 | -------------------------------------------------------------------------------- /txt2img_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | echo -n "enter prompt (e.g. a bento box): " 3 | while read -r PROMPT && [ -z "$PROMPT" ]; do echo -n "enter prompt (e.g. a bento box): "; done 4 | echo -n "enter negative prompt (e.g. low-res,blurry): " 5 | while read -r N_PROMPT && [ -z "$N_PROMPT" ]; do echo -n "enter negative prompt (e.g. low-res,blurry): "; done 6 | 7 | if [[ -n "$PROMPT" ]]; 8 | then 9 | curl -X POST http://127.0.0.1:3000/txt2img -H 'Content-Type: application/json' -d "{\"prompt\":\"$PROMPT\",\"negative_prompt\":\"$N_PROMPT\"}" --output output.jpg 10 | else 11 | echo "No input" 12 | fi 13 | --------------------------------------------------------------------------------