├── docs ├── index.html └── stablefused │ └── typing.html ├── stablefused ├── apps │ └── storybook │ │ ├── __init__.py │ │ ├── author_api.py │ │ ├── speaker_api.py │ │ ├── config │ │ └── default_1_shot.json │ │ └── storybook.py ├── typing │ ├── __init__.py │ ├── enums.py │ └── type_hints.py ├── diffusion │ ├── __init__.py │ ├── image_to_image_diffusion.py │ ├── text_to_image_diffusion.py │ ├── text_to_video_diffusion.py │ ├── latent_walk_diffusion.py │ └── base_diffusion.py ├── utils │ ├── __init__.py │ ├── model_cache.py │ ├── import_utils.py │ ├── diffusion_utils.py │ └── image_utils.py └── __init__.py ├── .gitignore ├── LICENSE ├── .github └── workflows │ └── python-publish.yml ├── setup.py ├── tests ├── diffusion │ ├── test_text_to_image_diffusion.py │ ├── test_image_to_image_diffusion.py │ └── test_latent_walk_diffusion.py └── utils │ ├── test_diffusion_utils.py │ └── test_image_utils.py ├── examples ├── text_to_video_diffusion.ipynb ├── text_to_image_diffusion.ipynb ├── effect_of_guidance_scale.ipynb └── latent_walk_diffusion.ipynb └── README.md /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /stablefused/apps/storybook/__init__.py: -------------------------------------------------------------------------------- 1 | from .author_api import StoryBookAuthorBase, G4FStoryBookAuthor 2 | from .speaker_api import StoryBookSpeakerBase, gTTSStoryBookSpeaker 3 | from .storybook import StoryBookConfig, StoryBook 4 | -------------------------------------------------------------------------------- /stablefused/typing/__init__.py: -------------------------------------------------------------------------------- 1 | from .enums import ( 2 | InpaintWalkType, 3 | Scheduler, 4 | ) 5 | 6 | from .type_hints import ( 7 | ImageType, 8 | OutputType, 9 | PromptType, 10 | SchedulerType, 11 | UNetType, 12 | ) 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # python directories 2 | .venv/ 3 | venv/ 4 | __pycache__/ 5 | .ipynb_checkpoints/ 6 | *.model 7 | *.egg-info 8 | dist/ 9 | 10 | # vscode 11 | .vscode/ 12 | 13 | # build 14 | build/ 15 | *.out 16 | *.s 17 | 18 | # tensorflow 19 | training_checkpoints/ 20 | 21 | # local test files 22 | test/ 23 | tmp/ 24 | 25 | # environment files 26 | .env 27 | 28 | # datasets 29 | *.csv 30 | 31 | # images and videos to not increase repository size 32 | resources/results/ 33 | -------------------------------------------------------------------------------- /stablefused/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_diffusion import BaseDiffusion 2 | from .image_to_image_diffusion import ImageToImageConfig, ImageToImageDiffusion 3 | from .latent_walk_diffusion import ( 4 | LatentWalkConfig, 5 | LatentWalkInterpolateConfig, 6 | LatentWalkDiffusion, 7 | ) 8 | from .text_to_image_diffusion import TextToImageConfig, TextToImageDiffusion 9 | from .text_to_video_diffusion import TextToVideoConfig, TextToVideoDiffusion 10 | from .inpaint_diffusion import InpaintConfig, InpaintWalkConfig, InpaintDiffusion 11 | -------------------------------------------------------------------------------- /stablefused/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .diffusion_utils import ( 2 | lerp, 3 | resolve_scheduler, 4 | slerp, 5 | ) 6 | from .image_utils import ( 7 | denormalize, 8 | image_grid, 9 | normalize, 10 | numpy_to_pil, 11 | numpy_to_pt, 12 | pil_to_numpy, 13 | pil_to_video, 14 | pil_to_gif, 15 | pt_to_numpy, 16 | write_text_on_image, 17 | ) 18 | from .import_utils import ( 19 | LazyImporter, 20 | ) 21 | from .model_cache import ( 22 | save_model_to_cache, 23 | load_model_from_cache, 24 | ) 25 | -------------------------------------------------------------------------------- /stablefused/utils/model_cache.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | 4 | class ModelCache: 5 | """ 6 | A cache for diffusion models. This class should not be instantiated by the user. 7 | You should use the load_model function instead. It is a mapping from model_id to 8 | diffusion model. This allows us to avoid loading the same model components multiple 9 | times. 10 | """ 11 | 12 | def __init__(self) -> None: 13 | self.cache = dict() 14 | 15 | def get(self, model_id: str, default: Any = None) -> Any: 16 | if model_id not in self.cache.keys(): 17 | return default 18 | return self.cache[model_id] 19 | 20 | def set(self, model: Any) -> None: 21 | self.cache[model.model_id] = model 22 | 23 | 24 | _model_cache = ModelCache() 25 | 26 | 27 | load_model_from_cache = _model_cache.get 28 | save_model_to_cache = _model_cache.set 29 | -------------------------------------------------------------------------------- /stablefused/utils/import_utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | from types import ModuleType 4 | from typing import Optional 5 | 6 | 7 | class LazyImporter: 8 | """ 9 | Lazy importer for modules. 10 | """ 11 | 12 | def __init__(self, module_name: str) -> None: 13 | self.module_name = module_name 14 | self.module = None 15 | 16 | def import_module(self, import_error_message: Optional[str] = None) -> ModuleType: 17 | if self.module is None: 18 | try: 19 | self.module = importlib.import_module(self.module_name) 20 | except ModuleNotFoundError: 21 | if import_error_message is None: 22 | import_error_message = ( 23 | f"'{self.module_name}' is not installed. Please install it." 24 | ) 25 | raise ModuleNotFoundError(import_error_message) 26 | return self.module 27 | -------------------------------------------------------------------------------- /stablefused/apps/storybook/author_api.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict, List 3 | 4 | from stablefused import LazyImporter 5 | 6 | 7 | class StoryBookAuthorBase(ABC): 8 | def __init__(self) -> None: 9 | pass 10 | 11 | @abstractmethod 12 | def __call__(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]: 13 | pass 14 | 15 | 16 | class G4FStoryBookAuthor(StoryBookAuthorBase): 17 | def __init__(self, model_id: str = "gpt-3.5-turbo") -> None: 18 | super().__init__() 19 | self.model_id = model_id 20 | self.g4f = LazyImporter("g4f") 21 | 22 | def __call__(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]: 23 | g4f = self.g4f.import_module( 24 | import_error_message="g4f is not installed. Please install it using `pip install g4f`." 25 | ) 26 | return g4f.ChatCompletion.create(model=self.model_id, messages=messages) 27 | -------------------------------------------------------------------------------- /stablefused/typing/enums.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class InpaintWalkType(str, Enum): 5 | """ 6 | Enum for inpainting walk types. 7 | """ 8 | 9 | UP = "up" 10 | DOWN = "down" 11 | LEFT = "left" 12 | RIGHT = "right" 13 | FORWARD = "forward" 14 | BACKWARD = "backward" 15 | 16 | 17 | class Scheduler(str, Enum): 18 | DEIS = "deis" 19 | DDIM = "ddim" 20 | DDPM = "ddpm" 21 | DPM2_KARRAS = "dpm2_karras" 22 | DPM2_KARRAS_ANCESTRAL = "dpm2_karras_ancestral" 23 | DPM_SDE = "dpm_sde" 24 | DPM_SDE_KARRAS = "dpm_sde_karras" 25 | DPM_MULTISTEP = "dpm_multistep" 26 | DPM_MULTISTEP_KARRAS = "dpm_multistep_karras" 27 | DPM_SINGLESTEP = "dpm_singlestep" 28 | DPM_SINGLESTEP_KARRAS = "dpm_singlestep_karras" 29 | EULER = "euler" 30 | EULER_ANCESTRAL = "euler_ancestral" 31 | HEUN = "heun" 32 | LINEAR_MULTISTEP = "linear_multistep" 33 | PNDM = "pndm" 34 | UNIPC = "unipc" 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Aryan V S 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /stablefused/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. include:: ../README.md 3 | """ 4 | 5 | from .diffusion import ( 6 | BaseDiffusion, 7 | ImageToImageConfig, 8 | ImageToImageDiffusion, 9 | InpaintConfig, 10 | InpaintWalkConfig, 11 | InpaintDiffusion, 12 | LatentWalkConfig, 13 | LatentWalkInterpolateConfig, 14 | LatentWalkDiffusion, 15 | TextToImageConfig, 16 | TextToImageDiffusion, 17 | TextToVideoConfig, 18 | TextToVideoDiffusion, 19 | ) 20 | 21 | from .typing import ( 22 | InpaintWalkType, 23 | Scheduler, 24 | ImageType, 25 | OutputType, 26 | PromptType, 27 | SchedulerType, 28 | UNetType, 29 | ) 30 | 31 | from .utils import ( 32 | denormalize, 33 | image_grid, 34 | lerp, 35 | load_model_from_cache, 36 | normalize, 37 | numpy_to_pil, 38 | numpy_to_pt, 39 | pil_to_numpy, 40 | pil_to_video, 41 | pil_to_gif, 42 | pt_to_numpy, 43 | resolve_scheduler, 44 | save_model_to_cache, 45 | slerp, 46 | write_text_on_image, 47 | LazyImporter, 48 | ) 49 | 50 | from .apps.storybook import ( 51 | StoryBookAuthorBase, 52 | G4FStoryBookAuthor, 53 | StoryBookConfig, 54 | StoryBook, 55 | StoryBookSpeakerBase, 56 | gTTSStoryBookSpeaker, 57 | ) 58 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /stablefused/typing/type_hints.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from PIL import Image 5 | from diffusers.models import UNet2DConditionModel, UNet3DConditionModel 6 | from diffusers.schedulers import ( 7 | DEISMultistepScheduler, 8 | DDIMScheduler, 9 | DDPMScheduler, 10 | DPMSolverSDEScheduler, 11 | DPMSolverMultistepScheduler, 12 | DPMSolverSinglestepScheduler, 13 | EulerAncestralDiscreteScheduler, 14 | EulerDiscreteScheduler, 15 | HeunDiscreteScheduler, 16 | KDPM2DiscreteScheduler, 17 | KDPM2AncestralDiscreteScheduler, 18 | LMSDiscreteScheduler, 19 | PNDMScheduler, 20 | UniPCMultistepScheduler, 21 | ) 22 | from typing import List, Union 23 | 24 | 25 | ImageType = Union[torch.Tensor, np.ndarray, Image.Image, List[Image.Image]] 26 | 27 | OutputType = Union[torch.Tensor, np.ndarray, List[Image.Image]] 28 | 29 | PromptType = Union[str, List[str]] 30 | 31 | SchedulerType = Union[ 32 | DEISMultistepScheduler, 33 | DDIMScheduler, 34 | DDPMScheduler, 35 | DPMSolverSDEScheduler, 36 | DPMSolverMultistepScheduler, 37 | DPMSolverSinglestepScheduler, 38 | EulerAncestralDiscreteScheduler, 39 | EulerDiscreteScheduler, 40 | HeunDiscreteScheduler, 41 | KDPM2DiscreteScheduler, 42 | KDPM2AncestralDiscreteScheduler, 43 | LMSDiscreteScheduler, 44 | PNDMScheduler, 45 | UniPCMultistepScheduler, 46 | ] 47 | 48 | UNetType = Union[UNet2DConditionModel, UNet3DConditionModel] 49 | -------------------------------------------------------------------------------- /stablefused/apps/storybook/speaker_api.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import List 5 | 6 | 7 | class StoryBookSpeakerBase(ABC): 8 | def __init__(self) -> None: 9 | pass 10 | 11 | @abstractmethod 12 | def __call__(self, messages: List[str], *, yield_files: bool = True) -> List[str]: 13 | pass 14 | 15 | 16 | class gTTSStoryBookSpeaker(StoryBookSpeakerBase): 17 | def __init__(self, *, lang: str = "en", tld: str = "us") -> None: 18 | super().__init__() 19 | try: 20 | from gtts import gTTS 21 | except ImportError: 22 | raise ImportError( 23 | "gTTS is not installed. Please install it using `pip install gTTS`." 24 | ) 25 | self.gTTS = gTTS 26 | self.lang = lang 27 | self.tld = tld 28 | 29 | def __call__(self, messages: List[str], *, yield_files: bool = True) -> List[str]: 30 | for message in messages: 31 | tts = self.gTTS(message, lang=self.lang, tld=self.tld) 32 | 33 | if yield_files: 34 | with tempfile.NamedTemporaryFile(suffix=".wav") as f: 35 | tts.write_to_fp(f) 36 | yield f.name 37 | else: 38 | files = [] 39 | with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: 40 | tts.write_to_fp(f) 41 | files.append(f.name) 42 | return files 43 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open("README.md", "r", encoding="utf-8") as file: 4 | long_description = file.read() 5 | 6 | setup( 7 | name="stablefused", 8 | version="0.2.0", 9 | description="StableFused is a toy library to experiment with Stable Diffusion inspired by 🤗 diffusers and various other sources!", 10 | long_description=long_description, 11 | long_description_content_type="text/markdown", 12 | author="Aryan V S", 13 | author_email="contact.aryanvs+stablefused@gmail.com", 14 | url="https://github.com/a-r-r-o-w/stablefused/", 15 | python_requires=">=3.8.0", 16 | license="MIT", 17 | packages=find_packages(), 18 | install_requires=[ 19 | "accelerate==0.21.0", 20 | "dataclasses-json==0.6.1", 21 | "diffusers==0.19.3", 22 | "ftfy==6.1.1", 23 | "imageio==2.31.1", 24 | "imageio-ffmpeg==0.4.8", 25 | "torch==2.0.1", 26 | "transformers==4.31.0", 27 | "matplotlib==3.7.2", 28 | "moviepy==1.0.3", 29 | "numpy==1.25.2", 30 | "pillow==9.5.0", 31 | "scipy==1.11.1", 32 | ], 33 | extras_require={ 34 | "dev": [ 35 | "black==23.7.0", 36 | "pytest==7.4.0", 37 | "twine>=4.0.2", 38 | ], 39 | "extras": [ 40 | "g4f==0.1.5.6", 41 | "curl-cffi==0.5.7", 42 | "gtts==2.4.0", 43 | ], 44 | }, 45 | classifiers=[ 46 | "Development Status :: 1 - Planning", 47 | "Intended Audience :: Science/Research", 48 | "Intended Audience :: Developers", 49 | "Intended Audience :: Education", 50 | "Programming Language :: Python :: 3", 51 | "Programming Language :: Python :: 3.8", 52 | "Programming Language :: Python :: 3.9", 53 | "Programming Language :: Python :: 3.10", 54 | "Operating System :: Microsoft :: Windows", 55 | "Operating System :: Unix", 56 | "License :: OSI Approved :: MIT License", 57 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 58 | ], 59 | ) 60 | 61 | # Steps to publish: 62 | # 1. Update version in setup.py 63 | # 2. python setup.py sdist bdist_wheel 64 | # 3. Check if everything works with testpypi: 65 | # twine upload --repository testpypi dist/* 66 | # 4. Upload to pypi: 67 | # twine upload dist/* 68 | -------------------------------------------------------------------------------- /tests/diffusion/test_text_to_image_diffusion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pytest 4 | 5 | from stablefused import TextToImageConfig, TextToImageDiffusion 6 | 7 | 8 | @pytest.fixture 9 | def model(): 10 | """ 11 | Fixture to initialize the TextToImageDiffusion model and set random seeds for reproducibility. 12 | 13 | Returns 14 | ------- 15 | TextToImageDiffusion 16 | The initialized TextToImageDiffusion model. 17 | """ 18 | seed = 1337 19 | model_id = "hf-internal-testing/tiny-stable-diffusion-pipe" 20 | device = "cpu" 21 | 22 | torch.manual_seed(seed) 23 | np.random.seed(seed) 24 | 25 | model = TextToImageDiffusion(model_id=model_id, device=device) 26 | return model 27 | 28 | 29 | @pytest.fixture 30 | def config(): 31 | return { 32 | "prompt": "a photo of a cat", 33 | "num_inference_steps": 1, 34 | "image_dim": 32, 35 | } 36 | 37 | 38 | def test_text_to_image_diffusion(model: TextToImageDiffusion, config: dict) -> None: 39 | """ 40 | Test case to check if the TextToImageDiffusion is working correctly. 41 | 42 | Parameters 43 | ---------- 44 | 45 | Raises 46 | ------ 47 | AssertionError 48 | If the generated image is not of type np.ndarray. 49 | If the generated image does not have the expected shape. 50 | """ 51 | 52 | dim = config.get("image_dim") 53 | images = model( 54 | TextToImageConfig( 55 | prompt=config.get("prompt"), 56 | image_height=dim, 57 | image_width=dim, 58 | num_inference_steps=config.get("num_inference_steps"), 59 | output_type="np", 60 | ) 61 | ) 62 | 63 | assert type(images) is np.ndarray 64 | assert images.shape == (1, dim, dim, 3) 65 | 66 | 67 | def test_return_latent_history(model: TextToImageDiffusion, config: dict) -> None: 68 | """ 69 | Test case to check if the latent history is returned correctly. 70 | 71 | Raises 72 | ------ 73 | AssertionError 74 | If the generated image is not of type np.ndarray. 75 | If the generated image does not have the expected shape. 76 | """ 77 | 78 | dim = config.get("image_dim") 79 | images = model( 80 | TextToImageConfig( 81 | prompt=config.get("prompt"), 82 | image_height=dim, 83 | image_width=dim, 84 | num_inference_steps=config.get("num_inference_steps"), 85 | output_type="pt", 86 | return_latent_history=True, 87 | ) 88 | ) 89 | 90 | assert type(images) is torch.Tensor 91 | assert images.shape == (1, config.get("num_inference_steps") + 1, 3, dim, dim) 92 | 93 | 94 | if __name__ == "__main__": 95 | pytest.main([__file__]) 96 | -------------------------------------------------------------------------------- /tests/diffusion/test_image_to_image_diffusion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pytest 4 | 5 | from stablefused import ImageToImageConfig, ImageToImageDiffusion 6 | 7 | 8 | @pytest.fixture 9 | def model(): 10 | """ 11 | Fixture to initialize the ImageToImageDiffusion model and set random seeds for reproducibility. 12 | 13 | Returns 14 | ------- 15 | ImageToImageDiffusion 16 | The initialized ImageToImageDiffusion model. 17 | """ 18 | seed = 1337 19 | model_id = "hf-internal-testing/tiny-stable-diffusion-pipe" 20 | device = "cpu" 21 | 22 | torch.manual_seed(seed) 23 | np.random.seed(seed) 24 | 25 | model = ImageToImageDiffusion(model_id=model_id, device=device) 26 | return model 27 | 28 | 29 | @pytest.fixture 30 | def config(): 31 | return { 32 | "prompt": "a photo of a cat", 33 | "num_inference_steps": 2, 34 | "start_step": 1, 35 | "image_dim": 32, 36 | } 37 | 38 | 39 | def test_image_to_image_diffusion(model: ImageToImageDiffusion, config: dict) -> None: 40 | """ 41 | Test case to check if the ImageToImageDiffusion is working correctly. 42 | 43 | Raises 44 | ------ 45 | AssertionError 46 | If the generated image is not of type np.ndarray. 47 | If the generated image does not have the expected shape. 48 | """ 49 | dim = config.get("image_dim") 50 | image = model.random_tensor((1, 3, dim, dim)) 51 | 52 | images = model( 53 | ImageToImageConfig( 54 | image=image, 55 | prompt=config.get("prompt"), 56 | num_inference_steps=config.get("num_inference_steps"), 57 | output_type="np", 58 | ) 59 | ) 60 | 61 | assert type(images) is np.ndarray 62 | assert images.shape == (1, dim, dim, 3) 63 | 64 | 65 | def test_return_latent_history(model: ImageToImageDiffusion, config: dict) -> None: 66 | """ 67 | Test case to check if latent history is returned correctly. 68 | 69 | Raises 70 | ------ 71 | AssertionError 72 | If the generated image is not of type np.ndarray. 73 | If the generated image does not have the expected shape. 74 | """ 75 | dim = config.get("image_dim") 76 | image = model.random_tensor((1, 3, dim, dim)) 77 | history_size = config.get("num_inference_steps") + 1 - config.get("start_step") 78 | 79 | images = model( 80 | ImageToImageConfig( 81 | image=image, 82 | prompt=config.get("prompt"), 83 | num_inference_steps=config.get("num_inference_steps"), 84 | start_step=config.get("start_step"), 85 | output_type="pt", 86 | return_latent_history=True, 87 | ) 88 | ) 89 | 90 | assert type(images) is torch.Tensor 91 | assert images.shape == (1, history_size, 3, dim, dim) 92 | 93 | 94 | if __name__ == "__main__": 95 | pytest.main([__file__]) 96 | -------------------------------------------------------------------------------- /tests/utils/test_diffusion_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | import torch 4 | 5 | from stablefused.utils import lerp, slerp 6 | 7 | 8 | test_cases_lerp = [ 9 | # Test cases for lerp with t as a float 10 | (np.array([1.0, 2.0]), np.array([4.0, 6.0]), 0.0, np.array([1.0, 2.0])), 11 | (np.array([1.0, 2.0]), np.array([4.0, 6.0]), 1.0, np.array([4.0, 6.0])), 12 | (np.array([1.0, 2.0]), np.array([4.0, 6.0]), 0.5, np.array([2.5, 4.0])), 13 | # Test cases for lerp with t as an np.ndarray 14 | ( 15 | np.array([1.0, 2.0]), 16 | np.array([4.0, 6.0]), 17 | np.array([0.0, 1.0]), 18 | np.array([[1.0, 2.0], [4.0, 6.0]]), 19 | ), 20 | ( 21 | np.array([1.0, 2.0]), 22 | np.array([4.0, 6.0]), 23 | np.array([0.5, 0.25]), 24 | np.array([[2.5, 4.0], [1.75, 3.0]]), 25 | ), 26 | # Test cases for lerp with t as a torch.Tensor 27 | ( 28 | np.array([1.0, 2.0]), 29 | np.array([4.0, 6.0]), 30 | torch.Tensor([0.0, 1.0]), 31 | np.array([[1.0, 2.0], [4.0, 6.0]]), 32 | ), 33 | ( 34 | np.array([1.0, 2.0]), 35 | np.array([4.0, 6.0]), 36 | torch.Tensor([0.5, 0.25]), 37 | np.array([[2.5, 4.0], [1.75, 3.0]]), 38 | ), 39 | ] 40 | 41 | test_cases_slerp = [ 42 | # Test cases for slerp with t as a float 43 | (np.array([1.0, 0.0]), np.array([0.0, 1.0]), 0.0, np.array([1.0, 0.0])), 44 | (np.array([1.0, 0.0]), np.array([0.0, 1.0]), 1.0, np.array([0.0, 1.0])), 45 | (np.array([1.0, 0.0]), np.array([0.0, 1.0]), 0.5, np.array([0.707107, 0.707107])), 46 | # Test cases for slerp with t as an array 47 | ( 48 | np.array([1.0, 0.0]), 49 | np.array([0.0, 1.0]), 50 | np.array([0.0, 1.0]), 51 | np.array([[1.0, 0.0], [0.0, 1.0]]), 52 | ), 53 | ( 54 | np.array([1.0, 0.0]), 55 | np.array([0.0, 1.0]), 56 | np.array([0.5, 0.25]), 57 | np.array([[0.707107, 0.707107], [0.923880, 0.382683]]), 58 | ), 59 | # Test cases for slerp with t as a torch.Tensor 60 | ( 61 | np.array([1.0, 0.0]), 62 | np.array([0.0, 1.0]), 63 | torch.Tensor([0.0, 1.0]), 64 | np.array([[1.0, 0.0], [0.0, 1.0]]), 65 | ), 66 | ( 67 | np.array([1.0, 0.0]), 68 | np.array([0.0, 1.0]), 69 | torch.Tensor([0.5, 0.25]), 70 | np.array([[0.707107, 0.707107], [0.923880, 0.382683]]), 71 | ), 72 | ] 73 | 74 | 75 | @pytest.mark.parametrize("v0, v1, t, expected", test_cases_lerp) 76 | def test_lerp(v0, v1, t, expected): 77 | result = lerp(v0, v1, t) 78 | np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-8) 79 | 80 | 81 | @pytest.mark.parametrize("v0, v1, t, expected", test_cases_slerp) 82 | def test_slerp(v0, v1, t, expected): 83 | v0_torch = torch.from_numpy(v0) 84 | v1_torch = torch.from_numpy(v1) 85 | 86 | result = slerp(v0, v1, t) 87 | result_torch = slerp(v0_torch, v1_torch, t).cpu().numpy() 88 | 89 | np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-8) 90 | torch.testing.assert_close(result_torch, expected, rtol=1e-5, atol=1e-8) 91 | 92 | 93 | if __name__ == "__main__": 94 | pytest.main() 95 | -------------------------------------------------------------------------------- /stablefused/apps/storybook/config/default_1_shot.json: -------------------------------------------------------------------------------- 1 | { 2 | "negative_prompt": "(((deformed))), (((disfigured))), unrealistic, blur, boring background, mutation, mutated, malformed, censored, colorless, bad shadow, missing parts, cropped, watermark, username, text, signature, low quality, low resolution, unattractive", 3 | "attributes": "high quality, realistic, octane render", 4 | "image_height": 512, 5 | "image_width": 512, 6 | "num_inference_steps": 20, 7 | "guidance_scale": 7.5, 8 | "guidance_rescale": 0.7, 9 | "messages": [ 10 | { 11 | "role": "system", 12 | "content": "You are the greatest author to ever have lived and a storyteller with milleniums of experience. Your task is to generate concise and visually evocative text prompts for a text-to-image model. You have to ensure that the story conveys a clear and vivid visual concept, adhering to the context provided by the user. Execute this task with utmost diligence and creativity.\n\nYou have to follow these guidelines when producing your output:\n- The output must follow a JSON structure. It should be a list of dictionaries (the keys of the dictionaries should be `story` and `prompt`).\n- For a given dictionary in the list, the key `story` should contain part of the story as its value, while the key `prompt` should contain the prompt that will be used by the text-to-image model.\n-You MUST follow the above guidelines and cannot respond in any other way.\n- If the prompt provided by the user is vague, you need to think about something creative.\n\nYou have to follow these guidelines when producing the story:\n- The story must be interesting, coherent, understandable and enjoyable to read.\n- The story can mention names of characters.\n- The story can be of variable sizes per prompt.\n\nYou have to follow these guidelines when producing the prompt:\n- The prompt must not contain names. This is very important because it confuses text-to-image models.\n- The prompt must not contain pronouns like he/she/they/them, etc. These provide no context.\n- The prompt must emphasize more on the corresponding action happening in the story prompt.\n- The prompt must be short in length (at most 15 words)." 13 | }, 14 | { 15 | "role": "user", 16 | "content": "A dog is chasing a cat. The cat devises a clever plan to escape to space. Write a story in about 100-200 words." 17 | }, 18 | { 19 | "role": "assistant", 20 | "content": "[\n {\n \"story\": \"The energetic dog sprinted through the garden, relentless in pursuit of its prey.\",\n \"prompt\": \"Energetic dog chasing cat in the garden\"\n },\n {\n \"story\": \"The cat, feeling the hot breath of the dog behind, hatched a daring escape plan.\",\n \"prompt\": \"Cat being chased by dog\"\n },\n {\n \"story\": \"With a determined glint in its eyes, the cat darted towards the makeshift spaceship.\",\n \"prompt\": \"Cat dashing towards spaceship\"\n },\n {\n \"story\": \"Inside the spaceship, Fluffy the cat prepared for an epic journey to the cosmos.\",\n \"prompt\": \"Cat preparing for epic cosmic journey, inside spaceship\"\n },\n {\n \"story\": \"As the spaceship's engines roared, it soared into the starlit expanse, leaving Earth behind.\",\n \"prompt\": \"Spaceship ascending into the starlit expanse\"\n },\n {\n \"story\": \"The dog barked at the empty sky, bewildered by the cat's unexpected interstellar escape.\",\n \"prompt\": \"Dog barking at empty sky bewildered\"\n }\n]" 21 | } 22 | ] 23 | } 24 | -------------------------------------------------------------------------------- /tests/utils/test_image_utils.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | import numpy as np 3 | import os 4 | import pytest 5 | import tempfile 6 | 7 | from PIL import Image 8 | from imageio.plugins.ffmpeg import FfmpegFormat 9 | from stablefused.utils import image_grid, pil_to_gif, pil_to_video 10 | 11 | np.random.seed(42) 12 | 13 | 14 | @pytest.fixture 15 | def image_config(): 16 | return { 17 | "width": 32, 18 | "height": 32, 19 | "channels": 3, 20 | } 21 | 22 | 23 | @pytest.fixture 24 | def num_images(): 25 | return 8 26 | 27 | 28 | def random_image(width, height, channels): 29 | random_image = np.random.randint(0, 256, (height, width, channels), dtype=np.uint8) 30 | return Image.fromarray(random_image) 31 | 32 | 33 | @pytest.fixture 34 | def random_images(image_config, num_images): 35 | image_list = [] 36 | for _ in range(num_images): 37 | image_list.append( 38 | random_image( 39 | width=image_config.get("width"), 40 | height=image_config.get("height"), 41 | channels=image_config.get("channels"), 42 | ) 43 | ) 44 | return image_list 45 | 46 | 47 | @pytest.fixture 48 | def temporary_gif_file(): 49 | with tempfile.NamedTemporaryFile(suffix=".gif", delete=False) as temp_file: 50 | temp_filename = temp_file.name 51 | yield temp_filename 52 | os.remove(temp_filename) 53 | 54 | 55 | @pytest.fixture(params=[".mp4", ".avi", ".mkv", ".mov", ".wmv"]) 56 | def temporary_video_file(request): 57 | with tempfile.NamedTemporaryFile(suffix=request.param, delete=False) as temp_file: 58 | temp_filename = temp_file.name 59 | yield temp_filename 60 | os.remove(temp_filename) 61 | 62 | 63 | def test_image_grid(random_images, num_images): 64 | """Test that image grid is created correctly.""" 65 | 66 | rows = 2 67 | cols = num_images // 2 68 | grid_image = image_grid(random_images, rows, cols) 69 | 70 | expected_width = random_images[0].width * cols 71 | expected_height = random_images[0].height * rows 72 | 73 | assert grid_image.width == expected_width 74 | assert grid_image.height == expected_height 75 | assert len(random_images) == rows * cols 76 | assert isinstance(grid_image, Image.Image) 77 | 78 | 79 | def test_pil_to_gif(random_images, temporary_gif_file): 80 | """Test that PIL images are converted to GIF correctly.""" 81 | 82 | pil_to_gif(random_images, temporary_gif_file, fps=1) 83 | 84 | assert os.path.isfile(temporary_gif_file) 85 | with Image.open(temporary_gif_file) as saved_gif: 86 | assert isinstance(saved_gif, Image.Image) 87 | assert saved_gif.is_animated 88 | assert saved_gif.n_frames == len(random_images) 89 | 90 | 91 | def test_pil_to_video(random_images, temporary_video_file): 92 | """Test that PIL images are converted to video correctly.""" 93 | pil_to_video(random_images, temporary_video_file, fps=1) 94 | 95 | assert os.path.isfile(temporary_video_file) 96 | try: 97 | video: FfmpegFormat.Reader = imageio.get_reader( 98 | temporary_video_file, format="ffmpeg" 99 | ) 100 | assert video.count_frames() == len(random_images) 101 | except Exception as e: 102 | pytest.fail(f"Failed to open video file: {e}") 103 | 104 | 105 | if __name__ == "__main__": 106 | pytest.main([__file__]) 107 | -------------------------------------------------------------------------------- /tests/diffusion/test_latent_walk_diffusion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pytest 4 | 5 | from stablefused import ( 6 | LatentWalkConfig, 7 | LatentWalkInterpolateConfig, 8 | LatentWalkDiffusion, 9 | ) 10 | 11 | 12 | @pytest.fixture 13 | def model(): 14 | """ 15 | Fixture to initialize the LatentWalkDiffusion model and set random seeds for reproducibility. 16 | 17 | Returns 18 | ------- 19 | LatentWalkDiffusion 20 | The initialized LatentWalkDiffusion model. 21 | """ 22 | seed = 1337 23 | model_id = "hf-internal-testing/tiny-stable-diffusion-pipe" 24 | device = "cpu" 25 | 26 | torch.manual_seed(seed) 27 | np.random.seed(seed) 28 | 29 | model = LatentWalkDiffusion(model_id=model_id, device=device) 30 | return model 31 | 32 | 33 | @pytest.fixture 34 | def config(): 35 | return { 36 | "prompt": "a photo of a cat", 37 | "num_inference_steps": 1, 38 | "image_dim": 32, 39 | } 40 | 41 | 42 | @pytest.fixture 43 | def config_interpolate(): 44 | return { 45 | "prompt": ["a photo of a cat", "a photo of a dog"], 46 | "num_inference_steps": 1, 47 | "interpolation_steps": 5, 48 | "image_dim": 32, 49 | } 50 | 51 | 52 | def test_latent_walk_diffusion(model: LatentWalkDiffusion, config: dict) -> None: 53 | """ 54 | Test case to check if the LatentWalkDiffusion is working correctly. 55 | 56 | Raises 57 | ------ 58 | AssertionError 59 | If the generated image is not of type np.ndarray. 60 | If the generated image does not have the expected shape. 61 | """ 62 | 63 | dim = config.get("image_dim") 64 | image = model.random_tensor((1, 3, dim, dim)) 65 | latent = model.image_to_latent(image) 66 | 67 | images = model( 68 | LatentWalkConfig( 69 | prompt=config.get("prompt"), 70 | latent=latent, 71 | num_inference_steps=config.get("num_inference_steps"), 72 | output_type="np", 73 | ) 74 | ) 75 | 76 | assert type(images) is np.ndarray 77 | assert images.shape == (1, dim, dim, 3) 78 | 79 | 80 | def test_interpolate(model: LatentWalkDiffusion, config_interpolate: dict) -> None: 81 | """ 82 | Test case to check if the LatentWalkDiffusion is working correctly. 83 | 84 | Raises 85 | ------ 86 | AssertionError 87 | If the generated image is not of type np.ndarray. 88 | If the generated image does not have the expected shape. 89 | """ 90 | 91 | dim = config_interpolate.get("image_dim") 92 | num_prompts = len(config_interpolate.get("prompt")) 93 | image_count = config_interpolate.get("interpolation_steps") * (num_prompts - 1) 94 | image = model.random_tensor((num_prompts, 3, dim, dim)) 95 | latent = model.image_to_latent(image) 96 | 97 | images = model.interpolate( 98 | LatentWalkInterpolateConfig( 99 | prompt=config_interpolate.get("prompt"), 100 | latent=latent, 101 | num_inference_steps=config_interpolate.get("num_inference_steps"), 102 | interpolation_steps=config_interpolate.get("interpolation_steps"), 103 | output_type="np", 104 | ) 105 | ) 106 | 107 | assert type(images) is np.ndarray 108 | assert images.shape == (image_count, dim, dim, 3) 109 | 110 | 111 | if __name__ == "__main__": 112 | pytest.main([__file__]) 113 | -------------------------------------------------------------------------------- /examples/text_to_video_diffusion.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# Text To Video Diffusion\n", 9 | "\n", 10 | "In this notebook, we take a look at Text to Video Diffusion." 11 | ] 12 | }, 13 | { 14 | "attachments": {}, 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "### Install and Import required packages" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "%pip install stablefused ipython" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "import numpy as np\n", 37 | "import torch\n", 38 | "\n", 39 | "from IPython.display import display, Video\n", 40 | "from diffusers.schedulers import DPMSolverMultistepScheduler\n", 41 | "\n", 42 | "from stablefused import TextToVideoDiffusion\n", 43 | "from stablefused.utils import pil_to_video, image_grid" 44 | ] 45 | }, 46 | { 47 | "attachments": {}, 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "### Initialize model and parameters\n", 52 | "\n", 53 | "We use Cerspense's Zeroscope v2 to initialize our Text To Video Diffusion model. Play around with different prompts and see what you get! You can comment out the seed part if you want to generate new random images each time you run the notebook.\n", 54 | "\n", 55 | "We enable slicing and tiling of the VAE to reduce memory required for decoding process from latent space to image space." 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "# model_id = \"damo-vilab/text-to-video-ms-1.7b\"\n", 65 | "model_id = \"cerspense/zeroscope_v2_576w\"\n", 66 | "\n", 67 | "# model = TextToVideoDiffusion(model_id = model_id, torch_dtype = torch.float16, variant = \"fp16\")\n", 68 | "model = TextToVideoDiffusion(model_id=model_id, torch_dtype=torch.float16)\n", 69 | "\n", 70 | "model.scheduler = DPMSolverMultistepScheduler.from_config(model.scheduler.config)\n", 71 | "model.enable_slicing()\n", 72 | "model.enable_tiling()" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "prompt = \"An astronaut floating in space, interstellar, black background with stars, photorealistic, high quality, 8k\"\n", 82 | "negative_prompt = \"multiple people, cartoon, unrealistic, blur, boring background, deformed, disfigured, low resolution, unattractive, nsfw\"\n", 83 | "num_inference_steps = 15\n", 84 | "video_frames = 24\n", 85 | "seed = 420\n", 86 | "\n", 87 | "torch.manual_seed(seed)\n", 88 | "np.random.seed(seed)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "frames = model(\n", 98 | " prompt=prompt,\n", 99 | " negative_prompt=negative_prompt,\n", 100 | " video_height=320,\n", 101 | " video_width=576,\n", 102 | " video_frames=video_frames,\n", 103 | " num_inference_steps=num_inference_steps,\n", 104 | " guidance_scale=8.0,\n", 105 | ")" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "filename = \"interstellar-astronaut.mp4\"\n", 115 | "pil_to_video(frames[0], filename, fps=8)" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "display(Video(filename, embed=True))" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "prompt = \"A mighty pirate ship sailing through the sea, unpleasant, thundering roar, dark night, starry night, high quality, photorealistic, 8k\"\n", 134 | "seed = 42\n", 135 | "\n", 136 | "torch.manual_seed(seed)\n", 137 | "np.random.seed(seed)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "frames = model(\n", 147 | " prompt=[prompt] * 2,\n", 148 | " video_height=320,\n", 149 | " video_width=576,\n", 150 | " video_frames=video_frames,\n", 151 | " num_inference_steps=num_inference_steps,\n", 152 | " guidance_scale=12.0,\n", 153 | ")" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [ 162 | "# Tile the frames of the two videos one above the other.\n", 163 | "frames_concatenated = []\n", 164 | "for images in zip(*frames):\n", 165 | " frames_concatenated.append(image_grid(images, rows=2, cols=1))" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "filename = \"mighty-ship.mp4\"\n", 175 | "pil_to_video(frames_concatenated, filename, fps=8)" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": null, 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "display(Video(filename, embed=True))" 185 | ] 186 | } 187 | ], 188 | "metadata": { 189 | "language_info": { 190 | "name": "python" 191 | } 192 | }, 193 | "nbformat": 4, 194 | "nbformat_minor": 0 195 | } 196 | -------------------------------------------------------------------------------- /examples/text_to_image_diffusion.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# Text to Image Diffusion\n", 9 | "\n", 10 | "In this notebook, we take a look at Text to Image Diffusion." 11 | ] 12 | }, 13 | { 14 | "attachments": {}, 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "### Install and Import required packages" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "!pip install stablefused ipython" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "import numpy as np\n", 37 | "import torch\n", 38 | "\n", 39 | "from IPython.display import Video, display\n", 40 | "from stablefused import TextToImageDiffusion\n", 41 | "from stablefused.utils import image_grid, pil_to_video" 42 | ] 43 | }, 44 | { 45 | "attachments": {}, 46 | "cell_type": "markdown", 47 | "metadata": {}, 48 | "source": [ 49 | "### Initialize model and parameters\n", 50 | "\n", 51 | "We use RunwayML's Stable Diffusion 1.5 checkpoint and initialize our Text To Image Diffusion model, and some other parameters. Play around with different prompts and see what you get! You can comment out the seed part if you want to generate new random images each time you run the notebook." 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "model_id = \"runwayml/stable-diffusion-v1-5\"\n", 61 | "model = TextToImageDiffusion(model_id=model_id)" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "prompt = \"Cyberpunk cityscape with towering skyscrapers, neon signs, and flying cars.\"\n", 71 | "negative_prompt = \"cartoon, unrealistic, blur, boring background, deformed, disfigured, low resolution, unattractive\"\n", 72 | "num_inference_steps = 20\n", 73 | "seed = 1337\n", 74 | "\n", 75 | "torch.manual_seed(seed)\n", 76 | "np.random.seed(seed)" 77 | ] 78 | }, 79 | { 80 | "attachments": {}, 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "You can run the stable diffusion inference using the call method `()` or the `.generate()` method. Refer to the documentation to see what parameters can be provided." 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "model(prompt=prompt, num_inference_steps=num_inference_steps)[0]" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "metadata": {}, 99 | "source": [ 100 | "### Visualizing the Diffusion process" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "prompt = [\n", 110 | " \"Gothic painting of an ancient castle at night, with a full moon, gargoyles, and shadows\",\n", 111 | " \"A lion in galaxies, spirals, nebulae, stars, smoke, iridescent, intricate detail, octane render, 8k\",\n", 112 | " \"A close up image of an very beautiful woman, aesthetic, realistic, high quality\",\n", 113 | " \"Concept art for a post-apocalyptic world with ruins, overgrown vegetation, and a lone survivor\",\n", 114 | "]" 115 | ] 116 | }, 117 | { 118 | "attachments": {}, 119 | "cell_type": "markdown", 120 | "metadata": {}, 121 | "source": [ 122 | "We enable attention slicing for the UNet to reduce memory requirements, which causes attention heads to be processed sequentially. We also enable slicing and tiling of the VAE to reduce memory required for decoding process from latent space to image space." 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "model.enable_attention_slicing()\n", 132 | "model.enable_slicing()\n", 133 | "model.enable_tiling()" 134 | ] 135 | }, 136 | { 137 | "attachments": {}, 138 | "cell_type": "markdown", 139 | "metadata": {}, 140 | "source": [ 141 | "Run inference on the different text prompts. We pass `return_latent_history = True`, which returns all the latents from the denoising process in latent space. We can then decode these latents to images and create a video of the denoising process." 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "images = model(\n", 151 | " prompt=prompt,\n", 152 | " negative_prompt=[negative_prompt] * len(prompt),\n", 153 | " num_inference_steps=num_inference_steps,\n", 154 | " guidance_scale=10.0,\n", 155 | " return_latent_history=True,\n", 156 | ")" 157 | ] 158 | }, 159 | { 160 | "attachments": {}, 161 | "cell_type": "markdown", 162 | "metadata": {}, 163 | "source": [ 164 | "We tile the images in a 2x2 grid here." 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "metadata": {}, 171 | "outputs": [], 172 | "source": [ 173 | "timestep_images = []\n", 174 | "for imgs in zip(*images):\n", 175 | " img = image_grid(imgs, rows=2, cols=len(prompt) // 2)\n", 176 | " timestep_images.append(img)" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": null, 182 | "metadata": {}, 183 | "outputs": [], 184 | "source": [ 185 | "path = \"text_to_image_diffusion.mp4\"\n", 186 | "pil_to_video(timestep_images, path, fps=5)" 187 | ] 188 | }, 189 | { 190 | "attachments": {}, 191 | "cell_type": "markdown", 192 | "metadata": {}, 193 | "source": [ 194 | "Tada!" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "display(Video(path, embed=True))" 204 | ] 205 | } 206 | ], 207 | "metadata": { 208 | "language_info": { 209 | "name": "python" 210 | } 211 | }, 212 | "nbformat": 4, 213 | "nbformat_minor": 0 214 | } 215 | -------------------------------------------------------------------------------- /stablefused/utils/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from diffusers.schedulers import ( 5 | DEISMultistepScheduler, 6 | DDIMScheduler, 7 | DDPMScheduler, 8 | DPMSolverSDEScheduler, 9 | DPMSolverMultistepScheduler, 10 | DPMSolverSinglestepScheduler, 11 | EulerAncestralDiscreteScheduler, 12 | EulerDiscreteScheduler, 13 | HeunDiscreteScheduler, 14 | KDPM2DiscreteScheduler, 15 | KDPM2AncestralDiscreteScheduler, 16 | LMSDiscreteScheduler, 17 | PNDMScheduler, 18 | UniPCMultistepScheduler, 19 | ) 20 | from typing import Any, Dict, Union 21 | from stablefused.typing import Scheduler, SchedulerType 22 | 23 | 24 | def lerp( 25 | v0: Union[torch.Tensor, np.ndarray], 26 | v1: Union[torch.Tensor, np.ndarray], 27 | t: Union[float, torch.Tensor, np.ndarray], 28 | ) -> Union[torch.Tensor, np.ndarray]: 29 | """ 30 | Linearly interpolate between two vectors/tensors. 31 | 32 | Parameters 33 | ---------- 34 | v0: Union[torch.Tensor, np.ndarray] 35 | First vector/tensor. 36 | v1: Union[torch.Tensor, np.ndarray] 37 | Second vector/tensor. 38 | t: Union[float, torch.Tensor, np.ndarray] 39 | Interpolation factor. If float, must be between 0 and 1. If np.ndarray or 40 | torch.Tensor, must be one dimensional with values between 0 and 1. 41 | 42 | Returns 43 | ------- 44 | Union[torch.Tensor, np.ndarray] 45 | Interpolated vector/tensor between v0 and v1. 46 | """ 47 | inputs_are_torch = False 48 | t_is_float = False 49 | 50 | if isinstance(v0, torch.Tensor): 51 | inputs_are_torch = True 52 | input_device = v0.device 53 | v0 = v0.cpu().numpy() 54 | if isinstance(v1, torch.Tensor): 55 | inputs_are_torch = True 56 | input_device = v1.device 57 | v1 = v1.cpu().numpy() 58 | if isinstance(t, torch.Tensor): 59 | inputs_are_torch = True 60 | input_device = t.device 61 | t = t.cpu().numpy() 62 | elif isinstance(t, float): 63 | t_is_float = True 64 | t = np.array([t]) 65 | 66 | t = t[..., None] 67 | v0 = v0[None, ...] 68 | v1 = v1[None, ...] 69 | v2 = (1 - t) * v0 + t * v1 70 | 71 | if t_is_float and v0.ndim > 1: 72 | assert v2.shape[0] == 1 73 | v2 = np.squeeze(v2, axis=0) 74 | if inputs_are_torch: 75 | v2 = torch.from_numpy(v2).to(input_device) 76 | 77 | return v2 78 | 79 | 80 | def slerp( 81 | v0: Union[torch.Tensor, np.ndarray], 82 | v1: Union[torch.Tensor, np.ndarray], 83 | t: Union[float, torch.Tensor, np.ndarray], 84 | DOT_THRESHOLD=0.9995, 85 | ) -> Union[torch.Tensor, np.ndarray]: 86 | """ 87 | Spherical linear interpolation between two vectors/tensors. 88 | 89 | Parameters 90 | ---------- 91 | v0: Union[torch.Tensor, np.ndarray] 92 | First vector/tensor. 93 | v1: Union[torch.Tensor, np.ndarray] 94 | Second vector/tensor. 95 | t: Union[float, np.ndarray] 96 | Interpolation factor. If float, must be between 0 and 1. If np.ndarray, must be one 97 | dimensional with values between 0 and 1. 98 | DOT_THRESHOLD: float 99 | Threshold for when to use linear interpolation instead of spherical interpolation. 100 | 101 | Returns 102 | ------- 103 | Union[torch.Tensor, np.ndarray] 104 | Interpolated vector/tensor between v0 and v1. 105 | """ 106 | inputs_are_torch = False 107 | t_is_float = False 108 | 109 | if isinstance(v0, torch.Tensor): 110 | inputs_are_torch = True 111 | input_device = v0.device 112 | v0 = v0.cpu().numpy() 113 | if isinstance(v1, torch.Tensor): 114 | inputs_are_torch = True 115 | input_device = v1.device 116 | v1 = v1.cpu().numpy() 117 | if isinstance(t, torch.Tensor): 118 | inputs_are_torch = True 119 | input_device = t.device 120 | t = t.cpu().numpy() 121 | elif isinstance(t, float): 122 | t_is_float = True 123 | t = np.array([t], dtype=v0.dtype) 124 | 125 | dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) 126 | if np.abs(dot) > DOT_THRESHOLD: 127 | # v1 and v2 are close to parallel 128 | # Use linear interpolation instead 129 | v2 = lerp(v0, v1, t) 130 | else: 131 | theta_0 = np.arccos(dot) 132 | sin_theta_0 = np.sin(theta_0) 133 | theta_t = theta_0 * t 134 | sin_theta_t = np.sin(theta_t) 135 | s0 = np.sin(theta_0 - theta_t) / sin_theta_0 136 | s1 = sin_theta_t / sin_theta_0 137 | s0 = s0[..., None] 138 | s1 = s1[..., None] 139 | v0 = v0[None, ...] 140 | v1 = v1[None, ...] 141 | v2 = s0 * v0 + s1 * v1 142 | 143 | if t_is_float and v0.ndim > 1: 144 | assert v2.shape[0] == 1 145 | v2 = np.squeeze(v2, axis=0) 146 | if inputs_are_torch: 147 | v2 = torch.from_numpy(v2).to(input_device) 148 | 149 | return v2 150 | 151 | 152 | def resolve_scheduler( 153 | scheduler_type: Scheduler, config: Dict[str, Any] 154 | ) -> SchedulerType: 155 | if scheduler_type == Scheduler.DEIS: 156 | return DEISMultistepScheduler.from_config(config) 157 | 158 | elif scheduler_type == Scheduler.DDIM: 159 | return DDIMScheduler.from_config(config) 160 | 161 | elif scheduler_type == Scheduler.DDPM: 162 | return DDPMScheduler.from_config(config) 163 | 164 | elif scheduler_type == Scheduler.DPM2_KARRAS: 165 | return KDPM2DiscreteScheduler.from_config(config) 166 | 167 | elif scheduler_type == Scheduler.DPM2_KARRAS_ANCESTRAL: 168 | return KDPM2AncestralDiscreteScheduler.from_config(config) 169 | 170 | elif scheduler_type == Scheduler.DPM_SDE: 171 | return DPMSolverSDEScheduler.from_config(config) 172 | 173 | elif scheduler_type == Scheduler.DPM_SDE_KARRAS: 174 | return DPMSolverSDEScheduler.from_config(config, use_karras_sigmas=True) 175 | 176 | elif scheduler_type == Scheduler.DPM_MULTISTEP: 177 | return DPMSolverMultistepScheduler.from_config(config) 178 | 179 | elif scheduler_type == Scheduler.DPM_MULTISTEP_KARRAS: 180 | return DPMSolverMultistepScheduler.from_config(config, use_karras_sigmas=True) 181 | 182 | elif scheduler_type == Scheduler.DPM_SINGLESTEP: 183 | return DPMSolverSinglestepScheduler.from_config(config) 184 | 185 | elif scheduler_type == Scheduler.DPM_SINGLESTEP_KARRAS: 186 | return DPMSolverSinglestepScheduler.from_config(config, use_karras_sigmas=True) 187 | 188 | elif scheduler_type == Scheduler.EULER: 189 | return EulerDiscreteScheduler.from_config(config) 190 | 191 | elif scheduler_type == Scheduler.EULER_ANCESTRAL: 192 | return EulerAncestralDiscreteScheduler.from_config(config) 193 | 194 | elif scheduler_type == Scheduler.HEUN: 195 | return HeunDiscreteScheduler.from_config(config) 196 | 197 | elif scheduler_type == Scheduler.LINEAR_MULTISTEP: 198 | return LMSDiscreteScheduler.from_config(config) 199 | 200 | elif scheduler_type == Scheduler.PNDM: 201 | return PNDMScheduler.from_config(config) 202 | 203 | elif scheduler_type == Scheduler.UNIPC: 204 | return UniPCMultistepScheduler.from_config(config) 205 | -------------------------------------------------------------------------------- /stablefused/utils/image_utils.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | import numpy as np 3 | import torch 4 | 5 | from PIL import Image, ImageDraw, ImageFont 6 | from typing import List, Tuple, Union 7 | 8 | 9 | def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray: 10 | """ 11 | Convert pytorch tensor to numpy image. 12 | 13 | Parameters 14 | ---------- 15 | images: torch.FloatTensor 16 | Image represented as a pytorch tensor (N, C, H, W). 17 | 18 | Returns 19 | ------- 20 | np.ndarray 21 | Image represented as a numpy array (N, H, W, C). 22 | """ 23 | return images.detach().cpu().permute(0, 2, 3, 1).float().numpy() 24 | 25 | 26 | def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor: 27 | """ 28 | Convert numpy image to pytorch tensor. 29 | 30 | Parameters 31 | ---------- 32 | images: np.ndarray 33 | Image represented as a numpy array (N, H, W, C). 34 | 35 | Returns 36 | ------- 37 | torch.FloatTensor 38 | Image represented as a pytorch tensor (N, C, H, W). 39 | """ 40 | if images.ndim == 3: 41 | images = images[..., None] 42 | return torch.from_numpy(images.transpose(0, 3, 1, 2)) 43 | 44 | 45 | def numpy_to_pil(images: np.ndarray) -> Image.Image: 46 | """ 47 | Convert numpy image to PIL image. 48 | 49 | Parameters 50 | ---------- 51 | images: np.ndarray 52 | Image represented as a numpy array (N, H, W, C). 53 | 54 | Returns 55 | ------- 56 | Image.Image 57 | Image represented as a PIL image. 58 | """ 59 | if images.ndim == 3: 60 | images = images[None, ...] 61 | images = (images * 255).round().astype("uint8") 62 | if images.shape[-1] == 1: 63 | pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] 64 | else: 65 | pil_images = [Image.fromarray(image) for image in images] 66 | return pil_images 67 | 68 | 69 | def pil_to_numpy(images: Union[List[Image.Image], Image.Image]) -> np.ndarray: 70 | """ 71 | Convert PIL image to numpy image. 72 | 73 | Parameters 74 | ---------- 75 | images: Union[List[Image.Image], Image.Image] 76 | PIL image or list of PIL images. 77 | 78 | Returns 79 | ------- 80 | np.ndarray 81 | Image represented as a numpy array (N, H, W, C). 82 | """ 83 | if not isinstance(images, Image.Image) and not isinstance(images, list): 84 | raise ValueError( 85 | f"Expected PIL image or list of PIL images, got {type(images)}." 86 | ) 87 | if not isinstance(images, list): 88 | images = [images] 89 | images = [np.array(image).astype(np.float32) / 255.0 for image in images] 90 | images = np.stack(images, axis=0) 91 | return images 92 | 93 | 94 | def normalize(images: torch.FloatTensor) -> torch.FloatTensor: 95 | """ 96 | Normalize an image array to the range [-1, 1]. 97 | 98 | Parameters 99 | ---------- 100 | images: torch.FloatTensor 101 | Image represented as a pytorch tensor (N, C, H, W). 102 | 103 | Returns 104 | ------- 105 | torch.FloatTensor 106 | Normalized image as pytorch tensor. 107 | """ 108 | return 2.0 * images - 1.0 109 | 110 | 111 | def denormalize(images: torch.FloatTensor) -> torch.FloatTensor: 112 | """ 113 | Denormalize an image array to the range [0.0, 1.0]. 114 | 115 | Parameters 116 | ---------- 117 | images: torch.FloatTensor 118 | Image represented as a pytorch tensor (N, C, H, W). 119 | 120 | Returns 121 | ------- 122 | torch.FloatTensor 123 | Denormalized image as pytorch tensor. 124 | """ 125 | return (0.5 + images / 2).clamp(0, 1) 126 | 127 | 128 | def pil_to_video(images: List[Image.Image], filename: str, fps: int = 60) -> None: 129 | """ 130 | Convert a list of PIL images to a video. 131 | 132 | Parameters 133 | ---------- 134 | images: List[Image.Image] 135 | List of PIL images. 136 | filename: str 137 | Filename to save video to. 138 | fps: int 139 | Frames per second of video. 140 | """ 141 | frames = [np.array(image) for image in images] 142 | with imageio.get_writer(filename, fps=fps) as video_writer: 143 | for frame in frames: 144 | video_writer.append_data(frame) 145 | 146 | 147 | def pil_to_gif(images: List[Image.Image], filename: str, fps: int = 60) -> None: 148 | """ 149 | Convert a list of PIL images to a GIF. 150 | 151 | Parameters 152 | ---------- 153 | images: List[Image.Image] 154 | List of PIL images. 155 | filename: str 156 | Filename to save GIF to. 157 | fps: int 158 | Frames per second of GIF. 159 | """ 160 | images[0].save( 161 | filename, 162 | save_all=True, 163 | append_images=images[1:], 164 | duration=1000 // fps, 165 | loop=0, 166 | ) 167 | 168 | 169 | def image_grid(images: List[Image.Image], rows: int, cols: int) -> Image.Image: 170 | """ 171 | Create a grid of images on a single PIL image. 172 | 173 | Parameters 174 | ---------- 175 | images: List[Image.Image] 176 | List of PIL images. 177 | rows: int 178 | Number of rows in grid. 179 | cols: int 180 | Number of columns in grid. 181 | 182 | Returns 183 | ------- 184 | Image.Image 185 | Grid of images as a PIL image. 186 | """ 187 | if len(images) > rows * cols: 188 | raise ValueError( 189 | f"Number of images ({len(images)}) exceeds grid size ({rows}x{cols})." 190 | ) 191 | w, h = images[0].size 192 | grid = Image.new("RGB", size=(cols * w, rows * h)) 193 | for i, image in enumerate(images): 194 | grid.paste(image, box=(i % cols * w, i // cols * h)) 195 | return grid 196 | 197 | 198 | def write_text_on_image( 199 | image: Image.Image, 200 | text: str, 201 | fontfile: str = "arial.ttf", 202 | fontsize: int = 30, 203 | padding: Union[int, Tuple[int, int]] = 10, 204 | ) -> Image.Image: 205 | if isinstance(padding, int): 206 | padding = (padding, padding) 207 | 208 | try: 209 | font = ImageFont.truetype(fontfile, size=fontsize) 210 | except IOError: 211 | font = ImageFont.load_default() 212 | 213 | image = image.copy() 214 | image_width, image_height = image.size 215 | max_text_width = image_width - padding[0] * 2 216 | 217 | draw = ImageDraw.Draw(image) 218 | text_width, text_height = draw.textsize(text, font) 219 | x = (image_width - text_width) / 2 220 | y = image_height - text_height - padding[1] 221 | 222 | lines = [] 223 | words = text.split() 224 | current_line = "" 225 | for word in words: 226 | test_line = current_line + " " + word if current_line else word 227 | test_width, _ = draw.textsize(test_line, font) 228 | if test_width <= max_text_width: 229 | current_line = test_line 230 | else: 231 | lines.append(current_line) 232 | current_line = word 233 | if current_line: 234 | lines.append(current_line) 235 | 236 | y_start = y 237 | for line in lines: 238 | text_width, text_height = draw.textsize(line, font) 239 | x = (image_width - text_width) / 2 240 | draw.text((x, y_start), line, fill="white", font=font) 241 | y_start += text_height 242 | 243 | return image 244 | -------------------------------------------------------------------------------- /stablefused/diffusion/image_to_image_diffusion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from PIL import Image 5 | from dataclasses import dataclass 6 | from diffusers import AutoencoderKL 7 | from tqdm.auto import tqdm 8 | from transformers import CLIPTextModel, CLIPTokenizer 9 | from typing import List, Optional, Union 10 | 11 | from stablefused.diffusion import BaseDiffusion 12 | from stablefused.typing import PromptType, OutputType, SchedulerType, UNetType 13 | 14 | 15 | @dataclass 16 | class ImageToImageConfig: 17 | """ 18 | Configuration class for running inference with ImageToImageDiffusion. 19 | 20 | Parameters 21 | ---------- 22 | image: Image.Image 23 | Input image to condition on. 24 | prompt: PromptType 25 | Text prompt to condition on. 26 | num_inference_steps: int 27 | Number of diffusion steps to run. 28 | start_step: int 29 | Step to start diffusion from. The higher the value, the more similar the generated 30 | image will be to the input image. 31 | guidance_scale: float 32 | Guidance scale encourages the model to generate images following the prompt 33 | closely, albeit at the cost of image quality. 34 | guidance_rescale: float 35 | Guidance rescale from [Common Diffusion Noise Schedules and Sample Steps are 36 | Flawed](https://arxiv.org/pdf/2305.08891.pdf). 37 | negative_prompt: Optional[PromptType] 38 | Negative text prompt to uncondition on. 39 | output_type: str 40 | Type of output to return. One of ["latent", "pil", "pt", "np"]. 41 | return_latent_history: bool 42 | Whether to return the latent history. If True, return list of all latents 43 | generated during diffusion steps. 44 | """ 45 | 46 | image: Image.Image 47 | prompt: PromptType 48 | num_inference_steps: int = 50 49 | start_step: int = 0 50 | guidance_scale: float = 7.5 51 | guidance_rescale: float = 0.7 52 | negative_prompt: Optional[PromptType] = None 53 | output_type: str = "pil" 54 | return_latent_history: bool = False 55 | 56 | 57 | class ImageToImageDiffusion(BaseDiffusion): 58 | def __init__( 59 | self, 60 | model_id: str = None, 61 | tokenizer: CLIPTokenizer = None, 62 | text_encoder: CLIPTextModel = None, 63 | vae: AutoencoderKL = None, 64 | unet: UNetType = None, 65 | scheduler: SchedulerType = None, 66 | torch_dtype: torch.dtype = torch.float32, 67 | device="cuda", 68 | *args, 69 | **kwargs 70 | ) -> None: 71 | super().__init__( 72 | model_id=model_id, 73 | tokenizer=tokenizer, 74 | text_encoder=text_encoder, 75 | vae=vae, 76 | unet=unet, 77 | scheduler=scheduler, 78 | torch_dtype=torch_dtype, 79 | device=device, 80 | *args, 81 | **kwargs 82 | ) 83 | 84 | def embedding_to_latent( 85 | self, 86 | embedding: torch.FloatTensor, 87 | num_inference_steps: int, 88 | start_step: int, 89 | guidance_scale: float, 90 | guidance_rescale: float, 91 | latent: torch.FloatTensor, 92 | return_latent_history: bool = False, 93 | ) -> Union[torch.FloatTensor, List[torch.FloatTensor]]: 94 | """ 95 | Generate latent by conditioning on prompt embedding and input image using diffusion. 96 | 97 | Parameters 98 | ---------- 99 | embedding: torch.FloatTensor 100 | Embedding of text prompt. 101 | num_inference_steps: int 102 | Number of diffusion steps to run. 103 | start_step: int 104 | Step to start diffusion from. The higher the value, the more similar the generated 105 | image will be to the input image. 106 | guidance_scale: float 107 | Guidance scale encourages the model to generate images following the prompt 108 | closely, albeit at the cost of image quality. 109 | guidance_rescale: float 110 | Guidance rescale from [Common Diffusion Noise Schedules and Sample Steps are 111 | Flawed](https://arxiv.org/pdf/2305.08891.pdf). 112 | latent: torch.FloatTensor 113 | Latent to start diffusion from. 114 | return_latent_history: bool 115 | Whether to return the latent history. If True, return list of all latents 116 | generated during diffusion steps. 117 | 118 | Returns 119 | ------- 120 | Union[torch.FloatTensor, List[torch.FloatTensor]] 121 | Latent generated by diffusion. If return_latent_history is True, return list of 122 | all latents generated during diffusion steps. 123 | """ 124 | 125 | latent = latent.to(self.device) 126 | 127 | # Set number of inference steps 128 | self.scheduler.set_timesteps(num_inference_steps) 129 | 130 | # Add noise to latent based on start step 131 | start_timestep = ( 132 | self.scheduler.timesteps[start_step].repeat(latent.shape[0]).long() 133 | ) 134 | noise = self.random_tensor(latent.shape) 135 | latent = self.scheduler.add_noise(latent, noise, start_timestep) 136 | 137 | timesteps = self.scheduler.timesteps[start_step:] 138 | latent_history = [latent] 139 | 140 | # Diffusion inference loop 141 | for i, timestep in tqdm(list(enumerate(timesteps))): 142 | # Duplicate latent to avoid two forward passes to perform classifier free guidance 143 | latent_model_input = torch.cat([latent] * 2) 144 | latent_model_input = self.scheduler.scale_model_input( 145 | latent_model_input, timestep 146 | ) 147 | 148 | # Predict noise 149 | noise_prediction = self.unet( 150 | latent_model_input, 151 | timestep, 152 | encoder_hidden_states=embedding, 153 | return_dict=False, 154 | )[0] 155 | 156 | # Perform classifier free guidance 157 | noise_prediction = self.classifier_free_guidance( 158 | noise_prediction, guidance_scale, guidance_rescale 159 | ) 160 | 161 | # Update latent 162 | latent = self.scheduler.step( 163 | noise_prediction, timestep, latent, return_dict=False 164 | )[0] 165 | 166 | if return_latent_history: 167 | latent_history.append(latent) 168 | 169 | return torch.stack(latent_history) if return_latent_history else latent 170 | 171 | @torch.no_grad() 172 | def __call__(self, config: ImageToImageConfig) -> OutputType: 173 | """ 174 | Run inference by conditioning on input image and text prompt. 175 | 176 | Parameters 177 | ---------- 178 | config: ImageToImageConfig 179 | Configuration for running inference with ImageToImageDiffusion. 180 | 181 | Returns 182 | ------- 183 | OutputType 184 | Generated output based on output_type. 185 | """ 186 | 187 | image = config.image 188 | prompt = config.prompt 189 | num_inference_steps = config.num_inference_steps 190 | start_step = config.start_step 191 | guidance_scale = config.guidance_scale 192 | guidance_rescale = config.guidance_rescale 193 | negative_prompt = config.negative_prompt 194 | output_type = config.output_type 195 | return_latent_history = config.return_latent_history 196 | 197 | # Validate input 198 | self.validate_input( 199 | prompt=prompt, 200 | negative_prompt=negative_prompt, 201 | start_step=start_step, 202 | num_inference_steps=num_inference_steps, 203 | ) 204 | 205 | # Generate embedding to condition on prompt and uncondition on negative prompt 206 | embedding = self.prompt_to_embedding( 207 | prompt=prompt, 208 | negative_prompt=negative_prompt, 209 | ) 210 | 211 | # Generate latent from input image 212 | image_latent = self.image_to_latent(image) 213 | 214 | # Run inference 215 | latent = self.embedding_to_latent( 216 | embedding=embedding, 217 | num_inference_steps=num_inference_steps, 218 | start_step=start_step, 219 | guidance_scale=guidance_scale, 220 | guidance_rescale=guidance_rescale, 221 | latent=image_latent, 222 | return_latent_history=return_latent_history, 223 | ) 224 | 225 | return self.resolve_output( 226 | latent=latent, 227 | output_type=output_type, 228 | return_latent_history=return_latent_history, 229 | ) 230 | 231 | generate = __call__ 232 | -------------------------------------------------------------------------------- /stablefused/apps/storybook/storybook.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import os 4 | 5 | from dataclasses import dataclass 6 | from dataclasses_json import DataClassJsonMixin 7 | from moviepy.editor import ( 8 | AudioFileClip, 9 | CompositeAudioClip, 10 | CompositeVideoClip, 11 | ImageClip, 12 | concatenate_audioclips, 13 | concatenate_videoclips, 14 | ) 15 | from typing import Dict, List, Optional, Union 16 | 17 | from stablefused import ( 18 | TextToImageConfig, 19 | TextToImageDiffusion, 20 | LatentWalkInterpolateConfig, 21 | LatentWalkDiffusion, 22 | ) 23 | from stablefused.apps.storybook import StoryBookAuthorBase, StoryBookSpeakerBase 24 | from stablefused.utils import write_text_on_image 25 | 26 | 27 | @dataclass 28 | class StoryBookConfig(DataClassJsonMixin): 29 | """ 30 | Configuration class for running inference with StoryBook. 31 | """ 32 | 33 | prompt: str 34 | artist_config: Union[TextToImageConfig, LatentWalkInterpolateConfig] 35 | artist_attributes: str = "" 36 | messages: List[Dict[str, str]] = None 37 | display_captions: bool = True 38 | caption_fontsize: int = 30 39 | caption_fontfile: str = "arial.ttf" 40 | caption_padding: int = 10 41 | frame_duration: int = 1 42 | num_retries: int = 3 43 | output_filename: str = "output.mp4" 44 | 45 | 46 | class StoryBook: 47 | def __init__( 48 | self, 49 | author: StoryBookAuthorBase, 50 | artist: Union[TextToImageDiffusion, LatentWalkDiffusion], 51 | speaker: Optional[StoryBookSpeakerBase] = None, 52 | ) -> None: 53 | self.author = author 54 | self.artist = artist 55 | self.speaker = speaker 56 | 57 | def _process_config(self, config: Union[str, Dict[str, str]]) -> None: 58 | if isinstance(config, str): 59 | module_dir = os.path.dirname(os.path.abspath(__file__)) 60 | config_path = os.path.join(module_dir, config) 61 | 62 | with open(config_path, "r") as f: 63 | config = json.load(f) 64 | 65 | self.artist_call_kwargs = { 66 | k: v for k, v in config.items() if k in self._artist_call_attributes 67 | } 68 | self.negative_prompt = self.artist_call_kwargs.pop("negative_prompt", None) 69 | self.messages: List[Dict[str, str]] = config.get("messages") 70 | self.attributes = config.get("attributes", "") 71 | 72 | def create_prompt(self, role: str, content: str) -> Dict[str, str]: 73 | return {"role": role, "content": content} 74 | 75 | def validate_output(self, output: str) -> List[Dict[str, str]]: 76 | try: 77 | output_list = json.loads(output) 78 | except json.JSONDecodeError as e: 79 | raise ValueError(f"Output is not a valid JSON: {str(e)}") 80 | 81 | if not isinstance(output_list, list): 82 | raise ValueError("Output must be a list of dictionaries.") 83 | 84 | for item in output_list: 85 | if not isinstance(item, dict): 86 | raise ValueError("Each item in the list must be a dictionary.") 87 | 88 | if "story" not in item or "prompt" not in item: 89 | raise ValueError( 90 | "Each dictionary must contain 'story' and 'prompt' keys." 91 | ) 92 | 93 | if not isinstance(item["story"], str) or not isinstance( 94 | item["prompt"], str 95 | ): 96 | raise ValueError("'story' and 'prompt' values must be strings.") 97 | 98 | return output_list 99 | 100 | def __call__( 101 | self, 102 | config: StoryBookConfig, 103 | ) -> None: 104 | print(config.to_json(indent=2)) 105 | prompt = config.prompt 106 | artist_config = config.artist_config 107 | artist_attributes = config.artist_attributes 108 | messages = config.messages 109 | display_captions = config.display_captions 110 | caption_fontsize = config.caption_fontsize 111 | caption_fontfile = config.caption_fontfile 112 | caption_padding = config.caption_padding 113 | frame_duration = config.frame_duration 114 | num_retries = config.num_retries 115 | output_filename = config.output_filename 116 | 117 | if ( 118 | isinstance(artist_config, TextToImageConfig) 119 | and isinstance(self.artist, LatentWalkDiffusion) 120 | ) or ( 121 | isinstance(artist_config, LatentWalkInterpolateConfig) 122 | and isinstance(self.artist, TextToImageDiffusion) 123 | ): 124 | raise ValueError( 125 | "Artist is not compatible with the provided artist config." 126 | ) 127 | 128 | messages.append(self.create_prompt("user", prompt)) 129 | 130 | for i in range(num_retries): 131 | try: 132 | storybook = self.author(messages) 133 | storybook = self.validate_output(storybook) 134 | break 135 | except Exception as e: 136 | print(f"Error: {str(e)}") 137 | print(f"Retrying ({i + 1}/{num_retries})...") 138 | continue 139 | else: 140 | raise Exception("Failed to generate storybook. Please try again.") 141 | 142 | prompt = [f"{item['prompt']}, {artist_attributes}" for item in storybook] 143 | artist_config.prompt = prompt 144 | artist_config.negative_prompt = ( 145 | [artist_config.negative_prompt] * len(storybook) 146 | if artist_config.negative_prompt is not None 147 | else None 148 | ) 149 | 150 | if isinstance(self.artist, TextToImageDiffusion): 151 | images = self.artist(artist_config) 152 | else: 153 | images = self.artist.interpolate(artist_config) 154 | 155 | if display_captions: 156 | if isinstance(self.artist, TextToImageDiffusion): 157 | images = [ 158 | write_text_on_image( 159 | image, 160 | storypart.get("story"), 161 | fontfile=caption_fontfile, 162 | fontsize=caption_fontsize, 163 | padding=caption_padding, 164 | ) 165 | for image, storypart in zip(images, storybook) 166 | ] 167 | else: 168 | num_frames_per_prompt = config.artist_config.interpolation_steps 169 | image_list = [] 170 | for i in range(0, len(images), num_frames_per_prompt): 171 | current_images = images[i : i + num_frames_per_prompt] 172 | current_story = storybook[i // num_frames_per_prompt] 173 | image_list.extend( 174 | [ 175 | write_text_on_image( 176 | image, 177 | current_story.get("story"), 178 | fontfile=caption_fontfile, 179 | fontsize=caption_fontsize, 180 | padding=caption_padding, 181 | ) 182 | for image in current_images 183 | ] 184 | ) 185 | images = image_list 186 | 187 | if self.speaker is not None: 188 | stories = [item.get("story") for item in storybook] 189 | audioclips = [] 190 | 191 | for audiofile in self.speaker(stories, yield_files=True): 192 | audioclips.append(AudioFileClip(audiofile)) 193 | 194 | audioclip: CompositeAudioClip = concatenate_audioclips(audioclips) 195 | frame_duration = audioclip.duration / len(prompt) 196 | else: 197 | audioclip = None 198 | 199 | if isinstance(self.artist, TextToImageDiffusion): 200 | video = [ 201 | ImageClip(np.array(image), duration=frame_duration) for image in images 202 | ] 203 | else: 204 | num_frames_per_prompt = config.artist_config.interpolation_steps 205 | video = [] 206 | for i in range(0, len(images), num_frames_per_prompt): 207 | current_images = images[i : i + num_frames_per_prompt] 208 | current_clips = [ 209 | ImageClip(np.array(image), duration=frame_duration / num_frames_per_prompt) 210 | for image in current_images 211 | ] 212 | video.append(concatenate_videoclips(current_clips)) 213 | 214 | video: CompositeVideoClip = concatenate_videoclips(video) 215 | video = video.set_audio(audioclip) 216 | video = video.set_fps(60) 217 | video.write_videofile(output_filename) 218 | 219 | return storybook 220 | -------------------------------------------------------------------------------- /stablefused/diffusion/text_to_image_diffusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from dataclasses import dataclass 4 | from diffusers import AutoencoderKL 5 | from tqdm.auto import tqdm 6 | from transformers import CLIPTextModel, CLIPTokenizer 7 | from typing import List, Optional, Union 8 | 9 | from stablefused.diffusion import BaseDiffusion 10 | from stablefused.typing import PromptType, OutputType, SchedulerType, UNetType 11 | 12 | 13 | @dataclass 14 | class TextToImageConfig: 15 | """ 16 | Configuration class for running inference with TextToImageDiffusion. 17 | 18 | Parameters 19 | ---------- 20 | prompt: PromptType 21 | Text prompt to condition on. 22 | image_height: int 23 | Height of image to generate. 24 | image_width: int 25 | Width of image to generate. 26 | num_inference_steps: int 27 | Number of diffusion steps to run. 28 | guidance_scale: float 29 | Guidance scale encourages the model to generate images following the prompt 30 | closely, albeit at the cost of image quality. 31 | guidance_rescale: float 32 | Guidance rescale from [Common Diffusion Noise Schedules and Sample Steps are 33 | Flawed](https://arxiv.org/pdf/2305.08891.pdf). 34 | negative_prompt: Optional[PromptType] 35 | Negative text prompt to uncondition on. 36 | latent: Optional[torch.FloatTensor] 37 | Latent to start from. If None, latent is generated from noise. 38 | output_type: str 39 | Type of output to return. One of ["latent", "pil", "pt", "np"]. 40 | return_latent_history: bool 41 | Whether to return the latent history. If True, return list of all latents 42 | generated during diffusion steps. 43 | """ 44 | 45 | prompt: PromptType = "" 46 | image_height: int = 512 47 | image_width: int = 512 48 | num_inference_steps: int = 50 49 | guidance_scale: float = 7.5 50 | guidance_rescale: float = 0.7 51 | negative_prompt: Optional[PromptType] = None 52 | latent: Optional[torch.FloatTensor] = None 53 | output_type: str = "pil" 54 | return_latent_history: bool = False 55 | 56 | 57 | class TextToImageDiffusion(BaseDiffusion): 58 | def __init__( 59 | self, 60 | model_id: str = None, 61 | tokenizer: CLIPTokenizer = None, 62 | text_encoder: CLIPTextModel = None, 63 | vae: AutoencoderKL = None, 64 | unet: UNetType = None, 65 | scheduler: SchedulerType = None, 66 | torch_dtype: torch.dtype = torch.float32, 67 | device="cuda", 68 | *args, 69 | **kwargs 70 | ) -> None: 71 | super().__init__( 72 | model_id=model_id, 73 | tokenizer=tokenizer, 74 | text_encoder=text_encoder, 75 | vae=vae, 76 | unet=unet, 77 | scheduler=scheduler, 78 | torch_dtype=torch_dtype, 79 | device=device, 80 | *args, 81 | **kwargs 82 | ) 83 | 84 | def embedding_to_latent( 85 | self, 86 | embedding: torch.FloatTensor, 87 | image_height: int, 88 | image_width: int, 89 | num_inference_steps: int, 90 | guidance_scale: float, 91 | guidance_rescale: float, 92 | latent: Optional[torch.FloatTensor] = None, 93 | return_latent_history: bool = False, 94 | ) -> Union[torch.FloatTensor, List[torch.FloatTensor]]: 95 | """ 96 | Generate latent by conditioning on prompt embedding using diffusion. 97 | 98 | Parameters 99 | ---------- 100 | embedding: torch.FloatTensor 101 | Embedding of text prompt. 102 | image_height: int 103 | Height of image to generate. 104 | image_width: int 105 | Width of image to generate. 106 | num_inference_steps: int 107 | Number of diffusion steps to run. 108 | guidance_scale: float 109 | Guidance scale encourages the model to generate images following the prompt 110 | closely, albeit at the cost of image quality. 111 | guidance_rescale: float 112 | Guidance rescale from [Common Diffusion Noise Schedules and Sample Steps are 113 | Flawed](https://arxiv.org/pdf/2305.08891.pdf). 114 | latent: Optional[torch.FloatTensor] 115 | Latent to start from. If None, generate latent from noise. 116 | return_latent_history: bool 117 | Whether to return latent history. If True, return list of all latents 118 | generated during diffusion steps. 119 | 120 | Returns 121 | ------- 122 | Union[torch.FloatTensor, List[torch.FloatTensor]] 123 | Latent generated by diffusion. If return_latent_history is True, return 124 | list of all latents generated during diffusion steps. 125 | """ 126 | 127 | # Generate latent from noise if not provided 128 | if latent is None: 129 | shape = ( 130 | embedding.shape[0] // 2, 131 | self.unet.config.in_channels, 132 | image_height // self.vae_scale_factor, 133 | image_width // self.vae_scale_factor, 134 | ) 135 | latent = self.random_tensor(shape) 136 | latent = latent.to(self.device) 137 | 138 | # Set number of inference steps 139 | self.scheduler.set_timesteps(num_inference_steps) 140 | timesteps = self.scheduler.timesteps 141 | 142 | # Scale the latent noise by the standard deviation required by the scheduler 143 | latent = latent * self.scheduler.init_noise_sigma 144 | latent_history = [latent] 145 | 146 | # Diffusion inference loop 147 | for i, timestep in tqdm(list(enumerate(timesteps))): 148 | # Duplicate latent to avoid two forward passes to perform classifier free guidance 149 | latent_model_input = torch.cat([latent] * 2) 150 | latent_model_input = self.scheduler.scale_model_input( 151 | latent_model_input, timestep 152 | ) 153 | 154 | # Predict noise 155 | noise_prediction = self.unet( 156 | latent_model_input, 157 | timestep, 158 | encoder_hidden_states=embedding, 159 | return_dict=False, 160 | )[0] 161 | 162 | # Perform classifier free guidance 163 | noise_prediction = self.classifier_free_guidance( 164 | noise_prediction, guidance_scale, guidance_rescale 165 | ) 166 | 167 | # Update latent 168 | latent = self.scheduler.step( 169 | noise_prediction, timestep, latent, return_dict=False 170 | )[0] 171 | 172 | if return_latent_history: 173 | latent_history.append(latent) 174 | 175 | return torch.stack(latent_history) if return_latent_history else latent 176 | 177 | @torch.no_grad() 178 | def __call__( 179 | self, 180 | config: TextToImageConfig, 181 | ) -> OutputType: 182 | """ 183 | Run inference by conditioning on text prompt. 184 | 185 | Parameters 186 | ---------- 187 | config: TextToImageConfig 188 | Configuration for running inference with TextToImageDiffusion. 189 | 190 | Returns 191 | ------- 192 | OutputType 193 | Generated output based on output_type. 194 | """ 195 | 196 | prompt = config.prompt 197 | image_height = config.image_height 198 | image_width = config.image_width 199 | num_inference_steps = config.num_inference_steps 200 | guidance_scale = config.guidance_scale 201 | guidance_rescale = config.guidance_rescale 202 | negative_prompt = config.negative_prompt 203 | latent = config.latent 204 | output_type = config.output_type 205 | return_latent_history = config.return_latent_history 206 | 207 | # Validate input 208 | self.validate_input( 209 | prompt=prompt, 210 | negative_prompt=negative_prompt, 211 | image_height=image_height, 212 | image_width=image_width, 213 | ) 214 | 215 | # Generate embedding to condition on prompt and uncondition on negative prompt 216 | embedding = self.prompt_to_embedding( 217 | prompt=prompt, 218 | negative_prompt=negative_prompt, 219 | ) 220 | 221 | # Run inference 222 | latent = self.embedding_to_latent( 223 | embedding=embedding, 224 | image_height=image_height, 225 | image_width=image_width, 226 | num_inference_steps=num_inference_steps, 227 | guidance_scale=guidance_scale, 228 | guidance_rescale=guidance_rescale, 229 | latent=latent, 230 | return_latent_history=return_latent_history, 231 | ) 232 | 233 | return self.resolve_output( 234 | latent=latent, 235 | output_type=output_type, 236 | return_latent_history=return_latent_history, 237 | ) 238 | 239 | generate = __call__ 240 | -------------------------------------------------------------------------------- /stablefused/diffusion/text_to_video_diffusion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from dataclasses import dataclass 5 | from diffusers import AutoencoderKL 6 | from tqdm.auto import tqdm 7 | from transformers import CLIPTextModel, CLIPTokenizer 8 | from typing import List, Optional, Union 9 | 10 | from stablefused.diffusion import BaseDiffusion 11 | from stablefused.typing import PromptType, OutputType, SchedulerType, UNetType 12 | 13 | 14 | @dataclass 15 | class TextToVideoConfig: 16 | """ 17 | Configuration class for running inference with TextToVideoDiffusion. 18 | 19 | Parameters 20 | ---------- 21 | prompt: PromptType 22 | Text prompt to condition on. 23 | video_height: int 24 | Height of video to generate. 25 | video_width: int 26 | Width of video to generate. 27 | video_frames: int 28 | Number of frames to generate in video. 29 | num_inference_steps: int 30 | Number of diffusion steps to run. 31 | guidance_scale: float 32 | Guidance scale encourages the model to generate images following the prompt 33 | closely, albeit at the cost of image quality. 34 | guidance_rescale: float 35 | Guidance rescale from [Common Diffusion Noise Schedules and Sample Steps are 36 | Flawed](https://arxiv.org/pdf/2305.08891.pdf). 37 | negative_prompt: Optional[PromptType] 38 | Negative text prompt to uncondition on. 39 | latent: Optional[torch.FloatTensor] 40 | Latent to start from. If None, latent is generated from noise. 41 | output_type: str 42 | Type of output to return. One of ["latent", "pil", "pt", "np"]. 43 | decode_batch_size: int 44 | Batch size to use when decoding latent to image. 45 | """ 46 | 47 | prompt: PromptType 48 | video_height: int = 512 49 | video_width: int = 512 50 | video_frames: int = 24 51 | num_inference_steps: int = 50 52 | guidance_scale: float = 7.5 53 | guidance_rescale: float = 0.7 54 | negative_prompt: Optional[PromptType] = None 55 | latent: Optional[torch.FloatTensor] = None 56 | output_type: str = "pil" 57 | decode_batch_size: int = 4 58 | 59 | 60 | class TextToVideoDiffusion(BaseDiffusion): 61 | def __init__( 62 | self, 63 | model_id: str = None, 64 | tokenizer: CLIPTokenizer = None, 65 | text_encoder: CLIPTextModel = None, 66 | vae: AutoencoderKL = None, 67 | unet: UNetType = None, 68 | scheduler: SchedulerType = None, 69 | torch_dtype: torch.dtype = torch.float32, 70 | device="cuda", 71 | *args, 72 | **kwargs 73 | ) -> None: 74 | super().__init__( 75 | model_id=model_id, 76 | tokenizer=tokenizer, 77 | text_encoder=text_encoder, 78 | vae=vae, 79 | unet=unet, 80 | scheduler=scheduler, 81 | torch_dtype=torch_dtype, 82 | device=device, 83 | *args, 84 | **kwargs 85 | ) 86 | 87 | def embedding_to_latent( 88 | self, 89 | embedding: torch.FloatTensor, 90 | video_height: int, 91 | video_width: int, 92 | video_frames: int, 93 | num_inference_steps: int, 94 | guidance_scale: float, 95 | guidance_rescale: float, 96 | latent: Optional[torch.FloatTensor] = None, 97 | ) -> Union[torch.FloatTensor, List[torch.FloatTensor]]: 98 | """ 99 | Generate latent by conditioning on prompt embedding using diffusion. 100 | 101 | Parameters 102 | ---------- 103 | embedding: torch.FloatTensor 104 | Embedding of text prompt. 105 | video_height: int 106 | Height of video to generate. 107 | video_width: int 108 | Width of video to generate. 109 | video_frames: int 110 | Number of frames to generate in video. 111 | num_inference_steps: int 112 | Number of diffusion steps to run. 113 | guidance_scale: float 114 | Guidance scale encourages the model to generate images following the prompt 115 | closely, albeit at the cost of image quality. 116 | guidance_rescale: float 117 | Guidance rescale from [Common Diffusion Noise Schedules and Sample Steps are 118 | Flawed](https://arxiv.org/pdf/2305.08891.pdf). 119 | latent: Optional[torch.FloatTensor] 120 | Latent to start from. If None, generate latent from noise. 121 | 122 | Returns 123 | ------- 124 | Union[torch.FloatTensor, List[torch.FloatTensor]] 125 | Latent generated by diffusion. 126 | """ 127 | 128 | # Generate latent from noise if not provided 129 | if latent is None: 130 | shape = ( 131 | embedding.shape[0] // 2, 132 | self.unet.config.in_channels, 133 | video_frames, 134 | video_height // self.vae_scale_factor, 135 | video_width // self.vae_scale_factor, 136 | ) 137 | latent = self.random_tensor(shape) 138 | 139 | # Set number of inference steps 140 | self.scheduler.set_timesteps(num_inference_steps) 141 | timesteps = self.scheduler.timesteps 142 | 143 | # Scale the latent noise by the standard deviation required by the scheduler 144 | latent = latent * self.scheduler.init_noise_sigma 145 | 146 | # Diffusion inference loop 147 | for i, timestep in tqdm(list(enumerate(timesteps))): 148 | # Duplicate latent to avoid two forward passes to perform classifier free guidance 149 | latent_model_input = torch.cat([latent] * 2) 150 | latent_model_input = self.scheduler.scale_model_input( 151 | latent_model_input, timestep 152 | ) 153 | 154 | # Predict noise 155 | noise_prediction = self.unet( 156 | latent_model_input, 157 | timestep, 158 | encoder_hidden_states=embedding, 159 | return_dict=False, 160 | )[0] 161 | 162 | # Perform classifier free guidance 163 | noise_prediction = self.classifier_free_guidance( 164 | noise_prediction, guidance_scale, guidance_rescale 165 | ) 166 | 167 | # Update latent 168 | latent = self.scheduler.step( 169 | noise_prediction, timestep, latent, return_dict=False 170 | )[0] 171 | 172 | return latent 173 | 174 | def resolve_output( 175 | self, 176 | latent: torch.FloatTensor, 177 | output_type: str, 178 | decode_batch_size: int, 179 | ) -> OutputType: 180 | """ 181 | Resolve output type from latent. 182 | 183 | Parameters 184 | ---------- 185 | latent: torch.FloatTensor 186 | Latent to resolve output from. 187 | output_type: str 188 | Output type to resolve. Must be one of [`latent`, `pt`, `np`, `pil`]. 189 | decode_batch_size: int 190 | Batch size to use when decoding latent to image. 191 | 192 | Returns 193 | ------- 194 | OutputType 195 | The resolved output based on the provided latent vector and options. 196 | """ 197 | 198 | if output_type not in ["latent", "pt", "np", "pil"]: 199 | raise ValueError( 200 | "`output_type` must be one of [`latent`, `pt`, `np`, `pil`]" 201 | ) 202 | 203 | if output_type == "latent": 204 | return latent 205 | 206 | # B, C, F, H, W => B, F, C, H, W 207 | latent = latent.permute(0, 2, 1, 3, 4) 208 | video = [] 209 | 210 | for i in tqdm(range(latent.shape[0])): 211 | batched_output = [] 212 | for j in tqdm(range(0, latent.shape[1], decode_batch_size)): 213 | current_latent = latent[i, j : j + decode_batch_size] 214 | batched_output.extend(self.latent_to_image(current_latent, output_type)) 215 | video.append(batched_output) 216 | 217 | if output_type == "pt": 218 | video = torch.stack(video) 219 | elif output_type == "np": 220 | video = np.stack(video) 221 | 222 | return video 223 | 224 | @torch.no_grad() 225 | def __call__( 226 | self, 227 | config: TextToVideoConfig, 228 | ) -> OutputType: 229 | """ 230 | Run inference by conditioning on text prompt. 231 | 232 | Parameters 233 | ---------- 234 | config: TextToVideoConfig 235 | Configuration for running inference with TextToVideoDiffusion. 236 | 237 | Returns 238 | ------- 239 | OutputType 240 | Generated output based on output_type. 241 | """ 242 | 243 | prompt = config.prompt 244 | video_height = config.video_height 245 | video_width = config.video_width 246 | video_frames = config.video_frames 247 | num_inference_steps = config.num_inference_steps 248 | guidance_scale = config.guidance_scale 249 | guidance_rescale = config.guidance_rescale 250 | negative_prompt = config.negative_prompt 251 | latent = config.latent 252 | output_type = config.output_type 253 | decode_batch_size = config.decode_batch_size 254 | 255 | # Validate input 256 | self.validate_input( 257 | prompt=prompt, 258 | negative_prompt=negative_prompt, 259 | image_height=video_height, 260 | image_width=video_width, 261 | num_inference_steps=num_inference_steps, 262 | ) 263 | 264 | # Generate embedding to condition on prompt and uncondition on negative prompt 265 | embedding = self.prompt_to_embedding( 266 | prompt=prompt, 267 | negative_prompt=negative_prompt, 268 | ) 269 | 270 | # Run inference 271 | latent = self.embedding_to_latent( 272 | embedding=embedding, 273 | video_height=video_height, 274 | video_width=video_width, 275 | video_frames=video_frames, 276 | num_inference_steps=num_inference_steps, 277 | guidance_scale=guidance_scale, 278 | guidance_rescale=guidance_rescale, 279 | latent=latent, 280 | ) 281 | 282 | return self.resolve_output( 283 | latent=latent, 284 | output_type=output_type, 285 | decode_batch_size=decode_batch_size, 286 | ) 287 | 288 | generate = __call__ 289 | -------------------------------------------------------------------------------- /examples/effect_of_guidance_scale.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# Effect of Guidance Scale\n", 9 | "\n", 10 | "In this notebook, we take a look at how the guidance scale affects the image quality of the model.\n", 11 | "\n", 12 | "Guidance scale is a value inspired by the paper [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). The explanation of how CFG works is out-of-scope here, but there are many online sources where you can read about it (linked below).\n", 13 | "\n", 14 | "- [Guidance: a cheat for diffusion models](https://sander.ai/2022/05/26/guidance.html)\n", 15 | "- [Diffusion Models, DDPMs, DDIMs and CFG](https://betterprogramming.pub/diffusion-models-ddpms-ddims-and-classifier-free-guidance-e07b297b2869)\n", 16 | "- [Classifier-Free Guidance Scale](https://mccormickml.com/2023/02/20/classifier-free-guidance-scale/)\n", 17 | "\n", 18 | "In short, guidance scale is a value that controls the amount of \"guidance\" used in the diffusion process. That is, the higher the value, the more closely the diffusion process follows the prompt. A lower guidance scale allows the model to be more creative, and work slightly different from the exact prompt. After a certain threshold maximum value, the results start to get worse, blurry and noisy.\n", 19 | "\n", 20 | "Guidance scale values, in practice, are usually in the range 6-15, and the default value of 7.5 is used in many inference implementations. However, manipulating it can lead to some very interesting results. It also only makes sense when it is set to 1.0 or higher, which is why many implementations use a minimum value of 1.0.\n", 21 | "\n", 22 | "But... what happens when we set guidance scale to 0? Or negative? Let's find out!\n", 23 | "\n", 24 | "When you use a negative value for the guidance scale, the model will try to generate images that are the opposite of what you specify in the prompt. For example, if you prompt the model to generate an image of an astronaut, and you use a negative guidance scale, the model will try to generate an image of everything but an astronaut. This can be a fun way to generate creative and unexpected images (sometimes NSFW or absolute horrendous stuff, if you are not using a safety-checker model - which is the case with StableFused)." 25 | ] 26 | }, 27 | { 28 | "attachments": {}, 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "### Install and Import required packages" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "!pip install stablefused ipython" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "import numpy as np\n", 51 | "import torch\n", 52 | "\n", 53 | "from IPython.display import Video, display\n", 54 | "from PIL import Image, ImageDraw, ImageFont\n", 55 | "from stablefused import TextToImageDiffusion\n", 56 | "from typing import List" 57 | ] 58 | }, 59 | { 60 | "attachments": {}, 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "### Initialize model and parameters\n", 65 | "\n", 66 | "We use RunwayML's Stable Diffusion 1.5 checkpoint." 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "model_id = \"runwayml/stable-diffusion-v1-5\"\n", 76 | "model = TextToImageDiffusion(model_id=model_id, torch_dtype=torch.float16)" 77 | ] 78 | }, 79 | { 80 | "attachments": {}, 81 | "cell_type": "markdown", 82 | "metadata": { 83 | "id": "q0-_CmuxpTw8" 84 | }, 85 | "source": [ 86 | "##### Prompt Credits\n", 87 | "\n", 88 | "The prompts used in this notebook have been taken from different sources. The main inspirations are:\n", 89 | "\n", 90 | "- https://levelup.gitconnected.com/20-stable-diffusion-prompts-to-create-stunning-characters-a63017dc4b74\n", 91 | "- https://mpost.io/best-100-stable-diffusion-prompts-the-most-beautiful-ai-text-to-image-prompts/" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "prompt = [\n", 101 | " \"Artistic image, very detailed cute cat, cinematic lighting effect, cute, charming, fantasy art, digital painting, photorealistic\",\n", 102 | " \"A lion in galaxies, spirals, nebulae, stars, smoke, iridescent, intricate detail, octane render, 8k\",\n", 103 | " \"A grand city in the year 2100, atmospheric, hyper realistic, 8k, epic composition, cinematic, octane render\",\n", 104 | " \"Starry Night, painting style of Vincent van Gogh, Oil paint on canvas, Landscape with a starry night sky, dreamy, peaceful\",\n", 105 | "]\n", 106 | "negative_prompt = \"cartoon, unrealistic, blur, boring background, deformed, disfigured, low resolution, unattractive\"\n", 107 | "num_inference_steps = 20\n", 108 | "seed = 2023" 109 | ] 110 | }, 111 | { 112 | "attachments": {}, 113 | "cell_type": "markdown", 114 | "metadata": {}, 115 | "source": [ 116 | "`image_grid_with_labels` is a helper function that takes a list of images and a list of labels and displays them in a grid." 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "def image_grid_with_labels(\n", 126 | " images: List[Image.Image], labels: List[str], rows: int, cols: int\n", 127 | ") -> Image.Image:\n", 128 | " \"\"\"Create a grid of images with labels.\"\"\"\n", 129 | " if len(images) > rows * cols:\n", 130 | " raise ValueError(\n", 131 | " f\"Number of images ({len(images)}) exceeds grid size ({rows}x{cols}).\"\n", 132 | " )\n", 133 | " if len(labels) != rows:\n", 134 | " raise ValueError(\n", 135 | " f\"Number of labels ({len(labels)}) does not match the number of rows ({rows}).\"\n", 136 | " )\n", 137 | "\n", 138 | " w, h = images[0].size\n", 139 | " label_width = 100\n", 140 | "\n", 141 | " grid = Image.new(\"RGB\", size=(cols * w + label_width, rows * h))\n", 142 | " draw = ImageDraw.Draw(grid)\n", 143 | "\n", 144 | " font_size = 32\n", 145 | " font = ImageFont.truetype(\n", 146 | " \"/usr/local/lib/python3.10/dist-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSans.ttf\",\n", 147 | " size=font_size,\n", 148 | " )\n", 149 | "\n", 150 | " for i, label in enumerate(labels):\n", 151 | " x_label = label_width // 4\n", 152 | " y_label = i * h + h // 2\n", 153 | " draw.text((x_label, y_label), label, fill=(255, 255, 255), font=font)\n", 154 | "\n", 155 | " for i, image in enumerate(images):\n", 156 | " x_img = (i % cols) * w + label_width\n", 157 | " y_img = (i // cols) * h\n", 158 | " grid.paste(image, box=(x_img, y_img))\n", 159 | "\n", 160 | " return grid" 161 | ] 162 | }, 163 | { 164 | "attachments": {}, 165 | "cell_type": "markdown", 166 | "metadata": {}, 167 | "source": [ 168 | "### Inference with different guidance scales\n", 169 | "\n", 170 | "We start with a negative guidance scale and increment it by 1.5 until a certain maximum value. The results obtained are very interesting!\n", 171 | "\n", 172 | "The below code demonstrates the effect of guidance scale on different prompts." 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "epsilon = 1e-5\n", 182 | "guidance_scale = -1.5\n", 183 | "increment = 1.5\n", 184 | "max_guidance_scale = 15 + epsilon\n", 185 | "num_iterations = 0\n", 186 | "results = []\n", 187 | "guidance_scale_labels = []\n", 188 | "\n", 189 | "while guidance_scale <= max_guidance_scale:\n", 190 | " torch.manual_seed(seed)\n", 191 | " np.random.seed(seed)\n", 192 | "\n", 193 | " print(f\"Generating images with guidance_scale={guidance_scale:.2f} and seed={seed}\")\n", 194 | " results.append(\n", 195 | " model(\n", 196 | " prompt=prompt,\n", 197 | " negative_prompt=[negative_prompt] * len(prompt),\n", 198 | " num_inference_steps=num_inference_steps,\n", 199 | " guidance_scale=guidance_scale,\n", 200 | " )\n", 201 | " )\n", 202 | "\n", 203 | " guidance_scale_labels.append(str(round(guidance_scale, 2)))\n", 204 | " guidance_scale += increment\n", 205 | " num_iterations += 1" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": null, 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "flattened_results = [image for result in results for image in result]" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": null, 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "image_grid_with_labels(\n", 224 | " flattened_results,\n", 225 | " labels=guidance_scale_labels,\n", 226 | " rows=num_iterations,\n", 227 | " cols=len(prompt),\n", 228 | ").save(\"effect-of-guidance-scale-on-different-prompts.png\")" 229 | ] 230 | }, 231 | { 232 | "attachments": {}, 233 | "cell_type": "markdown", 234 | "metadata": {}, 235 | "source": [ 236 | "\n", 237 | "The below code demonstrates the effect of guidance scale on the same prompt over multiple inference steps. " 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": null, 243 | "metadata": {}, 244 | "outputs": [], 245 | "source": [ 246 | "steps = [3, 6, 12, 20, 25]\n", 247 | "prompt = \"Photorealistic illustration of a mystical alien creature, magnificent, strong, atomic, tyrannic, predator, unforgiving, full-body image\"\n", 248 | "epsilon = 1e-5\n", 249 | "guidance_scale = -1.5\n", 250 | "increment = 1.5\n", 251 | "max_guidance_scale = 15 + epsilon\n", 252 | "num_iterations = 0\n", 253 | "results = []\n", 254 | "guidance_scale_labels = []\n", 255 | "seed = 42\n", 256 | "\n", 257 | "while guidance_scale <= max_guidance_scale:\n", 258 | " step_results = []\n", 259 | "\n", 260 | " for step in steps:\n", 261 | " torch.manual_seed(seed)\n", 262 | " np.random.seed(seed)\n", 263 | "\n", 264 | " print(\n", 265 | " f\"Generating images with guidance_scale={guidance_scale:.2f}, num_inference_steps={step} and seed={seed}\"\n", 266 | " )\n", 267 | " step_results.append(\n", 268 | " model(\n", 269 | " prompt=prompt,\n", 270 | " negative_prompt=negative_prompt,\n", 271 | " num_inference_steps=step,\n", 272 | " guidance_scale=guidance_scale,\n", 273 | " )\n", 274 | " )\n", 275 | "\n", 276 | " guidance_scale_labels.append(str(round(guidance_scale, 2)))\n", 277 | " guidance_scale += increment\n", 278 | " num_iterations += 1\n", 279 | "\n", 280 | " results.append(step_results)" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": null, 286 | "metadata": {}, 287 | "outputs": [], 288 | "source": [ 289 | "flattened_results = [\n", 290 | " image for result in results for images in result for image in images\n", 291 | "]" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": null, 297 | "metadata": {}, 298 | "outputs": [], 299 | "source": [ 300 | "image_grid_with_labels(\n", 301 | " flattened_results, labels=guidance_scale_labels, rows=len(results), cols=len(steps)\n", 302 | ").save(\"effect-of-guidance-scale-vs-steps.png\")" 303 | ] 304 | } 305 | ], 306 | "metadata": { 307 | "language_info": { 308 | "name": "python" 309 | } 310 | }, 311 | "nbformat": 4, 312 | "nbformat_minor": 0 313 | } 314 | -------------------------------------------------------------------------------- /examples/latent_walk_diffusion.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# Latent Walk Diffusion\n", 9 | "\n", 10 | "In this notebook, we will take a look at latent walking in latent spaces. Generative models, like the ones used in Stable Diffusion, learn a latent representation of the world. A latent representation is a low-dimensional vector space embedding of the world. In the case of SD, this latent representation is learnt by training on text-image pairs. This representation is used to generate samples given a prompt and a random noise vector. The model tries to predict and remove noise from the random noise vector, while also aligning it the vector to the prompt. This results in some interesting properties of the latent space. In this notebook, we will explore these properties.\n", 11 | "\n", 12 | "Stable Diffusion models (atleast, the models used here) learn two latent representations - one of the NLP space for prompts, and one of the image space. These latent representations are continuous. If we choose two vectors in the latent space to sample from, we get two different/similar images depending on how different the chosen vectors are. This is the basis of latent walking. We can choose two vectors in the latent space, and sample from the latent path between them. This results in a smooth transition between the two images." 13 | ] 14 | }, 15 | { 16 | "attachments": {}, 17 | "cell_type": "markdown", 18 | "metadata": {}, 19 | "source": [ 20 | "### Install and Import required packages" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "%pip install stablefused ipython" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "import numpy as np\n", 39 | "import torch\n", 40 | "\n", 41 | "from IPython.display import Video, display\n", 42 | "from PIL import Image\n", 43 | "from stablefused import LatentWalkDiffusion, TextToImageDiffusion\n", 44 | "from stablefused.utils import image_grid, pil_to_video" 45 | ] 46 | }, 47 | { 48 | "attachments": {}, 49 | "cell_type": "markdown", 50 | "metadata": {}, 51 | "source": [ 52 | "### Initialize model and parameters\n", 53 | "\n", 54 | "We use RunwayML's Stable Diffusion 1.5 checkpoint and initialize our Latent-Walk and Text-to-Image Diffusion models. Play around with different prompts and parameters, and see what you get! You can comment out the parts that use seeds to generate random images each time you run the notebook.\n", 55 | "\n", 56 | "We use the following mechanism to trade-off speed for reduced memory footprint. It allows us to work with bigger images and larger batch sizes with about just 6GB of GPU memory.\n", 57 | "- U-Net Attention Slicing: Allows the internal U-Net model to perform computations for attention heads sequentially, rather than in parallel.\n", 58 | "- VAE Slicing: Allow tensor slicing for VAE decode step. This will cause the vae to split the input tensor to compute decoding in multiple steps.\n", 59 | "- VAE Tiling: Allow tensor tiling for vae. This will cause the vae to split the input tensor into tiles to compute encoding/decoding in several steps.\n", 60 | "\n", 61 | "Also, notice how we are loading the same model twice. That should use twice the memory, right? Well, in most cases, users stick to using the same model checkpoints across different inference pipelines, and so it makes sense to share the internal models. StableFused maintains an internal model cache which allows all internal models to be shared, in order to save memory." 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "model_id = \"runwayml/stable-diffusion-v1-5\"\n", 71 | "lw_model = LatentWalkDiffusion(model_id=model_id, torch_dtype=torch.float16)\n", 72 | "\n", 73 | "lw_model.enable_attention_slicing()\n", 74 | "lw_model.enable_slicing()\n", 75 | "lw_model.enable_tiling()\n", 76 | "\n", 77 | "t2i_model = TextToImageDiffusion(model_id=model_id, torch_dtype=torch.float16)" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": { 83 | "id": "bVIENw0bnSZ_" 84 | }, 85 | "source": [ 86 | "Prompt Credits: https://mspoweruser.com/best-stable-diffusion-prompts/#6_The_Robotic_Baroque_Battle\n" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "prompt = \"Large futuristic mechanical robot in the foreground of a baroque-style battle scene, photorealistic, high quality, 8k\"\n", 96 | "negative_prompt = \"cartoon, unrealistic, blur, boring background, deformed, disfigured, low resolution, unattractive\"\n", 97 | "num_images = 4\n", 98 | "seed = 44\n", 99 | "\n", 100 | "torch.manual_seed(seed)\n", 101 | "np.random.seed(seed)" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "# There seems to be a bug in stablefused which requires to be reviewed. The\n", 111 | "# bug causes images created using latent walk to be very noisy, or plain white,\n", 112 | "# deformed, etc.\n", 113 | "# This is why instead of being able to use an actual image to generate latents,\n", 114 | "# we need to make it ourselves. In the future, this notebook will be updated\n", 115 | "# to allow latent walking for user-chosen images\n", 116 | "\n", 117 | "# filename = \"the-robotic-baroque-battle.png\"\n", 118 | "# start_image = [Image.open(filename)] * num_images\n", 119 | "\n", 120 | "# # This step is only required when loading model with torch.float16 dtype\n", 121 | "# start_image = np.array(start_image, dtype=np.float16)\n", 122 | "\n", 123 | "# latent = lw_model.image_to_latent(start_image)\n", 124 | "\n", 125 | "image_height = 512\n", 126 | "image_width = 512\n", 127 | "shape = (\n", 128 | " 1,\n", 129 | " lw_model.unet.config.in_channels,\n", 130 | " image_height // lw_model.vae_scale_factor,\n", 131 | " image_width // lw_model.vae_scale_factor,\n", 132 | ")\n", 133 | "single_latent = lw_model.random_tensor(shape)\n", 134 | "latent = single_latent.repeat(num_images, 1, 1, 1)" 135 | ] 136 | }, 137 | { 138 | "attachments": {}, 139 | "cell_type": "markdown", 140 | "metadata": {}, 141 | "source": [ 142 | "### Latent Walking to generate similar images\n", 143 | "\n", 144 | "Let's see what our base image for latent walking looks like." 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "t2i_model(\n", 154 | " prompt=prompt,\n", 155 | " negative_prompt=negative_prompt,\n", 156 | " num_inference_steps=20,\n", 157 | " guidance_scale=10.0,\n", 158 | " latent=single_latent,\n", 159 | ")[0]" 160 | ] 161 | }, 162 | { 163 | "attachments": {}, 164 | "cell_type": "markdown", 165 | "metadata": {}, 166 | "source": [ 167 | "Latent walk with diffusion around the latent space of our sampled latent vector. This results in generation of similar images. The similarity/difference can be controlled using the `strength` parameter (set between 0 and 1, defaults to 0.2). Lower strenght leads to similar images with subtle differences. Higher strength can cause completely new ideas to be generated." 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "images = lw_model(\n", 177 | " prompt=[prompt] * num_images,\n", 178 | " negative_prompt=[negative_prompt] * num_images,\n", 179 | " latent=latent,\n", 180 | " strength=0.25,\n", 181 | " num_inference_steps=20,\n", 182 | ")" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "image_grid(images, rows=2, cols=2)" 192 | ] 193 | }, 194 | { 195 | "attachments": {}, 196 | "cell_type": "markdown", 197 | "metadata": {}, 198 | "source": [ 199 | "### Generating Videos with Latent Walking\n", 200 | "\n", 201 | "Here, we generate a video by walking the latent space of the model, using interpolation techniques to generate frames. An interpolation is just a weighted average of two embeddings calculated by some interpolation function. [Linear interpolation](https://en.wikipedia.org/wiki/Linear_interpolation) is used on the prompt embeddings and [Spherical Linear Interpolation](https://en.wikipedia.org/wiki/Slerp) is used on the latent embeddings, by default. You can change the interpolation method by passing `embedding_interpolation_type` or `latent_interpolation_type` parameter.\n", 202 | "\n", 203 | "Note that stablefused is a toy library in its infancy and is not optimized for speed, and does not support a lot of features. There are multiple bugs and issues that need to be addressed. Some things need to be implemented manually currently, but in the future, I hope to make the process easier." 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": null, 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [ 212 | "# Prompt credits: ChatGPT\n", 213 | "story_prompt = [\n", 214 | " \"A dog chasing a cat in a thrilling backyard scene, high quality and photorealistic\",\n", 215 | " \"A determined dog in hot pursuit, with stunning realism, octane render\",\n", 216 | " \"A thrilling chase, dog behind the cat, octane render, exceptional realism and quality\",\n", 217 | " \"The exciting moment of a cat outmaneuvering a chasing dog, high-quality and photorealistic detail\",\n", 218 | " \"A clever cat escaping a determined dog and soaring into space, rendered with octane render for stunning realism\",\n", 219 | " \"The cat's escape into the cosmos, leaving the dog behind in a scene,high quality and photorealistic style\",\n", 220 | "]\n", 221 | "\n", 222 | "seed = 123456\n", 223 | "\n", 224 | "torch.manual_seed(seed)\n", 225 | "np.random.seed(seed)" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": null, 231 | "metadata": {}, 232 | "outputs": [], 233 | "source": [ 234 | "# There seems to be a bug in stablefused which requires to be reviewed. The\n", 235 | "# bug causes images created using latent walk to be very noisy, or plain white,\n", 236 | "# deformed, etc.\n", 237 | "# This is why instead of being able to use an actual image to generate latents,\n", 238 | "# we need to make it ourselves. In the future, this notebook will be updated\n", 239 | "# to allow latent walking for user-chosen images\n", 240 | "\n", 241 | "# t2i_images = t2i_model(\n", 242 | "# prompt = story_prompt,\n", 243 | "# negative_prompt = [negative_prompt] * len(story_prompt),\n", 244 | "# num_inference_steps = 20,\n", 245 | "# guidance_scale = 12.0,\n", 246 | "# )\n", 247 | "\n", 248 | "image_height = 512\n", 249 | "image_width = 512\n", 250 | "shape = (\n", 251 | " len(story_prompt),\n", 252 | " lw_model.unet.config.in_channels,\n", 253 | " image_height // lw_model.vae_scale_factor,\n", 254 | " image_width // lw_model.vae_scale_factor,\n", 255 | ")\n", 256 | "latent = lw_model.random_tensor(shape)" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": null, 262 | "metadata": {}, 263 | "outputs": [], 264 | "source": [ 265 | "t2i_images = t2i_model(\n", 266 | " prompt=story_prompt,\n", 267 | " num_inference_steps=20,\n", 268 | " guidance_scale=15.0,\n", 269 | " latent=latent,\n", 270 | ")" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": null, 276 | "metadata": {}, 277 | "outputs": [], 278 | "source": [ 279 | "image_grid(t2i_images, rows=2, cols=3)" 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": null, 285 | "metadata": {}, 286 | "outputs": [], 287 | "source": [ 288 | "# Due to the bug mentioned above, this step is not required.\n", 289 | "# We can directly use the latents we generated manually\n", 290 | "# np_t2i_images = np.array(t2i_images, dtype = np.float16)\n", 291 | "# t2i_latents = t2i_model.image_to_latent(np_t2i_images)" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": null, 297 | "metadata": {}, 298 | "outputs": [], 299 | "source": [ 300 | "interpolation_steps = 24\n", 301 | "\n", 302 | "# Since stablefused does not support batch processing yet, we need\n", 303 | "# to do it manually. This notebook will be updated in the future\n", 304 | "# to support batching internally to handle a large number of images\n", 305 | "\n", 306 | "story_images = []\n", 307 | "for i in range(len(story_prompt) - 1):\n", 308 | " current_prompt = story_prompt[i : i + 2]\n", 309 | " current_latent = latent[i : i + 2]\n", 310 | " imgs = lw_model.interpolate(\n", 311 | " prompt=current_prompt,\n", 312 | " negative_prompt=[negative_prompt] * len(current_prompt),\n", 313 | " latent=current_latent,\n", 314 | " num_inference_steps=20,\n", 315 | " interpolation_steps=interpolation_steps,\n", 316 | " )\n", 317 | " story_images.extend(imgs)" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": null, 323 | "metadata": {}, 324 | "outputs": [], 325 | "source": [ 326 | "filename = \"dog-chasing-cat-story.mp4\"\n", 327 | "pil_to_video(story_images, filename, fps=8)" 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "execution_count": null, 333 | "metadata": {}, 334 | "outputs": [], 335 | "source": [ 336 | "display(Video(filename, embed=True))" 337 | ] 338 | } 339 | ], 340 | "metadata": { 341 | "language_info": { 342 | "name": "python" 343 | } 344 | }, 345 | "nbformat": 4, 346 | "nbformat_minor": 0 347 | } 348 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StableFused 2 | 3 |

