├── .gitattributes
├── LICENSE
├── README.md
├── asset
├── docs
│ └── badge-website.svg
└── images
│ └── teasor.png
├── datasets
└── tide_uwdense.py
├── diffusion
├── __init__.py
├── data
│ ├── __init__.py
│ ├── builder.py
│ ├── datasets
│ │ ├── Dreambooth.py
│ │ ├── InternalData.py
│ │ ├── InternalData_ms.py
│ │ ├── SA.py
│ │ ├── __init__.py
│ │ ├── pixart_control.py
│ │ └── utils.py
│ └── transforms.py
├── dpm_solver.py
├── iddpm.py
├── lcm_scheduler.py
├── model
│ ├── __init__.py
│ ├── builder.py
│ ├── diffusion_utils.py
│ ├── dpm_solver.py
│ ├── edm_sample.py
│ ├── gaussian_diffusion.py
│ ├── hed.py
│ ├── llava
│ │ ├── __init__.py
│ │ ├── llava_mpt.py
│ │ └── mpt
│ │ │ ├── attention.py
│ │ │ ├── blocks.py
│ │ │ ├── configuration_mpt.py
│ │ │ ├── modeling_mpt.py
│ │ │ ├── norm.py
│ │ │ └── param_init_fns.py
│ ├── nets
│ │ ├── PixArt.py
│ │ ├── PixArtMS.py
│ │ ├── PixArt_blocks.py
│ │ ├── __init__.py
│ │ └── pixart_controlnet.py
│ ├── respace.py
│ ├── sa_solver.py
│ ├── t5.py
│ ├── timestep_sampler.py
│ └── utils.py
├── sa_sampler.py
├── sa_solver_diffusers.py
└── utils
│ ├── __init__.py
│ ├── checkpoint.py
│ ├── data_sampler.py
│ ├── dist_utils.py
│ ├── logger.py
│ ├── lr_scheduler.py
│ ├── misc.py
│ └── optimizer.py
├── inference.py
├── requirements.txt
└── tide
├── __init__.py
├── config.py
├── pipeline
├── layers.py
├── pipeline_tide.py
├── tide_transformer.py
├── transformer_attentions.py
└── transformer_blocks.py
├── train_tide_hf.py
└── utils
├── __init__.py
├── dataset_palette.py
├── mask_process.py
└── prompt_process.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [CVPR 2025] A Unified Image-Dense Annotation Generation Model for Underwater Scenes
2 |
3 |
4 | [](https://hongklin.github.io/TIDE/)
5 | [](https://arxiv.org/abs/2503.21771)
6 | [](https://www.apache.org/licenses/LICENSE-2.0)
7 |
8 | ## 🌊 **Introduction**
9 | We present TIDE, a unified underwater image-dense annotation generation model. Its core lies in the shared layout information and the natural complementarity between multimodal features. Our model, derived from the pre-trained text-to-image model and fine-tuned with underwater data, enables the generation of highly consistent underwater image-dense annotations from solely text conditions.
10 |
11 | 
12 | ---
13 | ## 🐚 **News**
14 | - 2025-4-8: The training data and the SynTIDE dataset are available [here](https://huggingface.co/datasets/hongk1998/TIDE/tree/main)!
15 | - 2025-3-28: The training and inference code is now available!
16 | - 2025-2-27: Our TIDE is accepted to CVPR 2025!
17 | ---
18 |
19 | ## 🪸 Dependencies and Installation
20 |
21 | - Python >= 3.9 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html))
22 | - [PyTorch >= 2.0.1+cu11.7](https://pytorch.org/)
23 | ```bash
24 | conda create -n TIDE python=3.9
25 | conda activate TIDE
26 | conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
27 |
28 | git clone https://github.com/HongkLin/TIDE
29 | cd TIDE
30 | pip install -r requirements.txt
31 | ```
32 |
33 | ## 🐬 Inference
34 | Download the pre-trained [PixArt-α](https://huggingface.co/PixArt-alpha/PixArt-XL-2-512x512), [MiniTransformer](https://github.com/HongkLin/TIDE/releases/download/model_weights/TIDE_MiniTransformer.zip), and [TIDE checkpoint](https://github.com/HongkLin/TIDE/releases/download/model_weights/TIDE_r32_64_b4_200k.zip), then modify the model weights path.
35 | ```bash
36 | python inference.py --model_weights_dir ./model_weights --text_prompt "A large school of fish swimming in a circle." --output ./outputs
37 | ```
38 |
39 | ## 🐢 Training
40 |
41 | ### 🏖️ ️Training Data Prepare
42 | - Download [SUIM](https://github.com/xahidbuffon/SUIM), [UIIS](https://github.com/LiamLian0727/WaterMask), [USIS10K](https://github.com/LiamLian0727/USIS10K) datasets.
43 | - The semantic segmentation annotations are obtained by merging instances with the same semantics.
44 | - The depth annotations are obtained by [Depth Anything V2](https://github.com/DepthAnything/Depth-Anything-V2), and the inverse depth results are saved as npy files.
45 | - The image caption and JSON file for organizing training data can follow [Atlantis](https://github.com/zkawfanx/Atlantis), which we also do.
46 |
47 | The final dataset should be ordered as follow:
48 | ```
49 | datasets/
50 | UWD_triplets/
51 | images/
52 | train_05543.jpg
53 | ...
54 | semseg_annotations/
55 | train_05543.jpg
56 | ...
57 | depth_annotations/
58 | train_05543_raw_depth_meter.npy
59 | ...
60 | TrainTIDE_Caption.json
61 | ```
62 | If you have prepared the training data and environment, you can run the following script to start the training:
63 | ```bash
64 | accelerate launch --num_processes=4 --main_process_port=36666 ./tide/train_tide_hf.py \
65 | --max_train_steps 200000 --learning_rate=1e-4 --train_batch_size=1 \
66 | --gradient_accumulation_steps=1 --seed=42 --dataloader_num_workers=4 --validation_steps 10000 \
67 | --wandb_name=tide_r32_64_b4_200k --output_dir=./outputs/tide_r32_64_b4_200k
68 | ```
69 |
70 | # 🤗Acknowledgements
71 | - Thanks to [Diffusers](https://github.com/huggingface/diffusers) for their wonderful technical support and awesome collaboration!
72 | - Thanks to [Hugging Face](https://github.com/huggingface) for sponsoring the nicely demo!
73 | - Thanks to [DiT](https://github.com/facebookresearch/DiT) for their wonderful work and codebase!
74 | - Thanks to [PixArt-α](https://github.com/PixArt-alpha/PixArt-alpha) for their wonderful work and codebase!
75 |
76 | ## 📖BibTeX
77 | ```
78 | @inproceedings{lin2025tide,
79 | title={A Unified Image-Dense Annotation Generation Model for Underwater Scenes},
80 | author={Lin, Hongkai and Liang, Dingkang and Qi, Zhenghao and Bai, Xiang},
81 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
82 | year={2025},
83 | }
84 | ```
--------------------------------------------------------------------------------
/asset/docs/badge-website.svg:
--------------------------------------------------------------------------------
1 |
2 |
130 |
--------------------------------------------------------------------------------
/asset/images/teasor.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HongkLin/TIDE/d0cba53604b3dd9e16e4f81ce24d7f4be3ba0d4e/asset/images/teasor.png
--------------------------------------------------------------------------------
/datasets/tide_uwdense.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import datasets
4 | import pandas as pd
5 | import numpy as np
6 |
7 | _VERSION = datasets.Version("0.0.1")
8 |
9 | _DESCRIPTION = "TODO"
10 | _HOMEPAGE = "TODO"
11 | _LICENSE = "TODO"
12 | _CITATION = "TODO"
13 |
14 | _FEATURES = datasets.Features(
15 | {
16 | "text": datasets.Value("string"),
17 | "image": datasets.Image(),
18 | "depth_image": datasets.Value("string"),
19 | "semantic_image": datasets.Image(),
20 | },
21 | )
22 | _root = os.getenv("DETECTRON2_DATASETS", "path_to/datasets")
23 | METADATA_PATH = os.path.join(_root, "UWD_triplets/UWDense.json")
24 | IMAGES_DIR = os.path.join(_root, "UWD_triplets/images")
25 | DEPTH_IMAGES_DIR = os.path.join(_root, "UWD_triplets/disparity_depth")
26 | SEMANTIC_IMAGES_DIR = os.path.join(_root, "UWD_triplets/annotations")
27 |
28 | _DEFAULT_CONFIG = datasets.BuilderConfig(name="default", version=_VERSION)
29 |
30 | class Depth2Underwater(datasets.GeneratorBasedBuilder):
31 | BUILDER_CONFIGS = [_DEFAULT_CONFIG]
32 | DEFAULT_CONFIG_NAME = "default"
33 |
34 | def _info(self):
35 | return datasets.DatasetInfo(
36 | description=_DESCRIPTION,
37 | features=_FEATURES,
38 | supervised_keys=None,
39 | homepage=_HOMEPAGE,
40 | license=_LICENSE,
41 | citation=_CITATION,
42 | )
43 |
44 | def _split_generators(self, dl_manager):
45 | metadata_path = METADATA_PATH
46 | images_dir = IMAGES_DIR
47 | depth_images_dir = DEPTH_IMAGES_DIR
48 | semantic_images_dir = SEMANTIC_IMAGES_DIR
49 | return [
50 | datasets.SplitGenerator(
51 | name=datasets.Split.TRAIN,
52 | # These kwargs will be passed to _generate_examples
53 | gen_kwargs={
54 | "metadata_path": metadata_path,
55 | "images_dir": images_dir,
56 | "depth_images_dir": depth_images_dir,
57 | "semantic_images_dir": semantic_images_dir,
58 | },
59 | ),
60 | ]
61 |
62 | def _generate_examples(self, metadata_path, images_dir, depth_images_dir, semantic_images_dir):
63 | metadata = pd.read_json(metadata_path, lines=True)
64 |
65 | for _, row in metadata.iterrows():
66 | text = row["text"]
67 |
68 | image_path = row["image"]
69 | image_path = os.path.join(images_dir, image_path)
70 | image = open(image_path, "rb").read()
71 |
72 | if '.jpg' in row["conditioning_image"]:
73 | depth_image_path = row["conditioning_image"].replace('.jpg', '_raw_depth_meter.npy')
74 | semantic_image_path = row["conditioning_image"].replace('.jpg', '.png')
75 | else:
76 | assert '.png' in row["conditioning_image"]
77 | depth_image_path = row["conditioning_image"].replace('.png', '_raw_depth_meter.npy')
78 | semantic_image_path = row["conditioning_image"]
79 | depth_image_path = os.path.join(
80 | depth_images_dir, depth_image_path
81 | )
82 | semantic_image_path = os.path.join(
83 | semantic_images_dir, semantic_image_path
84 | )
85 | semantic_image = open(semantic_image_path, "rb").read()
86 |
87 | yield row["image"], {
88 | "text": text,
89 | "image": {
90 | "path": image_path,
91 | "bytes": image,
92 | },
93 | "depth_image": depth_image_path,
94 | "semantic_image": semantic_image
95 | }
--------------------------------------------------------------------------------
/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | from .iddpm import IDDPM
7 | from .dpm_solver import DPMS
8 | from .sa_sampler import SASolverSampler
9 |
--------------------------------------------------------------------------------
/diffusion/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .datasets import *
2 | from .transforms import get_transform
3 |
--------------------------------------------------------------------------------
/diffusion/data/builder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 |
4 | from mmcv import Registry, build_from_cfg
5 | from torch.utils.data import DataLoader
6 |
7 | from diffusion.data.transforms import get_transform
8 | from diffusion.utils.logger import get_root_logger
9 |
10 | DATASETS = Registry('datasets')
11 |
12 | DATA_ROOT = '/cache/data'
13 |
14 |
15 | def set_data_root(data_root):
16 | global DATA_ROOT
17 | DATA_ROOT = data_root
18 |
19 |
20 | def get_data_path(data_dir):
21 | if os.path.isabs(data_dir):
22 | return data_dir
23 | global DATA_ROOT
24 | return os.path.join(DATA_ROOT, data_dir)
25 |
26 |
27 | def build_dataset(cfg, resolution=224, **kwargs):
28 | logger = get_root_logger()
29 |
30 | dataset_type = cfg.get('type')
31 | logger.info(f"Constructing dataset {dataset_type}...")
32 | t = time.time()
33 | transform = cfg.pop('transform', 'default_train')
34 | transform = get_transform(transform, resolution)
35 | dataset = build_from_cfg(cfg, DATASETS, default_args=dict(transform=transform, resolution=resolution, **kwargs))
36 | logger.info(f"Dataset {dataset_type} constructed. time: {(time.time() - t):.2f} s, length (use/ori): {len(dataset)}/{dataset.ori_imgs_nums}")
37 | return dataset
38 |
39 |
40 | def build_dataloader(dataset, batch_size=256, num_workers=4, shuffle=True, **kwargs):
41 | return (
42 | DataLoader(
43 | dataset,
44 | batch_sampler=kwargs['batch_sampler'],
45 | num_workers=num_workers,
46 | pin_memory=True,
47 | )
48 | if 'batch_sampler' in kwargs
49 | else DataLoader(
50 | dataset,
51 | batch_size=batch_size,
52 | shuffle=shuffle,
53 | num_workers=num_workers,
54 | pin_memory=True,
55 | **kwargs
56 | )
57 | )
58 |
--------------------------------------------------------------------------------
/diffusion/data/datasets/Dreambooth.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 | import torch
4 | from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS
5 | from torch.utils.data import Dataset
6 | from diffusers.utils.torch_utils import randn_tensor
7 | from torchvision import transforms as T
8 | import pathlib
9 | from diffusers.models import AutoencoderKL
10 |
11 | from diffusion.data.builder import get_data_path, DATASETS
12 | from diffusion.data.datasets.utils import *
13 |
14 | IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm', 'tif', 'tiff', 'webp', 'JPEG'}
15 |
16 |
17 | @DATASETS.register_module()
18 | class DreamBooth(Dataset):
19 | def __init__(self,
20 | root,
21 | transform=None,
22 | resolution=1024,
23 | **kwargs):
24 | self.root = get_data_path(root)
25 | path = pathlib.Path(self.root)
26 | self.transform = transform
27 | self.resolution = resolution
28 | self.img_samples = sorted(
29 | [file for ext in IMAGE_EXTENSIONS for file in path.glob(f'*.{ext}')]
30 | )
31 | self.ori_imgs_nums = len(self)
32 | self.loader = default_loader
33 | self.base_size = int(kwargs['aspect_ratio_type'].split('_')[-1])
34 | self.aspect_ratio = eval(kwargs.pop('aspect_ratio_type')) # base aspect ratio
35 | self.ratio_nums = {}
36 | for k, v in self.aspect_ratio.items():
37 | self.ratio_nums[float(k)] = 0 # used for batch-sampler
38 | self.data_info = {'img_hw': torch.tensor([resolution, resolution], dtype=torch.float32), 'aspect_ratio': 1.}
39 |
40 | # image related
41 | with torch.inference_mode():
42 | vae = AutoencoderKL.from_pretrained("output/pretrained_models/sd-vae-ft-ema")
43 | imgs = []
44 | for img_path in self.img_samples:
45 | img = self.loader(img_path)
46 | self.ratio_nums[1.0] += 1
47 | if self.transform is not None:
48 | imgs.append(self.transform(img))
49 | imgs = torch.stack(imgs, dim=0)
50 | self.img_vae = vae.encode(imgs).latent_dist.sample()
51 | del vae
52 |
53 | def __getitem__(self, index):
54 | return self.img_vae[index], self.data_info
55 |
56 | @staticmethod
57 | def vae_feat_loader(path):
58 | # [mean, std]
59 | mean, std = torch.from_numpy(np.load(path)).chunk(2)
60 | sample = randn_tensor(mean.shape, generator=None, device=mean.device, dtype=mean.dtype)
61 | return mean + std * sample
62 |
63 | def load_ori_img(self, img_path):
64 | # 加载图像并转换为Tensor
65 | transform = T.Compose([
66 | T.Resize(256), # Image.BICUBIC
67 | T.CenterCrop(256),
68 | T.ToTensor(),
69 | ])
70 | return transform(Image.open(img_path))
71 |
72 | def __len__(self):
73 | return len(self.img_samples)
74 |
75 | def __getattr__(self, name):
76 | if name == "set_epoch":
77 | return lambda epoch: None
78 | raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
79 |
80 | def get_data_info(self, idx):
81 | return {'height': self.resolution, 'width': self.resolution}
82 |
--------------------------------------------------------------------------------
/diffusion/data/datasets/InternalData.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | from PIL import Image
4 | import numpy as np
5 | import torch
6 | from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS
7 | from torch.utils.data import Dataset
8 | from diffusers.utils.torch_utils import randn_tensor
9 | from torchvision import transforms as T
10 | from diffusion.data.builder import get_data_path, DATASETS
11 | from diffusion.utils.logger import get_root_logger
12 |
13 | import json
14 |
15 |
16 | @DATASETS.register_module()
17 | class InternalData(Dataset):
18 | def __init__(self,
19 | root,
20 | image_list_json='data_info.json',
21 | transform=None,
22 | resolution=256,
23 | sample_subset=None,
24 | load_vae_feat=False,
25 | input_size=32,
26 | patch_size=2,
27 | mask_ratio=0.0,
28 | load_mask_index=False,
29 | max_length=120,
30 | config=None,
31 | **kwargs):
32 | self.root = get_data_path(root)
33 | self.transform = transform
34 | self.load_vae_feat = load_vae_feat
35 | self.ori_imgs_nums = 0
36 | self.resolution = resolution
37 | self.N = int(resolution // (input_size // patch_size))
38 | self.mask_ratio = mask_ratio
39 | self.load_mask_index = load_mask_index
40 | self.max_lenth = max_length
41 | self.meta_data_clean = []
42 | self.img_samples = []
43 | self.txt_feat_samples = []
44 | self.vae_feat_samples = []
45 | self.mask_index_samples = []
46 | self.prompt_samples = []
47 |
48 | image_list_json = image_list_json if isinstance(image_list_json, list) else [image_list_json]
49 | for json_file in image_list_json:
50 | meta_data = self.load_json(os.path.join(self.root, 'partition', json_file))
51 | self.ori_imgs_nums += len(meta_data)
52 | meta_data_clean = [item for item in meta_data if item['ratio'] <= 4]
53 | self.meta_data_clean.extend(meta_data_clean)
54 | self.img_samples.extend([os.path.join(self.root.replace('InternData', "InternImgs"), item['path']) for item in meta_data_clean])
55 | self.txt_feat_samples.extend([os.path.join(self.root, 'caption_feature_wmask', '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npz')) for item in meta_data_clean])
56 | self.vae_feat_samples.extend([os.path.join(self.root, f'img_vae_features_{resolution}resolution/noflip', '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npy')) for item in meta_data_clean])
57 | self.prompt_samples.extend([item['prompt'] for item in meta_data_clean])
58 |
59 | # Set loader and extensions
60 | if load_vae_feat:
61 | self.transform = None
62 | self.loader = self.vae_feat_loader
63 | else:
64 | self.loader = default_loader
65 |
66 | if sample_subset is not None:
67 | self.sample_subset(sample_subset) # sample dataset for local debug
68 | logger = get_root_logger() if config is None else get_root_logger(os.path.join(config.work_dir, 'train_log.log'))
69 | logger.info(f"T5 max token length: {self.max_lenth}")
70 |
71 | def getdata(self, index):
72 | img_path = self.img_samples[index]
73 | npz_path = self.txt_feat_samples[index]
74 | npy_path = self.vae_feat_samples[index]
75 | prompt = self.prompt_samples[index]
76 | data_info = {
77 | 'img_hw': torch.tensor([torch.tensor(self.resolution), torch.tensor(self.resolution)], dtype=torch.float32),
78 | 'aspect_ratio': torch.tensor(1.)
79 | }
80 |
81 | img = self.loader(npy_path) if self.load_vae_feat else self.loader(img_path)
82 | txt_info = np.load(npz_path)
83 | txt_fea = torch.from_numpy(txt_info['caption_feature']) # 1xTx4096
84 | attention_mask = torch.ones(1, 1, txt_fea.shape[1]) # 1x1xT
85 | if 'attention_mask' in txt_info.keys():
86 | attention_mask = torch.from_numpy(txt_info['attention_mask'])[None]
87 | if txt_fea.shape[1] != self.max_lenth:
88 | txt_fea = torch.cat([txt_fea, txt_fea[:, -1:].repeat(1, self.max_lenth-txt_fea.shape[1], 1)], dim=1)
89 | attention_mask = torch.cat([attention_mask, torch.zeros(1, 1, self.max_lenth-attention_mask.shape[-1])], dim=-1)
90 |
91 | if self.transform:
92 | img = self.transform(img)
93 |
94 | data_info['prompt'] = prompt
95 | return img, txt_fea, attention_mask, data_info
96 |
97 | def __getitem__(self, idx):
98 | for _ in range(20):
99 | try:
100 | return self.getdata(idx)
101 | except Exception as e:
102 | print(f"Error details: {str(e)}")
103 | idx = np.random.randint(len(self))
104 | raise RuntimeError('Too many bad data.')
105 |
106 | def get_data_info(self, idx):
107 | data_info = self.meta_data_clean[idx]
108 | return {'height': data_info['height'], 'width': data_info['width']}
109 |
110 | @staticmethod
111 | def vae_feat_loader(path):
112 | # [mean, std]
113 | mean, std = torch.from_numpy(np.load(path)).chunk(2)
114 | sample = randn_tensor(mean.shape, generator=None, device=mean.device, dtype=mean.dtype)
115 | return mean + std * sample
116 |
117 | def load_ori_img(self, img_path):
118 | # 加载图像并转换为Tensor
119 | transform = T.Compose([
120 | T.Resize(256), # Image.BICUBIC
121 | T.CenterCrop(256),
122 | T.ToTensor(),
123 | ])
124 | return transform(Image.open(img_path))
125 |
126 | def load_json(self, file_path):
127 | with open(file_path, 'r') as f:
128 | meta_data = json.load(f)
129 |
130 | return meta_data
131 |
132 | def sample_subset(self, ratio):
133 | sampled_idx = random.sample(list(range(len(self))), int(len(self) * ratio))
134 | self.img_samples = [self.img_samples[i] for i in sampled_idx]
135 |
136 | def __len__(self):
137 | return len(self.img_samples)
138 |
139 | def __getattr__(self, name):
140 | if name == "set_epoch":
141 | return lambda epoch: None
142 | raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
143 |
144 |
--------------------------------------------------------------------------------
/diffusion/data/datasets/InternalData_ms.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import random
5 | from torchvision.datasets.folder import default_loader
6 | from diffusion.data.datasets.InternalData import InternalData
7 | from diffusion.data.builder import get_data_path, DATASETS
8 | from diffusion.utils.logger import get_root_logger
9 | import torchvision.transforms as T
10 | from torchvision.transforms.functional import InterpolationMode
11 | from diffusion.data.datasets.utils import *
12 |
13 | def get_closest_ratio(height: float, width: float, ratios: dict):
14 | aspect_ratio = height / width
15 | closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
16 | return ratios[closest_ratio], float(closest_ratio)
17 |
18 |
19 | @DATASETS.register_module()
20 | class InternalDataMS(InternalData):
21 | def __init__(self,
22 | root,
23 | image_list_json='data_info.json',
24 | transform=None,
25 | resolution=256,
26 | sample_subset=None,
27 | load_vae_feat=False,
28 | input_size=32,
29 | patch_size=2,
30 | mask_ratio=0.0,
31 | mask_type='null',
32 | load_mask_index=False,
33 | max_length=120,
34 | config=None,
35 | **kwargs):
36 | self.root = get_data_path(root)
37 | self.transform = transform
38 | self.load_vae_feat = load_vae_feat
39 | self.ori_imgs_nums = 0
40 | self.resolution = resolution
41 | self.N = int(resolution // (input_size // patch_size))
42 | self.mask_ratio = mask_ratio
43 | self.load_mask_index = load_mask_index
44 | self.mask_type = mask_type
45 | self.base_size = int(kwargs['aspect_ratio_type'].split('_')[-1])
46 | self.max_lenth = max_length
47 | self.aspect_ratio = eval(kwargs.pop('aspect_ratio_type')) # base aspect ratio
48 | self.meta_data_clean = []
49 | self.img_samples = []
50 | self.txt_feat_samples = []
51 | self.vae_feat_samples = []
52 | self.mask_index_samples = []
53 | self.ratio_index = {}
54 | self.ratio_nums = {}
55 | for k, v in self.aspect_ratio.items():
56 | self.ratio_index[float(k)] = [] # used for self.getitem
57 | self.ratio_nums[float(k)] = 0 # used for batch-sampler
58 |
59 | image_list_json = image_list_json if isinstance(image_list_json, list) else [image_list_json]
60 | for json_file in image_list_json:
61 | meta_data = self.load_json(os.path.join(self.root, 'partition_filter', json_file))
62 | self.ori_imgs_nums += len(meta_data)
63 | meta_data_clean = [item for item in meta_data if item['ratio'] <= 4]
64 | self.meta_data_clean.extend(meta_data_clean)
65 | self.img_samples.extend([os.path.join(self.root.replace('InternData', "InternImgs"), item['path']) for item in meta_data_clean])
66 | self.txt_feat_samples.extend([os.path.join(self.root, 'caption_feature_wmask', '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npz')) for item in meta_data_clean])
67 | self.vae_feat_samples.extend([os.path.join(self.root, f'img_vae_fatures_{resolution}_multiscale/ms', '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npy')) for item in meta_data_clean])
68 |
69 | # Set loader and extensions
70 | if load_vae_feat:
71 | self.transform = None
72 | self.loader = self.vae_feat_loader
73 | else:
74 | self.loader = default_loader
75 |
76 | if sample_subset is not None:
77 | self.sample_subset(sample_subset) # sample dataset for local debug
78 |
79 | # scan the dataset for ratio static
80 | for i, info in enumerate(self.meta_data_clean[:len(self.meta_data_clean)//3]):
81 | ori_h, ori_w = info['height'], info['width']
82 | closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, self.aspect_ratio)
83 | self.ratio_nums[closest_ratio] += 1
84 | if len(self.ratio_index[closest_ratio]) == 0:
85 | self.ratio_index[closest_ratio].append(i)
86 | # print(self.ratio_nums)
87 | logger = get_root_logger() if config is None else get_root_logger(os.path.join(config.work_dir, 'train_log.log'))
88 | logger.info(f"T5 max token length: {self.max_lenth}")
89 |
90 | def getdata(self, index):
91 | img_path = self.img_samples[index]
92 | npz_path = self.txt_feat_samples[index]
93 | npy_path = self.vae_feat_samples[index]
94 | ori_h, ori_w = self.meta_data_clean[index]['height'], self.meta_data_clean[index]['width']
95 |
96 | # Calculate the closest aspect ratio and resize & crop image[w, h]
97 | closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, self.aspect_ratio)
98 | closest_size = list(map(lambda x: int(x), closest_size))
99 | self.closest_ratio = closest_ratio
100 |
101 | if self.load_vae_feat:
102 | try:
103 | img = self.loader(npy_path)
104 | if index not in self.ratio_index[closest_ratio]:
105 | self.ratio_index[closest_ratio].append(index)
106 | except Exception:
107 | index = random.choice(self.ratio_index[closest_ratio])
108 | return self.getdata(index)
109 | h, w = (img.shape[1], img.shape[2])
110 | assert h, w == (ori_h//8, ori_w//8)
111 | else:
112 | img = self.loader(img_path)
113 | h, w = (img.size[1], img.size[0])
114 | assert h, w == (ori_h, ori_w)
115 |
116 | data_info = {'img_hw': torch.tensor([ori_h, ori_w], dtype=torch.float32)}
117 | data_info['aspect_ratio'] = closest_ratio
118 | data_info["mask_type"] = self.mask_type
119 |
120 | txt_info = np.load(npz_path)
121 | txt_fea = torch.from_numpy(txt_info['caption_feature'])
122 | attention_mask = torch.ones(1, 1, txt_fea.shape[1])
123 | if 'attention_mask' in txt_info.keys():
124 | attention_mask = torch.from_numpy(txt_info['attention_mask'])[None]
125 |
126 | if not self.load_vae_feat:
127 | if closest_size[0] / ori_h > closest_size[1] / ori_w:
128 | resize_size = closest_size[0], int(ori_w * closest_size[0] / ori_h)
129 | else:
130 | resize_size = int(ori_h * closest_size[1] / ori_w), closest_size[1]
131 | self.transform = T.Compose([
132 | T.Lambda(lambda img: img.convert('RGB')),
133 | T.Resize(resize_size, interpolation=InterpolationMode.BICUBIC), # Image.BICUBIC
134 | T.CenterCrop(closest_size),
135 | T.ToTensor(),
136 | T.Normalize([.5], [.5]),
137 | ])
138 |
139 | if self.transform:
140 | img = self.transform(img)
141 |
142 | return img, txt_fea, attention_mask, data_info
143 |
144 | def __getitem__(self, idx):
145 | for _ in range(20):
146 | try:
147 | return self.getdata(idx)
148 | except Exception as e:
149 | print(f"Error details: {str(e)}")
150 | idx = random.choice(self.ratio_index[self.closest_ratio])
151 | raise RuntimeError('Too many bad data.')
152 |
--------------------------------------------------------------------------------
/diffusion/data/datasets/SA.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import time
4 |
5 | import numpy as np
6 | import torch
7 | from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS
8 | from torch.utils.data import Dataset
9 | from diffusers.utils.torch_utils import randn_tensor
10 |
11 | from diffusion.data.builder import get_data_path, DATASETS
12 |
13 |
14 | @DATASETS.register_module()
15 | class SAM(Dataset):
16 | def __init__(self,
17 | root,
18 | image_list_txt='part0.txt',
19 | transform=None,
20 | resolution=256,
21 | sample_subset=None,
22 | load_vae_feat=False,
23 | mask_ratio=0.0,
24 | mask_type='null',
25 | **kwargs):
26 | self.root = get_data_path(root)
27 | self.transform = transform
28 | self.load_vae_feat = load_vae_feat
29 | self.mask_type = mask_type
30 | self.mask_ratio = mask_ratio
31 | self.resolution = resolution
32 | self.img_samples = []
33 | self.txt_feat_samples = []
34 | self.vae_feat_samples = []
35 | image_list_txt = image_list_txt if isinstance(image_list_txt, list) else [image_list_txt]
36 | if image_list_txt == 'all':
37 | image_list_txts = os.listdir(os.path.join(self.root, 'partition'))
38 | for txt in image_list_txts:
39 | image_list = os.path.join(self.root, 'partition', txt)
40 | with open(image_list, 'r') as f:
41 | lines = [line.strip() for line in f.readlines()]
42 | self.img_samples.extend([os.path.join(self.root, 'images', i+'.jpg') for i in lines])
43 | self.txt_feat_samples.extend([os.path.join(self.root, 'caption_feature_wmask', i+'.npz') for i in lines])
44 | elif isinstance(image_list_txt, list):
45 | for txt in image_list_txt:
46 | image_list = os.path.join(self.root, 'partition', txt)
47 | with open(image_list, 'r') as f:
48 | lines = [line.strip() for line in f.readlines()]
49 | self.img_samples.extend([os.path.join(self.root, 'images', i + '.jpg') for i in lines])
50 | self.txt_feat_samples.extend([os.path.join(self.root, 'caption_feature_wmask', i + '.npz') for i in lines])
51 | self.vae_feat_samples.extend([os.path.join(self.root, 'img_vae_feature/train_vae_256/noflip', i + '.npy') for i in lines])
52 |
53 | self.ori_imgs_nums = len(self)
54 | # self.img_samples = self.img_samples[:10000]
55 | # Set loader and extensions
56 | if load_vae_feat:
57 | self.transform = None
58 | self.loader = self.vae_feat_loader
59 | else:
60 | self.loader = default_loader
61 |
62 | if sample_subset is not None:
63 | self.sample_subset(sample_subset) # sample dataset for local debug
64 |
65 | def getdata(self, idx):
66 | img_path = self.img_samples[idx]
67 | npz_path = self.txt_feat_samples[idx]
68 | npy_path = self.vae_feat_samples[idx]
69 | data_info = {'img_hw': torch.tensor([self.resolution, self.resolution], dtype=torch.float32),
70 | 'aspect_ratio': torch.tensor(1.)}
71 |
72 | img = self.loader(npy_path) if self.load_vae_feat else self.loader(img_path)
73 | npz_info = np.load(npz_path)
74 | txt_fea = torch.from_numpy(npz_info['caption_feature'])
75 | attention_mask = torch.ones(1, 1, txt_fea.shape[1])
76 | if 'attention_mask' in npz_info.keys():
77 | attention_mask = torch.from_numpy(npz_info['attention_mask'])[None]
78 |
79 | if self.transform:
80 | img = self.transform(img)
81 |
82 | data_info["mask_type"] = self.mask_type
83 |
84 | return img, txt_fea, attention_mask, data_info
85 |
86 | def __getitem__(self, idx):
87 | for _ in range(20):
88 | try:
89 | return self.getdata(idx)
90 | except Exception:
91 | print(self.img_samples[idx], ' info is not correct')
92 | idx = np.random.randint(len(self))
93 | raise RuntimeError('Too many bad data.')
94 |
95 | @staticmethod
96 | def vae_feat_loader(path):
97 | # [mean, std]
98 | mean, std = torch.from_numpy(np.load(path)).chunk(2)
99 | sample = randn_tensor(mean.shape, generator=None, device=mean.device, dtype=mean.dtype)
100 | return mean + std * sample
101 | # return mean
102 |
103 | def sample_subset(self, ratio):
104 | sampled_idx = random.sample(list(range(len(self))), int(len(self) * ratio))
105 | self.img_samples = [self.img_samples[i] for i in sampled_idx]
106 | self.txt_feat_samples = [self.txt_feat_samples[i] for i in sampled_idx]
107 |
108 | def __len__(self):
109 | return len(self.img_samples)
110 |
--------------------------------------------------------------------------------
/diffusion/data/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .SA import SAM
2 | from .InternalData import InternalData
3 | from .InternalData_ms import InternalDataMS
4 | from .Dreambooth import DreamBooth
5 | from .pixart_control import InternalDataHed
6 | from .utils import *
7 |
--------------------------------------------------------------------------------
/diffusion/data/datasets/utils.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | ASPECT_RATIO_1024 = {
4 | '0.25': [512., 2048.], '0.26': [512., 1984.], '0.27': [512., 1920.], '0.28': [512., 1856.],
5 | '0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.],
6 | '0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.],
7 | '0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.],
8 | '0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.],
9 | '1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.],
10 | '1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.],
11 | '1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.],
12 | '2.5': [1600., 640.], '2.89': [1664., 576.], '3.0': [1728., 576.], '3.11': [1792., 576.],
13 | '3.62': [1856., 512.], '3.75': [1920., 512.], '3.88': [1984., 512.], '4.0': [2048., 512.],
14 | }
15 |
16 | ASPECT_RATIO_512 = {
17 | '0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0],
18 | '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
19 | '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
20 | '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
21 | '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
22 | '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
23 | '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
24 | '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
25 | '2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0],
26 | '3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0]
27 | }
28 |
29 | ASPECT_RATIO_256 = {
30 | '0.25': [128.0, 512.0], '0.26': [128.0, 496.0], '0.27': [128.0, 480.0], '0.28': [128.0, 464.0],
31 | '0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0],
32 | '0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0],
33 | '0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0],
34 | '0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0],
35 | '1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0],
36 | '1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0],
37 | '1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0],
38 | '2.5': [400.0, 160.0], '2.89': [416.0, 144.0], '3.0': [432.0, 144.0], '3.11': [448.0, 144.0],
39 | '3.62': [464.0, 128.0], '3.75': [480.0, 128.0], '3.88': [496.0, 128.0], '4.0': [512.0, 128.0]
40 | }
41 |
42 | ASPECT_RATIO_256_TEST = {
43 | '0.25': [128.0, 512.0], '0.28': [128.0, 464.0],
44 | '0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0],
45 | '0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0],
46 | '0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0],
47 | '0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0],
48 | '1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0],
49 | '1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0],
50 | '1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0],
51 | '2.5': [400.0, 160.0], '3.0': [432.0, 144.0],
52 | '4.0': [512.0, 128.0]
53 | }
54 |
55 | ASPECT_RATIO_512_TEST = {
56 | '0.25': [256.0, 1024.0], '0.28': [256.0, 928.0],
57 | '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
58 | '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
59 | '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
60 | '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
61 | '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
62 | '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
63 | '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
64 | '2.5': [800.0, 320.0], '3.0': [864.0, 288.0],
65 | '4.0': [1024.0, 256.0]
66 | }
67 |
68 | ASPECT_RATIO_1024_TEST = {
69 | '0.25': [512., 2048.], '0.28': [512., 1856.],
70 | '0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.],
71 | '0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.],
72 | '0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.],
73 | '0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.],
74 | '1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.],
75 | '1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.],
76 | '1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.],
77 | '2.5': [1600., 640.], '3.0': [1728., 576.],
78 | '4.0': [2048., 512.],
79 | }
80 |
81 |
82 | def get_chunks(lst, n):
83 | for i in range(0, len(lst), n):
84 | yield lst[i:i + n]
85 |
--------------------------------------------------------------------------------
/diffusion/data/transforms.py:
--------------------------------------------------------------------------------
1 | import torchvision.transforms as T
2 |
3 | TRANSFORMS = {}
4 |
5 |
6 | def register_transform(transform):
7 | name = transform.__name__
8 | if name in TRANSFORMS:
9 | raise RuntimeError(f'Transform {name} has already registered.')
10 | TRANSFORMS.update({name: transform})
11 |
12 |
13 | def get_transform(type, resolution):
14 | transform = TRANSFORMS[type](resolution)
15 | transform = T.Compose(transform)
16 | transform.image_size = resolution
17 | return transform
18 |
19 |
20 | @register_transform
21 | def default_train(n_px):
22 | return [
23 | T.Lambda(lambda img: img.convert('RGB')),
24 | T.Resize(n_px), # Image.BICUBIC
25 | T.CenterCrop(n_px),
26 | # T.RandomHorizontalFlip(),
27 | T.ToTensor(),
28 | T.Normalize([0.5], [0.5]),
29 | ]
30 |
--------------------------------------------------------------------------------
/diffusion/dpm_solver.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .model import gaussian_diffusion as gd
3 | from .model.dpm_solver import model_wrapper, DPM_Solver, NoiseScheduleVP
4 |
5 |
6 | def DPMS(model, condition, uncondition, cfg_scale, model_type='noise', noise_schedule="linear", guidance_type='classifier-free', model_kwargs=None, diffusion_steps=1000):
7 | if model_kwargs is None:
8 | model_kwargs = {}
9 | betas = torch.tensor(gd.get_named_beta_schedule(noise_schedule, diffusion_steps))
10 |
11 | ## 1. Define the noise schedule.
12 | noise_schedule = NoiseScheduleVP(schedule='discrete', betas=betas)
13 |
14 | ## 2. Convert your discrete-time `model` to the continuous-time
15 | ## noise prediction model. Here is an example for a diffusion model
16 | ## `model` with the noise prediction type ("noise") .
17 | model_fn = model_wrapper(
18 | model,
19 | noise_schedule,
20 | model_type=model_type,
21 | model_kwargs=model_kwargs,
22 | guidance_type=guidance_type,
23 | condition=condition,
24 | unconditional_condition=uncondition,
25 | guidance_scale=cfg_scale,
26 | )
27 | ## 3. Define dpm-solver and sample by multistep DPM-Solver.
28 | return DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
--------------------------------------------------------------------------------
/diffusion/iddpm.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 | from diffusion.model.respace import SpacedDiffusion, space_timesteps
6 | from .model import gaussian_diffusion as gd
7 |
8 |
9 | def IDDPM(
10 | timestep_respacing,
11 | noise_schedule="linear",
12 | use_kl=False,
13 | sigma_small=False,
14 | predict_xstart=False,
15 | learn_sigma=True,
16 | pred_sigma=True,
17 | rescale_learned_sigmas=False,
18 | diffusion_steps=1000,
19 | snr=False,
20 | return_startx=False,
21 | ):
22 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
23 | if use_kl:
24 | loss_type = gd.LossType.RESCALED_KL
25 | elif rescale_learned_sigmas:
26 | loss_type = gd.LossType.RESCALED_MSE
27 | else:
28 | loss_type = gd.LossType.MSE
29 | if timestep_respacing is None or timestep_respacing == "":
30 | timestep_respacing = [diffusion_steps]
31 | return SpacedDiffusion(
32 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
33 | betas=betas,
34 | model_mean_type=(
35 | gd.ModelMeanType.START_X if predict_xstart else gd.ModelMeanType.EPSILON
36 | ),
37 | model_var_type=(
38 | (gd.ModelVarType.LEARNED_RANGE if learn_sigma else (
39 | gd.ModelVarType.FIXED_LARGE
40 | if not sigma_small
41 | else gd.ModelVarType.FIXED_SMALL
42 | )
43 | )
44 | if pred_sigma
45 | else None
46 | ),
47 | loss_type=loss_type,
48 | snr=snr,
49 | return_startx=return_startx,
50 | # rescale_timesteps=rescale_timesteps,
51 | )
--------------------------------------------------------------------------------
/diffusion/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .nets import *
2 |
--------------------------------------------------------------------------------
/diffusion/model/builder.py:
--------------------------------------------------------------------------------
1 | from mmcv import Registry
2 |
3 | from diffusion.model.utils import set_grad_checkpoint
4 |
5 | MODELS = Registry('models')
6 |
7 |
8 | def build_model(cfg, use_grad_checkpoint=False, use_fp32_attention=False, gc_step=1, **kwargs):
9 | if isinstance(cfg, str):
10 | cfg = dict(type=cfg)
11 | model = MODELS.build(cfg, default_args=kwargs)
12 | if use_grad_checkpoint:
13 | set_grad_checkpoint(model, use_fp32_attention=use_fp32_attention, gc_step=gc_step)
14 | return model
15 |
--------------------------------------------------------------------------------
/diffusion/model/diffusion_utils.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import numpy as np
7 | import torch as th
8 |
9 |
10 | def normal_kl(mean1, logvar1, mean2, logvar2):
11 | """
12 | Compute the KL divergence between two gaussians.
13 | Shapes are automatically broadcasted, so batches can be compared to
14 | scalars, among other use cases.
15 | """
16 | tensor = next(
17 | (
18 | obj
19 | for obj in (mean1, logvar1, mean2, logvar2)
20 | if isinstance(obj, th.Tensor)
21 | ),
22 | None,
23 | )
24 | assert tensor is not None, "at least one argument must be a Tensor"
25 |
26 | # Force variances to be Tensors. Broadcasting helps convert scalars to
27 | # Tensors, but it does not work for th.exp().
28 | logvar1, logvar2 = [
29 | x if isinstance(x, th.Tensor) else th.tensor(x, device=tensor.device)
30 | for x in (logvar1, logvar2)
31 | ]
32 |
33 | return 0.5 * (
34 | -1.0
35 | + logvar2
36 | - logvar1
37 | + th.exp(logvar1 - logvar2)
38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
39 | )
40 |
41 |
42 | def approx_standard_normal_cdf(x):
43 | """
44 | A fast approximation of the cumulative distribution function of the
45 | standard normal.
46 | """
47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
48 |
49 |
50 | def continuous_gaussian_log_likelihood(x, *, means, log_scales):
51 | """
52 | Compute the log-likelihood of a continuous Gaussian distribution.
53 | :param x: the targets
54 | :param means: the Gaussian mean Tensor.
55 | :param log_scales: the Gaussian log stddev Tensor.
56 | :return: a tensor like x of log probabilities (in nats).
57 | """
58 | centered_x = x - means
59 | inv_stdv = th.exp(-log_scales)
60 | normalized_x = centered_x * inv_stdv
61 | return th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(
62 | normalized_x
63 | )
64 |
65 |
66 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
67 | """
68 | Compute the log-likelihood of a Gaussian distribution discretizing to a
69 | given image.
70 | :param x: the target images. It is assumed that this was uint8 values,
71 | rescaled to the range [-1, 1].
72 | :param means: the Gaussian mean Tensor.
73 | :param log_scales: the Gaussian log stddev Tensor.
74 | :return: a tensor like x of log probabilities (in nats).
75 | """
76 | assert x.shape == means.shape == log_scales.shape
77 | centered_x = x - means
78 | inv_stdv = th.exp(-log_scales)
79 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
80 | cdf_plus = approx_standard_normal_cdf(plus_in)
81 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
82 | cdf_min = approx_standard_normal_cdf(min_in)
83 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
84 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
85 | cdf_delta = cdf_plus - cdf_min
86 | log_probs = th.where(
87 | x < -0.999,
88 | log_cdf_plus,
89 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
90 | )
91 | assert log_probs.shape == x.shape
92 | return log_probs
93 |
--------------------------------------------------------------------------------
/diffusion/model/edm_sample.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | from tqdm import tqdm
4 |
5 | from diffusion.model.utils import *
6 |
7 |
8 | # ----------------------------------------------------------------------------
9 | # Proposed EDM sampler (Algorithm 2).
10 |
11 | def edm_sampler(
12 | net, latents, class_labels=None, cfg_scale=None, randn_like=torch.randn_like,
13 | num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
14 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, **kwargs
15 | ):
16 | # Adjust noise levels based on what's supported by the network.
17 | sigma_min = max(sigma_min, net.sigma_min)
18 | sigma_max = min(sigma_max, net.sigma_max)
19 |
20 | # Time step discretization.
21 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
22 | t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (
23 | sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
24 | t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
25 |
26 | # Main sampling loop.
27 | x_next = latents.to(torch.float64) * t_steps[0]
28 | for i, (t_cur, t_next) in tqdm(list(enumerate(zip(t_steps[:-1], t_steps[1:])))): # 0, ..., N-1
29 | x_cur = x_next
30 |
31 | # Increase noise temporarily.
32 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
33 | t_hat = net.round_sigma(t_cur + gamma * t_cur)
34 | x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
35 |
36 | # Euler step.
37 | denoised = net(x_hat.float(), t_hat, class_labels, cfg_scale, **kwargs)['x'].to(torch.float64)
38 | d_cur = (x_hat - denoised) / t_hat
39 | x_next = x_hat + (t_next - t_hat) * d_cur
40 |
41 | # Apply 2nd order correction.
42 | if i < num_steps - 1:
43 | denoised = net(x_next.float(), t_next, class_labels, cfg_scale, **kwargs)['x'].to(torch.float64)
44 | d_prime = (x_next - denoised) / t_next
45 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
46 |
47 | return x_next
48 |
49 |
50 | # ----------------------------------------------------------------------------
51 | # Generalized ablation sampler, representing the superset of all sampling
52 | # methods discussed in the paper.
53 |
54 | def ablation_sampler(
55 | net, latents, class_labels=None, cfg_scale=None, feat=None, randn_like=torch.randn_like,
56 | num_steps=18, sigma_min=None, sigma_max=None, rho=7,
57 | solver='heun', discretization='edm', schedule='linear', scaling='none',
58 | epsilon_s=1e-3, C_1=0.001, C_2=0.008, M=1000, alpha=1,
59 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
60 | ):
61 | assert solver in ['euler', 'heun']
62 | assert discretization in ['vp', 've', 'iddpm', 'edm']
63 | assert schedule in ['vp', 've', 'linear']
64 | assert scaling in ['vp', 'none']
65 |
66 | # Helper functions for VP & VE noise level schedules.
67 | vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
68 | vp_sigma_deriv = lambda beta_d, beta_min: lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t))
69 | vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (
70 | sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
71 | ve_sigma = lambda t: t.sqrt()
72 | ve_sigma_deriv = lambda t: 0.5 / t.sqrt()
73 | ve_sigma_inv = lambda sigma: sigma ** 2
74 |
75 | # Select default noise level range based on the specified time step discretization.
76 | if sigma_min is None:
77 | vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=epsilon_s)
78 | sigma_min = {'vp': vp_def, 've': 0.02, 'iddpm': 0.002, 'edm': 0.002}[discretization]
79 | if sigma_max is None:
80 | vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=1)
81 | sigma_max = {'vp': vp_def, 've': 100, 'iddpm': 81, 'edm': 80}[discretization]
82 |
83 | # Adjust noise levels based on what's supported by the network.
84 | sigma_min = max(sigma_min, net.sigma_min)
85 | sigma_max = min(sigma_max, net.sigma_max)
86 |
87 | # Compute corresponding betas for VP.
88 | vp_beta_d = 2 * (np.log(sigma_min ** 2 + 1) / epsilon_s - np.log(sigma_max ** 2 + 1)) / (epsilon_s - 1)
89 | vp_beta_min = np.log(sigma_max ** 2 + 1) - 0.5 * vp_beta_d
90 |
91 | # Define time steps in terms of noise level.
92 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
93 | if discretization == 'vp':
94 | orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1)
95 | sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps)
96 | elif discretization == 've':
97 | orig_t_steps = (sigma_max ** 2) * ((sigma_min ** 2 / sigma_max ** 2) ** (step_indices / (num_steps - 1)))
98 | sigma_steps = ve_sigma(orig_t_steps)
99 | elif discretization == 'iddpm':
100 | u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device)
101 | alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2
102 | for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1
103 | u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt()
104 | u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)]
105 | sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)]
106 | else:
107 | assert discretization == 'edm'
108 | sigma_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (
109 | sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
110 |
111 | # Define noise level schedule.
112 | if schedule == 'vp':
113 | sigma = vp_sigma(vp_beta_d, vp_beta_min)
114 | sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min)
115 | sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min)
116 | elif schedule == 've':
117 | sigma = ve_sigma
118 | sigma_deriv = ve_sigma_deriv
119 | sigma_inv = ve_sigma_inv
120 | else:
121 | assert schedule == 'linear'
122 | sigma = lambda t: t
123 | sigma_deriv = lambda t: 1
124 | sigma_inv = lambda sigma: sigma
125 |
126 | # Define scaling schedule.
127 | if scaling == 'vp':
128 | s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt()
129 | s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3)
130 | else:
131 | assert scaling == 'none'
132 | s = lambda t: 1
133 | s_deriv = lambda t: 0
134 |
135 | # Compute final time steps based on the corresponding noise levels.
136 | t_steps = sigma_inv(net.round_sigma(sigma_steps))
137 | t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0
138 |
139 | # Main sampling loop.
140 | t_next = t_steps[0]
141 | x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next))
142 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
143 | x_cur = x_next
144 |
145 | # Increase noise temporarily.
146 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0
147 | t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur)))
148 | x_hat = s(t_hat) / s(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * s(
149 | t_hat) * S_noise * randn_like(x_cur)
150 |
151 | # Euler step.
152 | h = t_next - t_hat
153 | denoised = net(x_hat.float() / s(t_hat), sigma(t_hat), class_labels, cfg_scale, feat=feat)['x'].to(
154 | torch.float64)
155 | d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)) * x_hat - sigma_deriv(t_hat) * s(
156 | t_hat) / sigma(t_hat) * denoised
157 | x_prime = x_hat + alpha * h * d_cur
158 | t_prime = t_hat + alpha * h
159 |
160 | # Apply 2nd order correction.
161 | if solver == 'euler' or i == num_steps - 1:
162 | x_next = x_hat + h * d_cur
163 | else:
164 | assert solver == 'heun'
165 | denoised = net(x_prime.float() / s(t_prime), sigma(t_prime), class_labels, cfg_scale, feat=feat)['x'].to(
166 | torch.float64)
167 | d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)) * x_prime - sigma_deriv(
168 | t_prime) * s(t_prime) / sigma(t_prime) * denoised
169 | x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime)
170 |
171 | return x_next
172 |
--------------------------------------------------------------------------------
/diffusion/model/hed.py:
--------------------------------------------------------------------------------
1 | # This is an improved version and model of HED edge detection with Apache License, Version 2.0.
2 | # Please use this implementation in your products
3 | # This implementation may produce slightly different results from Saining Xie's official implementations,
4 | # but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations.
5 | # Different from official models and other implementations, this is an RGB-input model (rather than BGR)
6 | # and in this way it works better for gradio's RGB protocol
7 | import sys
8 | from pathlib import Path
9 | current_file_path = Path(__file__).resolve()
10 | sys.path.insert(0, str(current_file_path.parent.parent.parent))
11 | from torch import nn
12 | import torch
13 | import numpy as np
14 | from torchvision import transforms as T
15 | from tqdm import tqdm
16 | from torch.utils.data import Dataset, DataLoader
17 | import json
18 | from PIL import Image
19 | import torchvision.transforms.functional as TF
20 | from accelerate import Accelerator
21 | from diffusers.models import AutoencoderKL
22 | import os
23 |
24 | image_resize = 1024
25 |
26 |
27 | class DoubleConvBlock(nn.Module):
28 | def __init__(self, input_channel, output_channel, layer_number):
29 | super().__init__()
30 | self.convs = torch.nn.Sequential()
31 | self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
32 | for i in range(1, layer_number):
33 | self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
34 | self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0)
35 |
36 | def forward(self, x, down_sampling=False):
37 | h = x
38 | if down_sampling:
39 | h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2))
40 | for conv in self.convs:
41 | h = conv(h)
42 | h = torch.nn.functional.relu(h)
43 | return h, self.projection(h)
44 |
45 |
46 | class ControlNetHED_Apache2(nn.Module):
47 | def __init__(self):
48 | super().__init__()
49 | self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1)))
50 | self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2)
51 | self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2)
52 | self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3)
53 | self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3)
54 | self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3)
55 |
56 | def forward(self, x):
57 | h = x - self.norm
58 | h, projection1 = self.block1(h)
59 | h, projection2 = self.block2(h, down_sampling=True)
60 | h, projection3 = self.block3(h, down_sampling=True)
61 | h, projection4 = self.block4(h, down_sampling=True)
62 | h, projection5 = self.block5(h, down_sampling=True)
63 | return projection1, projection2, projection3, projection4, projection5
64 |
65 |
66 | class InternData(Dataset):
67 | def __init__(self):
68 | ####
69 | with open('data/InternData/partition/data_info.json', 'r') as f:
70 | self.j = json.load(f)
71 | self.transform = T.Compose([
72 | T.Lambda(lambda img: img.convert('RGB')),
73 | T.Resize(image_resize), # Image.BICUBIC
74 | T.CenterCrop(image_resize),
75 | T.ToTensor(),
76 | ])
77 |
78 | def __len__(self):
79 | return len(self.j)
80 |
81 | def getdata(self, idx):
82 |
83 | path = self.j[idx]['path']
84 | image = Image.open("data/InternImgs/" + path)
85 | image = self.transform(image)
86 | return image, path
87 |
88 | def __getitem__(self, idx):
89 | for i in range(20):
90 | try:
91 | data = self.getdata(idx)
92 | return data
93 | except Exception as e:
94 | print(f"Error details: {str(e)}")
95 | idx = np.random.randint(len(self))
96 | raise RuntimeError('Too many bad data.')
97 |
98 | class HEDdetector(nn.Module):
99 | def __init__(self, feature=True, vae=None):
100 | super().__init__()
101 | self.model = ControlNetHED_Apache2()
102 | self.model.load_state_dict(torch.load('output/pretrained_models/ControlNetHED.pth', map_location='cpu'))
103 | self.model.eval()
104 | self.model.requires_grad_(False)
105 | if feature:
106 | if vae is None:
107 | self.vae = AutoencoderKL.from_pretrained("output/pretrained_models/sd-vae-ft-ema")
108 | else:
109 | self.vae = vae
110 | self.vae.eval()
111 | self.vae.requires_grad_(False)
112 | else:
113 | self.vae = None
114 |
115 | def forward(self, input_image):
116 | B, C, H, W = input_image.shape
117 | with torch.inference_mode():
118 | edges = self.model(input_image * 255.)
119 | edges = torch.cat([TF.resize(e, [H, W]) for e in edges], dim=1)
120 | edge = 1 / (1 + torch.exp(-torch.mean(edges, dim=1, keepdim=True)))
121 | edge.clip_(0, 1)
122 | if self.vae:
123 | edge = TF.normalize(edge, [.5], [.5])
124 | edge = edge.repeat(1, 3, 1, 1)
125 | posterior = self.vae.encode(edge).latent_dist
126 | edge = torch.cat([posterior.mean, posterior.std], dim=1).cpu().numpy()
127 | return edge
128 |
129 |
130 | def main():
131 | dataset = InternData()
132 | dataloader = DataLoader(dataset, batch_size=10, shuffle=False, num_workers=8, pin_memory=True)
133 | hed = HEDdetector()
134 |
135 | accelerator = Accelerator()
136 | hed, dataloader = accelerator.prepare(hed, dataloader)
137 |
138 |
139 | for img, path in tqdm(dataloader):
140 | out = hed(img.cuda())
141 | for p, o in zip(path, out):
142 | save = f'data/InternalData/hed_feature_{image_resize}/' + p.replace('.png', '.npz')
143 | if os.path.exists(save):
144 | continue
145 | os.makedirs(os.path.dirname(save), exist_ok=True)
146 | np.savez_compressed(save, o)
147 |
148 |
149 | if __name__ == "__main__":
150 | main()
151 |
--------------------------------------------------------------------------------
/diffusion/model/llava/__init__.py:
--------------------------------------------------------------------------------
1 | from diffusion.model.llava.llava_mpt import LlavaMPTForCausalLM, LlavaMPTConfig
--------------------------------------------------------------------------------
/diffusion/model/llava/mpt/blocks.py:
--------------------------------------------------------------------------------
1 | """GPT Blocks used for the GPT Model."""
2 | from typing import Dict, Optional, Tuple
3 | import torch
4 | import torch.nn as nn
5 | from .attention import ATTN_CLASS_REGISTRY
6 | from .norm import NORM_CLASS_REGISTRY
7 |
8 | class MPTMLP(nn.Module):
9 |
10 | def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None):
11 | super().__init__()
12 | self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)
13 | self.act = nn.GELU(approximate='none')
14 | self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)
15 | self.down_proj._is_residual = True
16 |
17 | def forward(self, x):
18 | return self.down_proj(self.act(self.up_proj(x)))
19 |
20 | class MPTBlock(nn.Module):
21 |
22 | def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict = None, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', device: Optional[str]=None, **kwargs):
23 | if attn_config is None:
24 | attn_config = {
25 | 'attn_type': 'multihead_attention',
26 | 'attn_pdrop': 0.0,
27 | 'attn_impl': 'triton',
28 | 'qk_ln': False,
29 | 'clip_qkv': None,
30 | 'softmax_scale': None,
31 | 'prefix_lm': False,
32 | 'attn_uses_sequence_id': False,
33 | 'alibi': False,
34 | 'alibi_bias_max': 8,
35 | }
36 | del kwargs
37 | super().__init__()
38 | norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
39 | attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
40 | self.norm_1 = norm_class(d_model, device=device)
41 | self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, device=device)
42 | self.norm_2 = norm_class(d_model, device=device)
43 | self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
44 | self.resid_attn_dropout = nn.Dropout(resid_pdrop)
45 | self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
46 |
47 | def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
48 | a = self.norm_1(x)
49 | (b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
50 | x = x + self.resid_attn_dropout(b)
51 | m = self.norm_2(x)
52 | n = self.ffn(m)
53 | x = x + self.resid_ffn_dropout(n)
54 | return (x, past_key_value)
--------------------------------------------------------------------------------
/diffusion/model/llava/mpt/configuration_mpt.py:
--------------------------------------------------------------------------------
1 | """A HuggingFace-style model configuration."""
2 | from typing import Dict, Optional, Union
3 | from transformers import PretrainedConfig
4 | attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}
5 | init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu'}
6 |
7 | class MPTConfig(PretrainedConfig):
8 | model_type = 'mpt'
9 |
10 | def __init__(self, d_model: int=2048, n_heads: int=16, n_layers: int=24, expansion_ratio: int=4, max_seq_len: int=2048, vocab_size: int=50368, resid_pdrop: float=0.0, emb_pdrop: float=0.0, learned_pos_emb: bool=True, attn_config: Dict=attn_config_defaults, init_device: str='cpu', logit_scale: Optional[Union[float, str]]=None, no_bias: bool=False, verbose: int=0, embedding_fraction: float=1.0, norm_type: str='low_precision_layernorm', use_cache: bool=False, init_config: Dict=init_config_defaults, **kwargs):
11 | """The MPT configuration class.
12 |
13 | Args:
14 | d_model (int): The size of the embedding dimension of the model.
15 | n_heads (int): The number of attention heads.
16 | n_layers (int): The number of layers in the model.
17 | expansion_ratio (int): The ratio of the up/down scale in the MLP.
18 | max_seq_len (int): The maximum sequence length of the model.
19 | vocab_size (int): The size of the vocabulary.
20 | resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
21 | emb_pdrop (float): The dropout probability for the embedding layer.
22 | learned_pos_emb (bool): Whether to use learned positional embeddings
23 | attn_config (Dict): A dictionary used to configure the model's attention module:
24 | attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention
25 | attn_pdrop (float): The dropout probability for the attention layers.
26 | attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.
27 | qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
28 | clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
29 | this value.
30 | softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
31 | use the default scale of ``1/sqrt(d_keys)``.
32 | prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an
33 | extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix
34 | can attend to one another bi-directionally. Tokens outside the prefix use causal attention.
35 | attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.
36 | When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
37 | which sub-sequence each token belongs to.
38 | Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
39 | alibi (bool): Whether to use the alibi bias instead of position embeddings.
40 | alibi_bias_max (int): The maximum value of the alibi bias.
41 | init_device (str): The device to use for parameter initialization.
42 | logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
43 | no_bias (bool): Whether to use bias in all layers.
44 | verbose (int): The verbosity level. 0 is silent.
45 | embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
46 | norm_type (str): choose type of norm to use
47 | multiquery_attention (bool): Whether to use multiquery attention implementation.
48 | use_cache (bool): Whether or not the model should return the last key/values attentions
49 | init_config (Dict): A dictionary used to configure the model initialization:
50 | init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_',
51 | 'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or
52 | 'xavier_normal_'. These mimic the parameter initialization methods in PyTorch.
53 | init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True.
54 | emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer.
55 | emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution
56 | used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``.
57 | init_std (float): The standard deviation of the normal distribution used to initialize the model,
58 | if using the baseline_ parameter initialization scheme.
59 | init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes.
60 | fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes.
61 | init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes.
62 | ---
63 | See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
64 | """
65 | self.d_model = d_model
66 | self.n_heads = n_heads
67 | self.n_layers = n_layers
68 | self.expansion_ratio = expansion_ratio
69 | self.max_seq_len = max_seq_len
70 | self.vocab_size = vocab_size
71 | self.resid_pdrop = resid_pdrop
72 | self.emb_pdrop = emb_pdrop
73 | self.learned_pos_emb = learned_pos_emb
74 | self.attn_config = attn_config
75 | self.init_device = init_device
76 | self.logit_scale = logit_scale
77 | self.no_bias = no_bias
78 | self.verbose = verbose
79 | self.embedding_fraction = embedding_fraction
80 | self.norm_type = norm_type
81 | self.use_cache = use_cache
82 | self.init_config = init_config
83 | if 'name' in kwargs:
84 | del kwargs['name']
85 | if 'loss_fn' in kwargs:
86 | del kwargs['loss_fn']
87 | super().__init__(**kwargs)
88 | self._validate_config()
89 |
90 | def _set_config_defaults(self, config, config_defaults):
91 | for (k, v) in config_defaults.items():
92 | if k not in config:
93 | config[k] = v
94 | return config
95 |
96 | def _validate_config(self):
97 | self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults)
98 | self.init_config = self._set_config_defaults(self.init_config, init_config_defaults)
99 | if self.d_model % self.n_heads != 0:
100 | raise ValueError('d_model must be divisible by n_heads')
101 | if any((prob < 0 or prob > 1 for prob in [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop])):
102 | raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1")
103 | if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']:
104 | raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}")
105 | if self.attn_config['prefix_lm'] and self.attn_config['attn_impl'] not in ['torch', 'triton']:
106 | raise NotImplementedError('prefix_lm only implemented with torch and triton attention.')
107 | if self.attn_config['alibi'] and self.attn_config['attn_impl'] not in ['torch', 'triton']:
108 | raise NotImplementedError('alibi only implemented with torch and triton attention.')
109 | if self.attn_config['attn_uses_sequence_id'] and self.attn_config['attn_impl'] not in ['torch', 'triton']:
110 | raise NotImplementedError('attn_uses_sequence_id only implemented with torch and triton attention.')
111 | if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
112 | raise ValueError('model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!')
113 | if isinstance(self.logit_scale, str) and self.logit_scale != 'inv_sqrt_d_model':
114 | raise ValueError(f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
115 | if self.init_config.get('name', None) is None:
116 | raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.")
117 | if not self.learned_pos_emb and (not self.attn_config['alibi']):
118 | raise ValueError(
119 | 'Positional information must be provided to the model using either learned_pos_emb or alibi.'
120 | )
--------------------------------------------------------------------------------
/diffusion/model/llava/mpt/norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def _cast_if_autocast_enabled(tensor):
4 | if torch.is_autocast_enabled():
5 | if tensor.device.type == 'cuda':
6 | dtype = torch.get_autocast_gpu_dtype()
7 | elif tensor.device.type == 'cpu':
8 | dtype = torch.get_autocast_cpu_dtype()
9 | else:
10 | raise NotImplementedError()
11 | return tensor.to(dtype=dtype)
12 | return tensor
13 |
14 | class LPLayerNorm(torch.nn.LayerNorm):
15 |
16 | def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None):
17 | super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)
18 |
19 | def forward(self, x):
20 | module_device = x.device
21 | downcast_x = _cast_if_autocast_enabled(x)
22 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
23 | downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
24 | with torch.autocast(enabled=False, device_type=module_device.type):
25 | return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
26 |
27 | def rms_norm(x, weight=None, eps=1e-05):
28 | output = x / torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
29 | return output * weight if weight is not None else output
30 |
31 | class RMSNorm(torch.nn.Module):
32 |
33 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
34 | super().__init__()
35 | self.eps = eps
36 | if weight:
37 | self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device))
38 | else:
39 | self.register_parameter('weight', None)
40 |
41 | def forward(self, x):
42 | return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
43 |
44 | class LPRMSNorm(RMSNorm):
45 |
46 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
47 | super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device)
48 |
49 | def forward(self, x):
50 | downcast_x = _cast_if_autocast_enabled(x)
51 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
52 | with torch.autocast(enabled=False, device_type=x.device.type):
53 | return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
54 | NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm}
--------------------------------------------------------------------------------
/diffusion/model/llava/mpt/param_init_fns.py:
--------------------------------------------------------------------------------
1 | import math
2 | import warnings
3 | from collections.abc import Sequence
4 | from functools import partial
5 | from typing import Optional, Tuple, Union
6 | import torch
7 | from torch import nn
8 | from .norm import NORM_CLASS_REGISTRY
9 |
10 | def torch_default_param_init_fn_(module: nn.Module, verbose: int=0, **kwargs):
11 | del kwargs
12 | if verbose > 1:
13 | warnings.warn("Initializing network using module's reset_parameters attribute")
14 | if hasattr(module, 'reset_parameters'):
15 | module.reset_parameters()
16 |
17 | def fused_init_helper_(module: nn.Module, init_fn_):
18 | _fused = getattr(module, '_fused', None)
19 | if _fused is None:
20 | raise RuntimeError('Internal logic error')
21 | (dim, splits) = _fused
22 | splits = (0, *splits, module.weight.size(dim))
23 | for (s, e) in zip(splits[:-1], splits[1:]):
24 | slice_indices = [slice(None)] * module.weight.ndim
25 | slice_indices[dim] = slice(s, e)
26 | init_fn_(module.weight[slice_indices])
27 |
28 | def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
29 | del kwargs
30 | if verbose > 1:
31 | warnings.warn('If model has bias parameters they are initialized to 0.')
32 | init_div_is_residual = init_div_is_residual
33 | if init_div_is_residual is False:
34 | div_is_residual = 1.0
35 | elif init_div_is_residual is True:
36 | div_is_residual = math.sqrt(2 * n_layers)
37 | elif isinstance(init_div_is_residual, (float, int)):
38 | div_is_residual = init_div_is_residual
39 | elif isinstance(init_div_is_residual, str) and init_div_is_residual.isnumeric():
40 | div_is_residual = float(init_div_is_residual)
41 | else:
42 | div_is_residual = 1.0
43 | raise ValueError(f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}')
44 | if init_div_is_residual is not False and verbose > 1:
45 | warnings.warn(
46 | f'Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. Set `init_div_is_residual: false` in init config to disable this.'
47 | )
48 | if isinstance(module, nn.Linear):
49 | if hasattr(module, '_fused'):
50 | fused_init_helper_(module, init_fn_)
51 | else:
52 | init_fn_(module.weight)
53 | if module.bias is not None:
54 | torch.nn.init.zeros_(module.bias)
55 | if init_div_is_residual is not False and getattr(module, '_is_residual', False):
56 | with torch.no_grad():
57 | module.weight.div_(div_is_residual)
58 | elif isinstance(module, nn.Embedding):
59 | if emb_init_std is not None:
60 | std = emb_init_std
61 | if std == 0:
62 | warnings.warn('Embedding layer initialized to 0.')
63 | emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
64 | if verbose > 1:
65 | warnings.warn(f'Embedding layer initialized using normal distribution with mean=0 and std={std!r}.')
66 | elif emb_init_uniform_lim is not None:
67 | lim = emb_init_uniform_lim
68 | if isinstance(lim, Sequence):
69 | if len(lim) > 2:
70 | raise ValueError(f'Uniform init requires a min and a max limit. User input: {lim}.')
71 | if lim[0] == lim[1]:
72 | warnings.warn(f'Embedding layer initialized to {lim[0]}.')
73 | else:
74 | if lim == 0:
75 | warnings.warn('Embedding layer initialized to 0.')
76 | lim = [-lim, lim]
77 | (a, b) = lim
78 | emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
79 | if verbose > 1:
80 | warnings.warn(f'Embedding layer initialized using uniform distribution in range {lim}.')
81 | else:
82 | emb_init_fn_ = init_fn_
83 | emb_init_fn_(module.weight)
84 | elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):
85 | if verbose > 1:
86 | warnings.warn(
87 | 'Norm weights are set to 1. If norm layer has a bias it is initialized to 0.'
88 | )
89 | if hasattr(module, 'weight') and module.weight is not None:
90 | torch.nn.init.ones_(module.weight)
91 | if hasattr(module, 'bias') and module.bias is not None:
92 | torch.nn.init.zeros_(module.bias)
93 | elif isinstance(module, nn.MultiheadAttention):
94 | if module._qkv_same_embed_dim:
95 | _extracted_from_generic_param_init_fn__69(module, d_model, init_fn_)
96 | else:
97 | assert module.q_proj_weight is not None and module.k_proj_weight is not None and (module.v_proj_weight is not None)
98 | assert module.in_proj_weight is None
99 | init_fn_(module.q_proj_weight)
100 | init_fn_(module.k_proj_weight)
101 | init_fn_(module.v_proj_weight)
102 | if module.in_proj_bias is not None:
103 | torch.nn.init.zeros_(module.in_proj_bias)
104 | if module.bias_k is not None:
105 | torch.nn.init.zeros_(module.bias_k)
106 | if module.bias_v is not None:
107 | torch.nn.init.zeros_(module.bias_v)
108 | init_fn_(module.out_proj.weight)
109 | if init_div_is_residual is not False and getattr(module.out_proj, '_is_residual', False):
110 | with torch.no_grad():
111 | module.out_proj.weight.div_(div_is_residual)
112 | if module.out_proj.bias is not None:
113 | torch.nn.init.zeros_(module.out_proj.bias)
114 | else:
115 | for _ in module.parameters(recurse=False):
116 | raise NotImplementedError(f'{module.__class__.__name__} parameters are not initialized by param_init_fn.')
117 |
118 |
119 | # TODO Rename this here and in `generic_param_init_fn_`
120 | def _extracted_from_generic_param_init_fn__69(module, d_model, init_fn_):
121 | assert module.in_proj_weight is not None
122 | assert module.q_proj_weight is None and module.k_proj_weight is None and (module.v_proj_weight is None)
123 | assert d_model is not None
124 | _d = d_model
125 | splits = (0, _d, 2 * _d, 3 * _d)
126 | for (s, e) in zip(splits[:-1], splits[1:]):
127 | init_fn_(module.in_proj_weight[s:e])
128 |
129 | def _normal_init_(std, mean=0.0):
130 | return partial(torch.nn.init.normal_, mean=mean, std=std)
131 |
132 | def _normal_param_init_fn_(module: nn.Module, std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
133 | del kwargs
134 | init_fn_ = _normal_init_(std=std)
135 | if verbose > 1:
136 | warnings.warn(f'Using torch.nn.init.normal_ init fn mean=0.0, std={std}')
137 | generic_param_init_fn_(module=module, init_fn_=init_fn_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
138 |
139 | def baseline_param_init_fn_(module: nn.Module, init_std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
140 | del kwargs
141 | if init_std is None:
142 | raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.")
143 | _normal_param_init_fn_(module=module, std=init_std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
144 |
145 | def small_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
146 | del kwargs
147 | std = math.sqrt(2 / (5 * d_model))
148 | _normal_param_init_fn_(module=module, std=std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
149 |
150 | def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
151 | """From section 2.3.1 of GPT-NeoX-20B:
152 |
153 | An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
154 | see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151
155 | and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py
156 | """
157 | del kwargs
158 | residual_div = n_layers / math.sqrt(10)
159 | if verbose > 1:
160 | warnings.warn(f'setting init_div_is_residual to {residual_div}')
161 | small_param_init_fn_(module=module, d_model=d_model, n_layers=n_layers, init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
162 |
163 | def kaiming_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs):
164 | del kwargs
165 | if verbose > 1:
166 | warnings.warn(
167 | f'Using nn.init.kaiming_uniform_ init fn with parameters: a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}'
168 | )
169 | kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
170 | generic_param_init_fn_(module=module, init_fn_=kaiming_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
171 |
172 | def kaiming_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs):
173 | del kwargs
174 | if verbose > 1:
175 | warnings.warn(
176 | f'Using nn.init.kaiming_normal_ init fn with parameters: a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}'
177 | )
178 | kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
179 | generic_param_init_fn_(module=module, init_fn_=kaiming_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
180 |
181 | def xavier_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs):
182 | del kwargs
183 | xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
184 | if verbose > 1:
185 | warnings.warn(
186 | f'Using torch.nn.init.xavier_uniform_ init fn with parameters: gain={init_gain}'
187 | )
188 | generic_param_init_fn_(module=module, init_fn_=xavier_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
189 |
190 | def xavier_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs):
191 | xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
192 | if verbose > 1:
193 | warnings.warn(
194 | f'Using torch.nn.init.xavier_normal_ init fn with parameters: gain={init_gain}'
195 | )
196 | generic_param_init_fn_(module=module, init_fn_=xavier_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
197 | MODEL_INIT_REGISTRY = {'default_': torch_default_param_init_fn_, 'baseline_': baseline_param_init_fn_, 'kaiming_uniform_': kaiming_uniform_param_init_fn_, 'kaiming_normal_': kaiming_normal_param_init_fn_, 'neox_init_': neox_param_init_fn_, 'small_init_': small_param_init_fn_, 'xavier_uniform_': xavier_uniform_param_init_fn_, 'xavier_normal_': xavier_normal_param_init_fn_}
--------------------------------------------------------------------------------
/diffusion/model/nets/PixArtMS.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # GLIDE: https://github.com/openai/glide-text2im
9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10 | # --------------------------------------------------------
11 | import torch
12 | import torch.nn as nn
13 | from timm.models.layers import DropPath
14 | from timm.models.vision_transformer import Mlp
15 |
16 | from diffusion.model.builder import MODELS
17 | from diffusion.model.utils import auto_grad_checkpoint, to_2tuple
18 | from diffusion.model.nets.PixArt_blocks import t2i_modulate, CaptionEmbedder, WindowAttention, MultiHeadCrossAttention, T2IFinalLayer, TimestepEmbedder, SizeEmbedder
19 | from diffusion.model.nets.PixArt import PixArt, get_2d_sincos_pos_embed
20 |
21 |
22 | class PatchEmbed(nn.Module):
23 | """ 2D Image to Patch Embedding
24 | """
25 | def __init__(
26 | self,
27 | patch_size=16,
28 | in_chans=3,
29 | embed_dim=768,
30 | norm_layer=None,
31 | flatten=True,
32 | bias=True,
33 | ):
34 | super().__init__()
35 | patch_size = to_2tuple(patch_size)
36 | self.patch_size = patch_size
37 | self.flatten = flatten
38 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
39 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
40 |
41 | def forward(self, x):
42 | x = self.proj(x)
43 | if self.flatten:
44 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
45 | x = self.norm(x)
46 | return x
47 |
48 |
49 | class PixArtMSBlock(nn.Module):
50 | """
51 | A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning.
52 | """
53 |
54 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., window_size=0, input_size=None, use_rel_pos=False, **block_kwargs):
55 | super().__init__()
56 | self.hidden_size = hidden_size
57 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
58 | self.attn = WindowAttention(hidden_size, num_heads=num_heads, qkv_bias=True,
59 | input_size=input_size if window_size == 0 else (window_size, window_size),
60 | use_rel_pos=use_rel_pos, **block_kwargs)
61 | self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs)
62 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
63 | # to be compatible with lower version pytorch
64 | approx_gelu = lambda: nn.GELU(approximate="tanh")
65 | self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0)
66 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
67 | self.window_size = window_size
68 | self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
69 |
70 | def forward(self, x, y, t, mask=None, **kwargs):
71 | B, N, C = x.shape
72 |
73 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1)
74 | x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)))
75 | x = x + self.cross_attn(x, y, mask)
76 | x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
77 |
78 | return x
79 |
80 |
81 | #############################################################################
82 | # Core PixArt Model #
83 | #################################################################################
84 | @MODELS.register_module()
85 | class PixArtMS(PixArt):
86 | """
87 | Diffusion model with a Transformer backbone.
88 | """
89 |
90 | def __init__(self, input_size=32, patch_size=2, in_channels=4, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, learn_sigma=True, pred_sigma=True, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, caption_channels=4096, lewei_scale=1., config=None, model_max_length=120, **kwargs):
91 | if window_block_indexes is None:
92 | window_block_indexes = []
93 | super().__init__(
94 | input_size=input_size,
95 | patch_size=patch_size,
96 | in_channels=in_channels,
97 | hidden_size=hidden_size,
98 | depth=depth,
99 | num_heads=num_heads,
100 | mlp_ratio=mlp_ratio,
101 | class_dropout_prob=class_dropout_prob,
102 | learn_sigma=learn_sigma,
103 | pred_sigma=pred_sigma,
104 | drop_path=drop_path,
105 | window_size=window_size,
106 | window_block_indexes=window_block_indexes,
107 | use_rel_pos=use_rel_pos,
108 | lewei_scale=lewei_scale,
109 | config=config,
110 | model_max_length=model_max_length,
111 | **kwargs,
112 | )
113 | self.h = self.w = 0
114 | approx_gelu = lambda: nn.GELU(approximate="tanh")
115 | self.t_block = nn.Sequential(
116 | nn.SiLU(),
117 | nn.Linear(hidden_size, 6 * hidden_size, bias=True)
118 | )
119 | self.x_embedder = PatchEmbed(patch_size, in_channels, hidden_size, bias=True)
120 | self.y_embedder = CaptionEmbedder(in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob, act_layer=approx_gelu, token_num=model_max_length)
121 | self.csize_embedder = SizeEmbedder(hidden_size//3) # c_size embed
122 | self.ar_embedder = SizeEmbedder(hidden_size//3) # aspect ratio embed
123 | drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
124 | self.blocks = nn.ModuleList([
125 | PixArtMSBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
126 | input_size=(input_size // patch_size, input_size // patch_size),
127 | window_size=window_size if i in window_block_indexes else 0,
128 | use_rel_pos=use_rel_pos if i in window_block_indexes else False)
129 | for i in range(depth)
130 | ])
131 | self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels)
132 |
133 | self.initialize()
134 |
135 | def forward(self, x, timestep, y, mask=None, data_info=None, **kwargs):
136 | """
137 | Forward pass of PixArt.
138 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
139 | t: (N,) tensor of diffusion timesteps
140 | y: (N, 1, 120, C) tensor of class labels
141 | """
142 | bs = x.shape[0]
143 | x = x.to(self.dtype)
144 | timestep = timestep.to(self.dtype)
145 | y = y.to(self.dtype)
146 | c_size, ar = data_info['img_hw'].to(self.dtype), data_info['aspect_ratio'].to(self.dtype)
147 | self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size
148 | pos_embed = torch.from_numpy(get_2d_sincos_pos_embed(self.pos_embed.shape[-1], (self.h, self.w), lewei_scale=self.lewei_scale, base_size=self.base_size)).unsqueeze(0).to(x.device).to(self.dtype)
149 | x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
150 | t = self.t_embedder(timestep) # (N, D)
151 | csize = self.csize_embedder(c_size, bs) # (N, D)
152 | ar = self.ar_embedder(ar, bs) # (N, D)
153 | t = t + torch.cat([csize, ar], dim=1)
154 | t0 = self.t_block(t)
155 | y = self.y_embedder(y, self.training) # (N, D)
156 | if mask is not None:
157 | if mask.shape[0] != y.shape[0]:
158 | mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
159 | mask = mask.squeeze(1).squeeze(1)
160 | y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
161 | y_lens = mask.sum(dim=1).tolist()
162 | else:
163 | y_lens = [y.shape[2]] * y.shape[0]
164 | y = y.squeeze(1).view(1, -1, x.shape[-1])
165 | for block in self.blocks:
166 | x = auto_grad_checkpoint(block, x, y, t0, y_lens, **kwargs) # (N, T, D) #support grad checkpoint
167 | x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
168 | x = self.unpatchify(x) # (N, out_channels, H, W)
169 | return x
170 |
171 | def forward_with_dpmsolver(self, x, timestep, y, data_info, **kwargs):
172 | """
173 | dpm solver donnot need variance prediction
174 | """
175 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
176 | model_out = self.forward(x, timestep, y, data_info=data_info, **kwargs)
177 | return model_out.chunk(2, dim=1)[0]
178 |
179 | def forward_with_cfg(self, x, timestep, y, cfg_scale, data_info, **kwargs):
180 | """
181 | Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance.
182 | """
183 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
184 | half = x[: len(x) // 2]
185 | combined = torch.cat([half, half], dim=0)
186 | model_out = self.forward(combined, timestep, y, data_info=data_info)
187 | eps, rest = model_out[:, :3], model_out[:, 3:]
188 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
189 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
190 | eps = torch.cat([half_eps, half_eps], dim=0)
191 | return torch.cat([eps, rest], dim=1)
192 |
193 | def unpatchify(self, x):
194 | """
195 | x: (N, T, patch_size**2 * C)
196 | imgs: (N, H, W, C)
197 | """
198 | c = self.out_channels
199 | p = self.x_embedder.patch_size[0]
200 | assert self.h * self.w == x.shape[1]
201 |
202 | x = x.reshape(shape=(x.shape[0], self.h, self.w, p, p, c))
203 | x = torch.einsum('nhwpqc->nchpwq', x)
204 | return x.reshape(shape=(x.shape[0], c, self.h * p, self.w * p))
205 |
206 | def initialize(self):
207 | # Initialize transformer layers:
208 | def _basic_init(module):
209 | if isinstance(module, nn.Linear):
210 | torch.nn.init.xavier_uniform_(module.weight)
211 | if module.bias is not None:
212 | nn.init.constant_(module.bias, 0)
213 |
214 | self.apply(_basic_init)
215 |
216 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
217 | w = self.x_embedder.proj.weight.data
218 | nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
219 |
220 | # Initialize timestep embedding MLP:
221 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
222 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
223 | nn.init.normal_(self.t_block[1].weight, std=0.02)
224 | nn.init.normal_(self.csize_embedder.mlp[0].weight, std=0.02)
225 | nn.init.normal_(self.csize_embedder.mlp[2].weight, std=0.02)
226 | nn.init.normal_(self.ar_embedder.mlp[0].weight, std=0.02)
227 | nn.init.normal_(self.ar_embedder.mlp[2].weight, std=0.02)
228 |
229 | # Initialize caption embedding MLP:
230 | nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02)
231 | nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
232 |
233 | # Zero-out adaLN modulation layers in PixArt blocks:
234 | for block in self.blocks:
235 | nn.init.constant_(block.cross_attn.proj.weight, 0)
236 | nn.init.constant_(block.cross_attn.proj.bias, 0)
237 |
238 | # Zero-out output layers:
239 | nn.init.constant_(self.final_layer.linear.weight, 0)
240 | nn.init.constant_(self.final_layer.linear.bias, 0)
241 |
242 |
243 | #################################################################################
244 | # PixArt Configs #
245 | #################################################################################
246 | @MODELS.register_module()
247 | def PixArtMS_XL_2(**kwargs):
248 | return PixArtMS(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
249 |
--------------------------------------------------------------------------------
/diffusion/model/nets/__init__.py:
--------------------------------------------------------------------------------
1 | from .PixArt import PixArt, PixArt_XL_2
2 | from .PixArtMS import PixArtMS, PixArtMS_XL_2, PixArtMSBlock
3 | from .pixart_controlnet import ControlPixArtHalf, ControlPixArtMSHalf
--------------------------------------------------------------------------------
/diffusion/model/nets/pixart_controlnet.py:
--------------------------------------------------------------------------------
1 | import re
2 | import torch
3 | import torch.nn as nn
4 |
5 | from copy import deepcopy
6 | from torch import Tensor
7 | from torch.nn import Module, Linear, init
8 | from typing import Any, Mapping
9 |
10 | from diffusion.model.nets import PixArtMSBlock, PixArtMS, PixArt
11 | from diffusion.model.nets.PixArt import get_2d_sincos_pos_embed
12 | from diffusion.model.utils import auto_grad_checkpoint
13 |
14 |
15 | # The implementation of ControlNet-Half architrecture
16 | # https://github.com/lllyasviel/ControlNet/discussions/188
17 | class ControlT2IDitBlockHalf(Module):
18 | def __init__(self, base_block: PixArtMSBlock, block_index: 0) -> None:
19 | super().__init__()
20 | self.copied_block = deepcopy(base_block)
21 | self.block_index = block_index
22 |
23 | for p in self.copied_block.parameters():
24 | p.requires_grad_(True)
25 |
26 | self.copied_block.load_state_dict(base_block.state_dict())
27 | self.copied_block.train()
28 |
29 | self.hidden_size = hidden_size = base_block.hidden_size
30 | if self.block_index == 0:
31 | self.before_proj = Linear(hidden_size, hidden_size)
32 | init.zeros_(self.before_proj.weight)
33 | init.zeros_(self.before_proj.bias)
34 | self.after_proj = Linear(hidden_size, hidden_size)
35 | init.zeros_(self.after_proj.weight)
36 | init.zeros_(self.after_proj.bias)
37 |
38 | def forward(self, x, y, t, mask=None, c=None):
39 |
40 | if self.block_index == 0:
41 | # the first block
42 | c = self.before_proj(c)
43 | c = self.copied_block(x + c, y, t, mask)
44 | c_skip = self.after_proj(c)
45 | else:
46 | # load from previous c and produce the c for skip connection
47 | c = self.copied_block(c, y, t, mask)
48 | c_skip = self.after_proj(c)
49 |
50 | return c, c_skip
51 |
52 |
53 | # The implementation of ControlPixArtHalf net
54 | class ControlPixArtHalf(Module):
55 | # only support single res model
56 | def __init__(self, base_model: PixArt, copy_blocks_num: int = 13) -> None:
57 | super().__init__()
58 | self.base_model = base_model.eval()
59 | self.controlnet = []
60 | self.copy_blocks_num = copy_blocks_num
61 | self.total_blocks_num = len(base_model.blocks)
62 | for p in self.base_model.parameters():
63 | p.requires_grad_(False)
64 |
65 | # Copy first copy_blocks_num block
66 | for i in range(copy_blocks_num):
67 | self.controlnet.append(ControlT2IDitBlockHalf(base_model.blocks[i], i))
68 | self.controlnet = nn.ModuleList(self.controlnet)
69 |
70 | def __getattr__(self, name: str) -> Tensor or Module:
71 | if name in ['forward', 'forward_with_dpmsolver', 'forward_with_cfg', 'forward_c', 'load_state_dict']:
72 | return self.__dict__[name]
73 | elif name in ['base_model', 'controlnet']:
74 | return super().__getattr__(name)
75 | else:
76 | return getattr(self.base_model, name)
77 |
78 | def forward_c(self, c):
79 | self.h, self.w = c.shape[-2]//self.patch_size, c.shape[-1]//self.patch_size
80 | pos_embed = torch.from_numpy(get_2d_sincos_pos_embed(self.pos_embed.shape[-1], (self.h, self.w), lewei_scale=self.lewei_scale, base_size=self.base_size)).unsqueeze(0).to(c.device).to(self.dtype)
81 | return self.x_embedder(c) + pos_embed if c is not None else c
82 |
83 | # def forward(self, x, t, c, **kwargs):
84 | # return self.base_model(x, t, c=self.forward_c(c), **kwargs)
85 | def forward(self, x, timestep, y, mask=None, data_info=None, c=None, **kwargs):
86 | # modify the original PixArtMS forward function
87 | if c is not None:
88 | c = c.to(self.dtype)
89 | c = self.forward_c(c)
90 | """
91 | Forward pass of PixArt.
92 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
93 | t: (N,) tensor of diffusion timesteps
94 | y: (N, 1, 120, C) tensor of class labels
95 | """
96 | x = x.to(self.dtype)
97 | timestep = timestep.to(self.dtype)
98 | y = y.to(self.dtype)
99 | pos_embed = self.pos_embed.to(self.dtype)
100 | self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size
101 | x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
102 | t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
103 | t0 = self.t_block(t)
104 | y = self.y_embedder(y, self.training) # (N, 1, L, D)
105 | if mask is not None:
106 | if mask.shape[0] != y.shape[0]:
107 | mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
108 | mask = mask.squeeze(1).squeeze(1)
109 | y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
110 | y_lens = mask.sum(dim=1).tolist()
111 | else:
112 | y_lens = [y.shape[2]] * y.shape[0]
113 | y = y.squeeze(1).view(1, -1, x.shape[-1])
114 |
115 | # define the first layer
116 | x = auto_grad_checkpoint(self.base_model.blocks[0], x, y, t0, y_lens, **kwargs) # (N, T, D) #support grad checkpoint
117 |
118 | if c is not None:
119 | # update c
120 | for index in range(1, self.copy_blocks_num + 1):
121 | c, c_skip = auto_grad_checkpoint(self.controlnet[index - 1], x, y, t0, y_lens, c, **kwargs)
122 | x = auto_grad_checkpoint(self.base_model.blocks[index], x + c_skip, y, t0, y_lens, **kwargs)
123 |
124 | # update x
125 | for index in range(self.copy_blocks_num + 1, self.total_blocks_num):
126 | x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, **kwargs)
127 | else:
128 | for index in range(1, self.total_blocks_num):
129 | x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, **kwargs)
130 |
131 | x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
132 | x = self.unpatchify(x) # (N, out_channels, H, W)
133 | return x
134 |
135 | def forward_with_dpmsolver(self, x, t, y, data_info, c, **kwargs):
136 | model_out = self.forward(x, t, y, data_info=data_info, c=c, **kwargs)
137 | return model_out.chunk(2, dim=1)[0]
138 |
139 | # def forward_with_dpmsolver(self, x, t, y, data_info, c, **kwargs):
140 | # return self.base_model.forward_with_dpmsolver(x, t, y, data_info=data_info, c=self.forward_c(c), **kwargs)
141 |
142 | def forward_with_cfg(self, x, t, y, cfg_scale, data_info, c, **kwargs):
143 | return self.base_model.forward_with_cfg(x, t, y, cfg_scale, data_info, c=self.forward_c(c), **kwargs)
144 |
145 | def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
146 | if all((k.startswith('base_model') or k.startswith('controlnet')) for k in state_dict.keys()):
147 | return super().load_state_dict(state_dict, strict)
148 | else:
149 | new_key = {}
150 | for k in state_dict.keys():
151 | new_key[k] = re.sub(r"(blocks\.\d+)(.*)", r"\1.base_block\2", k)
152 | for k, v in new_key.items():
153 | if k != v:
154 | print(f"replace {k} to {v}")
155 | state_dict[v] = state_dict.pop(k)
156 |
157 | return self.base_model.load_state_dict(state_dict, strict)
158 |
159 | def unpatchify(self, x):
160 | """
161 | x: (N, T, patch_size**2 * C)
162 | imgs: (N, H, W, C)
163 | """
164 | c = self.out_channels
165 | p = self.x_embedder.patch_size[0]
166 | assert self.h * self.w == x.shape[1]
167 |
168 | x = x.reshape(shape=(x.shape[0], self.h, self.w, p, p, c))
169 | x = torch.einsum('nhwpqc->nchpwq', x)
170 | imgs = x.reshape(shape=(x.shape[0], c, self.h * p, self.w * p))
171 | return imgs
172 |
173 | @property
174 | def dtype(self):
175 | # 返回模型参数的数据类型
176 | return next(self.parameters()).dtype
177 |
178 |
179 | # The implementation for PixArtMS_Half + 1024 resolution
180 | class ControlPixArtMSHalf(ControlPixArtHalf):
181 | # support multi-scale res model (multi-scale model can also be applied to single reso training & inference)
182 | def __init__(self, base_model: PixArtMS, copy_blocks_num: int = 13) -> None:
183 | super().__init__(base_model=base_model, copy_blocks_num=copy_blocks_num)
184 |
185 | def forward(self, x, timestep, y, mask=None, data_info=None, c=None, **kwargs):
186 | # modify the original PixArtMS forward function
187 | """
188 | Forward pass of PixArt.
189 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
190 | t: (N,) tensor of diffusion timesteps
191 | y: (N, 1, 120, C) tensor of class labels
192 | """
193 | if c is not None:
194 | c = c.to(self.dtype)
195 | c = self.forward_c(c)
196 | bs = x.shape[0]
197 | x = x.to(self.dtype)
198 | timestep = timestep.to(self.dtype)
199 | y = y.to(self.dtype)
200 | c_size, ar = data_info['img_hw'].to(self.dtype), data_info['aspect_ratio'].to(self.dtype)
201 | self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size
202 |
203 | pos_embed = torch.from_numpy(get_2d_sincos_pos_embed(self.pos_embed.shape[-1], (self.h, self.w), lewei_scale=self.lewei_scale, base_size=self.base_size)).unsqueeze(0).to(x.device).to(self.dtype)
204 | x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
205 | t = self.t_embedder(timestep) # (N, D)
206 | csize = self.csize_embedder(c_size, bs) # (N, D)
207 | ar = self.ar_embedder(ar, bs) # (N, D)
208 | t = t + torch.cat([csize, ar], dim=1)
209 | t0 = self.t_block(t)
210 | y = self.y_embedder(y, self.training) # (N, D)
211 | if mask is not None:
212 | if mask.shape[0] != y.shape[0]:
213 | mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
214 | mask = mask.squeeze(1).squeeze(1)
215 | y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
216 | y_lens = mask.sum(dim=1).tolist()
217 | else:
218 | y_lens = [y.shape[2]] * y.shape[0]
219 | y = y.squeeze(1).view(1, -1, x.shape[-1])
220 |
221 | # define the first layer
222 | x = auto_grad_checkpoint(self.base_model.blocks[0], x, y, t0, y_lens, **kwargs) # (N, T, D) #support grad checkpoint
223 |
224 | if c is not None:
225 | # update c
226 | for index in range(1, self.copy_blocks_num + 1):
227 | c, c_skip = auto_grad_checkpoint(self.controlnet[index - 1], x, y, t0, y_lens, c, **kwargs)
228 | x = auto_grad_checkpoint(self.base_model.blocks[index], x + c_skip, y, t0, y_lens, **kwargs)
229 |
230 | # update x
231 | for index in range(self.copy_blocks_num + 1, self.total_blocks_num):
232 | x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, **kwargs)
233 | else:
234 | for index in range(1, self.total_blocks_num):
235 | x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, **kwargs)
236 |
237 | x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
238 | x = self.unpatchify(x) # (N, out_channels, H, W)
239 | return x
240 |
--------------------------------------------------------------------------------
/diffusion/model/respace.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import numpy as np
7 | import torch as th
8 |
9 | from .gaussian_diffusion import GaussianDiffusion
10 |
11 |
12 | def space_timesteps(num_timesteps, section_counts):
13 | """
14 | Create a list of timesteps to use from an original diffusion process,
15 | given the number of timesteps we want to take from equally-sized portions
16 | of the original process.
17 | For example, if there's 300 timesteps and the section counts are [10,15,20]
18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
19 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
20 | If the stride is a string starting with "ddim", then the fixed striding
21 | from the DDIM paper is used, and only one section is allowed.
22 | :param num_timesteps: the number of diffusion steps in the original
23 | process to divide up.
24 | :param section_counts: either a list of numbers, or a string containing
25 | comma-separated numbers, indicating the step count
26 | per section. As a special case, use "ddimN" where N
27 | is a number of steps to use the striding from the
28 | DDIM paper.
29 | :return: a set of diffusion steps from the original process to use.
30 | """
31 | if isinstance(section_counts, str):
32 | if section_counts.startswith("ddim"):
33 | desired_count = int(section_counts[len("ddim") :])
34 | for i in range(1, num_timesteps):
35 | if len(range(0, num_timesteps, i)) == desired_count:
36 | return set(range(0, num_timesteps, i))
37 | raise ValueError(
38 | f"cannot create exactly {num_timesteps} steps with an integer stride"
39 | )
40 | section_counts = [int(x) for x in section_counts.split(",")]
41 | size_per = num_timesteps // len(section_counts)
42 | extra = num_timesteps % len(section_counts)
43 | start_idx = 0
44 | all_steps = []
45 | for i, section_count in enumerate(section_counts):
46 | size = size_per + (1 if i < extra else 0)
47 | if size < section_count:
48 | raise ValueError(
49 | f"cannot divide section of {size} steps into {section_count}"
50 | )
51 | frac_stride = 1 if section_count <= 1 else (size - 1) / (section_count - 1)
52 | cur_idx = 0.0
53 | taken_steps = []
54 | for _ in range(section_count):
55 | taken_steps.append(start_idx + round(cur_idx))
56 | cur_idx += frac_stride
57 | all_steps += taken_steps
58 | start_idx += size
59 | return set(all_steps)
60 |
61 |
62 | class SpacedDiffusion(GaussianDiffusion):
63 | """
64 | A diffusion process which can skip steps in a base diffusion process.
65 | :param use_timesteps: a collection (sequence or set) of timesteps from the
66 | original diffusion process to retain.
67 | :param kwargs: the kwargs to create the base diffusion process.
68 | """
69 |
70 | def __init__(self, use_timesteps, **kwargs):
71 | self.use_timesteps = set(use_timesteps)
72 | self.timestep_map = []
73 | self.original_num_steps = len(kwargs["betas"])
74 |
75 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
76 | last_alpha_cumprod = 1.0
77 | new_betas = []
78 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
79 | if i in self.use_timesteps:
80 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
81 | last_alpha_cumprod = alpha_cumprod
82 | self.timestep_map.append(i)
83 | kwargs["betas"] = np.array(new_betas)
84 | super().__init__(**kwargs)
85 |
86 | def p_mean_variance(
87 | self, model, *args, **kwargs
88 | ): # pylint: disable=signature-differs
89 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
90 |
91 | def training_losses(
92 | self, model, *args, **kwargs
93 | ): # pylint: disable=signature-differs
94 | return super().training_losses(self._wrap_model(model), *args, **kwargs)
95 |
96 | def training_losses_diffusers(
97 | self, model, *args, **kwargs
98 | ): # pylint: disable=signature-differs
99 | return super().training_losses_diffusers(self._wrap_model(model), *args, **kwargs)
100 |
101 | def condition_mean(self, cond_fn, *args, **kwargs):
102 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
103 |
104 | def condition_score(self, cond_fn, *args, **kwargs):
105 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
106 |
107 | def _wrap_model(self, model):
108 | if isinstance(model, _WrappedModel):
109 | return model
110 | return _WrappedModel(
111 | model, self.timestep_map, self.original_num_steps
112 | )
113 |
114 | def _scale_timesteps(self, t):
115 | # Scaling is done by the wrapped model.
116 | return t
117 |
118 |
119 | class _WrappedModel:
120 | def __init__(self, model, timestep_map, original_num_steps):
121 | self.model = model
122 | self.timestep_map = timestep_map
123 | # self.rescale_timesteps = rescale_timesteps
124 | self.original_num_steps = original_num_steps
125 |
126 | def __call__(self, x, timestep, **kwargs):
127 | map_tensor = th.tensor(self.timestep_map, device=timestep.device, dtype=timestep.dtype)
128 | new_ts = map_tensor[timestep]
129 | # if self.rescale_timesteps:
130 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
131 | return self.model(x, timestep=new_ts, **kwargs)
132 |
--------------------------------------------------------------------------------
/diffusion/model/t5.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import re
4 | import html
5 | import urllib.parse as ul
6 |
7 | import ftfy
8 | import torch
9 | from bs4 import BeautifulSoup
10 | from transformers import T5EncoderModel, AutoTokenizer
11 | from huggingface_hub import hf_hub_download
12 |
13 | class T5Embedder:
14 |
15 | available_models = ['t5-v1_1-xxl']
16 | bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}') # noqa
17 |
18 | def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, local_cache=False, cache_dir=None, hf_token=None, use_text_preprocessing=True,
19 | t5_model_kwargs=None, torch_dtype=None, use_offload_folder=None, model_max_length=120):
20 | self.device = torch.device(device)
21 | self.torch_dtype = torch_dtype or torch.bfloat16
22 | if t5_model_kwargs is None:
23 | t5_model_kwargs = {'low_cpu_mem_usage': True, 'torch_dtype': self.torch_dtype}
24 | if use_offload_folder is not None:
25 | t5_model_kwargs['offload_folder'] = use_offload_folder
26 | t5_model_kwargs['device_map'] = {
27 | 'shared': self.device,
28 | 'encoder.embed_tokens': self.device,
29 | 'encoder.block.0': self.device,
30 | 'encoder.block.1': self.device,
31 | 'encoder.block.2': self.device,
32 | 'encoder.block.3': self.device,
33 | 'encoder.block.4': self.device,
34 | 'encoder.block.5': self.device,
35 | 'encoder.block.6': self.device,
36 | 'encoder.block.7': self.device,
37 | 'encoder.block.8': self.device,
38 | 'encoder.block.9': self.device,
39 | 'encoder.block.10': self.device,
40 | 'encoder.block.11': self.device,
41 | 'encoder.block.12': 'disk',
42 | 'encoder.block.13': 'disk',
43 | 'encoder.block.14': 'disk',
44 | 'encoder.block.15': 'disk',
45 | 'encoder.block.16': 'disk',
46 | 'encoder.block.17': 'disk',
47 | 'encoder.block.18': 'disk',
48 | 'encoder.block.19': 'disk',
49 | 'encoder.block.20': 'disk',
50 | 'encoder.block.21': 'disk',
51 | 'encoder.block.22': 'disk',
52 | 'encoder.block.23': 'disk',
53 | 'encoder.final_layer_norm': 'disk',
54 | 'encoder.dropout': 'disk',
55 | }
56 | else:
57 | t5_model_kwargs['device_map'] = {'shared': self.device, 'encoder': self.device}
58 |
59 | self.use_text_preprocessing = use_text_preprocessing
60 | self.hf_token = hf_token
61 | self.cache_dir = cache_dir or os.path.expanduser('~/.cache/IF_')
62 | self.dir_or_name = dir_or_name
63 | tokenizer_path, path = dir_or_name, dir_or_name
64 | if local_cache:
65 | cache_dir = os.path.join(self.cache_dir, dir_or_name)
66 | tokenizer_path, path = cache_dir, cache_dir
67 | elif dir_or_name in self.available_models:
68 | cache_dir = os.path.join(self.cache_dir, dir_or_name)
69 | for filename in [
70 | 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',
71 | 'pytorch_model.bin.index.json', 'pytorch_model-00001-of-00002.bin', 'pytorch_model-00002-of-00002.bin'
72 | ]:
73 | hf_hub_download(repo_id=f'DeepFloyd/{dir_or_name}', filename=filename, cache_dir=cache_dir,
74 | force_filename=filename, token=self.hf_token)
75 | tokenizer_path, path = cache_dir, cache_dir
76 | else:
77 | cache_dir = os.path.join(self.cache_dir, 't5-v1_1-xxl')
78 | for filename in [
79 | 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',
80 | ]:
81 | hf_hub_download(repo_id='DeepFloyd/t5-v1_1-xxl', filename=filename, cache_dir=cache_dir,
82 | force_filename=filename, token=self.hf_token)
83 | tokenizer_path = cache_dir
84 |
85 | print(tokenizer_path)
86 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
87 | self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval()
88 | self.model_max_length = model_max_length
89 |
90 | def get_text_embeddings(self, texts):
91 | texts = [self.text_preprocessing(text) for text in texts]
92 |
93 | text_tokens_and_mask = self.tokenizer(
94 | texts,
95 | max_length=self.model_max_length,
96 | padding='max_length',
97 | truncation=True,
98 | return_attention_mask=True,
99 | add_special_tokens=True,
100 | return_tensors='pt'
101 | )
102 |
103 | text_tokens_and_mask['input_ids'] = text_tokens_and_mask['input_ids']
104 | text_tokens_and_mask['attention_mask'] = text_tokens_and_mask['attention_mask']
105 |
106 | with torch.no_grad():
107 | text_encoder_embs = self.model(
108 | input_ids=text_tokens_and_mask['input_ids'].to(self.device),
109 | attention_mask=text_tokens_and_mask['attention_mask'].to(self.device),
110 | )['last_hidden_state'].detach()
111 | return text_encoder_embs, text_tokens_and_mask['attention_mask'].to(self.device)
112 |
113 | def text_preprocessing(self, text):
114 | if self.use_text_preprocessing:
115 | # The exact text cleaning as was in the training stage:
116 | text = self.clean_caption(text)
117 | text = self.clean_caption(text)
118 | return text
119 | else:
120 | return text.lower().strip()
121 |
122 | @staticmethod
123 | def basic_clean(text):
124 | text = ftfy.fix_text(text)
125 | text = html.unescape(html.unescape(text))
126 | return text.strip()
127 |
128 | def clean_caption(self, caption):
129 | caption = str(caption)
130 | caption = ul.unquote_plus(caption)
131 | caption = caption.strip().lower()
132 | caption = re.sub('', 'person', caption)
133 | # urls:
134 | caption = re.sub(
135 | r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa
136 | '', caption) # regex for urls
137 | caption = re.sub(
138 | r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa
139 | '', caption) # regex for urls
140 | # html:
141 | caption = BeautifulSoup(caption, features='html.parser').text
142 |
143 | # @
144 | caption = re.sub(r'@[\w\d]+\b', '', caption)
145 |
146 | # 31C0—31EF CJK Strokes
147 | # 31F0—31FF Katakana Phonetic Extensions
148 | # 3200—32FF Enclosed CJK Letters and Months
149 | # 3300—33FF CJK Compatibility
150 | # 3400—4DBF CJK Unified Ideographs Extension A
151 | # 4DC0—4DFF Yijing Hexagram Symbols
152 | # 4E00—9FFF CJK Unified Ideographs
153 | caption = re.sub(r'[\u31c0-\u31ef]+', '', caption)
154 | caption = re.sub(r'[\u31f0-\u31ff]+', '', caption)
155 | caption = re.sub(r'[\u3200-\u32ff]+', '', caption)
156 | caption = re.sub(r'[\u3300-\u33ff]+', '', caption)
157 | caption = re.sub(r'[\u3400-\u4dbf]+', '', caption)
158 | caption = re.sub(r'[\u4dc0-\u4dff]+', '', caption)
159 | caption = re.sub(r'[\u4e00-\u9fff]+', '', caption)
160 | #######################################################
161 |
162 | # все виды тире / all types of dash --> "-"
163 | caption = re.sub(
164 | r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+', # noqa
165 | '-', caption)
166 |
167 | # кавычки к одному стандарту
168 | caption = re.sub(r'[`´«»“”¨]', '"', caption)
169 | caption = re.sub(r'[‘’]', "'", caption)
170 |
171 | # "
172 | caption = re.sub(r'"?', '', caption)
173 | # &
174 | caption = re.sub(r'&', '', caption)
175 |
176 | # ip adresses:
177 | caption = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', ' ', caption)
178 |
179 | # article ids:
180 | caption = re.sub(r'\d:\d\d\s+$', '', caption)
181 |
182 | # \n
183 | caption = re.sub(r'\\n', ' ', caption)
184 |
185 | # "#123"
186 | caption = re.sub(r'#\d{1,3}\b', '', caption)
187 | # "#12345.."
188 | caption = re.sub(r'#\d{5,}\b', '', caption)
189 | # "123456.."
190 | caption = re.sub(r'\b\d{6,}\b', '', caption)
191 | # filenames:
192 | caption = re.sub(r'[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption)
193 |
194 | #
195 | caption = re.sub(r'[\"\']{2,}', r'"', caption) # """AUSVERKAUFT"""
196 | caption = re.sub(r'[\.]{2,}', r' ', caption) # """AUSVERKAUFT"""
197 |
198 | caption = re.sub(self.bad_punct_regex, r' ', caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
199 | caption = re.sub(r'\s+\.\s+', r' ', caption) # " . "
200 |
201 | # this-is-my-cute-cat / this_is_my_cute_cat
202 | regex2 = re.compile(r'(?:\-|\_)')
203 | if len(re.findall(regex2, caption)) > 3:
204 | caption = re.sub(regex2, ' ', caption)
205 |
206 | caption = self.basic_clean(caption)
207 |
208 | caption = re.sub(r'\b[a-zA-Z]{1,3}\d{3,15}\b', '', caption) # jc6640
209 | caption = re.sub(r'\b[a-zA-Z]+\d+[a-zA-Z]+\b', '', caption) # jc6640vc
210 | caption = re.sub(r'\b\d+[a-zA-Z]+\d+\b', '', caption) # 6640vc231
211 |
212 | caption = re.sub(r'(worldwide\s+)?(free\s+)?shipping', '', caption)
213 | caption = re.sub(r'(free\s)?download(\sfree)?', '', caption)
214 | caption = re.sub(r'\bclick\b\s(?:for|on)\s\w+', '', caption)
215 | caption = re.sub(r'\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?', '', caption)
216 | caption = re.sub(r'\bpage\s+\d+\b', '', caption)
217 |
218 | caption = re.sub(r'\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b', r' ', caption) # j2d1a2a...
219 |
220 | caption = re.sub(r'\b\d+\.?\d*[xх×]\d+\.?\d*\b', '', caption)
221 |
222 | caption = re.sub(r'\b\s+\:\s+', r': ', caption)
223 | caption = re.sub(r'(\D[,\./])\b', r'\1 ', caption)
224 | caption = re.sub(r'\s+', ' ', caption)
225 |
226 | caption.strip()
227 |
228 | caption = re.sub(r'^[\"\']([\w\W]+)[\"\']$', r'\1', caption)
229 | caption = re.sub(r'^[\'\_,\-\:;]', r'', caption)
230 | caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption)
231 | caption = re.sub(r'^\.\S+$', '', caption)
232 |
233 | return caption.strip()
234 |
--------------------------------------------------------------------------------
/diffusion/model/timestep_sampler.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | from abc import ABC, abstractmethod
7 |
8 | import numpy as np
9 | import torch as th
10 | import torch.distributed as dist
11 |
12 |
13 | def create_named_schedule_sampler(name, diffusion):
14 | """
15 | Create a ScheduleSampler from a library of pre-defined samplers.
16 | :param name: the name of the sampler.
17 | :param diffusion: the diffusion object to sample for.
18 | """
19 | if name == "uniform":
20 | return UniformSampler(diffusion)
21 | elif name == "loss-second-moment":
22 | return LossSecondMomentResampler(diffusion)
23 | else:
24 | raise NotImplementedError(f"unknown schedule sampler: {name}")
25 |
26 |
27 | class ScheduleSampler(ABC):
28 | """
29 | A distribution over timesteps in the diffusion process, intended to reduce
30 | variance of the objective.
31 | By default, samplers perform unbiased importance sampling, in which the
32 | objective's mean is unchanged.
33 | However, subclasses may override sample() to change how the resampled
34 | terms are reweighted, allowing for actual changes in the objective.
35 | """
36 |
37 | @abstractmethod
38 | def weights(self):
39 | """
40 | Get a numpy array of weights, one per diffusion step.
41 | The weights needn't be normalized, but must be positive.
42 | """
43 |
44 | def sample(self, batch_size, device):
45 | """
46 | Importance-sample timesteps for a batch.
47 | :param batch_size: the number of timesteps.
48 | :param device: the torch device to save to.
49 | :return: a tuple (timesteps, weights):
50 | - timesteps: a tensor of timestep indices.
51 | - weights: a tensor of weights to scale the resulting losses.
52 | """
53 | w = self.weights()
54 | p = w / np.sum(w)
55 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
56 | indices = th.from_numpy(indices_np).long().to(device)
57 | weights_np = 1 / (len(p) * p[indices_np])
58 | weights = th.from_numpy(weights_np).float().to(device)
59 | return indices, weights
60 |
61 |
62 | class UniformSampler(ScheduleSampler):
63 | def __init__(self, diffusion):
64 | self.diffusion = diffusion
65 | self._weights = np.ones([diffusion.num_timesteps])
66 |
67 | def weights(self):
68 | return self._weights
69 |
70 |
71 | class LossAwareSampler(ScheduleSampler):
72 | def update_with_local_losses(self, local_ts, local_losses):
73 | """
74 | Update the reweighting using losses from a model.
75 | Call this method from each rank with a batch of timesteps and the
76 | corresponding losses for each of those timesteps.
77 | This method will perform synchronization to make sure all of the ranks
78 | maintain the exact same reweighting.
79 | :param local_ts: an integer Tensor of timesteps.
80 | :param local_losses: a 1D Tensor of losses.
81 | """
82 | batch_sizes = [
83 | th.tensor([0], dtype=th.int32, device=local_ts.device)
84 | for _ in range(dist.get_world_size())
85 | ]
86 | dist.all_gather(
87 | batch_sizes,
88 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
89 | )
90 |
91 | # Pad all_gather batches to be the maximum batch size.
92 | batch_sizes = [x.item() for x in batch_sizes]
93 | max_bs = max(batch_sizes)
94 |
95 | timestep_batches = [th.zeros(max_bs, device=local_ts.device) for _ in batch_sizes]
96 | loss_batches = [th.zeros(max_bs, device=local_losses.device) for _ in batch_sizes]
97 | dist.all_gather(timestep_batches, local_ts)
98 | dist.all_gather(loss_batches, local_losses)
99 | timesteps = [
100 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
101 | ]
102 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
103 | self.update_with_all_losses(timesteps, losses)
104 |
105 | @abstractmethod
106 | def update_with_all_losses(self, ts, losses):
107 | """
108 | Update the reweighting using losses from a model.
109 | Sub-classes should override this method to update the reweighting
110 | using losses from the model.
111 | This method directly updates the reweighting without synchronizing
112 | between workers. It is called by update_with_local_losses from all
113 | ranks with identical arguments. Thus, it should have deterministic
114 | behavior to maintain state across workers.
115 | :param ts: a list of int timesteps.
116 | :param losses: a list of float losses, one per timestep.
117 | """
118 |
119 |
120 | class LossSecondMomentResampler(LossAwareSampler):
121 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
122 | self.diffusion = diffusion
123 | self.history_per_term = history_per_term
124 | self.uniform_prob = uniform_prob
125 | self._loss_history = np.zeros(
126 | [diffusion.num_timesteps, history_per_term], dtype=np.float64
127 | )
128 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
129 |
130 | def weights(self):
131 | if not self._warmed_up():
132 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
133 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
134 | weights /= np.sum(weights)
135 | weights *= 1 - self.uniform_prob
136 | weights += self.uniform_prob / len(weights)
137 | return weights
138 |
139 | def update_with_all_losses(self, ts, losses):
140 | for t, loss in zip(ts, losses):
141 | if self._loss_counts[t] == self.history_per_term:
142 | # Shift out the oldest loss term.
143 | self._loss_history[t, :-1] = self._loss_history[t, 1:]
144 | self._loss_history[t, -1] = loss
145 | else:
146 | self._loss_history[t, self._loss_counts[t]] = loss
147 | self._loss_counts[t] += 1
148 |
149 | def _warmed_up(self):
150 | return (self._loss_counts == self.history_per_term).all()
151 |
--------------------------------------------------------------------------------
/diffusion/sa_sampler.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 |
3 | import torch
4 | import numpy as np
5 |
6 | from diffusion.model.sa_solver import NoiseScheduleVP, model_wrapper, SASolver
7 | from .model import gaussian_diffusion as gd
8 |
9 |
10 | class SASolverSampler(object):
11 | def __init__(self, model,
12 | noise_schedule="linear",
13 | diffusion_steps=1000,
14 | device='cpu',
15 | ):
16 | super().__init__()
17 | self.model = model
18 | self.device = device
19 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(device)
20 | betas = torch.tensor(gd.get_named_beta_schedule(noise_schedule, diffusion_steps))
21 | alphas = 1.0 - betas
22 | self.register_buffer('alphas_cumprod', to_torch(np.cumprod(alphas, axis=0)))
23 |
24 | def register_buffer(self, name, attr):
25 | if type(attr) == torch.Tensor and attr.device != torch.device("cuda"):
26 | attr = attr.to(torch.device("cuda"))
27 | setattr(self, name, attr)
28 |
29 | @torch.no_grad()
30 | def sample(self, S, batch_size, shape, conditioning=None, callback=None, normals_sequence=None, img_callback=None, quantize_x0=False, eta=0., mask=None, x0=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, verbose=True, x_T=None, log_every_t=100, unconditional_guidance_scale=1., unconditional_conditioning=None, model_kwargs=None, **kwargs):
31 | if model_kwargs is None:
32 | model_kwargs = {}
33 | if conditioning is not None:
34 | if isinstance(conditioning, dict):
35 | cbs = conditioning[list(conditioning.keys())[0]].shape[0]
36 | if cbs != batch_size:
37 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
38 | elif conditioning.shape[0] != batch_size:
39 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
40 |
41 | # sampling
42 | C, H, W = shape
43 | size = (batch_size, C, H, W)
44 |
45 | device = self.device
46 | img = torch.randn(size, device=device) if x_T is None else x_T
47 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
48 |
49 | model_fn = model_wrapper(
50 | self.model,
51 | ns,
52 | model_type="noise",
53 | guidance_type="classifier-free",
54 | condition=conditioning,
55 | unconditional_condition=unconditional_conditioning,
56 | guidance_scale=unconditional_guidance_scale,
57 | model_kwargs=model_kwargs,
58 | )
59 |
60 | sasolver = SASolver(model_fn, ns, algorithm_type="data_prediction")
61 |
62 | tau_t = lambda t: eta if 0.2 <= t <= 0.8 else 0
63 |
64 | x = sasolver.sample(mode='few_steps', x=img, tau=tau_t, steps=S, skip_type='time', skip_order=1, predictor_order=2, corrector_order=2, pc_mode='PEC', return_intermediate=False)
65 |
66 | return x.to(device), None
--------------------------------------------------------------------------------
/diffusion/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HongkLin/TIDE/d0cba53604b3dd9e16e4f81ce24d7f4be3ba0d4e/diffusion/utils/__init__.py
--------------------------------------------------------------------------------
/diffusion/utils/checkpoint.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import torch
4 |
5 | from diffusion.utils.logger import get_root_logger
6 |
7 |
8 | def save_checkpoint(work_dir,
9 | epoch,
10 | model,
11 | model_ema=None,
12 | optimizer=None,
13 | lr_scheduler=None,
14 | keep_last=False,
15 | step=None,
16 | ):
17 | os.makedirs(work_dir, exist_ok=True)
18 | state_dict = dict(state_dict=model.state_dict())
19 | if model_ema is not None:
20 | state_dict['state_dict_ema'] = model_ema.state_dict()
21 | if optimizer is not None:
22 | state_dict['optimizer'] = optimizer.state_dict()
23 | if lr_scheduler is not None:
24 | state_dict['scheduler'] = lr_scheduler.state_dict()
25 | if epoch is not None:
26 | state_dict['epoch'] = epoch
27 | file_path = os.path.join(work_dir, f"epoch_{epoch}.pth")
28 | if step is not None:
29 | file_path = file_path.split('.pth')[0] + f"_step_{step}.pth"
30 | logger = get_root_logger()
31 | torch.save(state_dict, file_path)
32 | logger.info(f'Saved checkpoint of epoch {epoch} to {file_path.format(epoch)}.')
33 | if keep_last:
34 | for i in range(epoch):
35 | previous_ckgt = file_path.format(i)
36 | if os.path.exists(previous_ckgt):
37 | os.remove(previous_ckgt)
38 |
39 |
40 | def load_checkpoint(checkpoint,
41 | model,
42 | model_ema=None,
43 | optimizer=None,
44 | lr_scheduler=None,
45 | load_ema=False,
46 | resume_optimizer=True,
47 | resume_lr_scheduler=True
48 | ):
49 | assert isinstance(checkpoint, str)
50 | ckpt_file = checkpoint
51 | checkpoint = torch.load(ckpt_file, map_location="cpu")
52 |
53 | state_dict_keys = ['pos_embed', 'base_model.pos_embed', 'model.pos_embed']
54 | for key in state_dict_keys:
55 | if key in checkpoint['state_dict']:
56 | del checkpoint['state_dict'][key]
57 | if 'state_dict_ema' in checkpoint and key in checkpoint['state_dict_ema']:
58 | del checkpoint['state_dict_ema'][key]
59 | break
60 |
61 | if load_ema:
62 | state_dict = checkpoint['state_dict_ema']
63 | else:
64 | state_dict = checkpoint.get('state_dict', checkpoint) # to be compatible with the official checkpoint
65 | # model.load_state_dict(state_dict)
66 | missing, unexpect = model.load_state_dict(state_dict, strict=False)
67 | if model_ema is not None:
68 | model_ema.load_state_dict(checkpoint['state_dict_ema'], strict=False)
69 | if optimizer is not None and resume_optimizer:
70 | optimizer.load_state_dict(checkpoint['optimizer'])
71 | if lr_scheduler is not None and resume_lr_scheduler:
72 | lr_scheduler.load_state_dict(checkpoint['scheduler'])
73 | logger = get_root_logger()
74 | if optimizer is not None:
75 | epoch = checkpoint.get('epoch', re.match(r'.*epoch_(\d*).*.pth', ckpt_file).group()[0])
76 | logger.info(f'Resume checkpoint of epoch {epoch} from {ckpt_file}. Load ema: {load_ema}, '
77 | f'resume optimizer: {resume_optimizer}, resume lr scheduler: {resume_lr_scheduler}.')
78 | return epoch, missing, unexpect
79 | logger.info(f'Load checkpoint from {ckpt_file}. Load ema: {load_ema}.')
80 | return missing, unexpect
81 |
--------------------------------------------------------------------------------
/diffusion/utils/data_sampler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import os
3 | from typing import Sequence
4 | from torch.utils.data import BatchSampler, Sampler, Dataset
5 | from random import shuffle, choice
6 | from copy import deepcopy
7 | from diffusion.utils.logger import get_root_logger
8 |
9 |
10 | class AspectRatioBatchSampler(BatchSampler):
11 | """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
12 |
13 | Args:
14 | sampler (Sampler): Base sampler.
15 | dataset (Dataset): Dataset providing data information.
16 | batch_size (int): Size of mini-batch.
17 | drop_last (bool): If ``True``, the sampler will drop the last batch if
18 | its size would be less than ``batch_size``.
19 | aspect_ratios (dict): The predefined aspect ratios.
20 | """
21 |
22 | def __init__(self,
23 | sampler: Sampler,
24 | dataset: Dataset,
25 | batch_size: int,
26 | aspect_ratios: dict,
27 | drop_last: bool = False,
28 | config=None,
29 | valid_num=0, # take as valid aspect-ratio when sample number >= valid_num
30 | **kwargs) -> None:
31 | if not isinstance(sampler, Sampler):
32 | raise TypeError('sampler should be an instance of ``Sampler``, '
33 | f'but got {sampler}')
34 | if not isinstance(batch_size, int) or batch_size <= 0:
35 | raise ValueError('batch_size should be a positive integer value, '
36 | f'but got batch_size={batch_size}')
37 | self.sampler = sampler
38 | self.dataset = dataset
39 | self.batch_size = batch_size
40 | self.aspect_ratios = aspect_ratios
41 | self.drop_last = drop_last
42 | self.ratio_nums_gt = kwargs.get('ratio_nums', None)
43 | self.config = config
44 | assert self.ratio_nums_gt
45 | # buckets for each aspect ratio
46 | self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
47 | self.current_available_bucket_keys = [str(k) for k, v in self.ratio_nums_gt.items() if v >= valid_num]
48 | logger = get_root_logger() if config is None else get_root_logger(os.path.join(config.work_dir, 'train_log.log'))
49 | logger.warning(f"Using valid_num={valid_num} in config file. Available {len(self.current_available_bucket_keys)} aspect_ratios: {self.current_available_bucket_keys}")
50 |
51 | def __iter__(self) -> Sequence[int]:
52 | for idx in self.sampler:
53 | data_info = self.dataset.get_data_info(idx)
54 | height, width = data_info['height'], data_info['width']
55 | ratio = height / width
56 | # find the closest aspect ratio
57 | closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
58 | if closest_ratio not in self.current_available_bucket_keys:
59 | continue
60 | bucket = self._aspect_ratio_buckets[closest_ratio]
61 | bucket.append(idx)
62 | # yield a batch of indices in the same aspect ratio group
63 | if len(bucket) == self.batch_size:
64 | yield bucket[:]
65 | del bucket[:]
66 |
67 | # yield the rest data and reset the buckets
68 | for bucket in self._aspect_ratio_buckets.values():
69 | while len(bucket) > 0:
70 | if len(bucket) <= self.batch_size:
71 | if not self.drop_last:
72 | yield bucket[:]
73 | bucket = []
74 | else:
75 | yield bucket[:self.batch_size]
76 | bucket = bucket[self.batch_size:]
77 |
78 |
79 | class BalancedAspectRatioBatchSampler(AspectRatioBatchSampler):
80 | def __init__(self, *args, **kwargs):
81 | super().__init__(*args, **kwargs)
82 | # Assign samples to each bucket
83 | self.ratio_nums_gt = kwargs.get('ratio_nums', None)
84 | assert self.ratio_nums_gt
85 | self._aspect_ratio_buckets = {float(ratio): [] for ratio in self.aspect_ratios.keys()}
86 | self.original_buckets = {}
87 | self.current_available_bucket_keys = [k for k, v in self.ratio_nums_gt.items() if v >= 3000]
88 | self.all_available_keys = deepcopy(self.current_available_bucket_keys)
89 | self.exhausted_bucket_keys = []
90 | self.total_batches = len(self.sampler) // self.batch_size
91 | self._aspect_ratio_count = {}
92 | for k in self.all_available_keys:
93 | self._aspect_ratio_count[float(k)] = 0
94 | self.original_buckets[float(k)] = []
95 | logger = get_root_logger(os.path.join(self.config.work_dir, 'train_log.log'))
96 | logger.warning(f"Available {len(self.current_available_bucket_keys)} aspect_ratios: {self.current_available_bucket_keys}")
97 |
98 | def __iter__(self) -> Sequence[int]:
99 | i = 0
100 | for idx in self.sampler:
101 | data_info = self.dataset.get_data_info(idx)
102 | height, width = data_info['height'], data_info['width']
103 | ratio = height / width
104 | closest_ratio = float(min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio)))
105 | if closest_ratio not in self.all_available_keys:
106 | continue
107 | if self._aspect_ratio_count[closest_ratio] < self.ratio_nums_gt[closest_ratio]:
108 | self._aspect_ratio_count[closest_ratio] += 1
109 | self._aspect_ratio_buckets[closest_ratio].append(idx)
110 | self.original_buckets[closest_ratio].append(idx) # Save the original samples for each bucket
111 | if not self.current_available_bucket_keys:
112 | self.current_available_bucket_keys, self.exhausted_bucket_keys = self.exhausted_bucket_keys, []
113 |
114 | if closest_ratio not in self.current_available_bucket_keys:
115 | continue
116 | key = closest_ratio
117 | bucket = self._aspect_ratio_buckets[key]
118 | if len(bucket) == self.batch_size:
119 | yield bucket[:self.batch_size]
120 | del bucket[:self.batch_size]
121 | i += 1
122 | self.exhausted_bucket_keys.append(key)
123 | self.current_available_bucket_keys.remove(key)
124 |
125 | for _ in range(self.total_batches - i):
126 | key = choice(self.all_available_keys)
127 | bucket = self._aspect_ratio_buckets[key]
128 | if len(bucket) >= self.batch_size:
129 | yield bucket[:self.batch_size]
130 | del bucket[:self.batch_size]
131 |
132 | # If a bucket is exhausted
133 | if not bucket:
134 | self._aspect_ratio_buckets[key] = deepcopy(self.original_buckets[key][:])
135 | shuffle(self._aspect_ratio_buckets[key])
136 | else:
137 | self._aspect_ratio_buckets[key] = deepcopy(self.original_buckets[key][:])
138 | shuffle(self._aspect_ratio_buckets[key])
139 |
--------------------------------------------------------------------------------
/diffusion/utils/dist_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | This file contains primitives for multi-gpu communication.
3 | This is useful when doing distributed training.
4 | """
5 | import os
6 | import pickle
7 | import shutil
8 |
9 | import gc
10 | import mmcv
11 | import torch
12 | import torch.distributed as dist
13 | from mmcv.runner import get_dist_info
14 |
15 |
16 | def is_distributed():
17 | return get_world_size() > 1
18 |
19 |
20 | def get_world_size():
21 | if not dist.is_available():
22 | return 1
23 | return dist.get_world_size() if dist.is_initialized() else 1
24 |
25 |
26 | def get_rank():
27 | if not dist.is_available():
28 | return 0
29 | return dist.get_rank() if dist.is_initialized() else 0
30 |
31 |
32 | def get_local_rank():
33 | if not dist.is_available():
34 | return 0
35 | return int(os.getenv('LOCAL_RANK', 0)) if dist.is_initialized() else 0
36 |
37 |
38 | def is_master():
39 | return get_rank() == 0
40 |
41 |
42 | def is_local_master():
43 | return get_local_rank() == 0
44 |
45 |
46 | def get_local_proc_group(group_size=8):
47 | world_size = get_world_size()
48 | if world_size <= group_size or group_size == 1:
49 | return None
50 | assert world_size % group_size == 0, f'world size ({world_size}) should be evenly divided by group size ({group_size}).'
51 | process_groups = getattr(get_local_proc_group, 'process_groups', {})
52 | if group_size not in process_groups:
53 | num_groups = dist.get_world_size() // group_size
54 | groups = [list(range(i * group_size, (i + 1) * group_size)) for i in range(num_groups)]
55 | process_groups.update({group_size: [torch.distributed.new_group(group) for group in groups]})
56 | get_local_proc_group.process_groups = process_groups
57 |
58 | group_idx = get_rank() // group_size
59 | return get_local_proc_group.process_groups.get(group_size)[group_idx]
60 |
61 |
62 | def synchronize():
63 | """
64 | Helper function to synchronize (barrier) among all processes when
65 | using distributed training
66 | """
67 | if not dist.is_available():
68 | return
69 | if not dist.is_initialized():
70 | return
71 | world_size = dist.get_world_size()
72 | if world_size == 1:
73 | return
74 | dist.barrier()
75 |
76 |
77 | def all_gather(data):
78 | """
79 | Run all_gather on arbitrary picklable data (not necessarily tensors)
80 | Args:
81 | data: any picklable object
82 | Returns:
83 | list[data]: list of data gathered from each rank
84 | """
85 | to_device = torch.device("cuda")
86 | # to_device = torch.device("cpu")
87 |
88 | world_size = get_world_size()
89 | if world_size == 1:
90 | return [data]
91 |
92 | # serialized to a Tensor
93 | buffer = pickle.dumps(data)
94 | storage = torch.ByteStorage.from_buffer(buffer)
95 | tensor = torch.ByteTensor(storage).to(to_device)
96 |
97 | # obtain Tensor size of each rank
98 | local_size = torch.LongTensor([tensor.numel()]).to(to_device)
99 | size_list = [torch.LongTensor([0]).to(to_device) for _ in range(world_size)]
100 | dist.all_gather(size_list, local_size)
101 | size_list = [int(size.item()) for size in size_list]
102 | max_size = max(size_list)
103 |
104 | tensor_list = [
105 | torch.ByteTensor(size=(max_size,)).to(to_device) for _ in size_list
106 | ]
107 | if local_size != max_size:
108 | padding = torch.ByteTensor(size=(max_size - local_size,)).to(to_device)
109 | tensor = torch.cat((tensor, padding), dim=0)
110 | dist.all_gather(tensor_list, tensor)
111 |
112 | data_list = []
113 | for size, tensor in zip(size_list, tensor_list):
114 | buffer = tensor.cpu().numpy().tobytes()[:size]
115 | data_list.append(pickle.loads(buffer))
116 |
117 | return data_list
118 |
119 |
120 | def reduce_dict(input_dict, average=True):
121 | """
122 | Args:
123 | input_dict (dict): all the values will be reduced
124 | average (bool): whether to do average or sum
125 | Reduce the values in the dictionary from all processes so that process with rank
126 | 0 has the averaged results. Returns a dict with the same fields as
127 | input_dict, after reduction.
128 | """
129 | world_size = get_world_size()
130 | if world_size < 2:
131 | return input_dict
132 | with torch.no_grad():
133 | reduced_dict = _extracted_from_reduce_dict_14(input_dict, average, world_size)
134 | return reduced_dict
135 |
136 |
137 | # TODO Rename this here and in `reduce_dict`
138 | def _extracted_from_reduce_dict_14(input_dict, average, world_size):
139 | names = []
140 | values = []
141 | # sort the keys so that they are consistent across processes
142 | for k in sorted(input_dict.keys()):
143 | names.append(k)
144 | values.append(input_dict[k])
145 | values = torch.stack(values, dim=0)
146 | dist.reduce(values, dst=0)
147 | if dist.get_rank() == 0 and average:
148 | # only main process gets accumulated, so only divide by
149 | # world_size in this case
150 | values /= world_size
151 | return dict(zip(names, values))
152 |
153 |
154 | def broadcast(data, **kwargs):
155 | if get_world_size() == 1:
156 | return data
157 | data = [data]
158 | dist.broadcast_object_list(data, **kwargs)
159 | return data[0]
160 |
161 |
162 | def all_gather_cpu(result_part, tmpdir=None, collect_by_master=True):
163 | rank, world_size = get_dist_info()
164 | if tmpdir is None:
165 | tmpdir = './tmp'
166 | if rank == 0:
167 | mmcv.mkdir_or_exist(tmpdir)
168 | synchronize()
169 | # dump the part result to the dir
170 | mmcv.dump(result_part, os.path.join(tmpdir, f'part_{rank}.pkl'))
171 | synchronize()
172 | if collect_by_master and rank != 0:
173 | return None
174 | # load results of all parts from tmp dir
175 | results = []
176 | for i in range(world_size):
177 | part_file = os.path.join(tmpdir, f'part_{i}.pkl')
178 | results.append(mmcv.load(part_file))
179 | if not collect_by_master:
180 | synchronize()
181 | # remove tmp dir
182 | if rank == 0:
183 | shutil.rmtree(tmpdir)
184 | return results
185 |
186 | def all_gather_tensor(tensor, group_size=None, group=None):
187 | if group_size is None:
188 | group_size = get_world_size()
189 | if group_size == 1:
190 | output = [tensor]
191 | else:
192 | output = [torch.zeros_like(tensor) for _ in range(group_size)]
193 | dist.all_gather(output, tensor, group=group)
194 | return output
195 |
196 |
197 | def gather_difflen_tensor(feat, num_samples_list, concat=True, group=None, group_size=None):
198 | world_size = get_world_size()
199 | if world_size == 1:
200 | return feat if concat else [feat]
201 | num_samples, *feat_dim = feat.size()
202 | # padding to max number of samples
203 | feat_padding = feat.new_zeros((max(num_samples_list), *feat_dim))
204 | feat_padding[:num_samples] = feat
205 | # gather
206 | feat_gather = all_gather_tensor(feat_padding, group=group, group_size=group_size)
207 | for r, num in enumerate(num_samples_list):
208 | feat_gather[r] = feat_gather[r][:num]
209 | if concat:
210 | feat_gather = torch.cat(feat_gather)
211 | return feat_gather
212 |
213 |
214 | class GatherLayer(torch.autograd.Function):
215 | '''Gather tensors from all process, supporting backward propagation.
216 | '''
217 |
218 | @staticmethod
219 | def forward(ctx, input):
220 | ctx.save_for_backward(input)
221 | num_samples = torch.tensor(input.size(0), dtype=torch.long, device=input.device)
222 | ctx.num_samples_list = all_gather_tensor(num_samples)
223 | output = gather_difflen_tensor(input, ctx.num_samples_list, concat=False)
224 | return tuple(output)
225 |
226 | @staticmethod
227 | def backward(ctx, *grads): # tuple(output)'s grad
228 | input, = ctx.saved_tensors
229 | num_samples_list = ctx.num_samples_list
230 | rank = get_rank()
231 | start, end = sum(num_samples_list[:rank]), sum(num_samples_list[:rank + 1])
232 | grads = torch.cat(grads)
233 | if is_distributed():
234 | dist.all_reduce(grads)
235 | grad_out = torch.zeros_like(input)
236 | grad_out[:] = grads[start:end]
237 | return grad_out, None, None
238 |
239 |
240 | class GatherLayerWithGroup(torch.autograd.Function):
241 | '''Gather tensors from all process, supporting backward propagation.
242 | '''
243 |
244 | @staticmethod
245 | def forward(ctx, input, group, group_size):
246 | ctx.save_for_backward(input)
247 | ctx.group_size = group_size
248 | output = all_gather_tensor(input, group=group, group_size=group_size)
249 | return tuple(output)
250 |
251 | @staticmethod
252 | def backward(ctx, *grads): # tuple(output)'s grad
253 | input, = ctx.saved_tensors
254 | grads = torch.stack(grads)
255 | if is_distributed():
256 | dist.all_reduce(grads)
257 | grad_out = torch.zeros_like(input)
258 | grad_out[:] = grads[get_rank() % ctx.group_size]
259 | return grad_out, None, None
260 |
261 |
262 | def gather_layer_with_group(data, group=None, group_size=None):
263 | if group_size is None:
264 | group_size = get_world_size()
265 | return GatherLayer.apply(data, group, group_size)
266 |
267 | from typing import Union
268 | import math
269 | # from torch.distributed.fsdp.fully_sharded_data_parallel import TrainingState_, _calc_grad_norm
270 |
271 | @torch.no_grad()
272 | def clip_grad_norm_(
273 | self, max_norm: Union[float, int], norm_type: Union[float, int] = 2.0
274 | ) -> None:
275 | self._lazy_init()
276 | self._wait_for_previous_optim_step()
277 | assert self._is_root, "clip_grad_norm should only be called on the root (parent) instance"
278 | self._assert_state(TrainingState_.IDLE)
279 |
280 | max_norm = float(max_norm)
281 | norm_type = float(norm_type)
282 | # Computes the max norm for this shard's gradients and sync's across workers
283 | local_norm = _calc_grad_norm(self.params_with_grad, norm_type).cuda() # type: ignore[arg-type]
284 | if norm_type == math.inf:
285 | total_norm = local_norm
286 | dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=self.process_group)
287 | else:
288 | total_norm = local_norm ** norm_type
289 | dist.all_reduce(total_norm, group=self.process_group)
290 | total_norm = total_norm ** (1.0 / norm_type)
291 |
292 | clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6)
293 | if clip_coef < 1:
294 | # multiply by clip_coef, aka, (max_norm/total_norm).
295 | for p in self.params_with_grad:
296 | assert p.grad is not None
297 | p.grad.detach().mul_(clip_coef.to(p.grad.device))
298 | return total_norm
299 |
300 |
301 | def flush():
302 | gc.collect()
303 | torch.cuda.empty_cache()
304 |
--------------------------------------------------------------------------------
/diffusion/utils/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import torch.distributed as dist
4 | from datetime import datetime
5 | from .dist_utils import is_local_master
6 | from mmcv.utils.logging import logger_initialized
7 |
8 |
9 | def get_root_logger(log_file=None, log_level=logging.INFO, name='PixArt'):
10 | """Get root logger.
11 |
12 | Args:
13 | log_file (str, optional): File path of log. Defaults to None.
14 | log_level (int, optional): The level of logger.
15 | Defaults to logging.INFO.
16 | name (str): logger name
17 | Returns:
18 | :obj:`logging.Logger`: The obtained logger
19 | """
20 | if log_file is None:
21 | log_file = '/dev/null'
22 | return get_logger(name=name, log_file=log_file, log_level=log_level)
23 |
24 |
25 | def get_logger(name, log_file=None, log_level=logging.INFO):
26 | """Initialize and get a logger by name.
27 |
28 | If the logger has not been initialized, this method will initialize the
29 | logger by adding one or two handlers, otherwise the initialized logger will
30 | be directly returned. During initialization, a StreamHandler will always be
31 | added. If `log_file` is specified and the process rank is 0, a FileHandler
32 | will also be added.
33 |
34 | Args:
35 | name (str): Logger name.
36 | log_file (str | None): The log filename. If specified, a FileHandler
37 | will be added to the logger.
38 | log_level (int): The logger level. Note that only the process of
39 | rank 0 is affected, and other processes will set the level to
40 | "Error" thus be silent most of the time.
41 |
42 | Returns:
43 | logging.Logger: The expected logger.
44 | """
45 | logger = logging.getLogger(name)
46 | logger.propagate = False # disable root logger to avoid duplicate logging
47 |
48 | if name in logger_initialized:
49 | return logger
50 | # handle hierarchical names
51 | # e.g., logger "a" is initialized, then logger "a.b" will skip the
52 | # initialization since it is a child of "a".
53 | for logger_name in logger_initialized:
54 | if name.startswith(logger_name):
55 | return logger
56 |
57 | stream_handler = logging.StreamHandler()
58 | handlers = [stream_handler]
59 |
60 | rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0
61 | # only rank 0 will add a FileHandler
62 | if rank == 0 and log_file is not None:
63 | file_handler = logging.FileHandler(log_file, 'w')
64 | handlers.append(file_handler)
65 |
66 | formatter = logging.Formatter(
67 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
68 | for handler in handlers:
69 | handler.setFormatter(formatter)
70 | handler.setLevel(log_level)
71 | logger.addHandler(handler)
72 |
73 | # only rank0 for each node will print logs
74 | log_level = log_level if is_local_master() else logging.ERROR
75 | logger.setLevel(log_level)
76 |
77 | logger_initialized[name] = True
78 |
79 | return logger
80 |
81 | def rename_file_with_creation_time(file_path):
82 | # 获取文件的创建时间
83 | creation_time = os.path.getctime(file_path)
84 | creation_time_str = datetime.fromtimestamp(creation_time).strftime('%Y-%m-%d_%H-%M-%S')
85 |
86 | # 构建新的文件名
87 | dir_name, file_name = os.path.split(file_path)
88 | name, ext = os.path.splitext(file_name)
89 | new_file_name = f"{name}_{creation_time_str}{ext}"
90 | new_file_path = os.path.join(dir_name, new_file_name)
91 |
92 | # 重命名文件
93 | os.rename(file_path, new_file_path)
94 | print(f"File renamed to: {new_file_path}")
95 |
--------------------------------------------------------------------------------
/diffusion/utils/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | from diffusers import get_cosine_schedule_with_warmup, get_constant_schedule_with_warmup
2 | from torch.optim import Optimizer
3 | from torch.optim.lr_scheduler import LambdaLR
4 | import math
5 |
6 | from diffusion.utils.logger import get_root_logger
7 |
8 |
9 | def build_lr_scheduler(config, optimizer, train_dataloader, lr_scale_ratio):
10 | if not config.get('lr_schedule_args', None):
11 | config.lr_schedule_args = {}
12 | if config.get('lr_warmup_steps', None):
13 | config['num_warmup_steps'] = config.get('lr_warmup_steps') # for compatibility with old version
14 |
15 | logger = get_root_logger()
16 | logger.info(
17 | f'Lr schedule: {config.lr_schedule}, ' + ",".join(
18 | [f"{key}:{value}" for key, value in config.lr_schedule_args.items()]) + '.')
19 | if config.lr_schedule == 'cosine':
20 | lr_scheduler = get_cosine_schedule_with_warmup(
21 | optimizer=optimizer,
22 | **config.lr_schedule_args,
23 | num_training_steps=(len(train_dataloader) * config.num_epochs),
24 | )
25 | elif config.lr_schedule == 'constant':
26 | lr_scheduler = get_constant_schedule_with_warmup(
27 | optimizer=optimizer,
28 | **config.lr_schedule_args,
29 | )
30 | elif config.lr_schedule == 'cosine_decay_to_constant':
31 | assert lr_scale_ratio >= 1
32 | lr_scheduler = get_cosine_decay_to_constant_with_warmup(
33 | optimizer=optimizer,
34 | **config.lr_schedule_args,
35 | final_lr=1 / lr_scale_ratio,
36 | num_training_steps=(len(train_dataloader) * config.num_epochs),
37 | )
38 | else:
39 | raise RuntimeError(f'Unrecognized lr schedule {config.lr_schedule}.')
40 | return lr_scheduler
41 |
42 |
43 | def get_cosine_decay_to_constant_with_warmup(optimizer: Optimizer,
44 | num_warmup_steps: int,
45 | num_training_steps: int,
46 | final_lr: float = 0.0,
47 | num_decay: float = 0.667,
48 | num_cycles: float = 0.5,
49 | last_epoch: int = -1
50 | ):
51 | """
52 | Create a schedule with a cosine annealing lr followed by a constant lr.
53 |
54 | Args:
55 | optimizer ([`~torch.optim.Optimizer`]):
56 | The optimizer for which to schedule the learning rate.
57 | num_warmup_steps (`int`):
58 | The number of steps for the warmup phase.
59 | num_training_steps (`int`):
60 | The number of total training steps.
61 | final_lr (`int`):
62 | The final constant lr after cosine decay.
63 | num_decay (`int`):
64 | The
65 | last_epoch (`int`, *optional*, defaults to -1):
66 | The index of the last epoch when resuming training.
67 |
68 | Return:
69 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
70 | """
71 |
72 | def lr_lambda(current_step):
73 | if current_step < num_warmup_steps:
74 | return float(current_step) / float(max(1, num_warmup_steps))
75 |
76 | num_decay_steps = int(num_training_steps * num_decay)
77 | if current_step > num_decay_steps:
78 | return final_lr
79 |
80 | progress = float(current_step - num_warmup_steps) / float(max(1, num_decay_steps - num_warmup_steps))
81 | return (
82 | max(
83 | 0.0,
84 | 0.5 * (1.0 + math.cos(math.pi * num_cycles * 2.0 * progress)),
85 | )
86 | * (1 - final_lr)
87 | ) + final_lr
88 |
89 | return LambdaLR(optimizer, lr_lambda, last_epoch)
90 |
--------------------------------------------------------------------------------
/diffusion/utils/optimizer.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | from mmcv import Config
4 | from mmcv.runner import build_optimizer as mm_build_optimizer, OPTIMIZER_BUILDERS, DefaultOptimizerConstructor, \
5 | OPTIMIZERS
6 | from mmcv.utils import _BatchNorm, _InstanceNorm
7 | from torch.nn import GroupNorm, LayerNorm
8 |
9 | from .logger import get_root_logger
10 |
11 | from typing import Tuple, Optional, Callable
12 |
13 | import torch
14 | from torch.optim.optimizer import Optimizer
15 |
16 |
17 | def auto_scale_lr(effective_bs, optimizer_cfg, rule='linear', base_batch_size=256):
18 | assert rule in ['linear', 'sqrt']
19 | logger = get_root_logger()
20 | # scale by world size
21 | if rule == 'sqrt':
22 | scale_ratio = math.sqrt(effective_bs / base_batch_size)
23 | elif rule == 'linear':
24 | scale_ratio = effective_bs / base_batch_size
25 | optimizer_cfg['lr'] *= scale_ratio
26 | logger.info(f'Automatically adapt lr to {optimizer_cfg["lr"]:.7f} (using {rule} scaling rule).')
27 | return scale_ratio
28 |
29 |
30 | @OPTIMIZER_BUILDERS.register_module()
31 | class MyOptimizerConstructor(DefaultOptimizerConstructor):
32 |
33 | def add_params(self, params, module, prefix='', is_dcn_module=None):
34 | """Add all parameters of module to the params list.
35 |
36 | The parameters of the given module will be added to the list of param
37 | groups, with specific rules defined by paramwise_cfg.
38 |
39 | Args:
40 | params (list[dict]): A list of param groups, it will be modified
41 | in place.
42 | module (nn.Module): The module to be added.
43 | prefix (str): The prefix of the module
44 |
45 | """
46 | # get param-wise options
47 | custom_keys = self.paramwise_cfg.get('custom_keys', {})
48 | # first sort with alphabet order and then sort with reversed len of str
49 | # sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)
50 |
51 | bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', 1.)
52 | bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', 1.)
53 | norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', 1.)
54 | bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False)
55 |
56 | # special rules for norm layers and depth-wise conv layers
57 | is_norm = isinstance(module,
58 | (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm))
59 |
60 | for name, param in module.named_parameters(recurse=False):
61 | base_lr = self.base_lr
62 | if name == 'bias' and not is_norm and not is_dcn_module:
63 | base_lr *= bias_lr_mult
64 |
65 | # apply weight decay policies
66 | base_wd = self.base_wd
67 | # norm decay
68 | if is_norm:
69 | if self.base_wd is not None:
70 | base_wd *= norm_decay_mult
71 | elif name == 'bias' and not is_dcn_module:
72 | if self.base_wd is not None:
73 | # TODO: current bias_decay_mult will have affect on DCN
74 | base_wd *= bias_decay_mult
75 |
76 | param_group = {'params': [param]}
77 | if not param.requires_grad:
78 | param_group['requires_grad'] = False
79 | params.append(param_group)
80 | continue
81 | if bypass_duplicate and self._is_in(param_group, params):
82 | logger = get_root_logger()
83 | logger.warn(f'{prefix} is duplicate. It is skipped since '
84 | f'bypass_duplicate={bypass_duplicate}')
85 | continue
86 | # if the parameter match one of the custom keys, ignore other rules
87 | is_custom = False
88 | for key in custom_keys:
89 | scope, key_name = key if isinstance(key, tuple) else (None, key)
90 | if scope is not None and scope not in f'{prefix}':
91 | continue
92 | if key_name in f'{prefix}.{name}':
93 | is_custom = True
94 | if 'lr_mult' in custom_keys[key]:
95 | # if 'base_classes' in f'{prefix}.{name}' or 'attn_base' in f'{prefix}.{name}':
96 | # param_group['lr'] = self.base_lr
97 | # else:
98 | param_group['lr'] = self.base_lr * custom_keys[key]['lr_mult']
99 | elif 'lr' not in param_group:
100 | param_group['lr'] = base_lr
101 | if self.base_wd is not None:
102 | if 'decay_mult' in custom_keys[key]:
103 | param_group['weight_decay'] = self.base_wd * custom_keys[key]['decay_mult']
104 | elif 'weight_decay' not in param_group:
105 | param_group['weight_decay'] = base_wd
106 |
107 | if not is_custom:
108 | # bias_lr_mult affects all bias parameters
109 | # except for norm.bias dcn.conv_offset.bias
110 | if base_lr != self.base_lr:
111 | param_group['lr'] = base_lr
112 | if base_wd != self.base_wd:
113 | param_group['weight_decay'] = base_wd
114 | params.append(param_group)
115 |
116 | for child_name, child_mod in module.named_children():
117 | child_prefix = f'{prefix}.{child_name}' if prefix else child_name
118 | self.add_params(
119 | params,
120 | child_mod,
121 | prefix=child_prefix,
122 | is_dcn_module=is_dcn_module)
123 |
124 |
125 | def build_optimizer(model, optimizer_cfg):
126 | # default parameter-wise config
127 | logger = get_root_logger()
128 |
129 | if hasattr(model, 'module'):
130 | model = model.module
131 | # set optimizer constructor
132 | optimizer_cfg.setdefault('constructor', 'MyOptimizerConstructor')
133 | # parameter-wise setting: cancel weight decay for some specific modules
134 | custom_keys = dict()
135 | for name, module in model.named_modules():
136 | if hasattr(module, 'zero_weight_decay'):
137 | custom_keys |= {
138 | (name, key): dict(decay_mult=0)
139 | for key in module.zero_weight_decay
140 | }
141 |
142 | paramwise_cfg = Config(dict(cfg=dict(custom_keys=custom_keys)))
143 | if given_cfg := optimizer_cfg.get('paramwise_cfg'):
144 | paramwise_cfg.merge_from_dict(dict(cfg=given_cfg))
145 | optimizer_cfg['paramwise_cfg'] = paramwise_cfg.cfg
146 | # build optimizer
147 | optimizer = mm_build_optimizer(model, optimizer_cfg)
148 |
149 | weight_decay_groups = dict()
150 | lr_groups = dict()
151 | for group in optimizer.param_groups:
152 | if not group.get('requires_grad', True): continue
153 | lr_groups.setdefault(group['lr'], []).append(group)
154 | weight_decay_groups.setdefault(group['weight_decay'], []).append(group)
155 |
156 | learnable_count, fix_count = 0, 0
157 | for p in model.parameters():
158 | if p.requires_grad:
159 | learnable_count += 1
160 | else:
161 | fix_count += 1
162 | fix_info = f"{learnable_count} are learnable, {fix_count} are fix"
163 | lr_info = "Lr group: " + ", ".join([f'{len(group)} params with lr {lr:.5f}' for lr, group in lr_groups.items()])
164 | wd_info = "Weight decay group: " + ", ".join(
165 | [f'{len(group)} params with weight decay {wd}' for wd, group in weight_decay_groups.items()])
166 | opt_info = f"Optimizer: total {len(optimizer.param_groups)} param groups, {fix_info}. {lr_info}; {wd_info}."
167 | logger.info(opt_info)
168 |
169 | return optimizer
170 |
171 |
172 | @OPTIMIZERS.register_module()
173 | class Lion(Optimizer):
174 | def __init__(
175 | self,
176 | params,
177 | lr: float = 1e-4,
178 | betas: Tuple[float, float] = (0.9, 0.99),
179 | weight_decay: float = 0.0,
180 | ):
181 | assert lr > 0.
182 | assert all(0. <= beta <= 1. for beta in betas)
183 |
184 | defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
185 |
186 | super().__init__(params, defaults)
187 |
188 | @staticmethod
189 | def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2):
190 | # stepweight decay
191 | p.data.mul_(1 - lr * wd)
192 |
193 | # weight update
194 | update = exp_avg.clone().lerp_(grad, 1 - beta1).sign_()
195 | p.add_(update, alpha=-lr)
196 |
197 | # decay the momentum running average coefficient
198 | exp_avg.lerp_(grad, 1 - beta2)
199 |
200 | @staticmethod
201 | def exists(val):
202 | return val is not None
203 |
204 | @torch.no_grad()
205 | def step(
206 | self,
207 | closure: Optional[Callable] = None
208 | ):
209 |
210 | loss = None
211 | if self.exists(closure):
212 | with torch.enable_grad():
213 | loss = closure()
214 |
215 | for group in self.param_groups:
216 | for p in filter(lambda p: self.exists(p.grad), group['params']):
217 |
218 | grad, lr, wd, beta1, beta2, state = p.grad, group['lr'], group['weight_decay'], *group['betas'], \
219 | self.state[p]
220 |
221 | # init state - exponential moving average of gradient values
222 | if len(state) == 0:
223 | state['exp_avg'] = torch.zeros_like(p)
224 |
225 | exp_avg = state['exp_avg']
226 |
227 | self.update_fn(
228 | p,
229 | grad,
230 | exp_avg,
231 | lr,
232 | wd,
233 | beta1,
234 | beta2
235 | )
236 |
237 | return loss
238 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from PIL import Image
4 | import os
5 | import copy
6 | import argparse
7 | from matplotlib import pyplot as plt
8 |
9 | from peft import PeftModel
10 |
11 | from tide.utils.dataset_palette import USIS10K_COLORS_PALETTE
12 | from tide.utils import mask_postprocess
13 |
14 | from tide.pipeline.tide_transformer import (
15 | PixArtSpecialAttnTransformerModel,
16 | TIDETransformerModel,
17 | TIDE_TANs,
18 | MiniTransformerModel
19 | )
20 | from tide.pipeline.pipeline_tide import TIDEPipeline
21 |
22 | cmap = plt.get_cmap('Spectral_r')
23 |
24 |
25 | def colour_and_vis_id_map(id_map, palette, out_path):
26 | id_map = id_map.astype(np.uint8)
27 | id_map -= 1
28 | ids = np.unique(id_map)
29 | valid_ids = np.delete(ids, np.where(ids == 255))
30 |
31 | colour_layout = np.zeros((id_map.shape[0], id_map.shape[1], 3), dtype=np.uint8)
32 | for id in valid_ids:
33 | colour_layout[id_map == id, :] = palette[id].reshape(1, 3)
34 | colour_layout = Image.fromarray(colour_layout)
35 | colour_layout.save(out_path)
36 |
37 |
38 | def main(args):
39 | pretrained_t2i_model = os.path.join(args.model_weights_dir, "PixArt-XL-2-512x512")
40 | mini_transformer_dir = os.path.join(args.model_weights_dir, "TIDE_MiniTransformer")
41 | tide_weight_dir = os.path.join(args.model_weights_dir, "TIDE_r32_64_b4_200k")
42 |
43 | generator = torch.manual_seed(50)
44 | palette = np.array([[0, 0, 0]] + USIS10K_COLORS_PALETTE)
45 |
46 | # model definitions
47 | transformer = PixArtSpecialAttnTransformerModel.from_pretrained(
48 | pretrained_t2i_model,
49 | subfolder="transformer", torch_dtype=torch.float16
50 | )
51 | transformer.requires_grad_(False)
52 |
53 | depth_transformer = MiniTransformerModel.from_config(
54 | mini_transformer_dir,
55 | subfolder="mini_transformer",
56 | torch_dtype=torch.float16
57 | )
58 | _state_dict = torch.load(
59 | os.path.join(mini_transformer_dir, 'mini_transformer/diffusion_pytorch_model.pth'),
60 | map_location='cpu'
61 | )
62 | depth_transformer.load_state_dict(_state_dict)
63 | depth_transformer.half()
64 | depth_transformer.requires_grad_(False)
65 | del _state_dict
66 |
67 | mask_transformer = copy.deepcopy(depth_transformer)
68 |
69 | image_transformer = PeftModel.from_pretrained(
70 | transformer, os.path.join(tide_weight_dir, 'image_transformer_lora')
71 | )
72 | depth_transformer = PeftModel.from_pretrained(
73 | depth_transformer, os.path.join(tide_weight_dir, 'depth_transformer_lora')
74 | )
75 | mask_transformer = PeftModel.from_pretrained(
76 | mask_transformer, os.path.join(tide_weight_dir, 'mask_transformer_lora')
77 | )
78 |
79 | tan_modules = TIDE_TANs.from_pretrained(
80 | os.path.join(tide_weight_dir, 'tan_modules'), torch_dtype=torch.float16
81 | )
82 |
83 | tide_transformer = TIDETransformerModel(image_transformer, depth_transformer, mask_transformer, tan_modules)
84 | del image_transformer, depth_transformer, mask_transformer, tan_modules
85 |
86 | model = TIDEPipeline.from_pretrained(
87 | pretrained_t2i_model,
88 | transformer=tide_transformer,
89 | torch_dtype=torch.float16,
90 | use_safetensors=True
91 | ).to("cuda")
92 |
93 | # generate image, depth map, semantic mask
94 | target_image, depth_image, mask_image = model(
95 | prompt=args.text_prompt,
96 | num_inference_steps=20,
97 | generator=generator,
98 | guidance_scale=2.0,
99 | )
100 | target_image = target_image.images[0]
101 | depth_image = depth_image.images[0]
102 | mask_image = mask_image.images[0]
103 |
104 | target_image.save(os.path.join(args.output, "image.jpg"))
105 |
106 | depth_image = np.mean(depth_image, axis=-1)
107 | vis_depth_image = (cmap(depth_image) * 255).astype(np.uint8)
108 | Image.fromarray(vis_depth_image).save(os.path.join(args.output, "depth.png"))
109 |
110 | id_map = mask_postprocess(mask_image, palette)
111 | colour_and_vis_id_map(id_map, palette[1:], os.path.join(args.output, "mask.png"))
112 |
113 |
114 | if __name__ == "__main__":
115 | parser = argparse.ArgumentParser()
116 | parser.add_argument("--model_weights_dir", type=str, default="./model_weights")
117 | parser.add_argument('--text_prompt', type=str, default="A large school of fish swimming in a circle.")
118 | parser.add_argument('--output', type=str, default="./outputs")
119 | args = parser.parse_args()
120 |
121 | os.makedirs(args.output, exist_ok=True)
122 |
123 | main(args)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | mmcv==1.7.0
2 | diffusers==0.29.2
3 | timm==0.6.12
4 | accelerate==0.31.0
5 | tensorboard==2.17.0
6 | tensorboardX==2.6.2.2
7 | transformers==4.41.2
8 | ftfy
9 | protobuf==3.20.2
10 | gradio==4.1.1
11 | yapf==0.40.1
12 | bs4
13 | einops
14 | optimum
15 | xformers==0.0.20
16 | Pillow==10.2.0
17 | sentencepiece~=0.1.99
18 | peft==0.10.0
19 | beautifulsoup4
20 | git+https://github.com/facebookresearch/detectron2.git
21 | git+https://github.com/lucasb-eyer/pydensecrf.git
22 |
--------------------------------------------------------------------------------
/tide/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HongkLin/TIDE/d0cba53604b3dd9e16e4f81ce24d7f4be3ba0d4e/tide/__init__.py
--------------------------------------------------------------------------------
/tide/config.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from detectron2.config import CfgNode as CN
3 |
4 |
5 | def add_tide_config(cfg):
6 | #TIDE
7 | cfg.TIDE = CN()
8 | cfg.TIDE.SEED = 42
9 | cfg.TIDE.PRETRAINED_DIFFUSION_MODEL_WEIGHT = "pretrained_model/PixArt-XL-2-512x512"
10 |
11 | #instruction
12 | cfg.TIDE.INSTRUCT = CN()
13 | cfg.TIDE.NUM_IMAGE_PER_PROMPT = 1
14 | cfg.TIDE.REGION_FILTER_TH = 100
15 | cfg.TIDE.DISTANCE_MAP_TH = None
16 | cfg.TIDE.RETURN_CRF_REFINE = True
17 | cfg.TIDE.TEMPERATURE = 40
18 |
--------------------------------------------------------------------------------
/tide/pipeline/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 |
5 |
6 | class MLP(nn.Module):
7 | """Very simple multi-layer perceptron (also called FFN)"""
8 |
9 | def __init__(
10 | self,
11 | input_dim,
12 | hidden_dim,
13 | output_dim,
14 | num_layers,
15 | sigmoid_output: bool = False,
16 | affine_func=nn.Linear,
17 | ):
18 | super().__init__()
19 | self.num_layers = num_layers
20 | h = [hidden_dim] * (num_layers - 1)
21 | self.layers = nn.ModuleList(
22 | affine_func(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
23 | )
24 | self.sigmoid_output = sigmoid_output
25 |
26 | def forward(self, x: torch.Tensor):
27 | for i, layer in enumerate(self.layers):
28 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
29 | if self.sigmoid_output:
30 | x = F.sigmoid(x)
31 | return x
32 |
33 | class ResidualMLP(nn.Module):
34 |
35 | def __init__(
36 | self,
37 | input_dim,
38 | hidden_dim,
39 | output_dim,
40 | num_mlp,
41 | num_layer_per_mlp,
42 | sigmoid_output: bool = False,
43 | affine_func=nn.Linear,
44 | ):
45 | super().__init__()
46 | self.num_mlp = num_mlp
47 | self.in2hidden_dim = affine_func(input_dim, hidden_dim)
48 | self.hidden2out_dim = affine_func(hidden_dim, output_dim)
49 | self.mlp_list = nn.ModuleList(
50 | MLP(
51 | hidden_dim,
52 | hidden_dim,
53 | hidden_dim,
54 | num_layer_per_mlp,
55 | affine_func=affine_func,
56 | ) for _ in range(num_mlp)
57 | )
58 | self.sigmoid_output = sigmoid_output
59 |
60 | def forward(self, x: torch.Tensor):
61 | x = self.in2hidden_dim(x)
62 | for mlp in self.mlp_list:
63 | out = mlp(x)
64 | x = x + out
65 | out = self.hidden2out_dim(x)
66 | return out
--------------------------------------------------------------------------------
/tide/pipeline/transformer_attentions.py:
--------------------------------------------------------------------------------
1 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2 | # See the License for the specific language governing permissions and
3 | # limitations under the License.
4 | from torch import einsum
5 | from typing import Callable, List, Optional, Union
6 |
7 | import torch
8 | from torch import nn
9 | import torch.nn.functional as F
10 |
11 | from einops import rearrange, repeat
12 |
13 | from diffusers.utils.import_utils import is_xformers_available
14 | from diffusers.models.attention_processor import Attention, SpatialNorm, AttnProcessor2_0, AttnProcessor
15 | from diffusers.models.attention import logger
16 | from diffusers.utils import deprecate
17 |
18 | class ILSAttnProcessor:
19 | r"""
20 | Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
21 | """
22 |
23 | def __init__(self):
24 | if not hasattr(F, "scaled_dot_product_attention"):
25 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
26 |
27 | def __call__(
28 | self,
29 | attn,
30 | hidden_states,
31 | encoder_hidden_states,
32 | attention_mask: Optional[torch.Tensor] = None,
33 | cross_attn_map: Optional[torch.Tensor] = None,
34 | temb: Optional[torch.Tensor] = None,
35 | *args,
36 | **kwargs,
37 | ) -> torch.Tensor:
38 | context = encoder_hidden_states
39 | if len(args) > 0 or kwargs.get("scale", None) is not None:
40 | deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
41 | deprecate("scale", "1.0.0", deprecation_message)
42 |
43 | residual = hidden_states
44 | if attn.spatial_norm is not None:
45 | hidden_states = attn.spatial_norm(hidden_states, temb)
46 |
47 | input_ndim = hidden_states.ndim
48 |
49 | if input_ndim == 4:
50 | batch_size, channel, height, width = hidden_states.shape
51 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
52 |
53 | batch_size, sequence_length, _ = (
54 | hidden_states.shape if context is None else context.shape
55 | )
56 |
57 | if attention_mask is not None:
58 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
59 | # scaled_dot_product_attention expects attention_mask shape to be
60 | # (batch, heads, source_length, target_length)
61 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
62 |
63 | if attn.group_norm is not None:
64 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
65 |
66 | query = attn.to_q(hidden_states)
67 |
68 | if context is None:
69 | context = hidden_states
70 | elif attn.norm_cross:
71 | context = attn.norm_context(context)
72 |
73 | key = attn.to_k(context)
74 | value = attn.to_v(context)
75 |
76 | inner_dim = key.shape[-1]
77 | head_dim = inner_dim // attn.heads
78 |
79 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
80 |
81 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
82 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
83 |
84 | if cross_attn_map is not None:
85 | hidden_states = einsum('b h l n, b h n c -> b h l c', cross_attn_map, value)
86 | else:
87 | hidden_states, cross_attn_map = self.scaled_dot_product_attention(
88 | attn, query, key, value, attn_mask=attention_mask, dropout_p=0.0
89 | )
90 |
91 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
92 | hidden_states = hidden_states.to(query.dtype)
93 |
94 | # linear proj
95 | hidden_states = attn.to_out[0](hidden_states)
96 | # dropout
97 | hidden_states = attn.to_out[1](hidden_states)
98 |
99 | if input_ndim == 4:
100 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
101 |
102 | if attn.residual_connection:
103 | hidden_states = hidden_states + residual
104 |
105 | hidden_states = hidden_states / attn.rescale_output_factor
106 |
107 | return hidden_states, cross_attn_map
108 |
109 | def scaled_dot_product_attention(self, attn, query, key, value, attn_mask=None, dropout_p=0.0):
110 | query = query * attn.scale
111 | attn_map = einsum('b h l c, b h n c -> b h l n', query, key)
112 |
113 | if attn_mask is not None:
114 | attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype == torch.bool else attn_mask
115 | attn_map += attn_mask
116 |
117 | attn_weight = F.softmax(attn_map, dim=-1)
118 |
119 | if dropout_p > 0.0:
120 | attn_weight = F.dropout(attn_weight, p=dropout_p)
121 |
122 | hidden_states = einsum('b h l n, b h n c -> b h l c', attn_weight, value)
123 |
124 | return hidden_states, attn_weight
125 |
--------------------------------------------------------------------------------
/tide/pipeline/transformer_blocks.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Optional
2 |
3 | import torch
4 | from torch import nn
5 |
6 | from diffusers.models.attention import logger, _chunked_feed_forward, BasicTransformerBlock
7 | from diffusers.models.attention_processor import Attention
8 | from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous
9 | from .transformer_attentions import ILSAttnProcessor
10 |
11 | class BasicTransformerTIDEBlock(BasicTransformerBlock):
12 | def __init__(
13 | self,
14 | dim: int,
15 | num_attention_heads: int,
16 | attention_head_dim: int,
17 | dropout=0.0,
18 | cross_attention_dim: Optional[int] = None,
19 | activation_fn: str = "geglu",
20 | num_embeds_ada_norm: Optional[int] = None,
21 | attention_bias: bool = False,
22 | only_cross_attention: bool = False,
23 | double_self_attention: bool = False,
24 | upcast_attention: bool = False,
25 | norm_elementwise_affine: bool = True,
26 | norm_type: str = "layer_norm",
27 | norm_eps: float = 1e-5,
28 | final_dropout: bool = False,
29 | attention_type: str = "default",
30 | positional_embeddings: Optional[str] = None,
31 | num_positional_embeddings: Optional[int] = None,
32 | ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
33 | ada_norm_bias: Optional[int] = None,
34 | ff_inner_dim: Optional[int] = None,
35 | ff_bias: bool = True,
36 | attention_out_bias: bool = True,
37 | ):
38 | super().__init__(
39 | dim,
40 | num_attention_heads,
41 | attention_head_dim,
42 | dropout,
43 | cross_attention_dim,
44 | activation_fn,
45 | num_embeds_ada_norm,
46 | attention_bias,
47 | only_cross_attention,
48 | double_self_attention,
49 | upcast_attention,
50 | norm_elementwise_affine,
51 | norm_type,
52 | norm_eps,
53 | final_dropout,
54 | attention_type,
55 | positional_embeddings,
56 | num_positional_embeddings,
57 | ada_norm_continous_conditioning_embedding_dim,
58 | ada_norm_bias,
59 | ff_inner_dim,
60 | ff_bias,
61 | attention_out_bias,
62 | )
63 |
64 |
65 | # 2. Cross-Attn
66 | if cross_attention_dim is not None or double_self_attention:
67 | # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
68 | # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
69 | # the second cross attention block.
70 | if norm_type == "ada_norm":
71 | self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
72 | elif norm_type == "ada_norm_continuous":
73 | self.norm2 = AdaLayerNormContinuous(
74 | dim,
75 | ada_norm_continous_conditioning_embedding_dim,
76 | norm_elementwise_affine,
77 | norm_eps,
78 | ada_norm_bias,
79 | "rms_norm",
80 | )
81 | else:
82 | self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
83 |
84 | self.attn2 = Attention(
85 | query_dim=dim,
86 | cross_attention_dim=cross_attention_dim if not double_self_attention else None,
87 | heads=num_attention_heads,
88 | dim_head=attention_head_dim,
89 | dropout=dropout,
90 | bias=attention_bias,
91 | upcast_attention=upcast_attention,
92 | out_bias=attention_out_bias,
93 | processor=ILSAttnProcessor()
94 | ) # is self-attn if encoder_hidden_states is none
95 | else:
96 | self.norm2 = None
97 | self.attn2 = None
98 |
99 |
100 | def forward(
101 | self,
102 | hidden_states: torch.Tensor,
103 | attention_mask: Optional[torch.Tensor] = None,
104 | encoder_hidden_states: Optional[torch.Tensor] = None,
105 | encoder_attention_mask: Optional[torch.Tensor] = None,
106 | timestep: Optional[torch.LongTensor] = None,
107 | cross_attn_map: Optional[torch.Tensor] = None,
108 | cross_attention_kwargs: Dict[str, Any] = None,
109 | class_labels: Optional[torch.LongTensor] = None,
110 | added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
111 | ) -> torch.Tensor:
112 | if cross_attention_kwargs is not None:
113 | if cross_attention_kwargs.get("scale", None) is not None:
114 | logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
115 |
116 | # Notice that normalization is always applied before the real computation in the following blocks.
117 | # 0. Self-Attention
118 | batch_size = hidden_states.shape[0]
119 |
120 | if self.norm_type == "ada_norm":
121 | norm_hidden_states = self.norm1(hidden_states, timestep)
122 | elif self.norm_type == "ada_norm_zero":
123 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
124 | hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
125 | )
126 | elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
127 | norm_hidden_states = self.norm1(hidden_states)
128 | elif self.norm_type == "ada_norm_continuous":
129 | norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
130 | elif self.norm_type == "ada_norm_single":
131 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
132 | self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
133 | ).chunk(6, dim=1)
134 | norm_hidden_states = self.norm1(hidden_states)
135 | norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
136 | norm_hidden_states = norm_hidden_states.squeeze(1)
137 | else:
138 | raise ValueError("Incorrect norm used")
139 |
140 | if self.pos_embed is not None:
141 | norm_hidden_states = self.pos_embed(norm_hidden_states)
142 |
143 | # 1. Prepare GLIGEN inputs
144 | cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
145 | gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
146 |
147 | attn_output = self.attn1(
148 | norm_hidden_states,
149 | encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
150 | attention_mask=attention_mask,
151 | # **cross_attention_kwargs,
152 | )
153 | if self.norm_type == "ada_norm_zero":
154 | attn_output = gate_msa.unsqueeze(1) * attn_output
155 | elif self.norm_type == "ada_norm_single":
156 | attn_output = gate_msa * attn_output
157 |
158 | hidden_states = attn_output + hidden_states
159 | if hidden_states.ndim == 4:
160 | hidden_states = hidden_states.squeeze(1)
161 |
162 | # 1.2 GLIGEN Control
163 | if gligen_kwargs is not None:
164 | hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
165 |
166 | # 3. Cross-Attention
167 | if self.attn2 is not None:
168 | if self.norm_type == "ada_norm":
169 | norm_hidden_states = self.norm2(hidden_states, timestep)
170 | elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
171 | norm_hidden_states = self.norm2(hidden_states)
172 | elif self.norm_type == "ada_norm_single":
173 | # For PixArt norm2 isn't applied here:
174 | # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
175 | norm_hidden_states = hidden_states
176 | elif self.norm_type == "ada_norm_continuous":
177 | norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
178 | else:
179 | raise ValueError("Incorrect norm")
180 |
181 | if self.pos_embed is not None and self.norm_type != "ada_norm_single":
182 | norm_hidden_states = self.pos_embed(norm_hidden_states)
183 |
184 | attn_output, cross_attn_map = self.attn2(
185 | norm_hidden_states,
186 | encoder_hidden_states=encoder_hidden_states,
187 | attention_mask=encoder_attention_mask,
188 | cross_attn_map=cross_attn_map,
189 | **cross_attention_kwargs,
190 | )
191 | hidden_states = attn_output + hidden_states
192 |
193 | # 4. Feed-forward
194 | # i2vgen doesn't have this norm 🤷♂️
195 | if self.norm_type == "ada_norm_continuous":
196 | norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
197 | elif not self.norm_type == "ada_norm_single":
198 | norm_hidden_states = self.norm3(hidden_states)
199 |
200 | if self.norm_type == "ada_norm_zero":
201 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
202 |
203 | if self.norm_type == "ada_norm_single":
204 | norm_hidden_states = self.norm2(hidden_states)
205 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
206 |
207 | if self._chunk_size is not None:
208 | # "feed_forward_chunk_size" can be used to save memory
209 | ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
210 | else:
211 | ff_output = self.ff(norm_hidden_states)
212 |
213 | if self.norm_type == "ada_norm_zero":
214 | ff_output = gate_mlp.unsqueeze(1) * ff_output
215 | elif self.norm_type == "ada_norm_single":
216 | ff_output = gate_mlp * ff_output
217 |
218 | hidden_states = ff_output + hidden_states
219 | if hidden_states.ndim == 4:
220 | hidden_states = hidden_states.squeeze(1)
221 |
222 | return hidden_states, cross_attn_map
223 |
--------------------------------------------------------------------------------
/tide/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .prompt_process import IDColour
2 | from .mask_process import mask_postprocess
--------------------------------------------------------------------------------
/tide/utils/dataset_palette.py:
--------------------------------------------------------------------------------
1 | USIS10K_CLASS_NAMES = (
2 | "wrecks",
3 | "fish",
4 | "reefs",
5 | "aquatic plants",
6 | "human divers",
7 | "robots",
8 | "sea-floor"
9 | )
10 | USIS10K_COLORS_PALETTE = [[0, 0, 255], [0, 255, 0], [0, 255, 255], [255, 0, 0], [255, 0, 255], [255, 255, 0], [255, 255, 255]]
--------------------------------------------------------------------------------
/tide/utils/mask_process.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from PIL import Image
4 | import cv2
5 | import pydensecrf.densecrf as dcrf
6 | from pydensecrf.utils import unary_from_softmax, create_pairwise_gaussian
7 |
8 |
9 |
10 | def apply_crf(probability_matrix, num_classes):
11 | h, w = probability_matrix.shape[-2:]
12 |
13 | dcrf_model = dcrf.DenseCRF(h * w, num_classes)
14 |
15 | unary = unary_from_softmax(probability_matrix)
16 | unary = np.ascontiguousarray(unary)
17 | dcrf_model.setUnaryEnergy(unary)
18 |
19 | feats = create_pairwise_gaussian(sdims=(10, 10), shape=(h, w))
20 |
21 | dcrf_model.addPairwiseEnergy(feats, compat=3,
22 | kernel=dcrf.DIAG_KERNEL,
23 | normalization=dcrf.NORMALIZE_SYMMETRIC)
24 |
25 | Q = dcrf_model.inference(5)
26 |
27 | MAP = np.argmax(Q, axis=0).reshape((h, w))
28 | return MAP
29 |
30 |
31 | def colour_and_vis_id_map(id_map, palette, out_path):
32 | palette = np.array(palette)
33 |
34 | id_map = id_map.astype(np.uint8)
35 | id_map -= 1
36 | ids = np.unique(id_map)
37 | valid_ids = np.delete(ids, np.where(ids == 255))
38 |
39 | colour_layout = np.zeros((id_map.shape[0], id_map.shape[1], 3), dtype=np.uint8)
40 | for id in valid_ids:
41 | colour_layout[id_map == id, :] = palette[id].reshape(1, 3)
42 | colour_layout = Image.fromarray(colour_layout)
43 | colour_layout.save(out_path)
44 |
45 |
46 | def compute_distance(matrix_A, matrix_B):
47 | matrix_A = matrix_A[:, np.newaxis, :]
48 |
49 | diff_squared = (matrix_A - matrix_B) ** 2
50 |
51 | distances = np.sqrt(np.sum(diff_squared, axis=2))
52 | return distances
53 |
54 |
55 | def rgb2id(matrix_A, matrix_B, th=None, return_crf_refine=True, t=40):
56 | h, w = matrix_A.shape[:2]
57 | matrix_A = matrix_A.reshape(-1, 3)
58 |
59 | distances = compute_distance(matrix_A, matrix_B)
60 | min_distance_indices = np.argmin(distances, axis=1)
61 |
62 | if th is not None:
63 | min_distances = np.min(distances, axis=1)
64 | min_distance_indices[min_distances > th] = len(matrix_B) - 1
65 |
66 | crf_results = None
67 | if return_crf_refine:
68 | prob = 1 - distances / distances.sum(1)[:, None]
69 | prob = torch.tensor(prob).reshape(h, w, -1).permute(2, 0, 1)
70 | prob = torch.nn.functional.softmax(t * prob, dim=0).numpy()
71 | crf_results = apply_crf(prob, num_classes=len(matrix_B))
72 | return min_distance_indices.reshape(w, h), crf_results
73 |
74 |
75 | def filter_small_regions(arr, min_area=100, ignore_class=0):
76 | unique_labels = np.unique(arr)
77 | result = arr.copy()
78 |
79 | for label in unique_labels:
80 | if label == ignore_class:
81 | continue
82 | binary_image = np.where(arr == label, 1, 0).astype('uint8')
83 | num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_image, connectivity=8)
84 | for i in range(1, num_labels):
85 | if stats[i, cv2.CC_STAT_AREA] < min_area:
86 | result[labels == i] = ignore_class
87 |
88 | return result
89 |
90 |
91 | def mask_postprocess(mask_image, palette, ignore_class=0):
92 | mask_image = np.array(mask_image)
93 | _, refine_id_map = rgb2id(mask_image, palette)
94 | id_map = filter_small_regions(refine_id_map)
95 | id_map[id_map == 7] = ignore_class # ignore sea-floor
96 | return id_map
--------------------------------------------------------------------------------
/tide/utils/prompt_process.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from .dataset_palette import (
4 | USIS10K_COLORS_PALETTE,
5 | )
6 |
7 | class IDColour:
8 | def __init__(self, max_id_num=8, color_palette=USIS10K_COLORS_PALETTE):
9 | palette = color_palette
10 | self.palette = np.array(palette)
11 | self.background_id = 255
12 |
13 | def __call__(self, id_map):
14 | id_map = np.array(id_map) - 1
15 | ids = np.unique(id_map)
16 | valid_ids = np.delete(ids, np.where(ids == self.background_id))
17 |
18 | mask_pixel_values = np.zeros((id_map.shape[0], id_map.shape[1], 3), dtype=np.uint8)
19 | for id in valid_ids:
20 | mask_pixel_values[id_map == id, :] = self.palette[id].reshape(1, 3)
21 | return mask_pixel_values
22 |
23 | def generate_rgb_values(self, num_values):
24 | step = (255 ** 3) / (num_values - 1)
25 |
26 | rgb_values = []
27 |
28 | for i in range(num_values):
29 | value_i = round(i * step)
30 | red_i = value_i // (256 ** 2)
31 | green_i = (value_i % (256 ** 2)) // 256
32 | blue_i = value_i % 256
33 | rgb_values.append((red_i, green_i, blue_i))
34 |
35 | return rgb_values
--------------------------------------------------------------------------------