├── .gitignore ├── .readthedocs.yml ├── LICENSE ├── README.md ├── configs ├── ddpo_sd_imagenet.yml ├── ddpo_sd_imagenet_lora.yml └── ddpo_sd_pickapic.yml ├── docs ├── Makefile ├── make.bat ├── requirements.txt └── source │ ├── conf.py │ ├── denoisers.rst │ ├── example.rst │ ├── index.rst │ ├── installation.rst │ ├── pipeline.rst │ ├── reward_modelling.rst │ ├── sampling.rst │ └── trainer.rst ├── examples ├── contrasted_panda_inference.py ├── contrasted_panda_training.py ├── ddpo_imagenet_lora_inference.py ├── train_ddpo_imagenet.py ├── train_ddpo_imagenet_lora.py └── train_ddpo_pickapic.py ├── pyproject.toml ├── requirements.txt ├── src └── drlx │ ├── __init__.py │ ├── configs.py │ ├── denoisers │ ├── __init__.py │ └── ldm_unet.py │ ├── pipeline │ ├── __init__.py │ ├── imagenet_animal_prompts.py │ └── pickapic_prompts.py │ ├── reward_modelling │ ├── __init__.py │ ├── aesthetics.py │ ├── pickscore.py │ └── toy_rewards.py │ ├── sampling │ └── __init__.py │ ├── trainer │ ├── __init__.py │ └── ddpo_trainer.py │ └── utils │ └── __init__.py ├── tests ├── accelerate_checkpoint_test.py └── ddpo_unet_pipeline_test.py └── visualization ├── README.md └── vid_from_sample.py /.gitignore: -------------------------------------------------------------------------------- 1 | dist 2 | *.pth 3 | wandb/* 4 | __pycache__/ 5 | checkpoints/ 6 | *.egg-info/ 7 | output/ 8 | temp/ 9 | LOC_synset_mapping.txt -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | sphinx: 4 | configuration: docs/source/conf.py 5 | 6 | python: 7 | version: 3.8 8 | install: 9 | - requirements: docs/requirements.txt 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 CarperAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Diffusion Reinforcement Learning X 2 | 3 | DRLX is a library for distributed training of diffusion models via RL. It is meant to wrap around 🤗 Hugging Face's [Diffusers](https://huggingface.co/docs/diffusers/) library and uses [Accelerate](https://huggingface.co/docs/accelerate/) for Multi-GPU and Multi-Node (as of yet untested) 4 | 5 | **News (09/27/2023): Check out our blog post with some recent experiments [here](https://carper.ai/enhancing-diffusion-models-with-reinforcement-learning/)!** 6 | 7 | 📖 **[Documentation](https://DRLX.readthedocs.io)** 8 | 9 | # Setup 10 | 11 | First make sure you've installed [OpenCLIP](https://github.com/openai/CLIP.git). Afterwards, you can install the library from pypi: 12 | 13 | ```sh 14 | pip install drlx 15 | ``` 16 | 17 | or from source: 18 | 19 | ```sh 20 | pip install git+https://github.com/CarperAI/DRLX.git 21 | ``` 22 | 23 | # How to use 24 | 25 | Currently we have only tested the library with Stable Diffusion 1.4, 1.5, and 2.1, but the plug and play nature of it means that realistically any denoiser from most pipelines should be usable. Models saved with DRLX are compatible with the pipeline they originated from and can be loaded like any other pretrained model. Currently the only algorithm supported for training is [DDPO](https://arxiv.org/abs/2305.13301). 26 | 27 | ```python 28 | from drlx.reward_modelling.aesthetics import Aesthetics 29 | from drlx.pipeline.pickapic_prompts import PickAPicPrompts 30 | from drlx.trainer.ddpo_trainer import DDPOTrainer 31 | from drlx.configs import DRLXConfig 32 | 33 | # We import a reward model, a prompt pipeline, the trainer and config 34 | 35 | pipe = PickAPicPrompts() 36 | config = DRLXConfig.load_yaml("configs/my_cfg.yml") 37 | trainer = DDPOTrainer(config) 38 | 39 | trainer.train(pipe, Aesthetics()) 40 | ``` 41 | 42 | And then to use a trained model for inference: 43 | 44 | ```python 45 | pipe = StableDiffusionPipeline.from_pretrained("out/ddpo_exp") 46 | prompt = "A mad panda scientist" 47 | image = pipe(prompt).images[0] 48 | image.save("test.jpeg") 49 | ``` 50 | 51 | # Accelerated Training 52 | 53 | ```bash 54 | accelerate config 55 | accelerate launch -m [your module] 56 | ``` 57 | 58 | # Roadmap 59 | 60 | - [x] Initial launch and DDPO 61 | - [x] PickScore Tuned Models 62 | - [ ] DPO 63 | - [ ] SDXL support 64 | -------------------------------------------------------------------------------- /configs/ddpo_sd_imagenet.yml: -------------------------------------------------------------------------------- 1 | method: 2 | name : "DDPO" 3 | 4 | model: 5 | model_path: "stabilityai/stable-diffusion-2-1-base" 6 | model_arch_type: "LDMUnet" 7 | attention_slicing: True 8 | xformers_memory_efficient: True 9 | gradient_checkpointing: True 10 | 11 | sampler: 12 | num_inference_steps: 50 13 | 14 | optimizer: 15 | name: "adamw" 16 | kwargs: 17 | lr: 1.0e-5 18 | weight_decay: 1.0e-4 19 | betas: [0.9, 0.999] 20 | 21 | scheduler: 22 | name: "linear" # Name of learning rate scheduler 23 | kwargs: 24 | start_factor: 1.0 25 | end_factor: 1.0 26 | 27 | logging: 28 | run_name: 'ddpo_sd_imagenet' 29 | wandb_project: 'DRLX' 30 | 31 | train: 32 | num_epochs: 200 33 | num_samples_per_epoch: 256 34 | batch_size: 4 35 | sample_batch_size: 32 36 | grad_clip: 1.0 37 | checkpoint_interval: 50 38 | tf32: True 39 | suppress_log_keywords: "diffusers.pipelines,transformers" 40 | save_samples: False -------------------------------------------------------------------------------- /configs/ddpo_sd_imagenet_lora.yml: -------------------------------------------------------------------------------- 1 | method: 2 | name : "DDPO" 3 | 4 | model: 5 | model_path: "stabilityai/stable-diffusion-2-1-base" 6 | model_arch_type: "LDMUnet" 7 | attention_slicing: True 8 | xformers_memory_efficient: True 9 | gradient_checkpointing: True 10 | lora_rank: 4 11 | sampler: 12 | num_inference_steps: 50 13 | 14 | optimizer: 15 | name: "adamw" 16 | kwargs: 17 | lr: 1.0e-4 18 | weight_decay: 1.0e-4 19 | betas: [0.9, 0.999] 20 | 21 | scheduler: 22 | name: "linear" # Name of learning rate scheduler 23 | kwargs: 24 | start_factor: 1.0 25 | end_factor: 1.0 26 | 27 | logging: 28 | run_name: 'ddpo_sd_imagenet_lora' 29 | wandb_project: 'DRLX' 30 | 31 | train: 32 | num_epochs: 200 33 | num_samples_per_epoch: 256 34 | batch_size: 4 35 | sample_batch_size: 32 36 | grad_clip: 1.0 37 | checkpoint_interval: 10 38 | tf32: True 39 | suppress_log_keywords: "diffusers.pipelines,transformers" 40 | save_samples: False -------------------------------------------------------------------------------- /configs/ddpo_sd_pickapic.yml: -------------------------------------------------------------------------------- 1 | method: 2 | name : "DDPO" 3 | 4 | model: 5 | model_path: "stabilityai/stable-diffusion-2-1-base" 6 | model_arch_type: "LDMUnet" 7 | attention_slicing: True 8 | xformers_memory_efficient: True 9 | gradient_checkpointing: True 10 | 11 | sampler: 12 | guidance_scale: 7.5 13 | num_inference_steps: 50 14 | 15 | optimizer: 16 | name: "adamw" 17 | kwargs: 18 | lr: 1.0e-5 19 | weight_decay: 1.0e-4 20 | betas: [0.9, 0.999] 21 | 22 | scheduler: 23 | name: "linear" # Name of learning rate scheduler 24 | kwargs: 25 | start_factor: 1.0 26 | end_factor: 1.0 27 | 28 | logging: 29 | run_name: 'ddpo_sd_pickapic_pickscore' 30 | wandb_project: 'DRLX' 31 | 32 | train: 33 | num_epochs: 450 34 | num_samples_per_epoch: 256 35 | batch_size: 4 36 | sample_batch_size: 32 37 | grad_clip: 1.0 38 | checkpoint_interval: 50 39 | tf32: True 40 | suppress_log_keywords: "diffusers.pipelines,transformers" 41 | save_samples: False -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | torchtyping 4 | einops 5 | diffusers 6 | transformers 7 | accelerate 8 | xformers 9 | wandb 10 | fastprogress 11 | matplotlib 12 | git+https://github.com/openai/CLIP.git 13 | tqdm 14 | sphinx_rtd_theme 15 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | sys.path.insert(0, os.path.abspath('../../src/')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = 'DRLX' 21 | copyright = '2023, CarperAI' 22 | author = 'CarperAI' 23 | 24 | # The full version, including alpha/beta/rc tags 25 | release = '1.0' 26 | 27 | 28 | # -- General configuration --------------------------------------------------- 29 | 30 | # Add any Sphinx extension module names here, as strings. They can be 31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 32 | # ones. 33 | extensions = [ 34 | 'sphinx.ext.autodoc', 35 | 'sphinx_rtd_theme', 36 | ] 37 | 38 | # Add any paths that contain templates here, relative to this directory. 39 | templates_path = ['_templates'] 40 | 41 | # List of patterns, relative to source directory, that match files and 42 | # directories to ignore when looking for source files. 43 | # This pattern also affects html_static_path and html_extra_path. 44 | exclude_patterns = [] 45 | 46 | 47 | # -- Options for HTML output ------------------------------------------------- 48 | 49 | # The theme to use for HTML and HTML Help pages. See the documentation for 50 | # a list of builtin themes. 51 | # 52 | html_theme = 'sphinx_rtd_theme' 53 | 54 | # Add any paths that contain custom static files (such as style sheets) here, 55 | # relative to this directory. They are copied after the builtin static files, 56 | # so a file named "default.css" will overwrite the builtin "default.css". 57 | html_static_path = ['_static'] 58 | 59 | -------------------------------------------------------------------------------- /docs/source/denoisers.rst: -------------------------------------------------------------------------------- 1 | .. _denoisers: 2 | 3 | Denoisers 4 | ========= 5 | 6 | DRLX generally uses conditioned denoisers for diffusion modelling. Currently, the library is made with text conditioning in mind, the base classes are with generalizability in mind, and to this end the conditional denoiser 7 | supports any kind of conditioning signal that produces an embedding. 8 | 9 | BaseConditionalDenoiser 10 | ------------------------- 11 | 12 | .. automodule:: drlx.denoisers 13 | :members: 14 | :undoc-members: 15 | :show-inheritance: 16 | 17 | LDMUNet 18 | ---------- 19 | 20 | .. automodule:: drlx.denoisers.ldm_unet 21 | :members: 22 | :undoc-members: 23 | :show-inheritance: 24 | -------------------------------------------------------------------------------- /docs/source/example.rst: -------------------------------------------------------------------------------- 1 | .. _example: 2 | 3 | DRLX Example 4 | ============ 5 | 6 | This example demonstrates how to use DRLX to train a model with a custom prompt pipeline and reward model. The prompt pipeline will repeatedly provide the same prompt, "Photo of a mad scientist panda", and the reward model will reward images for having high contrast. 7 | 8 | Custom Prompt Pipeline 9 | ----------------------- 10 | 11 | First, we define a custom prompt pipeline that only gives a single phrase "Photo of a mad scientist panda" over and over. 12 | 13 | .. code-block:: python 14 | 15 | from drlx.pipeline import PromptPipeline 16 | 17 | class MadScientistPandaPrompts(PromptPipeline): 18 | """ 19 | Custom prompt pipeline that only gives a single phrase "Photo of a mad scientist panda" over and over. 20 | """ 21 | def __getitem__(self, index): 22 | return "Photo of a mad scientist panda" 23 | 24 | def __len__(self): 25 | return 100000 # arbitrary 26 | 27 | Custom Reward 28 | ---------------- 29 | 30 | Next, we define a custom reward model that rewards images for having high contrast. The contrast is calculated as the standard deviation of the pixel intensities. 31 | 32 | .. code-block:: python 33 | 34 | from drlx.reward_modelling import RewardModel 35 | import numpy as np 36 | import torch 37 | 38 | class HighContrastReward(RewardModel): 39 | """ 40 | Rewards high contrast in the image. 41 | """ 42 | def forward(self, images, prompts): 43 | # If the input is a list of PIL Images, convert to numpy array 44 | if isinstance(images, list): 45 | images = np.array([np.array(img) for img in images]) 46 | 47 | # Calculate the standard deviation of the pixel intensities for each image 48 | contrast = images.std(axis=(1,2,3)) # N 49 | 50 | return torch.from_numpy(contrast) 51 | 52 | Training Setup 53 | --------------- 54 | 55 | Now, we set up the training process. We use the MadScientistPandaPrompts as the prompt pipeline and the HighContrastReward as the reward model. 56 | 57 | .. code-block:: python 58 | 59 | from drlx.trainer.ddpo_trainer import DDPOTrainer 60 | from drlx.configs import DRLXConfig 61 | from drlx.reward_modelling.toy_rewards import JPEGCompressability 62 | from drlx.reward_modelling.aesthetics import Aesthetics 63 | from drlx.utils import get_latest_checkpoint 64 | 65 | # Pipeline first 66 | from drlx.pipeline.pickapic_prompts import PickAPicPrompts 67 | 68 | import torch 69 | 70 | pipe = MadScientistPandaPrompts() 71 | 72 | config = DRLXConfig.load_yaml("configs/ddpo_sd.yml") 73 | trainer = DDPOTrainer(config) 74 | 75 | trainer.train(pipe, HighContrastReward()) 76 | 77 | For accelerated training, simply run the following command: 78 | 79 | .. code-block:: bash 80 | 81 | accelerate launch -m [script] 82 | 83 | Loading the Model and Performing Inference 84 | -------------------------------------------- 85 | 86 | After training, we can load the model and perform inference with it using a default sampler. 87 | 88 | .. code-block:: python 89 | 90 | # Load the trainer from a checkpoint if you wanted to resume training 91 | # Trainer by default saves both output and checkpoint in seperate folders specified by run_name 92 | checkpoint_path = "checkpoints/run_name" 93 | output_path = "output/run_name" 94 | trainer.load_checkpoint(checkpoint_path) 95 | 96 | # Otherwise, you can just use a pretrained pipeline 97 | from diffusers import StableDiffusionPipeline 98 | 99 | pipe = StableDiffusionPipeline.from_pretrained(output_path, local_files_only = True) 100 | 101 | To actually run this code or make tweaks, please see the notebooks or scripts under the examples folder. 102 | 103 | 104 | 105 | 106 | 107 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. DRLX documentation master file, created by 2 | sphinx-quickstart on Tue Aug 15 13:19:31 2023. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to DRLX's documentation! 7 | ================================ 8 | DRLX (Diffuser Reinforcement Library X) is a library made to simplify the training of diffusion models with reinforcement learning, powered by `diffusers `_ and `accelerate `_. 9 | 10 | 11 | .. toctree:: 12 | :maxdepth: 2 13 | :caption: Setup: 14 | 15 | installation 16 | 17 | .. toctree:: 18 | :maxdepth: 2 19 | :caption: Getting Started: 20 | 21 | example 22 | 23 | .. toctree:: 24 | :maxdepth: 2 25 | :caption: Documentation: 26 | 27 | pipeline 28 | sampling 29 | denoisers 30 | reward_modelling 31 | trainer 32 | 33 | Indices and tables 34 | ================== 35 | 36 | * :ref:`genindex` 37 | * :ref:`modindex` 38 | * :ref:`search` 39 | 40 | -------------------------------------------------------------------------------- /docs/source/installation.rst: -------------------------------------------------------------------------------- 1 | .. _installation: 2 | 3 | Installation and Setup 4 | ============ 5 | 6 | todo -------------------------------------------------------------------------------- /docs/source/pipeline.rst: -------------------------------------------------------------------------------- 1 | .. _pipeline: 2 | 3 | Pipeline 4 | ======== 5 | 6 | The pipeline module in DRLX is used for data preparation when training some RL model. It includes a base class `Pipeline` and two subclasses `PromptPipeline`, `PickAPicPrompts` and `ImagenetAnimalPrompts`. 7 | 8 | Pipeline 9 | ---------- 10 | 11 | .. autoclass:: drlx.pipeline.Pipeline 12 | :members: 13 | :undoc-members: 14 | :show-inheritance: 15 | 16 | PromptPipeline 17 | --------------- 18 | 19 | .. autoclass:: drlx.pipeline.PromptPipeline 20 | :members: 21 | :undoc-members: 22 | :show-inheritance: 23 | 24 | PickAPicPrompts 25 | ----------------- 26 | 27 | .. automodule:: drlx.pipeline.pickapic_prompts 28 | :members: 29 | :undoc-members: 30 | :show-inheritance: 31 | 32 | ImagenetAnimalPrompts 33 | ----------------------- 34 | 35 | .. automodule:: drlx.pipeline.imagenet_animal_prompts 36 | :members: 37 | :undoc-members: 38 | :show-inheritance: 39 | -------------------------------------------------------------------------------- /docs/source/reward_modelling.rst: -------------------------------------------------------------------------------- 1 | .. _reward_modelling: 2 | 3 | Reward Modelling 4 | ================ 5 | 6 | Reward models are used to generate a reward signal to be used during RL training for an image generation model. Typically, they take an image and return some reward. Some may use prompts while generating reward, but this is not neccesary. 7 | The library includes some toy rewards intended primarily for debugging. 8 | 9 | Toy Rewards 10 | ----------- 11 | 12 | .. automodule:: drlx.reward_modelling.toy_rewards 13 | :members: 14 | :undoc-members: 15 | :show-inheritance: 16 | 17 | Aesthetics 18 | ---------- 19 | 20 | .. autoclass:: drlx.reward_modelling.aesthetics.Aesthetics 21 | :members: 22 | :undoc-members: 23 | :show-inheritance: 24 | 25 | Pickscore (WIP) 26 | ---------------- 27 | 28 | .. automodule:: drlx.reward_modelling.pickscore 29 | :members: 30 | :undoc-members: 31 | :show-inheritance: -------------------------------------------------------------------------------- /docs/source/sampling.rst: -------------------------------------------------------------------------------- 1 | .. _sampling: 2 | 3 | Sampling 4 | ======== 5 | 6 | DRLX provides multiple samplers. Different methods often require a specific sampling procedure. A default sampler is also included for inference purposes. 7 | 8 | Sampler 9 | --------- 10 | 11 | .. automodule:: drlx.sampling 12 | :members: Sampler 13 | :undoc-members: 14 | :show-inheritance: 15 | 16 | DDPOSampler 17 | ------------- 18 | 19 | .. automodule:: drlx.sampling 20 | :members: DDPOSampler 21 | :undoc-members: 22 | :show-inheritance: 23 | -------------------------------------------------------------------------------- /docs/source/trainer.rst: -------------------------------------------------------------------------------- 1 | .. _trainers: 2 | 3 | Trainers 4 | ======== 5 | 6 | DRLX provides a base trainer class and specific trainers for different methods. The base trainer class provides the basic functionalities such as setting up the optimizer, scheduler, and model, saving and loading checkpoints. The specific trainers extend the base trainer and implement the training process for the specific method. 7 | 8 | BaseTrainer 9 | ------------ 10 | 11 | .. automodule:: drlx.trainer 12 | :members: BaseTrainer 13 | :undoc-members: 14 | :show-inheritance: 15 | 16 | DDPOTrainer 17 | ------------- 18 | 19 | .. automodule:: drlx.trainer.ddpo_trainer 20 | :members: DDPOTrainer 21 | :undoc-members: 22 | :show-inheritance: 23 | -------------------------------------------------------------------------------- /examples/contrasted_panda_inference.py: -------------------------------------------------------------------------------- 1 | # Inference: 2 | 3 | from diffusers import StableDiffusionPipeline 4 | import torch 5 | 6 | pipe = StableDiffusionPipeline.from_pretrained("out/contrasting_panda", torch_dtype=torch.float16, local_files_only = True).to('cuda') 7 | pipe.enable_attention_slicing() 8 | 9 | prompt = "A mad panda scientist" 10 | image = pipe(prompt).images[0] 11 | image.save("test.jpeg") 12 | -------------------------------------------------------------------------------- /examples/contrasted_panda_training.py: -------------------------------------------------------------------------------- 1 | # Set up pipeline with repeating prompts 2 | 3 | from drlx.pipeline import PromptPipeline 4 | 5 | class MadScientistPandaPrompts(PromptPipeline): 6 | """ 7 | Custom prompt pipeline that only gives a single phrase "Photo of a mad scientist panda" over and over. 8 | """ 9 | def __getitem__(self, index): 10 | return "Photo of a mad scientist panda" 11 | 12 | def __len__(self): 13 | return 100000 # arbitrary 14 | 15 | # Next we make our reward model 16 | 17 | from drlx.reward_modelling import RewardModel 18 | import numpy as np 19 | import torch 20 | 21 | class HighContrastReward(RewardModel): 22 | """ 23 | Rewards high contrast in the image. 24 | """ 25 | def forward(self, images, prompts): 26 | # If the input is a list of PIL Images, convert to numpy array 27 | if isinstance(images, list): 28 | images = np.array([np.array(img) for img in images]) 29 | 30 | # Calculate the standard deviation of the pixel intensities for each image 31 | contrast = images.std(axis=(1,2,3)) # N 32 | 33 | return torch.from_numpy(contrast) 34 | 35 | # Next, we setup trainer using default config 36 | 37 | from drlx.trainer.ddpo_trainer import DDPOTrainer 38 | from drlx.configs import DRLXConfig 39 | 40 | pipe = MadScientistPandaPrompts() 41 | 42 | config = DRLXConfig.load_yaml("configs/ddpo_sd.yml") 43 | 44 | # Some changes to config for our use-case 45 | config.train.num_samples_per_epoch = 32 46 | config.train.batch_size = 4 # adjust as needed 47 | config.logging.run_name = "contrasting_panda" 48 | 49 | trainer = DDPOTrainer(config) 50 | 51 | # If we wanted to resume a run... we can make this little change 52 | from drlx.utils import get_latest_checkpoint 53 | 54 | RESUME = False 55 | if RESUME: 56 | cp_dir = get_latest_checkpoint(f"checkpoints/{config.logging.run_name}") 57 | trainer.load_checkpoint(cp_dir) 58 | 59 | trainer.train(pipe, HighContrastReward()) -------------------------------------------------------------------------------- /examples/ddpo_imagenet_lora_inference.py: -------------------------------------------------------------------------------- 1 | # Inference: 2 | 3 | from diffusers import StableDiffusionPipeline 4 | import torch 5 | 6 | pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base", torch_dtype=torch.float16).to('cuda') 7 | pipe.load_lora_weights("output/ddpo_sd_imagenet_lora") 8 | pipe.enable_attention_slicing() 9 | 10 | prompt = "llama" 11 | image = pipe(prompt).images[0] 12 | image.save("test.jpeg") 13 | -------------------------------------------------------------------------------- /examples/train_ddpo_imagenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from drlx.trainer.ddpo_trainer import DDPOTrainer 3 | from drlx.configs import DRLXConfig 4 | from drlx.reward_modelling.aesthetics import Aesthetics 5 | from drlx.pipeline.imagenet_animal_prompts import ImagenetAnimalPrompts 6 | from drlx.utils import get_latest_checkpoint 7 | 8 | config = DRLXConfig.load_yaml("configs/ddpo_sd_imagenet.yml") 9 | 10 | pipe = ImagenetAnimalPrompts(prefix='', postfix='', num=config.train.num_samples_per_epoch) 11 | resume = False 12 | 13 | trainer = DDPOTrainer(config) 14 | 15 | if resume: 16 | cp_dir = get_latest_checkpoint(f"checkpoints/{config.logging.run_name}") 17 | trainer.load_checkpoint(cp_dir) 18 | 19 | trainer.train(pipe, Aesthetics()) -------------------------------------------------------------------------------- /examples/train_ddpo_imagenet_lora.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from drlx.trainer.ddpo_trainer import DDPOTrainer 3 | from drlx.configs import DRLXConfig 4 | from drlx.reward_modelling.aesthetics import Aesthetics 5 | from drlx.pipeline.imagenet_animal_prompts import ImagenetAnimalPrompts 6 | from drlx.utils import get_latest_checkpoint 7 | 8 | config = DRLXConfig.load_yaml("configs/ddpo_sd_imagenet_lora.yml") 9 | 10 | pipe = ImagenetAnimalPrompts(prefix='', postfix='', num=config.train.num_samples_per_epoch) 11 | resume = False 12 | 13 | trainer = DDPOTrainer(config) 14 | 15 | if resume: 16 | cp_dir = get_latest_checkpoint(f"checkpoints/{config.logging.run_name}") 17 | trainer.load_checkpoint(cp_dir) 18 | 19 | trainer.train(pipe, Aesthetics()) -------------------------------------------------------------------------------- /examples/train_ddpo_pickapic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from drlx.trainer.ddpo_trainer import DDPOTrainer 3 | from drlx.configs import DRLXConfig 4 | from drlx.pipeline.pickapic_prompts import PickAPicPrompts, PickAPicReplacementPrompts 5 | from drlx.reward_modelling.pickscore import PickScoreModel 6 | from drlx.reward_modelling.aesthetics import Aesthetics 7 | from drlx.utils import get_latest_checkpoint 8 | 9 | pipe = PickAPicPrompts() 10 | resume = False 11 | 12 | config = DRLXConfig.load_yaml("configs/ddpo_sd_pickapic.yml") 13 | trainer = DDPOTrainer(config) 14 | 15 | if resume: 16 | cp_dir = get_latest_checkpoint(f"checkpoints/{config.logging.run_name}") 17 | trainer.load_checkpoint(cp_dir) 18 | 19 | trainer.train(pipe, PickScoreModel()) -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=64", "setuptools_scm[toml]>=7"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.setuptools.dynamic] 6 | dependencies = {file = ["requirements.txt"]} 7 | 8 | [tool.setuptools_scm] 9 | 10 | [tool.isort] 11 | profile = "black" 12 | 13 | [tool.flake8] 14 | max-line-length = 88 15 | extend-ignore = [ 16 | "E203", 17 | "E501", 18 | "W503", 19 | "F811", 20 | ] 21 | extend-exclude = ["docs"] 22 | count = true 23 | statistics = true 24 | 25 | [tool.pydocstyle] 26 | convention = "google" 27 | add-ignore = "D10, D212" 28 | 29 | [tool.pytest.ini_options] 30 | minversion = "6.0" 31 | addopts = "-rA -x --doctest-modules --color=yes" # --cov=openelm" # Uncomment this for coverage by default 32 | testpaths = ["tests"] 33 | doctest_optionflags = ["NORMALIZE_WHITESPACE", "IGNORE_EXCEPTION_DETAIL"] 34 | 35 | [project] 36 | name = "drlx" 37 | version = "0.0.2" 38 | description = "DRLX is a library for distributed training of diffusion models via RL" 39 | authors = [{name = "CarperAI"}] 40 | readme = "README.md" 41 | requires-python = ">=3.9" 42 | license = {text = "MIT"} 43 | dynamic = ["dependencies"] 44 | classifiers=[ 45 | # Trove classifiers 46 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 47 | "Development Status :: 3 - Alpha", 48 | "Environment :: Console", 49 | "Intended Audience :: Science/Research", 50 | "Intended Audience :: Developers", 51 | "License :: OSI Approved :: MIT License", 52 | "Natural Language :: English", 53 | "Programming Language :: Python", 54 | "Programming Language :: Python :: 3", 55 | "Programming Language :: Python :: 3.9", 56 | "Programming Language :: Python :: 3.10", 57 | "Programming Language :: Python :: 3 :: Only", 58 | "Programming Language :: Python :: Implementation :: CPython", 59 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 60 | "Typing :: Typed", 61 | "Operating System :: Unix", 62 | ] 63 | 64 | [project.optional-dependencies] 65 | dev = [ 66 | "black", 67 | "isort", 68 | "flake8", 69 | "flake8-pyproject", 70 | "pydocstyle", 71 | "mypy", 72 | "pre-commit", 73 | "pytest", 74 | "pytest-cov", 75 | ] 76 | benchmarks = [ 77 | "pygraphviz", 78 | "graphviz", 79 | "openai", 80 | ] 81 | docs = [ 82 | "sphinx==5.3.0", 83 | "sphinx_rtd_theme", 84 | "sphinx_autodoc_typehints", 85 | ] 86 | triton = [ 87 | "tritonclient[all]", 88 | ] 89 | notebook = ["ipython"] 90 | sodaracer = [ 91 | "swig>=4.1.0", 92 | "box2d-py==2.3.8", 93 | "pygame" 94 | ] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | torchtyping 4 | einops 5 | diffusers 6 | peft 7 | transformers 8 | accelerate 9 | datasets 10 | xformers 11 | wandb 12 | fastprogress 13 | matplotlib 14 | tqdm 15 | -------------------------------------------------------------------------------- /src/drlx/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CarperAI/DRLX/3a23b506a5199dcf2162ab7c0e05a9a8259c567a/src/drlx/__init__.py -------------------------------------------------------------------------------- /src/drlx/configs.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from copy import deepcopy 3 | from dataclasses import dataclass, field, asdict 4 | from typing import Any, Dict, List, Optional, Set 5 | import yaml 6 | 7 | @dataclass 8 | class ConfigClass: 9 | @classmethod 10 | def from_dict(cls, cfg : Dict[str, Any]): 11 | return cls(**cfg) 12 | 13 | def to_dict(self): 14 | return asdict(self) 15 | 16 | 17 | # specifies a dictionary of method configs 18 | _METHODS: Dict[str, Any] = {} # registry 19 | 20 | 21 | def register_method(name): 22 | """Decorator used register a method config 23 | Args: 24 | name: Name of the method 25 | """ 26 | 27 | def register_class(cls, name): 28 | _METHODS[name] = cls 29 | setattr(sys.modules[__name__], name, cls) 30 | return cls 31 | 32 | if isinstance(name, str): 33 | name = name.lower() 34 | return lambda c: register_class(c, name) 35 | 36 | cls = name 37 | name = cls.__name__ 38 | register_class(cls, name.lower()) 39 | 40 | return cls 41 | 42 | @dataclass 43 | @register_method 44 | class MethodConfig(ConfigClass): 45 | """ 46 | Config for a certain RL method. 47 | 48 | :param name: Name of the method 49 | :type name: str 50 | """ 51 | 52 | name: str = None 53 | 54 | def get_method(name: str) -> MethodConfig: 55 | """ 56 | Return constructor for specified method config 57 | """ 58 | name = name.lower() 59 | if name in _METHODS: 60 | return _METHODS[name] 61 | else: 62 | raise Exception("Error: Trying to access a method that has not been registered") 63 | 64 | @register_method("DDPO") 65 | @dataclass 66 | class DDPOConfig(MethodConfig): 67 | """ 68 | Config for DDPO-related hyperparameters including per prompt stat tracker 69 | 70 | :param clip_advantages: Maximum absolute value of advantages 71 | :type clip_advantages: float 72 | 73 | :param clip_ratio: Maximum absolute value of ratio of new to old policy 74 | :type clip_ratio: float 75 | 76 | :param num_inner_epochs: Number of epochs to train the policy for 77 | :type num_inner_epochs: int 78 | 79 | :param buffer_size: Number of samples to keep in the buffer 80 | :type buffer_size: int 81 | 82 | :param min_count: Minimum number of samples to keep in the buffer before 83 | calculating statistics 84 | :type min_count: int 85 | """ 86 | name : str = "DDPO" 87 | clip_advantages: float = 10.0 88 | clip_ratio: float = 1e-4 89 | num_inner_epochs: int = 1 90 | 91 | buffer_size: int = 32 # Set to None to avoid using per prompt stat tracker 92 | min_count: int = 16 93 | 94 | @dataclass 95 | class TrainConfig(ConfigClass): 96 | """ 97 | Config for training 98 | 99 | :param batch_size: Batch size 100 | :type batch_size: int 101 | 102 | :param target_batch: Target batch size with gradient accumulation 103 | :type target_batch: int 104 | 105 | :param sample_batch_size: Batch size to use during inference only 106 | :type sample_batch_size: int 107 | 108 | :param num_epochs: Number of epochs to train for 109 | :type num_epochs: int 110 | 111 | :param total_samples: Provide this as alternative to epochs. Computes required epochs to see this may samples. 112 | :type total_samples: int 113 | 114 | :param num_samples_per_epoch: Number of samples to use per epoch 115 | :type num_samples_per_epoch: int 116 | 117 | :param grad_clip: Maximum absolute value of gradient 118 | :type grad_clip: float 119 | 120 | :param checkpoint_interval: Number of epochs between checkpoints 121 | :type checkpoint_interval: int 122 | 123 | :param checkpoint_path: Path to save checkpoints to 124 | :type checkpoint_path: str 125 | 126 | :param seed: Random seed 127 | :type seed: int 128 | 129 | :param tf32: Use tf32 precision 130 | :type tf32: bool 131 | 132 | :suppress_log_keywords: List of prefixes for loggers to suppress warnings from during training. Type as single string with different prefixes delimited by commas. 133 | :type suppress_log_keywords: str 134 | 135 | :param sample_prompts: List of sample prompts to use for fixed sample taken every training step 136 | :type sample_prompts: List[str] 137 | 138 | :param save_samples: Save samples locally? 139 | :type save_samples: bool 140 | """ 141 | batch_size: int = 4 142 | target_batch: int = None 143 | sample_batch_size: int = 8 144 | num_epochs: int = 50 145 | total_samples: int = None 146 | num_samples_per_epoch: int = 256 147 | grad_clip: float = 1.0 148 | checkpoint_interval: int = 10 149 | checkpoint_path: str = "checkpoints" 150 | seed: int = 0 151 | tf32: bool = False 152 | suppress_log_keywords: str = None 153 | sample_prompts : List[str] = None 154 | save_samples : bool = True 155 | 156 | 157 | @dataclass 158 | class LoggingConfig(ConfigClass): 159 | """ 160 | Config for logging 161 | 162 | :param log_with: Logging backend to use (either "wandb" or "tensorboard") 163 | :type log_with: str 164 | 165 | :param run_name: Name of the run. Also used during saving. 166 | :type run_name: str 167 | 168 | :param wandb_entity: Name of the wandb entity to log to 169 | :type wandb_entity: str 170 | 171 | :param wandb_project: Name of the wandb project to log to 172 | :type wandb_project: str 173 | """ 174 | log_with: str = "wandb" # "wandb" or "tensorboard" 175 | log_dir: str = None 176 | run_name: str = 'ddpo_exp' 177 | wandb_entity: str = None 178 | wandb_project: str = None 179 | 180 | 181 | 182 | @dataclass 183 | class OptimizerConfig(ConfigClass): 184 | """ 185 | Config for an optimizer. 186 | 187 | :param name: Name of the optimizer 188 | :type name: str 189 | 190 | :param kwargs: Keyword arguments for the optimizer (e.g. lr, betas, eps, weight_decay) 191 | :type kwargs: Dict[str, Any] 192 | """ 193 | 194 | name: str = None 195 | kwargs: Dict[str, Any] = field(default_factory=dict) 196 | 197 | 198 | @dataclass 199 | class SchedulerConfig(ConfigClass): 200 | """ 201 | Config for a learning rate scheduler. 202 | 203 | :param name: Name of the scheduler 204 | :type name: str 205 | 206 | :param kwargs: Keyword arguments for the scheduler instance (e.g. warmup_steps, T_max) 207 | :type kwargs: Dict[str, Any] 208 | """ 209 | 210 | name: str = None 211 | kwargs: Dict[str, Any] = field(default_factory=dict) 212 | 213 | 214 | @dataclass 215 | class ModelConfig(ConfigClass): 216 | """ 217 | Config for a model. 218 | 219 | :param model_path: Path or name of the model (local or on huggingface hub) 220 | :type model_path: str 221 | 222 | :param model_arch_type: Type of model architecture. 223 | :type model_arch_type: str 224 | 225 | :param use_safetensors: Use safe tensors when loading pipeline? 226 | :type use_safetensors: bool 227 | 228 | :param local_model: Force model to load checkpoint locally only 229 | :type local_model: bool 230 | 231 | :param attention_slicing: Whether to use attention slicing 232 | :type attention_slicing: bool 233 | 234 | :param xformers_memory_efficient: Whether to use memory efficient attention implementation from xformers 235 | :type xformers_memory_efficient: bool 236 | 237 | :param gradient_checkpointing: Whether to use gradient checkpointing 238 | :type gradient_checkpointing: bool 239 | 240 | :param lora_rank: Rank of LoRA matrix 241 | :type lora_rank: int 242 | """ 243 | 244 | model_path: str = None 245 | model_arch_type: str = None 246 | use_safetensors : bool = False 247 | local_model : bool = False 248 | attention_slicing: bool = False 249 | xformers_memory_efficient: bool = False 250 | gradient_checkpointing: bool = False 251 | lora_rank: int = None 252 | 253 | 254 | 255 | @dataclass 256 | class SamplerConfig(ConfigClass): 257 | guidance_scale : float = 5.0 # if guidance is being used 258 | guidance_rescale : float = None # see https://arxiv.org/pdf/2305.08891.pdf 259 | num_inference_steps : int = 50 260 | eta : float = 1 261 | postprocess : bool = False # If true, post processes latents to images (uint8 np arrays) 262 | img_size : int = 512 263 | 264 | def load_yaml(yml_fp : str) -> Dict[str, ConfigClass]: 265 | with open(yml_fp, mode = 'r') as file: 266 | config = yaml.safe_load(file) 267 | d = {} 268 | if config["model"]: 269 | d["model"] = ModelConfig.from_dict(config["model"]) 270 | if config["train"]: 271 | d["train"] = TrainConfig.from_dict(config["train"]) 272 | if config["sampler"]: 273 | d["sampler"] = SamplerConfig.from_dict(config["sampler"]) 274 | 275 | return d 276 | 277 | 278 | def merge(base: Dict, update: Dict, updated: Set) -> Dict: 279 | "Recursively updates a nested dictionary with new values" 280 | for k, v in base.items(): 281 | if k in update and isinstance(v, dict): 282 | base[k] = merge(v, update[k], updated) 283 | updated.add(k) 284 | elif k in update: 285 | base[k] = update[k] 286 | updated.add(k) 287 | 288 | return base 289 | 290 | @dataclass 291 | class DRLXConfig(ConfigClass): 292 | """ 293 | Top-level config 294 | 295 | :param model: Model config 296 | :type model: ModelConfig 297 | 298 | :param optimizer: Optimizer config 299 | :type optimizer: OptimizerConfig 300 | 301 | :param scheduler: Scheduler config 302 | :type scheduler: SchedulerConfig 303 | 304 | :param train: Training config 305 | :type train: TrainConfig 306 | 307 | :param logging: Logging config 308 | :type logging: LoggingConfig 309 | 310 | :param method: Method config 311 | :type method: MethodConfig 312 | """ 313 | 314 | model: ModelConfig 315 | sampler: SamplerConfig 316 | optimizer: OptimizerConfig 317 | scheduler: SchedulerConfig 318 | train: TrainConfig 319 | logging: LoggingConfig 320 | method: MethodConfig 321 | 322 | @classmethod 323 | def load_yaml(cls, yml_fp: str): 324 | """ 325 | Load yaml file as DRLXConfig. 326 | 327 | :param yml_fp: Path to yaml file 328 | :type yml_fp: str 329 | """ 330 | with open(yml_fp, mode="r") as file: 331 | config = yaml.safe_load(file) 332 | return cls.from_dict(config) 333 | 334 | def to_dict(self): 335 | """ 336 | Convert TRLConfig to dictionary. 337 | """ 338 | data = { 339 | "method": self.method.__dict__, 340 | "model": self.model.__dict__, 341 | "sampler": self.sampler.__dict__, 342 | "optimizer": self.optimizer.__dict__, 343 | "scheduler": self.scheduler.__dict__, 344 | "train": self.train.__dict__, 345 | "logging": self.logging.__dict__ 346 | } 347 | 348 | return data 349 | 350 | @classmethod 351 | def from_dict(cls, config: Dict): 352 | """ 353 | Convert dictionary to DRLXConfig. 354 | """ 355 | return cls( 356 | method=get_method(config["method"]["name"]).from_dict(config["method"]), 357 | model=ModelConfig.from_dict(config["model"]), 358 | sampler=SamplerConfig.from_dict(config["sampler"]), 359 | optimizer=OptimizerConfig.from_dict(config["optimizer"]), 360 | scheduler=SchedulerConfig.from_dict(config["scheduler"]), 361 | train=TrainConfig.from_dict(config["train"]), 362 | logging=LoggingConfig.from_dict(config["logging"]), 363 | ) 364 | 365 | @classmethod 366 | def update(cls, baseconfig: Dict, config: Dict): 367 | update = {} 368 | # unflatten a string variable name into a nested dictionary 369 | # key1.key2.key3: value -> {key1: {key2: {key3: value}}} 370 | for name, value in config.items(): 371 | if isinstance(value, dict): 372 | update[name] = value 373 | else: 374 | *layers, var = name.split(".") 375 | if layers: 376 | d = update.setdefault(layers[0], {}) 377 | for layer in layers[1:]: 378 | d = d.setdefault(layer, {}) 379 | d[var] = value 380 | 381 | if not isinstance(baseconfig, Dict): 382 | baseconfig = baseconfig.to_dict() 383 | 384 | updates = set() 385 | merged = merge(baseconfig, update, updates) 386 | 387 | for param in update: 388 | if param not in updates: 389 | raise ValueError(f"parameter {param} is not present in the config (typo or a wrong config)") 390 | 391 | return cls.from_dict(merged) 392 | 393 | def __str__(self): 394 | """Returns a human-readable string representation of the config.""" 395 | import json 396 | 397 | return json.dumps(self.to_dict(), indent=4) 398 | -------------------------------------------------------------------------------- /src/drlx/denoisers/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Dict, Any, Optional, Tuple 2 | from torchtyping import TensorType 3 | 4 | from abc import abstractmethod 5 | 6 | import os 7 | 8 | import torch 9 | from torch import nn 10 | import numpy as np 11 | 12 | from drlx.configs import ModelConfig, SamplerConfig 13 | from drlx.sampling import Sampler 14 | 15 | class BaseConditionalDenoiser(nn.Module): 16 | """ 17 | Base class for any denoiser that takes a conditioning signal during denoising process, including text conditioned denoisers. 18 | 19 | :param config: Configuration for model 20 | :type config: ModelConfig 21 | 22 | :param sampler_config: Configuration for sampler (optional). If provided, will create a default sampler. 23 | :type sampler_config: SamplerConfig 24 | 25 | :param sampler: Can be provided as alternative to sampler_config (also optional). If neither are provided, a default sampler will be used. 26 | :type sampler: Sampler 27 | """ 28 | def __init__(self, config : ModelConfig, sampler_config : SamplerConfig = None, sampler : Sampler = None): 29 | super().__init__() 30 | 31 | self.config = config 32 | self.scheduler = None 33 | 34 | if sampler_config is None and sampler is None: 35 | self.sampler = Sampler(SamplerConfig()) 36 | else: 37 | self.sampler = Sampler(sampler_config) if sampler_config is not None else sampler 38 | 39 | def sample(self, **kwargs): 40 | """ 41 | Use the sampler to sample an image. Will require postprocess to output an image. Note that different samplers have different outputs. 42 | 43 | :param kwargs: Keyword arguments to sampler 44 | 45 | :return: Varies per sampler but always includes denoised latent/images 46 | """ 47 | kwargs['denoiser'] = self 48 | return self.sampler.sample(**kwargs) 49 | 50 | @abstractmethod 51 | def get_input_shape(self) -> Tuple: 52 | """ 53 | Get input shape for denoiser. Useful during training + sampling when shape of input noise to denoiser is needed. 54 | 55 | :return: Input shape as a tuple 56 | :rtype: Tuple[int] 57 | """ 58 | pass 59 | 60 | @abstractmethod 61 | def preprocess(self, *inputs) -> TensorType["batch", "embedding_dim"]: 62 | """ 63 | Called on the conditioning input (typically: tokenizes text prompt) 64 | 65 | :return: Conditioning input embeddings (i.e. text embeddings) as tensors 66 | :rtype: torch.Tensor 67 | """ 68 | pass 69 | 70 | @abstractmethod 71 | def postprocess(self, output) -> np.ndarray: 72 | """ 73 | Called on the output from the model after sampling to give final image 74 | 75 | :return: Final denoised image as uint8 numpy array 76 | :rtype: np.ndarray 77 | """ 78 | pass 79 | 80 | @abstractmethod 81 | def forward(self, *inputs) -> TensorType["batch", "channels", "height", "width"]: 82 | """ 83 | Forward pass for denoiser. Output varies based on prediction type. 84 | """ 85 | pass 86 | 87 | # === LATENT DIFFUSION === 88 | 89 | @abstractmethod 90 | def encode(self, pixel_values : TensorType["batch", "channels", "height", "width"]) -> torch.Tensor: 91 | """ 92 | Encode image into latent vector 93 | """ 94 | pass 95 | 96 | @abstractmethod 97 | def decode(self, latent : torch.Tensor) -> TensorType["batch", "channels", "height", "width"]: 98 | """ 99 | Decode latent vector into an image (typically called in postprocess) 100 | """ 101 | pass 102 | 103 | -------------------------------------------------------------------------------- /src/drlx/denoisers/ldm_unet.py: -------------------------------------------------------------------------------- 1 | from torchtyping import TensorType 2 | from typing import Iterable, Union, Callable, Type, Tuple 3 | 4 | import torch 5 | import numpy as np 6 | from diffusers import UNet2DConditionModel, DDIMScheduler 7 | 8 | from drlx.denoisers import BaseConditionalDenoiser 9 | from drlx.configs import ModelConfig, SamplerConfig 10 | from drlx.sampling import Sampler 11 | 12 | from peft import LoraConfig 13 | 14 | class LDMUNet(BaseConditionalDenoiser): 15 | """ 16 | Class for Latent Diffusion Model UNet denoiser. Can optionally pass sampler information, though it is not required. Generally used in tandem with a diffusers pipeline. 17 | 18 | :param config: Configuration for model 19 | :type config: ModelConfig 20 | 21 | :param sampler_config: Configuration for sampler (optional). If provided, will create a default sampler. 22 | :type sampler_config: SamplerConfig 23 | 24 | :param sampler: Can be provided as alternative to sampler_config (also optional). If neither are provided, a default sampler will be used. 25 | :type sampler: Sampler 26 | """ 27 | def __init__(self, config : ModelConfig, sampler_config : SamplerConfig = None, sampler : Sampler = None): 28 | super().__init__(config, sampler_config, sampler) 29 | 30 | self.unet : UNet2DConditionModel = None 31 | self.text_encoder = None 32 | self.vae = None 33 | self.encode_prompt : Callable = None 34 | 35 | self.tokenizer = None 36 | self.scheduler = None 37 | 38 | self.scale_factor = None 39 | 40 | def get_input_shape(self) -> Tuple[int]: 41 | """ 42 | Figure out latent noise input shape for the UNet. Requires that unet and vae are defined 43 | 44 | :return: Input shape as a tuple 45 | :rtype: Tuple[int] 46 | """ 47 | assert self.unet and self.vae, "Cannot get input shape if model not initialized" 48 | 49 | in_channels = self.unet.config.in_channels 50 | sample_size = self.sampler.config.img_size // self.scale_factor 51 | 52 | return (in_channels, sample_size, sample_size) 53 | 54 | def from_pretrained_pipeline(self, cls : Type, path : str): 55 | """ 56 | Get unet from some pretrained model pipeline 57 | 58 | :param cls: Class to use for pipeline (i.e. StableDiffusionPipeline) 59 | :type cls: Type 60 | 61 | :param path: Path to pretrained pipeline 62 | :type path: str 63 | 64 | :return: an LDMUNet object with UNet, Text Encoder, VAE, tokenizer and scheduler from pretrained pipeline. Also returns the pretrained pipeline in case caller needs it. 65 | :rtype: LDMUNet 66 | """ 67 | 68 | pipe = cls.from_pretrained(path, use_safetensors = self.config.use_safetensors, local_files_only = self.config.local_model) 69 | 70 | if self.config.attention_slicing: pipe.enable_attention_slicing() 71 | if self.config.xformers_memory_efficient: pipe.enable_xformers_memory_efficient_attention() 72 | 73 | self.unet = pipe.unet 74 | self.text_encoder = pipe.text_encoder 75 | self.vae = pipe.vae 76 | self.scale_factor = pipe.vae_scale_factor 77 | self.encode_prompt = pipe._encode_prompt 78 | 79 | self.text_encoder.requires_grad_(False) 80 | self.vae.requires_grad_(False) 81 | self.unet.requires_grad_(not self.config.lora_rank) 82 | 83 | self.tokenizer = pipe.tokenizer 84 | self.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) 85 | 86 | if self.config.gradient_checkpointing: self.unet.enable_gradient_checkpointing() 87 | 88 | if self.config.lora_rank: 89 | peft_config = LoraConfig( 90 | r=self.config.lora_rank, 91 | lora_alpha=self.config.lora_rank, 92 | init_lora_weights="gaussian", 93 | target_modules=["to_k", "to_q", "to_v", "to_out.0"], 94 | ) 95 | 96 | self.unet.add_adapter(peft_config) 97 | for param in self.unet.parameters(): 98 | # only upcast trainable parameters (LoRA) into fp32 99 | if param.requires_grad: 100 | param.data = param.to(torch.float32) 101 | 102 | return self, pipe 103 | 104 | def preprocess(self, text : Iterable[str], mode = "tokens", **embed_kwargs): 105 | """ 106 | Preprocess text input, either into tokens or into embeddings. 107 | 108 | :param mode: Either "tokens" or "embeds" 109 | :type mode: str 110 | 111 | :param text: Text to preprocess 112 | :type text: Iterable[str] 113 | 114 | :return: Either a tuple of tensors for input_ids and attention_mask or a tensor of embeddings 115 | :rtype: Union[Tuple[Tensor, Tensor], Tensor] 116 | """ 117 | 118 | if mode == "tokens": 119 | tok_out = self.tokenizer( 120 | text, 121 | padding = 'max_length', 122 | max_length = self.tokenizer.model_max_length, 123 | truncation = True, 124 | return_tensors = "pt" 125 | ) 126 | return tok_out.input_ids, tok_out.attention_mask 127 | elif mode == "embeds": 128 | return self.encode_prompt(text, **embed_kwargs) 129 | else: 130 | raise ValueError("Invalid mode specified for preprocessing") 131 | 132 | @torch.no_grad() 133 | def postprocess(self, output : TensorType["batch", "channels", "height", "width"], vae_device = None): 134 | """ 135 | Post process 136 | """ 137 | if vae_device is not None: 138 | self.vae = self.vae.to(vae_device) 139 | output = output.to(vae_device) 140 | images = self.vae.decode(1 / 0.18215 * output).sample 141 | images = (images / 2 + 0.5).clamp(0, 1) 142 | images = images.detach().cpu().permute(0,2,3,1).numpy() 143 | images = (images * 255).round().astype(np.uint8) 144 | return images 145 | 146 | def forward( 147 | self, 148 | pixel_values : TensorType["batch", "channels", "height", "width"], 149 | time_step : Union[TensorType["batch"], int], # Note diffusers tyically does 999->0 as steps 150 | input_ids : TensorType["batch", "seq_len"] = None, 151 | attention_mask : TensorType["batch", "seq_len"] = None, 152 | text_embeds : TensorType["batch", "d"] = None 153 | ) -> TensorType["batch", "channels", "height", "width"]: 154 | """ 155 | For text conditioned UNET, inputs are assumed to be: 156 | pixel_values, input_ids, attention_mask, time_step 157 | """ 158 | with torch.no_grad(): 159 | if text_embeds is None: 160 | text_embeds = self.text_encoder(input_ids, attention_mask)[0] 161 | 162 | return self.unet( 163 | pixel_values, 164 | time_step, 165 | encoder_hidden_states = text_embeds 166 | ).sample 167 | 168 | 169 | 170 | -------------------------------------------------------------------------------- /src/drlx/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Callable, Iterable, Tuple, Any 3 | 4 | from PIL import Image 5 | 6 | import torch 7 | from torch.utils.data import DataLoader, Dataset 8 | 9 | class Pipeline(Dataset): 10 | """ 11 | Pipeline for data during RL training. Subclasses should define some dataset with getitem and len methods. 12 | 13 | :param prep_fn: Function that will be called on iterable of data elements from the pipeline. Not always required, and by default is simply an identity function. 14 | :type prep_fn: Callable 15 | """ 16 | def __init__(self, prep_fn : Callable = None): 17 | super().__init__() 18 | 19 | if not prep_fn: 20 | self.prep : Callable = lambda x: x # identity by default 21 | else: 22 | self.prep = prep_fn 23 | 24 | @abstractmethod 25 | def __getitem__(self, index): 26 | pass 27 | 28 | @abstractmethod 29 | def __len__(self): 30 | pass 31 | 32 | def create_train_loader(self, **kwargs) -> DataLoader: 33 | """ 34 | Create loader for training data. Default behaviour is to just call create_loader (i.e. assumes there is no split) 35 | """ 36 | return self.create_loader(**kwargs) 37 | 38 | @abstractmethod 39 | def create_val_loader(self, **kwargs) -> DataLoader: 40 | """ 41 | Create validation loader. 42 | """ 43 | pass 44 | 45 | @classmethod 46 | def make_default_collate(self, prep : Callable): 47 | """ 48 | Creates a default collate function for the dataloader that assumes dataset elements are tuples of images and strings. 49 | """ 50 | def collate(batch : Iterable[Tuple[Image.Image, str]]): 51 | img_batch = [d[0] for d in batch] 52 | txt_batch = [d[1] for d in batch] 53 | 54 | return prep(img_batch, txt_batch) 55 | 56 | return collate 57 | 58 | def create_loader(self, **kwargs) -> DataLoader: 59 | """ 60 | Create dataloader over self. Assumes __getitem__ and __len__ are implemented. 61 | 62 | :param kwargs: Keyword arguments for the created pytorch dataloader 63 | 64 | :return: Dataloader for dataset within pipeline 65 | :rtype: DataLoader 66 | """ 67 | if self.prep is None: 68 | raise ValueError("Preprocessing function must be set before creating a dataloader.") 69 | 70 | if 'shuffle' in kwargs: 71 | if kwargs['shuffle'] and 'generator' not in kwargs: 72 | generator = torch.Generator() 73 | generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item())) 74 | kwargs['generator'] = generator 75 | 76 | return DataLoader(self, collate_fn = self.make_default_collate(self.prep), **kwargs) 77 | 78 | class PromptPipeline(Pipeline): 79 | """ 80 | Base class for a pipeline that provides text prompts only. 81 | """ 82 | 83 | @classmethod 84 | def make_default_collate(self, prep : Callable): 85 | """ 86 | Default collate for a prompt pipeline which assumes the dataset elements are simply strings. 87 | """ 88 | def collate(batch : Iterable[str]): 89 | return prep(batch) 90 | 91 | return collate 92 | 93 | -------------------------------------------------------------------------------- /src/drlx/pipeline/imagenet_animal_prompts.py: -------------------------------------------------------------------------------- 1 | import random 2 | import requests 3 | from pathlib import Path 4 | from drlx.pipeline import PromptPipeline 5 | 6 | class ImagenetAnimalPrompts(PromptPipeline): 7 | """ 8 | Pipeline of prompts consisting of animals from ImageNet, as used in the original `DDPO paper `_. 9 | """ 10 | def __init__(self, prefix='A picture of a ', postfix=', 4k unreal engine', num=10000, *args, **kwargs): 11 | super().__init__(*args, **kwargs) 12 | if not Path('LOC_synset_mapping.txt').exists(): 13 | r = requests.get("https://raw.githubusercontent.com/formigone/tf-imagenet/master/LOC_synset_mapping.txt") 14 | with open("LOC_synset_mapping.txt", "wb") as f: f.write(r.content) 15 | self.synsets = {k:v for k,v in [o.split(',')[0].split(' ', maxsplit=1) for o in Path('LOC_synset_mapping.txt').read_text().splitlines()]} 16 | self.imagenet_classes = list(self.synsets.values()) 17 | self.prefix = prefix 18 | self.postfix = postfix 19 | self.num = num 20 | 21 | def __getitem__(self, index): 22 | animal = random.choice(self.imagenet_classes[:397]) 23 | return f'{self.prefix}{animal}{self.postfix}' 24 | 25 | def __len__(self): 26 | 'Denotes the total number of samples' 27 | return self.num 28 | 29 | -------------------------------------------------------------------------------- /src/drlx/pipeline/pickapic_prompts.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import torch 3 | 4 | from drlx.pipeline import PromptPipeline 5 | 6 | class PickAPicPrompts(PromptPipeline): 7 | """ 8 | Prompt pipeline consisting of prompts from the `PickAPic dataset `_ training set. 9 | """ 10 | def __init__(self, *args): 11 | super().__init__(*args) 12 | 13 | self.dataset = load_dataset("carperai/pickapic_v1_no_images_training_sfw")["train"] 14 | 15 | def __getitem__(self, index): 16 | return self.dataset[index]['caption'] 17 | 18 | def __len__(self): 19 | return len(self.dataset) 20 | 21 | class PickAPicReplacementPrompts(PromptPipeline): 22 | """ 23 | Prompt pipeline consisting of prompts from the `PickAPic dataset `_ training set. 24 | Differs from main pipeline in that prompts are picked with replacement from some sample size. i.e. when create loader is called, 25 | a sample of the dataset of size provided is drawn. The dataloader draws from this small sample with replacement. With the default 26 | value of 500 and a sample_size of 256, one can expect ~80 duplicates. Duplicates allow the model to see the reward 27 | from multiple generations given the same prompt, potentially providing a stronger learning signal. It is assumed the sample 28 | used during a training epoch is smaller than n_sample. 29 | 30 | :param n_sample: Whenever a dataloader is created, it creates a subset of the pipeline with this size (at random), then draws with replacement. 31 | :type n_sample: int 32 | """ 33 | def __init__(self, n_sample : int = 500, *args): 34 | super().__init__(*args) 35 | 36 | self.dataset = load_dataset("carperai/pickapic_v1_no_images_training_sfw")["train"] 37 | self.n_sample = n_sample 38 | self.indices = None # indices of the subset 39 | self.original_length = len(self.dataset) 40 | 41 | def __getitem__(self, index): 42 | true_index = self.indices[index].item() 43 | return self.dataset[true_index]['caption'] 44 | 45 | def __len__(self): 46 | return self.n_sample 47 | 48 | def subset_shuffle(self): 49 | indices = torch.randint(self.original_length, (self.n_sample,)) 50 | self.indices = indices[torch.randint(self.n_sample, (self.n_sample,))] # randomly sample with replacement 51 | 52 | def create_loader(self, **kwargs): 53 | self.subset_shuffle() 54 | return super().create_loader(**kwargs) -------------------------------------------------------------------------------- /src/drlx/reward_modelling/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | from torchtyping import TensorType 4 | from typing import Iterable 5 | 6 | from transformers import AutoProcessor, AutoModel 7 | import torch 8 | from torch import nn 9 | from PIL import Image 10 | 11 | from drlx.utils import any_chunk 12 | 13 | class RewardModel(nn.Module): 14 | """ 15 | Generalized reward model. Can be a wrapper for any black-box function 16 | that produces reward given pixel values and text input. 17 | """ 18 | def __init__(self): 19 | super().__init__() 20 | 21 | @abstractmethod 22 | def preprocess( 23 | self, *inputs 24 | ) -> Iterable[torch.Tensor]: 25 | """ 26 | Preprocess any form of data into something that can be input into model (generally PIL images and text strings) 27 | """ 28 | pass 29 | 30 | @abstractmethod 31 | def forward( 32 | self, 33 | *inputs 34 | ) -> TensorType["batch"]: 35 | """ 36 | Given any form of raw data (may not be tensors, may not even be batched), processes into reward scores. Inputs must all be iterable 37 | """ 38 | pass 39 | 40 | class NNRewardModel(nn.Module): 41 | """ 42 | Any reward model that requires a neural network. Currently single GPU. 43 | 44 | :param device: Device to store model on 45 | :type device: str 46 | 47 | :param dtype: Data type to use for model input 48 | 49 | :param batch_size: Batch size to pass input in during inference 50 | """ 51 | def __init__(self, device='cpu', dtype=torch.float, batch_size=1): 52 | super().__init__() 53 | 54 | self.device = device 55 | self.dtype = dtype 56 | self.batch_size = batch_size 57 | 58 | @abstractmethod 59 | def _forward(self, *inputs) -> Iterable[float]: 60 | """ 61 | Actual forward pass on a single batch of data 62 | 63 | :param inputs: Arbitrary inputs to reward model 64 | 65 | :return: Rewards across batch of inputs 66 | :rtype: Iterable[float] 67 | """ 68 | pass 69 | 70 | def forward( 71 | self, 72 | *inputs 73 | ) -> TensorType["batch"]: 74 | """ 75 | Wrapper around _forward which chunks inputs based on batch size 76 | 77 | :param inputs: Arbitrary inputs to reward model 78 | 79 | :return Rewards across batch of inputs 80 | :rtype: torch.Tensor 81 | """ 82 | inputs = [any_chunk(input, self.batch_size) for input in inputs] 83 | batched_inputs = zip(*inputs) 84 | outputs = [self._forward(*self.preprocess(*batch)) for batch in batched_inputs] 85 | return torch.cat(outputs) 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /src/drlx/reward_modelling/aesthetics.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | 3 | import torch 4 | from torch import nn 5 | import numpy as np 6 | import requests 7 | import os 8 | import clip 9 | from PIL import Image 10 | 11 | from drlx.reward_modelling import RewardModel 12 | 13 | class MLP(nn.Module): 14 | def __init__(self, input_size, xcol='emb', ycol='avg_rating'): 15 | super().__init__() 16 | self.input_size = input_size 17 | self.xcol = xcol 18 | self.ycol = ycol 19 | self.layers = nn.Sequential( 20 | nn.Linear(self.input_size, 1024), 21 | #nn.ReLU(), 22 | nn.Dropout(0.2), 23 | nn.Linear(1024, 128), 24 | #nn.ReLU(), 25 | nn.Dropout(0.2), 26 | nn.Linear(128, 64), 27 | #nn.ReLU(), 28 | nn.Dropout(0.1), 29 | 30 | nn.Linear(64, 16), 31 | #nn.ReLU(), 32 | 33 | nn.Linear(16, 1) 34 | ) 35 | 36 | def forward(self, x): 37 | return self.layers(x) 38 | 39 | def load_aesthetic_model_weights(cache="."): 40 | """ 41 | Load aesthetic model weights 42 | 43 | :param cache: Stores the downloaded weights here 44 | :type cache: str 45 | """ 46 | weights_fname = "sac+logos+ava1-l14-linearMSE.pth" 47 | loadpath = os.path.join(cache, weights_fname) 48 | 49 | if not os.path.exists(loadpath): 50 | url = ( 51 | "https://github.com/christophschuhmann/" 52 | f"improved-aesthetic-predictor/blob/main/{weights_fname}?raw=true" 53 | ) 54 | r = requests.get(url) 55 | 56 | with open(loadpath, "wb") as f: 57 | f.write(r.content) 58 | 59 | weights = torch.load(loadpath, map_location=torch.device("cpu")) 60 | return weights 61 | 62 | def aesthetic_model_normalize(a, axis=-1, order=2): 63 | """ 64 | Normalize output from aesthetics model 65 | """ 66 | l2 = np.atleast_1d(np.linalg.norm(a, order, axis)) 67 | l2[l2 == 0] = 1 68 | return a / np.expand_dims(l2, axis) 69 | 70 | def aesthetic_scoring(imgs, preprocess, clip_model, aesthetic_model): 71 | imgs = torch.stack([preprocess(Image.fromarray(img)).cuda() for img in imgs]) 72 | with torch.no_grad(): image_features = clip_model.encode_image(imgs) 73 | im_emb_arr = aesthetic_model_normalize(image_features.cpu().detach().numpy()) 74 | prediction = aesthetic_model(torch.from_numpy(im_emb_arr).float().cuda()) 75 | return prediction 76 | 77 | class Aesthetics(RewardModel): 78 | """ 79 | Reward model that rewards images with higher aesthetic score. Uses CLIP and an MLP (not put on any device by default) 80 | 81 | :param device: Device to load model on 82 | :type device: torch.device 83 | """ 84 | def __init__(self, device = None): 85 | super().__init__() 86 | self.model = MLP(768) 87 | self.model.load_state_dict(load_aesthetic_model_weights()) 88 | self.clip_model, self.preprocess = clip.load("ViT-L/14", device=device if device is not None else 'cpu') 89 | 90 | if device is not None: 91 | self.model.to(device) 92 | 93 | def forward(self, images : np.ndarray, prompts : Iterable[str]): 94 | return aesthetic_scoring( 95 | images, 96 | self.preprocess, 97 | self.clip_model, 98 | self.model 99 | ) 100 | -------------------------------------------------------------------------------- /src/drlx/reward_modelling/pickscore.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | from torchtyping import TensorType 3 | 4 | import torch 5 | from transformers import AutoModel, AutoProcessor 6 | from PIL import Image 7 | 8 | from drlx.reward_modelling import NNRewardModel 9 | 10 | class PickScoreModel(NNRewardModel): 11 | """ 12 | Reward model using PickScore model from PickAPic 13 | """ 14 | def __init__(self, **kwargs): 15 | super().__init__(**kwargs) 16 | 17 | processor_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" 18 | model_path = "yuvalkirstain/PickScore_v1" 19 | 20 | self.model = AutoModel.from_pretrained(model_path).to(self.device).to(self.dtype) 21 | self.processor = AutoProcessor.from_pretrained(processor_path) 22 | 23 | def preprocess(self, images : Iterable[Image.Image], prompts : Iterable[str]): 24 | """ 25 | Preprocess images and prompts into tensors, making sure to move to correct device and data type 26 | """ 27 | image_inputs = self.processor( 28 | images=images, 29 | padding=True, 30 | truncation=True, 31 | max_length=77, 32 | return_tensors="pt" 33 | ) 34 | 35 | text_inputs = self.processor( 36 | text=prompts, 37 | padding=True, 38 | truncation=True, 39 | max_length=77, 40 | return_tensors="pt" 41 | ) 42 | 43 | pixels, ids, mask = image_inputs['pixel_values'], text_inputs['input_ids'], text_inputs['attention_mask'] 44 | pixels = pixels.to(device = self.device, dtype = self.dtype) 45 | ids = ids.to(device = self.device) 46 | mask = mask.to(device = self.device) 47 | return pixels, ids, mask 48 | 49 | 50 | @torch.no_grad() # This repo does not train the model, so in general, no_grad will be used here 51 | def _forward( 52 | self, 53 | pixel_values : TensorType["batch", "channels", "height", "width"], 54 | input_ids : TensorType["batch", "sequence"], 55 | attention_mask : TensorType["batch", "sequence"] 56 | ) -> TensorType["batch"]: 57 | 58 | image_embs = self.model.get_image_features(pixel_values=pixel_values.cuda()) 59 | image_embs /= image_embs.norm(dim=-1, keepdim=True) 60 | 61 | text_embs = self.model.get_text_features(input_ids=input_ids.cuda(), attention_mask=attention_mask.cuda()) 62 | text_embs /= text_embs.norm(dim=-1, keepdim=True) 63 | 64 | scores = torch.einsum('bd,bd->b', image_embs, text_embs) 65 | scores = self.model.logit_scale.exp() * scores 66 | 67 | return scores -------------------------------------------------------------------------------- /src/drlx/reward_modelling/toy_rewards.py: -------------------------------------------------------------------------------- 1 | """ 2 | Toy reward models for testing purposes 3 | """ 4 | 5 | from io import BytesIO 6 | from PIL import Image 7 | 8 | import torch 9 | import numpy as np 10 | 11 | from drlx.reward_modelling import RewardModel 12 | 13 | class AverageBlueReward(RewardModel): 14 | """ 15 | Rewards "blue-ness" of image 16 | """ 17 | def __init__(self): 18 | super().__init__() 19 | 20 | def forward(self, images, prompts): 21 | # If the input is a list of PIL Images, convert to numpy array 22 | if isinstance(images, list): 23 | images = np.array([np.array(img) for img in images]) 24 | 25 | blue_channel = images[:,:,:,2] # N x 256 x 256 26 | 27 | # Calculate the mean of the blue channel for each image 28 | blueness = blue_channel.astype(float).mean(axis=(1,2)) # N 29 | blueness = (2 * blueness - 255)/255 # normalize to [0,1] 30 | 31 | return torch.from_numpy(blueness) 32 | 33 | class JPEGCompressability(RewardModel): 34 | """ 35 | Rewards JPEG compression potential of image (from https://arxiv.org/pdf/2305.13301.pdf) 36 | """ 37 | def __init__(self, quality=10): 38 | super().__init__() 39 | self.quality = quality 40 | 41 | def encode_jpeg(self, x, quality = 95): 42 | img = Image.fromarray(x) 43 | buffer = BytesIO() 44 | img.save(buffer, 'JPEG', quality=quality) 45 | jpeg = buffer.getvalue() 46 | bytes = np.frombuffer(jpeg, dtype = np.uint8) 47 | return len(bytes) / 1000 48 | 49 | def forward(self, images, prompts): 50 | scores = [-1 * self.encode_jpeg(img) for img in images] 51 | return torch.tensor(scores) -------------------------------------------------------------------------------- /src/drlx/sampling/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Iterable, Tuple, Any, Optional 2 | from torchtyping import TensorType 3 | 4 | import torch 5 | from tqdm import tqdm 6 | import math 7 | import einops as eo 8 | 9 | from drlx.utils import rescale_noise_cfg 10 | 11 | from drlx.configs import SamplerConfig, DDPOConfig 12 | 13 | class Sampler: 14 | """ 15 | Generic class for sampling generations using a denoiser. Assumes LDMUnet 16 | """ 17 | def __init__(self, config : SamplerConfig = SamplerConfig()): 18 | self.config = config 19 | 20 | def cfg_rescale(self, pred : TensorType["2 * b", "c", "h", "w"]): 21 | """ 22 | Applies classifier free guidance to prediction and rescales if cfg_rescaling is enabled 23 | 24 | :param pred: 25 | Assumed to be batched repeated prediction with first half consisting of 26 | unconditioned (empty token) predictions and second half being conditioned 27 | predictions 28 | """ 29 | 30 | pred_uncond, pred_cond = pred.chunk(2) 31 | pred = pred_uncond + self.config.guidance_scale * (pred_cond - pred_uncond) 32 | 33 | if self.config.guidance_rescale is not None: 34 | pred = rescale_noise_cfg(pred, pred_cond, self.config.guidance_rescale) 35 | 36 | return pred 37 | 38 | @torch.no_grad() 39 | def sample(self, prompts : Iterable[str], denoiser, device = None, show_progress : bool = False, accelerator = None): 40 | """ 41 | Samples latents given some prompts and a denoiser 42 | 43 | :param prompts: Text prompts for image generation (to condition denoiser) 44 | :param denoiser: Model to use for denoising 45 | :param device: Device on which to perform model inference 46 | :param show_progress: Whether to display a progress bar for the sampling steps 47 | :param accelerator: Accelerator object for accelerated training (optional) 48 | 49 | :return: Latents unless postprocess flag is set to true in config, in which case VAE decoded latents are returned (i.e. images) 50 | """ 51 | if accelerator is None: 52 | denoiser_unwrapped = denoiser 53 | else: 54 | denoiser_unwrapped = accelerator.unwrap_model(denoiser) 55 | 56 | scheduler = denoiser_unwrapped.scheduler 57 | preprocess = denoiser_unwrapped.preprocess 58 | noise_shape = denoiser_unwrapped.get_input_shape() 59 | 60 | text_embeds = preprocess( 61 | prompts, mode = "embeds", device = device, 62 | num_images_per_prompt = 1, 63 | do_classifier_free_guidance = self.config.guidance_scale > 1.0 64 | ).detach() 65 | 66 | scheduler.set_timesteps(self.config.num_inference_steps, device = device) 67 | latents = torch.randn(len(prompts), *noise_shape, device = device) 68 | 69 | for i, t in enumerate(tqdm(scheduler.timesteps), disable = not show_progress): 70 | input = torch.cat([latents] * 2) 71 | input = scheduler.scale_model_input(input, t) 72 | 73 | pred = denoiser( 74 | pixel_values=input, 75 | time_step = t, 76 | text_embeds = text_embeds 77 | ) 78 | 79 | # guidance 80 | pred = self.cfg_rescale(pred) 81 | 82 | # step backward 83 | scheduler_out = scheduler.step(pred, t, latents, self.config.eta) 84 | latents = scheduler_out.prev_sample 85 | 86 | if self.config.postprocess: 87 | return denoiser_unwrapped.postprocess(latents) 88 | else: 89 | return latents 90 | 91 | class DDPOSampler(Sampler): 92 | def step_and_logprobs(self, 93 | scheduler, 94 | pred : TensorType["b", "c", "h", "w"], 95 | t : float, 96 | latents : TensorType["b", "c", "h", "w"], 97 | old_pred : Optional[TensorType["b", "c", "h", "w"]] = None 98 | ): 99 | """ 100 | Steps backwards using scheduler. Considers the prediction as an action sampled 101 | from a normal distribution and returns average log probability for that prediction. 102 | Can also be used to find probability of current model giving some other prediction (old_pred) 103 | 104 | :param scheduler: Scheduler being used for diffusion process 105 | :param pred: Denoiser prediction with CFG and scaling accounted for 106 | :param t: Timestep in diffusion process 107 | :param latents: Latent vector given as input to denoiser 108 | :param old_pred: Alternate prediction. If given, computes log probability of current model predicting alternative output. 109 | """ 110 | scheduler_out = scheduler.step(pred, t, latents, self.config.eta, variance_noise=0) 111 | 112 | # computing log_probs 113 | t_1 = t - scheduler.config.num_train_timesteps // self.config.num_inference_steps 114 | variance = scheduler._get_variance(t, t_1) 115 | std_dev_t = self.config.eta * variance ** 0.5 116 | prev_sample_mean = scheduler_out.prev_sample 117 | prev_sample = prev_sample_mean + torch.randn_like(prev_sample_mean) * std_dev_t 118 | 119 | std_dev_t = torch.clip(std_dev_t, 1e-6) # force sigma > 1e-6 120 | 121 | # If old_pred provided, we are finding probability of new model outputting same action as before 122 | # Otherwise finding probability of current action 123 | action = old_pred if old_pred is not None else prev_sample # Log prob of new model giving old output 124 | log_probs = -((action.detach() - prev_sample_mean) ** 2) / (2 * std_dev_t ** 2) - torch.log(std_dev_t) - math.log(math.sqrt(2 * math.pi)) 125 | log_probs = eo.reduce(log_probs, 'b c h w -> b', 'mean') 126 | 127 | return prev_sample, log_probs 128 | 129 | @torch.no_grad() 130 | def sample( 131 | self, prompts, denoiser, device, 132 | show_progress : bool = False, 133 | accelerator = None 134 | ) -> Iterable[torch.Tensor]: 135 | """ 136 | DDPO sampling is analagous to playing a game in an RL environment. This function samples 137 | given denoiser and prompts but in addition to giving latents also gives log probabilities 138 | for predictions as well as ALL predictions (i.e. at each timestep) 139 | 140 | :param prompts: Text prompts to condition denoiser 141 | :param denoiser: Denoising model 142 | :param device: Device to do inference on 143 | :param show_progress: Display progress bar? 144 | :param accelerator: Accelerator object for accelerated training (optional) 145 | 146 | :return: triple of final denoised latents, all model predictions, all log probabilities for each prediction 147 | """ 148 | 149 | if accelerator is None: 150 | denoiser_unwrapped = denoiser 151 | else: 152 | denoiser_unwrapped = accelerator.unwrap_model(denoiser) 153 | 154 | scheduler = denoiser_unwrapped.scheduler 155 | preprocess = denoiser_unwrapped.preprocess 156 | noise_shape = denoiser_unwrapped.get_input_shape() 157 | 158 | text_embeds = preprocess( 159 | prompts, mode = "embeds", device = device, 160 | num_images_per_prompt = 1, 161 | do_classifier_free_guidance = self.config.guidance_scale > 1.0 162 | ).detach() 163 | 164 | scheduler.set_timesteps(self.config.num_inference_steps, device = device) 165 | latents = torch.randn(len(prompts), *noise_shape, device = device) 166 | 167 | all_step_preds, all_log_probs = [latents], [] 168 | 169 | for t in tqdm(scheduler.timesteps, disable = not show_progress): 170 | latent_input = torch.cat([latents] * 2) 171 | latent_input = scheduler.scale_model_input(latent_input, t) 172 | 173 | pred = denoiser( 174 | pixel_values = latent_input, 175 | time_step = t, 176 | text_embeds = text_embeds 177 | ) 178 | 179 | # cfg 180 | pred = self.cfg_rescale(pred) 181 | 182 | # step 183 | prev_sample, log_probs = self.step_and_logprobs(scheduler, pred, t, latents) 184 | 185 | all_step_preds.append(prev_sample) 186 | all_log_probs.append(log_probs) 187 | latents = prev_sample 188 | 189 | return latents, torch.stack(all_step_preds), torch.stack(all_log_probs) 190 | 191 | def compute_loss( 192 | self, prompts, denoiser, device, 193 | show_progress : bool = False, 194 | advantages = None, old_preds = None, old_log_probs = None, 195 | method_config : DDPOConfig = None, 196 | accelerator = None 197 | ): 198 | 199 | 200 | """ 201 | Computes the loss for the DDPO sampling process. This function is used to train the denoiser model. 202 | 203 | :param prompts: Text prompts to condition the denoiser 204 | :param denoiser: Denoising model 205 | :param device: Device to perform model inference on 206 | :param show_progress: Whether to display a progress bar for the sampling steps 207 | :param advantages: Normalized advantages obtained from reward computation 208 | :param old_preds: Previous predictions from past model 209 | :param old_log_probs: Log probabilities of predictions from past model 210 | :param method_config: Configuration for the DDPO method 211 | :param accelerator: Accelerator object for accelerated training (optional) 212 | 213 | :return: Total loss computed over the sampling process 214 | """ 215 | 216 | # All metrics are reduced and gathered before result is returned 217 | metrics = { 218 | "loss" : [], 219 | "kl_div" : [], # ~ KL div between new policy and old one (average) 220 | "clip_frac" : [], # Proportion of policy updates where magnitude of update was clipped 221 | } 222 | 223 | if accelerator is None: 224 | denoiser_unwrapped = denoiser 225 | else: 226 | denoiser_unwrapped = accelerator.unwrap_model(denoiser) 227 | 228 | scheduler = denoiser_unwrapped.scheduler 229 | preprocess = denoiser_unwrapped.preprocess 230 | 231 | adv_clip = method_config.clip_advantages # clip value for advantages 232 | pi_clip = method_config.clip_ratio # clip value for policy ratio 233 | 234 | text_embeds = preprocess( 235 | prompts, mode = "embeds", device = device, 236 | num_images_per_prompt = 1, 237 | do_classifier_free_guidance = self.config.guidance_scale > 1.0 238 | ).detach() 239 | 240 | scheduler.set_timesteps(self.config.num_inference_steps, device = device) 241 | total_loss = 0. 242 | 243 | for i, t in enumerate(tqdm(scheduler.timesteps, disable = not show_progress)): 244 | latent_input = torch.cat([old_preds[i].detach()] * 2) 245 | latent_input = scheduler.scale_model_input(latent_input, t) 246 | 247 | pred = denoiser( 248 | pixel_values = latent_input, 249 | time_step = t, 250 | text_embeds = text_embeds 251 | ) 252 | 253 | # cfg 254 | pred = self.cfg_rescale(pred) 255 | 256 | # step 257 | prev_sample, log_probs = self.step_and_logprobs( 258 | scheduler, pred, t, old_preds[i], 259 | old_preds[i+1] 260 | ) 261 | 262 | # Need to be computed and detached again because of autograd weirdness 263 | clipped_advs = torch.clip(advantages,-adv_clip,adv_clip).detach() 264 | 265 | # ppo actor loss 266 | 267 | ratio = torch.exp(log_probs - old_log_probs[i].detach()) 268 | surr1 = -clipped_advs * ratio 269 | surr2 = -clipped_advs * torch.clip(ratio, 1. - pi_clip, 1. + pi_clip) 270 | loss = torch.max(surr1, surr2).mean() 271 | if accelerator is not None: 272 | accelerator.backward(loss) 273 | else: 274 | loss.backward() 275 | 276 | # Metric computations 277 | kl_div = 0.5 * (log_probs - old_log_probs[i]).mean() ** 2 278 | clip_frac = ((ratio < 1 - pi_clip) | (ratio > 1 + pi_clip)).float().mean() 279 | 280 | metrics["loss"].append(loss.item()) 281 | metrics["kl_div"].append(kl_div.item()) 282 | metrics["clip_frac"].append(clip_frac.item()) 283 | 284 | # Reduce across timesteps then across devices 285 | for k in metrics: 286 | metrics[k] = torch.tensor(metrics[k]).mean().cuda() # Needed for reduction to work 287 | if accelerator is not None: 288 | metrics = accelerator.reduce(metrics, 'mean') 289 | 290 | return metrics -------------------------------------------------------------------------------- /src/drlx/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Callable, Dict, Any, Iterable 2 | from torchtyping import TensorType 3 | 4 | from abc import abstractmethod 5 | import os 6 | 7 | import torch 8 | 9 | from drlx.configs import DRLXConfig 10 | from drlx.reward_modelling import RewardModel 11 | from drlx.denoisers.ldm_unet import LDMUNet 12 | from drlx.pipeline import Pipeline 13 | from drlx.utils import get_optimizer_class, get_scheduler_class, get_diffusion_pipeline_class 14 | 15 | from PIL import Image 16 | 17 | class BaseTrainer: 18 | """ 19 | Base class for any DRLX trainer 20 | """ 21 | def __init__(self, config : DRLXConfig): 22 | self.config = config 23 | 24 | if self.config.train.tf32: 25 | torch.backends.cuda.matmul.allow_tf32 = True 26 | torch.backends.cudnn.allow_tf32 = True 27 | 28 | # Assume these are defined in base classes 29 | self.optimizer = None 30 | self.scheduler = None 31 | self.model = None 32 | 33 | def setup_optimizer(self): 34 | """ 35 | Returns an optimizer derived from an instance's config 36 | """ 37 | optimizer_class = get_optimizer_class(self.config.optimizer.name) 38 | optimizer = optimizer_class( 39 | filter(lambda p: p.requires_grad, self.model.parameters()), 40 | **self.config.optimizer.kwargs, 41 | ) 42 | return optimizer 43 | 44 | def setup_scheduler(self): 45 | """ 46 | Returns a learning rate scheduler derived from an instance's config 47 | """ 48 | scheduler_class = get_scheduler_class(self.config.scheduler.name) 49 | scheduler = scheduler_class(self.optimizer, **self.config.scheduler.kwargs) 50 | return scheduler 51 | 52 | def get_arch(self, config): 53 | """ 54 | Get model class from arch_name in config file. Currently only supports LDMUNet 55 | """ 56 | model_name = LDMUNet # nothing else is supported for now (TODO: add support for other models) 57 | return model_name 58 | 59 | @abstractmethod 60 | def train(self, pipeline : Pipeline, reward_fn : Callable[[Iterable[Image.Image], Iterable[str]], TensorType["batch"]]): 61 | """ 62 | Trains model on a given pipeline using a given reward function. 63 | 64 | :param pipeline: Data pipeline used for training 65 | :param reward_fn: Function used to get rewards. Should take tuples of images (either as a sequence of numpy arrays, or as a list of images) 66 | """ 67 | pass 68 | 69 | def save_checkpoint(self, fp : str, components : Dict[str, Any], index : int = None): 70 | """ 71 | Basic checkpoint saving for any derived trainer to use 72 | 73 | :param fp: Path to save checkpoint to 74 | :type fp: str 75 | 76 | :param components: Dictionary of all components to save (i.e. model, optimizer, scheduler, etc.) 77 | :type components: Dict 78 | 79 | :param index: When provided, uses fp as a root folder and puts checkpoint under a subdirectory that is named numerically with index 80 | :type index: Optional[int] 81 | """ 82 | if not os.path.exists(fp): 83 | os.makedirs(fp) 84 | 85 | if index is not None: 86 | fp = os.path.join(fp, str(index)) 87 | if not os.path.exists(fp): 88 | os.makedirs(fp) 89 | 90 | for key, component in components.items(): 91 | torch.save(component, os.path.join(fp, f"{key}.pt")) 92 | 93 | def load_checkpoint(self, fp: str, index: int = None) -> Dict[str, Any]: 94 | """ 95 | Basic checkpoint loading for derived trainers to use. 96 | 97 | :param fp: Path to load checkpoint from 98 | :type fp: str 99 | 100 | :param index: When provided, uses fp as root and loads subdirectory with numerical name given by index 101 | :type index: Optional[int] 102 | 103 | :return: Dictionary of components and their states 104 | :rtype: Dict 105 | """ 106 | # If an index is given, update the file path to include the subdirectory with the index as its name 107 | if index is not None: 108 | fp = os.path.join(fp, str(index)) 109 | 110 | # Initialize an empty dictionary to store the loaded components 111 | loaded_components = {} 112 | 113 | # Iterate through the files in the directory 114 | for file_name in os.listdir(fp): 115 | # Check if the file has a .pt extension 116 | if file_name.endswith(".pt"): 117 | # Load the component using torch.load and add it to the loaded_components dictionary 118 | key = file_name[:-3] # Remove the .pt extension from the file name to get the key 119 | component = torch.load(os.path.join(fp, file_name)) 120 | loaded_components[key] = component 121 | 122 | return loaded_components 123 | -------------------------------------------------------------------------------- /src/drlx/trainer/ddpo_trainer.py: -------------------------------------------------------------------------------- 1 | from torchtyping import TensorType 2 | from typing import Iterable, Tuple, Callable 3 | 4 | from accelerate import Accelerator 5 | from drlx.configs import DRLXConfig, DDPOConfig 6 | from drlx.trainer import BaseTrainer 7 | from drlx.sampling import DDPOSampler 8 | from drlx.utils import suppress_warnings, Timer, PerPromptStatTracker, scoped_seed, save_images 9 | 10 | import torch 11 | import einops as eo 12 | import os 13 | import gc 14 | import logging 15 | from torch.utils.data import DataLoader, Dataset 16 | from tqdm import tqdm 17 | import numpy as np 18 | import wandb 19 | import accelerate.utils 20 | from PIL import Image 21 | 22 | from diffusers import StableDiffusionPipeline 23 | 24 | from diffusers.utils import convert_state_dict_to_diffusers 25 | from peft.utils import get_peft_model_state_dict 26 | 27 | class DDPOExperienceReplay(Dataset): 28 | """ 29 | Utility class to compute advantages and create dataloader from sampling experiences. 30 | """ 31 | 32 | def __init__(self, 33 | accelerator: Accelerator, 34 | reward_fn: callable, 35 | ppst: PerPromptStatTracker, 36 | imgs : Iterable[Iterable], 37 | prompts : Iterable[Iterable[str]], 38 | all_step_preds : Iterable[TensorType["t","b","c","h","w"]], 39 | log_probs : Iterable[TensorType["t", "b"]], 40 | **dataloader_kwargs 41 | ): 42 | # Compute rewards first 43 | rewards = [reward_fn(img_batch, prompt_batch) 44 | for img_batch, prompt_batch in zip(imgs, prompts)] 45 | 46 | 47 | # Combine all_step_preds, log_probs 48 | self.all_step_preds = torch.cat(all_step_preds, dim = 1) # [t, n, c, h, w] 49 | self.log_probs = torch.cat(log_probs, dim = 1) # [t, n] 50 | self.rewards = torch.cat(rewards) # [n] 51 | 52 | # Gather all rewards 53 | self.all_rewards = accelerator.gather(self.rewards).detach().cpu().numpy() 54 | 55 | # Prompts is list of batches of prompts (list of list of strings) 56 | # Iterate through each batch and each prompt within it to unwrap into single list of prompts 57 | self.prompts = [prompt for prompt_list in prompts for prompt in prompt_list] 58 | 59 | # Gather all prompts 60 | self.all_prompts = accelerate.utils.gather_object(self.prompts) 61 | 62 | # Compute advantages 63 | advantages = torch.from_numpy(ppst.update(np.array(self.all_prompts), self.all_rewards)).float() 64 | self.advantages = advantages.reshape(accelerator.num_processes, -1)[accelerator.process_index].to(accelerator.device) 65 | 66 | def __getitem__(self, i): 67 | return self.all_step_preds[:,i], self.log_probs[:,i], self.advantages[i], self.prompts[i] 68 | 69 | def __len__(self): 70 | return self.all_step_preds.size(1) 71 | 72 | def create_loader(self, **kwargs): 73 | def collate(batch): 74 | # unzip the batch 75 | all_steps, log_probs, advs, prompts = list(zip(*batch)) 76 | all_steps = torch.stack(all_steps, dim = 1) 77 | log_probs = torch.stack(log_probs, dim = 1) 78 | advs = torch.stack(advs) 79 | prompts = list(prompts) 80 | 81 | return (all_steps, log_probs, advs, prompts) 82 | 83 | return DataLoader(self, collate_fn=collate, **kwargs) 84 | 85 | class DDPOTrainer(BaseTrainer): 86 | """ 87 | DDPO Accelerated Trainer initilization from config. During init, sets up model, optimizer, sampler and logging 88 | 89 | :param config: DRLX config 90 | :type config: DRLXConfig 91 | """ 92 | 93 | def __init__(self, config : DRLXConfig): 94 | super().__init__(config) 95 | 96 | assert isinstance(self.config.method, DDPOConfig), "ERROR: Method config must be DDPO config" 97 | 98 | # Figure out batch size and accumulation steps 99 | if self.config.train.target_batch is not None: # Just use normal batch_size 100 | self.accum_steps = (self.config.train.target_batch // self.config.train.batch_size) 101 | else: 102 | self.accum_steps = 1 103 | 104 | self.accelerator = Accelerator( 105 | log_with = config.logging.log_with, 106 | gradient_accumulation_steps = self.accum_steps 107 | ) 108 | 109 | # Disable tokenizer warnings since they clutter the CLI 110 | kw_str = self.config.train.suppress_log_keywords 111 | if kw_str is not None: 112 | for prefix in kw_str.split(","): 113 | suppress_warnings(prefix.strip()) 114 | 115 | self.pipe = None # Store reference to pipeline so that we can use save_pretrained later 116 | self.model = self.setup_model() 117 | self.optimizer = self.setup_optimizer() 118 | self.scheduler = self.setup_scheduler() 119 | 120 | self.sampler = self.model.sampler 121 | self.model, self.optimizer, self.scheduler = self.accelerator.prepare( 122 | self.model, self.optimizer, self.scheduler 123 | ) 124 | 125 | # Setup tracking 126 | 127 | tracker_kwargs = {} 128 | self.use_wandb = not (config.logging.wandb_project is None) 129 | if self.use_wandb: 130 | log = config.logging 131 | tracker_kwargs["wandb"] = { 132 | "name" : log.run_name, 133 | "entity" : log.wandb_entity, 134 | "mode" : "online" 135 | } 136 | 137 | self.accelerator.init_trackers( 138 | project_name = log.wandb_project, 139 | config = config.to_dict(), 140 | init_kwargs = tracker_kwargs 141 | ) 142 | 143 | self.world_size = self.accelerator.state.num_processes 144 | 145 | def setup_model(self): 146 | """ 147 | Set up model from config. 148 | """ 149 | model = self.get_arch(self.config)(self.config.model, sampler = DDPOSampler(self.config.sampler)) 150 | if self.config.model.model_path is not None: 151 | model, pipe = model.from_pretrained_pipeline(StableDiffusionPipeline, self.config.model.model_path) 152 | 153 | self.pipe = pipe 154 | return model 155 | 156 | def loss( 157 | self, 158 | x_t : TensorType["timesteps", "batch", "channels", "height", "width"], 159 | log_probs_t : TensorType["timesteps", "batch"], 160 | advantages : TensorType["batch"], 161 | prompts : Iterable[str] 162 | ): 163 | """ 164 | Get loss for training 165 | 166 | :param x_t: Samples across time steps and across batch 167 | :type x_t: torch.Tensor 168 | 169 | :param log_probs_t: Log probabilities for each sample prediction 170 | :type log_probs_t: torch.Tensor 171 | 172 | :advantages: Advantages associated with each image across batch 173 | :type advantages: torch.Tensor 174 | 175 | :prompts: Prompts used for generation across the batch 176 | :type prompts: Iterable[str] 177 | 178 | :return: loss 179 | :rtype: torch.Tensor 180 | """ 181 | return self.sampler.compute_loss( 182 | prompts=prompts, 183 | denoiser=self.model, 184 | device=self.accelerator.device, 185 | advantages=advantages, 186 | old_preds=x_t, 187 | old_log_probs=log_probs_t, 188 | show_progress=self.accelerator.is_main_process, 189 | method_config=self.config.method, 190 | accelerator=self.accelerator 191 | ) 192 | 193 | def sample(self, prompts : Iterable[str]) -> Tuple[torch.Tensor]: 194 | """ 195 | Sample predictions, predictions at time steps and log probabilities from sampler 196 | 197 | :param prompts: Batched prompts to use for sampling 198 | :type prompts: Iterable[str] 199 | 200 | :return: 3 Tensors: final predictions for latent, all step predictions during denoising process, and log probabilities for each prediction 201 | :rtype: Tuple[torch.Tensor] 202 | """ 203 | preds, all_preds, log_probs = self.sampler.sample( 204 | prompts = prompts, 205 | denoiser = self.model, 206 | device = self.accelerator.device, 207 | accelerator = self.accelerator 208 | ) 209 | 210 | return preds, all_preds, log_probs 211 | 212 | def sample_and_calculate_rewards(self, prompts : Iterable[str], reward_fn : Callable) -> Tuple: 213 | """ 214 | Samples a batch of images and calculates the rewards for each image 215 | 216 | :param prompts: Batch of prompts to sample with 217 | :type prompts: Iterable[str] 218 | 219 | :param reward_fn: Function to be called on final images and prompts to be used for reward computation 220 | :type reward_fn: Callable[[np.ndarray, Iterable[str]], Iterable[float]] 221 | 222 | :return: Final images, rewards, all step predictions, log probabilities for predictions 223 | :rtype: Tuple 224 | """ 225 | 226 | preds, all_preds, log_probs = self.sample(prompts) 227 | imgs = self.accelerator.unwrap_model(self.model).postprocess(preds) 228 | 229 | rewards = reward_fn(imgs, prompts).to(self.accelerator.device) 230 | return imgs, rewards, all_preds, log_probs 231 | 232 | def train(self, prompt_pipeline, reward_fn): 233 | """ 234 | Trains the model based on config parameters. Needs to be passed a prompt pipeline and reward function. 235 | 236 | :param prompt_pipeline: Pipeline to draw text prompts from. Should be composed of just strings. 237 | :type prompt_pipeline: PromptPipeline 238 | 239 | :param reward_fn: Any function that returns a tensor of scalar rewards given np array of images (uint8) and text prompts (strings). 240 | It is fine to have a reward function that only rewards images without looking at prompts, simply add prompts as a dummy input. 241 | :type reward_fn: Callable[[np.array, Iterable[str], torch.Tensor] 242 | """ 243 | 244 | # === SETUP === 245 | 246 | # Singular dataloader made to get a sample of prompts 247 | # This sample batch is dependent on config seed so it can be same across runs 248 | with scoped_seed(self.config.train.seed): 249 | dataloader = self.accelerator.prepare( 250 | prompt_pipeline.create_train_loader(batch_size = self.config.train.batch_size, shuffle = False) 251 | ) 252 | sample_prompts = self.config.train.sample_prompts 253 | if sample_prompts is None: 254 | sample_prompts = [] 255 | if len(sample_prompts) < self.config.train.batch_size: 256 | new_sample_prompts = next(iter(dataloader)) 257 | sample_prompts += new_sample_prompts 258 | sample_prompts = sample_prompts[:self.config.train.batch_size] 259 | 260 | assert isinstance(self.sampler, DDPOSampler), "Error: Model Sampler for DDPO training must be DDPO sampler" 261 | 262 | per_prompt_stat_tracker = PerPromptStatTracker(self.config.method.buffer_size, self.config.method.min_count) 263 | 264 | if isinstance(reward_fn, torch.nn.Module): 265 | reward_fn = self.accelerator.prepare(reward_fn) 266 | 267 | # Set the epoch count 268 | outer_epochs = self.config.train.num_epochs 269 | if self.config.train.total_samples is not None: 270 | outer_epochs = int(self.config.train.total_samples // self.config.train.num_samples_per_epoch) 271 | 272 | # Timer to measure time per 1k images (as metric) 273 | timer = Timer() 274 | def time_per_1k(n_samples : int): 275 | total_time = timer.hit() 276 | return total_time * 1000 / n_samples 277 | 278 | # === MAIN TRAINING LOOP === 279 | 280 | mean_rewards = [] 281 | accum = 0 282 | last_epoch_time = timer.hit() 283 | for epoch in range(outer_epochs): 284 | 285 | # Clean up unused resources 286 | preds, all_step_preds, log_probs, all_prompts = [], [], [], [] 287 | self.accelerator._dataloaders = [] # Clear dataloaders 288 | gc.collect() 289 | torch.cuda.empty_cache() 290 | 291 | self.accelerator.print(f"Epoch {epoch}/{outer_epochs}. {epoch * self.config.train.num_samples_per_epoch} samples seen. Averaging {last_epoch_time:.2f}s/1k samples.") 292 | 293 | # Make a new dataloader to reshuffle data 294 | dataloader = prompt_pipeline.create_train_loader(batch_size = self.config.train.sample_batch_size, shuffle = True) 295 | dataloader = self.accelerator.prepare(dataloader) 296 | 297 | # Sample (play the game) 298 | data_steps = self.config.train.num_samples_per_epoch // self.config.train.sample_batch_size // self.world_size 299 | self.accelerator.print("Sampling...") 300 | for i, prompts in enumerate(tqdm(dataloader, total = data_steps, disable=not self.accelerator.is_main_process)): 301 | if i >= data_steps: 302 | break 303 | 304 | batch_preds, batch_all_step_preds, batch_log_probs = self.sample(prompts) 305 | 306 | preds.append(batch_preds) 307 | all_step_preds.append(batch_all_step_preds) 308 | log_probs.append(batch_log_probs) 309 | all_prompts.append(prompts) 310 | 311 | # Get rewards from experiences 312 | self.accelerator.wait_for_everyone() 313 | unwrapped_model = self.accelerator.unwrap_model(self.model) 314 | imgs = [unwrapped_model.postprocess(pred) for pred in preds] 315 | 316 | # Experience replay computes normalized rewards, 317 | # then is used to create a loader for training 318 | exp_replay = DDPOExperienceReplay( 319 | self.accelerator, 320 | reward_fn, per_prompt_stat_tracker, 321 | imgs, all_prompts, 322 | all_step_preds, log_probs 323 | ) 324 | all_rewards = self.accelerator.gather(exp_replay.rewards).detach().cpu().numpy() 325 | 326 | experience_loader = exp_replay.create_loader(batch_size = self.config.train.batch_size) 327 | 328 | mean_rewards.append(all_rewards.mean().item()) 329 | 330 | # Consistent prompt sample for logging 331 | with scoped_seed(self.config.train.seed): 332 | sample_imgs_np, sample_rewards, _, _ = self.sample_and_calculate_rewards(sample_prompts, reward_fn) 333 | sample_imgs = [wandb.Image(Image.fromarray(img), caption=prompt + f', {reward.item()}') for img, prompt, reward in zip(sample_imgs_np, sample_prompts, sample_rewards)] 334 | batch_imgs = [wandb.Image(Image.fromarray(img), caption=prompt) for img, prompt in zip(imgs[-1], all_prompts[-1])] 335 | 336 | # Logging 337 | if self.use_wandb: 338 | self.accelerator.log({ 339 | "mean_reward" : mean_rewards[-1], 340 | "reward_hist" : wandb.Histogram(all_rewards), 341 | "time_per_1k" : last_epoch_time, 342 | "img_batch" : batch_imgs, 343 | "img_sample" : sample_imgs 344 | }) 345 | # save images 346 | if self.accelerator.is_main_process and self.config.train.save_samples: 347 | save_images(sample_imgs_np, f"./samples/{self.config.logging.run_name}/{epoch}") 348 | 349 | 350 | 351 | # Inner epochs and actual training 352 | self.accelerator.print("Training...") 353 | experience_loader = self.accelerator.prepare(experience_loader) 354 | # Inner epochs normally one, disable progress bar when this is the case 355 | for inner_epoch in tqdm(range(self.config.method.num_inner_epochs), 356 | disable=(not self.accelerator.is_main_process) or (self.config.method.num_inner_epochs == 1) 357 | ): 358 | for (all_step_preds, log_probs, advantages, prompts) in tqdm(experience_loader, disable=not self.accelerator.is_main_process): 359 | with self.accelerator.accumulate(self.model): # Accumulate across minibatches 360 | metrics = self.loss(all_step_preds, log_probs, advantages, prompts) 361 | self.accelerator.clip_grad_norm_(filter(lambda p: p.requires_grad, self.model.parameters()), self.config.train.grad_clip) 362 | self.optimizer.step() 363 | self.scheduler.step() 364 | self.optimizer.zero_grad() 365 | if self.use_wandb: 366 | self.accelerator.log({ #TODO: add approx_kl tracking 367 | "loss" : metrics["loss"], 368 | "kl_div" : metrics["kl_div"], 369 | "clip_frac" : metrics["clip_frac"], 370 | "lr" : self.scheduler.get_last_lr()[0], 371 | "epoch": epoch 372 | }) 373 | 374 | # Save model every [interval] epochs 375 | accum += 1 376 | if accum % self.config.train.checkpoint_interval == 0 and self.config.train.checkpoint_interval > 0: 377 | self.accelerator.print("Saving...") 378 | base_path = f"./checkpoints/{self.config.logging.run_name}" 379 | output_path = f"./output/{self.config.logging.run_name}" 380 | self.accelerator.wait_for_everyone() 381 | self.save_checkpoint(f"{base_path}/{accum}") 382 | self.save_pretrained(output_path) 383 | 384 | last_epoch_time = time_per_1k(self.config.train.num_samples_per_epoch) 385 | 386 | del metrics, dataloader, experience_loader 387 | 388 | def save_checkpoint(self, fp : str, components = None): 389 | """ 390 | Save checkpoint in main process 391 | 392 | :param fp: File path to save checkpoint to 393 | """ 394 | if self.accelerator.is_main_process: 395 | os.makedirs(fp, exist_ok = True) 396 | self.accelerator.save_state(output_dir=fp) 397 | self.accelerator.wait_for_everyone() # need to use this twice or a corrupted state is saved 398 | 399 | def save_pretrained(self, fp : str): 400 | """ 401 | Save model into pretrained pipeline so it can be loaded in pipeline later 402 | 403 | :param fp: File path to save to 404 | """ 405 | if self.accelerator.is_main_process: 406 | os.makedirs(fp, exist_ok = True) 407 | unwrapped_model = self.accelerator.unwrap_model(self.model) 408 | if self.config.model.lora_rank is not None: 409 | unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_model.unet)) 410 | StableDiffusionPipeline.save_lora_weights(fp, unet_lora_layers=unet_lora_state_dict, safe_serialization = unwrapped_model.config.use_safetensors) 411 | else: 412 | self.pipe.unet = unwrapped_model.unet 413 | self.pipe.save_pretrained(fp, safe_serialization = unwrapped_model.config.use_safetensors) 414 | self.accelerator.wait_for_everyone() 415 | 416 | def extract_pipeline(self): 417 | """ 418 | Return original pipeline with finetuned denoiser plugged in 419 | 420 | :return: Diffusers pipeline 421 | """ 422 | 423 | self.pipe.unet = self.accelerator.unwrap_model(self.model).unet 424 | return self.pipe 425 | 426 | def load_checkpoint(self, fp : str): 427 | """ 428 | Load checkpoint 429 | 430 | :param fp: File path to checkpoint to load from 431 | """ 432 | self.accelerator.load_state(fp) 433 | self.accelerator.print("Succesfully loaded checkpoint") 434 | -------------------------------------------------------------------------------- /src/drlx/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | from enum import Enum 4 | from itertools import repeat 5 | from typing import Any, Dict, Iterable, Tuple 6 | from collections import deque 7 | import torch 8 | from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR 9 | from diffusers import StableDiffusionPipeline 10 | import logging 11 | import time 12 | from contextlib import contextmanager 13 | from PIL import Image 14 | 15 | import numpy as np 16 | 17 | class OptimizerName(str, Enum): 18 | """Supported optimizer names""" 19 | 20 | ADAM: str = "adam" 21 | ADAMW: str = "adamw" 22 | ADAM_8BIT_BNB: str = "adam_8bit_bnb" 23 | ADAMW_8BIT_BNB: str = "adamw_8bit_bnb" 24 | SGD: str = "sgd" 25 | 26 | 27 | def get_optimizer_class(name: OptimizerName): 28 | """ 29 | Returns the optimizer class with the given name 30 | 31 | Args: 32 | name (str): Name of the optimizer as found in `OptimizerNames` 33 | """ 34 | if name == OptimizerName.ADAM: 35 | return torch.optim.Adam 36 | if name == OptimizerName.ADAMW: 37 | return torch.optim.AdamW 38 | if name == OptimizerName.ADAM_8BIT_BNB.value: 39 | try: 40 | from bitsandbytes.optim import Adam8bit 41 | 42 | return Adam8bit 43 | except ImportError: 44 | raise ImportError( 45 | "You must install the `bitsandbytes` package to use the 8-bit Adam. " 46 | "Install with: `pip install bitsandbytes`" 47 | ) 48 | if name == OptimizerName.ADAMW_8BIT_BNB.value: 49 | try: 50 | from bitsandbytes.optim import AdamW8bit 51 | 52 | return AdamW8bit 53 | except ImportError: 54 | raise ImportError( 55 | "You must install the `bitsandbytes` package to use 8-bit AdamW. " 56 | "Install with: `pip install bitsandbytes`" 57 | ) 58 | if name == OptimizerName.SGD.value: 59 | return torch.optim.SGD 60 | supported_optimizers = [o.value for o in OptimizerName] 61 | raise ValueError(f"`{name}` is not a supported optimizer. " f"Supported optimizers are: {supported_optimizers}") 62 | 63 | 64 | class SchedulerName(str, Enum): 65 | """Supported scheduler names""" 66 | 67 | COSINE_ANNEALING = "cosine_annealing" 68 | LINEAR = "linear" 69 | 70 | 71 | def get_scheduler_class(name: SchedulerName): 72 | """ 73 | Returns the scheduler class with the given name 74 | """ 75 | if name == SchedulerName.COSINE_ANNEALING: 76 | return CosineAnnealingLR 77 | if name == SchedulerName.LINEAR: 78 | return LinearLR 79 | supported_schedulers = [s.value for s in SchedulerName] 80 | raise ValueError(f"`{name}` is not a supported scheduler. " f"Supported schedulers are: {supported_schedulers}") 81 | 82 | 83 | 84 | class DiffusionPipelineName(str, Enum): 85 | """Supported diffusion pipeline names""" 86 | StableDiffusion = "stable_diffusion" 87 | 88 | def get_diffusion_pipeline_class(name: DiffusionPipelineName): 89 | """ 90 | Returns the diffusion pipeline class with the given name 91 | """ 92 | if name == DiffusionPipelineName.StableDiffusion: 93 | return StableDiffusionPipeline 94 | supported_diffusion_pipelines = [d.value for d in DiffusionPipelineName] 95 | raise ValueError(f"`{name}` is not a supported diffusion pipeline. " f"Supported diffusion pipelines are: {supported_diffusion_pipelines}") 96 | 97 | def any_chunk(x, chunk_size): 98 | """ 99 | Chunks any iterable by chunk size 100 | """ 101 | is_tensor = isinstance(x, torch.Tensor) 102 | 103 | x_chunks = [x[i:i+chunk_size] for i in range(0, len(x), chunk_size)] 104 | return torch.stack(x_chunks) if is_tensor else x_chunks 105 | 106 | def suppress_warnings(prefix : str): 107 | """ 108 | With logging module, suppresses any warnings that are coming from a logger 109 | with a given prefix 110 | """ 111 | 112 | names = logging.root.manager.loggerDict 113 | names = list(filter(lambda x: x.startswith(prefix), names)) 114 | for name in names: 115 | logging.getLogger(name).setLevel(logging.ERROR) 116 | 117 | class Timer: 118 | """ 119 | Utility class for timing models 120 | """ 121 | def __init__(self): 122 | self.time = time.time() 123 | 124 | def hit(self) -> float: 125 | """ 126 | Restarts timer and returns the time in seconds since last restart or initialization 127 | """ 128 | new_time = time.time() 129 | res = new_time - self.time 130 | self.time = new_time 131 | return res 132 | 133 | def get_latest_checkpoint(checkpoint_root): 134 | """ 135 | Assume folder root_dir stores checkpoints for model, all named numerically (in terms of training steps associated with said checkpoints). 136 | This function returns the path to the latest checkpoint, aka the subfolder with largest numerical name. Returns none if the root dir is empty 137 | """ 138 | subdirs = glob.glob(os.path.join(checkpoint_root, '*')) 139 | if not subdirs: 140 | return None 141 | 142 | # Filter out any paths that are not directories or are not numeric 143 | subdirs = [s for s in subdirs if os.path.isdir(s) and os.path.basename(s).isdigit()] 144 | # Find the maximum directory number (assuming all subdirectories are numeric) 145 | latest_checkpoint = max(subdirs, key=lambda s: int(os.path.basename(s))) 146 | return latest_checkpoint 147 | 148 | class PerPromptStatTracker: 149 | """ 150 | Stat tracker to normalize rewards across prompts. If there is a sufficient number of duplicate prompts, averages across rewards given for that specific prompts. Otherwise, simply averages across all rewards. 151 | 152 | :param buffer_size: How many prompts to consider for average 153 | :type buffer_size: int 154 | 155 | :param min_count: How many duplicates for a prompt minimum before we average over that prompt and not over all prompts 156 | :type min_count: int 157 | """ 158 | def __init__(self, buffer_size : int, min_count : int): 159 | self.buffer_size = buffer_size 160 | self.min_count = min_count 161 | self.stats = {} 162 | 163 | def update(self, prompts, rewards): 164 | unique = np.unique(prompts) 165 | advantages = np.empty_like(rewards) 166 | for prompt in unique: 167 | prompt_rewards = rewards[prompts == prompt] 168 | if prompt not in self.stats: 169 | self.stats[prompt] = deque(maxlen=self.buffer_size) 170 | self.stats[prompt].extend(prompt_rewards) 171 | 172 | if len(self.stats[prompt]) < self.min_count: 173 | mean = np.mean(rewards) 174 | std = np.std(rewards) + 1e-6 175 | else: 176 | mean = np.mean(self.stats[prompt]) 177 | std = np.std(self.stats[prompt]) + 1e-6 178 | advantages[prompts == prompt] = (prompt_rewards - mean) / std 179 | 180 | return advantages 181 | 182 | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): 183 | """ 184 | From [diffusers repository](https://github.com/huggingface/diffusers/blob/a7508a76f025fcbe104c28f73dd17c8e866f655b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L58). 185 | Copied here due to import errors when attempting to import from package 186 | """ 187 | std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) 188 | std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) 189 | # rescale the results from guidance (fixes overexposure) 190 | noise_pred_rescaled = noise_cfg * (std_text / std_cfg) 191 | # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images 192 | noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg 193 | return noise_cfg 194 | 195 | @contextmanager 196 | def scoped_seed(seed : int = 0): 197 | """ 198 | Set torch seed within a context. Useful for deterministic sampling. 199 | 200 | :param seed: Seed to use for random state 201 | :type seed: int 202 | """ 203 | # Record the state of the RNG 204 | cpu_rng_state = torch.get_rng_state() 205 | if torch.cuda.is_available(): 206 | cuda_rng_state = torch.cuda.get_rng_state() 207 | 208 | # Set the desired seed 209 | torch.manual_seed(seed) 210 | 211 | try: 212 | yield 213 | finally: 214 | # Restore the previous RNG state after exiting the scope 215 | torch.set_rng_state(cpu_rng_state) 216 | if torch.cuda.is_available(): 217 | torch.cuda.set_rng_state(cuda_rng_state) 218 | 219 | def save_images(images : np.array, fp : str): 220 | """ 221 | Saves images to folder designated by fp 222 | """ 223 | 224 | os.makedirs(fp, exist_ok = True) 225 | 226 | images = [Image.fromarray(image) for image in images] 227 | for i, image in enumerate(images): 228 | image.save(os.path.join(fp,f"{i}.png")) 229 | 230 | 231 | -------------------------------------------------------------------------------- /tests/accelerate_checkpoint_test.py: -------------------------------------------------------------------------------- 1 | from drlx.trainer.ddpo_trainer import DDPOTrainer 2 | from drlx.configs import DRLXConfig 3 | import torch 4 | import os 5 | 6 | config = DRLXConfig.load_yaml("configs/ddpo_sd_imagenet.yml") 7 | trainer = DDPOTrainer(config) 8 | 9 | fp = "./checkpoints_saving_test" 10 | trainer.save_pretrained("./output/saving_test") 11 | trainer.save_checkpoint(fp) 12 | 13 | trainer.load_checkpoint(fp) 14 | 15 | from diffusers import StableDiffusionPipeline 16 | 17 | #pipe = trainer.extract_pipeline() 18 | pipe = StableDiffusionPipeline.from_pretrained("./output/saving_test") 19 | print("Successfully loaded pipeline") -------------------------------------------------------------------------------- /tests/ddpo_unet_pipeline_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script tests if pipelines + LDM unet work with DDPO sampler 3 | """ 4 | 5 | from drlx.denoisers.ldm_unet import LDMUNet 6 | from drlx.configs import ModelConfig, SamplerConfig 7 | from drlx.sampling import DDPOSampler 8 | from diffusers import StableDiffusionPipeline 9 | 10 | import torch 11 | import numpy as np 12 | from PIL import Image 13 | from tqdm import tqdm 14 | 15 | # Pipeline first 16 | from drlx.pipeline.prompt_pipeline import PromptPipeline 17 | 18 | class ToyPipeline(PromptPipeline): 19 | def __init__(self): 20 | super().__init__() 21 | 22 | self.dataset = ["A cat", "A dog", "A bird", "A fish"] * 100 23 | 24 | def __getitem__(self, index): 25 | return self.dataset[index] 26 | 27 | def __len__(self): 28 | return len(self.dataset) 29 | 30 | def create_loader(self, batch_size): 31 | return torch.utils.data.DataLoader(self, batch_size) 32 | 33 | 34 | model = LDMUNet(ModelConfig(), sampler = DDPOSampler()) 35 | model.from_pretrained_pipeline(StableDiffusionPipeline, "CompVis/stable-diffusion-v1-4") 36 | model = model.to('cuda') 37 | 38 | pipe = ToyPipeline() 39 | 40 | text = "A cat" 41 | input_ids, attention_mask = model.preprocess([text]) 42 | 43 | loader = pipe.create_loader(8) 44 | 45 | in_shape = model.get_input_shape() 46 | 47 | for prompts in loader: 48 | with torch.no_grad(): 49 | 50 | latents, all_preds, log_probs = model.sampler.sample( 51 | prompts, 52 | model, 53 | device = 'cuda' 54 | ) 55 | 56 | print(latents.shape) 57 | 58 | exit() 59 | -------------------------------------------------------------------------------- /visualization/README.md: -------------------------------------------------------------------------------- 1 | # Visualization 2 | 3 | This folder contains scripts useful for visualizing results from samples. The scripts contain variables that you can set to customize them. 4 | 5 | `vid_from_sample.py` Generates a video from samples outputted during a training run. You can control FPS, which run to animate, and the index of the sample from that run you want to animate. Saves to the same folder as the samples as an mp4. -------------------------------------------------------------------------------- /visualization/vid_from_sample.py: -------------------------------------------------------------------------------- 1 | import skvideo.io as skv 2 | from PIL import Image, ImageDraw, ImageFont 3 | from tqdm import tqdm 4 | import os 5 | import numpy as np 6 | 7 | RUN_NAME = "new_samples" 8 | SAMPLE_INDEX = 0 # Which sample prompt to animate? 9 | FPS = 15 10 | 11 | def animate(run_name, sample_index, fps, add_timestamps = False): 12 | """ 13 | Takes a specific run, a specific sample prompt and animates all the images resulting from sampling with that prompt during training. 14 | Useful to visualize how a models generation given a single prompt changed over time. Note this script may get permission errors 15 | when trying to save videos using FFMPEG if it is not setup properly. 16 | 17 | :param run_name: Name of run to draw samples from 18 | :param sample_index: Index of sample to animate 19 | :param fps: FPS for resulting video 20 | :param add_timesteps: Add timestamps? Note: needs fonts to be installed, otherwise will raise error 21 | """ 22 | root_path = f"./samples/{run_name}/" 23 | 24 | paths = os.listdir(root_path) 25 | paths = list(sorted(paths, key=lambda x : int(x))) 26 | paths = [os.path.join(root_path, path, str(sample_index)+".png") for path in paths] 27 | 28 | print("Loading as Images") 29 | imgs = [Image.open(path) for path in tqdm(paths)] 30 | 31 | if add_timestamps: 32 | print("Adding timestamps") 33 | for i in tqdm(range(len(imgs))): 34 | draw = ImageDraw.Draw(imgs[i]) 35 | font = ImageFont.truetype("arial", 20) 36 | draw.text((10, 10), str(i), fill="white", font=font) 37 | 38 | print("Saving as video") 39 | frames = np.stack([np.asarray(img) for img in imgs]) 40 | 41 | skv.vwrite(f"{run_name}/{sample_index}.mp4", frames, outputdict={"-r": f"{fps}"}) 42 | 43 | 44 | if __name__ == "__main__": 45 | animate(RUN_NAME,SAMPLE_INDEX,FPS) 46 | 47 | 48 | --------------------------------------------------------------------------------