4 | 5 | 6 | pypi 7 | 8 | 9 | 10 | docs 11 | 12 | 13 |

14 | 15 | StableFused is a toy library to experiment with Stable Diffusion inspired by 🤗 diffusers and various other sources! One of the main reasons I'm working on this project is to learn more about Stable Diffusion, and generative models in general. It is my current area of research at university. 16 | 17 | ## Installation 18 | 19 | It is recommended to use a virtual environment. You can use [venv](https://docs.python.org/3/library/venv.html) or [conda](https://docs.conda.io/en/latest/) to create one. 20 | 21 | Unix: 22 | ```bash 23 | python -m venv venv 24 | source venv/bin/activate 25 | ``` 26 | 27 | Windows: 28 | ```PowerShell 29 | python -m venv venv 30 | venv\Scripts\activate 31 | ``` 32 | 33 | For usage, install the package from PyPI. 34 | 35 | ```bash 36 | pip install stablefused 37 | ``` 38 | 39 | For development, fork the repository, clone it and install the package in editable mode. 40 | 41 | ```bash 42 | git clone https://github.com//stablefused.git 43 | cd stablefused 44 | pip install -e ".[dev]" 45 | ``` 46 | 47 | ## Usage 48 | 49 | Checkout the [examples](https://github.com/a-r-r-o-w/stablefused/tree/main/examples) folder for notebooks 🥰 50 | 51 | ## Contributing 52 | 53 | Contributions are welcome! Note that this project is not a serious implementation for training/inference/fine-tuning diffusion models. It is a toy library. I am working on it for fun and experimentation purposes (and because I'm too stupid to modify large codebases and understand what's going on). 54 | 55 | As I'm not an expert in this field, I will have probably made a lot of mistakes. If you find any, please open an issue or a PR. I'll be happy to learn from you! 56 | 57 | ## Acknowledgements/Resources 58 | 59 | The following sources have been very helpful to me in understanding Stable Diffusion. I highly recommend you to check them out! 60 | 61 | - [🤗 diffusers](https://github.com/huggingface/diffusers) 62 | - [Karpathy's gist on latent walking](https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355) 63 | - [Nateraw's stable-diffusion-videos](https://github.com/nateraw/stable-diffusion-videos) 64 | - [AnimateDiff](https://github.com/guoyww/AnimateDiff/) 65 | - [RunwayML StableDiffusion](https://github.com/runwayml/stable-diffusion) 66 | - [🤗 Annotated Diffusion Blog](https://huggingface.co/blog/annotated-diffusion) 67 | - [Keras CV](https://github.com/keras-team/keras-cv) 68 | - [Lillian Weng's Blogs](https://lilianweng.github.io/) 69 | - [Emilio Dorigatti's Blogs](https://e-dorigatti.github.io/) 70 | - [The AI Summer Diffusion Models Blog](https://theaisummer.com/diffusion-models/) 71 | 72 | ## Results 73 | 74 | ### Visualization of diffusion process 75 | 76 | Refer to the notebooks for more details and enjoy the denoising process! 77 | 78 |
79 | Text to Image 80 | 81 | These results are generated using the [Text to Image](https://github.com/a-r-r-o-w/stablefused/blob/main/examples/text_to_image_diffusion.ipynb) notebook. 82 | 83 |
84 | 87 |
88 |
89 | 90 |
91 | Image to Image 92 | 93 | These results are generated using the [Image to Image](https://github.com/a-r-r-o-w/stablefused/blob/main/examples/image_to_image_diffusion.ipynb) notebook. 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 |
Source ImageDenoising Diffusion Process
The Renaissance AstronautHigh quality and colorful photo of Robert J Oppenheimer, father of the atomic bomb, in a spacesuit, galaxy in the background, universe, octane render, realistic, 8k, bright colorsStylistic photorealisic photo of Margot Robbie, playing the role of astronaut, pretty, beautiful, high contrast, high quality, galaxies, intricate detail, colorful, 8k
114 | 115 |
116 | PS 117 | The results from Image to Image Diffusion don't seem very great from my experimentation. It might be some kind of bug in my implementation, which I'll have to look into later... 118 |
119 |
120 | 121 | ### Text to Video 122 | 123 | There is a lot of ongoing research on the generation of videos from text prompts. It is also my current area of research at university. The implementation here is adapted from [AnimateDiff](https://animatediff.github.io/). 124 | 125 | There is immense potential in developing this kind of technology and its possible usecases are unlimited - personalized educational content, marketing and advertising, creativity and art, etc. to name a few. Imagine a world where you have your own personal ChatGPT/Bard like assistants for visual learning - a model that can generate [3Blue1Brown](https://www.youtube.com/c/3blue1brown) style videos explaining science topics, or depict a story! Current models are not that capable yet, but this is where we are headed, I think, and is what me and my team are researching on. The future of this technology will be fascinating to witness! 126 | 127 |
128 | Text to Video 129 | 130 | These results are generated using the [Text to Video](https://github.com/a-r-r-o-w/stablefused/blob/main/examples/text_to_video_diffusion.ipynb) notebook. 131 | 132 |
133 | 134 | | Text to Video | 135 | | :-: | 136 | | _An astronaut floating in space, interstellar, black background with stars, photorealistic, high quality, 8k_ | 137 | | | 138 | | _A mighty pirate ship sailing through the sea, unpleasant, thundering roar, dark night, starry night, high quality, photorealistic, 8k_ | 139 | | | 140 | 141 |
142 |
143 | 144 | ### Inpainting 145 | 146 | Image inpainting is a technique that aims to fill in missing or damaged parts of an image. It is used to restore or repair images by extrapolating the surrounding information to recreate the missing regions seamlessly. 147 | 148 | These results are generated using the [Inpainting](https://github.com/a-r-r-o-w/stablefused/blob/main/examples/inpaint_diffusion.ipynb) notebook. 149 | 150 |
151 | Inpainting using a fixed mask and different prompts 152 | 153 |
154 | 155 | 156 | 157 | 158 | 159 | 160 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 |
Inpainting
161 |

Prompt 1: Digital illustration of a mythical creature, high quality, realistic, 8k
Prompt 2: Digital illustration of a mythical creature, high quality, realistic, 8k
Prompt 3: Digital illustration of a dragon, high quality, realistic, octane render, 8k
Prompt 4: Digital illustration of a ferocious lion, high quality, realistic, octane render, 8k
Prompt 5: Digital illustration of an evil white rabbit, high quality, realistic, 8k
Prompt 6: Digital illustration of samurai with a moon-like object in the background, high quality, realistic, octane render, 8k

162 |
Image Mask
177 |
178 |
179 | 180 |
181 | Infinite Zoom In 182 | 183 | **Prompt:** _A painting of a cat, in the style of Vincent Van Gogh, hanging in a room_ 184 | 185 |
186 | 187 |
188 |
189 | 190 |
191 | Pan and Zoom Out 192 | 193 | **Prompt:** _Post-apocalyptic world with ruins, overgrown vegetation, and a lone survivor_ 194 | 195 |
196 | 197 |
198 |
199 | 200 | ### Understanding the effect of Guidance Scale 201 | 202 | Guidance scale is a value inspired by the paper [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). The explanation of how CFG works is out-of-scope here, but there are many online sources where you can read about it (linked below). 203 | 204 | - [Guidance: a cheat for diffusion models](https://sander.ai/2022/05/26/guidance.html) 205 | - [Diffusion Models, DDPMs, DDIMs and CFG](https://betterprogramming.pub/diffusion-models-ddpms-ddims-and-classifier-free-guidance-e07b297b2869) 206 | - [Classifier-Free Guidance Scale](https://mccormickml.com/2023/02/20/classifier-free-guidance-scale/) 207 | 208 | In short, guidance scale is a value that controls the amount of "guidance" used in the diffusion process. That is, the higher the value, the more closely the diffusion process follows the prompt. A lower guidance scale allows the model to be more creative, and work slightly different from the exact prompt. After a certain threshold maximum value, the results start to get worse, blurry and noisy. 209 | 210 | Guidance scale values, in practice, are usually in the range 6-15, and the default value of 7.5 is used in many inference implementations. However, manipulating it can lead to some very interesting results. It also only makes sense when it is set to 1.0 or higher, which is why many implementations use a minimum value of 1.0. 211 | 212 | But... what happens when we set guidance scale to 0? Or negative? Let's find out! 213 | 214 | When you use a negative value for the guidance scale, the model will try to generate images that are the opposite of what you specify in the prompt. For example, if you prompt the model to generate an image of an astronaut, and you use a negative guidance scale, the model will try to generate an image of everything but an astronaut. This can be a fun way to generate creative and unexpected images (sometimes NSFW or absolute horrendous stuff, if you are not using a safety-checker model - which is the case with StableFused). 215 | 216 | ##### Results 217 | 218 | The original images produced are too large to display in high quality here. You can find them in my [Drive](https://drive.google.com/drive/folders/13eZsi7y1LZxUHlaxagGTPS6pLwzBysU6?usp=sharing). These images are compressed from ~30 MB to ~6 MB in order for GitHub to accept uploads. 219 | 220 |
221 | 222 | Effect of Guidance Scale on Different Prompts 223 | 224 | 225 | | Effect of Guidance Scale on Different Prompts | 226 | | --- | 227 | | Each image is sampled with the same prompt and seed to ensure only the guidance scale plays a role.
**Column 1:** _Artistic image, very detailed cute cat, cinematic lighting effect, cute, charming, fantasy art, digital painting, photorealistic_
**Column 2:** _A lion in galaxies, spirals, nebulae, stars, smoke, iridescent, intricate detail, octane render, 8k_
**Column 3:** _A grand city in the year 2100, atmospheric, hyper realistic, 8k, epic composition, cinematic, octane render_
**Column 4:** _Starry Night, painting style of Vincent van Gogh, Oil paint on canvas, Landscape with a starry night sky, dreamy, peaceful_ | 228 | |
| 229 |
230 | 231 |
232 | 233 | Effect of Guidance Scale with increased number of inference steps 234 | 235 | 236 | | Effect of Guidance Scale with increased number of inference steps | 237 | | --- | 238 | | Columns have number of inference steps set to 3, 6, 12, 20, 25.
**Prompt:** _Photorealistic illustration of a mystical alien creature, magnificent, strong, atomic, tyrannic, predator, unforgiving, full-body image_ | 239 | |
| 240 | |
| 241 |
242 | 243 | ### Latent Walk 244 | 245 | Generative models, like the ones used in Stable Diffusion, learn a latent representation of the world. A latent representation is a low-dimensional vector space embedding of the world. In the case of SD, this latent representation is learnt by training on text-image pairs. This representation is used to generate samples given a prompt and a random noise vector. The model tries to predict and remove noise from the random noise vector, while also aligning the vector to the prompt. This results in some interesting properties of the latent space. 246 | 247 | Stable Diffusion models (atleast, the models used here) learn two latent representations - one of the NLP space for prompts, and one of the image space. These latent representations are continuous. If we choose two vectors in the latent space to sample from, we get two different/similar images depending on how different the chosen vectors are. This is the basis of latent walking. We can choose two vectors in the latent space, and sample from the latent path between them. This results in a smooth transition between the two images. 248 | 249 |
250 | Similar Image Generation by sampling latent space 251 | 252 | The results below show just how information rich the latent space of these stable diffusion models are. 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 266 | 267 | 268 | 271 | 274 | 275 | 276 |
Source ImageLatent Walks
264 | Large futuristic mechanical robot in the foreground of a baroque-style battle scene, photorealistic, high quality, 8k 265 |
269 | 270 | 272 | 273 |
277 |
278 | 279 |
280 | 281 | Generating Latent Walk videos 282 | 283 | 284 | | Generating Latent Walk videos | 285 | | --- | 286 | | **Prompt 1:** _A dog chasing a cat in a thrilling backyard scene, high quality and photorealistic_
**Prompt 2:** _A determined dog in hot pursuit, with stunning realism, octane render_
**Prompt 3:** _A thrilling chase, dog behind the cat, octane render, exceptional realism and quality_
**Prompt 4:** _The exciting moment of a cat outmaneuvering a chasing dog, high-quality and photorealistic detail_
**Prompt 5:** _A clever cat escaping a determined dog and soaring into space, rendered with octane render for stunning realism_
**Prompt 6:** _The cat's escape into the cosmos, leaving the dog behind in a scene,high quality and photorealistic style_
| 287 | |
| 288 | 289 | Note that these results aren't very good. I tried different seeds but for this story, I couldn't make a great video. I did try some other prompts and got better results, but I like this story so I'm sticking with it 🤓 290 | You can improve the results by using better prompts and increasing the number of interpolation and inference steps. 291 |
292 | 293 | ## Future 294 | 295 | At the moment, I'm not sure if I'll continue to expand on this project, but if I do, here are some things I have in mind (in no particular order, and for documentation purposes): 296 | 297 | - Add support for more techniques of inference - explore new sampling techniques and optimize diffusion paths 298 | - Implement and stay up-to-date with the latest papers in the field 299 | - Removing 🧨 diffusers as a dependency by implementing all required components myself 300 | - Create user-friendly web demos or GUI tools to make experimentation easier. 301 | - Add LoRA, training and fine-tuning support 302 | - Improve codebase, documentation and tests 303 | - Improve support for not only Stable Diffusion, but other diffusion techniques, involving but not limited to audio, video, etc. 304 | 305 | ## License 306 | 307 | MIT 308 | -------------------------------------------------------------------------------- /stablefused/diffusion/latent_walk_diffusion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from dataclasses import dataclass 5 | from diffusers import AutoencoderKL 6 | from tqdm.auto import tqdm 7 | from transformers import CLIPTextModel, CLIPTokenizer 8 | from typing import List, Optional, Union 9 | 10 | from stablefused.diffusion import BaseDiffusion 11 | from stablefused.typing import PromptType, OutputType, SchedulerType, UNetType 12 | from stablefused.utils import lerp, slerp 13 | 14 | 15 | @dataclass 16 | class LatentWalkConfig: 17 | """ 18 | Configuration class for running inference using LatentWalkDiffusion. 19 | 20 | Parameters 21 | ---------- 22 | prompt: PromptType 23 | Text prompt to condition on. 24 | latent: torch.FloatTensor 25 | Latent to start from. 26 | strength: float 27 | The strength of the latent modification, controlling the amount of noise added. 28 | num_inference_steps: int 29 | Number of diffusion steps to run. 30 | guidance_scale: float 31 | Guidance scale encourages the model to generate images following the prompt 32 | closely, albeit at the cost of image quality. 33 | guidance_rescale: float 34 | Guidance rescale from [Common Diffusion Noise Schedules and Sample Steps are 35 | Flawed](https://arxiv.org/pdf/2305.08891.pdf). 36 | negative_prompt: Optional[PromptType] 37 | Negative text prompt to uncondition on. 38 | output_type: str 39 | Type of output to return. One of ["latent", "pil", "pt", "np"]. 40 | return_latent_history: bool 41 | Whether to return the latent history. If True, return list of all latents 42 | generated during diffusion steps. 43 | """ 44 | 45 | prompt: PromptType 46 | latent: torch.FloatTensor 47 | strength: float = 0.2 48 | num_inference_steps: int = 50 49 | guidance_scale: float = 7.5 50 | guidance_rescale: float = 0.7 51 | negative_prompt: Optional[PromptType] = None 52 | output_type: str = "pil" 53 | return_latent_history: bool = False 54 | 55 | 56 | @dataclass 57 | class LatentWalkInterpolateConfig: 58 | """ 59 | Configuration class for running interpolation using LatentWalkDiffusion. 60 | 61 | Parameters 62 | ---------- 63 | prompt: List[str] 64 | List of text prompts to condition on. 65 | latent: Optional[torch.FloatTensor] 66 | Latents to interpolate between. If None, latents are generated from noise 67 | but image_height and image_width must be provided. 68 | image_height: Optional[int] 69 | Height of image to generate. 70 | image_width: Optional[int] 71 | Width of image to generate. 72 | num_inference_steps: int 73 | Number of diffusion steps to run. 74 | interpolation_steps: Union[int, List[int]] 75 | Number of interpolation steps to run. 76 | guidance_scale: float 77 | Guidance scale encourages the model to generate images following the prompt 78 | closely, albeit at the cost of image quality. 79 | guidance_rescale: float 80 | Guidance rescale from [Common Diffusion Noise Schedules and Sample Steps are 81 | Flawed](https://arxiv.org/pdf/2305.08891.pdf). 82 | negative_prompt: Optional[List[str]] 83 | Negative text prompts to uncondition on. 84 | output_type: str 85 | Type of output to return. One of ["latent", "pil", "pt", "np"]. 86 | return_latent_history: bool 87 | Whether to return the latent history. If True, return list of all latents 88 | generated during diffusion steps. 89 | embedding_interpolation_type: str 90 | Type of interpolation to run for text embeddings. One of ["lerp", "slerp"]. 91 | latent_interpolation_type: str 92 | Type of interpolation to run for latents. One of ["lerp", "slerp"]. 93 | """ 94 | 95 | prompt: List[str] = None 96 | latent: Optional[torch.FloatTensor] = None 97 | image_height: Optional[int] = None 98 | image_width: Optional[int] = None 99 | num_inference_steps: int = 50 100 | interpolation_steps: Union[int, List[int]] = 8 101 | guidance_scale: float = 7.5 102 | guidance_rescale: float = 0.7 103 | negative_prompt: Optional[List[str]] = None 104 | output_type: str = "pil" 105 | return_latent_history: bool = False 106 | embedding_interpolation_type: str = "lerp" 107 | latent_interpolation_type: str = "slerp" 108 | 109 | 110 | class LatentWalkDiffusion(BaseDiffusion): 111 | def __init__( 112 | self, 113 | model_id: str = None, 114 | tokenizer: CLIPTokenizer = None, 115 | text_encoder: CLIPTextModel = None, 116 | vae: AutoencoderKL = None, 117 | unet: UNetType = None, 118 | scheduler: SchedulerType = None, 119 | torch_dtype: torch.dtype = torch.float32, 120 | device="cuda", 121 | *args, 122 | **kwargs, 123 | ) -> None: 124 | super().__init__( 125 | model_id=model_id, 126 | tokenizer=tokenizer, 127 | text_encoder=text_encoder, 128 | vae=vae, 129 | unet=unet, 130 | scheduler=scheduler, 131 | torch_dtype=torch_dtype, 132 | device=device, 133 | *args, 134 | **kwargs, 135 | ) 136 | 137 | def modify_latent( 138 | self, 139 | latent: torch.FloatTensor, 140 | strength: float, 141 | ) -> torch.FloatTensor: 142 | """ 143 | Modify a latent vector by adding noise. 144 | 145 | Parameters 146 | ---------- 147 | latent: torch.FloatTensor 148 | The input latent vector to modify. 149 | strength: float 150 | The strength of the modification, controlling the amount of noise added. 151 | 152 | Returns 153 | ------- 154 | torch.FloatTensor 155 | Modified latent vector. 156 | """ 157 | noise = self.random_tensor(latent.shape) 158 | new_latent = (1 - strength) * latent + strength * noise 159 | new_latent = (new_latent - new_latent.mean()) / new_latent.std() 160 | return new_latent 161 | 162 | def embedding_to_latent( 163 | self, 164 | embedding: torch.FloatTensor, 165 | num_inference_steps: int, 166 | guidance_scale: float, 167 | guidance_rescale: float, 168 | latent: torch.FloatTensor, 169 | return_latent_history: bool = False, 170 | ) -> Union[torch.FloatTensor, List[torch.FloatTensor]]: 171 | """ 172 | Generate latent by conditioning on prompt embedding using diffusion. 173 | 174 | Parameters 175 | ---------- 176 | embedding: torch.FloatTensor 177 | Embedding of text prompt. 178 | num_inference_steps: int 179 | Number of diffusion steps to run. 180 | guidance_scale: float 181 | Guidance scale encourages the model to generate images following the prompt 182 | closely, albeit at the cost of image quality. 183 | guidance_rescale: float 184 | Guidance rescale from [Common Diffusion Noise Schedules and Sample Steps are 185 | Flawed](https://arxiv.org/pdf/2305.08891.pdf). 186 | latent: torch.FloatTensor 187 | Latent to start diffusion from. 188 | return_latent_history: bool 189 | Whether to return latent history. If True, return list of all latents 190 | generated during diffusion steps. 191 | 192 | Returns 193 | ------- 194 | Union[torch.FloatTensor, List[torch.FloatTensor]] 195 | Latent generated by diffusion. If return_latent_history is True, return list of 196 | all latents generated during diffusion steps. 197 | """ 198 | 199 | latent = latent.to(self.device) 200 | 201 | # Set number of inference steps 202 | self.scheduler.set_timesteps(num_inference_steps) 203 | timesteps = self.scheduler.timesteps 204 | 205 | # Scale the latent noise by the standard deviation required by the scheduler 206 | latent = latent * self.scheduler.init_noise_sigma 207 | latent_history = [latent] 208 | 209 | # Diffusion inference loop 210 | for i, timestep in tqdm(list(enumerate(timesteps))): 211 | # Duplicate latent to avoid two forward passes to perform classifier free guidance 212 | latent_model_input = torch.cat([latent] * 2) 213 | latent_model_input = self.scheduler.scale_model_input( 214 | latent_model_input, timestep 215 | ) 216 | 217 | # Predict noise 218 | noise_prediction = self.unet( 219 | latent_model_input, 220 | timestep, 221 | encoder_hidden_states=embedding, 222 | return_dict=False, 223 | )[0] 224 | 225 | # Perform classifier free guidance 226 | noise_prediction = self.classifier_free_guidance( 227 | noise_prediction, guidance_scale, guidance_rescale 228 | ) 229 | 230 | # Update latent 231 | latent = self.scheduler.step( 232 | noise_prediction, timestep, latent, return_dict=False 233 | )[0] 234 | 235 | if return_latent_history: 236 | latent_history.append(latent) 237 | 238 | return torch.stack(latent_history) if return_latent_history else latent 239 | 240 | def interpolate_embedding( 241 | self, 242 | embedding: torch.FloatTensor, 243 | interpolation_steps: Union[int, List[int]], 244 | interpolation_type: str, 245 | ) -> torch.FloatTensor: 246 | """ 247 | Interpolate based on interpolation type. 248 | 249 | Parameters 250 | ---------- 251 | embedding: torch.FloatTensor 252 | Embedding of text prompt. 253 | interpolation_steps: Union[int, List[int]] 254 | Number of interpolation steps to run. 255 | embedding_interpolation_type: str 256 | Type of interpolation to run. One of ["lerp", "slerp"]. 257 | 258 | Returns 259 | ------- 260 | torch.FloatTensor 261 | Interpolated embedding. 262 | """ 263 | 264 | if interpolation_type == "lerp": 265 | interpolation_fn = lerp 266 | elif interpolation_type == "slerp": 267 | interpolation_fn = slerp 268 | else: 269 | raise ValueError( 270 | f"embedding_interpolation_type must be one of ['lerp', 'slerp'], got {interpolation_type}." 271 | ) 272 | 273 | # Split embedding into unconditional and text embeddings 274 | unconditional_embedding, text_embedding = embedding.chunk(2) 275 | steps = ( 276 | torch.linspace(0, 1, interpolation_steps, dtype=embedding.dtype) 277 | .cpu() 278 | .numpy() 279 | ) 280 | steps = np.expand_dims(steps, axis=tuple(range(1, text_embedding.ndim))) 281 | interpolations = [] 282 | 283 | # Interpolate between text embeddings 284 | # TODO: Think of a better way of doing this 285 | # See if it can be done parallelly instead 286 | for i in range(text_embedding.shape[0] - 1): 287 | interpolations.append( 288 | interpolation_fn( 289 | text_embedding[i], text_embedding[i + 1], steps 290 | ).squeeze(dim=1) 291 | ) 292 | interpolations = torch.cat(interpolations) 293 | 294 | # TODO: Think of a better way of doing this 295 | # It can be done because all unconditional embeddings are the same 296 | single_unconditional_embedding = unconditional_embedding[0].unsqueeze(dim=0) 297 | unconditional_embedding = single_unconditional_embedding.repeat( 298 | interpolations.shape[0], 1, 1 299 | ) 300 | interpolations = torch.cat([unconditional_embedding, interpolations]) 301 | 302 | return interpolations 303 | 304 | def interpolate_latent( 305 | self, 306 | latent: torch.FloatTensor, 307 | interpolation_steps: Union[int, List[int]], 308 | interpolation_type: str, 309 | ) -> torch.FloatTensor: 310 | """ 311 | Interpolate latent based on interpolation type. 312 | 313 | Parameters 314 | ---------- 315 | latent: torch.FloatTensor 316 | Latent to interpolate. 317 | interpolation_steps: Union[int, List[int]] 318 | Number of interpolation steps to run. 319 | latent_interpolation_type: str 320 | Type of interpolation to run. One of ["lerp", "slerp"]. 321 | 322 | Returns 323 | ------- 324 | torch.FloatTensor 325 | Interpolated latent. 326 | """ 327 | 328 | if interpolation_type == "lerp": 329 | interpolation_fn = lerp 330 | elif interpolation_type == "slerp": 331 | interpolation_fn = slerp 332 | 333 | steps = ( 334 | torch.linspace(0, 1, interpolation_steps, dtype=latent.dtype).cpu().numpy() 335 | ) 336 | steps = np.expand_dims(steps, axis=tuple(range(1, latent.ndim))) 337 | interpolations = [] 338 | 339 | # Interpolate between latents 340 | # TODO: Think of a better way of doing this 341 | # See if it can be done parallelly instead 342 | for i in range(latent.shape[0] - 1): 343 | interpolations.append( 344 | interpolation_fn(latent[i], latent[i + 1], steps).squeeze(dim=1) 345 | ) 346 | 347 | return torch.cat(interpolations) 348 | 349 | @torch.no_grad() 350 | def __call__( 351 | self, 352 | config: LatentWalkConfig, 353 | ) -> OutputType: 354 | """ 355 | Run inference by conditioning on text prompt starting from provided latent tensor. 356 | 357 | Parameters 358 | ---------- 359 | config: LatentWalkConfig 360 | Configuration for running inference using LatentWalkDiffusion. 361 | 362 | Returns 363 | ------- 364 | OutputType 365 | Generated output based on output_type. 366 | """ 367 | 368 | prompt = config.prompt 369 | latent = config.latent 370 | strength = config.strength 371 | num_inference_steps = config.num_inference_steps 372 | guidance_scale = config.guidance_scale 373 | guidance_rescale = config.guidance_rescale 374 | negative_prompt = config.negative_prompt 375 | output_type = config.output_type 376 | return_latent_history = config.return_latent_history 377 | 378 | # Validate input 379 | self.validate_input( 380 | prompt=prompt, 381 | negative_prompt=negative_prompt, 382 | strength=strength, 383 | ) 384 | 385 | # Generate embedding to condition on prompt and uncondition on negative prompt 386 | embedding = self.prompt_to_embedding( 387 | prompt=prompt, 388 | negative_prompt=negative_prompt, 389 | ) 390 | 391 | # Modify latent 392 | latent = self.modify_latent(latent, strength) 393 | 394 | # Run inference 395 | latent = self.embedding_to_latent( 396 | embedding=embedding, 397 | num_inference_steps=num_inference_steps, 398 | guidance_scale=guidance_scale, 399 | guidance_rescale=guidance_rescale, 400 | latent=latent, 401 | return_latent_history=return_latent_history, 402 | ) 403 | 404 | return self.resolve_output( 405 | latent=latent, 406 | output_type=output_type, 407 | return_latent_history=return_latent_history, 408 | ) 409 | 410 | generate = __call__ 411 | 412 | @torch.no_grad() 413 | def interpolate( 414 | self, 415 | config: LatentWalkInterpolateConfig, 416 | ) -> OutputType: 417 | """ 418 | Run inference by conditioning on text prompts and interpolating between them. 419 | 420 | Parameters 421 | ---------- 422 | config: LatentWalkInterpolateConfig 423 | Configuration for running interpolation using LatentWalkDiffusion. 424 | 425 | Returns 426 | ------- 427 | OutputType 428 | Generated output based on output_type. 429 | """ 430 | 431 | prompt = config.prompt 432 | latent = config.latent 433 | image_height = config.image_height 434 | image_width = config.image_width 435 | num_inference_steps = config.num_inference_steps 436 | interpolation_steps = config.interpolation_steps 437 | guidance_scale = config.guidance_scale 438 | guidance_rescale = config.guidance_rescale 439 | negative_prompt = config.negative_prompt 440 | output_type = config.output_type 441 | return_latent_history = config.return_latent_history 442 | embedding_interpolation_type = config.embedding_interpolation_type 443 | latent_interpolation_type = config.latent_interpolation_type 444 | 445 | # Validate input 446 | self.validate_input( 447 | prompt=prompt, 448 | negative_prompt=negative_prompt, 449 | image_height=image_height, 450 | image_width=image_width, 451 | ) 452 | 453 | # There should be atleast 2 prompts to run interpolation 454 | if not isinstance(prompt, list): 455 | raise ValueError(f"prompt must be a list of strings, not {type(prompt)}") 456 | if len(prompt) < 2: 457 | raise ValueError( 458 | f"prompt must be a list of at least 2 strings, not {len(prompt)}" 459 | ) 460 | if isinstance(interpolation_steps, int): 461 | pass 462 | # interpolation_steps = [interpolation_steps] * (len(prompt) - 1) 463 | elif isinstance(interpolation_steps, list): 464 | if len(interpolation_steps) != len(prompt) - 1: 465 | raise ValueError( 466 | f"interpolation_steps must be a list of length len(prompt) - 1, not {len(interpolation_steps)}" 467 | ) 468 | raise NotImplementedError( 469 | "interpolation_steps as a list is not yet implemented" 470 | ) 471 | else: 472 | raise ValueError( 473 | f"interpolation_steps must be an int or list, not {type(interpolation_steps)}" 474 | ) 475 | 476 | if latent is None: 477 | shape = ( 478 | len(prompt), 479 | self.unet.config.in_channels, 480 | image_height // self.vae_scale_factor, 481 | image_width // self.vae_scale_factor, 482 | ) 483 | latent = self.random_tensor(shape) 484 | elif len(prompt) != latent.shape[0]: 485 | raise ValueError( 486 | f"prompt and latent must be of the same length, not {len(prompt)} and {latent.shape[0]}" 487 | ) 488 | 489 | # Generate embedding to condition on prompt and uncondition on negative prompt 490 | embedding = self.prompt_to_embedding( 491 | prompt=prompt, 492 | negative_prompt=negative_prompt, 493 | ) 494 | 495 | # Interpolate between embeddings 496 | interpolated_embedding = self.interpolate_embedding( 497 | embedding=embedding, 498 | interpolation_steps=interpolation_steps, 499 | interpolation_type=embedding_interpolation_type, 500 | ) 501 | 502 | # Interpolate between latents 503 | interpolated_latent = self.interpolate_latent( 504 | latent=latent, 505 | interpolation_steps=interpolation_steps, 506 | interpolation_type=latent_interpolation_type, 507 | ) 508 | 509 | # Run inference 510 | latent = self.embedding_to_latent( 511 | embedding=interpolated_embedding, 512 | num_inference_steps=num_inference_steps, 513 | guidance_scale=guidance_scale, 514 | guidance_rescale=guidance_rescale, 515 | latent=interpolated_latent, 516 | return_latent_history=return_latent_history, 517 | ) 518 | 519 | return self.resolve_output( 520 | latent=latent, 521 | output_type=output_type, 522 | return_latent_history=return_latent_history, 523 | ) 524 | -------------------------------------------------------------------------------- /stablefused/diffusion/base_diffusion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from PIL import Image 5 | from abc import ABC, abstractmethod 6 | from diffusers import ( 7 | AutoencoderKL, 8 | DiffusionPipeline, 9 | ) 10 | from tqdm.auto import tqdm 11 | from transformers import CLIPTextModel, CLIPTokenizer 12 | from typing import Any, List, Optional, Tuple, Union 13 | 14 | from stablefused.typing import ( 15 | PromptType, 16 | OutputType, 17 | Scheduler, 18 | SchedulerType, 19 | UNetType, 20 | ) 21 | from stablefused.utils import ( 22 | denormalize, 23 | load_model_from_cache, 24 | normalize, 25 | numpy_to_pil, 26 | numpy_to_pt, 27 | pil_to_numpy, 28 | pt_to_numpy, 29 | resolve_scheduler, 30 | save_model_to_cache, 31 | ) 32 | 33 | 34 | class BaseDiffusion(ABC): 35 | def __init__( 36 | self, 37 | model_id: str = None, 38 | tokenizer: CLIPTokenizer = None, 39 | text_encoder: CLIPTextModel = None, 40 | vae: AutoencoderKL = None, 41 | unet: UNetType = None, 42 | scheduler: SchedulerType = None, 43 | torch_dtype: torch.dtype = torch.float32, 44 | device="cuda", 45 | use_cache=True, 46 | *args, 47 | **kwargs, 48 | ) -> None: 49 | self.device: str = device 50 | self.torch_dtype: torch.dtype = torch_dtype 51 | self.model_id: str = model_id 52 | 53 | self.tokenizer: CLIPTokenizer 54 | self.text_encoder: CLIPTextModel 55 | self.vae: AutoencoderKL 56 | self.unet: UNetType 57 | self.scheduler: SchedulerType 58 | self.vae_scale_factor: int 59 | 60 | if model_id is None: 61 | if ( 62 | tokenizer is None 63 | or text_encoder is None 64 | or vae is None 65 | or unet is None 66 | or scheduler is None 67 | ): 68 | raise ValueError( 69 | "Either (`model_id`) or (`tokenizer`, `text_encoder`, `vae`, `unet`, `scheduler`) must be provided." 70 | ) 71 | 72 | self.tokenizer = tokenizer 73 | self.text_encoder = text_encoder 74 | self.vae = vae 75 | self.unet = unet 76 | self.scheduler = scheduler 77 | else: 78 | model = DiffusionPipeline.from_pretrained( 79 | model_id, torch_dtype=torch_dtype, *args, **kwargs 80 | ) 81 | self.tokenizer = model.tokenizer 82 | self.text_encoder = model.text_encoder 83 | self.vae = model.vae 84 | self.unet = model.unet 85 | self.scheduler = model.scheduler 86 | 87 | self.to(self.device) 88 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 89 | 90 | if use_cache and model_id is not None: 91 | model = load_model_from_cache(model_id, None) 92 | if model is None: 93 | save_model_to_cache(self) 94 | else: 95 | print("here") 96 | self.share_components_with(model) 97 | 98 | def to(self, device: str) -> "BaseDiffusion": 99 | """ 100 | Move model to specified compute device. 101 | 102 | Parameters 103 | ---------- 104 | device: str 105 | The device to move the model to. Must be one of `cuda` or `cpu`. 106 | """ 107 | self.device = device 108 | self.text_encoder = self.text_encoder.to(device) 109 | self.vae = self.vae.to(device) 110 | self.unet = self.unet.to(device) 111 | return self 112 | 113 | def share_components_with(self, model: "BaseDiffusion") -> None: 114 | """ 115 | Share components with another model. This allows for sharing of the 116 | different internal components of the model, such as the text encoder, 117 | VAE, and UNet. This is useful for reducing memory usage when using 118 | multiple diffusion pipelines with the same checkpoint at the same time. 119 | 120 | Parameters 121 | ---------- 122 | model: BaseDiffusion 123 | The model to share components with. 124 | """ 125 | self.device = model.device 126 | self.torch_dtype = model.torch_dtype 127 | self.tokenizer = model.tokenizer 128 | self.text_encoder = model.text_encoder 129 | self.vae = model.vae 130 | self.unet = model.unet 131 | self.scheduler = model.scheduler 132 | self.vae_scale_factor = model.vae_scale_factor 133 | 134 | def set_scheduler(self, scheduler: Scheduler) -> None: 135 | """ 136 | Set the scheduler for the diffusion pipeline. 137 | 138 | Parameters 139 | ---------- 140 | scheduler: SchedulerType 141 | The scheduler to use for the diffusion pipeline. 142 | """ 143 | self.scheduler = resolve_scheduler(scheduler, self.scheduler.config) 144 | 145 | def enable_attention_slicing(self, slice_size: Optional[int] = -1) -> None: 146 | """ 147 | Enable attention slicing. By default, the attention head is sliced in half. 148 | This is a good tradeoff between memory and performance. 149 | 150 | Parameters 151 | ---------- 152 | slice_size: int 153 | The size of the attention slice. If -1, the attention head is sliced in 154 | half. If None, attention slicing is disabled. 155 | """ 156 | if slice_size == -1: 157 | slice_size = self.unet.config.attention_head_dim // 2 158 | self.unet.set_attention_slice(slice_size) 159 | 160 | def disable_attention_slicing(self) -> None: 161 | """Disable attention slicing.""" 162 | self.unet.set_attention_slice(None) 163 | 164 | def enable_slicing(self) -> None: 165 | """ 166 | Allow tensor slicing for vae decode step. This will cause the vae to split 167 | the input tensor to compute decoding in multiple steps. This will save 168 | memory and allow for larger batch sizes, but will affect performance slightly. 169 | """ 170 | self.vae.enable_slicing() 171 | 172 | def disable_slicing(self) -> None: 173 | """Disable tensor slicing for vae decode step.""" 174 | self.vae.disable_slicing() 175 | 176 | def enable_tiling(self) -> None: 177 | """ 178 | Allow tensor tiling for vae. This will cause the vae to split the input tensor 179 | into tiles to compute encoding/decoding in several steps. This will save a large 180 | amount of memory and allow processing larger images, but will affect performance. 181 | """ 182 | self.vae.enable_tiling() 183 | 184 | def disable_tiling(self) -> None: 185 | """Disable tensor tiling for vae.""" 186 | self.vae.disable_tiling() 187 | 188 | @staticmethod 189 | def validate_input( 190 | prompt: PromptType = None, 191 | negative_prompt: PromptType = None, 192 | image_height: int = None, 193 | image_width: int = None, 194 | start_step: int = None, 195 | num_inference_steps: int = None, 196 | strength: float = None, 197 | ) -> None: 198 | """ 199 | Validate input parameters. 200 | TODO: This needs to be removed and improved. More checks need to be added. 201 | 202 | Parameters 203 | ---------- 204 | prompt: PromptType 205 | The prompt(s) to condition on. 206 | negative_prompt: PromptType 207 | The negative prompt(s) to condition on. 208 | image_height: int 209 | The height of the image to generate. 210 | image_width: int 211 | The width of the image to generate. 212 | start_step: int 213 | The step to start inference from. 214 | num_inference_steps: int 215 | The number of inference steps to perform. 216 | strength: float 217 | The strength of the noise mixing when performing LatentWalkDiffusion. 218 | """ 219 | if image_height is not None and image_width is not None: 220 | if image_height % 8 != 0 or image_width % 8 != 0: 221 | raise ValueError( 222 | "`image_height` and `image_width` must a multiple of 8" 223 | ) 224 | if negative_prompt is not None: 225 | if type(prompt) is not type(negative_prompt): 226 | raise TypeError( 227 | "Type of `prompt` and `negative_prompt` must be the same" 228 | ) 229 | if isinstance(prompt, list) and len(prompt) != len(negative_prompt): 230 | raise ValueError( 231 | "Length of `prompt` list and `negative_prompt` list should match" 232 | ) 233 | if start_step is not None: 234 | if num_inference_steps is None: 235 | raise ValueError( 236 | "`num_inference_steps` must be provided if `start_step` is provided" 237 | ) 238 | if start_step < 0 or start_step >= num_inference_steps: 239 | raise ValueError( 240 | "`start_step` must be in the range [0, `num_inference_steps` - 1]" 241 | ) 242 | if strength is not None: 243 | if strength < 0 or strength > 1: 244 | raise ValueError("`strength` must be in the range [0.0, 1.0]") 245 | 246 | @abstractmethod 247 | def embedding_to_latent(self, *args: Any, **kwargs: Any) -> Any: 248 | """ 249 | Abstract method for converting an embedding to a latent vector. This method 250 | must be implemented by all subclasses. 251 | """ 252 | pass 253 | 254 | @abstractmethod 255 | def __call__(self, *args: Any, **kwargs: Any) -> Any: 256 | """ 257 | Abstract method for performing inference. This method must be implemented 258 | by all subclasses. 259 | """ 260 | pass 261 | 262 | def random_tensor(self, shape: Union[List[int], Tuple[int]]) -> torch.FloatTensor: 263 | """ 264 | Generate a random tensor of the specified shape. 265 | 266 | Parameters 267 | ---------- 268 | shape: List[int] or Tuple[int] 269 | The shape of the random tensor to generate. 270 | 271 | Returns 272 | ------- 273 | torch.FloatTensor 274 | A random tensor of the specified shape on the same device and dtype 275 | as model. 276 | """ 277 | rand_tensor = torch.randn(shape, device=self.device, dtype=self.torch_dtype) 278 | return rand_tensor 279 | 280 | def prompt_to_embedding( 281 | self, 282 | prompt: PromptType, 283 | negative_prompt: Optional[PromptType] = None, 284 | ) -> torch.FloatTensor: 285 | """ 286 | Convert a prompt or a list of prompts into a text embedding. 287 | 288 | Parameters 289 | ---------- 290 | prompt: PromptType 291 | The prompt or a list of prompts to convert into an embedding. Used 292 | for conditioning. 293 | negative_prompt: Optional[PromptType] 294 | A negative prompt or a list of negative prompts, by default None. 295 | Use for unconditioning. If not provided, an empty string ('') will 296 | be used to generate the unconditional embeddings. 297 | 298 | Returns 299 | ------- 300 | torch.FloatTensor 301 | A text embedding generated from the given prompt(s) and, if provided, 302 | the negative prompt(s). 303 | """ 304 | 305 | if negative_prompt is not None and type(negative_prompt) is not type(prompt): 306 | raise TypeError( 307 | f"`negative_prompt` must have the same type as `prompt` ({type(prompt)}), but found {type(negative_prompt)}" 308 | ) 309 | 310 | if isinstance(prompt, str): 311 | batch_size = 1 312 | prompt = [prompt] 313 | if negative_prompt is not None: 314 | negative_prompt = [negative_prompt] 315 | elif isinstance(prompt, list): 316 | batch_size = len(prompt) 317 | else: 318 | raise TypeError("`prompt` must be a string or a list of strings") 319 | 320 | # Tokenize the prompt(s) 321 | text_input = self.tokenizer( 322 | prompt, 323 | padding="max_length", 324 | max_length=self.tokenizer.model_max_length, 325 | truncation=True, 326 | return_tensors="pt", 327 | ) 328 | 329 | # Enable use of attention_mask if the text_encoder supports it 330 | if ( 331 | hasattr(self.text_encoder.config, "use_attention_mask") 332 | and self.text_encoder.config.use_attention_mask 333 | ): 334 | attention_mask = text_input.attention_mask.to(self.device) 335 | else: 336 | attention_mask = None 337 | 338 | # Generate text embedding 339 | text_embedding = self.text_encoder( 340 | text_input.input_ids.to(self.device), attention_mask=attention_mask 341 | )[0] 342 | 343 | # Unconditioning input is an empty string if negative_prompt is not provided 344 | if negative_prompt is None: 345 | unconditioning_input = [""] * batch_size 346 | else: 347 | unconditioning_input = negative_prompt 348 | 349 | # Tokenize the unconditioning input 350 | unconditioning_input = self.tokenizer( 351 | unconditioning_input, 352 | padding="max_length", 353 | max_length=self.tokenizer.model_max_length, 354 | truncation=True, 355 | return_tensors="pt", 356 | ) 357 | 358 | # Generate unconditional embedding 359 | unconditional_embedding = self.text_encoder( 360 | unconditioning_input.input_ids.to(self.device), 361 | attention_mask=attention_mask, 362 | )[0] 363 | 364 | # Concatenate unconditional and conditional embeddings 365 | embedding = torch.cat([unconditional_embedding, text_embedding]) 366 | return embedding 367 | 368 | def classifier_free_guidance( 369 | self, 370 | noise_prediction: torch.FloatTensor, 371 | guidance_scale: float, 372 | guidance_rescale: float, 373 | ) -> torch.FloatTensor: 374 | """ 375 | Apply classifier-free guidance to noise prediction. 376 | 377 | Parameters 378 | ---------- 379 | noise_prediction: torch.FloatTensor 380 | The noise prediction tensor to which guidance will be applied. 381 | guidance_scale: float 382 | The scale factor for applying guidance to the noise prediction. 383 | guidance_rescale: float 384 | The rescale factor for adjusting the noise prediction based on 385 | guidance. Based on findings in Section 3.4 of [Common Diffusion 386 | Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). 387 | 388 | Returns 389 | ------- 390 | torch.FloatTensor 391 | The noise prediction tensor after applying classifier-free guidance. 392 | """ 393 | 394 | # Perform guidance 395 | noise_unconditional, noise_prompt = noise_prediction.chunk(2) 396 | noise_prediction = noise_unconditional + guidance_scale * ( 397 | noise_prompt - noise_unconditional 398 | ) 399 | 400 | # Skip computing std if guidance_rescale is 0 401 | if guidance_rescale > 0: 402 | std_prompt = noise_prompt.std( 403 | dim=list(range(1, noise_prompt.ndim)), keepdim=True 404 | ) 405 | std_prediction = noise_prediction.std( 406 | dim=list(range(1, noise_prediction.ndim)), keepdim=True 407 | ) 408 | noise_prediction_rescaled = noise_prediction * (std_prompt / std_prediction) 409 | noise_prediction = ( 410 | noise_prediction * (1 - guidance_rescale) 411 | + noise_prediction_rescaled * guidance_rescale 412 | ) 413 | 414 | return noise_prediction 415 | 416 | def latent_to_image( 417 | self, latent: torch.FloatTensor, output_type: str 418 | ) -> OutputType: 419 | """ 420 | Convert a latent tensor to an image in the specified output format. 421 | 422 | Parameters 423 | ---------- 424 | latent: torch.FloatTensor 425 | The latent tensor to convert into an image. 426 | output_type: str 427 | The desired output format for the image. Should be one of [`pt`, `np`, `pil`]. 428 | 429 | Returns 430 | ------- 431 | OutputType 432 | An image in the specified output format. 433 | """ 434 | if output_type not in ["pt", "np", "pil"]: 435 | raise ValueError("`output_type` must be one of [`pt`, `np`, `pil`]") 436 | 437 | image = self.vae.decode( 438 | latent / self.vae.config.scaling_factor, return_dict=False 439 | )[0] 440 | image = denormalize(image) 441 | 442 | if output_type == "pt": 443 | return image 444 | 445 | image = pt_to_numpy(image) 446 | 447 | if output_type == "np": 448 | return image 449 | 450 | image = numpy_to_pil(image) 451 | return image 452 | 453 | def image_to_latent( 454 | self, 455 | image: Union[Image.Image, List[Image.Image], np.ndarray, torch.Tensor], 456 | ) -> torch.FloatTensor: 457 | """ 458 | Convert an image or a list of images into a latent tensor. 459 | 460 | Parameters 461 | ---------- 462 | image: Union[Image.Image, List[Image.Image], np.ndarray, torch.Tensor] 463 | The input image(s) to convert into a latent tensor. Supported types are 464 | `PIL.Image.Image`, `np.ndarray`, and `torch.Tensor`. 465 | 466 | Returns 467 | ------- 468 | torch.FloatTensor 469 | A latent tensor representing the input image(s). 470 | """ 471 | if ( 472 | not isinstance(image, Image.Image) 473 | and not isinstance(image, list) 474 | and not isinstance(image, np.ndarray) 475 | and not isinstance(image, torch.Tensor) 476 | ): 477 | raise TypeError( 478 | "`image` type must be one of (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`). Other types are not supported yet" 479 | ) 480 | if isinstance(image, Image.Image): 481 | image: List[Image.Image] = [image] 482 | 483 | if isinstance(image[0], Image.Image): 484 | image: np.ndarray = pil_to_numpy(image) 485 | 486 | if isinstance(image[0], np.ndarray): 487 | image: torch.FloatTensor = numpy_to_pt(image) 488 | 489 | image = image.to(self.device) 490 | latent = ( 491 | self.vae.encode(image).latent_dist.sample() * self.vae.config.scaling_factor 492 | ) 493 | 494 | return latent 495 | 496 | def resolve_output( 497 | self, 498 | latent: torch.FloatTensor, 499 | output_type: str, 500 | return_latent_history: bool, 501 | ) -> Union[OutputType, List[OutputType]]: 502 | """ 503 | Resolve the output from the latent based on the provided output options. 504 | 505 | Parameters 506 | ---------- 507 | latent: torch.FloatTensor 508 | The latent tensor representing the content to be resolved. 509 | output_type: str 510 | The desired output format. Should be one of [`latent`, `pt`, `np`, `pil`]. 511 | return_latent_history: bool 512 | If True, it means that the input latent tensor contains a tensor of latent 513 | tensor for each inference step. This requires decoding each latent tensor 514 | and returning a list of images. If False, decoding occurs directly. 515 | 516 | Returns 517 | ------- 518 | Union[OutputType, List[OutputType]] 519 | The resolved output based on the provided latent vector and options. 520 | """ 521 | if output_type not in ["latent", "pt", "np", "pil"]: 522 | raise ValueError( 523 | "`output_type` must be one of [`latent`, `pt`, `np`, `pil`]" 524 | ) 525 | 526 | if output_type == "latent": 527 | return latent 528 | 529 | if return_latent_history: 530 | # Transpose latent tensor from [num_steps, batch_size, *latent_dim] to 531 | # [batch_size, num_steps, *latent_dim]. 532 | # This is done so that the history of latent vectors for each prompt 533 | # is returned as a row instead of a column. It is what the user would 534 | # intuitively expect. 535 | latent = torch.transpose(latent, 0, 1) 536 | image = [ 537 | self.latent_to_image(l, output_type) 538 | for _, l in list(enumerate(tqdm(latent))) 539 | ] 540 | 541 | if output_type == "pt": 542 | image = torch.stack(image) 543 | elif output_type == "np": 544 | image = np.stack(image) 545 | else: 546 | # output type is "pil" so we can just return as a python list 547 | pass 548 | else: 549 | image = self.latent_to_image(latent, output_type) 550 | 551 | return image 552 | -------------------------------------------------------------------------------- /docs/stablefused/typing.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | stablefused.typing API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 45 |
46 |
47 |

48 | stablefused.typing

49 | 50 | 51 | 52 | 53 | 54 | 55 |
 1from .enums import (
 56 |  2    InpaintWalkType,
 57 |  3    Scheduler,
 58 |  4)
 59 |  5
 60 |  6from .type_hints import (
 61 |  7    ImageType,
 62 |  8    OutputType,
 63 |  9    PromptType,
 64 | 10    SchedulerType,
 65 | 11    UNetType,
 66 | 12)
 67 | 
68 | 69 | 70 |
71 |
72 | 254 | --------------------------------------------------------------------------------