├── .dockerignore ├── .github ├── ISSUE_TEMPLATE │ ├── BUG_ISSUE.md │ └── FEATURE_REQUEST.md └── PULL_REQUEST_TEMPLATE.md ├── .gitignore ├── README.md ├── api ├── __init__.py ├── home │ ├── __init__.py │ └── router.py └── stable_diffusion │ ├── __init__.py │ ├── request.py │ ├── response.py │ └── router.py ├── app ├── __init__.py ├── server.py └── stable_diffusion │ ├── manager │ ├── __init__.py │ ├── manager.py │ └── schema.py │ ├── pipeline │ ├── __init__.py │ ├── image2image.py │ ├── inpaint.py │ ├── text2image.py │ └── types.py │ └── service.py ├── core ├── decorator │ └── singleton.py ├── dependencies │ ├── __init__.py │ └── models.py ├── logger │ └── __init__.py ├── middlewares │ └── logging.py ├── settings │ ├── __init__.py │ ├── enum.py │ └── settings.py └── utils │ └── convert_script.py ├── docker-compose.yaml ├── docker ├── api │ ├── Dockerfile │ └── start.sh └── frontend │ └── Dockerfile ├── env └── .gitignore ├── frontend ├── helps.py ├── inpaint.py ├── pages │ ├── image2image.py │ └── text2image.py ├── requirements.txt ├── settings.py ├── task.py └── utils.py ├── huggingface_model_download.py ├── main.py ├── requirements.txt └── src └── image ├── image2image ├── 1.png └── 2.png ├── inpaint ├── 0.png ├── 1.png └── 2.png └── text2image ├── 1.png └── 2.png /.dockerignore: -------------------------------------------------------------------------------- 1 | # IDE 2 | .vscode/ 3 | 4 | # git 5 | .git/ 6 | 7 | # CACHE 8 | **/__pycache__/** 9 | .pytest_cache 10 | .testmondata 11 | 12 | # DATASET 13 | **/DATASET/** 14 | 15 | # MODELS 16 | *.zip 17 | *.pt.zip 18 | *.pt 19 | *.pth 20 | *.onnx 21 | 22 | # ETC 23 | env/ 24 | static/ 25 | docker-compose.yaml 26 | *.ipynb 27 | 28 | # LOG 29 | logs/ 30 | 31 | # DEV 32 | debug.py 33 | 34 | # frontend 35 | frontend/ -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/BUG_ISSUE.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug Report Template 3 | about: 버그 리포트 템플릿 4 | title: "" 5 | assignees: "yslee" 6 | --- 7 | 8 | # System info 9 | 10 | # Describe of bug 11 | 12 | # Code example 13 | 14 | # Log or Screenshot 15 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/FEATURE_REQUEST.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature Request Template 3 | about: 기능 및 추가 요청을 위한 템플릿 4 | title: "" 5 | assignees: "yslee" 6 | --- 7 | 8 | # Description 9 | 10 | # TODO 11 | 12 | - [ ] todo 13 | - [ ] todo 14 | 15 | # ETC 16 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | # Summary 2 | 3 | content 4 | 5 | # Work 6 | 7 | content 8 | 9 | # Related issues [optional] 10 | 11 | write related issues 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # IDE 2 | .vscode/ 3 | 4 | # CACHE 5 | **/__pycache__/** 6 | .pytest_cache 7 | .testmondata 8 | 9 | # DATASET 10 | **/DATASET/** 11 | 12 | # MODELS 13 | *.zip 14 | *.pt.zip 15 | *.pt 16 | *.pth 17 | *.onnx 18 | 19 | # LOG 20 | tensorboard/ 21 | wandb/ 22 | logs/ 23 | 24 | temp/ 25 | 26 | # DEV 27 | debug.py 28 | *.ipynb 29 | ckpt/ 30 | 31 | static/ 32 | lpw_stable_diffusion/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unofficial Fastapi implementation Stable-Diffusion API 2 | 3 | UNOFFICIAL, [Stable-Diffusion](https://github.com/CompVis/stable-diffusion) api using FastAPI 4 | 5 | # Samples 6 | 7 | | Text2Image-01 | Text2Image-02 | 8 | | :--------------------------------: | :--------------------------------: | 9 | | ![](./src/image/text2image/1.png) | ![](./src/image/text2image/2.png) | 10 | | Image2Image-01 | Image2Image-02 | 11 | | ![](./src/image/image2image/1.png) | ![](./src/image/image2image/2.png) | 12 | | Inpaint-01 | Inpaint-02 | 13 | | ![](./src/image/inpaint/0.png) | ![](./src/image/inpaint/0.png) | 14 | | ![](./src/image/inpaint/1.png) | ![](./src/image/inpaint/2.png) | 15 | 16 | # Features 17 | - [x] long-prompt-weighting support 18 | - [x] text2image 19 | - [x] image2image 20 | - [x] inpaints 21 | - [x] negative-prompt 22 | - [x] celery async task (check celery_task [branch](https://github.com/rapidrabbit76/stable-diffusion-API/tree/celery_task)) 23 | - [x] original ```ckpt``` format support 24 | - [ ] object storage support 25 | - [ ] stable-diffusion 2.0 support 26 | - [ ] token size checker 27 | - [ ] JAX/Flax pipeline 28 | 29 | # Requirements 30 | 31 | 32 | ## API 33 | 34 | ```txt 35 | fastapi[all]==0.80.0 36 | fastapi-restful==0.4.3 37 | fastapi-health==0.4.0 38 | service-streamer==0.1.2 39 | pydantic==1.9.2 40 | diffusers==0.3.0 41 | transformers==4.19.2 42 | scipy 43 | ftfy 44 | ``` 45 | 46 | ## Frontend 47 | ```txt 48 | streamlit==1.12.2 49 | requests==2.27.1 50 | requests-toolbelt==0.9.1 51 | pydantic==1.8.2 52 | streamlit-drawable-canvas==0.9.2 53 | ``` 54 | 55 | 56 | # API 57 | 58 | 59 | ## /text2image 60 | create image from input prompt 61 | 62 | inputs: 63 | 64 | - prompt(str): text prompt 65 | - num_images(int): number of images 66 | - guidance_scale(float): guidance scale for stable-diffusion 67 | - height(int): image height 68 | - width(int): image width 69 | - seed(int): generator seed 70 | 71 | outputs: 72 | 73 | - prompt(str): input text prompt 74 | - task_id(str): uuid4 hex string 75 | - image_urls(str): generated images url 76 | 77 | 78 | ## /image2image 79 | create image from input image 80 | 81 | inputs: 82 | 83 | - prompt(str): text prompt 84 | - init_image(imagefile): init image for i2i task 85 | - num_images(int): number of images 86 | - guidance_scale(float): guidance scale for stable-diffusion 87 | - seed(int): generator seed 88 | 89 | outputs: 90 | 91 | - prompt(str): input text prompt 92 | - task_id(str): uuid4 hex string 93 | - image_urls(str): generated images url 94 | 95 | 96 | 97 | # Environment variable 98 | 99 | 100 | ```bash 101 | # env setting is in 102 | >> ./core/settings/settings.py 103 | ``` 104 | 105 | | Name | Default | Desc | 106 | | ------------------------ | ----------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- | 107 | | MODEL_ID | CompVis/stable-diffusion-v1-4 | huggingface repo id or model path | 108 | | ENABLE_ATTENTION_SLICING | True | [Enable sliced attention computation.](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionPipeline.enable_attention_slicing) | 109 | | CUDA_DEVICE | "cuda" | target cuda device | 110 | | CUDA_DEVICES | [0] | visible cuda device | 111 | | MB_BATCH_SIZE | 1 | Micro Batch: MAX Batch size | 112 | | MB_TIMEOUT | 120 | Micro Batch: timeout sec | 113 | | HUGGINGFACE_TOKEN | None | huggingface access token | 114 | | IMAGESERVER_URL | None | result image base url | 115 | | SAVE_DIR | static | result image save dir | 116 | | CORS_ALLOW_ORIGINS | [*] | cross origin resource sharing setting for FastAPI | 117 | 118 | # RUN from code (API) 119 | 120 | ## 1. install python Requirements 121 | ```bash 122 | pip install -r requirements.txt 123 | ``` 124 | 125 | ## 2. downlaod and caching huggingface model 126 | ```bash 127 | python huggingface_model_download.py 128 | # check stable-diffusion model in huggingface cache dir 129 | [[ -d ~/.cache/huggingface/diffusers/models--CompVis--stable-diffusion-v1-4 ]] && echo "exist" 130 | >> exist 131 | ``` 132 | 133 | ## 3. update settings.py in ./core/settings/settings.py 134 | ```python 135 | # example 136 | class ModelSetting(BaseSettings): 137 | MODEL_ID: str = "CompVis/stable-diffusion-v1-4" # huggingface repo id 138 | ENABLE_ATTENTION_SLICING: bool = True 139 | ... 140 | class Settings( 141 | ... 142 | ): 143 | HUGGINGFACE_TOKEN: str = "YOUR HUGGINGFACE ACCESS TOKEN" 144 | IMAGESERVER_URL: str = "http://localhost:3000/images" 145 | SAVE_DIR: str = 'static' 146 | ... 147 | ``` 148 | 149 | ## 4. RUN API from code 150 | ```bash 151 | bash docker/api/start.sh 152 | ``` 153 | 154 | # RUN from code (frontend) 155 | 156 | ## 1. install python Requirements 157 | ```bash 158 | pip install \ 159 | streamlit==1.12.2 \ 160 | requests==2.27.1 \ 161 | requests-toolbelt==0.9.1 \ 162 | pydantic==1.8.2 \ 163 | streamlit-drawable-canvas==0.9.2 164 | ``` 165 | 166 | ## 2. RUN streamlit frontend 167 | ```bash 168 | streamlit run inpaint.py 169 | ``` 170 | 171 | 172 | # RUN using Docker (docker-compose) 173 | 174 | ## 1. Image Build 175 | ```bash 176 | docker-compose build 177 | ``` 178 | 179 | ## 3. update docker-compose.yaml file in repo root 180 | ```yaml 181 | version: "3.7" 182 | services: 183 | api: 184 | ... 185 | volumes: 186 | # mount huggingface model cache dir path to container root user home dir 187 | - /model:/model # if you load pretraind model 188 | - ... 189 | environment: 190 | ... 191 | MODEL_ID: "CompVis/stable-diffusion-v1-4" 192 | HUGGINGFACE_TOKEN: {YOUR HUGGINGFACE ACCESS TOKEN} 193 | ... 194 | 195 | deploy: 196 | ... 197 | frontend: 198 | ... 199 | ``` 200 | 201 | ## 4. Container RUN 202 | ```bash 203 | docker-compose up -d 204 | # or API only 205 | docker-compsoe up -d api 206 | # or frontend only 207 | docker-compsoe up -d frontend 208 | ``` 209 | 210 | 211 | 212 | 213 | ## References 214 | - [CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion) 215 | - [huggingface, stable-diffusion](https://huggingface.co/CompVis) 216 | - [teamhide/fastapi-boilerplate](https://github.com/teamhide/fastapi-boilerplate) -------------------------------------------------------------------------------- /api/__init__.py: -------------------------------------------------------------------------------- 1 | from .stable_diffusion import router as StableDiffusionRouter 2 | from .home import router as HomeRouter 3 | -------------------------------------------------------------------------------- /api/home/__init__.py: -------------------------------------------------------------------------------- 1 | from .router import router 2 | -------------------------------------------------------------------------------- /api/home/router.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from fastapi import Response 3 | from fastapi_restful.cbv import cbv 4 | from fastapi_restful.inferring_router import InferringRouter 5 | 6 | from core.settings import get_settings 7 | 8 | 9 | router = InferringRouter() 10 | env = get_settings() 11 | 12 | 13 | @cbv(router) 14 | class Home: 15 | @router.get("/") 16 | async def index(self): 17 | """ELB check""" 18 | current_time = datetime.utcnow() 19 | msg = f"Notification API (UTC: {current_time.strftime('%Y.%m.%d %H:%M:%S')})" 20 | return Response(msg) 21 | -------------------------------------------------------------------------------- /api/stable_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from .router import router 2 | -------------------------------------------------------------------------------- /api/stable_diffusion/request.py: -------------------------------------------------------------------------------- 1 | import typing as T 2 | from pydantic import BaseModel, Field 3 | from fastapi import Form, UploadFile, File 4 | import sys 5 | from random import randint 6 | from PIL import Image 7 | 8 | 9 | class PromptRequest(BaseModel): 10 | prompt: str = Field(..., description="text prompt") 11 | num_images: int = Field(1, description="num images", ge=1, le=2) 12 | 13 | 14 | def random_seed(seed: T.Optional[int] = Form(None)): 15 | seed = seed if seed is not None else randint(1, sys.maxsize) 16 | return seed 17 | 18 | 19 | def read_image(image: UploadFile) -> Image.Image: 20 | image = Image.open(image.file).convert("RGB") 21 | return image 22 | -------------------------------------------------------------------------------- /api/stable_diffusion/response.py: -------------------------------------------------------------------------------- 1 | import typing as T 2 | from pydantic import BaseModel, Field, HttpUrl 3 | 4 | 5 | class StableDiffussionResponse(BaseModel): 6 | prompt: str = Field(..., description="input prompt") 7 | task_id: str = Field(..., description="task id") 8 | image_urls: T.List[T.Union[str, HttpUrl]] = Field(..., description="image url") 9 | -------------------------------------------------------------------------------- /api/stable_diffusion/router.py: -------------------------------------------------------------------------------- 1 | import typing as T 2 | import os 3 | from uuid import uuid4 4 | 5 | from fastapi import Form, Depends, UploadFile, File 6 | from fastapi_restful.cbv import cbv 7 | from fastapi_restful.inferring_router import InferringRouter 8 | 9 | from .request import random_seed, read_image 10 | from .response import StableDiffussionResponse 11 | from app.stable_diffusion.service import StableDiffusionService 12 | from core.settings import get_settings 13 | 14 | router = InferringRouter() 15 | env = get_settings() 16 | 17 | 18 | @cbv(router) 19 | class StableDiffusion: 20 | svc: StableDiffusionService = Depends(StableDiffusionService) 21 | 22 | @router.post("/text2image", response_model=StableDiffussionResponse) 23 | def text2image( 24 | self, 25 | prompt: str = Form(), 26 | negative_prompt: str = Form(default=""), 27 | num_images: int = Form(1, description="num images", ge=1, le=8), 28 | steps: int = Form(25, ge=1), 29 | guidance_scale: float = Form( 30 | 7.5, description="guidance_scale", gt=0, le=20 31 | ), 32 | height: int = Form(512, description="result height"), 33 | width: int = Form(512, description="result width"), 34 | seed: T.Optional[int] = Form(None), 35 | ): 36 | task_id = str(uuid4()) 37 | 38 | images = self.svc.text2image( 39 | prompt=prompt, 40 | negative_prompt=negative_prompt, 41 | num_images=num_images, 42 | num_inference_steps=steps, 43 | guidance_scale=guidance_scale, 44 | height=height, 45 | width=width, 46 | seed=seed, 47 | ) 48 | 49 | info = { 50 | "task": "text2image", 51 | "prompt": prompt, 52 | "guidance_scale": guidance_scale, 53 | "height": height, 54 | "width": width, 55 | "seed": seed, 56 | "num_inference_steps": steps, 57 | } 58 | image_paths = self.svc.image_save(images, task_id, info=info) 59 | urls = [os.path.join(env.IMAGESERVER_URL, path) for path in image_paths] 60 | 61 | response = StableDiffussionResponse( 62 | prompt=prompt, 63 | task_id=task_id, 64 | image_urls=urls, 65 | ) 66 | return response 67 | 68 | @router.post("/image2image", response_model=StableDiffussionResponse) 69 | def image2image( 70 | self, 71 | prompt: str = Form(), 72 | negative_prompt: str = Form(default=""), 73 | init_image: UploadFile = File(...), 74 | num_images: int = Form(1, description="num images", ge=1, le=8), 75 | steps: int = Form(25, ge=1), 76 | strength: float = Form(0.8, ge=0, le=1.0), 77 | guidance_scale: float = Form( 78 | 7.5, description="guidance_scale", gt=0, le=20 79 | ), 80 | seed: T.Optional[int] = Form(None), 81 | ): 82 | init_image = read_image(init_image) 83 | task_id = str(uuid4()) 84 | 85 | images = self.svc.image2image( 86 | prompt=prompt, 87 | negative_prompt=negative_prompt, 88 | init_image=init_image, 89 | num_images=num_images, 90 | strength=strength, 91 | num_inference_steps=steps, 92 | guidance_scale=guidance_scale, 93 | seed=seed, 94 | ) 95 | 96 | info = { 97 | "task": "image2image", 98 | "prompt": prompt, 99 | "strength": strength, 100 | "guidance_scale": guidance_scale, 101 | "seed": seed, 102 | "num_inference_steps": steps, 103 | } 104 | image_paths = self.svc.image_save(images, task_id, info=info) 105 | init_image.save(os.path.join(env.SAVE_DIR, task_id, "init_image.webp")) 106 | urls = [os.path.join(env.IMAGESERVER_URL, path) for path in image_paths] 107 | 108 | response = StableDiffussionResponse( 109 | prompt=prompt, 110 | task_id=task_id, 111 | image_urls=urls, 112 | ) 113 | return response 114 | 115 | @router.post("/inpaint", response_model=StableDiffussionResponse) 116 | def inpaint( 117 | self, 118 | prompt: str = Form(), 119 | negative_prompt: str = Form(default=""), 120 | init_image: UploadFile = File(...), 121 | mask_image: UploadFile = File(...), 122 | num_images: int = Form(1, description="num images", ge=1, le=8), 123 | steps: int = Form(25, ge=1), 124 | strength: float = Form(0.8, ge=0, le=1.0), 125 | guidance_scale: float = Form( 126 | 7.5, description="guidance_scale", gt=0, le=20 127 | ), 128 | seed: T.Optional[int] = Form(None), 129 | ): 130 | init_image = read_image(init_image) 131 | mask_image = read_image(mask_image) 132 | 133 | task_id = str(uuid4()) 134 | images = self.svc.inpaint( 135 | prompt=prompt, 136 | negative_prompt=negative_prompt, 137 | init_image=init_image, 138 | mask_image=mask_image, 139 | num_inference_steps=steps, 140 | strength=strength, 141 | num_images=num_images, 142 | guidance_scale=guidance_scale, 143 | seed=seed, 144 | ) 145 | 146 | info = { 147 | "task": "inpaint", 148 | "prompt": prompt, 149 | "strength": strength, 150 | "guidance_scale": guidance_scale, 151 | "seed": seed, 152 | "num_inference_steps": steps, 153 | } 154 | image_paths = self.svc.image_save(images, task_id, info=info) 155 | init_image.save(os.path.join(env.SAVE_DIR, task_id, "init_image.webp")) 156 | mask_image.save(os.path.join(env.SAVE_DIR, task_id, "mask_image.webp")) 157 | urls = [os.path.join(env.IMAGESERVER_URL, path) for path in image_paths] 158 | 159 | response = StableDiffussionResponse( 160 | prompt=prompt, 161 | task_id=task_id, 162 | image_urls=urls, 163 | ) 164 | return response 165 | -------------------------------------------------------------------------------- /app/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rapidrabbit76/stable-diffusion-API/25e2ae5bc83837865a48151db71447e66fabbfd0/app/__init__.py -------------------------------------------------------------------------------- /app/server.py: -------------------------------------------------------------------------------- 1 | from fastapi.staticfiles import StaticFiles 2 | import typing as T 3 | 4 | from fastapi import FastAPI, Response 5 | from fastapi.middleware.cors import CORSMiddleware 6 | from fastapi.middleware import Middleware 7 | 8 | import api 9 | from core.settings import env 10 | 11 | 12 | def init_router(app: FastAPI): 13 | app.mount( 14 | "/images", 15 | StaticFiles(directory=env.SAVE_DIR), 16 | name="result image", 17 | ) 18 | app.include_router(api.StableDiffusionRouter) 19 | app.include_router(api.HomeRouter) 20 | app.router.redirect_slashes = False 21 | 22 | 23 | def create_app() -> FastAPI: 24 | app = FastAPI( 25 | redoc_url=None, 26 | middleware=init_middleware(), 27 | ) 28 | init_router(app) 29 | return app 30 | 31 | 32 | def init_middleware() -> T.List[Middleware]: 33 | middleware = [ 34 | Middleware( 35 | CORSMiddleware, 36 | allow_origins=env.CORS_ALLOW_ORIGINS, 37 | allow_credentials=env.CORS_CREDENTIALS, 38 | allow_methods=env.CORS_ALLOW_METHODS, 39 | allow_headers=env.CORS_ALLOW_HEADERS, 40 | ), 41 | ] 42 | return middleware 43 | 44 | 45 | def init_settings(app: FastAPI): 46 | @app.on_event("startup") 47 | def startup_event(): 48 | from core.dependencies import models 49 | 50 | @app.on_event("shutdown") 51 | def shutdown_event(): 52 | pass 53 | 54 | 55 | app = create_app() 56 | init_settings(app) 57 | -------------------------------------------------------------------------------- /app/stable_diffusion/manager/__init__.py: -------------------------------------------------------------------------------- 1 | from .manager import ( 2 | build_streamer, 3 | ) 4 | from .schema import ( 5 | Text2ImageTask, 6 | Image2ImageTask, 7 | InpaintTask, 8 | ) 9 | -------------------------------------------------------------------------------- /app/stable_diffusion/manager/manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import lru_cache 3 | import typing as T 4 | import torch 5 | 6 | torch.backends.cudnn.benchmark = True 7 | import sys 8 | from random import randint 9 | from service_streamer import ThreadedStreamer 10 | from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler 11 | 12 | from app.stable_diffusion.manager.schema import ( 13 | InpaintTask, 14 | Text2ImageTask, 15 | Image2ImageTask, 16 | ) 17 | from core.settings import get_settings 18 | 19 | from core.utils.convert_script import conver_ckpt_to_diff 20 | 21 | from functools import lru_cache 22 | 23 | env = get_settings() 24 | 25 | _StableDiffusionTask = T.Union[ 26 | Text2ImageTask, 27 | Image2ImageTask, 28 | InpaintTask, 29 | ] 30 | 31 | 32 | @lru_cache() 33 | def build_pipeline(repo: str, device: str, enable_attention_slicing: bool): 34 | 35 | # convert ckpt to diffusers 36 | if repo.lower().endswith(".ckpt") and os.path.exists(repo): 37 | dump_path = repo[:-5] 38 | repo = conver_ckpt_to_diff(ckpt_path=repo, dump_path=dump_path) 39 | 40 | pipe = DiffusionPipeline.from_pretrained( 41 | repo, 42 | torch_dtype=torch.float16, 43 | revision="fp16", 44 | custom_pipeline="lpw_stable_diffusion", 45 | ) 46 | 47 | pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) 48 | pipe.safety_checker = lambda images, clip_input: (images, False) 49 | 50 | if enable_attention_slicing: 51 | pipe.enable_attention_slicing() 52 | 53 | pipe = pipe.to(device) 54 | return pipe 55 | 56 | 57 | build_pipeline( 58 | repo=env.MODEL_ID, 59 | device=env.CUDA_DEVICE, 60 | enable_attention_slicing=env.ENABLE_ATTENTION_SLICING, 61 | ) 62 | 63 | 64 | class StableDiffusionManager: 65 | def __init__(self): 66 | self.pipe = build_pipeline( 67 | repo=env.MODEL_ID, 68 | device=env.CUDA_DEVICE, 69 | enable_attention_slicing=env.ENABLE_ATTENTION_SLICING, 70 | ) 71 | 72 | @torch.inference_mode() 73 | def predict( 74 | self, 75 | batch: T.List[_StableDiffusionTask], 76 | ): 77 | task = batch[0] 78 | pipeline = self.pipe 79 | if isinstance(task, Text2ImageTask): 80 | pipeline = self.pipe.text2img 81 | elif isinstance(task, Image2ImageTask): 82 | pipeline = self.pipe.img2img 83 | elif isinstance(task, InpaintTask): 84 | pipeline = self.pipe.inpaint 85 | else: 86 | raise NotImplementedError 87 | 88 | device = env.CUDA_DEVICE 89 | 90 | generator = self._get_generator(task, device) 91 | with torch.autocast("cuda" if device != "cpu" else "cpu"): 92 | task = task.dict() 93 | del task["seed"] 94 | images = pipeline(**task, generator=generator).images 95 | if device != "cpu": 96 | torch.cuda.empty_cache() 97 | 98 | return [images] 99 | 100 | def _get_generator(self, task: _StableDiffusionTask, device: str): 101 | generator = torch.Generator(device=device) 102 | seed = task.seed 103 | seed = seed if seed else randint(1, sys.maxsize) 104 | seed = seed if seed > 0 else randint(1, sys.maxsize) 105 | generator.manual_seed(seed) 106 | return generator 107 | 108 | 109 | @lru_cache(maxsize=1) 110 | def build_streamer() -> ThreadedStreamer: 111 | manager = StableDiffusionManager() 112 | streamer = ThreadedStreamer( 113 | manager.predict, 114 | batch_size=1, 115 | max_latency=0, 116 | ) 117 | return streamer 118 | -------------------------------------------------------------------------------- /app/stable_diffusion/manager/schema.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field, validator 2 | from PIL import Image 3 | import typing as T 4 | 5 | 6 | class Text2ImageTask(BaseModel): 7 | prompt: T.Union[str, T.List[str]] = Field(...) 8 | negative_prompt: T.Union[str, T.List[str]] = Field(...) 9 | num_inference_steps: int = Field(..., gt=0) 10 | guidance_scale: float = Field(..., ge=0.0) 11 | height: int 12 | width: int 13 | seed: T.Optional[int] 14 | 15 | @validator("height", "width") 16 | def size_constraint(cls, size): 17 | cond = size % 64 18 | if cond != 0: 19 | raise ValueError("height and width must multiple of 64") 20 | return size 21 | 22 | 23 | class Image2ImageTask(BaseModel): 24 | prompt: T.Union[str, T.List[str]] = Field(...) 25 | negative_prompt: T.Union[str, T.List[str]] = Field(...) 26 | image: T.Any 27 | strength: float = Field(..., ge=0.0, le=1.0) 28 | num_inference_steps: int = Field(..., gt=0) 29 | guidance_scale: float = Field(..., ge=0.0) 30 | seed: T.Optional[int] 31 | 32 | 33 | class InpaintTask(BaseModel): 34 | prompt: T.Union[str, T.List[str]] = Field(...) 35 | negative_prompt: T.Union[str, T.List[str]] = Field(...) 36 | image: T.Any 37 | mask_image: T.Any 38 | strength: float = Field(..., ge=0.0, le=1.0) 39 | num_inference_steps: int = Field(..., gt=0) 40 | guidance_scale: float = Field(..., ge=0.0) 41 | seed: T.Optional[int] 42 | -------------------------------------------------------------------------------- /app/stable_diffusion/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | from .text2image import StableDiffusionText2ImagePipeline 2 | from .image2image import StableDiffusionImg2ImgPipeline 3 | from .inpaint import StableDiffusionInpaintPipeline 4 | -------------------------------------------------------------------------------- /app/stable_diffusion/pipeline/image2image.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import List, Optional, Union 3 | from transformers import CLIPTextModel, CLIPTokenizer 4 | from diffusers import DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler 5 | from diffusers import AutoencoderKL, UNet2DConditionModel 6 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( 7 | DiffusionPipeline, 8 | ) 9 | 10 | import torch 11 | import PIL 12 | import numpy as np 13 | 14 | from .types import StableDiffusionCallback 15 | 16 | 17 | def preprocess(image) -> torch.Tensor: 18 | w, h = image.size 19 | w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 20 | image = image.resize((w, h), resample=PIL.Image.LANCZOS) 21 | image = np.array(image).astype(np.float32) / 255.0 22 | image = image[None].transpose(0, 3, 1, 2) 23 | image = torch.from_numpy(image) 24 | return 2.0 * image - 1.0 25 | 26 | 27 | class StableDiffusionImg2ImgPipeline(DiffusionPipeline): 28 | def __init__( 29 | self, 30 | vae: AutoencoderKL, 31 | text_encoder: CLIPTextModel, 32 | tokenizer: CLIPTokenizer, 33 | unet: UNet2DConditionModel, 34 | scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], 35 | ): 36 | super().__init__() 37 | scheduler = scheduler.set_format("pt") 38 | self.register_modules( 39 | vae=vae, 40 | text_encoder=text_encoder, 41 | tokenizer=tokenizer, 42 | unet=unet, 43 | scheduler=scheduler, 44 | ) 45 | 46 | def enable_attention_slicing( 47 | self, 48 | slice_size: Optional[Union[str, int]] = "auto", 49 | ): 50 | if slice_size == "auto": 51 | slice_size = self.unet.config.attention_head_dim // 2 52 | self.unet.set_attention_slice(slice_size) 53 | 54 | def disable_attention_slicing(self): 55 | self.enable_attention_slicing(None) 56 | 57 | @torch.inference_mode() 58 | def __call__( 59 | self, 60 | prompt: Union[str, List[str]], 61 | negative_prompt: Union[str, List[str]], 62 | init_image: Union[torch.FloatTensor, PIL.Image.Image], 63 | strength: float = 0.8, 64 | num_inference_steps: Optional[int] = 50, 65 | guidance_scale: Optional[float] = 7.5, 66 | eta: Optional[float] = 0.0, 67 | generator: Optional[torch.Generator] = None, 68 | callbacks: Optional[List[StableDiffusionCallback]] = None, 69 | **kwargs, 70 | ): 71 | if isinstance(prompt, str): 72 | batch_size = 1 73 | elif isinstance(prompt, list): 74 | batch_size = len(prompt) 75 | else: 76 | raise ValueError( 77 | f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" 78 | ) 79 | 80 | if strength < 0 or strength > 1: 81 | raise ValueError( 82 | f"The value of strength should in [0.0, 1.0] but is {strength}" 83 | ) 84 | 85 | # set timesteps 86 | accepts_offset = "offset" in set( 87 | inspect.signature(self.scheduler.set_timesteps).parameters.keys() 88 | ) 89 | extra_set_kwargs = {} 90 | offset = 0 91 | if accepts_offset: 92 | offset = 1 93 | extra_set_kwargs["offset"] = 1 94 | 95 | self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) 96 | 97 | if isinstance(init_image, PIL.Image.Image): 98 | init_image = preprocess(init_image) 99 | 100 | # encode the init image into latents and scale the latents 101 | init_latent_dist = self.vae.encode( 102 | init_image.to(self.device) 103 | ).latent_dist 104 | init_latents = init_latent_dist.sample(generator=generator) 105 | init_latents = 0.18215 * init_latents 106 | 107 | # expand init_latents for batch_size 108 | init_latents = torch.cat([init_latents] * batch_size) 109 | 110 | # get the original timestep using init_timestep 111 | init_timestep = int(num_inference_steps * strength) + offset 112 | init_timestep = min(init_timestep, num_inference_steps) 113 | if isinstance(self.scheduler, LMSDiscreteScheduler): 114 | timesteps = torch.tensor( 115 | [num_inference_steps - init_timestep] * batch_size, 116 | dtype=torch.long, 117 | device=self.device, 118 | ) 119 | else: 120 | timesteps = self.scheduler.timesteps[-init_timestep] 121 | timesteps = torch.tensor( 122 | [timesteps] * batch_size, dtype=torch.long, device=self.device 123 | ) 124 | 125 | # add noise to latents using the timesteps 126 | noise = torch.randn( 127 | init_latents.shape, generator=generator, device=self.device 128 | ) 129 | init_latents = self.scheduler.add_noise( 130 | init_latents, noise, timesteps 131 | ).to(self.device) 132 | 133 | # get prompt text embeddings 134 | text_input = self.tokenizer( 135 | prompt, 136 | padding="max_length", 137 | max_length=self.tokenizer.model_max_length, 138 | truncation=True, 139 | return_tensors="pt", 140 | ) 141 | text_embeddings = self.text_encoder( 142 | text_input.input_ids.to(self.device) 143 | )[0] 144 | 145 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 146 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 147 | # corresponds to doing no classifier free guidance. 148 | do_classifier_free_guidance = guidance_scale > 1.0 149 | # get unconditional embeddings for classifier free guidance 150 | if do_classifier_free_guidance: 151 | max_length = text_input.input_ids.shape[-1] 152 | uncond_input = self.tokenizer( 153 | negative_prompt, 154 | padding="max_length", 155 | max_length=max_length, 156 | return_tensors="pt", 157 | ) 158 | uncond_embeddings = self.text_encoder( 159 | uncond_input.input_ids.to(self.device) 160 | )[0] 161 | 162 | # For classifier free guidance, we need to do two forward passes. 163 | # Here we concatenate the unconditional and text embeddings into a single batch 164 | # to avoid doing two forward passes 165 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 166 | 167 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 168 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 169 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 170 | # and should be between [0, 1] 171 | accepts_eta = "eta" in set( 172 | inspect.signature(self.scheduler.step).parameters.keys() 173 | ) 174 | extra_step_kwargs = {} 175 | if accepts_eta: 176 | extra_step_kwargs["eta"] = eta 177 | 178 | latents = init_latents 179 | 180 | t_start = max(num_inference_steps - init_timestep + offset, 0) 181 | for i, t in enumerate( 182 | self.progress_bar(self.scheduler.timesteps[t_start:]) 183 | ): 184 | t_index = t_start + i 185 | 186 | # expand the latents if we are doing classifier free guidance 187 | latent_model_input = ( 188 | torch.cat([latents] * 2) 189 | if do_classifier_free_guidance 190 | else latents 191 | ) 192 | 193 | # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas 194 | if isinstance(self.scheduler, LMSDiscreteScheduler): 195 | sigma = self.scheduler.sigmas[t_index] 196 | # the model input needs to be scaled to match the continuous ODE formulation in K-LMS 197 | latent_model_input = latent_model_input / ( 198 | (sigma**2 + 1) ** 0.5 199 | ) 200 | latent_model_input = latent_model_input.to(self.unet.dtype) 201 | t = t.to(self.unet.dtype) 202 | 203 | # predict the noise residual 204 | noise_pred = self.unet( 205 | latent_model_input, t, encoder_hidden_states=text_embeddings 206 | ).sample 207 | 208 | # perform guidance 209 | if do_classifier_free_guidance: 210 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 211 | noise_pred = noise_pred_uncond + guidance_scale * ( 212 | noise_pred_text - noise_pred_uncond 213 | ) 214 | 215 | # compute the previous noisy sample x_t -> x_t-1 216 | if isinstance(self.scheduler, LMSDiscreteScheduler): 217 | latents = self.scheduler.step( 218 | noise_pred, t_index, latents, **extra_step_kwargs 219 | ).prev_sample 220 | else: 221 | latents = self.scheduler.step( 222 | noise_pred, t, latents, **extra_step_kwargs 223 | ).prev_sample 224 | 225 | if callbacks is None: 226 | continue 227 | 228 | for custom_callback in callbacks: 229 | custom_callback( 230 | latents=latents, 231 | noise_pred=noise_pred, 232 | ) 233 | # scale and decode the image latents with vae 234 | latents = 1 / 0.18215 * latents 235 | image = self.vae.decode(latents.to(self.vae.dtype)).sample 236 | 237 | image = (image / 2 + 0.5).clamp(0, 1) 238 | image = image.cpu().permute(0, 2, 3, 1).numpy() 239 | image = self.numpy_to_pil(image) 240 | return image 241 | -------------------------------------------------------------------------------- /app/stable_diffusion/pipeline/inpaint.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import List, Optional, Union 3 | from transformers import CLIPTextModel, CLIPTokenizer 4 | from diffusers import DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler 5 | from diffusers import AutoencoderKL, UNet2DConditionModel 6 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( 7 | DiffusionPipeline, 8 | ) 9 | import torch 10 | import PIL 11 | import numpy as np 12 | from .types import StableDiffusionCallback 13 | 14 | 15 | def preprocess_image(image): 16 | w, h = image.size 17 | w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 18 | image = image.resize((w, h), resample=PIL.Image.LANCZOS) 19 | image = np.array(image).astype(np.float32) / 255.0 20 | image = image[None].transpose(0, 3, 1, 2) 21 | image = torch.from_numpy(image) 22 | return 2.0 * image - 1.0 23 | 24 | 25 | def preprocess_mask(mask): 26 | mask = mask.convert("L") 27 | w, h = mask.size 28 | w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 29 | mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST) 30 | mask = np.array(mask).astype(np.float32) / 255.0 31 | mask = np.tile(mask, (4, 1, 1)) 32 | mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? 33 | mask = 1 - mask # repaint white, keep black 34 | mask = torch.from_numpy(mask) 35 | return mask 36 | 37 | 38 | class StableDiffusionInpaintPipeline(DiffusionPipeline): 39 | def __init__( 40 | self, 41 | vae: AutoencoderKL, 42 | text_encoder: CLIPTextModel, 43 | tokenizer: CLIPTokenizer, 44 | unet: UNet2DConditionModel, 45 | scheduler: Union[DDIMScheduler, PNDMScheduler], 46 | ): 47 | super().__init__() 48 | scheduler = scheduler.set_format("pt") 49 | self.register_modules( 50 | vae=vae, 51 | text_encoder=text_encoder, 52 | tokenizer=tokenizer, 53 | unet=unet, 54 | scheduler=scheduler, 55 | ) 56 | 57 | def enable_attention_slicing( 58 | self, 59 | slice_size: Optional[Union[str, int]] = "auto", 60 | ): 61 | if slice_size == "auto": 62 | slice_size = self.unet.config.attention_head_dim // 2 63 | self.unet.set_attention_slice(slice_size) 64 | 65 | def disable_attention_slicing(self): 66 | self.enable_attention_slicing(None) 67 | 68 | @torch.no_grad() 69 | def __call__( 70 | self, 71 | prompt: Union[str, List[str]], 72 | negative_prompt: Union[str, List[str]], 73 | init_image: Union[torch.FloatTensor, PIL.Image.Image], 74 | mask_image: Union[torch.FloatTensor, PIL.Image.Image], 75 | strength: float = 0.8, 76 | num_inference_steps: Optional[int] = 50, 77 | guidance_scale: Optional[float] = 7.5, 78 | eta: Optional[float] = 0.0, 79 | generator: Optional[torch.Generator] = None, 80 | callbacks: Optional[List[StableDiffusionCallback]] = None, 81 | **kwargs, 82 | ): 83 | if isinstance(prompt, str): 84 | batch_size = 1 85 | elif isinstance(prompt, list): 86 | batch_size = len(prompt) 87 | else: 88 | raise ValueError( 89 | f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" 90 | ) 91 | 92 | if strength < 0 or strength > 1: 93 | raise ValueError( 94 | f"The value of strength should in [0.0, 1.0] but is {strength}" 95 | ) 96 | 97 | # set timesteps 98 | accepts_offset = "offset" in set( 99 | inspect.signature(self.scheduler.set_timesteps).parameters.keys() 100 | ) 101 | extra_set_kwargs = {} 102 | offset = 0 103 | if accepts_offset: 104 | offset = 1 105 | extra_set_kwargs["offset"] = 1 106 | 107 | self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) 108 | 109 | # preprocess image 110 | init_image = preprocess_image(init_image).to(self.device) 111 | 112 | # encode the init image into latents and scale the latents 113 | init_latent_dist = self.vae.encode( 114 | init_image.to(self.device) 115 | ).latent_dist 116 | init_latents = init_latent_dist.sample(generator=generator) 117 | 118 | init_latents = 0.18215 * init_latents 119 | 120 | # Expand init_latents for batch_size 121 | init_latents = torch.cat([init_latents] * batch_size) 122 | init_latents_orig = init_latents 123 | 124 | # preprocess mask 125 | mask = preprocess_mask(mask_image).to(self.device) 126 | mask = torch.cat([mask] * batch_size) 127 | 128 | # check sizes 129 | if not mask.shape == init_latents.shape: 130 | raise ValueError("The mask and init_image should be the same size!") 131 | 132 | # get the original timestep using init_timestep 133 | init_timestep = int(num_inference_steps * strength) + offset 134 | init_timestep = min(init_timestep, num_inference_steps) 135 | timesteps = self.scheduler.timesteps[-init_timestep] 136 | timesteps = torch.tensor( 137 | [timesteps] * batch_size, dtype=torch.long, device=self.device 138 | ) 139 | 140 | # add noise to latents using the timesteps 141 | noise = torch.randn( 142 | init_latents.shape, generator=generator, device=self.device 143 | ) 144 | init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) 145 | 146 | # get prompt text embeddings 147 | text_input = self.tokenizer( 148 | prompt, 149 | padding="max_length", 150 | max_length=self.tokenizer.model_max_length, 151 | truncation=True, 152 | return_tensors="pt", 153 | ) 154 | text_embeddings = self.text_encoder( 155 | text_input.input_ids.to(self.device) 156 | )[0] 157 | 158 | do_classifier_free_guidance = guidance_scale > 1.0 159 | # get unconditional embeddings for classifier free guidance 160 | if do_classifier_free_guidance: 161 | max_length = text_input.input_ids.shape[-1] 162 | uncond_input = self.tokenizer( 163 | negative_prompt, 164 | padding="max_length", 165 | max_length=max_length, 166 | return_tensors="pt", 167 | ) 168 | uncond_embeddings = self.text_encoder( 169 | uncond_input.input_ids.to(self.device) 170 | )[0] 171 | 172 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 173 | 174 | accepts_eta = "eta" in set( 175 | inspect.signature(self.scheduler.step).parameters.keys() 176 | ) 177 | extra_step_kwargs = {} 178 | if accepts_eta: 179 | extra_step_kwargs["eta"] = eta 180 | 181 | latents = init_latents 182 | t_start = max(num_inference_steps - init_timestep + offset, 0) 183 | for i, t in enumerate(self.scheduler.timesteps[t_start:]): 184 | # expand the latents if we are doing classifier free guidance 185 | latent_model_input = ( 186 | torch.cat([latents] * 2) 187 | if do_classifier_free_guidance 188 | else latents 189 | ) 190 | 191 | # predict the noise residual 192 | noise_pred = self.unet( 193 | latent_model_input, t, encoder_hidden_states=text_embeddings 194 | ).sample 195 | 196 | # perform guidance 197 | if do_classifier_free_guidance: 198 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 199 | noise_pred = noise_pred_uncond + guidance_scale * ( 200 | noise_pred_text - noise_pred_uncond 201 | ) 202 | 203 | # compute the previous noisy sample x_t -> x_t-1 204 | latents = self.scheduler.step( 205 | noise_pred, t, latents, **extra_step_kwargs 206 | ).prev_sample 207 | 208 | # masking 209 | init_latents_proper = self.scheduler.add_noise( 210 | init_latents_orig, noise, t 211 | ) 212 | latents = (init_latents_proper * mask) + (latents * (1 - mask)) 213 | 214 | if callbacks is None: 215 | continue 216 | 217 | for custom_callback in callbacks: 218 | custom_callback( 219 | latents=latents, 220 | noise_pred=noise_pred, 221 | ) 222 | 223 | # scale and decode the image latents with vae 224 | latents = 1 / 0.18215 * latents 225 | image = self.vae.decode(latents).sample 226 | 227 | image = (image / 2 + 0.5).clamp(0, 1) 228 | image = image.cpu().permute(0, 2, 3, 1).numpy() 229 | 230 | image = self.numpy_to_pil(image) 231 | return image 232 | -------------------------------------------------------------------------------- /app/stable_diffusion/pipeline/text2image.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import List, Optional, Union 3 | from transformers import CLIPTextModel, CLIPTokenizer 4 | from diffusers import DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler 5 | from diffusers import AutoencoderKL, UNet2DConditionModel 6 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( 7 | DiffusionPipeline, 8 | ) 9 | import torch 10 | from .types import StableDiffusionCallback 11 | 12 | 13 | class StableDiffusionText2ImagePipeline(DiffusionPipeline): 14 | def __init__( 15 | self, 16 | vae: AutoencoderKL, 17 | text_encoder: CLIPTextModel, 18 | tokenizer: CLIPTokenizer, 19 | unet: UNet2DConditionModel, 20 | scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], 21 | ): 22 | super().__init__() 23 | scheduler = scheduler.set_format("pt") 24 | self.register_modules( 25 | vae=vae, 26 | text_encoder=text_encoder, 27 | tokenizer=tokenizer, 28 | unet=unet, 29 | scheduler=scheduler, 30 | ) 31 | 32 | def enable_attention_slicing( 33 | self, slice_size: Optional[Union[str, int]] = "auto" 34 | ): 35 | if slice_size == "auto": 36 | # half the attention head size is usually a good trade-off between 37 | # speed and memory 38 | slice_size = self.unet.config.attention_head_dim // 2 39 | self.unet.set_attention_slice(slice_size) 40 | 41 | def disable_attention_slicing(self): 42 | # set slice_size = `None` to disable `attention slicing` 43 | self.enable_attention_slicing(None) 44 | 45 | @torch.inference_mode() 46 | def __call__( 47 | self, 48 | prompt: Union[str, List[str]], 49 | negative_prompt: Union[str, List[str]] = "", 50 | height: Optional[int] = 512, 51 | width: Optional[int] = 512, 52 | num_inference_steps: Optional[int] = 50, 53 | guidance_scale: Optional[float] = 7.5, 54 | eta: Optional[float] = 0.0, 55 | generator: Optional[torch.Generator] = None, 56 | latents: Optional[torch.FloatTensor] = None, 57 | callbacks: Optional[List[StableDiffusionCallback]] = None, 58 | **kwargs, 59 | ): 60 | if isinstance(prompt, str): 61 | batch_size = 1 62 | elif isinstance(prompt, list): 63 | batch_size = len(prompt) 64 | 65 | if height % 8 != 0 or width % 8 != 0: 66 | raise ValueError( 67 | f"`height` and `width` have to be divisible by 8 but are {height} and {width}." 68 | ) 69 | 70 | text_input = self.tokenizer( 71 | prompt, 72 | padding="max_length", 73 | max_length=self.tokenizer.model_max_length, 74 | truncation=True, 75 | return_tensors="pt", 76 | ) 77 | text_embeddings = self.text_encoder( 78 | text_input.input_ids.to(self.device) 79 | )[0] 80 | 81 | do_classifier_free_guidance = guidance_scale > 1.0 82 | if do_classifier_free_guidance: 83 | max_length = text_input.input_ids.shape[-1] 84 | uncond_input = self.tokenizer( 85 | negative_prompt, 86 | padding="max_length", 87 | max_length=max_length, 88 | return_tensors="pt", 89 | ) 90 | uncond_embeddings = self.text_encoder( 91 | uncond_input.input_ids.to(self.device) 92 | )[0] 93 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 94 | 95 | latents_device = "cpu" if self.device.type == "mps" else self.device 96 | latents_shape = ( 97 | batch_size, 98 | self.unet.in_channels, 99 | height // 8, 100 | width // 8, 101 | ) 102 | if latents is None: 103 | latents = torch.randn( 104 | latents_shape, 105 | generator=generator, 106 | device=latents_device, 107 | ) 108 | else: 109 | if latents.shape != latents_shape: 110 | raise ValueError( 111 | f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}" 112 | ) 113 | latents = latents.to(self.device) 114 | 115 | # set timesteps 116 | accepts_offset = "offset" in set( 117 | inspect.signature(self.scheduler.set_timesteps).parameters.keys() 118 | ) 119 | extra_set_kwargs = {} 120 | if accepts_offset: 121 | extra_set_kwargs["offset"] = 1 122 | 123 | self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) 124 | 125 | # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas 126 | if isinstance(self.scheduler, LMSDiscreteScheduler): 127 | latents = latents * self.scheduler.sigmas[0] 128 | 129 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 130 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 131 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 132 | # and should be between [0, 1] 133 | accepts_eta = "eta" in set( 134 | inspect.signature(self.scheduler.step).parameters.keys() 135 | ) 136 | extra_step_kwargs = {} 137 | if accepts_eta: 138 | extra_step_kwargs["eta"] = eta 139 | 140 | for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): 141 | # expand the latents if we are doing classifier free guidance 142 | latent_model_input = ( 143 | torch.cat([latents] * 2) 144 | if do_classifier_free_guidance 145 | else latents 146 | ) 147 | if isinstance(self.scheduler, LMSDiscreteScheduler): 148 | sigma = self.scheduler.sigmas[i] 149 | # the model input needs to be scaled to match the continuous ODE formulation in K-LMS 150 | latent_model_input = latent_model_input / ( 151 | (sigma**2 + 1) ** 0.5 152 | ) 153 | 154 | # predict the noise residual 155 | noise_pred = self.unet( 156 | latent_model_input, t, encoder_hidden_states=text_embeddings 157 | ).sample 158 | 159 | # perform guidance 160 | if do_classifier_free_guidance: 161 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 162 | noise_pred = noise_pred_uncond + guidance_scale * ( 163 | noise_pred_text - noise_pred_uncond 164 | ) 165 | 166 | # compute the previous noisy sample x_t -> x_t-1 167 | if isinstance(self.scheduler, LMSDiscreteScheduler): 168 | latents = self.scheduler.step( 169 | noise_pred, i, latents, **extra_step_kwargs 170 | ).prev_sample 171 | else: 172 | latents = self.scheduler.step( 173 | noise_pred, t, latents, **extra_step_kwargs 174 | ).prev_sample 175 | 176 | if callbacks is None: 177 | continue 178 | 179 | for custom_callback in callbacks: 180 | custom_callback( 181 | latents=latents, 182 | noise_pred=noise_pred, 183 | ) 184 | 185 | # scale and decode the image latents with vae 186 | latents = 1 / 0.18215 * latents 187 | image = self.vae.decode(latents).sample 188 | 189 | image = (image / 2 + 0.5).clamp(0, 1) 190 | image = image.cpu().permute(0, 2, 3, 1).numpy() 191 | image = self.numpy_to_pil(image) 192 | return image 193 | -------------------------------------------------------------------------------- /app/stable_diffusion/pipeline/types.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class StableDiffusionCallback(metaclass=abc.ABCMeta): 5 | @abc.abstractmethod 6 | def __call__(self, **kwds): 7 | pass 8 | -------------------------------------------------------------------------------- /app/stable_diffusion/service.py: -------------------------------------------------------------------------------- 1 | import os 2 | import typing as T 3 | from itertools import chain, islice 4 | from service_streamer import ThreadedStreamer 5 | 6 | import torch 7 | from core.settings import get_settings 8 | from fastapi import Depends 9 | from loguru import logger 10 | from PIL import Image 11 | import json 12 | 13 | 14 | from .manager import ( 15 | Text2ImageTask, 16 | Image2ImageTask, 17 | InpaintTask, 18 | ) 19 | 20 | from .manager import ( 21 | build_streamer, 22 | ) 23 | 24 | 25 | env = get_settings() 26 | 27 | 28 | def data_to_batch(datasets: T.List[T.Any], batch_size: int): 29 | iterator = iter(datasets) 30 | for first in iterator: 31 | yield list(chain([first], islice(iterator, batch_size - 1))) 32 | 33 | 34 | class StableDiffusionService: 35 | def __init__( 36 | self, 37 | streamer: ThreadedStreamer = Depends(build_streamer), 38 | ) -> None: 39 | logger.info(f"DI:{self.__class__.__name__}") 40 | self.streamer = streamer 41 | 42 | @torch.inference_mode() 43 | def text2image( 44 | self, 45 | prompt: str, 46 | negative_prompt: str = "", 47 | num_images: int = 1, 48 | num_inference_steps: int = 50, 49 | guidance_scale: float = 8.5, 50 | height=512, 51 | width=512, 52 | seed: T.Optional[int] = None, 53 | ) -> T.List[Image.Image]: 54 | prompts = [prompt] * num_images 55 | 56 | tasks = [ 57 | Text2ImageTask( 58 | prompt=prompt, 59 | negative_prompt=[negative_prompt] * len(prompt), 60 | num_inference_steps=num_inference_steps, 61 | guidance_scale=guidance_scale, 62 | height=height, 63 | width=width, 64 | seed=seed, 65 | ) 66 | for prompt in data_to_batch(prompts, batch_size=env.MB_BATCH_SIZE) 67 | ] 68 | images = self._summit(tasks) 69 | return images 70 | 71 | @torch.inference_mode() 72 | def image2image( 73 | self, 74 | prompt: str, 75 | negative_prompt: str, 76 | init_image: Image.Image, 77 | num_images: int = 1, 78 | strength: float = 0.8, 79 | num_inference_steps: int = 50, 80 | guidance_scale: float = 8.5, 81 | seed: int = 203, 82 | ) -> T.List[Image.Image]: 83 | origin_size = init_image.size 84 | w, h = origin_size 85 | w, h = map(lambda x: x - x % 64, (w, h)) 86 | if origin_size != (w, h): 87 | init_image = init_image.resize((w, h), resample=Image.LANCZOS) 88 | 89 | prompts = [prompt] * num_images 90 | 91 | tasks = [ 92 | Image2ImageTask( 93 | prompt=prompt, 94 | negative_prompt=[negative_prompt] * len(prompt), 95 | image=init_image, 96 | strength=strength, 97 | num_inference_steps=num_inference_steps, 98 | guidance_scale=guidance_scale, 99 | seed=seed, 100 | ) 101 | for prompt in data_to_batch(prompts, batch_size=env.MB_BATCH_SIZE) 102 | ] 103 | 104 | images = self._summit(tasks) 105 | images = self.postprocess(images, origin_size=origin_size) 106 | return images 107 | 108 | @torch.inference_mode() 109 | def inpaint( 110 | self, 111 | prompt: str, 112 | negative_prompt: str, 113 | init_image: Image.Image, 114 | mask_image: Image.Image, 115 | strength: float, 116 | num_images: int = 1, 117 | num_inference_steps: int = 50, 118 | guidance_scale: float = 8.5, 119 | seed: int = 203, 120 | ) -> T.List[Image.Image]: 121 | origin_size = init_image.size 122 | w, h = origin_size 123 | w, h = map(lambda x: x - x % 64, (w, h)) 124 | if origin_size != (w, h): 125 | init_image = init_image.resize((w, h), resample=Image.LANCZOS) 126 | mask_image = mask_image.resize((w, h), resample=Image.NEAREST) 127 | 128 | prompts = [prompt] * num_images 129 | 130 | tasks = [ 131 | InpaintTask( 132 | prompt=prompt, 133 | negative_prompt=[negative_prompt] * len(prompt), 134 | init_image=init_image, 135 | mask_image=mask_image, 136 | strength=strength, 137 | num_inference_steps=num_inference_steps, 138 | guidance_scale=guidance_scale, 139 | seed=seed, 140 | ) 141 | for prompt in data_to_batch(prompts, batch_size=env.MB_BATCH_SIZE) 142 | ] 143 | 144 | images = self._summit(tasks) 145 | images = self.postprocess(images, origin_size=origin_size) 146 | return images 147 | 148 | def _summit(self, tasks) -> T.List[Image.Image]: 149 | future = self.streamer.submit(tasks) 150 | batchs = future.result(timeout=env.MB_TIMEOUT) 151 | images = [] 152 | for batch in batchs: 153 | images += batch 154 | return images 155 | 156 | @classmethod 157 | def postprocess( 158 | cls, images: T.List[Image.Image], origin_size: T.Tuple[int, int] 159 | ): 160 | if origin_size == images[0].size: 161 | return images 162 | for i, image in enumerate(images): 163 | images[i] = image.resize(origin_size) 164 | return images 165 | 166 | @staticmethod 167 | def image_save(images: T.List[Image.Image], task_id: str, info: dict): 168 | save_dir = os.path.join(env.SAVE_DIR, task_id) 169 | os.makedirs(save_dir) 170 | image_urls = [] 171 | 172 | with open(os.path.join(save_dir, "info.json"), "w") as f: 173 | json.dump(info, f) 174 | 175 | for i, image in enumerate(images): 176 | filename = f"{str(i).zfill(2)}.webp" 177 | save_path = os.path.join(env.SAVE_DIR, task_id, filename) 178 | image_url = os.path.join(task_id, filename) 179 | image.save(save_path) 180 | image_urls.append(image_url) 181 | return image_urls 182 | -------------------------------------------------------------------------------- /core/decorator/singleton.py: -------------------------------------------------------------------------------- 1 | def singleton(cls): 2 | instances = {} 3 | 4 | def wrapper(*args, **kwargs): 5 | if cls not in instances: 6 | instances[cls] = cls(*args, **kwargs) 7 | return instances[cls] 8 | 9 | return wrapper 10 | -------------------------------------------------------------------------------- /core/dependencies/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rapidrabbit76/stable-diffusion-API/25e2ae5bc83837865a48151db71447e66fabbfd0/core/dependencies/__init__.py -------------------------------------------------------------------------------- /core/dependencies/models.py: -------------------------------------------------------------------------------- 1 | from app.stable_diffusion.manager import build_streamer 2 | 3 | build_streamer() 4 | -------------------------------------------------------------------------------- /core/logger/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rapidrabbit76/stable-diffusion-API/25e2ae5bc83837865a48151db71447e66fabbfd0/core/logger/__init__.py -------------------------------------------------------------------------------- /core/middlewares/logging.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | from fastapi import Request, Response 3 | from starlette.responses import StreamingResponse 4 | -------------------------------------------------------------------------------- /core/settings/__init__.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | from .settings import Settings 3 | 4 | 5 | @lru_cache() 6 | def get_settings() -> Settings: 7 | setting = Settings("env/dev.env") 8 | return setting 9 | 10 | 11 | env = get_settings() 12 | -------------------------------------------------------------------------------- /core/settings/enum.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class DeploymentMode(str, Enum): 5 | DEV = "dev" 6 | TEST = "test" 7 | PRODUCTION = "prod" 8 | -------------------------------------------------------------------------------- /core/settings/settings.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import typing as T 3 | 4 | from pydantic import BaseSettings 5 | 6 | 7 | class ModelSetting(BaseSettings): 8 | MODEL_ID: str = "CompVis/stable-diffusion-v1-4" 9 | ENABLE_ATTENTION_SLICING: bool = True 10 | 11 | 12 | class DeviceSettings(BaseSettings): 13 | CUDA_DEVICE = "cuda" 14 | CUDA_DEVICES = [0] 15 | 16 | 17 | class MicroBatchSettings(BaseSettings): 18 | MB_BATCH_SIZE = 2 19 | MB_TIMEOUT = 600 20 | 21 | 22 | class Settings( 23 | ModelSetting, 24 | DeviceSettings, 25 | MicroBatchSettings, 26 | ): 27 | HUGGINGFACE_TOKEN: str = "HUGGINGFACE_TOKEN" 28 | IMAGESERVER_URL: str = "http://localhost:3000/images" 29 | SAVE_DIR: str = "static" 30 | 31 | CORS_ALLOW_ORIGINS: T.List[str] = ["*"] 32 | CORS_CREDENTIALS: bool = True 33 | CORS_ALLOW_METHODS: T.List[str] = ["*"] 34 | CORS_ALLOW_HEADERS: T.List[str] = ["*"] 35 | -------------------------------------------------------------------------------- /core/utils/convert_script.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | # THIS SCRIPTS FROM: https://github.com/huggingface/diffusers/blob/main/scripts/convert_original_stable_diffusion_to_diffusers.py 18 | 19 | """ Conversion script for the LDM checkpoints. """ 20 | 21 | import argparse 22 | import os 23 | 24 | import torch 25 | 26 | 27 | try: 28 | from omegaconf import OmegaConf 29 | except ImportError: 30 | raise ImportError( 31 | "OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`." 32 | ) 33 | 34 | from diffusers import ( 35 | AutoencoderKL, 36 | DDIMScheduler, 37 | DPMSolverMultistepScheduler, 38 | EulerAncestralDiscreteScheduler, 39 | EulerDiscreteScheduler, 40 | LDMTextToImagePipeline, 41 | LMSDiscreteScheduler, 42 | PNDMScheduler, 43 | StableDiffusionPipeline, 44 | UNet2DConditionModel, 45 | ) 46 | from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import ( 47 | LDMBertConfig, 48 | LDMBertModel, 49 | ) 50 | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker 51 | from transformers import ( 52 | AutoFeatureExtractor, 53 | BertTokenizerFast, 54 | CLIPTextModel, 55 | CLIPTokenizer, 56 | ) 57 | 58 | 59 | def shave_segments(path, n_shave_prefix_segments=1): 60 | """ 61 | Removes segments. Positive values shave the first segments, negative shave the last segments. 62 | """ 63 | if n_shave_prefix_segments >= 0: 64 | return ".".join(path.split(".")[n_shave_prefix_segments:]) 65 | else: 66 | return ".".join(path.split(".")[:n_shave_prefix_segments]) 67 | 68 | 69 | def renew_resnet_paths(old_list, n_shave_prefix_segments=0): 70 | """ 71 | Updates paths inside resnets to the new naming scheme (local renaming) 72 | """ 73 | mapping = [] 74 | for old_item in old_list: 75 | new_item = old_item.replace("in_layers.0", "norm1") 76 | new_item = new_item.replace("in_layers.2", "conv1") 77 | 78 | new_item = new_item.replace("out_layers.0", "norm2") 79 | new_item = new_item.replace("out_layers.3", "conv2") 80 | 81 | new_item = new_item.replace("emb_layers.1", "time_emb_proj") 82 | new_item = new_item.replace("skip_connection", "conv_shortcut") 83 | 84 | new_item = shave_segments( 85 | new_item, n_shave_prefix_segments=n_shave_prefix_segments 86 | ) 87 | 88 | mapping.append({"old": old_item, "new": new_item}) 89 | 90 | return mapping 91 | 92 | 93 | def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): 94 | """ 95 | Updates paths inside resnets to the new naming scheme (local renaming) 96 | """ 97 | mapping = [] 98 | for old_item in old_list: 99 | new_item = old_item 100 | 101 | new_item = new_item.replace("nin_shortcut", "conv_shortcut") 102 | new_item = shave_segments( 103 | new_item, n_shave_prefix_segments=n_shave_prefix_segments 104 | ) 105 | 106 | mapping.append({"old": old_item, "new": new_item}) 107 | 108 | return mapping 109 | 110 | 111 | def renew_attention_paths(old_list, n_shave_prefix_segments=0): 112 | """ 113 | Updates paths inside attentions to the new naming scheme (local renaming) 114 | """ 115 | mapping = [] 116 | for old_item in old_list: 117 | new_item = old_item 118 | 119 | # new_item = new_item.replace('norm.weight', 'group_norm.weight') 120 | # new_item = new_item.replace('norm.bias', 'group_norm.bias') 121 | 122 | # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') 123 | # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') 124 | 125 | # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) 126 | 127 | mapping.append({"old": old_item, "new": new_item}) 128 | 129 | return mapping 130 | 131 | 132 | def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): 133 | """ 134 | Updates paths inside attentions to the new naming scheme (local renaming) 135 | """ 136 | mapping = [] 137 | for old_item in old_list: 138 | new_item = old_item 139 | 140 | new_item = new_item.replace("norm.weight", "group_norm.weight") 141 | new_item = new_item.replace("norm.bias", "group_norm.bias") 142 | 143 | new_item = new_item.replace("q.weight", "query.weight") 144 | new_item = new_item.replace("q.bias", "query.bias") 145 | 146 | new_item = new_item.replace("k.weight", "key.weight") 147 | new_item = new_item.replace("k.bias", "key.bias") 148 | 149 | new_item = new_item.replace("v.weight", "value.weight") 150 | new_item = new_item.replace("v.bias", "value.bias") 151 | 152 | new_item = new_item.replace("proj_out.weight", "proj_attn.weight") 153 | new_item = new_item.replace("proj_out.bias", "proj_attn.bias") 154 | 155 | new_item = shave_segments( 156 | new_item, n_shave_prefix_segments=n_shave_prefix_segments 157 | ) 158 | 159 | mapping.append({"old": old_item, "new": new_item}) 160 | 161 | return mapping 162 | 163 | 164 | def assign_to_checkpoint( 165 | paths, 166 | checkpoint, 167 | old_checkpoint, 168 | attention_paths_to_split=None, 169 | additional_replacements=None, 170 | config=None, 171 | ): 172 | """ 173 | This does the final conversion step: take locally converted weights and apply a global renaming 174 | to them. It splits attention layers, and takes into account additional replacements 175 | that may arise. 176 | Assigns the weights to the new checkpoint. 177 | """ 178 | assert isinstance( 179 | paths, list 180 | ), "Paths should be a list of dicts containing 'old' and 'new' keys." 181 | 182 | # Splits the attention layers into three variables. 183 | if attention_paths_to_split is not None: 184 | for path, path_map in attention_paths_to_split.items(): 185 | old_tensor = old_checkpoint[path] 186 | channels = old_tensor.shape[0] // 3 187 | 188 | target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) 189 | 190 | num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 191 | 192 | old_tensor = old_tensor.reshape( 193 | (num_heads, 3 * channels // num_heads) + old_tensor.shape[1:] 194 | ) 195 | query, key, value = old_tensor.split(channels // num_heads, dim=1) 196 | 197 | checkpoint[path_map["query"]] = query.reshape(target_shape) 198 | checkpoint[path_map["key"]] = key.reshape(target_shape) 199 | checkpoint[path_map["value"]] = value.reshape(target_shape) 200 | 201 | for path in paths: 202 | new_path = path["new"] 203 | 204 | # These have already been assigned 205 | if ( 206 | attention_paths_to_split is not None 207 | and new_path in attention_paths_to_split 208 | ): 209 | continue 210 | 211 | # Global renaming happens here 212 | new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") 213 | new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") 214 | new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") 215 | 216 | if additional_replacements is not None: 217 | for replacement in additional_replacements: 218 | new_path = new_path.replace(replacement["old"], replacement["new"]) 219 | 220 | # proj_attn.weight has to be converted from conv 1D to linear 221 | if "proj_attn.weight" in new_path: 222 | checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] 223 | else: 224 | checkpoint[new_path] = old_checkpoint[path["old"]] 225 | 226 | 227 | def conv_attn_to_linear(checkpoint): 228 | keys = list(checkpoint.keys()) 229 | attn_keys = ["query.weight", "key.weight", "value.weight"] 230 | for key in keys: 231 | if ".".join(key.split(".")[-2:]) in attn_keys: 232 | if checkpoint[key].ndim > 2: 233 | checkpoint[key] = checkpoint[key][:, :, 0, 0] 234 | elif "proj_attn.weight" in key: 235 | if checkpoint[key].ndim > 2: 236 | checkpoint[key] = checkpoint[key][:, :, 0] 237 | 238 | 239 | def create_unet_diffusers_config(original_config): 240 | """ 241 | Creates a config for the diffusers based on the config of the LDM model. 242 | """ 243 | model_params = original_config.model.params 244 | unet_params = original_config.model.params.unet_config.params 245 | 246 | block_out_channels = [ 247 | unet_params.model_channels * mult for mult in unet_params.channel_mult 248 | ] 249 | 250 | down_block_types = [] 251 | resolution = 1 252 | for i in range(len(block_out_channels)): 253 | block_type = ( 254 | "CrossAttnDownBlock2D" 255 | if resolution in unet_params.attention_resolutions 256 | else "DownBlock2D" 257 | ) 258 | down_block_types.append(block_type) 259 | if i != len(block_out_channels) - 1: 260 | resolution *= 2 261 | 262 | up_block_types = [] 263 | for i in range(len(block_out_channels)): 264 | block_type = ( 265 | "CrossAttnUpBlock2D" 266 | if resolution in unet_params.attention_resolutions 267 | else "UpBlock2D" 268 | ) 269 | up_block_types.append(block_type) 270 | resolution //= 2 271 | 272 | config = dict( 273 | sample_size=model_params.image_size, 274 | in_channels=unet_params.in_channels, 275 | out_channels=unet_params.out_channels, 276 | down_block_types=tuple(down_block_types), 277 | up_block_types=tuple(up_block_types), 278 | block_out_channels=tuple(block_out_channels), 279 | layers_per_block=unet_params.num_res_blocks, 280 | cross_attention_dim=unet_params.context_dim, 281 | attention_head_dim=unet_params.num_heads, 282 | ) 283 | 284 | return config 285 | 286 | 287 | def create_vae_diffusers_config(original_config): 288 | """ 289 | Creates a config for the diffusers based on the config of the LDM model. 290 | """ 291 | vae_params = original_config.model.params.first_stage_config.params.ddconfig 292 | _ = original_config.model.params.first_stage_config.params.embed_dim 293 | 294 | block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] 295 | down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) 296 | up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) 297 | 298 | config = dict( 299 | sample_size=vae_params.resolution, 300 | in_channels=vae_params.in_channels, 301 | out_channels=vae_params.out_ch, 302 | down_block_types=tuple(down_block_types), 303 | up_block_types=tuple(up_block_types), 304 | block_out_channels=tuple(block_out_channels), 305 | latent_channels=vae_params.z_channels, 306 | layers_per_block=vae_params.num_res_blocks, 307 | ) 308 | return config 309 | 310 | 311 | def create_diffusers_schedular(original_config): 312 | schedular = DDIMScheduler( 313 | num_train_timesteps=original_config.model.params.timesteps, 314 | beta_start=original_config.model.params.linear_start, 315 | beta_end=original_config.model.params.linear_end, 316 | beta_schedule="scaled_linear", 317 | ) 318 | return schedular 319 | 320 | 321 | def create_ldm_bert_config(original_config): 322 | bert_params = original_config.model.parms.cond_stage_config.params 323 | config = LDMBertConfig( 324 | d_model=bert_params.n_embed, 325 | encoder_layers=bert_params.n_layer, 326 | encoder_ffn_dim=bert_params.n_embed * 4, 327 | ) 328 | return config 329 | 330 | 331 | def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False): 332 | """ 333 | Takes a state dict and a config, and returns a converted checkpoint. 334 | """ 335 | 336 | # extract state_dict for UNet 337 | unet_state_dict = {} 338 | keys = list(checkpoint.keys()) 339 | 340 | unet_key = "model.diffusion_model." 341 | # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA 342 | if sum(k.startswith("model_ema") for k in keys) > 100: 343 | print(f"Checkpoint {path} has both EMA and non-EMA weights.") 344 | if extract_ema: 345 | print( 346 | "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" 347 | " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." 348 | ) 349 | for key in keys: 350 | if key.startswith("model.diffusion_model"): 351 | flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) 352 | unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop( 353 | flat_ema_key 354 | ) 355 | else: 356 | print( 357 | "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" 358 | " weights (usually better for inference), please make sure to add the `--extract_ema` flag." 359 | ) 360 | 361 | for key in keys: 362 | if key.startswith(unet_key): 363 | unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) 364 | 365 | new_checkpoint = {} 366 | 367 | new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict[ 368 | "time_embed.0.weight" 369 | ] 370 | new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict[ 371 | "time_embed.0.bias" 372 | ] 373 | new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict[ 374 | "time_embed.2.weight" 375 | ] 376 | new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict[ 377 | "time_embed.2.bias" 378 | ] 379 | 380 | new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] 381 | new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] 382 | 383 | new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] 384 | new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] 385 | new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] 386 | new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] 387 | 388 | # Retrieves the keys for the input blocks only 389 | num_input_blocks = len( 390 | { 391 | ".".join(layer.split(".")[:2]) 392 | for layer in unet_state_dict 393 | if "input_blocks" in layer 394 | } 395 | ) 396 | input_blocks = { 397 | layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] 398 | for layer_id in range(num_input_blocks) 399 | } 400 | 401 | # Retrieves the keys for the middle blocks only 402 | num_middle_blocks = len( 403 | { 404 | ".".join(layer.split(".")[:2]) 405 | for layer in unet_state_dict 406 | if "middle_block" in layer 407 | } 408 | ) 409 | middle_blocks = { 410 | layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] 411 | for layer_id in range(num_middle_blocks) 412 | } 413 | 414 | # Retrieves the keys for the output blocks only 415 | num_output_blocks = len( 416 | { 417 | ".".join(layer.split(".")[:2]) 418 | for layer in unet_state_dict 419 | if "output_blocks" in layer 420 | } 421 | ) 422 | output_blocks = { 423 | layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] 424 | for layer_id in range(num_output_blocks) 425 | } 426 | 427 | for i in range(1, num_input_blocks): 428 | block_id = (i - 1) // (config["layers_per_block"] + 1) 429 | layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) 430 | 431 | resnets = [ 432 | key 433 | for key in input_blocks[i] 434 | if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key 435 | ] 436 | attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] 437 | 438 | if f"input_blocks.{i}.0.op.weight" in unet_state_dict: 439 | new_checkpoint[ 440 | f"down_blocks.{block_id}.downsamplers.0.conv.weight" 441 | ] = unet_state_dict.pop(f"input_blocks.{i}.0.op.weight") 442 | new_checkpoint[ 443 | f"down_blocks.{block_id}.downsamplers.0.conv.bias" 444 | ] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias") 445 | 446 | paths = renew_resnet_paths(resnets) 447 | meta_path = { 448 | "old": f"input_blocks.{i}.0", 449 | "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}", 450 | } 451 | assign_to_checkpoint( 452 | paths, 453 | new_checkpoint, 454 | unet_state_dict, 455 | additional_replacements=[meta_path], 456 | config=config, 457 | ) 458 | 459 | if len(attentions): 460 | paths = renew_attention_paths(attentions) 461 | meta_path = { 462 | "old": f"input_blocks.{i}.1", 463 | "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}", 464 | } 465 | assign_to_checkpoint( 466 | paths, 467 | new_checkpoint, 468 | unet_state_dict, 469 | additional_replacements=[meta_path], 470 | config=config, 471 | ) 472 | 473 | resnet_0 = middle_blocks[0] 474 | attentions = middle_blocks[1] 475 | resnet_1 = middle_blocks[2] 476 | 477 | resnet_0_paths = renew_resnet_paths(resnet_0) 478 | assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) 479 | 480 | resnet_1_paths = renew_resnet_paths(resnet_1) 481 | assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) 482 | 483 | attentions_paths = renew_attention_paths(attentions) 484 | meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} 485 | assign_to_checkpoint( 486 | attentions_paths, 487 | new_checkpoint, 488 | unet_state_dict, 489 | additional_replacements=[meta_path], 490 | config=config, 491 | ) 492 | 493 | for i in range(num_output_blocks): 494 | block_id = i // (config["layers_per_block"] + 1) 495 | layer_in_block_id = i % (config["layers_per_block"] + 1) 496 | output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] 497 | output_block_list = {} 498 | 499 | for layer in output_block_layers: 500 | layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) 501 | if layer_id in output_block_list: 502 | output_block_list[layer_id].append(layer_name) 503 | else: 504 | output_block_list[layer_id] = [layer_name] 505 | 506 | if len(output_block_list) > 1: 507 | resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] 508 | attentions = [ 509 | key for key in output_blocks[i] if f"output_blocks.{i}.1" in key 510 | ] 511 | 512 | resnet_0_paths = renew_resnet_paths(resnets) 513 | paths = renew_resnet_paths(resnets) 514 | 515 | meta_path = { 516 | "old": f"output_blocks.{i}.0", 517 | "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}", 518 | } 519 | assign_to_checkpoint( 520 | paths, 521 | new_checkpoint, 522 | unet_state_dict, 523 | additional_replacements=[meta_path], 524 | config=config, 525 | ) 526 | 527 | if ["conv.weight", "conv.bias"] in output_block_list.values(): 528 | index = list(output_block_list.values()).index( 529 | ["conv.weight", "conv.bias"] 530 | ) 531 | new_checkpoint[ 532 | f"up_blocks.{block_id}.upsamplers.0.conv.weight" 533 | ] = unet_state_dict[f"output_blocks.{i}.{index}.conv.weight"] 534 | new_checkpoint[ 535 | f"up_blocks.{block_id}.upsamplers.0.conv.bias" 536 | ] = unet_state_dict[f"output_blocks.{i}.{index}.conv.bias"] 537 | 538 | # Clear attentions as they have been attributed above. 539 | if len(attentions) == 2: 540 | attentions = [] 541 | 542 | if len(attentions): 543 | paths = renew_attention_paths(attentions) 544 | meta_path = { 545 | "old": f"output_blocks.{i}.1", 546 | "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", 547 | } 548 | assign_to_checkpoint( 549 | paths, 550 | new_checkpoint, 551 | unet_state_dict, 552 | additional_replacements=[meta_path], 553 | config=config, 554 | ) 555 | else: 556 | resnet_0_paths = renew_resnet_paths( 557 | output_block_layers, n_shave_prefix_segments=1 558 | ) 559 | for path in resnet_0_paths: 560 | old_path = ".".join(["output_blocks", str(i), path["old"]]) 561 | new_path = ".".join( 562 | [ 563 | "up_blocks", 564 | str(block_id), 565 | "resnets", 566 | str(layer_in_block_id), 567 | path["new"], 568 | ] 569 | ) 570 | 571 | new_checkpoint[new_path] = unet_state_dict[old_path] 572 | 573 | return new_checkpoint 574 | 575 | 576 | def convert_ldm_vae_checkpoint(checkpoint, config): 577 | # extract state dict for VAE 578 | vae_state_dict = {} 579 | vae_key = "first_stage_model." 580 | keys = list(checkpoint.keys()) 581 | for key in keys: 582 | if key.startswith(vae_key): 583 | vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) 584 | 585 | new_checkpoint = {} 586 | 587 | new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] 588 | new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] 589 | new_checkpoint["encoder.conv_out.weight"] = vae_state_dict[ 590 | "encoder.conv_out.weight" 591 | ] 592 | new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] 593 | new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict[ 594 | "encoder.norm_out.weight" 595 | ] 596 | new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict[ 597 | "encoder.norm_out.bias" 598 | ] 599 | 600 | new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] 601 | new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] 602 | new_checkpoint["decoder.conv_out.weight"] = vae_state_dict[ 603 | "decoder.conv_out.weight" 604 | ] 605 | new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] 606 | new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict[ 607 | "decoder.norm_out.weight" 608 | ] 609 | new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict[ 610 | "decoder.norm_out.bias" 611 | ] 612 | 613 | new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] 614 | new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] 615 | new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] 616 | new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] 617 | 618 | # Retrieves the keys for the encoder down blocks only 619 | num_down_blocks = len( 620 | { 621 | ".".join(layer.split(".")[:3]) 622 | for layer in vae_state_dict 623 | if "encoder.down" in layer 624 | } 625 | ) 626 | down_blocks = { 627 | layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] 628 | for layer_id in range(num_down_blocks) 629 | } 630 | 631 | # Retrieves the keys for the decoder up blocks only 632 | num_up_blocks = len( 633 | { 634 | ".".join(layer.split(".")[:3]) 635 | for layer in vae_state_dict 636 | if "decoder.up" in layer 637 | } 638 | ) 639 | up_blocks = { 640 | layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] 641 | for layer_id in range(num_up_blocks) 642 | } 643 | 644 | for i in range(num_down_blocks): 645 | resnets = [ 646 | key 647 | for key in down_blocks[i] 648 | if f"down.{i}" in key and f"down.{i}.downsample" not in key 649 | ] 650 | 651 | if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: 652 | new_checkpoint[ 653 | f"encoder.down_blocks.{i}.downsamplers.0.conv.weight" 654 | ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight") 655 | new_checkpoint[ 656 | f"encoder.down_blocks.{i}.downsamplers.0.conv.bias" 657 | ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias") 658 | 659 | paths = renew_vae_resnet_paths(resnets) 660 | meta_path = { 661 | "old": f"down.{i}.block", 662 | "new": f"down_blocks.{i}.resnets", 663 | } 664 | assign_to_checkpoint( 665 | paths, 666 | new_checkpoint, 667 | vae_state_dict, 668 | additional_replacements=[meta_path], 669 | config=config, 670 | ) 671 | 672 | mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] 673 | num_mid_res_blocks = 2 674 | for i in range(1, num_mid_res_blocks + 1): 675 | resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] 676 | 677 | paths = renew_vae_resnet_paths(resnets) 678 | meta_path = { 679 | "old": f"mid.block_{i}", 680 | "new": f"mid_block.resnets.{i - 1}", 681 | } 682 | assign_to_checkpoint( 683 | paths, 684 | new_checkpoint, 685 | vae_state_dict, 686 | additional_replacements=[meta_path], 687 | config=config, 688 | ) 689 | 690 | mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] 691 | paths = renew_vae_attention_paths(mid_attentions) 692 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} 693 | assign_to_checkpoint( 694 | paths, 695 | new_checkpoint, 696 | vae_state_dict, 697 | additional_replacements=[meta_path], 698 | config=config, 699 | ) 700 | conv_attn_to_linear(new_checkpoint) 701 | 702 | for i in range(num_up_blocks): 703 | block_id = num_up_blocks - 1 - i 704 | resnets = [ 705 | key 706 | for key in up_blocks[block_id] 707 | if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key 708 | ] 709 | 710 | if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: 711 | new_checkpoint[ 712 | f"decoder.up_blocks.{i}.upsamplers.0.conv.weight" 713 | ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"] 714 | new_checkpoint[ 715 | f"decoder.up_blocks.{i}.upsamplers.0.conv.bias" 716 | ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"] 717 | 718 | paths = renew_vae_resnet_paths(resnets) 719 | meta_path = { 720 | "old": f"up.{block_id}.block", 721 | "new": f"up_blocks.{i}.resnets", 722 | } 723 | assign_to_checkpoint( 724 | paths, 725 | new_checkpoint, 726 | vae_state_dict, 727 | additional_replacements=[meta_path], 728 | config=config, 729 | ) 730 | 731 | mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] 732 | num_mid_res_blocks = 2 733 | for i in range(1, num_mid_res_blocks + 1): 734 | resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] 735 | 736 | paths = renew_vae_resnet_paths(resnets) 737 | meta_path = { 738 | "old": f"mid.block_{i}", 739 | "new": f"mid_block.resnets.{i - 1}", 740 | } 741 | assign_to_checkpoint( 742 | paths, 743 | new_checkpoint, 744 | vae_state_dict, 745 | additional_replacements=[meta_path], 746 | config=config, 747 | ) 748 | 749 | mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] 750 | paths = renew_vae_attention_paths(mid_attentions) 751 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} 752 | assign_to_checkpoint( 753 | paths, 754 | new_checkpoint, 755 | vae_state_dict, 756 | additional_replacements=[meta_path], 757 | config=config, 758 | ) 759 | conv_attn_to_linear(new_checkpoint) 760 | return new_checkpoint 761 | 762 | 763 | def convert_ldm_bert_checkpoint(checkpoint, config): 764 | def _copy_attn_layer(hf_attn_layer, pt_attn_layer): 765 | hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight 766 | hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight 767 | hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight 768 | 769 | hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight 770 | hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias 771 | 772 | def _copy_linear(hf_linear, pt_linear): 773 | hf_linear.weight = pt_linear.weight 774 | hf_linear.bias = pt_linear.bias 775 | 776 | def _copy_layer(hf_layer, pt_layer): 777 | # copy layer norms 778 | _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) 779 | _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) 780 | 781 | # copy attn 782 | _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) 783 | 784 | # copy MLP 785 | pt_mlp = pt_layer[1][1] 786 | _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) 787 | _copy_linear(hf_layer.fc2, pt_mlp.net[2]) 788 | 789 | def _copy_layers(hf_layers, pt_layers): 790 | for i, hf_layer in enumerate(hf_layers): 791 | if i != 0: 792 | i += i 793 | pt_layer = pt_layers[i : i + 2] 794 | _copy_layer(hf_layer, pt_layer) 795 | 796 | hf_model = LDMBertModel(config).eval() 797 | 798 | # copy embeds 799 | hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight 800 | hf_model.model.embed_positions.weight.data = ( 801 | checkpoint.transformer.pos_emb.emb.weight 802 | ) 803 | 804 | # copy layer norm 805 | _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) 806 | 807 | # copy hidden layers 808 | _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) 809 | 810 | _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) 811 | 812 | return hf_model 813 | 814 | 815 | def convert_ldm_clip_checkpoint(checkpoint): 816 | text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") 817 | 818 | keys = list(checkpoint.keys()) 819 | 820 | text_model_dict = {} 821 | 822 | for key in keys: 823 | if key.startswith("cond_stage_model.transformer"): 824 | text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[ 825 | key 826 | ] 827 | 828 | text_model.load_state_dict(text_model_dict) 829 | 830 | return text_model 831 | 832 | 833 | if __name__ == "__main__": 834 | parser = argparse.ArgumentParser() 835 | 836 | parser.add_argument( 837 | "--checkpoint_path", 838 | default="./ckpt/anime-diffusion.ckpt", 839 | type=str, 840 | required=True, 841 | help="Path to the checkpoint to convert.", 842 | ) 843 | # !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml 844 | parser.add_argument( 845 | "--original_config_file", 846 | default=None, 847 | type=str, 848 | help="The YAML config file corresponding to the original architecture.", 849 | ) 850 | parser.add_argument( 851 | "--scheduler_type", 852 | default="pndm", 853 | type=str, 854 | help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancest', 'dpm']", 855 | ) 856 | parser.add_argument( 857 | "--extract_ema", 858 | action="store_true", 859 | help=( 860 | "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights" 861 | " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield" 862 | " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning." 863 | ), 864 | ) 865 | parser.add_argument( 866 | "--dump_path", 867 | default="./ckpt/anime-diff", 868 | type=str, 869 | required=True, 870 | help="Path to the output model.", 871 | ) 872 | 873 | args = parser.parse_args() 874 | 875 | if args.original_config_file is None: 876 | os.system( 877 | "wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" 878 | ) 879 | args.original_config_file = "./v1-inference.yaml" 880 | 881 | original_config = OmegaConf.load(args.original_config_file) 882 | 883 | checkpoint = torch.load(args.checkpoint_path) 884 | checkpoint = checkpoint["state_dict"] 885 | 886 | num_train_timesteps = original_config.model.params.timesteps 887 | beta_start = original_config.model.params.linear_start 888 | beta_end = original_config.model.params.linear_end 889 | if args.scheduler_type == "pndm": 890 | scheduler = PNDMScheduler( 891 | beta_end=beta_end, 892 | beta_schedule="scaled_linear", 893 | beta_start=beta_start, 894 | num_train_timesteps=num_train_timesteps, 895 | skip_prk_steps=True, 896 | ) 897 | elif args.scheduler_type == "lms": 898 | scheduler = LMSDiscreteScheduler( 899 | beta_start=beta_start, 900 | beta_end=beta_end, 901 | beta_schedule="scaled_linear", 902 | ) 903 | elif args.scheduler_type == "euler": 904 | scheduler = EulerDiscreteScheduler( 905 | beta_start=beta_start, 906 | beta_end=beta_end, 907 | beta_schedule="scaled_linear", 908 | ) 909 | elif args.scheduler_type == "euler-ancestral": 910 | scheduler = EulerAncestralDiscreteScheduler( 911 | beta_start=beta_start, 912 | beta_end=beta_end, 913 | beta_schedule="scaled_linear", 914 | ) 915 | elif args.scheduler_type == "dpm": 916 | scheduler = DPMSolverMultistepScheduler( 917 | beta_start=beta_start, 918 | beta_end=beta_end, 919 | beta_schedule="scaled_linear", 920 | ) 921 | elif args.scheduler_type == "ddim": 922 | scheduler = DDIMScheduler( 923 | beta_start=beta_start, 924 | beta_end=beta_end, 925 | beta_schedule="scaled_linear", 926 | clip_sample=False, 927 | set_alpha_to_one=False, 928 | ) 929 | else: 930 | raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!") 931 | 932 | # Convert the UNet2DConditionModel model. 933 | unet_config = create_unet_diffusers_config(original_config) 934 | converted_unet_checkpoint = convert_ldm_unet_checkpoint( 935 | checkpoint, 936 | unet_config, 937 | path=args.checkpoint_path, 938 | extract_ema=args.extract_ema, 939 | ) 940 | 941 | unet = UNet2DConditionModel(**unet_config) 942 | unet.load_state_dict(converted_unet_checkpoint) 943 | 944 | # Convert the VAE model. 945 | vae_config = create_vae_diffusers_config(original_config) 946 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) 947 | 948 | vae = AutoencoderKL(**vae_config) 949 | vae.load_state_dict(converted_vae_checkpoint) 950 | 951 | # Convert the text model. 952 | text_model_type = original_config.model.params.cond_stage_config.target.split(".")[ 953 | -1 954 | ] 955 | if text_model_type == "FrozenCLIPEmbedder": 956 | text_model = convert_ldm_clip_checkpoint(checkpoint) 957 | tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") 958 | safety_checker = StableDiffusionSafetyChecker.from_pretrained( 959 | "CompVis/stable-diffusion-safety-checker" 960 | ) 961 | feature_extractor = AutoFeatureExtractor.from_pretrained( 962 | "CompVis/stable-diffusion-safety-checker" 963 | ) 964 | print("TTTTOTOTOTOTOTOTO") 965 | pipe = StableDiffusionPipeline( 966 | vae=vae, 967 | text_encoder=text_model, 968 | tokenizer=tokenizer, 969 | unet=unet, 970 | scheduler=scheduler, 971 | safety_checker=safety_checker, 972 | feature_extractor=feature_extractor, 973 | ) 974 | else: 975 | text_config = create_ldm_bert_config(original_config) 976 | text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) 977 | tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") 978 | pipe = LDMTextToImagePipeline( 979 | vqvae=vae, 980 | bert=text_model, 981 | tokenizer=tokenizer, 982 | unet=unet, 983 | scheduler=scheduler, 984 | ) 985 | 986 | pipe.save_pretrained(args.dump_path) 987 | 988 | 989 | def conver_ckpt_to_diff( 990 | ckpt_path: str, 991 | dump_path: str, 992 | ): 993 | os.system( 994 | "wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" 995 | ) 996 | original_config_file: str = "./v1-inference.yaml" 997 | original_config = OmegaConf.load(original_config_file) 998 | 999 | checkpoint = torch.load(ckpt_path) 1000 | checkpoint = checkpoint["state_dict"] 1001 | 1002 | num_train_timesteps = original_config.model.params.timesteps 1003 | beta_start = original_config.model.params.linear_start 1004 | beta_end = original_config.model.params.linear_end 1005 | scheduler = PNDMScheduler( 1006 | beta_end=beta_end, 1007 | beta_schedule="scaled_linear", 1008 | beta_start=beta_start, 1009 | num_train_timesteps=num_train_timesteps, 1010 | skip_prk_steps=True, 1011 | ) 1012 | 1013 | # Convert the UNet2DConditionModel model. 1014 | unet_config = create_unet_diffusers_config(original_config) 1015 | converted_unet_checkpoint = convert_ldm_unet_checkpoint( 1016 | checkpoint, 1017 | unet_config, 1018 | path=ckpt_path, 1019 | extract_ema=False, 1020 | ) 1021 | 1022 | unet = UNet2DConditionModel(**unet_config) 1023 | unet.load_state_dict(converted_unet_checkpoint) 1024 | 1025 | # Convert the VAE model. 1026 | vae_config = create_vae_diffusers_config(original_config) 1027 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) 1028 | 1029 | vae = AutoencoderKL(**vae_config) 1030 | vae.load_state_dict(converted_vae_checkpoint) 1031 | 1032 | # Convert the text model. 1033 | text_model_type = original_config.model.params.cond_stage_config.target.split(".")[ 1034 | -1 1035 | ] 1036 | if text_model_type == "FrozenCLIPEmbedder": 1037 | text_model = convert_ldm_clip_checkpoint(checkpoint) 1038 | tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") 1039 | safety_checker = StableDiffusionSafetyChecker.from_pretrained( 1040 | "CompVis/stable-diffusion-safety-checker" 1041 | ) 1042 | feature_extractor = AutoFeatureExtractor.from_pretrained( 1043 | "CompVis/stable-diffusion-safety-checker" 1044 | ) 1045 | pipe = StableDiffusionPipeline( 1046 | vae=vae, 1047 | text_encoder=text_model, 1048 | tokenizer=tokenizer, 1049 | unet=unet, 1050 | scheduler=scheduler, 1051 | safety_checker=safety_checker, 1052 | feature_extractor=feature_extractor, 1053 | ) 1054 | else: 1055 | text_config = create_ldm_bert_config(original_config) 1056 | text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) 1057 | tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") 1058 | pipe = LDMTextToImagePipeline( 1059 | vqvae=vae, 1060 | bert=text_model, 1061 | tokenizer=tokenizer, 1062 | unet=unet, 1063 | scheduler=scheduler, 1064 | ) 1065 | 1066 | os.remove("v1-inference.yaml") 1067 | pipe.save_pretrained(dump_path) 1068 | del pipe 1069 | return dump_path 1070 | -------------------------------------------------------------------------------- /docker-compose.yaml: -------------------------------------------------------------------------------- 1 | version: "3.7" 2 | 3 | services: 4 | api: 5 | image: ys2lee/stable-diffusion-api:${TAG:-latest} 6 | build: 7 | context: . 8 | dockerfile: docker/api/Dockerfile 9 | 10 | restart: unless-stopped 11 | ports: 12 | - 3000:3000 13 | expose: 14 | - 3000 15 | volumes: 16 | - /home/{USERNAME}/.cache/huggingface:/root/.cache/huggingface 17 | - ./static:/app/static 18 | environment: 19 | MODEL_ID: "CompVis/stable-diffusion-v1-4" 20 | MB_BATCH_SIZE: 1 21 | 22 | CUDA_DEVICE: cuda 23 | HUGGINGFACE_TOKEN: ${TOKEN:-YOUR-HUGGINGFACE-ACCESS-TOKEN} 24 | IMAGESERVER_URL: http://localhost:3000/images 25 | MB_TIMEOUT: 120 26 | SAVE_DIR: static 27 | 28 | deploy: 29 | resources: 30 | reservations: 31 | devices: 32 | - driver: nvidia 33 | device_ids: [ '0' ] 34 | capabilities: [ gpu ] 35 | 36 | frontend: 37 | depends_on: 38 | - api 39 | 40 | image: ys2lee/stable-diffusion-streamlit:${TAG:-latest} 41 | build: 42 | context: ./frontend 43 | dockerfile: ../docker/frontend/Dockerfile 44 | 45 | restart: unless-stopped 46 | ports: 47 | - 8501:8501 48 | expose: 49 | - 8501 50 | environment: 51 | TZ: Asia/Seoul 52 | API_URL: http://api:3000 53 | ST_TITLE: Stable-diffusion 54 | ST_WIDE: "True" 55 | -------------------------------------------------------------------------------- /docker/api/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:1.12.0-cuda11.3-cudnn8-runtime 2 | WORKDIR /app 3 | 4 | RUN apt-get update && apt install -y \ 5 | libgl1-mesa-glx \ 6 | curl \ 7 | libglib2.0-0 && apt-get clean 8 | 9 | 10 | COPY requirements.txt . 11 | 12 | RUN pip install --no-cache-dir -r requirements.txt 13 | 14 | 15 | ADD . /app 16 | RUN chmod +x /app/docker/api/start.sh 17 | ENTRYPOINT /app/docker/api/start.sh 18 | -------------------------------------------------------------------------------- /docker/api/start.sh: -------------------------------------------------------------------------------- 1 | gunicorn app.server:app -k uvicorn.workers.UvicornWorker \ 2 | --bind 0.0.0.0:3000 \ 3 | --workers 1 -------------------------------------------------------------------------------- /docker/frontend/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.8.12-slim-buster 2 | WORKDIR /app 3 | 4 | RUN apt-get update && apt install -y \ 5 | libgl1-mesa-glx \ 6 | curl \ 7 | libglib2.0-0 && apt-get clean \ 8 | python -m pip install --upgrade pip 9 | 10 | RUN pip install --no-cache-dir \ 11 | streamlit==1.12.2 \ 12 | requests==2.27.1 \ 13 | requests-toolbelt==0.9.1 14 | 15 | COPY requirements.txt . 16 | RUN pip install -r requirements.txt --no-cache-dir 17 | 18 | COPY . /app 19 | 20 | EXPOSE 8501 21 | 22 | CMD ["sh", "-c", "streamlit run --server.address=0.0.0.0 /app/inpaint.py"] -------------------------------------------------------------------------------- /env/.gitignore: -------------------------------------------------------------------------------- 1 | *.env 2 | *.txt -------------------------------------------------------------------------------- /frontend/helps.py: -------------------------------------------------------------------------------- 1 | guidance_scale = """ 2 | 조건부 신호(prompt)의 반영 강도. \n 3 | 큰값을 사용하면 이미지가 좋아 보일 수 있지만 다양성이 떨어짐 \n 4 | 일반적으로 7~8.5 값을 사용하는게 stable diffusion 에서는 안정적인 결과물을 생성 5 | """ 6 | 7 | prompt = """text prompt""" 8 | 9 | init_image = """생성을 위한 기본 이미지""" 10 | 11 | strength = """ 이미지에 추가되는 노이즈의 양. \n 12 | 높을 수록 다양한 변형을 만들어 낼 수 있지만 \n 13 | 조건으로 입력한 이미지 형태의 따르지 않음""" 14 | -------------------------------------------------------------------------------- /frontend/inpaint.py: -------------------------------------------------------------------------------- 1 | import typing as T 2 | import os 3 | from io import BytesIO 4 | import streamlit as st 5 | import requests 6 | from PIL import Image 7 | from streamlit_image_comparison import image_comparison 8 | from streamlit_drawable_canvas import st_canvas 9 | from PIL import Image, ImageOps 10 | from settings import load_config, get_settings 11 | from task import Task 12 | 13 | load_config("Stable-Diffusion: inpaint task") 14 | 15 | import requests 16 | 17 | env = get_settings() 18 | 19 | URL = os.path.join(env.API_URL, Task.INPAINT) 20 | 21 | 22 | def main(): 23 | with st.sidebar: 24 | stroke_width = st.slider("Stroke width: ", 1, 50, 25) 25 | 26 | with st.form("metadata"): 27 | steps = st.slider( 28 | "Number of Steps", 29 | min_value=1, 30 | max_value=50, 31 | step=1, 32 | value=25, 33 | ) 34 | guidance_scale = st.slider( 35 | "Guidance scale", 36 | min_value=0.1, 37 | max_value=20.0, 38 | value=7.5, 39 | step=0.01, 40 | ) 41 | strength = st.slider( 42 | "strength", 43 | min_value=0.0, 44 | max_value=1.0, 45 | value=0.8, 46 | step=0.01, 47 | ) 48 | seed = st.number_input("Seed", value=-1) 49 | summit = st.form_submit_button("Predict") 50 | 51 | st.title(f"{env.ST_TITLE}: inpaint") 52 | prompt = st.text_area( 53 | label="Text Prompt", 54 | placeholder="Text Prompt", 55 | key="prompt", 56 | ) 57 | 58 | negative_prompt = st.text_area( 59 | label="Negative Text Prompt", 60 | placeholder="Text Prompt", 61 | key="nega-prompt", 62 | ) 63 | 64 | init_image = st.file_uploader("Init image", env.IMAGE_TYPES) 65 | 66 | if init_image: 67 | background_image = Image.open(init_image) 68 | w, h = background_image.size 69 | canvas = st_canvas( 70 | background_image=background_image, 71 | stroke_width=stroke_width, 72 | background_color="#FFFFFFFF", 73 | width=w, 74 | height=h, 75 | key="inpaint-canvas", 76 | ) 77 | 78 | st.markdown("---") 79 | if summit and not init_image: 80 | st.warning("Input init Image") 81 | return -1 82 | 83 | if summit: 84 | mask_image = get_mask_image(canvas) 85 | mask_image_bytes = BytesIO() 86 | mask_image.save(mask_image_bytes, format="WEBP") 87 | 88 | image_urls = predict( 89 | prompt=prompt, 90 | negative_prompt=negative_prompt, 91 | steps=int(steps), 92 | init_image=init_image.getvalue(), 93 | mask_image=mask_image_bytes.getvalue(), 94 | strength=float(strength), 95 | guidance_scale=float(guidance_scale), 96 | seed=seed, 97 | ) 98 | 99 | c1, c2 = st.columns([1, 1]) 100 | c1.title("Origin") 101 | c1.image(init_image) 102 | 103 | c2.title("Result") 104 | c2.image(image_urls) 105 | 106 | image_comparison( 107 | img1=Image.open(init_image), 108 | img2=image_urls[-1], 109 | label1="origin", 110 | label2="diffusion", 111 | ) 112 | 113 | 114 | def get_mask_image(canvas) -> Image.Image: 115 | mask_image = Image.fromarray(canvas.image_data) 116 | new_mask_image = Image.new("RGBA", mask_image.size, "WHITE") 117 | new_mask_image.paste(mask_image, mask=mask_image) 118 | new_mask_image = new_mask_image.convert("RGB") 119 | new_mask_image = ImageOps.invert(new_mask_image) 120 | return new_mask_image 121 | 122 | 123 | def predict( 124 | prompt: str, 125 | negative_prompt: str, 126 | steps: int, 127 | init_image: bytes, 128 | mask_image: bytes, 129 | guidance_scale: float, 130 | strength: float, 131 | seed: int, 132 | ) -> T.List[str]: 133 | print(prompt) 134 | prompt = " " if prompt is None else prompt 135 | negative_prompt = "" if negative_prompt is None else negative_prompt 136 | files = [ 137 | ("init_image", ("image.webp", init_image, "image/*")), 138 | ("mask_image", ("image.webp", mask_image, "image/*")), 139 | ] 140 | res = requests.post( 141 | URL, 142 | data={ 143 | "prompt": prompt, 144 | "negative_prompt": negative_prompt, 145 | "steps": steps, 146 | "num_images": 1, 147 | "guidance_scale": guidance_scale, 148 | "strength": strength, 149 | "seed": seed, 150 | }, 151 | files=files, 152 | headers={}, 153 | ) 154 | if not res.ok: 155 | st.error(res.text) 156 | 157 | output = res.json() 158 | image_urls = output["image_urls"] 159 | return image_urls 160 | 161 | 162 | main() 163 | -------------------------------------------------------------------------------- /frontend/pages/image2image.py: -------------------------------------------------------------------------------- 1 | import typing as T 2 | import os 3 | import streamlit as st 4 | import requests 5 | from PIL import Image 6 | from streamlit_image_comparison import image_comparison 7 | from settings import load_config, get_settings 8 | from task import Task 9 | 10 | load_config("Stable-Diffusion: image2image task") 11 | 12 | import requests 13 | 14 | env = get_settings() 15 | 16 | URL = os.path.join(env.API_URL, Task.IMAGE2IMAGE) 17 | 18 | 19 | def main(): 20 | st.title(f"{env.ST_TITLE}: Image to Image") 21 | st.sidebar.markdown("#Image to Image task") 22 | prompt = st.text_area( 23 | label="Text Prompt", 24 | value="A fantasy landscape, trending on artstation", 25 | key="prompt", 26 | ) 27 | 28 | negative_prompt = st.text_area( 29 | label="Negative Text Prompt", 30 | placeholder="Text Prompt", 31 | key="nega-prompt", 32 | ) 33 | init_image = st.file_uploader( 34 | "Init image", 35 | env.IMAGE_TYPES, 36 | ) 37 | 38 | st.markdown("---") 39 | 40 | with st.sidebar as bar, st.form("key") as form: 41 | steps = st.slider( 42 | "Number of Steps", 43 | min_value=1, 44 | max_value=50, 45 | step=1, 46 | value=25, 47 | ) 48 | guidance_scale = st.slider( 49 | "Guidance scale", 50 | min_value=0.1, 51 | max_value=20.0, 52 | value=7.5, 53 | step=0.01, 54 | ) 55 | strength = st.slider( 56 | "strength", 57 | min_value=0.0, 58 | max_value=1.0, 59 | value=0.8, 60 | step=0.01, 61 | ) 62 | seed = st.number_input("Seed", value=-1) 63 | summit = st.form_submit_button("Predict") 64 | 65 | if summit: 66 | image_urls = predict( 67 | prompt=prompt, 68 | negative_prompt=negative_prompt, 69 | steps=int(steps), 70 | init_image=init_image.getvalue(), 71 | strength=float(strength), 72 | guidance_scale=float(guidance_scale), 73 | seed=seed, 74 | ) 75 | 76 | c1, c2 = st.columns([1, 1]) 77 | c1.title("Origin") 78 | c1.image(init_image) 79 | 80 | c2.title("Result") 81 | c2.image(image_urls) 82 | 83 | image_comparison( 84 | img1=Image.open(init_image), 85 | img2=image_urls[-1], 86 | label1="origin", 87 | label2="diffusion", 88 | ) 89 | 90 | 91 | def predict( 92 | prompt: str, 93 | negative_prompt: str, 94 | init_image: bytes, 95 | steps: int, 96 | guidance_scale: float, 97 | strength: float, 98 | seed: int, 99 | ) -> T.List[str]: 100 | prompt = " " if prompt is None else prompt 101 | negative_prompt = "" if negative_prompt is None else negative_prompt 102 | files = [("init_image", ("image.jpg", init_image, "image/*"))] 103 | res = requests.post( 104 | URL, 105 | data={ 106 | "prompt": prompt, 107 | "negative_prompt": negative_prompt, 108 | "steps": steps, 109 | "num_images": 1, 110 | "guidance_scale": guidance_scale, 111 | "strength": strength, 112 | "seed": seed, 113 | }, 114 | files=files, 115 | headers={}, 116 | ) 117 | if not res.ok: 118 | st.error(res.text) 119 | output = res.json() 120 | image_urls = output["image_urls"] 121 | return image_urls 122 | 123 | 124 | main() 125 | -------------------------------------------------------------------------------- /frontend/pages/text2image.py: -------------------------------------------------------------------------------- 1 | import typing as T 2 | import os 3 | import sys 4 | import streamlit as st 5 | import requests 6 | from settings import load_config, get_settings 7 | from utils import image_grid 8 | from task import Task 9 | 10 | load_config("Stable-Diffusion: text2image") 11 | 12 | import requests 13 | 14 | env = get_settings() 15 | 16 | URL = os.path.join(env.API_URL, Task.TEXT2IMAGE) 17 | 18 | 19 | def main(): 20 | st.title(f"{env.ST_TITLE}: text to image") 21 | st.sidebar.markdown("# Text to Image task") 22 | prompt = st.text_area( 23 | label="Text Prompt", 24 | value="A fantasy landscape, trending on artstation", 25 | key="prompt", 26 | ) 27 | negative_prompt = st.text_area( 28 | label="Negative Text Prompt", 29 | placeholder="Text Prompt", 30 | key="nega-prompt", 31 | ) 32 | st.markdown("---") 33 | 34 | with st.sidebar as bar, st.form("key") as form: 35 | num_images = st.slider( 36 | "Number of Image", 37 | min_value=1, 38 | max_value=8, 39 | step=1, 40 | ) 41 | steps = st.slider( 42 | "Number of Steps", 43 | min_value=1, 44 | max_value=50, 45 | step=1, 46 | value=25, 47 | ) 48 | guidance_scale = st.slider( 49 | "Guidance scale", 50 | min_value=0.1, 51 | max_value=20.0, 52 | value=7.5, 53 | step=0.01, 54 | ) 55 | height = st.select_slider( 56 | "Height", options=[size for size in range(128, 1025, 64)], value=512 57 | ) 58 | width = st.select_slider( 59 | "Width", options=[size for size in range(128, 1025, 64)], value=512 60 | ) 61 | seed = st.number_input("Seed", value=-1) 62 | 63 | summit = st.form_submit_button("Predict") 64 | 65 | if summit: 66 | image_urls = predict( 67 | prompt=prompt, 68 | negative_prompt=negative_prompt, 69 | steps=int(steps), 70 | num_images=int(num_images), 71 | guidance_scale=float(guidance_scale), 72 | height=height, 73 | width=width, 74 | seed=seed, 75 | ) 76 | image_grid(image_urls) 77 | 78 | 79 | def predict( 80 | prompt: str, 81 | negative_prompt: str, 82 | steps: int, 83 | num_images: int, 84 | guidance_scale: float, 85 | height: int, 86 | width: int, 87 | seed: int, 88 | ) -> T.List[str]: 89 | prompt = " " if prompt is None else prompt 90 | negative_prompt = "" if negative_prompt is None else negative_prompt 91 | res = requests.post( 92 | URL, 93 | data={ 94 | "prompt": prompt, 95 | "negative_prompt": negative_prompt, 96 | "steps": steps, 97 | "num_images": num_images, 98 | "guidance_scale": guidance_scale, 99 | "height": height, 100 | "width": width, 101 | "seed": seed, 102 | }, 103 | headers={}, 104 | ) 105 | if not res.ok: 106 | st.error(res.json()) 107 | output = res.json() 108 | image_urls = output["image_urls"] 109 | return image_urls 110 | 111 | 112 | main() 113 | -------------------------------------------------------------------------------- /frontend/requirements.txt: -------------------------------------------------------------------------------- 1 | pydantic==1.8.2 2 | streamlit-drawable-canvas==0.9.2 3 | streamlit-image-comparison==0.0.2 -------------------------------------------------------------------------------- /frontend/settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | import streamlit as st 3 | from functools import lru_cache 4 | from pydantic import BaseSettings 5 | 6 | try: 7 | from dotenv import load_dotenv 8 | 9 | load_dotenv() 10 | except ImportError: 11 | pass 12 | 13 | 14 | def load_config(title=None, icone=None): 15 | st.set_page_config( 16 | page_title=title if title is not None else os.getenv("ST_TITLE"), 17 | layout="wide" if os.getenv("ST_WIDE") == "True" else "centered", 18 | menu_items={}, 19 | ) 20 | 21 | 22 | class Settings(BaseSettings): 23 | API_URL: str = "http://localhost:3000" 24 | ST_TITLE: str = "Stable-diffusion" 25 | ST_WIDE: str = True 26 | IMAGE_TYPES = ["png", "jpg", "jpeg", "webp", "bmp"] 27 | 28 | 29 | @lru_cache() 30 | def get_settings() -> Settings: 31 | setting = Settings() 32 | return setting 33 | -------------------------------------------------------------------------------- /frontend/task.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | # Task endpoint 4 | class Task: 5 | INPAINT = "inpaint" 6 | IMAGE2IMAGE = "image2image" 7 | TEXT2IMAGE = "text2image" 8 | -------------------------------------------------------------------------------- /frontend/utils.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | 3 | 4 | def image_grid(images, columns=2, show_caption: bool = True): 5 | st.header("Result Images:") 6 | c1, c2 = st.columns(columns) 7 | try: 8 | for i in range(0, len(images), columns): 9 | c1.image(images[i], caption=images[i] if show_caption else None) 10 | c2.image(images[i + 1], caption=images[i + 1] if show_caption else None) 11 | except IndexError: 12 | pass 13 | -------------------------------------------------------------------------------- /huggingface_model_download.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | # download stable diffusion model to local huggingface cache dir 3 | from core.dependencies import models 4 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import uvicorn 2 | 3 | if __name__ == "__main__": 4 | pass 5 | app = "app.server:app" 6 | uvicorn.run( 7 | app, 8 | port=3000, 9 | host="0.0.0.0", 10 | workers=1, 11 | ) 12 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi[all]==0.80.0 2 | fastapi-restful==0.4.3 3 | fastapi-health==0.4.0 4 | service-streamer==0.1.2 5 | loguru==0.6.0 6 | gunicorn==20.1.0 7 | pydantic==1.9.2 8 | diffusers 9 | transformers 10 | accelerate 11 | scipy 12 | ftfy 13 | -------------------------------------------------------------------------------- /src/image/image2image/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rapidrabbit76/stable-diffusion-API/25e2ae5bc83837865a48151db71447e66fabbfd0/src/image/image2image/1.png -------------------------------------------------------------------------------- /src/image/image2image/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rapidrabbit76/stable-diffusion-API/25e2ae5bc83837865a48151db71447e66fabbfd0/src/image/image2image/2.png -------------------------------------------------------------------------------- /src/image/inpaint/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rapidrabbit76/stable-diffusion-API/25e2ae5bc83837865a48151db71447e66fabbfd0/src/image/inpaint/0.png -------------------------------------------------------------------------------- /src/image/inpaint/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rapidrabbit76/stable-diffusion-API/25e2ae5bc83837865a48151db71447e66fabbfd0/src/image/inpaint/1.png -------------------------------------------------------------------------------- /src/image/inpaint/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rapidrabbit76/stable-diffusion-API/25e2ae5bc83837865a48151db71447e66fabbfd0/src/image/inpaint/2.png -------------------------------------------------------------------------------- /src/image/text2image/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rapidrabbit76/stable-diffusion-API/25e2ae5bc83837865a48151db71447e66fabbfd0/src/image/text2image/1.png -------------------------------------------------------------------------------- /src/image/text2image/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rapidrabbit76/stable-diffusion-API/25e2ae5bc83837865a48151db71447e66fabbfd0/src/image/text2image/2.png --------------------------------------------------------------------------------