├── docs ├── index.html └── stablefused │ └── typing.html ├── stablefused ├── apps │ └── storybook │ │ ├── __init__.py │ │ ├── author_api.py │ │ ├── speaker_api.py │ │ ├── config │ │ └── default_1_shot.json │ │ └── storybook.py ├── typing │ ├── __init__.py │ ├── enums.py │ └── type_hints.py ├── diffusion │ ├── __init__.py │ ├── image_to_image_diffusion.py │ ├── text_to_image_diffusion.py │ ├── text_to_video_diffusion.py │ ├── latent_walk_diffusion.py │ └── base_diffusion.py ├── utils │ ├── __init__.py │ ├── model_cache.py │ ├── import_utils.py │ ├── diffusion_utils.py │ └── image_utils.py └── __init__.py ├── .gitignore ├── LICENSE ├── .github └── workflows │ └── python-publish.yml ├── setup.py ├── tests ├── diffusion │ ├── test_text_to_image_diffusion.py │ ├── test_image_to_image_diffusion.py │ └── test_latent_walk_diffusion.py └── utils │ ├── test_diffusion_utils.py │ └── test_image_utils.py ├── examples ├── text_to_video_diffusion.ipynb ├── text_to_image_diffusion.ipynb ├── effect_of_guidance_scale.ipynb └── latent_walk_diffusion.ipynb └── README.md /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 |
4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /stablefused/apps/storybook/__init__.py: -------------------------------------------------------------------------------- 1 | from .author_api import StoryBookAuthorBase, G4FStoryBookAuthor 2 | from .speaker_api import StoryBookSpeakerBase, gTTSStoryBookSpeaker 3 | from .storybook import StoryBookConfig, StoryBook 4 | -------------------------------------------------------------------------------- /stablefused/typing/__init__.py: -------------------------------------------------------------------------------- 1 | from .enums import ( 2 | InpaintWalkType, 3 | Scheduler, 4 | ) 5 | 6 | from .type_hints import ( 7 | ImageType, 8 | OutputType, 9 | PromptType, 10 | SchedulerType, 11 | UNetType, 12 | ) 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # python directories 2 | .venv/ 3 | venv/ 4 | __pycache__/ 5 | .ipynb_checkpoints/ 6 | *.model 7 | *.egg-info 8 | dist/ 9 | 10 | # vscode 11 | .vscode/ 12 | 13 | # build 14 | build/ 15 | *.out 16 | *.s 17 | 18 | # tensorflow 19 | training_checkpoints/ 20 | 21 | # local test files 22 | test/ 23 | tmp/ 24 | 25 | # environment files 26 | .env 27 | 28 | # datasets 29 | *.csv 30 | 31 | # images and videos to not increase repository size 32 | resources/results/ 33 | -------------------------------------------------------------------------------- /stablefused/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_diffusion import BaseDiffusion 2 | from .image_to_image_diffusion import ImageToImageConfig, ImageToImageDiffusion 3 | from .latent_walk_diffusion import ( 4 | LatentWalkConfig, 5 | LatentWalkInterpolateConfig, 6 | LatentWalkDiffusion, 7 | ) 8 | from .text_to_image_diffusion import TextToImageConfig, TextToImageDiffusion 9 | from .text_to_video_diffusion import TextToVideoConfig, TextToVideoDiffusion 10 | from .inpaint_diffusion import InpaintConfig, InpaintWalkConfig, InpaintDiffusion 11 | -------------------------------------------------------------------------------- /stablefused/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .diffusion_utils import ( 2 | lerp, 3 | resolve_scheduler, 4 | slerp, 5 | ) 6 | from .image_utils import ( 7 | denormalize, 8 | image_grid, 9 | normalize, 10 | numpy_to_pil, 11 | numpy_to_pt, 12 | pil_to_numpy, 13 | pil_to_video, 14 | pil_to_gif, 15 | pt_to_numpy, 16 | write_text_on_image, 17 | ) 18 | from .import_utils import ( 19 | LazyImporter, 20 | ) 21 | from .model_cache import ( 22 | save_model_to_cache, 23 | load_model_from_cache, 24 | ) 25 | -------------------------------------------------------------------------------- /stablefused/utils/model_cache.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | 4 | class ModelCache: 5 | """ 6 | A cache for diffusion models. This class should not be instantiated by the user. 7 | You should use the load_model function instead. It is a mapping from model_id to 8 | diffusion model. This allows us to avoid loading the same model components multiple 9 | times. 10 | """ 11 | 12 | def __init__(self) -> None: 13 | self.cache = dict() 14 | 15 | def get(self, model_id: str, default: Any = None) -> Any: 16 | if model_id not in self.cache.keys(): 17 | return default 18 | return self.cache[model_id] 19 | 20 | def set(self, model: Any) -> None: 21 | self.cache[model.model_id] = model 22 | 23 | 24 | _model_cache = ModelCache() 25 | 26 | 27 | load_model_from_cache = _model_cache.get 28 | save_model_to_cache = _model_cache.set 29 | -------------------------------------------------------------------------------- /stablefused/utils/import_utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | from types import ModuleType 4 | from typing import Optional 5 | 6 | 7 | class LazyImporter: 8 | """ 9 | Lazy importer for modules. 10 | """ 11 | 12 | def __init__(self, module_name: str) -> None: 13 | self.module_name = module_name 14 | self.module = None 15 | 16 | def import_module(self, import_error_message: Optional[str] = None) -> ModuleType: 17 | if self.module is None: 18 | try: 19 | self.module = importlib.import_module(self.module_name) 20 | except ModuleNotFoundError: 21 | if import_error_message is None: 22 | import_error_message = ( 23 | f"'{self.module_name}' is not installed. Please install it." 24 | ) 25 | raise ModuleNotFoundError(import_error_message) 26 | return self.module 27 | -------------------------------------------------------------------------------- /stablefused/apps/storybook/author_api.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict, List 3 | 4 | from stablefused import LazyImporter 5 | 6 | 7 | class StoryBookAuthorBase(ABC): 8 | def __init__(self) -> None: 9 | pass 10 | 11 | @abstractmethod 12 | def __call__(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]: 13 | pass 14 | 15 | 16 | class G4FStoryBookAuthor(StoryBookAuthorBase): 17 | def __init__(self, model_id: str = "gpt-3.5-turbo") -> None: 18 | super().__init__() 19 | self.model_id = model_id 20 | self.g4f = LazyImporter("g4f") 21 | 22 | def __call__(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]: 23 | g4f = self.g4f.import_module( 24 | import_error_message="g4f is not installed. Please install it using `pip install g4f`." 25 | ) 26 | return g4f.ChatCompletion.create(model=self.model_id, messages=messages) 27 | -------------------------------------------------------------------------------- /stablefused/typing/enums.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class InpaintWalkType(str, Enum): 5 | """ 6 | Enum for inpainting walk types. 7 | """ 8 | 9 | UP = "up" 10 | DOWN = "down" 11 | LEFT = "left" 12 | RIGHT = "right" 13 | FORWARD = "forward" 14 | BACKWARD = "backward" 15 | 16 | 17 | class Scheduler(str, Enum): 18 | DEIS = "deis" 19 | DDIM = "ddim" 20 | DDPM = "ddpm" 21 | DPM2_KARRAS = "dpm2_karras" 22 | DPM2_KARRAS_ANCESTRAL = "dpm2_karras_ancestral" 23 | DPM_SDE = "dpm_sde" 24 | DPM_SDE_KARRAS = "dpm_sde_karras" 25 | DPM_MULTISTEP = "dpm_multistep" 26 | DPM_MULTISTEP_KARRAS = "dpm_multistep_karras" 27 | DPM_SINGLESTEP = "dpm_singlestep" 28 | DPM_SINGLESTEP_KARRAS = "dpm_singlestep_karras" 29 | EULER = "euler" 30 | EULER_ANCESTRAL = "euler_ancestral" 31 | HEUN = "heun" 32 | LINEAR_MULTISTEP = "linear_multistep" 33 | PNDM = "pndm" 34 | UNIPC = "unipc" 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Aryan V S 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /stablefused/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. include:: ../README.md 3 | """ 4 | 5 | from .diffusion import ( 6 | BaseDiffusion, 7 | ImageToImageConfig, 8 | ImageToImageDiffusion, 9 | InpaintConfig, 10 | InpaintWalkConfig, 11 | InpaintDiffusion, 12 | LatentWalkConfig, 13 | LatentWalkInterpolateConfig, 14 | LatentWalkDiffusion, 15 | TextToImageConfig, 16 | TextToImageDiffusion, 17 | TextToVideoConfig, 18 | TextToVideoDiffusion, 19 | ) 20 | 21 | from .typing import ( 22 | InpaintWalkType, 23 | Scheduler, 24 | ImageType, 25 | OutputType, 26 | PromptType, 27 | SchedulerType, 28 | UNetType, 29 | ) 30 | 31 | from .utils import ( 32 | denormalize, 33 | image_grid, 34 | lerp, 35 | load_model_from_cache, 36 | normalize, 37 | numpy_to_pil, 38 | numpy_to_pt, 39 | pil_to_numpy, 40 | pil_to_video, 41 | pil_to_gif, 42 | pt_to_numpy, 43 | resolve_scheduler, 44 | save_model_to_cache, 45 | slerp, 46 | write_text_on_image, 47 | LazyImporter, 48 | ) 49 | 50 | from .apps.storybook import ( 51 | StoryBookAuthorBase, 52 | G4FStoryBookAuthor, 53 | StoryBookConfig, 54 | StoryBook, 55 | StoryBookSpeakerBase, 56 | gTTSStoryBookSpeaker, 57 | ) 58 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /stablefused/typing/type_hints.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from PIL import Image 5 | from diffusers.models import UNet2DConditionModel, UNet3DConditionModel 6 | from diffusers.schedulers import ( 7 | DEISMultistepScheduler, 8 | DDIMScheduler, 9 | DDPMScheduler, 10 | DPMSolverSDEScheduler, 11 | DPMSolverMultistepScheduler, 12 | DPMSolverSinglestepScheduler, 13 | EulerAncestralDiscreteScheduler, 14 | EulerDiscreteScheduler, 15 | HeunDiscreteScheduler, 16 | KDPM2DiscreteScheduler, 17 | KDPM2AncestralDiscreteScheduler, 18 | LMSDiscreteScheduler, 19 | PNDMScheduler, 20 | UniPCMultistepScheduler, 21 | ) 22 | from typing import List, Union 23 | 24 | 25 | ImageType = Union[torch.Tensor, np.ndarray, Image.Image, List[Image.Image]] 26 | 27 | OutputType = Union[torch.Tensor, np.ndarray, List[Image.Image]] 28 | 29 | PromptType = Union[str, List[str]] 30 | 31 | SchedulerType = Union[ 32 | DEISMultistepScheduler, 33 | DDIMScheduler, 34 | DDPMScheduler, 35 | DPMSolverSDEScheduler, 36 | DPMSolverMultistepScheduler, 37 | DPMSolverSinglestepScheduler, 38 | EulerAncestralDiscreteScheduler, 39 | EulerDiscreteScheduler, 40 | HeunDiscreteScheduler, 41 | KDPM2DiscreteScheduler, 42 | KDPM2AncestralDiscreteScheduler, 43 | LMSDiscreteScheduler, 44 | PNDMScheduler, 45 | UniPCMultistepScheduler, 46 | ] 47 | 48 | UNetType = Union[UNet2DConditionModel, UNet3DConditionModel] 49 | -------------------------------------------------------------------------------- /stablefused/apps/storybook/speaker_api.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import List 5 | 6 | 7 | class StoryBookSpeakerBase(ABC): 8 | def __init__(self) -> None: 9 | pass 10 | 11 | @abstractmethod 12 | def __call__(self, messages: List[str], *, yield_files: bool = True) -> List[str]: 13 | pass 14 | 15 | 16 | class gTTSStoryBookSpeaker(StoryBookSpeakerBase): 17 | def __init__(self, *, lang: str = "en", tld: str = "us") -> None: 18 | super().__init__() 19 | try: 20 | from gtts import gTTS 21 | except ImportError: 22 | raise ImportError( 23 | "gTTS is not installed. Please install it using `pip install gTTS`." 24 | ) 25 | self.gTTS = gTTS 26 | self.lang = lang 27 | self.tld = tld 28 | 29 | def __call__(self, messages: List[str], *, yield_files: bool = True) -> List[str]: 30 | for message in messages: 31 | tts = self.gTTS(message, lang=self.lang, tld=self.tld) 32 | 33 | if yield_files: 34 | with tempfile.NamedTemporaryFile(suffix=".wav") as f: 35 | tts.write_to_fp(f) 36 | yield f.name 37 | else: 38 | files = [] 39 | with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: 40 | tts.write_to_fp(f) 41 | files.append(f.name) 42 | return files 43 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open("README.md", "r", encoding="utf-8") as file: 4 | long_description = file.read() 5 | 6 | setup( 7 | name="stablefused", 8 | version="0.2.0", 9 | description="StableFused is a toy library to experiment with Stable Diffusion inspired by 🤗 diffusers and various other sources!", 10 | long_description=long_description, 11 | long_description_content_type="text/markdown", 12 | author="Aryan V S", 13 | author_email="contact.aryanvs+stablefused@gmail.com", 14 | url="https://github.com/a-r-r-o-w/stablefused/", 15 | python_requires=">=3.8.0", 16 | license="MIT", 17 | packages=find_packages(), 18 | install_requires=[ 19 | "accelerate==0.21.0", 20 | "dataclasses-json==0.6.1", 21 | "diffusers==0.19.3", 22 | "ftfy==6.1.1", 23 | "imageio==2.31.1", 24 | "imageio-ffmpeg==0.4.8", 25 | "torch==2.0.1", 26 | "transformers==4.31.0", 27 | "matplotlib==3.7.2", 28 | "moviepy==1.0.3", 29 | "numpy==1.25.2", 30 | "pillow==9.5.0", 31 | "scipy==1.11.1", 32 | ], 33 | extras_require={ 34 | "dev": [ 35 | "black==23.7.0", 36 | "pytest==7.4.0", 37 | "twine>=4.0.2", 38 | ], 39 | "extras": [ 40 | "g4f==0.1.5.6", 41 | "curl-cffi==0.5.7", 42 | "gtts==2.4.0", 43 | ], 44 | }, 45 | classifiers=[ 46 | "Development Status :: 1 - Planning", 47 | "Intended Audience :: Science/Research", 48 | "Intended Audience :: Developers", 49 | "Intended Audience :: Education", 50 | "Programming Language :: Python :: 3", 51 | "Programming Language :: Python :: 3.8", 52 | "Programming Language :: Python :: 3.9", 53 | "Programming Language :: Python :: 3.10", 54 | "Operating System :: Microsoft :: Windows", 55 | "Operating System :: Unix", 56 | "License :: OSI Approved :: MIT License", 57 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 58 | ], 59 | ) 60 | 61 | # Steps to publish: 62 | # 1. Update version in setup.py 63 | # 2. python setup.py sdist bdist_wheel 64 | # 3. Check if everything works with testpypi: 65 | # twine upload --repository testpypi dist/* 66 | # 4. Upload to pypi: 67 | # twine upload dist/* 68 | -------------------------------------------------------------------------------- /tests/diffusion/test_text_to_image_diffusion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pytest 4 | 5 | from stablefused import TextToImageConfig, TextToImageDiffusion 6 | 7 | 8 | @pytest.fixture 9 | def model(): 10 | """ 11 | Fixture to initialize the TextToImageDiffusion model and set random seeds for reproducibility. 12 | 13 | Returns 14 | ------- 15 | TextToImageDiffusion 16 | The initialized TextToImageDiffusion model. 17 | """ 18 | seed = 1337 19 | model_id = "hf-internal-testing/tiny-stable-diffusion-pipe" 20 | device = "cpu" 21 | 22 | torch.manual_seed(seed) 23 | np.random.seed(seed) 24 | 25 | model = TextToImageDiffusion(model_id=model_id, device=device) 26 | return model 27 | 28 | 29 | @pytest.fixture 30 | def config(): 31 | return { 32 | "prompt": "a photo of a cat", 33 | "num_inference_steps": 1, 34 | "image_dim": 32, 35 | } 36 | 37 | 38 | def test_text_to_image_diffusion(model: TextToImageDiffusion, config: dict) -> None: 39 | """ 40 | Test case to check if the TextToImageDiffusion is working correctly. 41 | 42 | Parameters 43 | ---------- 44 | 45 | Raises 46 | ------ 47 | AssertionError 48 | If the generated image is not of type np.ndarray. 49 | If the generated image does not have the expected shape. 50 | """ 51 | 52 | dim = config.get("image_dim") 53 | images = model( 54 | TextToImageConfig( 55 | prompt=config.get("prompt"), 56 | image_height=dim, 57 | image_width=dim, 58 | num_inference_steps=config.get("num_inference_steps"), 59 | output_type="np", 60 | ) 61 | ) 62 | 63 | assert type(images) is np.ndarray 64 | assert images.shape == (1, dim, dim, 3) 65 | 66 | 67 | def test_return_latent_history(model: TextToImageDiffusion, config: dict) -> None: 68 | """ 69 | Test case to check if the latent history is returned correctly. 70 | 71 | Raises 72 | ------ 73 | AssertionError 74 | If the generated image is not of type np.ndarray. 75 | If the generated image does not have the expected shape. 76 | """ 77 | 78 | dim = config.get("image_dim") 79 | images = model( 80 | TextToImageConfig( 81 | prompt=config.get("prompt"), 82 | image_height=dim, 83 | image_width=dim, 84 | num_inference_steps=config.get("num_inference_steps"), 85 | output_type="pt", 86 | return_latent_history=True, 87 | ) 88 | ) 89 | 90 | assert type(images) is torch.Tensor 91 | assert images.shape == (1, config.get("num_inference_steps") + 1, 3, dim, dim) 92 | 93 | 94 | if __name__ == "__main__": 95 | pytest.main([__file__]) 96 | -------------------------------------------------------------------------------- /tests/diffusion/test_image_to_image_diffusion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pytest 4 | 5 | from stablefused import ImageToImageConfig, ImageToImageDiffusion 6 | 7 | 8 | @pytest.fixture 9 | def model(): 10 | """ 11 | Fixture to initialize the ImageToImageDiffusion model and set random seeds for reproducibility. 12 | 13 | Returns 14 | ------- 15 | ImageToImageDiffusion 16 | The initialized ImageToImageDiffusion model. 17 | """ 18 | seed = 1337 19 | model_id = "hf-internal-testing/tiny-stable-diffusion-pipe" 20 | device = "cpu" 21 | 22 | torch.manual_seed(seed) 23 | np.random.seed(seed) 24 | 25 | model = ImageToImageDiffusion(model_id=model_id, device=device) 26 | return model 27 | 28 | 29 | @pytest.fixture 30 | def config(): 31 | return { 32 | "prompt": "a photo of a cat", 33 | "num_inference_steps": 2, 34 | "start_step": 1, 35 | "image_dim": 32, 36 | } 37 | 38 | 39 | def test_image_to_image_diffusion(model: ImageToImageDiffusion, config: dict) -> None: 40 | """ 41 | Test case to check if the ImageToImageDiffusion is working correctly. 42 | 43 | Raises 44 | ------ 45 | AssertionError 46 | If the generated image is not of type np.ndarray. 47 | If the generated image does not have the expected shape. 48 | """ 49 | dim = config.get("image_dim") 50 | image = model.random_tensor((1, 3, dim, dim)) 51 | 52 | images = model( 53 | ImageToImageConfig( 54 | image=image, 55 | prompt=config.get("prompt"), 56 | num_inference_steps=config.get("num_inference_steps"), 57 | output_type="np", 58 | ) 59 | ) 60 | 61 | assert type(images) is np.ndarray 62 | assert images.shape == (1, dim, dim, 3) 63 | 64 | 65 | def test_return_latent_history(model: ImageToImageDiffusion, config: dict) -> None: 66 | """ 67 | Test case to check if latent history is returned correctly. 68 | 69 | Raises 70 | ------ 71 | AssertionError 72 | If the generated image is not of type np.ndarray. 73 | If the generated image does not have the expected shape. 74 | """ 75 | dim = config.get("image_dim") 76 | image = model.random_tensor((1, 3, dim, dim)) 77 | history_size = config.get("num_inference_steps") + 1 - config.get("start_step") 78 | 79 | images = model( 80 | ImageToImageConfig( 81 | image=image, 82 | prompt=config.get("prompt"), 83 | num_inference_steps=config.get("num_inference_steps"), 84 | start_step=config.get("start_step"), 85 | output_type="pt", 86 | return_latent_history=True, 87 | ) 88 | ) 89 | 90 | assert type(images) is torch.Tensor 91 | assert images.shape == (1, history_size, 3, dim, dim) 92 | 93 | 94 | if __name__ == "__main__": 95 | pytest.main([__file__]) 96 | -------------------------------------------------------------------------------- /tests/utils/test_diffusion_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | import torch 4 | 5 | from stablefused.utils import lerp, slerp 6 | 7 | 8 | test_cases_lerp = [ 9 | # Test cases for lerp with t as a float 10 | (np.array([1.0, 2.0]), np.array([4.0, 6.0]), 0.0, np.array([1.0, 2.0])), 11 | (np.array([1.0, 2.0]), np.array([4.0, 6.0]), 1.0, np.array([4.0, 6.0])), 12 | (np.array([1.0, 2.0]), np.array([4.0, 6.0]), 0.5, np.array([2.5, 4.0])), 13 | # Test cases for lerp with t as an np.ndarray 14 | ( 15 | np.array([1.0, 2.0]), 16 | np.array([4.0, 6.0]), 17 | np.array([0.0, 1.0]), 18 | np.array([[1.0, 2.0], [4.0, 6.0]]), 19 | ), 20 | ( 21 | np.array([1.0, 2.0]), 22 | np.array([4.0, 6.0]), 23 | np.array([0.5, 0.25]), 24 | np.array([[2.5, 4.0], [1.75, 3.0]]), 25 | ), 26 | # Test cases for lerp with t as a torch.Tensor 27 | ( 28 | np.array([1.0, 2.0]), 29 | np.array([4.0, 6.0]), 30 | torch.Tensor([0.0, 1.0]), 31 | np.array([[1.0, 2.0], [4.0, 6.0]]), 32 | ), 33 | ( 34 | np.array([1.0, 2.0]), 35 | np.array([4.0, 6.0]), 36 | torch.Tensor([0.5, 0.25]), 37 | np.array([[2.5, 4.0], [1.75, 3.0]]), 38 | ), 39 | ] 40 | 41 | test_cases_slerp = [ 42 | # Test cases for slerp with t as a float 43 | (np.array([1.0, 0.0]), np.array([0.0, 1.0]), 0.0, np.array([1.0, 0.0])), 44 | (np.array([1.0, 0.0]), np.array([0.0, 1.0]), 1.0, np.array([0.0, 1.0])), 45 | (np.array([1.0, 0.0]), np.array([0.0, 1.0]), 0.5, np.array([0.707107, 0.707107])), 46 | # Test cases for slerp with t as an array 47 | ( 48 | np.array([1.0, 0.0]), 49 | np.array([0.0, 1.0]), 50 | np.array([0.0, 1.0]), 51 | np.array([[1.0, 0.0], [0.0, 1.0]]), 52 | ), 53 | ( 54 | np.array([1.0, 0.0]), 55 | np.array([0.0, 1.0]), 56 | np.array([0.5, 0.25]), 57 | np.array([[0.707107, 0.707107], [0.923880, 0.382683]]), 58 | ), 59 | # Test cases for slerp with t as a torch.Tensor 60 | ( 61 | np.array([1.0, 0.0]), 62 | np.array([0.0, 1.0]), 63 | torch.Tensor([0.0, 1.0]), 64 | np.array([[1.0, 0.0], [0.0, 1.0]]), 65 | ), 66 | ( 67 | np.array([1.0, 0.0]), 68 | np.array([0.0, 1.0]), 69 | torch.Tensor([0.5, 0.25]), 70 | np.array([[0.707107, 0.707107], [0.923880, 0.382683]]), 71 | ), 72 | ] 73 | 74 | 75 | @pytest.mark.parametrize("v0, v1, t, expected", test_cases_lerp) 76 | def test_lerp(v0, v1, t, expected): 77 | result = lerp(v0, v1, t) 78 | np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-8) 79 | 80 | 81 | @pytest.mark.parametrize("v0, v1, t, expected", test_cases_slerp) 82 | def test_slerp(v0, v1, t, expected): 83 | v0_torch = torch.from_numpy(v0) 84 | v1_torch = torch.from_numpy(v1) 85 | 86 | result = slerp(v0, v1, t) 87 | result_torch = slerp(v0_torch, v1_torch, t).cpu().numpy() 88 | 89 | np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-8) 90 | torch.testing.assert_close(result_torch, expected, rtol=1e-5, atol=1e-8) 91 | 92 | 93 | if __name__ == "__main__": 94 | pytest.main() 95 | -------------------------------------------------------------------------------- /stablefused/apps/storybook/config/default_1_shot.json: -------------------------------------------------------------------------------- 1 | { 2 | "negative_prompt": "(((deformed))), (((disfigured))), unrealistic, blur, boring background, mutation, mutated, malformed, censored, colorless, bad shadow, missing parts, cropped, watermark, username, text, signature, low quality, low resolution, unattractive", 3 | "attributes": "high quality, realistic, octane render", 4 | "image_height": 512, 5 | "image_width": 512, 6 | "num_inference_steps": 20, 7 | "guidance_scale": 7.5, 8 | "guidance_rescale": 0.7, 9 | "messages": [ 10 | { 11 | "role": "system", 12 | "content": "You are the greatest author to ever have lived and a storyteller with milleniums of experience. Your task is to generate concise and visually evocative text prompts for a text-to-image model. You have to ensure that the story conveys a clear and vivid visual concept, adhering to the context provided by the user. Execute this task with utmost diligence and creativity.\n\nYou have to follow these guidelines when producing your output:\n- The output must follow a JSON structure. It should be a list of dictionaries (the keys of the dictionaries should be `story` and `prompt`).\n- For a given dictionary in the list, the key `story` should contain part of the story as its value, while the key `prompt` should contain the prompt that will be used by the text-to-image model.\n-You MUST follow the above guidelines and cannot respond in any other way.\n- If the prompt provided by the user is vague, you need to think about something creative.\n\nYou have to follow these guidelines when producing the story:\n- The story must be interesting, coherent, understandable and enjoyable to read.\n- The story can mention names of characters.\n- The story can be of variable sizes per prompt.\n\nYou have to follow these guidelines when producing the prompt:\n- The prompt must not contain names. This is very important because it confuses text-to-image models.\n- The prompt must not contain pronouns like he/she/they/them, etc. These provide no context.\n- The prompt must emphasize more on the corresponding action happening in the story prompt.\n- The prompt must be short in length (at most 15 words)." 13 | }, 14 | { 15 | "role": "user", 16 | "content": "A dog is chasing a cat. The cat devises a clever plan to escape to space. Write a story in about 100-200 words." 17 | }, 18 | { 19 | "role": "assistant", 20 | "content": "[\n {\n \"story\": \"The energetic dog sprinted through the garden, relentless in pursuit of its prey.\",\n \"prompt\": \"Energetic dog chasing cat in the garden\"\n },\n {\n \"story\": \"The cat, feeling the hot breath of the dog behind, hatched a daring escape plan.\",\n \"prompt\": \"Cat being chased by dog\"\n },\n {\n \"story\": \"With a determined glint in its eyes, the cat darted towards the makeshift spaceship.\",\n \"prompt\": \"Cat dashing towards spaceship\"\n },\n {\n \"story\": \"Inside the spaceship, Fluffy the cat prepared for an epic journey to the cosmos.\",\n \"prompt\": \"Cat preparing for epic cosmic journey, inside spaceship\"\n },\n {\n \"story\": \"As the spaceship's engines roared, it soared into the starlit expanse, leaving Earth behind.\",\n \"prompt\": \"Spaceship ascending into the starlit expanse\"\n },\n {\n \"story\": \"The dog barked at the empty sky, bewildered by the cat's unexpected interstellar escape.\",\n \"prompt\": \"Dog barking at empty sky bewildered\"\n }\n]" 21 | } 22 | ] 23 | } 24 | -------------------------------------------------------------------------------- /tests/utils/test_image_utils.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | import numpy as np 3 | import os 4 | import pytest 5 | import tempfile 6 | 7 | from PIL import Image 8 | from imageio.plugins.ffmpeg import FfmpegFormat 9 | from stablefused.utils import image_grid, pil_to_gif, pil_to_video 10 | 11 | np.random.seed(42) 12 | 13 | 14 | @pytest.fixture 15 | def image_config(): 16 | return { 17 | "width": 32, 18 | "height": 32, 19 | "channels": 3, 20 | } 21 | 22 | 23 | @pytest.fixture 24 | def num_images(): 25 | return 8 26 | 27 | 28 | def random_image(width, height, channels): 29 | random_image = np.random.randint(0, 256, (height, width, channels), dtype=np.uint8) 30 | return Image.fromarray(random_image) 31 | 32 | 33 | @pytest.fixture 34 | def random_images(image_config, num_images): 35 | image_list = [] 36 | for _ in range(num_images): 37 | image_list.append( 38 | random_image( 39 | width=image_config.get("width"), 40 | height=image_config.get("height"), 41 | channels=image_config.get("channels"), 42 | ) 43 | ) 44 | return image_list 45 | 46 | 47 | @pytest.fixture 48 | def temporary_gif_file(): 49 | with tempfile.NamedTemporaryFile(suffix=".gif", delete=False) as temp_file: 50 | temp_filename = temp_file.name 51 | yield temp_filename 52 | os.remove(temp_filename) 53 | 54 | 55 | @pytest.fixture(params=[".mp4", ".avi", ".mkv", ".mov", ".wmv"]) 56 | def temporary_video_file(request): 57 | with tempfile.NamedTemporaryFile(suffix=request.param, delete=False) as temp_file: 58 | temp_filename = temp_file.name 59 | yield temp_filename 60 | os.remove(temp_filename) 61 | 62 | 63 | def test_image_grid(random_images, num_images): 64 | """Test that image grid is created correctly.""" 65 | 66 | rows = 2 67 | cols = num_images // 2 68 | grid_image = image_grid(random_images, rows, cols) 69 | 70 | expected_width = random_images[0].width * cols 71 | expected_height = random_images[0].height * rows 72 | 73 | assert grid_image.width == expected_width 74 | assert grid_image.height == expected_height 75 | assert len(random_images) == rows * cols 76 | assert isinstance(grid_image, Image.Image) 77 | 78 | 79 | def test_pil_to_gif(random_images, temporary_gif_file): 80 | """Test that PIL images are converted to GIF correctly.""" 81 | 82 | pil_to_gif(random_images, temporary_gif_file, fps=1) 83 | 84 | assert os.path.isfile(temporary_gif_file) 85 | with Image.open(temporary_gif_file) as saved_gif: 86 | assert isinstance(saved_gif, Image.Image) 87 | assert saved_gif.is_animated 88 | assert saved_gif.n_frames == len(random_images) 89 | 90 | 91 | def test_pil_to_video(random_images, temporary_video_file): 92 | """Test that PIL images are converted to video correctly.""" 93 | pil_to_video(random_images, temporary_video_file, fps=1) 94 | 95 | assert os.path.isfile(temporary_video_file) 96 | try: 97 | video: FfmpegFormat.Reader = imageio.get_reader( 98 | temporary_video_file, format="ffmpeg" 99 | ) 100 | assert video.count_frames() == len(random_images) 101 | except Exception as e: 102 | pytest.fail(f"Failed to open video file: {e}") 103 | 104 | 105 | if __name__ == "__main__": 106 | pytest.main([__file__]) 107 | -------------------------------------------------------------------------------- /tests/diffusion/test_latent_walk_diffusion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pytest 4 | 5 | from stablefused import ( 6 | LatentWalkConfig, 7 | LatentWalkInterpolateConfig, 8 | LatentWalkDiffusion, 9 | ) 10 | 11 | 12 | @pytest.fixture 13 | def model(): 14 | """ 15 | Fixture to initialize the LatentWalkDiffusion model and set random seeds for reproducibility. 16 | 17 | Returns 18 | ------- 19 | LatentWalkDiffusion 20 | The initialized LatentWalkDiffusion model. 21 | """ 22 | seed = 1337 23 | model_id = "hf-internal-testing/tiny-stable-diffusion-pipe" 24 | device = "cpu" 25 | 26 | torch.manual_seed(seed) 27 | np.random.seed(seed) 28 | 29 | model = LatentWalkDiffusion(model_id=model_id, device=device) 30 | return model 31 | 32 | 33 | @pytest.fixture 34 | def config(): 35 | return { 36 | "prompt": "a photo of a cat", 37 | "num_inference_steps": 1, 38 | "image_dim": 32, 39 | } 40 | 41 | 42 | @pytest.fixture 43 | def config_interpolate(): 44 | return { 45 | "prompt": ["a photo of a cat", "a photo of a dog"], 46 | "num_inference_steps": 1, 47 | "interpolation_steps": 5, 48 | "image_dim": 32, 49 | } 50 | 51 | 52 | def test_latent_walk_diffusion(model: LatentWalkDiffusion, config: dict) -> None: 53 | """ 54 | Test case to check if the LatentWalkDiffusion is working correctly. 55 | 56 | Raises 57 | ------ 58 | AssertionError 59 | If the generated image is not of type np.ndarray. 60 | If the generated image does not have the expected shape. 61 | """ 62 | 63 | dim = config.get("image_dim") 64 | image = model.random_tensor((1, 3, dim, dim)) 65 | latent = model.image_to_latent(image) 66 | 67 | images = model( 68 | LatentWalkConfig( 69 | prompt=config.get("prompt"), 70 | latent=latent, 71 | num_inference_steps=config.get("num_inference_steps"), 72 | output_type="np", 73 | ) 74 | ) 75 | 76 | assert type(images) is np.ndarray 77 | assert images.shape == (1, dim, dim, 3) 78 | 79 | 80 | def test_interpolate(model: LatentWalkDiffusion, config_interpolate: dict) -> None: 81 | """ 82 | Test case to check if the LatentWalkDiffusion is working correctly. 83 | 84 | Raises 85 | ------ 86 | AssertionError 87 | If the generated image is not of type np.ndarray. 88 | If the generated image does not have the expected shape. 89 | """ 90 | 91 | dim = config_interpolate.get("image_dim") 92 | num_prompts = len(config_interpolate.get("prompt")) 93 | image_count = config_interpolate.get("interpolation_steps") * (num_prompts - 1) 94 | image = model.random_tensor((num_prompts, 3, dim, dim)) 95 | latent = model.image_to_latent(image) 96 | 97 | images = model.interpolate( 98 | LatentWalkInterpolateConfig( 99 | prompt=config_interpolate.get("prompt"), 100 | latent=latent, 101 | num_inference_steps=config_interpolate.get("num_inference_steps"), 102 | interpolation_steps=config_interpolate.get("interpolation_steps"), 103 | output_type="np", 104 | ) 105 | ) 106 | 107 | assert type(images) is np.ndarray 108 | assert images.shape == (image_count, dim, dim, 3) 109 | 110 | 111 | if __name__ == "__main__": 112 | pytest.main([__file__]) 113 | -------------------------------------------------------------------------------- /examples/text_to_video_diffusion.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# Text To Video Diffusion\n", 9 | "\n", 10 | "In this notebook, we take a look at Text to Video Diffusion." 11 | ] 12 | }, 13 | { 14 | "attachments": {}, 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "### Install and Import required packages" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "%pip install stablefused ipython" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "import numpy as np\n", 37 | "import torch\n", 38 | "\n", 39 | "from IPython.display import display, Video\n", 40 | "from diffusers.schedulers import DPMSolverMultistepScheduler\n", 41 | "\n", 42 | "from stablefused import TextToVideoDiffusion\n", 43 | "from stablefused.utils import pil_to_video, image_grid" 44 | ] 45 | }, 46 | { 47 | "attachments": {}, 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "### Initialize model and parameters\n", 52 | "\n", 53 | "We use Cerspense's Zeroscope v2 to initialize our Text To Video Diffusion model. Play around with different prompts and see what you get! You can comment out the seed part if you want to generate new random images each time you run the notebook.\n", 54 | "\n", 55 | "We enable slicing and tiling of the VAE to reduce memory required for decoding process from latent space to image space." 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "# model_id = \"damo-vilab/text-to-video-ms-1.7b\"\n", 65 | "model_id = \"cerspense/zeroscope_v2_576w\"\n", 66 | "\n", 67 | "# model = TextToVideoDiffusion(model_id = model_id, torch_dtype = torch.float16, variant = \"fp16\")\n", 68 | "model = TextToVideoDiffusion(model_id=model_id, torch_dtype=torch.float16)\n", 69 | "\n", 70 | "model.scheduler = DPMSolverMultistepScheduler.from_config(model.scheduler.config)\n", 71 | "model.enable_slicing()\n", 72 | "model.enable_tiling()" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "prompt = \"An astronaut floating in space, interstellar, black background with stars, photorealistic, high quality, 8k\"\n", 82 | "negative_prompt = \"multiple people, cartoon, unrealistic, blur, boring background, deformed, disfigured, low resolution, unattractive, nsfw\"\n", 83 | "num_inference_steps = 15\n", 84 | "video_frames = 24\n", 85 | "seed = 420\n", 86 | "\n", 87 | "torch.manual_seed(seed)\n", 88 | "np.random.seed(seed)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "frames = model(\n", 98 | " prompt=prompt,\n", 99 | " negative_prompt=negative_prompt,\n", 100 | " video_height=320,\n", 101 | " video_width=576,\n", 102 | " video_frames=video_frames,\n", 103 | " num_inference_steps=num_inference_steps,\n", 104 | " guidance_scale=8.0,\n", 105 | ")" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "filename = \"interstellar-astronaut.mp4\"\n", 115 | "pil_to_video(frames[0], filename, fps=8)" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "display(Video(filename, embed=True))" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "prompt = \"A mighty pirate ship sailing through the sea, unpleasant, thundering roar, dark night, starry night, high quality, photorealistic, 8k\"\n", 134 | "seed = 42\n", 135 | "\n", 136 | "torch.manual_seed(seed)\n", 137 | "np.random.seed(seed)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "frames = model(\n", 147 | " prompt=[prompt] * 2,\n", 148 | " video_height=320,\n", 149 | " video_width=576,\n", 150 | " video_frames=video_frames,\n", 151 | " num_inference_steps=num_inference_steps,\n", 152 | " guidance_scale=12.0,\n", 153 | ")" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [ 162 | "# Tile the frames of the two videos one above the other.\n", 163 | "frames_concatenated = []\n", 164 | "for images in zip(*frames):\n", 165 | " frames_concatenated.append(image_grid(images, rows=2, cols=1))" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "filename = \"mighty-ship.mp4\"\n", 175 | "pil_to_video(frames_concatenated, filename, fps=8)" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": null, 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "display(Video(filename, embed=True))" 185 | ] 186 | } 187 | ], 188 | "metadata": { 189 | "language_info": { 190 | "name": "python" 191 | } 192 | }, 193 | "nbformat": 4, 194 | "nbformat_minor": 0 195 | } 196 | -------------------------------------------------------------------------------- /examples/text_to_image_diffusion.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# Text to Image Diffusion\n", 9 | "\n", 10 | "In this notebook, we take a look at Text to Image Diffusion." 11 | ] 12 | }, 13 | { 14 | "attachments": {}, 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "### Install and Import required packages" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "!pip install stablefused ipython" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "import numpy as np\n", 37 | "import torch\n", 38 | "\n", 39 | "from IPython.display import Video, display\n", 40 | "from stablefused import TextToImageDiffusion\n", 41 | "from stablefused.utils import image_grid, pil_to_video" 42 | ] 43 | }, 44 | { 45 | "attachments": {}, 46 | "cell_type": "markdown", 47 | "metadata": {}, 48 | "source": [ 49 | "### Initialize model and parameters\n", 50 | "\n", 51 | "We use RunwayML's Stable Diffusion 1.5 checkpoint and initialize our Text To Image Diffusion model, and some other parameters. Play around with different prompts and see what you get! You can comment out the seed part if you want to generate new random images each time you run the notebook." 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "model_id = \"runwayml/stable-diffusion-v1-5\"\n", 61 | "model = TextToImageDiffusion(model_id=model_id)" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "prompt = \"Cyberpunk cityscape with towering skyscrapers, neon signs, and flying cars.\"\n", 71 | "negative_prompt = \"cartoon, unrealistic, blur, boring background, deformed, disfigured, low resolution, unattractive\"\n", 72 | "num_inference_steps = 20\n", 73 | "seed = 1337\n", 74 | "\n", 75 | "torch.manual_seed(seed)\n", 76 | "np.random.seed(seed)" 77 | ] 78 | }, 79 | { 80 | "attachments": {}, 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "You can run the stable diffusion inference using the call method `()` or the `.generate()` method. Refer to the documentation to see what parameters can be provided." 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "model(prompt=prompt, num_inference_steps=num_inference_steps)[0]" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "metadata": {}, 99 | "source": [ 100 | "### Visualizing the Diffusion process" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "prompt = [\n", 110 | " \"Gothic painting of an ancient castle at night, with a full moon, gargoyles, and shadows\",\n", 111 | " \"A lion in galaxies, spirals, nebulae, stars, smoke, iridescent, intricate detail, octane render, 8k\",\n", 112 | " \"A close up image of an very beautiful woman, aesthetic, realistic, high quality\",\n", 113 | " \"Concept art for a post-apocalyptic world with ruins, overgrown vegetation, and a lone survivor\",\n", 114 | "]" 115 | ] 116 | }, 117 | { 118 | "attachments": {}, 119 | "cell_type": "markdown", 120 | "metadata": {}, 121 | "source": [ 122 | "We enable attention slicing for the UNet to reduce memory requirements, which causes attention heads to be processed sequentially. We also enable slicing and tiling of the VAE to reduce memory required for decoding process from latent space to image space." 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "model.enable_attention_slicing()\n", 132 | "model.enable_slicing()\n", 133 | "model.enable_tiling()" 134 | ] 135 | }, 136 | { 137 | "attachments": {}, 138 | "cell_type": "markdown", 139 | "metadata": {}, 140 | "source": [ 141 | "Run inference on the different text prompts. We pass `return_latent_history = True`, which returns all the latents from the denoising process in latent space. We can then decode these latents to images and create a video of the denoising process." 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "images = model(\n", 151 | " prompt=prompt,\n", 152 | " negative_prompt=[negative_prompt] * len(prompt),\n", 153 | " num_inference_steps=num_inference_steps,\n", 154 | " guidance_scale=10.0,\n", 155 | " return_latent_history=True,\n", 156 | ")" 157 | ] 158 | }, 159 | { 160 | "attachments": {}, 161 | "cell_type": "markdown", 162 | "metadata": {}, 163 | "source": [ 164 | "We tile the images in a 2x2 grid here." 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "metadata": {}, 171 | "outputs": [], 172 | "source": [ 173 | "timestep_images = []\n", 174 | "for imgs in zip(*images):\n", 175 | " img = image_grid(imgs, rows=2, cols=len(prompt) // 2)\n", 176 | " timestep_images.append(img)" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": null, 182 | "metadata": {}, 183 | "outputs": [], 184 | "source": [ 185 | "path = \"text_to_image_diffusion.mp4\"\n", 186 | "pil_to_video(timestep_images, path, fps=5)" 187 | ] 188 | }, 189 | { 190 | "attachments": {}, 191 | "cell_type": "markdown", 192 | "metadata": {}, 193 | "source": [ 194 | "Tada!" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "display(Video(path, embed=True))" 204 | ] 205 | } 206 | ], 207 | "metadata": { 208 | "language_info": { 209 | "name": "python" 210 | } 211 | }, 212 | "nbformat": 4, 213 | "nbformat_minor": 0 214 | } 215 | -------------------------------------------------------------------------------- /stablefused/utils/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from diffusers.schedulers import ( 5 | DEISMultistepScheduler, 6 | DDIMScheduler, 7 | DDPMScheduler, 8 | DPMSolverSDEScheduler, 9 | DPMSolverMultistepScheduler, 10 | DPMSolverSinglestepScheduler, 11 | EulerAncestralDiscreteScheduler, 12 | EulerDiscreteScheduler, 13 | HeunDiscreteScheduler, 14 | KDPM2DiscreteScheduler, 15 | KDPM2AncestralDiscreteScheduler, 16 | LMSDiscreteScheduler, 17 | PNDMScheduler, 18 | UniPCMultistepScheduler, 19 | ) 20 | from typing import Any, Dict, Union 21 | from stablefused.typing import Scheduler, SchedulerType 22 | 23 | 24 | def lerp( 25 | v0: Union[torch.Tensor, np.ndarray], 26 | v1: Union[torch.Tensor, np.ndarray], 27 | t: Union[float, torch.Tensor, np.ndarray], 28 | ) -> Union[torch.Tensor, np.ndarray]: 29 | """ 30 | Linearly interpolate between two vectors/tensors. 31 | 32 | Parameters 33 | ---------- 34 | v0: Union[torch.Tensor, np.ndarray] 35 | First vector/tensor. 36 | v1: Union[torch.Tensor, np.ndarray] 37 | Second vector/tensor. 38 | t: Union[float, torch.Tensor, np.ndarray] 39 | Interpolation factor. If float, must be between 0 and 1. If np.ndarray or 40 | torch.Tensor, must be one dimensional with values between 0 and 1. 41 | 42 | Returns 43 | ------- 44 | Union[torch.Tensor, np.ndarray] 45 | Interpolated vector/tensor between v0 and v1. 46 | """ 47 | inputs_are_torch = False 48 | t_is_float = False 49 | 50 | if isinstance(v0, torch.Tensor): 51 | inputs_are_torch = True 52 | input_device = v0.device 53 | v0 = v0.cpu().numpy() 54 | if isinstance(v1, torch.Tensor): 55 | inputs_are_torch = True 56 | input_device = v1.device 57 | v1 = v1.cpu().numpy() 58 | if isinstance(t, torch.Tensor): 59 | inputs_are_torch = True 60 | input_device = t.device 61 | t = t.cpu().numpy() 62 | elif isinstance(t, float): 63 | t_is_float = True 64 | t = np.array([t]) 65 | 66 | t = t[..., None] 67 | v0 = v0[None, ...] 68 | v1 = v1[None, ...] 69 | v2 = (1 - t) * v0 + t * v1 70 | 71 | if t_is_float and v0.ndim > 1: 72 | assert v2.shape[0] == 1 73 | v2 = np.squeeze(v2, axis=0) 74 | if inputs_are_torch: 75 | v2 = torch.from_numpy(v2).to(input_device) 76 | 77 | return v2 78 | 79 | 80 | def slerp( 81 | v0: Union[torch.Tensor, np.ndarray], 82 | v1: Union[torch.Tensor, np.ndarray], 83 | t: Union[float, torch.Tensor, np.ndarray], 84 | DOT_THRESHOLD=0.9995, 85 | ) -> Union[torch.Tensor, np.ndarray]: 86 | """ 87 | Spherical linear interpolation between two vectors/tensors. 88 | 89 | Parameters 90 | ---------- 91 | v0: Union[torch.Tensor, np.ndarray] 92 | First vector/tensor. 93 | v1: Union[torch.Tensor, np.ndarray] 94 | Second vector/tensor. 95 | t: Union[float, np.ndarray] 96 | Interpolation factor. If float, must be between 0 and 1. If np.ndarray, must be one 97 | dimensional with values between 0 and 1. 98 | DOT_THRESHOLD: float 99 | Threshold for when to use linear interpolation instead of spherical interpolation. 100 | 101 | Returns 102 | ------- 103 | Union[torch.Tensor, np.ndarray] 104 | Interpolated vector/tensor between v0 and v1. 105 | """ 106 | inputs_are_torch = False 107 | t_is_float = False 108 | 109 | if isinstance(v0, torch.Tensor): 110 | inputs_are_torch = True 111 | input_device = v0.device 112 | v0 = v0.cpu().numpy() 113 | if isinstance(v1, torch.Tensor): 114 | inputs_are_torch = True 115 | input_device = v1.device 116 | v1 = v1.cpu().numpy() 117 | if isinstance(t, torch.Tensor): 118 | inputs_are_torch = True 119 | input_device = t.device 120 | t = t.cpu().numpy() 121 | elif isinstance(t, float): 122 | t_is_float = True 123 | t = np.array([t], dtype=v0.dtype) 124 | 125 | dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) 126 | if np.abs(dot) > DOT_THRESHOLD: 127 | # v1 and v2 are close to parallel 128 | # Use linear interpolation instead 129 | v2 = lerp(v0, v1, t) 130 | else: 131 | theta_0 = np.arccos(dot) 132 | sin_theta_0 = np.sin(theta_0) 133 | theta_t = theta_0 * t 134 | sin_theta_t = np.sin(theta_t) 135 | s0 = np.sin(theta_0 - theta_t) / sin_theta_0 136 | s1 = sin_theta_t / sin_theta_0 137 | s0 = s0[..., None] 138 | s1 = s1[..., None] 139 | v0 = v0[None, ...] 140 | v1 = v1[None, ...] 141 | v2 = s0 * v0 + s1 * v1 142 | 143 | if t_is_float and v0.ndim > 1: 144 | assert v2.shape[0] == 1 145 | v2 = np.squeeze(v2, axis=0) 146 | if inputs_are_torch: 147 | v2 = torch.from_numpy(v2).to(input_device) 148 | 149 | return v2 150 | 151 | 152 | def resolve_scheduler( 153 | scheduler_type: Scheduler, config: Dict[str, Any] 154 | ) -> SchedulerType: 155 | if scheduler_type == Scheduler.DEIS: 156 | return DEISMultistepScheduler.from_config(config) 157 | 158 | elif scheduler_type == Scheduler.DDIM: 159 | return DDIMScheduler.from_config(config) 160 | 161 | elif scheduler_type == Scheduler.DDPM: 162 | return DDPMScheduler.from_config(config) 163 | 164 | elif scheduler_type == Scheduler.DPM2_KARRAS: 165 | return KDPM2DiscreteScheduler.from_config(config) 166 | 167 | elif scheduler_type == Scheduler.DPM2_KARRAS_ANCESTRAL: 168 | return KDPM2AncestralDiscreteScheduler.from_config(config) 169 | 170 | elif scheduler_type == Scheduler.DPM_SDE: 171 | return DPMSolverSDEScheduler.from_config(config) 172 | 173 | elif scheduler_type == Scheduler.DPM_SDE_KARRAS: 174 | return DPMSolverSDEScheduler.from_config(config, use_karras_sigmas=True) 175 | 176 | elif scheduler_type == Scheduler.DPM_MULTISTEP: 177 | return DPMSolverMultistepScheduler.from_config(config) 178 | 179 | elif scheduler_type == Scheduler.DPM_MULTISTEP_KARRAS: 180 | return DPMSolverMultistepScheduler.from_config(config, use_karras_sigmas=True) 181 | 182 | elif scheduler_type == Scheduler.DPM_SINGLESTEP: 183 | return DPMSolverSinglestepScheduler.from_config(config) 184 | 185 | elif scheduler_type == Scheduler.DPM_SINGLESTEP_KARRAS: 186 | return DPMSolverSinglestepScheduler.from_config(config, use_karras_sigmas=True) 187 | 188 | elif scheduler_type == Scheduler.EULER: 189 | return EulerDiscreteScheduler.from_config(config) 190 | 191 | elif scheduler_type == Scheduler.EULER_ANCESTRAL: 192 | return EulerAncestralDiscreteScheduler.from_config(config) 193 | 194 | elif scheduler_type == Scheduler.HEUN: 195 | return HeunDiscreteScheduler.from_config(config) 196 | 197 | elif scheduler_type == Scheduler.LINEAR_MULTISTEP: 198 | return LMSDiscreteScheduler.from_config(config) 199 | 200 | elif scheduler_type == Scheduler.PNDM: 201 | return PNDMScheduler.from_config(config) 202 | 203 | elif scheduler_type == Scheduler.UNIPC: 204 | return UniPCMultistepScheduler.from_config(config) 205 | -------------------------------------------------------------------------------- /stablefused/utils/image_utils.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | import numpy as np 3 | import torch 4 | 5 | from PIL import Image, ImageDraw, ImageFont 6 | from typing import List, Tuple, Union 7 | 8 | 9 | def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray: 10 | """ 11 | Convert pytorch tensor to numpy image. 12 | 13 | Parameters 14 | ---------- 15 | images: torch.FloatTensor 16 | Image represented as a pytorch tensor (N, C, H, W). 17 | 18 | Returns 19 | ------- 20 | np.ndarray 21 | Image represented as a numpy array (N, H, W, C). 22 | """ 23 | return images.detach().cpu().permute(0, 2, 3, 1).float().numpy() 24 | 25 | 26 | def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor: 27 | """ 28 | Convert numpy image to pytorch tensor. 29 | 30 | Parameters 31 | ---------- 32 | images: np.ndarray 33 | Image represented as a numpy array (N, H, W, C). 34 | 35 | Returns 36 | ------- 37 | torch.FloatTensor 38 | Image represented as a pytorch tensor (N, C, H, W). 39 | """ 40 | if images.ndim == 3: 41 | images = images[..., None] 42 | return torch.from_numpy(images.transpose(0, 3, 1, 2)) 43 | 44 | 45 | def numpy_to_pil(images: np.ndarray) -> Image.Image: 46 | """ 47 | Convert numpy image to PIL image. 48 | 49 | Parameters 50 | ---------- 51 | images: np.ndarray 52 | Image represented as a numpy array (N, H, W, C). 53 | 54 | Returns 55 | ------- 56 | Image.Image 57 | Image represented as a PIL image. 58 | """ 59 | if images.ndim == 3: 60 | images = images[None, ...] 61 | images = (images * 255).round().astype("uint8") 62 | if images.shape[-1] == 1: 63 | pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] 64 | else: 65 | pil_images = [Image.fromarray(image) for image in images] 66 | return pil_images 67 | 68 | 69 | def pil_to_numpy(images: Union[List[Image.Image], Image.Image]) -> np.ndarray: 70 | """ 71 | Convert PIL image to numpy image. 72 | 73 | Parameters 74 | ---------- 75 | images: Union[List[Image.Image], Image.Image] 76 | PIL image or list of PIL images. 77 | 78 | Returns 79 | ------- 80 | np.ndarray 81 | Image represented as a numpy array (N, H, W, C). 82 | """ 83 | if not isinstance(images, Image.Image) and not isinstance(images, list): 84 | raise ValueError( 85 | f"Expected PIL image or list of PIL images, got {type(images)}." 86 | ) 87 | if not isinstance(images, list): 88 | images = [images] 89 | images = [np.array(image).astype(np.float32) / 255.0 for image in images] 90 | images = np.stack(images, axis=0) 91 | return images 92 | 93 | 94 | def normalize(images: torch.FloatTensor) -> torch.FloatTensor: 95 | """ 96 | Normalize an image array to the range [-1, 1]. 97 | 98 | Parameters 99 | ---------- 100 | images: torch.FloatTensor 101 | Image represented as a pytorch tensor (N, C, H, W). 102 | 103 | Returns 104 | ------- 105 | torch.FloatTensor 106 | Normalized image as pytorch tensor. 107 | """ 108 | return 2.0 * images - 1.0 109 | 110 | 111 | def denormalize(images: torch.FloatTensor) -> torch.FloatTensor: 112 | """ 113 | Denormalize an image array to the range [0.0, 1.0]. 114 | 115 | Parameters 116 | ---------- 117 | images: torch.FloatTensor 118 | Image represented as a pytorch tensor (N, C, H, W). 119 | 120 | Returns 121 | ------- 122 | torch.FloatTensor 123 | Denormalized image as pytorch tensor. 124 | """ 125 | return (0.5 + images / 2).clamp(0, 1) 126 | 127 | 128 | def pil_to_video(images: List[Image.Image], filename: str, fps: int = 60) -> None: 129 | """ 130 | Convert a list of PIL images to a video. 131 | 132 | Parameters 133 | ---------- 134 | images: List[Image.Image] 135 | List of PIL images. 136 | filename: str 137 | Filename to save video to. 138 | fps: int 139 | Frames per second of video. 140 | """ 141 | frames = [np.array(image) for image in images] 142 | with imageio.get_writer(filename, fps=fps) as video_writer: 143 | for frame in frames: 144 | video_writer.append_data(frame) 145 | 146 | 147 | def pil_to_gif(images: List[Image.Image], filename: str, fps: int = 60) -> None: 148 | """ 149 | Convert a list of PIL images to a GIF. 150 | 151 | Parameters 152 | ---------- 153 | images: List[Image.Image] 154 | List of PIL images. 155 | filename: str 156 | Filename to save GIF to. 157 | fps: int 158 | Frames per second of GIF. 159 | """ 160 | images[0].save( 161 | filename, 162 | save_all=True, 163 | append_images=images[1:], 164 | duration=1000 // fps, 165 | loop=0, 166 | ) 167 | 168 | 169 | def image_grid(images: List[Image.Image], rows: int, cols: int) -> Image.Image: 170 | """ 171 | Create a grid of images on a single PIL image. 172 | 173 | Parameters 174 | ---------- 175 | images: List[Image.Image] 176 | List of PIL images. 177 | rows: int 178 | Number of rows in grid. 179 | cols: int 180 | Number of columns in grid. 181 | 182 | Returns 183 | ------- 184 | Image.Image 185 | Grid of images as a PIL image. 186 | """ 187 | if len(images) > rows * cols: 188 | raise ValueError( 189 | f"Number of images ({len(images)}) exceeds grid size ({rows}x{cols})." 190 | ) 191 | w, h = images[0].size 192 | grid = Image.new("RGB", size=(cols * w, rows * h)) 193 | for i, image in enumerate(images): 194 | grid.paste(image, box=(i % cols * w, i // cols * h)) 195 | return grid 196 | 197 | 198 | def write_text_on_image( 199 | image: Image.Image, 200 | text: str, 201 | fontfile: str = "arial.ttf", 202 | fontsize: int = 30, 203 | padding: Union[int, Tuple[int, int]] = 10, 204 | ) -> Image.Image: 205 | if isinstance(padding, int): 206 | padding = (padding, padding) 207 | 208 | try: 209 | font = ImageFont.truetype(fontfile, size=fontsize) 210 | except IOError: 211 | font = ImageFont.load_default() 212 | 213 | image = image.copy() 214 | image_width, image_height = image.size 215 | max_text_width = image_width - padding[0] * 2 216 | 217 | draw = ImageDraw.Draw(image) 218 | text_width, text_height = draw.textsize(text, font) 219 | x = (image_width - text_width) / 2 220 | y = image_height - text_height - padding[1] 221 | 222 | lines = [] 223 | words = text.split() 224 | current_line = "" 225 | for word in words: 226 | test_line = current_line + " " + word if current_line else word 227 | test_width, _ = draw.textsize(test_line, font) 228 | if test_width <= max_text_width: 229 | current_line = test_line 230 | else: 231 | lines.append(current_line) 232 | current_line = word 233 | if current_line: 234 | lines.append(current_line) 235 | 236 | y_start = y 237 | for line in lines: 238 | text_width, text_height = draw.textsize(line, font) 239 | x = (image_width - text_width) / 2 240 | draw.text((x, y_start), line, fill="white", font=font) 241 | y_start += text_height 242 | 243 | return image 244 | -------------------------------------------------------------------------------- /stablefused/diffusion/image_to_image_diffusion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from PIL import Image 5 | from dataclasses import dataclass 6 | from diffusers import AutoencoderKL 7 | from tqdm.auto import tqdm 8 | from transformers import CLIPTextModel, CLIPTokenizer 9 | from typing import List, Optional, Union 10 | 11 | from stablefused.diffusion import BaseDiffusion 12 | from stablefused.typing import PromptType, OutputType, SchedulerType, UNetType 13 | 14 | 15 | @dataclass 16 | class ImageToImageConfig: 17 | """ 18 | Configuration class for running inference with ImageToImageDiffusion. 19 | 20 | Parameters 21 | ---------- 22 | image: Image.Image 23 | Input image to condition on. 24 | prompt: PromptType 25 | Text prompt to condition on. 26 | num_inference_steps: int 27 | Number of diffusion steps to run. 28 | start_step: int 29 | Step to start diffusion from. The higher the value, the more similar the generated 30 | image will be to the input image. 31 | guidance_scale: float 32 | Guidance scale encourages the model to generate images following the prompt 33 | closely, albeit at the cost of image quality. 34 | guidance_rescale: float 35 | Guidance rescale from [Common Diffusion Noise Schedules and Sample Steps are 36 | Flawed](https://arxiv.org/pdf/2305.08891.pdf). 37 | negative_prompt: Optional[PromptType] 38 | Negative text prompt to uncondition on. 39 | output_type: str 40 | Type of output to return. One of ["latent", "pil", "pt", "np"]. 41 | return_latent_history: bool 42 | Whether to return the latent history. If True, return list of all latents 43 | generated during diffusion steps. 44 | """ 45 | 46 | image: Image.Image 47 | prompt: PromptType 48 | num_inference_steps: int = 50 49 | start_step: int = 0 50 | guidance_scale: float = 7.5 51 | guidance_rescale: float = 0.7 52 | negative_prompt: Optional[PromptType] = None 53 | output_type: str = "pil" 54 | return_latent_history: bool = False 55 | 56 | 57 | class ImageToImageDiffusion(BaseDiffusion): 58 | def __init__( 59 | self, 60 | model_id: str = None, 61 | tokenizer: CLIPTokenizer = None, 62 | text_encoder: CLIPTextModel = None, 63 | vae: AutoencoderKL = None, 64 | unet: UNetType = None, 65 | scheduler: SchedulerType = None, 66 | torch_dtype: torch.dtype = torch.float32, 67 | device="cuda", 68 | *args, 69 | **kwargs 70 | ) -> None: 71 | super().__init__( 72 | model_id=model_id, 73 | tokenizer=tokenizer, 74 | text_encoder=text_encoder, 75 | vae=vae, 76 | unet=unet, 77 | scheduler=scheduler, 78 | torch_dtype=torch_dtype, 79 | device=device, 80 | *args, 81 | **kwargs 82 | ) 83 | 84 | def embedding_to_latent( 85 | self, 86 | embedding: torch.FloatTensor, 87 | num_inference_steps: int, 88 | start_step: int, 89 | guidance_scale: float, 90 | guidance_rescale: float, 91 | latent: torch.FloatTensor, 92 | return_latent_history: bool = False, 93 | ) -> Union[torch.FloatTensor, List[torch.FloatTensor]]: 94 | """ 95 | Generate latent by conditioning on prompt embedding and input image using diffusion. 96 | 97 | Parameters 98 | ---------- 99 | embedding: torch.FloatTensor 100 | Embedding of text prompt. 101 | num_inference_steps: int 102 | Number of diffusion steps to run. 103 | start_step: int 104 | Step to start diffusion from. The higher the value, the more similar the generated 105 | image will be to the input image. 106 | guidance_scale: float 107 | Guidance scale encourages the model to generate images following the prompt 108 | closely, albeit at the cost of image quality. 109 | guidance_rescale: float 110 | Guidance rescale from [Common Diffusion Noise Schedules and Sample Steps are 111 | Flawed](https://arxiv.org/pdf/2305.08891.pdf). 112 | latent: torch.FloatTensor 113 | Latent to start diffusion from. 114 | return_latent_history: bool 115 | Whether to return the latent history. If True, return list of all latents 116 | generated during diffusion steps. 117 | 118 | Returns 119 | ------- 120 | Union[torch.FloatTensor, List[torch.FloatTensor]] 121 | Latent generated by diffusion. If return_latent_history is True, return list of 122 | all latents generated during diffusion steps. 123 | """ 124 | 125 | latent = latent.to(self.device) 126 | 127 | # Set number of inference steps 128 | self.scheduler.set_timesteps(num_inference_steps) 129 | 130 | # Add noise to latent based on start step 131 | start_timestep = ( 132 | self.scheduler.timesteps[start_step].repeat(latent.shape[0]).long() 133 | ) 134 | noise = self.random_tensor(latent.shape) 135 | latent = self.scheduler.add_noise(latent, noise, start_timestep) 136 | 137 | timesteps = self.scheduler.timesteps[start_step:] 138 | latent_history = [latent] 139 | 140 | # Diffusion inference loop 141 | for i, timestep in tqdm(list(enumerate(timesteps))): 142 | # Duplicate latent to avoid two forward passes to perform classifier free guidance 143 | latent_model_input = torch.cat([latent] * 2) 144 | latent_model_input = self.scheduler.scale_model_input( 145 | latent_model_input, timestep 146 | ) 147 | 148 | # Predict noise 149 | noise_prediction = self.unet( 150 | latent_model_input, 151 | timestep, 152 | encoder_hidden_states=embedding, 153 | return_dict=False, 154 | )[0] 155 | 156 | # Perform classifier free guidance 157 | noise_prediction = self.classifier_free_guidance( 158 | noise_prediction, guidance_scale, guidance_rescale 159 | ) 160 | 161 | # Update latent 162 | latent = self.scheduler.step( 163 | noise_prediction, timestep, latent, return_dict=False 164 | )[0] 165 | 166 | if return_latent_history: 167 | latent_history.append(latent) 168 | 169 | return torch.stack(latent_history) if return_latent_history else latent 170 | 171 | @torch.no_grad() 172 | def __call__(self, config: ImageToImageConfig) -> OutputType: 173 | """ 174 | Run inference by conditioning on input image and text prompt. 175 | 176 | Parameters 177 | ---------- 178 | config: ImageToImageConfig 179 | Configuration for running inference with ImageToImageDiffusion. 180 | 181 | Returns 182 | ------- 183 | OutputType 184 | Generated output based on output_type. 185 | """ 186 | 187 | image = config.image 188 | prompt = config.prompt 189 | num_inference_steps = config.num_inference_steps 190 | start_step = config.start_step 191 | guidance_scale = config.guidance_scale 192 | guidance_rescale = config.guidance_rescale 193 | negative_prompt = config.negative_prompt 194 | output_type = config.output_type 195 | return_latent_history = config.return_latent_history 196 | 197 | # Validate input 198 | self.validate_input( 199 | prompt=prompt, 200 | negative_prompt=negative_prompt, 201 | start_step=start_step, 202 | num_inference_steps=num_inference_steps, 203 | ) 204 | 205 | # Generate embedding to condition on prompt and uncondition on negative prompt 206 | embedding = self.prompt_to_embedding( 207 | prompt=prompt, 208 | negative_prompt=negative_prompt, 209 | ) 210 | 211 | # Generate latent from input image 212 | image_latent = self.image_to_latent(image) 213 | 214 | # Run inference 215 | latent = self.embedding_to_latent( 216 | embedding=embedding, 217 | num_inference_steps=num_inference_steps, 218 | start_step=start_step, 219 | guidance_scale=guidance_scale, 220 | guidance_rescale=guidance_rescale, 221 | latent=image_latent, 222 | return_latent_history=return_latent_history, 223 | ) 224 | 225 | return self.resolve_output( 226 | latent=latent, 227 | output_type=output_type, 228 | return_latent_history=return_latent_history, 229 | ) 230 | 231 | generate = __call__ 232 | -------------------------------------------------------------------------------- /stablefused/apps/storybook/storybook.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import os 4 | 5 | from dataclasses import dataclass 6 | from dataclasses_json import DataClassJsonMixin 7 | from moviepy.editor import ( 8 | AudioFileClip, 9 | CompositeAudioClip, 10 | CompositeVideoClip, 11 | ImageClip, 12 | concatenate_audioclips, 13 | concatenate_videoclips, 14 | ) 15 | from typing import Dict, List, Optional, Union 16 | 17 | from stablefused import ( 18 | TextToImageConfig, 19 | TextToImageDiffusion, 20 | LatentWalkInterpolateConfig, 21 | LatentWalkDiffusion, 22 | ) 23 | from stablefused.apps.storybook import StoryBookAuthorBase, StoryBookSpeakerBase 24 | from stablefused.utils import write_text_on_image 25 | 26 | 27 | @dataclass 28 | class StoryBookConfig(DataClassJsonMixin): 29 | """ 30 | Configuration class for running inference with StoryBook. 31 | """ 32 | 33 | prompt: str 34 | artist_config: Union[TextToImageConfig, LatentWalkInterpolateConfig] 35 | artist_attributes: str = "" 36 | messages: List[Dict[str, str]] = None 37 | display_captions: bool = True 38 | caption_fontsize: int = 30 39 | caption_fontfile: str = "arial.ttf" 40 | caption_padding: int = 10 41 | frame_duration: int = 1 42 | num_retries: int = 3 43 | output_filename: str = "output.mp4" 44 | 45 | 46 | class StoryBook: 47 | def __init__( 48 | self, 49 | author: StoryBookAuthorBase, 50 | artist: Union[TextToImageDiffusion, LatentWalkDiffusion], 51 | speaker: Optional[StoryBookSpeakerBase] = None, 52 | ) -> None: 53 | self.author = author 54 | self.artist = artist 55 | self.speaker = speaker 56 | 57 | def _process_config(self, config: Union[str, Dict[str, str]]) -> None: 58 | if isinstance(config, str): 59 | module_dir = os.path.dirname(os.path.abspath(__file__)) 60 | config_path = os.path.join(module_dir, config) 61 | 62 | with open(config_path, "r") as f: 63 | config = json.load(f) 64 | 65 | self.artist_call_kwargs = { 66 | k: v for k, v in config.items() if k in self._artist_call_attributes 67 | } 68 | self.negative_prompt = self.artist_call_kwargs.pop("negative_prompt", None) 69 | self.messages: List[Dict[str, str]] = config.get("messages") 70 | self.attributes = config.get("attributes", "") 71 | 72 | def create_prompt(self, role: str, content: str) -> Dict[str, str]: 73 | return {"role": role, "content": content} 74 | 75 | def validate_output(self, output: str) -> List[Dict[str, str]]: 76 | try: 77 | output_list = json.loads(output) 78 | except json.JSONDecodeError as e: 79 | raise ValueError(f"Output is not a valid JSON: {str(e)}") 80 | 81 | if not isinstance(output_list, list): 82 | raise ValueError("Output must be a list of dictionaries.") 83 | 84 | for item in output_list: 85 | if not isinstance(item, dict): 86 | raise ValueError("Each item in the list must be a dictionary.") 87 | 88 | if "story" not in item or "prompt" not in item: 89 | raise ValueError( 90 | "Each dictionary must contain 'story' and 'prompt' keys." 91 | ) 92 | 93 | if not isinstance(item["story"], str) or not isinstance( 94 | item["prompt"], str 95 | ): 96 | raise ValueError("'story' and 'prompt' values must be strings.") 97 | 98 | return output_list 99 | 100 | def __call__( 101 | self, 102 | config: StoryBookConfig, 103 | ) -> None: 104 | print(config.to_json(indent=2)) 105 | prompt = config.prompt 106 | artist_config = config.artist_config 107 | artist_attributes = config.artist_attributes 108 | messages = config.messages 109 | display_captions = config.display_captions 110 | caption_fontsize = config.caption_fontsize 111 | caption_fontfile = config.caption_fontfile 112 | caption_padding = config.caption_padding 113 | frame_duration = config.frame_duration 114 | num_retries = config.num_retries 115 | output_filename = config.output_filename 116 | 117 | if ( 118 | isinstance(artist_config, TextToImageConfig) 119 | and isinstance(self.artist, LatentWalkDiffusion) 120 | ) or ( 121 | isinstance(artist_config, LatentWalkInterpolateConfig) 122 | and isinstance(self.artist, TextToImageDiffusion) 123 | ): 124 | raise ValueError( 125 | "Artist is not compatible with the provided artist config." 126 | ) 127 | 128 | messages.append(self.create_prompt("user", prompt)) 129 | 130 | for i in range(num_retries): 131 | try: 132 | storybook = self.author(messages) 133 | storybook = self.validate_output(storybook) 134 | break 135 | except Exception as e: 136 | print(f"Error: {str(e)}") 137 | print(f"Retrying ({i + 1}/{num_retries})...") 138 | continue 139 | else: 140 | raise Exception("Failed to generate storybook. Please try again.") 141 | 142 | prompt = [f"{item['prompt']}, {artist_attributes}" for item in storybook] 143 | artist_config.prompt = prompt 144 | artist_config.negative_prompt = ( 145 | [artist_config.negative_prompt] * len(storybook) 146 | if artist_config.negative_prompt is not None 147 | else None 148 | ) 149 | 150 | if isinstance(self.artist, TextToImageDiffusion): 151 | images = self.artist(artist_config) 152 | else: 153 | images = self.artist.interpolate(artist_config) 154 | 155 | if display_captions: 156 | if isinstance(self.artist, TextToImageDiffusion): 157 | images = [ 158 | write_text_on_image( 159 | image, 160 | storypart.get("story"), 161 | fontfile=caption_fontfile, 162 | fontsize=caption_fontsize, 163 | padding=caption_padding, 164 | ) 165 | for image, storypart in zip(images, storybook) 166 | ] 167 | else: 168 | num_frames_per_prompt = config.artist_config.interpolation_steps 169 | image_list = [] 170 | for i in range(0, len(images), num_frames_per_prompt): 171 | current_images = images[i : i + num_frames_per_prompt] 172 | current_story = storybook[i // num_frames_per_prompt] 173 | image_list.extend( 174 | [ 175 | write_text_on_image( 176 | image, 177 | current_story.get("story"), 178 | fontfile=caption_fontfile, 179 | fontsize=caption_fontsize, 180 | padding=caption_padding, 181 | ) 182 | for image in current_images 183 | ] 184 | ) 185 | images = image_list 186 | 187 | if self.speaker is not None: 188 | stories = [item.get("story") for item in storybook] 189 | audioclips = [] 190 | 191 | for audiofile in self.speaker(stories, yield_files=True): 192 | audioclips.append(AudioFileClip(audiofile)) 193 | 194 | audioclip: CompositeAudioClip = concatenate_audioclips(audioclips) 195 | frame_duration = audioclip.duration / len(prompt) 196 | else: 197 | audioclip = None 198 | 199 | if isinstance(self.artist, TextToImageDiffusion): 200 | video = [ 201 | ImageClip(np.array(image), duration=frame_duration) for image in images 202 | ] 203 | else: 204 | num_frames_per_prompt = config.artist_config.interpolation_steps 205 | video = [] 206 | for i in range(0, len(images), num_frames_per_prompt): 207 | current_images = images[i : i + num_frames_per_prompt] 208 | current_clips = [ 209 | ImageClip(np.array(image), duration=frame_duration / num_frames_per_prompt) 210 | for image in current_images 211 | ] 212 | video.append(concatenate_videoclips(current_clips)) 213 | 214 | video: CompositeVideoClip = concatenate_videoclips(video) 215 | video = video.set_audio(audioclip) 216 | video = video.set_fps(60) 217 | video.write_videofile(output_filename) 218 | 219 | return storybook 220 | -------------------------------------------------------------------------------- /stablefused/diffusion/text_to_image_diffusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from dataclasses import dataclass 4 | from diffusers import AutoencoderKL 5 | from tqdm.auto import tqdm 6 | from transformers import CLIPTextModel, CLIPTokenizer 7 | from typing import List, Optional, Union 8 | 9 | from stablefused.diffusion import BaseDiffusion 10 | from stablefused.typing import PromptType, OutputType, SchedulerType, UNetType 11 | 12 | 13 | @dataclass 14 | class TextToImageConfig: 15 | """ 16 | Configuration class for running inference with TextToImageDiffusion. 17 | 18 | Parameters 19 | ---------- 20 | prompt: PromptType 21 | Text prompt to condition on. 22 | image_height: int 23 | Height of image to generate. 24 | image_width: int 25 | Width of image to generate. 26 | num_inference_steps: int 27 | Number of diffusion steps to run. 28 | guidance_scale: float 29 | Guidance scale encourages the model to generate images following the prompt 30 | closely, albeit at the cost of image quality. 31 | guidance_rescale: float 32 | Guidance rescale from [Common Diffusion Noise Schedules and Sample Steps are 33 | Flawed](https://arxiv.org/pdf/2305.08891.pdf). 34 | negative_prompt: Optional[PromptType] 35 | Negative text prompt to uncondition on. 36 | latent: Optional[torch.FloatTensor] 37 | Latent to start from. If None, latent is generated from noise. 38 | output_type: str 39 | Type of output to return. One of ["latent", "pil", "pt", "np"]. 40 | return_latent_history: bool 41 | Whether to return the latent history. If True, return list of all latents 42 | generated during diffusion steps. 43 | """ 44 | 45 | prompt: PromptType = "" 46 | image_height: int = 512 47 | image_width: int = 512 48 | num_inference_steps: int = 50 49 | guidance_scale: float = 7.5 50 | guidance_rescale: float = 0.7 51 | negative_prompt: Optional[PromptType] = None 52 | latent: Optional[torch.FloatTensor] = None 53 | output_type: str = "pil" 54 | return_latent_history: bool = False 55 | 56 | 57 | class TextToImageDiffusion(BaseDiffusion): 58 | def __init__( 59 | self, 60 | model_id: str = None, 61 | tokenizer: CLIPTokenizer = None, 62 | text_encoder: CLIPTextModel = None, 63 | vae: AutoencoderKL = None, 64 | unet: UNetType = None, 65 | scheduler: SchedulerType = None, 66 | torch_dtype: torch.dtype = torch.float32, 67 | device="cuda", 68 | *args, 69 | **kwargs 70 | ) -> None: 71 | super().__init__( 72 | model_id=model_id, 73 | tokenizer=tokenizer, 74 | text_encoder=text_encoder, 75 | vae=vae, 76 | unet=unet, 77 | scheduler=scheduler, 78 | torch_dtype=torch_dtype, 79 | device=device, 80 | *args, 81 | **kwargs 82 | ) 83 | 84 | def embedding_to_latent( 85 | self, 86 | embedding: torch.FloatTensor, 87 | image_height: int, 88 | image_width: int, 89 | num_inference_steps: int, 90 | guidance_scale: float, 91 | guidance_rescale: float, 92 | latent: Optional[torch.FloatTensor] = None, 93 | return_latent_history: bool = False, 94 | ) -> Union[torch.FloatTensor, List[torch.FloatTensor]]: 95 | """ 96 | Generate latent by conditioning on prompt embedding using diffusion. 97 | 98 | Parameters 99 | ---------- 100 | embedding: torch.FloatTensor 101 | Embedding of text prompt. 102 | image_height: int 103 | Height of image to generate. 104 | image_width: int 105 | Width of image to generate. 106 | num_inference_steps: int 107 | Number of diffusion steps to run. 108 | guidance_scale: float 109 | Guidance scale encourages the model to generate images following the prompt 110 | closely, albeit at the cost of image quality. 111 | guidance_rescale: float 112 | Guidance rescale from [Common Diffusion Noise Schedules and Sample Steps are 113 | Flawed](https://arxiv.org/pdf/2305.08891.pdf). 114 | latent: Optional[torch.FloatTensor] 115 | Latent to start from. If None, generate latent from noise. 116 | return_latent_history: bool 117 | Whether to return latent history. If True, return list of all latents 118 | generated during diffusion steps. 119 | 120 | Returns 121 | ------- 122 | Union[torch.FloatTensor, List[torch.FloatTensor]] 123 | Latent generated by diffusion. If return_latent_history is True, return 124 | list of all latents generated during diffusion steps. 125 | """ 126 | 127 | # Generate latent from noise if not provided 128 | if latent is None: 129 | shape = ( 130 | embedding.shape[0] // 2, 131 | self.unet.config.in_channels, 132 | image_height // self.vae_scale_factor, 133 | image_width // self.vae_scale_factor, 134 | ) 135 | latent = self.random_tensor(shape) 136 | latent = latent.to(self.device) 137 | 138 | # Set number of inference steps 139 | self.scheduler.set_timesteps(num_inference_steps) 140 | timesteps = self.scheduler.timesteps 141 | 142 | # Scale the latent noise by the standard deviation required by the scheduler 143 | latent = latent * self.scheduler.init_noise_sigma 144 | latent_history = [latent] 145 | 146 | # Diffusion inference loop 147 | for i, timestep in tqdm(list(enumerate(timesteps))): 148 | # Duplicate latent to avoid two forward passes to perform classifier free guidance 149 | latent_model_input = torch.cat([latent] * 2) 150 | latent_model_input = self.scheduler.scale_model_input( 151 | latent_model_input, timestep 152 | ) 153 | 154 | # Predict noise 155 | noise_prediction = self.unet( 156 | latent_model_input, 157 | timestep, 158 | encoder_hidden_states=embedding, 159 | return_dict=False, 160 | )[0] 161 | 162 | # Perform classifier free guidance 163 | noise_prediction = self.classifier_free_guidance( 164 | noise_prediction, guidance_scale, guidance_rescale 165 | ) 166 | 167 | # Update latent 168 | latent = self.scheduler.step( 169 | noise_prediction, timestep, latent, return_dict=False 170 | )[0] 171 | 172 | if return_latent_history: 173 | latent_history.append(latent) 174 | 175 | return torch.stack(latent_history) if return_latent_history else latent 176 | 177 | @torch.no_grad() 178 | def __call__( 179 | self, 180 | config: TextToImageConfig, 181 | ) -> OutputType: 182 | """ 183 | Run inference by conditioning on text prompt. 184 | 185 | Parameters 186 | ---------- 187 | config: TextToImageConfig 188 | Configuration for running inference with TextToImageDiffusion. 189 | 190 | Returns 191 | ------- 192 | OutputType 193 | Generated output based on output_type. 194 | """ 195 | 196 | prompt = config.prompt 197 | image_height = config.image_height 198 | image_width = config.image_width 199 | num_inference_steps = config.num_inference_steps 200 | guidance_scale = config.guidance_scale 201 | guidance_rescale = config.guidance_rescale 202 | negative_prompt = config.negative_prompt 203 | latent = config.latent 204 | output_type = config.output_type 205 | return_latent_history = config.return_latent_history 206 | 207 | # Validate input 208 | self.validate_input( 209 | prompt=prompt, 210 | negative_prompt=negative_prompt, 211 | image_height=image_height, 212 | image_width=image_width, 213 | ) 214 | 215 | # Generate embedding to condition on prompt and uncondition on negative prompt 216 | embedding = self.prompt_to_embedding( 217 | prompt=prompt, 218 | negative_prompt=negative_prompt, 219 | ) 220 | 221 | # Run inference 222 | latent = self.embedding_to_latent( 223 | embedding=embedding, 224 | image_height=image_height, 225 | image_width=image_width, 226 | num_inference_steps=num_inference_steps, 227 | guidance_scale=guidance_scale, 228 | guidance_rescale=guidance_rescale, 229 | latent=latent, 230 | return_latent_history=return_latent_history, 231 | ) 232 | 233 | return self.resolve_output( 234 | latent=latent, 235 | output_type=output_type, 236 | return_latent_history=return_latent_history, 237 | ) 238 | 239 | generate = __call__ 240 | -------------------------------------------------------------------------------- /stablefused/diffusion/text_to_video_diffusion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from dataclasses import dataclass 5 | from diffusers import AutoencoderKL 6 | from tqdm.auto import tqdm 7 | from transformers import CLIPTextModel, CLIPTokenizer 8 | from typing import List, Optional, Union 9 | 10 | from stablefused.diffusion import BaseDiffusion 11 | from stablefused.typing import PromptType, OutputType, SchedulerType, UNetType 12 | 13 | 14 | @dataclass 15 | class TextToVideoConfig: 16 | """ 17 | Configuration class for running inference with TextToVideoDiffusion. 18 | 19 | Parameters 20 | ---------- 21 | prompt: PromptType 22 | Text prompt to condition on. 23 | video_height: int 24 | Height of video to generate. 25 | video_width: int 26 | Width of video to generate. 27 | video_frames: int 28 | Number of frames to generate in video. 29 | num_inference_steps: int 30 | Number of diffusion steps to run. 31 | guidance_scale: float 32 | Guidance scale encourages the model to generate images following the prompt 33 | closely, albeit at the cost of image quality. 34 | guidance_rescale: float 35 | Guidance rescale from [Common Diffusion Noise Schedules and Sample Steps are 36 | Flawed](https://arxiv.org/pdf/2305.08891.pdf). 37 | negative_prompt: Optional[PromptType] 38 | Negative text prompt to uncondition on. 39 | latent: Optional[torch.FloatTensor] 40 | Latent to start from. If None, latent is generated from noise. 41 | output_type: str 42 | Type of output to return. One of ["latent", "pil", "pt", "np"]. 43 | decode_batch_size: int 44 | Batch size to use when decoding latent to image. 45 | """ 46 | 47 | prompt: PromptType 48 | video_height: int = 512 49 | video_width: int = 512 50 | video_frames: int = 24 51 | num_inference_steps: int = 50 52 | guidance_scale: float = 7.5 53 | guidance_rescale: float = 0.7 54 | negative_prompt: Optional[PromptType] = None 55 | latent: Optional[torch.FloatTensor] = None 56 | output_type: str = "pil" 57 | decode_batch_size: int = 4 58 | 59 | 60 | class TextToVideoDiffusion(BaseDiffusion): 61 | def __init__( 62 | self, 63 | model_id: str = None, 64 | tokenizer: CLIPTokenizer = None, 65 | text_encoder: CLIPTextModel = None, 66 | vae: AutoencoderKL = None, 67 | unet: UNetType = None, 68 | scheduler: SchedulerType = None, 69 | torch_dtype: torch.dtype = torch.float32, 70 | device="cuda", 71 | *args, 72 | **kwargs 73 | ) -> None: 74 | super().__init__( 75 | model_id=model_id, 76 | tokenizer=tokenizer, 77 | text_encoder=text_encoder, 78 | vae=vae, 79 | unet=unet, 80 | scheduler=scheduler, 81 | torch_dtype=torch_dtype, 82 | device=device, 83 | *args, 84 | **kwargs 85 | ) 86 | 87 | def embedding_to_latent( 88 | self, 89 | embedding: torch.FloatTensor, 90 | video_height: int, 91 | video_width: int, 92 | video_frames: int, 93 | num_inference_steps: int, 94 | guidance_scale: float, 95 | guidance_rescale: float, 96 | latent: Optional[torch.FloatTensor] = None, 97 | ) -> Union[torch.FloatTensor, List[torch.FloatTensor]]: 98 | """ 99 | Generate latent by conditioning on prompt embedding using diffusion. 100 | 101 | Parameters 102 | ---------- 103 | embedding: torch.FloatTensor 104 | Embedding of text prompt. 105 | video_height: int 106 | Height of video to generate. 107 | video_width: int 108 | Width of video to generate. 109 | video_frames: int 110 | Number of frames to generate in video. 111 | num_inference_steps: int 112 | Number of diffusion steps to run. 113 | guidance_scale: float 114 | Guidance scale encourages the model to generate images following the prompt 115 | closely, albeit at the cost of image quality. 116 | guidance_rescale: float 117 | Guidance rescale from [Common Diffusion Noise Schedules and Sample Steps are 118 | Flawed](https://arxiv.org/pdf/2305.08891.pdf). 119 | latent: Optional[torch.FloatTensor] 120 | Latent to start from. If None, generate latent from noise. 121 | 122 | Returns 123 | ------- 124 | Union[torch.FloatTensor, List[torch.FloatTensor]] 125 | Latent generated by diffusion. 126 | """ 127 | 128 | # Generate latent from noise if not provided 129 | if latent is None: 130 | shape = ( 131 | embedding.shape[0] // 2, 132 | self.unet.config.in_channels, 133 | video_frames, 134 | video_height // self.vae_scale_factor, 135 | video_width // self.vae_scale_factor, 136 | ) 137 | latent = self.random_tensor(shape) 138 | 139 | # Set number of inference steps 140 | self.scheduler.set_timesteps(num_inference_steps) 141 | timesteps = self.scheduler.timesteps 142 | 143 | # Scale the latent noise by the standard deviation required by the scheduler 144 | latent = latent * self.scheduler.init_noise_sigma 145 | 146 | # Diffusion inference loop 147 | for i, timestep in tqdm(list(enumerate(timesteps))): 148 | # Duplicate latent to avoid two forward passes to perform classifier free guidance 149 | latent_model_input = torch.cat([latent] * 2) 150 | latent_model_input = self.scheduler.scale_model_input( 151 | latent_model_input, timestep 152 | ) 153 | 154 | # Predict noise 155 | noise_prediction = self.unet( 156 | latent_model_input, 157 | timestep, 158 | encoder_hidden_states=embedding, 159 | return_dict=False, 160 | )[0] 161 | 162 | # Perform classifier free guidance 163 | noise_prediction = self.classifier_free_guidance( 164 | noise_prediction, guidance_scale, guidance_rescale 165 | ) 166 | 167 | # Update latent 168 | latent = self.scheduler.step( 169 | noise_prediction, timestep, latent, return_dict=False 170 | )[0] 171 | 172 | return latent 173 | 174 | def resolve_output( 175 | self, 176 | latent: torch.FloatTensor, 177 | output_type: str, 178 | decode_batch_size: int, 179 | ) -> OutputType: 180 | """ 181 | Resolve output type from latent. 182 | 183 | Parameters 184 | ---------- 185 | latent: torch.FloatTensor 186 | Latent to resolve output from. 187 | output_type: str 188 | Output type to resolve. Must be one of [`latent`, `pt`, `np`, `pil`]. 189 | decode_batch_size: int 190 | Batch size to use when decoding latent to image. 191 | 192 | Returns 193 | ------- 194 | OutputType 195 | The resolved output based on the provided latent vector and options. 196 | """ 197 | 198 | if output_type not in ["latent", "pt", "np", "pil"]: 199 | raise ValueError( 200 | "`output_type` must be one of [`latent`, `pt`, `np`, `pil`]" 201 | ) 202 | 203 | if output_type == "latent": 204 | return latent 205 | 206 | # B, C, F, H, W => B, F, C, H, W 207 | latent = latent.permute(0, 2, 1, 3, 4) 208 | video = [] 209 | 210 | for i in tqdm(range(latent.shape[0])): 211 | batched_output = [] 212 | for j in tqdm(range(0, latent.shape[1], decode_batch_size)): 213 | current_latent = latent[i, j : j + decode_batch_size] 214 | batched_output.extend(self.latent_to_image(current_latent, output_type)) 215 | video.append(batched_output) 216 | 217 | if output_type == "pt": 218 | video = torch.stack(video) 219 | elif output_type == "np": 220 | video = np.stack(video) 221 | 222 | return video 223 | 224 | @torch.no_grad() 225 | def __call__( 226 | self, 227 | config: TextToVideoConfig, 228 | ) -> OutputType: 229 | """ 230 | Run inference by conditioning on text prompt. 231 | 232 | Parameters 233 | ---------- 234 | config: TextToVideoConfig 235 | Configuration for running inference with TextToVideoDiffusion. 236 | 237 | Returns 238 | ------- 239 | OutputType 240 | Generated output based on output_type. 241 | """ 242 | 243 | prompt = config.prompt 244 | video_height = config.video_height 245 | video_width = config.video_width 246 | video_frames = config.video_frames 247 | num_inference_steps = config.num_inference_steps 248 | guidance_scale = config.guidance_scale 249 | guidance_rescale = config.guidance_rescale 250 | negative_prompt = config.negative_prompt 251 | latent = config.latent 252 | output_type = config.output_type 253 | decode_batch_size = config.decode_batch_size 254 | 255 | # Validate input 256 | self.validate_input( 257 | prompt=prompt, 258 | negative_prompt=negative_prompt, 259 | image_height=video_height, 260 | image_width=video_width, 261 | num_inference_steps=num_inference_steps, 262 | ) 263 | 264 | # Generate embedding to condition on prompt and uncondition on negative prompt 265 | embedding = self.prompt_to_embedding( 266 | prompt=prompt, 267 | negative_prompt=negative_prompt, 268 | ) 269 | 270 | # Run inference 271 | latent = self.embedding_to_latent( 272 | embedding=embedding, 273 | video_height=video_height, 274 | video_width=video_width, 275 | video_frames=video_frames, 276 | num_inference_steps=num_inference_steps, 277 | guidance_scale=guidance_scale, 278 | guidance_rescale=guidance_rescale, 279 | latent=latent, 280 | ) 281 | 282 | return self.resolve_output( 283 | latent=latent, 284 | output_type=output_type, 285 | decode_batch_size=decode_batch_size, 286 | ) 287 | 288 | generate = __call__ 289 | -------------------------------------------------------------------------------- /examples/effect_of_guidance_scale.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# Effect of Guidance Scale\n", 9 | "\n", 10 | "In this notebook, we take a look at how the guidance scale affects the image quality of the model.\n", 11 | "\n", 12 | "Guidance scale is a value inspired by the paper [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). The explanation of how CFG works is out-of-scope here, but there are many online sources where you can read about it (linked below).\n", 13 | "\n", 14 | "- [Guidance: a cheat for diffusion models](https://sander.ai/2022/05/26/guidance.html)\n", 15 | "- [Diffusion Models, DDPMs, DDIMs and CFG](https://betterprogramming.pub/diffusion-models-ddpms-ddims-and-classifier-free-guidance-e07b297b2869)\n", 16 | "- [Classifier-Free Guidance Scale](https://mccormickml.com/2023/02/20/classifier-free-guidance-scale/)\n", 17 | "\n", 18 | "In short, guidance scale is a value that controls the amount of \"guidance\" used in the diffusion process. That is, the higher the value, the more closely the diffusion process follows the prompt. A lower guidance scale allows the model to be more creative, and work slightly different from the exact prompt. After a certain threshold maximum value, the results start to get worse, blurry and noisy.\n", 19 | "\n", 20 | "Guidance scale values, in practice, are usually in the range 6-15, and the default value of 7.5 is used in many inference implementations. However, manipulating it can lead to some very interesting results. It also only makes sense when it is set to 1.0 or higher, which is why many implementations use a minimum value of 1.0.\n", 21 | "\n", 22 | "But... what happens when we set guidance scale to 0? Or negative? Let's find out!\n", 23 | "\n", 24 | "When you use a negative value for the guidance scale, the model will try to generate images that are the opposite of what you specify in the prompt. For example, if you prompt the model to generate an image of an astronaut, and you use a negative guidance scale, the model will try to generate an image of everything but an astronaut. This can be a fun way to generate creative and unexpected images (sometimes NSFW or absolute horrendous stuff, if you are not using a safety-checker model - which is the case with StableFused)." 25 | ] 26 | }, 27 | { 28 | "attachments": {}, 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "### Install and Import required packages" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "!pip install stablefused ipython" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "import numpy as np\n", 51 | "import torch\n", 52 | "\n", 53 | "from IPython.display import Video, display\n", 54 | "from PIL import Image, ImageDraw, ImageFont\n", 55 | "from stablefused import TextToImageDiffusion\n", 56 | "from typing import List" 57 | ] 58 | }, 59 | { 60 | "attachments": {}, 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "### Initialize model and parameters\n", 65 | "\n", 66 | "We use RunwayML's Stable Diffusion 1.5 checkpoint." 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "model_id = \"runwayml/stable-diffusion-v1-5\"\n", 76 | "model = TextToImageDiffusion(model_id=model_id, torch_dtype=torch.float16)" 77 | ] 78 | }, 79 | { 80 | "attachments": {}, 81 | "cell_type": "markdown", 82 | "metadata": { 83 | "id": "q0-_CmuxpTw8" 84 | }, 85 | "source": [ 86 | "##### Prompt Credits\n", 87 | "\n", 88 | "The prompts used in this notebook have been taken from different sources. The main inspirations are:\n", 89 | "\n", 90 | "- https://levelup.gitconnected.com/20-stable-diffusion-prompts-to-create-stunning-characters-a63017dc4b74\n", 91 | "- https://mpost.io/best-100-stable-diffusion-prompts-the-most-beautiful-ai-text-to-image-prompts/" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "prompt = [\n", 101 | " \"Artistic image, very detailed cute cat, cinematic lighting effect, cute, charming, fantasy art, digital painting, photorealistic\",\n", 102 | " \"A lion in galaxies, spirals, nebulae, stars, smoke, iridescent, intricate detail, octane render, 8k\",\n", 103 | " \"A grand city in the year 2100, atmospheric, hyper realistic, 8k, epic composition, cinematic, octane render\",\n", 104 | " \"Starry Night, painting style of Vincent van Gogh, Oil paint on canvas, Landscape with a starry night sky, dreamy, peaceful\",\n", 105 | "]\n", 106 | "negative_prompt = \"cartoon, unrealistic, blur, boring background, deformed, disfigured, low resolution, unattractive\"\n", 107 | "num_inference_steps = 20\n", 108 | "seed = 2023" 109 | ] 110 | }, 111 | { 112 | "attachments": {}, 113 | "cell_type": "markdown", 114 | "metadata": {}, 115 | "source": [ 116 | "`image_grid_with_labels` is a helper function that takes a list of images and a list of labels and displays them in a grid." 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "def image_grid_with_labels(\n", 126 | " images: List[Image.Image], labels: List[str], rows: int, cols: int\n", 127 | ") -> Image.Image:\n", 128 | " \"\"\"Create a grid of images with labels.\"\"\"\n", 129 | " if len(images) > rows * cols:\n", 130 | " raise ValueError(\n", 131 | " f\"Number of images ({len(images)}) exceeds grid size ({rows}x{cols}).\"\n", 132 | " )\n", 133 | " if len(labels) != rows:\n", 134 | " raise ValueError(\n", 135 | " f\"Number of labels ({len(labels)}) does not match the number of rows ({rows}).\"\n", 136 | " )\n", 137 | "\n", 138 | " w, h = images[0].size\n", 139 | " label_width = 100\n", 140 | "\n", 141 | " grid = Image.new(\"RGB\", size=(cols * w + label_width, rows * h))\n", 142 | " draw = ImageDraw.Draw(grid)\n", 143 | "\n", 144 | " font_size = 32\n", 145 | " font = ImageFont.truetype(\n", 146 | " \"/usr/local/lib/python3.10/dist-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSans.ttf\",\n", 147 | " size=font_size,\n", 148 | " )\n", 149 | "\n", 150 | " for i, label in enumerate(labels):\n", 151 | " x_label = label_width // 4\n", 152 | " y_label = i * h + h // 2\n", 153 | " draw.text((x_label, y_label), label, fill=(255, 255, 255), font=font)\n", 154 | "\n", 155 | " for i, image in enumerate(images):\n", 156 | " x_img = (i % cols) * w + label_width\n", 157 | " y_img = (i // cols) * h\n", 158 | " grid.paste(image, box=(x_img, y_img))\n", 159 | "\n", 160 | " return grid" 161 | ] 162 | }, 163 | { 164 | "attachments": {}, 165 | "cell_type": "markdown", 166 | "metadata": {}, 167 | "source": [ 168 | "### Inference with different guidance scales\n", 169 | "\n", 170 | "We start with a negative guidance scale and increment it by 1.5 until a certain maximum value. The results obtained are very interesting!\n", 171 | "\n", 172 | "The below code demonstrates the effect of guidance scale on different prompts." 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "epsilon = 1e-5\n", 182 | "guidance_scale = -1.5\n", 183 | "increment = 1.5\n", 184 | "max_guidance_scale = 15 + epsilon\n", 185 | "num_iterations = 0\n", 186 | "results = []\n", 187 | "guidance_scale_labels = []\n", 188 | "\n", 189 | "while guidance_scale <= max_guidance_scale:\n", 190 | " torch.manual_seed(seed)\n", 191 | " np.random.seed(seed)\n", 192 | "\n", 193 | " print(f\"Generating images with guidance_scale={guidance_scale:.2f} and seed={seed}\")\n", 194 | " results.append(\n", 195 | " model(\n", 196 | " prompt=prompt,\n", 197 | " negative_prompt=[negative_prompt] * len(prompt),\n", 198 | " num_inference_steps=num_inference_steps,\n", 199 | " guidance_scale=guidance_scale,\n", 200 | " )\n", 201 | " )\n", 202 | "\n", 203 | " guidance_scale_labels.append(str(round(guidance_scale, 2)))\n", 204 | " guidance_scale += increment\n", 205 | " num_iterations += 1" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": null, 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "flattened_results = [image for result in results for image in result]" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": null, 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "image_grid_with_labels(\n", 224 | " flattened_results,\n", 225 | " labels=guidance_scale_labels,\n", 226 | " rows=num_iterations,\n", 227 | " cols=len(prompt),\n", 228 | ").save(\"effect-of-guidance-scale-on-different-prompts.png\")" 229 | ] 230 | }, 231 | { 232 | "attachments": {}, 233 | "cell_type": "markdown", 234 | "metadata": {}, 235 | "source": [ 236 | "\n", 237 | "The below code demonstrates the effect of guidance scale on the same prompt over multiple inference steps. " 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": null, 243 | "metadata": {}, 244 | "outputs": [], 245 | "source": [ 246 | "steps = [3, 6, 12, 20, 25]\n", 247 | "prompt = \"Photorealistic illustration of a mystical alien creature, magnificent, strong, atomic, tyrannic, predator, unforgiving, full-body image\"\n", 248 | "epsilon = 1e-5\n", 249 | "guidance_scale = -1.5\n", 250 | "increment = 1.5\n", 251 | "max_guidance_scale = 15 + epsilon\n", 252 | "num_iterations = 0\n", 253 | "results = []\n", 254 | "guidance_scale_labels = []\n", 255 | "seed = 42\n", 256 | "\n", 257 | "while guidance_scale <= max_guidance_scale:\n", 258 | " step_results = []\n", 259 | "\n", 260 | " for step in steps:\n", 261 | " torch.manual_seed(seed)\n", 262 | " np.random.seed(seed)\n", 263 | "\n", 264 | " print(\n", 265 | " f\"Generating images with guidance_scale={guidance_scale:.2f}, num_inference_steps={step} and seed={seed}\"\n", 266 | " )\n", 267 | " step_results.append(\n", 268 | " model(\n", 269 | " prompt=prompt,\n", 270 | " negative_prompt=negative_prompt,\n", 271 | " num_inference_steps=step,\n", 272 | " guidance_scale=guidance_scale,\n", 273 | " )\n", 274 | " )\n", 275 | "\n", 276 | " guidance_scale_labels.append(str(round(guidance_scale, 2)))\n", 277 | " guidance_scale += increment\n", 278 | " num_iterations += 1\n", 279 | "\n", 280 | " results.append(step_results)" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": null, 286 | "metadata": {}, 287 | "outputs": [], 288 | "source": [ 289 | "flattened_results = [\n", 290 | " image for result in results for images in result for image in images\n", 291 | "]" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": null, 297 | "metadata": {}, 298 | "outputs": [], 299 | "source": [ 300 | "image_grid_with_labels(\n", 301 | " flattened_results, labels=guidance_scale_labels, rows=len(results), cols=len(steps)\n", 302 | ").save(\"effect-of-guidance-scale-vs-steps.png\")" 303 | ] 304 | } 305 | ], 306 | "metadata": { 307 | "language_info": { 308 | "name": "python" 309 | } 310 | }, 311 | "nbformat": 4, 312 | "nbformat_minor": 0 313 | } 314 | -------------------------------------------------------------------------------- /examples/latent_walk_diffusion.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# Latent Walk Diffusion\n", 9 | "\n", 10 | "In this notebook, we will take a look at latent walking in latent spaces. Generative models, like the ones used in Stable Diffusion, learn a latent representation of the world. A latent representation is a low-dimensional vector space embedding of the world. In the case of SD, this latent representation is learnt by training on text-image pairs. This representation is used to generate samples given a prompt and a random noise vector. The model tries to predict and remove noise from the random noise vector, while also aligning it the vector to the prompt. This results in some interesting properties of the latent space. In this notebook, we will explore these properties.\n", 11 | "\n", 12 | "Stable Diffusion models (atleast, the models used here) learn two latent representations - one of the NLP space for prompts, and one of the image space. These latent representations are continuous. If we choose two vectors in the latent space to sample from, we get two different/similar images depending on how different the chosen vectors are. This is the basis of latent walking. We can choose two vectors in the latent space, and sample from the latent path between them. This results in a smooth transition between the two images." 13 | ] 14 | }, 15 | { 16 | "attachments": {}, 17 | "cell_type": "markdown", 18 | "metadata": {}, 19 | "source": [ 20 | "### Install and Import required packages" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "%pip install stablefused ipython" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "import numpy as np\n", 39 | "import torch\n", 40 | "\n", 41 | "from IPython.display import Video, display\n", 42 | "from PIL import Image\n", 43 | "from stablefused import LatentWalkDiffusion, TextToImageDiffusion\n", 44 | "from stablefused.utils import image_grid, pil_to_video" 45 | ] 46 | }, 47 | { 48 | "attachments": {}, 49 | "cell_type": "markdown", 50 | "metadata": {}, 51 | "source": [ 52 | "### Initialize model and parameters\n", 53 | "\n", 54 | "We use RunwayML's Stable Diffusion 1.5 checkpoint and initialize our Latent-Walk and Text-to-Image Diffusion models. Play around with different prompts and parameters, and see what you get! You can comment out the parts that use seeds to generate random images each time you run the notebook.\n", 55 | "\n", 56 | "We use the following mechanism to trade-off speed for reduced memory footprint. It allows us to work with bigger images and larger batch sizes with about just 6GB of GPU memory.\n", 57 | "- U-Net Attention Slicing: Allows the internal U-Net model to perform computations for attention heads sequentially, rather than in parallel.\n", 58 | "- VAE Slicing: Allow tensor slicing for VAE decode step. This will cause the vae to split the input tensor to compute decoding in multiple steps.\n", 59 | "- VAE Tiling: Allow tensor tiling for vae. This will cause the vae to split the input tensor into tiles to compute encoding/decoding in several steps.\n", 60 | "\n", 61 | "Also, notice how we are loading the same model twice. That should use twice the memory, right? Well, in most cases, users stick to using the same model checkpoints across different inference pipelines, and so it makes sense to share the internal models. StableFused maintains an internal model cache which allows all internal models to be shared, in order to save memory." 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "model_id = \"runwayml/stable-diffusion-v1-5\"\n", 71 | "lw_model = LatentWalkDiffusion(model_id=model_id, torch_dtype=torch.float16)\n", 72 | "\n", 73 | "lw_model.enable_attention_slicing()\n", 74 | "lw_model.enable_slicing()\n", 75 | "lw_model.enable_tiling()\n", 76 | "\n", 77 | "t2i_model = TextToImageDiffusion(model_id=model_id, torch_dtype=torch.float16)" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": { 83 | "id": "bVIENw0bnSZ_" 84 | }, 85 | "source": [ 86 | "Prompt Credits: https://mspoweruser.com/best-stable-diffusion-prompts/#6_The_Robotic_Baroque_Battle\n" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "prompt = \"Large futuristic mechanical robot in the foreground of a baroque-style battle scene, photorealistic, high quality, 8k\"\n", 96 | "negative_prompt = \"cartoon, unrealistic, blur, boring background, deformed, disfigured, low resolution, unattractive\"\n", 97 | "num_images = 4\n", 98 | "seed = 44\n", 99 | "\n", 100 | "torch.manual_seed(seed)\n", 101 | "np.random.seed(seed)" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "# There seems to be a bug in stablefused which requires to be reviewed. The\n", 111 | "# bug causes images created using latent walk to be very noisy, or plain white,\n", 112 | "# deformed, etc.\n", 113 | "# This is why instead of being able to use an actual image to generate latents,\n", 114 | "# we need to make it ourselves. In the future, this notebook will be updated\n", 115 | "# to allow latent walking for user-chosen images\n", 116 | "\n", 117 | "# filename = \"the-robotic-baroque-battle.png\"\n", 118 | "# start_image = [Image.open(filename)] * num_images\n", 119 | "\n", 120 | "# # This step is only required when loading model with torch.float16 dtype\n", 121 | "# start_image = np.array(start_image, dtype=np.float16)\n", 122 | "\n", 123 | "# latent = lw_model.image_to_latent(start_image)\n", 124 | "\n", 125 | "image_height = 512\n", 126 | "image_width = 512\n", 127 | "shape = (\n", 128 | " 1,\n", 129 | " lw_model.unet.config.in_channels,\n", 130 | " image_height // lw_model.vae_scale_factor,\n", 131 | " image_width // lw_model.vae_scale_factor,\n", 132 | ")\n", 133 | "single_latent = lw_model.random_tensor(shape)\n", 134 | "latent = single_latent.repeat(num_images, 1, 1, 1)" 135 | ] 136 | }, 137 | { 138 | "attachments": {}, 139 | "cell_type": "markdown", 140 | "metadata": {}, 141 | "source": [ 142 | "### Latent Walking to generate similar images\n", 143 | "\n", 144 | "Let's see what our base image for latent walking looks like." 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "t2i_model(\n", 154 | " prompt=prompt,\n", 155 | " negative_prompt=negative_prompt,\n", 156 | " num_inference_steps=20,\n", 157 | " guidance_scale=10.0,\n", 158 | " latent=single_latent,\n", 159 | ")[0]" 160 | ] 161 | }, 162 | { 163 | "attachments": {}, 164 | "cell_type": "markdown", 165 | "metadata": {}, 166 | "source": [ 167 | "Latent walk with diffusion around the latent space of our sampled latent vector. This results in generation of similar images. The similarity/difference can be controlled using the `strength` parameter (set between 0 and 1, defaults to 0.2). Lower strenght leads to similar images with subtle differences. Higher strength can cause completely new ideas to be generated." 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "images = lw_model(\n", 177 | " prompt=[prompt] * num_images,\n", 178 | " negative_prompt=[negative_prompt] * num_images,\n", 179 | " latent=latent,\n", 180 | " strength=0.25,\n", 181 | " num_inference_steps=20,\n", 182 | ")" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "image_grid(images, rows=2, cols=2)" 192 | ] 193 | }, 194 | { 195 | "attachments": {}, 196 | "cell_type": "markdown", 197 | "metadata": {}, 198 | "source": [ 199 | "### Generating Videos with Latent Walking\n", 200 | "\n", 201 | "Here, we generate a video by walking the latent space of the model, using interpolation techniques to generate frames. An interpolation is just a weighted average of two embeddings calculated by some interpolation function. [Linear interpolation](https://en.wikipedia.org/wiki/Linear_interpolation) is used on the prompt embeddings and [Spherical Linear Interpolation](https://en.wikipedia.org/wiki/Slerp) is used on the latent embeddings, by default. You can change the interpolation method by passing `embedding_interpolation_type` or `latent_interpolation_type` parameter.\n", 202 | "\n", 203 | "Note that stablefused is a toy library in its infancy and is not optimized for speed, and does not support a lot of features. There are multiple bugs and issues that need to be addressed. Some things need to be implemented manually currently, but in the future, I hope to make the process easier." 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": null, 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [ 212 | "# Prompt credits: ChatGPT\n", 213 | "story_prompt = [\n", 214 | " \"A dog chasing a cat in a thrilling backyard scene, high quality and photorealistic\",\n", 215 | " \"A determined dog in hot pursuit, with stunning realism, octane render\",\n", 216 | " \"A thrilling chase, dog behind the cat, octane render, exceptional realism and quality\",\n", 217 | " \"The exciting moment of a cat outmaneuvering a chasing dog, high-quality and photorealistic detail\",\n", 218 | " \"A clever cat escaping a determined dog and soaring into space, rendered with octane render for stunning realism\",\n", 219 | " \"The cat's escape into the cosmos, leaving the dog behind in a scene,high quality and photorealistic style\",\n", 220 | "]\n", 221 | "\n", 222 | "seed = 123456\n", 223 | "\n", 224 | "torch.manual_seed(seed)\n", 225 | "np.random.seed(seed)" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": null, 231 | "metadata": {}, 232 | "outputs": [], 233 | "source": [ 234 | "# There seems to be a bug in stablefused which requires to be reviewed. The\n", 235 | "# bug causes images created using latent walk to be very noisy, or plain white,\n", 236 | "# deformed, etc.\n", 237 | "# This is why instead of being able to use an actual image to generate latents,\n", 238 | "# we need to make it ourselves. In the future, this notebook will be updated\n", 239 | "# to allow latent walking for user-chosen images\n", 240 | "\n", 241 | "# t2i_images = t2i_model(\n", 242 | "# prompt = story_prompt,\n", 243 | "# negative_prompt = [negative_prompt] * len(story_prompt),\n", 244 | "# num_inference_steps = 20,\n", 245 | "# guidance_scale = 12.0,\n", 246 | "# )\n", 247 | "\n", 248 | "image_height = 512\n", 249 | "image_width = 512\n", 250 | "shape = (\n", 251 | " len(story_prompt),\n", 252 | " lw_model.unet.config.in_channels,\n", 253 | " image_height // lw_model.vae_scale_factor,\n", 254 | " image_width // lw_model.vae_scale_factor,\n", 255 | ")\n", 256 | "latent = lw_model.random_tensor(shape)" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": null, 262 | "metadata": {}, 263 | "outputs": [], 264 | "source": [ 265 | "t2i_images = t2i_model(\n", 266 | " prompt=story_prompt,\n", 267 | " num_inference_steps=20,\n", 268 | " guidance_scale=15.0,\n", 269 | " latent=latent,\n", 270 | ")" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": null, 276 | "metadata": {}, 277 | "outputs": [], 278 | "source": [ 279 | "image_grid(t2i_images, rows=2, cols=3)" 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": null, 285 | "metadata": {}, 286 | "outputs": [], 287 | "source": [ 288 | "# Due to the bug mentioned above, this step is not required.\n", 289 | "# We can directly use the latents we generated manually\n", 290 | "# np_t2i_images = np.array(t2i_images, dtype = np.float16)\n", 291 | "# t2i_latents = t2i_model.image_to_latent(np_t2i_images)" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": null, 297 | "metadata": {}, 298 | "outputs": [], 299 | "source": [ 300 | "interpolation_steps = 24\n", 301 | "\n", 302 | "# Since stablefused does not support batch processing yet, we need\n", 303 | "# to do it manually. This notebook will be updated in the future\n", 304 | "# to support batching internally to handle a large number of images\n", 305 | "\n", 306 | "story_images = []\n", 307 | "for i in range(len(story_prompt) - 1):\n", 308 | " current_prompt = story_prompt[i : i + 2]\n", 309 | " current_latent = latent[i : i + 2]\n", 310 | " imgs = lw_model.interpolate(\n", 311 | " prompt=current_prompt,\n", 312 | " negative_prompt=[negative_prompt] * len(current_prompt),\n", 313 | " latent=current_latent,\n", 314 | " num_inference_steps=20,\n", 315 | " interpolation_steps=interpolation_steps,\n", 316 | " )\n", 317 | " story_images.extend(imgs)" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": null, 323 | "metadata": {}, 324 | "outputs": [], 325 | "source": [ 326 | "filename = \"dog-chasing-cat-story.mp4\"\n", 327 | "pil_to_video(story_images, filename, fps=8)" 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "execution_count": null, 333 | "metadata": {}, 334 | "outputs": [], 335 | "source": [ 336 | "display(Video(filename, embed=True))" 337 | ] 338 | } 339 | ], 340 | "metadata": { 341 | "language_info": { 342 | "name": "python" 343 | } 344 | }, 345 | "nbformat": 4, 346 | "nbformat_minor": 0 347 | } 348 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StableFused 2 | 3 || Source Image | 99 |Denoising Diffusion Process | 100 ||
|---|---|---|
| The Renaissance Astronaut | 105 |High quality and colorful photo of Robert J Oppenheimer, father of the atomic bomb, in a spacesuit, galaxy in the background, universe, octane render, realistic, 8k, bright colors | 106 |Stylistic photorealisic photo of Margot Robbie, playing the role of astronaut, pretty, beautiful, high contrast, high quality, galaxies, intricate detail, colorful, 8k | 107 |
| 111 | | ||
| Inpainting | 157 | 158 | 159 ||
|---|---|
|
161 | Prompt 1: Digital illustration of a mythical creature, high quality, realistic, 8k |
163 | |
| Image | 166 |Mask | 167 |
| Source Image | 258 |Latent Walks | 259 |
|---|---|
| 264 | Large futuristic mechanical robot in the foreground of a baroque-style battle scene, photorealistic, high quality, 8k 265 | | 266 ||
|
269 | |
271 |
272 | |
274 |