├── .gitignore ├── LICENSE ├── README.md ├── _in ├── pix │ ├── 6458524847_2f4c361183_k.jpg │ ├── 8399166846_f6fb4e4b8e_k.jpg │ ├── alex-iby-G_Pk4D9rMLs.jpg │ ├── bench2.jpg │ ├── bertrand-gabioud-CpuFzIsHYJ0.jpg │ ├── billow926-12-Wc-Zgx6Y.jpg │ ├── mask │ │ ├── 6458524847_2f4c361183_k_mask.jpg │ │ ├── 8399166846_f6fb4e4b8e_k_mask.jpg │ │ ├── alex-iby-G_Pk4D9rMLs_mask.jpg │ │ ├── bench2_mask.jpg │ │ ├── bertrand-gabioud-CpuFzIsHYJ0_mask.jpg │ │ ├── billow926-12-Wc-Zgx6Y_mask.jpg │ │ ├── overture-creations-5sI6fQgYIuo_mask.jpg │ │ └── photo-1583445095369-9c651e7e5d34_mask.jpg │ ├── overture-creations-5sI6fQgYIuo.jpg │ └── photo-1583445095369-9c651e7e5d34.jpg └── something.jpg ├── download.py ├── img.bat ├── inpaint.bat ├── model_half.py ├── requirements.txt ├── src ├── _sdrun.py ├── custom │ ├── __init__.py │ ├── composenW.py │ ├── compress.py │ ├── convert.py │ ├── finetune_data.py │ ├── get_deltas.py │ ├── model.py │ └── modules.py ├── latwalk.py ├── ldm │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── base.py │ │ └── personalized.py │ ├── models │ │ ├── __init__.py │ │ ├── autoencoder.py │ │ └── diffusion │ │ │ ├── __init__.py │ │ │ ├── ddim.py │ │ │ └── ddpm.py │ ├── modules │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── diffusionmodules │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ ├── openaimodel.py │ │ │ └── util.py │ │ ├── distributions │ │ │ ├── __init__.py │ │ │ └── distributions.py │ │ ├── ema.py │ │ ├── embedding_manager.py │ │ ├── encoders │ │ │ ├── __init__.py │ │ │ └── modules.py │ │ └── midas │ │ │ ├── __init__.py │ │ │ ├── api.py │ │ │ ├── midas │ │ │ ├── __init__.py │ │ │ ├── base_model.py │ │ │ ├── blocks.py │ │ │ ├── dpt_depth.py │ │ │ ├── midas_net.py │ │ │ ├── midas_net_custom.py │ │ │ ├── transforms.py │ │ │ └── vit.py │ │ │ └── utils.py │ └── util.py ├── sampling.py ├── train.py ├── txt2mask.py ├── utils.py ├── xtra │ ├── clipseg │ │ ├── clipseg.py │ │ └── vitseg.py │ ├── k_diffusion │ │ ├── __init__.py │ │ ├── augmentation.py │ │ ├── config.py │ │ ├── evaluation.py │ │ ├── external.py │ │ ├── gns.py │ │ ├── layers.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ └── image_v1.py │ │ ├── sampling.py │ │ └── utils.py │ ├── open_clip │ │ ├── __init__.py │ │ ├── bpe_simple_vocab_16e6.txt.gz │ │ ├── factory.py │ │ ├── loss.py │ │ ├── model.py │ │ ├── model_configs │ │ │ ├── RN101-quickgelu.json │ │ │ ├── RN101.json │ │ │ ├── RN50-quickgelu.json │ │ │ ├── RN50.json │ │ │ ├── RN50x16.json │ │ │ ├── RN50x4.json │ │ │ ├── ViT-B-16-plus-240.json │ │ │ ├── ViT-B-16-plus.json │ │ │ ├── ViT-B-16.json │ │ │ ├── ViT-B-32-plus-256.json │ │ │ ├── ViT-B-32-quickgelu.json │ │ │ ├── ViT-B-32.json │ │ │ ├── ViT-H-14.json │ │ │ ├── ViT-H-16.json │ │ │ ├── ViT-L-14-280.json │ │ │ ├── ViT-L-14-336.json │ │ │ ├── ViT-L-14.json │ │ │ ├── ViT-L-16-320.json │ │ │ ├── ViT-L-16.json │ │ │ ├── ViT-g-14.json │ │ │ ├── timm-efficientnetv2_rw_s.json │ │ │ ├── timm-resnet50d.json │ │ │ ├── timm-resnetaa50d.json │ │ │ ├── timm-resnetblur50.json │ │ │ ├── timm-swin_base_patch4_window7_224.json │ │ │ ├── timm-vit_base_patch16_224.json │ │ │ ├── timm-vit_base_patch32_224.json │ │ │ └── timm-vit_small_patch16_224.json │ │ ├── openai.py │ │ ├── pretrained.py │ │ ├── src.zip │ │ ├── test_simple.py │ │ ├── timm_model.py │ │ ├── tokenizer.py │ │ ├── transform.py │ │ ├── utils.py │ │ └── version.py │ └── taming │ │ ├── __init__.py │ │ ├── data.zip │ │ ├── lr_scheduler.py │ │ ├── models │ │ ├── __init__.py │ │ ├── cond_transformer.py │ │ ├── dummy_cond_stage.py │ │ └── vqgan.py │ │ ├── modules │ │ ├── __init__.py │ │ ├── diffusionmodules │ │ │ ├── __init__.py │ │ │ └── model.py │ │ ├── discriminator │ │ │ ├── __init__.py │ │ │ └── model.py │ │ ├── losses │ │ │ ├── __init__.py │ │ │ ├── lpips.py │ │ │ ├── segmentation.py │ │ │ └── vqperceptual.py │ │ ├── misc │ │ │ ├── __init__.py │ │ │ └── coord.py │ │ ├── transformer │ │ │ ├── __init__.py │ │ │ ├── mingpt.py │ │ │ └── permuter.py │ │ ├── util.py │ │ └── vqvae │ │ │ ├── __init__.py │ │ │ └── quantize.py │ │ └── util.py └── yaml │ ├── v1-finetune-custom.yaml │ ├── v1-finetune.yaml │ ├── v1-finetune_style.yaml │ ├── v1-inference.yaml │ ├── v1-inpainting.yaml │ ├── v2-inference-v.yaml │ ├── v2-inference.yaml │ ├── v2-inpainting.yaml │ └── v2-midas.yaml ├── train.bat ├── txt.bat └── walk.bat /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib64/ 18 | parts/ 19 | sdist/ 20 | var/ 21 | wheels/ 22 | pip-wheel-metadata/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 90 | # install all needed dependencies. 91 | #Pipfile.lock 92 | 93 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 94 | __pypackages__/ 95 | 96 | # Celery stuff 97 | celerybeat-schedule 98 | celerybeat.pid 99 | 100 | # SageMath parsed files 101 | *.sage.py 102 | 103 | # Environments 104 | .env 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | embed/ 131 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 vadim epstein 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Stable Diffusion for studies 2 | 3 |

