├── .gitignore
├── LICENSE
├── README.md
├── _in
├── pix
│ ├── 6458524847_2f4c361183_k.jpg
│ ├── 8399166846_f6fb4e4b8e_k.jpg
│ ├── alex-iby-G_Pk4D9rMLs.jpg
│ ├── bench2.jpg
│ ├── bertrand-gabioud-CpuFzIsHYJ0.jpg
│ ├── billow926-12-Wc-Zgx6Y.jpg
│ ├── mask
│ │ ├── 6458524847_2f4c361183_k_mask.jpg
│ │ ├── 8399166846_f6fb4e4b8e_k_mask.jpg
│ │ ├── alex-iby-G_Pk4D9rMLs_mask.jpg
│ │ ├── bench2_mask.jpg
│ │ ├── bertrand-gabioud-CpuFzIsHYJ0_mask.jpg
│ │ ├── billow926-12-Wc-Zgx6Y_mask.jpg
│ │ ├── overture-creations-5sI6fQgYIuo_mask.jpg
│ │ └── photo-1583445095369-9c651e7e5d34_mask.jpg
│ ├── overture-creations-5sI6fQgYIuo.jpg
│ └── photo-1583445095369-9c651e7e5d34.jpg
└── something.jpg
├── download.py
├── img.bat
├── inpaint.bat
├── model_half.py
├── requirements.txt
├── src
├── _sdrun.py
├── custom
│ ├── __init__.py
│ ├── composenW.py
│ ├── compress.py
│ ├── convert.py
│ ├── finetune_data.py
│ ├── get_deltas.py
│ ├── model.py
│ └── modules.py
├── latwalk.py
├── ldm
│ ├── __init__.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── personalized.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── autoencoder.py
│ │ └── diffusion
│ │ │ ├── __init__.py
│ │ │ ├── ddim.py
│ │ │ └── ddpm.py
│ ├── modules
│ │ ├── __init__.py
│ │ ├── attention.py
│ │ ├── diffusionmodules
│ │ │ ├── __init__.py
│ │ │ ├── model.py
│ │ │ ├── openaimodel.py
│ │ │ └── util.py
│ │ ├── distributions
│ │ │ ├── __init__.py
│ │ │ └── distributions.py
│ │ ├── ema.py
│ │ ├── embedding_manager.py
│ │ ├── encoders
│ │ │ ├── __init__.py
│ │ │ └── modules.py
│ │ └── midas
│ │ │ ├── __init__.py
│ │ │ ├── api.py
│ │ │ ├── midas
│ │ │ ├── __init__.py
│ │ │ ├── base_model.py
│ │ │ ├── blocks.py
│ │ │ ├── dpt_depth.py
│ │ │ ├── midas_net.py
│ │ │ ├── midas_net_custom.py
│ │ │ ├── transforms.py
│ │ │ └── vit.py
│ │ │ └── utils.py
│ └── util.py
├── sampling.py
├── train.py
├── txt2mask.py
├── utils.py
├── xtra
│ ├── clipseg
│ │ ├── clipseg.py
│ │ └── vitseg.py
│ ├── k_diffusion
│ │ ├── __init__.py
│ │ ├── augmentation.py
│ │ ├── config.py
│ │ ├── evaluation.py
│ │ ├── external.py
│ │ ├── gns.py
│ │ ├── layers.py
│ │ ├── models
│ │ │ ├── __init__.py
│ │ │ └── image_v1.py
│ │ ├── sampling.py
│ │ └── utils.py
│ ├── open_clip
│ │ ├── __init__.py
│ │ ├── bpe_simple_vocab_16e6.txt.gz
│ │ ├── factory.py
│ │ ├── loss.py
│ │ ├── model.py
│ │ ├── model_configs
│ │ │ ├── RN101-quickgelu.json
│ │ │ ├── RN101.json
│ │ │ ├── RN50-quickgelu.json
│ │ │ ├── RN50.json
│ │ │ ├── RN50x16.json
│ │ │ ├── RN50x4.json
│ │ │ ├── ViT-B-16-plus-240.json
│ │ │ ├── ViT-B-16-plus.json
│ │ │ ├── ViT-B-16.json
│ │ │ ├── ViT-B-32-plus-256.json
│ │ │ ├── ViT-B-32-quickgelu.json
│ │ │ ├── ViT-B-32.json
│ │ │ ├── ViT-H-14.json
│ │ │ ├── ViT-H-16.json
│ │ │ ├── ViT-L-14-280.json
│ │ │ ├── ViT-L-14-336.json
│ │ │ ├── ViT-L-14.json
│ │ │ ├── ViT-L-16-320.json
│ │ │ ├── ViT-L-16.json
│ │ │ ├── ViT-g-14.json
│ │ │ ├── timm-efficientnetv2_rw_s.json
│ │ │ ├── timm-resnet50d.json
│ │ │ ├── timm-resnetaa50d.json
│ │ │ ├── timm-resnetblur50.json
│ │ │ ├── timm-swin_base_patch4_window7_224.json
│ │ │ ├── timm-vit_base_patch16_224.json
│ │ │ ├── timm-vit_base_patch32_224.json
│ │ │ └── timm-vit_small_patch16_224.json
│ │ ├── openai.py
│ │ ├── pretrained.py
│ │ ├── src.zip
│ │ ├── test_simple.py
│ │ ├── timm_model.py
│ │ ├── tokenizer.py
│ │ ├── transform.py
│ │ ├── utils.py
│ │ └── version.py
│ └── taming
│ │ ├── __init__.py
│ │ ├── data.zip
│ │ ├── lr_scheduler.py
│ │ ├── models
│ │ ├── __init__.py
│ │ ├── cond_transformer.py
│ │ ├── dummy_cond_stage.py
│ │ └── vqgan.py
│ │ ├── modules
│ │ ├── __init__.py
│ │ ├── diffusionmodules
│ │ │ ├── __init__.py
│ │ │ └── model.py
│ │ ├── discriminator
│ │ │ ├── __init__.py
│ │ │ └── model.py
│ │ ├── losses
│ │ │ ├── __init__.py
│ │ │ ├── lpips.py
│ │ │ ├── segmentation.py
│ │ │ └── vqperceptual.py
│ │ ├── misc
│ │ │ ├── __init__.py
│ │ │ └── coord.py
│ │ ├── transformer
│ │ │ ├── __init__.py
│ │ │ ├── mingpt.py
│ │ │ └── permuter.py
│ │ ├── util.py
│ │ └── vqvae
│ │ │ ├── __init__.py
│ │ │ └── quantize.py
│ │ └── util.py
└── yaml
│ ├── v1-finetune-custom.yaml
│ ├── v1-finetune.yaml
│ ├── v1-finetune_style.yaml
│ ├── v1-inference.yaml
│ ├── v1-inpainting.yaml
│ ├── v2-inference-v.yaml
│ ├── v2-inference.yaml
│ ├── v2-inpainting.yaml
│ └── v2-midas.yaml
├── train.bat
├── txt.bat
└── walk.bat
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib64/
18 | parts/
19 | sdist/
20 | var/
21 | wheels/
22 | pip-wheel-metadata/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 | db.sqlite3
61 | db.sqlite3-journal
62 |
63 | # Flask stuff:
64 | instance/
65 | .webassets-cache
66 |
67 | # Scrapy stuff:
68 | .scrapy
69 |
70 | # Sphinx documentation
71 | docs/_build/
72 |
73 | # PyBuilder
74 | target/
75 |
76 | # Jupyter Notebook
77 | .ipynb_checkpoints
78 |
79 | # IPython
80 | profile_default/
81 | ipython_config.py
82 |
83 | # pyenv
84 | .python-version
85 |
86 | # pipenv
87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
90 | # install all needed dependencies.
91 | #Pipfile.lock
92 |
93 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
94 | __pypackages__/
95 |
96 | # Celery stuff
97 | celerybeat-schedule
98 | celerybeat.pid
99 |
100 | # SageMath parsed files
101 | *.sage.py
102 |
103 | # Environments
104 | .env
105 | .venv
106 | env/
107 | venv/
108 | ENV/
109 | env.bak/
110 | venv.bak/
111 |
112 | # Spyder project settings
113 | .spyderproject
114 | .spyproject
115 |
116 | # Rope project settings
117 | .ropeproject
118 |
119 | # mkdocs documentation
120 | /site
121 |
122 | # mypy
123 | .mypy_cache/
124 | .dmypy.json
125 | dmypy.json
126 |
127 | # Pyre type checker
128 | .pyre/
129 |
130 | embed/
131 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 vadim epstein
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Stable Diffusion for studies
2 |
3 |

