├── .gitignore ├── LICENSE ├── README.md ├── dataset.py ├── generate.py ├── generate_t2i.py ├── loss.py ├── models ├── clip_vit.py ├── jepa.py ├── mae_vit.py ├── mmdit.py ├── mocov3_vit.py └── sit.py ├── preprocessing ├── README.md ├── dataset_tools.py ├── dnnlib │ ├── __init__.py │ └── util.py ├── encoders.py └── torch_utils │ ├── __init__.py │ ├── distributed.py │ ├── misc.py │ ├── persistence.py │ └── training_stats.py ├── requirements.txt ├── samplers.py ├── samplers_t2i.py ├── train.py ├── train_t2i.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.DS_Store 2 | .vscode/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Sihyun Yu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Representation Alignment for Generation:
Training Diffusion Transformers Is Easier Than You Think 2 |

3 | 4 | [![arXiv](https://img.shields.io/badge/arXiv%20paper-2410.06940-b31b1b.svg)](https://arxiv.org/abs/2410.06940)  5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/representation-alignment-for-generation/image-generation-on-imagenet-256x256)](https://paperswithcode.com/sota/image-generation-on-imagenet-256x256?p=representation-alignment-for-generation) 6 | 7 |
8 | Sihyun Yu1·   9 | Sangkyung Kwak1·   10 | Huiwon Jang1·   11 | Jongheon Jeong2 12 |
13 | Jonathan Huang3·   14 | Jinwoo Shin1*·   15 | Saining Xie4*
16 | 1 KAIST   2Korea University   3Scaled Foundations   4New York University  
17 | *Equal Advising  
18 |
19 |

[project page] [arXiv]

20 |
21 | 22 | Summary: We propose REPresentation Alignment (REPA), a method that aligns noisy input states in diffusion models with representations from pretrained visual encoders. This significantly improves training efficiency and generation quality. REPA speeds up SiT training by 17.5x and achieves state-of-the-art FID=1.42. 23 | 24 | ### 1. Environment setup 25 | 26 | ```bash 27 | conda create -n repa python=3.9 -y 28 | conda activate repa 29 | pip install -r requirements.txt 30 | ``` 31 | 32 | ### 2. Dataset 33 | 34 | #### Dataset download 35 | 36 | Currently, we provide experiments for [ImageNet](https://www.kaggle.com/competitions/imagenet-object-localization-challenge/data). You can place the data that you want and can specifiy it via `--data-dir` arguments in training scripts. Please refer to our [preprocessing guide](https://github.com/sihyun-yu/REPA/tree/master/preprocessing). 37 | 38 | ### 3. Training 39 | 40 | ```bash 41 | accelerate launch train.py \ 42 | --report-to="wandb" \ 43 | --allow-tf32 \ 44 | --mixed-precision="fp16" \ 45 | --seed=0 \ 46 | --path-type="linear" \ 47 | --prediction="v" \ 48 | --weighting="uniform" \ 49 | --model="SiT-XL/2" \ 50 | --enc-type="dinov2-vit-b" \ 51 | --proj-coeff=0.5 \ 52 | --encoder-depth=8 \ 53 | --output-dir="exps" \ 54 | --exp-name="linear-dinov2-b-enc8" \ 55 | --data-dir=[YOUR_DATA_PATH] 56 | ``` 57 | 58 | Then this script will automatically create the folder in `exps` to save logs and checkpoints. You can adjust the following options: 59 | 60 | - `--models`: `[SiT-B/2, SiT-L/2, SiT-XL/2]` 61 | - `--enc-type`: `[dinov2-vit-b, dinov2-vit-l, dinov2-vit-g, dinov1-vit-b, mocov3-vit-b, , mocov3-vit-l, clip-vit-L, jepa-vit-h, mae-vit-l]` 62 | - `--proj-coeff`: Any values larger than 0 63 | - `--encoder-depth`: Any values between 1 to the depth of the model 64 | - `--output-dir`: Any directory that you want to save checkpoints and logs 65 | - `--exp-name`: Any string name (the folder will be created under `output-dir`) 66 | 67 | For DINOv2 models, it will be automatically downloaded from `torch.hub`. For CLIP models, it will be also automatically downloaded from the CLIP repository. For other pretrained visual encoders, please download the model weights from the below links and place into the following directories with these names: 68 | 69 | - `dinov1`: Download the ViT-B/16 model from the [`DINO`](https://github.com/facebookresearch/dino) repository and place it as `./ckpts/dinov1_vitb.pth` 70 | - `mocov3`: Download the ViT-B/16 or ViT-L/16 model from the [`RCG`](https://github.com/LTH14/rcg) repository and place them as `./ckpts/mocov3_vitb.pth` or `./ckpts/mocov3_vitl.pth` 71 | - `jepa`: Download the ViT-H/14 model (ImageNet-1K) from the [`I-JEPA`](https://github.com/facebookresearch/ijepa) repository and place it as `./ckpts/ijepa_vith.pth` 72 | - `mae`: Download the ViT-L model from [`MAE`](https://github.com/facebookresearch/mae) repository and place it as `./ckpts/mae_vitl.pth` 73 | 74 | **[12/17/2024]**: We also support training on 512x512 resolution (ImageNet) and a text-to-image generation on MS-COCO. 75 | 76 | For ImageNet 512x512, please use the following script: 77 | 78 | ```bash 79 | accelerate launch train.py \ 80 | --report-to="wandb" \ 81 | --allow-tf32 \ 82 | --mixed-precision="fp16" \ 83 | --seed=0 \ 84 | --path-type="linear" \ 85 | --prediction="v" \ 86 | --weighting="uniform" \ 87 | --model="SiT-XL/2" \ 88 | --enc-type="dinov2-vit-b" \ 89 | --proj-coeff=0.5 \ 90 | --encoder-depth=8 \ 91 | --output-dir="exps" \ 92 | --exp-name="linear-dinov2-b-enc8-in512" \ 93 | --resolution=512 \ 94 | --data-dir=[YOUR_DATA_PATH] 95 | ``` 96 | 97 | You also need a new data preprocessing that resizes each image to 512x512 resolution and encodes each image as 64x64 resolution latent vectors (using stable-diffusion VAE). This script is also provided in our preprocessing guide. 98 | 99 | For text-to-image generation, please follow the data preprocessing protocol in [U-ViT](https://github.com/baofff/U-ViT/tree/main/scripts) before lanuching experiments. After that, you should be able to lanuch an experiment through the following script: 100 | 101 | ```bash 102 | accelerate launch train_t2i.py \ 103 | --report-to="wandb" \ 104 | --allow-tf32 \ 105 | --mixed-precision="fp16" \ 106 | --seed=0 \ 107 | --path-type="linear" \ 108 | --prediction="v" \ 109 | --weighting="uniform" \ 110 | --enc-type="dinov2-vit-b" \ 111 | --proj-coeff=0.5 \ 112 | --encoder-depth=8 \ 113 | --output-dir="exps" \ 114 | --exp-name="t2i_repa" \ 115 | --data-dir=[YOUR_DATA_PATH] 116 | ``` 117 | 118 | 119 | ### 4. Evaluation 120 | 121 | You can generate images (and the .npz file can be used for [ADM evaluation](https://github.com/openai/guided-diffusion/tree/main/evaluations) suite) through the following script: 122 | 123 | ```bash 124 | torchrun --nnodes=1 --nproc_per_node=8 generate.py \ 125 | --model SiT-XL/2 \ 126 | --num-fid-samples 50000 \ 127 | --ckpt YOUR_CHECKPOINT_PATH \ 128 | --path-type=linear \ 129 | --encoder-depth=8 \ 130 | --projector-embed-dims=768 \ 131 | --per-proc-batch-size=64 \ 132 | --mode=sde \ 133 | --num-steps=250 \ 134 | --cfg-scale=1.8 \ 135 | --guidance-high=0.7 136 | ``` 137 | 138 | We also provide the SiT-XL/2 checkpoint (trained for 4M iterations) used in the final evaluation. It will be automatically downloaded if you do not specify `--ckpt`. 139 | 140 | ### Note 141 | 142 | It's possible that this code may not accurately replicate the results outlined in the paper due to potential human errors during the preparation and cleaning of the code for release. If you encounter any difficulties in reproducing our findings, please don't hesitate to inform us. Additionally, we'll make an effort to carry out sanity-check experiments in the near future. 143 | 144 | ## Acknowledgement 145 | 146 | This code is mainly built upon [DiT](https://github.com/facebookresearch/DiT), [SiT](https://github.com/willisma/SiT), [edm2](https://github.com/NVlabs/edm2), and [RCG](https://github.com/LTH14/rcg) repositories.\ 147 | We also appreciate [Kyungmin Lee](https://kyungmnlee.github.io/) for providing the initial version of the implementation. 148 | 149 | ## BibTeX 150 | 151 | ```bibtex 152 | @inproceedings{yu2025repa, 153 | title={Representation Alignment for Generation: Training Diffusion Transformers Is Easier Than You Think}, 154 | author={Sihyun Yu and Sangkyung Kwak and Huiwon Jang and Jongheon Jeong and Jonathan Huang and Jinwoo Shin and Saining Xie}, 155 | year={2025}, 156 | booktitle={International Conference on Learning Representations}, 157 | } 158 | ``` 159 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import glob 5 | import torch 6 | from torch.utils.data import Dataset 7 | import numpy as np 8 | 9 | from PIL import Image 10 | import PIL.Image 11 | try: 12 | import pyspng 13 | except ImportError: 14 | pyspng = None 15 | 16 | 17 | class CustomDataset(Dataset): 18 | def __init__(self, data_dir): 19 | PIL.Image.init() 20 | supported_ext = PIL.Image.EXTENSION.keys() | {'.npy'} 21 | 22 | self.images_dir = os.path.join(data_dir, 'images') 23 | self.features_dir = os.path.join(data_dir, 'vae-sd') 24 | 25 | # images 26 | self._image_fnames = { 27 | os.path.relpath(os.path.join(root, fname), start=self.images_dir) 28 | for root, _dirs, files in os.walk(self.images_dir) for fname in files 29 | } 30 | self.image_fnames = sorted( 31 | fname for fname in self._image_fnames if self._file_ext(fname) in supported_ext 32 | ) 33 | # features 34 | self._feature_fnames = { 35 | os.path.relpath(os.path.join(root, fname), start=self.features_dir) 36 | for root, _dirs, files in os.walk(self.features_dir) for fname in files 37 | } 38 | self.feature_fnames = sorted( 39 | fname for fname in self._feature_fnames if self._file_ext(fname) in supported_ext 40 | ) 41 | # labels 42 | fname = 'dataset.json' 43 | with open(os.path.join(self.features_dir, fname), 'rb') as f: 44 | labels = json.load(f)['labels'] 45 | labels = dict(labels) 46 | labels = [labels[fname.replace('\\', '/')] for fname in self.feature_fnames] 47 | labels = np.array(labels) 48 | self.labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) 49 | 50 | 51 | def _file_ext(self, fname): 52 | return os.path.splitext(fname)[1].lower() 53 | 54 | def __len__(self): 55 | assert len(self.image_fnames) == len(self.feature_fnames), \ 56 | "Number of feature files and label files should be same" 57 | return len(self.feature_fnames) 58 | 59 | def __getitem__(self, idx): 60 | image_fname = self.image_fnames[idx] 61 | feature_fname = self.feature_fnames[idx] 62 | image_ext = self._file_ext(image_fname) 63 | with open(os.path.join(self.images_dir, image_fname), 'rb') as f: 64 | if image_ext == '.npy': 65 | image = np.load(f) 66 | image = image.reshape(-1, *image.shape[-2:]) 67 | elif image_ext == '.png' and pyspng is not None: 68 | image = pyspng.load(f.read()) 69 | image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1) 70 | else: 71 | image = np.array(PIL.Image.open(f)) 72 | image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1) 73 | 74 | features = np.load(os.path.join(self.features_dir, feature_fname)) 75 | return torch.from_numpy(image), torch.from_numpy(features), torch.tensor(self.labels[idx]) 76 | 77 | def get_feature_dir_info(root): 78 | files = glob.glob(os.path.join(root, '*.npy')) 79 | files_caption = glob.glob(os.path.join(root, '*_*.npy')) 80 | num_data = len(files) - len(files_caption) 81 | n_captions = {k: 0 for k in range(num_data)} 82 | for f in files_caption: 83 | name = os.path.split(f)[-1] 84 | k1, k2 = os.path.splitext(name)[0].split('_') 85 | n_captions[int(k1)] += 1 86 | return num_data, n_captions 87 | 88 | 89 | class DatasetFactory(object): 90 | 91 | def __init__(self): 92 | self.train = None 93 | self.test = None 94 | 95 | def get_split(self, split, labeled=False): 96 | if split == "train": 97 | dataset = self.train 98 | elif split == "test": 99 | dataset = self.test 100 | else: 101 | raise ValueError 102 | 103 | if self.has_label: 104 | return dataset #if labeled else UnlabeledDataset(dataset) 105 | else: 106 | assert not labeled 107 | return dataset 108 | 109 | def unpreprocess(self, v): # to B C H W and [0, 1] 110 | v = 0.5 * (v + 1.) 111 | v.clamp_(0., 1.) 112 | return v 113 | 114 | @property 115 | def has_label(self): 116 | return True 117 | 118 | @property 119 | def data_shape(self): 120 | raise NotImplementedError 121 | 122 | @property 123 | def data_dim(self): 124 | return int(np.prod(self.data_shape)) 125 | 126 | @property 127 | def fid_stat(self): 128 | return None 129 | 130 | def sample_label(self, n_samples, device): 131 | raise NotImplementedError 132 | 133 | def label_prob(self, k): 134 | raise NotImplementedError 135 | 136 | class MSCOCOFeatureDataset(Dataset): 137 | # the image features are got through sample 138 | def __init__(self, root): 139 | self.root = root 140 | self.num_data, self.n_captions = get_feature_dir_info(root) 141 | 142 | def __len__(self): 143 | return self.num_data 144 | 145 | def __getitem__(self, index): 146 | with open(os.path.join(self.root, f'{index}.png'), 'rb') as f: 147 | x = np.array(PIL.Image.open(f)) 148 | x = x.reshape(*x.shape[:2], -1).transpose(2, 0, 1) 149 | 150 | z = np.load(os.path.join(self.root, f'{index}.npy')) 151 | k = random.randint(0, self.n_captions[index] - 1) 152 | c = np.load(os.path.join(self.root, f'{index}_{k}.npy')) 153 | return x, z, c 154 | 155 | 156 | class CFGDataset(Dataset): # for classifier free guidance 157 | def __init__(self, dataset, p_uncond, empty_token): 158 | self.dataset = dataset 159 | self.p_uncond = p_uncond 160 | self.empty_token = empty_token 161 | 162 | def __len__(self): 163 | return len(self.dataset) 164 | 165 | def __getitem__(self, item): 166 | x, z, y = self.dataset[item] 167 | if random.random() < self.p_uncond: 168 | y = self.empty_token 169 | return x, z, y 170 | 171 | class MSCOCO256Features(DatasetFactory): # the moments calculated by Stable Diffusion image encoder & the contexts calculated by clip 172 | def __init__(self, path, cfg=True, p_uncond=0.1, mode='train'): 173 | super().__init__() 174 | print('Prepare dataset...') 175 | if mode == 'val': 176 | self.test = MSCOCOFeatureDataset(os.path.join(path, 'val')) 177 | assert len(self.test) == 40504 178 | self.empty_context = np.load(os.path.join(path, 'empty_context.npy')) 179 | else: 180 | self.train = MSCOCOFeatureDataset(os.path.join(path, 'train')) 181 | assert len(self.train) == 82783 182 | self.empty_context = np.load(os.path.join(path, 'empty_context.npy')) 183 | 184 | if cfg: # classifier free guidance 185 | assert p_uncond is not None 186 | print(f'prepare the dataset for classifier free guidance with p_uncond={p_uncond}') 187 | self.train = CFGDataset(self.train, p_uncond, self.empty_context) 188 | 189 | @property 190 | def data_shape(self): 191 | return 4, 32, 32 192 | 193 | @property 194 | def fid_stat(self): 195 | return f'assets/fid_stats/fid_stats_mscoco256_val.npz' -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Samples a large number of images from a pre-trained SiT model using DDP. 9 | Subsequently saves a .npz file that can be used to compute FID and other 10 | evaluation metrics via the ADM repo: https://github.com/openai/guided-diffusion/tree/main/evaluations 11 | 12 | For a simple single-GPU/CPU sampling script, see sample.py. 13 | """ 14 | import torch 15 | import torch.distributed as dist 16 | from models.sit import SiT_models 17 | from diffusers.models import AutoencoderKL 18 | from tqdm import tqdm 19 | import os 20 | from PIL import Image 21 | import numpy as np 22 | import math 23 | import argparse 24 | from samplers import euler_sampler, euler_maruyama_sampler 25 | from utils import load_legacy_checkpoints, download_model 26 | 27 | def create_npz_from_sample_folder(sample_dir, num=50_000): 28 | """ 29 | Builds a single .npz file from a folder of .png samples. 30 | """ 31 | samples = [] 32 | for i in tqdm(range(num), desc="Building .npz file from samples"): 33 | sample_pil = Image.open(f"{sample_dir}/{i:06d}.png") 34 | sample_np = np.asarray(sample_pil).astype(np.uint8) 35 | samples.append(sample_np) 36 | samples = np.stack(samples) 37 | assert samples.shape == (num, samples.shape[1], samples.shape[2], 3) 38 | npz_path = f"{sample_dir}.npz" 39 | np.savez(npz_path, arr_0=samples) 40 | print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") 41 | return npz_path 42 | 43 | 44 | def main(args): 45 | """ 46 | Run sampling. 47 | """ 48 | torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences 49 | assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" 50 | torch.set_grad_enabled(False) 51 | 52 | # Setup DDP:cd 53 | dist.init_process_group("nccl") 54 | rank = dist.get_rank() 55 | device = rank % torch.cuda.device_count() 56 | seed = args.global_seed * dist.get_world_size() + rank 57 | torch.manual_seed(seed) 58 | torch.cuda.set_device(device) 59 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") 60 | 61 | # Load model: 62 | block_kwargs = {"fused_attn": args.fused_attn, "qk_norm": args.qk_norm} 63 | latent_size = args.resolution // 8 64 | model = SiT_models[args.model]( 65 | input_size=latent_size, 66 | num_classes=args.num_classes, 67 | use_cfg = True, 68 | z_dims = [int(z_dim) for z_dim in args.projector_embed_dims.split(',')], 69 | encoder_depth=args.encoder_depth, 70 | **block_kwargs, 71 | ).to(device) 72 | # Auto-download a pre-trained model or load a custom SiT checkpoint from train.py: 73 | ckpt_path = args.ckpt 74 | if ckpt_path is None: 75 | args.ckpt = 'SiT-XL-2-256x256.pt' 76 | assert args.model == 'SiT-XL/2' 77 | assert len(args.projector_embed_dims.split(',')) == 1 78 | assert int(args.projector_embed_dims.split(',')[0]) == 768 79 | state_dict = download_model('last.pt') 80 | else: 81 | state_dict = torch.load(ckpt_path, map_location=f'cuda:{device}')['ema'] 82 | if args.legacy: 83 | state_dict = load_legacy_checkpoints( 84 | state_dict=state_dict, encoder_depth=args.encoder_depth 85 | ) 86 | model.load_state_dict(state_dict) 87 | model.eval() # important! 88 | vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device) 89 | assert args.cfg_scale >= 1.0, "In almost all cases, cfg_scale be >= 1.0" 90 | using_cfg = args.cfg_scale > 1.0 91 | 92 | # Create folder to save samples: 93 | model_string_name = args.model.replace("/", "-") 94 | ckpt_string_name = os.path.basename(args.ckpt).replace(".pt", "") if args.ckpt else "pretrained" 95 | folder_name = f"{model_string_name}-{ckpt_string_name}-size-{args.resolution}-vae-{args.vae}-" \ 96 | f"cfg-{args.cfg_scale}-seed-{args.global_seed}-{args.mode}" 97 | sample_folder_dir = f"{args.sample_dir}/{folder_name}" 98 | if rank == 0: 99 | os.makedirs(sample_folder_dir, exist_ok=True) 100 | print(f"Saving .png samples at {sample_folder_dir}") 101 | dist.barrier() 102 | 103 | # Figure out how many samples we need to generate on each GPU and how many iterations we need to run: 104 | n = args.per_proc_batch_size 105 | global_batch_size = n * dist.get_world_size() 106 | # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples: 107 | total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size) 108 | if rank == 0: 109 | print(f"Total number of images that will be sampled: {total_samples}") 110 | print(f"SiT Parameters: {sum(p.numel() for p in model.parameters()):,}") 111 | print(f"projector Parameters: {sum(p.numel() for p in model.projectors.parameters()):,}") 112 | assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size" 113 | samples_needed_this_gpu = int(total_samples // dist.get_world_size()) 114 | assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size" 115 | iterations = int(samples_needed_this_gpu // n) 116 | pbar = range(iterations) 117 | pbar = tqdm(pbar) if rank == 0 else pbar 118 | total = 0 119 | for _ in pbar: 120 | # Sample inputs: 121 | z = torch.randn(n, model.in_channels, latent_size, latent_size, device=device) 122 | y = torch.randint(0, args.num_classes, (n,), device=device) 123 | 124 | # Sample images: 125 | sampling_kwargs = dict( 126 | model=model, 127 | latents=z, 128 | y=y, 129 | num_steps=args.num_steps, 130 | heun=args.heun, 131 | cfg_scale=args.cfg_scale, 132 | guidance_low=args.guidance_low, 133 | guidance_high=args.guidance_high, 134 | path_type=args.path_type, 135 | ) 136 | with torch.no_grad(): 137 | if args.mode == "sde": 138 | samples = euler_maruyama_sampler(**sampling_kwargs).to(torch.float32) 139 | elif args.mode == "ode": 140 | samples = euler_sampler(**sampling_kwargs).to(torch.float32) 141 | else: 142 | raise NotImplementedError() 143 | 144 | latents_scale = torch.tensor( 145 | [0.18215, 0.18215, 0.18215, 0.18215, ] 146 | ).view(1, 4, 1, 1).to(device) 147 | latents_bias = -torch.tensor( 148 | [0., 0., 0., 0.,] 149 | ).view(1, 4, 1, 1).to(device) 150 | samples = vae.decode((samples - latents_bias) / latents_scale).sample 151 | samples = (samples + 1) / 2. 152 | samples = torch.clamp( 153 | 255. * samples, 0, 255 154 | ).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() 155 | 156 | # Save samples to disk as individual .png files 157 | for i, sample in enumerate(samples): 158 | index = i * dist.get_world_size() + rank + total 159 | Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png") 160 | total += global_batch_size 161 | 162 | # Make sure all processes have finished saving their samples before attempting to convert to .npz 163 | dist.barrier() 164 | if rank == 0: 165 | create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples) 166 | print("Done.") 167 | dist.barrier() 168 | dist.destroy_process_group() 169 | 170 | 171 | if __name__ == "__main__": 172 | parser = argparse.ArgumentParser() 173 | # seed 174 | parser.add_argument("--global-seed", type=int, default=0) 175 | 176 | # precision 177 | parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True, 178 | help="By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.") 179 | 180 | # logging/saving: 181 | parser.add_argument("--ckpt", type=str, default=None, help="Optional path to a SiT checkpoint.") 182 | parser.add_argument("--sample-dir", type=str, default="samples") 183 | 184 | # model 185 | parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2") 186 | parser.add_argument("--num-classes", type=int, default=1000) 187 | parser.add_argument("--encoder-depth", type=int, default=8) 188 | parser.add_argument("--resolution", type=int, choices=[256, 512], default=256) 189 | parser.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=False) 190 | parser.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=False) 191 | 192 | # vae 193 | parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema") 194 | 195 | # number of samples 196 | parser.add_argument("--per-proc-batch-size", type=int, default=32) 197 | parser.add_argument("--num-fid-samples", type=int, default=50_000) 198 | 199 | # sampling related hyperparameters 200 | parser.add_argument("--mode", type=str, default="ode") 201 | parser.add_argument("--cfg-scale", type=float, default=1.5) 202 | parser.add_argument("--projector-embed-dims", type=str, default="768,1024") 203 | parser.add_argument("--path-type", type=str, default="linear", choices=["linear", "cosine"]) 204 | parser.add_argument("--num-steps", type=int, default=50) 205 | parser.add_argument("--heun", action=argparse.BooleanOptionalAction, default=False) # only for ode 206 | parser.add_argument("--guidance-low", type=float, default=0.) 207 | parser.add_argument("--guidance-high", type=float, default=1.) 208 | 209 | # will be deprecated 210 | parser.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False) # only for ode 211 | 212 | 213 | args = parser.parse_args() 214 | main(args) 215 | -------------------------------------------------------------------------------- /generate_t2i.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Samples a large number of images from a pre-trained SiT model using DDP. 9 | Subsequently saves a .npz file that can be used to compute FID and other 10 | evaluation metrics via the ADM repo: https://github.com/openai/guided-diffusion/tree/main/evaluations 11 | 12 | For a simple single-GPU/CPU sampling script, see sample.py. 13 | """ 14 | import torch 15 | import torch.distributed as dist 16 | from models.mmdit import MMDiT 17 | from diffusers.models import AutoencoderKL 18 | from tqdm import tqdm 19 | import os 20 | from PIL import Image 21 | import numpy as np 22 | import math 23 | import argparse 24 | from sampler_t2i import euler_sampler, euler_maruyama_sampler 25 | from utils import load_legacy_checkpoints, download_model 26 | 27 | from dataset import MSCOCO256Features 28 | from torch.utils.data import DataLoader 29 | 30 | from accelerate import Accelerator 31 | from accelerate.logging import get_logger 32 | from accelerate.utils import ProjectConfiguration, set_seed 33 | 34 | 35 | def create_npz_from_sample_folder(sample_dir, num=50_000): 36 | """ 37 | Builds a single .npz file from a folder of .png samples. 38 | """ 39 | samples = [] 40 | for i in tqdm(range(num), desc="Building .npz file from samples"): 41 | sample_pil = Image.open(f"{sample_dir}/{i:06d}.png") 42 | sample_np = np.asarray(sample_pil).astype(np.uint8) 43 | samples.append(sample_np) 44 | samples = np.stack(samples) 45 | assert samples.shape == (num, samples.shape[1], samples.shape[2], 3) 46 | npz_path = f"{sample_dir}.npz" 47 | np.savez(npz_path, arr_0=samples) 48 | print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") 49 | return npz_path 50 | 51 | 52 | def main(args): 53 | """ 54 | Run sampling. 55 | """ 56 | torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences 57 | assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" 58 | torch.set_grad_enabled(False) 59 | 60 | accelerator = Accelerator(mixed_precision=None) 61 | device = accelerator.device 62 | if args.global_seed is not None: 63 | set_seed(args.global_seed + accelerator.process_index) 64 | # Load model: 65 | block_kwargs = {"fused_attn": args.fused_attn, "qk_norm": args.qk_norm} 66 | latent_size = args.resolution // 8 67 | model = MMDiT( 68 | input_size=latent_size, 69 | z_dims = [int(z_dim) for z_dim in args.projector_embed_dims.split(',')], 70 | encoder_depth=args.encoder_depth, 71 | ).to(device) 72 | 73 | # Setup data: 74 | all_dataset = MSCOCO256Features(path='../data/coco256_features', mode='val', ret_caption=True) 75 | val_dataset = all_dataset.test 76 | y_null = torch.from_numpy(all_dataset.empty_context).to(device).unsqueeze(0) 77 | local_batch_size = args.per_proc_batch_size 78 | val_dataloader = DataLoader( 79 | val_dataset, 80 | batch_size=local_batch_size, 81 | shuffle=False, 82 | num_workers=4, 83 | pin_memory=True, 84 | drop_last=False 85 | ) 86 | 87 | val_dataloader = accelerator.prepare(val_dataloader) 88 | # Auto-download a pre-trained model or load a custom SiT checkpoint from train.py: 89 | ckpt_path = args.ckpt 90 | state_dict = torch.load(ckpt_path, map_location="cpu")['model'] 91 | model.load_state_dict(state_dict) 92 | model.eval() # important! 93 | vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device) 94 | assert args.cfg_scale >= 1.0, "In almost all cases, cfg_scale be >= 1.0" 95 | using_cfg = args.cfg_scale > 1.0 96 | 97 | # Create folder to save samples: 98 | if args.prefix == "": 99 | folder_name = f"coco-size-{args.resolution}-vae-{args.vae}-" \ 100 | f"cfg-{args.cfg_scale}-seed-{args.global_seed}-{args.mode}" 101 | else: 102 | folder_name = f"{args.prefix}-coco-size-{args.resolution}-vae-{args.vae}-" \ 103 | f"cfg-{args.cfg_scale}-seed-{args.global_seed}-{args.mode}" 104 | sample_folder_dir = f"{args.sample_dir}/{folder_name}" 105 | real_sample_folder_dir = f"{args.sample_dir}/{folder_name}_real" 106 | if accelerator.is_main_process: 107 | os.makedirs(sample_folder_dir, exist_ok=True) 108 | os.makedirs(real_sample_folder_dir, exist_ok=True) 109 | print(f"Saving .png samples at {sample_folder_dir}") 110 | dist.barrier() 111 | 112 | # Figure out how many samples we need to generate on each GPU and how many iterations we need to run: 113 | n = args.per_proc_batch_size 114 | global_batch_size = n * dist.get_world_size() 115 | # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples: 116 | total_samples = 40192 117 | if accelerator.is_main_process: 118 | print(f"Total number of images that will be sampled: {total_samples}") 119 | print(f"SiT Parameters: {sum(p.numel() for p in model.parameters()):,}") 120 | print(f"projector Parameters: {sum(p.numel() for p in model.projectors.parameters()):,}") 121 | assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size" 122 | samples_needed_this_gpu = int(total_samples // dist.get_world_size()) 123 | assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size" 124 | iterations = int(samples_needed_this_gpu // n) 125 | pbar = range(iterations) 126 | pbar = tqdm(pbar) if accelerator.is_main_process else pbar 127 | total = 0 128 | clipsim_sum = 0. 129 | from utils import ClipSimilarity 130 | clipsim_fn = ClipSimilarity(device=device) 131 | 132 | for raw_image, _, context, raw_captions in val_dataloader: 133 | # Sample inputs: 134 | z = torch.randn(n, model.in_channels, latent_size, latent_size, device=device) 135 | 136 | # Sample images: 137 | sampling_kwargs = dict( 138 | model=model, 139 | latents=z, 140 | y=context, 141 | y_null=y_null.repeat(context.shape[0], 1, 1), 142 | num_steps=args.num_steps, 143 | heun=args.heun, 144 | cfg_scale=args.cfg_scale, 145 | guidance_low=args.guidance_low, 146 | guidance_high=args.guidance_high, 147 | path_type=args.path_type, 148 | ) 149 | with torch.no_grad(): 150 | if args.mode == "sde": 151 | samples = euler_maruyama_sampler(**sampling_kwargs).to(torch.float32) 152 | elif args.mode == "ode": 153 | samples = euler_sampler(**sampling_kwargs).to(torch.float32) 154 | else: 155 | raise NotImplementedError() 156 | latents_scale = torch.tensor( 157 | [0.18215, 0.18215, 0.18215, 0.18215, ] 158 | ).view(1, 4, 1, 1).to(device) 159 | latents_bias = -torch.tensor( 160 | [0., 0., 0., 0.,] 161 | ).view(1, 4, 1, 1).to(device) 162 | samples = vae.decode((samples - latents_bias) / latents_scale).sample 163 | samples = (samples + 1) / 2. 164 | samples = torch.clamp( 165 | 255. * samples, 0, 255 166 | ).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() 167 | # real_samples = (raw_image).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() 168 | 169 | # Save samples to disk as individual .png files 170 | for i, sample in enumerate(samples): 171 | index = i * accelerator.num_processes + accelerator.local_process_index + total 172 | Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png") 173 | # Image.fromarray(real_sample).save(f"{real_sample_folder_dir}/{index:06d}.png") 174 | batch_clipsim = clipsim_fn( 175 | torch.from_numpy(samples/255.).to(device).permute(0, 3, 1, 2), raw_captions 176 | ) 177 | total += global_batch_size 178 | gather_clipsim_sum = [ 179 | torch.zeros_like(batch_clipsim) for _ in range(4) 180 | ] 181 | torch.distributed.all_gather(gather_clipsim_sum, batch_clipsim) 182 | gather_clipsim_sum = torch.cat(gather_clipsim_sum).sum() 183 | clipsim_sum += gather_clipsim_sum 184 | if accelerator.is_main_process: 185 | print(f"{total}: {clipsim_sum / total}") 186 | if accelerator.is_main_process: 187 | pbar.update(1) 188 | 189 | # Make sure all processes have finished saving their samples before attempting to convert to .npz 190 | dist.barrier() 191 | if accelerator.is_main_process: 192 | create_npz_from_sample_folder(sample_folder_dir, 40192) 193 | # create_npz_from_sample_folder(real_sample_folder_dir, 40192) 194 | print("Done.") 195 | dist.barrier() 196 | dist.destroy_process_group() 197 | 198 | if __name__ == "__main__": 199 | parser = argparse.ArgumentParser() 200 | # seed 201 | parser.add_argument("--global-seed", type=int, default=0) 202 | 203 | # precision 204 | parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True, 205 | help="By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.") 206 | 207 | # logging/saving: 208 | parser.add_argument("--ckpt", type=str, default=None, help="Optional path to a SiT checkpoint.") 209 | parser.add_argument("--sample-dir", type=str, default="samples") 210 | parser.add_argument("--prefix", type=str, default="") 211 | 212 | # model 213 | parser.add_argument("--num-classes", type=int, default=1000) 214 | parser.add_argument("--encoder-depth", type=int, default=8) 215 | parser.add_argument("--resolution", type=int, choices=[256, 512], default=256) 216 | parser.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=False) 217 | parser.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=False) 218 | 219 | # vae 220 | parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema") 221 | 222 | # number of samples 223 | parser.add_argument("--per-proc-batch-size", type=int, default=32) 224 | parser.add_argument("--num-fid-samples", type=int, default=50_000) 225 | 226 | # sampling related hyperparameters 227 | parser.add_argument("--mode", type=str, default="ode") 228 | parser.add_argument("--cfg-scale", type=float, default=1.5) 229 | parser.add_argument("--projector-embed-dims", type=str, default="768,1024") 230 | parser.add_argument("--path-type", type=str, default="linear", choices=["linear", "cosine"]) 231 | parser.add_argument("--num-steps", type=int, default=50) 232 | parser.add_argument("--heun", action=argparse.BooleanOptionalAction, default=False) # only for ode 233 | parser.add_argument("--guidance-low", type=float, default=0.) 234 | parser.add_argument("--guidance-high", type=float, default=1.) 235 | 236 | # will be deprecated 237 | parser.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False) # only for ode 238 | 239 | args = parser.parse_args() 240 | main(args) -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | 5 | def mean_flat(x): 6 | """ 7 | Take the mean over all non-batch dimensions. 8 | """ 9 | return torch.mean(x, dim=list(range(1, len(x.size())))) 10 | 11 | def sum_flat(x): 12 | """ 13 | Take the mean over all non-batch dimensions. 14 | """ 15 | return torch.sum(x, dim=list(range(1, len(x.size())))) 16 | 17 | class SILoss: 18 | def __init__( 19 | self, 20 | prediction='v', 21 | path_type="linear", 22 | weighting="uniform", 23 | encoders=[], 24 | accelerator=None, 25 | latents_scale=None, 26 | latents_bias=None, 27 | ): 28 | self.prediction = prediction 29 | self.weighting = weighting 30 | self.path_type = path_type 31 | self.encoders = encoders 32 | self.accelerator = accelerator 33 | self.latents_scale = latents_scale 34 | self.latents_bias = latents_bias 35 | 36 | def interpolant(self, t): 37 | if self.path_type == "linear": 38 | alpha_t = 1 - t 39 | sigma_t = t 40 | d_alpha_t = -1 41 | d_sigma_t = 1 42 | elif self.path_type == "cosine": 43 | alpha_t = torch.cos(t * np.pi / 2) 44 | sigma_t = torch.sin(t * np.pi / 2) 45 | d_alpha_t = -np.pi / 2 * torch.sin(t * np.pi / 2) 46 | d_sigma_t = np.pi / 2 * torch.cos(t * np.pi / 2) 47 | else: 48 | raise NotImplementedError() 49 | 50 | return alpha_t, sigma_t, d_alpha_t, d_sigma_t 51 | 52 | def __call__(self, model, images, model_kwargs=None, zs=None): 53 | if model_kwargs == None: 54 | model_kwargs = {} 55 | # sample timesteps 56 | if self.weighting == "uniform": 57 | time_input = torch.rand((images.shape[0], 1, 1, 1)) 58 | elif self.weighting == "lognormal": 59 | # sample timestep according to log-normal distribution of sigmas following EDM 60 | rnd_normal = torch.randn((images.shape[0], 1 ,1, 1)) 61 | sigma = rnd_normal.exp() 62 | if self.path_type == "linear": 63 | time_input = sigma / (1 + sigma) 64 | elif self.path_type == "cosine": 65 | time_input = 2 / np.pi * torch.atan(sigma) 66 | 67 | time_input = time_input.to(device=images.device, dtype=images.dtype) 68 | 69 | noises = torch.randn_like(images) 70 | alpha_t, sigma_t, d_alpha_t, d_sigma_t = self.interpolant(time_input) 71 | 72 | model_input = alpha_t * images + sigma_t * noises 73 | if self.prediction == 'v': 74 | model_target = d_alpha_t * images + d_sigma_t * noises 75 | else: 76 | raise NotImplementedError() # TODO: add x or eps prediction 77 | model_output, zs_tilde = model(model_input, time_input.flatten(), **model_kwargs) 78 | denoising_loss = mean_flat((model_output - model_target) ** 2) 79 | 80 | # projection loss 81 | proj_loss = 0. 82 | bsz = zs[0].shape[0] 83 | for i, (z, z_tilde) in enumerate(zip(zs, zs_tilde)): 84 | for j, (z_j, z_tilde_j) in enumerate(zip(z, z_tilde)): 85 | z_tilde_j = torch.nn.functional.normalize(z_tilde_j, dim=-1) 86 | z_j = torch.nn.functional.normalize(z_j, dim=-1) 87 | proj_loss += mean_flat(-(z_j * z_tilde_j).sum(dim=-1)) 88 | proj_loss /= (len(zs) * bsz) 89 | 90 | return denoising_loss, proj_loss 91 | -------------------------------------------------------------------------------- /models/clip_vit.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | import clip 10 | 11 | 12 | class Bottleneck(nn.Module): 13 | expansion = 4 14 | 15 | def __init__(self, inplanes, planes, stride=1): 16 | super().__init__() 17 | 18 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 19 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 20 | self.bn1 = nn.BatchNorm2d(planes) 21 | self.relu1 = nn.ReLU(inplace=True) 22 | 23 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 24 | self.bn2 = nn.BatchNorm2d(planes) 25 | self.relu2 = nn.ReLU(inplace=True) 26 | 27 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 28 | 29 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 30 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 31 | self.relu3 = nn.ReLU(inplace=True) 32 | 33 | self.downsample = None 34 | self.stride = stride 35 | 36 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 37 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 38 | self.downsample = nn.Sequential(OrderedDict([ 39 | ("-1", nn.AvgPool2d(stride)), 40 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 41 | ("1", nn.BatchNorm2d(planes * self.expansion)) 42 | ])) 43 | 44 | def forward(self, x: torch.Tensor): 45 | identity = x 46 | 47 | out = self.relu1(self.bn1(self.conv1(x))) 48 | out = self.relu2(self.bn2(self.conv2(out))) 49 | out = self.avgpool(out) 50 | out = self.bn3(self.conv3(out)) 51 | 52 | if self.downsample is not None: 53 | identity = self.downsample(x) 54 | 55 | out += identity 56 | out = self.relu3(out) 57 | return out 58 | 59 | 60 | class AttentionPool2d(nn.Module): 61 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 62 | super().__init__() 63 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 64 | self.k_proj = nn.Linear(embed_dim, embed_dim) 65 | self.q_proj = nn.Linear(embed_dim, embed_dim) 66 | self.v_proj = nn.Linear(embed_dim, embed_dim) 67 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 68 | self.num_heads = num_heads 69 | 70 | def forward(self, x): 71 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC 72 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 73 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 74 | x, _ = F.multi_head_attention_forward( 75 | query=x[:1], key=x, value=x, 76 | embed_dim_to_check=x.shape[-1], 77 | num_heads=self.num_heads, 78 | q_proj_weight=self.q_proj.weight, 79 | k_proj_weight=self.k_proj.weight, 80 | v_proj_weight=self.v_proj.weight, 81 | in_proj_weight=None, 82 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 83 | bias_k=None, 84 | bias_v=None, 85 | add_zero_attn=False, 86 | dropout_p=0, 87 | out_proj_weight=self.c_proj.weight, 88 | out_proj_bias=self.c_proj.bias, 89 | use_separate_proj_weight=True, 90 | training=self.training, 91 | need_weights=False 92 | ) 93 | return x.squeeze(0) 94 | 95 | 96 | class ModifiedResNet(nn.Module): 97 | """ 98 | A ResNet class that is similar to torchvision's but contains the following changes: 99 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 100 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 101 | - The final pooling layer is a QKV attention instead of an average pool 102 | """ 103 | 104 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 105 | super().__init__() 106 | self.output_dim = output_dim 107 | self.input_resolution = input_resolution 108 | 109 | # the 3-layer stem 110 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 111 | self.bn1 = nn.BatchNorm2d(width // 2) 112 | self.relu1 = nn.ReLU(inplace=True) 113 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 114 | self.bn2 = nn.BatchNorm2d(width // 2) 115 | self.relu2 = nn.ReLU(inplace=True) 116 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 117 | self.bn3 = nn.BatchNorm2d(width) 118 | self.relu3 = nn.ReLU(inplace=True) 119 | self.avgpool = nn.AvgPool2d(2) 120 | 121 | # residual layers 122 | self._inplanes = width # this is a *mutable* variable used during construction 123 | self.layer1 = self._make_layer(width, layers[0]) 124 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 125 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 126 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 127 | 128 | embed_dim = width * 32 # the ResNet feature dimension 129 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 130 | 131 | def _make_layer(self, planes, blocks, stride=1): 132 | layers = [Bottleneck(self._inplanes, planes, stride)] 133 | 134 | self._inplanes = planes * Bottleneck.expansion 135 | for _ in range(1, blocks): 136 | layers.append(Bottleneck(self._inplanes, planes)) 137 | 138 | return nn.Sequential(*layers) 139 | 140 | def forward(self, x): 141 | def stem(x): 142 | x = self.relu1(self.bn1(self.conv1(x))) 143 | x = self.relu2(self.bn2(self.conv2(x))) 144 | x = self.relu3(self.bn3(self.conv3(x))) 145 | x = self.avgpool(x) 146 | return x 147 | 148 | x = x.type(self.conv1.weight.dtype) 149 | x = stem(x) 150 | x = self.layer1(x) 151 | x = self.layer2(x) 152 | x = self.layer3(x) 153 | x = self.layer4(x) 154 | x = self.attnpool(x) 155 | 156 | return x 157 | 158 | 159 | class LayerNorm(nn.LayerNorm): 160 | """Subclass torch's LayerNorm to handle fp16.""" 161 | 162 | def forward(self, x: torch.Tensor): 163 | orig_type = x.dtype 164 | ret = super().forward(x.type(torch.float32)) 165 | return ret.type(orig_type) 166 | 167 | 168 | class QuickGELU(nn.Module): 169 | def forward(self, x: torch.Tensor): 170 | return x * torch.sigmoid(1.702 * x) 171 | 172 | 173 | class ResidualAttentionBlock(nn.Module): 174 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 175 | super().__init__() 176 | 177 | self.attn = nn.MultiheadAttention(d_model, n_head) 178 | self.ln_1 = LayerNorm(d_model) 179 | self.mlp = nn.Sequential(OrderedDict([ 180 | ("c_fc", nn.Linear(d_model, d_model * 4)), 181 | ("gelu", QuickGELU()), 182 | ("c_proj", nn.Linear(d_model * 4, d_model)) 183 | ])) 184 | self.ln_2 = LayerNorm(d_model) 185 | self.attn_mask = attn_mask 186 | 187 | def attention(self, x: torch.Tensor): 188 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 189 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 190 | 191 | def forward(self, x: torch.Tensor): 192 | x = x + self.attention(self.ln_1(x)) 193 | x = x + self.mlp(self.ln_2(x)) 194 | return x 195 | 196 | 197 | class Transformer(nn.Module): 198 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 199 | super().__init__() 200 | self.width = width 201 | self.layers = layers 202 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 203 | 204 | def forward(self, x: torch.Tensor): 205 | return self.resblocks(x) 206 | 207 | 208 | class UpdatedVisionTransformer(nn.Module): 209 | def __init__(self, model): 210 | super().__init__() 211 | self.model = model 212 | 213 | def forward(self, x: torch.Tensor): 214 | x = self.model.conv1(x) # shape = [*, width, grid, grid] 215 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 216 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 217 | x = torch.cat([self.model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 218 | x = x + self.model.positional_embedding.to(x.dtype) 219 | x = self.model.ln_pre(x) 220 | 221 | x = x.permute(1, 0, 2) # NLD -> LND 222 | x = self.model.transformer(x) 223 | x = x.permute(1, 0, 2)[:, 1:] # LND -> NLD 224 | 225 | # x = self.ln_post(x[:, 0, :]) 226 | 227 | # if self.proj is not None: 228 | # x = x @ self.proj 229 | 230 | return x 231 | 232 | 233 | class CLIP(nn.Module): 234 | def __init__(self, 235 | embed_dim: int, 236 | # vision 237 | image_resolution: int, 238 | vision_layers: Union[Tuple[int, int, int, int], int], 239 | vision_width: int, 240 | vision_patch_size: int, 241 | # text 242 | context_length: int, 243 | vocab_size: int, 244 | transformer_width: int, 245 | transformer_heads: int, 246 | transformer_layers: int 247 | ): 248 | super().__init__() 249 | 250 | self.context_length = context_length 251 | 252 | if isinstance(vision_layers, (tuple, list)): 253 | vision_heads = vision_width * 32 // 64 254 | self.visual = ModifiedResNet( 255 | layers=vision_layers, 256 | output_dim=embed_dim, 257 | heads=vision_heads, 258 | input_resolution=image_resolution, 259 | width=vision_width 260 | ) 261 | else: 262 | vision_heads = vision_width // 64 263 | self.visual = UpdatedVisionTransformer( 264 | input_resolution=image_resolution, 265 | patch_size=vision_patch_size, 266 | width=vision_width, 267 | layers=vision_layers, 268 | heads=vision_heads, 269 | output_dim=embed_dim 270 | ) 271 | 272 | self.transformer = Transformer( 273 | width=transformer_width, 274 | layers=transformer_layers, 275 | heads=transformer_heads, 276 | attn_mask=self.build_attention_mask() 277 | ) 278 | 279 | self.vocab_size = vocab_size 280 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 281 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 282 | self.ln_final = LayerNorm(transformer_width) 283 | 284 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 285 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 286 | 287 | self.initialize_parameters() 288 | 289 | def initialize_parameters(self): 290 | nn.init.normal_(self.token_embedding.weight, std=0.02) 291 | nn.init.normal_(self.positional_embedding, std=0.01) 292 | 293 | if isinstance(self.visual, ModifiedResNet): 294 | if self.visual.attnpool is not None: 295 | std = self.visual.attnpool.c_proj.in_features ** -0.5 296 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 297 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 298 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 299 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 300 | 301 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 302 | for name, param in resnet_block.named_parameters(): 303 | if name.endswith("bn3.weight"): 304 | nn.init.zeros_(param) 305 | 306 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 307 | attn_std = self.transformer.width ** -0.5 308 | fc_std = (2 * self.transformer.width) ** -0.5 309 | for block in self.transformer.resblocks: 310 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 311 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 312 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 313 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 314 | 315 | if self.text_projection is not None: 316 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 317 | 318 | def build_attention_mask(self): 319 | # lazily create causal attention mask, with full attention between the vision tokens 320 | # pytorch uses additive attention mask; fill with -inf 321 | mask = torch.empty(self.context_length, self.context_length) 322 | mask.fill_(float("-inf")) 323 | mask.triu_(1) # zero out the lower diagonal 324 | return mask 325 | 326 | @property 327 | def dtype(self): 328 | return self.visual.conv1.weight.dtype 329 | 330 | def encode_image(self, image): 331 | return self.visual(image.type(self.dtype)) 332 | 333 | def encode_text(self, text): 334 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 335 | 336 | x = x + self.positional_embedding.type(self.dtype) 337 | x = x.permute(1, 0, 2) # NLD -> LND 338 | x = self.transformer(x) 339 | x = x.permute(1, 0, 2) # LND -> NLD 340 | x = self.ln_final(x).type(self.dtype) 341 | 342 | # x.shape = [batch_size, n_ctx, transformer.width] 343 | # take features from the eot embedding (eot_token is the highest number in each sequence) 344 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 345 | 346 | return x 347 | 348 | def forward(self, image, text): 349 | image_features = self.encode_image(image) 350 | text_features = self.encode_text(text) 351 | 352 | # normalized features 353 | image_features = image_features / image_features.norm(dim=1, keepdim=True) 354 | text_features = text_features / text_features.norm(dim=1, keepdim=True) 355 | 356 | # cosine similarity as logits 357 | logit_scale = self.logit_scale.exp() 358 | logits_per_image = logit_scale * image_features @ text_features.t() 359 | logits_per_text = logits_per_image.t() 360 | 361 | # shape = [global_batch_size, global_batch_size] 362 | return logits_per_image, logits_per_text 363 | 364 | 365 | def convert_weights(model: nn.Module): 366 | """Convert applicable model parameters to fp16""" 367 | 368 | def _convert_weights_to_fp16(l): 369 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 370 | l.weight.data = l.weight.data.half() 371 | if l.bias is not None: 372 | l.bias.data = l.bias.data.half() 373 | 374 | if isinstance(l, nn.MultiheadAttention): 375 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 376 | tensor = getattr(l, attr) 377 | if tensor is not None: 378 | tensor.data = tensor.data.half() 379 | 380 | for name in ["text_projection", "proj"]: 381 | if hasattr(l, name): 382 | attr = getattr(l, name) 383 | if attr is not None: 384 | attr.data = attr.data.half() 385 | 386 | model.apply(_convert_weights_to_fp16) 387 | 388 | 389 | def build_model(state_dict: dict): 390 | vit = "visual.proj" in state_dict 391 | 392 | if vit: 393 | vision_width = state_dict["visual.conv1.weight"].shape[0] 394 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 395 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 396 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 397 | image_resolution = vision_patch_size * grid_size 398 | else: 399 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 400 | vision_layers = tuple(counts) 401 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 402 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 403 | vision_patch_size = None 404 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 405 | image_resolution = output_width * 32 406 | 407 | embed_dim = state_dict["text_projection"].shape[1] 408 | context_length = state_dict["positional_embedding"].shape[0] 409 | vocab_size = state_dict["token_embedding.weight"].shape[0] 410 | transformer_width = state_dict["ln_final.weight"].shape[0] 411 | transformer_heads = transformer_width // 64 412 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) 413 | 414 | model = CLIP( 415 | embed_dim, 416 | image_resolution, vision_layers, vision_width, vision_patch_size, 417 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 418 | ) 419 | 420 | for key in ["input_resolution", "context_length", "vocab_size"]: 421 | if key in state_dict: 422 | del state_dict[key] 423 | 424 | convert_weights(model) 425 | model.load_state_dict(state_dict) 426 | return model.eval() -------------------------------------------------------------------------------- /models/mae_vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | from functools import partial 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | import timm.models.vision_transformer 18 | 19 | 20 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 21 | """ Vision Transformer with support for global average pooling 22 | """ 23 | def __init__(self, global_pool=False, **kwargs): 24 | super(VisionTransformer, self).__init__(**kwargs) 25 | 26 | self.global_pool = global_pool 27 | if self.global_pool: 28 | norm_layer = kwargs['norm_layer'] 29 | embed_dim = kwargs['embed_dim'] 30 | self.fc_norm = norm_layer(embed_dim) 31 | 32 | del self.norm # remove the original norm 33 | 34 | def forward_features(self, x): 35 | B = x.shape[0] 36 | x = self.patch_embed(x) 37 | 38 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 39 | x = torch.cat((cls_tokens, x), dim=1) 40 | x = x + self.pos_embed 41 | x = self.pos_drop(x) 42 | 43 | for blk in self.blocks: 44 | x = blk(x) 45 | 46 | x = x[:, 1:, :] #.mean(dim=1) # global pool without cls token 47 | 48 | return x 49 | 50 | 51 | def vit_base_patch16(**kwargs): 52 | model = VisionTransformer( 53 | num_classes=0, 54 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 55 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 56 | return model 57 | 58 | 59 | def vit_large_patch16(**kwargs): 60 | model = VisionTransformer( 61 | num_classes=0, 62 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 63 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 64 | return model 65 | 66 | 67 | def vit_huge_patch14(**kwargs): 68 | model = VisionTransformer( 69 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, 70 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 71 | return model -------------------------------------------------------------------------------- /models/mocov3_vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | import torch 9 | import torch.nn as nn 10 | from functools import partial, reduce 11 | from operator import mul 12 | 13 | from timm.layers.helpers import to_2tuple 14 | from timm.models.vision_transformer import VisionTransformer, _cfg 15 | from timm.models.vision_transformer import PatchEmbed 16 | 17 | __all__ = [ 18 | 'vit_small', 19 | 'vit_base', 20 | 'vit_large', 21 | 'vit_conv_small', 22 | 'vit_conv_base', 23 | ] 24 | 25 | 26 | def patchify_avg(input_tensor, patch_size): 27 | # Ensure input tensor is 4D: (batch_size, channels, height, width) 28 | if input_tensor.dim() != 4: 29 | raise ValueError("Input tensor must be 4D (batch_size, channels, height, width)") 30 | 31 | # Get input tensor dimensions 32 | batch_size, channels, height, width = input_tensor.shape 33 | 34 | # Ensure patch_size is valid 35 | patch_height, patch_width = patch_size, patch_size 36 | if height % patch_height != 0 or width % patch_width != 0: 37 | raise ValueError("Input tensor dimensions must be divisible by patch_size") 38 | 39 | # Use unfold to create patches 40 | patches = input_tensor.unfold(2, patch_height, patch_height).unfold(3, patch_width, patch_width) 41 | 42 | # Reshape patches to desired format: (batch_size, num_patches, channels) 43 | patches = patches.contiguous().view( 44 | batch_size, channels, -1, patch_height, patch_width 45 | ).mean(dim=-1).mean(dim=-1) 46 | patches = patches.permute(0, 2, 1).contiguous() 47 | 48 | return patches 49 | 50 | 51 | 52 | class VisionTransformerMoCo(VisionTransformer): 53 | def __init__(self, stop_grad_conv1=False, **kwargs): 54 | super().__init__(**kwargs) 55 | # Use fixed 2D sin-cos position embedding 56 | self.build_2d_sincos_position_embedding() 57 | 58 | # weight initialization 59 | for name, m in self.named_modules(): 60 | if isinstance(m, nn.Linear): 61 | if 'qkv' in name: 62 | # treat the weights of Q, K, V separately 63 | val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1])) 64 | nn.init.uniform_(m.weight, -val, val) 65 | else: 66 | nn.init.xavier_uniform_(m.weight) 67 | nn.init.zeros_(m.bias) 68 | nn.init.normal_(self.cls_token, std=1e-6) 69 | 70 | if isinstance(self.patch_embed, PatchEmbed): 71 | # xavier_uniform initialization 72 | val = math.sqrt(6. / float(3 * reduce(mul, self.patch_embed.patch_size, 1) + self.embed_dim)) 73 | nn.init.uniform_(self.patch_embed.proj.weight, -val, val) 74 | nn.init.zeros_(self.patch_embed.proj.bias) 75 | 76 | if stop_grad_conv1: 77 | self.patch_embed.proj.weight.requires_grad = False 78 | self.patch_embed.proj.bias.requires_grad = False 79 | 80 | def build_2d_sincos_position_embedding(self, temperature=10000.): 81 | h = self.patch_embed.img_size[0] // self.patch_embed.patch_size[0] 82 | w = self.patch_embed.img_size[1] // self.patch_embed.patch_size[1] 83 | grid_w = torch.arange(w, dtype=torch.float32) 84 | grid_h = torch.arange(h, dtype=torch.float32) 85 | grid_w, grid_h = torch.meshgrid(grid_w, grid_h) 86 | assert self.embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding' 87 | pos_dim = self.embed_dim // 4 88 | omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim 89 | omega = 1. / (temperature**omega) 90 | out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega]) 91 | out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega]) 92 | pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :] 93 | 94 | # assert self.num_tokens == 1, 'Assuming one and only one token, [cls]' 95 | pe_token = torch.zeros([1, 1, self.embed_dim], dtype=torch.float32) 96 | self.pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1)) 97 | self.pos_embed.requires_grad = False 98 | 99 | def forward_diffusion_output(self, x): 100 | x = x.reshape(*x.shape[0:2], -1).permute(0, 2, 1) 101 | x = self._pos_embed(x) 102 | x = self.patch_drop(x) 103 | x = self.norm_pre(x) 104 | x = self.blocks(x) 105 | x = self.norm(x) 106 | return x 107 | 108 | class ConvStem(nn.Module): 109 | """ 110 | ConvStem, from Early Convolutions Help Transformers See Better, Tete et al. https://arxiv.org/abs/2106.14881 111 | """ 112 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): 113 | super().__init__() 114 | 115 | assert patch_size == 16, 'ConvStem only supports patch size of 16' 116 | assert embed_dim % 8 == 0, 'Embed dimension must be divisible by 8 for ConvStem' 117 | 118 | img_size = to_2tuple(img_size) 119 | patch_size = to_2tuple(patch_size) 120 | self.img_size = img_size 121 | self.patch_size = patch_size 122 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 123 | self.num_patches = self.grid_size[0] * self.grid_size[1] 124 | self.flatten = flatten 125 | 126 | # build stem, similar to the design in https://arxiv.org/abs/2106.14881 127 | stem = [] 128 | input_dim, output_dim = 3, embed_dim // 8 129 | for l in range(4): 130 | stem.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False)) 131 | stem.append(nn.BatchNorm2d(output_dim)) 132 | stem.append(nn.ReLU(inplace=True)) 133 | input_dim = output_dim 134 | output_dim *= 2 135 | stem.append(nn.Conv2d(input_dim, embed_dim, kernel_size=1)) 136 | self.proj = nn.Sequential(*stem) 137 | 138 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 139 | 140 | def forward(self, x): 141 | B, C, H, W = x.shape 142 | assert H == self.img_size[0] and W == self.img_size[1], \ 143 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 144 | x = self.proj(x) 145 | if self.flatten: 146 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 147 | x = self.norm(x) 148 | return x 149 | 150 | 151 | def vit_small(**kwargs): 152 | model = VisionTransformerMoCo( 153 | img_size=256, 154 | patch_size=16, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 155 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 156 | model.default_cfg = _cfg() 157 | return model 158 | 159 | def vit_base(**kwargs): 160 | model = VisionTransformerMoCo( 161 | img_size=256, 162 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 163 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 164 | model.default_cfg = _cfg() 165 | return model 166 | 167 | def vit_large(**kwargs): 168 | model = VisionTransformerMoCo( 169 | img_size=256, 170 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 171 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 172 | model.default_cfg = _cfg() 173 | return model 174 | 175 | def vit_conv_small(**kwargs): 176 | # minus one ViT block 177 | model = VisionTransformerMoCo( 178 | patch_size=16, embed_dim=384, depth=11, num_heads=12, mlp_ratio=4, qkv_bias=True, 179 | norm_layer=partial(nn.LayerNorm, eps=1e-6), embed_layer=ConvStem, **kwargs) 180 | model.default_cfg = _cfg() 181 | return model 182 | 183 | def vit_conv_base(**kwargs): 184 | # minus one ViT block 185 | model = VisionTransformerMoCo( 186 | patch_size=16, embed_dim=768, depth=11, num_heads=12, mlp_ratio=4, qkv_bias=True, 187 | norm_layer=partial(nn.LayerNorm, eps=1e-6), embed_layer=ConvStem, **kwargs) 188 | model.default_cfg = _cfg() 189 | return model 190 | 191 | def build_mlp(num_layers, input_dim, mlp_dim, output_dim, last_bn=True): 192 | mlp = [] 193 | for l in range(num_layers): 194 | dim1 = input_dim if l == 0 else mlp_dim 195 | dim2 = output_dim if l == num_layers - 1 else mlp_dim 196 | 197 | mlp.append(nn.Linear(dim1, dim2, bias=False)) 198 | 199 | if l < num_layers - 1: 200 | mlp.append(nn.BatchNorm1d(dim2)) 201 | mlp.append(nn.ReLU(inplace=True)) 202 | elif last_bn: 203 | # follow SimCLR's design: https://github.com/google-research/simclr/blob/master/model_util.py#L157 204 | # for simplicity, we further removed gamma in BN 205 | mlp.append(nn.BatchNorm1d(dim2, affine=False)) 206 | 207 | return nn.Sequential(*mlp) -------------------------------------------------------------------------------- /models/sit.py: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | # -------------------------------------------------------- 4 | # References: 5 | # GLIDE: https://github.com/openai/glide-text2im 6 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py 7 | # -------------------------------------------------------- 8 | 9 | import torch 10 | import torch.nn as nn 11 | import numpy as np 12 | import math 13 | from timm.models.vision_transformer import PatchEmbed, Attention, Mlp 14 | 15 | 16 | def build_mlp(hidden_size, projector_dim, z_dim): 17 | return nn.Sequential( 18 | nn.Linear(hidden_size, projector_dim), 19 | nn.SiLU(), 20 | nn.Linear(projector_dim, projector_dim), 21 | nn.SiLU(), 22 | nn.Linear(projector_dim, z_dim), 23 | ) 24 | 25 | def modulate(x, shift, scale): 26 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 27 | 28 | ################################################################################# 29 | # Embedding Layers for Timesteps and Class Labels # 30 | ################################################################################# 31 | class TimestepEmbedder(nn.Module): 32 | """ 33 | Embeds scalar timesteps into vector representations. 34 | """ 35 | def __init__(self, hidden_size, frequency_embedding_size=256): 36 | super().__init__() 37 | self.mlp = nn.Sequential( 38 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 39 | nn.SiLU(), 40 | nn.Linear(hidden_size, hidden_size, bias=True), 41 | ) 42 | self.frequency_embedding_size = frequency_embedding_size 43 | 44 | @staticmethod 45 | def positional_embedding(t, dim, max_period=10000): 46 | """ 47 | Create sinusoidal timestep embeddings. 48 | :param t: a 1-D Tensor of N indices, one per batch element. 49 | These may be fractional. 50 | :param dim: the dimension of the output. 51 | :param max_period: controls the minimum frequency of the embeddings. 52 | :return: an (N, D) Tensor of positional embeddings. 53 | """ 54 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 55 | half = dim // 2 56 | freqs = torch.exp( 57 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 58 | ).to(device=t.device) 59 | args = t[:, None].float() * freqs[None] 60 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 61 | if dim % 2: 62 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 63 | return embedding 64 | 65 | def forward(self, t): 66 | self.timestep_embedding = self.positional_embedding 67 | t_freq = self.timestep_embedding(t, dim=self.frequency_embedding_size).to(t.dtype) 68 | t_emb = self.mlp(t_freq) 69 | return t_emb 70 | 71 | 72 | class LabelEmbedder(nn.Module): 73 | """ 74 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. 75 | """ 76 | def __init__(self, num_classes, hidden_size, dropout_prob): 77 | super().__init__() 78 | use_cfg_embedding = dropout_prob > 0 79 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) 80 | self.num_classes = num_classes 81 | self.dropout_prob = dropout_prob 82 | 83 | def token_drop(self, labels, force_drop_ids=None): 84 | """ 85 | Drops labels to enable classifier-free guidance. 86 | """ 87 | if force_drop_ids is None: 88 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob 89 | else: 90 | drop_ids = force_drop_ids == 1 91 | labels = torch.where(drop_ids, self.num_classes, labels) 92 | return labels 93 | 94 | def forward(self, labels, train, force_drop_ids=None): 95 | use_dropout = self.dropout_prob > 0 96 | if (train and use_dropout) or (force_drop_ids is not None): 97 | labels = self.token_drop(labels, force_drop_ids) 98 | embeddings = self.embedding_table(labels) 99 | return embeddings 100 | 101 | 102 | ################################################################################# 103 | # Core SiT Model # 104 | ################################################################################# 105 | 106 | class SiTBlock(nn.Module): 107 | """ 108 | A SiT block with adaptive layer norm zero (adaLN-Zero) conditioning. 109 | """ 110 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): 111 | super().__init__() 112 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 113 | self.attn = Attention( 114 | hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=block_kwargs["qk_norm"] 115 | ) 116 | if "fused_attn" in block_kwargs.keys(): 117 | self.attn.fused_attn = block_kwargs["fused_attn"] 118 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 119 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 120 | approx_gelu = lambda: nn.GELU(approximate="tanh") 121 | self.mlp = Mlp( 122 | in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0 123 | ) 124 | self.adaLN_modulation = nn.Sequential( 125 | nn.SiLU(), 126 | nn.Linear(hidden_size, 6 * hidden_size, bias=True) 127 | ) 128 | 129 | def forward(self, x, c): 130 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( 131 | self.adaLN_modulation(c).chunk(6, dim=-1) 132 | ) 133 | x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) 134 | x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) 135 | 136 | return x 137 | 138 | 139 | class FinalLayer(nn.Module): 140 | """ 141 | The final layer of SiT. 142 | """ 143 | def __init__(self, hidden_size, patch_size, out_channels): 144 | super().__init__() 145 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 146 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) 147 | self.adaLN_modulation = nn.Sequential( 148 | nn.SiLU(), 149 | nn.Linear(hidden_size, 2 * hidden_size, bias=True) 150 | ) 151 | 152 | def forward(self, x, c): 153 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) 154 | x = modulate(self.norm_final(x), shift, scale) 155 | x = self.linear(x) 156 | 157 | return x 158 | 159 | 160 | class SiT(nn.Module): 161 | """ 162 | Diffusion model with a Transformer backbone. 163 | """ 164 | def __init__( 165 | self, 166 | path_type='edm', 167 | input_size=32, 168 | patch_size=2, 169 | in_channels=4, 170 | hidden_size=1152, 171 | decoder_hidden_size=768, 172 | encoder_depth=8, 173 | depth=28, 174 | num_heads=16, 175 | mlp_ratio=4.0, 176 | class_dropout_prob=0.1, 177 | num_classes=1000, 178 | use_cfg=False, 179 | z_dims=[768], 180 | projector_dim=2048, 181 | **block_kwargs # fused_attn 182 | ): 183 | super().__init__() 184 | self.path_type = path_type 185 | self.in_channels = in_channels 186 | self.out_channels = in_channels 187 | self.patch_size = patch_size 188 | self.num_heads = num_heads 189 | self.use_cfg = use_cfg 190 | self.num_classes = num_classes 191 | self.z_dims = z_dims 192 | self.encoder_depth = encoder_depth 193 | 194 | self.x_embedder = PatchEmbed( 195 | input_size, patch_size, in_channels, hidden_size, bias=True 196 | ) 197 | self.t_embedder = TimestepEmbedder(hidden_size) # timestep embedding type 198 | self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) 199 | num_patches = self.x_embedder.num_patches 200 | # Will use fixed sin-cos embedding: 201 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) 202 | 203 | self.blocks = nn.ModuleList([ 204 | SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, **block_kwargs) for _ in range(depth) 205 | ]) 206 | self.projectors = nn.ModuleList([ 207 | build_mlp(hidden_size, projector_dim, z_dim) for z_dim in z_dims 208 | ]) 209 | self.final_layer = FinalLayer(decoder_hidden_size, patch_size, self.out_channels) 210 | self.initialize_weights() 211 | 212 | def initialize_weights(self): 213 | # Initialize transformer layers: 214 | def _basic_init(module): 215 | if isinstance(module, nn.Linear): 216 | torch.nn.init.xavier_uniform_(module.weight) 217 | if module.bias is not None: 218 | nn.init.constant_(module.bias, 0) 219 | self.apply(_basic_init) 220 | 221 | # Initialize (and freeze) pos_embed by sin-cos embedding: 222 | pos_embed = get_2d_sincos_pos_embed( 223 | self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5) 224 | ) 225 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 226 | 227 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): 228 | w = self.x_embedder.proj.weight.data 229 | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 230 | nn.init.constant_(self.x_embedder.proj.bias, 0) 231 | 232 | # Initialize label embedding table: 233 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) 234 | 235 | # Initialize timestep embedding MLP: 236 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) 237 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) 238 | 239 | # Zero-out adaLN modulation layers in SiT blocks: 240 | for block in self.blocks: 241 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 242 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 243 | 244 | # Zero-out output layers: 245 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 246 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) 247 | nn.init.constant_(self.final_layer.linear.weight, 0) 248 | nn.init.constant_(self.final_layer.linear.bias, 0) 249 | 250 | def unpatchify(self, x, patch_size=None): 251 | """ 252 | x: (N, T, patch_size**2 * C) 253 | imgs: (N, C, H, W) 254 | """ 255 | c = self.out_channels 256 | p = self.x_embedder.patch_size[0] if patch_size is None else patch_size 257 | h = w = int(x.shape[1] ** 0.5) 258 | assert h * w == x.shape[1] 259 | 260 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) 261 | x = torch.einsum('nhwpqc->nchpwq', x) 262 | imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) 263 | return imgs 264 | 265 | def forward(self, x, t, y, return_logvar=False): 266 | """ 267 | Forward pass of SiT. 268 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) 269 | t: (N,) tensor of diffusion timesteps 270 | y: (N,) tensor of class labels 271 | """ 272 | x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 273 | N, T, D = x.shape 274 | 275 | # timestep and class embedding 276 | t_embed = self.t_embedder(t) # (N, D) 277 | y = self.y_embedder(y, self.training) # (N, D) 278 | c = t_embed + y # (N, D) 279 | 280 | for i, block in enumerate(self.blocks): 281 | x = block(x, c) # (N, T, D) 282 | if (i + 1) == self.encoder_depth: 283 | zs = [projector(x.reshape(-1, D)).reshape(N, T, -1) for projector in self.projectors] 284 | x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) 285 | x = self.unpatchify(x) # (N, out_channels, H, W) 286 | 287 | return x, zs 288 | 289 | 290 | ################################################################################# 291 | # Sine/Cosine Positional Embedding Functions # 292 | ################################################################################# 293 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py 294 | 295 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): 296 | """ 297 | grid_size: int of the grid height and width 298 | return: 299 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 300 | """ 301 | grid_h = np.arange(grid_size, dtype=np.float32) 302 | grid_w = np.arange(grid_size, dtype=np.float32) 303 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 304 | grid = np.stack(grid, axis=0) 305 | 306 | grid = grid.reshape([2, 1, grid_size, grid_size]) 307 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 308 | if cls_token and extra_tokens > 0: 309 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) 310 | return pos_embed 311 | 312 | 313 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 314 | assert embed_dim % 2 == 0 315 | 316 | # use half of dimensions to encode grid_h 317 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 318 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 319 | 320 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 321 | return emb 322 | 323 | 324 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 325 | """ 326 | embed_dim: output dimension for each position 327 | pos: a list of positions to be encoded: size (M,) 328 | out: (M, D) 329 | """ 330 | assert embed_dim % 2 == 0 331 | omega = np.arange(embed_dim // 2, dtype=np.float64) 332 | omega /= embed_dim / 2. 333 | omega = 1. / 10000**omega # (D/2,) 334 | 335 | pos = pos.reshape(-1) # (M,) 336 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 337 | 338 | emb_sin = np.sin(out) # (M, D/2) 339 | emb_cos = np.cos(out) # (M, D/2) 340 | 341 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 342 | return emb 343 | 344 | 345 | ################################################################################# 346 | # SiT Configs # 347 | ################################################################################# 348 | 349 | def SiT_XL_2(**kwargs): 350 | return SiT(depth=28, hidden_size=1152, decoder_hidden_size=1152, patch_size=2, num_heads=16, **kwargs) 351 | 352 | def SiT_XL_4(**kwargs): 353 | return SiT(depth=28, hidden_size=1152, decoder_hidden_size=1152, patch_size=4, num_heads=16, **kwargs) 354 | 355 | def SiT_XL_8(**kwargs): 356 | return SiT(depth=28, hidden_size=1152, decoder_hidden_size=1152, patch_size=8, num_heads=16, **kwargs) 357 | 358 | def SiT_L_2(**kwargs): 359 | return SiT(depth=24, hidden_size=1024, decoder_hidden_size=1024, patch_size=2, num_heads=16, **kwargs) 360 | 361 | def SiT_L_4(**kwargs): 362 | return SiT(depth=24, hidden_size=1024, decoder_hidden_size=1024, patch_size=4, num_heads=16, **kwargs) 363 | 364 | def SiT_L_8(**kwargs): 365 | return SiT(depth=24, hidden_size=1024, decoder_hidden_size=1024, patch_size=8, num_heads=16, **kwargs) 366 | 367 | def SiT_B_2(**kwargs): 368 | return SiT(depth=12, hidden_size=768, decoder_hidden_size=768, patch_size=2, num_heads=12, **kwargs) 369 | 370 | def SiT_B_4(**kwargs): 371 | return SiT(depth=12, hidden_size=768, decoder_hidden_size=768, patch_size=4, num_heads=12, **kwargs) 372 | 373 | def SiT_B_8(**kwargs): 374 | return SiT(depth=12, hidden_size=768, decoder_hidden_size=768, patch_size=8, num_heads=12, **kwargs) 375 | 376 | def SiT_S_2(**kwargs): 377 | return SiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs) 378 | 379 | def SiT_S_4(**kwargs): 380 | return SiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs) 381 | 382 | def SiT_S_8(**kwargs): 383 | return SiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs) 384 | 385 | 386 | SiT_models = { 387 | 'SiT-XL/2': SiT_XL_2, 'SiT-XL/4': SiT_XL_4, 'SiT-XL/8': SiT_XL_8, 388 | 'SiT-L/2': SiT_L_2, 'SiT-L/4': SiT_L_4, 'SiT-L/8': SiT_L_8, 389 | 'SiT-B/2': SiT_B_2, 'SiT-B/4': SiT_B_4, 'SiT-B/8': SiT_B_8, 390 | 'SiT-S/2': SiT_S_2, 'SiT-S/4': SiT_S_4, 'SiT-S/8': SiT_S_8, 391 | } 392 | 393 | -------------------------------------------------------------------------------- /preprocessing/README.md: -------------------------------------------------------------------------------- 1 |

Preprocessing Guide 2 |

3 | 4 | #### Dataset download 5 | 6 | We follow the preprocessing code used in [edm2](https://github.com/NVlabs/edm2). In this code we made a several edits: (1) we removed unncessary parts except preprocessing because this code is only used for preprocessing, (2) we use [-1, 1] range for an input to the stable diffusion VAE (similar to DiT or SiT) unlike edm2 that uses [0, 1] range, and (3) we consider preprocessing to 256x256 resolution (or 512x512 resolution). 7 | 8 | After downloading ImageNet, please run the following scripts (please update 256x256 to 512x512 if you want to do experiments on 512x512 resolution); 9 | 10 | ```bash 11 | # Convert raw ImageNet data to a ZIP archive at 256x256 resolution 12 | python dataset_tools.py convert --source=[YOUR_DOWNLOAD_PATH]/ILSVRC/Data/CLS-LOC/train \ 13 | --dest=[TARGET_PATH]/images --resolution=256x256 --transform=center-crop-dhariwal 14 | ``` 15 | 16 | ```bash 17 | # Convert the pixel data to VAE latents 18 | python dataset_tools.py encode --source=[TARGET_PATH]/images \ 19 | --dest=[TARGET_PATH]/vae-sd 20 | ``` 21 | 22 | Here,`YOUR_DOWNLOAD_PATH` is the directory that you downloaded the dataset, and `TARGET_PATH` is the directory that you will save the preprocessed images and corresponding compressed latent vectors. This directory will be used for your experiment scripts. 23 | 24 | ## Acknowledgement 25 | 26 | This code is mainly built upon [edm2](https://github.com/NVlabs/edm2) repository. 27 | -------------------------------------------------------------------------------- /preprocessing/dataset_tools.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Tool for creating ZIP/PNG based datasets.""" 9 | 10 | from collections.abc import Iterator 11 | from dataclasses import dataclass 12 | import functools 13 | import io 14 | import json 15 | import os 16 | import re 17 | import zipfile 18 | from pathlib import Path 19 | from typing import Callable, Optional, Tuple, Union 20 | import click 21 | import numpy as np 22 | import PIL.Image 23 | import torch 24 | from tqdm import tqdm 25 | 26 | from encoders import StabilityVAEEncoder 27 | 28 | #---------------------------------------------------------------------------- 29 | 30 | @dataclass 31 | class ImageEntry: 32 | img: np.ndarray 33 | label: Optional[int] 34 | 35 | #---------------------------------------------------------------------------- 36 | # Parse a 'M,N' or 'MxN' integer tuple. 37 | # Example: '4x2' returns (4,2) 38 | 39 | def parse_tuple(s: str) -> Tuple[int, int]: 40 | m = re.match(r'^(\d+)[x,](\d+)$', s) 41 | if m: 42 | return int(m.group(1)), int(m.group(2)) 43 | raise click.ClickException(f'cannot parse tuple {s}') 44 | 45 | #---------------------------------------------------------------------------- 46 | 47 | def maybe_min(a: int, b: Optional[int]) -> int: 48 | if b is not None: 49 | return min(a, b) 50 | return a 51 | 52 | #---------------------------------------------------------------------------- 53 | 54 | def file_ext(name: Union[str, Path]) -> str: 55 | return str(name).split('.')[-1] 56 | 57 | #---------------------------------------------------------------------------- 58 | 59 | def is_image_ext(fname: Union[str, Path]) -> bool: 60 | ext = file_ext(fname).lower() 61 | return f'.{ext}' in PIL.Image.EXTENSION 62 | 63 | #---------------------------------------------------------------------------- 64 | 65 | def open_image_folder(source_dir, *, max_images: Optional[int]) -> tuple[int, Iterator[ImageEntry]]: 66 | input_images = [] 67 | def _recurse_dirs(root: str): # workaround Path().rglob() slowness 68 | with os.scandir(root) as it: 69 | for e in it: 70 | if e.is_file(): 71 | input_images.append(os.path.join(root, e.name)) 72 | elif e.is_dir(): 73 | _recurse_dirs(os.path.join(root, e.name)) 74 | _recurse_dirs(source_dir) 75 | input_images = sorted([f for f in input_images if is_image_ext(f)]) 76 | 77 | arch_fnames = {fname: os.path.relpath(fname, source_dir).replace('\\', '/') for fname in input_images} 78 | max_idx = maybe_min(len(input_images), max_images) 79 | 80 | # Load labels. 81 | labels = dict() 82 | meta_fname = os.path.join(source_dir, 'dataset.json') 83 | if os.path.isfile(meta_fname): 84 | with open(meta_fname, 'r') as file: 85 | data = json.load(file)['labels'] 86 | if data is not None: 87 | labels = {x[0]: x[1] for x in data} 88 | 89 | # No labels available => determine from top-level directory names. 90 | if len(labels) == 0: 91 | toplevel_names = {arch_fname: arch_fname.split('/')[0] if '/' in arch_fname else '' for arch_fname in arch_fnames.values()} 92 | toplevel_indices = {toplevel_name: idx for idx, toplevel_name in enumerate(sorted(set(toplevel_names.values())))} 93 | if len(toplevel_indices) > 1: 94 | labels = {arch_fname: toplevel_indices[toplevel_name] for arch_fname, toplevel_name in toplevel_names.items()} 95 | 96 | def iterate_images(): 97 | for idx, fname in enumerate(input_images): 98 | img = np.array(PIL.Image.open(fname).convert('RGB')) 99 | yield ImageEntry(img=img, label=labels.get(arch_fnames[fname])) 100 | if idx >= max_idx - 1: 101 | break 102 | return max_idx, iterate_images() 103 | 104 | #---------------------------------------------------------------------------- 105 | 106 | def open_image_zip(source, *, max_images: Optional[int]) -> tuple[int, Iterator[ImageEntry]]: 107 | with zipfile.ZipFile(source, mode='r') as z: 108 | input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)] 109 | max_idx = maybe_min(len(input_images), max_images) 110 | 111 | # Load labels. 112 | labels = dict() 113 | if 'dataset.json' in z.namelist(): 114 | with z.open('dataset.json', 'r') as file: 115 | data = json.load(file)['labels'] 116 | if data is not None: 117 | labels = {x[0]: x[1] for x in data} 118 | 119 | def iterate_images(): 120 | with zipfile.ZipFile(source, mode='r') as z: 121 | for idx, fname in enumerate(input_images): 122 | with z.open(fname, 'r') as file: 123 | img = np.array(PIL.Image.open(file).convert('RGB')) 124 | yield ImageEntry(img=img, label=labels.get(fname)) 125 | if idx >= max_idx - 1: 126 | break 127 | return max_idx, iterate_images() 128 | 129 | #---------------------------------------------------------------------------- 130 | 131 | def make_transform( 132 | transform: Optional[str], 133 | output_width: Optional[int], 134 | output_height: Optional[int] 135 | ) -> Callable[[np.ndarray], Optional[np.ndarray]]: 136 | def scale(width, height, img): 137 | w = img.shape[1] 138 | h = img.shape[0] 139 | if width == w and height == h: 140 | return img 141 | img = PIL.Image.fromarray(img, 'RGB') 142 | ww = width if width is not None else w 143 | hh = height if height is not None else h 144 | img = img.resize((ww, hh), PIL.Image.Resampling.LANCZOS) 145 | return np.array(img) 146 | 147 | def center_crop(width, height, img): 148 | crop = np.min(img.shape[:2]) 149 | img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2] 150 | img = PIL.Image.fromarray(img, 'RGB') 151 | img = img.resize((width, height), PIL.Image.Resampling.LANCZOS) 152 | return np.array(img) 153 | 154 | def center_crop_wide(width, height, img): 155 | ch = int(np.round(width * img.shape[0] / img.shape[1])) 156 | if img.shape[1] < width or ch < height: 157 | return None 158 | 159 | img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2] 160 | img = PIL.Image.fromarray(img, 'RGB') 161 | img = img.resize((width, height), PIL.Image.Resampling.LANCZOS) 162 | img = np.array(img) 163 | 164 | canvas = np.zeros([width, width, 3], dtype=np.uint8) 165 | canvas[(width - height) // 2 : (width + height) // 2, :] = img 166 | return canvas 167 | 168 | def center_crop_imagenet(image_size: int, arr: np.ndarray): 169 | """ 170 | Center cropping implementation from ADM. 171 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 172 | """ 173 | pil_image = PIL.Image.fromarray(arr) 174 | while min(*pil_image.size) >= 2 * image_size: 175 | new_size = tuple(x // 2 for x in pil_image.size) 176 | assert len(new_size) == 2 177 | pil_image = pil_image.resize(new_size, resample=PIL.Image.Resampling.BOX) 178 | 179 | scale = image_size / min(*pil_image.size) 180 | new_size = tuple(round(x * scale) for x in pil_image.size) 181 | assert len(new_size) == 2 182 | pil_image = pil_image.resize(new_size, resample=PIL.Image.Resampling.BICUBIC) 183 | 184 | arr = np.array(pil_image) 185 | crop_y = (arr.shape[0] - image_size) // 2 186 | crop_x = (arr.shape[1] - image_size) // 2 187 | return arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size] 188 | 189 | if transform is None: 190 | return functools.partial(scale, output_width, output_height) 191 | if transform == 'center-crop': 192 | if output_width is None or output_height is None: 193 | raise click.ClickException('must specify --resolution=WxH when using ' + transform + 'transform') 194 | return functools.partial(center_crop, output_width, output_height) 195 | if transform == 'center-crop-wide': 196 | if output_width is None or output_height is None: 197 | raise click.ClickException('must specify --resolution=WxH when using ' + transform + ' transform') 198 | return functools.partial(center_crop_wide, output_width, output_height) 199 | if transform == 'center-crop-dhariwal': 200 | if output_width is None or output_height is None: 201 | raise click.ClickException('must specify --resolution=WxH when using ' + transform + ' transform') 202 | if output_width != output_height: 203 | raise click.ClickException('width and height must match in --resolution=WxH when using ' + transform + ' transform') 204 | return functools.partial(center_crop_imagenet, output_width) 205 | assert False, 'unknown transform' 206 | 207 | #---------------------------------------------------------------------------- 208 | 209 | def open_dataset(source, *, max_images: Optional[int]): 210 | if os.path.isdir(source): 211 | return open_image_folder(source, max_images=max_images) 212 | elif os.path.isfile(source): 213 | if file_ext(source) == 'zip': 214 | return open_image_zip(source, max_images=max_images) 215 | else: 216 | raise click.ClickException(f'Only zip archives are supported: {source}') 217 | else: 218 | raise click.ClickException(f'Missing input file or directory: {source}') 219 | 220 | #---------------------------------------------------------------------------- 221 | 222 | def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]: 223 | dest_ext = file_ext(dest) 224 | 225 | if dest_ext == 'zip': 226 | if os.path.dirname(dest) != '': 227 | os.makedirs(os.path.dirname(dest), exist_ok=True) 228 | zf = zipfile.ZipFile(file=dest, mode='w', compression=zipfile.ZIP_STORED) 229 | def zip_write_bytes(fname: str, data: Union[bytes, str]): 230 | zf.writestr(fname, data) 231 | return '', zip_write_bytes, zf.close 232 | else: 233 | # If the output folder already exists, check that is is 234 | # empty. 235 | # 236 | # Note: creating the output directory is not strictly 237 | # necessary as folder_write_bytes() also mkdirs, but it's better 238 | # to give an error message earlier in case the dest folder 239 | # somehow cannot be created. 240 | if os.path.isdir(dest) and len(os.listdir(dest)) != 0: 241 | raise click.ClickException('--dest folder must be empty') 242 | os.makedirs(dest, exist_ok=True) 243 | 244 | def folder_write_bytes(fname: str, data: Union[bytes, str]): 245 | os.makedirs(os.path.dirname(fname), exist_ok=True) 246 | with open(fname, 'wb') as fout: 247 | if isinstance(data, str): 248 | data = data.encode('utf8') 249 | fout.write(data) 250 | return dest, folder_write_bytes, lambda: None 251 | 252 | #---------------------------------------------------------------------------- 253 | 254 | @click.group() 255 | def cmdline(): 256 | '''Dataset processing tool for dataset image data conversion and VAE encode/decode preprocessing.''' 257 | if os.environ.get('WORLD_SIZE', '1') != '1': 258 | raise click.ClickException('Distributed execution is not supported.') 259 | 260 | #---------------------------------------------------------------------------- 261 | 262 | @cmdline.command() 263 | @click.option('--source', help='Input directory or archive name', metavar='PATH', type=str, required=True) 264 | @click.option('--dest', help='Output directory or archive name', metavar='PATH', type=str, required=True) 265 | @click.option('--max-images', help='Maximum number of images to output', metavar='INT', type=int) 266 | @click.option('--transform', help='Input crop/resize mode', metavar='MODE', type=click.Choice(['center-crop', 'center-crop-wide', 'center-crop-dhariwal'])) 267 | @click.option('--resolution', help='Output resolution (e.g., 512x512)', metavar='WxH', type=parse_tuple) 268 | 269 | def convert( 270 | source: str, 271 | dest: str, 272 | max_images: Optional[int], 273 | transform: Optional[str], 274 | resolution: Optional[Tuple[int, int]] 275 | ): 276 | """Convert an image dataset into archive format for training. 277 | 278 | Specifying the input images: 279 | 280 | \b 281 | --source path/ Recursively load all images from path/ 282 | --source dataset.zip Load all images from dataset.zip 283 | 284 | Specifying the output format and path: 285 | 286 | \b 287 | --dest /path/to/dir Save output files under /path/to/dir 288 | --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip 289 | 290 | The output dataset format can be either an image folder or an uncompressed zip archive. 291 | Zip archives makes it easier to move datasets around file servers and clusters, and may 292 | offer better training performance on network file systems. 293 | 294 | Images within the dataset archive will be stored as uncompressed PNG. 295 | Uncompresed PNGs can be efficiently decoded in the training loop. 296 | 297 | Class labels are stored in a file called 'dataset.json' that is stored at the 298 | dataset root folder. This file has the following structure: 299 | 300 | \b 301 | { 302 | "labels": [ 303 | ["00000/img00000000.png",6], 304 | ["00000/img00000001.png",9], 305 | ... repeated for every image in the datase 306 | ["00049/img00049999.png",1] 307 | ] 308 | } 309 | 310 | If the 'dataset.json' file cannot be found, class labels are determined from 311 | top-level directory names. 312 | 313 | Image scale/crop and resolution requirements: 314 | 315 | Output images must be square-shaped and they must all have the same power-of-two 316 | dimensions. 317 | 318 | To scale arbitrary input image size to a specific width and height, use the 319 | --resolution option. Output resolution will be either the original 320 | input resolution (if resolution was not specified) or the one specified with 321 | --resolution option. 322 | 323 | The --transform=center-crop-dhariwal selects a crop/rescale mode that is intended 324 | to exactly match with results obtained for ImageNet in common diffusion model literature: 325 | 326 | \b 327 | python dataset_tool.py convert --source=downloads/imagenet/ILSVRC/Data/CLS-LOC/train \\ 328 | --dest=datasets/img64.zip --resolution=64x64 --transform=center-crop-dhariwal 329 | """ 330 | PIL.Image.init() 331 | if dest == '': 332 | raise click.ClickException('--dest output filename or directory must not be an empty string') 333 | 334 | num_files, input_iter = open_dataset(source, max_images=max_images) 335 | archive_root_dir, save_bytes, close_dest = open_dest(dest) 336 | transform_image = make_transform(transform, *resolution if resolution is not None else (None, None)) 337 | dataset_attrs = None 338 | 339 | labels = [] 340 | for idx, image in tqdm(enumerate(input_iter), total=num_files): 341 | idx_str = f'{idx:08d}' 342 | archive_fname = f'{idx_str[:5]}/img{idx_str}.png' 343 | 344 | # Apply crop and resize. 345 | img = transform_image(image.img) 346 | if img is None: 347 | continue 348 | 349 | # Error check to require uniform image attributes across 350 | # the whole dataset. 351 | assert img.ndim == 3 352 | cur_image_attrs = {'width': img.shape[1], 'height': img.shape[0]} 353 | if dataset_attrs is None: 354 | dataset_attrs = cur_image_attrs 355 | width = dataset_attrs['width'] 356 | height = dataset_attrs['height'] 357 | if width != height: 358 | raise click.ClickException(f'Image dimensions after scale and crop are required to be square. Got {width}x{height}') 359 | if width != 2 ** int(np.floor(np.log2(width))): 360 | raise click.ClickException('Image width/height after scale and crop are required to be power-of-two') 361 | elif dataset_attrs != cur_image_attrs: 362 | err = [f' dataset {k}/cur image {k}: {dataset_attrs[k]}/{cur_image_attrs[k]}' for k in dataset_attrs.keys()] 363 | raise click.ClickException(f'Image {archive_fname} attributes must be equal across all images of the dataset. Got:\n' + '\n'.join(err)) 364 | 365 | # Save the image as an uncompressed PNG. 366 | img = PIL.Image.fromarray(img) 367 | image_bits = io.BytesIO() 368 | img.save(image_bits, format='png', compress_level=0, optimize=False) 369 | save_bytes(os.path.join(archive_root_dir, archive_fname), image_bits.getbuffer()) 370 | labels.append([archive_fname, image.label] if image.label is not None else None) 371 | 372 | metadata = {'labels': labels if all(x is not None for x in labels) else None} 373 | save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata)) 374 | close_dest() 375 | 376 | #---------------------------------------------------------------------------- 377 | 378 | @cmdline.command() 379 | @click.option('--model-url', help='VAE encoder model', metavar='URL', type=str, default='stabilityai/sd-vae-ft-mse', show_default=True) 380 | @click.option('--source', help='Input directory or archive name', metavar='PATH', type=str, required=True) 381 | @click.option('--dest', help='Output directory or archive name', metavar='PATH', type=str, required=True) 382 | @click.option('--max-images', help='Maximum number of images to output', metavar='INT', type=int) 383 | 384 | def encode( 385 | model_url: str, 386 | source: str, 387 | dest: str, 388 | max_images: Optional[int], 389 | ): 390 | """Encode pixel data to VAE latents.""" 391 | PIL.Image.init() 392 | if dest == '': 393 | raise click.ClickException('--dest output filename or directory must not be an empty string') 394 | 395 | vae = StabilityVAEEncoder(vae_name=model_url, batch_size=1) 396 | num_files, input_iter = open_dataset(source, max_images=max_images) 397 | archive_root_dir, save_bytes, close_dest = open_dest(dest) 398 | labels = [] 399 | 400 | for idx, image in tqdm(enumerate(input_iter), total=num_files): 401 | img_tensor = torch.tensor(image.img).to('cuda').permute(2, 0, 1).unsqueeze(0) 402 | mean_std = vae.encode_pixels(img_tensor)[0].cpu() 403 | idx_str = f'{idx:08d}' 404 | archive_fname = f'{idx_str[:5]}/img-mean-std-{idx_str}.npy' 405 | 406 | f = io.BytesIO() 407 | np.save(f, mean_std) 408 | save_bytes(os.path.join(archive_root_dir, archive_fname), f.getvalue()) 409 | labels.append([archive_fname, image.label] if image.label is not None else None) 410 | 411 | metadata = {'labels': labels if all(x is not None for x in labels) else None} 412 | save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata)) 413 | close_dest() 414 | 415 | if __name__ == "__main__": 416 | cmdline() 417 | 418 | #---------------------------------------------------------------------------- -------------------------------------------------------------------------------- /preprocessing/dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | from .util import EasyDict, make_cache_dir_path 9 | -------------------------------------------------------------------------------- /preprocessing/dnnlib/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Miscellaneous utility classes and functions.""" 9 | 10 | import ctypes 11 | import fnmatch 12 | import importlib 13 | import inspect 14 | import numpy as np 15 | import os 16 | import shutil 17 | import sys 18 | import types 19 | import io 20 | import pickle 21 | import re 22 | import requests 23 | import html 24 | import hashlib 25 | import glob 26 | import tempfile 27 | import urllib 28 | import urllib.parse 29 | import uuid 30 | 31 | from typing import Any, Callable, BinaryIO, List, Tuple, Union, Optional 32 | 33 | # Util classes 34 | # ------------------------------------------------------------------------------------------ 35 | 36 | 37 | class EasyDict(dict): 38 | """Convenience class that behaves like a dict but allows access with the attribute syntax.""" 39 | 40 | def __getattr__(self, name: str) -> Any: 41 | try: 42 | return self[name] 43 | except KeyError: 44 | raise AttributeError(name) 45 | 46 | def __setattr__(self, name: str, value: Any) -> None: 47 | self[name] = value 48 | 49 | def __delattr__(self, name: str) -> None: 50 | del self[name] 51 | 52 | 53 | class Logger(object): 54 | """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" 55 | 56 | def __init__(self, file_name: Optional[str] = None, file_mode: str = "w", should_flush: bool = True): 57 | self.file = None 58 | 59 | if file_name is not None: 60 | self.file = open(file_name, file_mode) 61 | 62 | self.should_flush = should_flush 63 | self.stdout = sys.stdout 64 | self.stderr = sys.stderr 65 | 66 | sys.stdout = self 67 | sys.stderr = self 68 | 69 | def __enter__(self) -> "Logger": 70 | return self 71 | 72 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 73 | self.close() 74 | 75 | def write(self, text: Union[str, bytes]) -> None: 76 | """Write text to stdout (and a file) and optionally flush.""" 77 | if isinstance(text, bytes): 78 | text = text.decode() 79 | if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash 80 | return 81 | 82 | if self.file is not None: 83 | self.file.write(text) 84 | 85 | self.stdout.write(text) 86 | 87 | if self.should_flush: 88 | self.flush() 89 | 90 | def flush(self) -> None: 91 | """Flush written text to both stdout and a file, if open.""" 92 | if self.file is not None: 93 | self.file.flush() 94 | 95 | self.stdout.flush() 96 | 97 | def close(self) -> None: 98 | """Flush, close possible files, and remove stdout/stderr mirroring.""" 99 | self.flush() 100 | 101 | # if using multiple loggers, prevent closing in wrong order 102 | if sys.stdout is self: 103 | sys.stdout = self.stdout 104 | if sys.stderr is self: 105 | sys.stderr = self.stderr 106 | 107 | if self.file is not None: 108 | self.file.close() 109 | self.file = None 110 | 111 | 112 | # Cache directories 113 | # ------------------------------------------------------------------------------------------ 114 | 115 | _dnnlib_cache_dir = None 116 | 117 | def set_cache_dir(path: str) -> None: 118 | global _dnnlib_cache_dir 119 | _dnnlib_cache_dir = path 120 | 121 | def make_cache_dir_path(*paths: str) -> str: 122 | if _dnnlib_cache_dir is not None: 123 | return os.path.join(_dnnlib_cache_dir, *paths) 124 | if 'DNNLIB_CACHE_DIR' in os.environ: 125 | return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths) 126 | if 'HOME' in os.environ: 127 | return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths) 128 | if 'USERPROFILE' in os.environ: 129 | return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths) 130 | return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths) 131 | 132 | # Small util functions 133 | # ------------------------------------------------------------------------------------------ 134 | 135 | 136 | def format_time(seconds: Union[int, float]) -> str: 137 | """Convert the seconds to human readable string with days, hours, minutes and seconds.""" 138 | s = int(np.rint(seconds)) 139 | 140 | if s < 60: 141 | return "{0}s".format(s) 142 | elif s < 60 * 60: 143 | return "{0}m {1:02}s".format(s // 60, s % 60) 144 | elif s < 24 * 60 * 60: 145 | return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) 146 | else: 147 | return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) 148 | 149 | 150 | def format_time_brief(seconds: Union[int, float]) -> str: 151 | """Convert the seconds to human readable string with days, hours, minutes and seconds.""" 152 | s = int(np.rint(seconds)) 153 | 154 | if s < 60: 155 | return "{0}s".format(s) 156 | elif s < 60 * 60: 157 | return "{0}m {1:02}s".format(s // 60, s % 60) 158 | elif s < 24 * 60 * 60: 159 | return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60) 160 | else: 161 | return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24) 162 | 163 | 164 | def tuple_product(t: Tuple) -> Any: 165 | """Calculate the product of the tuple elements.""" 166 | result = 1 167 | 168 | for v in t: 169 | result *= v 170 | 171 | return result 172 | 173 | 174 | _str_to_ctype = { 175 | "uint8": ctypes.c_ubyte, 176 | "uint16": ctypes.c_uint16, 177 | "uint32": ctypes.c_uint32, 178 | "uint64": ctypes.c_uint64, 179 | "int8": ctypes.c_byte, 180 | "int16": ctypes.c_int16, 181 | "int32": ctypes.c_int32, 182 | "int64": ctypes.c_int64, 183 | "float32": ctypes.c_float, 184 | "float64": ctypes.c_double 185 | } 186 | 187 | 188 | def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: 189 | """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" 190 | type_str = None 191 | 192 | if isinstance(type_obj, str): 193 | type_str = type_obj 194 | elif hasattr(type_obj, "__name__"): 195 | type_str = type_obj.__name__ 196 | elif hasattr(type_obj, "name"): 197 | type_str = type_obj.name 198 | else: 199 | raise RuntimeError("Cannot infer type name from input") 200 | 201 | assert type_str in _str_to_ctype.keys() 202 | 203 | my_dtype = np.dtype(type_str) 204 | my_ctype = _str_to_ctype[type_str] 205 | 206 | assert my_dtype.itemsize == ctypes.sizeof(my_ctype) 207 | 208 | return my_dtype, my_ctype 209 | 210 | 211 | def is_pickleable(obj: Any) -> bool: 212 | try: 213 | with io.BytesIO() as stream: 214 | pickle.dump(obj, stream) 215 | return True 216 | except: 217 | return False 218 | 219 | 220 | # Functionality to import modules/objects by name, and call functions by name 221 | # ------------------------------------------------------------------------------------------ 222 | 223 | def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: 224 | """Searches for the underlying module behind the name to some python object. 225 | Returns the module and the object name (original name with module part removed).""" 226 | 227 | # allow convenience shorthands, substitute them by full names 228 | obj_name = re.sub("^np.", "numpy.", obj_name) 229 | obj_name = re.sub("^tf.", "tensorflow.", obj_name) 230 | 231 | # list alternatives for (module_name, local_obj_name) 232 | parts = obj_name.split(".") 233 | name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] 234 | 235 | # try each alternative in turn 236 | for module_name, local_obj_name in name_pairs: 237 | try: 238 | module = importlib.import_module(module_name) # may raise ImportError 239 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 240 | return module, local_obj_name 241 | except: 242 | pass 243 | 244 | # maybe some of the modules themselves contain errors? 245 | for module_name, _local_obj_name in name_pairs: 246 | try: 247 | importlib.import_module(module_name) # may raise ImportError 248 | except ImportError: 249 | if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): 250 | raise 251 | 252 | # maybe the requested attribute is missing? 253 | for module_name, local_obj_name in name_pairs: 254 | try: 255 | module = importlib.import_module(module_name) # may raise ImportError 256 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 257 | except ImportError: 258 | pass 259 | 260 | # we are out of luck, but we have no idea why 261 | raise ImportError(obj_name) 262 | 263 | 264 | def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: 265 | """Traverses the object name and returns the last (rightmost) python object.""" 266 | if obj_name == '': 267 | return module 268 | obj = module 269 | for part in obj_name.split("."): 270 | obj = getattr(obj, part) 271 | return obj 272 | 273 | 274 | def get_obj_by_name(name: str) -> Any: 275 | """Finds the python object with the given name.""" 276 | module, obj_name = get_module_from_obj_name(name) 277 | return get_obj_from_module(module, obj_name) 278 | 279 | 280 | def call_func_by_name(*args, func_name: Union[str, Callable], **kwargs) -> Any: 281 | """Finds the python object with the given name and calls it as a function.""" 282 | assert func_name is not None 283 | func_obj = get_obj_by_name(func_name) if isinstance(func_name, str) else func_name 284 | assert callable(func_obj) 285 | return func_obj(*args, **kwargs) 286 | 287 | 288 | def construct_class_by_name(*args, class_name: Union[str, type], **kwargs) -> Any: 289 | """Finds the python class with the given name and constructs it with the given arguments.""" 290 | return call_func_by_name(*args, func_name=class_name, **kwargs) 291 | 292 | 293 | def get_module_dir_by_obj_name(obj_name: str) -> str: 294 | """Get the directory path of the module containing the given object name.""" 295 | module, _ = get_module_from_obj_name(obj_name) 296 | return os.path.dirname(inspect.getfile(module)) 297 | 298 | 299 | def is_top_level_function(obj: Any) -> bool: 300 | """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" 301 | return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ 302 | 303 | 304 | def get_top_level_function_name(obj: Any) -> str: 305 | """Return the fully-qualified name of a top-level function.""" 306 | assert is_top_level_function(obj) 307 | module = obj.__module__ 308 | if module == '__main__': 309 | fname = sys.modules[module].__file__ 310 | assert fname is not None 311 | module = os.path.splitext(os.path.basename(fname))[0] 312 | return module + "." + obj.__name__ 313 | 314 | 315 | # File system helpers 316 | # ------------------------------------------------------------------------------------------ 317 | 318 | def list_dir_recursively_with_ignore(dir_path: str, ignores: Optional[List[str]] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]: 319 | """List all files recursively in a given directory while ignoring given file and directory names. 320 | Returns list of tuples containing both absolute and relative paths.""" 321 | assert os.path.isdir(dir_path) 322 | base_name = os.path.basename(os.path.normpath(dir_path)) 323 | 324 | if ignores is None: 325 | ignores = [] 326 | 327 | result = [] 328 | 329 | for root, dirs, files in os.walk(dir_path, topdown=True): 330 | for ignore_ in ignores: 331 | dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] 332 | 333 | # dirs need to be edited in-place 334 | for d in dirs_to_remove: 335 | dirs.remove(d) 336 | 337 | files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] 338 | 339 | absolute_paths = [os.path.join(root, f) for f in files] 340 | relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] 341 | 342 | if add_base_to_relative: 343 | relative_paths = [os.path.join(base_name, p) for p in relative_paths] 344 | 345 | assert len(absolute_paths) == len(relative_paths) 346 | result += zip(absolute_paths, relative_paths) 347 | 348 | return result 349 | 350 | 351 | def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: 352 | """Takes in a list of tuples of (src, dst) paths and copies files. 353 | Will create all necessary directories.""" 354 | for file in files: 355 | target_dir_name = os.path.dirname(file[1]) 356 | 357 | # will create all intermediate-level directories 358 | os.makedirs(target_dir_name, exist_ok=True) 359 | shutil.copyfile(file[0], file[1]) 360 | 361 | 362 | # URL helpers 363 | # ------------------------------------------------------------------------------------------ 364 | 365 | def is_url(obj: Any, allow_file_urls: bool = False) -> bool: 366 | """Determine whether the given object is a valid URL string.""" 367 | if not isinstance(obj, str) or not "://" in obj: 368 | return False 369 | if allow_file_urls and obj.startswith('file://'): 370 | return True 371 | try: 372 | res = urllib.parse.urlparse(obj) 373 | if not res.scheme or not res.netloc or not "." in res.netloc: 374 | return False 375 | res = urllib.parse.urlparse(urllib.parse.urljoin(obj, "/")) 376 | if not res.scheme or not res.netloc or not "." in res.netloc: 377 | return False 378 | except: 379 | return False 380 | return True 381 | 382 | # Note on static typing: a better API would be to split 'open_url' to 'openl_url' and 383 | # 'download_url' with separate return types (BinaryIO, str). As the `return_filename=True` 384 | # case is somewhat uncommon, we just pretend like this function never returns a string 385 | # and type ignore return value for those cases. 386 | def open_url(url: str, cache_dir: Optional[str] = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> BinaryIO: 387 | """Download the given URL and return a binary-mode file object to access the data.""" 388 | assert num_attempts >= 1 389 | assert not (return_filename and (not cache)) 390 | 391 | # Doesn't look like an URL scheme so interpret it as a local filename. 392 | if not re.match('^[a-z]+://', url): 393 | return url if return_filename else open(url, "rb") # type: ignore 394 | 395 | # Handle file URLs. This code handles unusual file:// patterns that 396 | # arise on Windows: 397 | # 398 | # file:///c:/foo.txt 399 | # 400 | # which would translate to a local '/c:/foo.txt' filename that's 401 | # invalid. Drop the forward slash for such pathnames. 402 | # 403 | # If you touch this code path, you should test it on both Linux and 404 | # Windows. 405 | # 406 | # Some internet resources suggest using urllib.request.url2pathname() 407 | # but that converts forward slashes to backslashes and this causes 408 | # its own set of problems. 409 | if url.startswith('file://'): 410 | filename = urllib.parse.urlparse(url).path 411 | if re.match(r'^/[a-zA-Z]:', filename): 412 | filename = filename[1:] 413 | return filename if return_filename else open(filename, "rb") # type: ignore 414 | 415 | assert is_url(url) 416 | 417 | # Lookup from cache. 418 | if cache_dir is None: 419 | cache_dir = make_cache_dir_path('downloads') 420 | 421 | url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() 422 | if cache: 423 | cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) 424 | if len(cache_files) == 1: 425 | filename = cache_files[0] 426 | return filename if return_filename else open(filename, "rb") # type: ignore 427 | 428 | # Download. 429 | url_name = None 430 | url_data = None 431 | with requests.Session() as session: 432 | if verbose: 433 | print("Downloading %s ..." % url, end="", flush=True) 434 | for attempts_left in reversed(range(num_attempts)): 435 | try: 436 | with session.get(url) as res: 437 | res.raise_for_status() 438 | if len(res.content) == 0: 439 | raise IOError("No data received") 440 | 441 | if len(res.content) < 8192: 442 | content_str = res.content.decode("utf-8") 443 | if "download_warning" in res.headers.get("Set-Cookie", ""): 444 | links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] 445 | if len(links) == 1: 446 | url = urllib.parse.urljoin(url, links[0]) 447 | raise IOError("Google Drive virus checker nag") 448 | if "Google Drive - Quota exceeded" in content_str: 449 | raise IOError("Google Drive download quota exceeded -- please try again later") 450 | 451 | match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) 452 | url_name = match[1] if match else url 453 | url_data = res.content 454 | if verbose: 455 | print(" done") 456 | break 457 | except KeyboardInterrupt: 458 | raise 459 | except: 460 | if not attempts_left: 461 | if verbose: 462 | print(" failed") 463 | raise 464 | if verbose: 465 | print(".", end="", flush=True) 466 | 467 | assert url_data is not None 468 | 469 | # Save to cache. 470 | if cache: 471 | assert url_name is not None 472 | safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) 473 | safe_name = safe_name[:min(len(safe_name), 128)] 474 | cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) 475 | temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) 476 | os.makedirs(cache_dir, exist_ok=True) 477 | with open(temp_file, "wb") as f: 478 | f.write(url_data) 479 | os.replace(temp_file, cache_file) # atomic 480 | if return_filename: 481 | return cache_file # type: ignore 482 | 483 | # Return data as file object. 484 | assert not return_filename 485 | return io.BytesIO(url_data) 486 | -------------------------------------------------------------------------------- /preprocessing/encoders.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Converting between pixel and latent representations of image data.""" 9 | 10 | import os 11 | import warnings 12 | import numpy as np 13 | import torch 14 | from torch_utils import persistence 15 | from torch_utils import misc 16 | 17 | warnings.filterwarnings('ignore', 'torch.utils._pytree._register_pytree_node is deprecated.') 18 | warnings.filterwarnings('ignore', '`resume_download` is deprecated') 19 | 20 | #---------------------------------------------------------------------------- 21 | # Abstract base class for encoders/decoders that convert back and forth 22 | # between pixel and latent representations of image data. 23 | # 24 | # Logically, "raw pixels" are first encoded into "raw latents" that are 25 | # then further encoded into "final latents". Decoding, on the other hand, 26 | # goes directly from the final latents to raw pixels. The final latents are 27 | # used as inputs and outputs of the model, whereas the raw latents are 28 | # stored in the dataset. This separation provides added flexibility in terms 29 | # of performing just-in-time adjustments, such as data whitening, without 30 | # having to construct a new dataset. 31 | # 32 | # All image data is represented as PyTorch tensors in NCHW order. 33 | # Raw pixels are represented as 3-channel uint8. 34 | 35 | @persistence.persistent_class 36 | class Encoder: 37 | def __init__(self): 38 | pass 39 | 40 | def init(self, device): # force lazy init to happen now 41 | pass 42 | 43 | def __getstate__(self): 44 | return self.__dict__ 45 | 46 | def encode_pixels(self, x): # raw pixels => raw latents 47 | raise NotImplementedError # to be overridden by subclass 48 | #---------------------------------------------------------------------------- 49 | # Pre-trained VAE encoder from Stability AI. 50 | 51 | @persistence.persistent_class 52 | class StabilityVAEEncoder(Encoder): 53 | def __init__(self, 54 | vae_name = 'stabilityai/sd-vae-ft-mse', # Name of the VAE to use. 55 | batch_size = 8, # Batch size to use when running the VAE. 56 | ): 57 | super().__init__() 58 | self.vae_name = vae_name 59 | self.batch_size = int(batch_size) 60 | self._vae = None 61 | 62 | def init(self, device): # force lazy init to happen now 63 | super().init(device) 64 | if self._vae is None: 65 | self._vae = load_stability_vae(self.vae_name, device=device) 66 | else: 67 | self._vae.to(device) 68 | 69 | def __getstate__(self): 70 | return dict(super().__getstate__(), _vae=None) # do not pickle the vae 71 | 72 | def _run_vae_encoder(self, x): 73 | d = self._vae.encode(x)['latent_dist'] 74 | return torch.cat([d.mean, d.std], dim=1) 75 | 76 | def encode_pixels(self, x): # raw pixels => raw latents 77 | self.init(x.device) 78 | x = x.to(torch.float32) / 127.5 - 1 79 | x = torch.cat([self._run_vae_encoder(batch) for batch in x.split(self.batch_size)]) 80 | return x 81 | 82 | #---------------------------------------------------------------------------- 83 | 84 | def load_stability_vae(vae_name='stabilityai/sd-vae-ft-mse', device=torch.device('cpu')): 85 | import dnnlib 86 | cache_dir = dnnlib.make_cache_dir_path('diffusers') 87 | os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1' 88 | os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1' 89 | os.environ['HF_HOME'] = cache_dir 90 | 91 | import diffusers # pip install diffusers # pyright: ignore [reportMissingImports] 92 | try: 93 | # First try with local_files_only to avoid consulting tfhub metadata if the model is already in cache. 94 | vae = diffusers.models.AutoencoderKL.from_pretrained( 95 | vae_name, cache_dir=cache_dir, local_files_only=True 96 | ) 97 | except: 98 | # Could not load the model from cache; try without local_files_only. 99 | vae = diffusers.models.AutoencoderKL.from_pretrained(vae_name, cache_dir=cache_dir) 100 | return vae.eval().requires_grad_(False).to(device) 101 | 102 | #---------------------------------------------------------------------------- -------------------------------------------------------------------------------- /preprocessing/torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | # empty 9 | -------------------------------------------------------------------------------- /preprocessing/torch_utils/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | import os 9 | import re 10 | import socket 11 | import torch 12 | import torch.distributed 13 | from . import training_stats 14 | 15 | _sync_device = None 16 | 17 | #---------------------------------------------------------------------------- 18 | 19 | def init(): 20 | global _sync_device 21 | 22 | if not torch.distributed.is_initialized(): 23 | # Setup some reasonable defaults for env-based distributed init if 24 | # not set by the running environment. 25 | if 'MASTER_ADDR' not in os.environ: 26 | os.environ['MASTER_ADDR'] = 'localhost' 27 | if 'MASTER_PORT' not in os.environ: 28 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 29 | s.bind(('', 0)) 30 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 31 | os.environ['MASTER_PORT'] = str(s.getsockname()[1]) 32 | s.close() 33 | if 'RANK' not in os.environ: 34 | os.environ['RANK'] = '0' 35 | if 'LOCAL_RANK' not in os.environ: 36 | os.environ['LOCAL_RANK'] = '0' 37 | if 'WORLD_SIZE' not in os.environ: 38 | os.environ['WORLD_SIZE'] = '1' 39 | backend = 'gloo' if os.name == 'nt' else 'nccl' 40 | torch.distributed.init_process_group(backend=backend, init_method='env://') 41 | torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0'))) 42 | 43 | _sync_device = torch.device('cuda') if get_world_size() > 1 else None 44 | training_stats.init_multiprocessing(rank=get_rank(), sync_device=_sync_device) 45 | 46 | #---------------------------------------------------------------------------- 47 | 48 | def get_rank(): 49 | return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 50 | 51 | #---------------------------------------------------------------------------- 52 | 53 | def get_world_size(): 54 | return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 55 | 56 | #---------------------------------------------------------------------------- 57 | 58 | def should_stop(): 59 | return False 60 | 61 | #---------------------------------------------------------------------------- 62 | 63 | def should_suspend(): 64 | return False 65 | 66 | #---------------------------------------------------------------------------- 67 | 68 | def request_suspend(): 69 | pass 70 | 71 | #---------------------------------------------------------------------------- 72 | 73 | def update_progress(cur, total): 74 | pass 75 | 76 | #---------------------------------------------------------------------------- 77 | 78 | def print0(*args, **kwargs): 79 | if get_rank() == 0: 80 | print(*args, **kwargs) 81 | 82 | #---------------------------------------------------------------------------- 83 | 84 | class CheckpointIO: 85 | def __init__(self, **kwargs): 86 | self._state_objs = kwargs 87 | 88 | def save(self, pt_path, verbose=True): 89 | if verbose: 90 | print0(f'Saving {pt_path} ... ', end='', flush=True) 91 | data = dict() 92 | for name, obj in self._state_objs.items(): 93 | if obj is None: 94 | data[name] = None 95 | elif isinstance(obj, dict): 96 | data[name] = obj 97 | elif hasattr(obj, 'state_dict'): 98 | data[name] = obj.state_dict() 99 | elif hasattr(obj, '__getstate__'): 100 | data[name] = obj.__getstate__() 101 | elif hasattr(obj, '__dict__'): 102 | data[name] = obj.__dict__ 103 | else: 104 | raise ValueError(f'Invalid state object of type {type(obj).__name__}') 105 | if get_rank() == 0: 106 | torch.save(data, pt_path) 107 | if verbose: 108 | print0('done') 109 | 110 | def load(self, pt_path, verbose=True): 111 | if verbose: 112 | print0(f'Loading {pt_path} ... ', end='', flush=True) 113 | data = torch.load(pt_path, map_location=torch.device('cpu')) 114 | for name, obj in self._state_objs.items(): 115 | if obj is None: 116 | pass 117 | elif isinstance(obj, dict): 118 | obj.clear() 119 | obj.update(data[name]) 120 | elif hasattr(obj, 'load_state_dict'): 121 | obj.load_state_dict(data[name]) 122 | elif hasattr(obj, '__setstate__'): 123 | obj.__setstate__(data[name]) 124 | elif hasattr(obj, '__dict__'): 125 | obj.__dict__.clear() 126 | obj.__dict__.update(data[name]) 127 | else: 128 | raise ValueError(f'Invalid state object of type {type(obj).__name__}') 129 | if verbose: 130 | print0('done') 131 | 132 | def load_latest(self, run_dir, pattern=r'training-state-(\d+).pt', verbose=True): 133 | fnames = [entry.name for entry in os.scandir(run_dir) if entry.is_file() and re.fullmatch(pattern, entry.name)] 134 | if len(fnames) == 0: 135 | return None 136 | pt_path = os.path.join(run_dir, max(fnames, key=lambda x: float(re.fullmatch(pattern, x).group(1)))) 137 | self.load(pt_path, verbose=verbose) 138 | return pt_path 139 | 140 | #---------------------------------------------------------------------------- 141 | -------------------------------------------------------------------------------- /preprocessing/torch_utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | import re 9 | import contextlib 10 | import functools 11 | import numpy as np 12 | import torch 13 | import warnings 14 | import dnnlib 15 | 16 | #---------------------------------------------------------------------------- 17 | # Re-seed torch & numpy random generators based on the given arguments. 18 | 19 | def set_random_seed(*args): 20 | seed = hash(args) % (1 << 31) 21 | torch.manual_seed(seed) 22 | np.random.seed(seed) 23 | 24 | #---------------------------------------------------------------------------- 25 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the 26 | # same constant is used multiple times. 27 | 28 | _constant_cache = dict() 29 | 30 | def constant(value, shape=None, dtype=None, device=None, memory_format=None): 31 | value = np.asarray(value) 32 | if shape is not None: 33 | shape = tuple(shape) 34 | if dtype is None: 35 | dtype = torch.get_default_dtype() 36 | if device is None: 37 | device = torch.device('cpu') 38 | if memory_format is None: 39 | memory_format = torch.contiguous_format 40 | 41 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) 42 | tensor = _constant_cache.get(key, None) 43 | if tensor is None: 44 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) 45 | if shape is not None: 46 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) 47 | tensor = tensor.contiguous(memory_format=memory_format) 48 | _constant_cache[key] = tensor 49 | return tensor 50 | 51 | #---------------------------------------------------------------------------- 52 | # Variant of constant() that inherits dtype and device from the given 53 | # reference tensor by default. 54 | 55 | def const_like(ref, value, shape=None, dtype=None, device=None, memory_format=None): 56 | if dtype is None: 57 | dtype = ref.dtype 58 | if device is None: 59 | device = ref.device 60 | return constant(value, shape=shape, dtype=dtype, device=device, memory_format=memory_format) 61 | 62 | #---------------------------------------------------------------------------- 63 | # Cached construction of temporary tensors in pinned CPU memory. 64 | 65 | @functools.lru_cache(None) 66 | def pinned_buf(shape, dtype): 67 | return torch.empty(shape, dtype=dtype).pin_memory() 68 | 69 | #---------------------------------------------------------------------------- 70 | # Symbolic assert. 71 | 72 | try: 73 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access 74 | except AttributeError: 75 | symbolic_assert = torch.Assert # 1.7.0 76 | 77 | #---------------------------------------------------------------------------- 78 | # Context manager to temporarily suppress known warnings in torch.jit.trace(). 79 | # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 80 | 81 | @contextlib.contextmanager 82 | def suppress_tracer_warnings(): 83 | flt = ('ignore', None, torch.jit.TracerWarning, None, 0) 84 | warnings.filters.insert(0, flt) 85 | yield 86 | warnings.filters.remove(flt) 87 | 88 | #---------------------------------------------------------------------------- 89 | # Assert that the shape of a tensor matches the given list of integers. 90 | # None indicates that the size of a dimension is allowed to vary. 91 | # Performs symbolic assertion when used in torch.jit.trace(). 92 | 93 | def assert_shape(tensor, ref_shape): 94 | if tensor.ndim != len(ref_shape): 95 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') 96 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): 97 | if ref_size is None: 98 | pass 99 | elif isinstance(ref_size, torch.Tensor): 100 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 101 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') 102 | elif isinstance(size, torch.Tensor): 103 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 104 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') 105 | elif size != ref_size: 106 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') 107 | 108 | #---------------------------------------------------------------------------- 109 | # Function decorator that calls torch.autograd.profiler.record_function(). 110 | 111 | def profiled_function(fn): 112 | def decorator(*args, **kwargs): 113 | with torch.autograd.profiler.record_function(fn.__name__): 114 | return fn(*args, **kwargs) 115 | decorator.__name__ = fn.__name__ 116 | return decorator 117 | 118 | #---------------------------------------------------------------------------- 119 | # Sampler for torch.utils.data.DataLoader that loops over the dataset 120 | # indefinitely, shuffling items as it goes. 121 | 122 | class InfiniteSampler(torch.utils.data.Sampler): 123 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, start_idx=0): 124 | assert len(dataset) > 0 125 | assert num_replicas > 0 126 | assert 0 <= rank < num_replicas 127 | warnings.filterwarnings('ignore', '`data_source` argument is not used and will be removed') 128 | super().__init__(dataset) 129 | self.dataset_size = len(dataset) 130 | self.start_idx = start_idx + rank 131 | self.stride = num_replicas 132 | self.shuffle = shuffle 133 | self.seed = seed 134 | 135 | def __iter__(self): 136 | idx = self.start_idx 137 | epoch = None 138 | while True: 139 | if epoch != idx // self.dataset_size: 140 | epoch = idx // self.dataset_size 141 | order = np.arange(self.dataset_size) 142 | if self.shuffle: 143 | np.random.RandomState(hash((self.seed, epoch)) % (1 << 31)).shuffle(order) 144 | yield int(order[idx % self.dataset_size]) 145 | idx += self.stride 146 | 147 | #---------------------------------------------------------------------------- 148 | # Utilities for operating with torch.nn.Module parameters and buffers. 149 | 150 | def params_and_buffers(module): 151 | assert isinstance(module, torch.nn.Module) 152 | return list(module.parameters()) + list(module.buffers()) 153 | 154 | def named_params_and_buffers(module): 155 | assert isinstance(module, torch.nn.Module) 156 | return list(module.named_parameters()) + list(module.named_buffers()) 157 | 158 | @torch.no_grad() 159 | def copy_params_and_buffers(src_module, dst_module, require_all=False): 160 | assert isinstance(src_module, torch.nn.Module) 161 | assert isinstance(dst_module, torch.nn.Module) 162 | src_tensors = dict(named_params_and_buffers(src_module)) 163 | for name, tensor in named_params_and_buffers(dst_module): 164 | assert (name in src_tensors) or (not require_all) 165 | if name in src_tensors: 166 | tensor.copy_(src_tensors[name]) 167 | 168 | #---------------------------------------------------------------------------- 169 | # Context manager for easily enabling/disabling DistributedDataParallel 170 | # synchronization. 171 | 172 | @contextlib.contextmanager 173 | def ddp_sync(module, sync): 174 | assert isinstance(module, torch.nn.Module) 175 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): 176 | yield 177 | else: 178 | with module.no_sync(): 179 | yield 180 | 181 | #---------------------------------------------------------------------------- 182 | # Check DistributedDataParallel consistency across processes. 183 | 184 | def check_ddp_consistency(module, ignore_regex=None): 185 | assert isinstance(module, torch.nn.Module) 186 | for name, tensor in named_params_and_buffers(module): 187 | fullname = type(module).__name__ + '.' + name 188 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): 189 | continue 190 | tensor = tensor.detach() 191 | if tensor.is_floating_point(): 192 | tensor = torch.nan_to_num(tensor) 193 | other = tensor.clone() 194 | torch.distributed.broadcast(tensor=other, src=0) 195 | assert (tensor == other).all(), fullname 196 | 197 | #---------------------------------------------------------------------------- 198 | # Print summary table of module hierarchy. 199 | 200 | @torch.no_grad() 201 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): 202 | assert isinstance(module, torch.nn.Module) 203 | assert not isinstance(module, torch.jit.ScriptModule) 204 | assert isinstance(inputs, (tuple, list)) 205 | 206 | # Register hooks. 207 | entries = [] 208 | nesting = [0] 209 | def pre_hook(_mod, _inputs): 210 | nesting[0] += 1 211 | def post_hook(mod, _inputs, outputs): 212 | nesting[0] -= 1 213 | if nesting[0] <= max_nesting: 214 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] 215 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)] 216 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) 217 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] 218 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] 219 | 220 | # Run module. 221 | outputs = module(*inputs) 222 | for hook in hooks: 223 | hook.remove() 224 | 225 | # Identify unique outputs, parameters, and buffers. 226 | tensors_seen = set() 227 | for e in entries: 228 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] 229 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] 230 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] 231 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} 232 | 233 | # Filter out redundant entries. 234 | if skip_redundant: 235 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] 236 | 237 | # Construct table. 238 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] 239 | rows += [['---'] * len(rows[0])] 240 | param_total = 0 241 | buffer_total = 0 242 | submodule_names = {mod: name for name, mod in module.named_modules()} 243 | for e in entries: 244 | name = '' if e.mod is module else submodule_names[e.mod] 245 | param_size = sum(t.numel() for t in e.unique_params) 246 | buffer_size = sum(t.numel() for t in e.unique_buffers) 247 | output_shapes = [str(list(t.shape)) for t in e.outputs] 248 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] 249 | rows += [[ 250 | name + (':0' if len(e.outputs) >= 2 else ''), 251 | str(param_size) if param_size else '-', 252 | str(buffer_size) if buffer_size else '-', 253 | (output_shapes + ['-'])[0], 254 | (output_dtypes + ['-'])[0], 255 | ]] 256 | for idx in range(1, len(e.outputs)): 257 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] 258 | param_total += param_size 259 | buffer_total += buffer_size 260 | rows += [['---'] * len(rows[0])] 261 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']] 262 | 263 | # Print table. 264 | widths = [max(len(cell) for cell in column) for column in zip(*rows)] 265 | print() 266 | for row in rows: 267 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) 268 | print() 269 | 270 | #---------------------------------------------------------------------------- 271 | # Tile a batch of images into a 2D grid. 272 | 273 | def tile_images(x, w, h): 274 | assert x.ndim == 4 # NCHW => CHW 275 | return x.reshape(h, w, *x.shape[1:]).permute(2, 0, 3, 1, 4).reshape(x.shape[1], h * x.shape[2], w * x.shape[3]) 276 | 277 | #---------------------------------------------------------------------------- 278 | -------------------------------------------------------------------------------- /preprocessing/torch_utils/persistence.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Facilities for pickling Python code alongside other data. 9 | 10 | The pickled code is automatically imported into a separate Python module 11 | during unpickling. This way, any previously exported pickles will remain 12 | usable even if the original code is no longer available, or if the current 13 | version of the code is not consistent with what was originally pickled.""" 14 | 15 | import sys 16 | import pickle 17 | import io 18 | import inspect 19 | import copy 20 | import uuid 21 | import types 22 | import functools 23 | import dnnlib 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | _version = 6 # internal version number 28 | _decorators = set() # {decorator_class, ...} 29 | _import_hooks = [] # [hook_function, ...] 30 | _module_to_src_dict = dict() # {module: src, ...} 31 | _src_to_module_dict = dict() # {src: module, ...} 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def persistent_class(orig_class): 36 | r"""Class decorator that extends a given class to save its source code 37 | when pickled. 38 | 39 | Example: 40 | 41 | from torch_utils import persistence 42 | 43 | @persistence.persistent_class 44 | class MyNetwork(torch.nn.Module): 45 | def __init__(self, num_inputs, num_outputs): 46 | super().__init__() 47 | self.fc = MyLayer(num_inputs, num_outputs) 48 | ... 49 | 50 | @persistence.persistent_class 51 | class MyLayer(torch.nn.Module): 52 | ... 53 | 54 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its 55 | source code alongside other internal state (e.g., parameters, buffers, 56 | and submodules). This way, any previously exported pickle will remain 57 | usable even if the class definitions have been modified or are no 58 | longer available. 59 | 60 | The decorator saves the source code of the entire Python module 61 | containing the decorated class. It does *not* save the source code of 62 | any imported modules. Thus, the imported modules must be available 63 | during unpickling, also including `torch_utils.persistence` itself. 64 | 65 | It is ok to call functions defined in the same module from the 66 | decorated class. However, if the decorated class depends on other 67 | classes defined in the same module, they must be decorated as well. 68 | This is illustrated in the above example in the case of `MyLayer`. 69 | 70 | It is also possible to employ the decorator just-in-time before 71 | calling the constructor. For example: 72 | 73 | cls = MyLayer 74 | if want_to_make_it_persistent: 75 | cls = persistence.persistent_class(cls) 76 | layer = cls(num_inputs, num_outputs) 77 | 78 | As an additional feature, the decorator also keeps track of the 79 | arguments that were used to construct each instance of the decorated 80 | class. The arguments can be queried via `obj.init_args` and 81 | `obj.init_kwargs`, and they are automatically pickled alongside other 82 | object state. This feature can be disabled on a per-instance basis 83 | by setting `self._record_init_args = False` in the constructor. 84 | 85 | A typical use case is to first unpickle a previous instance of a 86 | persistent class, and then upgrade it to use the latest version of 87 | the source code: 88 | 89 | with open('old_pickle.pkl', 'rb') as f: 90 | old_net = pickle.load(f) 91 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) 92 | misc.copy_params_and_buffers(old_net, new_net, require_all=True) 93 | """ 94 | assert isinstance(orig_class, type) 95 | if is_persistent(orig_class): 96 | return orig_class 97 | 98 | assert orig_class.__module__ in sys.modules 99 | orig_module = sys.modules[orig_class.__module__] 100 | orig_module_src = _module_to_src(orig_module) 101 | 102 | @functools.wraps(orig_class, updated=()) 103 | class Decorator(orig_class): 104 | _orig_module_src = orig_module_src 105 | _orig_class_name = orig_class.__name__ 106 | 107 | def __init__(self, *args, **kwargs): 108 | super().__init__(*args, **kwargs) 109 | record_init_args = getattr(self, '_record_init_args', True) 110 | self._init_args = copy.deepcopy(args) if record_init_args else None 111 | self._init_kwargs = copy.deepcopy(kwargs) if record_init_args else None 112 | assert orig_class.__name__ in orig_module.__dict__ 113 | _check_pickleable(self.__reduce__()) 114 | 115 | @property 116 | def init_args(self): 117 | assert self._init_args is not None 118 | return copy.deepcopy(self._init_args) 119 | 120 | @property 121 | def init_kwargs(self): 122 | assert self._init_kwargs is not None 123 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) 124 | 125 | def __reduce__(self): 126 | fields = list(super().__reduce__()) 127 | fields += [None] * max(3 - len(fields), 0) 128 | if fields[0] is not _reconstruct_persistent_obj: 129 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) 130 | fields[0] = _reconstruct_persistent_obj # reconstruct func 131 | fields[1] = (meta,) # reconstruct args 132 | fields[2] = None # state dict 133 | return tuple(fields) 134 | 135 | _decorators.add(Decorator) 136 | return Decorator 137 | 138 | #---------------------------------------------------------------------------- 139 | 140 | def is_persistent(obj): 141 | r"""Test whether the given object or class is persistent, i.e., 142 | whether it will save its source code when pickled. 143 | """ 144 | try: 145 | if obj in _decorators: 146 | return True 147 | except TypeError: 148 | pass 149 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck 150 | 151 | #---------------------------------------------------------------------------- 152 | 153 | def import_hook(hook): 154 | r"""Register an import hook that is called whenever a persistent object 155 | is being unpickled. A typical use case is to patch the pickled source 156 | code to avoid errors and inconsistencies when the API of some imported 157 | module has changed. 158 | 159 | The hook should have the following signature: 160 | 161 | hook(meta) -> modified meta 162 | 163 | `meta` is an instance of `dnnlib.EasyDict` with the following fields: 164 | 165 | type: Type of the persistent object, e.g. `'class'`. 166 | version: Internal version number of `torch_utils.persistence`. 167 | module_src Original source code of the Python module. 168 | class_name: Class name in the original Python module. 169 | state: Internal state of the object. 170 | 171 | Example: 172 | 173 | @persistence.import_hook 174 | def wreck_my_network(meta): 175 | if meta.class_name == 'MyNetwork': 176 | print('MyNetwork is being imported. I will wreck it!') 177 | meta.module_src = meta.module_src.replace("True", "False") 178 | return meta 179 | """ 180 | assert callable(hook) 181 | _import_hooks.append(hook) 182 | 183 | #---------------------------------------------------------------------------- 184 | 185 | def _reconstruct_persistent_obj(meta): 186 | r"""Hook that is called internally by the `pickle` module to unpickle 187 | a persistent object. 188 | """ 189 | meta = dnnlib.EasyDict(meta) 190 | meta.state = dnnlib.EasyDict(meta.state) 191 | for hook in _import_hooks: 192 | meta = hook(meta) 193 | assert meta is not None 194 | 195 | assert meta.version == _version 196 | module = _src_to_module(meta.module_src) 197 | 198 | assert meta.type == 'class' 199 | orig_class = module.__dict__[meta.class_name] 200 | decorator_class = persistent_class(orig_class) 201 | obj = decorator_class.__new__(decorator_class) 202 | 203 | setstate = getattr(obj, '__setstate__', None) 204 | if callable(setstate): 205 | setstate(meta.state) # pylint: disable=not-callable 206 | else: 207 | obj.__dict__.update(meta.state) 208 | return obj 209 | 210 | #---------------------------------------------------------------------------- 211 | 212 | def _module_to_src(module): 213 | r"""Query the source code of a given Python module. 214 | """ 215 | src = _module_to_src_dict.get(module, None) 216 | if src is None: 217 | src = inspect.getsource(module) 218 | _module_to_src_dict[module] = src 219 | _src_to_module_dict[src] = module 220 | return src 221 | 222 | def _src_to_module(src): 223 | r"""Get or create a Python module for the given source code. 224 | """ 225 | module = _src_to_module_dict.get(src, None) 226 | if module is None: 227 | module_name = "_imported_module_" + uuid.uuid4().hex 228 | module = types.ModuleType(module_name) 229 | sys.modules[module_name] = module 230 | _module_to_src_dict[module] = src 231 | _src_to_module_dict[src] = module 232 | exec(src, module.__dict__) # pylint: disable=exec-used 233 | return module 234 | 235 | #---------------------------------------------------------------------------- 236 | 237 | def _check_pickleable(obj): 238 | r"""Check that the given object is pickleable, raising an exception if 239 | it is not. This function is expected to be considerably more efficient 240 | than actually pickling the object. 241 | """ 242 | def recurse(obj): 243 | if isinstance(obj, (list, tuple, set)): 244 | return [recurse(x) for x in obj] 245 | if isinstance(obj, dict): 246 | return [[recurse(x), recurse(y)] for x, y in obj.items()] 247 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)): 248 | return None # Python primitive types are pickleable. 249 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']: 250 | return None # NumPy arrays and PyTorch tensors are pickleable. 251 | if is_persistent(obj): 252 | return None # Persistent objects are pickleable, by virtue of the constructor check. 253 | return obj 254 | with io.BytesIO() as f: 255 | pickle.dump(recurse(obj), f) 256 | 257 | #---------------------------------------------------------------------------- 258 | -------------------------------------------------------------------------------- /preprocessing/torch_utils/training_stats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Facilities for reporting and collecting training statistics across 9 | multiple processes and devices. The interface is designed to minimize 10 | synchronization overhead as well as the amount of boilerplate in user 11 | code.""" 12 | 13 | import re 14 | import numpy as np 15 | import torch 16 | import dnnlib 17 | 18 | from . import misc 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] 23 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. 24 | _counter_dtype = torch.float64 # Data type to use for the internal counters. 25 | _rank = 0 # Rank of the current process. 26 | _sync_device = None # Device to use for multiprocess communication. None = single-process. 27 | _sync_called = False # Has _sync() been called yet? 28 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor 29 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor 30 | 31 | #---------------------------------------------------------------------------- 32 | 33 | def init_multiprocessing(rank, sync_device): 34 | r"""Initializes `torch_utils.training_stats` for collecting statistics 35 | across multiple processes. 36 | 37 | This function must be called after 38 | `torch.distributed.init_process_group()` and before `Collector.update()`. 39 | The call is not necessary if multi-process collection is not needed. 40 | 41 | Args: 42 | rank: Rank of the current process. 43 | sync_device: PyTorch device to use for inter-process 44 | communication, or None to disable multi-process 45 | collection. Typically `torch.device('cuda', rank)`. 46 | """ 47 | global _rank, _sync_device 48 | assert not _sync_called 49 | _rank = rank 50 | _sync_device = sync_device 51 | 52 | #---------------------------------------------------------------------------- 53 | 54 | @misc.profiled_function 55 | def report(name, value): 56 | r"""Broadcasts the given set of scalars to all interested instances of 57 | `Collector`, across device and process boundaries. NaNs and Infs are 58 | ignored. 59 | 60 | This function is expected to be extremely cheap and can be safely 61 | called from anywhere in the training loop, loss function, or inside a 62 | `torch.nn.Module`. 63 | 64 | Warning: The current implementation expects the set of unique names to 65 | be consistent across processes. Please make sure that `report()` is 66 | called at least once for each unique name by each process, and in the 67 | same order. If a given process has no scalars to broadcast, it can do 68 | `report(name, [])` (empty list). 69 | 70 | Args: 71 | name: Arbitrary string specifying the name of the statistic. 72 | Averages are accumulated separately for each unique name. 73 | value: Arbitrary set of scalars. Can be a list, tuple, 74 | NumPy array, PyTorch tensor, or Python scalar. 75 | 76 | Returns: 77 | The same `value` that was passed in. 78 | """ 79 | if name not in _counters: 80 | _counters[name] = dict() 81 | 82 | elems = torch.as_tensor(value) 83 | if elems.numel() == 0: 84 | return value 85 | 86 | elems = elems.detach().flatten().to(_reduce_dtype) 87 | square = elems.square() 88 | finite = square.isfinite() 89 | moments = torch.stack([ 90 | finite.sum(dtype=_reduce_dtype), 91 | torch.where(finite, elems, 0).sum(), 92 | torch.where(finite, square, 0).sum(), 93 | ]) 94 | assert moments.ndim == 1 and moments.shape[0] == _num_moments 95 | moments = moments.to(_counter_dtype) 96 | 97 | device = moments.device 98 | if device not in _counters[name]: 99 | _counters[name][device] = torch.zeros_like(moments) 100 | _counters[name][device].add_(moments) 101 | return value 102 | 103 | #---------------------------------------------------------------------------- 104 | 105 | def report0(name, value): 106 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`), 107 | but ignores any scalars provided by the other processes. 108 | See `report()` for further details. 109 | """ 110 | report(name, value if _rank == 0 else []) 111 | return value 112 | 113 | #---------------------------------------------------------------------------- 114 | 115 | class Collector: 116 | r"""Collects the scalars broadcasted by `report()` and `report0()` and 117 | computes their long-term averages (mean and standard deviation) over 118 | user-defined periods of time. 119 | 120 | The averages are first collected into internal counters that are not 121 | directly visible to the user. They are then copied to the user-visible 122 | state as a result of calling `update()` and can then be queried using 123 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the 124 | internal counters for the next round, so that the user-visible state 125 | effectively reflects averages collected between the last two calls to 126 | `update()`. 127 | 128 | Args: 129 | regex: Regular expression defining which statistics to 130 | collect. The default is to collect everything. 131 | keep_previous: Whether to retain the previous averages if no 132 | scalars were collected on a given round 133 | (default: False). 134 | """ 135 | def __init__(self, regex='.*', keep_previous=False): 136 | self._regex = re.compile(regex) 137 | self._keep_previous = keep_previous 138 | self._cumulative = dict() 139 | self._moments = dict() 140 | self.update() 141 | self._moments.clear() 142 | 143 | def names(self): 144 | r"""Returns the names of all statistics broadcasted so far that 145 | match the regular expression specified at construction time. 146 | """ 147 | return [name for name in _counters if self._regex.fullmatch(name)] 148 | 149 | def update(self): 150 | r"""Copies current values of the internal counters to the 151 | user-visible state and resets them for the next round. 152 | 153 | If `keep_previous=True` was specified at construction time, the 154 | operation is skipped for statistics that have received no scalars 155 | since the last update, retaining their previous averages. 156 | 157 | This method performs a number of GPU-to-CPU transfers and one 158 | `torch.distributed.all_reduce()`. It is intended to be called 159 | periodically in the main training loop, typically once every 160 | N training steps. 161 | """ 162 | if not self._keep_previous: 163 | self._moments.clear() 164 | for name, cumulative in _sync(self.names()): 165 | if name not in self._cumulative: 166 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 167 | delta = cumulative - self._cumulative[name] 168 | self._cumulative[name].copy_(cumulative) 169 | if float(delta[0]) != 0: 170 | self._moments[name] = delta 171 | 172 | def _get_delta(self, name): 173 | r"""Returns the raw moments that were accumulated for the given 174 | statistic between the last two calls to `update()`, or zero if 175 | no scalars were collected. 176 | """ 177 | assert self._regex.fullmatch(name) 178 | if name not in self._moments: 179 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 180 | return self._moments[name] 181 | 182 | def num(self, name): 183 | r"""Returns the number of scalars that were accumulated for the given 184 | statistic between the last two calls to `update()`, or zero if 185 | no scalars were collected. 186 | """ 187 | delta = self._get_delta(name) 188 | return int(delta[0]) 189 | 190 | def mean(self, name): 191 | r"""Returns the mean of the scalars that were accumulated for the 192 | given statistic between the last two calls to `update()`, or NaN if 193 | no scalars were collected. 194 | """ 195 | delta = self._get_delta(name) 196 | if int(delta[0]) == 0: 197 | return float('nan') 198 | return float(delta[1] / delta[0]) 199 | 200 | def std(self, name): 201 | r"""Returns the standard deviation of the scalars that were 202 | accumulated for the given statistic between the last two calls to 203 | `update()`, or NaN if no scalars were collected. 204 | """ 205 | delta = self._get_delta(name) 206 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): 207 | return float('nan') 208 | if int(delta[0]) == 1: 209 | return float(0) 210 | mean = float(delta[1] / delta[0]) 211 | raw_var = float(delta[2] / delta[0]) 212 | return np.sqrt(max(raw_var - np.square(mean), 0)) 213 | 214 | def as_dict(self): 215 | r"""Returns the averages accumulated between the last two calls to 216 | `update()` as an `dnnlib.EasyDict`. The contents are as follows: 217 | 218 | dnnlib.EasyDict( 219 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), 220 | ... 221 | ) 222 | """ 223 | stats = dnnlib.EasyDict() 224 | for name in self.names(): 225 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) 226 | return stats 227 | 228 | def __getitem__(self, name): 229 | r"""Convenience getter. 230 | `collector[name]` is a synonym for `collector.mean(name)`. 231 | """ 232 | return self.mean(name) 233 | 234 | #---------------------------------------------------------------------------- 235 | 236 | def _sync(names): 237 | r"""Synchronize the global cumulative counters across devices and 238 | processes. Called internally by `Collector.update()`. 239 | """ 240 | if len(names) == 0: 241 | return [] 242 | global _sync_called 243 | _sync_called = True 244 | 245 | # Check that all ranks have the same set of names. 246 | if _sync_device is not None: 247 | value = hash(tuple(tuple(ord(char) for char in name) for name in names)) 248 | other = torch.as_tensor(value, dtype=torch.int64, device=_sync_device) 249 | torch.distributed.broadcast(tensor=other, src=0) 250 | if value != int(other.cpu()): 251 | raise ValueError('Training statistics are inconsistent between ranks') 252 | 253 | # Collect deltas within current rank. 254 | deltas = [] 255 | device = _sync_device if _sync_device is not None else torch.device('cpu') 256 | for name in names: 257 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) 258 | for counter in _counters[name].values(): 259 | delta.add_(counter.to(device)) 260 | counter.copy_(torch.zeros_like(counter)) 261 | deltas.append(delta) 262 | deltas = torch.stack(deltas) 263 | 264 | # Sum deltas across ranks. 265 | if _sync_device is not None: 266 | torch.distributed.all_reduce(deltas) 267 | 268 | # Update cumulative values. 269 | deltas = deltas.cpu() 270 | for idx, name in enumerate(names): 271 | if name not in _cumulative: 272 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 273 | _cumulative[name].add_(deltas[idx]) 274 | 275 | # Return name-value pairs. 276 | return [(name, _cumulative[name]) for name in names] 277 | 278 | #---------------------------------------------------------------------------- 279 | # Convenience. 280 | 281 | default_collector = Collector() 282 | 283 | #---------------------------------------------------------------------------- 284 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | diffusers 4 | transformers 5 | timm 6 | tqdm 7 | accelerate 8 | wandb 9 | requests 10 | git+https://github.com/openai/CLIP.git 11 | ftfy 12 | regex 13 | einops -------------------------------------------------------------------------------- /samplers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def expand_t_like_x(t, x_cur): 6 | """Function to reshape time t to broadcastable dimension of x 7 | Args: 8 | t: [batch_dim,], time vector 9 | x: [batch_dim,...], data point 10 | """ 11 | dims = [1] * (len(x_cur.size()) - 1) 12 | t = t.view(t.size(0), *dims) 13 | return t 14 | 15 | def get_score_from_velocity(vt, xt, t, path_type="linear"): 16 | """Wrapper function: transfrom velocity prediction model to score 17 | Args: 18 | velocity: [batch_dim, ...] shaped tensor; velocity model output 19 | x: [batch_dim, ...] shaped tensor; x_t data point 20 | t: [batch_dim,] time tensor 21 | """ 22 | t = expand_t_like_x(t, xt) 23 | if path_type == "linear": 24 | alpha_t, d_alpha_t = 1 - t, torch.ones_like(xt, device=xt.device) * -1 25 | sigma_t, d_sigma_t = t, torch.ones_like(xt, device=xt.device) 26 | elif path_type == "cosine": 27 | alpha_t = torch.cos(t * np.pi / 2) 28 | sigma_t = torch.sin(t * np.pi / 2) 29 | d_alpha_t = -np.pi / 2 * torch.sin(t * np.pi / 2) 30 | d_sigma_t = np.pi / 2 * torch.cos(t * np.pi / 2) 31 | else: 32 | raise NotImplementedError 33 | 34 | mean = xt 35 | reverse_alpha_ratio = alpha_t / d_alpha_t 36 | var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t 37 | score = (reverse_alpha_ratio * vt - mean) / var 38 | 39 | return score 40 | 41 | 42 | def compute_diffusion(t_cur): 43 | return 2 * t_cur 44 | 45 | 46 | def euler_sampler( 47 | model, 48 | latents, 49 | y, 50 | num_steps=20, 51 | heun=False, 52 | cfg_scale=1.0, 53 | guidance_low=0.0, 54 | guidance_high=1.0, 55 | path_type="linear", # not used, just for compatability 56 | ): 57 | # setup conditioning 58 | if cfg_scale > 1.0: 59 | y_null = torch.tensor([1000] * y.size(0), device=y.device) 60 | _dtype = latents.dtype 61 | t_steps = torch.linspace(1, 0, num_steps+1, dtype=torch.float64) 62 | x_next = latents.to(torch.float64) 63 | device = x_next.device 64 | 65 | with torch.no_grad(): 66 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): 67 | x_cur = x_next 68 | if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low: 69 | model_input = torch.cat([x_cur] * 2, dim=0) 70 | y_cur = torch.cat([y, y_null], dim=0) 71 | else: 72 | model_input = x_cur 73 | y_cur = y 74 | kwargs = dict(y=y_cur) 75 | time_input = torch.ones(model_input.size(0)).to(device=device, dtype=torch.float64) * t_cur 76 | d_cur = model( 77 | model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs 78 | )[0].to(torch.float64) 79 | if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low: 80 | d_cur_cond, d_cur_uncond = d_cur.chunk(2) 81 | d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond) 82 | x_next = x_cur + (t_next - t_cur) * d_cur 83 | if heun and (i < num_steps - 1): 84 | if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low: 85 | model_input = torch.cat([x_next] * 2) 86 | y_cur = torch.cat([y, y_null], dim=0) 87 | else: 88 | model_input = x_next 89 | y_cur = y 90 | kwargs = dict(y=y_cur) 91 | time_input = torch.ones(model_input.size(0)).to( 92 | device=model_input.device, dtype=torch.float64 93 | ) * t_next 94 | d_prime = model( 95 | model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs 96 | )[0].to(torch.float64) 97 | if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low: 98 | d_prime_cond, d_prime_uncond = d_prime.chunk(2) 99 | d_prime = d_prime_uncond + cfg_scale * (d_prime_cond - d_prime_uncond) 100 | x_next = x_cur + (t_next - t_cur) * (0.5 * d_cur + 0.5 * d_prime) 101 | 102 | return x_next 103 | 104 | 105 | def euler_maruyama_sampler( 106 | model, 107 | latents, 108 | y, 109 | num_steps=20, 110 | heun=False, # not used, just for compatability 111 | cfg_scale=1.0, 112 | guidance_low=0.0, 113 | guidance_high=1.0, 114 | path_type="linear", 115 | ): 116 | # setup conditioning 117 | if cfg_scale > 1.0: 118 | y_null = torch.tensor([1000] * y.size(0), device=y.device) 119 | 120 | _dtype = latents.dtype 121 | 122 | t_steps = torch.linspace(1., 0.04, num_steps, dtype=torch.float64) 123 | t_steps = torch.cat([t_steps, torch.tensor([0.], dtype=torch.float64)]) 124 | x_next = latents.to(torch.float64) 125 | device = x_next.device 126 | 127 | with torch.no_grad(): 128 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-2], t_steps[1:-1])): 129 | dt = t_next - t_cur 130 | x_cur = x_next 131 | if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low: 132 | model_input = torch.cat([x_cur] * 2, dim=0) 133 | y_cur = torch.cat([y, y_null], dim=0) 134 | else: 135 | model_input = x_cur 136 | y_cur = y 137 | kwargs = dict(y=y_cur) 138 | time_input = torch.ones(model_input.size(0)).to(device=device, dtype=torch.float64) * t_cur 139 | diffusion = compute_diffusion(t_cur) 140 | eps_i = torch.randn_like(x_cur).to(device) 141 | deps = eps_i * torch.sqrt(torch.abs(dt)) 142 | 143 | # compute drift 144 | v_cur = model( 145 | model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs 146 | )[0].to(torch.float64) 147 | s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type) 148 | d_cur = v_cur - 0.5 * diffusion * s_cur 149 | if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low: 150 | d_cur_cond, d_cur_uncond = d_cur.chunk(2) 151 | d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond) 152 | 153 | x_next = x_cur + d_cur * dt + torch.sqrt(diffusion) * deps 154 | 155 | # last step 156 | t_cur, t_next = t_steps[-2], t_steps[-1] 157 | dt = t_next - t_cur 158 | x_cur = x_next 159 | if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low: 160 | model_input = torch.cat([x_cur] * 2, dim=0) 161 | y_cur = torch.cat([y, y_null], dim=0) 162 | else: 163 | model_input = x_cur 164 | y_cur = y 165 | kwargs = dict(y=y_cur) 166 | time_input = torch.ones(model_input.size(0)).to( 167 | device=device, dtype=torch.float64 168 | ) * t_cur 169 | 170 | # compute drift 171 | v_cur = model( 172 | model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs 173 | )[0].to(torch.float64) 174 | s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type) 175 | diffusion = compute_diffusion(t_cur) 176 | d_cur = v_cur - 0.5 * diffusion * s_cur 177 | if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low: 178 | d_cur_cond, d_cur_uncond = d_cur.chunk(2) 179 | d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond) 180 | 181 | mean_x = x_cur + dt * d_cur 182 | 183 | return mean_x 184 | -------------------------------------------------------------------------------- /samplers_t2i.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def expand_t_like_x(t, x_cur): 6 | """Function to reshape time t to broadcastable dimension of x 7 | Args: 8 | t: [batch_dim,], time vector 9 | x: [batch_dim,...], data point 10 | """ 11 | dims = [1] * (len(x_cur.size()) - 1) 12 | t = t.view(t.size(0), *dims) 13 | return t 14 | 15 | def get_score_from_velocity(vt, xt, t, path_type="linear"): 16 | """Wrapper function: transfrom velocity prediction model to score 17 | Args: 18 | velocity: [batch_dim, ...] shaped tensor; velocity model output 19 | x: [batch_dim, ...] shaped tensor; x_t data point 20 | t: [batch_dim,] time tensor 21 | """ 22 | t = expand_t_like_x(t, xt) 23 | if path_type == "linear": 24 | alpha_t, d_alpha_t = 1 - t, torch.ones_like(xt, device=xt.device) * -1 25 | sigma_t, d_sigma_t = t, torch.ones_like(xt, device=xt.device) 26 | elif path_type == "cosine": 27 | alpha_t = torch.cos(t * np.pi / 2) 28 | sigma_t = torch.sin(t * np.pi / 2) 29 | d_alpha_t = -np.pi / 2 * torch.sin(t * np.pi / 2) 30 | d_sigma_t = np.pi / 2 * torch.cos(t * np.pi / 2) 31 | else: 32 | raise NotImplementedError 33 | 34 | mean = xt 35 | reverse_alpha_ratio = alpha_t / d_alpha_t 36 | var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t 37 | score = (reverse_alpha_ratio * vt - mean) / var 38 | 39 | return score 40 | 41 | 42 | def compute_diffusion(t_cur): 43 | return 2 * t_cur 44 | 45 | 46 | def euler_sampler( 47 | model, 48 | latents, 49 | y, 50 | y_null, 51 | num_steps=20, 52 | heun=False, 53 | cfg_scale=1.0, 54 | guidance_low=0.0, 55 | guidance_high=1.0, 56 | path_type="linear", # not used, just for compatability 57 | ): 58 | # setup conditioning 59 | _dtype = latents.dtype 60 | t_steps = torch.linspace(1, 0, num_steps+1, dtype=torch.float64) 61 | x_next = latents.to(torch.float64) 62 | device = x_next.device 63 | 64 | with torch.no_grad(): 65 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): 66 | x_cur = x_next 67 | if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low: 68 | model_input = torch.cat([x_cur] * 2, dim=0) 69 | y_cur = torch.cat([y, y_null], dim=0) 70 | else: 71 | model_input = x_cur 72 | y_cur = y 73 | kwargs = dict(context=y_cur) 74 | time_input = torch.ones(model_input.size(0)).to(device=device, dtype=torch.float64) * t_cur 75 | d_cur = model( 76 | model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs 77 | )[0].to(torch.float64) 78 | if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low: 79 | d_cur_cond, d_cur_uncond = d_cur.chunk(2) 80 | d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond) 81 | x_next = x_cur + (t_next - t_cur) * d_cur 82 | if heun and (i < num_steps - 1): 83 | if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low: 84 | model_input = torch.cat([x_next] * 2) 85 | y_cur = torch.cat([y, y_null], dim=0) 86 | else: 87 | model_input = x_next 88 | y_cur = y 89 | kwargs = dict(context=y_cur) 90 | time_input = torch.ones(model_input.size(0)).to( 91 | device=model_input.device, dtype=torch.float64 92 | ) * t_next 93 | d_prime = model( 94 | model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs 95 | )[0].to(torch.float64) 96 | if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low: 97 | d_prime_cond, d_prime_uncond = d_prime.chunk(2) 98 | d_prime = d_prime_uncond + cfg_scale * (d_prime_cond - d_prime_uncond) 99 | x_next = x_cur + (t_next - t_cur) * (0.5 * d_cur + 0.5 * d_prime) 100 | 101 | return x_next 102 | 103 | 104 | def euler_maruyama_sampler( 105 | model, 106 | latents, 107 | y, 108 | y_null, 109 | num_steps=20, 110 | heun=False, # not used, just for compatability 111 | cfg_scale=1.0, 112 | guidance_low=0.0, 113 | guidance_high=1.0, 114 | path_type="linear", 115 | ): 116 | # setup conditioning 117 | _dtype = latents.dtype 118 | 119 | t_steps = torch.linspace(1., 0.04, num_steps, dtype=torch.float64) 120 | t_steps = torch.cat([t_steps, torch.tensor([0.], dtype=torch.float64)]) 121 | x_next = latents.to(torch.float64) 122 | device = x_next.device 123 | 124 | with torch.no_grad(): 125 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-2], t_steps[1:-1])): 126 | dt = t_next - t_cur 127 | x_cur = x_next 128 | if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low: 129 | model_input = torch.cat([x_cur] * 2, dim=0) 130 | y_cur = torch.cat([y, y_null], dim=0) 131 | else: 132 | model_input = x_cur 133 | y_cur = y 134 | kwargs = dict(context=y_cur) 135 | time_input = torch.ones(model_input.size(0)).to(device=device, dtype=torch.float64) * t_cur 136 | diffusion = compute_diffusion(t_cur) 137 | eps_i = torch.randn_like(x_cur).to(device) 138 | deps = eps_i * torch.sqrt(torch.abs(dt)) 139 | 140 | # compute drift 141 | v_cur = model( 142 | model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs 143 | )[0].to(torch.float64) 144 | s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type) 145 | d_cur = v_cur - 0.5 * diffusion * s_cur 146 | if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low: 147 | d_cur_cond, d_cur_uncond = d_cur.chunk(2) 148 | d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond) 149 | 150 | x_next = x_cur + d_cur * dt + torch.sqrt(diffusion) * deps 151 | 152 | # last step 153 | t_cur, t_next = t_steps[-2], t_steps[-1] 154 | dt = t_next - t_cur 155 | x_cur = x_next 156 | if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low: 157 | model_input = torch.cat([x_cur] * 2, dim=0) 158 | y_cur = torch.cat([y, y_null], dim=0) 159 | else: 160 | model_input = x_cur 161 | y_cur = y 162 | kwargs = dict(context=y_cur) 163 | time_input = torch.ones(model_input.size(0)).to( 164 | device=device, dtype=torch.float64 165 | ) * t_cur 166 | 167 | # compute drift 168 | v_cur = model( 169 | model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs 170 | )[0].to(torch.float64) 171 | s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type) 172 | diffusion = compute_diffusion(t_cur) 173 | d_cur = v_cur - 0.5 * diffusion * s_cur 174 | if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low: 175 | d_cur_cond, d_cur_uncond = d_cur.chunk(2) 176 | d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond) 177 | 178 | mean_x = x_cur + dt * d_cur 179 | 180 | return mean_x -------------------------------------------------------------------------------- /train_t2i.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | from copy import deepcopy 4 | import logging 5 | import os 6 | from pathlib import Path 7 | from collections import OrderedDict 8 | import json 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn.functional as F 13 | import torch.utils.checkpoint 14 | from tqdm.auto import tqdm 15 | from torch.utils.data import DataLoader 16 | 17 | from accelerate import Accelerator 18 | from accelerate.logging import get_logger 19 | from accelerate.utils import ProjectConfiguration, set_seed 20 | 21 | from models.mmdit import MMDiT 22 | from loss import SILoss 23 | from utils import load_encoders 24 | 25 | from dataset import MSCOCO256Features 26 | from diffusers.models import AutoencoderKL 27 | # import wandb_utils 28 | import wandb 29 | import math 30 | from torchvision.utils import make_grid 31 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 32 | from torchvision.transforms import Normalize 33 | 34 | logger = get_logger(__name__) 35 | 36 | CLIP_DEFAULT_MEAN = (0.48145466, 0.4578275, 0.40821073) 37 | CLIP_DEFAULT_STD = (0.26862954, 0.26130258, 0.27577711) 38 | 39 | def preprocess_raw_image(x, enc_type, resolution=256): 40 | if 'clip' in enc_type: 41 | x = x / 255. 42 | x = torch.nn.functional.interpolate(x, 224, mode='bicubic') 43 | x = Normalize(CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD)(x) 44 | elif 'mocov3' in enc_type or 'mae' in enc_type: 45 | x = x / 255. 46 | x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) 47 | elif 'dinov2' in enc_type: 48 | x = x / 255. 49 | x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) 50 | x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic') 51 | elif 'dinov1' in enc_type: 52 | x = x / 255. 53 | x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) 54 | elif 'jepa' in enc_type: 55 | x = x / 255. 56 | x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) 57 | x = torch.nn.functional.interpolate(x, 224, mode='bicubic') 58 | 59 | return x 60 | 61 | 62 | def array2grid(x): 63 | nrow = round(math.sqrt(x.size(0))) 64 | x = make_grid(x.clamp(0, 1), nrow=nrow, value_range=(0, 1)) 65 | x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() 66 | return x 67 | 68 | 69 | @torch.no_grad() 70 | def sample_posterior(moments, latents_scale=1., latents_bias=0.): 71 | device = moments.device 72 | mean, logvar = torch.chunk(moments, 2, dim=1) 73 | logvar = torch.clamp(logvar, -30.0, 20.0) 74 | std = torch.exp(0.5 * logvar) 75 | z = mean + std * torch.randn_like(mean) 76 | z = (z * latents_scale + latents_bias) 77 | return z 78 | 79 | 80 | @torch.no_grad() 81 | def update_ema(ema_model, model, decay=0.9999): 82 | """ 83 | Step the EMA model towards the current model. 84 | """ 85 | ema_params = OrderedDict(ema_model.named_parameters()) 86 | model_params = OrderedDict(model.named_parameters()) 87 | 88 | for name, param in model_params.items(): 89 | name = name.replace("module.", "") 90 | # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed 91 | ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) 92 | 93 | 94 | def create_logger(logging_dir): 95 | """ 96 | Create a logger that writes to a log file and stdout. 97 | """ 98 | logging.basicConfig( 99 | level=logging.INFO, 100 | format='[\033[34m%(asctime)s\033[0m] %(message)s', 101 | datefmt='%Y-%m-%d %H:%M:%S', 102 | handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] 103 | ) 104 | logger = logging.getLogger(__name__) 105 | return logger 106 | 107 | 108 | def requires_grad(model, flag=True): 109 | """ 110 | Set requires_grad flag for all parameters in a model. 111 | """ 112 | for p in model.parameters(): 113 | p.requires_grad = flag 114 | 115 | 116 | ################################################################################# 117 | # Training Loop # 118 | ################################################################################# 119 | 120 | def main(args): 121 | # set accelerator 122 | logging_dir = Path(args.output_dir, args.logging_dir) 123 | accelerator_project_config = ProjectConfiguration( 124 | project_dir=args.output_dir, logging_dir=logging_dir 125 | ) 126 | 127 | accelerator = Accelerator( 128 | gradient_accumulation_steps=args.gradient_accumulation_steps, 129 | mixed_precision=args.mixed_precision, 130 | log_with=args.report_to, 131 | project_config=accelerator_project_config, 132 | ) 133 | 134 | if accelerator.is_main_process: 135 | os.makedirs(args.output_dir, exist_ok=True) # Make results folder (holds all experiment subfolders) 136 | save_dir = os.path.join(args.output_dir, args.exp_name) 137 | os.makedirs(save_dir, exist_ok=True) 138 | args_dict = vars(args) 139 | # Save to a JSON file 140 | json_dir = os.path.join(save_dir, "args.json") 141 | with open(json_dir, 'w') as f: 142 | json.dump(args_dict, f, indent=4) 143 | checkpoint_dir = f"{save_dir}/checkpoints" # Stores saved model checkpoints 144 | os.makedirs(checkpoint_dir, exist_ok=True) 145 | logger = create_logger(save_dir) 146 | logger.info(f"Experiment directory created at {save_dir}") 147 | device = accelerator.device 148 | if torch.backends.mps.is_available(): 149 | accelerator.native_amp = False 150 | if args.seed is not None: 151 | set_seed(args.seed + accelerator.process_index) 152 | 153 | # Create model: 154 | assert args.resolution % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)." 155 | latent_size = args.resolution // 8 156 | 157 | if args.enc_type != 'None': 158 | encoders, encoder_types, architectures = load_encoders(args.enc_type, device) 159 | else: 160 | encoders, encoder_types, architectures = [None], [None], [None] 161 | z_dims = [encoder.embed_dim for encoder in encoders] if args.enc_type != 'None' else [0] 162 | #block_kwargs = {"fused_attn": args.fused_attn, "qk_norm": args.qk_norm} 163 | model = MMDiT( 164 | input_size=latent_size, 165 | z_dims = z_dims, 166 | encoder_depth=args.encoder_depth, 167 | ) 168 | 169 | model = model.to(device) 170 | ema = deepcopy(model).to(device) # Create an EMA of the model for use after training 171 | vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-mse").to(device) 172 | requires_grad(ema, False) 173 | 174 | latents_scale = torch.tensor( 175 | [0.18215, 0.18215, 0.18215, 0.18215] 176 | ).view(1, 4, 1, 1).to(device) 177 | latents_bias = torch.tensor( 178 | [0., 0., 0., 0.] 179 | ).view(1, 4, 1, 1).to(device) 180 | 181 | # create loss function 182 | loss_fn = SILoss( 183 | prediction=args.prediction, 184 | path_type=args.path_type, 185 | encoders=encoders, 186 | accelerator=accelerator, 187 | latents_scale=latents_scale, 188 | latents_bias=latents_bias, 189 | weighting=args.weighting 190 | ) 191 | if accelerator.is_main_process: 192 | logger.info(f"SiT Parameters: {sum(p.numel() for p in model.parameters()):,}") 193 | 194 | # Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper): 195 | if args.allow_tf32: 196 | torch.backends.cuda.matmul.allow_tf32 = True 197 | torch.backends.cudnn.allow_tf32 = True 198 | 199 | optimizer = torch.optim.AdamW( 200 | model.parameters(), 201 | lr=args.learning_rate, 202 | betas=(args.adam_beta1, args.adam_beta2), 203 | weight_decay=args.adam_weight_decay, 204 | eps=args.adam_epsilon, 205 | ) 206 | 207 | # Setup data: 208 | train_dataset = MSCOCO256Features(path=args.data_dir).train 209 | local_batch_size = int(args.batch_size // accelerator.num_processes) 210 | train_dataloader = DataLoader( 211 | train_dataset, 212 | batch_size=local_batch_size, 213 | shuffle=True, 214 | num_workers=args.num_workers, 215 | pin_memory=True, 216 | drop_last=True 217 | ) 218 | if accelerator.is_main_process: 219 | logger.info(f"Dataset contains {len(train_dataset):,} images ({args.data_dir})") 220 | 221 | # Prepare models for training: 222 | update_ema(ema, model, decay=0) # Ensure EMA is initialized with synced weights 223 | model.train() # important! This enables embedding dropout for classifier-free guidance 224 | ema.eval() # EMA model should always be in eval mode 225 | 226 | # resume: 227 | global_step = 0 228 | if args.resume_step > 0: 229 | ckpt_name = str(args.resume_step).zfill(7) +'.pt' 230 | ckpt = torch.load( 231 | f'{os.path.join(args.output_dir, args.exp_name)}/checkpoints/{ckpt_name}', 232 | map_location='cpu', 233 | ) 234 | model.load_state_dict(ckpt['model']) 235 | ema.load_state_dict(ckpt['ema']) 236 | optimizer.load_state_dict(ckpt['opt']) 237 | global_step = ckpt['steps'] 238 | 239 | model, optimizer, train_dataloader = accelerator.prepare( 240 | model, optimizer, train_dataloader 241 | ) 242 | 243 | if accelerator.is_main_process: 244 | tracker_config = vars(copy.deepcopy(args)) 245 | accelerator.init_trackers( 246 | project_name="REPA", 247 | config=tracker_config, 248 | init_kwargs={ 249 | "wandb": {"name": f"{args.exp_name}"} 250 | }, 251 | ) 252 | 253 | progress_bar = tqdm( 254 | range(0, args.max_train_steps), 255 | initial=global_step, 256 | desc="Steps", 257 | # Only show the progress bar once on each machine. 258 | disable=not accelerator.is_local_main_process, 259 | ) 260 | 261 | # Labels to condition the model with (feel free to change): 262 | sample_batch_size = 64 // accelerator.num_processes 263 | _, gt_xs, _ = next(iter(train_dataloader)) 264 | gt_xs = gt_xs[:sample_batch_size] 265 | gt_xs = sample_posterior( 266 | gt_xs.to(device), latents_scale=latents_scale, latents_bias=latents_bias 267 | ) 268 | # Create sampling noise: 269 | xT = torch.randn((sample_batch_size, 4, latent_size, latent_size), device=device) 270 | 271 | for epoch in range(args.epochs): 272 | model.train() 273 | for raw_image, x, context, raw_captions in train_dataloader: 274 | if global_step == 0: 275 | ys = context[:sample_batch_size].to(device) # handed-coded 276 | raw_image = raw_image.to(device) 277 | x = x.squeeze(dim=1).to(device) 278 | context = context.to(device) 279 | z = None 280 | with torch.no_grad(): 281 | x = sample_posterior(x, latents_scale=latents_scale, latents_bias=latents_bias) 282 | zs = [] 283 | with accelerator.autocast(): 284 | for encoder, encoder_type, arch in zip(encoders, encoder_types, architectures): 285 | raw_image_ = preprocess_raw_image( 286 | raw_image, encoder_type, resolution=args.resolution 287 | ) 288 | z = encoder.forward_features(raw_image_) 289 | if 'mocov3' in encoder_type: z = z = z[:, 1:] 290 | if 'dinov2' in encoder_type: z = z['x_norm_patchtokens'] 291 | zs.append(z) 292 | 293 | with accelerator.accumulate(model): 294 | model_kwargs = dict(context=context) 295 | loss, proj_loss = loss_fn(model, x, model_kwargs, zs=zs) 296 | loss_mean = loss.mean() 297 | proj_loss_mean = proj_loss.mean() 298 | loss = loss_mean + proj_loss_mean * args.proj_coeff 299 | 300 | ## optimization 301 | accelerator.backward(loss) 302 | if accelerator.sync_gradients: 303 | params_to_clip = model.parameters() 304 | grad_norm = accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 305 | optimizer.step() 306 | optimizer.zero_grad(set_to_none=True) 307 | 308 | if accelerator.sync_gradients: 309 | update_ema(ema, model) # change ema function 310 | 311 | ### enter 312 | if accelerator.sync_gradients: 313 | progress_bar.update(1) 314 | global_step += 1 315 | if global_step % args.checkpointing_steps == 0 and global_step > 0: 316 | if accelerator.is_main_process: 317 | checkpoint = { 318 | "model": model.module.state_dict(), 319 | "ema": ema.state_dict(), 320 | "opt": optimizer.state_dict(), 321 | "args": args, 322 | "steps": global_step, 323 | } 324 | checkpoint_path = f"{checkpoint_dir}/{global_step:07d}.pt" 325 | torch.save(checkpoint, checkpoint_path) 326 | logger.info(f"Saved checkpoint to {checkpoint_path}") 327 | 328 | if (global_step == 1 or (global_step % args.sampling_steps == 0 and global_step > 0)): 329 | from samplers_t2i import euler_sampler 330 | with torch.no_grad(): 331 | samples = euler_sampler( 332 | model, 333 | xT, 334 | ys, 335 | y_null=torch.tensor( 336 | train_dataset.empty_token 337 | ).to(device).unsqueeze(0).repeat(ys.shape[0], 1, 1), 338 | num_steps=50, 339 | cfg_scale=4.0, 340 | guidance_low=0., 341 | guidance_high=1., 342 | path_type=args.path_type, 343 | heun=False, 344 | ).to(torch.float32) 345 | samples = vae.decode((samples - latents_bias) / latents_scale).sample 346 | gt_samples = vae.decode((gt_xs - latents_bias) / latents_scale).sample 347 | samples = (samples + 1) / 2. 348 | gt_samples = (gt_samples + 1) / 2. 349 | out_samples = accelerator.gather(samples.to(torch.float32)) 350 | gt_samples = accelerator.gather(gt_samples.to(torch.float32)) 351 | accelerator.log({"samples": wandb.Image(array2grid(out_samples)), 352 | "gt_samples": wandb.Image(array2grid(gt_samples))}) 353 | logging.info("Generating EMA samples done.") 354 | 355 | logs = { 356 | "loss": accelerator.gather(loss_mean).mean().detach().item(), 357 | "proj_loss": accelerator.gather(proj_loss_mean).mean().detach().item(), 358 | "grad_norm": accelerator.gather(grad_norm).mean().detach().item() 359 | } 360 | progress_bar.set_postfix(**logs) 361 | accelerator.log(logs, step=global_step) 362 | 363 | if global_step >= args.max_train_steps: 364 | break 365 | if global_step >= args.max_train_steps: 366 | break 367 | 368 | model.eval() # important! This disables randomized embedding dropout 369 | # do any sampling/FID calculation/etc. with ema (or model) in eval mode ... 370 | 371 | accelerator.wait_for_everyone() 372 | if accelerator.is_main_process: 373 | logger.info("Done!") 374 | accelerator.end_training() 375 | 376 | def parse_args(input_args=None): 377 | parser = argparse.ArgumentParser(description="Training") 378 | 379 | # logging: 380 | parser.add_argument("--output-dir", type=str, default="exps") 381 | parser.add_argument("--exp-name", type=str, required=True) 382 | parser.add_argument("--logging-dir", type=str, default="logs") 383 | parser.add_argument("--report-to", type=str, default="wandb") 384 | parser.add_argument("--sampling-steps", type=int, default=10000) 385 | parser.add_argument("--resume-step", type=int, default=0) 386 | 387 | # model 388 | parser.add_argument("--encoder-depth", type=int, default=8) 389 | parser.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=True) 390 | parser.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=False) 391 | 392 | # dataset 393 | parser.add_argument("--data-dir", type=str, default="../data/coco256_features") 394 | parser.add_argument("--resolution", type=int, choices=[256, 512], default=256) 395 | parser.add_argument("--batch-size", type=int, default=256) 396 | 397 | # precision 398 | parser.add_argument("--allow-tf32", action="store_true") 399 | parser.add_argument("--mixed-precision", type=str, default="fp16", choices=["no", "fp16", "bf16"]) 400 | 401 | # optimization 402 | parser.add_argument("--epochs", type=int, default=1400) 403 | parser.add_argument("--max-train-steps", type=int, default=400000) 404 | parser.add_argument("--checkpointing-steps", type=int, default=50000) 405 | parser.add_argument("--gradient-accumulation-steps", type=int, default=1) 406 | parser.add_argument("--learning-rate", type=float, default=1e-4) 407 | parser.add_argument("--adam-beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 408 | parser.add_argument("--adam-beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 409 | parser.add_argument("--adam-weight-decay", type=float, default=0., help="Weight decay to use.") 410 | parser.add_argument("--adam-epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 411 | parser.add_argument("--max-grad-norm", default=1.0, type=float, help="Max gradient norm.") 412 | 413 | # seed 414 | parser.add_argument("--seed", type=int, default=0) 415 | 416 | # cpu 417 | parser.add_argument("--num-workers", type=int, default=4) 418 | 419 | # loss 420 | parser.add_argument("--path-type", type=str, default="linear", choices=["linear", "cosine"]) 421 | parser.add_argument("--prediction", type=str, default="v", choices=["v"]) # currently we only support v-prediction 422 | parser.add_argument("--cfg-prob", type=float, default=0.1) 423 | parser.add_argument("--enc-type", type=str, default='dinov2-vit-b') 424 | parser.add_argument("--proj-coeff", type=float, default=0.5) 425 | parser.add_argument("--weighting", default="uniform", type=str, help="Max gradient norm.") 426 | parser.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False) 427 | 428 | if input_args is not None: 429 | args = parser.parse_args(input_args) 430 | else: 431 | args = parser.parse_args() 432 | 433 | return args 434 | 435 | if __name__ == "__main__": 436 | args = parse_args() 437 | 438 | main(args) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torchvision.datasets.utils import download_url 3 | import torch 4 | import torchvision.models as torchvision_models 5 | import timm 6 | from models import mocov3_vit 7 | import math 8 | import warnings 9 | 10 | 11 | # code from SiT repository 12 | pretrained_models = {'last.pt'} 13 | 14 | def download_model(model_name): 15 | """ 16 | Downloads a pre-trained SiT model from the web. 17 | """ 18 | assert model_name in pretrained_models 19 | local_path = f'pretrained_models/{model_name}' 20 | if not os.path.isfile(local_path): 21 | os.makedirs('pretrained_models', exist_ok=True) 22 | web_path = f'https://www.dl.dropboxusercontent.com/scl/fi/cxedbs4da5ugjq5wg3zrg/last.pt?rlkey=8otgrdkno0nd89po3dpwngwcc&st=apcc645o&dl=0' 23 | download_url(web_path, 'pretrained_models', filename=model_name) 24 | model = torch.load(local_path, map_location=lambda storage, loc: storage) 25 | return model 26 | 27 | def fix_mocov3_state_dict(state_dict): 28 | for k in list(state_dict.keys()): 29 | # retain only base_encoder up to before the embedding layer 30 | if k.startswith('module.base_encoder'): 31 | # fix naming bug in checkpoint 32 | new_k = k[len("module.base_encoder."):] 33 | if "blocks.13.norm13" in new_k: 34 | new_k = new_k.replace("norm13", "norm1") 35 | if "blocks.13.mlp.fc13" in k: 36 | new_k = new_k.replace("fc13", "fc1") 37 | if "blocks.14.norm14" in k: 38 | new_k = new_k.replace("norm14", "norm2") 39 | if "blocks.14.mlp.fc14" in k: 40 | new_k = new_k.replace("fc14", "fc2") 41 | # remove prefix 42 | if 'head' not in new_k and new_k.split('.')[0] != 'fc': 43 | state_dict[new_k] = state_dict[k] 44 | # delete renamed or unused k 45 | del state_dict[k] 46 | if 'pos_embed' in state_dict.keys(): 47 | state_dict['pos_embed'] = timm.layers.pos_embed.resample_abs_pos_embed( 48 | state_dict['pos_embed'], [16, 16], 49 | ) 50 | return state_dict 51 | 52 | @torch.no_grad() 53 | def load_encoders(enc_type, device, resolution=256): 54 | assert (resolution == 256) or (resolution == 512) 55 | 56 | enc_names = enc_type.split(',') 57 | encoders, architectures, encoder_types = [], [], [] 58 | for enc_name in enc_names: 59 | encoder_type, architecture, model_config = enc_name.split('-') 60 | # Currently, we only support 512x512 experiments with DINOv2 encoders. 61 | if resolution == 512: 62 | if encoder_type != 'dinov2': 63 | raise NotImplementedError( 64 | "Currently, we only support 512x512 experiments with DINOv2 encoders." 65 | ) 66 | 67 | architectures.append(architecture) 68 | encoder_types.append(encoder_type) 69 | if encoder_type == 'mocov3': 70 | if architecture == 'vit': 71 | if model_config == 's': 72 | encoder = mocov3_vit.vit_small() 73 | elif model_config == 'b': 74 | encoder = mocov3_vit.vit_base() 75 | elif model_config == 'l': 76 | encoder = mocov3_vit.vit_large() 77 | ckpt = torch.load(f'./ckpts/mocov3_vit{model_config}.pth') 78 | state_dict = fix_mocov3_state_dict(ckpt['state_dict']) 79 | del encoder.head 80 | encoder.load_state_dict(state_dict, strict=True) 81 | encoder.head = torch.nn.Identity() 82 | elif architecture == 'resnet': 83 | raise NotImplementedError() 84 | 85 | encoder = encoder.to(device) 86 | encoder.eval() 87 | 88 | elif 'dinov2' in encoder_type: 89 | import timm 90 | if 'reg' in encoder_type: 91 | encoder = torch.hub.load('facebookresearch/dinov2', f'dinov2_vit{model_config}14_reg') 92 | else: 93 | encoder = torch.hub.load('facebookresearch/dinov2', f'dinov2_vit{model_config}14') 94 | del encoder.head 95 | patch_resolution = 16 * (resolution // 256) 96 | encoder.pos_embed.data = timm.layers.pos_embed.resample_abs_pos_embed( 97 | encoder.pos_embed.data, [patch_resolution, patch_resolution], 98 | ) 99 | encoder.head = torch.nn.Identity() 100 | encoder = encoder.to(device) 101 | encoder.eval() 102 | 103 | elif 'dinov1' == encoder_type: 104 | import timm 105 | from models import dinov1 106 | encoder = dinov1.vit_base() 107 | ckpt = torch.load(f'./ckpts/dinov1_vit{model_config}.pth') 108 | if 'pos_embed' in ckpt.keys(): 109 | ckpt['pos_embed'] = timm.layers.pos_embed.resample_abs_pos_embed( 110 | ckpt['pos_embed'], [16, 16], 111 | ) 112 | del encoder.head 113 | encoder.head = torch.nn.Identity() 114 | encoder.load_state_dict(ckpt, strict=True) 115 | encoder = encoder.to(device) 116 | encoder.forward_features = encoder.forward 117 | encoder.eval() 118 | 119 | elif encoder_type == 'clip': 120 | import clip 121 | from models.clip_vit import UpdatedVisionTransformer 122 | encoder_ = clip.load(f"ViT-{model_config}/14", device='cpu')[0].visual 123 | encoder = UpdatedVisionTransformer(encoder_).to(device) 124 | #.to(device) 125 | encoder.embed_dim = encoder.model.transformer.width 126 | encoder.forward_features = encoder.forward 127 | encoder.eval() 128 | 129 | elif encoder_type == 'mae': 130 | from models.mae_vit import vit_large_patch16 131 | import timm 132 | kwargs = dict(img_size=256) 133 | encoder = vit_large_patch16(**kwargs).to(device) 134 | with open(f"ckpts/mae_vit{model_config}.pth", "rb") as f: 135 | state_dict = torch.load(f) 136 | if 'pos_embed' in state_dict["model"].keys(): 137 | state_dict["model"]['pos_embed'] = timm.layers.pos_embed.resample_abs_pos_embed( 138 | state_dict["model"]['pos_embed'], [16, 16], 139 | ) 140 | encoder.load_state_dict(state_dict["model"]) 141 | 142 | encoder.pos_embed.data = timm.layers.pos_embed.resample_abs_pos_embed( 143 | encoder.pos_embed.data, [16, 16], 144 | ) 145 | 146 | elif encoder_type == 'jepa': 147 | from models.jepa import vit_huge 148 | kwargs = dict(img_size=[224, 224], patch_size=14) 149 | encoder = vit_huge(**kwargs).to(device) 150 | with open(f"ckpts/ijepa_vit{model_config}.pth", "rb") as f: 151 | state_dict = torch.load(f, map_location=device) 152 | new_state_dict = dict() 153 | for key, value in state_dict['encoder'].items(): 154 | new_state_dict[key[7:]] = value 155 | encoder.load_state_dict(new_state_dict) 156 | encoder.forward_features = encoder.forward 157 | 158 | encoders.append(encoder) 159 | 160 | return encoders, encoder_types, architectures 161 | 162 | 163 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 164 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 165 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 166 | def norm_cdf(x): 167 | # Computes standard normal cumulative distribution function 168 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 169 | 170 | if (mean < a - 2 * std) or (mean > b + 2 * std): 171 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 172 | "The distribution of values may be incorrect.", 173 | stacklevel=2) 174 | 175 | with torch.no_grad(): 176 | # Values are generated by using a truncated uniform distribution and 177 | # then using the inverse CDF for the normal distribution. 178 | # Get upper and lower cdf values 179 | l = norm_cdf((a - mean) / std) 180 | u = norm_cdf((b - mean) / std) 181 | 182 | # Uniformly fill tensor with values from [l, u], then translate to 183 | # [2l-1, 2u-1]. 184 | tensor.uniform_(2 * l - 1, 2 * u - 1) 185 | 186 | # Use inverse cdf transform for normal distribution to get truncated 187 | # standard normal 188 | tensor.erfinv_() 189 | 190 | # Transform to proper mean, std 191 | tensor.mul_(std * math.sqrt(2.)) 192 | tensor.add_(mean) 193 | 194 | # Clamp to ensure it's in the proper range 195 | tensor.clamp_(min=a, max=b) 196 | return tensor 197 | 198 | 199 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 200 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 201 | 202 | 203 | def load_legacy_checkpoints(state_dict, encoder_depth): 204 | new_state_dict = dict() 205 | for key, value in state_dict.items(): 206 | if 'decoder_blocks' in key: 207 | parts =key.split('.') 208 | new_idx = int(parts[1]) + encoder_depth 209 | parts[0] = 'blocks' 210 | parts[1] = str(new_idx) 211 | new_key = '.'.join(parts) 212 | new_state_dict[new_key] = value 213 | else: 214 | new_state_dict[key] = value 215 | return new_state_dict --------------------------------------------------------------------------------