4 | 5 | This is yet another Stable Diffusion compilation, aimed to be functional, clean & compact enough for various experiments. There's no GUI here, as the target audience are creative coders rather than post-Photoshop users. For the latter one may check [InvokeAI] or [AUTOMATIC1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) as a convenient production tool, or [Deforum] for precisely controlled animations. 6 | 7 | The code is based on the [CompVis] and [Stability AI] libraries and heavily borrows from [this repo](https://github.com/AmericanPresidentJimmyCarter/stable-diffusion), with occasional additions from [InvokeAI] and [Deforum], as well as the others mentioned below. The following codebases are partially included here (to ensure compatibility and the ease of setup): [k-diffusion](https://github.com/crowsonkb/k-diffusion), [Taming Transformers](https://github.com/CompVis/taming-transformers), [OpenCLIP], [CLIPseg]. 8 | **There is also a [similar repo](https://github.com/eps696/SD), based on the [diffusers] library, which is more logical and up-to-date.** 9 | 10 | Current functions: 11 | * Text to image 12 | * Image re- and in-painting 13 | * Latent interpolations (with text prompts and images) 14 | 15 | Fine-tuning with your images: 16 | * Add subject (new token) with [textual inversion] 17 | * Add subject (prompt embedding + Unet delta) with [custom diffusion] 18 | 19 | Other features: 20 | * Memory efficient with `xformers` (hi res on 6gb VRAM GPU) 21 | * Use of special depth/inpainting and v2 models 22 | * Masking with text via [CLIPseg] 23 | * Weighted multi-prompts 24 | * to be continued.. 25 | 26 | More details and Colab version will follow. 27 | 28 | ## Setup 29 | 30 | Install CUDA 11.6. Setup the Conda environment: 31 | ``` 32 | conda create -n SD python=3.10 numpy pillow 33 | activate SD 34 | pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116 35 | pip install -r requirements.txt 36 | ``` 37 | Install `xformers` library to increase performance. It makes possible to run SD in any resolution on the lower grade hardware (e.g. videocards with 6gb VRAM). If you're on Windows, first ensure that you have Visual Studio 2019 installed. 38 | ``` 39 | pip install git+https://github.com/facebookresearch/xformers.git 40 | ``` 41 | Download Stable Diffusion ([1.5](https://huggingface.co/CompVis/stable-diffusion), [1.5-inpaint](https://huggingface.co/runwayml/stable-diffusion-inpainting), [2-inpaint](https://huggingface.co/stabilityai/stable-diffusion-2-inpainting), [2-depth](https://huggingface.co/stabilityai/stable-diffusion-2-depth), [2.1](https://huggingface.co/stabilityai/stable-diffusion-2-1-base), [2.1-v](https://huggingface.co/stabilityai/stable-diffusion-2-1), [OpenCLIP], [custom VAE](https://huggingface.co/stabilityai/sd-vae-ft-ema-original), [CLIPseg], [MiDaS](https://github.com/isl-org/MiDaS) models (mostly converted to `float16` for faster loading) by the command below. Licensing info is available on their webpages. 42 | ``` 43 | python download.py 44 | ``` 45 | 46 | ## Operations 47 | 48 | Examples of usage: 49 | 50 | * Generate an image from the text prompt: 51 | ``` 52 | python src/_sdrun.py -t "hello world" --size 1024-576 53 | ``` 54 | * Redraw an image with existing style embedding: 55 | ``` 56 | python src/_sdrun.py -im _in/something.jpg -t "" 57 | ``` 58 | * Redraw directory of images, keeping the basic forms intact: 59 | ``` 60 | python src/_sdrun.py -im _in/pix -t "neon light glow" --model v2d 61 | ``` 62 | * Inpaint directory of images with RunwayML model, turning humans into robots: 63 | ``` 64 | python src/_sdrun.py -im _in/pix --mask "human, person" -t "steampunk robot" --model 15i 65 | ``` 66 | * Make a video, interpolating between the lines of the text file: 67 | ``` 68 | python src/latwalk.py -t yourfile.txt --size 1024-576 69 | ``` 70 | * Same, with drawing over a masked image: 71 | ``` 72 | python src/latwalk.py -t yourfile.txt -im _in/pix/bench2.jpg --mask _in/pix/mask/bench2_mask.jpg 73 | ``` 74 | Check other options by running these scripts with `--help` option; try various models, samplers, noisers, etc. 75 | Text prompts may include either special tokens (e.g. ``) or weights (like `good prompt :1 | also good prompt :1 | bad prompt :-0.5`). The latter may degrade overall accuracy though. 76 | Interpolated videos may be further smoothed out with [FILM](https://github.com/google-research/frame-interpolation). 77 | 78 | There are also Windows bat-files, slightly simplifying and automating the commands. 79 | 80 | ## Fine-tuning 81 | 82 | * Train prompt embedding for a specific subject (e.g. cat) with [textual inversion]: 83 | ``` 84 | python src/train.py --token mycat1 --term cat --data data/mycat1 85 | ``` 86 | * Do the same with [custom diffusion]: 87 | ``` 88 | python src/train.py --token mycat1 --term cat --data data/mycat1 --reg_data data/cat 89 | ``` 90 | Results of the trainings above will be saved under `train` directory. 91 | 92 | Custom diffusion trains faster and can achieve impressive reproduction quality in the simple and similar prompts, but it can entirely lose the point if the prompt is too complex or aside from the original category. Result file is 73mb (can be compressed to ~16mb). Note that in that case you'll need both target reference images (`data/mycat1`) and more random images of similar subjects (`data/cat`). Apparently, you can generate the latter with SD itself. 93 | Textual inversion is more generic but stable. Its embeddings can also be easily combined without additional retraining. Result file is ~5kb. 94 | 95 | * Generate image with embedding from [textual inversion]. You'll need to rename the embedding file as your trained token (e.g. `mycat1.pt`), and point the path to its directory. Note that the token is hardcoded in the file, so you can't change it afterwards. 96 | ``` 97 | python src/_sdrun.py -t "cosmic beast" --embeds train 98 | ``` 99 | * Generate image with embedding from [custom diffusion]. You'll need to explicitly mention your new token (so you can name it differently here) and path to the trained delta file: 100 | ``` 101 | python src/_sdrun.py -t "cosmic beast" --token_mod mycat1 --delta_ckpt train/delta-xxx.ckpt 102 | ``` 103 | You can also run `python src/latwalk.py ...` with such arguments to make animations. 104 | 105 | 106 | ## Credits 107 | 108 | It's quite hard to mention all those who made the current revolution in visual creativity possible. Check the inline links above for some of the sources. 109 | Huge respect to the people behind [Stable Diffusion], [InvokeAI], [Deforum] and the whole open-source movement. 110 | 111 | [Stable Diffusion]: 112 | [CompVis]: 113 | [Stability AI]: 114 | [InvokeAI]: 115 | [Deforum]: 116 | [OpenCLIP]: 117 | [CLIPseg]: 118 | [textual inversion]: 119 | [custom diffusion]: 120 | -------------------------------------------------------------------------------- /_in/pix/6458524847_2f4c361183_k.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/_in/pix/6458524847_2f4c361183_k.jpg -------------------------------------------------------------------------------- /_in/pix/8399166846_f6fb4e4b8e_k.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/_in/pix/8399166846_f6fb4e4b8e_k.jpg -------------------------------------------------------------------------------- /_in/pix/alex-iby-G_Pk4D9rMLs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/_in/pix/alex-iby-G_Pk4D9rMLs.jpg -------------------------------------------------------------------------------- /_in/pix/bench2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/_in/pix/bench2.jpg -------------------------------------------------------------------------------- /_in/pix/bertrand-gabioud-CpuFzIsHYJ0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/_in/pix/bertrand-gabioud-CpuFzIsHYJ0.jpg -------------------------------------------------------------------------------- /_in/pix/billow926-12-Wc-Zgx6Y.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/_in/pix/billow926-12-Wc-Zgx6Y.jpg -------------------------------------------------------------------------------- /_in/pix/mask/6458524847_2f4c361183_k_mask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/_in/pix/mask/6458524847_2f4c361183_k_mask.jpg -------------------------------------------------------------------------------- /_in/pix/mask/8399166846_f6fb4e4b8e_k_mask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/_in/pix/mask/8399166846_f6fb4e4b8e_k_mask.jpg -------------------------------------------------------------------------------- /_in/pix/mask/alex-iby-G_Pk4D9rMLs_mask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/_in/pix/mask/alex-iby-G_Pk4D9rMLs_mask.jpg -------------------------------------------------------------------------------- /_in/pix/mask/bench2_mask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/_in/pix/mask/bench2_mask.jpg -------------------------------------------------------------------------------- /_in/pix/mask/bertrand-gabioud-CpuFzIsHYJ0_mask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/_in/pix/mask/bertrand-gabioud-CpuFzIsHYJ0_mask.jpg -------------------------------------------------------------------------------- /_in/pix/mask/billow926-12-Wc-Zgx6Y_mask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/_in/pix/mask/billow926-12-Wc-Zgx6Y_mask.jpg -------------------------------------------------------------------------------- /_in/pix/mask/overture-creations-5sI6fQgYIuo_mask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/_in/pix/mask/overture-creations-5sI6fQgYIuo_mask.jpg -------------------------------------------------------------------------------- /_in/pix/mask/photo-1583445095369-9c651e7e5d34_mask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/_in/pix/mask/photo-1583445095369-9c651e7e5d34_mask.jpg -------------------------------------------------------------------------------- /_in/pix/overture-creations-5sI6fQgYIuo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/_in/pix/overture-creations-5sI6fQgYIuo.jpg -------------------------------------------------------------------------------- /_in/pix/photo-1583445095369-9c651e7e5d34.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/_in/pix/photo-1583445095369-9c651e7e5d34.jpg -------------------------------------------------------------------------------- /_in/something.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/_in/something.jpg -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import urllib.request 4 | 5 | def download_model(url: str, root: str = "./models"): 6 | os.makedirs(root, exist_ok=True) 7 | filename = os.path.basename(url.split('?')[0]) 8 | download_target = os.path.join(root, filename) 9 | 10 | if os.path.exists(download_target) and not os.path.isfile(download_target): 11 | raise RuntimeError(f"{download_target} exists and is not a regular file") 12 | 13 | if os.path.isfile(download_target): 14 | return download_target 15 | 16 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 17 | with tqdm(total=int(source.info().get("Content-Length")), ncols=64, unit='iB', unit_scale=True) as loop: 18 | while True: 19 | buffer = source.read(8192) 20 | if not buffer: 21 | break 22 | output.write(buffer) 23 | loop.update(len(buffer)) 24 | 25 | return download_target 26 | 27 | # print(' downloading SD 1.4 model') 28 | # download_model("https://www.dropbox.com/s/2b3w0vysf485tpc/sd-v14-512-fp16.ckpt?dl=1", 'models') 29 | print(' downloading SD 1.5 model') 30 | download_model("https://www.dropbox.com/s/k9odmzadgyo9gdl/sd-v15-512-fp16.ckpt?dl=1", 'models') 31 | print(' downloading SD 1.5-inpainting model') 32 | download_model("https://www.dropbox.com/s/cc5usmoik43alcc/sd-v15-512-inpaint-fp16.ckpt?dl=1", 'models') 33 | print(' downloading SD 2-inpainting model') 34 | download_model("https://www.dropbox.com/s/kn9jhrkofsfqsae/sd-v2-512-inpaint-fp16.ckpt?dl=1", 'models') 35 | print(' downloading SD 2-depth model') 36 | download_model("https://www.dropbox.com/s/zrx5qfesb9jstsg/sd-v2-512-depth-fp16.ckpt?dl=1", 'models') 37 | print(' downloading SD 2.1 model') 38 | download_model("https://www.dropbox.com/s/m4v36h8tksqa2lk/sd-v21-512-fp16.ckpt?dl=1", 'models') 39 | print(' downloading SD 2.1-v model') 40 | download_model("https://www.dropbox.com/s/wjzh3l1szauz5ww/sd-v21v-768-fp16.ckpt?dl=1", 'models') 41 | 42 | print(' downloading OpenCLIP ViT-H-14-laion2B-s32B-b79K model') 43 | download_model("https://www.dropbox.com/s/7smohfi2ijdy1qm/laion2b_s32b_b79k-vit-h14.pt?dl=1", 'models/openclip') 44 | # download_model("https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin", 'models/openclip') 45 | 46 | print(' downloading SD VAE-ema model') 47 | download_model("https://www.dropbox.com/s/dv836z05lblkvkc/vae-ft-ema-560000.ckpt?dl=1", 'models') 48 | print(' downloading SD VAE-mse model') 49 | download_model("https://www.dropbox.com/s/jmxksbzyk9fls1y/vae-ft-mse-840000.ckpt?dl=1", 'models') 50 | 51 | print(' downloading CLIPseg model') 52 | download_model("https://www.dropbox.com/s/c0tduhr4g0al1cq/rd64-uni.pth?dl=1", 'models/clipseg') 53 | print(' downloading MiDaS depth model') 54 | download_model("https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt", 'models/depth') 55 | 56 | -------------------------------------------------------------------------------- /img.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | set KMP_DUPLICATE_LIB_OK=TRUE 3 | if [%1]==[] goto help 4 | echo .. %1 .. %2 5 | 6 | if exist _in\%~n1\* goto proc 7 | set seq=ok 8 | echo .. making source sequence 9 | mkdir _in\%~n1 10 | ffmpeg -y -v warning -i _in\%1 -q:v 2 _in\%~n1\%%06d.jpg 11 | 12 | :proc 13 | echo .. processing 14 | python src/_sdrun.py -v -im _in/%~n1 -o _out/%~n1 -t %2 ^ 15 | %3 %4 %5 %6 %7 %8 %9 16 | 17 | if %seq%==ok goto seq 18 | goto end 19 | 20 | :seq 21 | ffmpeg -y -v warning -i _out\%~n1\%%06d.jpg _out\%~n1-%2-%3%4%5%6%7%8%9.mp4 22 | goto end 23 | 24 | :help 25 | echo Usage: img imagedir "text prompt" [...] 26 | echo or: img videofile "text prompt" [...] 27 | :end 28 | -------------------------------------------------------------------------------- /inpaint.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | set KMP_DUPLICATE_LIB_OK=TRUE 3 | if [%1]==[] goto help 4 | echo .. %1 .. %2 .. %3 5 | 6 | python src/_sdrun.py -v -im %1 -o _out/%~n1 --mask %2 -t %3 ^ 7 | %4 %5 %6 %7 %8 %9 8 | 9 | goto end 10 | 11 | :help 12 | echo Usage: inpaint imagedir masksdir "text prompt" [...] 13 | echo e.g.: inpaint _in/pix _in/pix/mask "steampunk fantasy" 14 | echo or: inpaint _in/pix "human figure" "steampunk fantasy" -m 15i 15 | :end 16 | -------------------------------------------------------------------------------- /model_half.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import collections 4 | import torch 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--dir", '-d', default='./', help="directory with models") 8 | parser.add_argument("--ext", '-e', default=['ckpt','pt'], help="model extensions") 9 | a = parser.parse_args() 10 | 11 | def basename(file): 12 | return os.path.splitext(os.path.basename(file))[0] 13 | 14 | def file_list(path, ext=None): 15 | files = [os.path.join(path, f) for f in os.listdir(path)] 16 | if ext is not None: 17 | if isinstance(ext, list): 18 | files = [f for f in files if os.path.splitext(f.lower())[1][1:] in ext] 19 | elif isinstance(ext, str): 20 | files = [f for f in files if f.endswith(ext)] 21 | else: 22 | print(' Unknown extension/type for file list!') 23 | return sorted([f for f in files if os.path.isfile(f)]) 24 | 25 | def float2half(data): 26 | for k in data: 27 | if isinstance(data[k], collections.abc.Mapping): 28 | data[k] = float2half(data[k]) 29 | elif isinstance(data[k], list): 30 | data[k] = [float2half(x) for x in data[k]] 31 | else: 32 | if data[k] is not None and torch.is_tensor(data[k]) and data[k].type() == 'torch.FloatTensor': 33 | data[k] = data[k].half() 34 | return data 35 | 36 | models = file_list(a.dir, a.ext) 37 | 38 | for model_path in models: 39 | model = torch.load(model_path) # dict? 40 | model = float2half(model) 41 | file_out = basename(model_path) + '-half' + os.path.splitext(model_path)[-1] 42 | torch.save(model, file_out) 43 | 44 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | tqdm 3 | einops 4 | kornia 5 | omegaconf 6 | torchmetrics==0.6.0 7 | transformers 8 | requests 9 | opencv-python 10 | pytorch-lightning==1.4.2 11 | open-clip-torch==2.7.0 12 | git+https://github.com/openai/CLIP.git@main#egg=clip 13 | timm 14 | scikit-image 15 | jsonmerge 16 | clean-fid 17 | resize_right 18 | torchdiffeq 19 | torchsde 20 | ipywidgets 21 | 22 | -------------------------------------------------------------------------------- /src/custom/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/custom/__init__.py -------------------------------------------------------------------------------- /src/custom/compress.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Adobe Research. All rights reserved. 2 | # To view a copy of the license, visit LICENSE.md. 3 | 4 | import torch 5 | import argparse 6 | 7 | 8 | def compress(delta_ckpt, ckpt, diffuser=False, compression_ratio=0.6, device='cuda'): 9 | st = torch.load(f'{delta_ckpt}') 10 | 11 | if not diffuser: 12 | compressed_key = 'state_dict' 13 | compressed_st = {compressed_key: {}} 14 | pretrained_st = torch.load(ckpt)['state_dict'] 15 | if 'embed' in st['state_dict']: 16 | compressed_st['state_dict']['embed'] = st['state_dict']['embed'] 17 | del st['state_dict']['embed'] 18 | 19 | st = st['state_dict'] 20 | else: 21 | from diffusers import StableDiffusionPipeline 22 | compressed_key = 'unet' 23 | compressed_st = {compressed_key: {}} 24 | pretrained_st = StableDiffusionPipeline.from_pretrained(ckpt, torch_dtype=torch.float16).to("cuda") 25 | pretrained_st = pretrained_st.unet.state_dict() 26 | if 'modifier_token' in st: 27 | compressed_st['modifier_token'] = st['modifier_token'] 28 | st = st['unet'] 29 | 30 | print("getting compression") 31 | layers = list(st.keys()) 32 | for name in layers: 33 | if 'to_k' in name or 'to_v' in name: 34 | W = st[name].to(device) 35 | Wpretrain = pretrained_st[name].clone().to(device) 36 | deltaW = W-Wpretrain 37 | 38 | u, s, vt = torch.linalg.svd(deltaW.clone()) 39 | 40 | explain = 0 41 | all_ = (s).sum() 42 | for i, t in enumerate(s): 43 | explain += t/(all_) 44 | if explain > compression_ratio: 45 | break 46 | 47 | compressed_st[compressed_key][f'{name}'] = {} 48 | compressed_st[compressed_key][f'{name}']['u'] = (u[:, :i]@torch.diag(s)[:i, :i]).clone() 49 | compressed_st[compressed_key][f'{name}']['v'] = vt[:i].clone() 50 | else: 51 | compressed_st[compressed_key][f'{name}'] = st[name] 52 | 53 | name = delta_ckpt.replace('delta', 'compressed_delta') 54 | torch.save(compressed_st, f'{name}') 55 | 56 | 57 | def parse_args(): 58 | parser = argparse.ArgumentParser('', add_help=False) 59 | parser.add_argument('--delta_ckpt', help='path of checkpoint to compress', 60 | type=str) 61 | parser.add_argument('--ckpt', help='path of pretrained model checkpoint', 62 | type=str) 63 | parser.add_argument("--diffuser", action='store_true') 64 | return parser.parse_args() 65 | 66 | 67 | if __name__ == "__main__": 68 | args = parse_args() 69 | compress(args.delta_ckpt, args.ckpt, args.diffuser) 70 | -------------------------------------------------------------------------------- /src/custom/get_deltas.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Adobe Research. All rights reserved. 2 | # To view a copy of the license, visit LICENSE.md. 3 | import os 4 | import argparse 5 | import glob 6 | import torch 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser('', add_help=False) 10 | parser.add_argument('-i', '--path', help='path of folder to checkpoints', type=str) 11 | parser.add_argument('-n', '--newtoken', help='number of new tokens in the checkpoint', default=1, type=int) 12 | return parser.parse_args() 13 | 14 | def save_delta(ckpt_dir, newtoken=1): 15 | assert newtoken > 0, 'No new tokens found' 16 | layers = [] 17 | for ckptfile in glob.glob(f'{ckpt_dir}/*.ckpt', recursive=True): 18 | if 'delta' not in ckptfile: 19 | st = torch.load(ckptfile)["state_dict"] 20 | if len(layers) == 0: 21 | for key in list(st.keys()): 22 | if 'attn2.to_k' in key or 'attn2.to_v' in key: 23 | layers.append(key) 24 | st_delta = {'state_dict': {}} 25 | for each in layers: 26 | st_delta['state_dict'][each] = st[each].clone() 27 | num_tokens = st['cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'].shape[0] 28 | st_delta['state_dict']['embed'] = st['cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'][-newtoken:].clone() 29 | filepath = os.path.join(ckpt_dir, 'delta-' + os.path.basename(ckptfile)) 30 | torch.save(st_delta, filepath) 31 | os.remove(ckptfile) 32 | print('.. saved embedding', st_delta['state_dict']['embed'].cpu().numpy().shape, num_tokens, '=>', filepath) 33 | 34 | 35 | if __name__ == "__main__": 36 | args = parse_args() 37 | save_delta(args.path, args.newtoken) 38 | -------------------------------------------------------------------------------- /src/custom/modules.py: -------------------------------------------------------------------------------- 1 | # This code is built from the Huggingface repository: https://github.com/huggingface/transformers/tree/main/src/transformers/models/clip. 2 | # Copyright 2018- The Hugging Face team. All rights reserved. 3 | 4 | import os, sys 5 | from packaging import version 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | import transformers 11 | from transformers import CLIPTokenizer, CLIPTextModel 12 | 13 | class AbstractEncoder(nn.Module): 14 | def __init__(self): 15 | super().__init__() 16 | def encode(self, *args, **kwargs): 17 | raise NotImplementedError 18 | 19 | # https://github.com/adobe-research/custom-diffusion 20 | class FrozenCLIPEmbedderWrapper(AbstractEncoder): 21 | """Uses the CLIP transformer encoder for text (from Hugging Face)""" 22 | def __init__(self, modifier_token, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, model_dir='models'): 23 | super().__init__() 24 | self.tokenizer = CLIPTokenizer.from_pretrained(version, cache_dir=os.path.join(model_dir, os.path.basename(version)), local_files_only=False) 25 | self.transformer = CLIPTextModel.from_pretrained(version, cache_dir=os.path.join(model_dir, os.path.basename(version)), local_files_only=False) 26 | self.device = device 27 | self.max_length = max_length 28 | self.modifier_token = modifier_token 29 | if '+' in self.modifier_token: 30 | self.modifier_token = self.modifier_token.split('+') 31 | else: 32 | self.modifier_token = [self.modifier_token] 33 | 34 | self.add_token() 35 | self.freeze() 36 | 37 | def add_token(self): 38 | self.modifier_token_id = [] 39 | token_embeds1 = self.transformer.get_input_embeddings().weight.data 40 | for each_modifier_token in self.modifier_token: 41 | num_added_tokens = self.tokenizer.add_tokens(each_modifier_token) 42 | modifier_token_id = self.tokenizer.convert_tokens_to_ids(each_modifier_token) 43 | self.modifier_token_id.append(modifier_token_id) # .., 49408, 49409 44 | 45 | self.transformer.resize_token_embeddings(len(self.tokenizer)) 46 | token_embeds = self.transformer.get_input_embeddings().weight.data # [49410, 768] 47 | # print(' modifier_token_id', self.modifier_token_id, '.. token_embeds', token_embeds.shape) 48 | token_embeds[self.modifier_token_id[-1]] = torch.nn.Parameter(token_embeds[42170], requires_grad=True) 49 | if len(self.modifier_token) == 2: 50 | token_embeds[self.modifier_token_id[-2]] = torch.nn.Parameter(token_embeds[47629], requires_grad=True) 51 | if len(self.modifier_token) == 3: 52 | token_embeds[self.modifier_token_id[-3]] = torch.nn.Parameter(token_embeds[43514], requires_grad=True) 53 | 54 | def custom_forward(self, hidden_states, input_ids): 55 | input_shape = hidden_states.size() 56 | bsz, seq_len = input_shape[:2] 57 | if version.parse(transformers.__version__) >= version.parse('4.21'): 58 | causal_attention_mask = self.transformer.text_model._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(hidden_states.device) 59 | else: 60 | causal_attention_mask = self.transformer.text_model._build_causal_attention_mask(bsz, seq_len).to(hidden_states.device) 61 | 62 | encoder_outputs = self.transformer.text_model.encoder(inputs_embeds=hidden_states, causal_attention_mask=causal_attention_mask) 63 | 64 | last_hidden_state = encoder_outputs[0] 65 | last_hidden_state = self.transformer.text_model.final_layer_norm(last_hidden_state) 66 | 67 | return last_hidden_state 68 | 69 | def freeze(self): 70 | self.transformer = self.transformer.eval() 71 | for param in self.transformer.text_model.encoder.parameters(): 72 | param.requires_grad = False 73 | for param in self.transformer.text_model.final_layer_norm.parameters(): 74 | param.requires_grad = False 75 | for param in self.transformer.text_model.embeddings.position_embedding.parameters(): 76 | param.requires_grad = False 77 | 78 | def forward(self, text): 79 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 80 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 81 | tokens = batch_encoding["input_ids"].to(self.device) 82 | 83 | if len(self.modifier_token) == 3: 84 | indices = ((tokens == self.modifier_token_id[-1]) | (tokens == self.modifier_token_id[-2]) | (tokens == self.modifier_token_id[-3]))*1 85 | elif len(self.modifier_token) == 2: 86 | indices = ((tokens == self.modifier_token_id[-1]) | (tokens == self.modifier_token_id[-2]))*1 87 | else: 88 | indices = (tokens == self.modifier_token_id[-1])*1 89 | 90 | indices = indices.unsqueeze(-1) 91 | 92 | input_shape = tokens.size() 93 | tokens = tokens.view(-1, input_shape[-1]) 94 | 95 | hidden_states = self.transformer.text_model.embeddings(input_ids=tokens) 96 | hidden_states = (1-indices)*hidden_states.detach() + indices*hidden_states 97 | 98 | z = self.custom_forward(hidden_states, tokens) 99 | 100 | return z 101 | 102 | def encode(self, text): 103 | return self(text) 104 | 105 | -------------------------------------------------------------------------------- /src/ldm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/ldm/__init__.py -------------------------------------------------------------------------------- /src/ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/ldm/data/__init__.py -------------------------------------------------------------------------------- /src/ldm/data/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset 4 | 5 | class Txt2ImgIterableBaseDataset(IterableDataset): 6 | """ Define an interface to make the IterableDatasets for text2img data chainable """ 7 | 8 | def __init__(self, num_records=0, valid_ids=None, size=256): 9 | super().__init__() 10 | self.num_records = num_records 11 | self.valid_ids = valid_ids 12 | self.sample_ids = valid_ids 13 | self.size = size 14 | 15 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') 16 | 17 | def __len__(self): 18 | return self.num_records 19 | 20 | @abstractmethod 21 | def __iter__(self): 22 | pass 23 | -------------------------------------------------------------------------------- /src/ldm/data/personalized.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | from PIL import Image 5 | 6 | from torch.utils.data import Dataset 7 | from torchvision import transforms 8 | 9 | templates_smallest = ['a photo of a {}'] 10 | 11 | templates = [ 12 | 'a photo of a {}', 13 | 'a rendering of a {}', 14 | 'a cropped photo of the {}', 15 | 'the photo of a {}', 16 | 'a photo of a clean {}', 17 | 'a photo of a dirty {}', 18 | 'a dark photo of the {}', 19 | 'a photo of my {}', 20 | 'a photo of the cool {}', 21 | 'a close-up photo of a {}', 22 | 'a bright photo of the {}', 23 | 'a cropped photo of a {}', 24 | 'a photo of the {}', 25 | 'a good photo of the {}', 26 | 'a photo of one {}', 27 | 'a close-up photo of the {}', 28 | 'a rendition of the {}', 29 | 'a photo of the clean {}', 30 | 'a rendition of a {}', 31 | 'a photo of a nice {}', 32 | 'a good photo of a {}', 33 | 'a photo of the nice {}', 34 | 'a photo of the small {}', 35 | 'a photo of the weird {}', 36 | 'a photo of the large {}', 37 | 'a photo of a cool {}', 38 | 'a photo of a small {}', 39 | ] 40 | dual_templates = [ 41 | 'a photo of a {} with {}', 42 | 'a rendering of a {} with {}', 43 | 'a cropped photo of the {} with {}', 44 | 'the photo of a {} with {}', 45 | 'a photo of a clean {} with {}', 46 | 'a photo of a dirty {} with {}', 47 | 'a dark photo of the {} with {}', 48 | 'a photo of my {} with {}', 49 | 'a photo of the cool {} with {}', 50 | 'a close-up photo of a {} with {}', 51 | 'a bright photo of the {} with {}', 52 | 'a cropped photo of a {} with {}', 53 | 'a photo of the {} with {}', 54 | 'a good photo of the {} with {}', 55 | 'a photo of one {} with {}', 56 | 'a close-up photo of the {} with {}', 57 | 'a rendition of the {} with {}', 58 | 'a photo of the clean {} with {}', 59 | 'a rendition of a {} with {}', 60 | 'a photo of a nice {} with {}', 61 | 'a good photo of a {} with {}', 62 | 'a photo of the nice {} with {}', 63 | 'a photo of the small {} with {}', 64 | 'a photo of the weird {} with {}', 65 | 'a photo of the large {} with {}', 66 | 'a photo of a cool {} with {}', 67 | 'a photo of a small {} with {}', 68 | ] 69 | templates_style = [ 70 | 'a painting in the style of {}', 71 | 'a rendering in the style of {}', 72 | 'a cropped painting in the style of {}', 73 | 'the painting in the style of {}', 74 | 'a clean painting in the style of {}', 75 | 'a dirty painting in the style of {}', 76 | 'a dark painting in the style of {}', 77 | 'a picture in the style of {}', 78 | 'a cool painting in the style of {}', 79 | 'a close-up painting in the style of {}', 80 | 'a bright painting in the style of {}', 81 | 'a cropped painting in the style of {}', 82 | 'a good painting in the style of {}', 83 | 'a close-up painting in the style of {}', 84 | 'a rendition in the style of {}', 85 | 'a nice painting in the style of {}', 86 | 'a small painting in the style of {}', 87 | 'a weird painting in the style of {}', 88 | 'a large painting in the style of {}', 89 | ] 90 | dual_templates_style = [ 91 | 'a painting in the style of {} with {}', 92 | 'a rendering in the style of {} with {}', 93 | 'a cropped painting in the style of {} with {}', 94 | 'the painting in the style of {} with {}', 95 | 'a clean painting in the style of {} with {}', 96 | 'a dirty painting in the style of {} with {}', 97 | 'a dark painting in the style of {} with {}', 98 | 'a cool painting in the style of {} with {}', 99 | 'a close-up painting in the style of {} with {}', 100 | 'a bright painting in the style of {} with {}', 101 | 'a cropped painting in the style of {} with {}', 102 | 'a good painting in the style of {} with {}', 103 | 'a painting of one {} in the style of {}', 104 | 'a nice painting in the style of {} with {}', 105 | 'a small painting in the style of {} with {}', 106 | 'a weird painting in the style of {} with {}', 107 | 'a large painting in the style of {} with {}', 108 | ] 109 | per_img_token_list = ['א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת'] 110 | 111 | class PersonalizedBase(Dataset): 112 | def __init__(self, data_root, size=None, repeats=100, interpolation='bicubic', flip_p=0.5, set='train', placeholder_token='*', 113 | per_image_tokens=False, center_crop=False, style=False, mixing_prob=0.25, coarse_class_text=None): 114 | self.data_root = data_root 115 | self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] 116 | 117 | self.num_images = len(self.image_paths) 118 | self._length = self.num_images 119 | self.placeholder_token = placeholder_token 120 | self.per_image_tokens = per_image_tokens 121 | self.center_crop = center_crop 122 | self.templates = templates_style if style else templates 123 | self.dual_templates = dual_templates_style if style else dual_templates 124 | self.mixing_prob = mixing_prob 125 | self.coarse_class_text = coarse_class_text 126 | 127 | if per_image_tokens: 128 | assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'." 129 | 130 | if set == 'train': 131 | self._length = self.num_images * repeats 132 | 133 | self.size = size 134 | self.interpolation = {'linear': Image.LINEAR, 'bilinear': Image.BILINEAR, 'bicubic': Image.BICUBIC, 'lanczos': Image.LANCZOS}[interpolation] 135 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 136 | 137 | def __len__(self): 138 | return self._length 139 | 140 | def __getitem__(self, i): 141 | example = {} 142 | image = Image.open(self.image_paths[i % self.num_images]) 143 | 144 | if not image.mode == 'RGB': 145 | image = image.convert('RGB') 146 | 147 | placeholder_string = self.placeholder_token 148 | if self.coarse_class_text: 149 | placeholder_string = f'{placeholder_string} {self.coarse_class_text}' 150 | 151 | if self.per_image_tokens and np.random.uniform() < self.mixing_prob: 152 | text = random.choice(self.dual_templates).format(placeholder_string, per_img_token_list[i % self.num_images]) 153 | else: 154 | text = random.choice(self.templates).format(placeholder_string) 155 | 156 | example['caption'] = text 157 | 158 | # default to score-sde preprocessing 159 | img = np.array(image).astype(np.uint8) 160 | 161 | if self.center_crop: 162 | crop = min(img.shape[0], img.shape[1]) 163 | h, w, = img.shape[0], img.shape[1] 164 | img = img[(h-crop)//2 : (h+crop)//2, (w-crop)//2 : (w+crop)//2] 165 | 166 | image = Image.fromarray(img) 167 | if self.size is not None: 168 | image = image.resize((self.size, self.size), resample=self.interpolation) 169 | 170 | image = self.flip(image) 171 | image = np.array(image).astype(np.uint8) 172 | example['image'] = (image / 127.5 - 1.0).astype(np.float32) 173 | return example 174 | -------------------------------------------------------------------------------- /src/ldm/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/ldm/models/__init__.py -------------------------------------------------------------------------------- /src/ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /src/ldm/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/ldm/modules/__init__.py -------------------------------------------------------------------------------- /src/ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /src/ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /src/ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to( 34 | device=self.parameters.device 35 | ) 36 | 37 | def sample(self): 38 | x = self.mean + self.std * torch.randn(self.mean.shape).to( 39 | device=self.parameters.device 40 | ) 41 | return x 42 | 43 | def kl(self, other=None): 44 | if self.deterministic: 45 | return torch.Tensor([0.0]) 46 | else: 47 | if other is None: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 50 | dim=[1, 2, 3], 51 | ) 52 | else: 53 | return 0.5 * torch.sum( 54 | torch.pow(self.mean - other.mean, 2) / other.var 55 | + self.var / other.var 56 | - 1.0 57 | - self.logvar 58 | + other.logvar, 59 | dim=[1, 2, 3], 60 | ) 61 | 62 | def nll(self, sample, dims=[1, 2, 3]): 63 | if self.deterministic: 64 | return torch.Tensor([0.0]) 65 | logtwopi = np.log(2.0 * np.pi) 66 | return 0.5 * torch.sum( 67 | logtwopi 68 | + self.logvar 69 | + torch.pow(sample - self.mean, 2) / self.var, 70 | dim=dims, 71 | ) 72 | 73 | def mode(self): 74 | return self.mean 75 | 76 | 77 | def normal_kl(mean1, logvar1, mean2, logvar2): 78 | """ 79 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 80 | Compute the KL divergence between two gaussians. 81 | Shapes are automatically broadcasted, so batches can be compared to 82 | scalars, among other use cases. 83 | """ 84 | tensor = None 85 | for obj in (mean1, logvar1, mean2, logvar2): 86 | if isinstance(obj, torch.Tensor): 87 | tensor = obj 88 | break 89 | assert tensor is not None, 'at least one argument must be a Tensor' 90 | 91 | # Force variances to be Tensors. Broadcasting helps convert scalars to 92 | # Tensors, but it does not work for torch.exp(). 93 | logvar1, logvar2 = [ 94 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 95 | for x in (logvar1, logvar2) 96 | ] 97 | 98 | return 0.5 * ( 99 | -1.0 100 | + logvar2 101 | - logvar1 102 | + torch.exp(logvar1 - logvar2) 103 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 104 | ) 105 | -------------------------------------------------------------------------------- /src/ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer( 14 | 'num_updates', 15 | torch.tensor(0, dtype=torch.int) 16 | if use_num_upates 17 | else torch.tensor(-1, dtype=torch.int), 18 | ) 19 | 20 | for name, p in model.named_parameters(): 21 | if p.requires_grad: 22 | # remove as '.'-character is not allowed in buffers 23 | s_name = name.replace('.', '') 24 | self.m_name2s_name.update({name: s_name}) 25 | self.register_buffer(s_name, p.clone().detach().data) 26 | 27 | self.collected_params = [] 28 | 29 | def forward(self, model): 30 | decay = self.decay 31 | 32 | if self.num_updates >= 0: 33 | self.num_updates += 1 34 | decay = min( 35 | self.decay, (1 + self.num_updates) / (10 + self.num_updates) 36 | ) 37 | 38 | one_minus_decay = 1.0 - decay 39 | 40 | with torch.no_grad(): 41 | m_param = dict(model.named_parameters()) 42 | shadow_params = dict(self.named_buffers()) 43 | 44 | for key in m_param: 45 | if m_param[key].requires_grad: 46 | sname = self.m_name2s_name[key] 47 | shadow_params[sname] = shadow_params[sname].type_as( 48 | m_param[key] 49 | ) 50 | shadow_params[sname].sub_( 51 | one_minus_decay * (shadow_params[sname] - m_param[key]) 52 | ) 53 | else: 54 | assert not key in self.m_name2s_name 55 | 56 | def copy_to(self, model): 57 | m_param = dict(model.named_parameters()) 58 | shadow_params = dict(self.named_buffers()) 59 | for key in m_param: 60 | if m_param[key].requires_grad: 61 | m_param[key].data.copy_( 62 | shadow_params[self.m_name2s_name[key]].data 63 | ) 64 | else: 65 | assert not key in self.m_name2s_name 66 | 67 | def store(self, parameters): 68 | """ 69 | Save the current parameters for restoring later. 70 | Args: 71 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 72 | temporarily stored. 73 | """ 74 | self.collected_params = [param.clone() for param in parameters] 75 | 76 | def restore(self, parameters): 77 | """ 78 | Restore the parameters stored with the `store` method. 79 | Useful to validate the model with EMA parameters without affecting the 80 | original optimization process. Store the parameters before the 81 | `copy_to` method. After validation (or model saving), use this to 82 | restore the former parameters. 83 | Args: 84 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 85 | updated with the stored parameters. 86 | """ 87 | for c_param, param in zip(self.collected_params, parameters): 88 | param.data.copy_(c_param.data) 89 | -------------------------------------------------------------------------------- /src/ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /src/ldm/modules/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/ldm/modules/midas/__init__.py -------------------------------------------------------------------------------- /src/ldm/modules/midas/api.py: -------------------------------------------------------------------------------- 1 | # based on https://github.com/isl-org/MiDaS 2 | 3 | import cv2 4 | import torch 5 | import torch.nn as nn 6 | from torchvision.transforms import Compose 7 | 8 | from ldm.modules.midas.midas.dpt_depth import DPTDepthModel 9 | from ldm.modules.midas.midas.midas_net import MidasNet 10 | from ldm.modules.midas.midas.midas_net_custom import MidasNet_small 11 | from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet 12 | 13 | 14 | ISL_PATHS = { 15 | "dpt_large": "models/depth/dpt_large-midas-2f21e586.pt", 16 | "dpt_hybrid": "models/depth/dpt_hybrid-midas-501f0c75.pt", 17 | "midas_v21": "", 18 | "midas_v21_small": "", 19 | } 20 | 21 | 22 | def disabled_train(self, mode=True): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | def load_midas_transform(model_type): 29 | # https://github.com/isl-org/MiDaS/blob/master/run.py 30 | # load transform only 31 | if model_type == "dpt_large": # DPT-Large 32 | net_w, net_h = 384, 384 33 | resize_mode = "minimal" 34 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 35 | 36 | elif model_type == "dpt_hybrid": # DPT-Hybrid 37 | net_w, net_h = 384, 384 38 | resize_mode = "minimal" 39 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 40 | 41 | elif model_type == "midas_v21": 42 | net_w, net_h = 384, 384 43 | resize_mode = "upper_bound" 44 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 45 | 46 | elif model_type == "midas_v21_small": 47 | net_w, net_h = 256, 256 48 | resize_mode = "upper_bound" 49 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 50 | 51 | else: 52 | assert False, f"model_type '{model_type}' not implemented, use: --model_type large" 53 | 54 | transform = Compose( 55 | [ 56 | Resize( 57 | net_w, 58 | net_h, 59 | resize_target=None, 60 | keep_aspect_ratio=True, 61 | ensure_multiple_of=32, 62 | resize_method=resize_mode, 63 | image_interpolation_method=cv2.INTER_CUBIC, 64 | ), 65 | normalization, 66 | PrepareForNet(), 67 | ] 68 | ) 69 | 70 | return transform 71 | 72 | 73 | def load_model(model_type): 74 | # https://github.com/isl-org/MiDaS/blob/master/run.py 75 | # load network 76 | model_path = ISL_PATHS[model_type] 77 | if model_type == "dpt_large": # DPT-Large 78 | model = DPTDepthModel( 79 | path=model_path, 80 | backbone="vitl16_384", 81 | non_negative=True, 82 | ) 83 | net_w, net_h = 384, 384 84 | resize_mode = "minimal" 85 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 86 | 87 | elif model_type == "dpt_hybrid": # DPT-Hybrid 88 | model = DPTDepthModel( 89 | path=model_path, 90 | backbone="vitb_rn50_384", 91 | non_negative=True, 92 | ) 93 | net_w, net_h = 384, 384 94 | resize_mode = "minimal" 95 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 96 | 97 | elif model_type == "midas_v21": 98 | model = MidasNet(model_path, non_negative=True) 99 | net_w, net_h = 384, 384 100 | resize_mode = "upper_bound" 101 | normalization = NormalizeImage( 102 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 103 | ) 104 | 105 | elif model_type == "midas_v21_small": 106 | model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, 107 | non_negative=True, blocks={'expand': True}) 108 | net_w, net_h = 256, 256 109 | resize_mode = "upper_bound" 110 | normalization = NormalizeImage( 111 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 112 | ) 113 | 114 | else: 115 | print(f"model_type '{model_type}' not implemented, use: --model_type large") 116 | assert False 117 | 118 | transform = Compose( 119 | [ 120 | Resize( 121 | net_w, 122 | net_h, 123 | resize_target=None, 124 | keep_aspect_ratio=True, 125 | ensure_multiple_of=32, 126 | resize_method=resize_mode, 127 | image_interpolation_method=cv2.INTER_CUBIC, 128 | ), 129 | normalization, 130 | PrepareForNet(), 131 | ] 132 | ) 133 | 134 | return model.eval(), transform 135 | 136 | 137 | class MiDaSInference(nn.Module): 138 | MODEL_TYPES_TORCH_HUB = [ 139 | "DPT_Large", 140 | "DPT_Hybrid", 141 | "MiDaS_small" 142 | ] 143 | MODEL_TYPES_ISL = [ 144 | "dpt_large", 145 | "dpt_hybrid", 146 | "midas_v21", 147 | "midas_v21_small", 148 | ] 149 | 150 | def __init__(self, model_type): 151 | super().__init__() 152 | assert (model_type in self.MODEL_TYPES_ISL) 153 | model, _ = load_model(model_type) 154 | self.model = model 155 | self.model.train = disabled_train 156 | 157 | def forward(self, x): 158 | # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array 159 | # NOTE: we expect that the correct transform has been called during dataloading. 160 | with torch.no_grad(): 161 | prediction = self.model(x) 162 | prediction = torch.nn.functional.interpolate( 163 | prediction.unsqueeze(1), 164 | size=x.shape[2:], 165 | mode="bicubic", 166 | align_corners=False, 167 | ) 168 | assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3]) 169 | return prediction 170 | 171 | -------------------------------------------------------------------------------- /src/ldm/modules/midas/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/ldm/modules/midas/midas/__init__.py -------------------------------------------------------------------------------- /src/ldm/modules/midas/midas/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device('cpu')) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /src/ldm/modules/midas/midas/dpt_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_model import BaseModel 6 | from .blocks import ( 7 | FeatureFusionBlock, 8 | FeatureFusionBlock_custom, 9 | Interpolate, 10 | _make_encoder, 11 | forward_vit, 12 | ) 13 | 14 | 15 | def _make_fusion_block(features, use_bn): 16 | return FeatureFusionBlock_custom( 17 | features, 18 | nn.ReLU(False), 19 | deconv=False, 20 | bn=use_bn, 21 | expand=False, 22 | align_corners=True, 23 | ) 24 | 25 | 26 | class DPT(BaseModel): 27 | def __init__( 28 | self, 29 | head, 30 | features=256, 31 | backbone="vitb_rn50_384", 32 | readout="project", 33 | channels_last=False, 34 | use_bn=False, 35 | ): 36 | 37 | super(DPT, self).__init__() 38 | 39 | self.channels_last = channels_last 40 | 41 | hooks = { 42 | "vitb_rn50_384": [0, 1, 8, 11], 43 | "vitb16_384": [2, 5, 8, 11], 44 | "vitl16_384": [5, 11, 17, 23], 45 | } 46 | 47 | # Instantiate backbone and reassemble blocks 48 | self.pretrained, self.scratch = _make_encoder( 49 | backbone, 50 | features, 51 | False, # Set to true of you want to train from scratch, uses ImageNet weights 52 | groups=1, 53 | expand=False, 54 | exportable=False, 55 | hooks=hooks[backbone], 56 | use_readout=readout, 57 | ) 58 | 59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 63 | 64 | self.scratch.output_conv = head 65 | 66 | 67 | def forward(self, x): 68 | if self.channels_last == True: 69 | x.contiguous(memory_format=torch.channels_last) 70 | 71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 72 | 73 | layer_1_rn = self.scratch.layer1_rn(layer_1) 74 | layer_2_rn = self.scratch.layer2_rn(layer_2) 75 | layer_3_rn = self.scratch.layer3_rn(layer_3) 76 | layer_4_rn = self.scratch.layer4_rn(layer_4) 77 | 78 | path_4 = self.scratch.refinenet4(layer_4_rn) 79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 82 | 83 | out = self.scratch.output_conv(path_1) 84 | 85 | return out 86 | 87 | 88 | class DPTDepthModel(DPT): 89 | def __init__(self, path=None, non_negative=True, **kwargs): 90 | features = kwargs["features"] if "features" in kwargs else 256 91 | 92 | head = nn.Sequential( 93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 96 | nn.ReLU(True), 97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 98 | nn.ReLU(True) if non_negative else nn.Identity(), 99 | nn.Identity(), 100 | ) 101 | 102 | super().__init__(head, **kwargs) 103 | 104 | if path is not None: 105 | self.load(path) 106 | 107 | def forward(self, x): 108 | return super().forward(x).squeeze(dim=1) 109 | 110 | -------------------------------------------------------------------------------- /src/ldm/modules/midas/midas/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=256, non_negative=True): 17 | """Init. 18 | 19 | Args: 20 | path (str, optional): Path to saved model. Defaults to None. 21 | features (int, optional): Number of features. Defaults to 256. 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 23 | """ 24 | print("Loading weights: ", path) 25 | 26 | super(MidasNet, self).__init__() 27 | 28 | use_pretrained = False if path is None else True 29 | 30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) 31 | 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) 36 | 37 | self.scratch.output_conv = nn.Sequential( 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 39 | Interpolate(scale_factor=2, mode="bilinear"), 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 41 | nn.ReLU(True), 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 43 | nn.ReLU(True) if non_negative else nn.Identity(), 44 | ) 45 | 46 | if path: 47 | self.load(path) 48 | 49 | def forward(self, x): 50 | """Forward pass. 51 | 52 | Args: 53 | x (tensor): input data (image) 54 | 55 | Returns: 56 | tensor: depth 57 | """ 58 | 59 | layer_1 = self.pretrained.layer1(x) 60 | layer_2 = self.pretrained.layer2(layer_1) 61 | layer_3 = self.pretrained.layer3(layer_2) 62 | layer_4 = self.pretrained.layer4(layer_3) 63 | 64 | layer_1_rn = self.scratch.layer1_rn(layer_1) 65 | layer_2_rn = self.scratch.layer2_rn(layer_2) 66 | layer_3_rn = self.scratch.layer3_rn(layer_3) 67 | layer_4_rn = self.scratch.layer4_rn(layer_4) 68 | 69 | path_4 = self.scratch.refinenet4(layer_4_rn) 70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 73 | 74 | out = self.scratch.output_conv(path_1) 75 | 76 | return torch.squeeze(out, dim=1) 77 | -------------------------------------------------------------------------------- /src/ldm/modules/midas/midas/midas_net_custom.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet_small(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, 17 | blocks={'expand': True}): 18 | """Init. 19 | 20 | Args: 21 | path (str, optional): Path to saved model. Defaults to None. 22 | features (int, optional): Number of features. Defaults to 256. 23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 24 | """ 25 | print("Loading weights: ", path) 26 | 27 | super(MidasNet_small, self).__init__() 28 | 29 | use_pretrained = False if path else True 30 | 31 | self.channels_last = channels_last 32 | self.blocks = blocks 33 | self.backbone = backbone 34 | 35 | self.groups = 1 36 | 37 | features1=features 38 | features2=features 39 | features3=features 40 | features4=features 41 | self.expand = False 42 | if "expand" in self.blocks and self.blocks['expand'] == True: 43 | self.expand = True 44 | features1=features 45 | features2=features*2 46 | features3=features*4 47 | features4=features*8 48 | 49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) 50 | 51 | self.scratch.activation = nn.ReLU(False) 52 | 53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) 57 | 58 | 59 | self.scratch.output_conv = nn.Sequential( 60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), 61 | Interpolate(scale_factor=2, mode="bilinear"), 62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), 63 | self.scratch.activation, 64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 65 | nn.ReLU(True) if non_negative else nn.Identity(), 66 | nn.Identity(), 67 | ) 68 | 69 | if path: 70 | self.load(path) 71 | 72 | 73 | def forward(self, x): 74 | """Forward pass. 75 | 76 | Args: 77 | x (tensor): input data (image) 78 | 79 | Returns: 80 | tensor: depth 81 | """ 82 | if self.channels_last==True: 83 | print("self.channels_last = ", self.channels_last) 84 | x.contiguous(memory_format=torch.channels_last) 85 | 86 | 87 | layer_1 = self.pretrained.layer1(x) 88 | layer_2 = self.pretrained.layer2(layer_1) 89 | layer_3 = self.pretrained.layer3(layer_2) 90 | layer_4 = self.pretrained.layer4(layer_3) 91 | 92 | layer_1_rn = self.scratch.layer1_rn(layer_1) 93 | layer_2_rn = self.scratch.layer2_rn(layer_2) 94 | layer_3_rn = self.scratch.layer3_rn(layer_3) 95 | layer_4_rn = self.scratch.layer4_rn(layer_4) 96 | 97 | 98 | path_4 = self.scratch.refinenet4(layer_4_rn) 99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 102 | 103 | out = self.scratch.output_conv(path_1) 104 | 105 | return torch.squeeze(out, dim=1) 106 | 107 | 108 | 109 | def fuse_model(m): 110 | prev_previous_type = nn.Identity() 111 | prev_previous_name = '' 112 | previous_type = nn.Identity() 113 | previous_name = '' 114 | for name, module in m.named_modules(): 115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: 116 | # print("FUSED ", prev_previous_name, previous_name, name) 117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) 118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: 119 | # print("FUSED ", prev_previous_name, previous_name) 120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) 121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: 122 | # print("FUSED ", previous_name, name) 123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) 124 | 125 | prev_previous_type = previous_type 126 | prev_previous_name = previous_name 127 | previous_type = type(module) 128 | previous_name = name -------------------------------------------------------------------------------- /src/ldm/modules/midas/utils.py: -------------------------------------------------------------------------------- 1 | """Utils for monoDepth.""" 2 | import sys 3 | import re 4 | import numpy as np 5 | import cv2 6 | import torch 7 | 8 | 9 | def read_pfm(path): 10 | """Read pfm file. 11 | 12 | Args: 13 | path (str): path to file 14 | 15 | Returns: 16 | tuple: (data, scale) 17 | """ 18 | with open(path, "rb") as file: 19 | 20 | color = None 21 | width = None 22 | height = None 23 | scale = None 24 | endian = None 25 | 26 | header = file.readline().rstrip() 27 | if header.decode("ascii") == "PF": 28 | color = True 29 | elif header.decode("ascii") == "Pf": 30 | color = False 31 | else: 32 | raise Exception("Not a PFM file: " + path) 33 | 34 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) 35 | if dim_match: 36 | width, height = list(map(int, dim_match.groups())) 37 | else: 38 | raise Exception("Malformed PFM header.") 39 | 40 | scale = float(file.readline().decode("ascii").rstrip()) 41 | if scale < 0: 42 | # little-endian 43 | endian = "<" 44 | scale = -scale 45 | else: 46 | # big-endian 47 | endian = ">" 48 | 49 | data = np.fromfile(file, endian + "f") 50 | shape = (height, width, 3) if color else (height, width) 51 | 52 | data = np.reshape(data, shape) 53 | data = np.flipud(data) 54 | 55 | return data, scale 56 | 57 | 58 | def write_pfm(path, image, scale=1): 59 | """Write pfm file. 60 | 61 | Args: 62 | path (str): pathto file 63 | image (array): data 64 | scale (int, optional): Scale. Defaults to 1. 65 | """ 66 | 67 | with open(path, "wb") as file: 68 | color = None 69 | 70 | if image.dtype.name != "float32": 71 | raise Exception("Image dtype must be float32.") 72 | 73 | image = np.flipud(image) 74 | 75 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 76 | color = True 77 | elif ( 78 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 79 | ): # greyscale 80 | color = False 81 | else: 82 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") 83 | 84 | file.write("PF\n" if color else "Pf\n".encode()) 85 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) 86 | 87 | endian = image.dtype.byteorder 88 | 89 | if endian == "<" or endian == "=" and sys.byteorder == "little": 90 | scale = -scale 91 | 92 | file.write("%f\n".encode() % scale) 93 | 94 | image.tofile(file) 95 | 96 | 97 | def read_image(path): 98 | """Read image and output RGB image (0-1). 99 | 100 | Args: 101 | path (str): path to file 102 | 103 | Returns: 104 | array: RGB image (0-1) 105 | """ 106 | img = cv2.imread(path) 107 | 108 | if img.ndim == 2: 109 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 110 | 111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 112 | 113 | return img 114 | 115 | 116 | def resize_image(img): 117 | """Resize image and make it fit for network. 118 | 119 | Args: 120 | img (array): image 121 | 122 | Returns: 123 | tensor: data ready for network 124 | """ 125 | height_orig = img.shape[0] 126 | width_orig = img.shape[1] 127 | 128 | if width_orig > height_orig: 129 | scale = width_orig / 384 130 | else: 131 | scale = height_orig / 384 132 | 133 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int) 134 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int) 135 | 136 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) 137 | 138 | img_resized = ( 139 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() 140 | ) 141 | img_resized = img_resized.unsqueeze(0) 142 | 143 | return img_resized 144 | 145 | 146 | def resize_depth(depth, width, height): 147 | """Resize depth map and bring to CPU (numpy). 148 | 149 | Args: 150 | depth (tensor): depth 151 | width (int): image width 152 | height (int): image height 153 | 154 | Returns: 155 | array: processed depth 156 | """ 157 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu") 158 | 159 | depth_resized = cv2.resize( 160 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC 161 | ) 162 | 163 | return depth_resized 164 | 165 | def write_depth(path, depth, bits=1): 166 | """Write depth map to pfm and png file. 167 | 168 | Args: 169 | path (str): filepath without extension 170 | depth (array): depth 171 | """ 172 | write_pfm(path + ".pfm", depth.astype(np.float32)) 173 | 174 | depth_min = depth.min() 175 | depth_max = depth.max() 176 | 177 | max_val = (2**(8*bits))-1 178 | 179 | if depth_max - depth_min > np.finfo("float").eps: 180 | out = max_val * (depth - depth_min) / (depth_max - depth_min) 181 | else: 182 | out = np.zeros(depth.shape, dtype=depth.type) 183 | 184 | if bits == 1: 185 | cv2.imwrite(path + ".png", out.astype("uint8")) 186 | elif bits == 2: 187 | cv2.imwrite(path + ".png", out.astype("uint16")) 188 | 189 | return 190 | -------------------------------------------------------------------------------- /src/ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import numpy as np 5 | from collections import abc 6 | from einops import rearrange 7 | from functools import partial 8 | 9 | import multiprocessing as mp 10 | from threading import Thread 11 | from queue import Queue 12 | 13 | from inspect import isfunction 14 | from PIL import Image, ImageDraw, ImageFont 15 | 16 | 17 | def log_txt_as_img(wh, xc, size=10): 18 | # wh a tuple of (width, height) 19 | # xc a list of captions to plot 20 | b = len(xc) 21 | txts = list() 22 | for bi in range(b): 23 | txt = Image.new('RGB', wh, color='white') 24 | draw = ImageDraw.Draw(txt) 25 | font = ImageFont.load_default() 26 | nc = int(40 * (wh[0] / 256)) 27 | lines = '\n'.join( 28 | xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc) 29 | ) 30 | 31 | try: 32 | draw.text((0, 0), lines, fill='black', font=font) 33 | except UnicodeEncodeError: 34 | print('Cant encode string for logging. Skipping.') 35 | 36 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 37 | txts.append(txt) 38 | txts = np.stack(txts) 39 | txts = torch.tensor(txts) 40 | return txts 41 | 42 | 43 | def ismap(x): 44 | if not isinstance(x, torch.Tensor): 45 | return False 46 | return (len(x.shape) == 4) and (x.shape[1] > 3) 47 | 48 | 49 | def isimage(x): 50 | if not isinstance(x, torch.Tensor): 51 | return False 52 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 53 | 54 | 55 | def exists(x): 56 | return x is not None 57 | 58 | 59 | def default(val, d): 60 | if exists(val): 61 | return val 62 | return d() if isfunction(d) else d 63 | 64 | 65 | def mean_flat(tensor): 66 | """ 67 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 68 | Take the mean over all non-batch dimensions. 69 | """ 70 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 71 | 72 | 73 | def count_params(model, verbose=False): 74 | total_params = sum(p.numel() for p in model.parameters()) 75 | if verbose: 76 | print( 77 | f'{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.' 78 | ) 79 | return total_params 80 | 81 | 82 | def instantiate_from_config(config, **kwargs): 83 | if not 'target' in config: 84 | if config == '__is_first_stage__': 85 | return None 86 | elif config == '__is_unconditional__': 87 | return None 88 | raise KeyError('Expected key `target` to instantiate.') 89 | return get_obj_from_str(config['target'])( 90 | **config.get('params', dict()), **kwargs 91 | ) 92 | 93 | 94 | def get_obj_from_str(string, reload=False): 95 | module, cls = string.rsplit('.', 1) 96 | if reload: 97 | module_imp = importlib.import_module(module) 98 | importlib.reload(module_imp) 99 | return getattr(importlib.import_module(module, package=None), cls) 100 | 101 | 102 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): 103 | # create dummy dataset instance 104 | 105 | # run prefetching 106 | if idx_to_fn: 107 | res = func(data, worker_id=idx) 108 | else: 109 | res = func(data) 110 | Q.put([idx, res]) 111 | Q.put('Done') 112 | 113 | 114 | def parallel_data_prefetch( 115 | func: callable, 116 | data, 117 | n_proc, 118 | target_data_type='ndarray', 119 | cpu_intensive=True, 120 | use_worker_id=False, 121 | ): 122 | # if target_data_type not in ["ndarray", "list"]: 123 | # raise ValueError( 124 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." 125 | # ) 126 | if isinstance(data, np.ndarray) and target_data_type == 'list': 127 | raise ValueError('list expected but function got ndarray.') 128 | elif isinstance(data, abc.Iterable): 129 | if isinstance(data, dict): 130 | print( 131 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' 132 | ) 133 | data = list(data.values()) 134 | if target_data_type == 'ndarray': 135 | data = np.asarray(data) 136 | else: 137 | data = list(data) 138 | else: 139 | raise TypeError( 140 | f'The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}.' 141 | ) 142 | 143 | if cpu_intensive: 144 | Q = mp.Queue(1000) 145 | proc = mp.Process 146 | else: 147 | Q = Queue(1000) 148 | proc = Thread 149 | # spawn processes 150 | if target_data_type == 'ndarray': 151 | arguments = [ 152 | [func, Q, part, i, use_worker_id] 153 | for i, part in enumerate(np.array_split(data, n_proc)) 154 | ] 155 | else: 156 | step = ( 157 | int(len(data) / n_proc + 1) 158 | if len(data) % n_proc != 0 159 | else int(len(data) / n_proc) 160 | ) 161 | arguments = [ 162 | [func, Q, part, i, use_worker_id] 163 | for i, part in enumerate( 164 | [data[i : i + step] for i in range(0, len(data), step)] 165 | ) 166 | ] 167 | processes = [] 168 | for i in range(n_proc): 169 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) 170 | processes += [p] 171 | 172 | # start processes 173 | print(f'Start prefetching...') 174 | import time 175 | 176 | start = time.time() 177 | gather_res = [[] for _ in range(n_proc)] 178 | try: 179 | for p in processes: 180 | p.start() 181 | 182 | k = 0 183 | while k < n_proc: 184 | # get result 185 | res = Q.get() 186 | if res == 'Done': 187 | k += 1 188 | else: 189 | gather_res[res[0]] = res[1] 190 | 191 | except Exception as e: 192 | print('Exception: ', e) 193 | for p in processes: 194 | p.terminate() 195 | 196 | raise e 197 | finally: 198 | for p in processes: 199 | p.join() 200 | print(f'Prefetching complete. [{time.time() - start} sec.]') 201 | 202 | if target_data_type == 'ndarray': 203 | if not isinstance(gather_res[0], np.ndarray): 204 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0) 205 | 206 | # order outputs 207 | return np.concatenate(gather_res, axis=0) 208 | elif target_data_type == 'list': 209 | out = [] 210 | for r in gather_res: 211 | out.extend(r) 212 | return out 213 | else: 214 | return gather_res 215 | -------------------------------------------------------------------------------- /src/txt2mask.py: -------------------------------------------------------------------------------- 1 | '''Makes available the Txt2Mask class, which assists in the automatic 2 | assignment of masks via text prompt using clipseg. 3 | 4 | Here is typical usage: 5 | 6 | from txt2mask import Txt2Mask # SegmentedGrayscale 7 | from PIL import Image 8 | 9 | txt2mask = Txt2Mask(self.device) 10 | segmented = txt2mask.segment(Image.open('/path/to/img.png'),'a bagel') 11 | 12 | # this will return a grayscale Image of the segmented data 13 | grayscale = segmented.to_grayscale() 14 | 15 | # this will return a semi-transparent image in which the 16 | # selected object(s) are opaque and the rest is at various 17 | # levels of transparency 18 | transparent = segmented.to_transparent() 19 | 20 | # this will return a masked image suitable for use in inpainting: 21 | mask = segmented.to_mask(threshold=0.5) 22 | 23 | The threshold used in the call to to_mask() selects pixels for use in 24 | the mask that exceed the indicated confidence threshold. Values range 25 | from 0.0 to 1.0. The higher the threshold, the more confident the 26 | algorithm is. In limited testing, I have found that values around 0.5 27 | work fine. 28 | ''' 29 | 30 | import os, sys 31 | import numpy as np 32 | from einops import rearrange, repeat 33 | from PIL import Image, ImageOps 34 | 35 | import torch 36 | from torchvision import transforms 37 | 38 | CLIP_VERSION = 'ViT-B/16' 39 | CLIPSEG_SIZE = 352 40 | 41 | sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)), 'xtra')) 42 | 43 | from clipseg.clipseg import CLIPDensePredT 44 | 45 | class SegmentedGrayscale(object): 46 | def __init__(self, image:Image, heatmap:torch.Tensor): 47 | self.heatmap = heatmap 48 | self.image = image 49 | 50 | def to_grayscale(self,invert:bool=False)->Image: 51 | return self._rescale(Image.fromarray(np.uint8(255 - self.heatmap * 255 if invert else self.heatmap * 255))) 52 | 53 | def to_mask(self,threshold:float=0.5)->Image: 54 | discrete_heatmap = self.heatmap.lt(threshold).int() 55 | return self._rescale(Image.fromarray(np.uint8(discrete_heatmap*255),mode='L')) 56 | 57 | def to_transparent(self,invert:bool=False)->Image: 58 | transparent_image = self.image.copy() 59 | # For img2img, we want the selected regions to be transparent, 60 | # but to_grayscale() returns the opposite. Thus invert. 61 | gs = self.to_grayscale(not invert) 62 | transparent_image.putalpha(gs) 63 | return transparent_image 64 | 65 | # unscales and uncrops the 352x352 heatmap so that it matches the image again 66 | def _rescale(self, heatmap:Image)->Image: 67 | size = self.image.width if (self.image.width > self.image.height) else self.image.height 68 | resized_image = heatmap.resize((size,size), resample=Image.Resampling.LANCZOS) 69 | return resized_image.crop((0,0,self.image.width,self.image.height)) 70 | 71 | class Txt2Mask(object): 72 | ''' Create new Txt2Mask object. The optional device argument can be one of 'cuda', 'mps' or 'cpu' ''' 73 | def __init__(self, model_path='models/clipseg/rd64-uni.pth', device='cpu', refined=False): 74 | # print('>> Initializing clipseg model for text to mask inference') 75 | self.device = device 76 | self.model = CLIPDensePredT(version=CLIP_VERSION, reduce_dim=64, complex_trans_conv=refined) 77 | self.model.eval() 78 | # initially we keep everything in cpu to conserve space 79 | self.model.to('cpu') 80 | self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False) 81 | 82 | @torch.no_grad() 83 | def segment(self, image, prompt:str) -> SegmentedGrayscale: 84 | ''' 85 | Given a prompt string such as "a bagel", tries to identify the object in the 86 | provided image and returns a SegmentedGrayscale object in which the brighter 87 | pixels indicate where the object is inferred to be. 88 | ''' 89 | self._to_device(self.device) 90 | prompts = [prompt] # right now we operate on just a single prompt at a time 91 | 92 | transform = transforms.Compose([ 93 | transforms.ToTensor(), 94 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 95 | transforms.Resize((CLIPSEG_SIZE, CLIPSEG_SIZE)), # must be multiple of 64... 96 | ]) 97 | 98 | if type(image) is str: 99 | image = Image.open(image).convert('RGB') 100 | 101 | image = ImageOps.exif_transpose(image) 102 | img = self._scale_and_crop(image) 103 | img = transform(img).unsqueeze(0) 104 | 105 | preds = self.model(img.repeat(len(prompts),1,1,1), prompts)[0] 106 | heatmap = torch.sigmoid(preds[0][0]).cpu() 107 | self._to_device('cpu') 108 | return SegmentedGrayscale(image, heatmap) 109 | 110 | def _to_device(self, device): 111 | self.model.to(device) 112 | 113 | def _scale_and_crop(self, image:Image)->Image: 114 | scaled_image = Image.new('RGB', (CLIPSEG_SIZE, CLIPSEG_SIZE)) 115 | if image.width > image.height: # width is constraint 116 | scale = CLIPSEG_SIZE / image.width 117 | else: 118 | scale = CLIPSEG_SIZE / image.height 119 | scaled_image.paste(image.resize((int(scale * image.width), int(scale * image.height)), resample=Image.Resampling.LANCZOS),box=(0,0)) 120 | return scaled_image 121 | -------------------------------------------------------------------------------- /src/xtra/k_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from . import augmentation, config, evaluation, external, gns, layers, models, sampling, utils 2 | from .layers import Denoiser 3 | -------------------------------------------------------------------------------- /src/xtra/k_diffusion/augmentation.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | import math 3 | import operator 4 | 5 | import numpy as np 6 | from skimage import transform 7 | import torch 8 | from torch import nn 9 | 10 | 11 | def translate2d(tx, ty): 12 | mat = [[1, 0, tx], 13 | [0, 1, ty], 14 | [0, 0, 1]] 15 | return torch.tensor(mat, dtype=torch.float32) 16 | 17 | 18 | def scale2d(sx, sy): 19 | mat = [[sx, 0, 0], 20 | [ 0, sy, 0], 21 | [ 0, 0, 1]] 22 | return torch.tensor(mat, dtype=torch.float32) 23 | 24 | 25 | def rotate2d(theta): 26 | mat = [[torch.cos(theta), torch.sin(-theta), 0], 27 | [torch.sin(theta), torch.cos(theta), 0], 28 | [ 0, 0, 1]] 29 | return torch.tensor(mat, dtype=torch.float32) 30 | 31 | 32 | class KarrasAugmentationPipeline: 33 | def __init__(self, a_prob=0.12, a_scale=2**0.2, a_aniso=2**0.2, a_trans=1/8): 34 | self.a_prob = a_prob 35 | self.a_scale = a_scale 36 | self.a_aniso = a_aniso 37 | self.a_trans = a_trans 38 | 39 | def __call__(self, image): 40 | h, w = image.size 41 | mats = [translate2d(h / 2 - 0.5, w / 2 - 0.5)] 42 | 43 | # x-flip 44 | a0 = torch.randint(2, []).float() 45 | mats.append(scale2d(1 - 2 * a0, 1)) 46 | # y-flip 47 | do = (torch.rand([]) < self.a_prob).float() 48 | a1 = torch.randint(2, []).float() * do 49 | mats.append(scale2d(1, 1 - 2 * a1)) 50 | # scaling 51 | do = (torch.rand([]) < self.a_prob).float() 52 | a2 = torch.randn([]) * do 53 | mats.append(scale2d(self.a_scale ** a2, self.a_scale ** a2)) 54 | # rotation 55 | do = (torch.rand([]) < self.a_prob).float() 56 | a3 = (torch.rand([]) * 2 * math.pi - math.pi) * do 57 | mats.append(rotate2d(-a3)) 58 | # anisotropy 59 | do = (torch.rand([]) < self.a_prob).float() 60 | a4 = (torch.rand([]) * 2 * math.pi - math.pi) * do 61 | a5 = torch.randn([]) * do 62 | mats.append(rotate2d(a4)) 63 | mats.append(scale2d(self.a_aniso ** a5, self.a_aniso ** -a5)) 64 | mats.append(rotate2d(-a4)) 65 | # translation 66 | do = (torch.rand([]) < self.a_prob).float() 67 | a6 = torch.randn([]) * do 68 | a7 = torch.randn([]) * do 69 | mats.append(translate2d(self.a_trans * w * a6, self.a_trans * h * a7)) 70 | 71 | # form the transformation matrix and conditioning vector 72 | mats.append(translate2d(-h / 2 + 0.5, -w / 2 + 0.5)) 73 | mat = reduce(operator.matmul, mats) 74 | cond = torch.stack([a0, a1, a2, a3.cos() - 1, a3.sin(), a5 * a4.cos(), a5 * a4.sin(), a6, a7]) 75 | 76 | # apply the transformation 77 | image_orig = np.array(image, dtype=np.float32) / 255 78 | if image_orig.ndim == 2: 79 | image_orig = image_orig[..., None] 80 | tf = transform.AffineTransform(mat.numpy()) 81 | image = transform.warp(image_orig, tf.inverse, order=3, mode='reflect', cval=0.5, clip=False, preserve_range=True) 82 | image_orig = torch.as_tensor(image_orig).movedim(2, 0) * 2 - 1 83 | image = torch.as_tensor(image).movedim(2, 0) * 2 - 1 84 | return image, image_orig, cond 85 | 86 | 87 | class KarrasAugmentWrapper(nn.Module): 88 | def __init__(self, model): 89 | super().__init__() 90 | self.inner_model = model 91 | 92 | def forward(self, input, sigma, aug_cond=None, mapping_cond=None, **kwargs): 93 | if aug_cond is None: 94 | aug_cond = input.new_zeros([input.shape[0], 9]) 95 | if mapping_cond is None: 96 | mapping_cond = aug_cond 97 | else: 98 | mapping_cond = torch.cat([aug_cond, mapping_cond], dim=1) 99 | return self.inner_model(input, sigma, mapping_cond=mapping_cond, **kwargs) 100 | 101 | def set_skip_stages(self, skip_stages): 102 | return self.inner_model.set_skip_stages(skip_stages) 103 | 104 | def set_patch_size(self, patch_size): 105 | return self.inner_model.set_patch_size(patch_size) 106 | -------------------------------------------------------------------------------- /src/xtra/k_diffusion/config.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import json 3 | import math 4 | import warnings 5 | 6 | from jsonmerge import merge 7 | 8 | from . import augmentation, layers, models, utils 9 | 10 | 11 | def load_config(file): 12 | defaults = { 13 | 'model': { 14 | 'sigma_data': 1., 15 | 'patch_size': 1, 16 | 'dropout_rate': 0., 17 | 'augment_wrapper': True, 18 | 'augment_prob': 0., 19 | 'mapping_cond_dim': 0, 20 | 'unet_cond_dim': 0, 21 | 'cross_cond_dim': 0, 22 | 'cross_attn_depths': None, 23 | 'skip_stages': 0, 24 | 'has_variance': False, 25 | }, 26 | 'dataset': { 27 | 'type': 'imagefolder', 28 | }, 29 | 'optimizer': { 30 | 'type': 'adamw', 31 | 'lr': 1e-4, 32 | 'betas': [0.95, 0.999], 33 | 'eps': 1e-6, 34 | 'weight_decay': 1e-3, 35 | }, 36 | 'lr_sched': { 37 | 'type': 'inverse', 38 | 'inv_gamma': 20000., 39 | 'power': 1., 40 | 'warmup': 0.99, 41 | }, 42 | 'ema_sched': { 43 | 'type': 'inverse', 44 | 'power': 0.6667, 45 | 'max_value': 0.9999 46 | }, 47 | } 48 | config = json.load(file) 49 | return merge(defaults, config) 50 | 51 | 52 | def make_model(config): 53 | config = config['model'] 54 | assert config['type'] == 'image_v1' 55 | model = models.ImageDenoiserModelV1( 56 | config['input_channels'], 57 | config['mapping_out'], 58 | config['depths'], 59 | config['channels'], 60 | config['self_attn_depths'], 61 | config['cross_attn_depths'], 62 | patch_size=config['patch_size'], 63 | dropout_rate=config['dropout_rate'], 64 | mapping_cond_dim=config['mapping_cond_dim'] + (9 if config['augment_wrapper'] else 0), 65 | unet_cond_dim=config['unet_cond_dim'], 66 | cross_cond_dim=config['cross_cond_dim'], 67 | skip_stages=config['skip_stages'], 68 | has_variance=config['has_variance'], 69 | ) 70 | if config['augment_wrapper']: 71 | model = augmentation.KarrasAugmentWrapper(model) 72 | return model 73 | 74 | 75 | def make_denoiser_wrapper(config): 76 | config = config['model'] 77 | sigma_data = config.get('sigma_data', 1.) 78 | has_variance = config.get('has_variance', False) 79 | if not has_variance: 80 | return partial(layers.Denoiser, sigma_data=sigma_data) 81 | return partial(layers.DenoiserWithVariance, sigma_data=sigma_data) 82 | 83 | 84 | def make_sample_density(config): 85 | sd_config = config['sigma_sample_density'] 86 | sigma_data = config['sigma_data'] 87 | if sd_config['type'] == 'lognormal': 88 | loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc'] 89 | scale = sd_config['std'] if 'std' in sd_config else sd_config['scale'] 90 | return partial(utils.rand_log_normal, loc=loc, scale=scale) 91 | if sd_config['type'] == 'loglogistic': 92 | loc = sd_config['loc'] if 'loc' in sd_config else math.log(sigma_data) 93 | scale = sd_config['scale'] if 'scale' in sd_config else 0.5 94 | min_value = sd_config['min_value'] if 'min_value' in sd_config else 0. 95 | max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf') 96 | return partial(utils.rand_log_logistic, loc=loc, scale=scale, min_value=min_value, max_value=max_value) 97 | if sd_config['type'] == 'loguniform': 98 | min_value = sd_config['min_value'] if 'min_value' in sd_config else config['sigma_min'] 99 | max_value = sd_config['max_value'] if 'max_value' in sd_config else config['sigma_max'] 100 | return partial(utils.rand_log_uniform, min_value=min_value, max_value=max_value) 101 | if sd_config['type'] == 'v-diffusion': 102 | min_value = sd_config['min_value'] if 'min_value' in sd_config else 0. 103 | max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf') 104 | return partial(utils.rand_v_diffusion, sigma_data=sigma_data, min_value=min_value, max_value=max_value) 105 | if sd_config['type'] == 'split-lognormal': 106 | loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc'] 107 | scale_1 = sd_config['std_1'] if 'std_1' in sd_config else sd_config['scale_1'] 108 | scale_2 = sd_config['std_2'] if 'std_2' in sd_config else sd_config['scale_2'] 109 | return partial(utils.rand_split_log_normal, loc=loc, scale_1=scale_1, scale_2=scale_2) 110 | raise ValueError('Unknown sample density type') 111 | -------------------------------------------------------------------------------- /src/xtra/k_diffusion/evaluation.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from pathlib import Path 4 | 5 | from cleanfid.inception_torchscript import InceptionV3W 6 | import clip 7 | from resize_right import resize 8 | import torch 9 | from torch import nn 10 | from torch.nn import functional as F 11 | from torchvision import transforms 12 | from tqdm.auto import trange 13 | 14 | from . import utils 15 | 16 | 17 | class InceptionV3FeatureExtractor(nn.Module): 18 | def __init__(self, device='cpu'): 19 | super().__init__() 20 | path = Path(os.environ.get('XDG_CACHE_HOME', Path.home() / '.cache')) / 'k-diffusion' 21 | url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' 22 | digest = 'f58cb9b6ec323ed63459aa4fb441fe750cfe39fafad6da5cb504a16f19e958f4' 23 | utils.download_file(path / 'inception-2015-12-05.pt', url, digest) 24 | self.model = InceptionV3W(str(path), resize_inside=False).to(device) 25 | self.size = (299, 299) 26 | 27 | def forward(self, x): 28 | if x.shape[2:4] != self.size: 29 | x = resize(x, out_shape=self.size, pad_mode='reflect') 30 | if x.shape[1] == 1: 31 | x = torch.cat([x] * 3, dim=1) 32 | x = (x * 127.5 + 127.5).clamp(0, 255) 33 | return self.model(x) 34 | 35 | 36 | class CLIPFeatureExtractor(nn.Module): 37 | def __init__(self, name='ViT-L/14@336px', device='cpu'): 38 | super().__init__() 39 | self.model = clip.load(name, device=device)[0].eval().requires_grad_(False) 40 | self.normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), 41 | std=(0.26862954, 0.26130258, 0.27577711)) 42 | self.size = (self.model.visual.input_resolution, self.model.visual.input_resolution) 43 | 44 | def forward(self, x): 45 | if x.shape[2:4] != self.size: 46 | x = resize(x.add(1).div(2), out_shape=self.size, pad_mode='reflect').clamp(0, 1) 47 | x = self.normalize(x) 48 | x = self.model.encode_image(x).float() 49 | x = F.normalize(x) * x.shape[1] ** 0.5 50 | return x 51 | 52 | 53 | def compute_features(accelerator, sample_fn, extractor_fn, n, batch_size): 54 | n_per_proc = math.ceil(n / accelerator.num_processes) 55 | feats_all = [] 56 | try: 57 | for i in trange(0, n_per_proc, batch_size, disable=not accelerator.is_main_process): 58 | cur_batch_size = min(n - i, batch_size) 59 | samples = sample_fn(cur_batch_size)[:cur_batch_size] 60 | feats_all.append(accelerator.gather(extractor_fn(samples))) 61 | except StopIteration: 62 | pass 63 | return torch.cat(feats_all)[:n] 64 | 65 | 66 | def polynomial_kernel(x, y): 67 | d = x.shape[-1] 68 | dot = x @ y.transpose(-2, -1) 69 | return (dot / d + 1) ** 3 70 | 71 | 72 | def squared_mmd(x, y, kernel=polynomial_kernel): 73 | m = x.shape[-2] 74 | n = y.shape[-2] 75 | kxx = kernel(x, x) 76 | kyy = kernel(y, y) 77 | kxy = kernel(x, y) 78 | kxx_sum = kxx.sum([-1, -2]) - kxx.diagonal(dim1=-1, dim2=-2).sum(-1) 79 | kyy_sum = kyy.sum([-1, -2]) - kyy.diagonal(dim1=-1, dim2=-2).sum(-1) 80 | kxy_sum = kxy.sum([-1, -2]) 81 | term_1 = kxx_sum / m / (m - 1) 82 | term_2 = kyy_sum / n / (n - 1) 83 | term_3 = kxy_sum * 2 / m / n 84 | return term_1 + term_2 - term_3 85 | 86 | 87 | @utils.tf32_mode(matmul=False) 88 | def kid(x, y, max_size=5000): 89 | x_size, y_size = x.shape[0], y.shape[0] 90 | n_partitions = math.ceil(max(x_size / max_size, y_size / max_size)) 91 | total_mmd = x.new_zeros([]) 92 | for i in range(n_partitions): 93 | cur_x = x[round(i * x_size / n_partitions):round((i + 1) * x_size / n_partitions)] 94 | cur_y = y[round(i * y_size / n_partitions):round((i + 1) * y_size / n_partitions)] 95 | total_mmd = total_mmd + squared_mmd(cur_x, cur_y) 96 | return total_mmd / n_partitions 97 | 98 | 99 | class _MatrixSquareRootEig(torch.autograd.Function): 100 | @staticmethod 101 | def forward(ctx, a): 102 | vals, vecs = torch.linalg.eigh(a) 103 | ctx.save_for_backward(vals, vecs) 104 | return vecs @ vals.abs().sqrt().diag_embed() @ vecs.transpose(-2, -1) 105 | 106 | @staticmethod 107 | def backward(ctx, grad_output): 108 | vals, vecs = ctx.saved_tensors 109 | d = vals.abs().sqrt().unsqueeze(-1).repeat_interleave(vals.shape[-1], -1) 110 | vecs_t = vecs.transpose(-2, -1) 111 | return vecs @ (vecs_t @ grad_output @ vecs / (d + d.transpose(-2, -1))) @ vecs_t 112 | 113 | 114 | def sqrtm_eig(a): 115 | if a.ndim < 2: 116 | raise RuntimeError('tensor of matrices must have at least 2 dimensions') 117 | if a.shape[-2] != a.shape[-1]: 118 | raise RuntimeError('tensor must be batches of square matrices') 119 | return _MatrixSquareRootEig.apply(a) 120 | 121 | 122 | @utils.tf32_mode(matmul=False) 123 | def fid(x, y, eps=1e-8): 124 | x_mean = x.mean(dim=0) 125 | y_mean = y.mean(dim=0) 126 | mean_term = (x_mean - y_mean).pow(2).sum() 127 | x_cov = torch.cov(x.T) 128 | y_cov = torch.cov(y.T) 129 | eps_eye = torch.eye(x_cov.shape[0], device=x_cov.device, dtype=x_cov.dtype) * eps 130 | x_cov = x_cov + eps_eye 131 | y_cov = y_cov + eps_eye 132 | x_cov_sqrt = sqrtm_eig(x_cov) 133 | cov_term = torch.trace(x_cov + y_cov - 2 * sqrtm_eig(x_cov_sqrt @ y_cov @ x_cov_sqrt)) 134 | return mean_term + cov_term 135 | -------------------------------------------------------------------------------- /src/xtra/k_diffusion/external.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from . import sampling, utils 7 | 8 | 9 | class VDenoiser(nn.Module): 10 | """A v-diffusion-pytorch model wrapper for k-diffusion.""" 11 | 12 | def __init__(self, inner_model): 13 | super().__init__() 14 | self.inner_model = inner_model 15 | self.sigma_data = 1. 16 | 17 | def get_scalings(self, sigma): 18 | c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) 19 | c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 20 | c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 21 | return c_skip, c_out, c_in 22 | 23 | def sigma_to_t(self, sigma): 24 | return sigma.atan() / math.pi * 2 25 | 26 | def t_to_sigma(self, t): 27 | return (t * math.pi / 2).tan() 28 | 29 | def loss(self, input, noise, sigma, **kwargs): 30 | c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] 31 | noised_input = input + noise * utils.append_dims(sigma, input.ndim) 32 | model_output = self.inner_model(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) 33 | target = (input - c_skip * noised_input) / c_out 34 | return (model_output - target).pow(2).flatten(1).mean(1) 35 | 36 | def forward(self, input, sigma, **kwargs): 37 | c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] 38 | return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip 39 | 40 | 41 | class DiscreteSchedule(nn.Module): 42 | """A mapping between continuous noise levels (sigmas) and a list of discrete noise 43 | levels.""" 44 | 45 | def __init__(self, sigmas, quantize): 46 | super().__init__() 47 | self.register_buffer('sigmas', sigmas) 48 | self.register_buffer('log_sigmas', sigmas.log()) 49 | self.quantize = quantize 50 | 51 | @property 52 | def sigma_min(self): 53 | return self.sigmas[0] 54 | 55 | @property 56 | def sigma_max(self): 57 | return self.sigmas[-1] 58 | 59 | def get_sigmas(self, n=None): 60 | if n is None: 61 | return sampling.append_zero(self.sigmas.flip(0)) 62 | t_max = len(self.sigmas) - 1 63 | t = torch.linspace(t_max, 0, n, device=self.sigmas.device) 64 | return sampling.append_zero(self.t_to_sigma(t)) 65 | 66 | def sigma_to_t(self, sigma, quantize=None): 67 | quantize = self.quantize if quantize is None else quantize 68 | log_sigma = sigma.log() 69 | dists = log_sigma - self.log_sigmas[:, None] 70 | if quantize: 71 | return dists.abs().argmin(dim=0).view(sigma.shape) 72 | low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2) 73 | high_idx = low_idx + 1 74 | low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx] 75 | w = (low - log_sigma) / (low - high) 76 | w = w.clamp(0, 1) 77 | t = (1 - w) * low_idx + w * high_idx 78 | return t.view(sigma.shape) 79 | 80 | def t_to_sigma(self, t): 81 | t = t.float() 82 | low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() 83 | log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] 84 | return log_sigma.exp() 85 | 86 | 87 | class DiscreteEpsDDPMDenoiser(DiscreteSchedule): 88 | """A wrapper for discrete schedule DDPM models that output eps (the predicted 89 | noise).""" 90 | 91 | def __init__(self, model, alphas_cumprod, quantize): 92 | super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize) 93 | self.inner_model = model 94 | self.sigma_data = 1. 95 | 96 | def get_scalings(self, sigma): 97 | c_out = -sigma 98 | c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 99 | return c_out, c_in 100 | 101 | def get_eps(self, *args, **kwargs): 102 | return self.inner_model(*args, **kwargs) 103 | 104 | def loss(self, input, noise, sigma, **kwargs): 105 | c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] 106 | noised_input = input + noise * utils.append_dims(sigma, input.ndim) 107 | eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) 108 | return (eps - noise).pow(2).flatten(1).mean(1) 109 | 110 | def forward(self, input, sigma, **kwargs): 111 | c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] 112 | eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs) 113 | return input + eps * c_out 114 | 115 | 116 | class OpenAIDenoiser(DiscreteEpsDDPMDenoiser): 117 | """A wrapper for OpenAI diffusion models.""" 118 | 119 | def __init__(self, model, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'): 120 | alphas_cumprod = torch.tensor(diffusion.alphas_cumprod, device=device, dtype=torch.float32) 121 | super().__init__(model, alphas_cumprod, quantize=quantize) 122 | self.has_learned_sigmas = has_learned_sigmas 123 | 124 | def get_eps(self, *args, **kwargs): 125 | model_output = self.inner_model(*args, **kwargs) 126 | if self.has_learned_sigmas: 127 | return model_output.chunk(2, dim=1)[0] 128 | return model_output 129 | 130 | 131 | class CompVisDenoiser(DiscreteEpsDDPMDenoiser): 132 | """A wrapper for CompVis diffusion models.""" 133 | 134 | def __init__(self, model, quantize=False, device='cpu'): 135 | super().__init__(model, model.alphas_cumprod, quantize=quantize) 136 | 137 | def get_eps(self, *args, **kwargs): 138 | return self.inner_model.apply_model(*args, **kwargs) 139 | 140 | 141 | class DiscreteVDDPMDenoiser(DiscreteSchedule): 142 | """A wrapper for discrete schedule DDPM models that output v.""" 143 | 144 | def __init__(self, model, alphas_cumprod, quantize): 145 | super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize) 146 | self.inner_model = model 147 | self.sigma_data = 1. 148 | 149 | def get_scalings(self, sigma): 150 | c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) 151 | c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 152 | c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 153 | return c_skip, c_out, c_in 154 | 155 | def get_v(self, *args, **kwargs): 156 | return self.inner_model(*args, **kwargs) 157 | 158 | def loss(self, input, noise, sigma, **kwargs): 159 | c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] 160 | noised_input = input + noise * utils.append_dims(sigma, input.ndim) 161 | model_output = self.get_v(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) 162 | target = (input - c_skip * noised_input) / c_out 163 | return (model_output - target).pow(2).flatten(1).mean(1) 164 | 165 | def forward(self, input, sigma, **kwargs): 166 | c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] 167 | return self.get_v(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip 168 | 169 | 170 | class CompVisVDenoiser(DiscreteVDDPMDenoiser): 171 | """A wrapper for CompVis diffusion models that output v.""" 172 | 173 | def __init__(self, model, quantize=False, device='cpu'): 174 | super().__init__(model, model.alphas_cumprod, quantize=quantize) 175 | 176 | def get_v(self, x, t, cond, **kwargs): 177 | return self.inner_model.apply_model(x, t, cond) 178 | -------------------------------------------------------------------------------- /src/xtra/k_diffusion/gns.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class DDPGradientStatsHook: 6 | def __init__(self, ddp_module): 7 | try: 8 | ddp_module.register_comm_hook(self, self._hook_fn) 9 | except AttributeError: 10 | raise ValueError('DDPGradientStatsHook does not support non-DDP wrapped modules') 11 | self._clear_state() 12 | 13 | def _clear_state(self): 14 | self.bucket_sq_norms_small_batch = [] 15 | self.bucket_sq_norms_large_batch = [] 16 | 17 | @staticmethod 18 | def _hook_fn(self, bucket): 19 | buf = bucket.buffer() 20 | self.bucket_sq_norms_small_batch.append(buf.pow(2).sum()) 21 | fut = torch.distributed.all_reduce(buf, op=torch.distributed.ReduceOp.AVG, async_op=True).get_future() 22 | def callback(fut): 23 | buf = fut.value()[0] 24 | self.bucket_sq_norms_large_batch.append(buf.pow(2).sum()) 25 | return buf 26 | return fut.then(callback) 27 | 28 | def get_stats(self): 29 | sq_norm_small_batch = sum(self.bucket_sq_norms_small_batch) 30 | sq_norm_large_batch = sum(self.bucket_sq_norms_large_batch) 31 | self._clear_state() 32 | stats = torch.stack([sq_norm_small_batch, sq_norm_large_batch]) 33 | torch.distributed.all_reduce(stats, op=torch.distributed.ReduceOp.AVG) 34 | return stats[0].item(), stats[1].item() 35 | 36 | 37 | class GradientNoiseScale: 38 | """Calculates the gradient noise scale (1 / SNR), or critical batch size, 39 | from _An Empirical Model of Large-Batch Training_, 40 | https://arxiv.org/abs/1812.06162). 41 | 42 | Args: 43 | beta (float): The decay factor for the exponential moving averages used to 44 | calculate the gradient noise scale. 45 | Default: 0.9998 46 | eps (float): Added for numerical stability. 47 | Default: 1e-8 48 | """ 49 | 50 | def __init__(self, beta=0.9998, eps=1e-8): 51 | self.beta = beta 52 | self.eps = eps 53 | self.ema_sq_norm = 0. 54 | self.ema_var = 0. 55 | self.beta_cumprod = 1. 56 | self.gradient_noise_scale = float('nan') 57 | 58 | def state_dict(self): 59 | """Returns the state of the object as a :class:`dict`.""" 60 | return dict(self.__dict__.items()) 61 | 62 | def load_state_dict(self, state_dict): 63 | """Loads the object's state. 64 | Args: 65 | state_dict (dict): object state. Should be an object returned 66 | from a call to :meth:`state_dict`. 67 | """ 68 | self.__dict__.update(state_dict) 69 | 70 | def update(self, sq_norm_small_batch, sq_norm_large_batch, n_small_batch, n_large_batch): 71 | """Updates the state with a new batch's gradient statistics, and returns the 72 | current gradient noise scale. 73 | 74 | Args: 75 | sq_norm_small_batch (float): The mean of the squared 2-norms of microbatch or 76 | per sample gradients. 77 | sq_norm_large_batch (float): The squared 2-norm of the mean of the microbatch or 78 | per sample gradients. 79 | n_small_batch (int): The batch size of the individual microbatch or per sample 80 | gradients (1 if per sample). 81 | n_large_batch (int): The total batch size of the mean of the microbatch or 82 | per sample gradients. 83 | """ 84 | est_sq_norm = (n_large_batch * sq_norm_large_batch - n_small_batch * sq_norm_small_batch) / (n_large_batch - n_small_batch) 85 | est_var = (sq_norm_small_batch - sq_norm_large_batch) / (1 / n_small_batch - 1 / n_large_batch) 86 | self.ema_sq_norm = self.beta * self.ema_sq_norm + (1 - self.beta) * est_sq_norm 87 | self.ema_var = self.beta * self.ema_var + (1 - self.beta) * est_var 88 | self.beta_cumprod *= self.beta 89 | self.gradient_noise_scale = max(self.ema_var, self.eps) / max(self.ema_sq_norm, self.eps) 90 | return self.gradient_noise_scale 91 | 92 | def get_gns(self): 93 | """Returns the current gradient noise scale.""" 94 | return self.gradient_noise_scale 95 | 96 | def get_stats(self): 97 | """Returns the current (debiased) estimates of the squared mean gradient 98 | and gradient variance.""" 99 | return self.ema_sq_norm / (1 - self.beta_cumprod), self.ema_var / (1 - self.beta_cumprod) 100 | -------------------------------------------------------------------------------- /src/xtra/k_diffusion/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .image_v1 import ImageDenoiserModelV1 2 | -------------------------------------------------------------------------------- /src/xtra/open_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .factory import list_models, create_model, create_model_and_transforms, add_model_config 2 | from .loss import ClipLoss 3 | from .model import CLIP, CLIPTextCfg, CLIPVisionCfg, convert_weights_to_fp16, trace_model 4 | from .openai import load_openai_model, list_openai_models 5 | from .pretrained import list_pretrained, list_pretrained_tag_models, list_pretrained_model_tags,\ 6 | get_pretrained_url, download_pretrained 7 | from .tokenizer import SimpleTokenizer, tokenize 8 | from .transform import image_transform 9 | -------------------------------------------------------------------------------- /src/xtra/open_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/xtra/open_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /src/xtra/open_clip/factory.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import pathlib 5 | import re 6 | from copy import deepcopy 7 | from pathlib import Path 8 | from typing import Optional, Tuple 9 | 10 | import torch 11 | 12 | from .model import CLIP, convert_weights_to_fp16, resize_pos_embed 13 | from .openai import load_openai_model 14 | from .pretrained import get_pretrained_url, download_pretrained 15 | from .transform import image_transform 16 | 17 | 18 | _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] 19 | _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs 20 | 21 | 22 | def _natural_key(string_): 23 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 24 | 25 | 26 | def _rescan_model_configs(): 27 | global _MODEL_CONFIGS 28 | 29 | config_ext = ('.json',) 30 | config_files = [] 31 | for config_path in _MODEL_CONFIG_PATHS: 32 | if config_path.is_file() and config_path.suffix in config_ext: 33 | config_files.append(config_path) 34 | elif config_path.is_dir(): 35 | for ext in config_ext: 36 | config_files.extend(config_path.glob(f'*{ext}')) 37 | 38 | for cf in config_files: 39 | with open(cf, 'r') as f: 40 | model_cfg = json.load(f) 41 | if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')): 42 | _MODEL_CONFIGS[cf.stem] = model_cfg 43 | 44 | _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))} 45 | 46 | 47 | _rescan_model_configs() # initial populate of model config registry 48 | 49 | 50 | def load_state_dict(checkpoint_path: str, map_location='cpu'): 51 | checkpoint = torch.load(checkpoint_path, map_location=map_location) 52 | if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: 53 | state_dict = checkpoint['state_dict'] 54 | else: 55 | state_dict = checkpoint 56 | if next(iter(state_dict.items()))[0].startswith('module'): 57 | state_dict = {k[7:]: v for k, v in state_dict.items()} 58 | return state_dict 59 | 60 | 61 | def load_checkpoint(model, checkpoint_path, strict=True): 62 | state_dict = load_state_dict(checkpoint_path) 63 | resize_pos_embed(state_dict, model) 64 | incompatible_keys = model.load_state_dict(state_dict, strict=strict) 65 | return incompatible_keys 66 | 67 | 68 | def create_model( 69 | model_name: str, 70 | pretrained: str = '', 71 | precision: str = 'fp32', 72 | device: torch.device = torch.device('cpu'), 73 | jit: bool = False, 74 | force_quick_gelu: bool = False, 75 | pretrained_image: bool = False, 76 | ): 77 | model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names 78 | 79 | if pretrained.lower() == 'openai': 80 | logging.info(f'Loading pretrained {model_name} from OpenAI.') 81 | model = load_openai_model(model_name, device=device, jit=jit) 82 | # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372 83 | if precision == "amp" or precision == "fp32": 84 | model = model.float() 85 | else: 86 | if model_name in _MODEL_CONFIGS: 87 | logging.info(f'Loading {model_name} model config.') 88 | model_cfg = deepcopy(_MODEL_CONFIGS[model_name]) 89 | else: 90 | logging.error(f'Model config for {model_name} not found; available models {list_models()}.') 91 | raise RuntimeError(f'Model config for {model_name} not found.') 92 | 93 | if force_quick_gelu: 94 | # override for use of QuickGELU on non-OpenAI transformer models 95 | model_cfg["quick_gelu"] = True 96 | 97 | if pretrained_image: 98 | if 'timm_model_name' in model_cfg.get('vision_cfg', {}): 99 | # pretrained weight loading for timm models set via vision_cfg 100 | model_cfg['vision_cfg']['timm_model_pretrained'] = True 101 | else: 102 | assert False, 'pretrained image towers currently only supported for timm models' 103 | 104 | model = CLIP(**model_cfg) 105 | 106 | if pretrained: 107 | checkpoint_path = '' 108 | url = get_pretrained_url(model_name, pretrained) 109 | if url: 110 | checkpoint_path = download_pretrained(url) 111 | elif os.path.exists(pretrained): 112 | checkpoint_path = pretrained 113 | 114 | if checkpoint_path: 115 | logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') 116 | load_checkpoint(model, checkpoint_path) 117 | else: 118 | logging.warning(f'Pretrained weights ({pretrained}) not found for model {model_name}.') 119 | raise RuntimeError(f'Pretrained weights ({pretrained}) not found for model {model_name}.') 120 | 121 | model.to(device=device) 122 | if precision == "fp16": 123 | assert device.type != 'cpu' 124 | convert_weights_to_fp16(model) 125 | 126 | if jit: 127 | model = torch.jit.script(model) 128 | 129 | return model 130 | 131 | 132 | def create_model_and_transforms( 133 | model_name: str, 134 | pretrained: str = '', 135 | precision: str = 'fp32', 136 | device: torch.device = torch.device('cpu'), 137 | jit: bool = False, 138 | force_quick_gelu: bool = False, 139 | pretrained_image: bool = False, 140 | mean: Optional[Tuple[float, ...]] = None, 141 | std: Optional[Tuple[float, ...]] = None, 142 | ): 143 | model = create_model( 144 | model_name, pretrained, precision, device, jit, 145 | force_quick_gelu=force_quick_gelu, 146 | pretrained_image=pretrained_image) 147 | preprocess_train = image_transform(model.visual.image_size, is_train=True, mean=mean, std=std) 148 | preprocess_val = image_transform(model.visual.image_size, is_train=False, mean=mean, std=std) 149 | return model, preprocess_train, preprocess_val 150 | 151 | 152 | def list_models(): 153 | """ enumerate available model architectures based on config files """ 154 | return list(_MODEL_CONFIGS.keys()) 155 | 156 | 157 | def add_model_config(path): 158 | """ add model config path or file and update registry """ 159 | if not isinstance(path, Path): 160 | path = Path(path) 161 | _MODEL_CONFIG_PATHS.append(path) 162 | _rescan_model_configs() 163 | -------------------------------------------------------------------------------- /src/xtra/open_clip/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | try: 6 | import torch.distributed.nn 7 | from torch import distributed as dist 8 | has_distributed = True 9 | except ImportError: 10 | has_distributed = False 11 | 12 | try: 13 | import horovod.torch as hvd 14 | except ImportError: 15 | hvd = None 16 | 17 | 18 | def gather_features( 19 | image_features, 20 | text_features, 21 | local_loss=False, 22 | gather_with_grad=False, 23 | rank=0, 24 | world_size=1, 25 | use_horovod=False 26 | ): 27 | assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.' 28 | if use_horovod: 29 | assert hvd is not None, 'Please install horovod' 30 | if gather_with_grad: 31 | all_image_features = hvd.allgather(image_features) 32 | all_text_features = hvd.allgather(text_features) 33 | else: 34 | with torch.no_grad(): 35 | all_image_features = hvd.allgather(image_features) 36 | all_text_features = hvd.allgather(text_features) 37 | if not local_loss: 38 | # ensure grads for local rank when all_* features don't have a gradient 39 | gathered_image_features = list(all_image_features.chunk(world_size, dim=0)) 40 | gathered_text_features = list(all_text_features.chunk(world_size, dim=0)) 41 | gathered_image_features[rank] = image_features 42 | gathered_text_features[rank] = text_features 43 | all_image_features = torch.cat(gathered_image_features, dim=0) 44 | all_text_features = torch.cat(gathered_text_features, dim=0) 45 | else: 46 | # We gather tensors from all gpus 47 | if gather_with_grad: 48 | all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0) 49 | all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0) 50 | else: 51 | gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)] 52 | gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)] 53 | dist.all_gather(gathered_image_features, image_features) 54 | dist.all_gather(gathered_text_features, text_features) 55 | if not local_loss: 56 | # ensure grads for local rank when all_* features don't have a gradient 57 | gathered_image_features[rank] = image_features 58 | gathered_text_features[rank] = text_features 59 | all_image_features = torch.cat(gathered_image_features, dim=0) 60 | all_text_features = torch.cat(gathered_text_features, dim=0) 61 | 62 | return all_image_features, all_text_features 63 | 64 | 65 | class ClipLoss(nn.Module): 66 | 67 | def __init__( 68 | self, 69 | local_loss=False, 70 | gather_with_grad=False, 71 | cache_labels=False, 72 | rank=0, 73 | world_size=1, 74 | use_horovod=False, 75 | ): 76 | super().__init__() 77 | self.local_loss = local_loss 78 | self.gather_with_grad = gather_with_grad 79 | self.cache_labels = cache_labels 80 | self.rank = rank 81 | self.world_size = world_size 82 | self.use_horovod = use_horovod 83 | 84 | # cache state 85 | self.prev_num_logits = 0 86 | self.labels = {} 87 | 88 | def forward(self, image_features, text_features, logit_scale): 89 | device = image_features.device 90 | if self.world_size > 1: 91 | all_image_features, all_text_features = gather_features( 92 | image_features, text_features, 93 | self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) 94 | 95 | if self.local_loss: 96 | logits_per_image = logit_scale * image_features @ all_text_features.T 97 | logits_per_text = logit_scale * text_features @ all_image_features.T 98 | else: 99 | logits_per_image = logit_scale * all_image_features @ all_text_features.T 100 | logits_per_text = logits_per_image.T 101 | else: 102 | logits_per_image = logit_scale * image_features @ text_features.T 103 | logits_per_text = logit_scale * text_features @ image_features.T 104 | 105 | # calculated ground-truth and cache if enabled 106 | num_logits = logits_per_image.shape[0] 107 | if self.prev_num_logits != num_logits or device not in self.labels: 108 | labels = torch.arange(num_logits, device=device, dtype=torch.long) 109 | if self.world_size > 1 and self.local_loss: 110 | labels = labels + num_logits * self.rank 111 | if self.cache_labels: 112 | self.labels[device] = labels 113 | self.prev_num_logits = num_logits 114 | else: 115 | labels = self.labels[device] 116 | 117 | total_loss = ( 118 | F.cross_entropy(logits_per_image, labels) + 119 | F.cross_entropy(logits_per_text, labels) 120 | ) / 2 121 | return total_loss 122 | -------------------------------------------------------------------------------- /src/xtra/open_clip/model_configs/RN101-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 23, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } -------------------------------------------------------------------------------- /src/xtra/open_clip/model_configs/RN101.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 23, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /src/xtra/open_clip/model_configs/RN50-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 6, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /src/xtra/open_clip/model_configs/RN50.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 6, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /src/xtra/open_clip/model_configs/RN50x16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 384, 5 | "layers": [ 6 | 6, 7 | 8, 8 | 18, 9 | 8 10 | ], 11 | "width": 96, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 768, 18 | "heads": 12, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /src/xtra/open_clip/model_configs/RN50x4.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 288, 5 | "layers": [ 6 | 4, 7 | 6, 8 | 10, 9 | 6 10 | ], 11 | "width": 80, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 640, 18 | "heads": 10, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /src/xtra/open_clip/model_configs/ViT-B-16-plus-240.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 240, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/xtra/open_clip/model_configs/ViT-B-16-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/xtra/open_clip/model_configs/ViT-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/xtra/open_clip/model_configs/ViT-B-32-plus-256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 256, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/xtra/open_clip/model_configs/ViT-B-32-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 32 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /src/xtra/open_clip/model_configs/ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/xtra/open_clip/model_configs/ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | } 17 | } -------------------------------------------------------------------------------- /src/xtra/open_clip/model_configs/ViT-H-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 16 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | } 17 | } -------------------------------------------------------------------------------- /src/xtra/open_clip/model_configs/ViT-L-14-280.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 280, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/xtra/open_clip/model_configs/ViT-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/xtra/open_clip/model_configs/ViT-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/xtra/open_clip/model_configs/ViT-L-16-320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 320, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/xtra/open_clip/model_configs/ViT-L-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/xtra/open_clip/model_configs/ViT-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1024, 15 | "heads": 16, 16 | "layers": 24 17 | } 18 | } -------------------------------------------------------------------------------- /src/xtra/open_clip/model_configs/timm-efficientnetv2_rw_s.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "timm_model_name": "efficientnetv2_rw_s", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "abs_attn", 7 | "timm_proj": "", 8 | "image_size": 288 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 768, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /src/xtra/open_clip/model_configs/timm-resnet50d.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "resnet50d", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "abs_attn", 7 | "timm_proj": "", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/xtra/open_clip/model_configs/timm-resnetaa50d.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "resnetaa50d", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "abs_attn", 7 | "timm_proj": "", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/xtra/open_clip/model_configs/timm-resnetblur50.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "resnetblur50", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "abs_attn", 7 | "timm_proj": "", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/xtra/open_clip/model_configs/timm-swin_base_patch4_window7_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "swin_base_patch4_window7_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /src/xtra/open_clip/model_configs/timm-vit_base_patch16_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "vit_base_patch16_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /src/xtra/open_clip/model_configs/timm-vit_base_patch32_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "vit_base_patch32_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /src/xtra/open_clip/model_configs/timm-vit_small_patch16_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "vit_small_patch16_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /src/xtra/open_clip/openai.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import Union, List 9 | 10 | import torch 11 | 12 | from .model import build_model_from_openai_state_dict 13 | from .pretrained import get_pretrained_url, list_pretrained_tag_models, download_pretrained 14 | 15 | __all__ = ["list_openai_models", "load_openai_model"] 16 | 17 | 18 | def list_openai_models() -> List[str]: 19 | """Returns the names of available CLIP models""" 20 | return list_pretrained_tag_models('openai') 21 | 22 | 23 | def load_openai_model( 24 | name: str, 25 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", 26 | jit=True, 27 | ): 28 | """Load a CLIP model 29 | 30 | Parameters 31 | ---------- 32 | name : str 33 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 34 | device : Union[str, torch.device] 35 | The device to put the loaded model 36 | jit : bool 37 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 38 | 39 | Returns 40 | ------- 41 | model : torch.nn.Module 42 | The CLIP model 43 | preprocess : Callable[[PIL.Image], torch.Tensor] 44 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 45 | """ 46 | if get_pretrained_url(name, 'openai'): 47 | model_path = download_pretrained(get_pretrained_url(name, 'openai')) 48 | elif os.path.isfile(name): 49 | model_path = name 50 | else: 51 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") 52 | 53 | try: 54 | # loading JIT archive 55 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 56 | state_dict = None 57 | except RuntimeError: 58 | # loading saved state dict 59 | if jit: 60 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 61 | jit = False 62 | state_dict = torch.load(model_path, map_location="cpu") 63 | 64 | if not jit: 65 | try: 66 | model = build_model_from_openai_state_dict(state_dict or model.state_dict()).to(device) 67 | except KeyError: 68 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 69 | model = build_model_from_openai_state_dict(sd).to(device) 70 | 71 | if str(device) == "cpu": 72 | model.float() 73 | return model 74 | 75 | # patch the device names 76 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 77 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 78 | 79 | def patch_device(module): 80 | try: 81 | graphs = [module.graph] if hasattr(module, "graph") else [] 82 | except RuntimeError: 83 | graphs = [] 84 | 85 | if hasattr(module, "forward1"): 86 | graphs.append(module.forward1.graph) 87 | 88 | for graph in graphs: 89 | for node in graph.findAllNodes("prim::Constant"): 90 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 91 | node.copyAttributes(device_node) 92 | 93 | model.apply(patch_device) 94 | patch_device(model.encode_image) 95 | patch_device(model.encode_text) 96 | 97 | # patch dtype to float32 on CPU 98 | if str(device) == "cpu": 99 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 100 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 101 | float_node = float_input.node() 102 | 103 | def patch_float(module): 104 | try: 105 | graphs = [module.graph] if hasattr(module, "graph") else [] 106 | except RuntimeError: 107 | graphs = [] 108 | 109 | if hasattr(module, "forward1"): 110 | graphs.append(module.forward1.graph) 111 | 112 | for graph in graphs: 113 | for node in graph.findAllNodes("aten::to"): 114 | inputs = list(node.inputs()) 115 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 116 | if inputs[i].node()["value"] == 5: 117 | inputs[i].node().copyAttributes(float_node) 118 | 119 | model.apply(patch_float) 120 | patch_float(model.encode_image) 121 | patch_float(model.encode_text) 122 | model.float() 123 | 124 | # ensure image_size attr available at consistent location for both jit and non-jit 125 | model.visual.image_size = model.input_resolution.item() 126 | return model 127 | -------------------------------------------------------------------------------- /src/xtra/open_clip/src.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/xtra/open_clip/src.zip -------------------------------------------------------------------------------- /src/xtra/open_clip/test_simple.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from PIL import Image 4 | from open_clip import tokenizer 5 | import open_clip 6 | import os 7 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 8 | 9 | def test_inference(): 10 | model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32-quickgelu', pretrained='laion400m_e32') 11 | 12 | current_dir = os.path.dirname(os.path.realpath(__file__)) 13 | 14 | image = preprocess(Image.open(current_dir + "/../docs/CLIP.png")).unsqueeze(0) 15 | text = tokenizer.tokenize(["a diagram", "a dog", "a cat"]) 16 | 17 | with torch.no_grad(): 18 | image_features = model.encode_image(image) 19 | text_features = model.encode_text(text) 20 | 21 | text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) 22 | 23 | assert text_probs.cpu().numpy()[0].tolist() == [1.0, 0.0, 0.0] -------------------------------------------------------------------------------- /src/xtra/open_clip/timm_model.py: -------------------------------------------------------------------------------- 1 | """ timm model adapter 2 | 3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. 4 | """ 5 | from collections import OrderedDict 6 | 7 | import torch.nn as nn 8 | 9 | try: 10 | import timm 11 | from timm.models.layers import Mlp, to_2tuple 12 | from timm.models.layers.attention_pool2d import RotAttentionPool2d 13 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d 14 | except ImportError as e: 15 | timm = None 16 | 17 | from .utils import freeze_batch_norm_2d 18 | 19 | 20 | class TimmModel(nn.Module): 21 | """ timm model adapter 22 | # FIXME this adapter is a work in progress, may change in ways that break weight compat 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model_name, 28 | embed_dim, 29 | image_size=224, 30 | pool='avg', 31 | proj='linear', 32 | drop=0., 33 | pretrained=False): 34 | super().__init__() 35 | if timm is None: 36 | raise RuntimeError("Please `pip install timm` to use timm models.") 37 | 38 | self.image_size = to_2tuple(image_size) 39 | self.trunk = timm.create_model(model_name, pretrained=pretrained) 40 | feat_size = self.trunk.default_cfg.get('pool_size', None) 41 | feature_ndim = 1 if not feat_size else 2 42 | if pool in ('abs_attn', 'rot_attn'): 43 | assert feature_ndim == 2 44 | # if attn pooling used, remove both classifier and default pool 45 | self.trunk.reset_classifier(0, global_pool='') 46 | else: 47 | # reset global pool if pool config set, otherwise leave as network default 48 | reset_kwargs = dict(global_pool=pool) if pool else {} 49 | self.trunk.reset_classifier(0, **reset_kwargs) 50 | prev_chs = self.trunk.num_features 51 | 52 | head_layers = OrderedDict() 53 | if pool == 'abs_attn': 54 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) 55 | prev_chs = embed_dim 56 | elif pool == 'rot_attn': 57 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) 58 | prev_chs = embed_dim 59 | else: 60 | assert proj, 'projection layer needed if non-attention pooling is used.' 61 | 62 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used 63 | if proj == 'linear': 64 | head_layers['drop'] = nn.Dropout(drop) 65 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim) 66 | elif proj == 'mlp': 67 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop) 68 | 69 | self.head = nn.Sequential(head_layers) 70 | 71 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 72 | """ lock modules 73 | Args: 74 | unlocked_groups (int): leave last n layer groups unlocked (default: 0) 75 | """ 76 | if not unlocked_groups: 77 | # lock full model 78 | for param in self.trunk.parameters(): 79 | param.requires_grad = False 80 | if freeze_bn_stats: 81 | freeze_batch_norm_2d(self.trunk) 82 | else: 83 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change 84 | try: 85 | # FIXME import here until API stable and in an official release 86 | from timm.models.helpers import group_parameters, group_modules 87 | except ImportError: 88 | raise RuntimeError( 89 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') 90 | matcher = self.trunk.group_matcher() 91 | gparams = group_parameters(self.trunk, matcher) 92 | max_layer_id = max(gparams.keys()) 93 | max_layer_id = max_layer_id - unlocked_groups 94 | for group_idx in range(max_layer_id + 1): 95 | group = gparams[group_idx] 96 | for param in group: 97 | self.trunk.get_parameter(param).requires_grad = False 98 | if freeze_bn_stats: 99 | gmodules = group_modules(self.trunk, matcher, reverse=True) 100 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} 101 | freeze_batch_norm_2d(self.trunk, gmodules) 102 | 103 | def forward(self, x): 104 | x = self.trunk(x) 105 | x = self.head(x) 106 | return x 107 | -------------------------------------------------------------------------------- /src/xtra/open_clip/tokenizer.py: -------------------------------------------------------------------------------- 1 | """ CLIP tokenizer 2 | 3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | import gzip 6 | import html 7 | import os 8 | from functools import lru_cache 9 | from typing import Union, List 10 | 11 | import ftfy 12 | import regex as re 13 | import torch 14 | 15 | 16 | @lru_cache() 17 | def default_bpe(): 18 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 19 | 20 | 21 | @lru_cache() 22 | def bytes_to_unicode(): 23 | """ 24 | Returns list of utf-8 byte and a corresponding list of unicode strings. 25 | The reversible bpe codes work on unicode strings. 26 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 27 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 28 | This is a signficant percentage of your normal, say, 32K bpe vocab. 29 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 30 | And avoids mapping to whitespace/control characters the bpe code barfs on. 31 | """ 32 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 33 | cs = bs[:] 34 | n = 0 35 | for b in range(2**8): 36 | if b not in bs: 37 | bs.append(b) 38 | cs.append(2**8+n) 39 | n += 1 40 | cs = [chr(n) for n in cs] 41 | return dict(zip(bs, cs)) 42 | 43 | 44 | def get_pairs(word): 45 | """Return set of symbol pairs in a word. 46 | Word is represented as tuple of symbols (symbols being variable-length strings). 47 | """ 48 | pairs = set() 49 | prev_char = word[0] 50 | for char in word[1:]: 51 | pairs.add((prev_char, char)) 52 | prev_char = char 53 | return pairs 54 | 55 | 56 | def basic_clean(text): 57 | text = ftfy.fix_text(text) 58 | text = html.unescape(html.unescape(text)) 59 | return text.strip() 60 | 61 | 62 | def whitespace_clean(text): 63 | text = re.sub(r'\s+', ' ', text) 64 | text = text.strip() 65 | return text 66 | 67 | 68 | class SimpleTokenizer(object): 69 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): 70 | self.byte_encoder = bytes_to_unicode() 71 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 72 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 73 | merges = merges[1:49152-256-2+1] 74 | merges = [tuple(merge.split()) for merge in merges] 75 | vocab = list(bytes_to_unicode().values()) 76 | vocab = vocab + [v+'' for v in vocab] 77 | for merge in merges: 78 | vocab.append(''.join(merge)) 79 | if not special_tokens: 80 | special_tokens = ['', ''] 81 | else: 82 | special_tokens = ['', ''] + special_tokens 83 | vocab.extend(special_tokens) 84 | self.encoder = dict(zip(vocab, range(len(vocab)))) 85 | self.decoder = {v: k for k, v in self.encoder.items()} 86 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 87 | self.cache = {t:t for t in special_tokens} 88 | special = "|".join(special_tokens) 89 | self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 90 | 91 | self.vocab_size = len(self.encoder) 92 | self.all_special_ids = [self.encoder[t] for t in special_tokens] 93 | 94 | def bpe(self, token): 95 | if token in self.cache: 96 | return self.cache[token] 97 | word = tuple(token[:-1]) + ( token[-1] + '',) 98 | pairs = get_pairs(word) 99 | 100 | if not pairs: 101 | return token+'' 102 | 103 | while True: 104 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 105 | if bigram not in self.bpe_ranks: 106 | break 107 | first, second = bigram 108 | new_word = [] 109 | i = 0 110 | while i < len(word): 111 | try: 112 | j = word.index(first, i) 113 | new_word.extend(word[i:j]) 114 | i = j 115 | except: 116 | new_word.extend(word[i:]) 117 | break 118 | 119 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 120 | new_word.append(first+second) 121 | i += 2 122 | else: 123 | new_word.append(word[i]) 124 | i += 1 125 | new_word = tuple(new_word) 126 | word = new_word 127 | if len(word) == 1: 128 | break 129 | else: 130 | pairs = get_pairs(word) 131 | word = ' '.join(word) 132 | self.cache[token] = word 133 | return word 134 | 135 | def encode(self, text): 136 | bpe_tokens = [] 137 | text = whitespace_clean(basic_clean(text)).lower() 138 | for token in re.findall(self.pat, text): 139 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 140 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 141 | return bpe_tokens 142 | 143 | def decode(self, tokens): 144 | text = ''.join([self.decoder[token] for token in tokens]) 145 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 146 | return text 147 | 148 | 149 | _tokenizer = SimpleTokenizer() 150 | 151 | 152 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: 153 | """ 154 | Returns the tokenized representation of given input string(s) 155 | 156 | Parameters 157 | ---------- 158 | texts : Union[str, List[str]] 159 | An input string or a list of input strings to tokenize 160 | context_length : int 161 | The context length to use; all CLIP models use 77 as the context length 162 | 163 | Returns 164 | ------- 165 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 166 | """ 167 | if isinstance(texts, str): 168 | texts = [texts] 169 | 170 | sot_token = _tokenizer.encoder[""] 171 | eot_token = _tokenizer.encoder[""] 172 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 173 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 174 | 175 | for i, tokens in enumerate(all_tokens): 176 | if len(tokens) > context_length: 177 | tokens = tokens[:context_length] # Truncate 178 | tokens[-1] = eot_token 179 | result[i, :len(tokens)] = torch.tensor(tokens) 180 | 181 | return result 182 | -------------------------------------------------------------------------------- /src/xtra/open_clip/transform.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.transforms.functional as F 6 | 7 | 8 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ 9 | CenterCrop 10 | 11 | 12 | class ResizeMaxSize(nn.Module): 13 | 14 | def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0): 15 | super().__init__() 16 | if not isinstance(max_size, int): 17 | raise TypeError(f"Size should be int. Got {type(max_size)}") 18 | self.max_size = max_size 19 | self.interpolation = interpolation 20 | self.fn = min if fn == 'min' else min 21 | self.fill = fill 22 | 23 | def forward(self, img): 24 | if isinstance(img, torch.Tensor): 25 | height, width = img.shape[:2] 26 | else: 27 | width, height = img.size 28 | scale = self.max_size / float(max(height, width)) 29 | if scale != 1.0: 30 | new_size = tuple(round(dim * scale) for dim in (height, width)) 31 | img = F.resize(img, new_size, self.interpolation) 32 | pad_h = self.max_size - new_size[0] 33 | pad_w = self.max_size - new_size[1] 34 | img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill) 35 | return img 36 | 37 | 38 | def _convert_to_rgb(image): 39 | return image.convert('RGB') 40 | 41 | 42 | def image_transform( 43 | image_size: int, 44 | is_train: bool, 45 | mean: Optional[Tuple[float, ...]] = None, 46 | std: Optional[Tuple[float, ...]] = None, 47 | resize_longest_max: bool = False, 48 | fill_color: int = 0, 49 | ): 50 | mean = mean or (0.48145466, 0.4578275, 0.40821073) # OpenAI dataset mean 51 | std = std or (0.26862954, 0.26130258, 0.27577711) # OpenAI dataset std 52 | if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: 53 | # for square size, pass size as int so that Resize() uses aspect preserving shortest edge 54 | image_size = image_size[0] 55 | 56 | normalize = Normalize(mean=mean, std=std) 57 | if is_train: 58 | return Compose([ 59 | RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC), 60 | _convert_to_rgb, 61 | ToTensor(), 62 | normalize, 63 | ]) 64 | else: 65 | if resize_longest_max: 66 | transforms = [ 67 | ResizeMaxSize(image_size, fill=fill_color) 68 | ] 69 | else: 70 | transforms = [ 71 | Resize(image_size, interpolation=InterpolationMode.BICUBIC), 72 | CenterCrop(image_size), 73 | ] 74 | transforms.extend([ 75 | _convert_to_rgb, 76 | ToTensor(), 77 | normalize, 78 | ]) 79 | return Compose(transforms) 80 | -------------------------------------------------------------------------------- /src/xtra/open_clip/utils.py: -------------------------------------------------------------------------------- 1 | from itertools import repeat 2 | import collections.abc 3 | 4 | from torch import nn as nn 5 | from torchvision.ops.misc import FrozenBatchNorm2d 6 | 7 | 8 | def freeze_batch_norm_2d(module, module_match={}, name=''): 9 | """ 10 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is 11 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and 12 | returned. Otherwise, the module is walked recursively and submodules are converted in place. 13 | 14 | Args: 15 | module (torch.nn.Module): Any PyTorch module. 16 | module_match (dict): Dictionary of full module names to freeze (all if empty) 17 | name (str): Full module name (prefix) 18 | 19 | Returns: 20 | torch.nn.Module: Resulting module 21 | 22 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 23 | """ 24 | res = module 25 | is_match = True 26 | if module_match: 27 | is_match = name in module_match 28 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): 29 | res = FrozenBatchNorm2d(module.num_features) 30 | res.num_features = module.num_features 31 | res.affine = module.affine 32 | if module.affine: 33 | res.weight.data = module.weight.data.clone().detach() 34 | res.bias.data = module.bias.data.clone().detach() 35 | res.running_mean.data = module.running_mean.data 36 | res.running_var.data = module.running_var.data 37 | res.eps = module.eps 38 | else: 39 | for child_name, child in module.named_children(): 40 | full_child_name = '.'.join([name, child_name]) if name else child_name 41 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name) 42 | if new_child is not child: 43 | res.add_module(child_name, new_child) 44 | return res 45 | 46 | 47 | # From PyTorch internals 48 | def _ntuple(n): 49 | def parse(x): 50 | if isinstance(x, collections.abc.Iterable): 51 | return x 52 | return tuple(repeat(x, n)) 53 | return parse 54 | 55 | 56 | to_1tuple = _ntuple(1) 57 | to_2tuple = _ntuple(2) 58 | to_3tuple = _ntuple(3) 59 | to_4tuple = _ntuple(4) 60 | to_ntuple = lambda n, x: _ntuple(n)(x) 61 | -------------------------------------------------------------------------------- /src/xtra/open_clip/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '1.3.0' 2 | -------------------------------------------------------------------------------- /src/xtra/taming/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/xtra/taming/__init__.py -------------------------------------------------------------------------------- /src/xtra/taming/data.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/xtra/taming/data.zip -------------------------------------------------------------------------------- /src/xtra/taming/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n): 33 | return self.schedule(n) 34 | 35 | -------------------------------------------------------------------------------- /src/xtra/taming/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/xtra/taming/models/__init__.py -------------------------------------------------------------------------------- /src/xtra/taming/models/dummy_cond_stage.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | 4 | class DummyCondStage: 5 | def __init__(self, conditional_key): 6 | self.conditional_key = conditional_key 7 | self.train = None 8 | 9 | def eval(self): 10 | return self 11 | 12 | @staticmethod 13 | def encode(c: Tensor): 14 | return c, None, (None, None, c) 15 | 16 | @staticmethod 17 | def decode(c: Tensor): 18 | return c 19 | 20 | @staticmethod 21 | def to_rgb(c: Tensor): 22 | return c 23 | -------------------------------------------------------------------------------- /src/xtra/taming/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/xtra/taming/modules/__init__.py -------------------------------------------------------------------------------- /src/xtra/taming/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/xtra/taming/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /src/xtra/taming/modules/discriminator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/xtra/taming/modules/discriminator/__init__.py -------------------------------------------------------------------------------- /src/xtra/taming/modules/discriminator/model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | 4 | 5 | from taming.modules.util import ActNorm 6 | 7 | 8 | def weights_init(m): 9 | classname = m.__class__.__name__ 10 | if classname.find('Conv') != -1: 11 | nn.init.normal_(m.weight.data, 0.0, 0.02) 12 | elif classname.find('BatchNorm') != -1: 13 | nn.init.normal_(m.weight.data, 1.0, 0.02) 14 | nn.init.constant_(m.bias.data, 0) 15 | 16 | 17 | class NLayerDiscriminator(nn.Module): 18 | """Defines a PatchGAN discriminator as in Pix2Pix 19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 20 | """ 21 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 22 | """Construct a PatchGAN discriminator 23 | Parameters: 24 | input_nc (int) -- the number of channels in input images 25 | ndf (int) -- the number of filters in the last conv layer 26 | n_layers (int) -- the number of conv layers in the discriminator 27 | norm_layer -- normalization layer 28 | """ 29 | super(NLayerDiscriminator, self).__init__() 30 | if not use_actnorm: 31 | norm_layer = nn.BatchNorm2d 32 | else: 33 | norm_layer = ActNorm 34 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 35 | use_bias = norm_layer.func != nn.BatchNorm2d 36 | else: 37 | use_bias = norm_layer != nn.BatchNorm2d 38 | 39 | kw = 4 40 | padw = 1 41 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 42 | nf_mult = 1 43 | nf_mult_prev = 1 44 | for n in range(1, n_layers): # gradually increase the number of filters 45 | nf_mult_prev = nf_mult 46 | nf_mult = min(2 ** n, 8) 47 | sequence += [ 48 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 49 | norm_layer(ndf * nf_mult), 50 | nn.LeakyReLU(0.2, True) 51 | ] 52 | 53 | nf_mult_prev = nf_mult 54 | nf_mult = min(2 ** n_layers, 8) 55 | sequence += [ 56 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 57 | norm_layer(ndf * nf_mult), 58 | nn.LeakyReLU(0.2, True) 59 | ] 60 | 61 | sequence += [ 62 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 63 | self.main = nn.Sequential(*sequence) 64 | 65 | def forward(self, input): 66 | """Standard forward.""" 67 | return self.main(input) 68 | -------------------------------------------------------------------------------- /src/xtra/taming/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from taming.modules.losses.vqperceptual import DummyLoss 2 | 3 | -------------------------------------------------------------------------------- /src/xtra/taming/modules/losses/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models 6 | from collections import namedtuple 7 | 8 | from taming.util import get_ckpt_path 9 | 10 | 11 | class LPIPS(nn.Module): 12 | # Learned perceptual metric 13 | def __init__(self, use_dropout=True): 14 | super().__init__() 15 | self.scaling_layer = ScalingLayer() 16 | self.chns = [64, 128, 256, 512, 512] # vg16 features 17 | self.net = vgg16(pretrained=True, requires_grad=False) 18 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 19 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 20 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 21 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 22 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 23 | self.load_from_pretrained() 24 | for param in self.parameters(): 25 | param.requires_grad = False 26 | 27 | def load_from_pretrained(self, name="vgg_lpips"): 28 | ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips") 29 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 30 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 31 | 32 | @classmethod 33 | def from_pretrained(cls, name="vgg_lpips"): 34 | if name != "vgg_lpips": 35 | raise NotImplementedError 36 | model = cls() 37 | ckpt = get_ckpt_path(name) 38 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 39 | return model 40 | 41 | def forward(self, input, target): 42 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 43 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 44 | feats0, feats1, diffs = {}, {}, {} 45 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 46 | for kk in range(len(self.chns)): 47 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 48 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 49 | 50 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 51 | val = res[0] 52 | for l in range(1, len(self.chns)): 53 | val += res[l] 54 | return val 55 | 56 | 57 | class ScalingLayer(nn.Module): 58 | def __init__(self): 59 | super(ScalingLayer, self).__init__() 60 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 61 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 62 | 63 | def forward(self, inp): 64 | return (inp - self.shift) / self.scale 65 | 66 | 67 | class NetLinLayer(nn.Module): 68 | """ A single linear layer which does a 1x1 conv """ 69 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 70 | super(NetLinLayer, self).__init__() 71 | layers = [nn.Dropout(), ] if (use_dropout) else [] 72 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 73 | self.model = nn.Sequential(*layers) 74 | 75 | 76 | class vgg16(torch.nn.Module): 77 | def __init__(self, requires_grad=False, pretrained=True): 78 | super(vgg16, self).__init__() 79 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 80 | self.slice1 = torch.nn.Sequential() 81 | self.slice2 = torch.nn.Sequential() 82 | self.slice3 = torch.nn.Sequential() 83 | self.slice4 = torch.nn.Sequential() 84 | self.slice5 = torch.nn.Sequential() 85 | self.N_slices = 5 86 | for x in range(4): 87 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 88 | for x in range(4, 9): 89 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 90 | for x in range(9, 16): 91 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 92 | for x in range(16, 23): 93 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 94 | for x in range(23, 30): 95 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 96 | if not requires_grad: 97 | for param in self.parameters(): 98 | param.requires_grad = False 99 | 100 | def forward(self, X): 101 | h = self.slice1(X) 102 | h_relu1_2 = h 103 | h = self.slice2(h) 104 | h_relu2_2 = h 105 | h = self.slice3(h) 106 | h_relu3_3 = h 107 | h = self.slice4(h) 108 | h_relu4_3 = h 109 | h = self.slice5(h) 110 | h_relu5_3 = h 111 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 112 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 113 | return out 114 | 115 | 116 | def normalize_tensor(x,eps=1e-10): 117 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) 118 | return x/(norm_factor+eps) 119 | 120 | 121 | def spatial_average(x, keepdim=True): 122 | return x.mean([2,3],keepdim=keepdim) 123 | 124 | -------------------------------------------------------------------------------- /src/xtra/taming/modules/losses/segmentation.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class BCELoss(nn.Module): 6 | def forward(self, prediction, target): 7 | loss = F.binary_cross_entropy_with_logits(prediction,target) 8 | return loss, {} 9 | 10 | 11 | class BCELossWithQuant(nn.Module): 12 | def __init__(self, codebook_weight=1.): 13 | super().__init__() 14 | self.codebook_weight = codebook_weight 15 | 16 | def forward(self, qloss, target, prediction, split): 17 | bce_loss = F.binary_cross_entropy_with_logits(prediction,target) 18 | loss = bce_loss + self.codebook_weight*qloss 19 | return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(), 20 | "{}/bce_loss".format(split): bce_loss.detach().mean(), 21 | "{}/quant_loss".format(split): qloss.detach().mean() 22 | } 23 | -------------------------------------------------------------------------------- /src/xtra/taming/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from taming.modules.losses.lpips import LPIPS 6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | 8 | 9 | class DummyLoss(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | 14 | def adopt_weight(weight, global_step, threshold=0, value=0.): 15 | if global_step < threshold: 16 | weight = value 17 | return weight 18 | 19 | 20 | def hinge_d_loss(logits_real, logits_fake): 21 | loss_real = torch.mean(F.relu(1. - logits_real)) 22 | loss_fake = torch.mean(F.relu(1. + logits_fake)) 23 | d_loss = 0.5 * (loss_real + loss_fake) 24 | return d_loss 25 | 26 | 27 | def vanilla_d_loss(logits_real, logits_fake): 28 | d_loss = 0.5 * ( 29 | torch.mean(torch.nn.functional.softplus(-logits_real)) + 30 | torch.mean(torch.nn.functional.softplus(logits_fake))) 31 | return d_loss 32 | 33 | 34 | class VQLPIPSWithDiscriminator(nn.Module): 35 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 36 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 37 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 38 | disc_ndf=64, disc_loss="hinge"): 39 | super().__init__() 40 | assert disc_loss in ["hinge", "vanilla"] 41 | self.codebook_weight = codebook_weight 42 | self.pixel_weight = pixelloss_weight 43 | self.perceptual_loss = LPIPS().eval() 44 | self.perceptual_weight = perceptual_weight 45 | 46 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 47 | n_layers=disc_num_layers, 48 | use_actnorm=use_actnorm, 49 | ndf=disc_ndf 50 | ).apply(weights_init) 51 | self.discriminator_iter_start = disc_start 52 | if disc_loss == "hinge": 53 | self.disc_loss = hinge_d_loss 54 | elif disc_loss == "vanilla": 55 | self.disc_loss = vanilla_d_loss 56 | else: 57 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 58 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 59 | self.disc_factor = disc_factor 60 | self.discriminator_weight = disc_weight 61 | self.disc_conditional = disc_conditional 62 | 63 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 64 | if last_layer is not None: 65 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 66 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 67 | else: 68 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 69 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 70 | 71 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 72 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 73 | d_weight = d_weight * self.discriminator_weight 74 | return d_weight 75 | 76 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 77 | global_step, last_layer=None, cond=None, split="train"): 78 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 79 | if self.perceptual_weight > 0: 80 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 81 | rec_loss = rec_loss + self.perceptual_weight * p_loss 82 | else: 83 | p_loss = torch.tensor([0.0]) 84 | 85 | nll_loss = rec_loss 86 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 87 | nll_loss = torch.mean(nll_loss) 88 | 89 | # now the GAN part 90 | if optimizer_idx == 0: 91 | # generator update 92 | if cond is None: 93 | assert not self.disc_conditional 94 | logits_fake = self.discriminator(reconstructions.contiguous()) 95 | else: 96 | assert self.disc_conditional 97 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 98 | g_loss = -torch.mean(logits_fake) 99 | 100 | try: 101 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 102 | except RuntimeError: 103 | assert not self.training 104 | d_weight = torch.tensor(0.0) 105 | 106 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 107 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 108 | 109 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 110 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 111 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 112 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 113 | "{}/p_loss".format(split): p_loss.detach().mean(), 114 | "{}/d_weight".format(split): d_weight.detach(), 115 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 116 | "{}/g_loss".format(split): g_loss.detach().mean(), 117 | } 118 | return loss, log 119 | 120 | if optimizer_idx == 1: 121 | # second pass for discriminator update 122 | if cond is None: 123 | logits_real = self.discriminator(inputs.contiguous().detach()) 124 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 125 | else: 126 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 127 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 128 | 129 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 130 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 131 | 132 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 133 | "{}/logits_real".format(split): logits_real.detach().mean(), 134 | "{}/logits_fake".format(split): logits_fake.detach().mean() 135 | } 136 | return d_loss, log 137 | -------------------------------------------------------------------------------- /src/xtra/taming/modules/misc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/xtra/taming/modules/misc/__init__.py -------------------------------------------------------------------------------- /src/xtra/taming/modules/misc/coord.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class CoordStage(object): 4 | def __init__(self, n_embed, down_factor): 5 | self.n_embed = n_embed 6 | self.down_factor = down_factor 7 | 8 | def eval(self): 9 | return self 10 | 11 | def encode(self, c): 12 | """fake vqmodel interface""" 13 | assert 0.0 <= c.min() and c.max() <= 1.0 14 | b,ch,h,w = c.shape 15 | assert ch == 1 16 | 17 | c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor, 18 | mode="area") 19 | c = c.clamp(0.0, 1.0) 20 | c = self.n_embed*c 21 | c_quant = c.round() 22 | c_ind = c_quant.to(dtype=torch.long) 23 | 24 | info = None, None, c_ind 25 | return c_quant, None, info 26 | 27 | def decode(self, c): 28 | c = c/self.n_embed 29 | c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor, 30 | mode="nearest") 31 | return c 32 | -------------------------------------------------------------------------------- /src/xtra/taming/modules/transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/xtra/taming/modules/transformer/__init__.py -------------------------------------------------------------------------------- /src/xtra/taming/modules/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def count_params(model): 6 | total_params = sum(p.numel() for p in model.parameters()) 7 | return total_params 8 | 9 | 10 | class ActNorm(nn.Module): 11 | def __init__(self, num_features, logdet=False, affine=True, 12 | allow_reverse_init=False): 13 | assert affine 14 | super().__init__() 15 | self.logdet = logdet 16 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 17 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 18 | self.allow_reverse_init = allow_reverse_init 19 | 20 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) 21 | 22 | def initialize(self, input): 23 | with torch.no_grad(): 24 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 25 | mean = ( 26 | flatten.mean(1) 27 | .unsqueeze(1) 28 | .unsqueeze(2) 29 | .unsqueeze(3) 30 | .permute(1, 0, 2, 3) 31 | ) 32 | std = ( 33 | flatten.std(1) 34 | .unsqueeze(1) 35 | .unsqueeze(2) 36 | .unsqueeze(3) 37 | .permute(1, 0, 2, 3) 38 | ) 39 | 40 | self.loc.data.copy_(-mean) 41 | self.scale.data.copy_(1 / (std + 1e-6)) 42 | 43 | def forward(self, input, reverse=False): 44 | if reverse: 45 | return self.reverse(input) 46 | if len(input.shape) == 2: 47 | input = input[:,:,None,None] 48 | squeeze = True 49 | else: 50 | squeeze = False 51 | 52 | _, _, height, width = input.shape 53 | 54 | if self.training and self.initialized.item() == 0: 55 | self.initialize(input) 56 | self.initialized.fill_(1) 57 | 58 | h = self.scale * (input + self.loc) 59 | 60 | if squeeze: 61 | h = h.squeeze(-1).squeeze(-1) 62 | 63 | if self.logdet: 64 | log_abs = torch.log(torch.abs(self.scale)) 65 | logdet = height*width*torch.sum(log_abs) 66 | logdet = logdet * torch.ones(input.shape[0]).to(input) 67 | return h, logdet 68 | 69 | return h 70 | 71 | def reverse(self, output): 72 | if self.training and self.initialized.item() == 0: 73 | if not self.allow_reverse_init: 74 | raise RuntimeError( 75 | "Initializing ActNorm in reverse direction is " 76 | "disabled by default. Use allow_reverse_init=True to enable." 77 | ) 78 | else: 79 | self.initialize(output) 80 | self.initialized.fill_(1) 81 | 82 | if len(output.shape) == 2: 83 | output = output[:,:,None,None] 84 | squeeze = True 85 | else: 86 | squeeze = False 87 | 88 | h = output / self.scale - self.loc 89 | 90 | if squeeze: 91 | h = h.squeeze(-1).squeeze(-1) 92 | return h 93 | 94 | 95 | class AbstractEncoder(nn.Module): 96 | def __init__(self): 97 | super().__init__() 98 | 99 | def encode(self, *args, **kwargs): 100 | raise NotImplementedError 101 | 102 | 103 | class Labelator(AbstractEncoder): 104 | """Net2Net Interface for Class-Conditional Model""" 105 | def __init__(self, n_classes, quantize_interface=True): 106 | super().__init__() 107 | self.n_classes = n_classes 108 | self.quantize_interface = quantize_interface 109 | 110 | def encode(self, c): 111 | c = c[:,None] 112 | if self.quantize_interface: 113 | return c, None, [None, None, c.long()] 114 | return c 115 | 116 | 117 | class SOSProvider(AbstractEncoder): 118 | # for unconditional training 119 | def __init__(self, sos_token, quantize_interface=True): 120 | super().__init__() 121 | self.sos_token = sos_token 122 | self.quantize_interface = quantize_interface 123 | 124 | def encode(self, x): 125 | # get batch size from data and replicate sos_token 126 | c = torch.ones(x.shape[0], 1)*self.sos_token 127 | c = c.long().to(x.device) 128 | if self.quantize_interface: 129 | return c, None, [None, None, c] 130 | return c 131 | -------------------------------------------------------------------------------- /src/xtra/taming/modules/vqvae/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/xtra/taming/modules/vqvae/__init__.py -------------------------------------------------------------------------------- /src/xtra/taming/util.py: -------------------------------------------------------------------------------- 1 | import os, hashlib 2 | import requests 3 | from tqdm import tqdm 4 | 5 | URL_MAP = { 6 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" 7 | } 8 | 9 | CKPT_MAP = { 10 | "vgg_lpips": "vgg.pth" 11 | } 12 | 13 | MD5_MAP = { 14 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" 15 | } 16 | 17 | 18 | def download(url, local_path, chunk_size=1024): 19 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 20 | with requests.get(url, stream=True) as r: 21 | total_size = int(r.headers.get("content-length", 0)) 22 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 23 | with open(local_path, "wb") as f: 24 | for data in r.iter_content(chunk_size=chunk_size): 25 | if data: 26 | f.write(data) 27 | pbar.update(chunk_size) 28 | 29 | 30 | def md5_hash(path): 31 | with open(path, "rb") as f: 32 | content = f.read() 33 | return hashlib.md5(content).hexdigest() 34 | 35 | 36 | def get_ckpt_path(name, root, check=False): 37 | assert name in URL_MAP 38 | path = os.path.join(root, CKPT_MAP[name]) 39 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 40 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 41 | download(URL_MAP[name], path) 42 | md5 = md5_hash(path) 43 | assert md5 == MD5_MAP[name], md5 44 | return path 45 | 46 | 47 | class KeyNotFoundError(Exception): 48 | def __init__(self, cause, keys=None, visited=None): 49 | self.cause = cause 50 | self.keys = keys 51 | self.visited = visited 52 | messages = list() 53 | if keys is not None: 54 | messages.append("Key not found: {}".format(keys)) 55 | if visited is not None: 56 | messages.append("Visited: {}".format(visited)) 57 | messages.append("Cause:\n{}".format(cause)) 58 | message = "\n".join(messages) 59 | super().__init__(message) 60 | 61 | 62 | def retrieve( 63 | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False 64 | ): 65 | """Given a nested list or dict return the desired value at key expanding 66 | callable nodes if necessary and :attr:`expand` is ``True``. The expansion 67 | is done in-place. 68 | 69 | Parameters 70 | ---------- 71 | list_or_dict : list or dict 72 | Possibly nested list or dictionary. 73 | key : str 74 | key/to/value, path like string describing all keys necessary to 75 | consider to get to the desired value. List indices can also be 76 | passed here. 77 | splitval : str 78 | String that defines the delimiter between keys of the 79 | different depth levels in `key`. 80 | default : obj 81 | Value returned if :attr:`key` is not found. 82 | expand : bool 83 | Whether to expand callable nodes on the path or not. 84 | 85 | Returns 86 | ------- 87 | The desired value or if :attr:`default` is not ``None`` and the 88 | :attr:`key` is not found returns ``default``. 89 | 90 | Raises 91 | ------ 92 | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is 93 | ``None``. 94 | """ 95 | 96 | keys = key.split(splitval) 97 | 98 | success = True 99 | try: 100 | visited = [] 101 | parent = None 102 | last_key = None 103 | for key in keys: 104 | if callable(list_or_dict): 105 | if not expand: 106 | raise KeyNotFoundError( 107 | ValueError( 108 | "Trying to get past callable node with expand=False." 109 | ), 110 | keys=keys, 111 | visited=visited, 112 | ) 113 | list_or_dict = list_or_dict() 114 | parent[last_key] = list_or_dict 115 | 116 | last_key = key 117 | parent = list_or_dict 118 | 119 | try: 120 | if isinstance(list_or_dict, dict): 121 | list_or_dict = list_or_dict[key] 122 | else: 123 | list_or_dict = list_or_dict[int(key)] 124 | except (KeyError, IndexError, ValueError) as e: 125 | raise KeyNotFoundError(e, keys=keys, visited=visited) 126 | 127 | visited += [key] 128 | # final expansion of retrieved value 129 | if expand and callable(list_or_dict): 130 | list_or_dict = list_or_dict() 131 | parent[last_key] = list_or_dict 132 | except KeyNotFoundError as e: 133 | if default is None: 134 | raise e 135 | else: 136 | list_or_dict = default 137 | success = False 138 | 139 | if not pass_success: 140 | return list_or_dict 141 | else: 142 | return list_or_dict, success 143 | 144 | 145 | if __name__ == "__main__": 146 | config = {"keya": "a", 147 | "keyb": "b", 148 | "keyc": 149 | {"cc1": 1, 150 | "cc2": 2, 151 | } 152 | } 153 | from omegaconf import OmegaConf 154 | config = OmegaConf.create(config) 155 | print(config) 156 | retrieve(config, "keya") 157 | 158 | -------------------------------------------------------------------------------- /src/yaml/v1-finetune-custom.yaml: -------------------------------------------------------------------------------- 1 | # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion. 2 | # Adobe’s modifications are licensed under the Adobe Research License. 3 | 4 | model: 5 | base_learning_rate: 1.0e-05 6 | target: custom.model.CustomDiffusion 7 | params: 8 | linear_start: 0.00085 9 | linear_end: 0.0120 10 | num_timesteps_cond: 1 11 | log_every_t: 200 12 | timesteps: 1000 13 | first_stage_key: "image" 14 | cond_stage_key: "caption" 15 | image_size: 64 16 | channels: 4 17 | cond_stage_trainable: true # Note: different from the one we trained before 18 | add_token: True 19 | freeze_model: "crossattn-kv" 20 | conditioning_key: crossattn 21 | monitor: val/loss_simple_ema 22 | scale_factor: 0.18215 23 | use_ema: False 24 | 25 | unet_config: 26 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 27 | params: 28 | image_size: 64 # unused 29 | in_channels: 4 30 | out_channels: 4 31 | model_channels: 320 32 | attention_resolutions: [ 4, 2, 1 ] 33 | num_res_blocks: 2 34 | channel_mult: [ 1, 2, 4, 4 ] 35 | num_heads: 8 36 | use_spatial_transformer: True 37 | transformer_depth: 1 38 | context_dim: 768 39 | use_checkpoint: False 40 | legacy: False 41 | 42 | first_stage_config: 43 | target: ldm.models.autoencoder.AutoencoderKL 44 | params: 45 | embed_dim: 4 46 | monitor: val/rec_loss 47 | ddconfig: 48 | double_z: true 49 | z_channels: 4 50 | resolution: 256 51 | in_channels: 3 52 | out_ch: 3 53 | ch: 128 54 | ch_mult: 55 | - 1 56 | - 2 57 | - 4 58 | - 4 59 | num_res_blocks: 2 60 | attn_resolutions: [] 61 | dropout: 0.0 62 | lossconfig: 63 | target: torch.nn.Identity 64 | 65 | cond_stage_config: 66 | target: custom.modules.FrozenCLIPEmbedderWrapper 67 | params: 68 | modifier_token: 69 | 70 | data: 71 | target: train.DataModuleFromConfig 72 | params: 73 | batch_size: 1 74 | num_workers: 2 75 | wrap: false 76 | train: 77 | target: custom.finetune_data.MaskBase 78 | params: 79 | size: 512 80 | train2: 81 | target: custom.finetune_data.MaskBase 82 | params: 83 | size: 512 84 | 85 | lightning: 86 | modelcheckpoint: # from usual finetune 87 | params: 88 | verbose: false 89 | save_last: true 90 | callbacks: 91 | image_logger: 92 | target: train.ImageLogger 93 | params: 94 | batch_frequency: 500 95 | max_images: 4 96 | clamp: True 97 | increase_log_steps: false 98 | 99 | trainer: 100 | max_steps: 1000 # for gpu=1 batch=1 [orig was 300] 101 | find_unused_parameters: False 102 | -------------------------------------------------------------------------------- /src/yaml/v1-finetune.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-03 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: true # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | embedding_reg_weight: 0.0 20 | 21 | personalization_config: 22 | target: ldm.modules.embedding_manager.EmbeddingManager 23 | params: 24 | placeholder_strings: ["*"] 25 | initializer_words: ["sculpture"] 26 | per_image_tokens: false 27 | num_vectors_per_token: 1 28 | progressive_words: False 29 | 30 | unet_config: 31 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 32 | params: 33 | image_size: 32 # unused 34 | in_channels: 4 35 | out_channels: 4 36 | model_channels: 320 37 | attention_resolutions: [ 4, 2, 1 ] 38 | num_res_blocks: 2 39 | channel_mult: [ 1, 2, 4, 4 ] 40 | num_heads: 8 41 | use_spatial_transformer: True 42 | transformer_depth: 1 43 | context_dim: 768 44 | use_checkpoint: True 45 | legacy: False 46 | 47 | first_stage_config: 48 | target: ldm.models.autoencoder.AutoencoderKL 49 | params: 50 | embed_dim: 4 51 | monitor: val/rec_loss 52 | ddconfig: 53 | double_z: true 54 | z_channels: 4 55 | resolution: 256 56 | in_channels: 3 57 | out_ch: 3 58 | ch: 128 59 | ch_mult: 60 | - 1 61 | - 2 62 | - 4 63 | - 4 64 | num_res_blocks: 2 65 | attn_resolutions: [] 66 | dropout: 0.0 67 | lossconfig: 68 | target: torch.nn.Identity 69 | 70 | cond_stage_config: 71 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 72 | 73 | data: 74 | target: train.DataModuleFromConfig 75 | params: 76 | batch_size: 1 77 | num_workers: 4 78 | wrap: false 79 | train: 80 | target: ldm.data.personalized.PersonalizedBase 81 | params: 82 | size: 512 83 | set: train 84 | per_image_tokens: false 85 | repeats: 10 # 100 86 | validation: 87 | target: ldm.data.personalized.PersonalizedBase 88 | params: 89 | size: 512 90 | set: val 91 | per_image_tokens: false 92 | # repeats: 10 93 | 94 | lightning: 95 | modelcheckpoint: 96 | params: 97 | every_n_train_steps: 500 98 | verbose: false 99 | callbacks: 100 | image_logger: 101 | target: train.ImageLogger 102 | params: 103 | batch_frequency: 500 104 | max_images: 2 105 | increase_log_steps: False 106 | 107 | trainer: 108 | benchmark: True 109 | max_steps: 5000 110 | find_unused_parameters: False 111 | -------------------------------------------------------------------------------- /src/yaml/v1-finetune_style.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-03 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: true # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | embedding_reg_weight: 0.0 20 | 21 | personalization_config: 22 | target: ldm.modules.embedding_manager.EmbeddingManager 23 | params: 24 | placeholder_strings: ["*"] 25 | initializer_words: ["painting"] 26 | per_image_tokens: false 27 | num_vectors_per_token: 1 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 71 | 72 | data: 73 | target: main.DataModuleFromConfig 74 | params: 75 | batch_size: 2 76 | num_workers: 16 77 | wrap: false 78 | train: 79 | target: ldm.data.personalized_style.PersonalizedBase 80 | params: 81 | size: 512 82 | set: train 83 | per_image_tokens: false 84 | repeats: 100 85 | validation: 86 | target: ldm.data.personalized_style.PersonalizedBase 87 | params: 88 | size: 512 89 | set: val 90 | per_image_tokens: false 91 | repeats: 10 92 | 93 | lightning: 94 | callbacks: 95 | image_logger: 96 | target: main.ImageLogger 97 | params: 98 | batch_frequency: 500 99 | max_images: 8 100 | increase_log_steps: False 101 | 102 | trainer: 103 | benchmark: True -------------------------------------------------------------------------------- /src/yaml/v1-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 10000 ] 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | personalization_config: 30 | target: ldm.modules.embedding_manager.EmbeddingManager 31 | params: 32 | placeholder_strings: ["*"] 33 | initializer_words: [""] 34 | per_image_tokens: false 35 | num_vectors_per_token: 1 36 | progressive_words: False 37 | 38 | unet_config: 39 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 40 | params: 41 | image_size: 32 # unused 42 | in_channels: 4 43 | out_channels: 4 44 | model_channels: 320 45 | attention_resolutions: [ 4, 2, 1 ] 46 | num_res_blocks: 2 47 | channel_mult: [ 1, 2, 4, 4 ] 48 | num_heads: 8 49 | use_spatial_transformer: True 50 | transformer_depth: 1 51 | context_dim: 768 52 | use_checkpoint: True 53 | legacy: False 54 | 55 | first_stage_config: 56 | target: ldm.models.autoencoder.AutoencoderKL 57 | params: 58 | embed_dim: 4 59 | monitor: val/rec_loss 60 | ddconfig: 61 | double_z: true 62 | z_channels: 4 63 | resolution: 256 64 | in_channels: 3 65 | out_ch: 3 66 | ch: 128 67 | ch_mult: 68 | - 1 69 | - 2 70 | - 4 71 | - 4 72 | num_res_blocks: 2 73 | attn_resolutions: [] 74 | dropout: 0.0 75 | lossconfig: 76 | target: torch.nn.Identity 77 | 78 | cond_stage_config: 79 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 80 | -------------------------------------------------------------------------------- /src/yaml/v1-inpainting.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 7.5e-05 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: hybrid # important 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | personalization_config: 30 | target: ldm.modules.embedding_manager.EmbeddingManager 31 | params: 32 | placeholder_strings: ["*"] 33 | initializer_words: [""] 34 | per_image_tokens: false 35 | num_vectors_per_token: 1 36 | progressive_words: False 37 | 38 | unet_config: 39 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 40 | params: 41 | image_size: 32 # unused 42 | in_channels: 9 # 4 data + 4 downscaled image + 1 mask 43 | out_channels: 4 44 | model_channels: 320 45 | attention_resolutions: [ 4, 2, 1 ] 46 | num_res_blocks: 2 47 | channel_mult: [ 1, 2, 4, 4 ] 48 | num_heads: 8 49 | use_spatial_transformer: True 50 | transformer_depth: 1 51 | context_dim: 768 52 | use_checkpoint: True 53 | legacy: False 54 | 55 | first_stage_config: 56 | target: ldm.models.autoencoder.AutoencoderKL 57 | params: 58 | embed_dim: 4 59 | monitor: val/rec_loss 60 | ddconfig: 61 | double_z: true 62 | z_channels: 4 63 | resolution: 256 64 | in_channels: 3 65 | out_ch: 3 66 | ch: 128 67 | ch_mult: 68 | - 1 69 | - 2 70 | - 4 71 | - 4 72 | num_res_blocks: 2 73 | attn_resolutions: [] 74 | dropout: 0.0 75 | lossconfig: 76 | target: torch.nn.Identity 77 | 78 | cond_stage_config: 79 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 80 | -------------------------------------------------------------------------------- /src/yaml/v2-inference-v.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | parameterization: "v" 6 | linear_start: 0.00085 7 | linear_end: 0.0120 8 | num_timesteps_cond: 1 9 | log_every_t: 200 10 | timesteps: 1000 11 | first_stage_key: "jpg" 12 | cond_stage_key: "txt" 13 | image_size: 64 14 | channels: 4 15 | cond_stage_trainable: false # Note: different from the one we trained before 16 | conditioning_key: crossattn 17 | monitor: val/loss_simple_ema 18 | scale_factor: 0.18215 19 | use_ema: False 20 | 21 | scheduler_config: # 10000 warmup steps 22 | target: ldm.lr_scheduler.LambdaLinearScheduler 23 | params: 24 | warm_up_steps: [ 10000 ] 25 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 26 | f_start: [ 1.e-6 ] 27 | f_max: [ 1. ] 28 | f_min: [ 1. ] 29 | 30 | personalization_config: 31 | target: ldm.modules.embedding_manager.EmbeddingManager 32 | params: 33 | placeholder_strings: ["*"] 34 | initializer_words: [""] 35 | per_image_tokens: false 36 | num_vectors_per_token: 1 37 | progressive_words: False 38 | 39 | unet_config: 40 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 41 | params: 42 | use_checkpoint: True 43 | use_fp16: True 44 | image_size: 32 # unused 45 | in_channels: 4 46 | out_channels: 4 47 | model_channels: 320 48 | attention_resolutions: [ 4, 2, 1 ] 49 | num_res_blocks: 2 50 | channel_mult: [ 1, 2, 4, 4 ] 51 | num_head_channels: 64 # need to fix for flash-attn 52 | use_spatial_transformer: True 53 | use_linear_in_transformer: True 54 | transformer_depth: 1 55 | context_dim: 1024 56 | legacy: False 57 | 58 | first_stage_config: 59 | target: ldm.models.autoencoder.AutoencoderKL 60 | params: 61 | embed_dim: 4 62 | monitor: val/rec_loss 63 | ddconfig: 64 | #attn_type: "vanilla-xformers" 65 | double_z: true 66 | z_channels: 4 67 | resolution: 256 68 | in_channels: 3 69 | out_ch: 3 70 | ch: 128 71 | ch_mult: 72 | - 1 73 | - 2 74 | - 4 75 | - 4 76 | num_res_blocks: 2 77 | attn_resolutions: [] 78 | dropout: 0.0 79 | lossconfig: 80 | target: torch.nn.Identity 81 | 82 | cond_stage_config: 83 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 84 | params: 85 | freeze: True 86 | layer: "penultimate" 87 | 88 | -------------------------------------------------------------------------------- /src/yaml/v2-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 10000 ] 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | personalization_config: 30 | target: ldm.modules.embedding_manager.EmbeddingManager 31 | params: 32 | placeholder_strings: ["*"] 33 | initializer_words: [""] 34 | per_image_tokens: false 35 | num_vectors_per_token: 1 36 | progressive_words: False 37 | 38 | unet_config: 39 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 40 | params: 41 | use_checkpoint: True 42 | use_fp16: True 43 | image_size: 32 # unused 44 | in_channels: 4 45 | out_channels: 4 46 | model_channels: 320 47 | attention_resolutions: [ 4, 2, 1 ] 48 | num_res_blocks: 2 49 | channel_mult: [ 1, 2, 4, 4 ] 50 | num_head_channels: 64 # need to fix for flash-attn 51 | use_spatial_transformer: True 52 | use_linear_in_transformer: True 53 | transformer_depth: 1 54 | context_dim: 1024 55 | legacy: False 56 | 57 | first_stage_config: 58 | target: ldm.models.autoencoder.AutoencoderKL 59 | params: 60 | embed_dim: 4 61 | monitor: val/rec_loss 62 | ddconfig: 63 | #attn_type: "vanilla-xformers" 64 | double_z: true 65 | z_channels: 4 66 | resolution: 256 67 | in_channels: 3 68 | out_ch: 3 69 | ch: 128 70 | ch_mult: 71 | - 1 72 | - 2 73 | - 4 74 | - 4 75 | num_res_blocks: 2 76 | attn_resolutions: [] 77 | dropout: 0.0 78 | lossconfig: 79 | target: torch.nn.Identity 80 | 81 | cond_stage_config: 82 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 83 | params: 84 | freeze: True 85 | layer: "penultimate" 86 | 87 | -------------------------------------------------------------------------------- /src/yaml/v2-inpainting.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 7.5e-05 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: hybrid # important 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | personalization_config: 30 | target: ldm.modules.embedding_manager.EmbeddingManager 31 | params: 32 | placeholder_strings: ["*"] 33 | initializer_words: [""] 34 | per_image_tokens: false 35 | num_vectors_per_token: 1 36 | progressive_words: False 37 | 38 | unet_config: 39 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 40 | params: 41 | use_checkpoint: True 42 | image_size: 32 # unused 43 | in_channels: 9 44 | out_channels: 4 45 | model_channels: 320 46 | attention_resolutions: [ 4, 2, 1 ] 47 | num_res_blocks: 2 48 | channel_mult: [ 1, 2, 4, 4 ] 49 | num_head_channels: 64 # need to fix for flash-attn 50 | use_spatial_transformer: True 51 | use_linear_in_transformer: True 52 | transformer_depth: 1 53 | context_dim: 1024 54 | legacy: False 55 | 56 | first_stage_config: 57 | target: ldm.models.autoencoder.AutoencoderKL 58 | params: 59 | embed_dim: 4 60 | monitor: val/rec_loss 61 | ddconfig: 62 | #attn_type: "vanilla-xformers" 63 | double_z: true 64 | z_channels: 4 65 | resolution: 256 66 | in_channels: 3 67 | out_ch: 3 68 | ch: 128 69 | ch_mult: 70 | - 1 71 | - 2 72 | - 4 73 | - 4 74 | num_res_blocks: 2 75 | attn_resolutions: [ ] 76 | dropout: 0.0 77 | lossconfig: 78 | target: torch.nn.Identity 79 | 80 | cond_stage_config: 81 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 82 | params: 83 | freeze: True 84 | layer: "penultimate" 85 | -------------------------------------------------------------------------------- /src/yaml/v2-midas.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-07 3 | target: ldm.models.diffusion.ddpm.LatentDepth2ImageDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: hybrid 16 | scale_factor: 0.18215 17 | monitor: val/loss_simple_ema 18 | finetune_keys: null 19 | use_ema: False 20 | 21 | depth_stage_config: 22 | target: ldm.modules.midas.api.MiDaSInference 23 | params: 24 | model_type: "dpt_hybrid" 25 | 26 | personalization_config: 27 | target: ldm.modules.embedding_manager.EmbeddingManager 28 | params: 29 | placeholder_strings: ["*"] 30 | initializer_words: [""] 31 | per_image_tokens: false 32 | num_vectors_per_token: 1 33 | progressive_words: False 34 | 35 | unet_config: 36 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 37 | params: 38 | use_checkpoint: True 39 | image_size: 32 # unused 40 | in_channels: 5 41 | out_channels: 4 42 | model_channels: 320 43 | attention_resolutions: [ 4, 2, 1 ] 44 | num_res_blocks: 2 45 | channel_mult: [ 1, 2, 4, 4 ] 46 | num_head_channels: 64 # need to fix for flash-attn 47 | use_spatial_transformer: True 48 | use_linear_in_transformer: True 49 | transformer_depth: 1 50 | context_dim: 1024 51 | legacy: False 52 | 53 | first_stage_config: 54 | target: ldm.models.autoencoder.AutoencoderKL 55 | params: 56 | embed_dim: 4 57 | monitor: val/rec_loss 58 | ddconfig: 59 | #attn_type: "vanilla-xformers" 60 | double_z: true 61 | z_channels: 4 62 | resolution: 256 63 | in_channels: 3 64 | out_ch: 3 65 | ch: 128 66 | ch_mult: 67 | - 1 68 | - 2 69 | - 4 70 | - 4 71 | num_res_blocks: 2 72 | attn_resolutions: [ ] 73 | dropout: 0.0 74 | lossconfig: 75 | target: torch.nn.Identity 76 | 77 | cond_stage_config: 78 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 79 | params: 80 | freeze: True 81 | layer: "penultimate" 82 | 83 | 84 | -------------------------------------------------------------------------------- /train.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | if [%1]==[] goto help 3 | echo .. %1 4 | 5 | python src/train.py --token "%1" --term "%2" --data data/%1 ^ 6 | --reg_data data/%2 ^ 7 | %3 %4 %5 %6 %7 %8 %9 8 | 9 | goto end 10 | 11 | :help 12 | echo Usage: train "" category 13 | echo e.g.: train "" lady 14 | echo or: train "" pattern --style 15 | :end 16 | -------------------------------------------------------------------------------- /txt.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | set KMP_DUPLICATE_LIB_OK=TRUE 3 | if [%1]==[] goto help 4 | echo .. %1 5 | 6 | python src/_sdrun.py -v -t %1 ^ 7 | %2 %3 %4 %5 %6 %7 %8 %9 8 | 9 | goto end 10 | 11 | :help 12 | echo Usage: txt "text prompt" [...] 13 | echo or: txt textfile [...] 14 | :end 15 | -------------------------------------------------------------------------------- /walk.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | set KMP_DUPLICATE_LIB_OK=TRUE 3 | if [%1]==[] goto help 4 | echo .. %1 5 | 6 | python src/latwalk.py -v -t %1 ^ 7 | %2 %3 %4 %5 %6 %7 %8 %9 8 | 9 | ffmpeg -v warning -y -i _out\%~n1\%%06d.jpg _out\%~n1-%2%3%4%5%6%7%8%9.mp4 10 | goto end 11 | 12 | :help 13 | echo Usage: walk textfile [...] 14 | :end 15 | --------------------------------------------------------------------------------