4 |
5 | This is yet another Stable Diffusion compilation, aimed to be functional, clean & compact enough for various experiments. There's no GUI here, as the target audience are creative coders rather than post-Photoshop users. For the latter one may check [InvokeAI] or [AUTOMATIC1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) as a convenient production tool, or [Deforum] for precisely controlled animations.
6 |
7 | The code is based on the [CompVis] and [Stability AI] libraries and heavily borrows from [this repo](https://github.com/AmericanPresidentJimmyCarter/stable-diffusion), with occasional additions from [InvokeAI] and [Deforum], as well as the others mentioned below. The following codebases are partially included here (to ensure compatibility and the ease of setup): [k-diffusion](https://github.com/crowsonkb/k-diffusion), [Taming Transformers](https://github.com/CompVis/taming-transformers), [OpenCLIP], [CLIPseg].
8 | **There is also a [similar repo](https://github.com/eps696/SD), based on the [diffusers] library, which is more logical and up-to-date.**
9 |
10 | Current functions:
11 | * Text to image
12 | * Image re- and in-painting
13 | * Latent interpolations (with text prompts and images)
14 |
15 | Fine-tuning with your images:
16 | * Add subject (new token) with [textual inversion]
17 | * Add subject (prompt embedding + Unet delta) with [custom diffusion]
18 |
19 | Other features:
20 | * Memory efficient with `xformers` (hi res on 6gb VRAM GPU)
21 | * Use of special depth/inpainting and v2 models
22 | * Masking with text via [CLIPseg]
23 | * Weighted multi-prompts
24 | * to be continued..
25 |
26 | More details and Colab version will follow.
27 |
28 | ## Setup
29 |
30 | Install CUDA 11.6. Setup the Conda environment:
31 | ```
32 | conda create -n SD python=3.10 numpy pillow
33 | activate SD
34 | pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116
35 | pip install -r requirements.txt
36 | ```
37 | Install `xformers` library to increase performance. It makes possible to run SD in any resolution on the lower grade hardware (e.g. videocards with 6gb VRAM). If you're on Windows, first ensure that you have Visual Studio 2019 installed.
38 | ```
39 | pip install git+https://github.com/facebookresearch/xformers.git
40 | ```
41 | Download Stable Diffusion ([1.5](https://huggingface.co/CompVis/stable-diffusion), [1.5-inpaint](https://huggingface.co/runwayml/stable-diffusion-inpainting), [2-inpaint](https://huggingface.co/stabilityai/stable-diffusion-2-inpainting), [2-depth](https://huggingface.co/stabilityai/stable-diffusion-2-depth), [2.1](https://huggingface.co/stabilityai/stable-diffusion-2-1-base), [2.1-v](https://huggingface.co/stabilityai/stable-diffusion-2-1), [OpenCLIP], [custom VAE](https://huggingface.co/stabilityai/sd-vae-ft-ema-original), [CLIPseg], [MiDaS](https://github.com/isl-org/MiDaS) models (mostly converted to `float16` for faster loading) by the command below. Licensing info is available on their webpages.
42 | ```
43 | python download.py
44 | ```
45 |
46 | ## Operations
47 |
48 | Examples of usage:
49 |
50 | * Generate an image from the text prompt:
51 | ```
52 | python src/_sdrun.py -t "hello world" --size 1024-576
53 | ```
54 | * Redraw an image with existing style embedding:
55 | ```
56 | python src/_sdrun.py -im _in/something.jpg -t ""
57 | ```
58 | * Redraw directory of images, keeping the basic forms intact:
59 | ```
60 | python src/_sdrun.py -im _in/pix -t "neon light glow" --model v2d
61 | ```
62 | * Inpaint directory of images with RunwayML model, turning humans into robots:
63 | ```
64 | python src/_sdrun.py -im _in/pix --mask "human, person" -t "steampunk robot" --model 15i
65 | ```
66 | * Make a video, interpolating between the lines of the text file:
67 | ```
68 | python src/latwalk.py -t yourfile.txt --size 1024-576
69 | ```
70 | * Same, with drawing over a masked image:
71 | ```
72 | python src/latwalk.py -t yourfile.txt -im _in/pix/bench2.jpg --mask _in/pix/mask/bench2_mask.jpg
73 | ```
74 | Check other options by running these scripts with `--help` option; try various models, samplers, noisers, etc.
75 | Text prompts may include either special tokens (e.g. ``) or weights (like `good prompt :1 | also good prompt :1 | bad prompt :-0.5`). The latter may degrade overall accuracy though.
76 | Interpolated videos may be further smoothed out with [FILM](https://github.com/google-research/frame-interpolation).
77 |
78 | There are also Windows bat-files, slightly simplifying and automating the commands.
79 |
80 | ## Fine-tuning
81 |
82 | * Train prompt embedding for a specific subject (e.g. cat) with [textual inversion]:
83 | ```
84 | python src/train.py --token mycat1 --term cat --data data/mycat1
85 | ```
86 | * Do the same with [custom diffusion]:
87 | ```
88 | python src/train.py --token mycat1 --term cat --data data/mycat1 --reg_data data/cat
89 | ```
90 | Results of the trainings above will be saved under `train` directory.
91 |
92 | Custom diffusion trains faster and can achieve impressive reproduction quality in the simple and similar prompts, but it can entirely lose the point if the prompt is too complex or aside from the original category. Result file is 73mb (can be compressed to ~16mb). Note that in that case you'll need both target reference images (`data/mycat1`) and more random images of similar subjects (`data/cat`). Apparently, you can generate the latter with SD itself.
93 | Textual inversion is more generic but stable. Its embeddings can also be easily combined without additional retraining. Result file is ~5kb.
94 |
95 | * Generate image with embedding from [textual inversion]. You'll need to rename the embedding file as your trained token (e.g. `mycat1.pt`), and point the path to its directory. Note that the token is hardcoded in the file, so you can't change it afterwards.
96 | ```
97 | python src/_sdrun.py -t "cosmic beast" --embeds train
98 | ```
99 | * Generate image with embedding from [custom diffusion]. You'll need to explicitly mention your new token (so you can name it differently here) and path to the trained delta file:
100 | ```
101 | python src/_sdrun.py -t "cosmic beast" --token_mod mycat1 --delta_ckpt train/delta-xxx.ckpt
102 | ```
103 | You can also run `python src/latwalk.py ...` with such arguments to make animations.
104 |
105 |
106 | ## Credits
107 |
108 | It's quite hard to mention all those who made the current revolution in visual creativity possible. Check the inline links above for some of the sources.
109 | Huge respect to the people behind [Stable Diffusion], [InvokeAI], [Deforum] and the whole open-source movement.
110 |
111 | [Stable Diffusion]:
112 | [CompVis]:
113 | [Stability AI]:
114 | [InvokeAI]:
115 | [Deforum]:
116 | [OpenCLIP]:
117 | [CLIPseg]:
118 | [textual inversion]:
119 | [custom diffusion]:
120 |
--------------------------------------------------------------------------------
/_in/pix/6458524847_2f4c361183_k.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/_in/pix/6458524847_2f4c361183_k.jpg
--------------------------------------------------------------------------------
/_in/pix/8399166846_f6fb4e4b8e_k.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/_in/pix/8399166846_f6fb4e4b8e_k.jpg
--------------------------------------------------------------------------------
/_in/pix/alex-iby-G_Pk4D9rMLs.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/_in/pix/alex-iby-G_Pk4D9rMLs.jpg
--------------------------------------------------------------------------------
/_in/pix/bench2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/_in/pix/bench2.jpg
--------------------------------------------------------------------------------
/_in/pix/bertrand-gabioud-CpuFzIsHYJ0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/_in/pix/bertrand-gabioud-CpuFzIsHYJ0.jpg
--------------------------------------------------------------------------------
/_in/pix/billow926-12-Wc-Zgx6Y.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/_in/pix/billow926-12-Wc-Zgx6Y.jpg
--------------------------------------------------------------------------------
/_in/pix/mask/6458524847_2f4c361183_k_mask.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/_in/pix/mask/6458524847_2f4c361183_k_mask.jpg
--------------------------------------------------------------------------------
/_in/pix/mask/8399166846_f6fb4e4b8e_k_mask.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/_in/pix/mask/8399166846_f6fb4e4b8e_k_mask.jpg
--------------------------------------------------------------------------------
/_in/pix/mask/alex-iby-G_Pk4D9rMLs_mask.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/_in/pix/mask/alex-iby-G_Pk4D9rMLs_mask.jpg
--------------------------------------------------------------------------------
/_in/pix/mask/bench2_mask.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/_in/pix/mask/bench2_mask.jpg
--------------------------------------------------------------------------------
/_in/pix/mask/bertrand-gabioud-CpuFzIsHYJ0_mask.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/_in/pix/mask/bertrand-gabioud-CpuFzIsHYJ0_mask.jpg
--------------------------------------------------------------------------------
/_in/pix/mask/billow926-12-Wc-Zgx6Y_mask.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/_in/pix/mask/billow926-12-Wc-Zgx6Y_mask.jpg
--------------------------------------------------------------------------------
/_in/pix/mask/overture-creations-5sI6fQgYIuo_mask.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/_in/pix/mask/overture-creations-5sI6fQgYIuo_mask.jpg
--------------------------------------------------------------------------------
/_in/pix/mask/photo-1583445095369-9c651e7e5d34_mask.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/_in/pix/mask/photo-1583445095369-9c651e7e5d34_mask.jpg
--------------------------------------------------------------------------------
/_in/pix/overture-creations-5sI6fQgYIuo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/_in/pix/overture-creations-5sI6fQgYIuo.jpg
--------------------------------------------------------------------------------
/_in/pix/photo-1583445095369-9c651e7e5d34.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/_in/pix/photo-1583445095369-9c651e7e5d34.jpg
--------------------------------------------------------------------------------
/_in/something.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/_in/something.jpg
--------------------------------------------------------------------------------
/download.py:
--------------------------------------------------------------------------------
1 | import os
2 | from tqdm import tqdm
3 | import urllib.request
4 |
5 | def download_model(url: str, root: str = "./models"):
6 | os.makedirs(root, exist_ok=True)
7 | filename = os.path.basename(url.split('?')[0])
8 | download_target = os.path.join(root, filename)
9 |
10 | if os.path.exists(download_target) and not os.path.isfile(download_target):
11 | raise RuntimeError(f"{download_target} exists and is not a regular file")
12 |
13 | if os.path.isfile(download_target):
14 | return download_target
15 |
16 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
17 | with tqdm(total=int(source.info().get("Content-Length")), ncols=64, unit='iB', unit_scale=True) as loop:
18 | while True:
19 | buffer = source.read(8192)
20 | if not buffer:
21 | break
22 | output.write(buffer)
23 | loop.update(len(buffer))
24 |
25 | return download_target
26 |
27 | # print(' downloading SD 1.4 model')
28 | # download_model("https://www.dropbox.com/s/2b3w0vysf485tpc/sd-v14-512-fp16.ckpt?dl=1", 'models')
29 | print(' downloading SD 1.5 model')
30 | download_model("https://www.dropbox.com/s/k9odmzadgyo9gdl/sd-v15-512-fp16.ckpt?dl=1", 'models')
31 | print(' downloading SD 1.5-inpainting model')
32 | download_model("https://www.dropbox.com/s/cc5usmoik43alcc/sd-v15-512-inpaint-fp16.ckpt?dl=1", 'models')
33 | print(' downloading SD 2-inpainting model')
34 | download_model("https://www.dropbox.com/s/kn9jhrkofsfqsae/sd-v2-512-inpaint-fp16.ckpt?dl=1", 'models')
35 | print(' downloading SD 2-depth model')
36 | download_model("https://www.dropbox.com/s/zrx5qfesb9jstsg/sd-v2-512-depth-fp16.ckpt?dl=1", 'models')
37 | print(' downloading SD 2.1 model')
38 | download_model("https://www.dropbox.com/s/m4v36h8tksqa2lk/sd-v21-512-fp16.ckpt?dl=1", 'models')
39 | print(' downloading SD 2.1-v model')
40 | download_model("https://www.dropbox.com/s/wjzh3l1szauz5ww/sd-v21v-768-fp16.ckpt?dl=1", 'models')
41 |
42 | print(' downloading OpenCLIP ViT-H-14-laion2B-s32B-b79K model')
43 | download_model("https://www.dropbox.com/s/7smohfi2ijdy1qm/laion2b_s32b_b79k-vit-h14.pt?dl=1", 'models/openclip')
44 | # download_model("https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin", 'models/openclip')
45 |
46 | print(' downloading SD VAE-ema model')
47 | download_model("https://www.dropbox.com/s/dv836z05lblkvkc/vae-ft-ema-560000.ckpt?dl=1", 'models')
48 | print(' downloading SD VAE-mse model')
49 | download_model("https://www.dropbox.com/s/jmxksbzyk9fls1y/vae-ft-mse-840000.ckpt?dl=1", 'models')
50 |
51 | print(' downloading CLIPseg model')
52 | download_model("https://www.dropbox.com/s/c0tduhr4g0al1cq/rd64-uni.pth?dl=1", 'models/clipseg')
53 | print(' downloading MiDaS depth model')
54 | download_model("https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt", 'models/depth')
55 |
56 |
--------------------------------------------------------------------------------
/img.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 | set KMP_DUPLICATE_LIB_OK=TRUE
3 | if [%1]==[] goto help
4 | echo .. %1 .. %2
5 |
6 | if exist _in\%~n1\* goto proc
7 | set seq=ok
8 | echo .. making source sequence
9 | mkdir _in\%~n1
10 | ffmpeg -y -v warning -i _in\%1 -q:v 2 _in\%~n1\%%06d.jpg
11 |
12 | :proc
13 | echo .. processing
14 | python src/_sdrun.py -v -im _in/%~n1 -o _out/%~n1 -t %2 ^
15 | %3 %4 %5 %6 %7 %8 %9
16 |
17 | if %seq%==ok goto seq
18 | goto end
19 |
20 | :seq
21 | ffmpeg -y -v warning -i _out\%~n1\%%06d.jpg _out\%~n1-%2-%3%4%5%6%7%8%9.mp4
22 | goto end
23 |
24 | :help
25 | echo Usage: img imagedir "text prompt" [...]
26 | echo or: img videofile "text prompt" [...]
27 | :end
28 |
--------------------------------------------------------------------------------
/inpaint.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 | set KMP_DUPLICATE_LIB_OK=TRUE
3 | if [%1]==[] goto help
4 | echo .. %1 .. %2 .. %3
5 |
6 | python src/_sdrun.py -v -im %1 -o _out/%~n1 --mask %2 -t %3 ^
7 | %4 %5 %6 %7 %8 %9
8 |
9 | goto end
10 |
11 | :help
12 | echo Usage: inpaint imagedir masksdir "text prompt" [...]
13 | echo e.g.: inpaint _in/pix _in/pix/mask "steampunk fantasy"
14 | echo or: inpaint _in/pix "human figure" "steampunk fantasy" -m 15i
15 | :end
16 |
--------------------------------------------------------------------------------
/model_half.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import collections
4 | import torch
5 |
6 | parser = argparse.ArgumentParser()
7 | parser.add_argument("--dir", '-d', default='./', help="directory with models")
8 | parser.add_argument("--ext", '-e', default=['ckpt','pt'], help="model extensions")
9 | a = parser.parse_args()
10 |
11 | def basename(file):
12 | return os.path.splitext(os.path.basename(file))[0]
13 |
14 | def file_list(path, ext=None):
15 | files = [os.path.join(path, f) for f in os.listdir(path)]
16 | if ext is not None:
17 | if isinstance(ext, list):
18 | files = [f for f in files if os.path.splitext(f.lower())[1][1:] in ext]
19 | elif isinstance(ext, str):
20 | files = [f for f in files if f.endswith(ext)]
21 | else:
22 | print(' Unknown extension/type for file list!')
23 | return sorted([f for f in files if os.path.isfile(f)])
24 |
25 | def float2half(data):
26 | for k in data:
27 | if isinstance(data[k], collections.abc.Mapping):
28 | data[k] = float2half(data[k])
29 | elif isinstance(data[k], list):
30 | data[k] = [float2half(x) for x in data[k]]
31 | else:
32 | if data[k] is not None and torch.is_tensor(data[k]) and data[k].type() == 'torch.FloatTensor':
33 | data[k] = data[k].half()
34 | return data
35 |
36 | models = file_list(a.dir, a.ext)
37 |
38 | for model_path in models:
39 | model = torch.load(model_path) # dict?
40 | model = float2half(model)
41 | file_out = basename(model_path) + '-half' + os.path.splitext(model_path)[-1]
42 | torch.save(model, file_out)
43 |
44 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | tqdm
3 | einops
4 | kornia
5 | omegaconf
6 | torchmetrics==0.6.0
7 | transformers
8 | requests
9 | opencv-python
10 | pytorch-lightning==1.4.2
11 | open-clip-torch==2.7.0
12 | git+https://github.com/openai/CLIP.git@main#egg=clip
13 | timm
14 | scikit-image
15 | jsonmerge
16 | clean-fid
17 | resize_right
18 | torchdiffeq
19 | torchsde
20 | ipywidgets
21 |
22 |
--------------------------------------------------------------------------------
/src/custom/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/custom/__init__.py
--------------------------------------------------------------------------------
/src/custom/compress.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Adobe Research. All rights reserved.
2 | # To view a copy of the license, visit LICENSE.md.
3 |
4 | import torch
5 | import argparse
6 |
7 |
8 | def compress(delta_ckpt, ckpt, diffuser=False, compression_ratio=0.6, device='cuda'):
9 | st = torch.load(f'{delta_ckpt}')
10 |
11 | if not diffuser:
12 | compressed_key = 'state_dict'
13 | compressed_st = {compressed_key: {}}
14 | pretrained_st = torch.load(ckpt)['state_dict']
15 | if 'embed' in st['state_dict']:
16 | compressed_st['state_dict']['embed'] = st['state_dict']['embed']
17 | del st['state_dict']['embed']
18 |
19 | st = st['state_dict']
20 | else:
21 | from diffusers import StableDiffusionPipeline
22 | compressed_key = 'unet'
23 | compressed_st = {compressed_key: {}}
24 | pretrained_st = StableDiffusionPipeline.from_pretrained(ckpt, torch_dtype=torch.float16).to("cuda")
25 | pretrained_st = pretrained_st.unet.state_dict()
26 | if 'modifier_token' in st:
27 | compressed_st['modifier_token'] = st['modifier_token']
28 | st = st['unet']
29 |
30 | print("getting compression")
31 | layers = list(st.keys())
32 | for name in layers:
33 | if 'to_k' in name or 'to_v' in name:
34 | W = st[name].to(device)
35 | Wpretrain = pretrained_st[name].clone().to(device)
36 | deltaW = W-Wpretrain
37 |
38 | u, s, vt = torch.linalg.svd(deltaW.clone())
39 |
40 | explain = 0
41 | all_ = (s).sum()
42 | for i, t in enumerate(s):
43 | explain += t/(all_)
44 | if explain > compression_ratio:
45 | break
46 |
47 | compressed_st[compressed_key][f'{name}'] = {}
48 | compressed_st[compressed_key][f'{name}']['u'] = (u[:, :i]@torch.diag(s)[:i, :i]).clone()
49 | compressed_st[compressed_key][f'{name}']['v'] = vt[:i].clone()
50 | else:
51 | compressed_st[compressed_key][f'{name}'] = st[name]
52 |
53 | name = delta_ckpt.replace('delta', 'compressed_delta')
54 | torch.save(compressed_st, f'{name}')
55 |
56 |
57 | def parse_args():
58 | parser = argparse.ArgumentParser('', add_help=False)
59 | parser.add_argument('--delta_ckpt', help='path of checkpoint to compress',
60 | type=str)
61 | parser.add_argument('--ckpt', help='path of pretrained model checkpoint',
62 | type=str)
63 | parser.add_argument("--diffuser", action='store_true')
64 | return parser.parse_args()
65 |
66 |
67 | if __name__ == "__main__":
68 | args = parse_args()
69 | compress(args.delta_ckpt, args.ckpt, args.diffuser)
70 |
--------------------------------------------------------------------------------
/src/custom/get_deltas.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Adobe Research. All rights reserved.
2 | # To view a copy of the license, visit LICENSE.md.
3 | import os
4 | import argparse
5 | import glob
6 | import torch
7 |
8 | def parse_args():
9 | parser = argparse.ArgumentParser('', add_help=False)
10 | parser.add_argument('-i', '--path', help='path of folder to checkpoints', type=str)
11 | parser.add_argument('-n', '--newtoken', help='number of new tokens in the checkpoint', default=1, type=int)
12 | return parser.parse_args()
13 |
14 | def save_delta(ckpt_dir, newtoken=1):
15 | assert newtoken > 0, 'No new tokens found'
16 | layers = []
17 | for ckptfile in glob.glob(f'{ckpt_dir}/*.ckpt', recursive=True):
18 | if 'delta' not in ckptfile:
19 | st = torch.load(ckptfile)["state_dict"]
20 | if len(layers) == 0:
21 | for key in list(st.keys()):
22 | if 'attn2.to_k' in key or 'attn2.to_v' in key:
23 | layers.append(key)
24 | st_delta = {'state_dict': {}}
25 | for each in layers:
26 | st_delta['state_dict'][each] = st[each].clone()
27 | num_tokens = st['cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'].shape[0]
28 | st_delta['state_dict']['embed'] = st['cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'][-newtoken:].clone()
29 | filepath = os.path.join(ckpt_dir, 'delta-' + os.path.basename(ckptfile))
30 | torch.save(st_delta, filepath)
31 | os.remove(ckptfile)
32 | print('.. saved embedding', st_delta['state_dict']['embed'].cpu().numpy().shape, num_tokens, '=>', filepath)
33 |
34 |
35 | if __name__ == "__main__":
36 | args = parse_args()
37 | save_delta(args.path, args.newtoken)
38 |
--------------------------------------------------------------------------------
/src/custom/modules.py:
--------------------------------------------------------------------------------
1 | # This code is built from the Huggingface repository: https://github.com/huggingface/transformers/tree/main/src/transformers/models/clip.
2 | # Copyright 2018- The Hugging Face team. All rights reserved.
3 |
4 | import os, sys
5 | from packaging import version
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | import transformers
11 | from transformers import CLIPTokenizer, CLIPTextModel
12 |
13 | class AbstractEncoder(nn.Module):
14 | def __init__(self):
15 | super().__init__()
16 | def encode(self, *args, **kwargs):
17 | raise NotImplementedError
18 |
19 | # https://github.com/adobe-research/custom-diffusion
20 | class FrozenCLIPEmbedderWrapper(AbstractEncoder):
21 | """Uses the CLIP transformer encoder for text (from Hugging Face)"""
22 | def __init__(self, modifier_token, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, model_dir='models'):
23 | super().__init__()
24 | self.tokenizer = CLIPTokenizer.from_pretrained(version, cache_dir=os.path.join(model_dir, os.path.basename(version)), local_files_only=False)
25 | self.transformer = CLIPTextModel.from_pretrained(version, cache_dir=os.path.join(model_dir, os.path.basename(version)), local_files_only=False)
26 | self.device = device
27 | self.max_length = max_length
28 | self.modifier_token = modifier_token
29 | if '+' in self.modifier_token:
30 | self.modifier_token = self.modifier_token.split('+')
31 | else:
32 | self.modifier_token = [self.modifier_token]
33 |
34 | self.add_token()
35 | self.freeze()
36 |
37 | def add_token(self):
38 | self.modifier_token_id = []
39 | token_embeds1 = self.transformer.get_input_embeddings().weight.data
40 | for each_modifier_token in self.modifier_token:
41 | num_added_tokens = self.tokenizer.add_tokens(each_modifier_token)
42 | modifier_token_id = self.tokenizer.convert_tokens_to_ids(each_modifier_token)
43 | self.modifier_token_id.append(modifier_token_id) # .., 49408, 49409
44 |
45 | self.transformer.resize_token_embeddings(len(self.tokenizer))
46 | token_embeds = self.transformer.get_input_embeddings().weight.data # [49410, 768]
47 | # print(' modifier_token_id', self.modifier_token_id, '.. token_embeds', token_embeds.shape)
48 | token_embeds[self.modifier_token_id[-1]] = torch.nn.Parameter(token_embeds[42170], requires_grad=True)
49 | if len(self.modifier_token) == 2:
50 | token_embeds[self.modifier_token_id[-2]] = torch.nn.Parameter(token_embeds[47629], requires_grad=True)
51 | if len(self.modifier_token) == 3:
52 | token_embeds[self.modifier_token_id[-3]] = torch.nn.Parameter(token_embeds[43514], requires_grad=True)
53 |
54 | def custom_forward(self, hidden_states, input_ids):
55 | input_shape = hidden_states.size()
56 | bsz, seq_len = input_shape[:2]
57 | if version.parse(transformers.__version__) >= version.parse('4.21'):
58 | causal_attention_mask = self.transformer.text_model._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(hidden_states.device)
59 | else:
60 | causal_attention_mask = self.transformer.text_model._build_causal_attention_mask(bsz, seq_len).to(hidden_states.device)
61 |
62 | encoder_outputs = self.transformer.text_model.encoder(inputs_embeds=hidden_states, causal_attention_mask=causal_attention_mask)
63 |
64 | last_hidden_state = encoder_outputs[0]
65 | last_hidden_state = self.transformer.text_model.final_layer_norm(last_hidden_state)
66 |
67 | return last_hidden_state
68 |
69 | def freeze(self):
70 | self.transformer = self.transformer.eval()
71 | for param in self.transformer.text_model.encoder.parameters():
72 | param.requires_grad = False
73 | for param in self.transformer.text_model.final_layer_norm.parameters():
74 | param.requires_grad = False
75 | for param in self.transformer.text_model.embeddings.position_embedding.parameters():
76 | param.requires_grad = False
77 |
78 | def forward(self, text):
79 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
80 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
81 | tokens = batch_encoding["input_ids"].to(self.device)
82 |
83 | if len(self.modifier_token) == 3:
84 | indices = ((tokens == self.modifier_token_id[-1]) | (tokens == self.modifier_token_id[-2]) | (tokens == self.modifier_token_id[-3]))*1
85 | elif len(self.modifier_token) == 2:
86 | indices = ((tokens == self.modifier_token_id[-1]) | (tokens == self.modifier_token_id[-2]))*1
87 | else:
88 | indices = (tokens == self.modifier_token_id[-1])*1
89 |
90 | indices = indices.unsqueeze(-1)
91 |
92 | input_shape = tokens.size()
93 | tokens = tokens.view(-1, input_shape[-1])
94 |
95 | hidden_states = self.transformer.text_model.embeddings(input_ids=tokens)
96 | hidden_states = (1-indices)*hidden_states.detach() + indices*hidden_states
97 |
98 | z = self.custom_forward(hidden_states, tokens)
99 |
100 | return z
101 |
102 | def encode(self, text):
103 | return self(text)
104 |
105 |
--------------------------------------------------------------------------------
/src/ldm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/ldm/__init__.py
--------------------------------------------------------------------------------
/src/ldm/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/ldm/data/__init__.py
--------------------------------------------------------------------------------
/src/ldm/data/base.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 |
3 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
4 |
5 | class Txt2ImgIterableBaseDataset(IterableDataset):
6 | """ Define an interface to make the IterableDatasets for text2img data chainable """
7 |
8 | def __init__(self, num_records=0, valid_ids=None, size=256):
9 | super().__init__()
10 | self.num_records = num_records
11 | self.valid_ids = valid_ids
12 | self.sample_ids = valid_ids
13 | self.size = size
14 |
15 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
16 |
17 | def __len__(self):
18 | return self.num_records
19 |
20 | @abstractmethod
21 | def __iter__(self):
22 | pass
23 |
--------------------------------------------------------------------------------
/src/ldm/data/personalized.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import numpy as np
4 | from PIL import Image
5 |
6 | from torch.utils.data import Dataset
7 | from torchvision import transforms
8 |
9 | templates_smallest = ['a photo of a {}']
10 |
11 | templates = [
12 | 'a photo of a {}',
13 | 'a rendering of a {}',
14 | 'a cropped photo of the {}',
15 | 'the photo of a {}',
16 | 'a photo of a clean {}',
17 | 'a photo of a dirty {}',
18 | 'a dark photo of the {}',
19 | 'a photo of my {}',
20 | 'a photo of the cool {}',
21 | 'a close-up photo of a {}',
22 | 'a bright photo of the {}',
23 | 'a cropped photo of a {}',
24 | 'a photo of the {}',
25 | 'a good photo of the {}',
26 | 'a photo of one {}',
27 | 'a close-up photo of the {}',
28 | 'a rendition of the {}',
29 | 'a photo of the clean {}',
30 | 'a rendition of a {}',
31 | 'a photo of a nice {}',
32 | 'a good photo of a {}',
33 | 'a photo of the nice {}',
34 | 'a photo of the small {}',
35 | 'a photo of the weird {}',
36 | 'a photo of the large {}',
37 | 'a photo of a cool {}',
38 | 'a photo of a small {}',
39 | ]
40 | dual_templates = [
41 | 'a photo of a {} with {}',
42 | 'a rendering of a {} with {}',
43 | 'a cropped photo of the {} with {}',
44 | 'the photo of a {} with {}',
45 | 'a photo of a clean {} with {}',
46 | 'a photo of a dirty {} with {}',
47 | 'a dark photo of the {} with {}',
48 | 'a photo of my {} with {}',
49 | 'a photo of the cool {} with {}',
50 | 'a close-up photo of a {} with {}',
51 | 'a bright photo of the {} with {}',
52 | 'a cropped photo of a {} with {}',
53 | 'a photo of the {} with {}',
54 | 'a good photo of the {} with {}',
55 | 'a photo of one {} with {}',
56 | 'a close-up photo of the {} with {}',
57 | 'a rendition of the {} with {}',
58 | 'a photo of the clean {} with {}',
59 | 'a rendition of a {} with {}',
60 | 'a photo of a nice {} with {}',
61 | 'a good photo of a {} with {}',
62 | 'a photo of the nice {} with {}',
63 | 'a photo of the small {} with {}',
64 | 'a photo of the weird {} with {}',
65 | 'a photo of the large {} with {}',
66 | 'a photo of a cool {} with {}',
67 | 'a photo of a small {} with {}',
68 | ]
69 | templates_style = [
70 | 'a painting in the style of {}',
71 | 'a rendering in the style of {}',
72 | 'a cropped painting in the style of {}',
73 | 'the painting in the style of {}',
74 | 'a clean painting in the style of {}',
75 | 'a dirty painting in the style of {}',
76 | 'a dark painting in the style of {}',
77 | 'a picture in the style of {}',
78 | 'a cool painting in the style of {}',
79 | 'a close-up painting in the style of {}',
80 | 'a bright painting in the style of {}',
81 | 'a cropped painting in the style of {}',
82 | 'a good painting in the style of {}',
83 | 'a close-up painting in the style of {}',
84 | 'a rendition in the style of {}',
85 | 'a nice painting in the style of {}',
86 | 'a small painting in the style of {}',
87 | 'a weird painting in the style of {}',
88 | 'a large painting in the style of {}',
89 | ]
90 | dual_templates_style = [
91 | 'a painting in the style of {} with {}',
92 | 'a rendering in the style of {} with {}',
93 | 'a cropped painting in the style of {} with {}',
94 | 'the painting in the style of {} with {}',
95 | 'a clean painting in the style of {} with {}',
96 | 'a dirty painting in the style of {} with {}',
97 | 'a dark painting in the style of {} with {}',
98 | 'a cool painting in the style of {} with {}',
99 | 'a close-up painting in the style of {} with {}',
100 | 'a bright painting in the style of {} with {}',
101 | 'a cropped painting in the style of {} with {}',
102 | 'a good painting in the style of {} with {}',
103 | 'a painting of one {} in the style of {}',
104 | 'a nice painting in the style of {} with {}',
105 | 'a small painting in the style of {} with {}',
106 | 'a weird painting in the style of {} with {}',
107 | 'a large painting in the style of {} with {}',
108 | ]
109 | per_img_token_list = ['א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת']
110 |
111 | class PersonalizedBase(Dataset):
112 | def __init__(self, data_root, size=None, repeats=100, interpolation='bicubic', flip_p=0.5, set='train', placeholder_token='*',
113 | per_image_tokens=False, center_crop=False, style=False, mixing_prob=0.25, coarse_class_text=None):
114 | self.data_root = data_root
115 | self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
116 |
117 | self.num_images = len(self.image_paths)
118 | self._length = self.num_images
119 | self.placeholder_token = placeholder_token
120 | self.per_image_tokens = per_image_tokens
121 | self.center_crop = center_crop
122 | self.templates = templates_style if style else templates
123 | self.dual_templates = dual_templates_style if style else dual_templates
124 | self.mixing_prob = mixing_prob
125 | self.coarse_class_text = coarse_class_text
126 |
127 | if per_image_tokens:
128 | assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
129 |
130 | if set == 'train':
131 | self._length = self.num_images * repeats
132 |
133 | self.size = size
134 | self.interpolation = {'linear': Image.LINEAR, 'bilinear': Image.BILINEAR, 'bicubic': Image.BICUBIC, 'lanczos': Image.LANCZOS}[interpolation]
135 | self.flip = transforms.RandomHorizontalFlip(p=flip_p)
136 |
137 | def __len__(self):
138 | return self._length
139 |
140 | def __getitem__(self, i):
141 | example = {}
142 | image = Image.open(self.image_paths[i % self.num_images])
143 |
144 | if not image.mode == 'RGB':
145 | image = image.convert('RGB')
146 |
147 | placeholder_string = self.placeholder_token
148 | if self.coarse_class_text:
149 | placeholder_string = f'{placeholder_string} {self.coarse_class_text}'
150 |
151 | if self.per_image_tokens and np.random.uniform() < self.mixing_prob:
152 | text = random.choice(self.dual_templates).format(placeholder_string, per_img_token_list[i % self.num_images])
153 | else:
154 | text = random.choice(self.templates).format(placeholder_string)
155 |
156 | example['caption'] = text
157 |
158 | # default to score-sde preprocessing
159 | img = np.array(image).astype(np.uint8)
160 |
161 | if self.center_crop:
162 | crop = min(img.shape[0], img.shape[1])
163 | h, w, = img.shape[0], img.shape[1]
164 | img = img[(h-crop)//2 : (h+crop)//2, (w-crop)//2 : (w+crop)//2]
165 |
166 | image = Image.fromarray(img)
167 | if self.size is not None:
168 | image = image.resize((self.size, self.size), resample=self.interpolation)
169 |
170 | image = self.flip(image)
171 | image = np.array(image).astype(np.uint8)
172 | example['image'] = (image / 127.5 - 1.0).astype(np.float32)
173 | return example
174 |
--------------------------------------------------------------------------------
/src/ldm/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/ldm/models/__init__.py
--------------------------------------------------------------------------------
/src/ldm/models/diffusion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/ldm/models/diffusion/__init__.py
--------------------------------------------------------------------------------
/src/ldm/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/ldm/modules/__init__.py
--------------------------------------------------------------------------------
/src/ldm/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/ldm/modules/diffusionmodules/__init__.py
--------------------------------------------------------------------------------
/src/ldm/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/ldm/modules/distributions/__init__.py
--------------------------------------------------------------------------------
/src/ldm/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(
34 | device=self.parameters.device
35 | )
36 |
37 | def sample(self):
38 | x = self.mean + self.std * torch.randn(self.mean.shape).to(
39 | device=self.parameters.device
40 | )
41 | return x
42 |
43 | def kl(self, other=None):
44 | if self.deterministic:
45 | return torch.Tensor([0.0])
46 | else:
47 | if other is None:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
50 | dim=[1, 2, 3],
51 | )
52 | else:
53 | return 0.5 * torch.sum(
54 | torch.pow(self.mean - other.mean, 2) / other.var
55 | + self.var / other.var
56 | - 1.0
57 | - self.logvar
58 | + other.logvar,
59 | dim=[1, 2, 3],
60 | )
61 |
62 | def nll(self, sample, dims=[1, 2, 3]):
63 | if self.deterministic:
64 | return torch.Tensor([0.0])
65 | logtwopi = np.log(2.0 * np.pi)
66 | return 0.5 * torch.sum(
67 | logtwopi
68 | + self.logvar
69 | + torch.pow(sample - self.mean, 2) / self.var,
70 | dim=dims,
71 | )
72 |
73 | def mode(self):
74 | return self.mean
75 |
76 |
77 | def normal_kl(mean1, logvar1, mean2, logvar2):
78 | """
79 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
80 | Compute the KL divergence between two gaussians.
81 | Shapes are automatically broadcasted, so batches can be compared to
82 | scalars, among other use cases.
83 | """
84 | tensor = None
85 | for obj in (mean1, logvar1, mean2, logvar2):
86 | if isinstance(obj, torch.Tensor):
87 | tensor = obj
88 | break
89 | assert tensor is not None, 'at least one argument must be a Tensor'
90 |
91 | # Force variances to be Tensors. Broadcasting helps convert scalars to
92 | # Tensors, but it does not work for torch.exp().
93 | logvar1, logvar2 = [
94 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
95 | for x in (logvar1, logvar2)
96 | ]
97 |
98 | return 0.5 * (
99 | -1.0
100 | + logvar2
101 | - logvar1
102 | + torch.exp(logvar1 - logvar2)
103 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
104 | )
105 |
--------------------------------------------------------------------------------
/src/ldm/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError('Decay must be between 0 and 1')
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer(
14 | 'num_updates',
15 | torch.tensor(0, dtype=torch.int)
16 | if use_num_upates
17 | else torch.tensor(-1, dtype=torch.int),
18 | )
19 |
20 | for name, p in model.named_parameters():
21 | if p.requires_grad:
22 | # remove as '.'-character is not allowed in buffers
23 | s_name = name.replace('.', '')
24 | self.m_name2s_name.update({name: s_name})
25 | self.register_buffer(s_name, p.clone().detach().data)
26 |
27 | self.collected_params = []
28 |
29 | def forward(self, model):
30 | decay = self.decay
31 |
32 | if self.num_updates >= 0:
33 | self.num_updates += 1
34 | decay = min(
35 | self.decay, (1 + self.num_updates) / (10 + self.num_updates)
36 | )
37 |
38 | one_minus_decay = 1.0 - decay
39 |
40 | with torch.no_grad():
41 | m_param = dict(model.named_parameters())
42 | shadow_params = dict(self.named_buffers())
43 |
44 | for key in m_param:
45 | if m_param[key].requires_grad:
46 | sname = self.m_name2s_name[key]
47 | shadow_params[sname] = shadow_params[sname].type_as(
48 | m_param[key]
49 | )
50 | shadow_params[sname].sub_(
51 | one_minus_decay * (shadow_params[sname] - m_param[key])
52 | )
53 | else:
54 | assert not key in self.m_name2s_name
55 |
56 | def copy_to(self, model):
57 | m_param = dict(model.named_parameters())
58 | shadow_params = dict(self.named_buffers())
59 | for key in m_param:
60 | if m_param[key].requires_grad:
61 | m_param[key].data.copy_(
62 | shadow_params[self.m_name2s_name[key]].data
63 | )
64 | else:
65 | assert not key in self.m_name2s_name
66 |
67 | def store(self, parameters):
68 | """
69 | Save the current parameters for restoring later.
70 | Args:
71 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
72 | temporarily stored.
73 | """
74 | self.collected_params = [param.clone() for param in parameters]
75 |
76 | def restore(self, parameters):
77 | """
78 | Restore the parameters stored with the `store` method.
79 | Useful to validate the model with EMA parameters without affecting the
80 | original optimization process. Store the parameters before the
81 | `copy_to` method. After validation (or model saving), use this to
82 | restore the former parameters.
83 | Args:
84 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
85 | updated with the stored parameters.
86 | """
87 | for c_param, param in zip(self.collected_params, parameters):
88 | param.data.copy_(c_param.data)
89 |
--------------------------------------------------------------------------------
/src/ldm/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/ldm/modules/encoders/__init__.py
--------------------------------------------------------------------------------
/src/ldm/modules/midas/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/ldm/modules/midas/__init__.py
--------------------------------------------------------------------------------
/src/ldm/modules/midas/api.py:
--------------------------------------------------------------------------------
1 | # based on https://github.com/isl-org/MiDaS
2 |
3 | import cv2
4 | import torch
5 | import torch.nn as nn
6 | from torchvision.transforms import Compose
7 |
8 | from ldm.modules.midas.midas.dpt_depth import DPTDepthModel
9 | from ldm.modules.midas.midas.midas_net import MidasNet
10 | from ldm.modules.midas.midas.midas_net_custom import MidasNet_small
11 | from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
12 |
13 |
14 | ISL_PATHS = {
15 | "dpt_large": "models/depth/dpt_large-midas-2f21e586.pt",
16 | "dpt_hybrid": "models/depth/dpt_hybrid-midas-501f0c75.pt",
17 | "midas_v21": "",
18 | "midas_v21_small": "",
19 | }
20 |
21 |
22 | def disabled_train(self, mode=True):
23 | """Overwrite model.train with this function to make sure train/eval mode
24 | does not change anymore."""
25 | return self
26 |
27 |
28 | def load_midas_transform(model_type):
29 | # https://github.com/isl-org/MiDaS/blob/master/run.py
30 | # load transform only
31 | if model_type == "dpt_large": # DPT-Large
32 | net_w, net_h = 384, 384
33 | resize_mode = "minimal"
34 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
35 |
36 | elif model_type == "dpt_hybrid": # DPT-Hybrid
37 | net_w, net_h = 384, 384
38 | resize_mode = "minimal"
39 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
40 |
41 | elif model_type == "midas_v21":
42 | net_w, net_h = 384, 384
43 | resize_mode = "upper_bound"
44 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
45 |
46 | elif model_type == "midas_v21_small":
47 | net_w, net_h = 256, 256
48 | resize_mode = "upper_bound"
49 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
50 |
51 | else:
52 | assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
53 |
54 | transform = Compose(
55 | [
56 | Resize(
57 | net_w,
58 | net_h,
59 | resize_target=None,
60 | keep_aspect_ratio=True,
61 | ensure_multiple_of=32,
62 | resize_method=resize_mode,
63 | image_interpolation_method=cv2.INTER_CUBIC,
64 | ),
65 | normalization,
66 | PrepareForNet(),
67 | ]
68 | )
69 |
70 | return transform
71 |
72 |
73 | def load_model(model_type):
74 | # https://github.com/isl-org/MiDaS/blob/master/run.py
75 | # load network
76 | model_path = ISL_PATHS[model_type]
77 | if model_type == "dpt_large": # DPT-Large
78 | model = DPTDepthModel(
79 | path=model_path,
80 | backbone="vitl16_384",
81 | non_negative=True,
82 | )
83 | net_w, net_h = 384, 384
84 | resize_mode = "minimal"
85 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
86 |
87 | elif model_type == "dpt_hybrid": # DPT-Hybrid
88 | model = DPTDepthModel(
89 | path=model_path,
90 | backbone="vitb_rn50_384",
91 | non_negative=True,
92 | )
93 | net_w, net_h = 384, 384
94 | resize_mode = "minimal"
95 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
96 |
97 | elif model_type == "midas_v21":
98 | model = MidasNet(model_path, non_negative=True)
99 | net_w, net_h = 384, 384
100 | resize_mode = "upper_bound"
101 | normalization = NormalizeImage(
102 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
103 | )
104 |
105 | elif model_type == "midas_v21_small":
106 | model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
107 | non_negative=True, blocks={'expand': True})
108 | net_w, net_h = 256, 256
109 | resize_mode = "upper_bound"
110 | normalization = NormalizeImage(
111 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
112 | )
113 |
114 | else:
115 | print(f"model_type '{model_type}' not implemented, use: --model_type large")
116 | assert False
117 |
118 | transform = Compose(
119 | [
120 | Resize(
121 | net_w,
122 | net_h,
123 | resize_target=None,
124 | keep_aspect_ratio=True,
125 | ensure_multiple_of=32,
126 | resize_method=resize_mode,
127 | image_interpolation_method=cv2.INTER_CUBIC,
128 | ),
129 | normalization,
130 | PrepareForNet(),
131 | ]
132 | )
133 |
134 | return model.eval(), transform
135 |
136 |
137 | class MiDaSInference(nn.Module):
138 | MODEL_TYPES_TORCH_HUB = [
139 | "DPT_Large",
140 | "DPT_Hybrid",
141 | "MiDaS_small"
142 | ]
143 | MODEL_TYPES_ISL = [
144 | "dpt_large",
145 | "dpt_hybrid",
146 | "midas_v21",
147 | "midas_v21_small",
148 | ]
149 |
150 | def __init__(self, model_type):
151 | super().__init__()
152 | assert (model_type in self.MODEL_TYPES_ISL)
153 | model, _ = load_model(model_type)
154 | self.model = model
155 | self.model.train = disabled_train
156 |
157 | def forward(self, x):
158 | # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array
159 | # NOTE: we expect that the correct transform has been called during dataloading.
160 | with torch.no_grad():
161 | prediction = self.model(x)
162 | prediction = torch.nn.functional.interpolate(
163 | prediction.unsqueeze(1),
164 | size=x.shape[2:],
165 | mode="bicubic",
166 | align_corners=False,
167 | )
168 | assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3])
169 | return prediction
170 |
171 |
--------------------------------------------------------------------------------
/src/ldm/modules/midas/midas/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/ldm/modules/midas/midas/__init__.py
--------------------------------------------------------------------------------
/src/ldm/modules/midas/midas/base_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class BaseModel(torch.nn.Module):
5 | def load(self, path):
6 | """Load model from file.
7 |
8 | Args:
9 | path (str): file path
10 | """
11 | parameters = torch.load(path, map_location=torch.device('cpu'))
12 |
13 | if "optimizer" in parameters:
14 | parameters = parameters["model"]
15 |
16 | self.load_state_dict(parameters)
17 |
--------------------------------------------------------------------------------
/src/ldm/modules/midas/midas/dpt_depth.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .base_model import BaseModel
6 | from .blocks import (
7 | FeatureFusionBlock,
8 | FeatureFusionBlock_custom,
9 | Interpolate,
10 | _make_encoder,
11 | forward_vit,
12 | )
13 |
14 |
15 | def _make_fusion_block(features, use_bn):
16 | return FeatureFusionBlock_custom(
17 | features,
18 | nn.ReLU(False),
19 | deconv=False,
20 | bn=use_bn,
21 | expand=False,
22 | align_corners=True,
23 | )
24 |
25 |
26 | class DPT(BaseModel):
27 | def __init__(
28 | self,
29 | head,
30 | features=256,
31 | backbone="vitb_rn50_384",
32 | readout="project",
33 | channels_last=False,
34 | use_bn=False,
35 | ):
36 |
37 | super(DPT, self).__init__()
38 |
39 | self.channels_last = channels_last
40 |
41 | hooks = {
42 | "vitb_rn50_384": [0, 1, 8, 11],
43 | "vitb16_384": [2, 5, 8, 11],
44 | "vitl16_384": [5, 11, 17, 23],
45 | }
46 |
47 | # Instantiate backbone and reassemble blocks
48 | self.pretrained, self.scratch = _make_encoder(
49 | backbone,
50 | features,
51 | False, # Set to true of you want to train from scratch, uses ImageNet weights
52 | groups=1,
53 | expand=False,
54 | exportable=False,
55 | hooks=hooks[backbone],
56 | use_readout=readout,
57 | )
58 |
59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
63 |
64 | self.scratch.output_conv = head
65 |
66 |
67 | def forward(self, x):
68 | if self.channels_last == True:
69 | x.contiguous(memory_format=torch.channels_last)
70 |
71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
72 |
73 | layer_1_rn = self.scratch.layer1_rn(layer_1)
74 | layer_2_rn = self.scratch.layer2_rn(layer_2)
75 | layer_3_rn = self.scratch.layer3_rn(layer_3)
76 | layer_4_rn = self.scratch.layer4_rn(layer_4)
77 |
78 | path_4 = self.scratch.refinenet4(layer_4_rn)
79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
82 |
83 | out = self.scratch.output_conv(path_1)
84 |
85 | return out
86 |
87 |
88 | class DPTDepthModel(DPT):
89 | def __init__(self, path=None, non_negative=True, **kwargs):
90 | features = kwargs["features"] if "features" in kwargs else 256
91 |
92 | head = nn.Sequential(
93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
96 | nn.ReLU(True),
97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
98 | nn.ReLU(True) if non_negative else nn.Identity(),
99 | nn.Identity(),
100 | )
101 |
102 | super().__init__(head, **kwargs)
103 |
104 | if path is not None:
105 | self.load(path)
106 |
107 | def forward(self, x):
108 | return super().forward(x).squeeze(dim=1)
109 |
110 |
--------------------------------------------------------------------------------
/src/ldm/modules/midas/midas/midas_net.py:
--------------------------------------------------------------------------------
1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2 | This file contains code that is adapted from
3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4 | """
5 | import torch
6 | import torch.nn as nn
7 |
8 | from .base_model import BaseModel
9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
10 |
11 |
12 | class MidasNet(BaseModel):
13 | """Network for monocular depth estimation.
14 | """
15 |
16 | def __init__(self, path=None, features=256, non_negative=True):
17 | """Init.
18 |
19 | Args:
20 | path (str, optional): Path to saved model. Defaults to None.
21 | features (int, optional): Number of features. Defaults to 256.
22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50
23 | """
24 | print("Loading weights: ", path)
25 |
26 | super(MidasNet, self).__init__()
27 |
28 | use_pretrained = False if path is None else True
29 |
30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
31 |
32 | self.scratch.refinenet4 = FeatureFusionBlock(features)
33 | self.scratch.refinenet3 = FeatureFusionBlock(features)
34 | self.scratch.refinenet2 = FeatureFusionBlock(features)
35 | self.scratch.refinenet1 = FeatureFusionBlock(features)
36 |
37 | self.scratch.output_conv = nn.Sequential(
38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
39 | Interpolate(scale_factor=2, mode="bilinear"),
40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
41 | nn.ReLU(True),
42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
43 | nn.ReLU(True) if non_negative else nn.Identity(),
44 | )
45 |
46 | if path:
47 | self.load(path)
48 |
49 | def forward(self, x):
50 | """Forward pass.
51 |
52 | Args:
53 | x (tensor): input data (image)
54 |
55 | Returns:
56 | tensor: depth
57 | """
58 |
59 | layer_1 = self.pretrained.layer1(x)
60 | layer_2 = self.pretrained.layer2(layer_1)
61 | layer_3 = self.pretrained.layer3(layer_2)
62 | layer_4 = self.pretrained.layer4(layer_3)
63 |
64 | layer_1_rn = self.scratch.layer1_rn(layer_1)
65 | layer_2_rn = self.scratch.layer2_rn(layer_2)
66 | layer_3_rn = self.scratch.layer3_rn(layer_3)
67 | layer_4_rn = self.scratch.layer4_rn(layer_4)
68 |
69 | path_4 = self.scratch.refinenet4(layer_4_rn)
70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
73 |
74 | out = self.scratch.output_conv(path_1)
75 |
76 | return torch.squeeze(out, dim=1)
77 |
--------------------------------------------------------------------------------
/src/ldm/modules/midas/midas/midas_net_custom.py:
--------------------------------------------------------------------------------
1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2 | This file contains code that is adapted from
3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4 | """
5 | import torch
6 | import torch.nn as nn
7 |
8 | from .base_model import BaseModel
9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
10 |
11 |
12 | class MidasNet_small(BaseModel):
13 | """Network for monocular depth estimation.
14 | """
15 |
16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
17 | blocks={'expand': True}):
18 | """Init.
19 |
20 | Args:
21 | path (str, optional): Path to saved model. Defaults to None.
22 | features (int, optional): Number of features. Defaults to 256.
23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50
24 | """
25 | print("Loading weights: ", path)
26 |
27 | super(MidasNet_small, self).__init__()
28 |
29 | use_pretrained = False if path else True
30 |
31 | self.channels_last = channels_last
32 | self.blocks = blocks
33 | self.backbone = backbone
34 |
35 | self.groups = 1
36 |
37 | features1=features
38 | features2=features
39 | features3=features
40 | features4=features
41 | self.expand = False
42 | if "expand" in self.blocks and self.blocks['expand'] == True:
43 | self.expand = True
44 | features1=features
45 | features2=features*2
46 | features3=features*4
47 | features4=features*8
48 |
49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
50 |
51 | self.scratch.activation = nn.ReLU(False)
52 |
53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
57 |
58 |
59 | self.scratch.output_conv = nn.Sequential(
60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
61 | Interpolate(scale_factor=2, mode="bilinear"),
62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
63 | self.scratch.activation,
64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
65 | nn.ReLU(True) if non_negative else nn.Identity(),
66 | nn.Identity(),
67 | )
68 |
69 | if path:
70 | self.load(path)
71 |
72 |
73 | def forward(self, x):
74 | """Forward pass.
75 |
76 | Args:
77 | x (tensor): input data (image)
78 |
79 | Returns:
80 | tensor: depth
81 | """
82 | if self.channels_last==True:
83 | print("self.channels_last = ", self.channels_last)
84 | x.contiguous(memory_format=torch.channels_last)
85 |
86 |
87 | layer_1 = self.pretrained.layer1(x)
88 | layer_2 = self.pretrained.layer2(layer_1)
89 | layer_3 = self.pretrained.layer3(layer_2)
90 | layer_4 = self.pretrained.layer4(layer_3)
91 |
92 | layer_1_rn = self.scratch.layer1_rn(layer_1)
93 | layer_2_rn = self.scratch.layer2_rn(layer_2)
94 | layer_3_rn = self.scratch.layer3_rn(layer_3)
95 | layer_4_rn = self.scratch.layer4_rn(layer_4)
96 |
97 |
98 | path_4 = self.scratch.refinenet4(layer_4_rn)
99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
102 |
103 | out = self.scratch.output_conv(path_1)
104 |
105 | return torch.squeeze(out, dim=1)
106 |
107 |
108 |
109 | def fuse_model(m):
110 | prev_previous_type = nn.Identity()
111 | prev_previous_name = ''
112 | previous_type = nn.Identity()
113 | previous_name = ''
114 | for name, module in m.named_modules():
115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
116 | # print("FUSED ", prev_previous_name, previous_name, name)
117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
119 | # print("FUSED ", prev_previous_name, previous_name)
120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
122 | # print("FUSED ", previous_name, name)
123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
124 |
125 | prev_previous_type = previous_type
126 | prev_previous_name = previous_name
127 | previous_type = type(module)
128 | previous_name = name
--------------------------------------------------------------------------------
/src/ldm/modules/midas/utils.py:
--------------------------------------------------------------------------------
1 | """Utils for monoDepth."""
2 | import sys
3 | import re
4 | import numpy as np
5 | import cv2
6 | import torch
7 |
8 |
9 | def read_pfm(path):
10 | """Read pfm file.
11 |
12 | Args:
13 | path (str): path to file
14 |
15 | Returns:
16 | tuple: (data, scale)
17 | """
18 | with open(path, "rb") as file:
19 |
20 | color = None
21 | width = None
22 | height = None
23 | scale = None
24 | endian = None
25 |
26 | header = file.readline().rstrip()
27 | if header.decode("ascii") == "PF":
28 | color = True
29 | elif header.decode("ascii") == "Pf":
30 | color = False
31 | else:
32 | raise Exception("Not a PFM file: " + path)
33 |
34 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
35 | if dim_match:
36 | width, height = list(map(int, dim_match.groups()))
37 | else:
38 | raise Exception("Malformed PFM header.")
39 |
40 | scale = float(file.readline().decode("ascii").rstrip())
41 | if scale < 0:
42 | # little-endian
43 | endian = "<"
44 | scale = -scale
45 | else:
46 | # big-endian
47 | endian = ">"
48 |
49 | data = np.fromfile(file, endian + "f")
50 | shape = (height, width, 3) if color else (height, width)
51 |
52 | data = np.reshape(data, shape)
53 | data = np.flipud(data)
54 |
55 | return data, scale
56 |
57 |
58 | def write_pfm(path, image, scale=1):
59 | """Write pfm file.
60 |
61 | Args:
62 | path (str): pathto file
63 | image (array): data
64 | scale (int, optional): Scale. Defaults to 1.
65 | """
66 |
67 | with open(path, "wb") as file:
68 | color = None
69 |
70 | if image.dtype.name != "float32":
71 | raise Exception("Image dtype must be float32.")
72 |
73 | image = np.flipud(image)
74 |
75 | if len(image.shape) == 3 and image.shape[2] == 3: # color image
76 | color = True
77 | elif (
78 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
79 | ): # greyscale
80 | color = False
81 | else:
82 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
83 |
84 | file.write("PF\n" if color else "Pf\n".encode())
85 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
86 |
87 | endian = image.dtype.byteorder
88 |
89 | if endian == "<" or endian == "=" and sys.byteorder == "little":
90 | scale = -scale
91 |
92 | file.write("%f\n".encode() % scale)
93 |
94 | image.tofile(file)
95 |
96 |
97 | def read_image(path):
98 | """Read image and output RGB image (0-1).
99 |
100 | Args:
101 | path (str): path to file
102 |
103 | Returns:
104 | array: RGB image (0-1)
105 | """
106 | img = cv2.imread(path)
107 |
108 | if img.ndim == 2:
109 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
110 |
111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
112 |
113 | return img
114 |
115 |
116 | def resize_image(img):
117 | """Resize image and make it fit for network.
118 |
119 | Args:
120 | img (array): image
121 |
122 | Returns:
123 | tensor: data ready for network
124 | """
125 | height_orig = img.shape[0]
126 | width_orig = img.shape[1]
127 |
128 | if width_orig > height_orig:
129 | scale = width_orig / 384
130 | else:
131 | scale = height_orig / 384
132 |
133 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
134 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
135 |
136 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
137 |
138 | img_resized = (
139 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
140 | )
141 | img_resized = img_resized.unsqueeze(0)
142 |
143 | return img_resized
144 |
145 |
146 | def resize_depth(depth, width, height):
147 | """Resize depth map and bring to CPU (numpy).
148 |
149 | Args:
150 | depth (tensor): depth
151 | width (int): image width
152 | height (int): image height
153 |
154 | Returns:
155 | array: processed depth
156 | """
157 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
158 |
159 | depth_resized = cv2.resize(
160 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
161 | )
162 |
163 | return depth_resized
164 |
165 | def write_depth(path, depth, bits=1):
166 | """Write depth map to pfm and png file.
167 |
168 | Args:
169 | path (str): filepath without extension
170 | depth (array): depth
171 | """
172 | write_pfm(path + ".pfm", depth.astype(np.float32))
173 |
174 | depth_min = depth.min()
175 | depth_max = depth.max()
176 |
177 | max_val = (2**(8*bits))-1
178 |
179 | if depth_max - depth_min > np.finfo("float").eps:
180 | out = max_val * (depth - depth_min) / (depth_max - depth_min)
181 | else:
182 | out = np.zeros(depth.shape, dtype=depth.type)
183 |
184 | if bits == 1:
185 | cv2.imwrite(path + ".png", out.astype("uint8"))
186 | elif bits == 2:
187 | cv2.imwrite(path + ".png", out.astype("uint16"))
188 |
189 | return
190 |
--------------------------------------------------------------------------------
/src/ldm/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | import torch
4 | import numpy as np
5 | from collections import abc
6 | from einops import rearrange
7 | from functools import partial
8 |
9 | import multiprocessing as mp
10 | from threading import Thread
11 | from queue import Queue
12 |
13 | from inspect import isfunction
14 | from PIL import Image, ImageDraw, ImageFont
15 |
16 |
17 | def log_txt_as_img(wh, xc, size=10):
18 | # wh a tuple of (width, height)
19 | # xc a list of captions to plot
20 | b = len(xc)
21 | txts = list()
22 | for bi in range(b):
23 | txt = Image.new('RGB', wh, color='white')
24 | draw = ImageDraw.Draw(txt)
25 | font = ImageFont.load_default()
26 | nc = int(40 * (wh[0] / 256))
27 | lines = '\n'.join(
28 | xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
29 | )
30 |
31 | try:
32 | draw.text((0, 0), lines, fill='black', font=font)
33 | except UnicodeEncodeError:
34 | print('Cant encode string for logging. Skipping.')
35 |
36 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
37 | txts.append(txt)
38 | txts = np.stack(txts)
39 | txts = torch.tensor(txts)
40 | return txts
41 |
42 |
43 | def ismap(x):
44 | if not isinstance(x, torch.Tensor):
45 | return False
46 | return (len(x.shape) == 4) and (x.shape[1] > 3)
47 |
48 |
49 | def isimage(x):
50 | if not isinstance(x, torch.Tensor):
51 | return False
52 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
53 |
54 |
55 | def exists(x):
56 | return x is not None
57 |
58 |
59 | def default(val, d):
60 | if exists(val):
61 | return val
62 | return d() if isfunction(d) else d
63 |
64 |
65 | def mean_flat(tensor):
66 | """
67 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
68 | Take the mean over all non-batch dimensions.
69 | """
70 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
71 |
72 |
73 | def count_params(model, verbose=False):
74 | total_params = sum(p.numel() for p in model.parameters())
75 | if verbose:
76 | print(
77 | f'{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.'
78 | )
79 | return total_params
80 |
81 |
82 | def instantiate_from_config(config, **kwargs):
83 | if not 'target' in config:
84 | if config == '__is_first_stage__':
85 | return None
86 | elif config == '__is_unconditional__':
87 | return None
88 | raise KeyError('Expected key `target` to instantiate.')
89 | return get_obj_from_str(config['target'])(
90 | **config.get('params', dict()), **kwargs
91 | )
92 |
93 |
94 | def get_obj_from_str(string, reload=False):
95 | module, cls = string.rsplit('.', 1)
96 | if reload:
97 | module_imp = importlib.import_module(module)
98 | importlib.reload(module_imp)
99 | return getattr(importlib.import_module(module, package=None), cls)
100 |
101 |
102 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
103 | # create dummy dataset instance
104 |
105 | # run prefetching
106 | if idx_to_fn:
107 | res = func(data, worker_id=idx)
108 | else:
109 | res = func(data)
110 | Q.put([idx, res])
111 | Q.put('Done')
112 |
113 |
114 | def parallel_data_prefetch(
115 | func: callable,
116 | data,
117 | n_proc,
118 | target_data_type='ndarray',
119 | cpu_intensive=True,
120 | use_worker_id=False,
121 | ):
122 | # if target_data_type not in ["ndarray", "list"]:
123 | # raise ValueError(
124 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
125 | # )
126 | if isinstance(data, np.ndarray) and target_data_type == 'list':
127 | raise ValueError('list expected but function got ndarray.')
128 | elif isinstance(data, abc.Iterable):
129 | if isinstance(data, dict):
130 | print(
131 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
132 | )
133 | data = list(data.values())
134 | if target_data_type == 'ndarray':
135 | data = np.asarray(data)
136 | else:
137 | data = list(data)
138 | else:
139 | raise TypeError(
140 | f'The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}.'
141 | )
142 |
143 | if cpu_intensive:
144 | Q = mp.Queue(1000)
145 | proc = mp.Process
146 | else:
147 | Q = Queue(1000)
148 | proc = Thread
149 | # spawn processes
150 | if target_data_type == 'ndarray':
151 | arguments = [
152 | [func, Q, part, i, use_worker_id]
153 | for i, part in enumerate(np.array_split(data, n_proc))
154 | ]
155 | else:
156 | step = (
157 | int(len(data) / n_proc + 1)
158 | if len(data) % n_proc != 0
159 | else int(len(data) / n_proc)
160 | )
161 | arguments = [
162 | [func, Q, part, i, use_worker_id]
163 | for i, part in enumerate(
164 | [data[i : i + step] for i in range(0, len(data), step)]
165 | )
166 | ]
167 | processes = []
168 | for i in range(n_proc):
169 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
170 | processes += [p]
171 |
172 | # start processes
173 | print(f'Start prefetching...')
174 | import time
175 |
176 | start = time.time()
177 | gather_res = [[] for _ in range(n_proc)]
178 | try:
179 | for p in processes:
180 | p.start()
181 |
182 | k = 0
183 | while k < n_proc:
184 | # get result
185 | res = Q.get()
186 | if res == 'Done':
187 | k += 1
188 | else:
189 | gather_res[res[0]] = res[1]
190 |
191 | except Exception as e:
192 | print('Exception: ', e)
193 | for p in processes:
194 | p.terminate()
195 |
196 | raise e
197 | finally:
198 | for p in processes:
199 | p.join()
200 | print(f'Prefetching complete. [{time.time() - start} sec.]')
201 |
202 | if target_data_type == 'ndarray':
203 | if not isinstance(gather_res[0], np.ndarray):
204 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
205 |
206 | # order outputs
207 | return np.concatenate(gather_res, axis=0)
208 | elif target_data_type == 'list':
209 | out = []
210 | for r in gather_res:
211 | out.extend(r)
212 | return out
213 | else:
214 | return gather_res
215 |
--------------------------------------------------------------------------------
/src/txt2mask.py:
--------------------------------------------------------------------------------
1 | '''Makes available the Txt2Mask class, which assists in the automatic
2 | assignment of masks via text prompt using clipseg.
3 |
4 | Here is typical usage:
5 |
6 | from txt2mask import Txt2Mask # SegmentedGrayscale
7 | from PIL import Image
8 |
9 | txt2mask = Txt2Mask(self.device)
10 | segmented = txt2mask.segment(Image.open('/path/to/img.png'),'a bagel')
11 |
12 | # this will return a grayscale Image of the segmented data
13 | grayscale = segmented.to_grayscale()
14 |
15 | # this will return a semi-transparent image in which the
16 | # selected object(s) are opaque and the rest is at various
17 | # levels of transparency
18 | transparent = segmented.to_transparent()
19 |
20 | # this will return a masked image suitable for use in inpainting:
21 | mask = segmented.to_mask(threshold=0.5)
22 |
23 | The threshold used in the call to to_mask() selects pixels for use in
24 | the mask that exceed the indicated confidence threshold. Values range
25 | from 0.0 to 1.0. The higher the threshold, the more confident the
26 | algorithm is. In limited testing, I have found that values around 0.5
27 | work fine.
28 | '''
29 |
30 | import os, sys
31 | import numpy as np
32 | from einops import rearrange, repeat
33 | from PIL import Image, ImageOps
34 |
35 | import torch
36 | from torchvision import transforms
37 |
38 | CLIP_VERSION = 'ViT-B/16'
39 | CLIPSEG_SIZE = 352
40 |
41 | sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)), 'xtra'))
42 |
43 | from clipseg.clipseg import CLIPDensePredT
44 |
45 | class SegmentedGrayscale(object):
46 | def __init__(self, image:Image, heatmap:torch.Tensor):
47 | self.heatmap = heatmap
48 | self.image = image
49 |
50 | def to_grayscale(self,invert:bool=False)->Image:
51 | return self._rescale(Image.fromarray(np.uint8(255 - self.heatmap * 255 if invert else self.heatmap * 255)))
52 |
53 | def to_mask(self,threshold:float=0.5)->Image:
54 | discrete_heatmap = self.heatmap.lt(threshold).int()
55 | return self._rescale(Image.fromarray(np.uint8(discrete_heatmap*255),mode='L'))
56 |
57 | def to_transparent(self,invert:bool=False)->Image:
58 | transparent_image = self.image.copy()
59 | # For img2img, we want the selected regions to be transparent,
60 | # but to_grayscale() returns the opposite. Thus invert.
61 | gs = self.to_grayscale(not invert)
62 | transparent_image.putalpha(gs)
63 | return transparent_image
64 |
65 | # unscales and uncrops the 352x352 heatmap so that it matches the image again
66 | def _rescale(self, heatmap:Image)->Image:
67 | size = self.image.width if (self.image.width > self.image.height) else self.image.height
68 | resized_image = heatmap.resize((size,size), resample=Image.Resampling.LANCZOS)
69 | return resized_image.crop((0,0,self.image.width,self.image.height))
70 |
71 | class Txt2Mask(object):
72 | ''' Create new Txt2Mask object. The optional device argument can be one of 'cuda', 'mps' or 'cpu' '''
73 | def __init__(self, model_path='models/clipseg/rd64-uni.pth', device='cpu', refined=False):
74 | # print('>> Initializing clipseg model for text to mask inference')
75 | self.device = device
76 | self.model = CLIPDensePredT(version=CLIP_VERSION, reduce_dim=64, complex_trans_conv=refined)
77 | self.model.eval()
78 | # initially we keep everything in cpu to conserve space
79 | self.model.to('cpu')
80 | self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False)
81 |
82 | @torch.no_grad()
83 | def segment(self, image, prompt:str) -> SegmentedGrayscale:
84 | '''
85 | Given a prompt string such as "a bagel", tries to identify the object in the
86 | provided image and returns a SegmentedGrayscale object in which the brighter
87 | pixels indicate where the object is inferred to be.
88 | '''
89 | self._to_device(self.device)
90 | prompts = [prompt] # right now we operate on just a single prompt at a time
91 |
92 | transform = transforms.Compose([
93 | transforms.ToTensor(),
94 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
95 | transforms.Resize((CLIPSEG_SIZE, CLIPSEG_SIZE)), # must be multiple of 64...
96 | ])
97 |
98 | if type(image) is str:
99 | image = Image.open(image).convert('RGB')
100 |
101 | image = ImageOps.exif_transpose(image)
102 | img = self._scale_and_crop(image)
103 | img = transform(img).unsqueeze(0)
104 |
105 | preds = self.model(img.repeat(len(prompts),1,1,1), prompts)[0]
106 | heatmap = torch.sigmoid(preds[0][0]).cpu()
107 | self._to_device('cpu')
108 | return SegmentedGrayscale(image, heatmap)
109 |
110 | def _to_device(self, device):
111 | self.model.to(device)
112 |
113 | def _scale_and_crop(self, image:Image)->Image:
114 | scaled_image = Image.new('RGB', (CLIPSEG_SIZE, CLIPSEG_SIZE))
115 | if image.width > image.height: # width is constraint
116 | scale = CLIPSEG_SIZE / image.width
117 | else:
118 | scale = CLIPSEG_SIZE / image.height
119 | scaled_image.paste(image.resize((int(scale * image.width), int(scale * image.height)), resample=Image.Resampling.LANCZOS),box=(0,0))
120 | return scaled_image
121 |
--------------------------------------------------------------------------------
/src/xtra/k_diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | from . import augmentation, config, evaluation, external, gns, layers, models, sampling, utils
2 | from .layers import Denoiser
3 |
--------------------------------------------------------------------------------
/src/xtra/k_diffusion/augmentation.py:
--------------------------------------------------------------------------------
1 | from functools import reduce
2 | import math
3 | import operator
4 |
5 | import numpy as np
6 | from skimage import transform
7 | import torch
8 | from torch import nn
9 |
10 |
11 | def translate2d(tx, ty):
12 | mat = [[1, 0, tx],
13 | [0, 1, ty],
14 | [0, 0, 1]]
15 | return torch.tensor(mat, dtype=torch.float32)
16 |
17 |
18 | def scale2d(sx, sy):
19 | mat = [[sx, 0, 0],
20 | [ 0, sy, 0],
21 | [ 0, 0, 1]]
22 | return torch.tensor(mat, dtype=torch.float32)
23 |
24 |
25 | def rotate2d(theta):
26 | mat = [[torch.cos(theta), torch.sin(-theta), 0],
27 | [torch.sin(theta), torch.cos(theta), 0],
28 | [ 0, 0, 1]]
29 | return torch.tensor(mat, dtype=torch.float32)
30 |
31 |
32 | class KarrasAugmentationPipeline:
33 | def __init__(self, a_prob=0.12, a_scale=2**0.2, a_aniso=2**0.2, a_trans=1/8):
34 | self.a_prob = a_prob
35 | self.a_scale = a_scale
36 | self.a_aniso = a_aniso
37 | self.a_trans = a_trans
38 |
39 | def __call__(self, image):
40 | h, w = image.size
41 | mats = [translate2d(h / 2 - 0.5, w / 2 - 0.5)]
42 |
43 | # x-flip
44 | a0 = torch.randint(2, []).float()
45 | mats.append(scale2d(1 - 2 * a0, 1))
46 | # y-flip
47 | do = (torch.rand([]) < self.a_prob).float()
48 | a1 = torch.randint(2, []).float() * do
49 | mats.append(scale2d(1, 1 - 2 * a1))
50 | # scaling
51 | do = (torch.rand([]) < self.a_prob).float()
52 | a2 = torch.randn([]) * do
53 | mats.append(scale2d(self.a_scale ** a2, self.a_scale ** a2))
54 | # rotation
55 | do = (torch.rand([]) < self.a_prob).float()
56 | a3 = (torch.rand([]) * 2 * math.pi - math.pi) * do
57 | mats.append(rotate2d(-a3))
58 | # anisotropy
59 | do = (torch.rand([]) < self.a_prob).float()
60 | a4 = (torch.rand([]) * 2 * math.pi - math.pi) * do
61 | a5 = torch.randn([]) * do
62 | mats.append(rotate2d(a4))
63 | mats.append(scale2d(self.a_aniso ** a5, self.a_aniso ** -a5))
64 | mats.append(rotate2d(-a4))
65 | # translation
66 | do = (torch.rand([]) < self.a_prob).float()
67 | a6 = torch.randn([]) * do
68 | a7 = torch.randn([]) * do
69 | mats.append(translate2d(self.a_trans * w * a6, self.a_trans * h * a7))
70 |
71 | # form the transformation matrix and conditioning vector
72 | mats.append(translate2d(-h / 2 + 0.5, -w / 2 + 0.5))
73 | mat = reduce(operator.matmul, mats)
74 | cond = torch.stack([a0, a1, a2, a3.cos() - 1, a3.sin(), a5 * a4.cos(), a5 * a4.sin(), a6, a7])
75 |
76 | # apply the transformation
77 | image_orig = np.array(image, dtype=np.float32) / 255
78 | if image_orig.ndim == 2:
79 | image_orig = image_orig[..., None]
80 | tf = transform.AffineTransform(mat.numpy())
81 | image = transform.warp(image_orig, tf.inverse, order=3, mode='reflect', cval=0.5, clip=False, preserve_range=True)
82 | image_orig = torch.as_tensor(image_orig).movedim(2, 0) * 2 - 1
83 | image = torch.as_tensor(image).movedim(2, 0) * 2 - 1
84 | return image, image_orig, cond
85 |
86 |
87 | class KarrasAugmentWrapper(nn.Module):
88 | def __init__(self, model):
89 | super().__init__()
90 | self.inner_model = model
91 |
92 | def forward(self, input, sigma, aug_cond=None, mapping_cond=None, **kwargs):
93 | if aug_cond is None:
94 | aug_cond = input.new_zeros([input.shape[0], 9])
95 | if mapping_cond is None:
96 | mapping_cond = aug_cond
97 | else:
98 | mapping_cond = torch.cat([aug_cond, mapping_cond], dim=1)
99 | return self.inner_model(input, sigma, mapping_cond=mapping_cond, **kwargs)
100 |
101 | def set_skip_stages(self, skip_stages):
102 | return self.inner_model.set_skip_stages(skip_stages)
103 |
104 | def set_patch_size(self, patch_size):
105 | return self.inner_model.set_patch_size(patch_size)
106 |
--------------------------------------------------------------------------------
/src/xtra/k_diffusion/config.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | import json
3 | import math
4 | import warnings
5 |
6 | from jsonmerge import merge
7 |
8 | from . import augmentation, layers, models, utils
9 |
10 |
11 | def load_config(file):
12 | defaults = {
13 | 'model': {
14 | 'sigma_data': 1.,
15 | 'patch_size': 1,
16 | 'dropout_rate': 0.,
17 | 'augment_wrapper': True,
18 | 'augment_prob': 0.,
19 | 'mapping_cond_dim': 0,
20 | 'unet_cond_dim': 0,
21 | 'cross_cond_dim': 0,
22 | 'cross_attn_depths': None,
23 | 'skip_stages': 0,
24 | 'has_variance': False,
25 | },
26 | 'dataset': {
27 | 'type': 'imagefolder',
28 | },
29 | 'optimizer': {
30 | 'type': 'adamw',
31 | 'lr': 1e-4,
32 | 'betas': [0.95, 0.999],
33 | 'eps': 1e-6,
34 | 'weight_decay': 1e-3,
35 | },
36 | 'lr_sched': {
37 | 'type': 'inverse',
38 | 'inv_gamma': 20000.,
39 | 'power': 1.,
40 | 'warmup': 0.99,
41 | },
42 | 'ema_sched': {
43 | 'type': 'inverse',
44 | 'power': 0.6667,
45 | 'max_value': 0.9999
46 | },
47 | }
48 | config = json.load(file)
49 | return merge(defaults, config)
50 |
51 |
52 | def make_model(config):
53 | config = config['model']
54 | assert config['type'] == 'image_v1'
55 | model = models.ImageDenoiserModelV1(
56 | config['input_channels'],
57 | config['mapping_out'],
58 | config['depths'],
59 | config['channels'],
60 | config['self_attn_depths'],
61 | config['cross_attn_depths'],
62 | patch_size=config['patch_size'],
63 | dropout_rate=config['dropout_rate'],
64 | mapping_cond_dim=config['mapping_cond_dim'] + (9 if config['augment_wrapper'] else 0),
65 | unet_cond_dim=config['unet_cond_dim'],
66 | cross_cond_dim=config['cross_cond_dim'],
67 | skip_stages=config['skip_stages'],
68 | has_variance=config['has_variance'],
69 | )
70 | if config['augment_wrapper']:
71 | model = augmentation.KarrasAugmentWrapper(model)
72 | return model
73 |
74 |
75 | def make_denoiser_wrapper(config):
76 | config = config['model']
77 | sigma_data = config.get('sigma_data', 1.)
78 | has_variance = config.get('has_variance', False)
79 | if not has_variance:
80 | return partial(layers.Denoiser, sigma_data=sigma_data)
81 | return partial(layers.DenoiserWithVariance, sigma_data=sigma_data)
82 |
83 |
84 | def make_sample_density(config):
85 | sd_config = config['sigma_sample_density']
86 | sigma_data = config['sigma_data']
87 | if sd_config['type'] == 'lognormal':
88 | loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc']
89 | scale = sd_config['std'] if 'std' in sd_config else sd_config['scale']
90 | return partial(utils.rand_log_normal, loc=loc, scale=scale)
91 | if sd_config['type'] == 'loglogistic':
92 | loc = sd_config['loc'] if 'loc' in sd_config else math.log(sigma_data)
93 | scale = sd_config['scale'] if 'scale' in sd_config else 0.5
94 | min_value = sd_config['min_value'] if 'min_value' in sd_config else 0.
95 | max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf')
96 | return partial(utils.rand_log_logistic, loc=loc, scale=scale, min_value=min_value, max_value=max_value)
97 | if sd_config['type'] == 'loguniform':
98 | min_value = sd_config['min_value'] if 'min_value' in sd_config else config['sigma_min']
99 | max_value = sd_config['max_value'] if 'max_value' in sd_config else config['sigma_max']
100 | return partial(utils.rand_log_uniform, min_value=min_value, max_value=max_value)
101 | if sd_config['type'] == 'v-diffusion':
102 | min_value = sd_config['min_value'] if 'min_value' in sd_config else 0.
103 | max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf')
104 | return partial(utils.rand_v_diffusion, sigma_data=sigma_data, min_value=min_value, max_value=max_value)
105 | if sd_config['type'] == 'split-lognormal':
106 | loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc']
107 | scale_1 = sd_config['std_1'] if 'std_1' in sd_config else sd_config['scale_1']
108 | scale_2 = sd_config['std_2'] if 'std_2' in sd_config else sd_config['scale_2']
109 | return partial(utils.rand_split_log_normal, loc=loc, scale_1=scale_1, scale_2=scale_2)
110 | raise ValueError('Unknown sample density type')
111 |
--------------------------------------------------------------------------------
/src/xtra/k_diffusion/evaluation.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | from pathlib import Path
4 |
5 | from cleanfid.inception_torchscript import InceptionV3W
6 | import clip
7 | from resize_right import resize
8 | import torch
9 | from torch import nn
10 | from torch.nn import functional as F
11 | from torchvision import transforms
12 | from tqdm.auto import trange
13 |
14 | from . import utils
15 |
16 |
17 | class InceptionV3FeatureExtractor(nn.Module):
18 | def __init__(self, device='cpu'):
19 | super().__init__()
20 | path = Path(os.environ.get('XDG_CACHE_HOME', Path.home() / '.cache')) / 'k-diffusion'
21 | url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
22 | digest = 'f58cb9b6ec323ed63459aa4fb441fe750cfe39fafad6da5cb504a16f19e958f4'
23 | utils.download_file(path / 'inception-2015-12-05.pt', url, digest)
24 | self.model = InceptionV3W(str(path), resize_inside=False).to(device)
25 | self.size = (299, 299)
26 |
27 | def forward(self, x):
28 | if x.shape[2:4] != self.size:
29 | x = resize(x, out_shape=self.size, pad_mode='reflect')
30 | if x.shape[1] == 1:
31 | x = torch.cat([x] * 3, dim=1)
32 | x = (x * 127.5 + 127.5).clamp(0, 255)
33 | return self.model(x)
34 |
35 |
36 | class CLIPFeatureExtractor(nn.Module):
37 | def __init__(self, name='ViT-L/14@336px', device='cpu'):
38 | super().__init__()
39 | self.model = clip.load(name, device=device)[0].eval().requires_grad_(False)
40 | self.normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
41 | std=(0.26862954, 0.26130258, 0.27577711))
42 | self.size = (self.model.visual.input_resolution, self.model.visual.input_resolution)
43 |
44 | def forward(self, x):
45 | if x.shape[2:4] != self.size:
46 | x = resize(x.add(1).div(2), out_shape=self.size, pad_mode='reflect').clamp(0, 1)
47 | x = self.normalize(x)
48 | x = self.model.encode_image(x).float()
49 | x = F.normalize(x) * x.shape[1] ** 0.5
50 | return x
51 |
52 |
53 | def compute_features(accelerator, sample_fn, extractor_fn, n, batch_size):
54 | n_per_proc = math.ceil(n / accelerator.num_processes)
55 | feats_all = []
56 | try:
57 | for i in trange(0, n_per_proc, batch_size, disable=not accelerator.is_main_process):
58 | cur_batch_size = min(n - i, batch_size)
59 | samples = sample_fn(cur_batch_size)[:cur_batch_size]
60 | feats_all.append(accelerator.gather(extractor_fn(samples)))
61 | except StopIteration:
62 | pass
63 | return torch.cat(feats_all)[:n]
64 |
65 |
66 | def polynomial_kernel(x, y):
67 | d = x.shape[-1]
68 | dot = x @ y.transpose(-2, -1)
69 | return (dot / d + 1) ** 3
70 |
71 |
72 | def squared_mmd(x, y, kernel=polynomial_kernel):
73 | m = x.shape[-2]
74 | n = y.shape[-2]
75 | kxx = kernel(x, x)
76 | kyy = kernel(y, y)
77 | kxy = kernel(x, y)
78 | kxx_sum = kxx.sum([-1, -2]) - kxx.diagonal(dim1=-1, dim2=-2).sum(-1)
79 | kyy_sum = kyy.sum([-1, -2]) - kyy.diagonal(dim1=-1, dim2=-2).sum(-1)
80 | kxy_sum = kxy.sum([-1, -2])
81 | term_1 = kxx_sum / m / (m - 1)
82 | term_2 = kyy_sum / n / (n - 1)
83 | term_3 = kxy_sum * 2 / m / n
84 | return term_1 + term_2 - term_3
85 |
86 |
87 | @utils.tf32_mode(matmul=False)
88 | def kid(x, y, max_size=5000):
89 | x_size, y_size = x.shape[0], y.shape[0]
90 | n_partitions = math.ceil(max(x_size / max_size, y_size / max_size))
91 | total_mmd = x.new_zeros([])
92 | for i in range(n_partitions):
93 | cur_x = x[round(i * x_size / n_partitions):round((i + 1) * x_size / n_partitions)]
94 | cur_y = y[round(i * y_size / n_partitions):round((i + 1) * y_size / n_partitions)]
95 | total_mmd = total_mmd + squared_mmd(cur_x, cur_y)
96 | return total_mmd / n_partitions
97 |
98 |
99 | class _MatrixSquareRootEig(torch.autograd.Function):
100 | @staticmethod
101 | def forward(ctx, a):
102 | vals, vecs = torch.linalg.eigh(a)
103 | ctx.save_for_backward(vals, vecs)
104 | return vecs @ vals.abs().sqrt().diag_embed() @ vecs.transpose(-2, -1)
105 |
106 | @staticmethod
107 | def backward(ctx, grad_output):
108 | vals, vecs = ctx.saved_tensors
109 | d = vals.abs().sqrt().unsqueeze(-1).repeat_interleave(vals.shape[-1], -1)
110 | vecs_t = vecs.transpose(-2, -1)
111 | return vecs @ (vecs_t @ grad_output @ vecs / (d + d.transpose(-2, -1))) @ vecs_t
112 |
113 |
114 | def sqrtm_eig(a):
115 | if a.ndim < 2:
116 | raise RuntimeError('tensor of matrices must have at least 2 dimensions')
117 | if a.shape[-2] != a.shape[-1]:
118 | raise RuntimeError('tensor must be batches of square matrices')
119 | return _MatrixSquareRootEig.apply(a)
120 |
121 |
122 | @utils.tf32_mode(matmul=False)
123 | def fid(x, y, eps=1e-8):
124 | x_mean = x.mean(dim=0)
125 | y_mean = y.mean(dim=0)
126 | mean_term = (x_mean - y_mean).pow(2).sum()
127 | x_cov = torch.cov(x.T)
128 | y_cov = torch.cov(y.T)
129 | eps_eye = torch.eye(x_cov.shape[0], device=x_cov.device, dtype=x_cov.dtype) * eps
130 | x_cov = x_cov + eps_eye
131 | y_cov = y_cov + eps_eye
132 | x_cov_sqrt = sqrtm_eig(x_cov)
133 | cov_term = torch.trace(x_cov + y_cov - 2 * sqrtm_eig(x_cov_sqrt @ y_cov @ x_cov_sqrt))
134 | return mean_term + cov_term
135 |
--------------------------------------------------------------------------------
/src/xtra/k_diffusion/external.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch import nn
5 |
6 | from . import sampling, utils
7 |
8 |
9 | class VDenoiser(nn.Module):
10 | """A v-diffusion-pytorch model wrapper for k-diffusion."""
11 |
12 | def __init__(self, inner_model):
13 | super().__init__()
14 | self.inner_model = inner_model
15 | self.sigma_data = 1.
16 |
17 | def get_scalings(self, sigma):
18 | c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
19 | c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
20 | c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
21 | return c_skip, c_out, c_in
22 |
23 | def sigma_to_t(self, sigma):
24 | return sigma.atan() / math.pi * 2
25 |
26 | def t_to_sigma(self, t):
27 | return (t * math.pi / 2).tan()
28 |
29 | def loss(self, input, noise, sigma, **kwargs):
30 | c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
31 | noised_input = input + noise * utils.append_dims(sigma, input.ndim)
32 | model_output = self.inner_model(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
33 | target = (input - c_skip * noised_input) / c_out
34 | return (model_output - target).pow(2).flatten(1).mean(1)
35 |
36 | def forward(self, input, sigma, **kwargs):
37 | c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
38 | return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
39 |
40 |
41 | class DiscreteSchedule(nn.Module):
42 | """A mapping between continuous noise levels (sigmas) and a list of discrete noise
43 | levels."""
44 |
45 | def __init__(self, sigmas, quantize):
46 | super().__init__()
47 | self.register_buffer('sigmas', sigmas)
48 | self.register_buffer('log_sigmas', sigmas.log())
49 | self.quantize = quantize
50 |
51 | @property
52 | def sigma_min(self):
53 | return self.sigmas[0]
54 |
55 | @property
56 | def sigma_max(self):
57 | return self.sigmas[-1]
58 |
59 | def get_sigmas(self, n=None):
60 | if n is None:
61 | return sampling.append_zero(self.sigmas.flip(0))
62 | t_max = len(self.sigmas) - 1
63 | t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
64 | return sampling.append_zero(self.t_to_sigma(t))
65 |
66 | def sigma_to_t(self, sigma, quantize=None):
67 | quantize = self.quantize if quantize is None else quantize
68 | log_sigma = sigma.log()
69 | dists = log_sigma - self.log_sigmas[:, None]
70 | if quantize:
71 | return dists.abs().argmin(dim=0).view(sigma.shape)
72 | low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
73 | high_idx = low_idx + 1
74 | low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
75 | w = (low - log_sigma) / (low - high)
76 | w = w.clamp(0, 1)
77 | t = (1 - w) * low_idx + w * high_idx
78 | return t.view(sigma.shape)
79 |
80 | def t_to_sigma(self, t):
81 | t = t.float()
82 | low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
83 | log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
84 | return log_sigma.exp()
85 |
86 |
87 | class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
88 | """A wrapper for discrete schedule DDPM models that output eps (the predicted
89 | noise)."""
90 |
91 | def __init__(self, model, alphas_cumprod, quantize):
92 | super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
93 | self.inner_model = model
94 | self.sigma_data = 1.
95 |
96 | def get_scalings(self, sigma):
97 | c_out = -sigma
98 | c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
99 | return c_out, c_in
100 |
101 | def get_eps(self, *args, **kwargs):
102 | return self.inner_model(*args, **kwargs)
103 |
104 | def loss(self, input, noise, sigma, **kwargs):
105 | c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
106 | noised_input = input + noise * utils.append_dims(sigma, input.ndim)
107 | eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
108 | return (eps - noise).pow(2).flatten(1).mean(1)
109 |
110 | def forward(self, input, sigma, **kwargs):
111 | c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
112 | eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
113 | return input + eps * c_out
114 |
115 |
116 | class OpenAIDenoiser(DiscreteEpsDDPMDenoiser):
117 | """A wrapper for OpenAI diffusion models."""
118 |
119 | def __init__(self, model, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'):
120 | alphas_cumprod = torch.tensor(diffusion.alphas_cumprod, device=device, dtype=torch.float32)
121 | super().__init__(model, alphas_cumprod, quantize=quantize)
122 | self.has_learned_sigmas = has_learned_sigmas
123 |
124 | def get_eps(self, *args, **kwargs):
125 | model_output = self.inner_model(*args, **kwargs)
126 | if self.has_learned_sigmas:
127 | return model_output.chunk(2, dim=1)[0]
128 | return model_output
129 |
130 |
131 | class CompVisDenoiser(DiscreteEpsDDPMDenoiser):
132 | """A wrapper for CompVis diffusion models."""
133 |
134 | def __init__(self, model, quantize=False, device='cpu'):
135 | super().__init__(model, model.alphas_cumprod, quantize=quantize)
136 |
137 | def get_eps(self, *args, **kwargs):
138 | return self.inner_model.apply_model(*args, **kwargs)
139 |
140 |
141 | class DiscreteVDDPMDenoiser(DiscreteSchedule):
142 | """A wrapper for discrete schedule DDPM models that output v."""
143 |
144 | def __init__(self, model, alphas_cumprod, quantize):
145 | super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
146 | self.inner_model = model
147 | self.sigma_data = 1.
148 |
149 | def get_scalings(self, sigma):
150 | c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
151 | c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
152 | c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
153 | return c_skip, c_out, c_in
154 |
155 | def get_v(self, *args, **kwargs):
156 | return self.inner_model(*args, **kwargs)
157 |
158 | def loss(self, input, noise, sigma, **kwargs):
159 | c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
160 | noised_input = input + noise * utils.append_dims(sigma, input.ndim)
161 | model_output = self.get_v(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
162 | target = (input - c_skip * noised_input) / c_out
163 | return (model_output - target).pow(2).flatten(1).mean(1)
164 |
165 | def forward(self, input, sigma, **kwargs):
166 | c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
167 | return self.get_v(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
168 |
169 |
170 | class CompVisVDenoiser(DiscreteVDDPMDenoiser):
171 | """A wrapper for CompVis diffusion models that output v."""
172 |
173 | def __init__(self, model, quantize=False, device='cpu'):
174 | super().__init__(model, model.alphas_cumprod, quantize=quantize)
175 |
176 | def get_v(self, x, t, cond, **kwargs):
177 | return self.inner_model.apply_model(x, t, cond)
178 |
--------------------------------------------------------------------------------
/src/xtra/k_diffusion/gns.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class DDPGradientStatsHook:
6 | def __init__(self, ddp_module):
7 | try:
8 | ddp_module.register_comm_hook(self, self._hook_fn)
9 | except AttributeError:
10 | raise ValueError('DDPGradientStatsHook does not support non-DDP wrapped modules')
11 | self._clear_state()
12 |
13 | def _clear_state(self):
14 | self.bucket_sq_norms_small_batch = []
15 | self.bucket_sq_norms_large_batch = []
16 |
17 | @staticmethod
18 | def _hook_fn(self, bucket):
19 | buf = bucket.buffer()
20 | self.bucket_sq_norms_small_batch.append(buf.pow(2).sum())
21 | fut = torch.distributed.all_reduce(buf, op=torch.distributed.ReduceOp.AVG, async_op=True).get_future()
22 | def callback(fut):
23 | buf = fut.value()[0]
24 | self.bucket_sq_norms_large_batch.append(buf.pow(2).sum())
25 | return buf
26 | return fut.then(callback)
27 |
28 | def get_stats(self):
29 | sq_norm_small_batch = sum(self.bucket_sq_norms_small_batch)
30 | sq_norm_large_batch = sum(self.bucket_sq_norms_large_batch)
31 | self._clear_state()
32 | stats = torch.stack([sq_norm_small_batch, sq_norm_large_batch])
33 | torch.distributed.all_reduce(stats, op=torch.distributed.ReduceOp.AVG)
34 | return stats[0].item(), stats[1].item()
35 |
36 |
37 | class GradientNoiseScale:
38 | """Calculates the gradient noise scale (1 / SNR), or critical batch size,
39 | from _An Empirical Model of Large-Batch Training_,
40 | https://arxiv.org/abs/1812.06162).
41 |
42 | Args:
43 | beta (float): The decay factor for the exponential moving averages used to
44 | calculate the gradient noise scale.
45 | Default: 0.9998
46 | eps (float): Added for numerical stability.
47 | Default: 1e-8
48 | """
49 |
50 | def __init__(self, beta=0.9998, eps=1e-8):
51 | self.beta = beta
52 | self.eps = eps
53 | self.ema_sq_norm = 0.
54 | self.ema_var = 0.
55 | self.beta_cumprod = 1.
56 | self.gradient_noise_scale = float('nan')
57 |
58 | def state_dict(self):
59 | """Returns the state of the object as a :class:`dict`."""
60 | return dict(self.__dict__.items())
61 |
62 | def load_state_dict(self, state_dict):
63 | """Loads the object's state.
64 | Args:
65 | state_dict (dict): object state. Should be an object returned
66 | from a call to :meth:`state_dict`.
67 | """
68 | self.__dict__.update(state_dict)
69 |
70 | def update(self, sq_norm_small_batch, sq_norm_large_batch, n_small_batch, n_large_batch):
71 | """Updates the state with a new batch's gradient statistics, and returns the
72 | current gradient noise scale.
73 |
74 | Args:
75 | sq_norm_small_batch (float): The mean of the squared 2-norms of microbatch or
76 | per sample gradients.
77 | sq_norm_large_batch (float): The squared 2-norm of the mean of the microbatch or
78 | per sample gradients.
79 | n_small_batch (int): The batch size of the individual microbatch or per sample
80 | gradients (1 if per sample).
81 | n_large_batch (int): The total batch size of the mean of the microbatch or
82 | per sample gradients.
83 | """
84 | est_sq_norm = (n_large_batch * sq_norm_large_batch - n_small_batch * sq_norm_small_batch) / (n_large_batch - n_small_batch)
85 | est_var = (sq_norm_small_batch - sq_norm_large_batch) / (1 / n_small_batch - 1 / n_large_batch)
86 | self.ema_sq_norm = self.beta * self.ema_sq_norm + (1 - self.beta) * est_sq_norm
87 | self.ema_var = self.beta * self.ema_var + (1 - self.beta) * est_var
88 | self.beta_cumprod *= self.beta
89 | self.gradient_noise_scale = max(self.ema_var, self.eps) / max(self.ema_sq_norm, self.eps)
90 | return self.gradient_noise_scale
91 |
92 | def get_gns(self):
93 | """Returns the current gradient noise scale."""
94 | return self.gradient_noise_scale
95 |
96 | def get_stats(self):
97 | """Returns the current (debiased) estimates of the squared mean gradient
98 | and gradient variance."""
99 | return self.ema_sq_norm / (1 - self.beta_cumprod), self.ema_var / (1 - self.beta_cumprod)
100 |
--------------------------------------------------------------------------------
/src/xtra/k_diffusion/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .image_v1 import ImageDenoiserModelV1
2 |
--------------------------------------------------------------------------------
/src/xtra/open_clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .factory import list_models, create_model, create_model_and_transforms, add_model_config
2 | from .loss import ClipLoss
3 | from .model import CLIP, CLIPTextCfg, CLIPVisionCfg, convert_weights_to_fp16, trace_model
4 | from .openai import load_openai_model, list_openai_models
5 | from .pretrained import list_pretrained, list_pretrained_tag_models, list_pretrained_model_tags,\
6 | get_pretrained_url, download_pretrained
7 | from .tokenizer import SimpleTokenizer, tokenize
8 | from .transform import image_transform
9 |
--------------------------------------------------------------------------------
/src/xtra/open_clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/xtra/open_clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/src/xtra/open_clip/factory.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | import pathlib
5 | import re
6 | from copy import deepcopy
7 | from pathlib import Path
8 | from typing import Optional, Tuple
9 |
10 | import torch
11 |
12 | from .model import CLIP, convert_weights_to_fp16, resize_pos_embed
13 | from .openai import load_openai_model
14 | from .pretrained import get_pretrained_url, download_pretrained
15 | from .transform import image_transform
16 |
17 |
18 | _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
19 | _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
20 |
21 |
22 | def _natural_key(string_):
23 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
24 |
25 |
26 | def _rescan_model_configs():
27 | global _MODEL_CONFIGS
28 |
29 | config_ext = ('.json',)
30 | config_files = []
31 | for config_path in _MODEL_CONFIG_PATHS:
32 | if config_path.is_file() and config_path.suffix in config_ext:
33 | config_files.append(config_path)
34 | elif config_path.is_dir():
35 | for ext in config_ext:
36 | config_files.extend(config_path.glob(f'*{ext}'))
37 |
38 | for cf in config_files:
39 | with open(cf, 'r') as f:
40 | model_cfg = json.load(f)
41 | if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
42 | _MODEL_CONFIGS[cf.stem] = model_cfg
43 |
44 | _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
45 |
46 |
47 | _rescan_model_configs() # initial populate of model config registry
48 |
49 |
50 | def load_state_dict(checkpoint_path: str, map_location='cpu'):
51 | checkpoint = torch.load(checkpoint_path, map_location=map_location)
52 | if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
53 | state_dict = checkpoint['state_dict']
54 | else:
55 | state_dict = checkpoint
56 | if next(iter(state_dict.items()))[0].startswith('module'):
57 | state_dict = {k[7:]: v for k, v in state_dict.items()}
58 | return state_dict
59 |
60 |
61 | def load_checkpoint(model, checkpoint_path, strict=True):
62 | state_dict = load_state_dict(checkpoint_path)
63 | resize_pos_embed(state_dict, model)
64 | incompatible_keys = model.load_state_dict(state_dict, strict=strict)
65 | return incompatible_keys
66 |
67 |
68 | def create_model(
69 | model_name: str,
70 | pretrained: str = '',
71 | precision: str = 'fp32',
72 | device: torch.device = torch.device('cpu'),
73 | jit: bool = False,
74 | force_quick_gelu: bool = False,
75 | pretrained_image: bool = False,
76 | ):
77 | model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
78 |
79 | if pretrained.lower() == 'openai':
80 | logging.info(f'Loading pretrained {model_name} from OpenAI.')
81 | model = load_openai_model(model_name, device=device, jit=jit)
82 | # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
83 | if precision == "amp" or precision == "fp32":
84 | model = model.float()
85 | else:
86 | if model_name in _MODEL_CONFIGS:
87 | logging.info(f'Loading {model_name} model config.')
88 | model_cfg = deepcopy(_MODEL_CONFIGS[model_name])
89 | else:
90 | logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
91 | raise RuntimeError(f'Model config for {model_name} not found.')
92 |
93 | if force_quick_gelu:
94 | # override for use of QuickGELU on non-OpenAI transformer models
95 | model_cfg["quick_gelu"] = True
96 |
97 | if pretrained_image:
98 | if 'timm_model_name' in model_cfg.get('vision_cfg', {}):
99 | # pretrained weight loading for timm models set via vision_cfg
100 | model_cfg['vision_cfg']['timm_model_pretrained'] = True
101 | else:
102 | assert False, 'pretrained image towers currently only supported for timm models'
103 |
104 | model = CLIP(**model_cfg)
105 |
106 | if pretrained:
107 | checkpoint_path = ''
108 | url = get_pretrained_url(model_name, pretrained)
109 | if url:
110 | checkpoint_path = download_pretrained(url)
111 | elif os.path.exists(pretrained):
112 | checkpoint_path = pretrained
113 |
114 | if checkpoint_path:
115 | logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
116 | load_checkpoint(model, checkpoint_path)
117 | else:
118 | logging.warning(f'Pretrained weights ({pretrained}) not found for model {model_name}.')
119 | raise RuntimeError(f'Pretrained weights ({pretrained}) not found for model {model_name}.')
120 |
121 | model.to(device=device)
122 | if precision == "fp16":
123 | assert device.type != 'cpu'
124 | convert_weights_to_fp16(model)
125 |
126 | if jit:
127 | model = torch.jit.script(model)
128 |
129 | return model
130 |
131 |
132 | def create_model_and_transforms(
133 | model_name: str,
134 | pretrained: str = '',
135 | precision: str = 'fp32',
136 | device: torch.device = torch.device('cpu'),
137 | jit: bool = False,
138 | force_quick_gelu: bool = False,
139 | pretrained_image: bool = False,
140 | mean: Optional[Tuple[float, ...]] = None,
141 | std: Optional[Tuple[float, ...]] = None,
142 | ):
143 | model = create_model(
144 | model_name, pretrained, precision, device, jit,
145 | force_quick_gelu=force_quick_gelu,
146 | pretrained_image=pretrained_image)
147 | preprocess_train = image_transform(model.visual.image_size, is_train=True, mean=mean, std=std)
148 | preprocess_val = image_transform(model.visual.image_size, is_train=False, mean=mean, std=std)
149 | return model, preprocess_train, preprocess_val
150 |
151 |
152 | def list_models():
153 | """ enumerate available model architectures based on config files """
154 | return list(_MODEL_CONFIGS.keys())
155 |
156 |
157 | def add_model_config(path):
158 | """ add model config path or file and update registry """
159 | if not isinstance(path, Path):
160 | path = Path(path)
161 | _MODEL_CONFIG_PATHS.append(path)
162 | _rescan_model_configs()
163 |
--------------------------------------------------------------------------------
/src/xtra/open_clip/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 |
5 | try:
6 | import torch.distributed.nn
7 | from torch import distributed as dist
8 | has_distributed = True
9 | except ImportError:
10 | has_distributed = False
11 |
12 | try:
13 | import horovod.torch as hvd
14 | except ImportError:
15 | hvd = None
16 |
17 |
18 | def gather_features(
19 | image_features,
20 | text_features,
21 | local_loss=False,
22 | gather_with_grad=False,
23 | rank=0,
24 | world_size=1,
25 | use_horovod=False
26 | ):
27 | assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
28 | if use_horovod:
29 | assert hvd is not None, 'Please install horovod'
30 | if gather_with_grad:
31 | all_image_features = hvd.allgather(image_features)
32 | all_text_features = hvd.allgather(text_features)
33 | else:
34 | with torch.no_grad():
35 | all_image_features = hvd.allgather(image_features)
36 | all_text_features = hvd.allgather(text_features)
37 | if not local_loss:
38 | # ensure grads for local rank when all_* features don't have a gradient
39 | gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
40 | gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
41 | gathered_image_features[rank] = image_features
42 | gathered_text_features[rank] = text_features
43 | all_image_features = torch.cat(gathered_image_features, dim=0)
44 | all_text_features = torch.cat(gathered_text_features, dim=0)
45 | else:
46 | # We gather tensors from all gpus
47 | if gather_with_grad:
48 | all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
49 | all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
50 | else:
51 | gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
52 | gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
53 | dist.all_gather(gathered_image_features, image_features)
54 | dist.all_gather(gathered_text_features, text_features)
55 | if not local_loss:
56 | # ensure grads for local rank when all_* features don't have a gradient
57 | gathered_image_features[rank] = image_features
58 | gathered_text_features[rank] = text_features
59 | all_image_features = torch.cat(gathered_image_features, dim=0)
60 | all_text_features = torch.cat(gathered_text_features, dim=0)
61 |
62 | return all_image_features, all_text_features
63 |
64 |
65 | class ClipLoss(nn.Module):
66 |
67 | def __init__(
68 | self,
69 | local_loss=False,
70 | gather_with_grad=False,
71 | cache_labels=False,
72 | rank=0,
73 | world_size=1,
74 | use_horovod=False,
75 | ):
76 | super().__init__()
77 | self.local_loss = local_loss
78 | self.gather_with_grad = gather_with_grad
79 | self.cache_labels = cache_labels
80 | self.rank = rank
81 | self.world_size = world_size
82 | self.use_horovod = use_horovod
83 |
84 | # cache state
85 | self.prev_num_logits = 0
86 | self.labels = {}
87 |
88 | def forward(self, image_features, text_features, logit_scale):
89 | device = image_features.device
90 | if self.world_size > 1:
91 | all_image_features, all_text_features = gather_features(
92 | image_features, text_features,
93 | self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
94 |
95 | if self.local_loss:
96 | logits_per_image = logit_scale * image_features @ all_text_features.T
97 | logits_per_text = logit_scale * text_features @ all_image_features.T
98 | else:
99 | logits_per_image = logit_scale * all_image_features @ all_text_features.T
100 | logits_per_text = logits_per_image.T
101 | else:
102 | logits_per_image = logit_scale * image_features @ text_features.T
103 | logits_per_text = logit_scale * text_features @ image_features.T
104 |
105 | # calculated ground-truth and cache if enabled
106 | num_logits = logits_per_image.shape[0]
107 | if self.prev_num_logits != num_logits or device not in self.labels:
108 | labels = torch.arange(num_logits, device=device, dtype=torch.long)
109 | if self.world_size > 1 and self.local_loss:
110 | labels = labels + num_logits * self.rank
111 | if self.cache_labels:
112 | self.labels[device] = labels
113 | self.prev_num_logits = num_logits
114 | else:
115 | labels = self.labels[device]
116 |
117 | total_loss = (
118 | F.cross_entropy(logits_per_image, labels) +
119 | F.cross_entropy(logits_per_text, labels)
120 | ) / 2
121 | return total_loss
122 |
--------------------------------------------------------------------------------
/src/xtra/open_clip/model_configs/RN101-quickgelu.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "quick_gelu": true,
4 | "vision_cfg": {
5 | "image_size": 224,
6 | "layers": [
7 | 3,
8 | 4,
9 | 23,
10 | 3
11 | ],
12 | "width": 64,
13 | "patch_size": null
14 | },
15 | "text_cfg": {
16 | "context_length": 77,
17 | "vocab_size": 49408,
18 | "width": 512,
19 | "heads": 8,
20 | "layers": 12
21 | }
22 | }
--------------------------------------------------------------------------------
/src/xtra/open_clip/model_configs/RN101.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": [
6 | 3,
7 | 4,
8 | 23,
9 | 3
10 | ],
11 | "width": 64,
12 | "patch_size": null
13 | },
14 | "text_cfg": {
15 | "context_length": 77,
16 | "vocab_size": 49408,
17 | "width": 512,
18 | "heads": 8,
19 | "layers": 12
20 | }
21 | }
--------------------------------------------------------------------------------
/src/xtra/open_clip/model_configs/RN50-quickgelu.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "quick_gelu": true,
4 | "vision_cfg": {
5 | "image_size": 224,
6 | "layers": [
7 | 3,
8 | 4,
9 | 6,
10 | 3
11 | ],
12 | "width": 64,
13 | "patch_size": null
14 | },
15 | "text_cfg": {
16 | "context_length": 77,
17 | "vocab_size": 49408,
18 | "width": 512,
19 | "heads": 8,
20 | "layers": 12
21 | }
22 | }
23 |
--------------------------------------------------------------------------------
/src/xtra/open_clip/model_configs/RN50.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": [
6 | 3,
7 | 4,
8 | 6,
9 | 3
10 | ],
11 | "width": 64,
12 | "patch_size": null
13 | },
14 | "text_cfg": {
15 | "context_length": 77,
16 | "vocab_size": 49408,
17 | "width": 512,
18 | "heads": 8,
19 | "layers": 12
20 | }
21 | }
--------------------------------------------------------------------------------
/src/xtra/open_clip/model_configs/RN50x16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 384,
5 | "layers": [
6 | 6,
7 | 8,
8 | 18,
9 | 8
10 | ],
11 | "width": 96,
12 | "patch_size": null
13 | },
14 | "text_cfg": {
15 | "context_length": 77,
16 | "vocab_size": 49408,
17 | "width": 768,
18 | "heads": 12,
19 | "layers": 12
20 | }
21 | }
--------------------------------------------------------------------------------
/src/xtra/open_clip/model_configs/RN50x4.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 640,
3 | "vision_cfg": {
4 | "image_size": 288,
5 | "layers": [
6 | 4,
7 | 6,
8 | 10,
9 | 6
10 | ],
11 | "width": 80,
12 | "patch_size": null
13 | },
14 | "text_cfg": {
15 | "context_length": 77,
16 | "vocab_size": 49408,
17 | "width": 640,
18 | "heads": 10,
19 | "layers": 12
20 | }
21 | }
--------------------------------------------------------------------------------
/src/xtra/open_clip/model_configs/ViT-B-16-plus-240.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 640,
3 | "vision_cfg": {
4 | "image_size": 240,
5 | "layers": 12,
6 | "width": 896,
7 | "patch_size": 16
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 640,
13 | "heads": 10,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/src/xtra/open_clip/model_configs/ViT-B-16-plus.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 640,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 896,
7 | "patch_size": 16
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 640,
13 | "heads": 10,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/src/xtra/open_clip/model_configs/ViT-B-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "patch_size": 16
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 512,
13 | "heads": 8,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/src/xtra/open_clip/model_configs/ViT-B-32-plus-256.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 640,
3 | "vision_cfg": {
4 | "image_size": 256,
5 | "layers": 12,
6 | "width": 896,
7 | "patch_size": 32
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 640,
13 | "heads": 10,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/src/xtra/open_clip/model_configs/ViT-B-32-quickgelu.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "quick_gelu": true,
4 | "vision_cfg": {
5 | "image_size": 224,
6 | "layers": 12,
7 | "width": 768,
8 | "patch_size": 32
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/src/xtra/open_clip/model_configs/ViT-B-32.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "patch_size": 32
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 512,
13 | "heads": 8,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/src/xtra/open_clip/model_configs/ViT-H-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 32,
6 | "width": 1280,
7 | "head_width": 80,
8 | "patch_size": 14
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 1024,
14 | "heads": 16,
15 | "layers": 24
16 | }
17 | }
--------------------------------------------------------------------------------
/src/xtra/open_clip/model_configs/ViT-H-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 32,
6 | "width": 1280,
7 | "head_width": 80,
8 | "patch_size": 16
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 1024,
14 | "heads": 16,
15 | "layers": 24
16 | }
17 | }
--------------------------------------------------------------------------------
/src/xtra/open_clip/model_configs/ViT-L-14-280.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 280,
5 | "layers": 24,
6 | "width": 1024,
7 | "patch_size": 14
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 768,
13 | "heads": 12,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/src/xtra/open_clip/model_configs/ViT-L-14-336.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 336,
5 | "layers": 24,
6 | "width": 1024,
7 | "patch_size": 14
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 768,
13 | "heads": 12,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/src/xtra/open_clip/model_configs/ViT-L-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 24,
6 | "width": 1024,
7 | "patch_size": 14
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 768,
13 | "heads": 12,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/src/xtra/open_clip/model_configs/ViT-L-16-320.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 320,
5 | "layers": 24,
6 | "width": 1024,
7 | "patch_size": 16
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 768,
13 | "heads": 12,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/src/xtra/open_clip/model_configs/ViT-L-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 24,
6 | "width": 1024,
7 | "patch_size": 16
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 768,
13 | "heads": 12,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/src/xtra/open_clip/model_configs/ViT-g-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 40,
6 | "width": 1408,
7 | "head_width": 88,
8 | "mlp_ratio": 4.3637,
9 | "patch_size": 14
10 | },
11 | "text_cfg": {
12 | "context_length": 77,
13 | "vocab_size": 49408,
14 | "width": 1024,
15 | "heads": 16,
16 | "layers": 24
17 | }
18 | }
--------------------------------------------------------------------------------
/src/xtra/open_clip/model_configs/timm-efficientnetv2_rw_s.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "timm_model_name": "efficientnetv2_rw_s",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "abs_attn",
7 | "timm_proj": "",
8 | "image_size": 288
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 768,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/src/xtra/open_clip/model_configs/timm-resnet50d.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "timm_model_name": "resnet50d",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "abs_attn",
7 | "timm_proj": "",
8 | "image_size": 224
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/src/xtra/open_clip/model_configs/timm-resnetaa50d.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "timm_model_name": "resnetaa50d",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "abs_attn",
7 | "timm_proj": "",
8 | "image_size": 224
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/src/xtra/open_clip/model_configs/timm-resnetblur50.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "timm_model_name": "resnetblur50",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "abs_attn",
7 | "timm_proj": "",
8 | "image_size": 224
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/src/xtra/open_clip/model_configs/timm-swin_base_patch4_window7_224.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "timm_model_name": "swin_base_patch4_window7_224",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "image_size": 224
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/src/xtra/open_clip/model_configs/timm-vit_base_patch16_224.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "timm_model_name": "vit_base_patch16_224",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "image_size": 224
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/src/xtra/open_clip/model_configs/timm-vit_base_patch32_224.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "timm_model_name": "vit_base_patch32_224",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "image_size": 224
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/src/xtra/open_clip/model_configs/timm-vit_small_patch16_224.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "timm_model_name": "vit_small_patch16_224",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "image_size": 224
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/src/xtra/open_clip/openai.py:
--------------------------------------------------------------------------------
1 | """ OpenAI pretrained model functions
2 |
3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4 | """
5 |
6 | import os
7 | import warnings
8 | from typing import Union, List
9 |
10 | import torch
11 |
12 | from .model import build_model_from_openai_state_dict
13 | from .pretrained import get_pretrained_url, list_pretrained_tag_models, download_pretrained
14 |
15 | __all__ = ["list_openai_models", "load_openai_model"]
16 |
17 |
18 | def list_openai_models() -> List[str]:
19 | """Returns the names of available CLIP models"""
20 | return list_pretrained_tag_models('openai')
21 |
22 |
23 | def load_openai_model(
24 | name: str,
25 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
26 | jit=True,
27 | ):
28 | """Load a CLIP model
29 |
30 | Parameters
31 | ----------
32 | name : str
33 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
34 | device : Union[str, torch.device]
35 | The device to put the loaded model
36 | jit : bool
37 | Whether to load the optimized JIT model (default) or more hackable non-JIT model.
38 |
39 | Returns
40 | -------
41 | model : torch.nn.Module
42 | The CLIP model
43 | preprocess : Callable[[PIL.Image], torch.Tensor]
44 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
45 | """
46 | if get_pretrained_url(name, 'openai'):
47 | model_path = download_pretrained(get_pretrained_url(name, 'openai'))
48 | elif os.path.isfile(name):
49 | model_path = name
50 | else:
51 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
52 |
53 | try:
54 | # loading JIT archive
55 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
56 | state_dict = None
57 | except RuntimeError:
58 | # loading saved state dict
59 | if jit:
60 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
61 | jit = False
62 | state_dict = torch.load(model_path, map_location="cpu")
63 |
64 | if not jit:
65 | try:
66 | model = build_model_from_openai_state_dict(state_dict or model.state_dict()).to(device)
67 | except KeyError:
68 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
69 | model = build_model_from_openai_state_dict(sd).to(device)
70 |
71 | if str(device) == "cpu":
72 | model.float()
73 | return model
74 |
75 | # patch the device names
76 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
77 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
78 |
79 | def patch_device(module):
80 | try:
81 | graphs = [module.graph] if hasattr(module, "graph") else []
82 | except RuntimeError:
83 | graphs = []
84 |
85 | if hasattr(module, "forward1"):
86 | graphs.append(module.forward1.graph)
87 |
88 | for graph in graphs:
89 | for node in graph.findAllNodes("prim::Constant"):
90 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
91 | node.copyAttributes(device_node)
92 |
93 | model.apply(patch_device)
94 | patch_device(model.encode_image)
95 | patch_device(model.encode_text)
96 |
97 | # patch dtype to float32 on CPU
98 | if str(device) == "cpu":
99 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
100 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
101 | float_node = float_input.node()
102 |
103 | def patch_float(module):
104 | try:
105 | graphs = [module.graph] if hasattr(module, "graph") else []
106 | except RuntimeError:
107 | graphs = []
108 |
109 | if hasattr(module, "forward1"):
110 | graphs.append(module.forward1.graph)
111 |
112 | for graph in graphs:
113 | for node in graph.findAllNodes("aten::to"):
114 | inputs = list(node.inputs())
115 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
116 | if inputs[i].node()["value"] == 5:
117 | inputs[i].node().copyAttributes(float_node)
118 |
119 | model.apply(patch_float)
120 | patch_float(model.encode_image)
121 | patch_float(model.encode_text)
122 | model.float()
123 |
124 | # ensure image_size attr available at consistent location for both jit and non-jit
125 | model.visual.image_size = model.input_resolution.item()
126 | return model
127 |
--------------------------------------------------------------------------------
/src/xtra/open_clip/src.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/xtra/open_clip/src.zip
--------------------------------------------------------------------------------
/src/xtra/open_clip/test_simple.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from PIL import Image
4 | from open_clip import tokenizer
5 | import open_clip
6 | import os
7 | os.environ["CUDA_VISIBLE_DEVICES"] = ""
8 |
9 | def test_inference():
10 | model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32-quickgelu', pretrained='laion400m_e32')
11 |
12 | current_dir = os.path.dirname(os.path.realpath(__file__))
13 |
14 | image = preprocess(Image.open(current_dir + "/../docs/CLIP.png")).unsqueeze(0)
15 | text = tokenizer.tokenize(["a diagram", "a dog", "a cat"])
16 |
17 | with torch.no_grad():
18 | image_features = model.encode_image(image)
19 | text_features = model.encode_text(text)
20 |
21 | text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
22 |
23 | assert text_probs.cpu().numpy()[0].tolist() == [1.0, 0.0, 0.0]
--------------------------------------------------------------------------------
/src/xtra/open_clip/timm_model.py:
--------------------------------------------------------------------------------
1 | """ timm model adapter
2 |
3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
4 | """
5 | from collections import OrderedDict
6 |
7 | import torch.nn as nn
8 |
9 | try:
10 | import timm
11 | from timm.models.layers import Mlp, to_2tuple
12 | from timm.models.layers.attention_pool2d import RotAttentionPool2d
13 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
14 | except ImportError as e:
15 | timm = None
16 |
17 | from .utils import freeze_batch_norm_2d
18 |
19 |
20 | class TimmModel(nn.Module):
21 | """ timm model adapter
22 | # FIXME this adapter is a work in progress, may change in ways that break weight compat
23 | """
24 |
25 | def __init__(
26 | self,
27 | model_name,
28 | embed_dim,
29 | image_size=224,
30 | pool='avg',
31 | proj='linear',
32 | drop=0.,
33 | pretrained=False):
34 | super().__init__()
35 | if timm is None:
36 | raise RuntimeError("Please `pip install timm` to use timm models.")
37 |
38 | self.image_size = to_2tuple(image_size)
39 | self.trunk = timm.create_model(model_name, pretrained=pretrained)
40 | feat_size = self.trunk.default_cfg.get('pool_size', None)
41 | feature_ndim = 1 if not feat_size else 2
42 | if pool in ('abs_attn', 'rot_attn'):
43 | assert feature_ndim == 2
44 | # if attn pooling used, remove both classifier and default pool
45 | self.trunk.reset_classifier(0, global_pool='')
46 | else:
47 | # reset global pool if pool config set, otherwise leave as network default
48 | reset_kwargs = dict(global_pool=pool) if pool else {}
49 | self.trunk.reset_classifier(0, **reset_kwargs)
50 | prev_chs = self.trunk.num_features
51 |
52 | head_layers = OrderedDict()
53 | if pool == 'abs_attn':
54 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
55 | prev_chs = embed_dim
56 | elif pool == 'rot_attn':
57 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
58 | prev_chs = embed_dim
59 | else:
60 | assert proj, 'projection layer needed if non-attention pooling is used.'
61 |
62 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
63 | if proj == 'linear':
64 | head_layers['drop'] = nn.Dropout(drop)
65 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim)
66 | elif proj == 'mlp':
67 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop)
68 |
69 | self.head = nn.Sequential(head_layers)
70 |
71 | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
72 | """ lock modules
73 | Args:
74 | unlocked_groups (int): leave last n layer groups unlocked (default: 0)
75 | """
76 | if not unlocked_groups:
77 | # lock full model
78 | for param in self.trunk.parameters():
79 | param.requires_grad = False
80 | if freeze_bn_stats:
81 | freeze_batch_norm_2d(self.trunk)
82 | else:
83 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change
84 | try:
85 | # FIXME import here until API stable and in an official release
86 | from timm.models.helpers import group_parameters, group_modules
87 | except ImportError:
88 | raise RuntimeError(
89 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
90 | matcher = self.trunk.group_matcher()
91 | gparams = group_parameters(self.trunk, matcher)
92 | max_layer_id = max(gparams.keys())
93 | max_layer_id = max_layer_id - unlocked_groups
94 | for group_idx in range(max_layer_id + 1):
95 | group = gparams[group_idx]
96 | for param in group:
97 | self.trunk.get_parameter(param).requires_grad = False
98 | if freeze_bn_stats:
99 | gmodules = group_modules(self.trunk, matcher, reverse=True)
100 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
101 | freeze_batch_norm_2d(self.trunk, gmodules)
102 |
103 | def forward(self, x):
104 | x = self.trunk(x)
105 | x = self.head(x)
106 | return x
107 |
--------------------------------------------------------------------------------
/src/xtra/open_clip/tokenizer.py:
--------------------------------------------------------------------------------
1 | """ CLIP tokenizer
2 |
3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4 | """
5 | import gzip
6 | import html
7 | import os
8 | from functools import lru_cache
9 | from typing import Union, List
10 |
11 | import ftfy
12 | import regex as re
13 | import torch
14 |
15 |
16 | @lru_cache()
17 | def default_bpe():
18 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
19 |
20 |
21 | @lru_cache()
22 | def bytes_to_unicode():
23 | """
24 | Returns list of utf-8 byte and a corresponding list of unicode strings.
25 | The reversible bpe codes work on unicode strings.
26 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
27 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
28 | This is a signficant percentage of your normal, say, 32K bpe vocab.
29 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
30 | And avoids mapping to whitespace/control characters the bpe code barfs on.
31 | """
32 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
33 | cs = bs[:]
34 | n = 0
35 | for b in range(2**8):
36 | if b not in bs:
37 | bs.append(b)
38 | cs.append(2**8+n)
39 | n += 1
40 | cs = [chr(n) for n in cs]
41 | return dict(zip(bs, cs))
42 |
43 |
44 | def get_pairs(word):
45 | """Return set of symbol pairs in a word.
46 | Word is represented as tuple of symbols (symbols being variable-length strings).
47 | """
48 | pairs = set()
49 | prev_char = word[0]
50 | for char in word[1:]:
51 | pairs.add((prev_char, char))
52 | prev_char = char
53 | return pairs
54 |
55 |
56 | def basic_clean(text):
57 | text = ftfy.fix_text(text)
58 | text = html.unescape(html.unescape(text))
59 | return text.strip()
60 |
61 |
62 | def whitespace_clean(text):
63 | text = re.sub(r'\s+', ' ', text)
64 | text = text.strip()
65 | return text
66 |
67 |
68 | class SimpleTokenizer(object):
69 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
70 | self.byte_encoder = bytes_to_unicode()
71 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
72 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
73 | merges = merges[1:49152-256-2+1]
74 | merges = [tuple(merge.split()) for merge in merges]
75 | vocab = list(bytes_to_unicode().values())
76 | vocab = vocab + [v+'' for v in vocab]
77 | for merge in merges:
78 | vocab.append(''.join(merge))
79 | if not special_tokens:
80 | special_tokens = ['', '']
81 | else:
82 | special_tokens = ['', ''] + special_tokens
83 | vocab.extend(special_tokens)
84 | self.encoder = dict(zip(vocab, range(len(vocab))))
85 | self.decoder = {v: k for k, v in self.encoder.items()}
86 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
87 | self.cache = {t:t for t in special_tokens}
88 | special = "|".join(special_tokens)
89 | self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
90 |
91 | self.vocab_size = len(self.encoder)
92 | self.all_special_ids = [self.encoder[t] for t in special_tokens]
93 |
94 | def bpe(self, token):
95 | if token in self.cache:
96 | return self.cache[token]
97 | word = tuple(token[:-1]) + ( token[-1] + '',)
98 | pairs = get_pairs(word)
99 |
100 | if not pairs:
101 | return token+''
102 |
103 | while True:
104 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
105 | if bigram not in self.bpe_ranks:
106 | break
107 | first, second = bigram
108 | new_word = []
109 | i = 0
110 | while i < len(word):
111 | try:
112 | j = word.index(first, i)
113 | new_word.extend(word[i:j])
114 | i = j
115 | except:
116 | new_word.extend(word[i:])
117 | break
118 |
119 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
120 | new_word.append(first+second)
121 | i += 2
122 | else:
123 | new_word.append(word[i])
124 | i += 1
125 | new_word = tuple(new_word)
126 | word = new_word
127 | if len(word) == 1:
128 | break
129 | else:
130 | pairs = get_pairs(word)
131 | word = ' '.join(word)
132 | self.cache[token] = word
133 | return word
134 |
135 | def encode(self, text):
136 | bpe_tokens = []
137 | text = whitespace_clean(basic_clean(text)).lower()
138 | for token in re.findall(self.pat, text):
139 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
140 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
141 | return bpe_tokens
142 |
143 | def decode(self, tokens):
144 | text = ''.join([self.decoder[token] for token in tokens])
145 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
146 | return text
147 |
148 |
149 | _tokenizer = SimpleTokenizer()
150 |
151 |
152 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
153 | """
154 | Returns the tokenized representation of given input string(s)
155 |
156 | Parameters
157 | ----------
158 | texts : Union[str, List[str]]
159 | An input string or a list of input strings to tokenize
160 | context_length : int
161 | The context length to use; all CLIP models use 77 as the context length
162 |
163 | Returns
164 | -------
165 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
166 | """
167 | if isinstance(texts, str):
168 | texts = [texts]
169 |
170 | sot_token = _tokenizer.encoder[""]
171 | eot_token = _tokenizer.encoder[""]
172 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
173 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
174 |
175 | for i, tokens in enumerate(all_tokens):
176 | if len(tokens) > context_length:
177 | tokens = tokens[:context_length] # Truncate
178 | tokens[-1] = eot_token
179 | result[i, :len(tokens)] = torch.tensor(tokens)
180 |
181 | return result
182 |
--------------------------------------------------------------------------------
/src/xtra/open_clip/transform.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Sequence, Tuple
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torchvision.transforms.functional as F
6 |
7 |
8 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
9 | CenterCrop
10 |
11 |
12 | class ResizeMaxSize(nn.Module):
13 |
14 | def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
15 | super().__init__()
16 | if not isinstance(max_size, int):
17 | raise TypeError(f"Size should be int. Got {type(max_size)}")
18 | self.max_size = max_size
19 | self.interpolation = interpolation
20 | self.fn = min if fn == 'min' else min
21 | self.fill = fill
22 |
23 | def forward(self, img):
24 | if isinstance(img, torch.Tensor):
25 | height, width = img.shape[:2]
26 | else:
27 | width, height = img.size
28 | scale = self.max_size / float(max(height, width))
29 | if scale != 1.0:
30 | new_size = tuple(round(dim * scale) for dim in (height, width))
31 | img = F.resize(img, new_size, self.interpolation)
32 | pad_h = self.max_size - new_size[0]
33 | pad_w = self.max_size - new_size[1]
34 | img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
35 | return img
36 |
37 |
38 | def _convert_to_rgb(image):
39 | return image.convert('RGB')
40 |
41 |
42 | def image_transform(
43 | image_size: int,
44 | is_train: bool,
45 | mean: Optional[Tuple[float, ...]] = None,
46 | std: Optional[Tuple[float, ...]] = None,
47 | resize_longest_max: bool = False,
48 | fill_color: int = 0,
49 | ):
50 | mean = mean or (0.48145466, 0.4578275, 0.40821073) # OpenAI dataset mean
51 | std = std or (0.26862954, 0.26130258, 0.27577711) # OpenAI dataset std
52 | if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
53 | # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
54 | image_size = image_size[0]
55 |
56 | normalize = Normalize(mean=mean, std=std)
57 | if is_train:
58 | return Compose([
59 | RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
60 | _convert_to_rgb,
61 | ToTensor(),
62 | normalize,
63 | ])
64 | else:
65 | if resize_longest_max:
66 | transforms = [
67 | ResizeMaxSize(image_size, fill=fill_color)
68 | ]
69 | else:
70 | transforms = [
71 | Resize(image_size, interpolation=InterpolationMode.BICUBIC),
72 | CenterCrop(image_size),
73 | ]
74 | transforms.extend([
75 | _convert_to_rgb,
76 | ToTensor(),
77 | normalize,
78 | ])
79 | return Compose(transforms)
80 |
--------------------------------------------------------------------------------
/src/xtra/open_clip/utils.py:
--------------------------------------------------------------------------------
1 | from itertools import repeat
2 | import collections.abc
3 |
4 | from torch import nn as nn
5 | from torchvision.ops.misc import FrozenBatchNorm2d
6 |
7 |
8 | def freeze_batch_norm_2d(module, module_match={}, name=''):
9 | """
10 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
11 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
12 | returned. Otherwise, the module is walked recursively and submodules are converted in place.
13 |
14 | Args:
15 | module (torch.nn.Module): Any PyTorch module.
16 | module_match (dict): Dictionary of full module names to freeze (all if empty)
17 | name (str): Full module name (prefix)
18 |
19 | Returns:
20 | torch.nn.Module: Resulting module
21 |
22 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
23 | """
24 | res = module
25 | is_match = True
26 | if module_match:
27 | is_match = name in module_match
28 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
29 | res = FrozenBatchNorm2d(module.num_features)
30 | res.num_features = module.num_features
31 | res.affine = module.affine
32 | if module.affine:
33 | res.weight.data = module.weight.data.clone().detach()
34 | res.bias.data = module.bias.data.clone().detach()
35 | res.running_mean.data = module.running_mean.data
36 | res.running_var.data = module.running_var.data
37 | res.eps = module.eps
38 | else:
39 | for child_name, child in module.named_children():
40 | full_child_name = '.'.join([name, child_name]) if name else child_name
41 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
42 | if new_child is not child:
43 | res.add_module(child_name, new_child)
44 | return res
45 |
46 |
47 | # From PyTorch internals
48 | def _ntuple(n):
49 | def parse(x):
50 | if isinstance(x, collections.abc.Iterable):
51 | return x
52 | return tuple(repeat(x, n))
53 | return parse
54 |
55 |
56 | to_1tuple = _ntuple(1)
57 | to_2tuple = _ntuple(2)
58 | to_3tuple = _ntuple(3)
59 | to_4tuple = _ntuple(4)
60 | to_ntuple = lambda n, x: _ntuple(n)(x)
61 |
--------------------------------------------------------------------------------
/src/xtra/open_clip/version.py:
--------------------------------------------------------------------------------
1 | __version__ = '1.3.0'
2 |
--------------------------------------------------------------------------------
/src/xtra/taming/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/xtra/taming/__init__.py
--------------------------------------------------------------------------------
/src/xtra/taming/data.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/xtra/taming/data.zip
--------------------------------------------------------------------------------
/src/xtra/taming/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LambdaWarmUpCosineScheduler:
5 | """
6 | note: use with a base_lr of 1.0
7 | """
8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9 | self.lr_warm_up_steps = warm_up_steps
10 | self.lr_start = lr_start
11 | self.lr_min = lr_min
12 | self.lr_max = lr_max
13 | self.lr_max_decay_steps = max_decay_steps
14 | self.last_lr = 0.
15 | self.verbosity_interval = verbosity_interval
16 |
17 | def schedule(self, n):
18 | if self.verbosity_interval > 0:
19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20 | if n < self.lr_warm_up_steps:
21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22 | self.last_lr = lr
23 | return lr
24 | else:
25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26 | t = min(t, 1.0)
27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28 | 1 + np.cos(t * np.pi))
29 | self.last_lr = lr
30 | return lr
31 |
32 | def __call__(self, n):
33 | return self.schedule(n)
34 |
35 |
--------------------------------------------------------------------------------
/src/xtra/taming/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/xtra/taming/models/__init__.py
--------------------------------------------------------------------------------
/src/xtra/taming/models/dummy_cond_stage.py:
--------------------------------------------------------------------------------
1 | from torch import Tensor
2 |
3 |
4 | class DummyCondStage:
5 | def __init__(self, conditional_key):
6 | self.conditional_key = conditional_key
7 | self.train = None
8 |
9 | def eval(self):
10 | return self
11 |
12 | @staticmethod
13 | def encode(c: Tensor):
14 | return c, None, (None, None, c)
15 |
16 | @staticmethod
17 | def decode(c: Tensor):
18 | return c
19 |
20 | @staticmethod
21 | def to_rgb(c: Tensor):
22 | return c
23 |
--------------------------------------------------------------------------------
/src/xtra/taming/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/xtra/taming/modules/__init__.py
--------------------------------------------------------------------------------
/src/xtra/taming/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/xtra/taming/modules/diffusionmodules/__init__.py
--------------------------------------------------------------------------------
/src/xtra/taming/modules/discriminator/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/xtra/taming/modules/discriminator/__init__.py
--------------------------------------------------------------------------------
/src/xtra/taming/modules/discriminator/model.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import torch.nn as nn
3 |
4 |
5 | from taming.modules.util import ActNorm
6 |
7 |
8 | def weights_init(m):
9 | classname = m.__class__.__name__
10 | if classname.find('Conv') != -1:
11 | nn.init.normal_(m.weight.data, 0.0, 0.02)
12 | elif classname.find('BatchNorm') != -1:
13 | nn.init.normal_(m.weight.data, 1.0, 0.02)
14 | nn.init.constant_(m.bias.data, 0)
15 |
16 |
17 | class NLayerDiscriminator(nn.Module):
18 | """Defines a PatchGAN discriminator as in Pix2Pix
19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
20 | """
21 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
22 | """Construct a PatchGAN discriminator
23 | Parameters:
24 | input_nc (int) -- the number of channels in input images
25 | ndf (int) -- the number of filters in the last conv layer
26 | n_layers (int) -- the number of conv layers in the discriminator
27 | norm_layer -- normalization layer
28 | """
29 | super(NLayerDiscriminator, self).__init__()
30 | if not use_actnorm:
31 | norm_layer = nn.BatchNorm2d
32 | else:
33 | norm_layer = ActNorm
34 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
35 | use_bias = norm_layer.func != nn.BatchNorm2d
36 | else:
37 | use_bias = norm_layer != nn.BatchNorm2d
38 |
39 | kw = 4
40 | padw = 1
41 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
42 | nf_mult = 1
43 | nf_mult_prev = 1
44 | for n in range(1, n_layers): # gradually increase the number of filters
45 | nf_mult_prev = nf_mult
46 | nf_mult = min(2 ** n, 8)
47 | sequence += [
48 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
49 | norm_layer(ndf * nf_mult),
50 | nn.LeakyReLU(0.2, True)
51 | ]
52 |
53 | nf_mult_prev = nf_mult
54 | nf_mult = min(2 ** n_layers, 8)
55 | sequence += [
56 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
57 | norm_layer(ndf * nf_mult),
58 | nn.LeakyReLU(0.2, True)
59 | ]
60 |
61 | sequence += [
62 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
63 | self.main = nn.Sequential(*sequence)
64 |
65 | def forward(self, input):
66 | """Standard forward."""
67 | return self.main(input)
68 |
--------------------------------------------------------------------------------
/src/xtra/taming/modules/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from taming.modules.losses.vqperceptual import DummyLoss
2 |
3 |
--------------------------------------------------------------------------------
/src/xtra/taming/modules/losses/lpips.py:
--------------------------------------------------------------------------------
1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torchvision import models
6 | from collections import namedtuple
7 |
8 | from taming.util import get_ckpt_path
9 |
10 |
11 | class LPIPS(nn.Module):
12 | # Learned perceptual metric
13 | def __init__(self, use_dropout=True):
14 | super().__init__()
15 | self.scaling_layer = ScalingLayer()
16 | self.chns = [64, 128, 256, 512, 512] # vg16 features
17 | self.net = vgg16(pretrained=True, requires_grad=False)
18 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
19 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
20 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
21 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
22 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
23 | self.load_from_pretrained()
24 | for param in self.parameters():
25 | param.requires_grad = False
26 |
27 | def load_from_pretrained(self, name="vgg_lpips"):
28 | ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
29 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
30 | print("loaded pretrained LPIPS loss from {}".format(ckpt))
31 |
32 | @classmethod
33 | def from_pretrained(cls, name="vgg_lpips"):
34 | if name != "vgg_lpips":
35 | raise NotImplementedError
36 | model = cls()
37 | ckpt = get_ckpt_path(name)
38 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
39 | return model
40 |
41 | def forward(self, input, target):
42 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
43 | outs0, outs1 = self.net(in0_input), self.net(in1_input)
44 | feats0, feats1, diffs = {}, {}, {}
45 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
46 | for kk in range(len(self.chns)):
47 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
48 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
49 |
50 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
51 | val = res[0]
52 | for l in range(1, len(self.chns)):
53 | val += res[l]
54 | return val
55 |
56 |
57 | class ScalingLayer(nn.Module):
58 | def __init__(self):
59 | super(ScalingLayer, self).__init__()
60 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
61 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
62 |
63 | def forward(self, inp):
64 | return (inp - self.shift) / self.scale
65 |
66 |
67 | class NetLinLayer(nn.Module):
68 | """ A single linear layer which does a 1x1 conv """
69 | def __init__(self, chn_in, chn_out=1, use_dropout=False):
70 | super(NetLinLayer, self).__init__()
71 | layers = [nn.Dropout(), ] if (use_dropout) else []
72 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
73 | self.model = nn.Sequential(*layers)
74 |
75 |
76 | class vgg16(torch.nn.Module):
77 | def __init__(self, requires_grad=False, pretrained=True):
78 | super(vgg16, self).__init__()
79 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
80 | self.slice1 = torch.nn.Sequential()
81 | self.slice2 = torch.nn.Sequential()
82 | self.slice3 = torch.nn.Sequential()
83 | self.slice4 = torch.nn.Sequential()
84 | self.slice5 = torch.nn.Sequential()
85 | self.N_slices = 5
86 | for x in range(4):
87 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
88 | for x in range(4, 9):
89 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
90 | for x in range(9, 16):
91 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
92 | for x in range(16, 23):
93 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
94 | for x in range(23, 30):
95 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
96 | if not requires_grad:
97 | for param in self.parameters():
98 | param.requires_grad = False
99 |
100 | def forward(self, X):
101 | h = self.slice1(X)
102 | h_relu1_2 = h
103 | h = self.slice2(h)
104 | h_relu2_2 = h
105 | h = self.slice3(h)
106 | h_relu3_3 = h
107 | h = self.slice4(h)
108 | h_relu4_3 = h
109 | h = self.slice5(h)
110 | h_relu5_3 = h
111 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
112 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
113 | return out
114 |
115 |
116 | def normalize_tensor(x,eps=1e-10):
117 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
118 | return x/(norm_factor+eps)
119 |
120 |
121 | def spatial_average(x, keepdim=True):
122 | return x.mean([2,3],keepdim=keepdim)
123 |
124 |
--------------------------------------------------------------------------------
/src/xtra/taming/modules/losses/segmentation.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | class BCELoss(nn.Module):
6 | def forward(self, prediction, target):
7 | loss = F.binary_cross_entropy_with_logits(prediction,target)
8 | return loss, {}
9 |
10 |
11 | class BCELossWithQuant(nn.Module):
12 | def __init__(self, codebook_weight=1.):
13 | super().__init__()
14 | self.codebook_weight = codebook_weight
15 |
16 | def forward(self, qloss, target, prediction, split):
17 | bce_loss = F.binary_cross_entropy_with_logits(prediction,target)
18 | loss = bce_loss + self.codebook_weight*qloss
19 | return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(),
20 | "{}/bce_loss".format(split): bce_loss.detach().mean(),
21 | "{}/quant_loss".format(split): qloss.detach().mean()
22 | }
23 |
--------------------------------------------------------------------------------
/src/xtra/taming/modules/losses/vqperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from taming.modules.losses.lpips import LPIPS
6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7 |
8 |
9 | class DummyLoss(nn.Module):
10 | def __init__(self):
11 | super().__init__()
12 |
13 |
14 | def adopt_weight(weight, global_step, threshold=0, value=0.):
15 | if global_step < threshold:
16 | weight = value
17 | return weight
18 |
19 |
20 | def hinge_d_loss(logits_real, logits_fake):
21 | loss_real = torch.mean(F.relu(1. - logits_real))
22 | loss_fake = torch.mean(F.relu(1. + logits_fake))
23 | d_loss = 0.5 * (loss_real + loss_fake)
24 | return d_loss
25 |
26 |
27 | def vanilla_d_loss(logits_real, logits_fake):
28 | d_loss = 0.5 * (
29 | torch.mean(torch.nn.functional.softplus(-logits_real)) +
30 | torch.mean(torch.nn.functional.softplus(logits_fake)))
31 | return d_loss
32 |
33 |
34 | class VQLPIPSWithDiscriminator(nn.Module):
35 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
36 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
37 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
38 | disc_ndf=64, disc_loss="hinge"):
39 | super().__init__()
40 | assert disc_loss in ["hinge", "vanilla"]
41 | self.codebook_weight = codebook_weight
42 | self.pixel_weight = pixelloss_weight
43 | self.perceptual_loss = LPIPS().eval()
44 | self.perceptual_weight = perceptual_weight
45 |
46 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
47 | n_layers=disc_num_layers,
48 | use_actnorm=use_actnorm,
49 | ndf=disc_ndf
50 | ).apply(weights_init)
51 | self.discriminator_iter_start = disc_start
52 | if disc_loss == "hinge":
53 | self.disc_loss = hinge_d_loss
54 | elif disc_loss == "vanilla":
55 | self.disc_loss = vanilla_d_loss
56 | else:
57 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
58 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
59 | self.disc_factor = disc_factor
60 | self.discriminator_weight = disc_weight
61 | self.disc_conditional = disc_conditional
62 |
63 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
64 | if last_layer is not None:
65 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
66 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
67 | else:
68 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
69 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
70 |
71 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
72 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
73 | d_weight = d_weight * self.discriminator_weight
74 | return d_weight
75 |
76 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
77 | global_step, last_layer=None, cond=None, split="train"):
78 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
79 | if self.perceptual_weight > 0:
80 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
81 | rec_loss = rec_loss + self.perceptual_weight * p_loss
82 | else:
83 | p_loss = torch.tensor([0.0])
84 |
85 | nll_loss = rec_loss
86 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
87 | nll_loss = torch.mean(nll_loss)
88 |
89 | # now the GAN part
90 | if optimizer_idx == 0:
91 | # generator update
92 | if cond is None:
93 | assert not self.disc_conditional
94 | logits_fake = self.discriminator(reconstructions.contiguous())
95 | else:
96 | assert self.disc_conditional
97 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
98 | g_loss = -torch.mean(logits_fake)
99 |
100 | try:
101 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
102 | except RuntimeError:
103 | assert not self.training
104 | d_weight = torch.tensor(0.0)
105 |
106 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
107 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
108 |
109 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
110 | "{}/quant_loss".format(split): codebook_loss.detach().mean(),
111 | "{}/nll_loss".format(split): nll_loss.detach().mean(),
112 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
113 | "{}/p_loss".format(split): p_loss.detach().mean(),
114 | "{}/d_weight".format(split): d_weight.detach(),
115 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
116 | "{}/g_loss".format(split): g_loss.detach().mean(),
117 | }
118 | return loss, log
119 |
120 | if optimizer_idx == 1:
121 | # second pass for discriminator update
122 | if cond is None:
123 | logits_real = self.discriminator(inputs.contiguous().detach())
124 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
125 | else:
126 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
127 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
128 |
129 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
130 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
131 |
132 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
133 | "{}/logits_real".format(split): logits_real.detach().mean(),
134 | "{}/logits_fake".format(split): logits_fake.detach().mean()
135 | }
136 | return d_loss, log
137 |
--------------------------------------------------------------------------------
/src/xtra/taming/modules/misc/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/xtra/taming/modules/misc/__init__.py
--------------------------------------------------------------------------------
/src/xtra/taming/modules/misc/coord.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class CoordStage(object):
4 | def __init__(self, n_embed, down_factor):
5 | self.n_embed = n_embed
6 | self.down_factor = down_factor
7 |
8 | def eval(self):
9 | return self
10 |
11 | def encode(self, c):
12 | """fake vqmodel interface"""
13 | assert 0.0 <= c.min() and c.max() <= 1.0
14 | b,ch,h,w = c.shape
15 | assert ch == 1
16 |
17 | c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor,
18 | mode="area")
19 | c = c.clamp(0.0, 1.0)
20 | c = self.n_embed*c
21 | c_quant = c.round()
22 | c_ind = c_quant.to(dtype=torch.long)
23 |
24 | info = None, None, c_ind
25 | return c_quant, None, info
26 |
27 | def decode(self, c):
28 | c = c/self.n_embed
29 | c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor,
30 | mode="nearest")
31 | return c
32 |
--------------------------------------------------------------------------------
/src/xtra/taming/modules/transformer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/xtra/taming/modules/transformer/__init__.py
--------------------------------------------------------------------------------
/src/xtra/taming/modules/util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def count_params(model):
6 | total_params = sum(p.numel() for p in model.parameters())
7 | return total_params
8 |
9 |
10 | class ActNorm(nn.Module):
11 | def __init__(self, num_features, logdet=False, affine=True,
12 | allow_reverse_init=False):
13 | assert affine
14 | super().__init__()
15 | self.logdet = logdet
16 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
17 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
18 | self.allow_reverse_init = allow_reverse_init
19 |
20 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
21 |
22 | def initialize(self, input):
23 | with torch.no_grad():
24 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
25 | mean = (
26 | flatten.mean(1)
27 | .unsqueeze(1)
28 | .unsqueeze(2)
29 | .unsqueeze(3)
30 | .permute(1, 0, 2, 3)
31 | )
32 | std = (
33 | flatten.std(1)
34 | .unsqueeze(1)
35 | .unsqueeze(2)
36 | .unsqueeze(3)
37 | .permute(1, 0, 2, 3)
38 | )
39 |
40 | self.loc.data.copy_(-mean)
41 | self.scale.data.copy_(1 / (std + 1e-6))
42 |
43 | def forward(self, input, reverse=False):
44 | if reverse:
45 | return self.reverse(input)
46 | if len(input.shape) == 2:
47 | input = input[:,:,None,None]
48 | squeeze = True
49 | else:
50 | squeeze = False
51 |
52 | _, _, height, width = input.shape
53 |
54 | if self.training and self.initialized.item() == 0:
55 | self.initialize(input)
56 | self.initialized.fill_(1)
57 |
58 | h = self.scale * (input + self.loc)
59 |
60 | if squeeze:
61 | h = h.squeeze(-1).squeeze(-1)
62 |
63 | if self.logdet:
64 | log_abs = torch.log(torch.abs(self.scale))
65 | logdet = height*width*torch.sum(log_abs)
66 | logdet = logdet * torch.ones(input.shape[0]).to(input)
67 | return h, logdet
68 |
69 | return h
70 |
71 | def reverse(self, output):
72 | if self.training and self.initialized.item() == 0:
73 | if not self.allow_reverse_init:
74 | raise RuntimeError(
75 | "Initializing ActNorm in reverse direction is "
76 | "disabled by default. Use allow_reverse_init=True to enable."
77 | )
78 | else:
79 | self.initialize(output)
80 | self.initialized.fill_(1)
81 |
82 | if len(output.shape) == 2:
83 | output = output[:,:,None,None]
84 | squeeze = True
85 | else:
86 | squeeze = False
87 |
88 | h = output / self.scale - self.loc
89 |
90 | if squeeze:
91 | h = h.squeeze(-1).squeeze(-1)
92 | return h
93 |
94 |
95 | class AbstractEncoder(nn.Module):
96 | def __init__(self):
97 | super().__init__()
98 |
99 | def encode(self, *args, **kwargs):
100 | raise NotImplementedError
101 |
102 |
103 | class Labelator(AbstractEncoder):
104 | """Net2Net Interface for Class-Conditional Model"""
105 | def __init__(self, n_classes, quantize_interface=True):
106 | super().__init__()
107 | self.n_classes = n_classes
108 | self.quantize_interface = quantize_interface
109 |
110 | def encode(self, c):
111 | c = c[:,None]
112 | if self.quantize_interface:
113 | return c, None, [None, None, c.long()]
114 | return c
115 |
116 |
117 | class SOSProvider(AbstractEncoder):
118 | # for unconditional training
119 | def __init__(self, sos_token, quantize_interface=True):
120 | super().__init__()
121 | self.sos_token = sos_token
122 | self.quantize_interface = quantize_interface
123 |
124 | def encode(self, x):
125 | # get batch size from data and replicate sos_token
126 | c = torch.ones(x.shape[0], 1)*self.sos_token
127 | c = c.long().to(x.device)
128 | if self.quantize_interface:
129 | return c, None, [None, None, c]
130 | return c
131 |
--------------------------------------------------------------------------------
/src/xtra/taming/modules/vqvae/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/SD/36517ca1ce91b5bb7077e8d9b291325f2ac37780/src/xtra/taming/modules/vqvae/__init__.py
--------------------------------------------------------------------------------
/src/xtra/taming/util.py:
--------------------------------------------------------------------------------
1 | import os, hashlib
2 | import requests
3 | from tqdm import tqdm
4 |
5 | URL_MAP = {
6 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
7 | }
8 |
9 | CKPT_MAP = {
10 | "vgg_lpips": "vgg.pth"
11 | }
12 |
13 | MD5_MAP = {
14 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
15 | }
16 |
17 |
18 | def download(url, local_path, chunk_size=1024):
19 | os.makedirs(os.path.split(local_path)[0], exist_ok=True)
20 | with requests.get(url, stream=True) as r:
21 | total_size = int(r.headers.get("content-length", 0))
22 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
23 | with open(local_path, "wb") as f:
24 | for data in r.iter_content(chunk_size=chunk_size):
25 | if data:
26 | f.write(data)
27 | pbar.update(chunk_size)
28 |
29 |
30 | def md5_hash(path):
31 | with open(path, "rb") as f:
32 | content = f.read()
33 | return hashlib.md5(content).hexdigest()
34 |
35 |
36 | def get_ckpt_path(name, root, check=False):
37 | assert name in URL_MAP
38 | path = os.path.join(root, CKPT_MAP[name])
39 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
40 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
41 | download(URL_MAP[name], path)
42 | md5 = md5_hash(path)
43 | assert md5 == MD5_MAP[name], md5
44 | return path
45 |
46 |
47 | class KeyNotFoundError(Exception):
48 | def __init__(self, cause, keys=None, visited=None):
49 | self.cause = cause
50 | self.keys = keys
51 | self.visited = visited
52 | messages = list()
53 | if keys is not None:
54 | messages.append("Key not found: {}".format(keys))
55 | if visited is not None:
56 | messages.append("Visited: {}".format(visited))
57 | messages.append("Cause:\n{}".format(cause))
58 | message = "\n".join(messages)
59 | super().__init__(message)
60 |
61 |
62 | def retrieve(
63 | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
64 | ):
65 | """Given a nested list or dict return the desired value at key expanding
66 | callable nodes if necessary and :attr:`expand` is ``True``. The expansion
67 | is done in-place.
68 |
69 | Parameters
70 | ----------
71 | list_or_dict : list or dict
72 | Possibly nested list or dictionary.
73 | key : str
74 | key/to/value, path like string describing all keys necessary to
75 | consider to get to the desired value. List indices can also be
76 | passed here.
77 | splitval : str
78 | String that defines the delimiter between keys of the
79 | different depth levels in `key`.
80 | default : obj
81 | Value returned if :attr:`key` is not found.
82 | expand : bool
83 | Whether to expand callable nodes on the path or not.
84 |
85 | Returns
86 | -------
87 | The desired value or if :attr:`default` is not ``None`` and the
88 | :attr:`key` is not found returns ``default``.
89 |
90 | Raises
91 | ------
92 | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
93 | ``None``.
94 | """
95 |
96 | keys = key.split(splitval)
97 |
98 | success = True
99 | try:
100 | visited = []
101 | parent = None
102 | last_key = None
103 | for key in keys:
104 | if callable(list_or_dict):
105 | if not expand:
106 | raise KeyNotFoundError(
107 | ValueError(
108 | "Trying to get past callable node with expand=False."
109 | ),
110 | keys=keys,
111 | visited=visited,
112 | )
113 | list_or_dict = list_or_dict()
114 | parent[last_key] = list_or_dict
115 |
116 | last_key = key
117 | parent = list_or_dict
118 |
119 | try:
120 | if isinstance(list_or_dict, dict):
121 | list_or_dict = list_or_dict[key]
122 | else:
123 | list_or_dict = list_or_dict[int(key)]
124 | except (KeyError, IndexError, ValueError) as e:
125 | raise KeyNotFoundError(e, keys=keys, visited=visited)
126 |
127 | visited += [key]
128 | # final expansion of retrieved value
129 | if expand and callable(list_or_dict):
130 | list_or_dict = list_or_dict()
131 | parent[last_key] = list_or_dict
132 | except KeyNotFoundError as e:
133 | if default is None:
134 | raise e
135 | else:
136 | list_or_dict = default
137 | success = False
138 |
139 | if not pass_success:
140 | return list_or_dict
141 | else:
142 | return list_or_dict, success
143 |
144 |
145 | if __name__ == "__main__":
146 | config = {"keya": "a",
147 | "keyb": "b",
148 | "keyc":
149 | {"cc1": 1,
150 | "cc2": 2,
151 | }
152 | }
153 | from omegaconf import OmegaConf
154 | config = OmegaConf.create(config)
155 | print(config)
156 | retrieve(config, "keya")
157 |
158 |
--------------------------------------------------------------------------------
/src/yaml/v1-finetune-custom.yaml:
--------------------------------------------------------------------------------
1 | # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion.
2 | # Adobe’s modifications are licensed under the Adobe Research License.
3 |
4 | model:
5 | base_learning_rate: 1.0e-05
6 | target: custom.model.CustomDiffusion
7 | params:
8 | linear_start: 0.00085
9 | linear_end: 0.0120
10 | num_timesteps_cond: 1
11 | log_every_t: 200
12 | timesteps: 1000
13 | first_stage_key: "image"
14 | cond_stage_key: "caption"
15 | image_size: 64
16 | channels: 4
17 | cond_stage_trainable: true # Note: different from the one we trained before
18 | add_token: True
19 | freeze_model: "crossattn-kv"
20 | conditioning_key: crossattn
21 | monitor: val/loss_simple_ema
22 | scale_factor: 0.18215
23 | use_ema: False
24 |
25 | unet_config:
26 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
27 | params:
28 | image_size: 64 # unused
29 | in_channels: 4
30 | out_channels: 4
31 | model_channels: 320
32 | attention_resolutions: [ 4, 2, 1 ]
33 | num_res_blocks: 2
34 | channel_mult: [ 1, 2, 4, 4 ]
35 | num_heads: 8
36 | use_spatial_transformer: True
37 | transformer_depth: 1
38 | context_dim: 768
39 | use_checkpoint: False
40 | legacy: False
41 |
42 | first_stage_config:
43 | target: ldm.models.autoencoder.AutoencoderKL
44 | params:
45 | embed_dim: 4
46 | monitor: val/rec_loss
47 | ddconfig:
48 | double_z: true
49 | z_channels: 4
50 | resolution: 256
51 | in_channels: 3
52 | out_ch: 3
53 | ch: 128
54 | ch_mult:
55 | - 1
56 | - 2
57 | - 4
58 | - 4
59 | num_res_blocks: 2
60 | attn_resolutions: []
61 | dropout: 0.0
62 | lossconfig:
63 | target: torch.nn.Identity
64 |
65 | cond_stage_config:
66 | target: custom.modules.FrozenCLIPEmbedderWrapper
67 | params:
68 | modifier_token:
69 |
70 | data:
71 | target: train.DataModuleFromConfig
72 | params:
73 | batch_size: 1
74 | num_workers: 2
75 | wrap: false
76 | train:
77 | target: custom.finetune_data.MaskBase
78 | params:
79 | size: 512
80 | train2:
81 | target: custom.finetune_data.MaskBase
82 | params:
83 | size: 512
84 |
85 | lightning:
86 | modelcheckpoint: # from usual finetune
87 | params:
88 | verbose: false
89 | save_last: true
90 | callbacks:
91 | image_logger:
92 | target: train.ImageLogger
93 | params:
94 | batch_frequency: 500
95 | max_images: 4
96 | clamp: True
97 | increase_log_steps: false
98 |
99 | trainer:
100 | max_steps: 1000 # for gpu=1 batch=1 [orig was 300]
101 | find_unused_parameters: False
102 |
--------------------------------------------------------------------------------
/src/yaml/v1-finetune.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-03
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: caption
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: true # Note: different from the one we trained before
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 | embedding_reg_weight: 0.0
20 |
21 | personalization_config:
22 | target: ldm.modules.embedding_manager.EmbeddingManager
23 | params:
24 | placeholder_strings: ["*"]
25 | initializer_words: ["sculpture"]
26 | per_image_tokens: false
27 | num_vectors_per_token: 1
28 | progressive_words: False
29 |
30 | unet_config:
31 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
32 | params:
33 | image_size: 32 # unused
34 | in_channels: 4
35 | out_channels: 4
36 | model_channels: 320
37 | attention_resolutions: [ 4, 2, 1 ]
38 | num_res_blocks: 2
39 | channel_mult: [ 1, 2, 4, 4 ]
40 | num_heads: 8
41 | use_spatial_transformer: True
42 | transformer_depth: 1
43 | context_dim: 768
44 | use_checkpoint: True
45 | legacy: False
46 |
47 | first_stage_config:
48 | target: ldm.models.autoencoder.AutoencoderKL
49 | params:
50 | embed_dim: 4
51 | monitor: val/rec_loss
52 | ddconfig:
53 | double_z: true
54 | z_channels: 4
55 | resolution: 256
56 | in_channels: 3
57 | out_ch: 3
58 | ch: 128
59 | ch_mult:
60 | - 1
61 | - 2
62 | - 4
63 | - 4
64 | num_res_blocks: 2
65 | attn_resolutions: []
66 | dropout: 0.0
67 | lossconfig:
68 | target: torch.nn.Identity
69 |
70 | cond_stage_config:
71 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
72 |
73 | data:
74 | target: train.DataModuleFromConfig
75 | params:
76 | batch_size: 1
77 | num_workers: 4
78 | wrap: false
79 | train:
80 | target: ldm.data.personalized.PersonalizedBase
81 | params:
82 | size: 512
83 | set: train
84 | per_image_tokens: false
85 | repeats: 10 # 100
86 | validation:
87 | target: ldm.data.personalized.PersonalizedBase
88 | params:
89 | size: 512
90 | set: val
91 | per_image_tokens: false
92 | # repeats: 10
93 |
94 | lightning:
95 | modelcheckpoint:
96 | params:
97 | every_n_train_steps: 500
98 | verbose: false
99 | callbacks:
100 | image_logger:
101 | target: train.ImageLogger
102 | params:
103 | batch_frequency: 500
104 | max_images: 2
105 | increase_log_steps: False
106 |
107 | trainer:
108 | benchmark: True
109 | max_steps: 5000
110 | find_unused_parameters: False
111 |
--------------------------------------------------------------------------------
/src/yaml/v1-finetune_style.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-03
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: caption
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: true # Note: different from the one we trained before
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 | embedding_reg_weight: 0.0
20 |
21 | personalization_config:
22 | target: ldm.modules.embedding_manager.EmbeddingManager
23 | params:
24 | placeholder_strings: ["*"]
25 | initializer_words: ["painting"]
26 | per_image_tokens: false
27 | num_vectors_per_token: 1
28 |
29 | unet_config:
30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 32 # unused
33 | in_channels: 4
34 | out_channels: 4
35 | model_channels: 320
36 | attention_resolutions: [ 4, 2, 1 ]
37 | num_res_blocks: 2
38 | channel_mult: [ 1, 2, 4, 4 ]
39 | num_heads: 8
40 | use_spatial_transformer: True
41 | transformer_depth: 1
42 | context_dim: 768
43 | use_checkpoint: True
44 | legacy: False
45 |
46 | first_stage_config:
47 | target: ldm.models.autoencoder.AutoencoderKL
48 | params:
49 | embed_dim: 4
50 | monitor: val/rec_loss
51 | ddconfig:
52 | double_z: true
53 | z_channels: 4
54 | resolution: 256
55 | in_channels: 3
56 | out_ch: 3
57 | ch: 128
58 | ch_mult:
59 | - 1
60 | - 2
61 | - 4
62 | - 4
63 | num_res_blocks: 2
64 | attn_resolutions: []
65 | dropout: 0.0
66 | lossconfig:
67 | target: torch.nn.Identity
68 |
69 | cond_stage_config:
70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
71 |
72 | data:
73 | target: main.DataModuleFromConfig
74 | params:
75 | batch_size: 2
76 | num_workers: 16
77 | wrap: false
78 | train:
79 | target: ldm.data.personalized_style.PersonalizedBase
80 | params:
81 | size: 512
82 | set: train
83 | per_image_tokens: false
84 | repeats: 100
85 | validation:
86 | target: ldm.data.personalized_style.PersonalizedBase
87 | params:
88 | size: 512
89 | set: val
90 | per_image_tokens: false
91 | repeats: 10
92 |
93 | lightning:
94 | callbacks:
95 | image_logger:
96 | target: main.ImageLogger
97 | params:
98 | batch_frequency: 500
99 | max_images: 8
100 | increase_log_steps: False
101 |
102 | trainer:
103 | benchmark: True
--------------------------------------------------------------------------------
/src/yaml/v1-inference.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-04
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "txt"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false # Note: different from the one we trained before
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: ldm.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [ 10000 ]
24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25 | f_start: [ 1.e-6 ]
26 | f_max: [ 1. ]
27 | f_min: [ 1. ]
28 |
29 | personalization_config:
30 | target: ldm.modules.embedding_manager.EmbeddingManager
31 | params:
32 | placeholder_strings: ["*"]
33 | initializer_words: [""]
34 | per_image_tokens: false
35 | num_vectors_per_token: 1
36 | progressive_words: False
37 |
38 | unet_config:
39 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
40 | params:
41 | image_size: 32 # unused
42 | in_channels: 4
43 | out_channels: 4
44 | model_channels: 320
45 | attention_resolutions: [ 4, 2, 1 ]
46 | num_res_blocks: 2
47 | channel_mult: [ 1, 2, 4, 4 ]
48 | num_heads: 8
49 | use_spatial_transformer: True
50 | transformer_depth: 1
51 | context_dim: 768
52 | use_checkpoint: True
53 | legacy: False
54 |
55 | first_stage_config:
56 | target: ldm.models.autoencoder.AutoencoderKL
57 | params:
58 | embed_dim: 4
59 | monitor: val/rec_loss
60 | ddconfig:
61 | double_z: true
62 | z_channels: 4
63 | resolution: 256
64 | in_channels: 3
65 | out_ch: 3
66 | ch: 128
67 | ch_mult:
68 | - 1
69 | - 2
70 | - 4
71 | - 4
72 | num_res_blocks: 2
73 | attn_resolutions: []
74 | dropout: 0.0
75 | lossconfig:
76 | target: torch.nn.Identity
77 |
78 | cond_stage_config:
79 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
80 |
--------------------------------------------------------------------------------
/src/yaml/v1-inpainting.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 7.5e-05
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "txt"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false # Note: different from the one we trained before
15 | conditioning_key: hybrid # important
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: ldm.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25 | f_start: [ 1.e-6 ]
26 | f_max: [ 1. ]
27 | f_min: [ 1. ]
28 |
29 | personalization_config:
30 | target: ldm.modules.embedding_manager.EmbeddingManager
31 | params:
32 | placeholder_strings: ["*"]
33 | initializer_words: [""]
34 | per_image_tokens: false
35 | num_vectors_per_token: 1
36 | progressive_words: False
37 |
38 | unet_config:
39 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
40 | params:
41 | image_size: 32 # unused
42 | in_channels: 9 # 4 data + 4 downscaled image + 1 mask
43 | out_channels: 4
44 | model_channels: 320
45 | attention_resolutions: [ 4, 2, 1 ]
46 | num_res_blocks: 2
47 | channel_mult: [ 1, 2, 4, 4 ]
48 | num_heads: 8
49 | use_spatial_transformer: True
50 | transformer_depth: 1
51 | context_dim: 768
52 | use_checkpoint: True
53 | legacy: False
54 |
55 | first_stage_config:
56 | target: ldm.models.autoencoder.AutoencoderKL
57 | params:
58 | embed_dim: 4
59 | monitor: val/rec_loss
60 | ddconfig:
61 | double_z: true
62 | z_channels: 4
63 | resolution: 256
64 | in_channels: 3
65 | out_ch: 3
66 | ch: 128
67 | ch_mult:
68 | - 1
69 | - 2
70 | - 4
71 | - 4
72 | num_res_blocks: 2
73 | attn_resolutions: []
74 | dropout: 0.0
75 | lossconfig:
76 | target: torch.nn.Identity
77 |
78 | cond_stage_config:
79 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
80 |
--------------------------------------------------------------------------------
/src/yaml/v2-inference-v.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-04
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | parameterization: "v"
6 | linear_start: 0.00085
7 | linear_end: 0.0120
8 | num_timesteps_cond: 1
9 | log_every_t: 200
10 | timesteps: 1000
11 | first_stage_key: "jpg"
12 | cond_stage_key: "txt"
13 | image_size: 64
14 | channels: 4
15 | cond_stage_trainable: false # Note: different from the one we trained before
16 | conditioning_key: crossattn
17 | monitor: val/loss_simple_ema
18 | scale_factor: 0.18215
19 | use_ema: False
20 |
21 | scheduler_config: # 10000 warmup steps
22 | target: ldm.lr_scheduler.LambdaLinearScheduler
23 | params:
24 | warm_up_steps: [ 10000 ]
25 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
26 | f_start: [ 1.e-6 ]
27 | f_max: [ 1. ]
28 | f_min: [ 1. ]
29 |
30 | personalization_config:
31 | target: ldm.modules.embedding_manager.EmbeddingManager
32 | params:
33 | placeholder_strings: ["*"]
34 | initializer_words: [""]
35 | per_image_tokens: false
36 | num_vectors_per_token: 1
37 | progressive_words: False
38 |
39 | unet_config:
40 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
41 | params:
42 | use_checkpoint: True
43 | use_fp16: True
44 | image_size: 32 # unused
45 | in_channels: 4
46 | out_channels: 4
47 | model_channels: 320
48 | attention_resolutions: [ 4, 2, 1 ]
49 | num_res_blocks: 2
50 | channel_mult: [ 1, 2, 4, 4 ]
51 | num_head_channels: 64 # need to fix for flash-attn
52 | use_spatial_transformer: True
53 | use_linear_in_transformer: True
54 | transformer_depth: 1
55 | context_dim: 1024
56 | legacy: False
57 |
58 | first_stage_config:
59 | target: ldm.models.autoencoder.AutoencoderKL
60 | params:
61 | embed_dim: 4
62 | monitor: val/rec_loss
63 | ddconfig:
64 | #attn_type: "vanilla-xformers"
65 | double_z: true
66 | z_channels: 4
67 | resolution: 256
68 | in_channels: 3
69 | out_ch: 3
70 | ch: 128
71 | ch_mult:
72 | - 1
73 | - 2
74 | - 4
75 | - 4
76 | num_res_blocks: 2
77 | attn_resolutions: []
78 | dropout: 0.0
79 | lossconfig:
80 | target: torch.nn.Identity
81 |
82 | cond_stage_config:
83 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
84 | params:
85 | freeze: True
86 | layer: "penultimate"
87 |
88 |
--------------------------------------------------------------------------------
/src/yaml/v2-inference.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-04
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "txt"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false # Note: different from the one we trained before
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: ldm.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [ 10000 ]
24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25 | f_start: [ 1.e-6 ]
26 | f_max: [ 1. ]
27 | f_min: [ 1. ]
28 |
29 | personalization_config:
30 | target: ldm.modules.embedding_manager.EmbeddingManager
31 | params:
32 | placeholder_strings: ["*"]
33 | initializer_words: [""]
34 | per_image_tokens: false
35 | num_vectors_per_token: 1
36 | progressive_words: False
37 |
38 | unet_config:
39 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
40 | params:
41 | use_checkpoint: True
42 | use_fp16: True
43 | image_size: 32 # unused
44 | in_channels: 4
45 | out_channels: 4
46 | model_channels: 320
47 | attention_resolutions: [ 4, 2, 1 ]
48 | num_res_blocks: 2
49 | channel_mult: [ 1, 2, 4, 4 ]
50 | num_head_channels: 64 # need to fix for flash-attn
51 | use_spatial_transformer: True
52 | use_linear_in_transformer: True
53 | transformer_depth: 1
54 | context_dim: 1024
55 | legacy: False
56 |
57 | first_stage_config:
58 | target: ldm.models.autoencoder.AutoencoderKL
59 | params:
60 | embed_dim: 4
61 | monitor: val/rec_loss
62 | ddconfig:
63 | #attn_type: "vanilla-xformers"
64 | double_z: true
65 | z_channels: 4
66 | resolution: 256
67 | in_channels: 3
68 | out_ch: 3
69 | ch: 128
70 | ch_mult:
71 | - 1
72 | - 2
73 | - 4
74 | - 4
75 | num_res_blocks: 2
76 | attn_resolutions: []
77 | dropout: 0.0
78 | lossconfig:
79 | target: torch.nn.Identity
80 |
81 | cond_stage_config:
82 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
83 | params:
84 | freeze: True
85 | layer: "penultimate"
86 |
87 |
--------------------------------------------------------------------------------
/src/yaml/v2-inpainting.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 7.5e-05
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "txt"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false # Note: different from the one we trained before
15 | conditioning_key: hybrid # important
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: ldm.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25 | f_start: [ 1.e-6 ]
26 | f_max: [ 1. ]
27 | f_min: [ 1. ]
28 |
29 | personalization_config:
30 | target: ldm.modules.embedding_manager.EmbeddingManager
31 | params:
32 | placeholder_strings: ["*"]
33 | initializer_words: [""]
34 | per_image_tokens: false
35 | num_vectors_per_token: 1
36 | progressive_words: False
37 |
38 | unet_config:
39 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
40 | params:
41 | use_checkpoint: True
42 | image_size: 32 # unused
43 | in_channels: 9
44 | out_channels: 4
45 | model_channels: 320
46 | attention_resolutions: [ 4, 2, 1 ]
47 | num_res_blocks: 2
48 | channel_mult: [ 1, 2, 4, 4 ]
49 | num_head_channels: 64 # need to fix for flash-attn
50 | use_spatial_transformer: True
51 | use_linear_in_transformer: True
52 | transformer_depth: 1
53 | context_dim: 1024
54 | legacy: False
55 |
56 | first_stage_config:
57 | target: ldm.models.autoencoder.AutoencoderKL
58 | params:
59 | embed_dim: 4
60 | monitor: val/rec_loss
61 | ddconfig:
62 | #attn_type: "vanilla-xformers"
63 | double_z: true
64 | z_channels: 4
65 | resolution: 256
66 | in_channels: 3
67 | out_ch: 3
68 | ch: 128
69 | ch_mult:
70 | - 1
71 | - 2
72 | - 4
73 | - 4
74 | num_res_blocks: 2
75 | attn_resolutions: [ ]
76 | dropout: 0.0
77 | lossconfig:
78 | target: torch.nn.Identity
79 |
80 | cond_stage_config:
81 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
82 | params:
83 | freeze: True
84 | layer: "penultimate"
85 |
--------------------------------------------------------------------------------
/src/yaml/v2-midas.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-07
3 | target: ldm.models.diffusion.ddpm.LatentDepth2ImageDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "txt"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false
15 | conditioning_key: hybrid
16 | scale_factor: 0.18215
17 | monitor: val/loss_simple_ema
18 | finetune_keys: null
19 | use_ema: False
20 |
21 | depth_stage_config:
22 | target: ldm.modules.midas.api.MiDaSInference
23 | params:
24 | model_type: "dpt_hybrid"
25 |
26 | personalization_config:
27 | target: ldm.modules.embedding_manager.EmbeddingManager
28 | params:
29 | placeholder_strings: ["*"]
30 | initializer_words: [""]
31 | per_image_tokens: false
32 | num_vectors_per_token: 1
33 | progressive_words: False
34 |
35 | unet_config:
36 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
37 | params:
38 | use_checkpoint: True
39 | image_size: 32 # unused
40 | in_channels: 5
41 | out_channels: 4
42 | model_channels: 320
43 | attention_resolutions: [ 4, 2, 1 ]
44 | num_res_blocks: 2
45 | channel_mult: [ 1, 2, 4, 4 ]
46 | num_head_channels: 64 # need to fix for flash-attn
47 | use_spatial_transformer: True
48 | use_linear_in_transformer: True
49 | transformer_depth: 1
50 | context_dim: 1024
51 | legacy: False
52 |
53 | first_stage_config:
54 | target: ldm.models.autoencoder.AutoencoderKL
55 | params:
56 | embed_dim: 4
57 | monitor: val/rec_loss
58 | ddconfig:
59 | #attn_type: "vanilla-xformers"
60 | double_z: true
61 | z_channels: 4
62 | resolution: 256
63 | in_channels: 3
64 | out_ch: 3
65 | ch: 128
66 | ch_mult:
67 | - 1
68 | - 2
69 | - 4
70 | - 4
71 | num_res_blocks: 2
72 | attn_resolutions: [ ]
73 | dropout: 0.0
74 | lossconfig:
75 | target: torch.nn.Identity
76 |
77 | cond_stage_config:
78 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
79 | params:
80 | freeze: True
81 | layer: "penultimate"
82 |
83 |
84 |
--------------------------------------------------------------------------------
/train.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 | if [%1]==[] goto help
3 | echo .. %1
4 |
5 | python src/train.py --token "%1" --term "%2" --data data/%1 ^
6 | --reg_data data/%2 ^
7 | %3 %4 %5 %6 %7 %8 %9
8 |
9 | goto end
10 |
11 | :help
12 | echo Usage: train "" category
13 | echo e.g.: train "" lady
14 | echo or: train "" pattern --style
15 | :end
16 |
--------------------------------------------------------------------------------
/txt.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 | set KMP_DUPLICATE_LIB_OK=TRUE
3 | if [%1]==[] goto help
4 | echo .. %1
5 |
6 | python src/_sdrun.py -v -t %1 ^
7 | %2 %3 %4 %5 %6 %7 %8 %9
8 |
9 | goto end
10 |
11 | :help
12 | echo Usage: txt "text prompt" [...]
13 | echo or: txt textfile [...]
14 | :end
15 |
--------------------------------------------------------------------------------
/walk.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 | set KMP_DUPLICATE_LIB_OK=TRUE
3 | if [%1]==[] goto help
4 | echo .. %1
5 |
6 | python src/latwalk.py -v -t %1 ^
7 | %2 %3 %4 %5 %6 %7 %8 %9
8 |
9 | ffmpeg -v warning -y -i _out\%~n1\%%06d.jpg _out\%~n1-%2%3%4%5%6%7%8%9.mp4
10 | goto end
11 |
12 | :help
13 | echo Usage: walk textfile [...]
14 | :end
15 |
--------------------------------------------------------------------------------