├── .gitattributes ├── LICENSE ├── README.md ├── asset ├── docs │ └── badge-website.svg └── images │ └── teasor.png ├── datasets └── tide_uwdense.py ├── diffusion ├── __init__.py ├── data │ ├── __init__.py │ ├── builder.py │ ├── datasets │ │ ├── Dreambooth.py │ │ ├── InternalData.py │ │ ├── InternalData_ms.py │ │ ├── SA.py │ │ ├── __init__.py │ │ ├── pixart_control.py │ │ └── utils.py │ └── transforms.py ├── dpm_solver.py ├── iddpm.py ├── lcm_scheduler.py ├── model │ ├── __init__.py │ ├── builder.py │ ├── diffusion_utils.py │ ├── dpm_solver.py │ ├── edm_sample.py │ ├── gaussian_diffusion.py │ ├── hed.py │ ├── llava │ │ ├── __init__.py │ │ ├── llava_mpt.py │ │ └── mpt │ │ │ ├── attention.py │ │ │ ├── blocks.py │ │ │ ├── configuration_mpt.py │ │ │ ├── modeling_mpt.py │ │ │ ├── norm.py │ │ │ └── param_init_fns.py │ ├── nets │ │ ├── PixArt.py │ │ ├── PixArtMS.py │ │ ├── PixArt_blocks.py │ │ ├── __init__.py │ │ └── pixart_controlnet.py │ ├── respace.py │ ├── sa_solver.py │ ├── t5.py │ ├── timestep_sampler.py │ └── utils.py ├── sa_sampler.py ├── sa_solver_diffusers.py └── utils │ ├── __init__.py │ ├── checkpoint.py │ ├── data_sampler.py │ ├── dist_utils.py │ ├── logger.py │ ├── lr_scheduler.py │ ├── misc.py │ └── optimizer.py ├── inference.py ├── requirements.txt └── tide ├── __init__.py ├── config.py ├── pipeline ├── layers.py ├── pipeline_tide.py ├── tide_transformer.py ├── transformer_attentions.py └── transformer_blocks.py ├── train_tide_hf.py └── utils ├── __init__.py ├── dataset_palette.py ├── mask_process.py └── prompt_process.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [CVPR 2025] A Unified Image-Dense Annotation Generation Model for Underwater Scenes 2 | 3 | 4 | [![Website](asset/docs/badge-website.svg)](https://hongklin.github.io/TIDE/) 5 | [![arXiv](https://img.shields.io/badge/Arxiv-2503.21771-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2503.21771) 6 | [![License](https://img.shields.io/badge/License-Apache--2.0-929292)](https://www.apache.org/licenses/LICENSE-2.0) 7 | 8 | ## 🌊 **Introduction** 9 | We present TIDE, a unified underwater image-dense annotation generation model. Its core lies in the shared layout information and the natural complementarity between multimodal features. Our model, derived from the pre-trained text-to-image model and fine-tuned with underwater data, enables the generation of highly consistent underwater image-dense annotations from solely text conditions. 10 | 11 | ![TIDE_demo.](asset/images/teasor.png) 12 | --- 13 | ## 🐚 **News** 14 | - 2025-4-8: The training data and the SynTIDE dataset are available [here](https://huggingface.co/datasets/hongk1998/TIDE/tree/main)! 15 | - 2025-3-28: The training and inference code is now available! 16 | - 2025-2-27: Our TIDE is accepted to CVPR 2025! 17 | --- 18 | 19 | ## 🪸 Dependencies and Installation 20 | 21 | - Python >= 3.9 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html)) 22 | - [PyTorch >= 2.0.1+cu11.7](https://pytorch.org/) 23 | ```bash 24 | conda create -n TIDE python=3.9 25 | conda activate TIDE 26 | conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia 27 | 28 | git clone https://github.com/HongkLin/TIDE 29 | cd TIDE 30 | pip install -r requirements.txt 31 | ``` 32 | 33 | ## 🐬 Inference 34 | Download the pre-trained [PixArt-α](https://huggingface.co/PixArt-alpha/PixArt-XL-2-512x512), [MiniTransformer](https://github.com/HongkLin/TIDE/releases/download/model_weights/TIDE_MiniTransformer.zip), and [TIDE checkpoint](https://github.com/HongkLin/TIDE/releases/download/model_weights/TIDE_r32_64_b4_200k.zip), then modify the model weights path. 35 | ```bash 36 | python inference.py --model_weights_dir ./model_weights --text_prompt "A large school of fish swimming in a circle." --output ./outputs 37 | ``` 38 | 39 | ## 🐢 Training 40 | 41 | ### 🏖️ ️Training Data Prepare 42 | - Download [SUIM](https://github.com/xahidbuffon/SUIM), [UIIS](https://github.com/LiamLian0727/WaterMask), [USIS10K](https://github.com/LiamLian0727/USIS10K) datasets. 43 | - The semantic segmentation annotations are obtained by merging instances with the same semantics. 44 | - The depth annotations are obtained by [Depth Anything V2](https://github.com/DepthAnything/Depth-Anything-V2), and the inverse depth results are saved as npy files. 45 | - The image caption and JSON file for organizing training data can follow [Atlantis](https://github.com/zkawfanx/Atlantis), which we also do. 46 | 47 | The final dataset should be ordered as follow: 48 | ``` 49 | datasets/ 50 | UWD_triplets/ 51 | images/ 52 | train_05543.jpg 53 | ... 54 | semseg_annotations/ 55 | train_05543.jpg 56 | ... 57 | depth_annotations/ 58 | train_05543_raw_depth_meter.npy 59 | ... 60 | TrainTIDE_Caption.json 61 | ``` 62 | If you have prepared the training data and environment, you can run the following script to start the training: 63 | ```bash 64 | accelerate launch --num_processes=4 --main_process_port=36666 ./tide/train_tide_hf.py \ 65 | --max_train_steps 200000 --learning_rate=1e-4 --train_batch_size=1 \ 66 | --gradient_accumulation_steps=1 --seed=42 --dataloader_num_workers=4 --validation_steps 10000 \ 67 | --wandb_name=tide_r32_64_b4_200k --output_dir=./outputs/tide_r32_64_b4_200k 68 | ``` 69 | 70 | # 🤗Acknowledgements 71 | - Thanks to [Diffusers](https://github.com/huggingface/diffusers) for their wonderful technical support and awesome collaboration! 72 | - Thanks to [Hugging Face](https://github.com/huggingface) for sponsoring the nicely demo! 73 | - Thanks to [DiT](https://github.com/facebookresearch/DiT) for their wonderful work and codebase! 74 | - Thanks to [PixArt-α](https://github.com/PixArt-alpha/PixArt-alpha) for their wonderful work and codebase! 75 | 76 | ## 📖BibTeX 77 | ``` 78 | @inproceedings{lin2025tide, 79 | title={A Unified Image-Dense Annotation Generation Model for Underwater Scenes}, 80 | author={Lin, Hongkai and Liang, Dingkang and Qi, Zhenghao and Bai, Xiang}, 81 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 82 | year={2025}, 83 | } 84 | ``` -------------------------------------------------------------------------------- /asset/docs/badge-website.svg: -------------------------------------------------------------------------------- 1 | 2 | 15 | 17 | 35 | project: website 37 | 38 | 42 | 47 | 51 | 52 | 54 | 60 | 61 | 64 | 69 | 75 | 80 | 81 | 88 | 98 | Project 105 | 106 | 116 | Website 123 | 124 | 125 | 129 | 130 | -------------------------------------------------------------------------------- /asset/images/teasor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HongkLin/TIDE/d0cba53604b3dd9e16e4f81ce24d7f4be3ba0d4e/asset/images/teasor.png -------------------------------------------------------------------------------- /datasets/tide_uwdense.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import datasets 4 | import pandas as pd 5 | import numpy as np 6 | 7 | _VERSION = datasets.Version("0.0.1") 8 | 9 | _DESCRIPTION = "TODO" 10 | _HOMEPAGE = "TODO" 11 | _LICENSE = "TODO" 12 | _CITATION = "TODO" 13 | 14 | _FEATURES = datasets.Features( 15 | { 16 | "text": datasets.Value("string"), 17 | "image": datasets.Image(), 18 | "depth_image": datasets.Value("string"), 19 | "semantic_image": datasets.Image(), 20 | }, 21 | ) 22 | _root = os.getenv("DETECTRON2_DATASETS", "path_to/datasets") 23 | METADATA_PATH = os.path.join(_root, "UWD_triplets/UWDense.json") 24 | IMAGES_DIR = os.path.join(_root, "UWD_triplets/images") 25 | DEPTH_IMAGES_DIR = os.path.join(_root, "UWD_triplets/disparity_depth") 26 | SEMANTIC_IMAGES_DIR = os.path.join(_root, "UWD_triplets/annotations") 27 | 28 | _DEFAULT_CONFIG = datasets.BuilderConfig(name="default", version=_VERSION) 29 | 30 | class Depth2Underwater(datasets.GeneratorBasedBuilder): 31 | BUILDER_CONFIGS = [_DEFAULT_CONFIG] 32 | DEFAULT_CONFIG_NAME = "default" 33 | 34 | def _info(self): 35 | return datasets.DatasetInfo( 36 | description=_DESCRIPTION, 37 | features=_FEATURES, 38 | supervised_keys=None, 39 | homepage=_HOMEPAGE, 40 | license=_LICENSE, 41 | citation=_CITATION, 42 | ) 43 | 44 | def _split_generators(self, dl_manager): 45 | metadata_path = METADATA_PATH 46 | images_dir = IMAGES_DIR 47 | depth_images_dir = DEPTH_IMAGES_DIR 48 | semantic_images_dir = SEMANTIC_IMAGES_DIR 49 | return [ 50 | datasets.SplitGenerator( 51 | name=datasets.Split.TRAIN, 52 | # These kwargs will be passed to _generate_examples 53 | gen_kwargs={ 54 | "metadata_path": metadata_path, 55 | "images_dir": images_dir, 56 | "depth_images_dir": depth_images_dir, 57 | "semantic_images_dir": semantic_images_dir, 58 | }, 59 | ), 60 | ] 61 | 62 | def _generate_examples(self, metadata_path, images_dir, depth_images_dir, semantic_images_dir): 63 | metadata = pd.read_json(metadata_path, lines=True) 64 | 65 | for _, row in metadata.iterrows(): 66 | text = row["text"] 67 | 68 | image_path = row["image"] 69 | image_path = os.path.join(images_dir, image_path) 70 | image = open(image_path, "rb").read() 71 | 72 | if '.jpg' in row["conditioning_image"]: 73 | depth_image_path = row["conditioning_image"].replace('.jpg', '_raw_depth_meter.npy') 74 | semantic_image_path = row["conditioning_image"].replace('.jpg', '.png') 75 | else: 76 | assert '.png' in row["conditioning_image"] 77 | depth_image_path = row["conditioning_image"].replace('.png', '_raw_depth_meter.npy') 78 | semantic_image_path = row["conditioning_image"] 79 | depth_image_path = os.path.join( 80 | depth_images_dir, depth_image_path 81 | ) 82 | semantic_image_path = os.path.join( 83 | semantic_images_dir, semantic_image_path 84 | ) 85 | semantic_image = open(semantic_image_path, "rb").read() 86 | 87 | yield row["image"], { 88 | "text": text, 89 | "image": { 90 | "path": image_path, 91 | "bytes": image, 92 | }, 93 | "depth_image": depth_image_path, 94 | "semantic_image": semantic_image 95 | } -------------------------------------------------------------------------------- /diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from .iddpm import IDDPM 7 | from .dpm_solver import DPMS 8 | from .sa_sampler import SASolverSampler 9 | -------------------------------------------------------------------------------- /diffusion/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import * 2 | from .transforms import get_transform 3 | -------------------------------------------------------------------------------- /diffusion/data/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | from mmcv import Registry, build_from_cfg 5 | from torch.utils.data import DataLoader 6 | 7 | from diffusion.data.transforms import get_transform 8 | from diffusion.utils.logger import get_root_logger 9 | 10 | DATASETS = Registry('datasets') 11 | 12 | DATA_ROOT = '/cache/data' 13 | 14 | 15 | def set_data_root(data_root): 16 | global DATA_ROOT 17 | DATA_ROOT = data_root 18 | 19 | 20 | def get_data_path(data_dir): 21 | if os.path.isabs(data_dir): 22 | return data_dir 23 | global DATA_ROOT 24 | return os.path.join(DATA_ROOT, data_dir) 25 | 26 | 27 | def build_dataset(cfg, resolution=224, **kwargs): 28 | logger = get_root_logger() 29 | 30 | dataset_type = cfg.get('type') 31 | logger.info(f"Constructing dataset {dataset_type}...") 32 | t = time.time() 33 | transform = cfg.pop('transform', 'default_train') 34 | transform = get_transform(transform, resolution) 35 | dataset = build_from_cfg(cfg, DATASETS, default_args=dict(transform=transform, resolution=resolution, **kwargs)) 36 | logger.info(f"Dataset {dataset_type} constructed. time: {(time.time() - t):.2f} s, length (use/ori): {len(dataset)}/{dataset.ori_imgs_nums}") 37 | return dataset 38 | 39 | 40 | def build_dataloader(dataset, batch_size=256, num_workers=4, shuffle=True, **kwargs): 41 | return ( 42 | DataLoader( 43 | dataset, 44 | batch_sampler=kwargs['batch_sampler'], 45 | num_workers=num_workers, 46 | pin_memory=True, 47 | ) 48 | if 'batch_sampler' in kwargs 49 | else DataLoader( 50 | dataset, 51 | batch_size=batch_size, 52 | shuffle=shuffle, 53 | num_workers=num_workers, 54 | pin_memory=True, 55 | **kwargs 56 | ) 57 | ) 58 | -------------------------------------------------------------------------------- /diffusion/data/datasets/Dreambooth.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import torch 4 | from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS 5 | from torch.utils.data import Dataset 6 | from diffusers.utils.torch_utils import randn_tensor 7 | from torchvision import transforms as T 8 | import pathlib 9 | from diffusers.models import AutoencoderKL 10 | 11 | from diffusion.data.builder import get_data_path, DATASETS 12 | from diffusion.data.datasets.utils import * 13 | 14 | IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm', 'tif', 'tiff', 'webp', 'JPEG'} 15 | 16 | 17 | @DATASETS.register_module() 18 | class DreamBooth(Dataset): 19 | def __init__(self, 20 | root, 21 | transform=None, 22 | resolution=1024, 23 | **kwargs): 24 | self.root = get_data_path(root) 25 | path = pathlib.Path(self.root) 26 | self.transform = transform 27 | self.resolution = resolution 28 | self.img_samples = sorted( 29 | [file for ext in IMAGE_EXTENSIONS for file in path.glob(f'*.{ext}')] 30 | ) 31 | self.ori_imgs_nums = len(self) 32 | self.loader = default_loader 33 | self.base_size = int(kwargs['aspect_ratio_type'].split('_')[-1]) 34 | self.aspect_ratio = eval(kwargs.pop('aspect_ratio_type')) # base aspect ratio 35 | self.ratio_nums = {} 36 | for k, v in self.aspect_ratio.items(): 37 | self.ratio_nums[float(k)] = 0 # used for batch-sampler 38 | self.data_info = {'img_hw': torch.tensor([resolution, resolution], dtype=torch.float32), 'aspect_ratio': 1.} 39 | 40 | # image related 41 | with torch.inference_mode(): 42 | vae = AutoencoderKL.from_pretrained("output/pretrained_models/sd-vae-ft-ema") 43 | imgs = [] 44 | for img_path in self.img_samples: 45 | img = self.loader(img_path) 46 | self.ratio_nums[1.0] += 1 47 | if self.transform is not None: 48 | imgs.append(self.transform(img)) 49 | imgs = torch.stack(imgs, dim=0) 50 | self.img_vae = vae.encode(imgs).latent_dist.sample() 51 | del vae 52 | 53 | def __getitem__(self, index): 54 | return self.img_vae[index], self.data_info 55 | 56 | @staticmethod 57 | def vae_feat_loader(path): 58 | # [mean, std] 59 | mean, std = torch.from_numpy(np.load(path)).chunk(2) 60 | sample = randn_tensor(mean.shape, generator=None, device=mean.device, dtype=mean.dtype) 61 | return mean + std * sample 62 | 63 | def load_ori_img(self, img_path): 64 | # 加载图像并转换为Tensor 65 | transform = T.Compose([ 66 | T.Resize(256), # Image.BICUBIC 67 | T.CenterCrop(256), 68 | T.ToTensor(), 69 | ]) 70 | return transform(Image.open(img_path)) 71 | 72 | def __len__(self): 73 | return len(self.img_samples) 74 | 75 | def __getattr__(self, name): 76 | if name == "set_epoch": 77 | return lambda epoch: None 78 | raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") 79 | 80 | def get_data_info(self, idx): 81 | return {'height': self.resolution, 'width': self.resolution} 82 | -------------------------------------------------------------------------------- /diffusion/data/datasets/InternalData.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from PIL import Image 4 | import numpy as np 5 | import torch 6 | from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS 7 | from torch.utils.data import Dataset 8 | from diffusers.utils.torch_utils import randn_tensor 9 | from torchvision import transforms as T 10 | from diffusion.data.builder import get_data_path, DATASETS 11 | from diffusion.utils.logger import get_root_logger 12 | 13 | import json 14 | 15 | 16 | @DATASETS.register_module() 17 | class InternalData(Dataset): 18 | def __init__(self, 19 | root, 20 | image_list_json='data_info.json', 21 | transform=None, 22 | resolution=256, 23 | sample_subset=None, 24 | load_vae_feat=False, 25 | input_size=32, 26 | patch_size=2, 27 | mask_ratio=0.0, 28 | load_mask_index=False, 29 | max_length=120, 30 | config=None, 31 | **kwargs): 32 | self.root = get_data_path(root) 33 | self.transform = transform 34 | self.load_vae_feat = load_vae_feat 35 | self.ori_imgs_nums = 0 36 | self.resolution = resolution 37 | self.N = int(resolution // (input_size // patch_size)) 38 | self.mask_ratio = mask_ratio 39 | self.load_mask_index = load_mask_index 40 | self.max_lenth = max_length 41 | self.meta_data_clean = [] 42 | self.img_samples = [] 43 | self.txt_feat_samples = [] 44 | self.vae_feat_samples = [] 45 | self.mask_index_samples = [] 46 | self.prompt_samples = [] 47 | 48 | image_list_json = image_list_json if isinstance(image_list_json, list) else [image_list_json] 49 | for json_file in image_list_json: 50 | meta_data = self.load_json(os.path.join(self.root, 'partition', json_file)) 51 | self.ori_imgs_nums += len(meta_data) 52 | meta_data_clean = [item for item in meta_data if item['ratio'] <= 4] 53 | self.meta_data_clean.extend(meta_data_clean) 54 | self.img_samples.extend([os.path.join(self.root.replace('InternData', "InternImgs"), item['path']) for item in meta_data_clean]) 55 | self.txt_feat_samples.extend([os.path.join(self.root, 'caption_feature_wmask', '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npz')) for item in meta_data_clean]) 56 | self.vae_feat_samples.extend([os.path.join(self.root, f'img_vae_features_{resolution}resolution/noflip', '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npy')) for item in meta_data_clean]) 57 | self.prompt_samples.extend([item['prompt'] for item in meta_data_clean]) 58 | 59 | # Set loader and extensions 60 | if load_vae_feat: 61 | self.transform = None 62 | self.loader = self.vae_feat_loader 63 | else: 64 | self.loader = default_loader 65 | 66 | if sample_subset is not None: 67 | self.sample_subset(sample_subset) # sample dataset for local debug 68 | logger = get_root_logger() if config is None else get_root_logger(os.path.join(config.work_dir, 'train_log.log')) 69 | logger.info(f"T5 max token length: {self.max_lenth}") 70 | 71 | def getdata(self, index): 72 | img_path = self.img_samples[index] 73 | npz_path = self.txt_feat_samples[index] 74 | npy_path = self.vae_feat_samples[index] 75 | prompt = self.prompt_samples[index] 76 | data_info = { 77 | 'img_hw': torch.tensor([torch.tensor(self.resolution), torch.tensor(self.resolution)], dtype=torch.float32), 78 | 'aspect_ratio': torch.tensor(1.) 79 | } 80 | 81 | img = self.loader(npy_path) if self.load_vae_feat else self.loader(img_path) 82 | txt_info = np.load(npz_path) 83 | txt_fea = torch.from_numpy(txt_info['caption_feature']) # 1xTx4096 84 | attention_mask = torch.ones(1, 1, txt_fea.shape[1]) # 1x1xT 85 | if 'attention_mask' in txt_info.keys(): 86 | attention_mask = torch.from_numpy(txt_info['attention_mask'])[None] 87 | if txt_fea.shape[1] != self.max_lenth: 88 | txt_fea = torch.cat([txt_fea, txt_fea[:, -1:].repeat(1, self.max_lenth-txt_fea.shape[1], 1)], dim=1) 89 | attention_mask = torch.cat([attention_mask, torch.zeros(1, 1, self.max_lenth-attention_mask.shape[-1])], dim=-1) 90 | 91 | if self.transform: 92 | img = self.transform(img) 93 | 94 | data_info['prompt'] = prompt 95 | return img, txt_fea, attention_mask, data_info 96 | 97 | def __getitem__(self, idx): 98 | for _ in range(20): 99 | try: 100 | return self.getdata(idx) 101 | except Exception as e: 102 | print(f"Error details: {str(e)}") 103 | idx = np.random.randint(len(self)) 104 | raise RuntimeError('Too many bad data.') 105 | 106 | def get_data_info(self, idx): 107 | data_info = self.meta_data_clean[idx] 108 | return {'height': data_info['height'], 'width': data_info['width']} 109 | 110 | @staticmethod 111 | def vae_feat_loader(path): 112 | # [mean, std] 113 | mean, std = torch.from_numpy(np.load(path)).chunk(2) 114 | sample = randn_tensor(mean.shape, generator=None, device=mean.device, dtype=mean.dtype) 115 | return mean + std * sample 116 | 117 | def load_ori_img(self, img_path): 118 | # 加载图像并转换为Tensor 119 | transform = T.Compose([ 120 | T.Resize(256), # Image.BICUBIC 121 | T.CenterCrop(256), 122 | T.ToTensor(), 123 | ]) 124 | return transform(Image.open(img_path)) 125 | 126 | def load_json(self, file_path): 127 | with open(file_path, 'r') as f: 128 | meta_data = json.load(f) 129 | 130 | return meta_data 131 | 132 | def sample_subset(self, ratio): 133 | sampled_idx = random.sample(list(range(len(self))), int(len(self) * ratio)) 134 | self.img_samples = [self.img_samples[i] for i in sampled_idx] 135 | 136 | def __len__(self): 137 | return len(self.img_samples) 138 | 139 | def __getattr__(self, name): 140 | if name == "set_epoch": 141 | return lambda epoch: None 142 | raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") 143 | 144 | -------------------------------------------------------------------------------- /diffusion/data/datasets/InternalData_ms.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import random 5 | from torchvision.datasets.folder import default_loader 6 | from diffusion.data.datasets.InternalData import InternalData 7 | from diffusion.data.builder import get_data_path, DATASETS 8 | from diffusion.utils.logger import get_root_logger 9 | import torchvision.transforms as T 10 | from torchvision.transforms.functional import InterpolationMode 11 | from diffusion.data.datasets.utils import * 12 | 13 | def get_closest_ratio(height: float, width: float, ratios: dict): 14 | aspect_ratio = height / width 15 | closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio)) 16 | return ratios[closest_ratio], float(closest_ratio) 17 | 18 | 19 | @DATASETS.register_module() 20 | class InternalDataMS(InternalData): 21 | def __init__(self, 22 | root, 23 | image_list_json='data_info.json', 24 | transform=None, 25 | resolution=256, 26 | sample_subset=None, 27 | load_vae_feat=False, 28 | input_size=32, 29 | patch_size=2, 30 | mask_ratio=0.0, 31 | mask_type='null', 32 | load_mask_index=False, 33 | max_length=120, 34 | config=None, 35 | **kwargs): 36 | self.root = get_data_path(root) 37 | self.transform = transform 38 | self.load_vae_feat = load_vae_feat 39 | self.ori_imgs_nums = 0 40 | self.resolution = resolution 41 | self.N = int(resolution // (input_size // patch_size)) 42 | self.mask_ratio = mask_ratio 43 | self.load_mask_index = load_mask_index 44 | self.mask_type = mask_type 45 | self.base_size = int(kwargs['aspect_ratio_type'].split('_')[-1]) 46 | self.max_lenth = max_length 47 | self.aspect_ratio = eval(kwargs.pop('aspect_ratio_type')) # base aspect ratio 48 | self.meta_data_clean = [] 49 | self.img_samples = [] 50 | self.txt_feat_samples = [] 51 | self.vae_feat_samples = [] 52 | self.mask_index_samples = [] 53 | self.ratio_index = {} 54 | self.ratio_nums = {} 55 | for k, v in self.aspect_ratio.items(): 56 | self.ratio_index[float(k)] = [] # used for self.getitem 57 | self.ratio_nums[float(k)] = 0 # used for batch-sampler 58 | 59 | image_list_json = image_list_json if isinstance(image_list_json, list) else [image_list_json] 60 | for json_file in image_list_json: 61 | meta_data = self.load_json(os.path.join(self.root, 'partition_filter', json_file)) 62 | self.ori_imgs_nums += len(meta_data) 63 | meta_data_clean = [item for item in meta_data if item['ratio'] <= 4] 64 | self.meta_data_clean.extend(meta_data_clean) 65 | self.img_samples.extend([os.path.join(self.root.replace('InternData', "InternImgs"), item['path']) for item in meta_data_clean]) 66 | self.txt_feat_samples.extend([os.path.join(self.root, 'caption_feature_wmask', '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npz')) for item in meta_data_clean]) 67 | self.vae_feat_samples.extend([os.path.join(self.root, f'img_vae_fatures_{resolution}_multiscale/ms', '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npy')) for item in meta_data_clean]) 68 | 69 | # Set loader and extensions 70 | if load_vae_feat: 71 | self.transform = None 72 | self.loader = self.vae_feat_loader 73 | else: 74 | self.loader = default_loader 75 | 76 | if sample_subset is not None: 77 | self.sample_subset(sample_subset) # sample dataset for local debug 78 | 79 | # scan the dataset for ratio static 80 | for i, info in enumerate(self.meta_data_clean[:len(self.meta_data_clean)//3]): 81 | ori_h, ori_w = info['height'], info['width'] 82 | closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, self.aspect_ratio) 83 | self.ratio_nums[closest_ratio] += 1 84 | if len(self.ratio_index[closest_ratio]) == 0: 85 | self.ratio_index[closest_ratio].append(i) 86 | # print(self.ratio_nums) 87 | logger = get_root_logger() if config is None else get_root_logger(os.path.join(config.work_dir, 'train_log.log')) 88 | logger.info(f"T5 max token length: {self.max_lenth}") 89 | 90 | def getdata(self, index): 91 | img_path = self.img_samples[index] 92 | npz_path = self.txt_feat_samples[index] 93 | npy_path = self.vae_feat_samples[index] 94 | ori_h, ori_w = self.meta_data_clean[index]['height'], self.meta_data_clean[index]['width'] 95 | 96 | # Calculate the closest aspect ratio and resize & crop image[w, h] 97 | closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, self.aspect_ratio) 98 | closest_size = list(map(lambda x: int(x), closest_size)) 99 | self.closest_ratio = closest_ratio 100 | 101 | if self.load_vae_feat: 102 | try: 103 | img = self.loader(npy_path) 104 | if index not in self.ratio_index[closest_ratio]: 105 | self.ratio_index[closest_ratio].append(index) 106 | except Exception: 107 | index = random.choice(self.ratio_index[closest_ratio]) 108 | return self.getdata(index) 109 | h, w = (img.shape[1], img.shape[2]) 110 | assert h, w == (ori_h//8, ori_w//8) 111 | else: 112 | img = self.loader(img_path) 113 | h, w = (img.size[1], img.size[0]) 114 | assert h, w == (ori_h, ori_w) 115 | 116 | data_info = {'img_hw': torch.tensor([ori_h, ori_w], dtype=torch.float32)} 117 | data_info['aspect_ratio'] = closest_ratio 118 | data_info["mask_type"] = self.mask_type 119 | 120 | txt_info = np.load(npz_path) 121 | txt_fea = torch.from_numpy(txt_info['caption_feature']) 122 | attention_mask = torch.ones(1, 1, txt_fea.shape[1]) 123 | if 'attention_mask' in txt_info.keys(): 124 | attention_mask = torch.from_numpy(txt_info['attention_mask'])[None] 125 | 126 | if not self.load_vae_feat: 127 | if closest_size[0] / ori_h > closest_size[1] / ori_w: 128 | resize_size = closest_size[0], int(ori_w * closest_size[0] / ori_h) 129 | else: 130 | resize_size = int(ori_h * closest_size[1] / ori_w), closest_size[1] 131 | self.transform = T.Compose([ 132 | T.Lambda(lambda img: img.convert('RGB')), 133 | T.Resize(resize_size, interpolation=InterpolationMode.BICUBIC), # Image.BICUBIC 134 | T.CenterCrop(closest_size), 135 | T.ToTensor(), 136 | T.Normalize([.5], [.5]), 137 | ]) 138 | 139 | if self.transform: 140 | img = self.transform(img) 141 | 142 | return img, txt_fea, attention_mask, data_info 143 | 144 | def __getitem__(self, idx): 145 | for _ in range(20): 146 | try: 147 | return self.getdata(idx) 148 | except Exception as e: 149 | print(f"Error details: {str(e)}") 150 | idx = random.choice(self.ratio_index[self.closest_ratio]) 151 | raise RuntimeError('Too many bad data.') 152 | -------------------------------------------------------------------------------- /diffusion/data/datasets/SA.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import time 4 | 5 | import numpy as np 6 | import torch 7 | from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS 8 | from torch.utils.data import Dataset 9 | from diffusers.utils.torch_utils import randn_tensor 10 | 11 | from diffusion.data.builder import get_data_path, DATASETS 12 | 13 | 14 | @DATASETS.register_module() 15 | class SAM(Dataset): 16 | def __init__(self, 17 | root, 18 | image_list_txt='part0.txt', 19 | transform=None, 20 | resolution=256, 21 | sample_subset=None, 22 | load_vae_feat=False, 23 | mask_ratio=0.0, 24 | mask_type='null', 25 | **kwargs): 26 | self.root = get_data_path(root) 27 | self.transform = transform 28 | self.load_vae_feat = load_vae_feat 29 | self.mask_type = mask_type 30 | self.mask_ratio = mask_ratio 31 | self.resolution = resolution 32 | self.img_samples = [] 33 | self.txt_feat_samples = [] 34 | self.vae_feat_samples = [] 35 | image_list_txt = image_list_txt if isinstance(image_list_txt, list) else [image_list_txt] 36 | if image_list_txt == 'all': 37 | image_list_txts = os.listdir(os.path.join(self.root, 'partition')) 38 | for txt in image_list_txts: 39 | image_list = os.path.join(self.root, 'partition', txt) 40 | with open(image_list, 'r') as f: 41 | lines = [line.strip() for line in f.readlines()] 42 | self.img_samples.extend([os.path.join(self.root, 'images', i+'.jpg') for i in lines]) 43 | self.txt_feat_samples.extend([os.path.join(self.root, 'caption_feature_wmask', i+'.npz') for i in lines]) 44 | elif isinstance(image_list_txt, list): 45 | for txt in image_list_txt: 46 | image_list = os.path.join(self.root, 'partition', txt) 47 | with open(image_list, 'r') as f: 48 | lines = [line.strip() for line in f.readlines()] 49 | self.img_samples.extend([os.path.join(self.root, 'images', i + '.jpg') for i in lines]) 50 | self.txt_feat_samples.extend([os.path.join(self.root, 'caption_feature_wmask', i + '.npz') for i in lines]) 51 | self.vae_feat_samples.extend([os.path.join(self.root, 'img_vae_feature/train_vae_256/noflip', i + '.npy') for i in lines]) 52 | 53 | self.ori_imgs_nums = len(self) 54 | # self.img_samples = self.img_samples[:10000] 55 | # Set loader and extensions 56 | if load_vae_feat: 57 | self.transform = None 58 | self.loader = self.vae_feat_loader 59 | else: 60 | self.loader = default_loader 61 | 62 | if sample_subset is not None: 63 | self.sample_subset(sample_subset) # sample dataset for local debug 64 | 65 | def getdata(self, idx): 66 | img_path = self.img_samples[idx] 67 | npz_path = self.txt_feat_samples[idx] 68 | npy_path = self.vae_feat_samples[idx] 69 | data_info = {'img_hw': torch.tensor([self.resolution, self.resolution], dtype=torch.float32), 70 | 'aspect_ratio': torch.tensor(1.)} 71 | 72 | img = self.loader(npy_path) if self.load_vae_feat else self.loader(img_path) 73 | npz_info = np.load(npz_path) 74 | txt_fea = torch.from_numpy(npz_info['caption_feature']) 75 | attention_mask = torch.ones(1, 1, txt_fea.shape[1]) 76 | if 'attention_mask' in npz_info.keys(): 77 | attention_mask = torch.from_numpy(npz_info['attention_mask'])[None] 78 | 79 | if self.transform: 80 | img = self.transform(img) 81 | 82 | data_info["mask_type"] = self.mask_type 83 | 84 | return img, txt_fea, attention_mask, data_info 85 | 86 | def __getitem__(self, idx): 87 | for _ in range(20): 88 | try: 89 | return self.getdata(idx) 90 | except Exception: 91 | print(self.img_samples[idx], ' info is not correct') 92 | idx = np.random.randint(len(self)) 93 | raise RuntimeError('Too many bad data.') 94 | 95 | @staticmethod 96 | def vae_feat_loader(path): 97 | # [mean, std] 98 | mean, std = torch.from_numpy(np.load(path)).chunk(2) 99 | sample = randn_tensor(mean.shape, generator=None, device=mean.device, dtype=mean.dtype) 100 | return mean + std * sample 101 | # return mean 102 | 103 | def sample_subset(self, ratio): 104 | sampled_idx = random.sample(list(range(len(self))), int(len(self) * ratio)) 105 | self.img_samples = [self.img_samples[i] for i in sampled_idx] 106 | self.txt_feat_samples = [self.txt_feat_samples[i] for i in sampled_idx] 107 | 108 | def __len__(self): 109 | return len(self.img_samples) 110 | -------------------------------------------------------------------------------- /diffusion/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .SA import SAM 2 | from .InternalData import InternalData 3 | from .InternalData_ms import InternalDataMS 4 | from .Dreambooth import DreamBooth 5 | from .pixart_control import InternalDataHed 6 | from .utils import * 7 | -------------------------------------------------------------------------------- /diffusion/data/datasets/utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | ASPECT_RATIO_1024 = { 4 | '0.25': [512., 2048.], '0.26': [512., 1984.], '0.27': [512., 1920.], '0.28': [512., 1856.], 5 | '0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.], 6 | '0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.], 7 | '0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.], 8 | '0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.], 9 | '1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.], 10 | '1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.], 11 | '1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.], 12 | '2.5': [1600., 640.], '2.89': [1664., 576.], '3.0': [1728., 576.], '3.11': [1792., 576.], 13 | '3.62': [1856., 512.], '3.75': [1920., 512.], '3.88': [1984., 512.], '4.0': [2048., 512.], 14 | } 15 | 16 | ASPECT_RATIO_512 = { 17 | '0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0], 18 | '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0], 19 | '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0], 20 | '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0], 21 | '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0], 22 | '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0], 23 | '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0], 24 | '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0], 25 | '2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0], 26 | '3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0] 27 | } 28 | 29 | ASPECT_RATIO_256 = { 30 | '0.25': [128.0, 512.0], '0.26': [128.0, 496.0], '0.27': [128.0, 480.0], '0.28': [128.0, 464.0], 31 | '0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0], 32 | '0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0], 33 | '0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0], 34 | '0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0], 35 | '1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0], 36 | '1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0], 37 | '1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0], 38 | '2.5': [400.0, 160.0], '2.89': [416.0, 144.0], '3.0': [432.0, 144.0], '3.11': [448.0, 144.0], 39 | '3.62': [464.0, 128.0], '3.75': [480.0, 128.0], '3.88': [496.0, 128.0], '4.0': [512.0, 128.0] 40 | } 41 | 42 | ASPECT_RATIO_256_TEST = { 43 | '0.25': [128.0, 512.0], '0.28': [128.0, 464.0], 44 | '0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0], 45 | '0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0], 46 | '0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0], 47 | '0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0], 48 | '1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0], 49 | '1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0], 50 | '1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0], 51 | '2.5': [400.0, 160.0], '3.0': [432.0, 144.0], 52 | '4.0': [512.0, 128.0] 53 | } 54 | 55 | ASPECT_RATIO_512_TEST = { 56 | '0.25': [256.0, 1024.0], '0.28': [256.0, 928.0], 57 | '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0], 58 | '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0], 59 | '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0], 60 | '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0], 61 | '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0], 62 | '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0], 63 | '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0], 64 | '2.5': [800.0, 320.0], '3.0': [864.0, 288.0], 65 | '4.0': [1024.0, 256.0] 66 | } 67 | 68 | ASPECT_RATIO_1024_TEST = { 69 | '0.25': [512., 2048.], '0.28': [512., 1856.], 70 | '0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.], 71 | '0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.], 72 | '0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.], 73 | '0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.], 74 | '1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.], 75 | '1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.], 76 | '1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.], 77 | '2.5': [1600., 640.], '3.0': [1728., 576.], 78 | '4.0': [2048., 512.], 79 | } 80 | 81 | 82 | def get_chunks(lst, n): 83 | for i in range(0, len(lst), n): 84 | yield lst[i:i + n] 85 | -------------------------------------------------------------------------------- /diffusion/data/transforms.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as T 2 | 3 | TRANSFORMS = {} 4 | 5 | 6 | def register_transform(transform): 7 | name = transform.__name__ 8 | if name in TRANSFORMS: 9 | raise RuntimeError(f'Transform {name} has already registered.') 10 | TRANSFORMS.update({name: transform}) 11 | 12 | 13 | def get_transform(type, resolution): 14 | transform = TRANSFORMS[type](resolution) 15 | transform = T.Compose(transform) 16 | transform.image_size = resolution 17 | return transform 18 | 19 | 20 | @register_transform 21 | def default_train(n_px): 22 | return [ 23 | T.Lambda(lambda img: img.convert('RGB')), 24 | T.Resize(n_px), # Image.BICUBIC 25 | T.CenterCrop(n_px), 26 | # T.RandomHorizontalFlip(), 27 | T.ToTensor(), 28 | T.Normalize([0.5], [0.5]), 29 | ] 30 | -------------------------------------------------------------------------------- /diffusion/dpm_solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .model import gaussian_diffusion as gd 3 | from .model.dpm_solver import model_wrapper, DPM_Solver, NoiseScheduleVP 4 | 5 | 6 | def DPMS(model, condition, uncondition, cfg_scale, model_type='noise', noise_schedule="linear", guidance_type='classifier-free', model_kwargs=None, diffusion_steps=1000): 7 | if model_kwargs is None: 8 | model_kwargs = {} 9 | betas = torch.tensor(gd.get_named_beta_schedule(noise_schedule, diffusion_steps)) 10 | 11 | ## 1. Define the noise schedule. 12 | noise_schedule = NoiseScheduleVP(schedule='discrete', betas=betas) 13 | 14 | ## 2. Convert your discrete-time `model` to the continuous-time 15 | ## noise prediction model. Here is an example for a diffusion model 16 | ## `model` with the noise prediction type ("noise") . 17 | model_fn = model_wrapper( 18 | model, 19 | noise_schedule, 20 | model_type=model_type, 21 | model_kwargs=model_kwargs, 22 | guidance_type=guidance_type, 23 | condition=condition, 24 | unconditional_condition=uncondition, 25 | guidance_scale=cfg_scale, 26 | ) 27 | ## 3. Define dpm-solver and sample by multistep DPM-Solver. 28 | return DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") -------------------------------------------------------------------------------- /diffusion/iddpm.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | from diffusion.model.respace import SpacedDiffusion, space_timesteps 6 | from .model import gaussian_diffusion as gd 7 | 8 | 9 | def IDDPM( 10 | timestep_respacing, 11 | noise_schedule="linear", 12 | use_kl=False, 13 | sigma_small=False, 14 | predict_xstart=False, 15 | learn_sigma=True, 16 | pred_sigma=True, 17 | rescale_learned_sigmas=False, 18 | diffusion_steps=1000, 19 | snr=False, 20 | return_startx=False, 21 | ): 22 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) 23 | if use_kl: 24 | loss_type = gd.LossType.RESCALED_KL 25 | elif rescale_learned_sigmas: 26 | loss_type = gd.LossType.RESCALED_MSE 27 | else: 28 | loss_type = gd.LossType.MSE 29 | if timestep_respacing is None or timestep_respacing == "": 30 | timestep_respacing = [diffusion_steps] 31 | return SpacedDiffusion( 32 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), 33 | betas=betas, 34 | model_mean_type=( 35 | gd.ModelMeanType.START_X if predict_xstart else gd.ModelMeanType.EPSILON 36 | ), 37 | model_var_type=( 38 | (gd.ModelVarType.LEARNED_RANGE if learn_sigma else ( 39 | gd.ModelVarType.FIXED_LARGE 40 | if not sigma_small 41 | else gd.ModelVarType.FIXED_SMALL 42 | ) 43 | ) 44 | if pred_sigma 45 | else None 46 | ), 47 | loss_type=loss_type, 48 | snr=snr, 49 | return_startx=return_startx, 50 | # rescale_timesteps=rescale_timesteps, 51 | ) -------------------------------------------------------------------------------- /diffusion/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .nets import * 2 | -------------------------------------------------------------------------------- /diffusion/model/builder.py: -------------------------------------------------------------------------------- 1 | from mmcv import Registry 2 | 3 | from diffusion.model.utils import set_grad_checkpoint 4 | 5 | MODELS = Registry('models') 6 | 7 | 8 | def build_model(cfg, use_grad_checkpoint=False, use_fp32_attention=False, gc_step=1, **kwargs): 9 | if isinstance(cfg, str): 10 | cfg = dict(type=cfg) 11 | model = MODELS.build(cfg, default_args=kwargs) 12 | if use_grad_checkpoint: 13 | set_grad_checkpoint(model, use_fp32_attention=use_fp32_attention, gc_step=gc_step) 14 | return model 15 | -------------------------------------------------------------------------------- /diffusion/model/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import numpy as np 7 | import torch as th 8 | 9 | 10 | def normal_kl(mean1, logvar1, mean2, logvar2): 11 | """ 12 | Compute the KL divergence between two gaussians. 13 | Shapes are automatically broadcasted, so batches can be compared to 14 | scalars, among other use cases. 15 | """ 16 | tensor = next( 17 | ( 18 | obj 19 | for obj in (mean1, logvar1, mean2, logvar2) 20 | if isinstance(obj, th.Tensor) 21 | ), 22 | None, 23 | ) 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x, device=tensor.device) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | return 0.5 * ( 34 | -1.0 35 | + logvar2 36 | - logvar1 37 | + th.exp(logvar1 - logvar2) 38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 39 | ) 40 | 41 | 42 | def approx_standard_normal_cdf(x): 43 | """ 44 | A fast approximation of the cumulative distribution function of the 45 | standard normal. 46 | """ 47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 48 | 49 | 50 | def continuous_gaussian_log_likelihood(x, *, means, log_scales): 51 | """ 52 | Compute the log-likelihood of a continuous Gaussian distribution. 53 | :param x: the targets 54 | :param means: the Gaussian mean Tensor. 55 | :param log_scales: the Gaussian log stddev Tensor. 56 | :return: a tensor like x of log probabilities (in nats). 57 | """ 58 | centered_x = x - means 59 | inv_stdv = th.exp(-log_scales) 60 | normalized_x = centered_x * inv_stdv 61 | return th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob( 62 | normalized_x 63 | ) 64 | 65 | 66 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 67 | """ 68 | Compute the log-likelihood of a Gaussian distribution discretizing to a 69 | given image. 70 | :param x: the target images. It is assumed that this was uint8 values, 71 | rescaled to the range [-1, 1]. 72 | :param means: the Gaussian mean Tensor. 73 | :param log_scales: the Gaussian log stddev Tensor. 74 | :return: a tensor like x of log probabilities (in nats). 75 | """ 76 | assert x.shape == means.shape == log_scales.shape 77 | centered_x = x - means 78 | inv_stdv = th.exp(-log_scales) 79 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 80 | cdf_plus = approx_standard_normal_cdf(plus_in) 81 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 82 | cdf_min = approx_standard_normal_cdf(min_in) 83 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 84 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 85 | cdf_delta = cdf_plus - cdf_min 86 | log_probs = th.where( 87 | x < -0.999, 88 | log_cdf_plus, 89 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 90 | ) 91 | assert log_probs.shape == x.shape 92 | return log_probs 93 | -------------------------------------------------------------------------------- /diffusion/model/edm_sample.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from tqdm import tqdm 4 | 5 | from diffusion.model.utils import * 6 | 7 | 8 | # ---------------------------------------------------------------------------- 9 | # Proposed EDM sampler (Algorithm 2). 10 | 11 | def edm_sampler( 12 | net, latents, class_labels=None, cfg_scale=None, randn_like=torch.randn_like, 13 | num_steps=18, sigma_min=0.002, sigma_max=80, rho=7, 14 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, **kwargs 15 | ): 16 | # Adjust noise levels based on what's supported by the network. 17 | sigma_min = max(sigma_min, net.sigma_min) 18 | sigma_max = min(sigma_max, net.sigma_max) 19 | 20 | # Time step discretization. 21 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) 22 | t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * ( 23 | sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho 24 | t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0 25 | 26 | # Main sampling loop. 27 | x_next = latents.to(torch.float64) * t_steps[0] 28 | for i, (t_cur, t_next) in tqdm(list(enumerate(zip(t_steps[:-1], t_steps[1:])))): # 0, ..., N-1 29 | x_cur = x_next 30 | 31 | # Increase noise temporarily. 32 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 33 | t_hat = net.round_sigma(t_cur + gamma * t_cur) 34 | x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur) 35 | 36 | # Euler step. 37 | denoised = net(x_hat.float(), t_hat, class_labels, cfg_scale, **kwargs)['x'].to(torch.float64) 38 | d_cur = (x_hat - denoised) / t_hat 39 | x_next = x_hat + (t_next - t_hat) * d_cur 40 | 41 | # Apply 2nd order correction. 42 | if i < num_steps - 1: 43 | denoised = net(x_next.float(), t_next, class_labels, cfg_scale, **kwargs)['x'].to(torch.float64) 44 | d_prime = (x_next - denoised) / t_next 45 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) 46 | 47 | return x_next 48 | 49 | 50 | # ---------------------------------------------------------------------------- 51 | # Generalized ablation sampler, representing the superset of all sampling 52 | # methods discussed in the paper. 53 | 54 | def ablation_sampler( 55 | net, latents, class_labels=None, cfg_scale=None, feat=None, randn_like=torch.randn_like, 56 | num_steps=18, sigma_min=None, sigma_max=None, rho=7, 57 | solver='heun', discretization='edm', schedule='linear', scaling='none', 58 | epsilon_s=1e-3, C_1=0.001, C_2=0.008, M=1000, alpha=1, 59 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, 60 | ): 61 | assert solver in ['euler', 'heun'] 62 | assert discretization in ['vp', 've', 'iddpm', 'edm'] 63 | assert schedule in ['vp', 've', 'linear'] 64 | assert scaling in ['vp', 'none'] 65 | 66 | # Helper functions for VP & VE noise level schedules. 67 | vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5 68 | vp_sigma_deriv = lambda beta_d, beta_min: lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t)) 69 | vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * ( 70 | sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d 71 | ve_sigma = lambda t: t.sqrt() 72 | ve_sigma_deriv = lambda t: 0.5 / t.sqrt() 73 | ve_sigma_inv = lambda sigma: sigma ** 2 74 | 75 | # Select default noise level range based on the specified time step discretization. 76 | if sigma_min is None: 77 | vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=epsilon_s) 78 | sigma_min = {'vp': vp_def, 've': 0.02, 'iddpm': 0.002, 'edm': 0.002}[discretization] 79 | if sigma_max is None: 80 | vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=1) 81 | sigma_max = {'vp': vp_def, 've': 100, 'iddpm': 81, 'edm': 80}[discretization] 82 | 83 | # Adjust noise levels based on what's supported by the network. 84 | sigma_min = max(sigma_min, net.sigma_min) 85 | sigma_max = min(sigma_max, net.sigma_max) 86 | 87 | # Compute corresponding betas for VP. 88 | vp_beta_d = 2 * (np.log(sigma_min ** 2 + 1) / epsilon_s - np.log(sigma_max ** 2 + 1)) / (epsilon_s - 1) 89 | vp_beta_min = np.log(sigma_max ** 2 + 1) - 0.5 * vp_beta_d 90 | 91 | # Define time steps in terms of noise level. 92 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) 93 | if discretization == 'vp': 94 | orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1) 95 | sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps) 96 | elif discretization == 've': 97 | orig_t_steps = (sigma_max ** 2) * ((sigma_min ** 2 / sigma_max ** 2) ** (step_indices / (num_steps - 1))) 98 | sigma_steps = ve_sigma(orig_t_steps) 99 | elif discretization == 'iddpm': 100 | u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device) 101 | alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2 102 | for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1 103 | u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt() 104 | u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)] 105 | sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)] 106 | else: 107 | assert discretization == 'edm' 108 | sigma_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * ( 109 | sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho 110 | 111 | # Define noise level schedule. 112 | if schedule == 'vp': 113 | sigma = vp_sigma(vp_beta_d, vp_beta_min) 114 | sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min) 115 | sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min) 116 | elif schedule == 've': 117 | sigma = ve_sigma 118 | sigma_deriv = ve_sigma_deriv 119 | sigma_inv = ve_sigma_inv 120 | else: 121 | assert schedule == 'linear' 122 | sigma = lambda t: t 123 | sigma_deriv = lambda t: 1 124 | sigma_inv = lambda sigma: sigma 125 | 126 | # Define scaling schedule. 127 | if scaling == 'vp': 128 | s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt() 129 | s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3) 130 | else: 131 | assert scaling == 'none' 132 | s = lambda t: 1 133 | s_deriv = lambda t: 0 134 | 135 | # Compute final time steps based on the corresponding noise levels. 136 | t_steps = sigma_inv(net.round_sigma(sigma_steps)) 137 | t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 138 | 139 | # Main sampling loop. 140 | t_next = t_steps[0] 141 | x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next)) 142 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 143 | x_cur = x_next 144 | 145 | # Increase noise temporarily. 146 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0 147 | t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur))) 148 | x_hat = s(t_hat) / s(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * s( 149 | t_hat) * S_noise * randn_like(x_cur) 150 | 151 | # Euler step. 152 | h = t_next - t_hat 153 | denoised = net(x_hat.float() / s(t_hat), sigma(t_hat), class_labels, cfg_scale, feat=feat)['x'].to( 154 | torch.float64) 155 | d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)) * x_hat - sigma_deriv(t_hat) * s( 156 | t_hat) / sigma(t_hat) * denoised 157 | x_prime = x_hat + alpha * h * d_cur 158 | t_prime = t_hat + alpha * h 159 | 160 | # Apply 2nd order correction. 161 | if solver == 'euler' or i == num_steps - 1: 162 | x_next = x_hat + h * d_cur 163 | else: 164 | assert solver == 'heun' 165 | denoised = net(x_prime.float() / s(t_prime), sigma(t_prime), class_labels, cfg_scale, feat=feat)['x'].to( 166 | torch.float64) 167 | d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)) * x_prime - sigma_deriv( 168 | t_prime) * s(t_prime) / sigma(t_prime) * denoised 169 | x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime) 170 | 171 | return x_next 172 | -------------------------------------------------------------------------------- /diffusion/model/hed.py: -------------------------------------------------------------------------------- 1 | # This is an improved version and model of HED edge detection with Apache License, Version 2.0. 2 | # Please use this implementation in your products 3 | # This implementation may produce slightly different results from Saining Xie's official implementations, 4 | # but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations. 5 | # Different from official models and other implementations, this is an RGB-input model (rather than BGR) 6 | # and in this way it works better for gradio's RGB protocol 7 | import sys 8 | from pathlib import Path 9 | current_file_path = Path(__file__).resolve() 10 | sys.path.insert(0, str(current_file_path.parent.parent.parent)) 11 | from torch import nn 12 | import torch 13 | import numpy as np 14 | from torchvision import transforms as T 15 | from tqdm import tqdm 16 | from torch.utils.data import Dataset, DataLoader 17 | import json 18 | from PIL import Image 19 | import torchvision.transforms.functional as TF 20 | from accelerate import Accelerator 21 | from diffusers.models import AutoencoderKL 22 | import os 23 | 24 | image_resize = 1024 25 | 26 | 27 | class DoubleConvBlock(nn.Module): 28 | def __init__(self, input_channel, output_channel, layer_number): 29 | super().__init__() 30 | self.convs = torch.nn.Sequential() 31 | self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1)) 32 | for i in range(1, layer_number): 33 | self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1)) 34 | self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0) 35 | 36 | def forward(self, x, down_sampling=False): 37 | h = x 38 | if down_sampling: 39 | h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2)) 40 | for conv in self.convs: 41 | h = conv(h) 42 | h = torch.nn.functional.relu(h) 43 | return h, self.projection(h) 44 | 45 | 46 | class ControlNetHED_Apache2(nn.Module): 47 | def __init__(self): 48 | super().__init__() 49 | self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1))) 50 | self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2) 51 | self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2) 52 | self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3) 53 | self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3) 54 | self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3) 55 | 56 | def forward(self, x): 57 | h = x - self.norm 58 | h, projection1 = self.block1(h) 59 | h, projection2 = self.block2(h, down_sampling=True) 60 | h, projection3 = self.block3(h, down_sampling=True) 61 | h, projection4 = self.block4(h, down_sampling=True) 62 | h, projection5 = self.block5(h, down_sampling=True) 63 | return projection1, projection2, projection3, projection4, projection5 64 | 65 | 66 | class InternData(Dataset): 67 | def __init__(self): 68 | #### 69 | with open('data/InternData/partition/data_info.json', 'r') as f: 70 | self.j = json.load(f) 71 | self.transform = T.Compose([ 72 | T.Lambda(lambda img: img.convert('RGB')), 73 | T.Resize(image_resize), # Image.BICUBIC 74 | T.CenterCrop(image_resize), 75 | T.ToTensor(), 76 | ]) 77 | 78 | def __len__(self): 79 | return len(self.j) 80 | 81 | def getdata(self, idx): 82 | 83 | path = self.j[idx]['path'] 84 | image = Image.open("data/InternImgs/" + path) 85 | image = self.transform(image) 86 | return image, path 87 | 88 | def __getitem__(self, idx): 89 | for i in range(20): 90 | try: 91 | data = self.getdata(idx) 92 | return data 93 | except Exception as e: 94 | print(f"Error details: {str(e)}") 95 | idx = np.random.randint(len(self)) 96 | raise RuntimeError('Too many bad data.') 97 | 98 | class HEDdetector(nn.Module): 99 | def __init__(self, feature=True, vae=None): 100 | super().__init__() 101 | self.model = ControlNetHED_Apache2() 102 | self.model.load_state_dict(torch.load('output/pretrained_models/ControlNetHED.pth', map_location='cpu')) 103 | self.model.eval() 104 | self.model.requires_grad_(False) 105 | if feature: 106 | if vae is None: 107 | self.vae = AutoencoderKL.from_pretrained("output/pretrained_models/sd-vae-ft-ema") 108 | else: 109 | self.vae = vae 110 | self.vae.eval() 111 | self.vae.requires_grad_(False) 112 | else: 113 | self.vae = None 114 | 115 | def forward(self, input_image): 116 | B, C, H, W = input_image.shape 117 | with torch.inference_mode(): 118 | edges = self.model(input_image * 255.) 119 | edges = torch.cat([TF.resize(e, [H, W]) for e in edges], dim=1) 120 | edge = 1 / (1 + torch.exp(-torch.mean(edges, dim=1, keepdim=True))) 121 | edge.clip_(0, 1) 122 | if self.vae: 123 | edge = TF.normalize(edge, [.5], [.5]) 124 | edge = edge.repeat(1, 3, 1, 1) 125 | posterior = self.vae.encode(edge).latent_dist 126 | edge = torch.cat([posterior.mean, posterior.std], dim=1).cpu().numpy() 127 | return edge 128 | 129 | 130 | def main(): 131 | dataset = InternData() 132 | dataloader = DataLoader(dataset, batch_size=10, shuffle=False, num_workers=8, pin_memory=True) 133 | hed = HEDdetector() 134 | 135 | accelerator = Accelerator() 136 | hed, dataloader = accelerator.prepare(hed, dataloader) 137 | 138 | 139 | for img, path in tqdm(dataloader): 140 | out = hed(img.cuda()) 141 | for p, o in zip(path, out): 142 | save = f'data/InternalData/hed_feature_{image_resize}/' + p.replace('.png', '.npz') 143 | if os.path.exists(save): 144 | continue 145 | os.makedirs(os.path.dirname(save), exist_ok=True) 146 | np.savez_compressed(save, o) 147 | 148 | 149 | if __name__ == "__main__": 150 | main() 151 | -------------------------------------------------------------------------------- /diffusion/model/llava/__init__.py: -------------------------------------------------------------------------------- 1 | from diffusion.model.llava.llava_mpt import LlavaMPTForCausalLM, LlavaMPTConfig -------------------------------------------------------------------------------- /diffusion/model/llava/mpt/blocks.py: -------------------------------------------------------------------------------- 1 | """GPT Blocks used for the GPT Model.""" 2 | from typing import Dict, Optional, Tuple 3 | import torch 4 | import torch.nn as nn 5 | from .attention import ATTN_CLASS_REGISTRY 6 | from .norm import NORM_CLASS_REGISTRY 7 | 8 | class MPTMLP(nn.Module): 9 | 10 | def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None): 11 | super().__init__() 12 | self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device) 13 | self.act = nn.GELU(approximate='none') 14 | self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device) 15 | self.down_proj._is_residual = True 16 | 17 | def forward(self, x): 18 | return self.down_proj(self.act(self.up_proj(x))) 19 | 20 | class MPTBlock(nn.Module): 21 | 22 | def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict = None, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', device: Optional[str]=None, **kwargs): 23 | if attn_config is None: 24 | attn_config = { 25 | 'attn_type': 'multihead_attention', 26 | 'attn_pdrop': 0.0, 27 | 'attn_impl': 'triton', 28 | 'qk_ln': False, 29 | 'clip_qkv': None, 30 | 'softmax_scale': None, 31 | 'prefix_lm': False, 32 | 'attn_uses_sequence_id': False, 33 | 'alibi': False, 34 | 'alibi_bias_max': 8, 35 | } 36 | del kwargs 37 | super().__init__() 38 | norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] 39 | attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] 40 | self.norm_1 = norm_class(d_model, device=device) 41 | self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, device=device) 42 | self.norm_2 = norm_class(d_model, device=device) 43 | self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device) 44 | self.resid_attn_dropout = nn.Dropout(resid_pdrop) 45 | self.resid_ffn_dropout = nn.Dropout(resid_pdrop) 46 | 47 | def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: 48 | a = self.norm_1(x) 49 | (b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal) 50 | x = x + self.resid_attn_dropout(b) 51 | m = self.norm_2(x) 52 | n = self.ffn(m) 53 | x = x + self.resid_ffn_dropout(n) 54 | return (x, past_key_value) -------------------------------------------------------------------------------- /diffusion/model/llava/mpt/configuration_mpt.py: -------------------------------------------------------------------------------- 1 | """A HuggingFace-style model configuration.""" 2 | from typing import Dict, Optional, Union 3 | from transformers import PretrainedConfig 4 | attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8} 5 | init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu'} 6 | 7 | class MPTConfig(PretrainedConfig): 8 | model_type = 'mpt' 9 | 10 | def __init__(self, d_model: int=2048, n_heads: int=16, n_layers: int=24, expansion_ratio: int=4, max_seq_len: int=2048, vocab_size: int=50368, resid_pdrop: float=0.0, emb_pdrop: float=0.0, learned_pos_emb: bool=True, attn_config: Dict=attn_config_defaults, init_device: str='cpu', logit_scale: Optional[Union[float, str]]=None, no_bias: bool=False, verbose: int=0, embedding_fraction: float=1.0, norm_type: str='low_precision_layernorm', use_cache: bool=False, init_config: Dict=init_config_defaults, **kwargs): 11 | """The MPT configuration class. 12 | 13 | Args: 14 | d_model (int): The size of the embedding dimension of the model. 15 | n_heads (int): The number of attention heads. 16 | n_layers (int): The number of layers in the model. 17 | expansion_ratio (int): The ratio of the up/down scale in the MLP. 18 | max_seq_len (int): The maximum sequence length of the model. 19 | vocab_size (int): The size of the vocabulary. 20 | resid_pdrop (float): The dropout probability applied to the attention output before combining with residual. 21 | emb_pdrop (float): The dropout probability for the embedding layer. 22 | learned_pos_emb (bool): Whether to use learned positional embeddings 23 | attn_config (Dict): A dictionary used to configure the model's attention module: 24 | attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention 25 | attn_pdrop (float): The dropout probability for the attention layers. 26 | attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'. 27 | qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer. 28 | clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to 29 | this value. 30 | softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None, 31 | use the default scale of ``1/sqrt(d_keys)``. 32 | prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an 33 | extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix 34 | can attend to one another bi-directionally. Tokens outside the prefix use causal attention. 35 | attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id. 36 | When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates 37 | which sub-sequence each token belongs to. 38 | Defaults to ``False`` meaning any provided `sequence_id` will be ignored. 39 | alibi (bool): Whether to use the alibi bias instead of position embeddings. 40 | alibi_bias_max (int): The maximum value of the alibi bias. 41 | init_device (str): The device to use for parameter initialization. 42 | logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value. 43 | no_bias (bool): Whether to use bias in all layers. 44 | verbose (int): The verbosity level. 0 is silent. 45 | embedding_fraction (float): The fraction to scale the gradients of the embedding layer by. 46 | norm_type (str): choose type of norm to use 47 | multiquery_attention (bool): Whether to use multiquery attention implementation. 48 | use_cache (bool): Whether or not the model should return the last key/values attentions 49 | init_config (Dict): A dictionary used to configure the model initialization: 50 | init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_', 51 | 'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or 52 | 'xavier_normal_'. These mimic the parameter initialization methods in PyTorch. 53 | init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True. 54 | emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer. 55 | emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution 56 | used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``. 57 | init_std (float): The standard deviation of the normal distribution used to initialize the model, 58 | if using the baseline_ parameter initialization scheme. 59 | init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes. 60 | fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes. 61 | init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes. 62 | --- 63 | See llmfoundry.models.utils.param_init_fns.py for info on other param init config options 64 | """ 65 | self.d_model = d_model 66 | self.n_heads = n_heads 67 | self.n_layers = n_layers 68 | self.expansion_ratio = expansion_ratio 69 | self.max_seq_len = max_seq_len 70 | self.vocab_size = vocab_size 71 | self.resid_pdrop = resid_pdrop 72 | self.emb_pdrop = emb_pdrop 73 | self.learned_pos_emb = learned_pos_emb 74 | self.attn_config = attn_config 75 | self.init_device = init_device 76 | self.logit_scale = logit_scale 77 | self.no_bias = no_bias 78 | self.verbose = verbose 79 | self.embedding_fraction = embedding_fraction 80 | self.norm_type = norm_type 81 | self.use_cache = use_cache 82 | self.init_config = init_config 83 | if 'name' in kwargs: 84 | del kwargs['name'] 85 | if 'loss_fn' in kwargs: 86 | del kwargs['loss_fn'] 87 | super().__init__(**kwargs) 88 | self._validate_config() 89 | 90 | def _set_config_defaults(self, config, config_defaults): 91 | for (k, v) in config_defaults.items(): 92 | if k not in config: 93 | config[k] = v 94 | return config 95 | 96 | def _validate_config(self): 97 | self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults) 98 | self.init_config = self._set_config_defaults(self.init_config, init_config_defaults) 99 | if self.d_model % self.n_heads != 0: 100 | raise ValueError('d_model must be divisible by n_heads') 101 | if any((prob < 0 or prob > 1 for prob in [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop])): 102 | raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1") 103 | if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']: 104 | raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}") 105 | if self.attn_config['prefix_lm'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: 106 | raise NotImplementedError('prefix_lm only implemented with torch and triton attention.') 107 | if self.attn_config['alibi'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: 108 | raise NotImplementedError('alibi only implemented with torch and triton attention.') 109 | if self.attn_config['attn_uses_sequence_id'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: 110 | raise NotImplementedError('attn_uses_sequence_id only implemented with torch and triton attention.') 111 | if self.embedding_fraction > 1 or self.embedding_fraction <= 0: 112 | raise ValueError('model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!') 113 | if isinstance(self.logit_scale, str) and self.logit_scale != 'inv_sqrt_d_model': 114 | raise ValueError(f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.") 115 | if self.init_config.get('name', None) is None: 116 | raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.") 117 | if not self.learned_pos_emb and (not self.attn_config['alibi']): 118 | raise ValueError( 119 | 'Positional information must be provided to the model using either learned_pos_emb or alibi.' 120 | ) -------------------------------------------------------------------------------- /diffusion/model/llava/mpt/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def _cast_if_autocast_enabled(tensor): 4 | if torch.is_autocast_enabled(): 5 | if tensor.device.type == 'cuda': 6 | dtype = torch.get_autocast_gpu_dtype() 7 | elif tensor.device.type == 'cpu': 8 | dtype = torch.get_autocast_cpu_dtype() 9 | else: 10 | raise NotImplementedError() 11 | return tensor.to(dtype=dtype) 12 | return tensor 13 | 14 | class LPLayerNorm(torch.nn.LayerNorm): 15 | 16 | def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None): 17 | super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype) 18 | 19 | def forward(self, x): 20 | module_device = x.device 21 | downcast_x = _cast_if_autocast_enabled(x) 22 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight 23 | downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias 24 | with torch.autocast(enabled=False, device_type=module_device.type): 25 | return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps) 26 | 27 | def rms_norm(x, weight=None, eps=1e-05): 28 | output = x / torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) 29 | return output * weight if weight is not None else output 30 | 31 | class RMSNorm(torch.nn.Module): 32 | 33 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): 34 | super().__init__() 35 | self.eps = eps 36 | if weight: 37 | self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device)) 38 | else: 39 | self.register_parameter('weight', None) 40 | 41 | def forward(self, x): 42 | return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype) 43 | 44 | class LPRMSNorm(RMSNorm): 45 | 46 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): 47 | super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device) 48 | 49 | def forward(self, x): 50 | downcast_x = _cast_if_autocast_enabled(x) 51 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight 52 | with torch.autocast(enabled=False, device_type=x.device.type): 53 | return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype) 54 | NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm} -------------------------------------------------------------------------------- /diffusion/model/llava/mpt/param_init_fns.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | from collections.abc import Sequence 4 | from functools import partial 5 | from typing import Optional, Tuple, Union 6 | import torch 7 | from torch import nn 8 | from .norm import NORM_CLASS_REGISTRY 9 | 10 | def torch_default_param_init_fn_(module: nn.Module, verbose: int=0, **kwargs): 11 | del kwargs 12 | if verbose > 1: 13 | warnings.warn("Initializing network using module's reset_parameters attribute") 14 | if hasattr(module, 'reset_parameters'): 15 | module.reset_parameters() 16 | 17 | def fused_init_helper_(module: nn.Module, init_fn_): 18 | _fused = getattr(module, '_fused', None) 19 | if _fused is None: 20 | raise RuntimeError('Internal logic error') 21 | (dim, splits) = _fused 22 | splits = (0, *splits, module.weight.size(dim)) 23 | for (s, e) in zip(splits[:-1], splits[1:]): 24 | slice_indices = [slice(None)] * module.weight.ndim 25 | slice_indices[dim] = slice(s, e) 26 | init_fn_(module.weight[slice_indices]) 27 | 28 | def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): 29 | del kwargs 30 | if verbose > 1: 31 | warnings.warn('If model has bias parameters they are initialized to 0.') 32 | init_div_is_residual = init_div_is_residual 33 | if init_div_is_residual is False: 34 | div_is_residual = 1.0 35 | elif init_div_is_residual is True: 36 | div_is_residual = math.sqrt(2 * n_layers) 37 | elif isinstance(init_div_is_residual, (float, int)): 38 | div_is_residual = init_div_is_residual 39 | elif isinstance(init_div_is_residual, str) and init_div_is_residual.isnumeric(): 40 | div_is_residual = float(init_div_is_residual) 41 | else: 42 | div_is_residual = 1.0 43 | raise ValueError(f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}') 44 | if init_div_is_residual is not False and verbose > 1: 45 | warnings.warn( 46 | f'Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. Set `init_div_is_residual: false` in init config to disable this.' 47 | ) 48 | if isinstance(module, nn.Linear): 49 | if hasattr(module, '_fused'): 50 | fused_init_helper_(module, init_fn_) 51 | else: 52 | init_fn_(module.weight) 53 | if module.bias is not None: 54 | torch.nn.init.zeros_(module.bias) 55 | if init_div_is_residual is not False and getattr(module, '_is_residual', False): 56 | with torch.no_grad(): 57 | module.weight.div_(div_is_residual) 58 | elif isinstance(module, nn.Embedding): 59 | if emb_init_std is not None: 60 | std = emb_init_std 61 | if std == 0: 62 | warnings.warn('Embedding layer initialized to 0.') 63 | emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std) 64 | if verbose > 1: 65 | warnings.warn(f'Embedding layer initialized using normal distribution with mean=0 and std={std!r}.') 66 | elif emb_init_uniform_lim is not None: 67 | lim = emb_init_uniform_lim 68 | if isinstance(lim, Sequence): 69 | if len(lim) > 2: 70 | raise ValueError(f'Uniform init requires a min and a max limit. User input: {lim}.') 71 | if lim[0] == lim[1]: 72 | warnings.warn(f'Embedding layer initialized to {lim[0]}.') 73 | else: 74 | if lim == 0: 75 | warnings.warn('Embedding layer initialized to 0.') 76 | lim = [-lim, lim] 77 | (a, b) = lim 78 | emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b) 79 | if verbose > 1: 80 | warnings.warn(f'Embedding layer initialized using uniform distribution in range {lim}.') 81 | else: 82 | emb_init_fn_ = init_fn_ 83 | emb_init_fn_(module.weight) 84 | elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))): 85 | if verbose > 1: 86 | warnings.warn( 87 | 'Norm weights are set to 1. If norm layer has a bias it is initialized to 0.' 88 | ) 89 | if hasattr(module, 'weight') and module.weight is not None: 90 | torch.nn.init.ones_(module.weight) 91 | if hasattr(module, 'bias') and module.bias is not None: 92 | torch.nn.init.zeros_(module.bias) 93 | elif isinstance(module, nn.MultiheadAttention): 94 | if module._qkv_same_embed_dim: 95 | _extracted_from_generic_param_init_fn__69(module, d_model, init_fn_) 96 | else: 97 | assert module.q_proj_weight is not None and module.k_proj_weight is not None and (module.v_proj_weight is not None) 98 | assert module.in_proj_weight is None 99 | init_fn_(module.q_proj_weight) 100 | init_fn_(module.k_proj_weight) 101 | init_fn_(module.v_proj_weight) 102 | if module.in_proj_bias is not None: 103 | torch.nn.init.zeros_(module.in_proj_bias) 104 | if module.bias_k is not None: 105 | torch.nn.init.zeros_(module.bias_k) 106 | if module.bias_v is not None: 107 | torch.nn.init.zeros_(module.bias_v) 108 | init_fn_(module.out_proj.weight) 109 | if init_div_is_residual is not False and getattr(module.out_proj, '_is_residual', False): 110 | with torch.no_grad(): 111 | module.out_proj.weight.div_(div_is_residual) 112 | if module.out_proj.bias is not None: 113 | torch.nn.init.zeros_(module.out_proj.bias) 114 | else: 115 | for _ in module.parameters(recurse=False): 116 | raise NotImplementedError(f'{module.__class__.__name__} parameters are not initialized by param_init_fn.') 117 | 118 | 119 | # TODO Rename this here and in `generic_param_init_fn_` 120 | def _extracted_from_generic_param_init_fn__69(module, d_model, init_fn_): 121 | assert module.in_proj_weight is not None 122 | assert module.q_proj_weight is None and module.k_proj_weight is None and (module.v_proj_weight is None) 123 | assert d_model is not None 124 | _d = d_model 125 | splits = (0, _d, 2 * _d, 3 * _d) 126 | for (s, e) in zip(splits[:-1], splits[1:]): 127 | init_fn_(module.in_proj_weight[s:e]) 128 | 129 | def _normal_init_(std, mean=0.0): 130 | return partial(torch.nn.init.normal_, mean=mean, std=std) 131 | 132 | def _normal_param_init_fn_(module: nn.Module, std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): 133 | del kwargs 134 | init_fn_ = _normal_init_(std=std) 135 | if verbose > 1: 136 | warnings.warn(f'Using torch.nn.init.normal_ init fn mean=0.0, std={std}') 137 | generic_param_init_fn_(module=module, init_fn_=init_fn_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) 138 | 139 | def baseline_param_init_fn_(module: nn.Module, init_std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): 140 | del kwargs 141 | if init_std is None: 142 | raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.") 143 | _normal_param_init_fn_(module=module, std=init_std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) 144 | 145 | def small_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): 146 | del kwargs 147 | std = math.sqrt(2 / (5 * d_model)) 148 | _normal_param_init_fn_(module=module, std=std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) 149 | 150 | def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): 151 | """From section 2.3.1 of GPT-NeoX-20B: 152 | 153 | An Open-Source AutoregressiveLanguage Model — Black et. al. (2022) 154 | see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151 155 | and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py 156 | """ 157 | del kwargs 158 | residual_div = n_layers / math.sqrt(10) 159 | if verbose > 1: 160 | warnings.warn(f'setting init_div_is_residual to {residual_div}') 161 | small_param_init_fn_(module=module, d_model=d_model, n_layers=n_layers, init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) 162 | 163 | def kaiming_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs): 164 | del kwargs 165 | if verbose > 1: 166 | warnings.warn( 167 | f'Using nn.init.kaiming_uniform_ init fn with parameters: a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}' 168 | ) 169 | kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity) 170 | generic_param_init_fn_(module=module, init_fn_=kaiming_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) 171 | 172 | def kaiming_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs): 173 | del kwargs 174 | if verbose > 1: 175 | warnings.warn( 176 | f'Using nn.init.kaiming_normal_ init fn with parameters: a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}' 177 | ) 178 | kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity) 179 | generic_param_init_fn_(module=module, init_fn_=kaiming_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) 180 | 181 | def xavier_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs): 182 | del kwargs 183 | xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain) 184 | if verbose > 1: 185 | warnings.warn( 186 | f'Using torch.nn.init.xavier_uniform_ init fn with parameters: gain={init_gain}' 187 | ) 188 | generic_param_init_fn_(module=module, init_fn_=xavier_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) 189 | 190 | def xavier_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs): 191 | xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain) 192 | if verbose > 1: 193 | warnings.warn( 194 | f'Using torch.nn.init.xavier_normal_ init fn with parameters: gain={init_gain}' 195 | ) 196 | generic_param_init_fn_(module=module, init_fn_=xavier_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) 197 | MODEL_INIT_REGISTRY = {'default_': torch_default_param_init_fn_, 'baseline_': baseline_param_init_fn_, 'kaiming_uniform_': kaiming_uniform_param_init_fn_, 'kaiming_normal_': kaiming_normal_param_init_fn_, 'neox_init_': neox_param_init_fn_, 'small_init_': small_param_init_fn_, 'xavier_uniform_': xavier_uniform_param_init_fn_, 'xavier_normal_': xavier_normal_param_init_fn_} -------------------------------------------------------------------------------- /diffusion/model/nets/PixArtMS.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # GLIDE: https://github.com/openai/glide-text2im 9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py 10 | # -------------------------------------------------------- 11 | import torch 12 | import torch.nn as nn 13 | from timm.models.layers import DropPath 14 | from timm.models.vision_transformer import Mlp 15 | 16 | from diffusion.model.builder import MODELS 17 | from diffusion.model.utils import auto_grad_checkpoint, to_2tuple 18 | from diffusion.model.nets.PixArt_blocks import t2i_modulate, CaptionEmbedder, WindowAttention, MultiHeadCrossAttention, T2IFinalLayer, TimestepEmbedder, SizeEmbedder 19 | from diffusion.model.nets.PixArt import PixArt, get_2d_sincos_pos_embed 20 | 21 | 22 | class PatchEmbed(nn.Module): 23 | """ 2D Image to Patch Embedding 24 | """ 25 | def __init__( 26 | self, 27 | patch_size=16, 28 | in_chans=3, 29 | embed_dim=768, 30 | norm_layer=None, 31 | flatten=True, 32 | bias=True, 33 | ): 34 | super().__init__() 35 | patch_size = to_2tuple(patch_size) 36 | self.patch_size = patch_size 37 | self.flatten = flatten 38 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) 39 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 40 | 41 | def forward(self, x): 42 | x = self.proj(x) 43 | if self.flatten: 44 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 45 | x = self.norm(x) 46 | return x 47 | 48 | 49 | class PixArtMSBlock(nn.Module): 50 | """ 51 | A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning. 52 | """ 53 | 54 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., window_size=0, input_size=None, use_rel_pos=False, **block_kwargs): 55 | super().__init__() 56 | self.hidden_size = hidden_size 57 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 58 | self.attn = WindowAttention(hidden_size, num_heads=num_heads, qkv_bias=True, 59 | input_size=input_size if window_size == 0 else (window_size, window_size), 60 | use_rel_pos=use_rel_pos, **block_kwargs) 61 | self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs) 62 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 63 | # to be compatible with lower version pytorch 64 | approx_gelu = lambda: nn.GELU(approximate="tanh") 65 | self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0) 66 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 67 | self.window_size = window_size 68 | self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5) 69 | 70 | def forward(self, x, y, t, mask=None, **kwargs): 71 | B, N, C = x.shape 72 | 73 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1) 74 | x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa))) 75 | x = x + self.cross_attn(x, y, mask) 76 | x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) 77 | 78 | return x 79 | 80 | 81 | ############################################################################# 82 | # Core PixArt Model # 83 | ################################################################################# 84 | @MODELS.register_module() 85 | class PixArtMS(PixArt): 86 | """ 87 | Diffusion model with a Transformer backbone. 88 | """ 89 | 90 | def __init__(self, input_size=32, patch_size=2, in_channels=4, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, learn_sigma=True, pred_sigma=True, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, caption_channels=4096, lewei_scale=1., config=None, model_max_length=120, **kwargs): 91 | if window_block_indexes is None: 92 | window_block_indexes = [] 93 | super().__init__( 94 | input_size=input_size, 95 | patch_size=patch_size, 96 | in_channels=in_channels, 97 | hidden_size=hidden_size, 98 | depth=depth, 99 | num_heads=num_heads, 100 | mlp_ratio=mlp_ratio, 101 | class_dropout_prob=class_dropout_prob, 102 | learn_sigma=learn_sigma, 103 | pred_sigma=pred_sigma, 104 | drop_path=drop_path, 105 | window_size=window_size, 106 | window_block_indexes=window_block_indexes, 107 | use_rel_pos=use_rel_pos, 108 | lewei_scale=lewei_scale, 109 | config=config, 110 | model_max_length=model_max_length, 111 | **kwargs, 112 | ) 113 | self.h = self.w = 0 114 | approx_gelu = lambda: nn.GELU(approximate="tanh") 115 | self.t_block = nn.Sequential( 116 | nn.SiLU(), 117 | nn.Linear(hidden_size, 6 * hidden_size, bias=True) 118 | ) 119 | self.x_embedder = PatchEmbed(patch_size, in_channels, hidden_size, bias=True) 120 | self.y_embedder = CaptionEmbedder(in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob, act_layer=approx_gelu, token_num=model_max_length) 121 | self.csize_embedder = SizeEmbedder(hidden_size//3) # c_size embed 122 | self.ar_embedder = SizeEmbedder(hidden_size//3) # aspect ratio embed 123 | drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule 124 | self.blocks = nn.ModuleList([ 125 | PixArtMSBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i], 126 | input_size=(input_size // patch_size, input_size // patch_size), 127 | window_size=window_size if i in window_block_indexes else 0, 128 | use_rel_pos=use_rel_pos if i in window_block_indexes else False) 129 | for i in range(depth) 130 | ]) 131 | self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels) 132 | 133 | self.initialize() 134 | 135 | def forward(self, x, timestep, y, mask=None, data_info=None, **kwargs): 136 | """ 137 | Forward pass of PixArt. 138 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) 139 | t: (N,) tensor of diffusion timesteps 140 | y: (N, 1, 120, C) tensor of class labels 141 | """ 142 | bs = x.shape[0] 143 | x = x.to(self.dtype) 144 | timestep = timestep.to(self.dtype) 145 | y = y.to(self.dtype) 146 | c_size, ar = data_info['img_hw'].to(self.dtype), data_info['aspect_ratio'].to(self.dtype) 147 | self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size 148 | pos_embed = torch.from_numpy(get_2d_sincos_pos_embed(self.pos_embed.shape[-1], (self.h, self.w), lewei_scale=self.lewei_scale, base_size=self.base_size)).unsqueeze(0).to(x.device).to(self.dtype) 149 | x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 150 | t = self.t_embedder(timestep) # (N, D) 151 | csize = self.csize_embedder(c_size, bs) # (N, D) 152 | ar = self.ar_embedder(ar, bs) # (N, D) 153 | t = t + torch.cat([csize, ar], dim=1) 154 | t0 = self.t_block(t) 155 | y = self.y_embedder(y, self.training) # (N, D) 156 | if mask is not None: 157 | if mask.shape[0] != y.shape[0]: 158 | mask = mask.repeat(y.shape[0] // mask.shape[0], 1) 159 | mask = mask.squeeze(1).squeeze(1) 160 | y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) 161 | y_lens = mask.sum(dim=1).tolist() 162 | else: 163 | y_lens = [y.shape[2]] * y.shape[0] 164 | y = y.squeeze(1).view(1, -1, x.shape[-1]) 165 | for block in self.blocks: 166 | x = auto_grad_checkpoint(block, x, y, t0, y_lens, **kwargs) # (N, T, D) #support grad checkpoint 167 | x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) 168 | x = self.unpatchify(x) # (N, out_channels, H, W) 169 | return x 170 | 171 | def forward_with_dpmsolver(self, x, timestep, y, data_info, **kwargs): 172 | """ 173 | dpm solver donnot need variance prediction 174 | """ 175 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb 176 | model_out = self.forward(x, timestep, y, data_info=data_info, **kwargs) 177 | return model_out.chunk(2, dim=1)[0] 178 | 179 | def forward_with_cfg(self, x, timestep, y, cfg_scale, data_info, **kwargs): 180 | """ 181 | Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance. 182 | """ 183 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb 184 | half = x[: len(x) // 2] 185 | combined = torch.cat([half, half], dim=0) 186 | model_out = self.forward(combined, timestep, y, data_info=data_info) 187 | eps, rest = model_out[:, :3], model_out[:, 3:] 188 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) 189 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) 190 | eps = torch.cat([half_eps, half_eps], dim=0) 191 | return torch.cat([eps, rest], dim=1) 192 | 193 | def unpatchify(self, x): 194 | """ 195 | x: (N, T, patch_size**2 * C) 196 | imgs: (N, H, W, C) 197 | """ 198 | c = self.out_channels 199 | p = self.x_embedder.patch_size[0] 200 | assert self.h * self.w == x.shape[1] 201 | 202 | x = x.reshape(shape=(x.shape[0], self.h, self.w, p, p, c)) 203 | x = torch.einsum('nhwpqc->nchpwq', x) 204 | return x.reshape(shape=(x.shape[0], c, self.h * p, self.w * p)) 205 | 206 | def initialize(self): 207 | # Initialize transformer layers: 208 | def _basic_init(module): 209 | if isinstance(module, nn.Linear): 210 | torch.nn.init.xavier_uniform_(module.weight) 211 | if module.bias is not None: 212 | nn.init.constant_(module.bias, 0) 213 | 214 | self.apply(_basic_init) 215 | 216 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): 217 | w = self.x_embedder.proj.weight.data 218 | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 219 | 220 | # Initialize timestep embedding MLP: 221 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) 222 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) 223 | nn.init.normal_(self.t_block[1].weight, std=0.02) 224 | nn.init.normal_(self.csize_embedder.mlp[0].weight, std=0.02) 225 | nn.init.normal_(self.csize_embedder.mlp[2].weight, std=0.02) 226 | nn.init.normal_(self.ar_embedder.mlp[0].weight, std=0.02) 227 | nn.init.normal_(self.ar_embedder.mlp[2].weight, std=0.02) 228 | 229 | # Initialize caption embedding MLP: 230 | nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02) 231 | nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02) 232 | 233 | # Zero-out adaLN modulation layers in PixArt blocks: 234 | for block in self.blocks: 235 | nn.init.constant_(block.cross_attn.proj.weight, 0) 236 | nn.init.constant_(block.cross_attn.proj.bias, 0) 237 | 238 | # Zero-out output layers: 239 | nn.init.constant_(self.final_layer.linear.weight, 0) 240 | nn.init.constant_(self.final_layer.linear.bias, 0) 241 | 242 | 243 | ################################################################################# 244 | # PixArt Configs # 245 | ################################################################################# 246 | @MODELS.register_module() 247 | def PixArtMS_XL_2(**kwargs): 248 | return PixArtMS(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) 249 | -------------------------------------------------------------------------------- /diffusion/model/nets/__init__.py: -------------------------------------------------------------------------------- 1 | from .PixArt import PixArt, PixArt_XL_2 2 | from .PixArtMS import PixArtMS, PixArtMS_XL_2, PixArtMSBlock 3 | from .pixart_controlnet import ControlPixArtHalf, ControlPixArtMSHalf -------------------------------------------------------------------------------- /diffusion/model/nets/pixart_controlnet.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import torch.nn as nn 4 | 5 | from copy import deepcopy 6 | from torch import Tensor 7 | from torch.nn import Module, Linear, init 8 | from typing import Any, Mapping 9 | 10 | from diffusion.model.nets import PixArtMSBlock, PixArtMS, PixArt 11 | from diffusion.model.nets.PixArt import get_2d_sincos_pos_embed 12 | from diffusion.model.utils import auto_grad_checkpoint 13 | 14 | 15 | # The implementation of ControlNet-Half architrecture 16 | # https://github.com/lllyasviel/ControlNet/discussions/188 17 | class ControlT2IDitBlockHalf(Module): 18 | def __init__(self, base_block: PixArtMSBlock, block_index: 0) -> None: 19 | super().__init__() 20 | self.copied_block = deepcopy(base_block) 21 | self.block_index = block_index 22 | 23 | for p in self.copied_block.parameters(): 24 | p.requires_grad_(True) 25 | 26 | self.copied_block.load_state_dict(base_block.state_dict()) 27 | self.copied_block.train() 28 | 29 | self.hidden_size = hidden_size = base_block.hidden_size 30 | if self.block_index == 0: 31 | self.before_proj = Linear(hidden_size, hidden_size) 32 | init.zeros_(self.before_proj.weight) 33 | init.zeros_(self.before_proj.bias) 34 | self.after_proj = Linear(hidden_size, hidden_size) 35 | init.zeros_(self.after_proj.weight) 36 | init.zeros_(self.after_proj.bias) 37 | 38 | def forward(self, x, y, t, mask=None, c=None): 39 | 40 | if self.block_index == 0: 41 | # the first block 42 | c = self.before_proj(c) 43 | c = self.copied_block(x + c, y, t, mask) 44 | c_skip = self.after_proj(c) 45 | else: 46 | # load from previous c and produce the c for skip connection 47 | c = self.copied_block(c, y, t, mask) 48 | c_skip = self.after_proj(c) 49 | 50 | return c, c_skip 51 | 52 | 53 | # The implementation of ControlPixArtHalf net 54 | class ControlPixArtHalf(Module): 55 | # only support single res model 56 | def __init__(self, base_model: PixArt, copy_blocks_num: int = 13) -> None: 57 | super().__init__() 58 | self.base_model = base_model.eval() 59 | self.controlnet = [] 60 | self.copy_blocks_num = copy_blocks_num 61 | self.total_blocks_num = len(base_model.blocks) 62 | for p in self.base_model.parameters(): 63 | p.requires_grad_(False) 64 | 65 | # Copy first copy_blocks_num block 66 | for i in range(copy_blocks_num): 67 | self.controlnet.append(ControlT2IDitBlockHalf(base_model.blocks[i], i)) 68 | self.controlnet = nn.ModuleList(self.controlnet) 69 | 70 | def __getattr__(self, name: str) -> Tensor or Module: 71 | if name in ['forward', 'forward_with_dpmsolver', 'forward_with_cfg', 'forward_c', 'load_state_dict']: 72 | return self.__dict__[name] 73 | elif name in ['base_model', 'controlnet']: 74 | return super().__getattr__(name) 75 | else: 76 | return getattr(self.base_model, name) 77 | 78 | def forward_c(self, c): 79 | self.h, self.w = c.shape[-2]//self.patch_size, c.shape[-1]//self.patch_size 80 | pos_embed = torch.from_numpy(get_2d_sincos_pos_embed(self.pos_embed.shape[-1], (self.h, self.w), lewei_scale=self.lewei_scale, base_size=self.base_size)).unsqueeze(0).to(c.device).to(self.dtype) 81 | return self.x_embedder(c) + pos_embed if c is not None else c 82 | 83 | # def forward(self, x, t, c, **kwargs): 84 | # return self.base_model(x, t, c=self.forward_c(c), **kwargs) 85 | def forward(self, x, timestep, y, mask=None, data_info=None, c=None, **kwargs): 86 | # modify the original PixArtMS forward function 87 | if c is not None: 88 | c = c.to(self.dtype) 89 | c = self.forward_c(c) 90 | """ 91 | Forward pass of PixArt. 92 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) 93 | t: (N,) tensor of diffusion timesteps 94 | y: (N, 1, 120, C) tensor of class labels 95 | """ 96 | x = x.to(self.dtype) 97 | timestep = timestep.to(self.dtype) 98 | y = y.to(self.dtype) 99 | pos_embed = self.pos_embed.to(self.dtype) 100 | self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size 101 | x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 102 | t = self.t_embedder(timestep.to(x.dtype)) # (N, D) 103 | t0 = self.t_block(t) 104 | y = self.y_embedder(y, self.training) # (N, 1, L, D) 105 | if mask is not None: 106 | if mask.shape[0] != y.shape[0]: 107 | mask = mask.repeat(y.shape[0] // mask.shape[0], 1) 108 | mask = mask.squeeze(1).squeeze(1) 109 | y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) 110 | y_lens = mask.sum(dim=1).tolist() 111 | else: 112 | y_lens = [y.shape[2]] * y.shape[0] 113 | y = y.squeeze(1).view(1, -1, x.shape[-1]) 114 | 115 | # define the first layer 116 | x = auto_grad_checkpoint(self.base_model.blocks[0], x, y, t0, y_lens, **kwargs) # (N, T, D) #support grad checkpoint 117 | 118 | if c is not None: 119 | # update c 120 | for index in range(1, self.copy_blocks_num + 1): 121 | c, c_skip = auto_grad_checkpoint(self.controlnet[index - 1], x, y, t0, y_lens, c, **kwargs) 122 | x = auto_grad_checkpoint(self.base_model.blocks[index], x + c_skip, y, t0, y_lens, **kwargs) 123 | 124 | # update x 125 | for index in range(self.copy_blocks_num + 1, self.total_blocks_num): 126 | x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, **kwargs) 127 | else: 128 | for index in range(1, self.total_blocks_num): 129 | x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, **kwargs) 130 | 131 | x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) 132 | x = self.unpatchify(x) # (N, out_channels, H, W) 133 | return x 134 | 135 | def forward_with_dpmsolver(self, x, t, y, data_info, c, **kwargs): 136 | model_out = self.forward(x, t, y, data_info=data_info, c=c, **kwargs) 137 | return model_out.chunk(2, dim=1)[0] 138 | 139 | # def forward_with_dpmsolver(self, x, t, y, data_info, c, **kwargs): 140 | # return self.base_model.forward_with_dpmsolver(x, t, y, data_info=data_info, c=self.forward_c(c), **kwargs) 141 | 142 | def forward_with_cfg(self, x, t, y, cfg_scale, data_info, c, **kwargs): 143 | return self.base_model.forward_with_cfg(x, t, y, cfg_scale, data_info, c=self.forward_c(c), **kwargs) 144 | 145 | def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): 146 | if all((k.startswith('base_model') or k.startswith('controlnet')) for k in state_dict.keys()): 147 | return super().load_state_dict(state_dict, strict) 148 | else: 149 | new_key = {} 150 | for k in state_dict.keys(): 151 | new_key[k] = re.sub(r"(blocks\.\d+)(.*)", r"\1.base_block\2", k) 152 | for k, v in new_key.items(): 153 | if k != v: 154 | print(f"replace {k} to {v}") 155 | state_dict[v] = state_dict.pop(k) 156 | 157 | return self.base_model.load_state_dict(state_dict, strict) 158 | 159 | def unpatchify(self, x): 160 | """ 161 | x: (N, T, patch_size**2 * C) 162 | imgs: (N, H, W, C) 163 | """ 164 | c = self.out_channels 165 | p = self.x_embedder.patch_size[0] 166 | assert self.h * self.w == x.shape[1] 167 | 168 | x = x.reshape(shape=(x.shape[0], self.h, self.w, p, p, c)) 169 | x = torch.einsum('nhwpqc->nchpwq', x) 170 | imgs = x.reshape(shape=(x.shape[0], c, self.h * p, self.w * p)) 171 | return imgs 172 | 173 | @property 174 | def dtype(self): 175 | # 返回模型参数的数据类型 176 | return next(self.parameters()).dtype 177 | 178 | 179 | # The implementation for PixArtMS_Half + 1024 resolution 180 | class ControlPixArtMSHalf(ControlPixArtHalf): 181 | # support multi-scale res model (multi-scale model can also be applied to single reso training & inference) 182 | def __init__(self, base_model: PixArtMS, copy_blocks_num: int = 13) -> None: 183 | super().__init__(base_model=base_model, copy_blocks_num=copy_blocks_num) 184 | 185 | def forward(self, x, timestep, y, mask=None, data_info=None, c=None, **kwargs): 186 | # modify the original PixArtMS forward function 187 | """ 188 | Forward pass of PixArt. 189 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) 190 | t: (N,) tensor of diffusion timesteps 191 | y: (N, 1, 120, C) tensor of class labels 192 | """ 193 | if c is not None: 194 | c = c.to(self.dtype) 195 | c = self.forward_c(c) 196 | bs = x.shape[0] 197 | x = x.to(self.dtype) 198 | timestep = timestep.to(self.dtype) 199 | y = y.to(self.dtype) 200 | c_size, ar = data_info['img_hw'].to(self.dtype), data_info['aspect_ratio'].to(self.dtype) 201 | self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size 202 | 203 | pos_embed = torch.from_numpy(get_2d_sincos_pos_embed(self.pos_embed.shape[-1], (self.h, self.w), lewei_scale=self.lewei_scale, base_size=self.base_size)).unsqueeze(0).to(x.device).to(self.dtype) 204 | x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 205 | t = self.t_embedder(timestep) # (N, D) 206 | csize = self.csize_embedder(c_size, bs) # (N, D) 207 | ar = self.ar_embedder(ar, bs) # (N, D) 208 | t = t + torch.cat([csize, ar], dim=1) 209 | t0 = self.t_block(t) 210 | y = self.y_embedder(y, self.training) # (N, D) 211 | if mask is not None: 212 | if mask.shape[0] != y.shape[0]: 213 | mask = mask.repeat(y.shape[0] // mask.shape[0], 1) 214 | mask = mask.squeeze(1).squeeze(1) 215 | y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) 216 | y_lens = mask.sum(dim=1).tolist() 217 | else: 218 | y_lens = [y.shape[2]] * y.shape[0] 219 | y = y.squeeze(1).view(1, -1, x.shape[-1]) 220 | 221 | # define the first layer 222 | x = auto_grad_checkpoint(self.base_model.blocks[0], x, y, t0, y_lens, **kwargs) # (N, T, D) #support grad checkpoint 223 | 224 | if c is not None: 225 | # update c 226 | for index in range(1, self.copy_blocks_num + 1): 227 | c, c_skip = auto_grad_checkpoint(self.controlnet[index - 1], x, y, t0, y_lens, c, **kwargs) 228 | x = auto_grad_checkpoint(self.base_model.blocks[index], x + c_skip, y, t0, y_lens, **kwargs) 229 | 230 | # update x 231 | for index in range(self.copy_blocks_num + 1, self.total_blocks_num): 232 | x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, **kwargs) 233 | else: 234 | for index in range(1, self.total_blocks_num): 235 | x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, **kwargs) 236 | 237 | x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) 238 | x = self.unpatchify(x) # (N, out_channels, H, W) 239 | return x 240 | -------------------------------------------------------------------------------- /diffusion/model/respace.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import numpy as np 7 | import torch as th 8 | 9 | from .gaussian_diffusion import GaussianDiffusion 10 | 11 | 12 | def space_timesteps(num_timesteps, section_counts): 13 | """ 14 | Create a list of timesteps to use from an original diffusion process, 15 | given the number of timesteps we want to take from equally-sized portions 16 | of the original process. 17 | For example, if there's 300 timesteps and the section counts are [10,15,20] 18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 19 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 20 | If the stride is a string starting with "ddim", then the fixed striding 21 | from the DDIM paper is used, and only one section is allowed. 22 | :param num_timesteps: the number of diffusion steps in the original 23 | process to divide up. 24 | :param section_counts: either a list of numbers, or a string containing 25 | comma-separated numbers, indicating the step count 26 | per section. As a special case, use "ddimN" where N 27 | is a number of steps to use the striding from the 28 | DDIM paper. 29 | :return: a set of diffusion steps from the original process to use. 30 | """ 31 | if isinstance(section_counts, str): 32 | if section_counts.startswith("ddim"): 33 | desired_count = int(section_counts[len("ddim") :]) 34 | for i in range(1, num_timesteps): 35 | if len(range(0, num_timesteps, i)) == desired_count: 36 | return set(range(0, num_timesteps, i)) 37 | raise ValueError( 38 | f"cannot create exactly {num_timesteps} steps with an integer stride" 39 | ) 40 | section_counts = [int(x) for x in section_counts.split(",")] 41 | size_per = num_timesteps // len(section_counts) 42 | extra = num_timesteps % len(section_counts) 43 | start_idx = 0 44 | all_steps = [] 45 | for i, section_count in enumerate(section_counts): 46 | size = size_per + (1 if i < extra else 0) 47 | if size < section_count: 48 | raise ValueError( 49 | f"cannot divide section of {size} steps into {section_count}" 50 | ) 51 | frac_stride = 1 if section_count <= 1 else (size - 1) / (section_count - 1) 52 | cur_idx = 0.0 53 | taken_steps = [] 54 | for _ in range(section_count): 55 | taken_steps.append(start_idx + round(cur_idx)) 56 | cur_idx += frac_stride 57 | all_steps += taken_steps 58 | start_idx += size 59 | return set(all_steps) 60 | 61 | 62 | class SpacedDiffusion(GaussianDiffusion): 63 | """ 64 | A diffusion process which can skip steps in a base diffusion process. 65 | :param use_timesteps: a collection (sequence or set) of timesteps from the 66 | original diffusion process to retain. 67 | :param kwargs: the kwargs to create the base diffusion process. 68 | """ 69 | 70 | def __init__(self, use_timesteps, **kwargs): 71 | self.use_timesteps = set(use_timesteps) 72 | self.timestep_map = [] 73 | self.original_num_steps = len(kwargs["betas"]) 74 | 75 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 76 | last_alpha_cumprod = 1.0 77 | new_betas = [] 78 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 79 | if i in self.use_timesteps: 80 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 81 | last_alpha_cumprod = alpha_cumprod 82 | self.timestep_map.append(i) 83 | kwargs["betas"] = np.array(new_betas) 84 | super().__init__(**kwargs) 85 | 86 | def p_mean_variance( 87 | self, model, *args, **kwargs 88 | ): # pylint: disable=signature-differs 89 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 90 | 91 | def training_losses( 92 | self, model, *args, **kwargs 93 | ): # pylint: disable=signature-differs 94 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 95 | 96 | def training_losses_diffusers( 97 | self, model, *args, **kwargs 98 | ): # pylint: disable=signature-differs 99 | return super().training_losses_diffusers(self._wrap_model(model), *args, **kwargs) 100 | 101 | def condition_mean(self, cond_fn, *args, **kwargs): 102 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 103 | 104 | def condition_score(self, cond_fn, *args, **kwargs): 105 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 106 | 107 | def _wrap_model(self, model): 108 | if isinstance(model, _WrappedModel): 109 | return model 110 | return _WrappedModel( 111 | model, self.timestep_map, self.original_num_steps 112 | ) 113 | 114 | def _scale_timesteps(self, t): 115 | # Scaling is done by the wrapped model. 116 | return t 117 | 118 | 119 | class _WrappedModel: 120 | def __init__(self, model, timestep_map, original_num_steps): 121 | self.model = model 122 | self.timestep_map = timestep_map 123 | # self.rescale_timesteps = rescale_timesteps 124 | self.original_num_steps = original_num_steps 125 | 126 | def __call__(self, x, timestep, **kwargs): 127 | map_tensor = th.tensor(self.timestep_map, device=timestep.device, dtype=timestep.dtype) 128 | new_ts = map_tensor[timestep] 129 | # if self.rescale_timesteps: 130 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 131 | return self.model(x, timestep=new_ts, **kwargs) 132 | -------------------------------------------------------------------------------- /diffusion/model/t5.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import re 4 | import html 5 | import urllib.parse as ul 6 | 7 | import ftfy 8 | import torch 9 | from bs4 import BeautifulSoup 10 | from transformers import T5EncoderModel, AutoTokenizer 11 | from huggingface_hub import hf_hub_download 12 | 13 | class T5Embedder: 14 | 15 | available_models = ['t5-v1_1-xxl'] 16 | bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}') # noqa 17 | 18 | def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, local_cache=False, cache_dir=None, hf_token=None, use_text_preprocessing=True, 19 | t5_model_kwargs=None, torch_dtype=None, use_offload_folder=None, model_max_length=120): 20 | self.device = torch.device(device) 21 | self.torch_dtype = torch_dtype or torch.bfloat16 22 | if t5_model_kwargs is None: 23 | t5_model_kwargs = {'low_cpu_mem_usage': True, 'torch_dtype': self.torch_dtype} 24 | if use_offload_folder is not None: 25 | t5_model_kwargs['offload_folder'] = use_offload_folder 26 | t5_model_kwargs['device_map'] = { 27 | 'shared': self.device, 28 | 'encoder.embed_tokens': self.device, 29 | 'encoder.block.0': self.device, 30 | 'encoder.block.1': self.device, 31 | 'encoder.block.2': self.device, 32 | 'encoder.block.3': self.device, 33 | 'encoder.block.4': self.device, 34 | 'encoder.block.5': self.device, 35 | 'encoder.block.6': self.device, 36 | 'encoder.block.7': self.device, 37 | 'encoder.block.8': self.device, 38 | 'encoder.block.9': self.device, 39 | 'encoder.block.10': self.device, 40 | 'encoder.block.11': self.device, 41 | 'encoder.block.12': 'disk', 42 | 'encoder.block.13': 'disk', 43 | 'encoder.block.14': 'disk', 44 | 'encoder.block.15': 'disk', 45 | 'encoder.block.16': 'disk', 46 | 'encoder.block.17': 'disk', 47 | 'encoder.block.18': 'disk', 48 | 'encoder.block.19': 'disk', 49 | 'encoder.block.20': 'disk', 50 | 'encoder.block.21': 'disk', 51 | 'encoder.block.22': 'disk', 52 | 'encoder.block.23': 'disk', 53 | 'encoder.final_layer_norm': 'disk', 54 | 'encoder.dropout': 'disk', 55 | } 56 | else: 57 | t5_model_kwargs['device_map'] = {'shared': self.device, 'encoder': self.device} 58 | 59 | self.use_text_preprocessing = use_text_preprocessing 60 | self.hf_token = hf_token 61 | self.cache_dir = cache_dir or os.path.expanduser('~/.cache/IF_') 62 | self.dir_or_name = dir_or_name 63 | tokenizer_path, path = dir_or_name, dir_or_name 64 | if local_cache: 65 | cache_dir = os.path.join(self.cache_dir, dir_or_name) 66 | tokenizer_path, path = cache_dir, cache_dir 67 | elif dir_or_name in self.available_models: 68 | cache_dir = os.path.join(self.cache_dir, dir_or_name) 69 | for filename in [ 70 | 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json', 71 | 'pytorch_model.bin.index.json', 'pytorch_model-00001-of-00002.bin', 'pytorch_model-00002-of-00002.bin' 72 | ]: 73 | hf_hub_download(repo_id=f'DeepFloyd/{dir_or_name}', filename=filename, cache_dir=cache_dir, 74 | force_filename=filename, token=self.hf_token) 75 | tokenizer_path, path = cache_dir, cache_dir 76 | else: 77 | cache_dir = os.path.join(self.cache_dir, 't5-v1_1-xxl') 78 | for filename in [ 79 | 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json', 80 | ]: 81 | hf_hub_download(repo_id='DeepFloyd/t5-v1_1-xxl', filename=filename, cache_dir=cache_dir, 82 | force_filename=filename, token=self.hf_token) 83 | tokenizer_path = cache_dir 84 | 85 | print(tokenizer_path) 86 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) 87 | self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval() 88 | self.model_max_length = model_max_length 89 | 90 | def get_text_embeddings(self, texts): 91 | texts = [self.text_preprocessing(text) for text in texts] 92 | 93 | text_tokens_and_mask = self.tokenizer( 94 | texts, 95 | max_length=self.model_max_length, 96 | padding='max_length', 97 | truncation=True, 98 | return_attention_mask=True, 99 | add_special_tokens=True, 100 | return_tensors='pt' 101 | ) 102 | 103 | text_tokens_and_mask['input_ids'] = text_tokens_and_mask['input_ids'] 104 | text_tokens_and_mask['attention_mask'] = text_tokens_and_mask['attention_mask'] 105 | 106 | with torch.no_grad(): 107 | text_encoder_embs = self.model( 108 | input_ids=text_tokens_and_mask['input_ids'].to(self.device), 109 | attention_mask=text_tokens_and_mask['attention_mask'].to(self.device), 110 | )['last_hidden_state'].detach() 111 | return text_encoder_embs, text_tokens_and_mask['attention_mask'].to(self.device) 112 | 113 | def text_preprocessing(self, text): 114 | if self.use_text_preprocessing: 115 | # The exact text cleaning as was in the training stage: 116 | text = self.clean_caption(text) 117 | text = self.clean_caption(text) 118 | return text 119 | else: 120 | return text.lower().strip() 121 | 122 | @staticmethod 123 | def basic_clean(text): 124 | text = ftfy.fix_text(text) 125 | text = html.unescape(html.unescape(text)) 126 | return text.strip() 127 | 128 | def clean_caption(self, caption): 129 | caption = str(caption) 130 | caption = ul.unquote_plus(caption) 131 | caption = caption.strip().lower() 132 | caption = re.sub('', 'person', caption) 133 | # urls: 134 | caption = re.sub( 135 | r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa 136 | '', caption) # regex for urls 137 | caption = re.sub( 138 | r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa 139 | '', caption) # regex for urls 140 | # html: 141 | caption = BeautifulSoup(caption, features='html.parser').text 142 | 143 | # @ 144 | caption = re.sub(r'@[\w\d]+\b', '', caption) 145 | 146 | # 31C0—31EF CJK Strokes 147 | # 31F0—31FF Katakana Phonetic Extensions 148 | # 3200—32FF Enclosed CJK Letters and Months 149 | # 3300—33FF CJK Compatibility 150 | # 3400—4DBF CJK Unified Ideographs Extension A 151 | # 4DC0—4DFF Yijing Hexagram Symbols 152 | # 4E00—9FFF CJK Unified Ideographs 153 | caption = re.sub(r'[\u31c0-\u31ef]+', '', caption) 154 | caption = re.sub(r'[\u31f0-\u31ff]+', '', caption) 155 | caption = re.sub(r'[\u3200-\u32ff]+', '', caption) 156 | caption = re.sub(r'[\u3300-\u33ff]+', '', caption) 157 | caption = re.sub(r'[\u3400-\u4dbf]+', '', caption) 158 | caption = re.sub(r'[\u4dc0-\u4dff]+', '', caption) 159 | caption = re.sub(r'[\u4e00-\u9fff]+', '', caption) 160 | ####################################################### 161 | 162 | # все виды тире / all types of dash --> "-" 163 | caption = re.sub( 164 | r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+', # noqa 165 | '-', caption) 166 | 167 | # кавычки к одному стандарту 168 | caption = re.sub(r'[`´«»“”¨]', '"', caption) 169 | caption = re.sub(r'[‘’]', "'", caption) 170 | 171 | # " 172 | caption = re.sub(r'"?', '', caption) 173 | # & 174 | caption = re.sub(r'&', '', caption) 175 | 176 | # ip adresses: 177 | caption = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', ' ', caption) 178 | 179 | # article ids: 180 | caption = re.sub(r'\d:\d\d\s+$', '', caption) 181 | 182 | # \n 183 | caption = re.sub(r'\\n', ' ', caption) 184 | 185 | # "#123" 186 | caption = re.sub(r'#\d{1,3}\b', '', caption) 187 | # "#12345.." 188 | caption = re.sub(r'#\d{5,}\b', '', caption) 189 | # "123456.." 190 | caption = re.sub(r'\b\d{6,}\b', '', caption) 191 | # filenames: 192 | caption = re.sub(r'[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption) 193 | 194 | # 195 | caption = re.sub(r'[\"\']{2,}', r'"', caption) # """AUSVERKAUFT""" 196 | caption = re.sub(r'[\.]{2,}', r' ', caption) # """AUSVERKAUFT""" 197 | 198 | caption = re.sub(self.bad_punct_regex, r' ', caption) # ***AUSVERKAUFT***, #AUSVERKAUFT 199 | caption = re.sub(r'\s+\.\s+', r' ', caption) # " . " 200 | 201 | # this-is-my-cute-cat / this_is_my_cute_cat 202 | regex2 = re.compile(r'(?:\-|\_)') 203 | if len(re.findall(regex2, caption)) > 3: 204 | caption = re.sub(regex2, ' ', caption) 205 | 206 | caption = self.basic_clean(caption) 207 | 208 | caption = re.sub(r'\b[a-zA-Z]{1,3}\d{3,15}\b', '', caption) # jc6640 209 | caption = re.sub(r'\b[a-zA-Z]+\d+[a-zA-Z]+\b', '', caption) # jc6640vc 210 | caption = re.sub(r'\b\d+[a-zA-Z]+\d+\b', '', caption) # 6640vc231 211 | 212 | caption = re.sub(r'(worldwide\s+)?(free\s+)?shipping', '', caption) 213 | caption = re.sub(r'(free\s)?download(\sfree)?', '', caption) 214 | caption = re.sub(r'\bclick\b\s(?:for|on)\s\w+', '', caption) 215 | caption = re.sub(r'\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?', '', caption) 216 | caption = re.sub(r'\bpage\s+\d+\b', '', caption) 217 | 218 | caption = re.sub(r'\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b', r' ', caption) # j2d1a2a... 219 | 220 | caption = re.sub(r'\b\d+\.?\d*[xх×]\d+\.?\d*\b', '', caption) 221 | 222 | caption = re.sub(r'\b\s+\:\s+', r': ', caption) 223 | caption = re.sub(r'(\D[,\./])\b', r'\1 ', caption) 224 | caption = re.sub(r'\s+', ' ', caption) 225 | 226 | caption.strip() 227 | 228 | caption = re.sub(r'^[\"\']([\w\W]+)[\"\']$', r'\1', caption) 229 | caption = re.sub(r'^[\'\_,\-\:;]', r'', caption) 230 | caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption) 231 | caption = re.sub(r'^\.\S+$', '', caption) 232 | 233 | return caption.strip() 234 | -------------------------------------------------------------------------------- /diffusion/model/timestep_sampler.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from abc import ABC, abstractmethod 7 | 8 | import numpy as np 9 | import torch as th 10 | import torch.distributed as dist 11 | 12 | 13 | def create_named_schedule_sampler(name, diffusion): 14 | """ 15 | Create a ScheduleSampler from a library of pre-defined samplers. 16 | :param name: the name of the sampler. 17 | :param diffusion: the diffusion object to sample for. 18 | """ 19 | if name == "uniform": 20 | return UniformSampler(diffusion) 21 | elif name == "loss-second-moment": 22 | return LossSecondMomentResampler(diffusion) 23 | else: 24 | raise NotImplementedError(f"unknown schedule sampler: {name}") 25 | 26 | 27 | class ScheduleSampler(ABC): 28 | """ 29 | A distribution over timesteps in the diffusion process, intended to reduce 30 | variance of the objective. 31 | By default, samplers perform unbiased importance sampling, in which the 32 | objective's mean is unchanged. 33 | However, subclasses may override sample() to change how the resampled 34 | terms are reweighted, allowing for actual changes in the objective. 35 | """ 36 | 37 | @abstractmethod 38 | def weights(self): 39 | """ 40 | Get a numpy array of weights, one per diffusion step. 41 | The weights needn't be normalized, but must be positive. 42 | """ 43 | 44 | def sample(self, batch_size, device): 45 | """ 46 | Importance-sample timesteps for a batch. 47 | :param batch_size: the number of timesteps. 48 | :param device: the torch device to save to. 49 | :return: a tuple (timesteps, weights): 50 | - timesteps: a tensor of timestep indices. 51 | - weights: a tensor of weights to scale the resulting losses. 52 | """ 53 | w = self.weights() 54 | p = w / np.sum(w) 55 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 56 | indices = th.from_numpy(indices_np).long().to(device) 57 | weights_np = 1 / (len(p) * p[indices_np]) 58 | weights = th.from_numpy(weights_np).float().to(device) 59 | return indices, weights 60 | 61 | 62 | class UniformSampler(ScheduleSampler): 63 | def __init__(self, diffusion): 64 | self.diffusion = diffusion 65 | self._weights = np.ones([diffusion.num_timesteps]) 66 | 67 | def weights(self): 68 | return self._weights 69 | 70 | 71 | class LossAwareSampler(ScheduleSampler): 72 | def update_with_local_losses(self, local_ts, local_losses): 73 | """ 74 | Update the reweighting using losses from a model. 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | :param local_ts: an integer Tensor of timesteps. 80 | :param local_losses: a 1D Tensor of losses. 81 | """ 82 | batch_sizes = [ 83 | th.tensor([0], dtype=th.int32, device=local_ts.device) 84 | for _ in range(dist.get_world_size()) 85 | ] 86 | dist.all_gather( 87 | batch_sizes, 88 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 89 | ) 90 | 91 | # Pad all_gather batches to be the maximum batch size. 92 | batch_sizes = [x.item() for x in batch_sizes] 93 | max_bs = max(batch_sizes) 94 | 95 | timestep_batches = [th.zeros(max_bs, device=local_ts.device) for _ in batch_sizes] 96 | loss_batches = [th.zeros(max_bs, device=local_losses.device) for _ in batch_sizes] 97 | dist.all_gather(timestep_batches, local_ts) 98 | dist.all_gather(loss_batches, local_losses) 99 | timesteps = [ 100 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 101 | ] 102 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 103 | self.update_with_all_losses(timesteps, losses) 104 | 105 | @abstractmethod 106 | def update_with_all_losses(self, ts, losses): 107 | """ 108 | Update the reweighting using losses from a model. 109 | Sub-classes should override this method to update the reweighting 110 | using losses from the model. 111 | This method directly updates the reweighting without synchronizing 112 | between workers. It is called by update_with_local_losses from all 113 | ranks with identical arguments. Thus, it should have deterministic 114 | behavior to maintain state across workers. 115 | :param ts: a list of int timesteps. 116 | :param losses: a list of float losses, one per timestep. 117 | """ 118 | 119 | 120 | class LossSecondMomentResampler(LossAwareSampler): 121 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 122 | self.diffusion = diffusion 123 | self.history_per_term = history_per_term 124 | self.uniform_prob = uniform_prob 125 | self._loss_history = np.zeros( 126 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 127 | ) 128 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 129 | 130 | def weights(self): 131 | if not self._warmed_up(): 132 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 133 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 134 | weights /= np.sum(weights) 135 | weights *= 1 - self.uniform_prob 136 | weights += self.uniform_prob / len(weights) 137 | return weights 138 | 139 | def update_with_all_losses(self, ts, losses): 140 | for t, loss in zip(ts, losses): 141 | if self._loss_counts[t] == self.history_per_term: 142 | # Shift out the oldest loss term. 143 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 144 | self._loss_history[t, -1] = loss 145 | else: 146 | self._loss_history[t, self._loss_counts[t]] = loss 147 | self._loss_counts[t] += 1 148 | 149 | def _warmed_up(self): 150 | return (self._loss_counts == self.history_per_term).all() 151 | -------------------------------------------------------------------------------- /diffusion/sa_sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | import numpy as np 5 | 6 | from diffusion.model.sa_solver import NoiseScheduleVP, model_wrapper, SASolver 7 | from .model import gaussian_diffusion as gd 8 | 9 | 10 | class SASolverSampler(object): 11 | def __init__(self, model, 12 | noise_schedule="linear", 13 | diffusion_steps=1000, 14 | device='cpu', 15 | ): 16 | super().__init__() 17 | self.model = model 18 | self.device = device 19 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(device) 20 | betas = torch.tensor(gd.get_named_beta_schedule(noise_schedule, diffusion_steps)) 21 | alphas = 1.0 - betas 22 | self.register_buffer('alphas_cumprod', to_torch(np.cumprod(alphas, axis=0))) 23 | 24 | def register_buffer(self, name, attr): 25 | if type(attr) == torch.Tensor and attr.device != torch.device("cuda"): 26 | attr = attr.to(torch.device("cuda")) 27 | setattr(self, name, attr) 28 | 29 | @torch.no_grad() 30 | def sample(self, S, batch_size, shape, conditioning=None, callback=None, normals_sequence=None, img_callback=None, quantize_x0=False, eta=0., mask=None, x0=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, verbose=True, x_T=None, log_every_t=100, unconditional_guidance_scale=1., unconditional_conditioning=None, model_kwargs=None, **kwargs): 31 | if model_kwargs is None: 32 | model_kwargs = {} 33 | if conditioning is not None: 34 | if isinstance(conditioning, dict): 35 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 36 | if cbs != batch_size: 37 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 38 | elif conditioning.shape[0] != batch_size: 39 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 40 | 41 | # sampling 42 | C, H, W = shape 43 | size = (batch_size, C, H, W) 44 | 45 | device = self.device 46 | img = torch.randn(size, device=device) if x_T is None else x_T 47 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) 48 | 49 | model_fn = model_wrapper( 50 | self.model, 51 | ns, 52 | model_type="noise", 53 | guidance_type="classifier-free", 54 | condition=conditioning, 55 | unconditional_condition=unconditional_conditioning, 56 | guidance_scale=unconditional_guidance_scale, 57 | model_kwargs=model_kwargs, 58 | ) 59 | 60 | sasolver = SASolver(model_fn, ns, algorithm_type="data_prediction") 61 | 62 | tau_t = lambda t: eta if 0.2 <= t <= 0.8 else 0 63 | 64 | x = sasolver.sample(mode='few_steps', x=img, tau=tau_t, steps=S, skip_type='time', skip_order=1, predictor_order=2, corrector_order=2, pc_mode='PEC', return_intermediate=False) 65 | 66 | return x.to(device), None -------------------------------------------------------------------------------- /diffusion/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HongkLin/TIDE/d0cba53604b3dd9e16e4f81ce24d7f4be3ba0d4e/diffusion/utils/__init__.py -------------------------------------------------------------------------------- /diffusion/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import torch 4 | 5 | from diffusion.utils.logger import get_root_logger 6 | 7 | 8 | def save_checkpoint(work_dir, 9 | epoch, 10 | model, 11 | model_ema=None, 12 | optimizer=None, 13 | lr_scheduler=None, 14 | keep_last=False, 15 | step=None, 16 | ): 17 | os.makedirs(work_dir, exist_ok=True) 18 | state_dict = dict(state_dict=model.state_dict()) 19 | if model_ema is not None: 20 | state_dict['state_dict_ema'] = model_ema.state_dict() 21 | if optimizer is not None: 22 | state_dict['optimizer'] = optimizer.state_dict() 23 | if lr_scheduler is not None: 24 | state_dict['scheduler'] = lr_scheduler.state_dict() 25 | if epoch is not None: 26 | state_dict['epoch'] = epoch 27 | file_path = os.path.join(work_dir, f"epoch_{epoch}.pth") 28 | if step is not None: 29 | file_path = file_path.split('.pth')[0] + f"_step_{step}.pth" 30 | logger = get_root_logger() 31 | torch.save(state_dict, file_path) 32 | logger.info(f'Saved checkpoint of epoch {epoch} to {file_path.format(epoch)}.') 33 | if keep_last: 34 | for i in range(epoch): 35 | previous_ckgt = file_path.format(i) 36 | if os.path.exists(previous_ckgt): 37 | os.remove(previous_ckgt) 38 | 39 | 40 | def load_checkpoint(checkpoint, 41 | model, 42 | model_ema=None, 43 | optimizer=None, 44 | lr_scheduler=None, 45 | load_ema=False, 46 | resume_optimizer=True, 47 | resume_lr_scheduler=True 48 | ): 49 | assert isinstance(checkpoint, str) 50 | ckpt_file = checkpoint 51 | checkpoint = torch.load(ckpt_file, map_location="cpu") 52 | 53 | state_dict_keys = ['pos_embed', 'base_model.pos_embed', 'model.pos_embed'] 54 | for key in state_dict_keys: 55 | if key in checkpoint['state_dict']: 56 | del checkpoint['state_dict'][key] 57 | if 'state_dict_ema' in checkpoint and key in checkpoint['state_dict_ema']: 58 | del checkpoint['state_dict_ema'][key] 59 | break 60 | 61 | if load_ema: 62 | state_dict = checkpoint['state_dict_ema'] 63 | else: 64 | state_dict = checkpoint.get('state_dict', checkpoint) # to be compatible with the official checkpoint 65 | # model.load_state_dict(state_dict) 66 | missing, unexpect = model.load_state_dict(state_dict, strict=False) 67 | if model_ema is not None: 68 | model_ema.load_state_dict(checkpoint['state_dict_ema'], strict=False) 69 | if optimizer is not None and resume_optimizer: 70 | optimizer.load_state_dict(checkpoint['optimizer']) 71 | if lr_scheduler is not None and resume_lr_scheduler: 72 | lr_scheduler.load_state_dict(checkpoint['scheduler']) 73 | logger = get_root_logger() 74 | if optimizer is not None: 75 | epoch = checkpoint.get('epoch', re.match(r'.*epoch_(\d*).*.pth', ckpt_file).group()[0]) 76 | logger.info(f'Resume checkpoint of epoch {epoch} from {ckpt_file}. Load ema: {load_ema}, ' 77 | f'resume optimizer: {resume_optimizer}, resume lr scheduler: {resume_lr_scheduler}.') 78 | return epoch, missing, unexpect 79 | logger.info(f'Load checkpoint from {ckpt_file}. Load ema: {load_ema}.') 80 | return missing, unexpect 81 | -------------------------------------------------------------------------------- /diffusion/utils/data_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os 3 | from typing import Sequence 4 | from torch.utils.data import BatchSampler, Sampler, Dataset 5 | from random import shuffle, choice 6 | from copy import deepcopy 7 | from diffusion.utils.logger import get_root_logger 8 | 9 | 10 | class AspectRatioBatchSampler(BatchSampler): 11 | """A sampler wrapper for grouping images with similar aspect ratio into a same batch. 12 | 13 | Args: 14 | sampler (Sampler): Base sampler. 15 | dataset (Dataset): Dataset providing data information. 16 | batch_size (int): Size of mini-batch. 17 | drop_last (bool): If ``True``, the sampler will drop the last batch if 18 | its size would be less than ``batch_size``. 19 | aspect_ratios (dict): The predefined aspect ratios. 20 | """ 21 | 22 | def __init__(self, 23 | sampler: Sampler, 24 | dataset: Dataset, 25 | batch_size: int, 26 | aspect_ratios: dict, 27 | drop_last: bool = False, 28 | config=None, 29 | valid_num=0, # take as valid aspect-ratio when sample number >= valid_num 30 | **kwargs) -> None: 31 | if not isinstance(sampler, Sampler): 32 | raise TypeError('sampler should be an instance of ``Sampler``, ' 33 | f'but got {sampler}') 34 | if not isinstance(batch_size, int) or batch_size <= 0: 35 | raise ValueError('batch_size should be a positive integer value, ' 36 | f'but got batch_size={batch_size}') 37 | self.sampler = sampler 38 | self.dataset = dataset 39 | self.batch_size = batch_size 40 | self.aspect_ratios = aspect_ratios 41 | self.drop_last = drop_last 42 | self.ratio_nums_gt = kwargs.get('ratio_nums', None) 43 | self.config = config 44 | assert self.ratio_nums_gt 45 | # buckets for each aspect ratio 46 | self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios} 47 | self.current_available_bucket_keys = [str(k) for k, v in self.ratio_nums_gt.items() if v >= valid_num] 48 | logger = get_root_logger() if config is None else get_root_logger(os.path.join(config.work_dir, 'train_log.log')) 49 | logger.warning(f"Using valid_num={valid_num} in config file. Available {len(self.current_available_bucket_keys)} aspect_ratios: {self.current_available_bucket_keys}") 50 | 51 | def __iter__(self) -> Sequence[int]: 52 | for idx in self.sampler: 53 | data_info = self.dataset.get_data_info(idx) 54 | height, width = data_info['height'], data_info['width'] 55 | ratio = height / width 56 | # find the closest aspect ratio 57 | closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio)) 58 | if closest_ratio not in self.current_available_bucket_keys: 59 | continue 60 | bucket = self._aspect_ratio_buckets[closest_ratio] 61 | bucket.append(idx) 62 | # yield a batch of indices in the same aspect ratio group 63 | if len(bucket) == self.batch_size: 64 | yield bucket[:] 65 | del bucket[:] 66 | 67 | # yield the rest data and reset the buckets 68 | for bucket in self._aspect_ratio_buckets.values(): 69 | while len(bucket) > 0: 70 | if len(bucket) <= self.batch_size: 71 | if not self.drop_last: 72 | yield bucket[:] 73 | bucket = [] 74 | else: 75 | yield bucket[:self.batch_size] 76 | bucket = bucket[self.batch_size:] 77 | 78 | 79 | class BalancedAspectRatioBatchSampler(AspectRatioBatchSampler): 80 | def __init__(self, *args, **kwargs): 81 | super().__init__(*args, **kwargs) 82 | # Assign samples to each bucket 83 | self.ratio_nums_gt = kwargs.get('ratio_nums', None) 84 | assert self.ratio_nums_gt 85 | self._aspect_ratio_buckets = {float(ratio): [] for ratio in self.aspect_ratios.keys()} 86 | self.original_buckets = {} 87 | self.current_available_bucket_keys = [k for k, v in self.ratio_nums_gt.items() if v >= 3000] 88 | self.all_available_keys = deepcopy(self.current_available_bucket_keys) 89 | self.exhausted_bucket_keys = [] 90 | self.total_batches = len(self.sampler) // self.batch_size 91 | self._aspect_ratio_count = {} 92 | for k in self.all_available_keys: 93 | self._aspect_ratio_count[float(k)] = 0 94 | self.original_buckets[float(k)] = [] 95 | logger = get_root_logger(os.path.join(self.config.work_dir, 'train_log.log')) 96 | logger.warning(f"Available {len(self.current_available_bucket_keys)} aspect_ratios: {self.current_available_bucket_keys}") 97 | 98 | def __iter__(self) -> Sequence[int]: 99 | i = 0 100 | for idx in self.sampler: 101 | data_info = self.dataset.get_data_info(idx) 102 | height, width = data_info['height'], data_info['width'] 103 | ratio = height / width 104 | closest_ratio = float(min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))) 105 | if closest_ratio not in self.all_available_keys: 106 | continue 107 | if self._aspect_ratio_count[closest_ratio] < self.ratio_nums_gt[closest_ratio]: 108 | self._aspect_ratio_count[closest_ratio] += 1 109 | self._aspect_ratio_buckets[closest_ratio].append(idx) 110 | self.original_buckets[closest_ratio].append(idx) # Save the original samples for each bucket 111 | if not self.current_available_bucket_keys: 112 | self.current_available_bucket_keys, self.exhausted_bucket_keys = self.exhausted_bucket_keys, [] 113 | 114 | if closest_ratio not in self.current_available_bucket_keys: 115 | continue 116 | key = closest_ratio 117 | bucket = self._aspect_ratio_buckets[key] 118 | if len(bucket) == self.batch_size: 119 | yield bucket[:self.batch_size] 120 | del bucket[:self.batch_size] 121 | i += 1 122 | self.exhausted_bucket_keys.append(key) 123 | self.current_available_bucket_keys.remove(key) 124 | 125 | for _ in range(self.total_batches - i): 126 | key = choice(self.all_available_keys) 127 | bucket = self._aspect_ratio_buckets[key] 128 | if len(bucket) >= self.batch_size: 129 | yield bucket[:self.batch_size] 130 | del bucket[:self.batch_size] 131 | 132 | # If a bucket is exhausted 133 | if not bucket: 134 | self._aspect_ratio_buckets[key] = deepcopy(self.original_buckets[key][:]) 135 | shuffle(self._aspect_ratio_buckets[key]) 136 | else: 137 | self._aspect_ratio_buckets[key] = deepcopy(self.original_buckets[key][:]) 138 | shuffle(self._aspect_ratio_buckets[key]) 139 | -------------------------------------------------------------------------------- /diffusion/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains primitives for multi-gpu communication. 3 | This is useful when doing distributed training. 4 | """ 5 | import os 6 | import pickle 7 | import shutil 8 | 9 | import gc 10 | import mmcv 11 | import torch 12 | import torch.distributed as dist 13 | from mmcv.runner import get_dist_info 14 | 15 | 16 | def is_distributed(): 17 | return get_world_size() > 1 18 | 19 | 20 | def get_world_size(): 21 | if not dist.is_available(): 22 | return 1 23 | return dist.get_world_size() if dist.is_initialized() else 1 24 | 25 | 26 | def get_rank(): 27 | if not dist.is_available(): 28 | return 0 29 | return dist.get_rank() if dist.is_initialized() else 0 30 | 31 | 32 | def get_local_rank(): 33 | if not dist.is_available(): 34 | return 0 35 | return int(os.getenv('LOCAL_RANK', 0)) if dist.is_initialized() else 0 36 | 37 | 38 | def is_master(): 39 | return get_rank() == 0 40 | 41 | 42 | def is_local_master(): 43 | return get_local_rank() == 0 44 | 45 | 46 | def get_local_proc_group(group_size=8): 47 | world_size = get_world_size() 48 | if world_size <= group_size or group_size == 1: 49 | return None 50 | assert world_size % group_size == 0, f'world size ({world_size}) should be evenly divided by group size ({group_size}).' 51 | process_groups = getattr(get_local_proc_group, 'process_groups', {}) 52 | if group_size not in process_groups: 53 | num_groups = dist.get_world_size() // group_size 54 | groups = [list(range(i * group_size, (i + 1) * group_size)) for i in range(num_groups)] 55 | process_groups.update({group_size: [torch.distributed.new_group(group) for group in groups]}) 56 | get_local_proc_group.process_groups = process_groups 57 | 58 | group_idx = get_rank() // group_size 59 | return get_local_proc_group.process_groups.get(group_size)[group_idx] 60 | 61 | 62 | def synchronize(): 63 | """ 64 | Helper function to synchronize (barrier) among all processes when 65 | using distributed training 66 | """ 67 | if not dist.is_available(): 68 | return 69 | if not dist.is_initialized(): 70 | return 71 | world_size = dist.get_world_size() 72 | if world_size == 1: 73 | return 74 | dist.barrier() 75 | 76 | 77 | def all_gather(data): 78 | """ 79 | Run all_gather on arbitrary picklable data (not necessarily tensors) 80 | Args: 81 | data: any picklable object 82 | Returns: 83 | list[data]: list of data gathered from each rank 84 | """ 85 | to_device = torch.device("cuda") 86 | # to_device = torch.device("cpu") 87 | 88 | world_size = get_world_size() 89 | if world_size == 1: 90 | return [data] 91 | 92 | # serialized to a Tensor 93 | buffer = pickle.dumps(data) 94 | storage = torch.ByteStorage.from_buffer(buffer) 95 | tensor = torch.ByteTensor(storage).to(to_device) 96 | 97 | # obtain Tensor size of each rank 98 | local_size = torch.LongTensor([tensor.numel()]).to(to_device) 99 | size_list = [torch.LongTensor([0]).to(to_device) for _ in range(world_size)] 100 | dist.all_gather(size_list, local_size) 101 | size_list = [int(size.item()) for size in size_list] 102 | max_size = max(size_list) 103 | 104 | tensor_list = [ 105 | torch.ByteTensor(size=(max_size,)).to(to_device) for _ in size_list 106 | ] 107 | if local_size != max_size: 108 | padding = torch.ByteTensor(size=(max_size - local_size,)).to(to_device) 109 | tensor = torch.cat((tensor, padding), dim=0) 110 | dist.all_gather(tensor_list, tensor) 111 | 112 | data_list = [] 113 | for size, tensor in zip(size_list, tensor_list): 114 | buffer = tensor.cpu().numpy().tobytes()[:size] 115 | data_list.append(pickle.loads(buffer)) 116 | 117 | return data_list 118 | 119 | 120 | def reduce_dict(input_dict, average=True): 121 | """ 122 | Args: 123 | input_dict (dict): all the values will be reduced 124 | average (bool): whether to do average or sum 125 | Reduce the values in the dictionary from all processes so that process with rank 126 | 0 has the averaged results. Returns a dict with the same fields as 127 | input_dict, after reduction. 128 | """ 129 | world_size = get_world_size() 130 | if world_size < 2: 131 | return input_dict 132 | with torch.no_grad(): 133 | reduced_dict = _extracted_from_reduce_dict_14(input_dict, average, world_size) 134 | return reduced_dict 135 | 136 | 137 | # TODO Rename this here and in `reduce_dict` 138 | def _extracted_from_reduce_dict_14(input_dict, average, world_size): 139 | names = [] 140 | values = [] 141 | # sort the keys so that they are consistent across processes 142 | for k in sorted(input_dict.keys()): 143 | names.append(k) 144 | values.append(input_dict[k]) 145 | values = torch.stack(values, dim=0) 146 | dist.reduce(values, dst=0) 147 | if dist.get_rank() == 0 and average: 148 | # only main process gets accumulated, so only divide by 149 | # world_size in this case 150 | values /= world_size 151 | return dict(zip(names, values)) 152 | 153 | 154 | def broadcast(data, **kwargs): 155 | if get_world_size() == 1: 156 | return data 157 | data = [data] 158 | dist.broadcast_object_list(data, **kwargs) 159 | return data[0] 160 | 161 | 162 | def all_gather_cpu(result_part, tmpdir=None, collect_by_master=True): 163 | rank, world_size = get_dist_info() 164 | if tmpdir is None: 165 | tmpdir = './tmp' 166 | if rank == 0: 167 | mmcv.mkdir_or_exist(tmpdir) 168 | synchronize() 169 | # dump the part result to the dir 170 | mmcv.dump(result_part, os.path.join(tmpdir, f'part_{rank}.pkl')) 171 | synchronize() 172 | if collect_by_master and rank != 0: 173 | return None 174 | # load results of all parts from tmp dir 175 | results = [] 176 | for i in range(world_size): 177 | part_file = os.path.join(tmpdir, f'part_{i}.pkl') 178 | results.append(mmcv.load(part_file)) 179 | if not collect_by_master: 180 | synchronize() 181 | # remove tmp dir 182 | if rank == 0: 183 | shutil.rmtree(tmpdir) 184 | return results 185 | 186 | def all_gather_tensor(tensor, group_size=None, group=None): 187 | if group_size is None: 188 | group_size = get_world_size() 189 | if group_size == 1: 190 | output = [tensor] 191 | else: 192 | output = [torch.zeros_like(tensor) for _ in range(group_size)] 193 | dist.all_gather(output, tensor, group=group) 194 | return output 195 | 196 | 197 | def gather_difflen_tensor(feat, num_samples_list, concat=True, group=None, group_size=None): 198 | world_size = get_world_size() 199 | if world_size == 1: 200 | return feat if concat else [feat] 201 | num_samples, *feat_dim = feat.size() 202 | # padding to max number of samples 203 | feat_padding = feat.new_zeros((max(num_samples_list), *feat_dim)) 204 | feat_padding[:num_samples] = feat 205 | # gather 206 | feat_gather = all_gather_tensor(feat_padding, group=group, group_size=group_size) 207 | for r, num in enumerate(num_samples_list): 208 | feat_gather[r] = feat_gather[r][:num] 209 | if concat: 210 | feat_gather = torch.cat(feat_gather) 211 | return feat_gather 212 | 213 | 214 | class GatherLayer(torch.autograd.Function): 215 | '''Gather tensors from all process, supporting backward propagation. 216 | ''' 217 | 218 | @staticmethod 219 | def forward(ctx, input): 220 | ctx.save_for_backward(input) 221 | num_samples = torch.tensor(input.size(0), dtype=torch.long, device=input.device) 222 | ctx.num_samples_list = all_gather_tensor(num_samples) 223 | output = gather_difflen_tensor(input, ctx.num_samples_list, concat=False) 224 | return tuple(output) 225 | 226 | @staticmethod 227 | def backward(ctx, *grads): # tuple(output)'s grad 228 | input, = ctx.saved_tensors 229 | num_samples_list = ctx.num_samples_list 230 | rank = get_rank() 231 | start, end = sum(num_samples_list[:rank]), sum(num_samples_list[:rank + 1]) 232 | grads = torch.cat(grads) 233 | if is_distributed(): 234 | dist.all_reduce(grads) 235 | grad_out = torch.zeros_like(input) 236 | grad_out[:] = grads[start:end] 237 | return grad_out, None, None 238 | 239 | 240 | class GatherLayerWithGroup(torch.autograd.Function): 241 | '''Gather tensors from all process, supporting backward propagation. 242 | ''' 243 | 244 | @staticmethod 245 | def forward(ctx, input, group, group_size): 246 | ctx.save_for_backward(input) 247 | ctx.group_size = group_size 248 | output = all_gather_tensor(input, group=group, group_size=group_size) 249 | return tuple(output) 250 | 251 | @staticmethod 252 | def backward(ctx, *grads): # tuple(output)'s grad 253 | input, = ctx.saved_tensors 254 | grads = torch.stack(grads) 255 | if is_distributed(): 256 | dist.all_reduce(grads) 257 | grad_out = torch.zeros_like(input) 258 | grad_out[:] = grads[get_rank() % ctx.group_size] 259 | return grad_out, None, None 260 | 261 | 262 | def gather_layer_with_group(data, group=None, group_size=None): 263 | if group_size is None: 264 | group_size = get_world_size() 265 | return GatherLayer.apply(data, group, group_size) 266 | 267 | from typing import Union 268 | import math 269 | # from torch.distributed.fsdp.fully_sharded_data_parallel import TrainingState_, _calc_grad_norm 270 | 271 | @torch.no_grad() 272 | def clip_grad_norm_( 273 | self, max_norm: Union[float, int], norm_type: Union[float, int] = 2.0 274 | ) -> None: 275 | self._lazy_init() 276 | self._wait_for_previous_optim_step() 277 | assert self._is_root, "clip_grad_norm should only be called on the root (parent) instance" 278 | self._assert_state(TrainingState_.IDLE) 279 | 280 | max_norm = float(max_norm) 281 | norm_type = float(norm_type) 282 | # Computes the max norm for this shard's gradients and sync's across workers 283 | local_norm = _calc_grad_norm(self.params_with_grad, norm_type).cuda() # type: ignore[arg-type] 284 | if norm_type == math.inf: 285 | total_norm = local_norm 286 | dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=self.process_group) 287 | else: 288 | total_norm = local_norm ** norm_type 289 | dist.all_reduce(total_norm, group=self.process_group) 290 | total_norm = total_norm ** (1.0 / norm_type) 291 | 292 | clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6) 293 | if clip_coef < 1: 294 | # multiply by clip_coef, aka, (max_norm/total_norm). 295 | for p in self.params_with_grad: 296 | assert p.grad is not None 297 | p.grad.detach().mul_(clip_coef.to(p.grad.device)) 298 | return total_norm 299 | 300 | 301 | def flush(): 302 | gc.collect() 303 | torch.cuda.empty_cache() 304 | -------------------------------------------------------------------------------- /diffusion/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import torch.distributed as dist 4 | from datetime import datetime 5 | from .dist_utils import is_local_master 6 | from mmcv.utils.logging import logger_initialized 7 | 8 | 9 | def get_root_logger(log_file=None, log_level=logging.INFO, name='PixArt'): 10 | """Get root logger. 11 | 12 | Args: 13 | log_file (str, optional): File path of log. Defaults to None. 14 | log_level (int, optional): The level of logger. 15 | Defaults to logging.INFO. 16 | name (str): logger name 17 | Returns: 18 | :obj:`logging.Logger`: The obtained logger 19 | """ 20 | if log_file is None: 21 | log_file = '/dev/null' 22 | return get_logger(name=name, log_file=log_file, log_level=log_level) 23 | 24 | 25 | def get_logger(name, log_file=None, log_level=logging.INFO): 26 | """Initialize and get a logger by name. 27 | 28 | If the logger has not been initialized, this method will initialize the 29 | logger by adding one or two handlers, otherwise the initialized logger will 30 | be directly returned. During initialization, a StreamHandler will always be 31 | added. If `log_file` is specified and the process rank is 0, a FileHandler 32 | will also be added. 33 | 34 | Args: 35 | name (str): Logger name. 36 | log_file (str | None): The log filename. If specified, a FileHandler 37 | will be added to the logger. 38 | log_level (int): The logger level. Note that only the process of 39 | rank 0 is affected, and other processes will set the level to 40 | "Error" thus be silent most of the time. 41 | 42 | Returns: 43 | logging.Logger: The expected logger. 44 | """ 45 | logger = logging.getLogger(name) 46 | logger.propagate = False # disable root logger to avoid duplicate logging 47 | 48 | if name in logger_initialized: 49 | return logger 50 | # handle hierarchical names 51 | # e.g., logger "a" is initialized, then logger "a.b" will skip the 52 | # initialization since it is a child of "a". 53 | for logger_name in logger_initialized: 54 | if name.startswith(logger_name): 55 | return logger 56 | 57 | stream_handler = logging.StreamHandler() 58 | handlers = [stream_handler] 59 | 60 | rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0 61 | # only rank 0 will add a FileHandler 62 | if rank == 0 and log_file is not None: 63 | file_handler = logging.FileHandler(log_file, 'w') 64 | handlers.append(file_handler) 65 | 66 | formatter = logging.Formatter( 67 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s') 68 | for handler in handlers: 69 | handler.setFormatter(formatter) 70 | handler.setLevel(log_level) 71 | logger.addHandler(handler) 72 | 73 | # only rank0 for each node will print logs 74 | log_level = log_level if is_local_master() else logging.ERROR 75 | logger.setLevel(log_level) 76 | 77 | logger_initialized[name] = True 78 | 79 | return logger 80 | 81 | def rename_file_with_creation_time(file_path): 82 | # 获取文件的创建时间 83 | creation_time = os.path.getctime(file_path) 84 | creation_time_str = datetime.fromtimestamp(creation_time).strftime('%Y-%m-%d_%H-%M-%S') 85 | 86 | # 构建新的文件名 87 | dir_name, file_name = os.path.split(file_path) 88 | name, ext = os.path.splitext(file_name) 89 | new_file_name = f"{name}_{creation_time_str}{ext}" 90 | new_file_path = os.path.join(dir_name, new_file_name) 91 | 92 | # 重命名文件 93 | os.rename(file_path, new_file_path) 94 | print(f"File renamed to: {new_file_path}") 95 | -------------------------------------------------------------------------------- /diffusion/utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from diffusers import get_cosine_schedule_with_warmup, get_constant_schedule_with_warmup 2 | from torch.optim import Optimizer 3 | from torch.optim.lr_scheduler import LambdaLR 4 | import math 5 | 6 | from diffusion.utils.logger import get_root_logger 7 | 8 | 9 | def build_lr_scheduler(config, optimizer, train_dataloader, lr_scale_ratio): 10 | if not config.get('lr_schedule_args', None): 11 | config.lr_schedule_args = {} 12 | if config.get('lr_warmup_steps', None): 13 | config['num_warmup_steps'] = config.get('lr_warmup_steps') # for compatibility with old version 14 | 15 | logger = get_root_logger() 16 | logger.info( 17 | f'Lr schedule: {config.lr_schedule}, ' + ",".join( 18 | [f"{key}:{value}" for key, value in config.lr_schedule_args.items()]) + '.') 19 | if config.lr_schedule == 'cosine': 20 | lr_scheduler = get_cosine_schedule_with_warmup( 21 | optimizer=optimizer, 22 | **config.lr_schedule_args, 23 | num_training_steps=(len(train_dataloader) * config.num_epochs), 24 | ) 25 | elif config.lr_schedule == 'constant': 26 | lr_scheduler = get_constant_schedule_with_warmup( 27 | optimizer=optimizer, 28 | **config.lr_schedule_args, 29 | ) 30 | elif config.lr_schedule == 'cosine_decay_to_constant': 31 | assert lr_scale_ratio >= 1 32 | lr_scheduler = get_cosine_decay_to_constant_with_warmup( 33 | optimizer=optimizer, 34 | **config.lr_schedule_args, 35 | final_lr=1 / lr_scale_ratio, 36 | num_training_steps=(len(train_dataloader) * config.num_epochs), 37 | ) 38 | else: 39 | raise RuntimeError(f'Unrecognized lr schedule {config.lr_schedule}.') 40 | return lr_scheduler 41 | 42 | 43 | def get_cosine_decay_to_constant_with_warmup(optimizer: Optimizer, 44 | num_warmup_steps: int, 45 | num_training_steps: int, 46 | final_lr: float = 0.0, 47 | num_decay: float = 0.667, 48 | num_cycles: float = 0.5, 49 | last_epoch: int = -1 50 | ): 51 | """ 52 | Create a schedule with a cosine annealing lr followed by a constant lr. 53 | 54 | Args: 55 | optimizer ([`~torch.optim.Optimizer`]): 56 | The optimizer for which to schedule the learning rate. 57 | num_warmup_steps (`int`): 58 | The number of steps for the warmup phase. 59 | num_training_steps (`int`): 60 | The number of total training steps. 61 | final_lr (`int`): 62 | The final constant lr after cosine decay. 63 | num_decay (`int`): 64 | The 65 | last_epoch (`int`, *optional*, defaults to -1): 66 | The index of the last epoch when resuming training. 67 | 68 | Return: 69 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 70 | """ 71 | 72 | def lr_lambda(current_step): 73 | if current_step < num_warmup_steps: 74 | return float(current_step) / float(max(1, num_warmup_steps)) 75 | 76 | num_decay_steps = int(num_training_steps * num_decay) 77 | if current_step > num_decay_steps: 78 | return final_lr 79 | 80 | progress = float(current_step - num_warmup_steps) / float(max(1, num_decay_steps - num_warmup_steps)) 81 | return ( 82 | max( 83 | 0.0, 84 | 0.5 * (1.0 + math.cos(math.pi * num_cycles * 2.0 * progress)), 85 | ) 86 | * (1 - final_lr) 87 | ) + final_lr 88 | 89 | return LambdaLR(optimizer, lr_lambda, last_epoch) 90 | -------------------------------------------------------------------------------- /diffusion/utils/optimizer.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from mmcv import Config 4 | from mmcv.runner import build_optimizer as mm_build_optimizer, OPTIMIZER_BUILDERS, DefaultOptimizerConstructor, \ 5 | OPTIMIZERS 6 | from mmcv.utils import _BatchNorm, _InstanceNorm 7 | from torch.nn import GroupNorm, LayerNorm 8 | 9 | from .logger import get_root_logger 10 | 11 | from typing import Tuple, Optional, Callable 12 | 13 | import torch 14 | from torch.optim.optimizer import Optimizer 15 | 16 | 17 | def auto_scale_lr(effective_bs, optimizer_cfg, rule='linear', base_batch_size=256): 18 | assert rule in ['linear', 'sqrt'] 19 | logger = get_root_logger() 20 | # scale by world size 21 | if rule == 'sqrt': 22 | scale_ratio = math.sqrt(effective_bs / base_batch_size) 23 | elif rule == 'linear': 24 | scale_ratio = effective_bs / base_batch_size 25 | optimizer_cfg['lr'] *= scale_ratio 26 | logger.info(f'Automatically adapt lr to {optimizer_cfg["lr"]:.7f} (using {rule} scaling rule).') 27 | return scale_ratio 28 | 29 | 30 | @OPTIMIZER_BUILDERS.register_module() 31 | class MyOptimizerConstructor(DefaultOptimizerConstructor): 32 | 33 | def add_params(self, params, module, prefix='', is_dcn_module=None): 34 | """Add all parameters of module to the params list. 35 | 36 | The parameters of the given module will be added to the list of param 37 | groups, with specific rules defined by paramwise_cfg. 38 | 39 | Args: 40 | params (list[dict]): A list of param groups, it will be modified 41 | in place. 42 | module (nn.Module): The module to be added. 43 | prefix (str): The prefix of the module 44 | 45 | """ 46 | # get param-wise options 47 | custom_keys = self.paramwise_cfg.get('custom_keys', {}) 48 | # first sort with alphabet order and then sort with reversed len of str 49 | # sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True) 50 | 51 | bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', 1.) 52 | bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', 1.) 53 | norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', 1.) 54 | bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False) 55 | 56 | # special rules for norm layers and depth-wise conv layers 57 | is_norm = isinstance(module, 58 | (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)) 59 | 60 | for name, param in module.named_parameters(recurse=False): 61 | base_lr = self.base_lr 62 | if name == 'bias' and not is_norm and not is_dcn_module: 63 | base_lr *= bias_lr_mult 64 | 65 | # apply weight decay policies 66 | base_wd = self.base_wd 67 | # norm decay 68 | if is_norm: 69 | if self.base_wd is not None: 70 | base_wd *= norm_decay_mult 71 | elif name == 'bias' and not is_dcn_module: 72 | if self.base_wd is not None: 73 | # TODO: current bias_decay_mult will have affect on DCN 74 | base_wd *= bias_decay_mult 75 | 76 | param_group = {'params': [param]} 77 | if not param.requires_grad: 78 | param_group['requires_grad'] = False 79 | params.append(param_group) 80 | continue 81 | if bypass_duplicate and self._is_in(param_group, params): 82 | logger = get_root_logger() 83 | logger.warn(f'{prefix} is duplicate. It is skipped since ' 84 | f'bypass_duplicate={bypass_duplicate}') 85 | continue 86 | # if the parameter match one of the custom keys, ignore other rules 87 | is_custom = False 88 | for key in custom_keys: 89 | scope, key_name = key if isinstance(key, tuple) else (None, key) 90 | if scope is not None and scope not in f'{prefix}': 91 | continue 92 | if key_name in f'{prefix}.{name}': 93 | is_custom = True 94 | if 'lr_mult' in custom_keys[key]: 95 | # if 'base_classes' in f'{prefix}.{name}' or 'attn_base' in f'{prefix}.{name}': 96 | # param_group['lr'] = self.base_lr 97 | # else: 98 | param_group['lr'] = self.base_lr * custom_keys[key]['lr_mult'] 99 | elif 'lr' not in param_group: 100 | param_group['lr'] = base_lr 101 | if self.base_wd is not None: 102 | if 'decay_mult' in custom_keys[key]: 103 | param_group['weight_decay'] = self.base_wd * custom_keys[key]['decay_mult'] 104 | elif 'weight_decay' not in param_group: 105 | param_group['weight_decay'] = base_wd 106 | 107 | if not is_custom: 108 | # bias_lr_mult affects all bias parameters 109 | # except for norm.bias dcn.conv_offset.bias 110 | if base_lr != self.base_lr: 111 | param_group['lr'] = base_lr 112 | if base_wd != self.base_wd: 113 | param_group['weight_decay'] = base_wd 114 | params.append(param_group) 115 | 116 | for child_name, child_mod in module.named_children(): 117 | child_prefix = f'{prefix}.{child_name}' if prefix else child_name 118 | self.add_params( 119 | params, 120 | child_mod, 121 | prefix=child_prefix, 122 | is_dcn_module=is_dcn_module) 123 | 124 | 125 | def build_optimizer(model, optimizer_cfg): 126 | # default parameter-wise config 127 | logger = get_root_logger() 128 | 129 | if hasattr(model, 'module'): 130 | model = model.module 131 | # set optimizer constructor 132 | optimizer_cfg.setdefault('constructor', 'MyOptimizerConstructor') 133 | # parameter-wise setting: cancel weight decay for some specific modules 134 | custom_keys = dict() 135 | for name, module in model.named_modules(): 136 | if hasattr(module, 'zero_weight_decay'): 137 | custom_keys |= { 138 | (name, key): dict(decay_mult=0) 139 | for key in module.zero_weight_decay 140 | } 141 | 142 | paramwise_cfg = Config(dict(cfg=dict(custom_keys=custom_keys))) 143 | if given_cfg := optimizer_cfg.get('paramwise_cfg'): 144 | paramwise_cfg.merge_from_dict(dict(cfg=given_cfg)) 145 | optimizer_cfg['paramwise_cfg'] = paramwise_cfg.cfg 146 | # build optimizer 147 | optimizer = mm_build_optimizer(model, optimizer_cfg) 148 | 149 | weight_decay_groups = dict() 150 | lr_groups = dict() 151 | for group in optimizer.param_groups: 152 | if not group.get('requires_grad', True): continue 153 | lr_groups.setdefault(group['lr'], []).append(group) 154 | weight_decay_groups.setdefault(group['weight_decay'], []).append(group) 155 | 156 | learnable_count, fix_count = 0, 0 157 | for p in model.parameters(): 158 | if p.requires_grad: 159 | learnable_count += 1 160 | else: 161 | fix_count += 1 162 | fix_info = f"{learnable_count} are learnable, {fix_count} are fix" 163 | lr_info = "Lr group: " + ", ".join([f'{len(group)} params with lr {lr:.5f}' for lr, group in lr_groups.items()]) 164 | wd_info = "Weight decay group: " + ", ".join( 165 | [f'{len(group)} params with weight decay {wd}' for wd, group in weight_decay_groups.items()]) 166 | opt_info = f"Optimizer: total {len(optimizer.param_groups)} param groups, {fix_info}. {lr_info}; {wd_info}." 167 | logger.info(opt_info) 168 | 169 | return optimizer 170 | 171 | 172 | @OPTIMIZERS.register_module() 173 | class Lion(Optimizer): 174 | def __init__( 175 | self, 176 | params, 177 | lr: float = 1e-4, 178 | betas: Tuple[float, float] = (0.9, 0.99), 179 | weight_decay: float = 0.0, 180 | ): 181 | assert lr > 0. 182 | assert all(0. <= beta <= 1. for beta in betas) 183 | 184 | defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) 185 | 186 | super().__init__(params, defaults) 187 | 188 | @staticmethod 189 | def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2): 190 | # stepweight decay 191 | p.data.mul_(1 - lr * wd) 192 | 193 | # weight update 194 | update = exp_avg.clone().lerp_(grad, 1 - beta1).sign_() 195 | p.add_(update, alpha=-lr) 196 | 197 | # decay the momentum running average coefficient 198 | exp_avg.lerp_(grad, 1 - beta2) 199 | 200 | @staticmethod 201 | def exists(val): 202 | return val is not None 203 | 204 | @torch.no_grad() 205 | def step( 206 | self, 207 | closure: Optional[Callable] = None 208 | ): 209 | 210 | loss = None 211 | if self.exists(closure): 212 | with torch.enable_grad(): 213 | loss = closure() 214 | 215 | for group in self.param_groups: 216 | for p in filter(lambda p: self.exists(p.grad), group['params']): 217 | 218 | grad, lr, wd, beta1, beta2, state = p.grad, group['lr'], group['weight_decay'], *group['betas'], \ 219 | self.state[p] 220 | 221 | # init state - exponential moving average of gradient values 222 | if len(state) == 0: 223 | state['exp_avg'] = torch.zeros_like(p) 224 | 225 | exp_avg = state['exp_avg'] 226 | 227 | self.update_fn( 228 | p, 229 | grad, 230 | exp_avg, 231 | lr, 232 | wd, 233 | beta1, 234 | beta2 235 | ) 236 | 237 | return loss 238 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import Image 4 | import os 5 | import copy 6 | import argparse 7 | from matplotlib import pyplot as plt 8 | 9 | from peft import PeftModel 10 | 11 | from tide.utils.dataset_palette import USIS10K_COLORS_PALETTE 12 | from tide.utils import mask_postprocess 13 | 14 | from tide.pipeline.tide_transformer import ( 15 | PixArtSpecialAttnTransformerModel, 16 | TIDETransformerModel, 17 | TIDE_TANs, 18 | MiniTransformerModel 19 | ) 20 | from tide.pipeline.pipeline_tide import TIDEPipeline 21 | 22 | cmap = plt.get_cmap('Spectral_r') 23 | 24 | 25 | def colour_and_vis_id_map(id_map, palette, out_path): 26 | id_map = id_map.astype(np.uint8) 27 | id_map -= 1 28 | ids = np.unique(id_map) 29 | valid_ids = np.delete(ids, np.where(ids == 255)) 30 | 31 | colour_layout = np.zeros((id_map.shape[0], id_map.shape[1], 3), dtype=np.uint8) 32 | for id in valid_ids: 33 | colour_layout[id_map == id, :] = palette[id].reshape(1, 3) 34 | colour_layout = Image.fromarray(colour_layout) 35 | colour_layout.save(out_path) 36 | 37 | 38 | def main(args): 39 | pretrained_t2i_model = os.path.join(args.model_weights_dir, "PixArt-XL-2-512x512") 40 | mini_transformer_dir = os.path.join(args.model_weights_dir, "TIDE_MiniTransformer") 41 | tide_weight_dir = os.path.join(args.model_weights_dir, "TIDE_r32_64_b4_200k") 42 | 43 | generator = torch.manual_seed(50) 44 | palette = np.array([[0, 0, 0]] + USIS10K_COLORS_PALETTE) 45 | 46 | # model definitions 47 | transformer = PixArtSpecialAttnTransformerModel.from_pretrained( 48 | pretrained_t2i_model, 49 | subfolder="transformer", torch_dtype=torch.float16 50 | ) 51 | transformer.requires_grad_(False) 52 | 53 | depth_transformer = MiniTransformerModel.from_config( 54 | mini_transformer_dir, 55 | subfolder="mini_transformer", 56 | torch_dtype=torch.float16 57 | ) 58 | _state_dict = torch.load( 59 | os.path.join(mini_transformer_dir, 'mini_transformer/diffusion_pytorch_model.pth'), 60 | map_location='cpu' 61 | ) 62 | depth_transformer.load_state_dict(_state_dict) 63 | depth_transformer.half() 64 | depth_transformer.requires_grad_(False) 65 | del _state_dict 66 | 67 | mask_transformer = copy.deepcopy(depth_transformer) 68 | 69 | image_transformer = PeftModel.from_pretrained( 70 | transformer, os.path.join(tide_weight_dir, 'image_transformer_lora') 71 | ) 72 | depth_transformer = PeftModel.from_pretrained( 73 | depth_transformer, os.path.join(tide_weight_dir, 'depth_transformer_lora') 74 | ) 75 | mask_transformer = PeftModel.from_pretrained( 76 | mask_transformer, os.path.join(tide_weight_dir, 'mask_transformer_lora') 77 | ) 78 | 79 | tan_modules = TIDE_TANs.from_pretrained( 80 | os.path.join(tide_weight_dir, 'tan_modules'), torch_dtype=torch.float16 81 | ) 82 | 83 | tide_transformer = TIDETransformerModel(image_transformer, depth_transformer, mask_transformer, tan_modules) 84 | del image_transformer, depth_transformer, mask_transformer, tan_modules 85 | 86 | model = TIDEPipeline.from_pretrained( 87 | pretrained_t2i_model, 88 | transformer=tide_transformer, 89 | torch_dtype=torch.float16, 90 | use_safetensors=True 91 | ).to("cuda") 92 | 93 | # generate image, depth map, semantic mask 94 | target_image, depth_image, mask_image = model( 95 | prompt=args.text_prompt, 96 | num_inference_steps=20, 97 | generator=generator, 98 | guidance_scale=2.0, 99 | ) 100 | target_image = target_image.images[0] 101 | depth_image = depth_image.images[0] 102 | mask_image = mask_image.images[0] 103 | 104 | target_image.save(os.path.join(args.output, "image.jpg")) 105 | 106 | depth_image = np.mean(depth_image, axis=-1) 107 | vis_depth_image = (cmap(depth_image) * 255).astype(np.uint8) 108 | Image.fromarray(vis_depth_image).save(os.path.join(args.output, "depth.png")) 109 | 110 | id_map = mask_postprocess(mask_image, palette) 111 | colour_and_vis_id_map(id_map, palette[1:], os.path.join(args.output, "mask.png")) 112 | 113 | 114 | if __name__ == "__main__": 115 | parser = argparse.ArgumentParser() 116 | parser.add_argument("--model_weights_dir", type=str, default="./model_weights") 117 | parser.add_argument('--text_prompt', type=str, default="A large school of fish swimming in a circle.") 118 | parser.add_argument('--output', type=str, default="./outputs") 119 | args = parser.parse_args() 120 | 121 | os.makedirs(args.output, exist_ok=True) 122 | 123 | main(args) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | mmcv==1.7.0 2 | diffusers==0.29.2 3 | timm==0.6.12 4 | accelerate==0.31.0 5 | tensorboard==2.17.0 6 | tensorboardX==2.6.2.2 7 | transformers==4.41.2 8 | ftfy 9 | protobuf==3.20.2 10 | gradio==4.1.1 11 | yapf==0.40.1 12 | bs4 13 | einops 14 | optimum 15 | xformers==0.0.20 16 | Pillow==10.2.0 17 | sentencepiece~=0.1.99 18 | peft==0.10.0 19 | beautifulsoup4 20 | git+https://github.com/facebookresearch/detectron2.git 21 | git+https://github.com/lucasb-eyer/pydensecrf.git 22 | -------------------------------------------------------------------------------- /tide/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HongkLin/TIDE/d0cba53604b3dd9e16e4f81ce24d7f4be3ba0d4e/tide/__init__.py -------------------------------------------------------------------------------- /tide/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from detectron2.config import CfgNode as CN 3 | 4 | 5 | def add_tide_config(cfg): 6 | #TIDE 7 | cfg.TIDE = CN() 8 | cfg.TIDE.SEED = 42 9 | cfg.TIDE.PRETRAINED_DIFFUSION_MODEL_WEIGHT = "pretrained_model/PixArt-XL-2-512x512" 10 | 11 | #instruction 12 | cfg.TIDE.INSTRUCT = CN() 13 | cfg.TIDE.NUM_IMAGE_PER_PROMPT = 1 14 | cfg.TIDE.REGION_FILTER_TH = 100 15 | cfg.TIDE.DISTANCE_MAP_TH = None 16 | cfg.TIDE.RETURN_CRF_REFINE = True 17 | cfg.TIDE.TEMPERATURE = 40 18 | -------------------------------------------------------------------------------- /tide/pipeline/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class MLP(nn.Module): 7 | """Very simple multi-layer perceptron (also called FFN)""" 8 | 9 | def __init__( 10 | self, 11 | input_dim, 12 | hidden_dim, 13 | output_dim, 14 | num_layers, 15 | sigmoid_output: bool = False, 16 | affine_func=nn.Linear, 17 | ): 18 | super().__init__() 19 | self.num_layers = num_layers 20 | h = [hidden_dim] * (num_layers - 1) 21 | self.layers = nn.ModuleList( 22 | affine_func(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 23 | ) 24 | self.sigmoid_output = sigmoid_output 25 | 26 | def forward(self, x: torch.Tensor): 27 | for i, layer in enumerate(self.layers): 28 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 29 | if self.sigmoid_output: 30 | x = F.sigmoid(x) 31 | return x 32 | 33 | class ResidualMLP(nn.Module): 34 | 35 | def __init__( 36 | self, 37 | input_dim, 38 | hidden_dim, 39 | output_dim, 40 | num_mlp, 41 | num_layer_per_mlp, 42 | sigmoid_output: bool = False, 43 | affine_func=nn.Linear, 44 | ): 45 | super().__init__() 46 | self.num_mlp = num_mlp 47 | self.in2hidden_dim = affine_func(input_dim, hidden_dim) 48 | self.hidden2out_dim = affine_func(hidden_dim, output_dim) 49 | self.mlp_list = nn.ModuleList( 50 | MLP( 51 | hidden_dim, 52 | hidden_dim, 53 | hidden_dim, 54 | num_layer_per_mlp, 55 | affine_func=affine_func, 56 | ) for _ in range(num_mlp) 57 | ) 58 | self.sigmoid_output = sigmoid_output 59 | 60 | def forward(self, x: torch.Tensor): 61 | x = self.in2hidden_dim(x) 62 | for mlp in self.mlp_list: 63 | out = mlp(x) 64 | x = x + out 65 | out = self.hidden2out_dim(x) 66 | return out -------------------------------------------------------------------------------- /tide/pipeline/transformer_attentions.py: -------------------------------------------------------------------------------- 1 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 2 | # See the License for the specific language governing permissions and 3 | # limitations under the License. 4 | from torch import einsum 5 | from typing import Callable, List, Optional, Union 6 | 7 | import torch 8 | from torch import nn 9 | import torch.nn.functional as F 10 | 11 | from einops import rearrange, repeat 12 | 13 | from diffusers.utils.import_utils import is_xformers_available 14 | from diffusers.models.attention_processor import Attention, SpatialNorm, AttnProcessor2_0, AttnProcessor 15 | from diffusers.models.attention import logger 16 | from diffusers.utils import deprecate 17 | 18 | class ILSAttnProcessor: 19 | r""" 20 | Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). 21 | """ 22 | 23 | def __init__(self): 24 | if not hasattr(F, "scaled_dot_product_attention"): 25 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 26 | 27 | def __call__( 28 | self, 29 | attn, 30 | hidden_states, 31 | encoder_hidden_states, 32 | attention_mask: Optional[torch.Tensor] = None, 33 | cross_attn_map: Optional[torch.Tensor] = None, 34 | temb: Optional[torch.Tensor] = None, 35 | *args, 36 | **kwargs, 37 | ) -> torch.Tensor: 38 | context = encoder_hidden_states 39 | if len(args) > 0 or kwargs.get("scale", None) is not None: 40 | deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." 41 | deprecate("scale", "1.0.0", deprecation_message) 42 | 43 | residual = hidden_states 44 | if attn.spatial_norm is not None: 45 | hidden_states = attn.spatial_norm(hidden_states, temb) 46 | 47 | input_ndim = hidden_states.ndim 48 | 49 | if input_ndim == 4: 50 | batch_size, channel, height, width = hidden_states.shape 51 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 52 | 53 | batch_size, sequence_length, _ = ( 54 | hidden_states.shape if context is None else context.shape 55 | ) 56 | 57 | if attention_mask is not None: 58 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 59 | # scaled_dot_product_attention expects attention_mask shape to be 60 | # (batch, heads, source_length, target_length) 61 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 62 | 63 | if attn.group_norm is not None: 64 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 65 | 66 | query = attn.to_q(hidden_states) 67 | 68 | if context is None: 69 | context = hidden_states 70 | elif attn.norm_cross: 71 | context = attn.norm_context(context) 72 | 73 | key = attn.to_k(context) 74 | value = attn.to_v(context) 75 | 76 | inner_dim = key.shape[-1] 77 | head_dim = inner_dim // attn.heads 78 | 79 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 80 | 81 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 82 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 83 | 84 | if cross_attn_map is not None: 85 | hidden_states = einsum('b h l n, b h n c -> b h l c', cross_attn_map, value) 86 | else: 87 | hidden_states, cross_attn_map = self.scaled_dot_product_attention( 88 | attn, query, key, value, attn_mask=attention_mask, dropout_p=0.0 89 | ) 90 | 91 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 92 | hidden_states = hidden_states.to(query.dtype) 93 | 94 | # linear proj 95 | hidden_states = attn.to_out[0](hidden_states) 96 | # dropout 97 | hidden_states = attn.to_out[1](hidden_states) 98 | 99 | if input_ndim == 4: 100 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 101 | 102 | if attn.residual_connection: 103 | hidden_states = hidden_states + residual 104 | 105 | hidden_states = hidden_states / attn.rescale_output_factor 106 | 107 | return hidden_states, cross_attn_map 108 | 109 | def scaled_dot_product_attention(self, attn, query, key, value, attn_mask=None, dropout_p=0.0): 110 | query = query * attn.scale 111 | attn_map = einsum('b h l c, b h n c -> b h l n', query, key) 112 | 113 | if attn_mask is not None: 114 | attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype == torch.bool else attn_mask 115 | attn_map += attn_mask 116 | 117 | attn_weight = F.softmax(attn_map, dim=-1) 118 | 119 | if dropout_p > 0.0: 120 | attn_weight = F.dropout(attn_weight, p=dropout_p) 121 | 122 | hidden_states = einsum('b h l n, b h n c -> b h l c', attn_weight, value) 123 | 124 | return hidden_states, attn_weight 125 | -------------------------------------------------------------------------------- /tide/pipeline/transformer_blocks.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from diffusers.models.attention import logger, _chunked_feed_forward, BasicTransformerBlock 7 | from diffusers.models.attention_processor import Attention 8 | from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous 9 | from .transformer_attentions import ILSAttnProcessor 10 | 11 | class BasicTransformerTIDEBlock(BasicTransformerBlock): 12 | def __init__( 13 | self, 14 | dim: int, 15 | num_attention_heads: int, 16 | attention_head_dim: int, 17 | dropout=0.0, 18 | cross_attention_dim: Optional[int] = None, 19 | activation_fn: str = "geglu", 20 | num_embeds_ada_norm: Optional[int] = None, 21 | attention_bias: bool = False, 22 | only_cross_attention: bool = False, 23 | double_self_attention: bool = False, 24 | upcast_attention: bool = False, 25 | norm_elementwise_affine: bool = True, 26 | norm_type: str = "layer_norm", 27 | norm_eps: float = 1e-5, 28 | final_dropout: bool = False, 29 | attention_type: str = "default", 30 | positional_embeddings: Optional[str] = None, 31 | num_positional_embeddings: Optional[int] = None, 32 | ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, 33 | ada_norm_bias: Optional[int] = None, 34 | ff_inner_dim: Optional[int] = None, 35 | ff_bias: bool = True, 36 | attention_out_bias: bool = True, 37 | ): 38 | super().__init__( 39 | dim, 40 | num_attention_heads, 41 | attention_head_dim, 42 | dropout, 43 | cross_attention_dim, 44 | activation_fn, 45 | num_embeds_ada_norm, 46 | attention_bias, 47 | only_cross_attention, 48 | double_self_attention, 49 | upcast_attention, 50 | norm_elementwise_affine, 51 | norm_type, 52 | norm_eps, 53 | final_dropout, 54 | attention_type, 55 | positional_embeddings, 56 | num_positional_embeddings, 57 | ada_norm_continous_conditioning_embedding_dim, 58 | ada_norm_bias, 59 | ff_inner_dim, 60 | ff_bias, 61 | attention_out_bias, 62 | ) 63 | 64 | 65 | # 2. Cross-Attn 66 | if cross_attention_dim is not None or double_self_attention: 67 | # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. 68 | # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during 69 | # the second cross attention block. 70 | if norm_type == "ada_norm": 71 | self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) 72 | elif norm_type == "ada_norm_continuous": 73 | self.norm2 = AdaLayerNormContinuous( 74 | dim, 75 | ada_norm_continous_conditioning_embedding_dim, 76 | norm_elementwise_affine, 77 | norm_eps, 78 | ada_norm_bias, 79 | "rms_norm", 80 | ) 81 | else: 82 | self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) 83 | 84 | self.attn2 = Attention( 85 | query_dim=dim, 86 | cross_attention_dim=cross_attention_dim if not double_self_attention else None, 87 | heads=num_attention_heads, 88 | dim_head=attention_head_dim, 89 | dropout=dropout, 90 | bias=attention_bias, 91 | upcast_attention=upcast_attention, 92 | out_bias=attention_out_bias, 93 | processor=ILSAttnProcessor() 94 | ) # is self-attn if encoder_hidden_states is none 95 | else: 96 | self.norm2 = None 97 | self.attn2 = None 98 | 99 | 100 | def forward( 101 | self, 102 | hidden_states: torch.Tensor, 103 | attention_mask: Optional[torch.Tensor] = None, 104 | encoder_hidden_states: Optional[torch.Tensor] = None, 105 | encoder_attention_mask: Optional[torch.Tensor] = None, 106 | timestep: Optional[torch.LongTensor] = None, 107 | cross_attn_map: Optional[torch.Tensor] = None, 108 | cross_attention_kwargs: Dict[str, Any] = None, 109 | class_labels: Optional[torch.LongTensor] = None, 110 | added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, 111 | ) -> torch.Tensor: 112 | if cross_attention_kwargs is not None: 113 | if cross_attention_kwargs.get("scale", None) is not None: 114 | logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") 115 | 116 | # Notice that normalization is always applied before the real computation in the following blocks. 117 | # 0. Self-Attention 118 | batch_size = hidden_states.shape[0] 119 | 120 | if self.norm_type == "ada_norm": 121 | norm_hidden_states = self.norm1(hidden_states, timestep) 122 | elif self.norm_type == "ada_norm_zero": 123 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( 124 | hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype 125 | ) 126 | elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: 127 | norm_hidden_states = self.norm1(hidden_states) 128 | elif self.norm_type == "ada_norm_continuous": 129 | norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) 130 | elif self.norm_type == "ada_norm_single": 131 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( 132 | self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) 133 | ).chunk(6, dim=1) 134 | norm_hidden_states = self.norm1(hidden_states) 135 | norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa 136 | norm_hidden_states = norm_hidden_states.squeeze(1) 137 | else: 138 | raise ValueError("Incorrect norm used") 139 | 140 | if self.pos_embed is not None: 141 | norm_hidden_states = self.pos_embed(norm_hidden_states) 142 | 143 | # 1. Prepare GLIGEN inputs 144 | cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} 145 | gligen_kwargs = cross_attention_kwargs.pop("gligen", None) 146 | 147 | attn_output = self.attn1( 148 | norm_hidden_states, 149 | encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, 150 | attention_mask=attention_mask, 151 | # **cross_attention_kwargs, 152 | ) 153 | if self.norm_type == "ada_norm_zero": 154 | attn_output = gate_msa.unsqueeze(1) * attn_output 155 | elif self.norm_type == "ada_norm_single": 156 | attn_output = gate_msa * attn_output 157 | 158 | hidden_states = attn_output + hidden_states 159 | if hidden_states.ndim == 4: 160 | hidden_states = hidden_states.squeeze(1) 161 | 162 | # 1.2 GLIGEN Control 163 | if gligen_kwargs is not None: 164 | hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) 165 | 166 | # 3. Cross-Attention 167 | if self.attn2 is not None: 168 | if self.norm_type == "ada_norm": 169 | norm_hidden_states = self.norm2(hidden_states, timestep) 170 | elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: 171 | norm_hidden_states = self.norm2(hidden_states) 172 | elif self.norm_type == "ada_norm_single": 173 | # For PixArt norm2 isn't applied here: 174 | # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 175 | norm_hidden_states = hidden_states 176 | elif self.norm_type == "ada_norm_continuous": 177 | norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) 178 | else: 179 | raise ValueError("Incorrect norm") 180 | 181 | if self.pos_embed is not None and self.norm_type != "ada_norm_single": 182 | norm_hidden_states = self.pos_embed(norm_hidden_states) 183 | 184 | attn_output, cross_attn_map = self.attn2( 185 | norm_hidden_states, 186 | encoder_hidden_states=encoder_hidden_states, 187 | attention_mask=encoder_attention_mask, 188 | cross_attn_map=cross_attn_map, 189 | **cross_attention_kwargs, 190 | ) 191 | hidden_states = attn_output + hidden_states 192 | 193 | # 4. Feed-forward 194 | # i2vgen doesn't have this norm 🤷‍♂️ 195 | if self.norm_type == "ada_norm_continuous": 196 | norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) 197 | elif not self.norm_type == "ada_norm_single": 198 | norm_hidden_states = self.norm3(hidden_states) 199 | 200 | if self.norm_type == "ada_norm_zero": 201 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 202 | 203 | if self.norm_type == "ada_norm_single": 204 | norm_hidden_states = self.norm2(hidden_states) 205 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp 206 | 207 | if self._chunk_size is not None: 208 | # "feed_forward_chunk_size" can be used to save memory 209 | ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) 210 | else: 211 | ff_output = self.ff(norm_hidden_states) 212 | 213 | if self.norm_type == "ada_norm_zero": 214 | ff_output = gate_mlp.unsqueeze(1) * ff_output 215 | elif self.norm_type == "ada_norm_single": 216 | ff_output = gate_mlp * ff_output 217 | 218 | hidden_states = ff_output + hidden_states 219 | if hidden_states.ndim == 4: 220 | hidden_states = hidden_states.squeeze(1) 221 | 222 | return hidden_states, cross_attn_map 223 | -------------------------------------------------------------------------------- /tide/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .prompt_process import IDColour 2 | from .mask_process import mask_postprocess -------------------------------------------------------------------------------- /tide/utils/dataset_palette.py: -------------------------------------------------------------------------------- 1 | USIS10K_CLASS_NAMES = ( 2 | "wrecks", 3 | "fish", 4 | "reefs", 5 | "aquatic plants", 6 | "human divers", 7 | "robots", 8 | "sea-floor" 9 | ) 10 | USIS10K_COLORS_PALETTE = [[0, 0, 255], [0, 255, 0], [0, 255, 255], [255, 0, 0], [255, 0, 255], [255, 255, 0], [255, 255, 255]] -------------------------------------------------------------------------------- /tide/utils/mask_process.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import Image 4 | import cv2 5 | import pydensecrf.densecrf as dcrf 6 | from pydensecrf.utils import unary_from_softmax, create_pairwise_gaussian 7 | 8 | 9 | 10 | def apply_crf(probability_matrix, num_classes): 11 | h, w = probability_matrix.shape[-2:] 12 | 13 | dcrf_model = dcrf.DenseCRF(h * w, num_classes) 14 | 15 | unary = unary_from_softmax(probability_matrix) 16 | unary = np.ascontiguousarray(unary) 17 | dcrf_model.setUnaryEnergy(unary) 18 | 19 | feats = create_pairwise_gaussian(sdims=(10, 10), shape=(h, w)) 20 | 21 | dcrf_model.addPairwiseEnergy(feats, compat=3, 22 | kernel=dcrf.DIAG_KERNEL, 23 | normalization=dcrf.NORMALIZE_SYMMETRIC) 24 | 25 | Q = dcrf_model.inference(5) 26 | 27 | MAP = np.argmax(Q, axis=0).reshape((h, w)) 28 | return MAP 29 | 30 | 31 | def colour_and_vis_id_map(id_map, palette, out_path): 32 | palette = np.array(palette) 33 | 34 | id_map = id_map.astype(np.uint8) 35 | id_map -= 1 36 | ids = np.unique(id_map) 37 | valid_ids = np.delete(ids, np.where(ids == 255)) 38 | 39 | colour_layout = np.zeros((id_map.shape[0], id_map.shape[1], 3), dtype=np.uint8) 40 | for id in valid_ids: 41 | colour_layout[id_map == id, :] = palette[id].reshape(1, 3) 42 | colour_layout = Image.fromarray(colour_layout) 43 | colour_layout.save(out_path) 44 | 45 | 46 | def compute_distance(matrix_A, matrix_B): 47 | matrix_A = matrix_A[:, np.newaxis, :] 48 | 49 | diff_squared = (matrix_A - matrix_B) ** 2 50 | 51 | distances = np.sqrt(np.sum(diff_squared, axis=2)) 52 | return distances 53 | 54 | 55 | def rgb2id(matrix_A, matrix_B, th=None, return_crf_refine=True, t=40): 56 | h, w = matrix_A.shape[:2] 57 | matrix_A = matrix_A.reshape(-1, 3) 58 | 59 | distances = compute_distance(matrix_A, matrix_B) 60 | min_distance_indices = np.argmin(distances, axis=1) 61 | 62 | if th is not None: 63 | min_distances = np.min(distances, axis=1) 64 | min_distance_indices[min_distances > th] = len(matrix_B) - 1 65 | 66 | crf_results = None 67 | if return_crf_refine: 68 | prob = 1 - distances / distances.sum(1)[:, None] 69 | prob = torch.tensor(prob).reshape(h, w, -1).permute(2, 0, 1) 70 | prob = torch.nn.functional.softmax(t * prob, dim=0).numpy() 71 | crf_results = apply_crf(prob, num_classes=len(matrix_B)) 72 | return min_distance_indices.reshape(w, h), crf_results 73 | 74 | 75 | def filter_small_regions(arr, min_area=100, ignore_class=0): 76 | unique_labels = np.unique(arr) 77 | result = arr.copy() 78 | 79 | for label in unique_labels: 80 | if label == ignore_class: 81 | continue 82 | binary_image = np.where(arr == label, 1, 0).astype('uint8') 83 | num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_image, connectivity=8) 84 | for i in range(1, num_labels): 85 | if stats[i, cv2.CC_STAT_AREA] < min_area: 86 | result[labels == i] = ignore_class 87 | 88 | return result 89 | 90 | 91 | def mask_postprocess(mask_image, palette, ignore_class=0): 92 | mask_image = np.array(mask_image) 93 | _, refine_id_map = rgb2id(mask_image, palette) 94 | id_map = filter_small_regions(refine_id_map) 95 | id_map[id_map == 7] = ignore_class # ignore sea-floor 96 | return id_map -------------------------------------------------------------------------------- /tide/utils/prompt_process.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .dataset_palette import ( 4 | USIS10K_COLORS_PALETTE, 5 | ) 6 | 7 | class IDColour: 8 | def __init__(self, max_id_num=8, color_palette=USIS10K_COLORS_PALETTE): 9 | palette = color_palette 10 | self.palette = np.array(palette) 11 | self.background_id = 255 12 | 13 | def __call__(self, id_map): 14 | id_map = np.array(id_map) - 1 15 | ids = np.unique(id_map) 16 | valid_ids = np.delete(ids, np.where(ids == self.background_id)) 17 | 18 | mask_pixel_values = np.zeros((id_map.shape[0], id_map.shape[1], 3), dtype=np.uint8) 19 | for id in valid_ids: 20 | mask_pixel_values[id_map == id, :] = self.palette[id].reshape(1, 3) 21 | return mask_pixel_values 22 | 23 | def generate_rgb_values(self, num_values): 24 | step = (255 ** 3) / (num_values - 1) 25 | 26 | rgb_values = [] 27 | 28 | for i in range(num_values): 29 | value_i = round(i * step) 30 | red_i = value_i // (256 ** 2) 31 | green_i = (value_i % (256 ** 2)) // 256 32 | blue_i = value_i % 256 33 | rgb_values.append((red_i, green_i, blue_i)) 34 | 35 | return rgb_values --------------------------------------------------------------------------------