├── .github └── workflows │ ├── black.yml │ ├── test-build.yaml │ └── test-inference.yml ├── .gitignore ├── CODEOWNERS ├── LICENSE-CODE ├── README.md ├── assets ├── 000.jpg ├── 001_with_eval.png ├── sv3d.gif ├── sv4d.gif ├── sv4d2.gif ├── sv4d_videos │ ├── bear.gif │ ├── bee.gif │ ├── bmx-bumps.gif │ ├── camel.gif │ ├── chameleon.gif │ ├── chest.gif │ ├── cows.gif │ ├── dance-twirl.gif │ ├── flag.gif │ ├── gear.gif │ ├── hike.gif │ ├── horsejump-low.gif │ ├── robot.gif │ ├── snowboard.gif │ ├── test_video1.mp4 │ └── windmill.gif ├── test_image.png ├── tile.gif └── turbo_tile.png ├── configs ├── example_training │ ├── autoencoder │ │ └── kl-f4 │ │ │ ├── imagenet-attnfree-logvar.yaml │ │ │ └── imagenet-kl_f8_8chn.yaml │ ├── imagenet-f8_cond.yaml │ ├── toy │ │ ├── cifar10_cond.yaml │ │ ├── mnist.yaml │ │ ├── mnist_cond.yaml │ │ ├── mnist_cond_discrete_eps.yaml │ │ ├── mnist_cond_l1_loss.yaml │ │ └── mnist_cond_with_ema.yaml │ ├── txt2img-clipl-legacy-ucg-training.yaml │ └── txt2img-clipl.yaml └── inference │ ├── sd_2_1.yaml │ ├── sd_2_1_768.yaml │ ├── sd_xl_base.yaml │ ├── sd_xl_refiner.yaml │ ├── sv3d_p.yaml │ ├── sv3d_u.yaml │ ├── svd.yaml │ └── svd_image_decoder.yaml ├── data └── DejaVuSans.ttf ├── main.py ├── model_licenses ├── LICENCE-SD-Turbo ├── LICENSE-SDXL-Turbo ├── LICENSE-SDXL0.9 ├── LICENSE-SDXL1.0 ├── LICENSE-SV3D └── LICENSE-SVD ├── pyproject.toml ├── pytest.ini ├── requirements └── pt2.txt ├── scripts ├── __init__.py ├── demo │ ├── __init__.py │ ├── detect.py │ ├── discretization.py │ ├── gradio_app.py │ ├── gradio_app_sv4d.py │ ├── sampling.py │ ├── streamlit_helpers.py │ ├── sv3d_helpers.py │ ├── sv4d_helpers.py │ ├── turbo.py │ └── video_sampling.py ├── sampling │ ├── configs │ │ ├── sv3d_p.yaml │ │ ├── sv3d_u.yaml │ │ ├── sv4d.yaml │ │ ├── sv4d2.yaml │ │ ├── sv4d2_8views.yaml │ │ ├── svd.yaml │ │ ├── svd_image_decoder.yaml │ │ ├── svd_xt.yaml │ │ ├── svd_xt_1_1.yaml │ │ └── svd_xt_image_decoder.yaml │ ├── simple_video_sample.py │ ├── simple_video_sample_4d.py │ └── simple_video_sample_4d2.py ├── tests │ └── attention.py └── util │ ├── __init__.py │ └── detection │ ├── __init__.py │ ├── nsfw_and_watermark_dectection.py │ ├── p_head_v1.npz │ └── w_head_v1.npz ├── sgm ├── __init__.py ├── data │ ├── __init__.py │ ├── cifar10.py │ ├── dataset.py │ └── mnist.py ├── inference │ ├── api.py │ └── helpers.py ├── lr_scheduler.py ├── models │ ├── __init__.py │ ├── autoencoder.py │ └── diffusion.py ├── modules │ ├── __init__.py │ ├── attention.py │ ├── autoencoding │ │ ├── __init__.py │ │ ├── losses │ │ │ ├── __init__.py │ │ │ ├── discriminator_loss.py │ │ │ └── lpips.py │ │ ├── lpips │ │ │ ├── __init__.py │ │ │ ├── loss │ │ │ │ ├── .gitignore │ │ │ │ ├── LICENSE │ │ │ │ ├── __init__.py │ │ │ │ └── lpips.py │ │ │ ├── model │ │ │ │ ├── LICENSE │ │ │ │ ├── __init__.py │ │ │ │ └── model.py │ │ │ ├── util.py │ │ │ └── vqperceptual.py │ │ ├── regularizers │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ └── quantize.py │ │ └── temporal_ae.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── denoiser.py │ │ ├── denoiser_scaling.py │ │ ├── denoiser_weighting.py │ │ ├── discretizer.py │ │ ├── guiders.py │ │ ├── loss.py │ │ ├── loss_weighting.py │ │ ├── model.py │ │ ├── openaimodel.py │ │ ├── sampling.py │ │ ├── sampling_utils.py │ │ ├── sigma_sampling.py │ │ ├── util.py │ │ ├── video_model.py │ │ └── wrappers.py │ ├── distributions │ │ ├── __init__.py │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ └── modules.py │ ├── spacetime_attention.py │ └── video_attention.py └── util.py └── tests └── inference └── test_inference.py /.github/workflows/black.yml: -------------------------------------------------------------------------------- 1 | name: Run black 2 | on: [pull_request] 3 | 4 | jobs: 5 | lint: 6 | runs-on: ubuntu-latest 7 | steps: 8 | - uses: actions/checkout@v3 9 | - name: Install venv 10 | run: | 11 | sudo apt-get -y install python3.10-venv 12 | - uses: psf/black@stable 13 | with: 14 | options: "--check --verbose -l88" 15 | src: "./sgm ./scripts ./main.py" 16 | -------------------------------------------------------------------------------- /.github/workflows/test-build.yaml: -------------------------------------------------------------------------------- 1 | name: Build package 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | 8 | jobs: 9 | build: 10 | name: Build 11 | runs-on: ubuntu-latest 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | python-version: ["3.8", "3.10"] 16 | requirements-file: ["pt2", "pt13"] 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install -r requirements/${{ matrix.requirements-file }}.txt 27 | pip install . -------------------------------------------------------------------------------- /.github/workflows/test-inference.yml: -------------------------------------------------------------------------------- 1 | name: Test inference 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - main 8 | 9 | jobs: 10 | test: 11 | name: "Test inference" 12 | # This action is designed only to run on the Stability research cluster at this time, so many assumptions are made about the environment 13 | if: github.repository == 'stability-ai/generative-models' 14 | runs-on: [self-hosted, slurm, g40] 15 | steps: 16 | - uses: actions/checkout@v3 17 | - name: "Symlink checkpoints" 18 | run: ln -s ${{vars.SGM_CHECKPOINTS_PATH}} checkpoints 19 | - name: "Setup python" 20 | uses: actions/setup-python@v4 21 | with: 22 | python-version: "3.10" 23 | - name: "Install Hatch" 24 | run: pip install hatch 25 | - name: "Run inference tests" 26 | run: hatch run ci:test-inference --junit-xml test-results.xml 27 | - name: Surface failing tests 28 | if: always() 29 | uses: pmeier/pytest-results-action@main 30 | with: 31 | path: test-results.xml 32 | summary: true 33 | display-options: fEX 34 | fail-on-empty: true 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # extensions 2 | *.egg-info 3 | *.py[cod] 4 | 5 | # envs 6 | .pt13 7 | .pt2 8 | 9 | # directories 10 | /checkpoints 11 | /dist 12 | /outputs 13 | /build 14 | /src 15 | /.vscode 16 | **/__pycache__/ 17 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | .github @Stability-AI/infrastructure -------------------------------------------------------------------------------- /LICENSE-CODE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Stability AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /assets/000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/assets/000.jpg -------------------------------------------------------------------------------- /assets/001_with_eval.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/assets/001_with_eval.png -------------------------------------------------------------------------------- /assets/sv3d.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/assets/sv3d.gif -------------------------------------------------------------------------------- /assets/sv4d.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/assets/sv4d.gif -------------------------------------------------------------------------------- /assets/sv4d2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/assets/sv4d2.gif -------------------------------------------------------------------------------- /assets/sv4d_videos/bear.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/assets/sv4d_videos/bear.gif -------------------------------------------------------------------------------- /assets/sv4d_videos/bee.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/assets/sv4d_videos/bee.gif -------------------------------------------------------------------------------- /assets/sv4d_videos/bmx-bumps.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/assets/sv4d_videos/bmx-bumps.gif -------------------------------------------------------------------------------- /assets/sv4d_videos/camel.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/assets/sv4d_videos/camel.gif -------------------------------------------------------------------------------- /assets/sv4d_videos/chameleon.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/assets/sv4d_videos/chameleon.gif -------------------------------------------------------------------------------- /assets/sv4d_videos/chest.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/assets/sv4d_videos/chest.gif -------------------------------------------------------------------------------- /assets/sv4d_videos/cows.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/assets/sv4d_videos/cows.gif -------------------------------------------------------------------------------- /assets/sv4d_videos/dance-twirl.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/assets/sv4d_videos/dance-twirl.gif -------------------------------------------------------------------------------- /assets/sv4d_videos/flag.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/assets/sv4d_videos/flag.gif -------------------------------------------------------------------------------- /assets/sv4d_videos/gear.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/assets/sv4d_videos/gear.gif -------------------------------------------------------------------------------- /assets/sv4d_videos/hike.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/assets/sv4d_videos/hike.gif -------------------------------------------------------------------------------- /assets/sv4d_videos/horsejump-low.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/assets/sv4d_videos/horsejump-low.gif -------------------------------------------------------------------------------- /assets/sv4d_videos/robot.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/assets/sv4d_videos/robot.gif -------------------------------------------------------------------------------- /assets/sv4d_videos/snowboard.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/assets/sv4d_videos/snowboard.gif -------------------------------------------------------------------------------- /assets/sv4d_videos/test_video1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/assets/sv4d_videos/test_video1.mp4 -------------------------------------------------------------------------------- /assets/sv4d_videos/windmill.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/assets/sv4d_videos/windmill.gif -------------------------------------------------------------------------------- /assets/test_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/assets/test_image.png -------------------------------------------------------------------------------- /assets/tile.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/assets/tile.gif -------------------------------------------------------------------------------- /assets/turbo_tile.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/assets/turbo_tile.png -------------------------------------------------------------------------------- /configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: sgm.models.autoencoder.AutoencodingEngine 4 | params: 5 | input_key: jpg 6 | monitor: val/rec_loss 7 | 8 | loss_config: 9 | target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator 10 | params: 11 | perceptual_weight: 0.25 12 | disc_start: 20001 13 | disc_weight: 0.5 14 | learn_logvar: True 15 | 16 | regularization_weights: 17 | kl_loss: 1.0 18 | 19 | regularizer_config: 20 | target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer 21 | 22 | encoder_config: 23 | target: sgm.modules.diffusionmodules.model.Encoder 24 | params: 25 | attn_type: none 26 | double_z: True 27 | z_channels: 4 28 | resolution: 256 29 | in_channels: 3 30 | out_ch: 3 31 | ch: 128 32 | ch_mult: [1, 2, 4] 33 | num_res_blocks: 4 34 | attn_resolutions: [] 35 | dropout: 0.0 36 | 37 | decoder_config: 38 | target: sgm.modules.diffusionmodules.model.Decoder 39 | params: ${model.params.encoder_config.params} 40 | 41 | data: 42 | target: sgm.data.dataset.StableDataModuleFromConfig 43 | params: 44 | train: 45 | datapipeline: 46 | urls: 47 | - DATA-PATH 48 | pipeline_config: 49 | shardshuffle: 10000 50 | sample_shuffle: 10000 51 | 52 | decoders: 53 | - pil 54 | 55 | postprocessors: 56 | - target: sdata.mappers.TorchVisionImageTransforms 57 | params: 58 | key: jpg 59 | transforms: 60 | - target: torchvision.transforms.Resize 61 | params: 62 | size: 256 63 | interpolation: 3 64 | - target: torchvision.transforms.ToTensor 65 | - target: sdata.mappers.Rescaler 66 | - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare 67 | params: 68 | h_key: height 69 | w_key: width 70 | 71 | loader: 72 | batch_size: 8 73 | num_workers: 4 74 | 75 | 76 | lightning: 77 | strategy: 78 | target: pytorch_lightning.strategies.DDPStrategy 79 | params: 80 | find_unused_parameters: True 81 | 82 | modelcheckpoint: 83 | params: 84 | every_n_train_steps: 5000 85 | 86 | callbacks: 87 | metrics_over_trainsteps_checkpoint: 88 | params: 89 | every_n_train_steps: 50000 90 | 91 | image_logger: 92 | target: main.ImageLogger 93 | params: 94 | enable_autocast: False 95 | batch_frequency: 1000 96 | max_images: 8 97 | increase_log_steps: True 98 | 99 | trainer: 100 | devices: 0, 101 | limit_val_batches: 50 102 | benchmark: True 103 | accumulate_grad_batches: 1 104 | val_check_interval: 10000 -------------------------------------------------------------------------------- /configs/example_training/autoencoder/kl-f4/imagenet-kl_f8_8chn.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: sgm.models.autoencoder.AutoencodingEngine 4 | params: 5 | input_key: jpg 6 | monitor: val/loss/rec 7 | disc_start_iter: 0 8 | 9 | encoder_config: 10 | target: sgm.modules.diffusionmodules.model.Encoder 11 | params: 12 | attn_type: vanilla-xformers 13 | double_z: true 14 | z_channels: 8 15 | resolution: 256 16 | in_channels: 3 17 | out_ch: 3 18 | ch: 128 19 | ch_mult: [1, 2, 4, 4] 20 | num_res_blocks: 2 21 | attn_resolutions: [] 22 | dropout: 0.0 23 | 24 | decoder_config: 25 | target: sgm.modules.diffusionmodules.model.Decoder 26 | params: ${model.params.encoder_config.params} 27 | 28 | regularizer_config: 29 | target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer 30 | 31 | loss_config: 32 | target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator 33 | params: 34 | perceptual_weight: 0.25 35 | disc_start: 20001 36 | disc_weight: 0.5 37 | learn_logvar: True 38 | 39 | regularization_weights: 40 | kl_loss: 1.0 41 | 42 | data: 43 | target: sgm.data.dataset.StableDataModuleFromConfig 44 | params: 45 | train: 46 | datapipeline: 47 | urls: 48 | - DATA-PATH 49 | pipeline_config: 50 | shardshuffle: 10000 51 | sample_shuffle: 10000 52 | 53 | decoders: 54 | - pil 55 | 56 | postprocessors: 57 | - target: sdata.mappers.TorchVisionImageTransforms 58 | params: 59 | key: jpg 60 | transforms: 61 | - target: torchvision.transforms.Resize 62 | params: 63 | size: 256 64 | interpolation: 3 65 | - target: torchvision.transforms.ToTensor 66 | - target: sdata.mappers.Rescaler 67 | - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare 68 | params: 69 | h_key: height 70 | w_key: width 71 | 72 | loader: 73 | batch_size: 8 74 | num_workers: 4 75 | 76 | 77 | lightning: 78 | strategy: 79 | target: pytorch_lightning.strategies.DDPStrategy 80 | params: 81 | find_unused_parameters: True 82 | 83 | modelcheckpoint: 84 | params: 85 | every_n_train_steps: 5000 86 | 87 | callbacks: 88 | metrics_over_trainsteps_checkpoint: 89 | params: 90 | every_n_train_steps: 50000 91 | 92 | image_logger: 93 | target: main.ImageLogger 94 | params: 95 | enable_autocast: False 96 | batch_frequency: 1000 97 | max_images: 8 98 | increase_log_steps: True 99 | 100 | trainer: 101 | devices: 0, 102 | limit_val_batches: 50 103 | benchmark: True 104 | accumulate_grad_batches: 1 105 | val_check_interval: 10000 106 | -------------------------------------------------------------------------------- /configs/example_training/imagenet-f8_cond.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 3 | target: sgm.models.diffusion.DiffusionEngine 4 | params: 5 | scale_factor: 0.13025 6 | disable_first_stage_autocast: True 7 | log_keys: 8 | - cls 9 | 10 | scheduler_config: 11 | target: sgm.lr_scheduler.LambdaLinearScheduler 12 | params: 13 | warm_up_steps: [10000] 14 | cycle_lengths: [10000000000000] 15 | f_start: [1.e-6] 16 | f_max: [1.] 17 | f_min: [1.] 18 | 19 | denoiser_config: 20 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser 21 | params: 22 | num_idx: 1000 23 | 24 | scaling_config: 25 | target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling 26 | discretization_config: 27 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 28 | 29 | network_config: 30 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | use_checkpoint: True 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 256 36 | attention_resolutions: [1, 2, 4] 37 | num_res_blocks: 2 38 | channel_mult: [1, 2, 4] 39 | num_head_channels: 64 40 | num_classes: sequential 41 | adm_in_channels: 1024 42 | transformer_depth: 1 43 | context_dim: 1024 44 | spatial_transformer_attn_type: softmax-xformers 45 | 46 | conditioner_config: 47 | target: sgm.modules.GeneralConditioner 48 | params: 49 | emb_models: 50 | - is_trainable: True 51 | input_key: cls 52 | ucg_rate: 0.2 53 | target: sgm.modules.encoders.modules.ClassEmbedder 54 | params: 55 | add_sequence_dim: True 56 | embed_dim: 1024 57 | n_classes: 1000 58 | 59 | - is_trainable: False 60 | ucg_rate: 0.2 61 | input_key: original_size_as_tuple 62 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 63 | params: 64 | outdim: 256 65 | 66 | - is_trainable: False 67 | input_key: crop_coords_top_left 68 | ucg_rate: 0.2 69 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 70 | params: 71 | outdim: 256 72 | 73 | first_stage_config: 74 | target: sgm.models.autoencoder.AutoencoderKL 75 | params: 76 | ckpt_path: CKPT_PATH 77 | embed_dim: 4 78 | monitor: val/rec_loss 79 | ddconfig: 80 | attn_type: vanilla-xformers 81 | double_z: true 82 | z_channels: 4 83 | resolution: 256 84 | in_channels: 3 85 | out_ch: 3 86 | ch: 128 87 | ch_mult: [1, 2, 4, 4] 88 | num_res_blocks: 2 89 | attn_resolutions: [] 90 | dropout: 0.0 91 | lossconfig: 92 | target: torch.nn.Identity 93 | 94 | loss_fn_config: 95 | target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss 96 | params: 97 | loss_weighting_config: 98 | target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting 99 | sigma_sampler_config: 100 | target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling 101 | params: 102 | num_idx: 1000 103 | 104 | discretization_config: 105 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 106 | 107 | sampler_config: 108 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler 109 | params: 110 | num_steps: 50 111 | 112 | discretization_config: 113 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 114 | 115 | guider_config: 116 | target: sgm.modules.diffusionmodules.guiders.VanillaCFG 117 | params: 118 | scale: 5.0 119 | 120 | data: 121 | target: sgm.data.dataset.StableDataModuleFromConfig 122 | params: 123 | train: 124 | datapipeline: 125 | urls: 126 | # USER: adapt this path the root of your custom dataset 127 | - DATA_PATH 128 | pipeline_config: 129 | shardshuffle: 10000 130 | sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM 131 | 132 | decoders: 133 | - pil 134 | 135 | postprocessors: 136 | - target: sdata.mappers.TorchVisionImageTransforms 137 | params: 138 | key: jpg # USER: you might wanna adapt this for your custom dataset 139 | transforms: 140 | - target: torchvision.transforms.Resize 141 | params: 142 | size: 256 143 | interpolation: 3 144 | - target: torchvision.transforms.ToTensor 145 | - target: sdata.mappers.Rescaler 146 | 147 | - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare 148 | params: 149 | h_key: height # USER: you might wanna adapt this for your custom dataset 150 | w_key: width # USER: you might wanna adapt this for your custom dataset 151 | 152 | loader: 153 | batch_size: 64 154 | num_workers: 6 155 | 156 | lightning: 157 | modelcheckpoint: 158 | params: 159 | every_n_train_steps: 5000 160 | 161 | callbacks: 162 | metrics_over_trainsteps_checkpoint: 163 | params: 164 | every_n_train_steps: 25000 165 | 166 | image_logger: 167 | target: main.ImageLogger 168 | params: 169 | disabled: False 170 | enable_autocast: False 171 | batch_frequency: 1000 172 | max_images: 8 173 | increase_log_steps: True 174 | log_first_step: False 175 | log_images_kwargs: 176 | use_ema_scope: False 177 | N: 8 178 | n_rows: 2 179 | 180 | trainer: 181 | devices: 0, 182 | benchmark: True 183 | num_sanity_val_steps: 0 184 | accumulate_grad_batches: 1 185 | max_epochs: 1000 -------------------------------------------------------------------------------- /configs/example_training/toy/cifar10_cond.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 3 | target: sgm.models.diffusion.DiffusionEngine 4 | params: 5 | denoiser_config: 6 | target: sgm.modules.diffusionmodules.denoiser.Denoiser 7 | params: 8 | scaling_config: 9 | target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling 10 | params: 11 | sigma_data: 1.0 12 | 13 | network_config: 14 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel 15 | params: 16 | in_channels: 3 17 | out_channels: 3 18 | model_channels: 32 19 | attention_resolutions: [] 20 | num_res_blocks: 4 21 | channel_mult: [1, 2, 2] 22 | num_head_channels: 32 23 | num_classes: sequential 24 | adm_in_channels: 128 25 | 26 | conditioner_config: 27 | target: sgm.modules.GeneralConditioner 28 | params: 29 | emb_models: 30 | - is_trainable: True 31 | input_key: cls 32 | ucg_rate: 0.2 33 | target: sgm.modules.encoders.modules.ClassEmbedder 34 | params: 35 | embed_dim: 128 36 | n_classes: 10 37 | 38 | first_stage_config: 39 | target: sgm.models.autoencoder.IdentityFirstStage 40 | 41 | loss_fn_config: 42 | target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss 43 | params: 44 | loss_weighting_config: 45 | target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting 46 | params: 47 | sigma_data: 1.0 48 | sigma_sampler_config: 49 | target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling 50 | 51 | sampler_config: 52 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler 53 | params: 54 | num_steps: 50 55 | 56 | discretization_config: 57 | target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization 58 | 59 | guider_config: 60 | target: sgm.modules.diffusionmodules.guiders.VanillaCFG 61 | params: 62 | scale: 3.0 63 | 64 | data: 65 | target: sgm.data.cifar10.CIFAR10Loader 66 | params: 67 | batch_size: 512 68 | num_workers: 1 69 | 70 | lightning: 71 | modelcheckpoint: 72 | params: 73 | every_n_train_steps: 5000 74 | 75 | callbacks: 76 | metrics_over_trainsteps_checkpoint: 77 | params: 78 | every_n_train_steps: 25000 79 | 80 | image_logger: 81 | target: main.ImageLogger 82 | params: 83 | disabled: False 84 | batch_frequency: 1000 85 | max_images: 64 86 | increase_log_steps: True 87 | log_first_step: False 88 | log_images_kwargs: 89 | use_ema_scope: False 90 | N: 64 91 | n_rows: 8 92 | 93 | trainer: 94 | devices: 0, 95 | benchmark: True 96 | num_sanity_val_steps: 0 97 | accumulate_grad_batches: 1 98 | max_epochs: 20 -------------------------------------------------------------------------------- /configs/example_training/toy/mnist.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 3 | target: sgm.models.diffusion.DiffusionEngine 4 | params: 5 | denoiser_config: 6 | target: sgm.modules.diffusionmodules.denoiser.Denoiser 7 | params: 8 | scaling_config: 9 | target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling 10 | params: 11 | sigma_data: 1.0 12 | 13 | network_config: 14 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel 15 | params: 16 | in_channels: 1 17 | out_channels: 1 18 | model_channels: 32 19 | attention_resolutions: [] 20 | num_res_blocks: 4 21 | channel_mult: [1, 2, 2] 22 | num_head_channels: 32 23 | 24 | first_stage_config: 25 | target: sgm.models.autoencoder.IdentityFirstStage 26 | 27 | loss_fn_config: 28 | target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss 29 | params: 30 | loss_weighting_config: 31 | target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting 32 | params: 33 | sigma_data: 1.0 34 | sigma_sampler_config: 35 | target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling 36 | 37 | sampler_config: 38 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler 39 | params: 40 | num_steps: 50 41 | 42 | discretization_config: 43 | target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization 44 | 45 | data: 46 | target: sgm.data.mnist.MNISTLoader 47 | params: 48 | batch_size: 512 49 | num_workers: 1 50 | 51 | lightning: 52 | modelcheckpoint: 53 | params: 54 | every_n_train_steps: 5000 55 | 56 | callbacks: 57 | metrics_over_trainsteps_checkpoint: 58 | params: 59 | every_n_train_steps: 25000 60 | 61 | image_logger: 62 | target: main.ImageLogger 63 | params: 64 | disabled: False 65 | batch_frequency: 1000 66 | max_images: 64 67 | increase_log_steps: False 68 | log_first_step: False 69 | log_images_kwargs: 70 | use_ema_scope: False 71 | N: 64 72 | n_rows: 8 73 | 74 | trainer: 75 | devices: 0, 76 | benchmark: True 77 | num_sanity_val_steps: 0 78 | accumulate_grad_batches: 1 79 | max_epochs: 10 -------------------------------------------------------------------------------- /configs/example_training/toy/mnist_cond.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 3 | target: sgm.models.diffusion.DiffusionEngine 4 | params: 5 | denoiser_config: 6 | target: sgm.modules.diffusionmodules.denoiser.Denoiser 7 | params: 8 | scaling_config: 9 | target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling 10 | params: 11 | sigma_data: 1.0 12 | 13 | network_config: 14 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel 15 | params: 16 | in_channels: 1 17 | out_channels: 1 18 | model_channels: 32 19 | attention_resolutions: [] 20 | num_res_blocks: 4 21 | channel_mult: [1, 2, 2] 22 | num_head_channels: 32 23 | num_classes: sequential 24 | adm_in_channels: 128 25 | 26 | conditioner_config: 27 | target: sgm.modules.GeneralConditioner 28 | params: 29 | emb_models: 30 | - is_trainable: True 31 | input_key: cls 32 | ucg_rate: 0.2 33 | target: sgm.modules.encoders.modules.ClassEmbedder 34 | params: 35 | embed_dim: 128 36 | n_classes: 10 37 | 38 | first_stage_config: 39 | target: sgm.models.autoencoder.IdentityFirstStage 40 | 41 | loss_fn_config: 42 | target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss 43 | params: 44 | loss_weighting_config: 45 | target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting 46 | params: 47 | sigma_data: 1.0 48 | sigma_sampler_config: 49 | target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling 50 | 51 | sampler_config: 52 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler 53 | params: 54 | num_steps: 50 55 | 56 | discretization_config: 57 | target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization 58 | 59 | guider_config: 60 | target: sgm.modules.diffusionmodules.guiders.VanillaCFG 61 | params: 62 | scale: 3.0 63 | 64 | data: 65 | target: sgm.data.mnist.MNISTLoader 66 | params: 67 | batch_size: 512 68 | num_workers: 1 69 | 70 | lightning: 71 | modelcheckpoint: 72 | params: 73 | every_n_train_steps: 5000 74 | 75 | callbacks: 76 | metrics_over_trainsteps_checkpoint: 77 | params: 78 | every_n_train_steps: 25000 79 | 80 | image_logger: 81 | target: main.ImageLogger 82 | params: 83 | disabled: False 84 | batch_frequency: 1000 85 | max_images: 16 86 | increase_log_steps: True 87 | log_first_step: False 88 | log_images_kwargs: 89 | use_ema_scope: False 90 | N: 16 91 | n_rows: 4 92 | 93 | trainer: 94 | devices: 0, 95 | benchmark: True 96 | num_sanity_val_steps: 0 97 | accumulate_grad_batches: 1 98 | max_epochs: 20 -------------------------------------------------------------------------------- /configs/example_training/toy/mnist_cond_discrete_eps.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 3 | target: sgm.models.diffusion.DiffusionEngine 4 | params: 5 | denoiser_config: 6 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser 7 | params: 8 | num_idx: 1000 9 | 10 | scaling_config: 11 | target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling 12 | discretization_config: 13 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 14 | 15 | network_config: 16 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel 17 | params: 18 | in_channels: 1 19 | out_channels: 1 20 | model_channels: 32 21 | attention_resolutions: [] 22 | num_res_blocks: 4 23 | channel_mult: [1, 2, 2] 24 | num_head_channels: 32 25 | num_classes: sequential 26 | adm_in_channels: 128 27 | 28 | conditioner_config: 29 | target: sgm.modules.GeneralConditioner 30 | params: 31 | emb_models: 32 | - is_trainable: True 33 | input_key: cls 34 | ucg_rate: 0.2 35 | target: sgm.modules.encoders.modules.ClassEmbedder 36 | params: 37 | embed_dim: 128 38 | n_classes: 10 39 | 40 | first_stage_config: 41 | target: sgm.models.autoencoder.IdentityFirstStage 42 | 43 | loss_fn_config: 44 | target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss 45 | params: 46 | loss_weighting_config: 47 | target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting 48 | sigma_sampler_config: 49 | target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling 50 | params: 51 | num_idx: 1000 52 | 53 | discretization_config: 54 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 55 | 56 | sampler_config: 57 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler 58 | params: 59 | num_steps: 50 60 | 61 | discretization_config: 62 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 63 | 64 | guider_config: 65 | target: sgm.modules.diffusionmodules.guiders.VanillaCFG 66 | params: 67 | scale: 5.0 68 | 69 | data: 70 | target: sgm.data.mnist.MNISTLoader 71 | params: 72 | batch_size: 512 73 | num_workers: 1 74 | 75 | lightning: 76 | modelcheckpoint: 77 | params: 78 | every_n_train_steps: 5000 79 | 80 | callbacks: 81 | metrics_over_trainsteps_checkpoint: 82 | params: 83 | every_n_train_steps: 25000 84 | 85 | image_logger: 86 | target: main.ImageLogger 87 | params: 88 | disabled: False 89 | batch_frequency: 1000 90 | max_images: 16 91 | increase_log_steps: True 92 | log_first_step: False 93 | log_images_kwargs: 94 | use_ema_scope: False 95 | N: 16 96 | n_rows: 4 97 | 98 | trainer: 99 | devices: 0, 100 | benchmark: True 101 | num_sanity_val_steps: 0 102 | accumulate_grad_batches: 1 103 | max_epochs: 20 -------------------------------------------------------------------------------- /configs/example_training/toy/mnist_cond_l1_loss.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 3 | target: sgm.models.diffusion.DiffusionEngine 4 | params: 5 | denoiser_config: 6 | target: sgm.modules.diffusionmodules.denoiser.Denoiser 7 | params: 8 | scaling_config: 9 | target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling 10 | params: 11 | sigma_data: 1.0 12 | 13 | network_config: 14 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel 15 | params: 16 | in_channels: 1 17 | out_channels: 1 18 | model_channels: 32 19 | attention_resolutions: [] 20 | num_res_blocks: 4 21 | channel_mult: [1, 2, 2] 22 | num_head_channels: 32 23 | num_classes: sequential 24 | adm_in_channels: 128 25 | 26 | conditioner_config: 27 | target: sgm.modules.GeneralConditioner 28 | params: 29 | emb_models: 30 | - is_trainable: True 31 | input_key: cls 32 | ucg_rate: 0.2 33 | target: sgm.modules.encoders.modules.ClassEmbedder 34 | params: 35 | embed_dim: 128 36 | n_classes: 10 37 | 38 | first_stage_config: 39 | target: sgm.models.autoencoder.IdentityFirstStage 40 | 41 | loss_fn_config: 42 | target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss 43 | params: 44 | loss_type: l1 45 | loss_weighting_config: 46 | target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting 47 | params: 48 | sigma_data: 1.0 49 | sigma_sampler_config: 50 | target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling 51 | 52 | sampler_config: 53 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler 54 | params: 55 | num_steps: 50 56 | 57 | discretization_config: 58 | target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization 59 | 60 | guider_config: 61 | target: sgm.modules.diffusionmodules.guiders.VanillaCFG 62 | params: 63 | scale: 3.0 64 | 65 | data: 66 | target: sgm.data.mnist.MNISTLoader 67 | params: 68 | batch_size: 512 69 | num_workers: 1 70 | 71 | lightning: 72 | modelcheckpoint: 73 | params: 74 | every_n_train_steps: 5000 75 | 76 | callbacks: 77 | metrics_over_trainsteps_checkpoint: 78 | params: 79 | every_n_train_steps: 25000 80 | 81 | image_logger: 82 | target: main.ImageLogger 83 | params: 84 | disabled: False 85 | batch_frequency: 1000 86 | max_images: 64 87 | increase_log_steps: True 88 | log_first_step: False 89 | log_images_kwargs: 90 | use_ema_scope: False 91 | N: 64 92 | n_rows: 8 93 | 94 | trainer: 95 | devices: 0, 96 | benchmark: True 97 | num_sanity_val_steps: 0 98 | accumulate_grad_batches: 1 99 | max_epochs: 20 -------------------------------------------------------------------------------- /configs/example_training/toy/mnist_cond_with_ema.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 3 | target: sgm.models.diffusion.DiffusionEngine 4 | params: 5 | use_ema: True 6 | 7 | denoiser_config: 8 | target: sgm.modules.diffusionmodules.denoiser.Denoiser 9 | params: 10 | scaling_config: 11 | target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling 12 | params: 13 | sigma_data: 1.0 14 | 15 | network_config: 16 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel 17 | params: 18 | in_channels: 1 19 | out_channels: 1 20 | model_channels: 32 21 | attention_resolutions: [] 22 | num_res_blocks: 4 23 | channel_mult: [1, 2, 2] 24 | num_head_channels: 32 25 | num_classes: sequential 26 | adm_in_channels: 128 27 | 28 | conditioner_config: 29 | target: sgm.modules.GeneralConditioner 30 | params: 31 | emb_models: 32 | - is_trainable: True 33 | input_key: cls 34 | ucg_rate: 0.2 35 | target: sgm.modules.encoders.modules.ClassEmbedder 36 | params: 37 | embed_dim: 128 38 | n_classes: 10 39 | 40 | first_stage_config: 41 | target: sgm.models.autoencoder.IdentityFirstStage 42 | 43 | loss_fn_config: 44 | target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss 45 | params: 46 | loss_weighting_config: 47 | target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting 48 | params: 49 | sigma_data: 1.0 50 | sigma_sampler_config: 51 | target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling 52 | 53 | sampler_config: 54 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler 55 | params: 56 | num_steps: 50 57 | 58 | discretization_config: 59 | target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization 60 | 61 | guider_config: 62 | target: sgm.modules.diffusionmodules.guiders.VanillaCFG 63 | params: 64 | scale: 3.0 65 | 66 | data: 67 | target: sgm.data.mnist.MNISTLoader 68 | params: 69 | batch_size: 512 70 | num_workers: 1 71 | 72 | lightning: 73 | modelcheckpoint: 74 | params: 75 | every_n_train_steps: 5000 76 | 77 | callbacks: 78 | metrics_over_trainsteps_checkpoint: 79 | params: 80 | every_n_train_steps: 25000 81 | 82 | image_logger: 83 | target: main.ImageLogger 84 | params: 85 | disabled: False 86 | batch_frequency: 1000 87 | max_images: 64 88 | increase_log_steps: True 89 | log_first_step: False 90 | log_images_kwargs: 91 | use_ema_scope: False 92 | N: 64 93 | n_rows: 8 94 | 95 | trainer: 96 | devices: 0, 97 | benchmark: True 98 | num_sanity_val_steps: 0 99 | accumulate_grad_batches: 1 100 | max_epochs: 20 -------------------------------------------------------------------------------- /configs/example_training/txt2img-clipl-legacy-ucg-training.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 3 | target: sgm.models.diffusion.DiffusionEngine 4 | params: 5 | scale_factor: 0.13025 6 | disable_first_stage_autocast: True 7 | log_keys: 8 | - txt 9 | 10 | scheduler_config: 11 | target: sgm.lr_scheduler.LambdaLinearScheduler 12 | params: 13 | warm_up_steps: [10000] 14 | cycle_lengths: [10000000000000] 15 | f_start: [1.e-6] 16 | f_max: [1.] 17 | f_min: [1.] 18 | 19 | denoiser_config: 20 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser 21 | params: 22 | num_idx: 1000 23 | 24 | scaling_config: 25 | target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling 26 | discretization_config: 27 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 28 | 29 | network_config: 30 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | use_checkpoint: True 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [1, 2, 4] 37 | num_res_blocks: 2 38 | channel_mult: [1, 2, 4, 4] 39 | num_head_channels: 64 40 | num_classes: sequential 41 | adm_in_channels: 1792 42 | num_heads: 1 43 | transformer_depth: 1 44 | context_dim: 768 45 | spatial_transformer_attn_type: softmax-xformers 46 | 47 | conditioner_config: 48 | target: sgm.modules.GeneralConditioner 49 | params: 50 | emb_models: 51 | - is_trainable: True 52 | input_key: txt 53 | ucg_rate: 0.1 54 | legacy_ucg_value: "" 55 | target: sgm.modules.encoders.modules.FrozenCLIPEmbedder 56 | params: 57 | always_return_pooled: True 58 | 59 | - is_trainable: False 60 | ucg_rate: 0.1 61 | input_key: original_size_as_tuple 62 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 63 | params: 64 | outdim: 256 65 | 66 | - is_trainable: False 67 | input_key: crop_coords_top_left 68 | ucg_rate: 0.1 69 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 70 | params: 71 | outdim: 256 72 | 73 | first_stage_config: 74 | target: sgm.models.autoencoder.AutoencoderKL 75 | params: 76 | ckpt_path: CKPT_PATH 77 | embed_dim: 4 78 | monitor: val/rec_loss 79 | ddconfig: 80 | attn_type: vanilla-xformers 81 | double_z: true 82 | z_channels: 4 83 | resolution: 256 84 | in_channels: 3 85 | out_ch: 3 86 | ch: 128 87 | ch_mult: [ 1, 2, 4, 4 ] 88 | num_res_blocks: 2 89 | attn_resolutions: [ ] 90 | dropout: 0.0 91 | lossconfig: 92 | target: torch.nn.Identity 93 | 94 | loss_fn_config: 95 | target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss 96 | params: 97 | loss_weighting_config: 98 | target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting 99 | sigma_sampler_config: 100 | target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling 101 | params: 102 | num_idx: 1000 103 | 104 | discretization_config: 105 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 106 | 107 | sampler_config: 108 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler 109 | params: 110 | num_steps: 50 111 | 112 | discretization_config: 113 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 114 | 115 | guider_config: 116 | target: sgm.modules.diffusionmodules.guiders.VanillaCFG 117 | params: 118 | scale: 7.5 119 | 120 | data: 121 | target: sgm.data.dataset.StableDataModuleFromConfig 122 | params: 123 | train: 124 | datapipeline: 125 | urls: 126 | # USER: adapt this path the root of your custom dataset 127 | - DATA_PATH 128 | pipeline_config: 129 | shardshuffle: 10000 130 | sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM 131 | 132 | decoders: 133 | - pil 134 | 135 | postprocessors: 136 | - target: sdata.mappers.TorchVisionImageTransforms 137 | params: 138 | key: jpg # USER: you might wanna adapt this for your custom dataset 139 | transforms: 140 | - target: torchvision.transforms.Resize 141 | params: 142 | size: 256 143 | interpolation: 3 144 | - target: torchvision.transforms.ToTensor 145 | - target: sdata.mappers.Rescaler 146 | - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare 147 | # USER: you might wanna use non-default parameters due to your custom dataset 148 | 149 | loader: 150 | batch_size: 64 151 | num_workers: 6 152 | 153 | lightning: 154 | modelcheckpoint: 155 | params: 156 | every_n_train_steps: 5000 157 | 158 | callbacks: 159 | metrics_over_trainsteps_checkpoint: 160 | params: 161 | every_n_train_steps: 25000 162 | 163 | image_logger: 164 | target: main.ImageLogger 165 | params: 166 | disabled: False 167 | enable_autocast: False 168 | batch_frequency: 1000 169 | max_images: 8 170 | increase_log_steps: True 171 | log_first_step: False 172 | log_images_kwargs: 173 | use_ema_scope: False 174 | N: 8 175 | n_rows: 2 176 | 177 | trainer: 178 | devices: 0, 179 | benchmark: True 180 | num_sanity_val_steps: 0 181 | accumulate_grad_batches: 1 182 | max_epochs: 1000 -------------------------------------------------------------------------------- /configs/example_training/txt2img-clipl.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 3 | target: sgm.models.diffusion.DiffusionEngine 4 | params: 5 | scale_factor: 0.13025 6 | disable_first_stage_autocast: True 7 | log_keys: 8 | - txt 9 | 10 | scheduler_config: 11 | target: sgm.lr_scheduler.LambdaLinearScheduler 12 | params: 13 | warm_up_steps: [10000] 14 | cycle_lengths: [10000000000000] 15 | f_start: [1.e-6] 16 | f_max: [1.] 17 | f_min: [1.] 18 | 19 | denoiser_config: 20 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser 21 | params: 22 | num_idx: 1000 23 | 24 | scaling_config: 25 | target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling 26 | discretization_config: 27 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 28 | 29 | network_config: 30 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | use_checkpoint: True 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [1, 2, 4] 37 | num_res_blocks: 2 38 | channel_mult: [1, 2, 4, 4] 39 | num_head_channels: 64 40 | num_classes: sequential 41 | adm_in_channels: 1792 42 | num_heads: 1 43 | transformer_depth: 1 44 | context_dim: 768 45 | spatial_transformer_attn_type: softmax-xformers 46 | 47 | conditioner_config: 48 | target: sgm.modules.GeneralConditioner 49 | params: 50 | emb_models: 51 | - is_trainable: True 52 | input_key: txt 53 | ucg_rate: 0.1 54 | legacy_ucg_value: "" 55 | target: sgm.modules.encoders.modules.FrozenCLIPEmbedder 56 | params: 57 | always_return_pooled: True 58 | 59 | - is_trainable: False 60 | ucg_rate: 0.1 61 | input_key: original_size_as_tuple 62 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 63 | params: 64 | outdim: 256 65 | 66 | - is_trainable: False 67 | input_key: crop_coords_top_left 68 | ucg_rate: 0.1 69 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 70 | params: 71 | outdim: 256 72 | 73 | first_stage_config: 74 | target: sgm.models.autoencoder.AutoencoderKL 75 | params: 76 | ckpt_path: CKPT_PATH 77 | embed_dim: 4 78 | monitor: val/rec_loss 79 | ddconfig: 80 | attn_type: vanilla-xformers 81 | double_z: true 82 | z_channels: 4 83 | resolution: 256 84 | in_channels: 3 85 | out_ch: 3 86 | ch: 128 87 | ch_mult: [1, 2, 4, 4] 88 | num_res_blocks: 2 89 | attn_resolutions: [] 90 | dropout: 0.0 91 | lossconfig: 92 | target: torch.nn.Identity 93 | 94 | loss_fn_config: 95 | target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss 96 | params: 97 | loss_weighting_config: 98 | target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting 99 | sigma_sampler_config: 100 | target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling 101 | params: 102 | num_idx: 1000 103 | 104 | discretization_config: 105 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 106 | 107 | sampler_config: 108 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler 109 | params: 110 | num_steps: 50 111 | 112 | discretization_config: 113 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 114 | 115 | guider_config: 116 | target: sgm.modules.diffusionmodules.guiders.VanillaCFG 117 | params: 118 | scale: 7.5 119 | 120 | data: 121 | target: sgm.data.dataset.StableDataModuleFromConfig 122 | params: 123 | train: 124 | datapipeline: 125 | urls: 126 | # USER: adapt this path the root of your custom dataset 127 | - DATA_PATH 128 | pipeline_config: 129 | shardshuffle: 10000 130 | sample_shuffle: 10000 131 | 132 | 133 | decoders: 134 | - pil 135 | 136 | postprocessors: 137 | - target: sdata.mappers.TorchVisionImageTransforms 138 | params: 139 | key: jpg # USER: you might wanna adapt this for your custom dataset 140 | transforms: 141 | - target: torchvision.transforms.Resize 142 | params: 143 | size: 256 144 | interpolation: 3 145 | - target: torchvision.transforms.ToTensor 146 | - target: sdata.mappers.Rescaler 147 | # USER: you might wanna use non-default parameters due to your custom dataset 148 | - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare 149 | # USER: you might wanna use non-default parameters due to your custom dataset 150 | 151 | loader: 152 | batch_size: 64 153 | num_workers: 6 154 | 155 | lightning: 156 | modelcheckpoint: 157 | params: 158 | every_n_train_steps: 5000 159 | 160 | callbacks: 161 | metrics_over_trainsteps_checkpoint: 162 | params: 163 | every_n_train_steps: 25000 164 | 165 | image_logger: 166 | target: main.ImageLogger 167 | params: 168 | disabled: False 169 | enable_autocast: False 170 | batch_frequency: 1000 171 | max_images: 8 172 | increase_log_steps: True 173 | log_first_step: False 174 | log_images_kwargs: 175 | use_ema_scope: False 176 | N: 8 177 | n_rows: 2 178 | 179 | trainer: 180 | devices: 0, 181 | benchmark: True 182 | num_sanity_val_steps: 0 183 | accumulate_grad_batches: 1 184 | max_epochs: 1000 -------------------------------------------------------------------------------- /configs/inference/sd_2_1.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: sgm.models.diffusion.DiffusionEngine 3 | params: 4 | scale_factor: 0.18215 5 | disable_first_stage_autocast: True 6 | 7 | denoiser_config: 8 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser 9 | params: 10 | num_idx: 1000 11 | 12 | scaling_config: 13 | target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling 14 | discretization_config: 15 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 16 | 17 | network_config: 18 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | use_checkpoint: True 21 | in_channels: 4 22 | out_channels: 4 23 | model_channels: 320 24 | attention_resolutions: [4, 2, 1] 25 | num_res_blocks: 2 26 | channel_mult: [1, 2, 4, 4] 27 | num_head_channels: 64 28 | use_linear_in_transformer: True 29 | transformer_depth: 1 30 | context_dim: 1024 31 | 32 | conditioner_config: 33 | target: sgm.modules.GeneralConditioner 34 | params: 35 | emb_models: 36 | - is_trainable: False 37 | input_key: txt 38 | target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder 39 | params: 40 | freeze: true 41 | layer: penultimate 42 | 43 | first_stage_config: 44 | target: sgm.models.autoencoder.AutoencoderKL 45 | params: 46 | embed_dim: 4 47 | monitor: val/rec_loss 48 | ddconfig: 49 | double_z: true 50 | z_channels: 4 51 | resolution: 256 52 | in_channels: 3 53 | out_ch: 3 54 | ch: 128 55 | ch_mult: [1, 2, 4, 4] 56 | num_res_blocks: 2 57 | attn_resolutions: [] 58 | dropout: 0.0 59 | lossconfig: 60 | target: torch.nn.Identity -------------------------------------------------------------------------------- /configs/inference/sd_2_1_768.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: sgm.models.diffusion.DiffusionEngine 3 | params: 4 | scale_factor: 0.18215 5 | disable_first_stage_autocast: True 6 | 7 | denoiser_config: 8 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser 9 | params: 10 | num_idx: 1000 11 | 12 | scaling_config: 13 | target: sgm.modules.diffusionmodules.denoiser_scaling.VScaling 14 | discretization_config: 15 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 16 | 17 | network_config: 18 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | use_checkpoint: True 21 | in_channels: 4 22 | out_channels: 4 23 | model_channels: 320 24 | attention_resolutions: [4, 2, 1] 25 | num_res_blocks: 2 26 | channel_mult: [1, 2, 4, 4] 27 | num_head_channels: 64 28 | use_linear_in_transformer: True 29 | transformer_depth: 1 30 | context_dim: 1024 31 | 32 | conditioner_config: 33 | target: sgm.modules.GeneralConditioner 34 | params: 35 | emb_models: 36 | - is_trainable: False 37 | input_key: txt 38 | target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder 39 | params: 40 | freeze: true 41 | layer: penultimate 42 | 43 | first_stage_config: 44 | target: sgm.models.autoencoder.AutoencoderKL 45 | params: 46 | embed_dim: 4 47 | monitor: val/rec_loss 48 | ddconfig: 49 | double_z: true 50 | z_channels: 4 51 | resolution: 256 52 | in_channels: 3 53 | out_ch: 3 54 | ch: 128 55 | ch_mult: [1, 2, 4, 4] 56 | num_res_blocks: 2 57 | attn_resolutions: [] 58 | dropout: 0.0 59 | lossconfig: 60 | target: torch.nn.Identity -------------------------------------------------------------------------------- /configs/inference/sd_xl_base.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: sgm.models.diffusion.DiffusionEngine 3 | params: 4 | scale_factor: 0.13025 5 | disable_first_stage_autocast: True 6 | 7 | denoiser_config: 8 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser 9 | params: 10 | num_idx: 1000 11 | 12 | scaling_config: 13 | target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling 14 | discretization_config: 15 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 16 | 17 | network_config: 18 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | adm_in_channels: 2816 21 | num_classes: sequential 22 | use_checkpoint: True 23 | in_channels: 4 24 | out_channels: 4 25 | model_channels: 320 26 | attention_resolutions: [4, 2] 27 | num_res_blocks: 2 28 | channel_mult: [1, 2, 4] 29 | num_head_channels: 64 30 | use_linear_in_transformer: True 31 | transformer_depth: [1, 2, 10] 32 | context_dim: 2048 33 | spatial_transformer_attn_type: softmax-xformers 34 | 35 | conditioner_config: 36 | target: sgm.modules.GeneralConditioner 37 | params: 38 | emb_models: 39 | - is_trainable: False 40 | input_key: txt 41 | target: sgm.modules.encoders.modules.FrozenCLIPEmbedder 42 | params: 43 | layer: hidden 44 | layer_idx: 11 45 | 46 | - is_trainable: False 47 | input_key: txt 48 | target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2 49 | params: 50 | arch: ViT-bigG-14 51 | version: laion2b_s39b_b160k 52 | freeze: True 53 | layer: penultimate 54 | always_return_pooled: True 55 | legacy: False 56 | 57 | - is_trainable: False 58 | input_key: original_size_as_tuple 59 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 60 | params: 61 | outdim: 256 62 | 63 | - is_trainable: False 64 | input_key: crop_coords_top_left 65 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 66 | params: 67 | outdim: 256 68 | 69 | - is_trainable: False 70 | input_key: target_size_as_tuple 71 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 72 | params: 73 | outdim: 256 74 | 75 | first_stage_config: 76 | target: sgm.models.autoencoder.AutoencoderKL 77 | params: 78 | embed_dim: 4 79 | monitor: val/rec_loss 80 | ddconfig: 81 | attn_type: vanilla-xformers 82 | double_z: true 83 | z_channels: 4 84 | resolution: 256 85 | in_channels: 3 86 | out_ch: 3 87 | ch: 128 88 | ch_mult: [1, 2, 4, 4] 89 | num_res_blocks: 2 90 | attn_resolutions: [] 91 | dropout: 0.0 92 | lossconfig: 93 | target: torch.nn.Identity 94 | -------------------------------------------------------------------------------- /configs/inference/sd_xl_refiner.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: sgm.models.diffusion.DiffusionEngine 3 | params: 4 | scale_factor: 0.13025 5 | disable_first_stage_autocast: True 6 | 7 | denoiser_config: 8 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser 9 | params: 10 | num_idx: 1000 11 | 12 | scaling_config: 13 | target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling 14 | discretization_config: 15 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 16 | 17 | network_config: 18 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | adm_in_channels: 2560 21 | num_classes: sequential 22 | use_checkpoint: True 23 | in_channels: 4 24 | out_channels: 4 25 | model_channels: 384 26 | attention_resolutions: [4, 2] 27 | num_res_blocks: 2 28 | channel_mult: [1, 2, 4, 4] 29 | num_head_channels: 64 30 | use_linear_in_transformer: True 31 | transformer_depth: 4 32 | context_dim: [1280, 1280, 1280, 1280] 33 | spatial_transformer_attn_type: softmax-xformers 34 | 35 | conditioner_config: 36 | target: sgm.modules.GeneralConditioner 37 | params: 38 | emb_models: 39 | - is_trainable: False 40 | input_key: txt 41 | target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2 42 | params: 43 | arch: ViT-bigG-14 44 | version: laion2b_s39b_b160k 45 | legacy: False 46 | freeze: True 47 | layer: penultimate 48 | always_return_pooled: True 49 | 50 | - is_trainable: False 51 | input_key: original_size_as_tuple 52 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 53 | params: 54 | outdim: 256 55 | 56 | - is_trainable: False 57 | input_key: crop_coords_top_left 58 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 59 | params: 60 | outdim: 256 61 | 62 | - is_trainable: False 63 | input_key: aesthetic_score 64 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 65 | params: 66 | outdim: 256 67 | 68 | first_stage_config: 69 | target: sgm.models.autoencoder.AutoencoderKL 70 | params: 71 | embed_dim: 4 72 | monitor: val/rec_loss 73 | ddconfig: 74 | attn_type: vanilla-xformers 75 | double_z: true 76 | z_channels: 4 77 | resolution: 256 78 | in_channels: 3 79 | out_ch: 3 80 | ch: 128 81 | ch_mult: [1, 2, 4, 4] 82 | num_res_blocks: 2 83 | attn_resolutions: [] 84 | dropout: 0.0 85 | lossconfig: 86 | target: torch.nn.Identity 87 | -------------------------------------------------------------------------------- /configs/inference/sv3d_p.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: sgm.models.diffusion.DiffusionEngine 3 | params: 4 | scale_factor: 0.18215 5 | disable_first_stage_autocast: True 6 | 7 | denoiser_config: 8 | target: sgm.modules.diffusionmodules.denoiser.Denoiser 9 | params: 10 | scaling_config: 11 | target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise 12 | 13 | network_config: 14 | target: sgm.modules.diffusionmodules.video_model.VideoUNet 15 | params: 16 | adm_in_channels: 1280 17 | num_classes: sequential 18 | use_checkpoint: True 19 | in_channels: 8 20 | out_channels: 4 21 | model_channels: 320 22 | attention_resolutions: [4, 2, 1] 23 | num_res_blocks: 2 24 | channel_mult: [1, 2, 4, 4] 25 | num_head_channels: 64 26 | use_linear_in_transformer: True 27 | transformer_depth: 1 28 | context_dim: 1024 29 | spatial_transformer_attn_type: softmax-xformers 30 | extra_ff_mix_layer: True 31 | use_spatial_context: True 32 | merge_strategy: learned_with_images 33 | video_kernel_size: [3, 1, 1] 34 | 35 | conditioner_config: 36 | target: sgm.modules.GeneralConditioner 37 | params: 38 | emb_models: 39 | - input_key: cond_frames_without_noise 40 | is_trainable: False 41 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder 42 | params: 43 | n_cond_frames: 1 44 | n_copies: 1 45 | open_clip_embedding_config: 46 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder 47 | params: 48 | freeze: True 49 | 50 | - input_key: cond_frames 51 | is_trainable: False 52 | target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder 53 | params: 54 | disable_encoder_autocast: True 55 | n_cond_frames: 1 56 | n_copies: 1 57 | is_ae: True 58 | encoder_config: 59 | target: sgm.models.autoencoder.AutoencoderKLModeOnly 60 | params: 61 | embed_dim: 4 62 | monitor: val/rec_loss 63 | ddconfig: 64 | attn_type: vanilla-xformers 65 | double_z: True 66 | z_channels: 4 67 | resolution: 256 68 | in_channels: 3 69 | out_ch: 3 70 | ch: 128 71 | ch_mult: [1, 2, 4, 4] 72 | num_res_blocks: 2 73 | attn_resolutions: [] 74 | dropout: 0.0 75 | lossconfig: 76 | target: torch.nn.Identity 77 | 78 | - input_key: cond_aug 79 | is_trainable: False 80 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 81 | params: 82 | outdim: 256 83 | 84 | - input_key: polars_rad 85 | is_trainable: False 86 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 87 | params: 88 | outdim: 512 89 | 90 | - input_key: azimuths_rad 91 | is_trainable: False 92 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 93 | params: 94 | outdim: 512 95 | 96 | first_stage_config: 97 | target: sgm.models.autoencoder.AutoencodingEngine 98 | params: 99 | loss_config: 100 | target: torch.nn.Identity 101 | regularizer_config: 102 | target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer 103 | encoder_config: 104 | target: torch.nn.Identity 105 | decoder_config: 106 | target: sgm.modules.diffusionmodules.model.Decoder 107 | params: 108 | attn_type: vanilla-xformers 109 | double_z: True 110 | z_channels: 4 111 | resolution: 256 112 | in_channels: 3 113 | out_ch: 3 114 | ch: 128 115 | ch_mult: [ 1, 2, 4, 4 ] 116 | num_res_blocks: 2 117 | attn_resolutions: [ ] 118 | dropout: 0.0 -------------------------------------------------------------------------------- /configs/inference/sv3d_u.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: sgm.models.diffusion.DiffusionEngine 3 | params: 4 | scale_factor: 0.18215 5 | disable_first_stage_autocast: True 6 | 7 | denoiser_config: 8 | target: sgm.modules.diffusionmodules.denoiser.Denoiser 9 | params: 10 | scaling_config: 11 | target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise 12 | 13 | network_config: 14 | target: sgm.modules.diffusionmodules.video_model.VideoUNet 15 | params: 16 | adm_in_channels: 256 17 | num_classes: sequential 18 | use_checkpoint: True 19 | in_channels: 8 20 | out_channels: 4 21 | model_channels: 320 22 | attention_resolutions: [4, 2, 1] 23 | num_res_blocks: 2 24 | channel_mult: [1, 2, 4, 4] 25 | num_head_channels: 64 26 | use_linear_in_transformer: True 27 | transformer_depth: 1 28 | context_dim: 1024 29 | spatial_transformer_attn_type: softmax-xformers 30 | extra_ff_mix_layer: True 31 | use_spatial_context: True 32 | merge_strategy: learned_with_images 33 | video_kernel_size: [3, 1, 1] 34 | 35 | conditioner_config: 36 | target: sgm.modules.GeneralConditioner 37 | params: 38 | emb_models: 39 | - input_key: cond_frames_without_noise 40 | is_trainable: False 41 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder 42 | params: 43 | n_cond_frames: 1 44 | n_copies: 1 45 | open_clip_embedding_config: 46 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder 47 | params: 48 | freeze: True 49 | 50 | - input_key: cond_frames 51 | is_trainable: False 52 | target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder 53 | params: 54 | disable_encoder_autocast: True 55 | n_cond_frames: 1 56 | n_copies: 1 57 | is_ae: True 58 | encoder_config: 59 | target: sgm.models.autoencoder.AutoencoderKLModeOnly 60 | params: 61 | embed_dim: 4 62 | monitor: val/rec_loss 63 | ddconfig: 64 | attn_type: vanilla-xformers 65 | double_z: True 66 | z_channels: 4 67 | resolution: 256 68 | in_channels: 3 69 | out_ch: 3 70 | ch: 128 71 | ch_mult: [1, 2, 4, 4] 72 | num_res_blocks: 2 73 | attn_resolutions: [] 74 | dropout: 0.0 75 | lossconfig: 76 | target: torch.nn.Identity 77 | 78 | - input_key: cond_aug 79 | is_trainable: False 80 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 81 | params: 82 | outdim: 256 83 | 84 | first_stage_config: 85 | target: sgm.models.autoencoder.AutoencodingEngine 86 | params: 87 | loss_config: 88 | target: torch.nn.Identity 89 | regularizer_config: 90 | target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer 91 | encoder_config: 92 | target: torch.nn.Identity 93 | decoder_config: 94 | target: sgm.modules.diffusionmodules.model.Decoder 95 | params: 96 | attn_type: vanilla-xformers 97 | double_z: True 98 | z_channels: 4 99 | resolution: 256 100 | in_channels: 3 101 | out_ch: 3 102 | ch: 128 103 | ch_mult: [ 1, 2, 4, 4 ] 104 | num_res_blocks: 2 105 | attn_resolutions: [ ] 106 | dropout: 0.0 -------------------------------------------------------------------------------- /configs/inference/svd.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: sgm.models.diffusion.DiffusionEngine 3 | params: 4 | scale_factor: 0.18215 5 | disable_first_stage_autocast: True 6 | 7 | denoiser_config: 8 | target: sgm.modules.diffusionmodules.denoiser.Denoiser 9 | params: 10 | scaling_config: 11 | target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise 12 | 13 | network_config: 14 | target: sgm.modules.diffusionmodules.video_model.VideoUNet 15 | params: 16 | adm_in_channels: 768 17 | num_classes: sequential 18 | use_checkpoint: True 19 | in_channels: 8 20 | out_channels: 4 21 | model_channels: 320 22 | attention_resolutions: [4, 2, 1] 23 | num_res_blocks: 2 24 | channel_mult: [1, 2, 4, 4] 25 | num_head_channels: 64 26 | use_linear_in_transformer: True 27 | transformer_depth: 1 28 | context_dim: 1024 29 | spatial_transformer_attn_type: softmax-xformers 30 | extra_ff_mix_layer: True 31 | use_spatial_context: True 32 | merge_strategy: learned_with_images 33 | video_kernel_size: [3, 1, 1] 34 | 35 | conditioner_config: 36 | target: sgm.modules.GeneralConditioner 37 | params: 38 | emb_models: 39 | - is_trainable: False 40 | input_key: cond_frames_without_noise 41 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder 42 | params: 43 | n_cond_frames: 1 44 | n_copies: 1 45 | open_clip_embedding_config: 46 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder 47 | params: 48 | freeze: True 49 | 50 | - input_key: fps_id 51 | is_trainable: False 52 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 53 | params: 54 | outdim: 256 55 | 56 | - input_key: motion_bucket_id 57 | is_trainable: False 58 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 59 | params: 60 | outdim: 256 61 | 62 | - input_key: cond_frames 63 | is_trainable: False 64 | target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder 65 | params: 66 | disable_encoder_autocast: True 67 | n_cond_frames: 1 68 | n_copies: 1 69 | is_ae: True 70 | encoder_config: 71 | target: sgm.models.autoencoder.AutoencoderKLModeOnly 72 | params: 73 | embed_dim: 4 74 | monitor: val/rec_loss 75 | ddconfig: 76 | attn_type: vanilla-xformers 77 | double_z: True 78 | z_channels: 4 79 | resolution: 256 80 | in_channels: 3 81 | out_ch: 3 82 | ch: 128 83 | ch_mult: [1, 2, 4, 4] 84 | num_res_blocks: 2 85 | attn_resolutions: [] 86 | dropout: 0.0 87 | lossconfig: 88 | target: torch.nn.Identity 89 | 90 | - input_key: cond_aug 91 | is_trainable: False 92 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 93 | params: 94 | outdim: 256 95 | 96 | first_stage_config: 97 | target: sgm.models.autoencoder.AutoencodingEngine 98 | params: 99 | loss_config: 100 | target: torch.nn.Identity 101 | regularizer_config: 102 | target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer 103 | encoder_config: 104 | target: sgm.modules.diffusionmodules.model.Encoder 105 | params: 106 | attn_type: vanilla 107 | double_z: True 108 | z_channels: 4 109 | resolution: 256 110 | in_channels: 3 111 | out_ch: 3 112 | ch: 128 113 | ch_mult: [1, 2, 4, 4] 114 | num_res_blocks: 2 115 | attn_resolutions: [] 116 | dropout: 0.0 117 | decoder_config: 118 | target: sgm.modules.autoencoding.temporal_ae.VideoDecoder 119 | params: 120 | attn_type: vanilla 121 | double_z: True 122 | z_channels: 4 123 | resolution: 256 124 | in_channels: 3 125 | out_ch: 3 126 | ch: 128 127 | ch_mult: [1, 2, 4, 4] 128 | num_res_blocks: 2 129 | attn_resolutions: [] 130 | dropout: 0.0 131 | video_kernel_size: [3, 1, 1] -------------------------------------------------------------------------------- /configs/inference/svd_image_decoder.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: sgm.models.diffusion.DiffusionEngine 3 | params: 4 | scale_factor: 0.18215 5 | disable_first_stage_autocast: True 6 | 7 | denoiser_config: 8 | target: sgm.modules.diffusionmodules.denoiser.Denoiser 9 | params: 10 | scaling_config: 11 | target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise 12 | 13 | network_config: 14 | target: sgm.modules.diffusionmodules.video_model.VideoUNet 15 | params: 16 | adm_in_channels: 768 17 | num_classes: sequential 18 | use_checkpoint: True 19 | in_channels: 8 20 | out_channels: 4 21 | model_channels: 320 22 | attention_resolutions: [4, 2, 1] 23 | num_res_blocks: 2 24 | channel_mult: [1, 2, 4, 4] 25 | num_head_channels: 64 26 | use_linear_in_transformer: True 27 | transformer_depth: 1 28 | context_dim: 1024 29 | spatial_transformer_attn_type: softmax-xformers 30 | extra_ff_mix_layer: True 31 | use_spatial_context: True 32 | merge_strategy: learned_with_images 33 | video_kernel_size: [3, 1, 1] 34 | 35 | conditioner_config: 36 | target: sgm.modules.GeneralConditioner 37 | params: 38 | emb_models: 39 | - is_trainable: False 40 | input_key: cond_frames_without_noise 41 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder 42 | params: 43 | n_cond_frames: 1 44 | n_copies: 1 45 | open_clip_embedding_config: 46 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder 47 | params: 48 | freeze: True 49 | 50 | - input_key: fps_id 51 | is_trainable: False 52 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 53 | params: 54 | outdim: 256 55 | 56 | - input_key: motion_bucket_id 57 | is_trainable: False 58 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 59 | params: 60 | outdim: 256 61 | 62 | - input_key: cond_frames 63 | is_trainable: False 64 | target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder 65 | params: 66 | disable_encoder_autocast: True 67 | n_cond_frames: 1 68 | n_copies: 1 69 | is_ae: True 70 | encoder_config: 71 | target: sgm.models.autoencoder.AutoencoderKLModeOnly 72 | params: 73 | embed_dim: 4 74 | monitor: val/rec_loss 75 | ddconfig: 76 | attn_type: vanilla-xformers 77 | double_z: True 78 | z_channels: 4 79 | resolution: 256 80 | in_channels: 3 81 | out_ch: 3 82 | ch: 128 83 | ch_mult: [1, 2, 4, 4] 84 | num_res_blocks: 2 85 | attn_resolutions: [] 86 | dropout: 0.0 87 | lossconfig: 88 | target: torch.nn.Identity 89 | 90 | - input_key: cond_aug 91 | is_trainable: False 92 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 93 | params: 94 | outdim: 256 95 | 96 | first_stage_config: 97 | target: sgm.models.autoencoder.AutoencoderKL 98 | params: 99 | embed_dim: 4 100 | monitor: val/rec_loss 101 | ddconfig: 102 | attn_type: vanilla-xformers 103 | double_z: True 104 | z_channels: 4 105 | resolution: 256 106 | in_channels: 3 107 | out_ch: 3 108 | ch: 128 109 | ch_mult: [1, 2, 4, 4] 110 | num_res_blocks: 2 111 | attn_resolutions: [] 112 | dropout: 0.0 113 | lossconfig: 114 | target: torch.nn.Identity -------------------------------------------------------------------------------- /data/DejaVuSans.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/data/DejaVuSans.ttf -------------------------------------------------------------------------------- /model_licenses/LICENSE-SV3D: -------------------------------------------------------------------------------- 1 | STABILITY AI NON-COMMERCIAL COMMUNITY LICENSE AGREEMENT 2 | Dated: March 18, 2024 3 | 4 | "Agreement" means this Stable Non-Commercial Research Community License Agreement. 5 | 6 | “AUP” means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time. 7 | 8 | "Derivative Work(s)” means (a) any derivative work of the Software Products as recognized by U.S. copyright laws, (b) any modifications to a Model, and (c) any other model created which is based on or derived from the Model or the Model’s output. For clarity, Derivative Works do not include the output of any Model. 9 | 10 | “Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software. 11 | 12 | "Licensee" or "you" means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf. 13 | 14 | “Model(s)" means, collectively, Stability AI’s proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing, made available under this Agreement. 15 | 16 | “Non-Commercial Uses” means exercising any of the rights granted herein for the purpose of research or non-commercial purposes. Non-Commercial Uses does not include any production use of the Software Products or any Derivative Works. 17 | 18 | "Stability AI" or "we" means Stability AI Ltd and its affiliates. 19 | 20 | 21 | "Software" means Stability AI’s proprietary software made available under this Agreement. 22 | 23 | “Software Products” means the Models, Software and Documentation, individually or in any combination. 24 | 25 | 26 | 27 | 1. License Rights and Redistribution. 28 | a. Subject to your compliance with this Agreement, the AUP (which is hereby incorporated herein by reference), and the Documentation, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s intellectual property or other rights owned or controlled by Stability AI embodied in the Software Products to use, reproduce, distribute, and create Derivative Works of, the Software Products, in each case for Non-Commercial Uses only. 29 | b. You may not use the Software Products or Derivative Works to enable third parties to use the Software Products or Derivative Works as part of your hosted service or via your APIs, whether you are adding substantial additional functionality thereto or not. Merely distributing the Software Products or Derivative Works for download online without offering any related service (ex. by distributing the Models on HuggingFace) is not a violation of this subsection. If you wish to use the Software Products or any Derivative Works for commercial or production use or you wish to make the Software Products or any Derivative Works available to third parties via your hosted service or your APIs, contact Stability AI at https://stability.ai/contact. 30 | c. If you distribute or make the Software Products, or any Derivative Works thereof, available to a third party, the Software Products, Derivative Works, or any portion thereof, respectively, will remain subject to this Agreement and you must (i) provide a copy of this Agreement to such third party, and (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Non-Commercial Research Community License, Copyright (c) Stability AI Ltd. All Rights Reserved.” If you create a Derivative Work of a Software Product, you may add your own attribution notices to the Notice file included with the Software Product, provided that you clearly indicate which attributions apply to the Software Product and you must state in the NOTICE file that you changed the Software Product and how it was modified. 31 | 2. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS. 32 | 3. Limitation of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING. 33 | 4. Intellectual Property. 34 | a. No trademark licenses are granted under this Agreement, and in connection with the Software Products or Derivative Works, neither Stability AI nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Software Products or Derivative Works. 35 | b. Subject to Stability AI’s ownership of the Software Products and Derivative Works made by or for Stability AI, with respect to any Derivative Works that are made by you, as between you and Stability AI, you are and will be the owner of such Derivative Works 36 | c. If you institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to your use or distribution of the Software Products or Derivative Works in violation of this Agreement. 37 | 5. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Software Products and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of any Software Products or Derivative Works. Sections 2-4 shall survive the termination of this Agreement. 38 | 39 | 6. Governing Law. This Agreement will be governed by and construed in accordance with the laws of the United States and the State of California without regard to choice of law 40 | principles. 41 | 42 | -------------------------------------------------------------------------------- /model_licenses/LICENSE-SVD: -------------------------------------------------------------------------------- 1 | STABLE VIDEO DIFFUSION NON-COMMERCIAL COMMUNITY LICENSE AGREEMENT 2 | Dated: November 21, 2023 3 | 4 | “AUP” means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time. 5 | 6 | "Agreement" means the terms and conditions for use, reproduction, distribution and modification of the Software Products set forth herein. 7 | "Derivative Work(s)” means (a) any derivative work of the Software Products as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model’s output. For clarity, Derivative Works do not include the output of any Model. 8 | “Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software. 9 | 10 | "Licensee" or "you" means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf. 11 | 12 | "Stability AI" or "we" means Stability AI Ltd. 13 | 14 | "Software" means, collectively, Stability AI’s proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing, made available under this Agreement. 15 | 16 | “Software Products” means Software and Documentation. 17 | 18 | By using or distributing any portion or element of the Software Products, you agree to be bound by this Agreement. 19 | 20 | 21 | 22 | License Rights and Redistribution. 23 | Subject to your compliance with this Agreement, the AUP (which is hereby incorporated herein by reference), and the Documentation, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s intellectual property or other rights owned by Stability AI embodied in the Software Products to reproduce, distribute, and create Derivative Works of the Software Products for purposes other than commercial or production use. 24 | b. If you distribute or make the Software Products, or any Derivative Works thereof, available to a third party, the Software Products, Derivative Works, or any portion thereof, respectively, will remain subject to this Agreement and you must (i) provide a copy of this Agreement to such third party, and (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "Stable Video Diffusion is licensed under the Stable Video Diffusion Research License, Copyright (c) Stability AI Ltd. All Rights Reserved.” If you create a Derivative Work of a Software Product, you may add your own attribution notices to the Notice file included with the Software Product, provided that you clearly indicate which attributions apply to the Software Product and you must state in the NOTICE file that you changed the Software Product and how it was modified. 25 | 2. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS. 26 | 3. Limitation of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING. 27 | 3. Intellectual Property. 28 | a. No trademark licenses are granted under this Agreement, and in connection with the Software Products, neither Stability AI nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Software Products. 29 | Subject to Stability AI’s ownership of the Software Products and Derivative Works made by or for Stability AI, with respect to any Derivative Works that are made by you, as between you and Stability AI, you are and will be the owner of such Derivative Works. 30 | If you institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to your use or distribution of the Software Products in violation of this Agreement. 31 | 4. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Software Products and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Software Products. Sections 2-4 shall survive the termination of this Agreement. 32 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "sgm" 7 | dynamic = ["version"] 8 | description = "Stability Generative Models" 9 | readme = "README.md" 10 | license-files = { paths = ["LICENSE-CODE"] } 11 | requires-python = ">=3.8" 12 | 13 | [project.urls] 14 | Homepage = "https://github.com/Stability-AI/generative-models" 15 | 16 | [tool.hatch.version] 17 | path = "sgm/__init__.py" 18 | 19 | [tool.hatch.build] 20 | # This needs to be explicitly set so the configuration files 21 | # grafted into the `sgm` directory get included in the wheel's 22 | # RECORD file. 23 | include = [ 24 | "sgm", 25 | ] 26 | # The force-include configurations below make Hatch copy 27 | # the configs/ directory (containing the various YAML files required 28 | # to generatively model) into the source distribution and the wheel. 29 | 30 | [tool.hatch.build.targets.sdist.force-include] 31 | "./configs" = "sgm/configs" 32 | 33 | [tool.hatch.build.targets.wheel.force-include] 34 | "./configs" = "sgm/configs" 35 | 36 | [tool.hatch.envs.ci] 37 | skip-install = false 38 | 39 | dependencies = [ 40 | "pytest" 41 | ] 42 | 43 | [tool.hatch.envs.ci.scripts] 44 | test-inference = [ 45 | "pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118", 46 | "pip install -r requirements/pt2.txt", 47 | "pytest -v tests/inference/test_inference.py {args}", 48 | ] 49 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | markers = 3 | inference: mark as inference test (deselect with '-m "not inference"') -------------------------------------------------------------------------------- /requirements/pt2.txt: -------------------------------------------------------------------------------- 1 | black==23.7.0 2 | chardet==5.1.0 3 | clip @ git+https://github.com/openai/CLIP.git 4 | einops>=0.6.1 5 | fairscale>=0.4.13 6 | fire>=0.5.0 7 | fsspec>=2023.6.0 8 | imageio[ffmpeg] 9 | imageio[pyav] 10 | invisible-watermark>=0.2.0 11 | kornia==0.6.9 12 | matplotlib>=3.7.2 13 | natsort>=8.4.0 14 | ninja>=1.11.1 15 | numpy==2.1 16 | omegaconf>=2.3.0 17 | onnxruntime 18 | open-clip-torch>=2.20.0 19 | opencv-python==4.6.0.66 20 | pandas>=2.0.3 21 | pillow>=9.5.0 22 | pudb>=2022.1.3 23 | pytorch-lightning==2.0.1 24 | pyyaml>=6.0.1 25 | rembg 26 | scipy>=1.10.1 27 | streamlit>=0.73.1 28 | tensorboardx==2.6 29 | timm>=0.9.2 30 | tokenizers==0.12.1 31 | torch>=2.0.1 32 | torchaudio>=2.0.2 33 | torchdata==0.6.1 34 | torchmetrics>=1.0.1 35 | torchvision>=0.15.2 36 | tqdm>=4.65.0 37 | transformers==4.19.1 38 | triton==2.0.0 39 | urllib3<1.27,>=1.25.4 40 | wandb>=0.15.6 41 | webdataset>=0.2.33 42 | wheel>=0.41.0 43 | xformers>=0.0.20 44 | gradio 45 | streamlit-keyup==0.2.0 46 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/demo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/scripts/demo/__init__.py -------------------------------------------------------------------------------- /scripts/demo/detect.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | try: 7 | from imwatermark import WatermarkDecoder 8 | except ImportError as e: 9 | try: 10 | # Assume some of the other dependencies such as torch are not fulfilled 11 | # import file without loading unnecessary libraries. 12 | import importlib.util 13 | import sys 14 | 15 | spec = importlib.util.find_spec("imwatermark.maxDct") 16 | assert spec is not None 17 | maxDct = importlib.util.module_from_spec(spec) 18 | sys.modules["maxDct"] = maxDct 19 | spec.loader.exec_module(maxDct) 20 | 21 | class WatermarkDecoder(object): 22 | """A minimal version of 23 | https://github.com/ShieldMnt/invisible-watermark/blob/main/imwatermark/watermark.py 24 | to only reconstruct bits using dwtDct""" 25 | 26 | def __init__(self, wm_type="bytes", length=0): 27 | assert wm_type == "bits", "Only bits defined in minimal import" 28 | self._wmType = wm_type 29 | self._wmLen = length 30 | 31 | def reconstruct(self, bits): 32 | if len(bits) != self._wmLen: 33 | raise RuntimeError("bits are not matched with watermark length") 34 | 35 | return bits 36 | 37 | def decode(self, cv2Image, method="dwtDct", **configs): 38 | (r, c, channels) = cv2Image.shape 39 | if r * c < 256 * 256: 40 | raise RuntimeError("image too small, should be larger than 256x256") 41 | 42 | bits = [] 43 | assert method == "dwtDct" 44 | embed = maxDct.EmbedMaxDct(watermarks=[], wmLen=self._wmLen, **configs) 45 | bits = embed.decode(cv2Image) 46 | return self.reconstruct(bits) 47 | 48 | except: 49 | raise e 50 | 51 | 52 | # A fixed 48-bit message that was choosen at random 53 | # WATERMARK_MESSAGE = 0xB3EC907BB19E 54 | WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110 55 | # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 56 | WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] 57 | MATCH_VALUES = [ 58 | [27, "No watermark detected"], 59 | [33, "Partial watermark match. Cannot determine with certainty."], 60 | [ 61 | 35, 62 | ( 63 | "Likely watermarked. In our test 0.02% of real images were " 64 | 'falsely detected as "Likely watermarked"' 65 | ), 66 | ], 67 | [ 68 | 49, 69 | ( 70 | "Very likely watermarked. In our test no real images were " 71 | 'falsely detected as "Very likely watermarked"' 72 | ), 73 | ], 74 | ] 75 | 76 | 77 | class GetWatermarkMatch: 78 | def __init__(self, watermark): 79 | self.watermark = watermark 80 | self.num_bits = len(self.watermark) 81 | self.decoder = WatermarkDecoder("bits", self.num_bits) 82 | 83 | def __call__(self, x: np.ndarray) -> np.ndarray: 84 | """ 85 | Detects the number of matching bits the predefined watermark with one 86 | or multiple images. Images should be in cv2 format, e.g. h x w x c BGR. 87 | 88 | Args: 89 | x: ([B], h w, c) in range [0, 255] 90 | 91 | Returns: 92 | number of matched bits ([B],) 93 | """ 94 | squeeze = len(x.shape) == 3 95 | if squeeze: 96 | x = x[None, ...] 97 | 98 | bs = x.shape[0] 99 | detected = np.empty((bs, self.num_bits), dtype=bool) 100 | for k in range(bs): 101 | detected[k] = self.decoder.decode(x[k], "dwtDct") 102 | result = np.sum(detected == self.watermark, axis=-1) 103 | if squeeze: 104 | return result[0] 105 | else: 106 | return result 107 | 108 | 109 | get_watermark_match = GetWatermarkMatch(WATERMARK_BITS) 110 | 111 | 112 | if __name__ == "__main__": 113 | parser = argparse.ArgumentParser() 114 | parser.add_argument( 115 | "filename", 116 | nargs="+", 117 | type=str, 118 | help="Image files to check for watermarks", 119 | ) 120 | opts = parser.parse_args() 121 | 122 | print( 123 | """ 124 | This script tries to detect watermarked images. Please be aware of 125 | the following: 126 | - As the watermark is supposed to be invisible, there is the risk that 127 | watermarked images may not be detected. 128 | - To maximize the chance of detection make sure that the image has the same 129 | dimensions as when the watermark was applied (most likely 1024x1024 130 | or 512x512). 131 | - Specific image manipulation may drastically decrease the chance that 132 | watermarks can be detected. 133 | - There is also the chance that an image has the characteristics of the 134 | watermark by chance. 135 | - The watermark script is public, anybody may watermark any images, and 136 | could therefore claim it to be generated. 137 | - All numbers below are based on a test using 10,000 images without any 138 | modifications after applying the watermark. 139 | """ 140 | ) 141 | 142 | for fn in opts.filename: 143 | image = cv2.imread(fn) 144 | if image is None: 145 | print(f"Couldn't read {fn}. Skipping") 146 | continue 147 | 148 | num_bits = get_watermark_match(image) 149 | k = 0 150 | while num_bits > MATCH_VALUES[k][0]: 151 | k += 1 152 | print( 153 | f"{fn}: {MATCH_VALUES[k][1]}", 154 | f"Bits that matched the watermark {num_bits} from {len(WATERMARK_BITS)}\n", 155 | sep="\n\t", 156 | ) 157 | -------------------------------------------------------------------------------- /scripts/demo/discretization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from sgm.modules.diffusionmodules.discretizer import Discretization 4 | 5 | 6 | class Img2ImgDiscretizationWrapper: 7 | """ 8 | wraps a discretizer, and prunes the sigmas 9 | params: 10 | strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned) 11 | """ 12 | 13 | def __init__(self, discretization: Discretization, strength: float = 1.0): 14 | self.discretization = discretization 15 | self.strength = strength 16 | assert 0.0 <= self.strength <= 1.0 17 | 18 | def __call__(self, *args, **kwargs): 19 | # sigmas start large first, and decrease then 20 | sigmas = self.discretization(*args, **kwargs) 21 | print(f"sigmas after discretization, before pruning img2img: ", sigmas) 22 | sigmas = torch.flip(sigmas, (0,)) 23 | sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)] 24 | print("prune index:", max(int(self.strength * len(sigmas)), 1)) 25 | sigmas = torch.flip(sigmas, (0,)) 26 | print(f"sigmas after pruning: ", sigmas) 27 | return sigmas 28 | 29 | 30 | class Txt2NoisyDiscretizationWrapper: 31 | """ 32 | wraps a discretizer, and prunes the sigmas 33 | params: 34 | strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned) 35 | """ 36 | 37 | def __init__( 38 | self, discretization: Discretization, strength: float = 0.0, original_steps=None 39 | ): 40 | self.discretization = discretization 41 | self.strength = strength 42 | self.original_steps = original_steps 43 | assert 0.0 <= self.strength <= 1.0 44 | 45 | def __call__(self, *args, **kwargs): 46 | # sigmas start large first, and decrease then 47 | sigmas = self.discretization(*args, **kwargs) 48 | print(f"sigmas after discretization, before pruning img2img: ", sigmas) 49 | sigmas = torch.flip(sigmas, (0,)) 50 | if self.original_steps is None: 51 | steps = len(sigmas) 52 | else: 53 | steps = self.original_steps + 1 54 | prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0) 55 | sigmas = sigmas[prune_index:] 56 | print("prune index:", prune_index) 57 | sigmas = torch.flip(sigmas, (0,)) 58 | print(f"sigmas after pruning: ", sigmas) 59 | return sigmas 60 | -------------------------------------------------------------------------------- /scripts/demo/sv3d_helpers.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | 7 | def generate_dynamic_cycle_xy_values( 8 | length=21, 9 | init_elev=0, 10 | num_components=84, 11 | frequency_range=(1, 5), 12 | amplitude_range=(0.5, 10), 13 | step_range=(0, 2), 14 | ): 15 | # Y values generation 16 | y_sequence = np.ones(length) * init_elev 17 | for _ in range(num_components): 18 | # Choose a frequency that will complete whole cycles in the sequence 19 | frequency = np.random.randint(*frequency_range) * (2 * np.pi / length) 20 | amplitude = np.random.uniform(*amplitude_range) 21 | phase_shift = np.random.choice([0, np.pi]) # np.random.uniform(0, 2 * np.pi) 22 | angles = ( 23 | np.linspace(0, frequency * length, length, endpoint=False) + phase_shift 24 | ) 25 | y_sequence += np.sin(angles) * amplitude 26 | # X values generation 27 | # Generate length - 1 steps since the last step is back to start 28 | steps = np.random.uniform(*step_range, length - 1) 29 | total_step_sum = np.sum(steps) 30 | # Calculate the scale factor to scale total steps to just under 360 31 | scale_factor = ( 32 | 360 - ((360 / length) * np.random.uniform(*step_range)) 33 | ) / total_step_sum 34 | # Apply the scale factor and generate the sequence of X values 35 | x_values = np.cumsum(steps * scale_factor) 36 | # Ensure the sequence starts at 0 and add the final step to complete the loop 37 | x_values = np.insert(x_values, 0, 0) 38 | return x_values, y_sequence 39 | 40 | 41 | def smooth_data(data, window_size): 42 | # Extend data at both ends by wrapping around to create a continuous loop 43 | pad_size = window_size 44 | padded_data = np.concatenate((data[-pad_size:], data, data[:pad_size])) 45 | 46 | # Apply smoothing 47 | kernel = np.ones(window_size) / window_size 48 | smoothed_data = np.convolve(padded_data, kernel, mode="same") 49 | 50 | # Extract the smoothed data corresponding to the original sequence 51 | # Adjust the indices to account for the larger padding 52 | start_index = pad_size 53 | end_index = -pad_size if pad_size != 0 else None 54 | smoothed_original_data = smoothed_data[start_index:end_index] 55 | return smoothed_original_data 56 | 57 | 58 | # Function to generate and process the data 59 | def gen_dynamic_loop(length=21, elev_deg=0): 60 | while True: 61 | # Generate the combined X and Y values using the new function 62 | azim_values, elev_values = generate_dynamic_cycle_xy_values( 63 | length=84, init_elev=elev_deg 64 | ) 65 | # Smooth the Y values directly 66 | smoothed_elev_values = smooth_data(elev_values, 5) 67 | max_magnitude = np.max(np.abs(smoothed_elev_values)) 68 | if max_magnitude < 90: 69 | break 70 | subsample = 84 // length 71 | azim_rad = np.deg2rad(azim_values[::subsample]) 72 | elev_rad = np.deg2rad(smoothed_elev_values[::subsample]) 73 | # Make cond frame the last one 74 | return np.roll(azim_rad, -1), np.roll(elev_rad, -1) 75 | 76 | 77 | def plot_3D(azim, polar, save_path, dynamic=True): 78 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 79 | elev = np.deg2rad(90) - polar 80 | fig = plt.figure(figsize=(5, 5)) 81 | ax = fig.add_subplot(projection="3d") 82 | cm = plt.get_cmap("Greys") 83 | col_line = [cm(i) for i in np.linspace(0.3, 1, len(azim) + 1)] 84 | cm = plt.get_cmap("cool") 85 | col = [cm(float(i) / (len(azim))) for i in np.arange(len(azim))] 86 | xs = np.cos(elev) * np.cos(azim) 87 | ys = np.cos(elev) * np.sin(azim) 88 | zs = np.sin(elev) 89 | ax.scatter(xs[0], ys[0], zs[0], s=100, color=col[0]) 90 | xs_d, ys_d, zs_d = (xs[1:] - xs[:-1]), (ys[1:] - ys[:-1]), (zs[1:] - zs[:-1]) 91 | for i in range(len(xs) - 1): 92 | if dynamic: 93 | ax.quiver( 94 | xs[i], ys[i], zs[i], xs_d[i], ys_d[i], zs_d[i], lw=2, color=col_line[i] 95 | ) 96 | else: 97 | ax.plot(xs[i : i + 2], ys[i : i + 2], zs[i : i + 2], lw=2, c=col_line[i]) 98 | ax.scatter(xs[i + 1], ys[i + 1], zs[i + 1], s=100, color=col[i + 1]) 99 | ax.scatter(xs[:1], ys[:1], zs[:1], s=120, facecolors="none", edgecolors="k") 100 | ax.scatter(xs[-1:], ys[-1:], zs[-1:], s=120, facecolors="none", edgecolors="k") 101 | ax.view_init(elev=30, azim=-20, roll=0) 102 | plt.savefig(save_path, bbox_inches="tight") 103 | plt.clf() 104 | plt.close() 105 | -------------------------------------------------------------------------------- /scripts/sampling/configs/sv3d_p.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: sgm.models.diffusion.DiffusionEngine 3 | params: 4 | scale_factor: 0.18215 5 | disable_first_stage_autocast: True 6 | ckpt_path: checkpoints/sv3d_p.safetensors 7 | 8 | denoiser_config: 9 | target: sgm.modules.diffusionmodules.denoiser.Denoiser 10 | params: 11 | scaling_config: 12 | target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise 13 | 14 | network_config: 15 | target: sgm.modules.diffusionmodules.video_model.VideoUNet 16 | params: 17 | adm_in_channels: 1280 18 | num_classes: sequential 19 | use_checkpoint: True 20 | in_channels: 8 21 | out_channels: 4 22 | model_channels: 320 23 | attention_resolutions: [4, 2, 1] 24 | num_res_blocks: 2 25 | channel_mult: [1, 2, 4, 4] 26 | num_head_channels: 64 27 | use_linear_in_transformer: True 28 | transformer_depth: 1 29 | context_dim: 1024 30 | spatial_transformer_attn_type: softmax-xformers 31 | extra_ff_mix_layer: True 32 | use_spatial_context: True 33 | merge_strategy: learned_with_images 34 | video_kernel_size: [3, 1, 1] 35 | 36 | conditioner_config: 37 | target: sgm.modules.GeneralConditioner 38 | params: 39 | emb_models: 40 | - input_key: cond_frames_without_noise 41 | is_trainable: False 42 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder 43 | params: 44 | n_cond_frames: 1 45 | n_copies: 1 46 | open_clip_embedding_config: 47 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder 48 | params: 49 | freeze: True 50 | 51 | - input_key: cond_frames 52 | is_trainable: False 53 | target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder 54 | params: 55 | disable_encoder_autocast: True 56 | n_cond_frames: 1 57 | n_copies: 1 58 | is_ae: True 59 | encoder_config: 60 | target: sgm.models.autoencoder.AutoencoderKLModeOnly 61 | params: 62 | embed_dim: 4 63 | monitor: val/rec_loss 64 | ddconfig: 65 | attn_type: vanilla-xformers 66 | double_z: True 67 | z_channels: 4 68 | resolution: 256 69 | in_channels: 3 70 | out_ch: 3 71 | ch: 128 72 | ch_mult: [1, 2, 4, 4] 73 | num_res_blocks: 2 74 | attn_resolutions: [] 75 | dropout: 0.0 76 | lossconfig: 77 | target: torch.nn.Identity 78 | 79 | - input_key: cond_aug 80 | is_trainable: False 81 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 82 | params: 83 | outdim: 256 84 | 85 | - input_key: polars_rad 86 | is_trainable: False 87 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 88 | params: 89 | outdim: 512 90 | 91 | - input_key: azimuths_rad 92 | is_trainable: False 93 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 94 | params: 95 | outdim: 512 96 | 97 | first_stage_config: 98 | target: sgm.models.autoencoder.AutoencodingEngine 99 | params: 100 | loss_config: 101 | target: torch.nn.Identity 102 | regularizer_config: 103 | target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer 104 | encoder_config: 105 | target: torch.nn.Identity 106 | decoder_config: 107 | target: sgm.modules.diffusionmodules.model.Decoder 108 | params: 109 | attn_type: vanilla-xformers 110 | double_z: True 111 | z_channels: 4 112 | resolution: 256 113 | in_channels: 3 114 | out_ch: 3 115 | ch: 128 116 | ch_mult: [ 1, 2, 4, 4 ] 117 | num_res_blocks: 2 118 | attn_resolutions: [ ] 119 | dropout: 0.0 120 | 121 | sampler_config: 122 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler 123 | params: 124 | discretization_config: 125 | target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization 126 | params: 127 | sigma_max: 700.0 128 | 129 | guider_config: 130 | target: sgm.modules.diffusionmodules.guiders.TrianglePredictionGuider 131 | params: 132 | max_scale: 2.5 133 | -------------------------------------------------------------------------------- /scripts/sampling/configs/sv3d_u.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: sgm.models.diffusion.DiffusionEngine 3 | params: 4 | scale_factor: 0.18215 5 | disable_first_stage_autocast: True 6 | ckpt_path: checkpoints/sv3d_u.safetensors 7 | 8 | denoiser_config: 9 | target: sgm.modules.diffusionmodules.denoiser.Denoiser 10 | params: 11 | scaling_config: 12 | target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise 13 | 14 | network_config: 15 | target: sgm.modules.diffusionmodules.video_model.VideoUNet 16 | params: 17 | adm_in_channels: 256 18 | num_classes: sequential 19 | use_checkpoint: True 20 | in_channels: 8 21 | out_channels: 4 22 | model_channels: 320 23 | attention_resolutions: [4, 2, 1] 24 | num_res_blocks: 2 25 | channel_mult: [1, 2, 4, 4] 26 | num_head_channels: 64 27 | use_linear_in_transformer: True 28 | transformer_depth: 1 29 | context_dim: 1024 30 | spatial_transformer_attn_type: softmax-xformers 31 | extra_ff_mix_layer: True 32 | use_spatial_context: True 33 | merge_strategy: learned_with_images 34 | video_kernel_size: [3, 1, 1] 35 | 36 | conditioner_config: 37 | target: sgm.modules.GeneralConditioner 38 | params: 39 | emb_models: 40 | - is_trainable: False 41 | input_key: cond_frames_without_noise 42 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder 43 | params: 44 | n_cond_frames: 1 45 | n_copies: 1 46 | open_clip_embedding_config: 47 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder 48 | params: 49 | freeze: True 50 | 51 | - input_key: cond_frames 52 | is_trainable: False 53 | target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder 54 | params: 55 | disable_encoder_autocast: True 56 | n_cond_frames: 1 57 | n_copies: 1 58 | is_ae: True 59 | encoder_config: 60 | target: sgm.models.autoencoder.AutoencoderKLModeOnly 61 | params: 62 | embed_dim: 4 63 | monitor: val/rec_loss 64 | ddconfig: 65 | attn_type: vanilla-xformers 66 | double_z: True 67 | z_channels: 4 68 | resolution: 256 69 | in_channels: 3 70 | out_ch: 3 71 | ch: 128 72 | ch_mult: [1, 2, 4, 4] 73 | num_res_blocks: 2 74 | attn_resolutions: [] 75 | dropout: 0.0 76 | lossconfig: 77 | target: torch.nn.Identity 78 | 79 | - input_key: cond_aug 80 | is_trainable: False 81 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 82 | params: 83 | outdim: 256 84 | 85 | first_stage_config: 86 | target: sgm.models.autoencoder.AutoencodingEngine 87 | params: 88 | loss_config: 89 | target: torch.nn.Identity 90 | regularizer_config: 91 | target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer 92 | encoder_config: 93 | target: torch.nn.Identity 94 | decoder_config: 95 | target: sgm.modules.diffusionmodules.model.Decoder 96 | params: 97 | attn_type: vanilla-xformers 98 | double_z: True 99 | z_channels: 4 100 | resolution: 256 101 | in_channels: 3 102 | out_ch: 3 103 | ch: 128 104 | ch_mult: [ 1, 2, 4, 4 ] 105 | num_res_blocks: 2 106 | attn_resolutions: [ ] 107 | dropout: 0.0 108 | 109 | sampler_config: 110 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler 111 | params: 112 | discretization_config: 113 | target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization 114 | params: 115 | sigma_max: 700.0 116 | 117 | guider_config: 118 | target: sgm.modules.diffusionmodules.guiders.TrianglePredictionGuider 119 | params: 120 | max_scale: 2.5 121 | -------------------------------------------------------------------------------- /scripts/sampling/configs/svd.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: sgm.models.diffusion.DiffusionEngine 3 | params: 4 | scale_factor: 0.18215 5 | disable_first_stage_autocast: True 6 | ckpt_path: checkpoints/svd.safetensors 7 | 8 | denoiser_config: 9 | target: sgm.modules.diffusionmodules.denoiser.Denoiser 10 | params: 11 | scaling_config: 12 | target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise 13 | 14 | network_config: 15 | target: sgm.modules.diffusionmodules.video_model.VideoUNet 16 | params: 17 | adm_in_channels: 768 18 | num_classes: sequential 19 | use_checkpoint: True 20 | in_channels: 8 21 | out_channels: 4 22 | model_channels: 320 23 | attention_resolutions: [4, 2, 1] 24 | num_res_blocks: 2 25 | channel_mult: [1, 2, 4, 4] 26 | num_head_channels: 64 27 | use_linear_in_transformer: True 28 | transformer_depth: 1 29 | context_dim: 1024 30 | spatial_transformer_attn_type: softmax-xformers 31 | extra_ff_mix_layer: True 32 | use_spatial_context: True 33 | merge_strategy: learned_with_images 34 | video_kernel_size: [3, 1, 1] 35 | 36 | conditioner_config: 37 | target: sgm.modules.GeneralConditioner 38 | params: 39 | emb_models: 40 | - is_trainable: False 41 | input_key: cond_frames_without_noise 42 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder 43 | params: 44 | n_cond_frames: 1 45 | n_copies: 1 46 | open_clip_embedding_config: 47 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder 48 | params: 49 | freeze: True 50 | 51 | - input_key: fps_id 52 | is_trainable: False 53 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 54 | params: 55 | outdim: 256 56 | 57 | - input_key: motion_bucket_id 58 | is_trainable: False 59 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 60 | params: 61 | outdim: 256 62 | 63 | - input_key: cond_frames 64 | is_trainable: False 65 | target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder 66 | params: 67 | disable_encoder_autocast: True 68 | n_cond_frames: 1 69 | n_copies: 1 70 | is_ae: True 71 | encoder_config: 72 | target: sgm.models.autoencoder.AutoencoderKLModeOnly 73 | params: 74 | embed_dim: 4 75 | monitor: val/rec_loss 76 | ddconfig: 77 | attn_type: vanilla-xformers 78 | double_z: True 79 | z_channels: 4 80 | resolution: 256 81 | in_channels: 3 82 | out_ch: 3 83 | ch: 128 84 | ch_mult: [1, 2, 4, 4] 85 | num_res_blocks: 2 86 | attn_resolutions: [] 87 | dropout: 0.0 88 | lossconfig: 89 | target: torch.nn.Identity 90 | 91 | - input_key: cond_aug 92 | is_trainable: False 93 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 94 | params: 95 | outdim: 256 96 | 97 | first_stage_config: 98 | target: sgm.models.autoencoder.AutoencodingEngine 99 | params: 100 | loss_config: 101 | target: torch.nn.Identity 102 | regularizer_config: 103 | target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer 104 | encoder_config: 105 | target: sgm.modules.diffusionmodules.model.Encoder 106 | params: 107 | attn_type: vanilla 108 | double_z: True 109 | z_channels: 4 110 | resolution: 256 111 | in_channels: 3 112 | out_ch: 3 113 | ch: 128 114 | ch_mult: [1, 2, 4, 4] 115 | num_res_blocks: 2 116 | attn_resolutions: [] 117 | dropout: 0.0 118 | decoder_config: 119 | target: sgm.modules.autoencoding.temporal_ae.VideoDecoder 120 | params: 121 | attn_type: vanilla 122 | double_z: True 123 | z_channels: 4 124 | resolution: 256 125 | in_channels: 3 126 | out_ch: 3 127 | ch: 128 128 | ch_mult: [1, 2, 4, 4] 129 | num_res_blocks: 2 130 | attn_resolutions: [] 131 | dropout: 0.0 132 | video_kernel_size: [3, 1, 1] 133 | 134 | sampler_config: 135 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler 136 | params: 137 | discretization_config: 138 | target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization 139 | params: 140 | sigma_max: 700.0 141 | 142 | guider_config: 143 | target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider 144 | params: 145 | max_scale: 2.5 146 | min_scale: 1.0 -------------------------------------------------------------------------------- /scripts/sampling/configs/svd_image_decoder.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: sgm.models.diffusion.DiffusionEngine 3 | params: 4 | scale_factor: 0.18215 5 | disable_first_stage_autocast: True 6 | ckpt_path: checkpoints/svd_image_decoder.safetensors 7 | 8 | denoiser_config: 9 | target: sgm.modules.diffusionmodules.denoiser.Denoiser 10 | params: 11 | scaling_config: 12 | target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise 13 | 14 | network_config: 15 | target: sgm.modules.diffusionmodules.video_model.VideoUNet 16 | params: 17 | adm_in_channels: 768 18 | num_classes: sequential 19 | use_checkpoint: True 20 | in_channels: 8 21 | out_channels: 4 22 | model_channels: 320 23 | attention_resolutions: [4, 2, 1] 24 | num_res_blocks: 2 25 | channel_mult: [1, 2, 4, 4] 26 | num_head_channels: 64 27 | use_linear_in_transformer: True 28 | transformer_depth: 1 29 | context_dim: 1024 30 | spatial_transformer_attn_type: softmax-xformers 31 | extra_ff_mix_layer: True 32 | use_spatial_context: True 33 | merge_strategy: learned_with_images 34 | video_kernel_size: [3, 1, 1] 35 | 36 | conditioner_config: 37 | target: sgm.modules.GeneralConditioner 38 | params: 39 | emb_models: 40 | - is_trainable: False 41 | input_key: cond_frames_without_noise 42 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder 43 | params: 44 | n_cond_frames: 1 45 | n_copies: 1 46 | open_clip_embedding_config: 47 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder 48 | params: 49 | freeze: True 50 | 51 | - input_key: fps_id 52 | is_trainable: False 53 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 54 | params: 55 | outdim: 256 56 | 57 | - input_key: motion_bucket_id 58 | is_trainable: False 59 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 60 | params: 61 | outdim: 256 62 | 63 | - input_key: cond_frames 64 | is_trainable: False 65 | target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder 66 | params: 67 | disable_encoder_autocast: True 68 | n_cond_frames: 1 69 | n_copies: 1 70 | is_ae: True 71 | encoder_config: 72 | target: sgm.models.autoencoder.AutoencoderKLModeOnly 73 | params: 74 | embed_dim: 4 75 | monitor: val/rec_loss 76 | ddconfig: 77 | attn_type: vanilla-xformers 78 | double_z: True 79 | z_channels: 4 80 | resolution: 256 81 | in_channels: 3 82 | out_ch: 3 83 | ch: 128 84 | ch_mult: [1, 2, 4, 4] 85 | num_res_blocks: 2 86 | attn_resolutions: [] 87 | dropout: 0.0 88 | lossconfig: 89 | target: torch.nn.Identity 90 | 91 | - input_key: cond_aug 92 | is_trainable: False 93 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 94 | params: 95 | outdim: 256 96 | 97 | first_stage_config: 98 | target: sgm.models.autoencoder.AutoencoderKL 99 | params: 100 | embed_dim: 4 101 | monitor: val/rec_loss 102 | ddconfig: 103 | attn_type: vanilla-xformers 104 | double_z: True 105 | z_channels: 4 106 | resolution: 256 107 | in_channels: 3 108 | out_ch: 3 109 | ch: 128 110 | ch_mult: [1, 2, 4, 4] 111 | num_res_blocks: 2 112 | attn_resolutions: [] 113 | dropout: 0.0 114 | lossconfig: 115 | target: torch.nn.Identity 116 | 117 | sampler_config: 118 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler 119 | params: 120 | discretization_config: 121 | target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization 122 | params: 123 | sigma_max: 700.0 124 | 125 | guider_config: 126 | target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider 127 | params: 128 | max_scale: 2.5 129 | min_scale: 1.0 -------------------------------------------------------------------------------- /scripts/sampling/configs/svd_xt.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: sgm.models.diffusion.DiffusionEngine 3 | params: 4 | scale_factor: 0.18215 5 | disable_first_stage_autocast: True 6 | ckpt_path: checkpoints/svd_xt.safetensors 7 | 8 | denoiser_config: 9 | target: sgm.modules.diffusionmodules.denoiser.Denoiser 10 | params: 11 | scaling_config: 12 | target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise 13 | 14 | network_config: 15 | target: sgm.modules.diffusionmodules.video_model.VideoUNet 16 | params: 17 | adm_in_channels: 768 18 | num_classes: sequential 19 | use_checkpoint: True 20 | in_channels: 8 21 | out_channels: 4 22 | model_channels: 320 23 | attention_resolutions: [4, 2, 1] 24 | num_res_blocks: 2 25 | channel_mult: [1, 2, 4, 4] 26 | num_head_channels: 64 27 | use_linear_in_transformer: True 28 | transformer_depth: 1 29 | context_dim: 1024 30 | spatial_transformer_attn_type: softmax-xformers 31 | extra_ff_mix_layer: True 32 | use_spatial_context: True 33 | merge_strategy: learned_with_images 34 | video_kernel_size: [3, 1, 1] 35 | 36 | conditioner_config: 37 | target: sgm.modules.GeneralConditioner 38 | params: 39 | emb_models: 40 | - is_trainable: False 41 | input_key: cond_frames_without_noise 42 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder 43 | params: 44 | n_cond_frames: 1 45 | n_copies: 1 46 | open_clip_embedding_config: 47 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder 48 | params: 49 | freeze: True 50 | 51 | - input_key: fps_id 52 | is_trainable: False 53 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 54 | params: 55 | outdim: 256 56 | 57 | - input_key: motion_bucket_id 58 | is_trainable: False 59 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 60 | params: 61 | outdim: 256 62 | 63 | - input_key: cond_frames 64 | is_trainable: False 65 | target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder 66 | params: 67 | disable_encoder_autocast: True 68 | n_cond_frames: 1 69 | n_copies: 1 70 | is_ae: True 71 | encoder_config: 72 | target: sgm.models.autoencoder.AutoencoderKLModeOnly 73 | params: 74 | embed_dim: 4 75 | monitor: val/rec_loss 76 | ddconfig: 77 | attn_type: vanilla-xformers 78 | double_z: True 79 | z_channels: 4 80 | resolution: 256 81 | in_channels: 3 82 | out_ch: 3 83 | ch: 128 84 | ch_mult: [1, 2, 4, 4] 85 | num_res_blocks: 2 86 | attn_resolutions: [] 87 | dropout: 0.0 88 | lossconfig: 89 | target: torch.nn.Identity 90 | 91 | - input_key: cond_aug 92 | is_trainable: False 93 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 94 | params: 95 | outdim: 256 96 | 97 | first_stage_config: 98 | target: sgm.models.autoencoder.AutoencodingEngine 99 | params: 100 | loss_config: 101 | target: torch.nn.Identity 102 | regularizer_config: 103 | target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer 104 | encoder_config: 105 | target: sgm.modules.diffusionmodules.model.Encoder 106 | params: 107 | attn_type: vanilla 108 | double_z: True 109 | z_channels: 4 110 | resolution: 256 111 | in_channels: 3 112 | out_ch: 3 113 | ch: 128 114 | ch_mult: [1, 2, 4, 4] 115 | num_res_blocks: 2 116 | attn_resolutions: [] 117 | dropout: 0.0 118 | decoder_config: 119 | target: sgm.modules.autoencoding.temporal_ae.VideoDecoder 120 | params: 121 | attn_type: vanilla 122 | double_z: True 123 | z_channels: 4 124 | resolution: 256 125 | in_channels: 3 126 | out_ch: 3 127 | ch: 128 128 | ch_mult: [1, 2, 4, 4] 129 | num_res_blocks: 2 130 | attn_resolutions: [] 131 | dropout: 0.0 132 | video_kernel_size: [3, 1, 1] 133 | 134 | sampler_config: 135 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler 136 | params: 137 | discretization_config: 138 | target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization 139 | params: 140 | sigma_max: 700.0 141 | 142 | guider_config: 143 | target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider 144 | params: 145 | max_scale: 3.0 146 | min_scale: 1.5 -------------------------------------------------------------------------------- /scripts/sampling/configs/svd_xt_1_1.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: sgm.models.diffusion.DiffusionEngine 3 | params: 4 | scale_factor: 0.18215 5 | disable_first_stage_autocast: True 6 | ckpt_path: checkpoints/svd_xt_1_1.safetensors 7 | 8 | denoiser_config: 9 | target: sgm.modules.diffusionmodules.denoiser.Denoiser 10 | params: 11 | scaling_config: 12 | target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise 13 | 14 | network_config: 15 | target: sgm.modules.diffusionmodules.video_model.VideoUNet 16 | params: 17 | adm_in_channels: 768 18 | num_classes: sequential 19 | use_checkpoint: True 20 | in_channels: 8 21 | out_channels: 4 22 | model_channels: 320 23 | attention_resolutions: [4, 2, 1] 24 | num_res_blocks: 2 25 | channel_mult: [1, 2, 4, 4] 26 | num_head_channels: 64 27 | use_linear_in_transformer: True 28 | transformer_depth: 1 29 | context_dim: 1024 30 | spatial_transformer_attn_type: softmax-xformers 31 | extra_ff_mix_layer: True 32 | use_spatial_context: True 33 | merge_strategy: learned_with_images 34 | video_kernel_size: [3, 1, 1] 35 | 36 | conditioner_config: 37 | target: sgm.modules.GeneralConditioner 38 | params: 39 | emb_models: 40 | - is_trainable: False 41 | input_key: cond_frames_without_noise 42 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder 43 | params: 44 | n_cond_frames: 1 45 | n_copies: 1 46 | open_clip_embedding_config: 47 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder 48 | params: 49 | freeze: True 50 | 51 | - input_key: fps_id 52 | is_trainable: False 53 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 54 | params: 55 | outdim: 256 56 | 57 | - input_key: motion_bucket_id 58 | is_trainable: False 59 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 60 | params: 61 | outdim: 256 62 | 63 | - input_key: cond_frames 64 | is_trainable: False 65 | target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder 66 | params: 67 | disable_encoder_autocast: True 68 | n_cond_frames: 1 69 | n_copies: 1 70 | is_ae: True 71 | encoder_config: 72 | target: sgm.models.autoencoder.AutoencoderKLModeOnly 73 | params: 74 | embed_dim: 4 75 | monitor: val/rec_loss 76 | ddconfig: 77 | attn_type: vanilla-xformers 78 | double_z: True 79 | z_channels: 4 80 | resolution: 256 81 | in_channels: 3 82 | out_ch: 3 83 | ch: 128 84 | ch_mult: [1, 2, 4, 4] 85 | num_res_blocks: 2 86 | attn_resolutions: [] 87 | dropout: 0.0 88 | lossconfig: 89 | target: torch.nn.Identity 90 | 91 | - input_key: cond_aug 92 | is_trainable: False 93 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 94 | params: 95 | outdim: 256 96 | 97 | first_stage_config: 98 | target: sgm.models.autoencoder.AutoencodingEngine 99 | params: 100 | loss_config: 101 | target: torch.nn.Identity 102 | regularizer_config: 103 | target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer 104 | encoder_config: 105 | target: sgm.modules.diffusionmodules.model.Encoder 106 | params: 107 | attn_type: vanilla 108 | double_z: True 109 | z_channels: 4 110 | resolution: 256 111 | in_channels: 3 112 | out_ch: 3 113 | ch: 128 114 | ch_mult: [1, 2, 4, 4] 115 | num_res_blocks: 2 116 | attn_resolutions: [] 117 | dropout: 0.0 118 | decoder_config: 119 | target: sgm.modules.autoencoding.temporal_ae.VideoDecoder 120 | params: 121 | attn_type: vanilla 122 | double_z: True 123 | z_channels: 4 124 | resolution: 256 125 | in_channels: 3 126 | out_ch: 3 127 | ch: 128 128 | ch_mult: [1, 2, 4, 4] 129 | num_res_blocks: 2 130 | attn_resolutions: [] 131 | dropout: 0.0 132 | video_kernel_size: [3, 1, 1] 133 | 134 | sampler_config: 135 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler 136 | params: 137 | discretization_config: 138 | target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization 139 | params: 140 | sigma_max: 700.0 141 | 142 | guider_config: 143 | target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider 144 | params: 145 | max_scale: 3.0 146 | min_scale: 1.5 147 | -------------------------------------------------------------------------------- /scripts/sampling/configs/svd_xt_image_decoder.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: sgm.models.diffusion.DiffusionEngine 3 | params: 4 | scale_factor: 0.18215 5 | disable_first_stage_autocast: True 6 | ckpt_path: checkpoints/svd_xt_image_decoder.safetensors 7 | 8 | denoiser_config: 9 | target: sgm.modules.diffusionmodules.denoiser.Denoiser 10 | params: 11 | scaling_config: 12 | target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise 13 | 14 | network_config: 15 | target: sgm.modules.diffusionmodules.video_model.VideoUNet 16 | params: 17 | adm_in_channels: 768 18 | num_classes: sequential 19 | use_checkpoint: True 20 | in_channels: 8 21 | out_channels: 4 22 | model_channels: 320 23 | attention_resolutions: [4, 2, 1] 24 | num_res_blocks: 2 25 | channel_mult: [1, 2, 4, 4] 26 | num_head_channels: 64 27 | use_linear_in_transformer: True 28 | transformer_depth: 1 29 | context_dim: 1024 30 | spatial_transformer_attn_type: softmax-xformers 31 | extra_ff_mix_layer: True 32 | use_spatial_context: True 33 | merge_strategy: learned_with_images 34 | video_kernel_size: [3, 1, 1] 35 | 36 | conditioner_config: 37 | target: sgm.modules.GeneralConditioner 38 | params: 39 | emb_models: 40 | - is_trainable: False 41 | input_key: cond_frames_without_noise 42 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder 43 | params: 44 | n_cond_frames: 1 45 | n_copies: 1 46 | open_clip_embedding_config: 47 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder 48 | params: 49 | freeze: True 50 | 51 | - input_key: fps_id 52 | is_trainable: False 53 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 54 | params: 55 | outdim: 256 56 | 57 | - input_key: motion_bucket_id 58 | is_trainable: False 59 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 60 | params: 61 | outdim: 256 62 | 63 | - input_key: cond_frames 64 | is_trainable: False 65 | target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder 66 | params: 67 | disable_encoder_autocast: True 68 | n_cond_frames: 1 69 | n_copies: 1 70 | is_ae: True 71 | encoder_config: 72 | target: sgm.models.autoencoder.AutoencoderKLModeOnly 73 | params: 74 | embed_dim: 4 75 | monitor: val/rec_loss 76 | ddconfig: 77 | attn_type: vanilla-xformers 78 | double_z: True 79 | z_channels: 4 80 | resolution: 256 81 | in_channels: 3 82 | out_ch: 3 83 | ch: 128 84 | ch_mult: [1, 2, 4, 4] 85 | num_res_blocks: 2 86 | attn_resolutions: [] 87 | dropout: 0.0 88 | lossconfig: 89 | target: torch.nn.Identity 90 | 91 | - input_key: cond_aug 92 | is_trainable: False 93 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 94 | params: 95 | outdim: 256 96 | 97 | first_stage_config: 98 | target: sgm.models.autoencoder.AutoencoderKL 99 | params: 100 | embed_dim: 4 101 | monitor: val/rec_loss 102 | ddconfig: 103 | attn_type: vanilla-xformers 104 | double_z: True 105 | z_channels: 4 106 | resolution: 256 107 | in_channels: 3 108 | out_ch: 3 109 | ch: 128 110 | ch_mult: [1, 2, 4, 4] 111 | num_res_blocks: 2 112 | attn_resolutions: [] 113 | dropout: 0.0 114 | lossconfig: 115 | target: torch.nn.Identity 116 | 117 | sampler_config: 118 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler 119 | params: 120 | discretization_config: 121 | target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization 122 | params: 123 | sigma_max: 700.0 124 | 125 | guider_config: 126 | target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider 127 | params: 128 | max_scale: 3.0 129 | min_scale: 1.5 -------------------------------------------------------------------------------- /scripts/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/scripts/util/__init__.py -------------------------------------------------------------------------------- /scripts/util/detection/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/scripts/util/detection/__init__.py -------------------------------------------------------------------------------- /scripts/util/detection/nsfw_and_watermark_dectection.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import clip 4 | import numpy as np 5 | import torch 6 | import torchvision.transforms as T 7 | from PIL import Image 8 | 9 | RESOURCES_ROOT = "scripts/util/detection/" 10 | 11 | 12 | def predict_proba(X, weights, biases): 13 | logits = X @ weights.T + biases 14 | proba = np.where( 15 | logits >= 0, 1 / (1 + np.exp(-logits)), np.exp(logits) / (1 + np.exp(logits)) 16 | ) 17 | return proba.T 18 | 19 | 20 | def load_model_weights(path: str): 21 | model_weights = np.load(path) 22 | return model_weights["weights"], model_weights["biases"] 23 | 24 | 25 | def clip_process_images(images: torch.Tensor) -> torch.Tensor: 26 | min_size = min(images.shape[-2:]) 27 | return T.Compose( 28 | [ 29 | T.CenterCrop(min_size), # TODO: this might affect the watermark, check this 30 | T.Resize(224, interpolation=T.InterpolationMode.BICUBIC, antialias=True), 31 | T.Normalize( 32 | (0.48145466, 0.4578275, 0.40821073), 33 | (0.26862954, 0.26130258, 0.27577711), 34 | ), 35 | ] 36 | )(images) 37 | 38 | 39 | class DeepFloydDataFiltering(object): 40 | def __init__( 41 | self, verbose: bool = False, device: torch.device = torch.device("cpu") 42 | ): 43 | super().__init__() 44 | self.verbose = verbose 45 | self._device = None 46 | self.clip_model, _ = clip.load("ViT-L/14", device=device) 47 | self.clip_model.eval() 48 | 49 | self.cpu_w_weights, self.cpu_w_biases = load_model_weights( 50 | os.path.join(RESOURCES_ROOT, "w_head_v1.npz") 51 | ) 52 | self.cpu_p_weights, self.cpu_p_biases = load_model_weights( 53 | os.path.join(RESOURCES_ROOT, "p_head_v1.npz") 54 | ) 55 | self.w_threshold, self.p_threshold = 0.5, 0.5 56 | 57 | @torch.inference_mode() 58 | def __call__(self, images: torch.Tensor) -> torch.Tensor: 59 | imgs = clip_process_images(images) 60 | if self._device is None: 61 | self._device = next(p for p in self.clip_model.parameters()).device 62 | image_features = self.clip_model.encode_image(imgs.to(self._device)) 63 | image_features = image_features.detach().cpu().numpy().astype(np.float16) 64 | p_pred = predict_proba(image_features, self.cpu_p_weights, self.cpu_p_biases) 65 | w_pred = predict_proba(image_features, self.cpu_w_weights, self.cpu_w_biases) 66 | print(f"p_pred = {p_pred}, w_pred = {w_pred}") if self.verbose else None 67 | query = p_pred > self.p_threshold 68 | if query.sum() > 0: 69 | print(f"Hit for p_threshold: {p_pred}") if self.verbose else None 70 | images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query]) 71 | query = w_pred > self.w_threshold 72 | if query.sum() > 0: 73 | print(f"Hit for w_threshold: {w_pred}") if self.verbose else None 74 | images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query]) 75 | return images 76 | 77 | 78 | def load_img(path: str) -> torch.Tensor: 79 | image = Image.open(path) 80 | if not image.mode == "RGB": 81 | image = image.convert("RGB") 82 | image_transforms = T.Compose( 83 | [ 84 | T.ToTensor(), 85 | ] 86 | ) 87 | return image_transforms(image)[None, ...] 88 | 89 | 90 | def test(root): 91 | from einops import rearrange 92 | 93 | filter = DeepFloydDataFiltering(verbose=True) 94 | for p in os.listdir((root)): 95 | print(f"running on {p}...") 96 | img = load_img(os.path.join(root, p)) 97 | filtered_img = filter(img) 98 | filtered_img = rearrange( 99 | 255.0 * (filtered_img.numpy())[0], "c h w -> h w c" 100 | ).astype(np.uint8) 101 | Image.fromarray(filtered_img).save( 102 | os.path.join(root, f"{os.path.splitext(p)[0]}-filtered.jpg") 103 | ) 104 | 105 | 106 | if __name__ == "__main__": 107 | import fire 108 | 109 | fire.Fire(test) 110 | print("done.") 111 | -------------------------------------------------------------------------------- /scripts/util/detection/p_head_v1.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/scripts/util/detection/p_head_v1.npz -------------------------------------------------------------------------------- /scripts/util/detection/w_head_v1.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/scripts/util/detection/w_head_v1.npz -------------------------------------------------------------------------------- /sgm/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import AutoencodingEngine, DiffusionEngine 2 | from .util import get_configs_path, instantiate_from_config 3 | 4 | __version__ = "0.1.0" 5 | -------------------------------------------------------------------------------- /sgm/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import StableDataModuleFromConfig 2 | -------------------------------------------------------------------------------- /sgm/data/cifar10.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torchvision 3 | from torch.utils.data import DataLoader, Dataset 4 | from torchvision import transforms 5 | 6 | 7 | class CIFAR10DataDictWrapper(Dataset): 8 | def __init__(self, dset): 9 | super().__init__() 10 | self.dset = dset 11 | 12 | def __getitem__(self, i): 13 | x, y = self.dset[i] 14 | return {"jpg": x, "cls": y} 15 | 16 | def __len__(self): 17 | return len(self.dset) 18 | 19 | 20 | class CIFAR10Loader(pl.LightningDataModule): 21 | def __init__(self, batch_size, num_workers=0, shuffle=True): 22 | super().__init__() 23 | 24 | transform = transforms.Compose( 25 | [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)] 26 | ) 27 | 28 | self.batch_size = batch_size 29 | self.num_workers = num_workers 30 | self.shuffle = shuffle 31 | self.train_dataset = CIFAR10DataDictWrapper( 32 | torchvision.datasets.CIFAR10( 33 | root=".data/", train=True, download=True, transform=transform 34 | ) 35 | ) 36 | self.test_dataset = CIFAR10DataDictWrapper( 37 | torchvision.datasets.CIFAR10( 38 | root=".data/", train=False, download=True, transform=transform 39 | ) 40 | ) 41 | 42 | def prepare_data(self): 43 | pass 44 | 45 | def train_dataloader(self): 46 | return DataLoader( 47 | self.train_dataset, 48 | batch_size=self.batch_size, 49 | shuffle=self.shuffle, 50 | num_workers=self.num_workers, 51 | ) 52 | 53 | def test_dataloader(self): 54 | return DataLoader( 55 | self.test_dataset, 56 | batch_size=self.batch_size, 57 | shuffle=self.shuffle, 58 | num_workers=self.num_workers, 59 | ) 60 | 61 | def val_dataloader(self): 62 | return DataLoader( 63 | self.test_dataset, 64 | batch_size=self.batch_size, 65 | shuffle=self.shuffle, 66 | num_workers=self.num_workers, 67 | ) 68 | -------------------------------------------------------------------------------- /sgm/data/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torchdata.datapipes.iter 4 | import webdataset as wds 5 | from omegaconf import DictConfig 6 | from pytorch_lightning import LightningDataModule 7 | 8 | try: 9 | from sdata import create_dataset, create_dummy_dataset, create_loader 10 | except ImportError as e: 11 | print("#" * 100) 12 | print("Datasets not yet available") 13 | print("to enable, we need to add stable-datasets as a submodule") 14 | print("please use ``git submodule update --init --recursive``") 15 | print("and do ``pip install -e stable-datasets/`` from the root of this repo") 16 | print("#" * 100) 17 | exit(1) 18 | 19 | 20 | class StableDataModuleFromConfig(LightningDataModule): 21 | def __init__( 22 | self, 23 | train: DictConfig, 24 | validation: Optional[DictConfig] = None, 25 | test: Optional[DictConfig] = None, 26 | skip_val_loader: bool = False, 27 | dummy: bool = False, 28 | ): 29 | super().__init__() 30 | self.train_config = train 31 | assert ( 32 | "datapipeline" in self.train_config and "loader" in self.train_config 33 | ), "train config requires the fields `datapipeline` and `loader`" 34 | 35 | self.val_config = validation 36 | if not skip_val_loader: 37 | if self.val_config is not None: 38 | assert ( 39 | "datapipeline" in self.val_config and "loader" in self.val_config 40 | ), "validation config requires the fields `datapipeline` and `loader`" 41 | else: 42 | print( 43 | "Warning: No Validation datapipeline defined, using that one from training" 44 | ) 45 | self.val_config = train 46 | 47 | self.test_config = test 48 | if self.test_config is not None: 49 | assert ( 50 | "datapipeline" in self.test_config and "loader" in self.test_config 51 | ), "test config requires the fields `datapipeline` and `loader`" 52 | 53 | self.dummy = dummy 54 | if self.dummy: 55 | print("#" * 100) 56 | print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)") 57 | print("#" * 100) 58 | 59 | def setup(self, stage: str) -> None: 60 | print("Preparing datasets") 61 | if self.dummy: 62 | data_fn = create_dummy_dataset 63 | else: 64 | data_fn = create_dataset 65 | 66 | self.train_datapipeline = data_fn(**self.train_config.datapipeline) 67 | if self.val_config: 68 | self.val_datapipeline = data_fn(**self.val_config.datapipeline) 69 | if self.test_config: 70 | self.test_datapipeline = data_fn(**self.test_config.datapipeline) 71 | 72 | def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe: 73 | loader = create_loader(self.train_datapipeline, **self.train_config.loader) 74 | return loader 75 | 76 | def val_dataloader(self) -> wds.DataPipeline: 77 | return create_loader(self.val_datapipeline, **self.val_config.loader) 78 | 79 | def test_dataloader(self) -> wds.DataPipeline: 80 | return create_loader(self.test_datapipeline, **self.test_config.loader) 81 | -------------------------------------------------------------------------------- /sgm/data/mnist.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torchvision 3 | from torch.utils.data import DataLoader, Dataset 4 | from torchvision import transforms 5 | 6 | 7 | class MNISTDataDictWrapper(Dataset): 8 | def __init__(self, dset): 9 | super().__init__() 10 | self.dset = dset 11 | 12 | def __getitem__(self, i): 13 | x, y = self.dset[i] 14 | return {"jpg": x, "cls": y} 15 | 16 | def __len__(self): 17 | return len(self.dset) 18 | 19 | 20 | class MNISTLoader(pl.LightningDataModule): 21 | def __init__(self, batch_size, num_workers=0, prefetch_factor=2, shuffle=True): 22 | super().__init__() 23 | 24 | transform = transforms.Compose( 25 | [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)] 26 | ) 27 | 28 | self.batch_size = batch_size 29 | self.num_workers = num_workers 30 | self.prefetch_factor = prefetch_factor if num_workers > 0 else 0 31 | self.shuffle = shuffle 32 | self.train_dataset = MNISTDataDictWrapper( 33 | torchvision.datasets.MNIST( 34 | root=".data/", train=True, download=True, transform=transform 35 | ) 36 | ) 37 | self.test_dataset = MNISTDataDictWrapper( 38 | torchvision.datasets.MNIST( 39 | root=".data/", train=False, download=True, transform=transform 40 | ) 41 | ) 42 | 43 | def prepare_data(self): 44 | pass 45 | 46 | def train_dataloader(self): 47 | return DataLoader( 48 | self.train_dataset, 49 | batch_size=self.batch_size, 50 | shuffle=self.shuffle, 51 | num_workers=self.num_workers, 52 | prefetch_factor=self.prefetch_factor, 53 | ) 54 | 55 | def test_dataloader(self): 56 | return DataLoader( 57 | self.test_dataset, 58 | batch_size=self.batch_size, 59 | shuffle=self.shuffle, 60 | num_workers=self.num_workers, 61 | prefetch_factor=self.prefetch_factor, 62 | ) 63 | 64 | def val_dataloader(self): 65 | return DataLoader( 66 | self.test_dataset, 67 | batch_size=self.batch_size, 68 | shuffle=self.shuffle, 69 | num_workers=self.num_workers, 70 | prefetch_factor=self.prefetch_factor, 71 | ) 72 | 73 | 74 | if __name__ == "__main__": 75 | dset = MNISTDataDictWrapper( 76 | torchvision.datasets.MNIST( 77 | root=".data/", 78 | train=False, 79 | download=True, 80 | transform=transforms.Compose( 81 | [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)] 82 | ), 83 | ) 84 | ) 85 | ex = dset[0] 86 | -------------------------------------------------------------------------------- /sgm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | 9 | def __init__( 10 | self, 11 | warm_up_steps, 12 | lr_min, 13 | lr_max, 14 | lr_start, 15 | max_decay_steps, 16 | verbosity_interval=0, 17 | ): 18 | self.lr_warm_up_steps = warm_up_steps 19 | self.lr_start = lr_start 20 | self.lr_min = lr_min 21 | self.lr_max = lr_max 22 | self.lr_max_decay_steps = max_decay_steps 23 | self.last_lr = 0.0 24 | self.verbosity_interval = verbosity_interval 25 | 26 | def schedule(self, n, **kwargs): 27 | if self.verbosity_interval > 0: 28 | if n % self.verbosity_interval == 0: 29 | print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 30 | if n < self.lr_warm_up_steps: 31 | lr = ( 32 | self.lr_max - self.lr_start 33 | ) / self.lr_warm_up_steps * n + self.lr_start 34 | self.last_lr = lr 35 | return lr 36 | else: 37 | t = (n - self.lr_warm_up_steps) / ( 38 | self.lr_max_decay_steps - self.lr_warm_up_steps 39 | ) 40 | t = min(t, 1.0) 41 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 42 | 1 + np.cos(t * np.pi) 43 | ) 44 | self.last_lr = lr 45 | return lr 46 | 47 | def __call__(self, n, **kwargs): 48 | return self.schedule(n, **kwargs) 49 | 50 | 51 | class LambdaWarmUpCosineScheduler2: 52 | """ 53 | supports repeated iterations, configurable via lists 54 | note: use with a base_lr of 1.0. 55 | """ 56 | 57 | def __init__( 58 | self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0 59 | ): 60 | assert ( 61 | len(warm_up_steps) 62 | == len(f_min) 63 | == len(f_max) 64 | == len(f_start) 65 | == len(cycle_lengths) 66 | ) 67 | self.lr_warm_up_steps = warm_up_steps 68 | self.f_start = f_start 69 | self.f_min = f_min 70 | self.f_max = f_max 71 | self.cycle_lengths = cycle_lengths 72 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 73 | self.last_f = 0.0 74 | self.verbosity_interval = verbosity_interval 75 | 76 | def find_in_interval(self, n): 77 | interval = 0 78 | for cl in self.cum_cycles[1:]: 79 | if n <= cl: 80 | return interval 81 | interval += 1 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: 88 | print( 89 | f"current step: {n}, recent lr-multiplier: {self.last_f}, " 90 | f"current cycle {cycle}" 91 | ) 92 | if n < self.lr_warm_up_steps[cycle]: 93 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ 94 | cycle 95 | ] * n + self.f_start[cycle] 96 | self.last_f = f 97 | return f 98 | else: 99 | t = (n - self.lr_warm_up_steps[cycle]) / ( 100 | self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle] 101 | ) 102 | t = min(t, 1.0) 103 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 104 | 1 + np.cos(t * np.pi) 105 | ) 106 | self.last_f = f 107 | return f 108 | 109 | def __call__(self, n, **kwargs): 110 | return self.schedule(n, **kwargs) 111 | 112 | 113 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 114 | def schedule(self, n, **kwargs): 115 | cycle = self.find_in_interval(n) 116 | n = n - self.cum_cycles[cycle] 117 | if self.verbosity_interval > 0: 118 | if n % self.verbosity_interval == 0: 119 | print( 120 | f"current step: {n}, recent lr-multiplier: {self.last_f}, " 121 | f"current cycle {cycle}" 122 | ) 123 | 124 | if n < self.lr_warm_up_steps[cycle]: 125 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ 126 | cycle 127 | ] * n + self.f_start[cycle] 128 | self.last_f = f 129 | return f 130 | else: 131 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * ( 132 | self.cycle_lengths[cycle] - n 133 | ) / (self.cycle_lengths[cycle]) 134 | self.last_f = f 135 | return f 136 | -------------------------------------------------------------------------------- /sgm/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .autoencoder import AutoencodingEngine 2 | from .diffusion import DiffusionEngine 3 | -------------------------------------------------------------------------------- /sgm/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoders.modules import GeneralConditioner 2 | 3 | UNCONDITIONAL_CONFIG = { 4 | "target": "sgm.modules.GeneralConditioner", 5 | "params": {"emb_models": []}, 6 | } 7 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/sgm/modules/autoencoding/__init__.py -------------------------------------------------------------------------------- /sgm/modules/autoencoding/losses/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "GeneralLPIPSWithDiscriminator", 3 | "LatentLPIPS", 4 | ] 5 | 6 | from .discriminator_loss import GeneralLPIPSWithDiscriminator 7 | from .lpips import LatentLPIPS 8 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/losses/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from ....util import default, instantiate_from_config 5 | from ..lpips.loss.lpips import LPIPS 6 | 7 | 8 | class LatentLPIPS(nn.Module): 9 | def __init__( 10 | self, 11 | decoder_config, 12 | perceptual_weight=1.0, 13 | latent_weight=1.0, 14 | scale_input_to_tgt_size=False, 15 | scale_tgt_to_input_size=False, 16 | perceptual_weight_on_inputs=0.0, 17 | ): 18 | super().__init__() 19 | self.scale_input_to_tgt_size = scale_input_to_tgt_size 20 | self.scale_tgt_to_input_size = scale_tgt_to_input_size 21 | self.init_decoder(decoder_config) 22 | self.perceptual_loss = LPIPS().eval() 23 | self.perceptual_weight = perceptual_weight 24 | self.latent_weight = latent_weight 25 | self.perceptual_weight_on_inputs = perceptual_weight_on_inputs 26 | 27 | def init_decoder(self, config): 28 | self.decoder = instantiate_from_config(config) 29 | if hasattr(self.decoder, "encoder"): 30 | del self.decoder.encoder 31 | 32 | def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"): 33 | log = dict() 34 | loss = (latent_inputs - latent_predictions) ** 2 35 | log[f"{split}/latent_l2_loss"] = loss.mean().detach() 36 | image_reconstructions = None 37 | if self.perceptual_weight > 0.0: 38 | image_reconstructions = self.decoder.decode(latent_predictions) 39 | image_targets = self.decoder.decode(latent_inputs) 40 | perceptual_loss = self.perceptual_loss( 41 | image_targets.contiguous(), image_reconstructions.contiguous() 42 | ) 43 | loss = ( 44 | self.latent_weight * loss.mean() 45 | + self.perceptual_weight * perceptual_loss.mean() 46 | ) 47 | log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach() 48 | 49 | if self.perceptual_weight_on_inputs > 0.0: 50 | image_reconstructions = default( 51 | image_reconstructions, self.decoder.decode(latent_predictions) 52 | ) 53 | if self.scale_input_to_tgt_size: 54 | image_inputs = torch.nn.functional.interpolate( 55 | image_inputs, 56 | image_reconstructions.shape[2:], 57 | mode="bicubic", 58 | antialias=True, 59 | ) 60 | elif self.scale_tgt_to_input_size: 61 | image_reconstructions = torch.nn.functional.interpolate( 62 | image_reconstructions, 63 | image_inputs.shape[2:], 64 | mode="bicubic", 65 | antialias=True, 66 | ) 67 | 68 | perceptual_loss2 = self.perceptual_loss( 69 | image_inputs.contiguous(), image_reconstructions.contiguous() 70 | ) 71 | loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean() 72 | log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach() 73 | return loss, log 74 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/sgm/modules/autoencoding/lpips/__init__.py -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/loss/.gitignore: -------------------------------------------------------------------------------- 1 | vgg.pth -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/loss/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/sgm/modules/autoencoding/lpips/loss/__init__.py -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/loss/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | from collections import namedtuple 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from ..util import get_ckpt_path 10 | 11 | 12 | class LPIPS(nn.Module): 13 | # Learned perceptual metric 14 | def __init__(self, use_dropout=True): 15 | super().__init__() 16 | self.scaling_layer = ScalingLayer() 17 | self.chns = [64, 128, 256, 512, 512] # vg16 features 18 | self.net = vgg16(pretrained=True, requires_grad=False) 19 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 20 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 21 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 22 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 23 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 24 | self.load_from_pretrained() 25 | for param in self.parameters(): 26 | param.requires_grad = False 27 | 28 | def load_from_pretrained(self, name="vgg_lpips"): 29 | ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss") 30 | self.load_state_dict( 31 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False 32 | ) 33 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 34 | 35 | @classmethod 36 | def from_pretrained(cls, name="vgg_lpips"): 37 | if name != "vgg_lpips": 38 | raise NotImplementedError 39 | model = cls() 40 | ckpt = get_ckpt_path(name) 41 | model.load_state_dict( 42 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False 43 | ) 44 | return model 45 | 46 | def forward(self, input, target): 47 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 48 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 49 | feats0, feats1, diffs = {}, {}, {} 50 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 51 | for kk in range(len(self.chns)): 52 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor( 53 | outs1[kk] 54 | ) 55 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 56 | 57 | res = [ 58 | spatial_average(lins[kk].model(diffs[kk]), keepdim=True) 59 | for kk in range(len(self.chns)) 60 | ] 61 | val = res[0] 62 | for l in range(1, len(self.chns)): 63 | val += res[l] 64 | return val 65 | 66 | 67 | class ScalingLayer(nn.Module): 68 | def __init__(self): 69 | super(ScalingLayer, self).__init__() 70 | self.register_buffer( 71 | "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] 72 | ) 73 | self.register_buffer( 74 | "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None] 75 | ) 76 | 77 | def forward(self, inp): 78 | return (inp - self.shift) / self.scale 79 | 80 | 81 | class NetLinLayer(nn.Module): 82 | """A single linear layer which does a 1x1 conv""" 83 | 84 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 85 | super(NetLinLayer, self).__init__() 86 | layers = ( 87 | [ 88 | nn.Dropout(), 89 | ] 90 | if (use_dropout) 91 | else [] 92 | ) 93 | layers += [ 94 | nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), 95 | ] 96 | self.model = nn.Sequential(*layers) 97 | 98 | 99 | class vgg16(torch.nn.Module): 100 | def __init__(self, requires_grad=False, pretrained=True): 101 | super(vgg16, self).__init__() 102 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 103 | self.slice1 = torch.nn.Sequential() 104 | self.slice2 = torch.nn.Sequential() 105 | self.slice3 = torch.nn.Sequential() 106 | self.slice4 = torch.nn.Sequential() 107 | self.slice5 = torch.nn.Sequential() 108 | self.N_slices = 5 109 | for x in range(4): 110 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 111 | for x in range(4, 9): 112 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 113 | for x in range(9, 16): 114 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 115 | for x in range(16, 23): 116 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 117 | for x in range(23, 30): 118 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 119 | if not requires_grad: 120 | for param in self.parameters(): 121 | param.requires_grad = False 122 | 123 | def forward(self, X): 124 | h = self.slice1(X) 125 | h_relu1_2 = h 126 | h = self.slice2(h) 127 | h_relu2_2 = h 128 | h = self.slice3(h) 129 | h_relu3_3 = h 130 | h = self.slice4(h) 131 | h_relu4_3 = h 132 | h = self.slice5(h) 133 | h_relu5_3 = h 134 | vgg_outputs = namedtuple( 135 | "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"] 136 | ) 137 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 138 | return out 139 | 140 | 141 | def normalize_tensor(x, eps=1e-10): 142 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) 143 | return x / (norm_factor + eps) 144 | 145 | 146 | def spatial_average(x, keepdim=True): 147 | return x.mean([2, 3], keepdim=keepdim) 148 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/model/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017, Jun-Yan Zhu and Taesung Park 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | 25 | 26 | --------------------------- LICENSE FOR pix2pix -------------------------------- 27 | BSD License 28 | 29 | For pix2pix software 30 | Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu 31 | All rights reserved. 32 | 33 | Redistribution and use in source and binary forms, with or without 34 | modification, are permitted provided that the following conditions are met: 35 | 36 | * Redistributions of source code must retain the above copyright notice, this 37 | list of conditions and the following disclaimer. 38 | 39 | * Redistributions in binary form must reproduce the above copyright notice, 40 | this list of conditions and the following disclaimer in the documentation 41 | and/or other materials provided with the distribution. 42 | 43 | ----------------------------- LICENSE FOR DCGAN -------------------------------- 44 | BSD License 45 | 46 | For dcgan.torch software 47 | 48 | Copyright (c) 2015, Facebook, Inc. All rights reserved. 49 | 50 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 51 | 52 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 53 | 54 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 55 | 56 | Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 57 | 58 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/sgm/modules/autoencoding/lpips/model/__init__.py -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/model/model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch.nn as nn 4 | 5 | from ..util import ActNorm 6 | 7 | 8 | def weights_init(m): 9 | classname = m.__class__.__name__ 10 | if classname.find("Conv") != -1: 11 | nn.init.normal_(m.weight.data, 0.0, 0.02) 12 | elif classname.find("BatchNorm") != -1: 13 | nn.init.normal_(m.weight.data, 1.0, 0.02) 14 | nn.init.constant_(m.bias.data, 0) 15 | 16 | 17 | class NLayerDiscriminator(nn.Module): 18 | """Defines a PatchGAN discriminator as in Pix2Pix 19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 20 | """ 21 | 22 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 23 | """Construct a PatchGAN discriminator 24 | Parameters: 25 | input_nc (int) -- the number of channels in input images 26 | ndf (int) -- the number of filters in the last conv layer 27 | n_layers (int) -- the number of conv layers in the discriminator 28 | norm_layer -- normalization layer 29 | """ 30 | super(NLayerDiscriminator, self).__init__() 31 | if not use_actnorm: 32 | norm_layer = nn.BatchNorm2d 33 | else: 34 | norm_layer = ActNorm 35 | if ( 36 | type(norm_layer) == functools.partial 37 | ): # no need to use bias as BatchNorm2d has affine parameters 38 | use_bias = norm_layer.func != nn.BatchNorm2d 39 | else: 40 | use_bias = norm_layer != nn.BatchNorm2d 41 | 42 | kw = 4 43 | padw = 1 44 | sequence = [ 45 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 46 | nn.LeakyReLU(0.2, True), 47 | ] 48 | nf_mult = 1 49 | nf_mult_prev = 1 50 | for n in range(1, n_layers): # gradually increase the number of filters 51 | nf_mult_prev = nf_mult 52 | nf_mult = min(2**n, 8) 53 | sequence += [ 54 | nn.Conv2d( 55 | ndf * nf_mult_prev, 56 | ndf * nf_mult, 57 | kernel_size=kw, 58 | stride=2, 59 | padding=padw, 60 | bias=use_bias, 61 | ), 62 | norm_layer(ndf * nf_mult), 63 | nn.LeakyReLU(0.2, True), 64 | ] 65 | 66 | nf_mult_prev = nf_mult 67 | nf_mult = min(2**n_layers, 8) 68 | sequence += [ 69 | nn.Conv2d( 70 | ndf * nf_mult_prev, 71 | ndf * nf_mult, 72 | kernel_size=kw, 73 | stride=1, 74 | padding=padw, 75 | bias=use_bias, 76 | ), 77 | norm_layer(ndf * nf_mult), 78 | nn.LeakyReLU(0.2, True), 79 | ] 80 | 81 | sequence += [ 82 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) 83 | ] # output 1 channel prediction map 84 | self.main = nn.Sequential(*sequence) 85 | 86 | def forward(self, input): 87 | """Standard forward.""" 88 | return self.main(input) 89 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/util.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | 4 | import requests 5 | import torch 6 | import torch.nn as nn 7 | from tqdm import tqdm 8 | 9 | URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} 10 | 11 | CKPT_MAP = {"vgg_lpips": "vgg.pth"} 12 | 13 | MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"} 14 | 15 | 16 | def download(url, local_path, chunk_size=1024): 17 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 18 | with requests.get(url, stream=True) as r: 19 | total_size = int(r.headers.get("content-length", 0)) 20 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 21 | with open(local_path, "wb") as f: 22 | for data in r.iter_content(chunk_size=chunk_size): 23 | if data: 24 | f.write(data) 25 | pbar.update(chunk_size) 26 | 27 | 28 | def md5_hash(path): 29 | with open(path, "rb") as f: 30 | content = f.read() 31 | return hashlib.md5(content).hexdigest() 32 | 33 | 34 | def get_ckpt_path(name, root, check=False): 35 | assert name in URL_MAP 36 | path = os.path.join(root, CKPT_MAP[name]) 37 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 38 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 39 | download(URL_MAP[name], path) 40 | md5 = md5_hash(path) 41 | assert md5 == MD5_MAP[name], md5 42 | return path 43 | 44 | 45 | class ActNorm(nn.Module): 46 | def __init__( 47 | self, num_features, logdet=False, affine=True, allow_reverse_init=False 48 | ): 49 | assert affine 50 | super().__init__() 51 | self.logdet = logdet 52 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 53 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 54 | self.allow_reverse_init = allow_reverse_init 55 | 56 | self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) 57 | 58 | def initialize(self, input): 59 | with torch.no_grad(): 60 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 61 | mean = ( 62 | flatten.mean(1) 63 | .unsqueeze(1) 64 | .unsqueeze(2) 65 | .unsqueeze(3) 66 | .permute(1, 0, 2, 3) 67 | ) 68 | std = ( 69 | flatten.std(1) 70 | .unsqueeze(1) 71 | .unsqueeze(2) 72 | .unsqueeze(3) 73 | .permute(1, 0, 2, 3) 74 | ) 75 | 76 | self.loc.data.copy_(-mean) 77 | self.scale.data.copy_(1 / (std + 1e-6)) 78 | 79 | def forward(self, input, reverse=False): 80 | if reverse: 81 | return self.reverse(input) 82 | if len(input.shape) == 2: 83 | input = input[:, :, None, None] 84 | squeeze = True 85 | else: 86 | squeeze = False 87 | 88 | _, _, height, width = input.shape 89 | 90 | if self.training and self.initialized.item() == 0: 91 | self.initialize(input) 92 | self.initialized.fill_(1) 93 | 94 | h = self.scale * (input + self.loc) 95 | 96 | if squeeze: 97 | h = h.squeeze(-1).squeeze(-1) 98 | 99 | if self.logdet: 100 | log_abs = torch.log(torch.abs(self.scale)) 101 | logdet = height * width * torch.sum(log_abs) 102 | logdet = logdet * torch.ones(input.shape[0]).to(input) 103 | return h, logdet 104 | 105 | return h 106 | 107 | def reverse(self, output): 108 | if self.training and self.initialized.item() == 0: 109 | if not self.allow_reverse_init: 110 | raise RuntimeError( 111 | "Initializing ActNorm in reverse direction is " 112 | "disabled by default. Use allow_reverse_init=True to enable." 113 | ) 114 | else: 115 | self.initialize(output) 116 | self.initialized.fill_(1) 117 | 118 | if len(output.shape) == 2: 119 | output = output[:, :, None, None] 120 | squeeze = True 121 | else: 122 | squeeze = False 123 | 124 | h = output / self.scale - self.loc 125 | 126 | if squeeze: 127 | h = h.squeeze(-1).squeeze(-1) 128 | return h 129 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def hinge_d_loss(logits_real, logits_fake): 6 | loss_real = torch.mean(F.relu(1.0 - logits_real)) 7 | loss_fake = torch.mean(F.relu(1.0 + logits_fake)) 8 | d_loss = 0.5 * (loss_real + loss_fake) 9 | return d_loss 10 | 11 | 12 | def vanilla_d_loss(logits_real, logits_fake): 13 | d_loss = 0.5 * ( 14 | torch.mean(torch.nn.functional.softplus(-logits_real)) 15 | + torch.mean(torch.nn.functional.softplus(logits_fake)) 16 | ) 17 | return d_loss 18 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/regularizers/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Any, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from ....modules.distributions.distributions import \ 9 | DiagonalGaussianDistribution 10 | from .base import AbstractRegularizer 11 | 12 | 13 | class DiagonalGaussianRegularizer(AbstractRegularizer): 14 | def __init__(self, sample: bool = True): 15 | super().__init__() 16 | self.sample = sample 17 | 18 | def get_trainable_parameters(self) -> Any: 19 | yield from () 20 | 21 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 22 | log = dict() 23 | posterior = DiagonalGaussianDistribution(z) 24 | if self.sample: 25 | z = posterior.sample() 26 | else: 27 | z = posterior.mode() 28 | kl_loss = posterior.kl() 29 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 30 | log["kl_loss"] = kl_loss 31 | return z, log 32 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/regularizers/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Any, Tuple 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | 9 | class AbstractRegularizer(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 14 | raise NotImplementedError() 15 | 16 | @abstractmethod 17 | def get_trainable_parameters(self) -> Any: 18 | raise NotImplementedError() 19 | 20 | 21 | class IdentityRegularizer(AbstractRegularizer): 22 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 23 | return z, dict() 24 | 25 | def get_trainable_parameters(self) -> Any: 26 | yield from () 27 | 28 | 29 | def measure_perplexity( 30 | predicted_indices: torch.Tensor, num_centroids: int 31 | ) -> Tuple[torch.Tensor, torch.Tensor]: 32 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 33 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 34 | encodings = ( 35 | F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids) 36 | ) 37 | avg_probs = encodings.mean(0) 38 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 39 | cluster_use = torch.sum(avg_probs > 0) 40 | return perplexity, cluster_use 41 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/sgm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/denoiser.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from ...util import append_dims, instantiate_from_config 7 | from .denoiser_scaling import DenoiserScaling 8 | from .discretizer import Discretization 9 | 10 | 11 | class Denoiser(nn.Module): 12 | def __init__(self, scaling_config: Dict): 13 | super().__init__() 14 | 15 | self.scaling: DenoiserScaling = instantiate_from_config(scaling_config) 16 | 17 | def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor: 18 | return sigma 19 | 20 | def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor: 21 | return c_noise 22 | 23 | def forward( 24 | self, 25 | network: nn.Module, 26 | input: torch.Tensor, 27 | sigma: torch.Tensor, 28 | cond: Dict, 29 | **additional_model_inputs, 30 | ) -> torch.Tensor: 31 | sigma = self.possibly_quantize_sigma(sigma) 32 | sigma_shape = sigma.shape 33 | sigma = append_dims(sigma, input.ndim) 34 | c_skip, c_out, c_in, c_noise = self.scaling(sigma) 35 | c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) 36 | return ( 37 | network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out 38 | + input * c_skip 39 | ) 40 | 41 | 42 | class DiscreteDenoiser(Denoiser): 43 | def __init__( 44 | self, 45 | scaling_config: Dict, 46 | num_idx: int, 47 | discretization_config: Dict, 48 | do_append_zero: bool = False, 49 | quantize_c_noise: bool = True, 50 | flip: bool = True, 51 | ): 52 | super().__init__(scaling_config) 53 | self.discretization: Discretization = instantiate_from_config( 54 | discretization_config 55 | ) 56 | sigmas = self.discretization(num_idx, do_append_zero=do_append_zero, flip=flip) 57 | self.register_buffer("sigmas", sigmas) 58 | self.quantize_c_noise = quantize_c_noise 59 | self.num_idx = num_idx 60 | 61 | def sigma_to_idx(self, sigma: torch.Tensor) -> torch.Tensor: 62 | dists = sigma - self.sigmas[:, None] 63 | return dists.abs().argmin(dim=0).view(sigma.shape) 64 | 65 | def idx_to_sigma(self, idx: Union[torch.Tensor, int]) -> torch.Tensor: 66 | return self.sigmas[idx] 67 | 68 | def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor: 69 | return self.idx_to_sigma(self.sigma_to_idx(sigma)) 70 | 71 | def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor: 72 | if self.quantize_c_noise: 73 | return self.sigma_to_idx(c_noise) 74 | else: 75 | return c_noise 76 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/denoiser_scaling.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Tuple 3 | 4 | import torch 5 | 6 | 7 | class DenoiserScaling(ABC): 8 | @abstractmethod 9 | def __call__( 10 | self, sigma: torch.Tensor 11 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 12 | pass 13 | 14 | 15 | class EDMScaling: 16 | def __init__(self, sigma_data: float = 0.5): 17 | self.sigma_data = sigma_data 18 | 19 | def __call__( 20 | self, sigma: torch.Tensor 21 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 22 | c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) 23 | c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 24 | c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 25 | c_noise = 0.25 * sigma.log() 26 | return c_skip, c_out, c_in, c_noise 27 | 28 | 29 | class EpsScaling: 30 | def __call__( 31 | self, sigma: torch.Tensor 32 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 33 | c_skip = torch.ones_like(sigma, device=sigma.device) 34 | c_out = -sigma 35 | c_in = 1 / (sigma**2 + 1.0) ** 0.5 36 | c_noise = sigma.clone() 37 | return c_skip, c_out, c_in, c_noise 38 | 39 | 40 | class VScaling: 41 | def __call__( 42 | self, sigma: torch.Tensor 43 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 44 | c_skip = 1.0 / (sigma**2 + 1.0) 45 | c_out = -sigma / (sigma**2 + 1.0) ** 0.5 46 | c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 47 | c_noise = sigma.clone() 48 | return c_skip, c_out, c_in, c_noise 49 | 50 | 51 | class VScalingWithEDMcNoise(DenoiserScaling): 52 | def __call__( 53 | self, sigma: torch.Tensor 54 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 55 | c_skip = 1.0 / (sigma**2 + 1.0) 56 | c_out = -sigma / (sigma**2 + 1.0) ** 0.5 57 | c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 58 | c_noise = 0.25 * sigma.log() 59 | return c_skip, c_out, c_in, c_noise 60 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/denoiser_weighting.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class UnitWeighting: 5 | def __call__(self, sigma): 6 | return torch.ones_like(sigma, device=sigma.device) 7 | 8 | 9 | class EDMWeighting: 10 | def __init__(self, sigma_data=0.5): 11 | self.sigma_data = sigma_data 12 | 13 | def __call__(self, sigma): 14 | return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 15 | 16 | 17 | class VWeighting(EDMWeighting): 18 | def __init__(self): 19 | super().__init__(sigma_data=1.0) 20 | 21 | 22 | class EpsWeighting: 23 | def __call__(self, sigma): 24 | return sigma**-2.0 25 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/discretizer.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from functools import partial 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from ...modules.diffusionmodules.util import make_beta_schedule 8 | from ...util import append_zero 9 | 10 | 11 | def generate_roughly_equally_spaced_steps( 12 | num_substeps: int, max_step: int 13 | ) -> np.ndarray: 14 | return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1] 15 | 16 | 17 | class Discretization: 18 | def __call__(self, n, do_append_zero=True, device="cpu", flip=False): 19 | sigmas = self.get_sigmas(n, device=device) 20 | sigmas = append_zero(sigmas) if do_append_zero else sigmas 21 | return sigmas if not flip else torch.flip(sigmas, (0,)) 22 | 23 | @abstractmethod 24 | def get_sigmas(self, n, device): 25 | pass 26 | 27 | 28 | class EDMDiscretization(Discretization): 29 | def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0): 30 | self.sigma_min = sigma_min 31 | self.sigma_max = sigma_max 32 | self.rho = rho 33 | 34 | def get_sigmas(self, n, device="cpu"): 35 | ramp = torch.linspace(0, 1, n, device=device) 36 | min_inv_rho = self.sigma_min ** (1 / self.rho) 37 | max_inv_rho = self.sigma_max ** (1 / self.rho) 38 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho 39 | return sigmas 40 | 41 | 42 | class LegacyDDPMDiscretization(Discretization): 43 | def __init__( 44 | self, 45 | linear_start=0.00085, 46 | linear_end=0.0120, 47 | num_timesteps=1000, 48 | ): 49 | super().__init__() 50 | self.num_timesteps = num_timesteps 51 | betas = make_beta_schedule( 52 | "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end 53 | ) 54 | alphas = 1.0 - betas 55 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 56 | self.to_torch = partial(torch.tensor, dtype=torch.float32) 57 | 58 | def get_sigmas(self, n, device="cpu"): 59 | if n < self.num_timesteps: 60 | timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) 61 | alphas_cumprod = self.alphas_cumprod[timesteps] 62 | elif n == self.num_timesteps: 63 | alphas_cumprod = self.alphas_cumprod 64 | else: 65 | raise ValueError 66 | 67 | to_torch = partial(torch.tensor, dtype=torch.float32, device=device) 68 | sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 69 | return torch.flip(sigmas, (0,)) 70 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/guiders.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from abc import ABC, abstractmethod 3 | from typing import Dict, List, Literal, Optional, Tuple, Union 4 | 5 | import torch 6 | from einops import rearrange, repeat 7 | 8 | from ...util import append_dims, default 9 | 10 | logpy = logging.getLogger(__name__) 11 | 12 | 13 | class Guider(ABC): 14 | @abstractmethod 15 | def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor: 16 | pass 17 | 18 | def prepare_inputs( 19 | self, x: torch.Tensor, s: float, c: Dict, uc: Dict 20 | ) -> Tuple[torch.Tensor, float, Dict]: 21 | pass 22 | 23 | 24 | class VanillaCFG(Guider): 25 | def __init__(self, scale: float): 26 | self.scale = scale 27 | 28 | def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: 29 | x_u, x_c = x.chunk(2) 30 | x_pred = x_u + self.scale * (x_c - x_u) 31 | return x_pred 32 | 33 | def prepare_inputs(self, x, s, c, uc): 34 | c_out = dict() 35 | 36 | for k in c: 37 | if k in ["vector", "crossattn", "concat"]: 38 | c_out[k] = torch.cat((uc[k], c[k]), 0) 39 | else: 40 | assert c[k] == uc[k] 41 | c_out[k] = c[k] 42 | return torch.cat([x] * 2), torch.cat([s] * 2), c_out 43 | 44 | 45 | class IdentityGuider(Guider): 46 | def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor: 47 | return x 48 | 49 | def prepare_inputs( 50 | self, x: torch.Tensor, s: float, c: Dict, uc: Dict 51 | ) -> Tuple[torch.Tensor, float, Dict]: 52 | c_out = dict() 53 | 54 | for k in c: 55 | c_out[k] = c[k] 56 | 57 | return x, s, c_out 58 | 59 | 60 | class LinearPredictionGuider(Guider): 61 | def __init__( 62 | self, 63 | max_scale: float, 64 | num_frames: int, 65 | min_scale: float = 1.0, 66 | additional_cond_keys: Optional[Union[List[str], str]] = None, 67 | ): 68 | self.min_scale = min_scale 69 | self.max_scale = max_scale 70 | self.num_frames = num_frames 71 | self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0) 72 | 73 | additional_cond_keys = default(additional_cond_keys, []) 74 | if isinstance(additional_cond_keys, str): 75 | additional_cond_keys = [additional_cond_keys] 76 | self.additional_cond_keys = additional_cond_keys 77 | 78 | def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: 79 | x_u, x_c = x.chunk(2) 80 | 81 | x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames) 82 | x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames) 83 | scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0]) 84 | scale = append_dims(scale, x_u.ndim).to(x_u.device) 85 | 86 | return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...") 87 | 88 | def prepare_inputs( 89 | self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict 90 | ) -> Tuple[torch.Tensor, torch.Tensor, dict]: 91 | c_out = dict() 92 | 93 | for k in c: 94 | if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys: 95 | c_out[k] = torch.cat((uc[k], c[k]), 0) 96 | else: 97 | # assert c[k] == uc[k] 98 | c_out[k] = c[k] 99 | return torch.cat([x] * 2), torch.cat([s] * 2), c_out 100 | 101 | 102 | class TrianglePredictionGuider(LinearPredictionGuider): 103 | def __init__( 104 | self, 105 | max_scale: float, 106 | num_frames: int, 107 | min_scale: float = 1.0, 108 | period: Union[float, List[float]] = 1.0, 109 | period_fusing: Literal["mean", "multiply", "max"] = "max", 110 | additional_cond_keys: Optional[Union[List[str], str]] = None, 111 | ): 112 | super().__init__(max_scale, num_frames, min_scale, additional_cond_keys) 113 | values = torch.linspace(0, 1, num_frames) 114 | # Constructs a triangle wave 115 | if isinstance(period, float): 116 | period = [period] 117 | 118 | scales = [] 119 | for p in period: 120 | scales.append(self.triangle_wave(values, p)) 121 | 122 | if period_fusing == "mean": 123 | scale = sum(scales) / len(period) 124 | elif period_fusing == "multiply": 125 | scale = torch.prod(torch.stack(scales), dim=0) 126 | elif period_fusing == "max": 127 | scale = torch.max(torch.stack(scales), dim=0).values 128 | self.scale = (scale * (max_scale - min_scale) + min_scale).unsqueeze(0) 129 | 130 | def triangle_wave(self, values: torch.Tensor, period) -> torch.Tensor: 131 | return 2 * (values / period - torch.floor(values / period + 0.5)).abs() 132 | 133 | 134 | class TrapezoidPredictionGuider(LinearPredictionGuider): 135 | def __init__( 136 | self, 137 | max_scale: float, 138 | num_frames: int, 139 | min_scale: float = 1.0, 140 | edge_perc: float = 0.1, 141 | additional_cond_keys: Optional[Union[List[str], str]] = None, 142 | ): 143 | super().__init__(max_scale, num_frames, min_scale, additional_cond_keys) 144 | 145 | rise_steps = torch.linspace(min_scale, max_scale, int(num_frames * edge_perc)) 146 | fall_steps = torch.flip(rise_steps, [0]) 147 | self.scale = torch.cat( 148 | [ 149 | rise_steps, 150 | torch.ones(num_frames - 2 * int(num_frames * edge_perc)), 151 | fall_steps, 152 | ] 153 | ).unsqueeze(0) 154 | 155 | 156 | class SpatiotemporalPredictionGuider(LinearPredictionGuider): 157 | def __init__( 158 | self, 159 | max_scale: float, 160 | num_frames: int, 161 | num_views: int = 1, 162 | min_scale: float = 1.0, 163 | additional_cond_keys: Optional[Union[List[str], str]] = None, 164 | ): 165 | super().__init__(max_scale, num_frames, min_scale, additional_cond_keys) 166 | V = num_views 167 | T = num_frames // V 168 | scale = torch.zeros(num_frames).view(T, V) 169 | scale += torch.linspace(0, 1, T)[:,None] * 0.5 170 | scale += self.triangle_wave(torch.linspace(0, 1, V))[None,:] * 0.5 171 | scale = scale.flatten() 172 | self.scale = (scale * (max_scale - min_scale) + min_scale).unsqueeze(0) 173 | 174 | def triangle_wave(self, values: torch.Tensor, period=1) -> torch.Tensor: 175 | return 2 * (values / period - torch.floor(values / period + 0.5)).abs() -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/loss.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from ...modules.autoencoding.lpips.loss.lpips import LPIPS 7 | from ...modules.encoders.modules import GeneralConditioner 8 | from ...util import append_dims, instantiate_from_config 9 | from .denoiser import Denoiser 10 | 11 | 12 | class StandardDiffusionLoss(nn.Module): 13 | def __init__( 14 | self, 15 | sigma_sampler_config: dict, 16 | loss_weighting_config: dict, 17 | loss_type: str = "l2", 18 | offset_noise_level: float = 0.0, 19 | batch2model_keys: Optional[Union[str, List[str]]] = None, 20 | ): 21 | super().__init__() 22 | 23 | assert loss_type in ["l2", "l1", "lpips"] 24 | 25 | self.sigma_sampler = instantiate_from_config(sigma_sampler_config) 26 | self.loss_weighting = instantiate_from_config(loss_weighting_config) 27 | 28 | self.loss_type = loss_type 29 | self.offset_noise_level = offset_noise_level 30 | 31 | if loss_type == "lpips": 32 | self.lpips = LPIPS().eval() 33 | 34 | if not batch2model_keys: 35 | batch2model_keys = [] 36 | 37 | if isinstance(batch2model_keys, str): 38 | batch2model_keys = [batch2model_keys] 39 | 40 | self.batch2model_keys = set(batch2model_keys) 41 | 42 | def get_noised_input( 43 | self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor 44 | ) -> torch.Tensor: 45 | noised_input = input + noise * sigmas_bc 46 | return noised_input 47 | 48 | def forward( 49 | self, 50 | network: nn.Module, 51 | denoiser: Denoiser, 52 | conditioner: GeneralConditioner, 53 | input: torch.Tensor, 54 | batch: Dict, 55 | ) -> torch.Tensor: 56 | cond = conditioner(batch) 57 | return self._forward(network, denoiser, cond, input, batch) 58 | 59 | def _forward( 60 | self, 61 | network: nn.Module, 62 | denoiser: Denoiser, 63 | cond: Dict, 64 | input: torch.Tensor, 65 | batch: Dict, 66 | ) -> Tuple[torch.Tensor, Dict]: 67 | additional_model_inputs = { 68 | key: batch[key] for key in self.batch2model_keys.intersection(batch) 69 | } 70 | sigmas = self.sigma_sampler(input.shape[0]).to(input) 71 | 72 | noise = torch.randn_like(input) 73 | if self.offset_noise_level > 0.0: 74 | offset_shape = ( 75 | (input.shape[0], 1, input.shape[2]) 76 | if self.n_frames is not None 77 | else (input.shape[0], input.shape[1]) 78 | ) 79 | noise = noise + self.offset_noise_level * append_dims( 80 | torch.randn(offset_shape, device=input.device), 81 | input.ndim, 82 | ) 83 | sigmas_bc = append_dims(sigmas, input.ndim) 84 | noised_input = self.get_noised_input(sigmas_bc, noise, input) 85 | 86 | model_output = denoiser( 87 | network, noised_input, sigmas, cond, **additional_model_inputs 88 | ) 89 | w = append_dims(self.loss_weighting(sigmas), input.ndim) 90 | return self.get_loss(model_output, input, w) 91 | 92 | def get_loss(self, model_output, target, w): 93 | if self.loss_type == "l2": 94 | return torch.mean( 95 | (w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1 96 | ) 97 | elif self.loss_type == "l1": 98 | return torch.mean( 99 | (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1 100 | ) 101 | elif self.loss_type == "lpips": 102 | loss = self.lpips(model_output, target).reshape(-1) 103 | return loss 104 | else: 105 | raise NotImplementedError(f"Unknown loss type {self.loss_type}") 106 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/loss_weighting.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch 4 | 5 | 6 | class DiffusionLossWeighting(ABC): 7 | @abstractmethod 8 | def __call__(self, sigma: torch.Tensor) -> torch.Tensor: 9 | pass 10 | 11 | 12 | class UnitWeighting(DiffusionLossWeighting): 13 | def __call__(self, sigma: torch.Tensor) -> torch.Tensor: 14 | return torch.ones_like(sigma, device=sigma.device) 15 | 16 | 17 | class EDMWeighting(DiffusionLossWeighting): 18 | def __init__(self, sigma_data: float = 0.5): 19 | self.sigma_data = sigma_data 20 | 21 | def __call__(self, sigma: torch.Tensor) -> torch.Tensor: 22 | return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 23 | 24 | 25 | class VWeighting(EDMWeighting): 26 | def __init__(self): 27 | super().__init__(sigma_data=1.0) 28 | 29 | 30 | class EpsWeighting(DiffusionLossWeighting): 31 | def __call__(self, sigma: torch.Tensor) -> torch.Tensor: 32 | return sigma**-2.0 33 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/sampling_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scipy import integrate 3 | 4 | from ...util import append_dims 5 | 6 | 7 | def linear_multistep_coeff(order, t, i, j, epsrel=1e-4): 8 | if order - 1 > i: 9 | raise ValueError(f"Order {order} too high for step {i}") 10 | 11 | def fn(tau): 12 | prod = 1.0 13 | for k in range(order): 14 | if j == k: 15 | continue 16 | prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) 17 | return prod 18 | 19 | return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0] 20 | 21 | 22 | def get_ancestral_step(sigma_from, sigma_to, eta=1.0): 23 | if not eta: 24 | return sigma_to, 0.0 25 | sigma_up = torch.minimum( 26 | sigma_to, 27 | eta 28 | * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5, 29 | ) 30 | sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 31 | return sigma_down, sigma_up 32 | 33 | 34 | def to_d(x, sigma, denoised): 35 | return (x - denoised) / append_dims(sigma, x.ndim) 36 | 37 | 38 | def to_neg_log_sigma(sigma): 39 | return sigma.log().neg() 40 | 41 | 42 | def to_sigma(neg_log_sigma): 43 | return neg_log_sigma.neg().exp() 44 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/sigma_sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional, Union 3 | from ...util import default, instantiate_from_config 4 | 5 | 6 | class EDMSampling: 7 | def __init__(self, p_mean=-1.2, p_std=1.2): 8 | self.p_mean = p_mean 9 | self.p_std = p_std 10 | 11 | def __call__(self, n_samples, rand=None): 12 | log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,))) 13 | return log_sigma.exp() 14 | 15 | 16 | class DiscreteSampling: 17 | def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True): 18 | self.num_idx = num_idx 19 | self.sigmas = instantiate_from_config(discretization_config)( 20 | num_idx, do_append_zero=do_append_zero, flip=flip 21 | ) 22 | 23 | def idx_to_sigma(self, idx): 24 | return self.sigmas[idx] 25 | 26 | def __call__(self, n_samples, rand=None): 27 | idx = default( 28 | rand, 29 | torch.randint(0, self.num_idx, (n_samples,)), 30 | ) 31 | return self.idx_to_sigma(idx) 32 | 33 | 34 | class ZeroSampler: 35 | def __call__( 36 | self, n_samples: int, rand: Optional[torch.Tensor] = None 37 | ) -> torch.Tensor: 38 | return torch.zeros_like(default(rand, torch.randn((n_samples,)))) + 1.0e-5 39 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/wrappers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from packaging import version 4 | 5 | OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper" 6 | 7 | 8 | class IdentityWrapper(nn.Module): 9 | def __init__(self, diffusion_model, compile_model: bool = False): 10 | super().__init__() 11 | compile = ( 12 | torch.compile 13 | if (version.parse(torch.__version__) >= version.parse("2.0.0")) 14 | and compile_model 15 | else lambda x: x 16 | ) 17 | self.diffusion_model = compile(diffusion_model) 18 | 19 | def forward(self, *args, **kwargs): 20 | return self.diffusion_model(*args, **kwargs) 21 | 22 | 23 | class OpenAIWrapper(IdentityWrapper): 24 | def forward( 25 | self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs 26 | ) -> torch.Tensor: 27 | x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1) 28 | if "cond_view" in c: 29 | return self.diffusion_model( 30 | x, 31 | timesteps=t, 32 | context=c.get("crossattn", None), 33 | y=c.get("vector", None), 34 | cond_view=c.get("cond_view", None), 35 | cond_motion=c.get("cond_motion", None), 36 | **kwargs, 37 | ) 38 | else: 39 | return self.diffusion_model( 40 | x, 41 | timesteps=t, 42 | context=c.get("crossattn", None), 43 | y=c.get("vector", None), 44 | **kwargs, 45 | ) 46 | -------------------------------------------------------------------------------- /sgm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/sgm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /sgm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to( 34 | device=self.parameters.device 35 | ) 36 | 37 | def sample(self): 38 | x = self.mean + self.std * torch.randn(self.mean.shape).to( 39 | device=self.parameters.device 40 | ) 41 | return x 42 | 43 | def kl(self, other=None): 44 | if self.deterministic: 45 | return torch.Tensor([0.0]) 46 | else: 47 | if other is None: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 50 | dim=[1, 2, 3], 51 | ) 52 | else: 53 | return 0.5 * torch.sum( 54 | torch.pow(self.mean - other.mean, 2) / other.var 55 | + self.var / other.var 56 | - 1.0 57 | - self.logvar 58 | + other.logvar, 59 | dim=[1, 2, 3], 60 | ) 61 | 62 | def nll(self, sample, dims=[1, 2, 3]): 63 | if self.deterministic: 64 | return torch.Tensor([0.0]) 65 | logtwopi = np.log(2.0 * np.pi) 66 | return 0.5 * torch.sum( 67 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 68 | dim=dims, 69 | ) 70 | 71 | def mode(self): 72 | return self.mean 73 | 74 | 75 | def normal_kl(mean1, logvar1, mean2, logvar2): 76 | """ 77 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 78 | Compute the KL divergence between two gaussians. 79 | Shapes are automatically broadcasted, so batches can be compared to 80 | scalars, among other use cases. 81 | """ 82 | tensor = None 83 | for obj in (mean1, logvar1, mean2, logvar2): 84 | if isinstance(obj, torch.Tensor): 85 | tensor = obj 86 | break 87 | assert tensor is not None, "at least one argument must be a Tensor" 88 | 89 | # Force variances to be Tensors. Broadcasting helps convert scalars to 90 | # Tensors, but it does not work for torch.exp(). 91 | logvar1, logvar2 = [ 92 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 93 | for x in (logvar1, logvar2) 94 | ] 95 | 96 | return 0.5 * ( 97 | -1.0 98 | + logvar2 99 | - logvar1 100 | + torch.exp(logvar1 - logvar2) 101 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 102 | ) 103 | -------------------------------------------------------------------------------- /sgm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError("Decay must be between 0 and 1") 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer( 14 | "num_updates", 15 | torch.tensor(0, dtype=torch.int) 16 | if use_num_upates 17 | else torch.tensor(-1, dtype=torch.int), 18 | ) 19 | 20 | for name, p in model.named_parameters(): 21 | if p.requires_grad: 22 | # remove as '.'-character is not allowed in buffers 23 | s_name = name.replace(".", "") 24 | self.m_name2s_name.update({name: s_name}) 25 | self.register_buffer(s_name, p.clone().detach().data) 26 | 27 | self.collected_params = [] 28 | 29 | def reset_num_updates(self): 30 | del self.num_updates 31 | self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int)) 32 | 33 | def forward(self, model): 34 | decay = self.decay 35 | 36 | if self.num_updates >= 0: 37 | self.num_updates += 1 38 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 39 | 40 | one_minus_decay = 1.0 - decay 41 | 42 | with torch.no_grad(): 43 | m_param = dict(model.named_parameters()) 44 | shadow_params = dict(self.named_buffers()) 45 | 46 | for key in m_param: 47 | if m_param[key].requires_grad: 48 | sname = self.m_name2s_name[key] 49 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 50 | shadow_params[sname].sub_( 51 | one_minus_decay * (shadow_params[sname] - m_param[key]) 52 | ) 53 | else: 54 | assert not key in self.m_name2s_name 55 | 56 | def copy_to(self, model): 57 | m_param = dict(model.named_parameters()) 58 | shadow_params = dict(self.named_buffers()) 59 | for key in m_param: 60 | if m_param[key].requires_grad: 61 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 62 | else: 63 | assert not key in self.m_name2s_name 64 | 65 | def store(self, parameters): 66 | """ 67 | Save the current parameters for restoring later. 68 | Args: 69 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 70 | temporarily stored. 71 | """ 72 | self.collected_params = [param.clone() for param in parameters] 73 | 74 | def restore(self, parameters): 75 | """ 76 | Restore the parameters stored with the `store` method. 77 | Useful to validate the model with EMA parameters without affecting the 78 | original optimization process. Store the parameters before the 79 | `copy_to` method. After validation (or model saving), use this to 80 | restore the former parameters. 81 | Args: 82 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 83 | updated with the stored parameters. 84 | """ 85 | for c_param, param in zip(self.collected_params, parameters): 86 | param.data.copy_(c_param.data) 87 | -------------------------------------------------------------------------------- /sgm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/generative-models/0ad7de9a5cb53fd63d6d30a4f385485e72e08597/sgm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /tests/inference/test_inference.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | from PIL import Image 3 | import pytest 4 | from pytest import fixture 5 | import torch 6 | from typing import Tuple 7 | 8 | from sgm.inference.api import ( 9 | model_specs, 10 | SamplingParams, 11 | SamplingPipeline, 12 | Sampler, 13 | ModelArchitecture, 14 | ) 15 | import sgm.inference.helpers as helpers 16 | 17 | 18 | @pytest.mark.inference 19 | class TestInference: 20 | @fixture(scope="class", params=model_specs.keys()) 21 | def pipeline(self, request) -> SamplingPipeline: 22 | pipeline = SamplingPipeline(request.param) 23 | yield pipeline 24 | del pipeline 25 | torch.cuda.empty_cache() 26 | 27 | @fixture( 28 | scope="class", 29 | params=[ 30 | [ModelArchitecture.SDXL_V1_BASE, ModelArchitecture.SDXL_V1_REFINER], 31 | [ModelArchitecture.SDXL_V0_9_BASE, ModelArchitecture.SDXL_V0_9_REFINER], 32 | ], 33 | ids=["SDXL_V1", "SDXL_V0_9"], 34 | ) 35 | def sdxl_pipelines(self, request) -> Tuple[SamplingPipeline, SamplingPipeline]: 36 | base_pipeline = SamplingPipeline(request.param[0]) 37 | refiner_pipeline = SamplingPipeline(request.param[1]) 38 | yield base_pipeline, refiner_pipeline 39 | del base_pipeline 40 | del refiner_pipeline 41 | torch.cuda.empty_cache() 42 | 43 | def create_init_image(self, h, w): 44 | image_array = numpy.random.rand(h, w, 3) * 255 45 | image = Image.fromarray(image_array.astype("uint8")).convert("RGB") 46 | return helpers.get_input_image_tensor(image) 47 | 48 | @pytest.mark.parametrize("sampler_enum", Sampler) 49 | def test_txt2img(self, pipeline: SamplingPipeline, sampler_enum): 50 | output = pipeline.text_to_image( 51 | params=SamplingParams(sampler=sampler_enum.value, steps=10), 52 | prompt="A professional photograph of an astronaut riding a pig", 53 | negative_prompt="", 54 | samples=1, 55 | ) 56 | 57 | assert output is not None 58 | 59 | @pytest.mark.parametrize("sampler_enum", Sampler) 60 | def test_img2img(self, pipeline: SamplingPipeline, sampler_enum): 61 | output = pipeline.image_to_image( 62 | params=SamplingParams(sampler=sampler_enum.value, steps=10), 63 | image=self.create_init_image(pipeline.specs.height, pipeline.specs.width), 64 | prompt="A professional photograph of an astronaut riding a pig", 65 | negative_prompt="", 66 | samples=1, 67 | ) 68 | assert output is not None 69 | 70 | @pytest.mark.parametrize("sampler_enum", Sampler) 71 | @pytest.mark.parametrize( 72 | "use_init_image", [True, False], ids=["img2img", "txt2img"] 73 | ) 74 | def test_sdxl_with_refiner( 75 | self, 76 | sdxl_pipelines: Tuple[SamplingPipeline, SamplingPipeline], 77 | sampler_enum, 78 | use_init_image, 79 | ): 80 | base_pipeline, refiner_pipeline = sdxl_pipelines 81 | if use_init_image: 82 | output = base_pipeline.image_to_image( 83 | params=SamplingParams(sampler=sampler_enum.value, steps=10), 84 | image=self.create_init_image( 85 | base_pipeline.specs.height, base_pipeline.specs.width 86 | ), 87 | prompt="A professional photograph of an astronaut riding a pig", 88 | negative_prompt="", 89 | samples=1, 90 | return_latents=True, 91 | ) 92 | else: 93 | output = base_pipeline.text_to_image( 94 | params=SamplingParams(sampler=sampler_enum.value, steps=10), 95 | prompt="A professional photograph of an astronaut riding a pig", 96 | negative_prompt="", 97 | samples=1, 98 | return_latents=True, 99 | ) 100 | 101 | assert isinstance(output, (tuple, list)) 102 | samples, samples_z = output 103 | assert samples is not None 104 | assert samples_z is not None 105 | refiner_pipeline.refiner( 106 | params=SamplingParams(sampler=sampler_enum.value, steps=10), 107 | image=samples_z, 108 | prompt="A professional photograph of an astronaut riding a pig", 109 | negative_prompt="", 110 | samples=1, 111 | ) 112 | --------------------------------------------------------------------------------