├── .gitignore ├── LICENSE ├── LICENSE-MODEL ├── README.md ├── assets ├── model-variants.jpg ├── modelfigure.png ├── rick.jpeg ├── stable-inpainting │ ├── inpainting.gif │ └── merged-leopards.png └── stable-samples │ ├── depth2img │ ├── d2i.gif │ ├── depth2fantasy.jpeg │ ├── depth2img01.png │ ├── depth2img02.png │ ├── merged-0000.png │ ├── merged-0004.png │ ├── merged-0005.png │ ├── midas.jpeg │ └── old_man.png │ ├── img2img │ ├── mountains-1.png │ ├── mountains-2.png │ ├── mountains-3.png │ ├── sketch-mountains-input.jpg │ ├── upscaling-in.png │ └── upscaling-out.png │ ├── stable-unclip │ ├── houses_out.jpeg │ ├── oldcar000.jpeg │ ├── oldcar500.jpeg │ ├── oldcar800.jpeg │ ├── panda.jpg │ ├── plates_out.jpeg │ ├── unclip-variations.png │ └── unclip-variations_noise.png │ ├── txt2img │ ├── 768 │ │ ├── merged-0001.png │ │ ├── merged-0002.png │ │ ├── merged-0003.png │ │ ├── merged-0004.png │ │ ├── merged-0005.png │ │ └── merged-0006.png │ ├── 000002025.png │ ├── 000002035.png │ ├── merged-0001.png │ ├── merged-0003.png │ ├── merged-0005.png │ ├── merged-0006.png │ └── merged-0007.png │ └── upscaling │ ├── merged-dog.png │ ├── sampled-bear-x4.png │ └── snow-leopard-x4.png ├── checkpoints └── checkpoints.txt ├── configs ├── karlo │ ├── decoder_900M_vit_l.yaml │ ├── improved_sr_64_256_1.4B.yaml │ └── prior_1B_vit_l.yaml └── stable-diffusion │ ├── intel │ ├── v2-inference-bf16.yaml │ ├── v2-inference-fp32.yaml │ ├── v2-inference-v-bf16.yaml │ └── v2-inference-v-fp32.yaml │ ├── v2-1-stable-unclip-h-inference.yaml │ ├── v2-1-stable-unclip-l-inference.yaml │ ├── v2-inference-v.yaml │ ├── v2-inference.yaml │ ├── v2-inpainting-inference.yaml │ ├── v2-midas-inference.yaml │ └── x4-upscaling.yaml ├── doc └── UNCLIP.MD ├── environment.yaml ├── ldm ├── data │ ├── __init__.py │ └── util.py ├── models │ ├── autoencoder.py │ └── diffusion │ │ ├── __init__.py │ │ ├── ddim.py │ │ ├── ddpm.py │ │ ├── dpm_solver │ │ ├── __init__.py │ │ ├── dpm_solver.py │ │ └── sampler.py │ │ ├── plms.py │ │ └── sampling_util.py ├── modules │ ├── attention.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── model.py │ │ ├── openaimodel.py │ │ ├── upscaling.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ └── modules.py │ ├── image_degradation │ │ ├── __init__.py │ │ ├── bsrgan.py │ │ ├── bsrgan_light.py │ │ ├── utils │ │ │ └── test.png │ │ └── utils_image.py │ ├── karlo │ │ ├── __init__.py │ │ ├── diffusers_pipeline.py │ │ └── kakao │ │ │ ├── __init__.py │ │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── clip.py │ │ │ ├── decoder_model.py │ │ │ ├── prior_model.py │ │ │ ├── sr_256_1k.py │ │ │ └── sr_64_256.py │ │ │ ├── modules │ │ │ ├── __init__.py │ │ │ ├── diffusion │ │ │ │ ├── gaussian_diffusion.py │ │ │ │ └── respace.py │ │ │ ├── nn.py │ │ │ ├── resample.py │ │ │ ├── unet.py │ │ │ └── xf.py │ │ │ ├── sampler.py │ │ │ └── template.py │ └── midas │ │ ├── __init__.py │ │ ├── api.py │ │ ├── midas │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── blocks.py │ │ ├── dpt_depth.py │ │ ├── midas_net.py │ │ ├── midas_net_custom.py │ │ ├── transforms.py │ │ └── vit.py │ │ └── utils.py └── util.py ├── modelcard.md ├── requirements.txt ├── scripts ├── gradio │ ├── depth2img.py │ ├── inpainting.py │ └── superresolution.py ├── img2img.py ├── streamlit │ ├── depth2img.py │ ├── inpainting.py │ ├── stableunclip.py │ └── superresolution.py ├── tests │ └── test_watermark.py └── txt2img.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by project 2 | outputs/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # General MacOS 13 | .DS_Store 14 | .AppleDouble 15 | .LSOverride 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | cover/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | .pybuilder/ 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | # For a library or package, you might want to ignore these files since the code is 95 | # intended to run in multiple environments; otherwise, check them in: 96 | # .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # poetry 106 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 107 | # This is especially recommended for binary packages to ensure reproducibility, and is more 108 | # commonly ignored for libraries. 109 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 110 | #poetry.lock 111 | 112 | # pdm 113 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 114 | #pdm.lock 115 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 116 | # in version control. 117 | # https://pdm.fming.dev/#use-with-ide 118 | .pdm.toml 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # IDEs 164 | .idea/ 165 | .vscode/ 166 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Stability AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /assets/model-variants.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/model-variants.jpg -------------------------------------------------------------------------------- /assets/modelfigure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/modelfigure.png -------------------------------------------------------------------------------- /assets/rick.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/rick.jpeg -------------------------------------------------------------------------------- /assets/stable-inpainting/inpainting.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-inpainting/inpainting.gif -------------------------------------------------------------------------------- /assets/stable-inpainting/merged-leopards.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-inpainting/merged-leopards.png -------------------------------------------------------------------------------- /assets/stable-samples/depth2img/d2i.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/depth2img/d2i.gif -------------------------------------------------------------------------------- /assets/stable-samples/depth2img/depth2fantasy.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/depth2img/depth2fantasy.jpeg -------------------------------------------------------------------------------- /assets/stable-samples/depth2img/depth2img01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/depth2img/depth2img01.png -------------------------------------------------------------------------------- /assets/stable-samples/depth2img/depth2img02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/depth2img/depth2img02.png -------------------------------------------------------------------------------- /assets/stable-samples/depth2img/merged-0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/depth2img/merged-0000.png -------------------------------------------------------------------------------- /assets/stable-samples/depth2img/merged-0004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/depth2img/merged-0004.png -------------------------------------------------------------------------------- /assets/stable-samples/depth2img/merged-0005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/depth2img/merged-0005.png -------------------------------------------------------------------------------- /assets/stable-samples/depth2img/midas.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/depth2img/midas.jpeg -------------------------------------------------------------------------------- /assets/stable-samples/depth2img/old_man.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/depth2img/old_man.png -------------------------------------------------------------------------------- /assets/stable-samples/img2img/mountains-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/img2img/mountains-1.png -------------------------------------------------------------------------------- /assets/stable-samples/img2img/mountains-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/img2img/mountains-2.png -------------------------------------------------------------------------------- /assets/stable-samples/img2img/mountains-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/img2img/mountains-3.png -------------------------------------------------------------------------------- /assets/stable-samples/img2img/sketch-mountains-input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/img2img/sketch-mountains-input.jpg -------------------------------------------------------------------------------- /assets/stable-samples/img2img/upscaling-in.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/img2img/upscaling-in.png -------------------------------------------------------------------------------- /assets/stable-samples/img2img/upscaling-out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/img2img/upscaling-out.png -------------------------------------------------------------------------------- /assets/stable-samples/stable-unclip/houses_out.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/stable-unclip/houses_out.jpeg -------------------------------------------------------------------------------- /assets/stable-samples/stable-unclip/oldcar000.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/stable-unclip/oldcar000.jpeg -------------------------------------------------------------------------------- /assets/stable-samples/stable-unclip/oldcar500.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/stable-unclip/oldcar500.jpeg -------------------------------------------------------------------------------- /assets/stable-samples/stable-unclip/oldcar800.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/stable-unclip/oldcar800.jpeg -------------------------------------------------------------------------------- /assets/stable-samples/stable-unclip/panda.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/stable-unclip/panda.jpg -------------------------------------------------------------------------------- /assets/stable-samples/stable-unclip/plates_out.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/stable-unclip/plates_out.jpeg -------------------------------------------------------------------------------- /assets/stable-samples/stable-unclip/unclip-variations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/stable-unclip/unclip-variations.png -------------------------------------------------------------------------------- /assets/stable-samples/stable-unclip/unclip-variations_noise.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/stable-unclip/unclip-variations_noise.png -------------------------------------------------------------------------------- /assets/stable-samples/txt2img/000002025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/txt2img/000002025.png -------------------------------------------------------------------------------- /assets/stable-samples/txt2img/000002035.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/txt2img/000002035.png -------------------------------------------------------------------------------- /assets/stable-samples/txt2img/768/merged-0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/txt2img/768/merged-0001.png -------------------------------------------------------------------------------- /assets/stable-samples/txt2img/768/merged-0002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/txt2img/768/merged-0002.png -------------------------------------------------------------------------------- /assets/stable-samples/txt2img/768/merged-0003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/txt2img/768/merged-0003.png -------------------------------------------------------------------------------- /assets/stable-samples/txt2img/768/merged-0004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/txt2img/768/merged-0004.png -------------------------------------------------------------------------------- /assets/stable-samples/txt2img/768/merged-0005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/txt2img/768/merged-0005.png -------------------------------------------------------------------------------- /assets/stable-samples/txt2img/768/merged-0006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/txt2img/768/merged-0006.png -------------------------------------------------------------------------------- /assets/stable-samples/txt2img/merged-0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/txt2img/merged-0001.png -------------------------------------------------------------------------------- /assets/stable-samples/txt2img/merged-0003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/txt2img/merged-0003.png -------------------------------------------------------------------------------- /assets/stable-samples/txt2img/merged-0005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/txt2img/merged-0005.png -------------------------------------------------------------------------------- /assets/stable-samples/txt2img/merged-0006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/txt2img/merged-0006.png -------------------------------------------------------------------------------- /assets/stable-samples/txt2img/merged-0007.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/txt2img/merged-0007.png -------------------------------------------------------------------------------- /assets/stable-samples/upscaling/merged-dog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/upscaling/merged-dog.png -------------------------------------------------------------------------------- /assets/stable-samples/upscaling/sampled-bear-x4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/upscaling/sampled-bear-x4.png -------------------------------------------------------------------------------- /assets/stable-samples/upscaling/snow-leopard-x4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/assets/stable-samples/upscaling/snow-leopard-x4.png -------------------------------------------------------------------------------- /checkpoints/checkpoints.txt: -------------------------------------------------------------------------------- 1 | Put unCLIP checkpoints here. -------------------------------------------------------------------------------- /configs/karlo/decoder_900M_vit_l.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: t2i-decoder 3 | diffusion_sampler: uniform 4 | hparams: 5 | image_size: 64 6 | num_channels: 320 7 | num_res_blocks: 3 8 | channel_mult: '' 9 | attention_resolutions: 32,16,8 10 | num_heads: -1 11 | num_head_channels: 64 12 | num_heads_upsample: -1 13 | use_scale_shift_norm: true 14 | dropout: 0.1 15 | clip_dim: 768 16 | clip_emb_mult: 4 17 | text_ctx: 77 18 | xf_width: 1536 19 | xf_layers: 0 20 | xf_heads: 0 21 | xf_final_ln: false 22 | resblock_updown: true 23 | learn_sigma: true 24 | text_drop: 0.3 25 | clip_emb_type: image 26 | clip_emb_drop: 0.1 27 | use_plm: true 28 | 29 | diffusion: 30 | steps: 1000 31 | learn_sigma: true 32 | sigma_small: false 33 | noise_schedule: squaredcos_cap_v2 34 | use_kl: false 35 | predict_xstart: false 36 | rescale_learned_sigmas: true 37 | timestep_respacing: '' 38 | -------------------------------------------------------------------------------- /configs/karlo/improved_sr_64_256_1.4B.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: improved_sr_64_256 3 | diffusion_sampler: uniform 4 | hparams: 5 | channels: 320 6 | depth: 3 7 | channels_multiple: 8 | - 1 9 | - 2 10 | - 3 11 | - 4 12 | dropout: 0.0 13 | 14 | diffusion: 15 | steps: 1000 16 | learn_sigma: false 17 | sigma_small: true 18 | noise_schedule: squaredcos_cap_v2 19 | use_kl: false 20 | predict_xstart: false 21 | rescale_learned_sigmas: true 22 | timestep_respacing: '7' 23 | 24 | 25 | sampling: 26 | timestep_respacing: '7' # fix 27 | clip_denoise: true 28 | -------------------------------------------------------------------------------- /configs/karlo/prior_1B_vit_l.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: prior 3 | diffusion_sampler: uniform 4 | hparams: 5 | text_ctx: 77 6 | xf_width: 2048 7 | xf_layers: 20 8 | xf_heads: 32 9 | xf_final_ln: true 10 | text_drop: 0.2 11 | clip_dim: 768 12 | 13 | diffusion: 14 | steps: 1000 15 | learn_sigma: false 16 | sigma_small: true 17 | noise_schedule: squaredcos_cap_v2 18 | use_kl: false 19 | predict_xstart: true 20 | rescale_learned_sigmas: false 21 | timestep_respacing: '' 22 | -------------------------------------------------------------------------------- /configs/stable-diffusion/intel/v2-inference-bf16.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022 Intel Corporation 2 | # SPDX-License-Identifier: MIT 3 | 4 | model: 5 | base_learning_rate: 1.0e-4 6 | target: ldm.models.diffusion.ddpm.LatentDiffusion 7 | params: 8 | linear_start: 0.00085 9 | linear_end: 0.0120 10 | num_timesteps_cond: 1 11 | log_every_t: 200 12 | timesteps: 1000 13 | first_stage_key: "jpg" 14 | cond_stage_key: "txt" 15 | image_size: 64 16 | channels: 4 17 | cond_stage_trainable: false 18 | conditioning_key: crossattn 19 | monitor: val/loss_simple_ema 20 | scale_factor: 0.18215 21 | use_ema: False # we set this to false because this is an inference only config 22 | 23 | unet_config: 24 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 25 | params: 26 | use_checkpoint: False 27 | use_fp16: False 28 | use_bf16: True 29 | image_size: 32 # unused 30 | in_channels: 4 31 | out_channels: 4 32 | model_channels: 320 33 | attention_resolutions: [ 4, 2, 1 ] 34 | num_res_blocks: 2 35 | channel_mult: [ 1, 2, 4, 4 ] 36 | num_head_channels: 64 # need to fix for flash-attn 37 | use_spatial_transformer: True 38 | use_linear_in_transformer: True 39 | transformer_depth: 1 40 | context_dim: 1024 41 | legacy: False 42 | 43 | first_stage_config: 44 | target: ldm.models.autoencoder.AutoencoderKL 45 | params: 46 | embed_dim: 4 47 | monitor: val/rec_loss 48 | ddconfig: 49 | #attn_type: "vanilla-xformers" 50 | double_z: true 51 | z_channels: 4 52 | resolution: 256 53 | in_channels: 3 54 | out_ch: 3 55 | ch: 128 56 | ch_mult: 57 | - 1 58 | - 2 59 | - 4 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: [] 63 | dropout: 0.0 64 | lossconfig: 65 | target: torch.nn.Identity 66 | 67 | cond_stage_config: 68 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 69 | params: 70 | freeze: True 71 | layer: "penultimate" 72 | -------------------------------------------------------------------------------- /configs/stable-diffusion/intel/v2-inference-fp32.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022 Intel Corporation 2 | # SPDX-License-Identifier: MIT 3 | 4 | model: 5 | base_learning_rate: 1.0e-4 6 | target: ldm.models.diffusion.ddpm.LatentDiffusion 7 | params: 8 | linear_start: 0.00085 9 | linear_end: 0.0120 10 | num_timesteps_cond: 1 11 | log_every_t: 200 12 | timesteps: 1000 13 | first_stage_key: "jpg" 14 | cond_stage_key: "txt" 15 | image_size: 64 16 | channels: 4 17 | cond_stage_trainable: false 18 | conditioning_key: crossattn 19 | monitor: val/loss_simple_ema 20 | scale_factor: 0.18215 21 | use_ema: False # we set this to false because this is an inference only config 22 | 23 | unet_config: 24 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 25 | params: 26 | use_checkpoint: False 27 | use_fp16: False 28 | image_size: 32 # unused 29 | in_channels: 4 30 | out_channels: 4 31 | model_channels: 320 32 | attention_resolutions: [ 4, 2, 1 ] 33 | num_res_blocks: 2 34 | channel_mult: [ 1, 2, 4, 4 ] 35 | num_head_channels: 64 # need to fix for flash-attn 36 | use_spatial_transformer: True 37 | use_linear_in_transformer: True 38 | transformer_depth: 1 39 | context_dim: 1024 40 | legacy: False 41 | 42 | first_stage_config: 43 | target: ldm.models.autoencoder.AutoencoderKL 44 | params: 45 | embed_dim: 4 46 | monitor: val/rec_loss 47 | ddconfig: 48 | #attn_type: "vanilla-xformers" 49 | double_z: true 50 | z_channels: 4 51 | resolution: 256 52 | in_channels: 3 53 | out_ch: 3 54 | ch: 128 55 | ch_mult: 56 | - 1 57 | - 2 58 | - 4 59 | - 4 60 | num_res_blocks: 2 61 | attn_resolutions: [] 62 | dropout: 0.0 63 | lossconfig: 64 | target: torch.nn.Identity 65 | 66 | cond_stage_config: 67 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 68 | params: 69 | freeze: True 70 | layer: "penultimate" 71 | -------------------------------------------------------------------------------- /configs/stable-diffusion/intel/v2-inference-v-bf16.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022 Intel Corporation 2 | # SPDX-License-Identifier: MIT 3 | 4 | model: 5 | base_learning_rate: 1.0e-4 6 | target: ldm.models.diffusion.ddpm.LatentDiffusion 7 | params: 8 | parameterization: "v" 9 | linear_start: 0.00085 10 | linear_end: 0.0120 11 | num_timesteps_cond: 1 12 | log_every_t: 200 13 | timesteps: 1000 14 | first_stage_key: "jpg" 15 | cond_stage_key: "txt" 16 | image_size: 64 17 | channels: 4 18 | cond_stage_trainable: false 19 | conditioning_key: crossattn 20 | monitor: val/loss_simple_ema 21 | scale_factor: 0.18215 22 | use_ema: False # we set this to false because this is an inference only config 23 | 24 | unet_config: 25 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 26 | params: 27 | use_checkpoint: False 28 | use_fp16: False 29 | use_bf16: True 30 | image_size: 32 # unused 31 | in_channels: 4 32 | out_channels: 4 33 | model_channels: 320 34 | attention_resolutions: [ 4, 2, 1 ] 35 | num_res_blocks: 2 36 | channel_mult: [ 1, 2, 4, 4 ] 37 | num_head_channels: 64 # need to fix for flash-attn 38 | use_spatial_transformer: True 39 | use_linear_in_transformer: True 40 | transformer_depth: 1 41 | context_dim: 1024 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | params: 47 | embed_dim: 4 48 | monitor: val/rec_loss 49 | ddconfig: 50 | #attn_type: "vanilla-xformers" 51 | double_z: true 52 | z_channels: 4 53 | resolution: 256 54 | in_channels: 3 55 | out_ch: 3 56 | ch: 128 57 | ch_mult: 58 | - 1 59 | - 2 60 | - 4 61 | - 4 62 | num_res_blocks: 2 63 | attn_resolutions: [] 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | 68 | cond_stage_config: 69 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 70 | params: 71 | freeze: True 72 | layer: "penultimate" 73 | -------------------------------------------------------------------------------- /configs/stable-diffusion/intel/v2-inference-v-fp32.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022 Intel Corporation 2 | # SPDX-License-Identifier: MIT 3 | 4 | model: 5 | base_learning_rate: 1.0e-4 6 | target: ldm.models.diffusion.ddpm.LatentDiffusion 7 | params: 8 | parameterization: "v" 9 | linear_start: 0.00085 10 | linear_end: 0.0120 11 | num_timesteps_cond: 1 12 | log_every_t: 200 13 | timesteps: 1000 14 | first_stage_key: "jpg" 15 | cond_stage_key: "txt" 16 | image_size: 64 17 | channels: 4 18 | cond_stage_trainable: false 19 | conditioning_key: crossattn 20 | monitor: val/loss_simple_ema 21 | scale_factor: 0.18215 22 | use_ema: False # we set this to false because this is an inference only config 23 | 24 | unet_config: 25 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 26 | params: 27 | use_checkpoint: False 28 | use_fp16: False 29 | image_size: 32 # unused 30 | in_channels: 4 31 | out_channels: 4 32 | model_channels: 320 33 | attention_resolutions: [ 4, 2, 1 ] 34 | num_res_blocks: 2 35 | channel_mult: [ 1, 2, 4, 4 ] 36 | num_head_channels: 64 # need to fix for flash-attn 37 | use_spatial_transformer: True 38 | use_linear_in_transformer: True 39 | transformer_depth: 1 40 | context_dim: 1024 41 | legacy: False 42 | 43 | first_stage_config: 44 | target: ldm.models.autoencoder.AutoencoderKL 45 | params: 46 | embed_dim: 4 47 | monitor: val/rec_loss 48 | ddconfig: 49 | #attn_type: "vanilla-xformers" 50 | double_z: true 51 | z_channels: 4 52 | resolution: 256 53 | in_channels: 3 54 | out_ch: 3 55 | ch: 128 56 | ch_mult: 57 | - 1 58 | - 2 59 | - 4 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: [] 63 | dropout: 0.0 64 | lossconfig: 65 | target: torch.nn.Identity 66 | 67 | cond_stage_config: 68 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 69 | params: 70 | freeze: True 71 | layer: "penultimate" 72 | -------------------------------------------------------------------------------- /configs/stable-diffusion/v2-1-stable-unclip-h-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion 4 | params: 5 | embedding_dropout: 0.25 6 | parameterization: "v" 7 | linear_start: 0.00085 8 | linear_end: 0.0120 9 | log_every_t: 200 10 | timesteps: 1000 11 | first_stage_key: "jpg" 12 | cond_stage_key: "txt" 13 | image_size: 96 14 | channels: 4 15 | cond_stage_trainable: false 16 | conditioning_key: crossattn-adm 17 | scale_factor: 0.18215 18 | monitor: val/loss_simple_ema 19 | use_ema: False 20 | 21 | embedder_config: 22 | target: ldm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder 23 | 24 | noise_aug_config: 25 | target: ldm.modules.encoders.modules.CLIPEmbeddingNoiseAugmentation 26 | params: 27 | timestep_dim: 1024 28 | noise_schedule_config: 29 | timesteps: 1000 30 | beta_schedule: squaredcos_cap_v2 31 | 32 | unet_config: 33 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 34 | params: 35 | num_classes: "sequential" 36 | adm_in_channels: 2048 37 | use_checkpoint: True 38 | image_size: 32 # unused 39 | in_channels: 4 40 | out_channels: 4 41 | model_channels: 320 42 | attention_resolutions: [ 4, 2, 1 ] 43 | num_res_blocks: 2 44 | channel_mult: [ 1, 2, 4, 4 ] 45 | num_head_channels: 64 # need to fix for flash-attn 46 | use_spatial_transformer: True 47 | use_linear_in_transformer: True 48 | transformer_depth: 1 49 | context_dim: 1024 50 | legacy: False 51 | 52 | first_stage_config: 53 | target: ldm.models.autoencoder.AutoencoderKL 54 | params: 55 | embed_dim: 4 56 | monitor: val/rec_loss 57 | ddconfig: 58 | attn_type: "vanilla-xformers" 59 | double_z: true 60 | z_channels: 4 61 | resolution: 256 62 | in_channels: 3 63 | out_ch: 3 64 | ch: 128 65 | ch_mult: 66 | - 1 67 | - 2 68 | - 4 69 | - 4 70 | num_res_blocks: 2 71 | attn_resolutions: [ ] 72 | dropout: 0.0 73 | lossconfig: 74 | target: torch.nn.Identity 75 | 76 | cond_stage_config: 77 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 78 | params: 79 | freeze: True 80 | layer: "penultimate" 81 | -------------------------------------------------------------------------------- /configs/stable-diffusion/v2-1-stable-unclip-l-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion 4 | params: 5 | embedding_dropout: 0.25 6 | parameterization: "v" 7 | linear_start: 0.00085 8 | linear_end: 0.0120 9 | log_every_t: 200 10 | timesteps: 1000 11 | first_stage_key: "jpg" 12 | cond_stage_key: "txt" 13 | image_size: 96 14 | channels: 4 15 | cond_stage_trainable: false 16 | conditioning_key: crossattn-adm 17 | scale_factor: 0.18215 18 | monitor: val/loss_simple_ema 19 | use_ema: False 20 | 21 | embedder_config: 22 | target: ldm.modules.encoders.modules.ClipImageEmbedder 23 | params: 24 | model: "ViT-L/14" 25 | 26 | noise_aug_config: 27 | target: ldm.modules.encoders.modules.CLIPEmbeddingNoiseAugmentation 28 | params: 29 | clip_stats_path: "checkpoints/karlo_models/ViT-L-14_stats.th" 30 | timestep_dim: 768 31 | noise_schedule_config: 32 | timesteps: 1000 33 | beta_schedule: squaredcos_cap_v2 34 | 35 | unet_config: 36 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 37 | params: 38 | num_classes: "sequential" 39 | adm_in_channels: 1536 40 | use_checkpoint: True 41 | image_size: 32 # unused 42 | in_channels: 4 43 | out_channels: 4 44 | model_channels: 320 45 | attention_resolutions: [ 4, 2, 1 ] 46 | num_res_blocks: 2 47 | channel_mult: [ 1, 2, 4, 4 ] 48 | num_head_channels: 64 # need to fix for flash-attn 49 | use_spatial_transformer: True 50 | use_linear_in_transformer: True 51 | transformer_depth: 1 52 | context_dim: 1024 53 | legacy: False 54 | 55 | first_stage_config: 56 | target: ldm.models.autoencoder.AutoencoderKL 57 | params: 58 | embed_dim: 4 59 | monitor: val/rec_loss 60 | ddconfig: 61 | attn_type: "vanilla-xformers" 62 | double_z: true 63 | z_channels: 4 64 | resolution: 256 65 | in_channels: 3 66 | out_ch: 3 67 | ch: 128 68 | ch_mult: 69 | - 1 70 | - 2 71 | - 4 72 | - 4 73 | num_res_blocks: 2 74 | attn_resolutions: [ ] 75 | dropout: 0.0 76 | lossconfig: 77 | target: torch.nn.Identity 78 | 79 | cond_stage_config: 80 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 81 | params: 82 | freeze: True 83 | layer: "penultimate" -------------------------------------------------------------------------------- /configs/stable-diffusion/v2-inference-v.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | parameterization: "v" 6 | linear_start: 0.00085 7 | linear_end: 0.0120 8 | num_timesteps_cond: 1 9 | log_every_t: 200 10 | timesteps: 1000 11 | first_stage_key: "jpg" 12 | cond_stage_key: "txt" 13 | image_size: 64 14 | channels: 4 15 | cond_stage_trainable: false 16 | conditioning_key: crossattn 17 | monitor: val/loss_simple_ema 18 | scale_factor: 0.18215 19 | use_ema: False # we set this to false because this is an inference only config 20 | 21 | unet_config: 22 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 23 | params: 24 | use_checkpoint: True 25 | use_fp16: True 26 | image_size: 32 # unused 27 | in_channels: 4 28 | out_channels: 4 29 | model_channels: 320 30 | attention_resolutions: [ 4, 2, 1 ] 31 | num_res_blocks: 2 32 | channel_mult: [ 1, 2, 4, 4 ] 33 | num_head_channels: 64 # need to fix for flash-attn 34 | use_spatial_transformer: True 35 | use_linear_in_transformer: True 36 | transformer_depth: 1 37 | context_dim: 1024 38 | legacy: False 39 | 40 | first_stage_config: 41 | target: ldm.models.autoencoder.AutoencoderKL 42 | params: 43 | embed_dim: 4 44 | monitor: val/rec_loss 45 | ddconfig: 46 | #attn_type: "vanilla-xformers" 47 | double_z: true 48 | z_channels: 4 49 | resolution: 256 50 | in_channels: 3 51 | out_ch: 3 52 | ch: 128 53 | ch_mult: 54 | - 1 55 | - 2 56 | - 4 57 | - 4 58 | num_res_blocks: 2 59 | attn_resolutions: [] 60 | dropout: 0.0 61 | lossconfig: 62 | target: torch.nn.Identity 63 | 64 | cond_stage_config: 65 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 66 | params: 67 | freeze: True 68 | layer: "penultimate" 69 | -------------------------------------------------------------------------------- /configs/stable-diffusion/v2-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False # we set this to false because this is an inference only config 19 | 20 | unet_config: 21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 22 | params: 23 | use_checkpoint: True 24 | use_fp16: True 25 | image_size: 32 # unused 26 | in_channels: 4 27 | out_channels: 4 28 | model_channels: 320 29 | attention_resolutions: [ 4, 2, 1 ] 30 | num_res_blocks: 2 31 | channel_mult: [ 1, 2, 4, 4 ] 32 | num_head_channels: 64 # need to fix for flash-attn 33 | use_spatial_transformer: True 34 | use_linear_in_transformer: True 35 | transformer_depth: 1 36 | context_dim: 1024 37 | legacy: False 38 | 39 | first_stage_config: 40 | target: ldm.models.autoencoder.AutoencoderKL 41 | params: 42 | embed_dim: 4 43 | monitor: val/rec_loss 44 | ddconfig: 45 | #attn_type: "vanilla-xformers" 46 | double_z: true 47 | z_channels: 4 48 | resolution: 256 49 | in_channels: 3 50 | out_ch: 3 51 | ch: 128 52 | ch_mult: 53 | - 1 54 | - 2 55 | - 4 56 | - 4 57 | num_res_blocks: 2 58 | attn_resolutions: [] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: 64 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 65 | params: 66 | freeze: True 67 | layer: "penultimate" 68 | -------------------------------------------------------------------------------- /configs/stable-diffusion/v2-inpainting-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-05 3 | target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: hybrid 16 | scale_factor: 0.18215 17 | monitor: val/loss_simple_ema 18 | finetune_keys: null 19 | use_ema: False 20 | 21 | unet_config: 22 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 23 | params: 24 | use_checkpoint: True 25 | image_size: 32 # unused 26 | in_channels: 9 27 | out_channels: 4 28 | model_channels: 320 29 | attention_resolutions: [ 4, 2, 1 ] 30 | num_res_blocks: 2 31 | channel_mult: [ 1, 2, 4, 4 ] 32 | num_head_channels: 64 # need to fix for flash-attn 33 | use_spatial_transformer: True 34 | use_linear_in_transformer: True 35 | transformer_depth: 1 36 | context_dim: 1024 37 | legacy: False 38 | 39 | first_stage_config: 40 | target: ldm.models.autoencoder.AutoencoderKL 41 | params: 42 | embed_dim: 4 43 | monitor: val/rec_loss 44 | ddconfig: 45 | #attn_type: "vanilla-xformers" 46 | double_z: true 47 | z_channels: 4 48 | resolution: 256 49 | in_channels: 3 50 | out_ch: 3 51 | ch: 128 52 | ch_mult: 53 | - 1 54 | - 2 55 | - 4 56 | - 4 57 | num_res_blocks: 2 58 | attn_resolutions: [ ] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: 64 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 65 | params: 66 | freeze: True 67 | layer: "penultimate" 68 | 69 | 70 | data: 71 | target: ldm.data.laion.WebDataModuleFromConfig 72 | params: 73 | tar_base: null # for concat as in LAION-A 74 | p_unsafe_threshold: 0.1 75 | filter_word_list: "data/filters.yaml" 76 | max_pwatermark: 0.45 77 | batch_size: 8 78 | num_workers: 6 79 | multinode: True 80 | min_size: 512 81 | train: 82 | shards: 83 | - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -" 84 | - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -" 85 | - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -" 86 | - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -" 87 | - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar" 88 | shuffle: 10000 89 | image_key: jpg 90 | image_transforms: 91 | - target: torchvision.transforms.Resize 92 | params: 93 | size: 512 94 | interpolation: 3 95 | - target: torchvision.transforms.RandomCrop 96 | params: 97 | size: 512 98 | postprocess: 99 | target: ldm.data.laion.AddMask 100 | params: 101 | mode: "512train-large" 102 | p_drop: 0.25 103 | # NOTE use enough shards to avoid empty validation loops in workers 104 | validation: 105 | shards: 106 | - "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - " 107 | shuffle: 0 108 | image_key: jpg 109 | image_transforms: 110 | - target: torchvision.transforms.Resize 111 | params: 112 | size: 512 113 | interpolation: 3 114 | - target: torchvision.transforms.CenterCrop 115 | params: 116 | size: 512 117 | postprocess: 118 | target: ldm.data.laion.AddMask 119 | params: 120 | mode: "512train-large" 121 | p_drop: 0.25 122 | 123 | lightning: 124 | find_unused_parameters: True 125 | modelcheckpoint: 126 | params: 127 | every_n_train_steps: 5000 128 | 129 | callbacks: 130 | metrics_over_trainsteps_checkpoint: 131 | params: 132 | every_n_train_steps: 10000 133 | 134 | image_logger: 135 | target: main.ImageLogger 136 | params: 137 | enable_autocast: False 138 | disabled: False 139 | batch_frequency: 1000 140 | max_images: 4 141 | increase_log_steps: False 142 | log_first_step: False 143 | log_images_kwargs: 144 | use_ema_scope: False 145 | inpaint: False 146 | plot_progressive_rows: False 147 | plot_diffusion_rows: False 148 | N: 4 149 | unconditional_guidance_scale: 5.0 150 | unconditional_guidance_label: [""] 151 | ddim_steps: 50 # todo check these out for depth2img, 152 | ddim_eta: 0.0 # todo check these out for depth2img, 153 | 154 | trainer: 155 | benchmark: True 156 | val_check_interval: 5000000 157 | num_sanity_val_steps: 0 158 | accumulate_grad_batches: 1 159 | -------------------------------------------------------------------------------- /configs/stable-diffusion/v2-midas-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-07 3 | target: ldm.models.diffusion.ddpm.LatentDepth2ImageDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: hybrid 16 | scale_factor: 0.18215 17 | monitor: val/loss_simple_ema 18 | finetune_keys: null 19 | use_ema: False 20 | 21 | depth_stage_config: 22 | target: ldm.modules.midas.api.MiDaSInference 23 | params: 24 | model_type: "dpt_hybrid" 25 | 26 | unet_config: 27 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 28 | params: 29 | use_checkpoint: True 30 | image_size: 32 # unused 31 | in_channels: 5 32 | out_channels: 4 33 | model_channels: 320 34 | attention_resolutions: [ 4, 2, 1 ] 35 | num_res_blocks: 2 36 | channel_mult: [ 1, 2, 4, 4 ] 37 | num_head_channels: 64 # need to fix for flash-attn 38 | use_spatial_transformer: True 39 | use_linear_in_transformer: True 40 | transformer_depth: 1 41 | context_dim: 1024 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | params: 47 | embed_dim: 4 48 | monitor: val/rec_loss 49 | ddconfig: 50 | #attn_type: "vanilla-xformers" 51 | double_z: true 52 | z_channels: 4 53 | resolution: 256 54 | in_channels: 3 55 | out_ch: 3 56 | ch: 128 57 | ch_mult: 58 | - 1 59 | - 2 60 | - 4 61 | - 4 62 | num_res_blocks: 2 63 | attn_resolutions: [ ] 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | 68 | cond_stage_config: 69 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 70 | params: 71 | freeze: True 72 | layer: "penultimate" 73 | 74 | 75 | -------------------------------------------------------------------------------- /configs/stable-diffusion/x4-upscaling.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion 4 | params: 5 | parameterization: "v" 6 | low_scale_key: "lr" 7 | linear_start: 0.0001 8 | linear_end: 0.02 9 | num_timesteps_cond: 1 10 | log_every_t: 200 11 | timesteps: 1000 12 | first_stage_key: "jpg" 13 | cond_stage_key: "txt" 14 | image_size: 128 15 | channels: 4 16 | cond_stage_trainable: false 17 | conditioning_key: "hybrid-adm" 18 | monitor: val/loss_simple_ema 19 | scale_factor: 0.08333 20 | use_ema: False 21 | 22 | low_scale_config: 23 | target: ldm.modules.diffusionmodules.upscaling.ImageConcatWithNoiseAugmentation 24 | params: 25 | noise_schedule_config: # image space 26 | linear_start: 0.0001 27 | linear_end: 0.02 28 | max_noise_level: 350 29 | 30 | unet_config: 31 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 32 | params: 33 | use_checkpoint: True 34 | num_classes: 1000 # timesteps for noise conditioning (here constant, just need one) 35 | image_size: 128 36 | in_channels: 7 37 | out_channels: 4 38 | model_channels: 256 39 | attention_resolutions: [ 2,4,8] 40 | num_res_blocks: 2 41 | channel_mult: [ 1, 2, 2, 4] 42 | disable_self_attentions: [True, True, True, False] 43 | disable_middle_self_attn: False 44 | num_heads: 8 45 | use_spatial_transformer: True 46 | transformer_depth: 1 47 | context_dim: 1024 48 | legacy: False 49 | use_linear_in_transformer: True 50 | 51 | first_stage_config: 52 | target: ldm.models.autoencoder.AutoencoderKL 53 | params: 54 | embed_dim: 4 55 | ddconfig: 56 | # attn_type: "vanilla-xformers" this model needs efficient attention to be feasible on HR data, also the decoder seems to break in half precision (UNet is fine though) 57 | double_z: True 58 | z_channels: 4 59 | resolution: 256 60 | in_channels: 3 61 | out_ch: 3 62 | ch: 128 63 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 64 | num_res_blocks: 2 65 | attn_resolutions: [ ] 66 | dropout: 0.0 67 | 68 | lossconfig: 69 | target: torch.nn.Identity 70 | 71 | cond_stage_config: 72 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 73 | params: 74 | freeze: True 75 | layer: "penultimate" 76 | 77 | -------------------------------------------------------------------------------- /doc/UNCLIP.MD: -------------------------------------------------------------------------------- 1 | ### Stable unCLIP 2 | 3 | [unCLIP](https://openai.com/dall-e-2/) is the approach behind OpenAI's [DALL·E 2](https://openai.com/dall-e-2/), 4 | trained to invert CLIP image embeddings. 5 | We finetuned SD 2.1 to accept a CLIP ViT-L/14 image embedding in addition to the text encodings. 6 | This means that the model can be used to produce image variations, but can also be combined with a text-to-image 7 | embedding prior to yield a full text-to-image model at 768x768 resolution. 8 | 9 | If you would like to try a demo of this model on the web, please visit https://clipdrop.co/stable-diffusion-reimagine 10 | 11 | We provide two models, trained on OpenAI CLIP-L and OpenCLIP-H image embeddings, respectively, 12 | available from [https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip](https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip/tree/main). 13 | To use them, download from Hugging Face, and put and the weights into the `checkpoints` folder. 14 | 15 | #### Image Variations 16 | ![image-variations-l-1](../assets/stable-samples/stable-unclip/unclip-variations.png) 17 | 18 | Diffusers integration 19 | Stable UnCLIP Image Variations is integrated with the [🧨 diffusers](https://github.com/huggingface/diffusers) library 20 | ```python 21 | #pip install git+https://github.com/huggingface/diffusers.git transformers accelerate 22 | import requests 23 | import torch 24 | from PIL import Image 25 | from io import BytesIO 26 | 27 | from diffusers import StableUnCLIPImg2ImgPipeline 28 | 29 | #Start the StableUnCLIP Image variations pipeline 30 | pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( 31 | "stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variation="fp16" 32 | ) 33 | pipe = pipe.to("cuda") 34 | 35 | #Get image from URL 36 | url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_unclip/tarsila_do_amaral.png" 37 | response = requests.get(url) 38 | init_image = Image.open(BytesIO(response.content)).convert("RGB") 39 | 40 | #Pipe to make the variation 41 | images = pipe(init_image).images 42 | images[0].save("tarsila_variation.png") 43 | ``` 44 | Check out the [Stable UnCLIP pipeline docs here](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_unclip) 45 | 46 | Streamlit UI demo 47 | 48 | ``` 49 | streamlit run scripts/streamlit/stableunclip.py 50 | ``` 51 | to launch a streamlit script than can be used to make image variations with both models (CLIP-L and OpenCLIP-H). 52 | These models can process a `noise_level`, which specifies an amount of Gaussian noise added to the CLIP embeddings. 53 | This can be used to increase output variance as in the following examples. 54 | 55 | ![image-variations-noise](../assets/stable-samples/stable-unclip/unclip-variations_noise.png) 56 | 57 | 58 | ### Stable Diffusion Meets Karlo 59 | ![panda](../assets/stable-samples/stable-unclip/panda.jpg) 60 | 61 | Recently, [KakaoBrain](https://kakaobrain.com/) openly released [Karlo](https://github.com/kakaobrain/karlo), a pretrained, large-scale replication of [unCLIP](https://arxiv.org/abs/2204.06125). 62 | We introduce _Stable Karlo_, a combination of the Karlo CLIP image embedding prior, and Stable Diffusion v2.1-768. 63 | 64 | To run the model, first download the KARLO checkpoints 65 | ```shell 66 | mkdir -p checkpoints/karlo_models 67 | cd checkpoints/karlo_models 68 | wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/096db1af569b284eb76b3881534822d9/ViT-L-14.pt 69 | wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/0b62380a75e56f073e2844ab5199153d/ViT-L-14_stats.th 70 | wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/85626483eaca9f581e2a78d31ff905ca/prior-ckpt-step%3D01000000-of-01000000.ckpt 71 | cd ../../ 72 | ``` 73 | and the finetuned SD2.1 unCLIP-L checkpoint from [here](https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip/blob/main/sd21-unclip-l.ckpt), and put the ckpt into the `checkpoints folder` 74 | 75 | Then, run 76 | 77 | ``` 78 | streamlit run scripts/streamlit/stableunclip.py 79 | ``` 80 | and pick the `use_karlo` option in the GUI. 81 | The script optionally supports sampling from the full Karlo model. To use it, download the 64x64 decoder and 64->256 upscaler 82 | via 83 | ```shell 84 | cd checkpoints/karlo_models 85 | wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/efdf6206d8ed593961593dc029a8affa/decoder-ckpt-step%3D01000000-of-01000000.ckpt 86 | wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/4226b831ae0279020d134281f3c31590/improved-sr-ckpt-step%3D1.2M.ckpt 87 | cd ../../ 88 | ``` 89 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: ldm 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=20.3 8 | - cudatoolkit=11.3 9 | - pytorch=1.12.1 10 | - torchvision=0.13.1 11 | - numpy=1.23.1 12 | - pip: 13 | - albumentations==1.3.0 14 | - opencv-python==4.6.0.66 15 | - imageio==2.9.0 16 | - imageio-ffmpeg==0.4.2 17 | - pytorch-lightning==1.4.2 18 | - omegaconf==2.1.1 19 | - test-tube>=0.7.5 20 | - streamlit==1.12.1 21 | - einops==0.3.0 22 | - transformers==4.19.2 23 | - webdataset==0.2.5 24 | - kornia==0.6 25 | - open_clip_torch==2.0.2 26 | - invisible-watermark>=0.1.5 27 | - streamlit-drawable-canvas==0.8.0 28 | - torchmetrics==0.6.0 29 | - -e . 30 | -------------------------------------------------------------------------------- /ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/ldm/data/__init__.py -------------------------------------------------------------------------------- /ldm/data/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ldm.modules.midas.api import load_midas_transform 4 | 5 | 6 | class AddMiDaS(object): 7 | def __init__(self, model_type): 8 | super().__init__() 9 | self.transform = load_midas_transform(model_type) 10 | 11 | def pt2np(self, x): 12 | x = ((x + 1.0) * .5).detach().cpu().numpy() 13 | return x 14 | 15 | def np2pt(self, x): 16 | x = torch.from_numpy(x) * 2 - 1. 17 | return x 18 | 19 | def __call__(self, sample): 20 | # sample['jpg'] is tensor hwc in [-1, 1] at this point 21 | x = self.pt2np(sample['jpg']) 22 | x = self.transform({"image": x})["image"] 23 | sample['midas_in'] = x 24 | return sample -------------------------------------------------------------------------------- /ldm/models/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | import torch.nn.functional as F 4 | from contextlib import contextmanager 5 | 6 | from ldm.modules.diffusionmodules.model import Encoder, Decoder 7 | from ldm.modules.distributions.distributions import DiagonalGaussianDistribution 8 | 9 | from ldm.util import instantiate_from_config 10 | from ldm.modules.ema import LitEma 11 | 12 | 13 | class AutoencoderKL(pl.LightningModule): 14 | def __init__(self, 15 | ddconfig, 16 | lossconfig, 17 | embed_dim, 18 | ckpt_path=None, 19 | ignore_keys=[], 20 | image_key="image", 21 | colorize_nlabels=None, 22 | monitor=None, 23 | ema_decay=None, 24 | learn_logvar=False 25 | ): 26 | super().__init__() 27 | self.learn_logvar = learn_logvar 28 | self.image_key = image_key 29 | self.encoder = Encoder(**ddconfig) 30 | self.decoder = Decoder(**ddconfig) 31 | self.loss = instantiate_from_config(lossconfig) 32 | assert ddconfig["double_z"] 33 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) 34 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 35 | self.embed_dim = embed_dim 36 | if colorize_nlabels is not None: 37 | assert type(colorize_nlabels)==int 38 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 39 | if monitor is not None: 40 | self.monitor = monitor 41 | 42 | self.use_ema = ema_decay is not None 43 | if self.use_ema: 44 | self.ema_decay = ema_decay 45 | assert 0. < ema_decay < 1. 46 | self.model_ema = LitEma(self, decay=ema_decay) 47 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") 48 | 49 | if ckpt_path is not None: 50 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 51 | 52 | def init_from_ckpt(self, path, ignore_keys=list()): 53 | sd = torch.load(path, map_location="cpu")["state_dict"] 54 | keys = list(sd.keys()) 55 | for k in keys: 56 | for ik in ignore_keys: 57 | if k.startswith(ik): 58 | print("Deleting key {} from state_dict.".format(k)) 59 | del sd[k] 60 | self.load_state_dict(sd, strict=False) 61 | print(f"Restored from {path}") 62 | 63 | @contextmanager 64 | def ema_scope(self, context=None): 65 | if self.use_ema: 66 | self.model_ema.store(self.parameters()) 67 | self.model_ema.copy_to(self) 68 | if context is not None: 69 | print(f"{context}: Switched to EMA weights") 70 | try: 71 | yield None 72 | finally: 73 | if self.use_ema: 74 | self.model_ema.restore(self.parameters()) 75 | if context is not None: 76 | print(f"{context}: Restored training weights") 77 | 78 | def on_train_batch_end(self, *args, **kwargs): 79 | if self.use_ema: 80 | self.model_ema(self) 81 | 82 | def encode(self, x): 83 | h = self.encoder(x) 84 | moments = self.quant_conv(h) 85 | posterior = DiagonalGaussianDistribution(moments) 86 | return posterior 87 | 88 | def decode(self, z): 89 | z = self.post_quant_conv(z) 90 | dec = self.decoder(z) 91 | return dec 92 | 93 | def forward(self, input, sample_posterior=True): 94 | posterior = self.encode(input) 95 | if sample_posterior: 96 | z = posterior.sample() 97 | else: 98 | z = posterior.mode() 99 | dec = self.decode(z) 100 | return dec, posterior 101 | 102 | def get_input(self, batch, k): 103 | x = batch[k] 104 | if len(x.shape) == 3: 105 | x = x[..., None] 106 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() 107 | return x 108 | 109 | def training_step(self, batch, batch_idx, optimizer_idx): 110 | inputs = self.get_input(batch, self.image_key) 111 | reconstructions, posterior = self(inputs) 112 | 113 | if optimizer_idx == 0: 114 | # train encoder+decoder+logvar 115 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 116 | last_layer=self.get_last_layer(), split="train") 117 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 118 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) 119 | return aeloss 120 | 121 | if optimizer_idx == 1: 122 | # train the discriminator 123 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 124 | last_layer=self.get_last_layer(), split="train") 125 | 126 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 127 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) 128 | return discloss 129 | 130 | def validation_step(self, batch, batch_idx): 131 | log_dict = self._validation_step(batch, batch_idx) 132 | with self.ema_scope(): 133 | log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") 134 | return log_dict 135 | 136 | def _validation_step(self, batch, batch_idx, postfix=""): 137 | inputs = self.get_input(batch, self.image_key) 138 | reconstructions, posterior = self(inputs) 139 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, 140 | last_layer=self.get_last_layer(), split="val"+postfix) 141 | 142 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, 143 | last_layer=self.get_last_layer(), split="val"+postfix) 144 | 145 | self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) 146 | self.log_dict(log_dict_ae) 147 | self.log_dict(log_dict_disc) 148 | return self.log_dict 149 | 150 | def configure_optimizers(self): 151 | lr = self.learning_rate 152 | ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list( 153 | self.quant_conv.parameters()) + list(self.post_quant_conv.parameters()) 154 | if self.learn_logvar: 155 | print(f"{self.__class__.__name__}: Learning logvar") 156 | ae_params_list.append(self.loss.logvar) 157 | opt_ae = torch.optim.Adam(ae_params_list, 158 | lr=lr, betas=(0.5, 0.9)) 159 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 160 | lr=lr, betas=(0.5, 0.9)) 161 | return [opt_ae, opt_disc], [] 162 | 163 | def get_last_layer(self): 164 | return self.decoder.conv_out.weight 165 | 166 | @torch.no_grad() 167 | def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): 168 | log = dict() 169 | x = self.get_input(batch, self.image_key) 170 | x = x.to(self.device) 171 | if not only_inputs: 172 | xrec, posterior = self(x) 173 | if x.shape[1] > 3: 174 | # colorize with random projection 175 | assert xrec.shape[1] > 3 176 | x = self.to_rgb(x) 177 | xrec = self.to_rgb(xrec) 178 | log["samples"] = self.decode(torch.randn_like(posterior.sample())) 179 | log["reconstructions"] = xrec 180 | if log_ema or self.use_ema: 181 | with self.ema_scope(): 182 | xrec_ema, posterior_ema = self(x) 183 | if x.shape[1] > 3: 184 | # colorize with random projection 185 | assert xrec_ema.shape[1] > 3 186 | xrec_ema = self.to_rgb(xrec_ema) 187 | log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample())) 188 | log["reconstructions_ema"] = xrec_ema 189 | log["inputs"] = x 190 | return log 191 | 192 | def to_rgb(self, x): 193 | assert self.image_key == "segmentation" 194 | if not hasattr(self, "colorize"): 195 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 196 | x = F.conv2d(x, weight=self.colorize) 197 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 198 | return x 199 | 200 | 201 | class IdentityFirstStage(torch.nn.Module): 202 | def __init__(self, *args, vq_interface=False, **kwargs): 203 | self.vq_interface = vq_interface 204 | super().__init__() 205 | 206 | def encode(self, x, *args, **kwargs): 207 | return x 208 | 209 | def decode(self, x, *args, **kwargs): 210 | return x 211 | 212 | def quantize(self, x, *args, **kwargs): 213 | if self.vq_interface: 214 | return x, None, [None, None, None] 215 | return x 216 | 217 | def forward(self, x, *args, **kwargs): 218 | return x 219 | 220 | -------------------------------------------------------------------------------- /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import DPMSolverSampler -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | import torch 3 | 4 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver 5 | 6 | MODEL_TYPES = { 7 | "eps": "noise", 8 | "v": "v" 9 | } 10 | 11 | 12 | class DPMSolverSampler(object): 13 | def __init__(self, model, device=torch.device("cuda"), **kwargs): 14 | super().__init__() 15 | self.model = model 16 | self.device = device 17 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) 18 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) 19 | 20 | def register_buffer(self, name, attr): 21 | if type(attr) == torch.Tensor: 22 | if attr.device != self.device: 23 | attr = attr.to(self.device) 24 | setattr(self, name, attr) 25 | 26 | @torch.no_grad() 27 | def sample(self, 28 | S, 29 | batch_size, 30 | shape, 31 | conditioning=None, 32 | callback=None, 33 | normals_sequence=None, 34 | img_callback=None, 35 | quantize_x0=False, 36 | eta=0., 37 | mask=None, 38 | x0=None, 39 | temperature=1., 40 | noise_dropout=0., 41 | score_corrector=None, 42 | corrector_kwargs=None, 43 | verbose=True, 44 | x_T=None, 45 | log_every_t=100, 46 | unconditional_guidance_scale=1., 47 | unconditional_conditioning=None, 48 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 49 | **kwargs 50 | ): 51 | if conditioning is not None: 52 | if isinstance(conditioning, dict): 53 | ctmp = conditioning[list(conditioning.keys())[0]] 54 | while isinstance(ctmp, list): ctmp = ctmp[0] 55 | if isinstance(ctmp, torch.Tensor): 56 | cbs = ctmp.shape[0] 57 | if cbs != batch_size: 58 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 59 | elif isinstance(conditioning, list): 60 | for ctmp in conditioning: 61 | if ctmp.shape[0] != batch_size: 62 | print(f"Warning: Got {ctmp.shape[0]} conditionings but batch-size is {batch_size}") 63 | else: 64 | if isinstance(conditioning, torch.Tensor): 65 | if conditioning.shape[0] != batch_size: 66 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 67 | 68 | # sampling 69 | C, H, W = shape 70 | size = (batch_size, C, H, W) 71 | 72 | print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') 73 | 74 | device = self.model.betas.device 75 | if x_T is None: 76 | img = torch.randn(size, device=device) 77 | else: 78 | img = x_T 79 | 80 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) 81 | 82 | model_fn = model_wrapper( 83 | lambda x, t, c: self.model.apply_model(x, t, c), 84 | ns, 85 | model_type=MODEL_TYPES[self.model.parameterization], 86 | guidance_type="classifier-free", 87 | condition=conditioning, 88 | unconditional_condition=unconditional_conditioning, 89 | guidance_scale=unconditional_guidance_scale, 90 | ) 91 | 92 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) 93 | x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, 94 | lower_order_final=True) 95 | 96 | return x.to(device), None 97 | -------------------------------------------------------------------------------- /ldm/models/diffusion/sampling_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def append_dims(x, target_dims): 6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions. 7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" 8 | dims_to_append = target_dims - x.ndim 9 | if dims_to_append < 0: 10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') 11 | return x[(...,) + (None,) * dims_to_append] 12 | 13 | 14 | def norm_thresholding(x0, value): 15 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) 16 | return x0 * (value / s) 17 | 18 | 19 | def spatial_norm_thresholding(x0, value): 20 | # b c h w 21 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) 22 | return x0 * (value / s) -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/upscaling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from functools import partial 5 | 6 | from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule 7 | from ldm.util import default 8 | 9 | 10 | class AbstractLowScaleModel(nn.Module): 11 | # for concatenating a downsampled image to the latent representation 12 | def __init__(self, noise_schedule_config=None): 13 | super(AbstractLowScaleModel, self).__init__() 14 | if noise_schedule_config is not None: 15 | self.register_schedule(**noise_schedule_config) 16 | 17 | def register_schedule(self, beta_schedule="linear", timesteps=1000, 18 | linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 19 | betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, 20 | cosine_s=cosine_s) 21 | alphas = 1. - betas 22 | alphas_cumprod = np.cumprod(alphas, axis=0) 23 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 24 | 25 | timesteps, = betas.shape 26 | self.num_timesteps = int(timesteps) 27 | self.linear_start = linear_start 28 | self.linear_end = linear_end 29 | assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' 30 | 31 | to_torch = partial(torch.tensor, dtype=torch.float32) 32 | 33 | self.register_buffer('betas', to_torch(betas)) 34 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 35 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) 36 | 37 | # calculations for diffusion q(x_t | x_{t-1}) and others 38 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) 39 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) 40 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) 41 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) 42 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) 43 | 44 | def q_sample(self, x_start, t, noise=None): 45 | noise = default(noise, lambda: torch.randn_like(x_start)) 46 | return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 47 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) 48 | 49 | def forward(self, x): 50 | return x, None 51 | 52 | def decode(self, x): 53 | return x 54 | 55 | 56 | class SimpleImageConcat(AbstractLowScaleModel): 57 | # no noise level conditioning 58 | def __init__(self): 59 | super(SimpleImageConcat, self).__init__(noise_schedule_config=None) 60 | self.max_noise_level = 0 61 | 62 | def forward(self, x): 63 | # fix to constant noise level 64 | return x, torch.zeros(x.shape[0], device=x.device).long() 65 | 66 | 67 | class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): 68 | def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False): 69 | super().__init__(noise_schedule_config=noise_schedule_config) 70 | self.max_noise_level = max_noise_level 71 | 72 | def forward(self, x, noise_level=None): 73 | if noise_level is None: 74 | noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() 75 | else: 76 | assert isinstance(noise_level, torch.Tensor) 77 | z = self.q_sample(x, noise_level) 78 | return z, noise_level 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1, dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | # remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.', '') 20 | self.m_name2s_name.update({name: s_name}) 21 | self.register_buffer(s_name, p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def reset_num_updates(self): 26 | del self.num_updates 27 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) 28 | 29 | def forward(self, model): 30 | decay = self.decay 31 | 32 | if self.num_updates >= 0: 33 | self.num_updates += 1 34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 35 | 36 | one_minus_decay = 1.0 - decay 37 | 38 | with torch.no_grad(): 39 | m_param = dict(model.named_parameters()) 40 | shadow_params = dict(self.named_buffers()) 41 | 42 | for key in m_param: 43 | if m_param[key].requires_grad: 44 | sname = self.m_name2s_name[key] 45 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 46 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 47 | else: 48 | assert not key in self.m_name2s_name 49 | 50 | def copy_to(self, model): 51 | m_param = dict(model.named_parameters()) 52 | shadow_params = dict(self.named_buffers()) 53 | for key in m_param: 54 | if m_param[key].requires_grad: 55 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 56 | else: 57 | assert not key in self.m_name2s_name 58 | 59 | def store(self, parameters): 60 | """ 61 | Save the current parameters for restoring later. 62 | Args: 63 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 64 | temporarily stored. 65 | """ 66 | self.collected_params = [param.clone() for param in parameters] 67 | 68 | def restore(self, parameters): 69 | """ 70 | Restore the parameters stored with the `store` method. 71 | Useful to validate the model with EMA parameters without affecting the 72 | original optimization process. Store the parameters before the 73 | `copy_to` method. After validation (or model saving), use this to 74 | restore the former parameters. 75 | Args: 76 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 77 | updated with the stored parameters. 78 | """ 79 | for c_param, param in zip(self.collected_params, parameters): 80 | param.data.copy_(c_param.data) 81 | -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /ldm/modules/karlo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/ldm/modules/karlo/__init__.py -------------------------------------------------------------------------------- /ldm/modules/karlo/kakao/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/ldm/modules/karlo/kakao/__init__.py -------------------------------------------------------------------------------- /ldm/modules/karlo/kakao/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/ldm/modules/karlo/kakao/models/__init__.py -------------------------------------------------------------------------------- /ldm/modules/karlo/kakao/models/clip.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Karlo-v1.0.alpha 3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved. 4 | # ------------------------------------------------------------------------------------ 5 | # ------------------------------------------------------------------------------------ 6 | # Adapted from OpenAI's CLIP (https://github.com/openai/CLIP/) 7 | # ------------------------------------------------------------------------------------ 8 | 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import clip 14 | 15 | from clip.model import CLIP, convert_weights 16 | from clip.simple_tokenizer import SimpleTokenizer, default_bpe 17 | 18 | 19 | """===== Monkey-Patching original CLIP for JIT compile =====""" 20 | 21 | 22 | class LayerNorm(nn.LayerNorm): 23 | """Subclass torch's LayerNorm to handle fp16.""" 24 | 25 | def forward(self, x: torch.Tensor): 26 | orig_type = x.dtype 27 | ret = F.layer_norm( 28 | x.type(torch.float32), 29 | self.normalized_shape, 30 | self.weight, 31 | self.bias, 32 | self.eps, 33 | ) 34 | return ret.type(orig_type) 35 | 36 | 37 | clip.model.LayerNorm = LayerNorm 38 | delattr(clip.model.CLIP, "forward") 39 | 40 | """===== End of Monkey-Patching =====""" 41 | 42 | 43 | class CustomizedCLIP(CLIP): 44 | def __init__(self, *args, **kwargs): 45 | super().__init__(*args, **kwargs) 46 | 47 | @torch.jit.export 48 | def encode_image(self, image): 49 | return self.visual(image) 50 | 51 | @torch.jit.export 52 | def encode_text(self, text): 53 | # re-define this function to return unpooled text features 54 | 55 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 56 | 57 | x = x + self.positional_embedding.type(self.dtype) 58 | x = x.permute(1, 0, 2) # NLD -> LND 59 | x = self.transformer(x) 60 | x = x.permute(1, 0, 2) # LND -> NLD 61 | x = self.ln_final(x).type(self.dtype) 62 | 63 | x_seq = x 64 | # x.shape = [batch_size, n_ctx, transformer.width] 65 | # take features from the eot embedding (eot_token is the highest number in each sequence) 66 | x_out = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 67 | 68 | return x_out, x_seq 69 | 70 | @torch.jit.ignore 71 | def forward(self, image, text): 72 | super().forward(image, text) 73 | 74 | @classmethod 75 | def load_from_checkpoint(cls, ckpt_path: str): 76 | state_dict = torch.load(ckpt_path, map_location="cpu").state_dict() 77 | 78 | vit = "visual.proj" in state_dict 79 | if vit: 80 | vision_width = state_dict["visual.conv1.weight"].shape[0] 81 | vision_layers = len( 82 | [ 83 | k 84 | for k in state_dict.keys() 85 | if k.startswith("visual.") and k.endswith(".attn.in_proj_weight") 86 | ] 87 | ) 88 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 89 | grid_size = round( 90 | (state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5 91 | ) 92 | image_resolution = vision_patch_size * grid_size 93 | else: 94 | counts: list = [ 95 | len( 96 | set( 97 | k.split(".")[2] 98 | for k in state_dict 99 | if k.startswith(f"visual.layer{b}") 100 | ) 101 | ) 102 | for b in [1, 2, 3, 4] 103 | ] 104 | vision_layers = tuple(counts) 105 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 106 | output_width = round( 107 | (state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5 108 | ) 109 | vision_patch_size = None 110 | assert ( 111 | output_width**2 + 1 112 | == state_dict["visual.attnpool.positional_embedding"].shape[0] 113 | ) 114 | image_resolution = output_width * 32 115 | 116 | embed_dim = state_dict["text_projection"].shape[1] 117 | context_length = state_dict["positional_embedding"].shape[0] 118 | vocab_size = state_dict["token_embedding.weight"].shape[0] 119 | transformer_width = state_dict["ln_final.weight"].shape[0] 120 | transformer_heads = transformer_width // 64 121 | transformer_layers = len( 122 | set( 123 | k.split(".")[2] 124 | for k in state_dict 125 | if k.startswith("transformer.resblocks") 126 | ) 127 | ) 128 | 129 | model = cls( 130 | embed_dim, 131 | image_resolution, 132 | vision_layers, 133 | vision_width, 134 | vision_patch_size, 135 | context_length, 136 | vocab_size, 137 | transformer_width, 138 | transformer_heads, 139 | transformer_layers, 140 | ) 141 | 142 | for key in ["input_resolution", "context_length", "vocab_size"]: 143 | if key in state_dict: 144 | del state_dict[key] 145 | 146 | convert_weights(model) 147 | model.load_state_dict(state_dict) 148 | model.eval() 149 | model.float() 150 | return model 151 | 152 | 153 | class CustomizedTokenizer(SimpleTokenizer): 154 | def __init__(self): 155 | super().__init__(bpe_path=default_bpe()) 156 | 157 | self.sot_token = self.encoder["<|startoftext|>"] 158 | self.eot_token = self.encoder["<|endoftext|>"] 159 | 160 | def padded_tokens_and_mask(self, texts, text_ctx): 161 | assert isinstance(texts, list) and all( 162 | isinstance(elem, str) for elem in texts 163 | ), "texts should be a list of strings" 164 | 165 | all_tokens = [ 166 | [self.sot_token] + self.encode(text) + [self.eot_token] for text in texts 167 | ] 168 | 169 | mask = [ 170 | [True] * min(text_ctx, len(tokens)) 171 | + [False] * max(text_ctx - len(tokens), 0) 172 | for tokens in all_tokens 173 | ] 174 | mask = torch.tensor(mask, dtype=torch.bool) 175 | result = torch.zeros(len(all_tokens), text_ctx, dtype=torch.int) 176 | for i, tokens in enumerate(all_tokens): 177 | if len(tokens) > text_ctx: 178 | tokens = tokens[:text_ctx] 179 | tokens[-1] = self.eot_token 180 | result[i, : len(tokens)] = torch.tensor(tokens) 181 | 182 | return result, mask 183 | -------------------------------------------------------------------------------- /ldm/modules/karlo/kakao/models/decoder_model.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Karlo-v1.0.alpha 3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved. 4 | # ------------------------------------------------------------------------------------ 5 | 6 | import copy 7 | import torch 8 | 9 | from ldm.modules.karlo.kakao.modules import create_gaussian_diffusion 10 | from ldm.modules.karlo.kakao.modules.unet import PLMImUNet 11 | 12 | 13 | class Text2ImProgressiveModel(torch.nn.Module): 14 | """ 15 | A decoder that generates 64x64px images based on the text prompt. 16 | 17 | :param config: yaml config to define the decoder. 18 | :param tokenizer: tokenizer used in clip. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | config, 24 | tokenizer, 25 | ): 26 | super().__init__() 27 | 28 | self._conf = config 29 | self._model_conf = config.model.hparams 30 | self._diffusion_kwargs = dict( 31 | steps=config.diffusion.steps, 32 | learn_sigma=config.diffusion.learn_sigma, 33 | sigma_small=config.diffusion.sigma_small, 34 | noise_schedule=config.diffusion.noise_schedule, 35 | use_kl=config.diffusion.use_kl, 36 | predict_xstart=config.diffusion.predict_xstart, 37 | rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas, 38 | timestep_respacing=config.diffusion.timestep_respacing, 39 | ) 40 | self._tokenizer = tokenizer 41 | 42 | self.model = self.create_plm_dec_model() 43 | 44 | cf_token, cf_mask = self.set_cf_text_tensor() 45 | self.register_buffer("cf_token", cf_token, persistent=False) 46 | self.register_buffer("cf_mask", cf_mask, persistent=False) 47 | 48 | @classmethod 49 | def load_from_checkpoint(cls, config, tokenizer, ckpt_path, strict: bool = True): 50 | ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"] 51 | 52 | model = cls(config, tokenizer) 53 | model.load_state_dict(ckpt, strict=strict) 54 | return model 55 | 56 | def create_plm_dec_model(self): 57 | image_size = self._model_conf.image_size 58 | if self._model_conf.channel_mult == "": 59 | if image_size == 256: 60 | channel_mult = (1, 1, 2, 2, 4, 4) 61 | elif image_size == 128: 62 | channel_mult = (1, 1, 2, 3, 4) 63 | elif image_size == 64: 64 | channel_mult = (1, 2, 3, 4) 65 | else: 66 | raise ValueError(f"unsupported image size: {image_size}") 67 | else: 68 | channel_mult = tuple( 69 | int(ch_mult) for ch_mult in self._model_conf.channel_mult.split(",") 70 | ) 71 | assert 2 ** (len(channel_mult) + 2) == image_size 72 | 73 | attention_ds = [] 74 | for res in self._model_conf.attention_resolutions.split(","): 75 | attention_ds.append(image_size // int(res)) 76 | 77 | return PLMImUNet( 78 | text_ctx=self._model_conf.text_ctx, 79 | xf_width=self._model_conf.xf_width, 80 | in_channels=3, 81 | model_channels=self._model_conf.num_channels, 82 | out_channels=6 if self._model_conf.learn_sigma else 3, 83 | num_res_blocks=self._model_conf.num_res_blocks, 84 | attention_resolutions=tuple(attention_ds), 85 | dropout=self._model_conf.dropout, 86 | channel_mult=channel_mult, 87 | num_heads=self._model_conf.num_heads, 88 | num_head_channels=self._model_conf.num_head_channels, 89 | num_heads_upsample=self._model_conf.num_heads_upsample, 90 | use_scale_shift_norm=self._model_conf.use_scale_shift_norm, 91 | resblock_updown=self._model_conf.resblock_updown, 92 | clip_dim=self._model_conf.clip_dim, 93 | clip_emb_mult=self._model_conf.clip_emb_mult, 94 | clip_emb_type=self._model_conf.clip_emb_type, 95 | clip_emb_drop=self._model_conf.clip_emb_drop, 96 | ) 97 | 98 | def set_cf_text_tensor(self): 99 | return self._tokenizer.padded_tokens_and_mask([""], self.model.text_ctx) 100 | 101 | def get_sample_fn(self, timestep_respacing): 102 | use_ddim = timestep_respacing.startswith(("ddim", "fast")) 103 | 104 | diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs) 105 | diffusion_kwargs.update(timestep_respacing=timestep_respacing) 106 | diffusion = create_gaussian_diffusion(**diffusion_kwargs) 107 | sample_fn = ( 108 | diffusion.ddim_sample_loop_progressive 109 | if use_ddim 110 | else diffusion.p_sample_loop_progressive 111 | ) 112 | 113 | return sample_fn 114 | 115 | def forward( 116 | self, 117 | txt_feat, 118 | txt_feat_seq, 119 | tok, 120 | mask, 121 | img_feat=None, 122 | cf_guidance_scales=None, 123 | timestep_respacing=None, 124 | ): 125 | # cfg should be enabled in inference 126 | assert cf_guidance_scales is not None and all(cf_guidance_scales > 0.0) 127 | assert img_feat is not None 128 | 129 | bsz = txt_feat.shape[0] 130 | img_sz = self._model_conf.image_size 131 | 132 | def guided_model_fn(x_t, ts, **kwargs): 133 | half = x_t[: len(x_t) // 2] 134 | combined = torch.cat([half, half], dim=0) 135 | model_out = self.model(combined, ts, **kwargs) 136 | eps, rest = model_out[:, :3], model_out[:, 3:] 137 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) 138 | half_eps = uncond_eps + cf_guidance_scales.view(-1, 1, 1, 1) * ( 139 | cond_eps - uncond_eps 140 | ) 141 | eps = torch.cat([half_eps, half_eps], dim=0) 142 | return torch.cat([eps, rest], dim=1) 143 | 144 | cf_feat = self.model.cf_param.unsqueeze(0) 145 | cf_feat = cf_feat.expand(bsz // 2, -1) 146 | feat = torch.cat([img_feat, cf_feat.to(txt_feat.device)], dim=0) 147 | 148 | cond = { 149 | "y": feat, 150 | "txt_feat": txt_feat, 151 | "txt_feat_seq": txt_feat_seq, 152 | "mask": mask, 153 | } 154 | sample_fn = self.get_sample_fn(timestep_respacing) 155 | sample_outputs = sample_fn( 156 | guided_model_fn, 157 | (bsz, 3, img_sz, img_sz), 158 | noise=None, 159 | device=txt_feat.device, 160 | clip_denoised=True, 161 | model_kwargs=cond, 162 | ) 163 | 164 | for out in sample_outputs: 165 | sample = out["sample"] 166 | yield sample if cf_guidance_scales is None else sample[ 167 | : sample.shape[0] // 2 168 | ] 169 | 170 | 171 | class Text2ImModel(Text2ImProgressiveModel): 172 | def forward( 173 | self, 174 | txt_feat, 175 | txt_feat_seq, 176 | tok, 177 | mask, 178 | img_feat=None, 179 | cf_guidance_scales=None, 180 | timestep_respacing=None, 181 | ): 182 | last_out = None 183 | for out in super().forward( 184 | txt_feat, 185 | txt_feat_seq, 186 | tok, 187 | mask, 188 | img_feat, 189 | cf_guidance_scales, 190 | timestep_respacing, 191 | ): 192 | last_out = out 193 | return last_out 194 | -------------------------------------------------------------------------------- /ldm/modules/karlo/kakao/models/prior_model.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Karlo-v1.0.alpha 3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved. 4 | # ------------------------------------------------------------------------------------ 5 | 6 | import copy 7 | import torch 8 | 9 | from ldm.modules.karlo.kakao.modules import create_gaussian_diffusion 10 | from ldm.modules.karlo.kakao.modules.xf import PriorTransformer 11 | 12 | 13 | class PriorDiffusionModel(torch.nn.Module): 14 | """ 15 | A prior that generates clip image feature based on the text prompt. 16 | 17 | :param config: yaml config to define the decoder. 18 | :param tokenizer: tokenizer used in clip. 19 | :param clip_mean: mean to normalize the clip image feature (zero-mean, unit variance). 20 | :param clip_std: std to noramlize the clip image feature (zero-mean, unit variance). 21 | """ 22 | 23 | def __init__(self, config, tokenizer, clip_mean, clip_std): 24 | super().__init__() 25 | 26 | self._conf = config 27 | self._model_conf = config.model.hparams 28 | self._diffusion_kwargs = dict( 29 | steps=config.diffusion.steps, 30 | learn_sigma=config.diffusion.learn_sigma, 31 | sigma_small=config.diffusion.sigma_small, 32 | noise_schedule=config.diffusion.noise_schedule, 33 | use_kl=config.diffusion.use_kl, 34 | predict_xstart=config.diffusion.predict_xstart, 35 | rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas, 36 | timestep_respacing=config.diffusion.timestep_respacing, 37 | ) 38 | self._tokenizer = tokenizer 39 | 40 | self.register_buffer("clip_mean", clip_mean[None, :], persistent=False) 41 | self.register_buffer("clip_std", clip_std[None, :], persistent=False) 42 | 43 | causal_mask = self.get_causal_mask() 44 | self.register_buffer("causal_mask", causal_mask, persistent=False) 45 | 46 | self.model = PriorTransformer( 47 | text_ctx=self._model_conf.text_ctx, 48 | xf_width=self._model_conf.xf_width, 49 | xf_layers=self._model_conf.xf_layers, 50 | xf_heads=self._model_conf.xf_heads, 51 | xf_final_ln=self._model_conf.xf_final_ln, 52 | clip_dim=self._model_conf.clip_dim, 53 | ) 54 | 55 | cf_token, cf_mask = self.set_cf_text_tensor() 56 | self.register_buffer("cf_token", cf_token, persistent=False) 57 | self.register_buffer("cf_mask", cf_mask, persistent=False) 58 | 59 | @classmethod 60 | def load_from_checkpoint( 61 | cls, config, tokenizer, clip_mean, clip_std, ckpt_path, strict: bool = True 62 | ): 63 | ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"] 64 | 65 | model = cls(config, tokenizer, clip_mean, clip_std) 66 | model.load_state_dict(ckpt, strict=strict) 67 | return model 68 | 69 | def set_cf_text_tensor(self): 70 | return self._tokenizer.padded_tokens_and_mask([""], self.model.text_ctx) 71 | 72 | def get_sample_fn(self, timestep_respacing): 73 | use_ddim = timestep_respacing.startswith(("ddim", "fast")) 74 | 75 | diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs) 76 | diffusion_kwargs.update(timestep_respacing=timestep_respacing) 77 | diffusion = create_gaussian_diffusion(**diffusion_kwargs) 78 | sample_fn = diffusion.ddim_sample_loop if use_ddim else diffusion.p_sample_loop 79 | 80 | return sample_fn 81 | 82 | def get_causal_mask(self): 83 | seq_len = self._model_conf.text_ctx + 4 84 | mask = torch.empty(seq_len, seq_len) 85 | mask.fill_(float("-inf")) 86 | mask.triu_(1) 87 | mask = mask[None, ...] 88 | return mask 89 | 90 | def forward( 91 | self, 92 | txt_feat, 93 | txt_feat_seq, 94 | mask, 95 | cf_guidance_scales=None, 96 | timestep_respacing=None, 97 | denoised_fn=True, 98 | ): 99 | # cfg should be enabled in inference 100 | assert cf_guidance_scales is not None and all(cf_guidance_scales > 0.0) 101 | 102 | bsz_ = txt_feat.shape[0] 103 | bsz = bsz_ // 2 104 | 105 | def guided_model_fn(x_t, ts, **kwargs): 106 | half = x_t[: len(x_t) // 2] 107 | combined = torch.cat([half, half], dim=0) 108 | model_out = self.model(combined, ts, **kwargs) 109 | eps, rest = ( 110 | model_out[:, : int(x_t.shape[1])], 111 | model_out[:, int(x_t.shape[1]) :], 112 | ) 113 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) 114 | half_eps = uncond_eps + cf_guidance_scales.view(-1, 1) * ( 115 | cond_eps - uncond_eps 116 | ) 117 | eps = torch.cat([half_eps, half_eps], dim=0) 118 | return torch.cat([eps, rest], dim=1) 119 | 120 | cond = { 121 | "text_emb": txt_feat, 122 | "text_enc": txt_feat_seq, 123 | "mask": mask, 124 | "causal_mask": self.causal_mask, 125 | } 126 | sample_fn = self.get_sample_fn(timestep_respacing) 127 | sample = sample_fn( 128 | guided_model_fn, 129 | (bsz_, self.model.clip_dim), 130 | noise=None, 131 | device=txt_feat.device, 132 | clip_denoised=False, 133 | denoised_fn=lambda x: torch.clamp(x, -10, 10), 134 | model_kwargs=cond, 135 | ) 136 | sample = (sample * self.clip_std) + self.clip_mean 137 | 138 | return sample[:bsz] 139 | -------------------------------------------------------------------------------- /ldm/modules/karlo/kakao/models/sr_256_1k.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Karlo-v1.0.alpha 3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved. 4 | # ------------------------------------------------------------------------------------ 5 | 6 | from ldm.modules.karlo.kakao.models.sr_64_256 import SupRes64to256Progressive 7 | 8 | 9 | class SupRes256to1kProgressive(SupRes64to256Progressive): 10 | pass # no difference currently 11 | -------------------------------------------------------------------------------- /ldm/modules/karlo/kakao/models/sr_64_256.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Karlo-v1.0.alpha 3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved. 4 | # ------------------------------------------------------------------------------------ 5 | 6 | import copy 7 | import torch 8 | 9 | from ldm.modules.karlo.kakao.modules.unet import SuperResUNetModel 10 | from ldm.modules.karlo.kakao.modules import create_gaussian_diffusion 11 | 12 | 13 | class ImprovedSupRes64to256ProgressiveModel(torch.nn.Module): 14 | """ 15 | ImprovedSR model fine-tunes the pretrained DDPM-based SR model by using adversarial and perceptual losses. 16 | In specific, the low-resolution sample is iteratively recovered by 6 steps with the frozen pretrained SR model. 17 | In the following additional one step, a seperate fine-tuned model recovers high-frequency details. 18 | This approach greatly improves the fidelity of images of 256x256px, even with small number of reverse steps. 19 | """ 20 | 21 | def __init__(self, config): 22 | super().__init__() 23 | 24 | self._config = config 25 | self._diffusion_kwargs = dict( 26 | steps=config.diffusion.steps, 27 | learn_sigma=config.diffusion.learn_sigma, 28 | sigma_small=config.diffusion.sigma_small, 29 | noise_schedule=config.diffusion.noise_schedule, 30 | use_kl=config.diffusion.use_kl, 31 | predict_xstart=config.diffusion.predict_xstart, 32 | rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas, 33 | ) 34 | 35 | self.model_first_steps = SuperResUNetModel( 36 | in_channels=3, # auto-changed to 6 inside the model 37 | model_channels=config.model.hparams.channels, 38 | out_channels=3, 39 | num_res_blocks=config.model.hparams.depth, 40 | attention_resolutions=(), # no attention 41 | dropout=config.model.hparams.dropout, 42 | channel_mult=config.model.hparams.channels_multiple, 43 | resblock_updown=True, 44 | use_middle_attention=False, 45 | ) 46 | self.model_last_step = SuperResUNetModel( 47 | in_channels=3, # auto-changed to 6 inside the model 48 | model_channels=config.model.hparams.channels, 49 | out_channels=3, 50 | num_res_blocks=config.model.hparams.depth, 51 | attention_resolutions=(), # no attention 52 | dropout=config.model.hparams.dropout, 53 | channel_mult=config.model.hparams.channels_multiple, 54 | resblock_updown=True, 55 | use_middle_attention=False, 56 | ) 57 | 58 | @classmethod 59 | def load_from_checkpoint(cls, config, ckpt_path, strict: bool = True): 60 | ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"] 61 | 62 | model = cls(config) 63 | model.load_state_dict(ckpt, strict=strict) 64 | return model 65 | 66 | def get_sample_fn(self, timestep_respacing): 67 | diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs) 68 | diffusion_kwargs.update(timestep_respacing=timestep_respacing) 69 | diffusion = create_gaussian_diffusion(**diffusion_kwargs) 70 | return diffusion.p_sample_loop_progressive_for_improved_sr 71 | 72 | def forward(self, low_res, timestep_respacing="7", **kwargs): 73 | assert ( 74 | timestep_respacing == "7" 75 | ), "different respacing method may work, but no guaranteed" 76 | 77 | sample_fn = self.get_sample_fn(timestep_respacing) 78 | sample_outputs = sample_fn( 79 | self.model_first_steps, 80 | self.model_last_step, 81 | shape=low_res.shape, 82 | clip_denoised=True, 83 | model_kwargs=dict(low_res=low_res), 84 | **kwargs, 85 | ) 86 | for x in sample_outputs: 87 | sample = x["sample"] 88 | yield sample 89 | -------------------------------------------------------------------------------- /ldm/modules/karlo/kakao/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion) 3 | # ------------------------------------------------------------------------------------ 4 | 5 | 6 | from .diffusion import gaussian_diffusion as gd 7 | from .diffusion.respace import ( 8 | SpacedDiffusion, 9 | space_timesteps, 10 | ) 11 | 12 | 13 | def create_gaussian_diffusion( 14 | steps, 15 | learn_sigma, 16 | sigma_small, 17 | noise_schedule, 18 | use_kl, 19 | predict_xstart, 20 | rescale_learned_sigmas, 21 | timestep_respacing, 22 | ): 23 | betas = gd.get_named_beta_schedule(noise_schedule, steps) 24 | if use_kl: 25 | loss_type = gd.LossType.RESCALED_KL 26 | elif rescale_learned_sigmas: 27 | loss_type = gd.LossType.RESCALED_MSE 28 | else: 29 | loss_type = gd.LossType.MSE 30 | if not timestep_respacing: 31 | timestep_respacing = [steps] 32 | 33 | return SpacedDiffusion( 34 | use_timesteps=space_timesteps(steps, timestep_respacing), 35 | betas=betas, 36 | model_mean_type=( 37 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 38 | ), 39 | model_var_type=( 40 | ( 41 | gd.ModelVarType.FIXED_LARGE 42 | if not sigma_small 43 | else gd.ModelVarType.FIXED_SMALL 44 | ) 45 | if not learn_sigma 46 | else gd.ModelVarType.LEARNED_RANGE 47 | ), 48 | loss_type=loss_type, 49 | ) 50 | -------------------------------------------------------------------------------- /ldm/modules/karlo/kakao/modules/diffusion/respace.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion) 3 | # ------------------------------------------------------------------------------------ 4 | 5 | 6 | import torch as th 7 | 8 | from .gaussian_diffusion import GaussianDiffusion 9 | 10 | 11 | def space_timesteps(num_timesteps, section_counts): 12 | """ 13 | Create a list of timesteps to use from an original diffusion process, 14 | given the number of timesteps we want to take from equally-sized portions 15 | of the original process. 16 | 17 | For example, if there's 300 timesteps and the section counts are [10,15,20] 18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 19 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 20 | 21 | :param num_timesteps: the number of diffusion steps in the original 22 | process to divide up. 23 | :param section_counts: either a list of numbers, or a string containing 24 | comma-separated numbers, indicating the step count 25 | per section. As a special case, use "ddimN" where N 26 | is a number of steps to use the striding from the 27 | DDIM paper. 28 | :return: a set of diffusion steps from the original process to use. 29 | """ 30 | if isinstance(section_counts, str): 31 | if section_counts.startswith("ddim"): 32 | desired_count = int(section_counts[len("ddim") :]) 33 | for i in range(1, num_timesteps): 34 | if len(range(0, num_timesteps, i)) == desired_count: 35 | return set(range(0, num_timesteps, i)) 36 | raise ValueError( 37 | f"cannot create exactly {num_timesteps} steps with an integer stride" 38 | ) 39 | elif section_counts == "fast27": 40 | steps = space_timesteps(num_timesteps, "10,10,3,2,2") 41 | # Help reduce DDIM artifacts from noisiest timesteps. 42 | steps.remove(num_timesteps - 1) 43 | steps.add(num_timesteps - 3) 44 | return steps 45 | section_counts = [int(x) for x in section_counts.split(",")] 46 | size_per = num_timesteps // len(section_counts) 47 | extra = num_timesteps % len(section_counts) 48 | start_idx = 0 49 | all_steps = [] 50 | for i, section_count in enumerate(section_counts): 51 | size = size_per + (1 if i < extra else 0) 52 | if size < section_count: 53 | raise ValueError( 54 | f"cannot divide section of {size} steps into {section_count}" 55 | ) 56 | if section_count <= 1: 57 | frac_stride = 1 58 | else: 59 | frac_stride = (size - 1) / (section_count - 1) 60 | cur_idx = 0.0 61 | taken_steps = [] 62 | for _ in range(section_count): 63 | taken_steps.append(start_idx + round(cur_idx)) 64 | cur_idx += frac_stride 65 | all_steps += taken_steps 66 | start_idx += size 67 | return set(all_steps) 68 | 69 | 70 | class SpacedDiffusion(GaussianDiffusion): 71 | """ 72 | A diffusion process which can skip steps in a base diffusion process. 73 | 74 | :param use_timesteps: a collection (sequence or set) of timesteps from the 75 | original diffusion process to retain. 76 | :param kwargs: the kwargs to create the base diffusion process. 77 | """ 78 | 79 | def __init__(self, use_timesteps, **kwargs): 80 | self.use_timesteps = set(use_timesteps) 81 | self.original_num_steps = len(kwargs["betas"]) 82 | 83 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 84 | last_alpha_cumprod = 1.0 85 | new_betas = [] 86 | timestep_map = [] 87 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 88 | if i in self.use_timesteps: 89 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 90 | last_alpha_cumprod = alpha_cumprod 91 | timestep_map.append(i) 92 | kwargs["betas"] = th.tensor(new_betas).numpy() 93 | super().__init__(**kwargs) 94 | self.register_buffer("timestep_map", th.tensor(timestep_map), persistent=False) 95 | 96 | def p_mean_variance(self, model, *args, **kwargs): 97 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 98 | 99 | def condition_mean(self, cond_fn, *args, **kwargs): 100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 101 | 102 | def condition_score(self, cond_fn, *args, **kwargs): 103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 104 | 105 | def _wrap_model(self, model): 106 | def wrapped(x, ts, **kwargs): 107 | ts_cpu = ts.detach().to("cpu") 108 | return model( 109 | x, self.timestep_map[ts_cpu].to(device=ts.device, dtype=ts.dtype), **kwargs 110 | ) 111 | 112 | return wrapped 113 | -------------------------------------------------------------------------------- /ldm/modules/karlo/kakao/modules/nn.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion) 3 | # ------------------------------------------------------------------------------------ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class GroupNorm32(nn.GroupNorm): 13 | def __init__(self, num_groups, num_channels, swish, eps=1e-5): 14 | super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps) 15 | self.swish = swish 16 | 17 | def forward(self, x): 18 | y = super().forward(x.float()).to(x.dtype) 19 | if self.swish == 1.0: 20 | y = F.silu(y) 21 | elif self.swish: 22 | y = y * F.sigmoid(y * float(self.swish)) 23 | return y 24 | 25 | 26 | def conv_nd(dims, *args, **kwargs): 27 | """ 28 | Create a 1D, 2D, or 3D convolution module. 29 | """ 30 | if dims == 1: 31 | return nn.Conv1d(*args, **kwargs) 32 | elif dims == 2: 33 | return nn.Conv2d(*args, **kwargs) 34 | elif dims == 3: 35 | return nn.Conv3d(*args, **kwargs) 36 | raise ValueError(f"unsupported dimensions: {dims}") 37 | 38 | 39 | def linear(*args, **kwargs): 40 | """ 41 | Create a linear module. 42 | """ 43 | return nn.Linear(*args, **kwargs) 44 | 45 | 46 | def avg_pool_nd(dims, *args, **kwargs): 47 | """ 48 | Create a 1D, 2D, or 3D average pooling module. 49 | """ 50 | if dims == 1: 51 | return nn.AvgPool1d(*args, **kwargs) 52 | elif dims == 2: 53 | return nn.AvgPool2d(*args, **kwargs) 54 | elif dims == 3: 55 | return nn.AvgPool3d(*args, **kwargs) 56 | raise ValueError(f"unsupported dimensions: {dims}") 57 | 58 | 59 | def zero_module(module): 60 | """ 61 | Zero out the parameters of a module and return it. 62 | """ 63 | for p in module.parameters(): 64 | p.detach().zero_() 65 | return module 66 | 67 | 68 | def scale_module(module, scale): 69 | """ 70 | Scale the parameters of a module and return it. 71 | """ 72 | for p in module.parameters(): 73 | p.detach().mul_(scale) 74 | return module 75 | 76 | 77 | def normalization(channels, swish=0.0): 78 | """ 79 | Make a standard normalization layer, with an optional swish activation. 80 | 81 | :param channels: number of input channels. 82 | :return: an nn.Module for normalization. 83 | """ 84 | return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) 85 | 86 | 87 | def timestep_embedding(timesteps, dim, max_period=10000): 88 | """ 89 | Create sinusoidal timestep embeddings. 90 | 91 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 92 | These may be fractional. 93 | :param dim: the dimension of the output. 94 | :param max_period: controls the minimum frequency of the embeddings. 95 | :return: an [N x dim] Tensor of positional embeddings. 96 | """ 97 | half = dim // 2 98 | freqs = th.exp( 99 | -math.log(max_period) 100 | * th.arange(start=0, end=half, dtype=th.float32, device=timesteps.device) 101 | / half 102 | ) 103 | args = timesteps[:, None].float() * freqs[None] 104 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 105 | if dim % 2: 106 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 107 | return embedding 108 | 109 | 110 | def mean_flat(tensor): 111 | """ 112 | Take the mean over all non-batch dimensions. 113 | """ 114 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 115 | -------------------------------------------------------------------------------- /ldm/modules/karlo/kakao/modules/resample.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Modified from Guided-Diffusion (https://github.com/openai/guided-diffusion) 3 | # ------------------------------------------------------------------------------------ 4 | 5 | from abc import abstractmethod 6 | 7 | import torch as th 8 | 9 | 10 | def create_named_schedule_sampler(name, diffusion): 11 | """ 12 | Create a ScheduleSampler from a library of pre-defined samplers. 13 | 14 | :param name: the name of the sampler. 15 | :param diffusion: the diffusion object to sample for. 16 | """ 17 | if name == "uniform": 18 | return UniformSampler(diffusion) 19 | else: 20 | raise NotImplementedError(f"unknown schedule sampler: {name}") 21 | 22 | 23 | class ScheduleSampler(th.nn.Module): 24 | """ 25 | A distribution over timesteps in the diffusion process, intended to reduce 26 | variance of the objective. 27 | 28 | By default, samplers perform unbiased importance sampling, in which the 29 | objective's mean is unchanged. 30 | However, subclasses may override sample() to change how the resampled 31 | terms are reweighted, allowing for actual changes in the objective. 32 | """ 33 | 34 | @abstractmethod 35 | def weights(self): 36 | """ 37 | Get a numpy array of weights, one per diffusion step. 38 | 39 | The weights needn't be normalized, but must be positive. 40 | """ 41 | 42 | def sample(self, batch_size, device): 43 | """ 44 | Importance-sample timesteps for a batch. 45 | 46 | :param batch_size: the number of timesteps. 47 | :param device: the torch device to save to. 48 | :return: a tuple (timesteps, weights): 49 | - timesteps: a tensor of timestep indices. 50 | - weights: a tensor of weights to scale the resulting losses. 51 | """ 52 | w = self.weights() 53 | p = w / th.sum(w) 54 | indices = p.multinomial(batch_size, replacement=True) 55 | weights = 1 / (len(p) * p[indices]) 56 | return indices, weights 57 | 58 | 59 | class UniformSampler(ScheduleSampler): 60 | def __init__(self, diffusion): 61 | super(UniformSampler, self).__init__() 62 | self.diffusion = diffusion 63 | self.register_buffer( 64 | "_weights", th.ones([diffusion.num_timesteps]), persistent=False 65 | ) 66 | 67 | def weights(self): 68 | return self._weights 69 | -------------------------------------------------------------------------------- /ldm/modules/karlo/kakao/modules/xf.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Adapted from the repos below: 3 | # (a) Guided-Diffusion (https://github.com/openai/guided-diffusion) 4 | # (b) CLIP ViT (https://github.com/openai/CLIP/) 5 | # ------------------------------------------------------------------------------------ 6 | 7 | import math 8 | 9 | import torch as th 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from .nn import timestep_embedding 14 | 15 | 16 | def convert_module_to_f16(param): 17 | """ 18 | Convert primitive modules to float16. 19 | """ 20 | if isinstance(param, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): 21 | param.weight.data = param.weight.data.half() 22 | if param.bias is not None: 23 | param.bias.data = param.bias.data.half() 24 | 25 | 26 | class LayerNorm(nn.LayerNorm): 27 | """ 28 | Implementation that supports fp16 inputs but fp32 gains/biases. 29 | """ 30 | 31 | def forward(self, x: th.Tensor): 32 | return super().forward(x.float()).to(x.dtype) 33 | 34 | 35 | class MultiheadAttention(nn.Module): 36 | def __init__(self, n_ctx, width, heads): 37 | super().__init__() 38 | self.n_ctx = n_ctx 39 | self.width = width 40 | self.heads = heads 41 | self.c_qkv = nn.Linear(width, width * 3) 42 | self.c_proj = nn.Linear(width, width) 43 | self.attention = QKVMultiheadAttention(heads, n_ctx) 44 | 45 | def forward(self, x, mask=None): 46 | x = self.c_qkv(x) 47 | x = self.attention(x, mask=mask) 48 | x = self.c_proj(x) 49 | return x 50 | 51 | 52 | class MLP(nn.Module): 53 | def __init__(self, width): 54 | super().__init__() 55 | self.width = width 56 | self.c_fc = nn.Linear(width, width * 4) 57 | self.c_proj = nn.Linear(width * 4, width) 58 | self.gelu = nn.GELU() 59 | 60 | def forward(self, x): 61 | return self.c_proj(self.gelu(self.c_fc(x))) 62 | 63 | 64 | class QKVMultiheadAttention(nn.Module): 65 | def __init__(self, n_heads: int, n_ctx: int): 66 | super().__init__() 67 | self.n_heads = n_heads 68 | self.n_ctx = n_ctx 69 | 70 | def forward(self, qkv, mask=None): 71 | bs, n_ctx, width = qkv.shape 72 | attn_ch = width // self.n_heads // 3 73 | scale = 1 / math.sqrt(math.sqrt(attn_ch)) 74 | qkv = qkv.view(bs, n_ctx, self.n_heads, -1) 75 | q, k, v = th.split(qkv, attn_ch, dim=-1) 76 | weight = th.einsum("bthc,bshc->bhts", q * scale, k * scale) 77 | wdtype = weight.dtype 78 | if mask is not None: 79 | weight = weight + mask[:, None, ...] 80 | weight = th.softmax(weight, dim=-1).type(wdtype) 81 | return th.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) 82 | 83 | 84 | class ResidualAttentionBlock(nn.Module): 85 | def __init__( 86 | self, 87 | n_ctx: int, 88 | width: int, 89 | heads: int, 90 | ): 91 | super().__init__() 92 | 93 | self.attn = MultiheadAttention( 94 | n_ctx, 95 | width, 96 | heads, 97 | ) 98 | self.ln_1 = LayerNorm(width) 99 | self.mlp = MLP(width) 100 | self.ln_2 = LayerNorm(width) 101 | 102 | def forward(self, x, mask=None): 103 | x = x + self.attn(self.ln_1(x), mask=mask) 104 | x = x + self.mlp(self.ln_2(x)) 105 | return x 106 | 107 | 108 | class Transformer(nn.Module): 109 | def __init__( 110 | self, 111 | n_ctx: int, 112 | width: int, 113 | layers: int, 114 | heads: int, 115 | ): 116 | super().__init__() 117 | self.n_ctx = n_ctx 118 | self.width = width 119 | self.layers = layers 120 | self.resblocks = nn.ModuleList( 121 | [ 122 | ResidualAttentionBlock( 123 | n_ctx, 124 | width, 125 | heads, 126 | ) 127 | for _ in range(layers) 128 | ] 129 | ) 130 | 131 | def forward(self, x, mask=None): 132 | for block in self.resblocks: 133 | x = block(x, mask=mask) 134 | return x 135 | 136 | 137 | class PriorTransformer(nn.Module): 138 | """ 139 | A Causal Transformer that conditions on CLIP text embedding, text. 140 | 141 | :param text_ctx: number of text tokens to expect. 142 | :param xf_width: width of the transformer. 143 | :param xf_layers: depth of the transformer. 144 | :param xf_heads: heads in the transformer. 145 | :param xf_final_ln: use a LayerNorm after the output layer. 146 | :param clip_dim: dimension of clip feature. 147 | """ 148 | 149 | def __init__( 150 | self, 151 | text_ctx, 152 | xf_width, 153 | xf_layers, 154 | xf_heads, 155 | xf_final_ln, 156 | clip_dim, 157 | ): 158 | super().__init__() 159 | 160 | self.text_ctx = text_ctx 161 | self.xf_width = xf_width 162 | self.xf_layers = xf_layers 163 | self.xf_heads = xf_heads 164 | self.clip_dim = clip_dim 165 | self.ext_len = 4 166 | 167 | self.time_embed = nn.Sequential( 168 | nn.Linear(xf_width, xf_width), 169 | nn.SiLU(), 170 | nn.Linear(xf_width, xf_width), 171 | ) 172 | self.text_enc_proj = nn.Linear(clip_dim, xf_width) 173 | self.text_emb_proj = nn.Linear(clip_dim, xf_width) 174 | self.clip_img_proj = nn.Linear(clip_dim, xf_width) 175 | self.out_proj = nn.Linear(xf_width, clip_dim) 176 | self.transformer = Transformer( 177 | text_ctx + self.ext_len, 178 | xf_width, 179 | xf_layers, 180 | xf_heads, 181 | ) 182 | if xf_final_ln: 183 | self.final_ln = LayerNorm(xf_width) 184 | else: 185 | self.final_ln = None 186 | 187 | self.positional_embedding = nn.Parameter( 188 | th.empty(1, text_ctx + self.ext_len, xf_width) 189 | ) 190 | self.prd_emb = nn.Parameter(th.randn((1, 1, xf_width))) 191 | 192 | nn.init.normal_(self.prd_emb, std=0.01) 193 | nn.init.normal_(self.positional_embedding, std=0.01) 194 | 195 | def forward( 196 | self, 197 | x, 198 | timesteps, 199 | text_emb=None, 200 | text_enc=None, 201 | mask=None, 202 | causal_mask=None, 203 | ): 204 | bsz = x.shape[0] 205 | mask = F.pad(mask, (0, self.ext_len), value=True) 206 | 207 | t_emb = self.time_embed(timestep_embedding(timesteps, self.xf_width)) 208 | text_enc = self.text_enc_proj(text_enc) 209 | text_emb = self.text_emb_proj(text_emb) 210 | x = self.clip_img_proj(x) 211 | 212 | input_seq = [ 213 | text_enc, 214 | text_emb[:, None, :], 215 | t_emb[:, None, :], 216 | x[:, None, :], 217 | self.prd_emb.to(x.dtype).expand(bsz, -1, -1), 218 | ] 219 | input = th.cat(input_seq, dim=1) 220 | input = input + self.positional_embedding.to(input.dtype) 221 | 222 | mask = th.where(mask, 0.0, float("-inf")) 223 | mask = (mask[:, None, :] + causal_mask).to(input.dtype) 224 | 225 | out = self.transformer(input, mask=mask) 226 | if self.final_ln is not None: 227 | out = self.final_ln(out) 228 | 229 | out = self.out_proj(out[:, -1]) 230 | 231 | return out 232 | -------------------------------------------------------------------------------- /ldm/modules/karlo/kakao/template.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Karlo-v1.0.alpha 3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved. 4 | # ------------------------------------------------------------------------------------ 5 | 6 | import os 7 | import logging 8 | import torch 9 | 10 | from omegaconf import OmegaConf 11 | 12 | from ldm.modules.karlo.kakao.models.clip import CustomizedCLIP, CustomizedTokenizer 13 | from ldm.modules.karlo.kakao.models.prior_model import PriorDiffusionModel 14 | from ldm.modules.karlo.kakao.models.decoder_model import Text2ImProgressiveModel 15 | from ldm.modules.karlo.kakao.models.sr_64_256 import ImprovedSupRes64to256ProgressiveModel 16 | 17 | 18 | SAMPLING_CONF = { 19 | "default": { 20 | "prior_sm": "25", 21 | "prior_n_samples": 1, 22 | "prior_cf_scale": 4.0, 23 | "decoder_sm": "50", 24 | "decoder_cf_scale": 8.0, 25 | "sr_sm": "7", 26 | }, 27 | "fast": { 28 | "prior_sm": "25", 29 | "prior_n_samples": 1, 30 | "prior_cf_scale": 4.0, 31 | "decoder_sm": "25", 32 | "decoder_cf_scale": 8.0, 33 | "sr_sm": "7", 34 | }, 35 | } 36 | 37 | CKPT_PATH = { 38 | "prior": "prior-ckpt-step=01000000-of-01000000.ckpt", 39 | "decoder": "decoder-ckpt-step=01000000-of-01000000.ckpt", 40 | "sr_256": "improved-sr-ckpt-step=1.2M.ckpt", 41 | } 42 | 43 | 44 | class BaseSampler: 45 | _PRIOR_CLASS = PriorDiffusionModel 46 | _DECODER_CLASS = Text2ImProgressiveModel 47 | _SR256_CLASS = ImprovedSupRes64to256ProgressiveModel 48 | 49 | def __init__( 50 | self, 51 | root_dir: str, 52 | sampling_type: str = "fast", 53 | ): 54 | self._root_dir = root_dir 55 | 56 | sampling_type = SAMPLING_CONF[sampling_type] 57 | self._prior_sm = sampling_type["prior_sm"] 58 | self._prior_n_samples = sampling_type["prior_n_samples"] 59 | self._prior_cf_scale = sampling_type["prior_cf_scale"] 60 | 61 | assert self._prior_n_samples == 1 62 | 63 | self._decoder_sm = sampling_type["decoder_sm"] 64 | self._decoder_cf_scale = sampling_type["decoder_cf_scale"] 65 | 66 | self._sr_sm = sampling_type["sr_sm"] 67 | 68 | def __repr__(self): 69 | line = "" 70 | line += f"Prior, sampling method: {self._prior_sm}, cf_scale: {self._prior_cf_scale}\n" 71 | line += f"Decoder, sampling method: {self._decoder_sm}, cf_scale: {self._decoder_cf_scale}\n" 72 | line += f"SR(64->256), sampling method: {self._sr_sm}" 73 | 74 | return line 75 | 76 | def load_clip(self, clip_path: str): 77 | clip = CustomizedCLIP.load_from_checkpoint( 78 | os.path.join(self._root_dir, clip_path) 79 | ) 80 | clip = torch.jit.script(clip) 81 | clip.cuda() 82 | clip.eval() 83 | 84 | self._clip = clip 85 | self._tokenizer = CustomizedTokenizer() 86 | 87 | def load_prior( 88 | self, 89 | ckpt_path: str, 90 | clip_stat_path: str, 91 | prior_config: str = "configs/prior_1B_vit_l.yaml" 92 | ): 93 | logging.info(f"Loading prior: {ckpt_path}") 94 | 95 | config = OmegaConf.load(prior_config) 96 | clip_mean, clip_std = torch.load( 97 | os.path.join(self._root_dir, clip_stat_path), map_location="cpu" 98 | ) 99 | 100 | prior = self._PRIOR_CLASS.load_from_checkpoint( 101 | config, 102 | self._tokenizer, 103 | clip_mean, 104 | clip_std, 105 | os.path.join(self._root_dir, ckpt_path), 106 | strict=True, 107 | ) 108 | prior.cuda() 109 | prior.eval() 110 | logging.info("done.") 111 | 112 | self._prior = prior 113 | 114 | def load_decoder(self, ckpt_path: str, decoder_config: str = "configs/decoder_900M_vit_l.yaml"): 115 | logging.info(f"Loading decoder: {ckpt_path}") 116 | 117 | config = OmegaConf.load(decoder_config) 118 | decoder = self._DECODER_CLASS.load_from_checkpoint( 119 | config, 120 | self._tokenizer, 121 | os.path.join(self._root_dir, ckpt_path), 122 | strict=True, 123 | ) 124 | decoder.cuda() 125 | decoder.eval() 126 | logging.info("done.") 127 | 128 | self._decoder = decoder 129 | 130 | def load_sr_64_256(self, ckpt_path: str, sr_config: str = "configs/improved_sr_64_256_1.4B.yaml"): 131 | logging.info(f"Loading SR(64->256): {ckpt_path}") 132 | 133 | config = OmegaConf.load(sr_config) 134 | sr = self._SR256_CLASS.load_from_checkpoint( 135 | config, os.path.join(self._root_dir, ckpt_path), strict=True 136 | ) 137 | sr.cuda() 138 | sr.eval() 139 | logging.info("done.") 140 | 141 | self._sr_64_256 = sr -------------------------------------------------------------------------------- /ldm/modules/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/ldm/modules/midas/__init__.py -------------------------------------------------------------------------------- /ldm/modules/midas/api.py: -------------------------------------------------------------------------------- 1 | # based on https://github.com/isl-org/MiDaS 2 | 3 | import cv2 4 | import torch 5 | import torch.nn as nn 6 | from torchvision.transforms import Compose 7 | 8 | from ldm.modules.midas.midas.dpt_depth import DPTDepthModel 9 | from ldm.modules.midas.midas.midas_net import MidasNet 10 | from ldm.modules.midas.midas.midas_net_custom import MidasNet_small 11 | from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet 12 | 13 | 14 | ISL_PATHS = { 15 | "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt", 16 | "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt", 17 | "midas_v21": "", 18 | "midas_v21_small": "", 19 | } 20 | 21 | 22 | def disabled_train(self, mode=True): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | def load_midas_transform(model_type): 29 | # https://github.com/isl-org/MiDaS/blob/master/run.py 30 | # load transform only 31 | if model_type == "dpt_large": # DPT-Large 32 | net_w, net_h = 384, 384 33 | resize_mode = "minimal" 34 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 35 | 36 | elif model_type == "dpt_hybrid": # DPT-Hybrid 37 | net_w, net_h = 384, 384 38 | resize_mode = "minimal" 39 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 40 | 41 | elif model_type == "midas_v21": 42 | net_w, net_h = 384, 384 43 | resize_mode = "upper_bound" 44 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 45 | 46 | elif model_type == "midas_v21_small": 47 | net_w, net_h = 256, 256 48 | resize_mode = "upper_bound" 49 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 50 | 51 | else: 52 | assert False, f"model_type '{model_type}' not implemented, use: --model_type large" 53 | 54 | transform = Compose( 55 | [ 56 | Resize( 57 | net_w, 58 | net_h, 59 | resize_target=None, 60 | keep_aspect_ratio=True, 61 | ensure_multiple_of=32, 62 | resize_method=resize_mode, 63 | image_interpolation_method=cv2.INTER_CUBIC, 64 | ), 65 | normalization, 66 | PrepareForNet(), 67 | ] 68 | ) 69 | 70 | return transform 71 | 72 | 73 | def load_model(model_type): 74 | # https://github.com/isl-org/MiDaS/blob/master/run.py 75 | # load network 76 | model_path = ISL_PATHS[model_type] 77 | if model_type == "dpt_large": # DPT-Large 78 | model = DPTDepthModel( 79 | path=model_path, 80 | backbone="vitl16_384", 81 | non_negative=True, 82 | ) 83 | net_w, net_h = 384, 384 84 | resize_mode = "minimal" 85 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 86 | 87 | elif model_type == "dpt_hybrid": # DPT-Hybrid 88 | model = DPTDepthModel( 89 | path=model_path, 90 | backbone="vitb_rn50_384", 91 | non_negative=True, 92 | ) 93 | net_w, net_h = 384, 384 94 | resize_mode = "minimal" 95 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 96 | 97 | elif model_type == "midas_v21": 98 | model = MidasNet(model_path, non_negative=True) 99 | net_w, net_h = 384, 384 100 | resize_mode = "upper_bound" 101 | normalization = NormalizeImage( 102 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 103 | ) 104 | 105 | elif model_type == "midas_v21_small": 106 | model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, 107 | non_negative=True, blocks={'expand': True}) 108 | net_w, net_h = 256, 256 109 | resize_mode = "upper_bound" 110 | normalization = NormalizeImage( 111 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 112 | ) 113 | 114 | else: 115 | print(f"model_type '{model_type}' not implemented, use: --model_type large") 116 | assert False 117 | 118 | transform = Compose( 119 | [ 120 | Resize( 121 | net_w, 122 | net_h, 123 | resize_target=None, 124 | keep_aspect_ratio=True, 125 | ensure_multiple_of=32, 126 | resize_method=resize_mode, 127 | image_interpolation_method=cv2.INTER_CUBIC, 128 | ), 129 | normalization, 130 | PrepareForNet(), 131 | ] 132 | ) 133 | 134 | return model.eval(), transform 135 | 136 | 137 | class MiDaSInference(nn.Module): 138 | MODEL_TYPES_TORCH_HUB = [ 139 | "DPT_Large", 140 | "DPT_Hybrid", 141 | "MiDaS_small" 142 | ] 143 | MODEL_TYPES_ISL = [ 144 | "dpt_large", 145 | "dpt_hybrid", 146 | "midas_v21", 147 | "midas_v21_small", 148 | ] 149 | 150 | def __init__(self, model_type): 151 | super().__init__() 152 | assert (model_type in self.MODEL_TYPES_ISL) 153 | model, _ = load_model(model_type) 154 | self.model = model 155 | self.model.train = disabled_train 156 | 157 | def forward(self, x): 158 | # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array 159 | # NOTE: we expect that the correct transform has been called during dataloading. 160 | with torch.no_grad(): 161 | prediction = self.model(x) 162 | prediction = torch.nn.functional.interpolate( 163 | prediction.unsqueeze(1), 164 | size=x.shape[2:], 165 | mode="bicubic", 166 | align_corners=False, 167 | ) 168 | assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3]) 169 | return prediction 170 | 171 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stablediffusion/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/ldm/modules/midas/midas/__init__.py -------------------------------------------------------------------------------- /ldm/modules/midas/midas/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device('cpu')) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/dpt_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_model import BaseModel 6 | from .blocks import ( 7 | FeatureFusionBlock, 8 | FeatureFusionBlock_custom, 9 | Interpolate, 10 | _make_encoder, 11 | forward_vit, 12 | ) 13 | 14 | 15 | def _make_fusion_block(features, use_bn): 16 | return FeatureFusionBlock_custom( 17 | features, 18 | nn.ReLU(False), 19 | deconv=False, 20 | bn=use_bn, 21 | expand=False, 22 | align_corners=True, 23 | ) 24 | 25 | 26 | class DPT(BaseModel): 27 | def __init__( 28 | self, 29 | head, 30 | features=256, 31 | backbone="vitb_rn50_384", 32 | readout="project", 33 | channels_last=False, 34 | use_bn=False, 35 | ): 36 | 37 | super(DPT, self).__init__() 38 | 39 | self.channels_last = channels_last 40 | 41 | hooks = { 42 | "vitb_rn50_384": [0, 1, 8, 11], 43 | "vitb16_384": [2, 5, 8, 11], 44 | "vitl16_384": [5, 11, 17, 23], 45 | } 46 | 47 | # Instantiate backbone and reassemble blocks 48 | self.pretrained, self.scratch = _make_encoder( 49 | backbone, 50 | features, 51 | False, # Set to true of you want to train from scratch, uses ImageNet weights 52 | groups=1, 53 | expand=False, 54 | exportable=False, 55 | hooks=hooks[backbone], 56 | use_readout=readout, 57 | ) 58 | 59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 63 | 64 | self.scratch.output_conv = head 65 | 66 | 67 | def forward(self, x): 68 | if self.channels_last == True: 69 | x.contiguous(memory_format=torch.channels_last) 70 | 71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 72 | 73 | layer_1_rn = self.scratch.layer1_rn(layer_1) 74 | layer_2_rn = self.scratch.layer2_rn(layer_2) 75 | layer_3_rn = self.scratch.layer3_rn(layer_3) 76 | layer_4_rn = self.scratch.layer4_rn(layer_4) 77 | 78 | path_4 = self.scratch.refinenet4(layer_4_rn) 79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 82 | 83 | out = self.scratch.output_conv(path_1) 84 | 85 | return out 86 | 87 | 88 | class DPTDepthModel(DPT): 89 | def __init__(self, path=None, non_negative=True, **kwargs): 90 | features = kwargs["features"] if "features" in kwargs else 256 91 | 92 | head = nn.Sequential( 93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 96 | nn.ReLU(True), 97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 98 | nn.ReLU(True) if non_negative else nn.Identity(), 99 | nn.Identity(), 100 | ) 101 | 102 | super().__init__(head, **kwargs) 103 | 104 | if path is not None: 105 | self.load(path) 106 | 107 | def forward(self, x): 108 | return super().forward(x).squeeze(dim=1) 109 | 110 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=256, non_negative=True): 17 | """Init. 18 | 19 | Args: 20 | path (str, optional): Path to saved model. Defaults to None. 21 | features (int, optional): Number of features. Defaults to 256. 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 23 | """ 24 | print("Loading weights: ", path) 25 | 26 | super(MidasNet, self).__init__() 27 | 28 | use_pretrained = False if path is None else True 29 | 30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) 31 | 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) 36 | 37 | self.scratch.output_conv = nn.Sequential( 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 39 | Interpolate(scale_factor=2, mode="bilinear"), 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 41 | nn.ReLU(True), 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 43 | nn.ReLU(True) if non_negative else nn.Identity(), 44 | ) 45 | 46 | if path: 47 | self.load(path) 48 | 49 | def forward(self, x): 50 | """Forward pass. 51 | 52 | Args: 53 | x (tensor): input data (image) 54 | 55 | Returns: 56 | tensor: depth 57 | """ 58 | 59 | layer_1 = self.pretrained.layer1(x) 60 | layer_2 = self.pretrained.layer2(layer_1) 61 | layer_3 = self.pretrained.layer3(layer_2) 62 | layer_4 = self.pretrained.layer4(layer_3) 63 | 64 | layer_1_rn = self.scratch.layer1_rn(layer_1) 65 | layer_2_rn = self.scratch.layer2_rn(layer_2) 66 | layer_3_rn = self.scratch.layer3_rn(layer_3) 67 | layer_4_rn = self.scratch.layer4_rn(layer_4) 68 | 69 | path_4 = self.scratch.refinenet4(layer_4_rn) 70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 73 | 74 | out = self.scratch.output_conv(path_1) 75 | 76 | return torch.squeeze(out, dim=1) 77 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/midas_net_custom.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet_small(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, 17 | blocks={'expand': True}): 18 | """Init. 19 | 20 | Args: 21 | path (str, optional): Path to saved model. Defaults to None. 22 | features (int, optional): Number of features. Defaults to 256. 23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 24 | """ 25 | print("Loading weights: ", path) 26 | 27 | super(MidasNet_small, self).__init__() 28 | 29 | use_pretrained = False if path else True 30 | 31 | self.channels_last = channels_last 32 | self.blocks = blocks 33 | self.backbone = backbone 34 | 35 | self.groups = 1 36 | 37 | features1=features 38 | features2=features 39 | features3=features 40 | features4=features 41 | self.expand = False 42 | if "expand" in self.blocks and self.blocks['expand'] == True: 43 | self.expand = True 44 | features1=features 45 | features2=features*2 46 | features3=features*4 47 | features4=features*8 48 | 49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) 50 | 51 | self.scratch.activation = nn.ReLU(False) 52 | 53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) 57 | 58 | 59 | self.scratch.output_conv = nn.Sequential( 60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), 61 | Interpolate(scale_factor=2, mode="bilinear"), 62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), 63 | self.scratch.activation, 64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 65 | nn.ReLU(True) if non_negative else nn.Identity(), 66 | nn.Identity(), 67 | ) 68 | 69 | if path: 70 | self.load(path) 71 | 72 | 73 | def forward(self, x): 74 | """Forward pass. 75 | 76 | Args: 77 | x (tensor): input data (image) 78 | 79 | Returns: 80 | tensor: depth 81 | """ 82 | if self.channels_last==True: 83 | print("self.channels_last = ", self.channels_last) 84 | x.contiguous(memory_format=torch.channels_last) 85 | 86 | 87 | layer_1 = self.pretrained.layer1(x) 88 | layer_2 = self.pretrained.layer2(layer_1) 89 | layer_3 = self.pretrained.layer3(layer_2) 90 | layer_4 = self.pretrained.layer4(layer_3) 91 | 92 | layer_1_rn = self.scratch.layer1_rn(layer_1) 93 | layer_2_rn = self.scratch.layer2_rn(layer_2) 94 | layer_3_rn = self.scratch.layer3_rn(layer_3) 95 | layer_4_rn = self.scratch.layer4_rn(layer_4) 96 | 97 | 98 | path_4 = self.scratch.refinenet4(layer_4_rn) 99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 102 | 103 | out = self.scratch.output_conv(path_1) 104 | 105 | return torch.squeeze(out, dim=1) 106 | 107 | 108 | 109 | def fuse_model(m): 110 | prev_previous_type = nn.Identity() 111 | prev_previous_name = '' 112 | previous_type = nn.Identity() 113 | previous_name = '' 114 | for name, module in m.named_modules(): 115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: 116 | # print("FUSED ", prev_previous_name, previous_name, name) 117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) 118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: 119 | # print("FUSED ", prev_previous_name, previous_name) 120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) 121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: 122 | # print("FUSED ", previous_name, name) 123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) 124 | 125 | prev_previous_type = previous_type 126 | prev_previous_name = previous_name 127 | previous_type = type(module) 128 | previous_name = name -------------------------------------------------------------------------------- /ldm/modules/midas/midas/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import math 4 | 5 | 6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): 7 | """Rezise the sample to ensure the given size. Keeps aspect ratio. 8 | 9 | Args: 10 | sample (dict): sample 11 | size (tuple): image size 12 | 13 | Returns: 14 | tuple: new size 15 | """ 16 | shape = list(sample["disparity"].shape) 17 | 18 | if shape[0] >= size[0] and shape[1] >= size[1]: 19 | return sample 20 | 21 | scale = [0, 0] 22 | scale[0] = size[0] / shape[0] 23 | scale[1] = size[1] / shape[1] 24 | 25 | scale = max(scale) 26 | 27 | shape[0] = math.ceil(scale * shape[0]) 28 | shape[1] = math.ceil(scale * shape[1]) 29 | 30 | # resize 31 | sample["image"] = cv2.resize( 32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method 33 | ) 34 | 35 | sample["disparity"] = cv2.resize( 36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST 37 | ) 38 | sample["mask"] = cv2.resize( 39 | sample["mask"].astype(np.float32), 40 | tuple(shape[::-1]), 41 | interpolation=cv2.INTER_NEAREST, 42 | ) 43 | sample["mask"] = sample["mask"].astype(bool) 44 | 45 | return tuple(shape) 46 | 47 | 48 | class Resize(object): 49 | """Resize sample to given size (width, height). 50 | """ 51 | 52 | def __init__( 53 | self, 54 | width, 55 | height, 56 | resize_target=True, 57 | keep_aspect_ratio=False, 58 | ensure_multiple_of=1, 59 | resize_method="lower_bound", 60 | image_interpolation_method=cv2.INTER_AREA, 61 | ): 62 | """Init. 63 | 64 | Args: 65 | width (int): desired output width 66 | height (int): desired output height 67 | resize_target (bool, optional): 68 | True: Resize the full sample (image, mask, target). 69 | False: Resize image only. 70 | Defaults to True. 71 | keep_aspect_ratio (bool, optional): 72 | True: Keep the aspect ratio of the input sample. 73 | Output sample might not have the given width and height, and 74 | resize behaviour depends on the parameter 'resize_method'. 75 | Defaults to False. 76 | ensure_multiple_of (int, optional): 77 | Output width and height is constrained to be multiple of this parameter. 78 | Defaults to 1. 79 | resize_method (str, optional): 80 | "lower_bound": Output will be at least as large as the given size. 81 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) 82 | "minimal": Scale as least as possible. (Output size might be smaller than given size.) 83 | Defaults to "lower_bound". 84 | """ 85 | self.__width = width 86 | self.__height = height 87 | 88 | self.__resize_target = resize_target 89 | self.__keep_aspect_ratio = keep_aspect_ratio 90 | self.__multiple_of = ensure_multiple_of 91 | self.__resize_method = resize_method 92 | self.__image_interpolation_method = image_interpolation_method 93 | 94 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None): 95 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) 96 | 97 | if max_val is not None and y > max_val: 98 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) 99 | 100 | if y < min_val: 101 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) 102 | 103 | return y 104 | 105 | def get_size(self, width, height): 106 | # determine new height and width 107 | scale_height = self.__height / height 108 | scale_width = self.__width / width 109 | 110 | if self.__keep_aspect_ratio: 111 | if self.__resize_method == "lower_bound": 112 | # scale such that output size is lower bound 113 | if scale_width > scale_height: 114 | # fit width 115 | scale_height = scale_width 116 | else: 117 | # fit height 118 | scale_width = scale_height 119 | elif self.__resize_method == "upper_bound": 120 | # scale such that output size is upper bound 121 | if scale_width < scale_height: 122 | # fit width 123 | scale_height = scale_width 124 | else: 125 | # fit height 126 | scale_width = scale_height 127 | elif self.__resize_method == "minimal": 128 | # scale as least as possbile 129 | if abs(1 - scale_width) < abs(1 - scale_height): 130 | # fit width 131 | scale_height = scale_width 132 | else: 133 | # fit height 134 | scale_width = scale_height 135 | else: 136 | raise ValueError( 137 | f"resize_method {self.__resize_method} not implemented" 138 | ) 139 | 140 | if self.__resize_method == "lower_bound": 141 | new_height = self.constrain_to_multiple_of( 142 | scale_height * height, min_val=self.__height 143 | ) 144 | new_width = self.constrain_to_multiple_of( 145 | scale_width * width, min_val=self.__width 146 | ) 147 | elif self.__resize_method == "upper_bound": 148 | new_height = self.constrain_to_multiple_of( 149 | scale_height * height, max_val=self.__height 150 | ) 151 | new_width = self.constrain_to_multiple_of( 152 | scale_width * width, max_val=self.__width 153 | ) 154 | elif self.__resize_method == "minimal": 155 | new_height = self.constrain_to_multiple_of(scale_height * height) 156 | new_width = self.constrain_to_multiple_of(scale_width * width) 157 | else: 158 | raise ValueError(f"resize_method {self.__resize_method} not implemented") 159 | 160 | return (new_width, new_height) 161 | 162 | def __call__(self, sample): 163 | width, height = self.get_size( 164 | sample["image"].shape[1], sample["image"].shape[0] 165 | ) 166 | 167 | # resize sample 168 | sample["image"] = cv2.resize( 169 | sample["image"], 170 | (width, height), 171 | interpolation=self.__image_interpolation_method, 172 | ) 173 | 174 | if self.__resize_target: 175 | if "disparity" in sample: 176 | sample["disparity"] = cv2.resize( 177 | sample["disparity"], 178 | (width, height), 179 | interpolation=cv2.INTER_NEAREST, 180 | ) 181 | 182 | if "depth" in sample: 183 | sample["depth"] = cv2.resize( 184 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST 185 | ) 186 | 187 | sample["mask"] = cv2.resize( 188 | sample["mask"].astype(np.float32), 189 | (width, height), 190 | interpolation=cv2.INTER_NEAREST, 191 | ) 192 | sample["mask"] = sample["mask"].astype(bool) 193 | 194 | return sample 195 | 196 | 197 | class NormalizeImage(object): 198 | """Normlize image by given mean and std. 199 | """ 200 | 201 | def __init__(self, mean, std): 202 | self.__mean = mean 203 | self.__std = std 204 | 205 | def __call__(self, sample): 206 | sample["image"] = (sample["image"] - self.__mean) / self.__std 207 | 208 | return sample 209 | 210 | 211 | class PrepareForNet(object): 212 | """Prepare sample for usage as network input. 213 | """ 214 | 215 | def __init__(self): 216 | pass 217 | 218 | def __call__(self, sample): 219 | image = np.transpose(sample["image"], (2, 0, 1)) 220 | sample["image"] = np.ascontiguousarray(image).astype(np.float32) 221 | 222 | if "mask" in sample: 223 | sample["mask"] = sample["mask"].astype(np.float32) 224 | sample["mask"] = np.ascontiguousarray(sample["mask"]) 225 | 226 | if "disparity" in sample: 227 | disparity = sample["disparity"].astype(np.float32) 228 | sample["disparity"] = np.ascontiguousarray(disparity) 229 | 230 | if "depth" in sample: 231 | depth = sample["depth"].astype(np.float32) 232 | sample["depth"] = np.ascontiguousarray(depth) 233 | 234 | return sample 235 | -------------------------------------------------------------------------------- /ldm/modules/midas/utils.py: -------------------------------------------------------------------------------- 1 | """Utils for monoDepth.""" 2 | import sys 3 | import re 4 | import numpy as np 5 | import cv2 6 | import torch 7 | 8 | 9 | def read_pfm(path): 10 | """Read pfm file. 11 | 12 | Args: 13 | path (str): path to file 14 | 15 | Returns: 16 | tuple: (data, scale) 17 | """ 18 | with open(path, "rb") as file: 19 | 20 | color = None 21 | width = None 22 | height = None 23 | scale = None 24 | endian = None 25 | 26 | header = file.readline().rstrip() 27 | if header.decode("ascii") == "PF": 28 | color = True 29 | elif header.decode("ascii") == "Pf": 30 | color = False 31 | else: 32 | raise Exception("Not a PFM file: " + path) 33 | 34 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) 35 | if dim_match: 36 | width, height = list(map(int, dim_match.groups())) 37 | else: 38 | raise Exception("Malformed PFM header.") 39 | 40 | scale = float(file.readline().decode("ascii").rstrip()) 41 | if scale < 0: 42 | # little-endian 43 | endian = "<" 44 | scale = -scale 45 | else: 46 | # big-endian 47 | endian = ">" 48 | 49 | data = np.fromfile(file, endian + "f") 50 | shape = (height, width, 3) if color else (height, width) 51 | 52 | data = np.reshape(data, shape) 53 | data = np.flipud(data) 54 | 55 | return data, scale 56 | 57 | 58 | def write_pfm(path, image, scale=1): 59 | """Write pfm file. 60 | 61 | Args: 62 | path (str): pathto file 63 | image (array): data 64 | scale (int, optional): Scale. Defaults to 1. 65 | """ 66 | 67 | with open(path, "wb") as file: 68 | color = None 69 | 70 | if image.dtype.name != "float32": 71 | raise Exception("Image dtype must be float32.") 72 | 73 | image = np.flipud(image) 74 | 75 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 76 | color = True 77 | elif ( 78 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 79 | ): # greyscale 80 | color = False 81 | else: 82 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") 83 | 84 | file.write("PF\n" if color else "Pf\n".encode()) 85 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) 86 | 87 | endian = image.dtype.byteorder 88 | 89 | if endian == "<" or endian == "=" and sys.byteorder == "little": 90 | scale = -scale 91 | 92 | file.write("%f\n".encode() % scale) 93 | 94 | image.tofile(file) 95 | 96 | 97 | def read_image(path): 98 | """Read image and output RGB image (0-1). 99 | 100 | Args: 101 | path (str): path to file 102 | 103 | Returns: 104 | array: RGB image (0-1) 105 | """ 106 | img = cv2.imread(path) 107 | 108 | if img.ndim == 2: 109 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 110 | 111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 112 | 113 | return img 114 | 115 | 116 | def resize_image(img): 117 | """Resize image and make it fit for network. 118 | 119 | Args: 120 | img (array): image 121 | 122 | Returns: 123 | tensor: data ready for network 124 | """ 125 | height_orig = img.shape[0] 126 | width_orig = img.shape[1] 127 | 128 | if width_orig > height_orig: 129 | scale = width_orig / 384 130 | else: 131 | scale = height_orig / 384 132 | 133 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int) 134 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int) 135 | 136 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) 137 | 138 | img_resized = ( 139 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() 140 | ) 141 | img_resized = img_resized.unsqueeze(0) 142 | 143 | return img_resized 144 | 145 | 146 | def resize_depth(depth, width, height): 147 | """Resize depth map and bring to CPU (numpy). 148 | 149 | Args: 150 | depth (tensor): depth 151 | width (int): image width 152 | height (int): image height 153 | 154 | Returns: 155 | array: processed depth 156 | """ 157 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu") 158 | 159 | depth_resized = cv2.resize( 160 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC 161 | ) 162 | 163 | return depth_resized 164 | 165 | def write_depth(path, depth, bits=1): 166 | """Write depth map to pfm and png file. 167 | 168 | Args: 169 | path (str): filepath without extension 170 | depth (array): depth 171 | """ 172 | write_pfm(path + ".pfm", depth.astype(np.float32)) 173 | 174 | depth_min = depth.min() 175 | depth_max = depth.max() 176 | 177 | max_val = (2**(8*bits))-1 178 | 179 | if depth_max - depth_min > np.finfo("float").eps: 180 | out = max_val * (depth - depth_min) / (depth_max - depth_min) 181 | else: 182 | out = np.zeros(depth.shape, dtype=depth.type) 183 | 184 | if bits == 1: 185 | cv2.imwrite(path + ".png", out.astype("uint8")) 186 | elif bits == 2: 187 | cv2.imwrite(path + ".png", out.astype("uint16")) 188 | 189 | return 190 | -------------------------------------------------------------------------------- /ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | from torch import optim 5 | import numpy as np 6 | 7 | from inspect import isfunction 8 | from PIL import Image, ImageDraw, ImageFont 9 | 10 | 11 | def autocast(f): 12 | def do_autocast(*args, **kwargs): 13 | with torch.cuda.amp.autocast(enabled=True, 14 | dtype=torch.get_autocast_gpu_dtype(), 15 | cache_enabled=torch.is_autocast_cache_enabled()): 16 | return f(*args, **kwargs) 17 | 18 | return do_autocast 19 | 20 | 21 | def log_txt_as_img(wh, xc, size=10): 22 | # wh a tuple of (width, height) 23 | # xc a list of captions to plot 24 | b = len(xc) 25 | txts = list() 26 | for bi in range(b): 27 | txt = Image.new("RGB", wh, color="white") 28 | draw = ImageDraw.Draw(txt) 29 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 30 | nc = int(40 * (wh[0] / 256)) 31 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 32 | 33 | try: 34 | draw.text((0, 0), lines, fill="black", font=font) 35 | except UnicodeEncodeError: 36 | print("Cant encode string for logging. Skipping.") 37 | 38 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 39 | txts.append(txt) 40 | txts = np.stack(txts) 41 | txts = torch.tensor(txts) 42 | return txts 43 | 44 | 45 | def ismap(x): 46 | if not isinstance(x, torch.Tensor): 47 | return False 48 | return (len(x.shape) == 4) and (x.shape[1] > 3) 49 | 50 | 51 | def isimage(x): 52 | if not isinstance(x,torch.Tensor): 53 | return False 54 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 55 | 56 | 57 | def exists(x): 58 | return x is not None 59 | 60 | 61 | def default(val, d): 62 | if exists(val): 63 | return val 64 | return d() if isfunction(d) else d 65 | 66 | 67 | def mean_flat(tensor): 68 | """ 69 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 70 | Take the mean over all non-batch dimensions. 71 | """ 72 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 73 | 74 | 75 | def count_params(model, verbose=False): 76 | total_params = sum(p.numel() for p in model.parameters()) 77 | if verbose: 78 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 79 | return total_params 80 | 81 | 82 | def instantiate_from_config(config): 83 | if not "target" in config: 84 | if config == '__is_first_stage__': 85 | return None 86 | elif config == "__is_unconditional__": 87 | return None 88 | raise KeyError("Expected key `target` to instantiate.") 89 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 90 | 91 | 92 | def get_obj_from_str(string, reload=False): 93 | module, cls = string.rsplit(".", 1) 94 | if reload: 95 | module_imp = importlib.import_module(module) 96 | importlib.reload(module_imp) 97 | return getattr(importlib.import_module(module, package=None), cls) 98 | 99 | 100 | class AdamWwithEMAandWings(optim.Optimizer): 101 | # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 102 | def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using 103 | weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code 104 | ema_power=1., param_names=()): 105 | """AdamW that saves EMA versions of the parameters.""" 106 | if not 0.0 <= lr: 107 | raise ValueError("Invalid learning rate: {}".format(lr)) 108 | if not 0.0 <= eps: 109 | raise ValueError("Invalid epsilon value: {}".format(eps)) 110 | if not 0.0 <= betas[0] < 1.0: 111 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 112 | if not 0.0 <= betas[1] < 1.0: 113 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 114 | if not 0.0 <= weight_decay: 115 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 116 | if not 0.0 <= ema_decay <= 1.0: 117 | raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) 118 | defaults = dict(lr=lr, betas=betas, eps=eps, 119 | weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, 120 | ema_power=ema_power, param_names=param_names) 121 | super().__init__(params, defaults) 122 | 123 | def __setstate__(self, state): 124 | super().__setstate__(state) 125 | for group in self.param_groups: 126 | group.setdefault('amsgrad', False) 127 | 128 | @torch.no_grad() 129 | def step(self, closure=None): 130 | """Performs a single optimization step. 131 | Args: 132 | closure (callable, optional): A closure that reevaluates the model 133 | and returns the loss. 134 | """ 135 | loss = None 136 | if closure is not None: 137 | with torch.enable_grad(): 138 | loss = closure() 139 | 140 | for group in self.param_groups: 141 | params_with_grad = [] 142 | grads = [] 143 | exp_avgs = [] 144 | exp_avg_sqs = [] 145 | ema_params_with_grad = [] 146 | state_sums = [] 147 | max_exp_avg_sqs = [] 148 | state_steps = [] 149 | amsgrad = group['amsgrad'] 150 | beta1, beta2 = group['betas'] 151 | ema_decay = group['ema_decay'] 152 | ema_power = group['ema_power'] 153 | 154 | for p in group['params']: 155 | if p.grad is None: 156 | continue 157 | params_with_grad.append(p) 158 | if p.grad.is_sparse: 159 | raise RuntimeError('AdamW does not support sparse gradients') 160 | grads.append(p.grad) 161 | 162 | state = self.state[p] 163 | 164 | # State initialization 165 | if len(state) == 0: 166 | state['step'] = 0 167 | # Exponential moving average of gradient values 168 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 169 | # Exponential moving average of squared gradient values 170 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 171 | if amsgrad: 172 | # Maintains max of all exp. moving avg. of sq. grad. values 173 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 174 | # Exponential moving average of parameter values 175 | state['param_exp_avg'] = p.detach().float().clone() 176 | 177 | exp_avgs.append(state['exp_avg']) 178 | exp_avg_sqs.append(state['exp_avg_sq']) 179 | ema_params_with_grad.append(state['param_exp_avg']) 180 | 181 | if amsgrad: 182 | max_exp_avg_sqs.append(state['max_exp_avg_sq']) 183 | 184 | # update the steps for each param group update 185 | state['step'] += 1 186 | # record the step after step update 187 | state_steps.append(state['step']) 188 | 189 | optim._functional.adamw(params_with_grad, 190 | grads, 191 | exp_avgs, 192 | exp_avg_sqs, 193 | max_exp_avg_sqs, 194 | state_steps, 195 | amsgrad=amsgrad, 196 | beta1=beta1, 197 | beta2=beta2, 198 | lr=group['lr'], 199 | weight_decay=group['weight_decay'], 200 | eps=group['eps'], 201 | maximize=False) 202 | 203 | cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) 204 | for param, ema_param in zip(params_with_grad, ema_params_with_grad): 205 | ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) 206 | 207 | return loss -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations==0.4.3 2 | opencv-python 3 | pudb==2019.2 4 | imageio==2.9.0 5 | imageio-ffmpeg==0.4.2 6 | pytorch-lightning==1.4.2 7 | torchmetrics==0.6 8 | omegaconf==2.1.1 9 | test-tube>=0.7.5 10 | streamlit>=0.73.1 11 | einops==0.3.0 12 | transformers==4.19.2 13 | webdataset==0.2.5 14 | open-clip-torch==2.7.0 15 | gradio==3.13.2 16 | kornia==0.6 17 | invisible-watermark>=0.1.5 18 | streamlit-drawable-canvas==0.8.0 19 | -e . 20 | -------------------------------------------------------------------------------- /scripts/gradio/depth2img.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | import gradio as gr 5 | from PIL import Image 6 | from omegaconf import OmegaConf 7 | from einops import repeat, rearrange 8 | from pytorch_lightning import seed_everything 9 | from imwatermark import WatermarkEncoder 10 | 11 | from scripts.txt2img import put_watermark 12 | from ldm.util import instantiate_from_config 13 | from ldm.models.diffusion.ddim import DDIMSampler 14 | from ldm.data.util import AddMiDaS 15 | 16 | torch.set_grad_enabled(False) 17 | 18 | 19 | def initialize_model(config, ckpt): 20 | config = OmegaConf.load(config) 21 | model = instantiate_from_config(config.model) 22 | model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False) 23 | 24 | device = torch.device( 25 | "cuda") if torch.cuda.is_available() else torch.device("cpu") 26 | model = model.to(device) 27 | sampler = DDIMSampler(model) 28 | return sampler 29 | 30 | 31 | def make_batch_sd( 32 | image, 33 | txt, 34 | device, 35 | num_samples=1, 36 | model_type="dpt_hybrid" 37 | ): 38 | image = np.array(image.convert("RGB")) 39 | image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 40 | # sample['jpg'] is tensor hwc in [-1, 1] at this point 41 | midas_trafo = AddMiDaS(model_type=model_type) 42 | batch = { 43 | "jpg": image, 44 | "txt": num_samples * [txt], 45 | } 46 | batch = midas_trafo(batch) 47 | batch["jpg"] = rearrange(batch["jpg"], 'h w c -> 1 c h w') 48 | batch["jpg"] = repeat(batch["jpg"].to(device=device), 49 | "1 ... -> n ...", n=num_samples) 50 | batch["midas_in"] = repeat(torch.from_numpy(batch["midas_in"][None, ...]).to( 51 | device=device), "1 ... -> n ...", n=num_samples) 52 | return batch 53 | 54 | 55 | def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=None, 56 | do_full_sample=False): 57 | device = torch.device( 58 | "cuda") if torch.cuda.is_available() else torch.device("cpu") 59 | model = sampler.model 60 | seed_everything(seed) 61 | 62 | print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") 63 | wm = "SDV2" 64 | wm_encoder = WatermarkEncoder() 65 | wm_encoder.set_watermark('bytes', wm.encode('utf-8')) 66 | 67 | with torch.no_grad(),\ 68 | torch.autocast("cuda"): 69 | batch = make_batch_sd( 70 | image, txt=prompt, device=device, num_samples=num_samples) 71 | z = model.get_first_stage_encoding(model.encode_first_stage( 72 | batch[model.first_stage_key])) # move to latent space 73 | c = model.cond_stage_model.encode(batch["txt"]) 74 | c_cat = list() 75 | for ck in model.concat_keys: 76 | cc = batch[ck] 77 | cc = model.depth_model(cc) 78 | depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3], 79 | keepdim=True) 80 | display_depth = (cc - depth_min) / (depth_max - depth_min) 81 | depth_image = Image.fromarray( 82 | (display_depth[0, 0, ...].cpu().numpy() * 255.).astype(np.uint8)) 83 | cc = torch.nn.functional.interpolate( 84 | cc, 85 | size=z.shape[2:], 86 | mode="bicubic", 87 | align_corners=False, 88 | ) 89 | depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3], 90 | keepdim=True) 91 | cc = 2. * (cc - depth_min) / (depth_max - depth_min) - 1. 92 | c_cat.append(cc) 93 | c_cat = torch.cat(c_cat, dim=1) 94 | # cond 95 | cond = {"c_concat": [c_cat], "c_crossattn": [c]} 96 | 97 | # uncond cond 98 | uc_cross = model.get_unconditional_conditioning(num_samples, "") 99 | uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]} 100 | if not do_full_sample: 101 | # encode (scaled latent) 102 | z_enc = sampler.stochastic_encode( 103 | z, torch.tensor([t_enc] * num_samples).to(model.device)) 104 | else: 105 | z_enc = torch.randn_like(z) 106 | # decode it 107 | samples = sampler.decode(z_enc, cond, t_enc, unconditional_guidance_scale=scale, 108 | unconditional_conditioning=uc_full, callback=callback) 109 | x_samples_ddim = model.decode_first_stage(samples) 110 | result = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) 111 | result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255 112 | return [depth_image] + [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result] 113 | 114 | 115 | def pad_image(input_image): 116 | pad_w, pad_h = np.max(((2, 2), np.ceil( 117 | np.array(input_image.size) / 64).astype(int)), axis=0) * 64 - input_image.size 118 | im_padded = Image.fromarray( 119 | np.pad(np.array(input_image), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge')) 120 | return im_padded 121 | 122 | 123 | def predict(input_image, prompt, steps, num_samples, scale, seed, eta, strength): 124 | init_image = input_image.convert("RGB") 125 | image = pad_image(init_image) # resize to integer multiple of 32 126 | 127 | sampler.make_schedule(steps, ddim_eta=eta, verbose=True) 128 | assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]' 129 | do_full_sample = strength == 1. 130 | t_enc = min(int(strength * steps), steps-1) 131 | result = paint( 132 | sampler=sampler, 133 | image=image, 134 | prompt=prompt, 135 | t_enc=t_enc, 136 | seed=seed, 137 | scale=scale, 138 | num_samples=num_samples, 139 | callback=None, 140 | do_full_sample=do_full_sample 141 | ) 142 | return result 143 | 144 | 145 | sampler = initialize_model(sys.argv[1], sys.argv[2]) 146 | 147 | block = gr.Blocks().queue() 148 | with block: 149 | with gr.Row(): 150 | gr.Markdown("## Stable Diffusion Depth2Img") 151 | 152 | with gr.Row(): 153 | with gr.Column(): 154 | input_image = gr.Image(source='upload', type="pil") 155 | prompt = gr.Textbox(label="Prompt") 156 | run_button = gr.Button(label="Run") 157 | with gr.Accordion("Advanced options", open=False): 158 | num_samples = gr.Slider( 159 | label="Images", minimum=1, maximum=4, value=1, step=1) 160 | ddim_steps = gr.Slider(label="Steps", minimum=1, 161 | maximum=50, value=50, step=1) 162 | scale = gr.Slider( 163 | label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1 164 | ) 165 | strength = gr.Slider( 166 | label="Strength", minimum=0.0, maximum=1.0, value=0.9, step=0.01 167 | ) 168 | seed = gr.Slider( 169 | label="Seed", 170 | minimum=0, 171 | maximum=2147483647, 172 | step=1, 173 | randomize=True, 174 | ) 175 | eta = gr.Number(label="eta (DDIM)", value=0.0) 176 | with gr.Column(): 177 | gallery = gr.Gallery(label="Generated images", show_label=False).style( 178 | grid=[2], height="auto") 179 | 180 | run_button.click(fn=predict, inputs=[ 181 | input_image, prompt, ddim_steps, num_samples, scale, seed, eta, strength], outputs=[gallery]) 182 | 183 | 184 | block.launch() 185 | -------------------------------------------------------------------------------- /scripts/gradio/inpainting.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import cv2 3 | import torch 4 | import numpy as np 5 | import gradio as gr 6 | from PIL import Image 7 | from omegaconf import OmegaConf 8 | from einops import repeat 9 | from imwatermark import WatermarkEncoder 10 | from pathlib import Path 11 | 12 | from ldm.models.diffusion.ddim import DDIMSampler 13 | from ldm.util import instantiate_from_config 14 | 15 | 16 | torch.set_grad_enabled(False) 17 | 18 | 19 | def put_watermark(img, wm_encoder=None): 20 | if wm_encoder is not None: 21 | img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) 22 | img = wm_encoder.encode(img, 'dwtDct') 23 | img = Image.fromarray(img[:, :, ::-1]) 24 | return img 25 | 26 | 27 | def initialize_model(config, ckpt): 28 | config = OmegaConf.load(config) 29 | model = instantiate_from_config(config.model) 30 | 31 | model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False) 32 | 33 | device = torch.device( 34 | "cuda") if torch.cuda.is_available() else torch.device("cpu") 35 | model = model.to(device) 36 | sampler = DDIMSampler(model) 37 | 38 | return sampler 39 | 40 | 41 | def make_batch_sd( 42 | image, 43 | mask, 44 | txt, 45 | device, 46 | num_samples=1): 47 | image = np.array(image.convert("RGB")) 48 | image = image[None].transpose(0, 3, 1, 2) 49 | image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 50 | 51 | mask = np.array(mask.convert("L")) 52 | mask = mask.astype(np.float32) / 255.0 53 | mask = mask[None, None] 54 | mask[mask < 0.5] = 0 55 | mask[mask >= 0.5] = 1 56 | mask = torch.from_numpy(mask) 57 | 58 | masked_image = image * (mask < 0.5) 59 | 60 | batch = { 61 | "image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples), 62 | "txt": num_samples * [txt], 63 | "mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples), 64 | "masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples), 65 | } 66 | return batch 67 | 68 | 69 | def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1, w=512, h=512): 70 | device = torch.device( 71 | "cuda") if torch.cuda.is_available() else torch.device("cpu") 72 | model = sampler.model 73 | 74 | print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") 75 | wm = "SDV2" 76 | wm_encoder = WatermarkEncoder() 77 | wm_encoder.set_watermark('bytes', wm.encode('utf-8')) 78 | 79 | prng = np.random.RandomState(seed) 80 | start_code = prng.randn(num_samples, 4, h // 8, w // 8) 81 | start_code = torch.from_numpy(start_code).to( 82 | device=device, dtype=torch.float32) 83 | 84 | with torch.no_grad(), \ 85 | torch.autocast("cuda"): 86 | batch = make_batch_sd(image, mask, txt=prompt, 87 | device=device, num_samples=num_samples) 88 | 89 | c = model.cond_stage_model.encode(batch["txt"]) 90 | 91 | c_cat = list() 92 | for ck in model.concat_keys: 93 | cc = batch[ck].float() 94 | if ck != model.masked_image_key: 95 | bchw = [num_samples, 4, h // 8, w // 8] 96 | cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) 97 | else: 98 | cc = model.get_first_stage_encoding( 99 | model.encode_first_stage(cc)) 100 | c_cat.append(cc) 101 | c_cat = torch.cat(c_cat, dim=1) 102 | 103 | # cond 104 | cond = {"c_concat": [c_cat], "c_crossattn": [c]} 105 | 106 | # uncond cond 107 | uc_cross = model.get_unconditional_conditioning(num_samples, "") 108 | uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]} 109 | 110 | shape = [model.channels, h // 8, w // 8] 111 | samples_cfg, intermediates = sampler.sample( 112 | ddim_steps, 113 | num_samples, 114 | shape, 115 | cond, 116 | verbose=False, 117 | eta=1.0, 118 | unconditional_guidance_scale=scale, 119 | unconditional_conditioning=uc_full, 120 | x_T=start_code, 121 | ) 122 | x_samples_ddim = model.decode_first_stage(samples_cfg) 123 | 124 | result = torch.clamp((x_samples_ddim + 1.0) / 2.0, 125 | min=0.0, max=1.0) 126 | 127 | result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255 128 | return [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result] 129 | 130 | def pad_image(input_image): 131 | pad_w, pad_h = np.max(((2, 2), np.ceil( 132 | np.array(input_image.size) / 64).astype(int)), axis=0) * 64 - input_image.size 133 | im_padded = Image.fromarray( 134 | np.pad(np.array(input_image), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge')) 135 | return im_padded 136 | 137 | def predict(input_image, prompt, ddim_steps, num_samples, scale, seed): 138 | init_image = input_image["image"].convert("RGB") 139 | init_mask = input_image["mask"].convert("RGB") 140 | image = pad_image(init_image) # resize to integer multiple of 32 141 | mask = pad_image(init_mask) # resize to integer multiple of 32 142 | width, height = image.size 143 | print("Inpainting...", width, height) 144 | 145 | result = inpaint( 146 | sampler=sampler, 147 | image=image, 148 | mask=mask, 149 | prompt=prompt, 150 | seed=seed, 151 | scale=scale, 152 | ddim_steps=ddim_steps, 153 | num_samples=num_samples, 154 | h=height, w=width 155 | ) 156 | 157 | return result 158 | 159 | 160 | sampler = initialize_model(sys.argv[1], sys.argv[2]) 161 | 162 | block = gr.Blocks().queue() 163 | with block: 164 | with gr.Row(): 165 | gr.Markdown("## Stable Diffusion Inpainting") 166 | 167 | with gr.Row(): 168 | with gr.Column(): 169 | input_image = gr.Image(source='upload', tool='sketch', type="pil") 170 | prompt = gr.Textbox(label="Prompt") 171 | run_button = gr.Button(label="Run") 172 | with gr.Accordion("Advanced options", open=False): 173 | num_samples = gr.Slider( 174 | label="Images", minimum=1, maximum=4, value=4, step=1) 175 | ddim_steps = gr.Slider(label="Steps", minimum=1, 176 | maximum=50, value=45, step=1) 177 | scale = gr.Slider( 178 | label="Guidance Scale", minimum=0.1, maximum=30.0, value=10, step=0.1 179 | ) 180 | seed = gr.Slider( 181 | label="Seed", 182 | minimum=0, 183 | maximum=2147483647, 184 | step=1, 185 | randomize=True, 186 | ) 187 | with gr.Column(): 188 | gallery = gr.Gallery(label="Generated images", show_label=False).style( 189 | grid=[2], height="auto") 190 | 191 | run_button.click(fn=predict, inputs=[ 192 | input_image, prompt, ddim_steps, num_samples, scale, seed], outputs=[gallery]) 193 | 194 | 195 | block.launch() 196 | -------------------------------------------------------------------------------- /scripts/gradio/superresolution.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | import gradio as gr 5 | from PIL import Image 6 | from omegaconf import OmegaConf 7 | from einops import repeat, rearrange 8 | from pytorch_lightning import seed_everything 9 | from imwatermark import WatermarkEncoder 10 | 11 | from scripts.txt2img import put_watermark 12 | from ldm.models.diffusion.ddim import DDIMSampler 13 | from ldm.models.diffusion.ddpm import LatentUpscaleDiffusion, LatentUpscaleFinetuneDiffusion 14 | from ldm.util import exists, instantiate_from_config 15 | 16 | 17 | torch.set_grad_enabled(False) 18 | 19 | 20 | def initialize_model(config, ckpt): 21 | config = OmegaConf.load(config) 22 | model = instantiate_from_config(config.model) 23 | model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False) 24 | 25 | device = torch.device( 26 | "cuda") if torch.cuda.is_available() else torch.device("cpu") 27 | model = model.to(device) 28 | sampler = DDIMSampler(model) 29 | return sampler 30 | 31 | 32 | def make_batch_sd( 33 | image, 34 | txt, 35 | device, 36 | num_samples=1, 37 | ): 38 | image = np.array(image.convert("RGB")) 39 | image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 40 | batch = { 41 | "lr": rearrange(image, 'h w c -> 1 c h w'), 42 | "txt": num_samples * [txt], 43 | } 44 | batch["lr"] = repeat(batch["lr"].to(device=device), 45 | "1 ... -> n ...", n=num_samples) 46 | return batch 47 | 48 | 49 | def make_noise_augmentation(model, batch, noise_level=None): 50 | x_low = batch[model.low_scale_key] 51 | x_low = x_low.to(memory_format=torch.contiguous_format).float() 52 | x_aug, noise_level = model.low_scale_model(x_low, noise_level) 53 | return x_aug, noise_level 54 | 55 | 56 | def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callback=None, eta=0., noise_level=None): 57 | device = torch.device( 58 | "cuda") if torch.cuda.is_available() else torch.device("cpu") 59 | model = sampler.model 60 | seed_everything(seed) 61 | prng = np.random.RandomState(seed) 62 | start_code = prng.randn(num_samples, model.channels, h, w) 63 | start_code = torch.from_numpy(start_code).to( 64 | device=device, dtype=torch.float32) 65 | 66 | print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") 67 | wm = "SDV2" 68 | wm_encoder = WatermarkEncoder() 69 | wm_encoder.set_watermark('bytes', wm.encode('utf-8')) 70 | with torch.no_grad(),\ 71 | torch.autocast("cuda"): 72 | batch = make_batch_sd( 73 | image, txt=prompt, device=device, num_samples=num_samples) 74 | c = model.cond_stage_model.encode(batch["txt"]) 75 | c_cat = list() 76 | if isinstance(model, LatentUpscaleFinetuneDiffusion): 77 | for ck in model.concat_keys: 78 | cc = batch[ck] 79 | if exists(model.reshuffle_patch_size): 80 | assert isinstance(model.reshuffle_patch_size, int) 81 | cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w', 82 | p1=model.reshuffle_patch_size, p2=model.reshuffle_patch_size) 83 | c_cat.append(cc) 84 | c_cat = torch.cat(c_cat, dim=1) 85 | # cond 86 | cond = {"c_concat": [c_cat], "c_crossattn": [c]} 87 | # uncond cond 88 | uc_cross = model.get_unconditional_conditioning(num_samples, "") 89 | uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]} 90 | elif isinstance(model, LatentUpscaleDiffusion): 91 | x_augment, noise_level = make_noise_augmentation( 92 | model, batch, noise_level) 93 | cond = {"c_concat": [x_augment], 94 | "c_crossattn": [c], "c_adm": noise_level} 95 | # uncond cond 96 | uc_cross = model.get_unconditional_conditioning(num_samples, "") 97 | uc_full = {"c_concat": [x_augment], "c_crossattn": [ 98 | uc_cross], "c_adm": noise_level} 99 | else: 100 | raise NotImplementedError() 101 | 102 | shape = [model.channels, h, w] 103 | samples, intermediates = sampler.sample( 104 | steps, 105 | num_samples, 106 | shape, 107 | cond, 108 | verbose=False, 109 | eta=eta, 110 | unconditional_guidance_scale=scale, 111 | unconditional_conditioning=uc_full, 112 | x_T=start_code, 113 | callback=callback 114 | ) 115 | with torch.no_grad(): 116 | x_samples_ddim = model.decode_first_stage(samples) 117 | result = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) 118 | result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255 119 | return [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result] 120 | 121 | 122 | def pad_image(input_image): 123 | pad_w, pad_h = np.max(((2, 2), np.ceil( 124 | np.array(input_image.size) / 64).astype(int)), axis=0) * 64 - input_image.size 125 | im_padded = Image.fromarray( 126 | np.pad(np.array(input_image), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge')) 127 | return im_padded 128 | 129 | 130 | def predict(input_image, prompt, steps, num_samples, scale, seed, eta, noise_level): 131 | init_image = input_image.convert("RGB") 132 | image = pad_image(init_image) # resize to integer multiple of 32 133 | width, height = image.size 134 | 135 | noise_level = torch.Tensor( 136 | num_samples * [noise_level]).to(sampler.model.device).long() 137 | sampler.make_schedule(steps, ddim_eta=eta, verbose=True) 138 | result = paint( 139 | sampler=sampler, 140 | image=image, 141 | prompt=prompt, 142 | seed=seed, 143 | scale=scale, 144 | h=height, w=width, steps=steps, 145 | num_samples=num_samples, 146 | callback=None, 147 | noise_level=noise_level 148 | ) 149 | return result 150 | 151 | 152 | sampler = initialize_model(sys.argv[1], sys.argv[2]) 153 | 154 | block = gr.Blocks().queue() 155 | with block: 156 | with gr.Row(): 157 | gr.Markdown("## Stable Diffusion Upscaling") 158 | 159 | with gr.Row(): 160 | with gr.Column(): 161 | input_image = gr.Image(source='upload', type="pil") 162 | gr.Markdown( 163 | "Tip: Add a description of the object that should be upscaled, e.g.: 'a professional photograph of a cat") 164 | prompt = gr.Textbox(label="Prompt") 165 | run_button = gr.Button(label="Run") 166 | with gr.Accordion("Advanced options", open=False): 167 | num_samples = gr.Slider( 168 | label="Number of Samples", minimum=1, maximum=4, value=1, step=1) 169 | steps = gr.Slider(label="DDIM Steps", minimum=2, 170 | maximum=200, value=75, step=1) 171 | scale = gr.Slider( 172 | label="Scale", minimum=0.1, maximum=30.0, value=10, step=0.1 173 | ) 174 | seed = gr.Slider( 175 | label="Seed", 176 | minimum=0, 177 | maximum=2147483647, 178 | step=1, 179 | randomize=True, 180 | ) 181 | eta = gr.Number(label="eta (DDIM)", 182 | value=0.0, min=0.0, max=1.0) 183 | noise_level = None 184 | if isinstance(sampler.model, LatentUpscaleDiffusion): 185 | # TODO: make this work for all models 186 | noise_level = gr.Number( 187 | label="Noise Augmentation", min=0, max=350, value=20, step=1) 188 | 189 | with gr.Column(): 190 | gallery = gr.Gallery(label="Generated images", show_label=False).style( 191 | grid=[2], height="auto") 192 | 193 | run_button.click(fn=predict, inputs=[ 194 | input_image, prompt, steps, num_samples, scale, seed, eta, noise_level], outputs=[gallery]) 195 | 196 | 197 | block.launch() 198 | -------------------------------------------------------------------------------- /scripts/streamlit/depth2img.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | import streamlit as st 5 | from PIL import Image 6 | from omegaconf import OmegaConf 7 | from einops import repeat, rearrange 8 | from pytorch_lightning import seed_everything 9 | from imwatermark import WatermarkEncoder 10 | 11 | from scripts.txt2img import put_watermark 12 | from ldm.util import instantiate_from_config 13 | from ldm.models.diffusion.ddim import DDIMSampler 14 | from ldm.data.util import AddMiDaS 15 | 16 | torch.set_grad_enabled(False) 17 | 18 | 19 | @st.cache(allow_output_mutation=True) 20 | def initialize_model(config, ckpt): 21 | config = OmegaConf.load(config) 22 | model = instantiate_from_config(config.model) 23 | model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False) 24 | 25 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 26 | model = model.to(device) 27 | sampler = DDIMSampler(model) 28 | return sampler 29 | 30 | 31 | def make_batch_sd( 32 | image, 33 | txt, 34 | device, 35 | num_samples=1, 36 | model_type="dpt_hybrid" 37 | ): 38 | image = np.array(image.convert("RGB")) 39 | image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 40 | # sample['jpg'] is tensor hwc in [-1, 1] at this point 41 | midas_trafo = AddMiDaS(model_type=model_type) 42 | batch = { 43 | "jpg": image, 44 | "txt": num_samples * [txt], 45 | } 46 | batch = midas_trafo(batch) 47 | batch["jpg"] = rearrange(batch["jpg"], 'h w c -> 1 c h w') 48 | batch["jpg"] = repeat(batch["jpg"].to(device=device), "1 ... -> n ...", n=num_samples) 49 | batch["midas_in"] = repeat(torch.from_numpy(batch["midas_in"][None, ...]).to(device=device), "1 ... -> n ...", n=num_samples) 50 | return batch 51 | 52 | 53 | def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=None, 54 | do_full_sample=False): 55 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 56 | model = sampler.model 57 | seed_everything(seed) 58 | 59 | print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") 60 | wm = "SDV2" 61 | wm_encoder = WatermarkEncoder() 62 | wm_encoder.set_watermark('bytes', wm.encode('utf-8')) 63 | 64 | with torch.no_grad(),\ 65 | torch.autocast("cuda"): 66 | batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples) 67 | z = model.get_first_stage_encoding(model.encode_first_stage(batch[model.first_stage_key])) # move to latent space 68 | c = model.cond_stage_model.encode(batch["txt"]) 69 | c_cat = list() 70 | for ck in model.concat_keys: 71 | cc = batch[ck] 72 | cc = model.depth_model(cc) 73 | depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3], 74 | keepdim=True) 75 | display_depth = (cc - depth_min) / (depth_max - depth_min) 76 | st.image(Image.fromarray((display_depth[0, 0, ...].cpu().numpy() * 255.).astype(np.uint8))) 77 | cc = torch.nn.functional.interpolate( 78 | cc, 79 | size=z.shape[2:], 80 | mode="bicubic", 81 | align_corners=False, 82 | ) 83 | depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3], 84 | keepdim=True) 85 | cc = 2. * (cc - depth_min) / (depth_max - depth_min) - 1. 86 | c_cat.append(cc) 87 | c_cat = torch.cat(c_cat, dim=1) 88 | # cond 89 | cond = {"c_concat": [c_cat], "c_crossattn": [c]} 90 | 91 | # uncond cond 92 | uc_cross = model.get_unconditional_conditioning(num_samples, "") 93 | uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]} 94 | if not do_full_sample: 95 | # encode (scaled latent) 96 | z_enc = sampler.stochastic_encode(z, torch.tensor([t_enc] * num_samples).to(model.device)) 97 | else: 98 | z_enc = torch.randn_like(z) 99 | # decode it 100 | samples = sampler.decode(z_enc, cond, t_enc, unconditional_guidance_scale=scale, 101 | unconditional_conditioning=uc_full, callback=callback) 102 | x_samples_ddim = model.decode_first_stage(samples) 103 | result = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) 104 | result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255 105 | return [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result] 106 | 107 | 108 | def run(): 109 | st.title("Stable Diffusion Depth2Img") 110 | # run via streamlit run scripts/demo/depth2img.py 111 | sampler = initialize_model(sys.argv[1], sys.argv[2]) 112 | 113 | image = st.file_uploader("Image", ["jpg", "png"]) 114 | if image: 115 | image = Image.open(image) 116 | w, h = image.size 117 | st.text(f"loaded input image of size ({w}, {h})") 118 | width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 119 | image = image.resize((width, height)) 120 | st.text(f"resized input image to size ({width}, {height} (w, h))") 121 | st.image(image) 122 | 123 | prompt = st.text_input("Prompt") 124 | 125 | seed = st.number_input("Seed", min_value=0, max_value=1000000, value=0) 126 | num_samples = st.number_input("Number of Samples", min_value=1, max_value=64, value=1) 127 | scale = st.slider("Scale", min_value=0.1, max_value=30.0, value=9.0, step=0.1) 128 | steps = st.slider("DDIM Steps", min_value=0, max_value=50, value=50, step=1) 129 | strength = st.slider("Strength", min_value=0., max_value=1., value=0.9) 130 | 131 | t_progress = st.progress(0) 132 | def t_callback(t): 133 | t_progress.progress(min((t + 1) / t_enc, 1.)) 134 | 135 | assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]' 136 | do_full_sample = strength == 1. 137 | t_enc = min(int(strength * steps), steps-1) 138 | sampler.make_schedule(steps, ddim_eta=0., verbose=True) 139 | if st.button("Sample"): 140 | result = paint( 141 | sampler=sampler, 142 | image=image, 143 | prompt=prompt, 144 | t_enc=t_enc, 145 | seed=seed, 146 | scale=scale, 147 | num_samples=num_samples, 148 | callback=t_callback, 149 | do_full_sample=do_full_sample, 150 | ) 151 | st.write("Result") 152 | for image in result: 153 | st.image(image, output_format='PNG') 154 | 155 | 156 | if __name__ == "__main__": 157 | run() 158 | -------------------------------------------------------------------------------- /scripts/streamlit/inpainting.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import cv2 3 | import torch 4 | import numpy as np 5 | import streamlit as st 6 | from PIL import Image 7 | from omegaconf import OmegaConf 8 | from einops import repeat 9 | from streamlit_drawable_canvas import st_canvas 10 | from imwatermark import WatermarkEncoder 11 | 12 | from ldm.models.diffusion.ddim import DDIMSampler 13 | from ldm.util import instantiate_from_config 14 | 15 | 16 | torch.set_grad_enabled(False) 17 | 18 | 19 | def put_watermark(img, wm_encoder=None): 20 | if wm_encoder is not None: 21 | img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) 22 | img = wm_encoder.encode(img, 'dwtDct') 23 | img = Image.fromarray(img[:, :, ::-1]) 24 | return img 25 | 26 | 27 | @st.cache(allow_output_mutation=True) 28 | def initialize_model(config, ckpt): 29 | config = OmegaConf.load(config) 30 | model = instantiate_from_config(config.model) 31 | 32 | model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False) 33 | 34 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 35 | model = model.to(device) 36 | sampler = DDIMSampler(model) 37 | 38 | return sampler 39 | 40 | 41 | def make_batch_sd( 42 | image, 43 | mask, 44 | txt, 45 | device, 46 | num_samples=1): 47 | image = np.array(image.convert("RGB")) 48 | image = image[None].transpose(0, 3, 1, 2) 49 | image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 50 | 51 | mask = np.array(mask.convert("L")) 52 | mask = mask.astype(np.float32) / 255.0 53 | mask = mask[None, None] 54 | mask[mask < 0.5] = 0 55 | mask[mask >= 0.5] = 1 56 | mask = torch.from_numpy(mask) 57 | 58 | masked_image = image * (mask < 0.5) 59 | 60 | batch = { 61 | "image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples), 62 | "txt": num_samples * [txt], 63 | "mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples), 64 | "masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples), 65 | } 66 | return batch 67 | 68 | 69 | def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1, w=512, h=512, eta=1.): 70 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 71 | model = sampler.model 72 | 73 | print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") 74 | wm = "SDV2" 75 | wm_encoder = WatermarkEncoder() 76 | wm_encoder.set_watermark('bytes', wm.encode('utf-8')) 77 | 78 | prng = np.random.RandomState(seed) 79 | start_code = prng.randn(num_samples, 4, h // 8, w // 8) 80 | start_code = torch.from_numpy(start_code).to(device=device, dtype=torch.float32) 81 | 82 | with torch.no_grad(), \ 83 | torch.autocast("cuda"): 84 | batch = make_batch_sd(image, mask, txt=prompt, device=device, num_samples=num_samples) 85 | 86 | c = model.cond_stage_model.encode(batch["txt"]) 87 | 88 | c_cat = list() 89 | for ck in model.concat_keys: 90 | cc = batch[ck].float() 91 | if ck != model.masked_image_key: 92 | bchw = [num_samples, 4, h // 8, w // 8] 93 | cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) 94 | else: 95 | cc = model.get_first_stage_encoding(model.encode_first_stage(cc)) 96 | c_cat.append(cc) 97 | c_cat = torch.cat(c_cat, dim=1) 98 | 99 | # cond 100 | cond = {"c_concat": [c_cat], "c_crossattn": [c]} 101 | 102 | # uncond cond 103 | uc_cross = model.get_unconditional_conditioning(num_samples, "") 104 | uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]} 105 | 106 | shape = [model.channels, h // 8, w // 8] 107 | samples_cfg, intermediates = sampler.sample( 108 | ddim_steps, 109 | num_samples, 110 | shape, 111 | cond, 112 | verbose=False, 113 | eta=eta, 114 | unconditional_guidance_scale=scale, 115 | unconditional_conditioning=uc_full, 116 | x_T=start_code, 117 | ) 118 | x_samples_ddim = model.decode_first_stage(samples_cfg) 119 | 120 | result = torch.clamp((x_samples_ddim + 1.0) / 2.0, 121 | min=0.0, max=1.0) 122 | 123 | result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255 124 | return [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result] 125 | 126 | 127 | def run(): 128 | st.title("Stable Diffusion Inpainting") 129 | 130 | sampler = initialize_model(sys.argv[1], sys.argv[2]) 131 | 132 | image = st.file_uploader("Image", ["jpg", "png"]) 133 | if image: 134 | image = Image.open(image) 135 | w, h = image.size 136 | print(f"loaded input image of size ({w}, {h})") 137 | width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32 138 | image = image.resize((width, height)) 139 | 140 | prompt = st.text_input("Prompt") 141 | 142 | seed = st.number_input("Seed", min_value=0, max_value=1000000, value=0) 143 | num_samples = st.number_input("Number of Samples", min_value=1, max_value=64, value=1) 144 | scale = st.slider("Scale", min_value=0.1, max_value=30.0, value=10., step=0.1) 145 | ddim_steps = st.slider("DDIM Steps", min_value=0, max_value=50, value=50, step=1) 146 | eta = st.sidebar.number_input("eta (DDIM)", value=0., min_value=0., max_value=1.) 147 | 148 | fill_color = "rgba(255, 255, 255, 0.0)" 149 | stroke_width = st.number_input("Brush Size", 150 | value=64, 151 | min_value=1, 152 | max_value=100) 153 | stroke_color = "rgba(255, 255, 255, 1.0)" 154 | bg_color = "rgba(0, 0, 0, 1.0)" 155 | drawing_mode = "freedraw" 156 | 157 | st.write("Canvas") 158 | st.caption( 159 | "Draw a mask to inpaint, then click the 'Send to Streamlit' button (bottom left, with an arrow on it).") 160 | canvas_result = st_canvas( 161 | fill_color=fill_color, 162 | stroke_width=stroke_width, 163 | stroke_color=stroke_color, 164 | background_color=bg_color, 165 | background_image=image, 166 | update_streamlit=False, 167 | height=height, 168 | width=width, 169 | drawing_mode=drawing_mode, 170 | key="canvas", 171 | ) 172 | if canvas_result: 173 | mask = canvas_result.image_data 174 | mask = mask[:, :, -1] > 0 175 | if mask.sum() > 0: 176 | mask = Image.fromarray(mask) 177 | 178 | result = inpaint( 179 | sampler=sampler, 180 | image=image, 181 | mask=mask, 182 | prompt=prompt, 183 | seed=seed, 184 | scale=scale, 185 | ddim_steps=ddim_steps, 186 | num_samples=num_samples, 187 | h=height, w=width, eta=eta 188 | ) 189 | st.write("Inpainted") 190 | for image in result: 191 | st.image(image, output_format='PNG') 192 | 193 | 194 | if __name__ == "__main__": 195 | run() -------------------------------------------------------------------------------- /scripts/streamlit/superresolution.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | import streamlit as st 5 | from PIL import Image 6 | from omegaconf import OmegaConf 7 | from einops import repeat, rearrange 8 | from pytorch_lightning import seed_everything 9 | from imwatermark import WatermarkEncoder 10 | 11 | from scripts.txt2img import put_watermark 12 | from ldm.models.diffusion.ddim import DDIMSampler 13 | from ldm.models.diffusion.ddpm import LatentUpscaleDiffusion, LatentUpscaleFinetuneDiffusion 14 | from ldm.util import exists, instantiate_from_config 15 | 16 | 17 | torch.set_grad_enabled(False) 18 | 19 | 20 | @st.cache(allow_output_mutation=True) 21 | def initialize_model(config, ckpt): 22 | config = OmegaConf.load(config) 23 | model = instantiate_from_config(config.model) 24 | model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False) 25 | 26 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 27 | model = model.to(device) 28 | sampler = DDIMSampler(model) 29 | return sampler 30 | 31 | 32 | def make_batch_sd( 33 | image, 34 | txt, 35 | device, 36 | num_samples=1, 37 | ): 38 | image = np.array(image.convert("RGB")) 39 | image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 40 | batch = { 41 | "lr": rearrange(image, 'h w c -> 1 c h w'), 42 | "txt": num_samples * [txt], 43 | } 44 | batch["lr"] = repeat(batch["lr"].to(device=device), "1 ... -> n ...", n=num_samples) 45 | return batch 46 | 47 | 48 | def make_noise_augmentation(model, batch, noise_level=None): 49 | x_low = batch[model.low_scale_key] 50 | x_low = x_low.to(memory_format=torch.contiguous_format).float() 51 | x_aug, noise_level = model.low_scale_model(x_low, noise_level) 52 | return x_aug, noise_level 53 | 54 | 55 | def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callback=None, eta=0., noise_level=None): 56 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 57 | model = sampler.model 58 | seed_everything(seed) 59 | prng = np.random.RandomState(seed) 60 | start_code = prng.randn(num_samples, model.channels, h , w) 61 | start_code = torch.from_numpy(start_code).to(device=device, dtype=torch.float32) 62 | 63 | print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") 64 | wm = "SDV2" 65 | wm_encoder = WatermarkEncoder() 66 | wm_encoder.set_watermark('bytes', wm.encode('utf-8')) 67 | with torch.no_grad(),\ 68 | torch.autocast("cuda"): 69 | batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples) 70 | c = model.cond_stage_model.encode(batch["txt"]) 71 | c_cat = list() 72 | if isinstance(model, LatentUpscaleFinetuneDiffusion): 73 | for ck in model.concat_keys: 74 | cc = batch[ck] 75 | if exists(model.reshuffle_patch_size): 76 | assert isinstance(model.reshuffle_patch_size, int) 77 | cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w', 78 | p1=model.reshuffle_patch_size, p2=model.reshuffle_patch_size) 79 | c_cat.append(cc) 80 | c_cat = torch.cat(c_cat, dim=1) 81 | # cond 82 | cond = {"c_concat": [c_cat], "c_crossattn": [c]} 83 | # uncond cond 84 | uc_cross = model.get_unconditional_conditioning(num_samples, "") 85 | uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]} 86 | elif isinstance(model, LatentUpscaleDiffusion): 87 | x_augment, noise_level = make_noise_augmentation(model, batch, noise_level) 88 | cond = {"c_concat": [x_augment], "c_crossattn": [c], "c_adm": noise_level} 89 | # uncond cond 90 | uc_cross = model.get_unconditional_conditioning(num_samples, "") 91 | uc_full = {"c_concat": [x_augment], "c_crossattn": [uc_cross], "c_adm": noise_level} 92 | else: 93 | raise NotImplementedError() 94 | 95 | shape = [model.channels, h, w] 96 | samples, intermediates = sampler.sample( 97 | steps, 98 | num_samples, 99 | shape, 100 | cond, 101 | verbose=False, 102 | eta=eta, 103 | unconditional_guidance_scale=scale, 104 | unconditional_conditioning=uc_full, 105 | x_T=start_code, 106 | callback=callback 107 | ) 108 | with torch.no_grad(): 109 | x_samples_ddim = model.decode_first_stage(samples) 110 | result = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) 111 | result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255 112 | st.text(f"upscaled image shape: {result.shape}") 113 | return [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result] 114 | 115 | 116 | def run(): 117 | st.title("Stable Diffusion Upscaling") 118 | # run via streamlit run scripts/demo/depth2img.py 119 | sampler = initialize_model(sys.argv[1], sys.argv[2]) 120 | 121 | image = st.file_uploader("Image", ["jpg", "png"]) 122 | if image: 123 | image = Image.open(image) 124 | w, h = image.size 125 | st.text(f"loaded input image of size ({w}, {h})") 126 | width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 127 | image = image.resize((width, height)) 128 | st.text(f"resized input image to size ({width}, {height} (w, h))") 129 | st.image(image) 130 | 131 | st.write(f"\n Tip: Add a description of the object that should be upscaled, e.g.: 'a professional photograph of a cat'") 132 | prompt = st.text_input("Prompt", "a high quality professional photograph") 133 | 134 | seed = st.number_input("Seed", min_value=0, max_value=1000000, value=0) 135 | num_samples = st.number_input("Number of Samples", min_value=1, max_value=64, value=1) 136 | scale = st.slider("Scale", min_value=0.1, max_value=30.0, value=9.0, step=0.1) 137 | steps = st.slider("DDIM Steps", min_value=2, max_value=250, value=50, step=1) 138 | eta = st.sidebar.number_input("eta (DDIM)", value=0., min_value=0., max_value=1.) 139 | 140 | noise_level = None 141 | if isinstance(sampler.model, LatentUpscaleDiffusion): 142 | # TODO: make this work for all models 143 | noise_level = st.sidebar.number_input("Noise Augmentation", min_value=0, max_value=350, value=20) 144 | noise_level = torch.Tensor(num_samples * [noise_level]).to(sampler.model.device).long() 145 | 146 | t_progress = st.progress(0) 147 | def t_callback(t): 148 | t_progress.progress(min((t + 1) / steps, 1.)) 149 | 150 | sampler.make_schedule(steps, ddim_eta=eta, verbose=True) 151 | if st.button("Sample"): 152 | result = paint( 153 | sampler=sampler, 154 | image=image, 155 | prompt=prompt, 156 | seed=seed, 157 | scale=scale, 158 | h=height, w=width, steps=steps, 159 | num_samples=num_samples, 160 | callback=t_callback, 161 | noise_level=noise_level, 162 | eta=eta 163 | ) 164 | st.write("Result") 165 | for image in result: 166 | st.image(image, output_format='PNG') 167 | 168 | 169 | if __name__ == "__main__": 170 | run() 171 | -------------------------------------------------------------------------------- /scripts/tests/test_watermark.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import fire 3 | from imwatermark import WatermarkDecoder 4 | 5 | 6 | def testit(img_path): 7 | bgr = cv2.imread(img_path) 8 | decoder = WatermarkDecoder('bytes', 136) 9 | watermark = decoder.decode(bgr, 'dwtDct') 10 | try: 11 | dec = watermark.decode('utf-8') 12 | except: 13 | dec = "null" 14 | print(dec) 15 | 16 | 17 | if __name__ == "__main__": 18 | fire.Fire(testit) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='stable-diffusion', 5 | version='0.0.1', 6 | description='', 7 | packages=find_packages(), 8 | install_requires=[ 9 | 'torch', 10 | 'numpy', 11 | 'tqdm', 12 | ], 13 | ) --------------------------------------------------------------------------------