├── src ├── __init__.py ├── get_deltas.py ├── DETEX_sampler.py ├── DETEX_data.py ├── DETEX_modules.py └── model.py ├── .gitignore ├── data ├── dog7 │ ├── 01.png │ ├── 02.png │ ├── 03.png │ └── 04.png ├── dog7_fg │ ├── 01.png │ ├── 02.png │ ├── 03.png │ └── 04.png └── dog7_mask │ ├── 01.png │ ├── 02.png │ ├── 03.png │ └── 04.png ├── assets ├── pipeline.png └── results.png ├── train.sh ├── configs └── DETEX │ └── finetune.yaml ├── README.md ├── sample.py └── train.py /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | DETEX.code-workspace 2 | -------------------------------------------------------------------------------- /data/dog7/01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrototypeNx/DETEX/HEAD/data/dog7/01.png -------------------------------------------------------------------------------- /data/dog7/02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrototypeNx/DETEX/HEAD/data/dog7/02.png -------------------------------------------------------------------------------- /data/dog7/03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrototypeNx/DETEX/HEAD/data/dog7/03.png -------------------------------------------------------------------------------- /data/dog7/04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrototypeNx/DETEX/HEAD/data/dog7/04.png -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrototypeNx/DETEX/HEAD/assets/pipeline.png -------------------------------------------------------------------------------- /assets/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrototypeNx/DETEX/HEAD/assets/results.png -------------------------------------------------------------------------------- /data/dog7_fg/01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrototypeNx/DETEX/HEAD/data/dog7_fg/01.png -------------------------------------------------------------------------------- /data/dog7_fg/02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrototypeNx/DETEX/HEAD/data/dog7_fg/02.png -------------------------------------------------------------------------------- /data/dog7_fg/03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrototypeNx/DETEX/HEAD/data/dog7_fg/03.png -------------------------------------------------------------------------------- /data/dog7_fg/04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrototypeNx/DETEX/HEAD/data/dog7_fg/04.png -------------------------------------------------------------------------------- /data/dog7_mask/01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrototypeNx/DETEX/HEAD/data/dog7_mask/01.png -------------------------------------------------------------------------------- /data/dog7_mask/02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrototypeNx/DETEX/HEAD/data/dog7_mask/02.png -------------------------------------------------------------------------------- /data/dog7_mask/03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrototypeNx/DETEX/HEAD/data/dog7_mask/03.png -------------------------------------------------------------------------------- /data/dog7_mask/04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrototypeNx/DETEX/HEAD/data/dog7_mask/04.png -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | python -u train.py \ 2 | --base configs/DETEX/finetune.yaml \ 3 | -t --gpus 0,1,2,3 \ 4 | --resume-from-checkpoint-custom /data1/yfcai/sd-v1-4.ckpt \ 5 | --caption " dog with

pose in background" \ 6 | --num_imgs 4 \ 7 | --datapath data/dog7 \ 8 | --reg_datapath data/dog_samples/samples \ 9 | --mask_path data/dog7_fg\ 10 | --mask_path2 data/dog7_mask\ 11 | --reg_caption "dog" \ 12 | --modifier_token "++++++++" \ 13 | --name dog7 -------------------------------------------------------------------------------- /src/get_deltas.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import glob 4 | import torch 5 | 6 | os.environ['CUDA_VISIBLE_DEVICES'] ="0" 7 | 8 | def main(path, newtoken=0): 9 | layers = [] 10 | for files in glob.glob(f'{path}/checkpoints/*'): 11 | if ('=' in files or '_' in files) and 'delta' not in files: 12 | print(files) 13 | if '=' in files: 14 | epoch_number = files.split('=')[1].split('.ckpt')[0] 15 | elif '_' in files: 16 | epoch_number = files.split('/')[-1].split('.ckpt')[0] 17 | 18 | st = torch.load(files, map_location='cuda:0')["state_dict"] 19 | if len(layers) == 0: 20 | for key in list(st.keys()): 21 | if 'attn2.to_k' in key or 'attn2.to_v' in key or 'cond_stage_model.mapper' in key: 22 | layers.append(key) 23 | print(layers) 24 | st_delta = {'state_dict': {}} 25 | for each in layers: 26 | st_delta['state_dict'][each] = st[each].clone() 27 | print('/'.join(files.split('/')[:-1]) + f'/delta_epoch={epoch_number}.ckpt') 28 | 29 | num_tokens = st['cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'].shape[0] 30 | 31 | if newtoken > 0: 32 | print("saving the optimized embedding") 33 | st_delta['state_dict']['embed'] = st['cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'][-newtoken:].clone() 34 | print(st_delta['state_dict']['embed'].shape, num_tokens) 35 | 36 | torch.save(st_delta, '/'.join(files.split('/')[:-1]) + f'/delta_epoch_{epoch_number}.ckpt') 37 | os.remove(files) 38 | 39 | 40 | def parse_args(): 41 | parser = argparse.ArgumentParser('', add_help=False) 42 | parser.add_argument('--path', help='path of folder to checkpoints', 43 | type=str) 44 | parser.add_argument('--newtoken', help='number of new tokens in the checkpoint', default=5, 45 | type=int) 46 | return parser.parse_args() 47 | 48 | 49 | if __name__ == "__main__": 50 | args = parse_args() 51 | path = args.path 52 | main(path, args.newtoken) 53 | -------------------------------------------------------------------------------- /configs/DETEX/finetune.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-05 3 | target: src.model.DETEX 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "image" 11 | cond_stage_key: "caption" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: True # Note: different from the one we trained before 15 | add_token: True 16 | freeze_model: "crossattn-kv" 17 | conditioning_key: crossattn 18 | monitor: val/loss_simple_ema 19 | scale_factor: 0.18215 20 | use_ema: False 21 | 22 | unet_config: 23 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 24 | params: 25 | image_size: 64 # unused 26 | in_channels: 4 27 | out_channels: 4 28 | model_channels: 320 29 | attention_resolutions: [ 4, 2, 1 ] 30 | num_res_blocks: 2 31 | channel_mult: [ 1, 2, 4, 4 ] 32 | num_heads: 8 33 | use_spatial_transformer: True 34 | transformer_depth: 1 35 | context_dim: 768 36 | use_checkpoint: False 37 | legacy: False 38 | 39 | first_stage_config: 40 | target: ldm.models.autoencoder.AutoencoderKL 41 | params: 42 | embed_dim: 4 43 | monitor: val/rec_loss 44 | ddconfig: 45 | double_z: true 46 | z_channels: 4 47 | resolution: 256 48 | in_channels: 3 49 | out_ch: 3 50 | ch: 128 51 | ch_mult: 52 | - 1 53 | - 2 54 | - 4 55 | - 4 56 | num_res_blocks: 2 57 | attn_resolutions: [] 58 | dropout: 0.0 59 | lossconfig: 60 | target: torch.nn.Identity 61 | 62 | cond_stage_config: 63 | target: src.DETEX_modules.CLIPEmbedderWrapper 64 | params: 65 | modifier_token: 66 | num_imgs: 4 67 | 68 | data: 69 | target: train.DataModuleFromConfig 70 | params: 71 | batch_size: 1 72 | num_workers: 1 73 | wrap: false 74 | train: 75 | target: src.DETEX_data.AllBase 76 | params: 77 | size: 512 78 | aug: False 79 | train2: 80 | target: src.DETEX_data.AllBase 81 | params: 82 | size: 512 83 | aug: False 84 | 85 | 86 | lightning: 87 | callbacks: 88 | image_logger: 89 | target: train.ImageLogger 90 | params: 91 | batch_frequency: 5000 92 | max_images: 8 93 | increase_log_steps: False 94 | 95 | trainer: 96 | max_steps: 800 -------------------------------------------------------------------------------- /src/DETEX_sampler.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import numpy as np 4 | from tqdm import tqdm 5 | from functools import partial 6 | 7 | from src.model import ret_attn 8 | from ldm.models.diffusion.ddim import DDIMSampler as DDIMSampler 9 | 10 | class DETEXSampler(DDIMSampler): 11 | def __init__(self, model, schedule="linear", **kwargs): 12 | super().__init__(model, schedule="linear", **kwargs) 13 | 14 | @torch.no_grad() 15 | def ddim_sampling(self, cond, shape, 16 | x_T=None, ddim_use_original_steps=False, 17 | callback=None, timesteps=None, quantize_denoised=False, 18 | mask=None, x0=None, img_callback=None, log_every_t=100, 19 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 20 | unconditional_guidance_scale=1., unconditional_conditioning=None,): 21 | device = self.model.betas.device 22 | b = shape[0] 23 | if x_T is None: 24 | img = torch.randn(shape, device=device) 25 | else: 26 | img = x_T 27 | 28 | if timesteps is None: 29 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 30 | elif timesteps is not None and not ddim_use_original_steps: 31 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 32 | timesteps = self.ddim_timesteps[:subset_end] 33 | 34 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 35 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) 36 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 37 | print(f"Running DDIM Sampling with {total_steps} timesteps") 38 | 39 | iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) 40 | 41 | for i, step in enumerate(iterator): 42 | ret_attn.clear() 43 | index = total_steps - i - 1 44 | ts = torch.full((b,), step, device=device, dtype=torch.long) 45 | 46 | if mask is not None: 47 | assert x0 is not None 48 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 49 | img = img_orig * mask + (1. - mask) * img 50 | 51 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 52 | quantize_denoised=quantize_denoised, temperature=temperature, 53 | noise_dropout=noise_dropout, score_corrector=score_corrector, 54 | corrector_kwargs=corrector_kwargs, 55 | unconditional_guidance_scale=unconditional_guidance_scale, 56 | unconditional_conditioning=unconditional_conditioning) 57 | img, pred_x0 = outs 58 | if callback: callback(i) 59 | if img_callback: img_callback(pred_x0, i) 60 | 61 | if index % log_every_t == 0 or index == total_steps - 1: 62 | intermediates['x_inter'].append(img) 63 | intermediates['pred_x0'].append(pred_x0) 64 | 65 | return img, intermediates -------------------------------------------------------------------------------- /src/DETEX_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | 8 | templates_small = [ 9 | 'Photo of a {}', 10 | ] 11 | 12 | def isimage(path): 13 | if 'png' in path.lower() or 'jpg' in path.lower() or 'jpeg' in path.lower(): 14 | return True 15 | 16 | class AllBase(Dataset): 17 | def __init__(self, 18 | datapath, 19 | mask_path=None, 20 | mask_path2=None, 21 | modifier=None, 22 | reg_datapath=None, 23 | caption=None, 24 | reg_caption=None, 25 | ID_loss_weight=None, 26 | size=512, 27 | interpolation="bicubic", 28 | flip_p=0.5, 29 | aug=False, 30 | style=False, 31 | repeat=0. 32 | 33 | ): 34 | 35 | self.modifier = modifier.split("+") 36 | self.aug = aug 37 | 38 | self.repeat = repeat 39 | self.style = style 40 | self.templates_small = templates_small 41 | if self.style: 42 | self.templates_small = templates_small_style 43 | if os.path.isdir(datapath): 44 | self.image_paths1 = [(os.path.join(datapath, file_path), os.path.join(mask_path, file_path), os.path.join(mask_path2, file_path)) for file_path in sorted(os.listdir(datapath)) if isimage(file_path)] 45 | else: 46 | with open(datapath, "r") as f: 47 | self.image_paths1 = f.read().splitlines() 48 | print(self.image_paths1) 49 | 50 | self._length1 = len(self.image_paths1) 51 | 52 | self.image_paths2 = [] 53 | self._length2 = 0 54 | if reg_datapath is not None: 55 | if os.path.isdir(reg_datapath): 56 | self.image_paths2 = [os.path.join(reg_datapath, file_path) for file_path in os.listdir(reg_datapath) if isimage(file_path)] 57 | else: 58 | with open(reg_datapath, "r") as f: 59 | self.image_paths2 = f.read().splitlines() 60 | self._length2 = len(self.image_paths2) 61 | 62 | self.labels = { 63 | "relative_file_path1_": [x for x in self.image_paths1], 64 | "relative_file_path2_": [x for x in self.image_paths2], 65 | } 66 | 67 | self.size = size 68 | self.interpolation = {"linear": PIL.Image.LINEAR, 69 | "bilinear": PIL.Image.BILINEAR, 70 | "bicubic": PIL.Image.BICUBIC, 71 | "lanczos": PIL.Image.LANCZOS, 72 | }[interpolation] 73 | self.flip = transforms.RandomHorizontalFlip(p=1) 74 | self.caption = caption 75 | 76 | if os.path.exists(self.caption): 77 | self.caption = [x.strip() for x in open(caption, 'r').readlines()] 78 | 79 | self.reg_caption = reg_caption 80 | if os.path.exists(self.reg_caption): 81 | self.reg_caption = [x.strip() for x in open(reg_caption, 'r').readlines()] 82 | 83 | def __len__(self): 84 | if self._length2 > 0: 85 | return 2*self._length2 86 | elif self.repeat > 0: 87 | return self._length1*self.repeat 88 | else: 89 | return self._length1 90 | 91 | def __getitem__(self, i): 92 | example = {} 93 | if i > self._length2 or self._length2 == 0: # train data 94 | img_id = i % self._length1 95 | image = Image.open(self.labels["relative_file_path1_"][img_id][0]) 96 | fg_img = Image.open(self.labels["relative_file_path1_"][img_id][1]) 97 | mask = Image.open(self.labels["relative_file_path1_"][img_id][2]) 98 | if isinstance(self.caption, str): 99 | example["caption"] = np.random.choice(self.templates_small).format(self.caption) 100 | else: 101 | example["caption"] = self.caption[i % min(self._length1, len(self.caption)) ] 102 | else: # reg data 103 | image = Image.open(self.labels["relative_file_path2_"][i % self._length2]) 104 | if isinstance(self.reg_caption, str): 105 | example["caption"] = np.random.choice(self.templates_small).format(self.reg_caption) 106 | else: 107 | example["caption"] = self.reg_caption[i % self._length2] 108 | 109 | if not image.mode == "RGB": 110 | image = image.convert("RGB") 111 | 112 | if i > self._length2 or self._length2 == 0: # train data 113 | if np.random.randint(0, 100) < 50: # with background 114 | example["caption"] = example["caption"].replace("

", self.modifier[img_id + 1]) 115 | example["caption"] = example["caption"].replace("", self.modifier[img_id + self._length1 + 1]) 116 | else: # only fg 117 | image = fg_img 118 | if not image.mode == "RGB": 119 | image = image.convert("RGB") 120 | example["caption"] = example["caption"].split(" in")[0] 121 | example["caption"] = example["caption"].replace("

", self.modifier[img_id + 1]) 122 | 123 | image = image.resize((self.size, self.size), resample=self.interpolation) 124 | input_image1 = np.array(image).astype(np.uint8) 125 | input_image1 = (input_image1 / 127.5 - 1.0).astype(np.float32) 126 | fg_attn_mask = mask.resize((self.size // 16, self.size // 16), resample=PIL.Image.NEAREST) 127 | fg_attn_mask = np.array(fg_attn_mask).astype(np.bool_).astype(np.float32) 128 | bg_attn_mask = np.ones((self.size // 16, self.size // 16)) 129 | bg_attn_mask = (~fg_attn_mask.astype(bool)).astype(np.uint8).astype(np.float32) 130 | 131 | mask = np.ones((self.size // 8, self.size // 8)) 132 | 133 | else: # reg data 134 | if self.size is not None: 135 | image = image.resize((self.size, self.size), resample=self.interpolation) 136 | input_image1 = np.array(image).astype(np.uint8) 137 | input_image1 = (input_image1 / 127.5 - 1.0).astype(np.float32) 138 | mask = np.ones((self.size // 8, self.size // 8)) 139 | fg_attn_mask = np.ones((self.size // 16, self.size // 16)) 140 | bg_attn_mask = np.ones((self.size // 16, self.size // 16)) 141 | 142 | example["image"] = input_image1 143 | example["mask"] = mask 144 | 145 | example["fg_attn_mask"] = fg_attn_mask 146 | example["bg_attn_mask"] = bg_attn_mask 147 | 148 | return example 149 | 150 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Decoupled Textual Embeddings for Customized Image Generation 2 | 3 | 4 | 5 | We propose a customized image generation method DETEX that utilizes multiple tokens to alleviate the issue of overfitting and entanglement between the target concept and unrelated information. Our DETEX enables more precise and efficient control over preserving input image content in the generated results during inference by selectively utilizing different tokens. 6 | 7 |

8 |

9 | 10 |

11 | 12 | ## Method Details 13 | 14 |
15 |

16 | 17 |

18 | 19 | Framework of our DETEX. Left: Our DETEX represents each image with multiple decoupled textual embeddings, $i.e.$, an image-shared subject embedding $v$ and two image-specific subject-unrelated embeddings (pose $v^p_i$ and background $v^b_i$). Right: To learn target concept, we initialize the subject embedding $v$ as a learnable vector, and adopt two attribute mappers to project the input image as the pose and background embeddings. During training, we jointly finetune the embeddings with the K, V mapping parameters in cross-attention layer. A cross-attention loss is further introduced to facilitate the disentanglement.Framework of our DETEX. Left: Our DETEX represents each image with multiple decoupled textual embeddings, $i.e.$, an image-shared subject embedding $v$ and two image-specific subject-unrelated embeddings (pose $v^p_i$ and background $v^b_i$). Right: To learn target concept, we initialize the subject embedding $v$ as a learnable vector, and adopt two attribute mappers to project the input image as the pose and background embeddings. During training, we jointly finetune the embeddings with the K, V mapping parameters in cross-attention layer. A cross-attention loss is further introduced to facilitate the disentanglement. 20 | 21 | ## Getting Started 22 | 23 | ### Environment Setup 24 | 25 | ``` 26 | git clone https://github.com/PrototypeNx/DETEX.git 27 | cd DETEX 28 | git clone https://github.com/CompVis/stable-diffusion.git 29 | cd stable-diffusion 30 | conda env create -f environment.yaml 31 | conda activate ldm 32 | pip install clip-retrieval tqdm 33 | ``` 34 | 35 | Our code was developed on the following commit `#21f890f9da3cfbeaba8e2ac3c425ee9e998d5229` of [stable-diffusion](https://github.com/CompVis/stable-diffusion). Download the stable-diffusion model checkpoint 36 | `wget https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt` 37 | 38 | The pretrained CLIP model can be downloaded automatically. If that doesn't work, you can download the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) manually and place it in the appropriate [config](./configs/clip-vit-large-patch14) folder. 39 | 40 | ### Preparing Dataset 41 | 42 | We provide some processed example images in [data](./data) which contains original images and corresponding processed foreground images and masks mentioned in the paper. 43 | 44 | For custom dataset, you should prepare the original image `SubjectName` belong to a specific concept, the corresponding mask `SubjectName_mask`, and the corresponding foreground image `SubjectName_fg`. Note that the mask and foreground image files should have the same file name `xxx0n.png` as their corresponding original image. 45 | 46 | We recommend using [SAM](https://github.com/facebookresearch/segment-anything) to simply obtain the foreground mask and the corresponding foreground image. 47 | 48 | In addition, it is necessary to prepare a regularized dataset which contains images belong to the same category of the input subject. You can retrieve the images on the website or just generate with vanilla SD using prompt like `'Photo of a '`. We recommend preparing at least 200 regularized images for each category to achieve better performence. More details about regularization can be found in [Dreambooth](https://arxiv.org/abs/2208.12242). 49 | 50 | The data structure should be like this: 51 | 52 | ``` 53 | data 54 | ├── SubjectName 55 | │ ├── xxx01.png 56 | │ ├── xxx02.png 57 | │ ├── xxx03.png 58 | │ ├── xxx04.png 59 | ├── SubjectName_fg 60 | │ ├── xxx01.png 61 | │ ├── xxx02.png 62 | │ ├── xxx03.png 63 | │ ├── xxx04.png 64 | ├── SubjectName_mask 65 | │ ├── xxx01.png 66 | │ ├── xxx02.png 67 | │ ├── xxx03.png 68 | │ ├── xxx04.png 69 | ├── Subject_samples 70 | │ ├── 001.png 71 | │ ├── 002.png 72 | │ ├── .... 73 | │ ├── 199.png 74 | │ ├── 200.png 75 | ``` 76 | 77 | ### Training 78 | 79 | You can run the scripts below to train with the example data. 80 | 81 | ``` 82 | ## run training (on 4 GPUs) 83 | python -u train.py \ 84 | --base configs/DETEX/finetune.yaml \ 85 | -t --gpus 0,1,2,3 \ 86 | --resume-from-checkpoint-custom \ 87 | --caption " dog with

pose in background" \ 88 | --num_imgs 4 \ 89 | --datapath data/dog7 \ 90 | --reg_datapath data/dog_samples/samples \ 91 | --mask_path data/dog7_fg\ 92 | --mask_path2 data/dog7_mask\ 93 | --reg_caption "dog" \ 94 | --modifier_token "++++++++" \ 95 | --name dog7 96 | ``` 97 | 98 | The modifier tokens `~` and `~` represent the corresponding pose and background of the 4 input imgs respectively. Please refer to the paper for more details about the unrelated tokens. 99 | 100 | Note that the parameter `modifier_token` should be arranged in the form `++...+++...+`. Do not change the input order of ``, `

` and ``. 101 | 102 | If you don't have a sufficient number of GPUs, we recommend training with a lower learning rate for more iterations. 103 | 104 | ### Save Updated Checkpoint 105 | 106 | After training, run the following script to only save the updated weights. 107 | 108 | ``` 109 | python src/get_deltas.py --path logs//checkpoints/last.ckpt --newtoken 9 110 | ``` 111 | 112 | ### Generation 113 | 114 | Run the following script to generate with the target concept subject ``. 115 | 116 | ``` 117 | python sample.py --delta_ckpt logs//checkpoints/delta_epoch_last.ckpt \ 118 | --ckpt --scale 6 --n_samples 3 --n_iter 2 --ddim_steps 50 \ 119 | --prompt "photo of a dog" 120 | ``` 121 | 122 | If you use unrelated token `

` or `` in the prompt, a reference img path should be added in the script to get the unrelated embedding through mapper. 123 | 124 | ``` 125 | python sample.py --delta_ckpt logs//checkpoints/delta_epoch_last.ckpt \ 126 | --ckpt --scale 6 --n_samples 3 --n_iter 2 --ddim_steps 50 \ 127 | --ref data/dog7/02.png \ 128 | --prompt "photo of a dog running in background" 129 | ``` 130 | 131 | The generated images are saved in `logs/`. 132 | 133 | ## Citation 134 | 135 | ``` 136 | @article{cai2023DETEX, 137 | title={Decoupled Textual Embeddings for Customized Image Generation}, 138 | author={Yufei Cai and Yuxiang Wei and Zhilong Ji and Jinfeng Bai and Hu Han and Wangmeng Zuo}, 139 | journal={arXiv preprint arXiv:2312.11826}, 140 | year={2023} 141 | } 142 | ``` 143 | -------------------------------------------------------------------------------- /src/DETEX_modules.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from packaging import version 3 | import numpy as np 4 | from PIL import Image 5 | import torch 6 | import torch.nn as nn 7 | import transformers 8 | from transformers import CLIPTokenizer, CLIPTextModel, CLIPModel, CLIPProcessor 9 | from ldm.modules.attention import CrossAttention 10 | 11 | class AbstractEncoder(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def encode(self, *args, **kwargs): 16 | raise NotImplementedError 17 | 18 | 19 | class Mapper(nn.Module): 20 | def __init__(self, dim): 21 | super().__init__() 22 | self.net = nn.Sequential( 23 | nn.Linear(dim, dim), 24 | nn.ReLU(), 25 | nn.Linear(dim, dim), 26 | nn.ReLU(), 27 | nn.Linear(dim, dim), 28 | ) 29 | def forward(self, x, context=None): 30 | x1 = self.net(x) 31 | return x1 32 | 33 | class CLIPEmbedderWrapper(AbstractEncoder): 34 | """Uses the CLIP transformer encoder for text (from Hugging Face)""" 35 | def __init__(self, modifier_token, num_imgs, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): 36 | super().__init__() 37 | self.tokenizer = CLIPTokenizer.from_pretrained("./configs/clip-vit-large-patch14") 38 | self.transformer = CLIPTextModel.from_pretrained("./configs/clip-vit-large-patch14") 39 | self.image_encoder = CLIPModel.from_pretrained("./configs/clip-vit-large-patch14") 40 | self.processor = CLIPProcessor.from_pretrained("./configs/clip-vit-large-patch14") 41 | self.num_imgs = num_imgs 42 | 43 | for i in range(num_imgs): 44 | setattr(self, f"mapperp{i+1}", Mapper(dim=768)) 45 | setattr(self, f"mapperb{i+1}", Mapper(dim=768)) 46 | 47 | self.device = device 48 | self.max_length = max_length 49 | self.modifier_token = modifier_token 50 | if '+' in self.modifier_token: 51 | self.modifier_token = self.modifier_token.split('+') 52 | else: 53 | self.modifier_token = [self.modifier_token] 54 | 55 | self.add_token() 56 | self.freeze() 57 | 58 | def add_token(self): 59 | self.modifier_token_id = [] 60 | token_embeds1 = self.transformer.get_input_embeddings().weight.data 61 | for each_modifier_token in self.modifier_token: 62 | num_added_tokens = self.tokenizer.add_tokens(each_modifier_token) 63 | modifier_token_id = self.tokenizer.convert_tokens_to_ids(each_modifier_token) 64 | self.modifier_token_id.append(modifier_token_id) 65 | 66 | self.transformer.resize_token_embeddings(len(self.tokenizer)) 67 | token_embeds = self.transformer.get_input_embeddings().weight.data 68 | token_embeds[self.modifier_token_id[0]] = torch.nn.Parameter(token_embeds[42170], requires_grad=True) 69 | 70 | def custom_forward(self, hidden_states, input_ids): 71 | r""" 72 | Returns: 73 | """ 74 | input_shape = hidden_states.size() 75 | bsz, seq_len = input_shape[:2] 76 | if version.parse(transformers.__version__) >= version.parse('4.21'): 77 | causal_attention_mask = self.transformer.text_model._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to( 78 | hidden_states.device 79 | ) 80 | else: 81 | causal_attention_mask = self.transformer.text_model._build_causal_attention_mask(bsz, seq_len).to( 82 | hidden_states.device 83 | ) 84 | 85 | encoder_outputs = self.transformer.text_model.encoder( 86 | inputs_embeds=hidden_states, 87 | causal_attention_mask=causal_attention_mask, 88 | ) 89 | 90 | last_hidden_state = encoder_outputs[0] 91 | last_hidden_state = self.transformer.text_model.final_layer_norm(last_hidden_state) 92 | 93 | return last_hidden_state 94 | 95 | def freeze(self): 96 | self.transformer = self.transformer.eval() 97 | for param in self.transformer.text_model.encoder.parameters(): 98 | param.requires_grad = False 99 | for param in self.transformer.text_model.final_layer_norm.parameters(): 100 | param.requires_grad = False 101 | for param in self.transformer.text_model.embeddings.position_embedding.parameters(): 102 | param.requires_grad = False 103 | 104 | def forward(self, text, input_img=None): 105 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 106 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 107 | tokens = batch_encoding["input_ids"].to(self.device) 108 | bs=tokens.shape[0] 109 | 110 | input_shape = tokens.size() 111 | hidden_states = self.transformer.text_model.embeddings(input_ids=tokens.view(-1, input_shape[-1])) 112 | 113 | # get CLIP img embedding 114 | if input_img is not None: 115 | image_embeds = [] 116 | for index in range(bs): 117 | img = input_img[index].permute(1,2,0) 118 | img = (img+1.)*127.5 119 | img = img.cpu().numpy().astype(np.uint8) 120 | image = Image.fromarray(img) 121 | img = self.processor(text=["a"], images=image, return_tensors="pt", padding=True) 122 | image_embeds.append(self.image_encoder(**img.to(self.device)).image_embeds) 123 | image_embeds = torch.stack(image_embeds).to(self.device) 124 | 125 | indices = tokens == self.modifier_token_id[0] # optimize 126 | num_img = (len(self.modifier_token_id) - 1) // 2 127 | num = 0 128 | for token_id in self.modifier_token_id: 129 | indices |= tokens == token_id 130 | if input_img is not None and token_id != self.modifier_token_id[0]: 131 | idx = torch.where(token_id == tokens) 132 | if min(idx[0].shape) != 0: 133 | t = [] 134 | for i in range(bs): 135 | if num <= num_img: 136 | p = getattr(self, f'mapperp{num}') 137 | # print(p, text) 138 | t.append(p(image_embeds[i].view(1,768), image_embeds[i].view(1,768)).view(768)) 139 | else: 140 | b = getattr(self, f'mapperb{num - num_img}') 141 | # print(b, text) 142 | t.append(b(image_embeds[i].view(1,768), image_embeds[i].view(1,768)).view(768)) 143 | t = torch.stack(t).view(bs, 768) 144 | hidden_states[idx] = t 145 | num += 1 146 | 147 | indices = (indices*1).unsqueeze(-1) 148 | 149 | hidden_states = (1-indices)*hidden_states.detach() + indices*hidden_states # detach 150 | 151 | z = self.custom_forward(hidden_states, tokens) 152 | 153 | return z 154 | 155 | def encode(self, text, input_img): 156 | return self(text, input_img) 157 | 158 | def get_index(self, text): 159 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 160 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 161 | tokens = batch_encoding["input_ids"].to(self.device) 162 | index = [] 163 | for token_id in self.modifier_token_id: 164 | if token_id in tokens: 165 | idx = torch.where(token_id == tokens) 166 | index.append(int(idx[1])) 167 | return index 168 | 169 | 170 | def return_parameters(self): 171 | token_embeds = self.transformer.get_input_embeddings().weight.data 172 | param = list(itertools.chain(token_embeds[self.modifier_token_id[0]])) 173 | for i in range(self.num_imgs): 174 | param += itertools.chain(getattr(self, f'mapperp{i+1}').parameters()) 175 | for i in range(self.num_imgs): 176 | param += itertools.chain(getattr(self, f'mapperb{i+1}').parameters()) 177 | 178 | print(type(param)) 179 | print(len(param)) 180 | return param 181 | 182 | 183 | if __name__ == "__main__": 184 | from ldm.util import count_params 185 | model = FrozenCLIPEmbedderWrapper() 186 | count_params(model, verbose=True) 187 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, glob 2 | 3 | sys.path.append('stable-diffusion') 4 | import torch 5 | import numpy as np 6 | from omegaconf import OmegaConf 7 | from PIL import Image 8 | import PIL 9 | from tqdm import tqdm, trange 10 | from einops import rearrange 11 | from torchvision.utils import make_grid 12 | from pytorch_lightning import seed_everything 13 | from torch import autocast 14 | from contextlib import contextmanager, nullcontext 15 | from einops import rearrange, repeat 16 | 17 | from ldm.util import instantiate_from_config 18 | from ldm.models.diffusion.ddim import DDIMSampler 19 | from ldm.models.diffusion.plms import PLMSSampler 20 | from src.DETEX_sampler import DETEXSampler 21 | import wandb 22 | 23 | os.environ['CUDA_VISIBLE_DEVICES'] ="0" 24 | 25 | def load_model_from_config(config, ckpt, verbose=False): 26 | print(f"Loading model from {ckpt}") 27 | pl_sd = torch.load(ckpt, map_location="cpu") 28 | if "global_step" in pl_sd: 29 | print(f"Global Step: {pl_sd['global_step']}") 30 | sd = pl_sd["state_dict"] 31 | model = instantiate_from_config(config.model) 32 | 33 | token_weights = sd["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"] 34 | del sd["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"] 35 | m, u = model.load_state_dict(sd, strict=False) 36 | model.cond_stage_model.transformer.text_model.embeddings.token_embedding.weight.data[ 37 | :token_weights.shape[0]] = token_weights 38 | if len(m) > 0 and verbose: 39 | print("missing keys:") 40 | print(m) 41 | if len(u) > 0 and verbose: 42 | print("unexpected keys:") 43 | print(u) 44 | 45 | model.cuda() 46 | model.eval() 47 | return model 48 | 49 | def load_img(path): 50 | image = Image.open(path).convert("RGB") 51 | w, h = image.size 52 | print(f"loaded input image of size ({w}, {h}) from {path}") 53 | w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 54 | image = image.resize((512, 512), resample=PIL.Image.LANCZOS) 55 | image = np.array(image).astype(np.float32) / 255.0 56 | image = image[None].transpose(0, 3, 1, 2) 57 | image = torch.from_numpy(image) 58 | return 2.*image - 1. 59 | 60 | 61 | def main(): 62 | parser = argparse.ArgumentParser() 63 | 64 | parser.add_argument( 65 | "--prompt", 66 | type=str, 67 | nargs="?", 68 | default="a painting of a virus monster playing guitar", 69 | help="the prompt to render" 70 | ) 71 | parser.add_argument( 72 | "--outdir", 73 | type=str, 74 | nargs="?", 75 | help="dir to write results to", 76 | default="outputs/txt2img-samples" 77 | ) 78 | parser.add_argument( 79 | "--skip_grid", 80 | action='store_true', 81 | help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", 82 | ) 83 | parser.add_argument( 84 | "--skip_save", 85 | action='store_true', 86 | help="do not save individual samples. For speed measurements.", 87 | ) 88 | parser.add_argument( 89 | "--ddim_steps", 90 | type=int, 91 | default=200, 92 | help="number of ddim sampling steps", 93 | ) 94 | parser.add_argument( 95 | "--plms", 96 | action='store_true', 97 | help="use plms sampling", 98 | ) 99 | parser.add_argument( 100 | "--laion400m", 101 | action='store_true', 102 | help="uses the LAION400M model", 103 | ) 104 | parser.add_argument( 105 | "--fixed_code", 106 | action='store_true', 107 | help="if enabled, uses the same starting code across samples ", 108 | ) 109 | parser.add_argument( 110 | "--ddim_eta", 111 | type=float, 112 | default=1.0, 113 | help="ddim eta (eta=0.0 corresponds to deterministic sampling", 114 | ) 115 | parser.add_argument( 116 | "--n_iter", 117 | type=int, 118 | default=1, 119 | help="sample this often", 120 | ) 121 | parser.add_argument( 122 | "--H", 123 | type=int, 124 | default=512, 125 | help="image height, in pixel space", 126 | ) 127 | parser.add_argument( 128 | "--W", 129 | type=int, 130 | default=512, 131 | help="image width, in pixel space", 132 | ) 133 | parser.add_argument( 134 | "--C", 135 | type=int, 136 | default=4, 137 | help="latent channels", 138 | ) 139 | parser.add_argument( 140 | "--f", 141 | type=int, 142 | default=8, 143 | help="downsampling factor", 144 | ) 145 | parser.add_argument( 146 | "--n_samples", 147 | type=int, 148 | default=6, 149 | help="how many samples to produce for each given prompt. A.k.a. batch size", 150 | ) 151 | parser.add_argument( 152 | "--n_rows", 153 | type=int, 154 | default=6, 155 | help="rows in the grid (default: n_samples)", 156 | ) 157 | parser.add_argument( 158 | "--scale", 159 | type=float, 160 | default=6., 161 | help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", 162 | ) 163 | parser.add_argument( 164 | "--from-file", 165 | type=str, 166 | help="if specified, load prompts from this file", 167 | ) 168 | parser.add_argument( 169 | "--config", 170 | type=str, 171 | default="configs/DETEX/finetune.yaml", 172 | help="path to config which constructs model", 173 | ) 174 | parser.add_argument( 175 | "--ckpt", 176 | type=str, 177 | required=True, 178 | help="path to checkpoint of the pre-trained model", 179 | ) 180 | parser.add_argument( 181 | "--delta_ckpt", 182 | type=str, 183 | default=None, 184 | help="path to delta checkpoint of fine-tuned block", 185 | ) 186 | parser.add_argument( 187 | "--seed", 188 | type=int, 189 | default=42, 190 | help="the seed (for reproducible sampling)", 191 | ) 192 | parser.add_argument( 193 | "--precision", 194 | type=str, 195 | help="evaluate at this precision", 196 | choices=["full", "autocast"], 197 | default="full" 198 | ) 199 | parser.add_argument( 200 | "--wandb_log", 201 | action='store_true', 202 | help="save grid images to wandb.", 203 | ) 204 | parser.add_argument( 205 | "--compress", 206 | action='store_true', 207 | help="delta path provided is a compressed checkpoint.", 208 | ) 209 | parser.add_argument( 210 | "--ref", 211 | type=str, 212 | help="path to ref image", 213 | default=None 214 | ) 215 | opt = parser.parse_args() 216 | 217 | if opt.wandb_log: 218 | if opt.delta_ckpt is not None: 219 | name = opt.delta_ckpt.split('/')[-3] 220 | elif 'checkpoints' in opt.ckpt: 221 | name = opt.ckpt.split('/')[-3] 222 | else: 223 | name = opt.ckpt.split('/')[-1] 224 | wandb.init(project="DETEX", entity="cmu-gil", name=name) 225 | 226 | if opt.delta_ckpt is not None: 227 | if len(glob.glob(os.path.join(opt.delta_ckpt.split('checkpoints')[0], "configs/*.yaml"))) > 0: 228 | opt.config = sorted(glob.glob(os.path.join(opt.delta_ckpt.split('checkpoints')[0], "configs/*.yaml")))[-1] 229 | else: 230 | if len(glob.glob(os.path.join(opt.ckpt.split('checkpoints')[0], "configs/*.yaml"))) > 0: 231 | opt.config = sorted(glob.glob(os.path.join(opt.ckpt.split('checkpoints')[0], "configs/*.yaml")))[-1] 232 | 233 | seed_everything(opt.seed) 234 | ###################################################################################### 235 | config = OmegaConf.load(f"{opt.config}") 236 | model = load_model_from_config(config, f"{opt.ckpt}") 237 | 238 | if opt.delta_ckpt is not None: 239 | delta_st = torch.load(opt.delta_ckpt, map_location='cuda:0') 240 | embed = None 241 | if 'embed' in delta_st['state_dict']: 242 | embed = delta_st['state_dict']['embed'].reshape(-1, 768) 243 | del delta_st['state_dict']['embed'] 244 | print(embed.shape) 245 | delta_st = delta_st['state_dict'] 246 | if opt.compress: 247 | for name in delta_st.keys(): 248 | if 'to_k' in name or 'to_v' in name: 249 | delta_st[name] = model.state_dict()[name] + delta_st[name]['u'] @ delta_st[name]['v'] 250 | model.load_state_dict(delta_st, strict=False) 251 | else: 252 | model.load_state_dict(delta_st, strict=False) 253 | if embed is not None: 254 | print("loading new embedding") 255 | print(model.cond_stage_model.transformer.text_model.embeddings.token_embedding.weight.data.shape) 256 | model.cond_stage_model.transformer.text_model.embeddings.token_embedding.weight.data[ 257 | -embed.shape[0]:] = embed 258 | 259 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 260 | model = model.to(device) 261 | 262 | if opt.plms: 263 | sampler = PLMSSampler(model) 264 | else: 265 | sampler = DETEXSampler(model) 266 | 267 | if opt.delta_ckpt is not None: 268 | outpath = os.path.dirname(os.path.dirname(opt.delta_ckpt)) 269 | else: 270 | os.makedirs(opt.outdir, exist_ok=True) 271 | outpath = opt.outdir 272 | 273 | batch_size = opt.n_samples 274 | n_rows = opt.n_rows if opt.n_rows > 0 else batch_size 275 | if not opt.from_file: 276 | prompt = opt.prompt 277 | assert prompt is not None 278 | data = [batch_size * [prompt]] 279 | else: 280 | print(f"reading prompts from {opt.from_file}") 281 | with open(opt.from_file, "r") as f: 282 | data = f.read().splitlines() 283 | data = [batch_size * [prompt] for prompt in data] 284 | 285 | sample_path = os.path.join(outpath, "samples") 286 | 287 | os.makedirs(sample_path, exist_ok=True) 288 | base_count = len(os.listdir(sample_path)) 289 | grid_count = len(os.listdir(outpath)) - 1 290 | 291 | start_code = None 292 | if opt.fixed_code: 293 | start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device) 294 | 295 | ref_image = None 296 | if opt.ref is not None: 297 | ref_image = load_img(opt.ref).to(device) 298 | ref_image = repeat(ref_image, '1 ... -> b ...', b=batch_size) 299 | 300 | precision_scope = autocast if opt.precision == "autocast" else nullcontext 301 | with torch.no_grad(): 302 | with precision_scope("cuda"): 303 | with model.ema_scope(): 304 | for prompts in tqdm(data, desc="data"): 305 | all_samples = list() 306 | for n in trange(opt.n_iter, desc="Sampling"): 307 | print(prompts[0]) 308 | uc = None 309 | if opt.scale != 1.0: 310 | uc = model.get_learned_conditioning(batch_size * [""]) 311 | if isinstance(prompts, tuple): 312 | prompts = list(prompts) 313 | c = model.get_learned_conditioning(prompts, ref_image) 314 | shape = [opt.C, opt.H // opt.f, opt.W // opt.f] 315 | samples_ddim, _ = sampler.sample(S=opt.ddim_steps, 316 | conditioning=c, 317 | batch_size=opt.n_samples, 318 | shape=shape, 319 | verbose=False, 320 | unconditional_guidance_scale=opt.scale, 321 | unconditional_conditioning=uc, 322 | eta=opt.ddim_eta, 323 | x_T=start_code) 324 | # print(samples_ddim.size()) 325 | x_samples_ddim = model.decode_first_stage(samples_ddim) 326 | x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) 327 | x_samples_ddim = x_samples_ddim.cpu() 328 | 329 | if not opt.skip_save: 330 | for idx, x_sample in enumerate(x_samples_ddim): 331 | x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') 332 | img = Image.fromarray(x_sample.astype(np.uint8)) 333 | img.save(os.path.join(sample_path, f"{base_count:05}.png")) 334 | 335 | base_count += 1 336 | 337 | if not opt.skip_grid: 338 | all_samples.append(x_samples_ddim) 339 | 340 | if not opt.skip_grid: 341 | # additionally, save as grid 342 | grid = torch.stack(all_samples, 0) 343 | grid = rearrange(grid, 'n b c h w -> (n b) c h w') 344 | grid = make_grid(grid, nrow=n_rows) 345 | 346 | # to image 347 | grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() 348 | img = Image.fromarray(grid.astype(np.uint8)) 349 | sampling_method = 'plms' if opt.plms else 'ddim' 350 | img.save(os.path.join(outpath, 351 | f'{prompts[0].replace(" ", "-").replace("<", "_").replace(">", "_")}_{opt.scale}_{sampling_method}_{opt.ddim_steps}_{opt.ddim_eta}.png')) 352 | if opt.wandb_log: 353 | wandb.log({ 354 | f'{prompts[0].replace(" ", "-")}_{opt.scale}_{sampling_method}_{opt.ddim_steps}_{opt.ddim_eta}.png': [ 355 | wandb.Image(img)]}) 356 | grid_count += 1 357 | 358 | print(f"Your samples are ready and waiting for you here: \n{outpath} \n" 359 | f" \nEnjoy.") 360 | 361 | 362 | if __name__ == "__main__": 363 | main() 364 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn.functional as F 4 | from einops import rearrange, repeat 5 | from torch import nn, einsum 6 | 7 | from ldm.models.diffusion.ddpm import LatentDiffusion as LatentDiffusion 8 | from ldm.modules.distributions.distributions import DiagonalGaussianDistribution 9 | # from ldm.models.diffusion.ddim import ret_attn 10 | from ldm.util import default 11 | from ldm.modules.attention import BasicTransformerBlock as BasicTransformerBlock 12 | from ldm.modules.attention import CrossAttention as CrossAttention 13 | from ldm.util import log_txt_as_img, exists, ismap, isimage, mean_flat, count_params, instantiate_from_config 14 | from torchvision.utils import make_grid 15 | from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL 16 | import numpy as np 17 | 18 | ret_attn = [] 19 | 20 | def extract_into_tensor(a, t, x_shape): 21 | b, *_ = t.shape 22 | out = a.gather(-1, t) 23 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 24 | 25 | 26 | def get_average_attn(attn, res): 27 | attn_maps = [] 28 | for item in attn: 29 | attnmap = item.reshape(-1, 8, res, res, item.shape[-1]) 30 | attn_maps.append(attnmap) 31 | attn_maps = torch.cat(attn_maps, dim=1) 32 | average_attn = attn_maps.sum(1) / attn_maps.shape[1] 33 | 34 | return average_attn 35 | 36 | 37 | __conditioning_keys__ = {'concat': 'c_concat', 38 | 'crossattn': 'c_crossattn', 39 | 'adm': 'y'} 40 | 41 | class DETEX(LatentDiffusion): 42 | def __init__(self, 43 | freeze_model='crossattn-kv', 44 | cond_stage_trainable=False, 45 | add_token=False, 46 | *args, **kwargs): 47 | 48 | self.attn_res = 32 49 | self.freeze_model = freeze_model 50 | self.add_token = add_token 51 | self.cond_stage_trainable = cond_stage_trainable 52 | super().__init__(cond_stage_trainable=cond_stage_trainable, *args, **kwargs) 53 | 54 | self.cnt = np.zeros(2) 55 | 56 | if self.freeze_model == 'crossattn-kv': 57 | for x in self.model.diffusion_model.named_parameters(): 58 | if 'transformer_blocks' not in x[0]: 59 | x[1].requires_grad = False 60 | elif not ('attn2.to_k' in x[0] or 'attn2.to_v' in x[0]): 61 | x[1].requires_grad = False 62 | else: 63 | x[1].requires_grad = True 64 | elif self.freeze_model == 'crossattn': 65 | for x in self.model.diffusion_model.named_parameters(): 66 | if 'transformer_blocks' not in x[0]: 67 | x[1].requires_grad = False 68 | elif not 'attn2' in x[0]: 69 | x[1].requires_grad = False 70 | else: 71 | x[1].requires_grad = True 72 | 73 | def change_checkpoint(model): 74 | for layer in model.children(): 75 | if type(layer) == BasicTransformerBlock: 76 | layer.checkpoint = False 77 | else: 78 | change_checkpoint(layer) 79 | 80 | change_checkpoint(self.model.diffusion_model) 81 | 82 | def new_forward(self, x, context=None, mask=None): 83 | h = self.heads 84 | crossattn = False 85 | if context is not None: 86 | crossattn = True 87 | q = self.to_q(x) 88 | context = default(context, x) 89 | k = self.to_k(context) 90 | v = self.to_v(context) 91 | 92 | if crossattn: 93 | modifier = torch.ones_like(k) 94 | modifier[:, :1, :] = modifier[:, :1, :] * 0. 95 | k = modifier * k + (1 - modifier) * k.detach() 96 | v = modifier * v + (1 - modifier) * v.detach() 97 | 98 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 99 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 100 | attn = sim.softmax(dim=-1) 101 | 102 | if crossattn: 103 | global ret_attn 104 | if attn.shape[1] == 32 * 32: 105 | ret_attn += [attn] 106 | 107 | out = einsum('b i j, b j d -> b i d', attn, v) 108 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 109 | return self.to_out(out) 110 | 111 | def change_forward(model): 112 | for layer in model.children(): 113 | if type(layer) == CrossAttention: 114 | bound_method = new_forward.__get__(layer, layer.__class__) 115 | setattr(layer, 'forward', bound_method) 116 | else: 117 | change_forward(layer) 118 | 119 | change_forward(self.model.diffusion_model) 120 | 121 | def configure_optimizers(self): 122 | lr = self.learning_rate 123 | params = [] 124 | if self.freeze_model == 'crossattn-kv': 125 | for x in self.model.diffusion_model.named_parameters(): 126 | if 'transformer_blocks' in x[0]: 127 | if 'attn2.to_k' in x[0] or 'attn2.to_v' in x[0]: 128 | params += [x[1]] 129 | print(x[0]) 130 | elif self.freeze_model == 'crossattn': 131 | for x in self.model.diffusion_model.named_parameters(): 132 | if 'transformer_blocks' in x[0]: 133 | if 'attn2' in x[0]: 134 | params += [x[1]] 135 | print(x[0]) 136 | else: 137 | params = list(self.model.parameters()) 138 | 139 | if self.cond_stage_trainable: 140 | if self.add_token: 141 | print(f"{self.__class__.__name__}: Also optimizing conditioner params!") 142 | params = params + self.cond_stage_model.return_parameters() # optimize mapper and embeddings 143 | else: 144 | params = params + list(self.cond_stage_model.parameters()) 145 | 146 | # cond_params = [] 147 | # if self.cond_stage_trainable: 148 | # if self.add_token: 149 | # print(f"{self.__class__.__name__}: Also optimizing conditioner params!") 150 | # cond_params = cond_params + list( 151 | # self.cond_stage_model.return_parameters()) 152 | # else: 153 | # cond_params = cond_params + list(self.cond_stage_model.parameters()) 154 | 155 | if self.learn_logvar: 156 | print('Diffusion model optimizing logvar') 157 | params.append(self.logvar) 158 | 159 | opt = torch.optim.AdamW(params, lr=lr) 160 | # opt = torch.optim.AdamW([ 161 | # {'params': params, 'lr': lr}, 162 | # {'params': cond_params, 'lr': lr} 163 | # ]) 164 | 165 | if self.use_scheduler: 166 | assert 'target' in self.scheduler_config 167 | scheduler = instantiate_from_config(self.scheduler_config) 168 | 169 | print("Setting up LambdaLR scheduler...") 170 | scheduler = [ 171 | { 172 | 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), 173 | 'interval': 'step', 174 | 'frequency': 1 175 | }] 176 | return [opt], scheduler 177 | return opt 178 | 179 | def get_learned_conditioning(self, c, x=None): 180 | if self.cond_stage_forward is None: 181 | if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): 182 | c = self.cond_stage_model.encode(c, input_img=x) 183 | if isinstance(c, DiagonalGaussianDistribution): 184 | c = c.mode() 185 | else: 186 | c = self.cond_stage_model(c, input_img=x) 187 | else: 188 | assert hasattr(self.cond_stage_model, self.cond_stage_forward) 189 | c = getattr(self.cond_stage_model, self.cond_stage_forward)(c, input_img=x) 190 | return c 191 | 192 | @torch.no_grad() 193 | def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, 194 | cond_key=None, return_original_cond=False, bs=None): 195 | x = super().get_input(batch, k) 196 | if bs is not None: 197 | x = x[:bs] 198 | x = x.to(self.device) 199 | encoder_posterior = self.encode_first_stage(x) 200 | z = self.get_first_stage_encoding(encoder_posterior).detach() 201 | 202 | if self.model.conditioning_key is not None: 203 | if cond_key is None: 204 | cond_key = self.cond_stage_key 205 | if cond_key != self.first_stage_key: 206 | if cond_key in ['caption', 'coordinates_bbox']: 207 | xc = batch[cond_key] 208 | elif cond_key == 'class_label': 209 | xc = batch 210 | else: 211 | xc = super().get_input(batch, cond_key).to(self.device) 212 | else: 213 | xc = x 214 | if not self.cond_stage_trainable or force_c_encode: 215 | if isinstance(xc, dict) or isinstance(xc, list): 216 | # import pudb; pudb.set_trace() 217 | c = self.get_learned_conditioning(xc, x) 218 | else: 219 | c = self.get_learned_conditioning(xc.to(self.device), x) 220 | else: 221 | c = xc 222 | if bs is not None: 223 | c = c[:bs] 224 | 225 | if self.use_positional_encodings: 226 | pos_x, pos_y = self.compute_latent_shifts(batch) 227 | ckey = __conditioning_keys__[self.model.conditioning_key] 228 | c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y} 229 | 230 | else: 231 | c = None 232 | xc = None 233 | if self.use_positional_encodings: 234 | pos_x, pos_y = self.compute_latent_shifts(batch) 235 | c = {'pos_x': pos_x, 'pos_y': pos_y} 236 | out = [z, c] 237 | if return_first_stage_outputs: 238 | xrec = self.decode_first_stage(z) 239 | out.extend([x, xrec]) 240 | if return_original_cond: 241 | out.append(xc) 242 | return out 243 | 244 | 245 | def p_losses(self, x_start, cond, t, mask=None, noise=None, img=None, caption=None, fg_attn_mask=None, bg_attn_mask=None): 246 | global ret_attn 247 | ret_attn.clear() 248 | 249 | noise = default(noise, lambda: torch.randn_like(x_start)) 250 | x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) 251 | model_output = self.apply_model(x_noisy, t, cond) 252 | 253 | average_attn = get_average_attn(ret_attn, self.attn_res) 254 | 255 | loss_dict = {} 256 | prefix = 'train' if self.training else 'val' 257 | 258 | if self.parameterization == "x0": 259 | target = x_start 260 | elif self.parameterization == "eps": 261 | target = noise 262 | else: 263 | raise NotImplementedError() 264 | 265 | loss_simple = self.get_loss(model_output, target, mean=False) 266 | if mask is not None: 267 | loss_simple = (loss_simple * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3]) 268 | else: 269 | loss_simple = loss_simple.mean([1, 2, 3]) 270 | 271 | loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) 272 | 273 | index = self.cond_stage_model.get_index(caption) 274 | 275 | if ' b c h w') 319 | mask = mask.to(memory_format=torch.contiguous_format).float() 320 | out += [mask] 321 | 322 | img = batch["image"] 323 | img = rearrange(img, 'b h w c -> b c h w') 324 | img = img.to(memory_format=torch.contiguous_format).float() 325 | out += [img] 326 | 327 | caption = batch["caption"] 328 | out += [caption] 329 | 330 | fg_attn_mask = batch["fg_attn_mask"] 331 | if len(fg_attn_mask.shape) == 3: 332 | fg_attn_mask = fg_attn_mask[..., None] 333 | fg_attn_mask = rearrange(fg_attn_mask, 'b h w c -> b c h w') 334 | fg_attn_mask = fg_attn_mask.to(memory_format=torch.contiguous_format).float() 335 | out += [fg_attn_mask] 336 | 337 | bg_attn_mask = batch["bg_attn_mask"] 338 | if len(bg_attn_mask.shape) == 3: 339 | bg_attn_mask = bg_attn_mask[..., None] 340 | bg_attn_mask = rearrange(bg_attn_mask, 'b h w c -> b c h w') 341 | bg_attn_mask = bg_attn_mask.to(memory_format=torch.contiguous_format).float() 342 | out += [bg_attn_mask] 343 | 344 | return out 345 | 346 | def training_step(self, batch, batch_idx): 347 | if isinstance(batch, list): 348 | train_batch = batch[0] 349 | train2_batch = batch[1] 350 | loss_train, loss_dict = self.shared_step(train_batch) 351 | loss_train2, _ = self.shared_step(train2_batch) 352 | loss = loss_train + loss_train2 353 | else: 354 | train_batch = batch 355 | loss, loss_dict = self.shared_step(train_batch) 356 | 357 | self.log_dict(loss_dict, prog_bar=True, 358 | logger=True, on_step=True, on_epoch=True) 359 | 360 | self.log("global_step", self.global_step, 361 | prog_bar=True, logger=True, on_step=True, on_epoch=False) 362 | 363 | if self.use_scheduler: 364 | lr = self.optimizers().param_groups[0]['lr'] 365 | self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) 366 | 367 | return loss 368 | 369 | def shared_step(self, batch, **kwargs): 370 | x, c, mask, img, caption, fg_attn_mask, bg_attn_mask, = self.get_input_withmask(batch, **kwargs) 371 | loss = self(x, c, mask=mask, img=img, caption=caption, fg_attn_mask=fg_attn_mask, bg_attn_mask=bg_attn_mask) 372 | return loss 373 | 374 | def forward(self, x, c, *args, **kwargs): 375 | t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() 376 | if self.model.conditioning_key is not None: 377 | assert c is not None 378 | if self.cond_stage_trainable: 379 | c = self.get_learned_conditioning(c, x=kwargs['img']) 380 | if self.shorten_cond_schedule: # TODO: drop this option 381 | tc = self.cond_ids[t].to(self.device) 382 | c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) 383 | return self.p_losses(x, c, t, *args, **kwargs) 384 | 385 | @torch.no_grad() 386 | def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, 387 | quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, 388 | plot_diffusion_rows=True, **kwargs): 389 | 390 | use_ddim = ddim_steps is not None 391 | 392 | log = dict() 393 | if isinstance(batch, list): 394 | batch = batch[0] 395 | z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, 396 | return_first_stage_outputs=True, 397 | force_c_encode=True, 398 | return_original_cond=True, 399 | bs=N) 400 | N = min(x.shape[0], N) 401 | n_row = min(x.shape[0], n_row) 402 | log["inputs"] = x 403 | log["reconstruction"] = xrec 404 | if self.model.conditioning_key is not None: 405 | if hasattr(self.cond_stage_model, "decode"): 406 | xc = self.cond_stage_model.decode(c) 407 | log["conditioning"] = xc 408 | elif self.cond_stage_key in ["caption"]: 409 | xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"]) 410 | log["conditioning"] = xc 411 | elif self.cond_stage_key == 'class_label': 412 | xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) 413 | log['conditioning'] = xc 414 | elif isimage(xc): 415 | log["conditioning"] = xc 416 | if ismap(xc): 417 | log["original_conditioning"] = self.to_rgb(xc) 418 | 419 | if plot_diffusion_rows: 420 | # get diffusion row 421 | diffusion_row = list() 422 | z_start = z[:n_row] 423 | for t in range(self.num_timesteps): 424 | if t % self.log_every_t == 0 or t == self.num_timesteps - 1: 425 | t = repeat(torch.tensor([t]), '1 -> b', b=n_row) 426 | t = t.to(self.device).long() 427 | noise = torch.randn_like(z_start) 428 | z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) 429 | diffusion_row.append(self.decode_first_stage(z_noisy)) 430 | 431 | diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W 432 | diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') 433 | diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') 434 | diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) 435 | log["diffusion_row"] = diffusion_grid 436 | 437 | if sample: 438 | # get denoise row 439 | with self.ema_scope("Plotting"): 440 | unconditional_guidance_scale = 6. 441 | unconditional_conditioning = self.get_learned_conditioning(len(c) * [""]) 442 | samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, 443 | ddim_steps=ddim_steps, eta=ddim_eta, 444 | unconditional_conditioning=unconditional_conditioning, 445 | unconditional_guidance_scale=unconditional_guidance_scale) 446 | # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) 447 | x_samples = self.decode_first_stage(samples) 448 | log["samples_scaled"] = x_samples 449 | if plot_denoise_rows: 450 | denoise_grid = self._get_denoise_row_from_list(z_denoise_row) 451 | log["denoise_row"] = denoise_grid 452 | 453 | if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( 454 | self.first_stage_model, IdentityFirstStage): 455 | # also display when quantizing x0 while sampling 456 | with self.ema_scope("Plotting Quantized Denoised"): 457 | samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, 458 | ddim_steps=ddim_steps, eta=ddim_eta, 459 | quantize_denoised=True) 460 | # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, 461 | # quantize_denoised=True) 462 | x_samples = self.decode_first_stage(samples.to(self.device)) 463 | log["samples_x0_quantized"] = x_samples 464 | 465 | if inpaint: 466 | # make a simple center square 467 | b, h, w = z.shape[0], z.shape[2], z.shape[3] 468 | mask = torch.ones(N, h, w).to(self.device) 469 | # zeros will be filled in 470 | mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. 471 | mask = mask[:, None, ...] 472 | with self.ema_scope("Plotting Inpaint"): 473 | samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, 474 | ddim_steps=ddim_steps, x0=z[:N], mask=mask) 475 | x_samples = self.decode_first_stage(samples.to(self.device)) 476 | log["samples_inpainting"] = x_samples 477 | log["mask"] = mask 478 | 479 | # outpaint 480 | with self.ema_scope("Plotting Outpaint"): 481 | samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, 482 | ddim_steps=ddim_steps, x0=z[:N], mask=mask) 483 | x_samples = self.decode_first_stage(samples.to(self.device)) 484 | log["samples_outpainting"] = x_samples 485 | 486 | if plot_progressive_rows: 487 | with self.ema_scope("Plotting Progressives"): 488 | img, progressives = self.progressive_denoising(c, 489 | shape=(self.channels, self.image_size, self.image_size), 490 | batch_size=N) 491 | prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") 492 | log["progressive_row"] = prog_row 493 | 494 | if return_keys: 495 | if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: 496 | return log 497 | else: 498 | return {key: log[key] for key in return_keys} 499 | return log 500 | 501 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, datetime, glob 2 | sys.path.append('stable-diffusion') 3 | import numpy as np 4 | import time 5 | import torch 6 | import torchvision 7 | import pytorch_lightning as pl 8 | 9 | from packaging import version 10 | from omegaconf import OmegaConf 11 | from torch.utils.data import DataLoader, Dataset 12 | from functools import partial 13 | from PIL import Image 14 | 15 | from pytorch_lightning import seed_everything 16 | from pytorch_lightning.trainer import Trainer 17 | from pytorch_lightning.callbacks import Callback, LearningRateMonitor 18 | from pytorch_lightning.utilities.distributed import rank_zero_only 19 | from pytorch_lightning.utilities import rank_zero_info 20 | 21 | from ldm.data.base import Txt2ImgIterableBaseDataset 22 | from ldm.util import instantiate_from_config 23 | 24 | 25 | 26 | def get_parser(**parser_kwargs): 27 | def str2bool(v): 28 | if isinstance(v, bool): 29 | return v 30 | if v.lower() in ("yes", "true", "t", "y", "1"): 31 | return True 32 | elif v.lower() in ("no", "false", "f", "n", "0"): 33 | return False 34 | else: 35 | raise argparse.ArgumentTypeError("Boolean value expected.") 36 | 37 | parser = argparse.ArgumentParser(**parser_kwargs) 38 | parser.add_argument( 39 | "-n", 40 | "--name", 41 | type=str, 42 | const=True, 43 | default="", 44 | nargs="?", 45 | help="postfix for logdir", 46 | ) 47 | parser.add_argument( 48 | "-r", 49 | "--resume", 50 | type=str, 51 | const=True, 52 | default="", 53 | nargs="?", 54 | help="resume from logdir or checkpoint in logdir", 55 | ) 56 | parser.add_argument( 57 | "-rc", 58 | "--resume-from-checkpoint-custom", 59 | type=str, 60 | const=True, 61 | default="", 62 | nargs="?", 63 | help="resume from logdir or checkpoint in logdir", 64 | ) 65 | parser.add_argument( 66 | "--delta-ckpt", 67 | type=str, 68 | const=True, 69 | default=None, 70 | nargs="?", 71 | help="resume from logdir or checkpoint in logdir", 72 | ) 73 | parser.add_argument( 74 | "-b", 75 | "--base", 76 | nargs="*", 77 | metavar="base_config.yaml", 78 | help="paths to base configs. Loaded from left-to-right. " 79 | "Parameters can be overwritten or added with command-line options of the form `--key value`.", 80 | default=list(), 81 | ) 82 | parser.add_argument( 83 | "-t", 84 | "--train", 85 | type=str2bool, 86 | const=True, 87 | default=False, 88 | nargs="?", 89 | help="train", 90 | ) 91 | parser.add_argument( 92 | "--no-test", 93 | type=str2bool, 94 | const=True, 95 | default=False, 96 | nargs="?", 97 | help="disable test", 98 | ) 99 | parser.add_argument( 100 | "-p", 101 | "--project", 102 | help="name of new or path to existing project" 103 | ) 104 | parser.add_argument( 105 | "-d", 106 | "--debug", 107 | type=str2bool, 108 | nargs="?", 109 | const=True, 110 | default=False, 111 | help="enable post-mortem debugging", 112 | ) 113 | parser.add_argument( 114 | "-s", 115 | "--seed", 116 | type=int, 117 | default=23, 118 | help="seed for seed_everything", 119 | ) 120 | parser.add_argument( 121 | "-f", 122 | "--postfix", 123 | type=str, 124 | default="", 125 | help="post-postfix for default name", 126 | ) 127 | parser.add_argument( 128 | "-l", 129 | "--logdir", 130 | type=str, 131 | default="logs", 132 | help="directory for logging dat shit", 133 | ) 134 | parser.add_argument( 135 | "--scale_lr", 136 | type=str2bool, 137 | nargs="?", 138 | const=True, 139 | default=True, 140 | help="scale base-lr by ngpu * batch_size * n_accumulate", 141 | ) 142 | parser.add_argument( 143 | "--datapath", 144 | type=str, 145 | default="", 146 | help="path to target images", 147 | ) 148 | parser.add_argument( 149 | "--num_imgs", 150 | type=int, 151 | default=4, 152 | help="num of input images", 153 | ) 154 | parser.add_argument( 155 | "--reg_datapath", 156 | type=str, 157 | default=None, 158 | help="path to regularization images", 159 | ) 160 | parser.add_argument( 161 | "--mask_path", 162 | type=str, 163 | default="", 164 | help="path to image fg", 165 | ) 166 | parser.add_argument( 167 | "--mask_path2", 168 | type=str, 169 | default="", 170 | help="path to image masks", 171 | ) 172 | parser.add_argument( 173 | "--caption", 174 | type=str, 175 | default="", 176 | help="path to target images", 177 | ) 178 | parser.add_argument( 179 | "--reg_caption", 180 | type=str, 181 | default="", 182 | help="path to target images", 183 | ) 184 | parser.add_argument( 185 | "--datapath2", 186 | type=str, 187 | default="", 188 | help="path to target images", 189 | ) 190 | parser.add_argument( 191 | "--reg_datapath2", 192 | type=str, 193 | default=None, 194 | help="path to regularization images", 195 | ) 196 | parser.add_argument( 197 | "--caption2", 198 | type=str, 199 | default="", 200 | help="path to target images", 201 | ) 202 | parser.add_argument( 203 | "--reg_caption2", 204 | type=str, 205 | default="", 206 | help="path to regularization images' caption", 207 | ) 208 | parser.add_argument( 209 | "--modifier_token", 210 | type=str, 211 | default=None, 212 | help="token added before cateogry word for personalization use case", 213 | ) 214 | parser.add_argument( 215 | "--freeze_model", 216 | type=str, 217 | default=None, 218 | help="crossattn to enable fine-tuning of all key, value, query matrices", 219 | ) 220 | parser.add_argument( 221 | "--repeat", 222 | type=int, 223 | default=0, 224 | help="repeat the target dataset by how many times. Used when training without regularization", 225 | ) 226 | parser.add_argument( 227 | "--batch_size", 228 | type=int, 229 | default=None, 230 | help="overwrite batch size", 231 | ) 232 | return parser 233 | 234 | 235 | def nondefault_trainer_args(opt): 236 | parser = argparse.ArgumentParser() 237 | parser = Trainer.add_argparse_args(parser) 238 | args = parser.parse_args([]) 239 | return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k)) 240 | 241 | 242 | class WrappedDataset(Dataset): 243 | """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" 244 | 245 | def __init__(self, dataset): 246 | self.data = dataset 247 | 248 | def __len__(self): 249 | return len(self.data) 250 | 251 | def __getitem__(self, idx): 252 | return self.data[idx] 253 | 254 | 255 | def worker_init_fn(_): 256 | worker_info = torch.utils.data.get_worker_info() 257 | 258 | dataset = worker_info.dataset 259 | worker_id = worker_info.id 260 | 261 | if isinstance(dataset, Txt2ImgIterableBaseDataset): 262 | split_size = dataset.num_records // worker_info.num_workers 263 | # reset num_records to the true number to retain reliable length information 264 | dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size] 265 | current_id = np.random.choice(len(np.random.get_state()[1]), 1) 266 | return np.random.seed(np.random.get_state()[1][current_id] + worker_id) 267 | else: 268 | return np.random.seed(np.random.get_state()[1][0] + worker_id) 269 | 270 | 271 | class ConcatDataset(Dataset): 272 | def __init__(self, *datasets): 273 | self.datasets = datasets 274 | 275 | def __getitem__(self, idx): 276 | return tuple(d[idx] for d in self.datasets) 277 | 278 | def __len__(self): 279 | return min(len(d) for d in self.datasets) 280 | 281 | 282 | class DataModuleFromConfig(pl.LightningDataModule): 283 | def __init__(self, batch_size, train=None, train2=None, validation=None, test=None, predict=None, 284 | wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False, 285 | shuffle_val_dataloader=False): 286 | super().__init__() 287 | self.batch_size = batch_size 288 | self.dataset_configs = dict() 289 | self.num_workers = num_workers if num_workers is not None else batch_size * 2 290 | self.use_worker_init_fn = use_worker_init_fn 291 | 292 | if train is not None: 293 | self.dataset_configs["train"] = train 294 | self.train_dataloader = self._train_dataloader 295 | if validation is not None: 296 | self.dataset_configs["validation"] = validation 297 | self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader) 298 | if test is not None: 299 | self.dataset_configs["test"] = test 300 | self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader) 301 | if predict is not None: 302 | self.dataset_configs["predict"] = predict 303 | self.predict_dataloader = self._predict_dataloader 304 | self.wrap = wrap 305 | 306 | def prepare_data(self): 307 | for data_cfg in self.dataset_configs.values(): 308 | instantiate_from_config(data_cfg) 309 | 310 | def setup(self, stage=None): 311 | self.datasets = dict( 312 | (k, instantiate_from_config(self.dataset_configs[k])) 313 | for k in self.dataset_configs) 314 | if self.wrap: 315 | for k in self.datasets: 316 | self.datasets[k] = WrappedDataset(self.datasets[k]) 317 | 318 | def _train_dataloader(self): 319 | is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) 320 | if is_iterable_dataset or self.use_worker_init_fn: 321 | init_fn = worker_init_fn 322 | else: 323 | init_fn = None 324 | 325 | return DataLoader(self.datasets["train"], batch_size=self.batch_size, 326 | num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True, 327 | worker_init_fn=init_fn) 328 | 329 | def _val_dataloader(self, shuffle=False): 330 | if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: 331 | init_fn = worker_init_fn 332 | else: 333 | init_fn = None 334 | return DataLoader(self.datasets["validation"], 335 | batch_size=self.batch_size, 336 | num_workers=self.num_workers, 337 | worker_init_fn=init_fn, 338 | shuffle=shuffle) 339 | 340 | def _test_dataloader(self, shuffle=False): 341 | is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) 342 | if is_iterable_dataset or self.use_worker_init_fn: 343 | init_fn = worker_init_fn 344 | else: 345 | init_fn = None 346 | 347 | # do not shuffle dataloader for iterable dataset 348 | shuffle = shuffle and (not is_iterable_dataset) 349 | 350 | return DataLoader(self.datasets["test"], batch_size=self.batch_size, 351 | num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle) 352 | 353 | def _predict_dataloader(self, shuffle=False): 354 | if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: 355 | init_fn = worker_init_fn 356 | else: 357 | init_fn = None 358 | return DataLoader(self.datasets["predict"], batch_size=self.batch_size, 359 | num_workers=self.num_workers, worker_init_fn=init_fn) 360 | 361 | 362 | class SetupCallback(Callback): 363 | def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config): 364 | super().__init__() 365 | self.resume = resume 366 | self.now = now 367 | self.logdir = logdir 368 | self.ckptdir = ckptdir 369 | self.cfgdir = cfgdir 370 | self.config = config 371 | self.lightning_config = lightning_config 372 | 373 | def on_keyboard_interrupt(self, trainer, pl_module): 374 | if trainer.global_rank == 0: 375 | print("Summoning checkpoint.") 376 | ckpt_path = os.path.join(self.ckptdir, "last.ckpt") 377 | trainer.save_checkpoint(ckpt_path) 378 | 379 | def on_pretrain_routine_start(self, trainer, pl_module): 380 | if trainer.global_rank == 0: 381 | # Create logdirs and save configs 382 | os.makedirs(self.logdir, exist_ok=True) 383 | os.makedirs(self.ckptdir, exist_ok=True) 384 | os.makedirs(self.cfgdir, exist_ok=True) 385 | 386 | if "callbacks" in self.lightning_config: 387 | if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']: 388 | os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True) 389 | print("Project config") 390 | print(OmegaConf.to_yaml(self.config)) 391 | OmegaConf.save(self.config, 392 | os.path.join(self.cfgdir, "{}-project.yaml".format(self.now))) 393 | 394 | print("Lightning config") 395 | print(OmegaConf.to_yaml(self.lightning_config)) 396 | OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}), 397 | os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now))) 398 | 399 | else: 400 | # ModelCheckpoint callback created log directory --- remove it 401 | if not self.resume and os.path.exists(self.logdir): 402 | dst, name = os.path.split(self.logdir) 403 | dst = os.path.join(dst, "child_runs", name) 404 | os.makedirs(os.path.split(dst)[0], exist_ok=True) 405 | try: 406 | os.rename(self.logdir, dst) 407 | except FileNotFoundError: 408 | pass 409 | 410 | 411 | class ImageLogger(Callback): 412 | def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True, 413 | rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False, 414 | log_images_kwargs=None): 415 | super().__init__() 416 | self.rescale = rescale 417 | self.batch_freq = batch_frequency 418 | self.max_images = max_images 419 | self.save_freq = 250 420 | self.logger_log_images = { 421 | pl.loggers.TestTubeLogger: self._testtube, 422 | } 423 | self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)] 424 | if not increase_log_steps: 425 | self.log_steps = [self.batch_freq] 426 | self.clamp = clamp 427 | self.disabled = disabled 428 | self.log_on_batch_idx = log_on_batch_idx 429 | self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} 430 | self.log_first_step = log_first_step 431 | 432 | @rank_zero_only 433 | def _testtube(self, pl_module, images, batch_idx, split): 434 | for k in images: 435 | grid = torchvision.utils.make_grid(images[k]) 436 | grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w 437 | 438 | tag = f"{split}/{k}" 439 | pl_module.logger.experiment.add_image( 440 | tag, grid, 441 | global_step=pl_module.global_step) 442 | 443 | @rank_zero_only 444 | def log_local(self, save_dir, split, images, 445 | global_step, current_epoch, batch_idx): 446 | root = os.path.join(save_dir, "images", split) 447 | for k in images: 448 | grid = torchvision.utils.make_grid(images[k], nrow=4) 449 | if self.rescale: 450 | grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w 451 | grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) 452 | grid = grid.numpy() 453 | grid = (grid * 255).astype(np.uint8) 454 | filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format( 455 | k, 456 | global_step, 457 | current_epoch, 458 | batch_idx) 459 | path = os.path.join(root, filename) 460 | os.makedirs(os.path.split(path)[0], exist_ok=True) 461 | Image.fromarray(grid).save(path) 462 | 463 | def log_img(self, pl_module, batch, batch_idx, split="train"): 464 | check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step 465 | if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0 466 | hasattr(pl_module, "log_images") and 467 | callable(pl_module.log_images) and 468 | self.max_images > 0): 469 | logger = type(pl_module.logger) 470 | 471 | is_train = pl_module.training 472 | if is_train: 473 | pl_module.eval() 474 | 475 | with torch.no_grad(): 476 | images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) 477 | 478 | for k in images: 479 | N = min(images[k].shape[0], self.max_images) 480 | images[k] = images[k][:N] 481 | if isinstance(images[k], torch.Tensor): 482 | images[k] = images[k].detach().cpu() 483 | if self.clamp: 484 | images[k] = torch.clamp(images[k], -1., 1.) 485 | 486 | self.log_local(pl_module.logger.save_dir, split, images, 487 | pl_module.global_step, pl_module.current_epoch, batch_idx) 488 | 489 | logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None) 490 | logger_log_images(pl_module, images, pl_module.global_step, split) 491 | 492 | if is_train: 493 | pl_module.train() 494 | 495 | def check_frequency(self, check_idx): 496 | if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and ( 497 | check_idx > 0 or self.log_first_step): 498 | try: 499 | self.log_steps.pop(0) 500 | except IndexError as e: 501 | print(e) 502 | pass 503 | return True 504 | return False 505 | 506 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 507 | if not self.disabled and (pl_module.global_step > 0 or self.log_first_step): 508 | self.log_img(pl_module, batch, batch_idx, split="train") 509 | # if self.save_freq is not None: 510 | # epoch = trainer.current_epoch 511 | # global_step = trainer.global_step 512 | # if global_step % self.save_freq == 0: 513 | # filename = f'{epoch}_{global_step}.ckpt' 514 | # ckpt_path = os.path.join(trainer.checkpoint_callback.dirpath, filename) 515 | # trainer.save_checkpoint(ckpt_path) 516 | 517 | 518 | class CUDACallback(Callback): 519 | # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py 520 | def on_train_epoch_start(self, trainer, pl_module): 521 | # Reset the memory use counter 522 | torch.cuda.reset_peak_memory_stats(trainer.root_gpu) 523 | torch.cuda.synchronize(trainer.root_gpu) 524 | self.start_time = time.time() 525 | 526 | def on_train_epoch_end(self, trainer, pl_module): 527 | torch.cuda.synchronize(trainer.root_gpu) 528 | max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20 529 | epoch_time = time.time() - self.start_time 530 | 531 | try: 532 | max_memory = trainer.training_type_plugin.reduce(max_memory) 533 | epoch_time = trainer.training_type_plugin.reduce(epoch_time) 534 | 535 | rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") 536 | rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB") 537 | except AttributeError: 538 | pass 539 | 540 | 541 | if __name__ == "__main__": 542 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 543 | 544 | # add cwd for convenience and to make classes in this file available when 545 | # running as `python main.py` 546 | # (in particular `main.DataModuleFromConfig`) 547 | sys.path.append(os.getcwd()) 548 | 549 | parser = get_parser() 550 | parser = Trainer.add_argparse_args(parser) 551 | 552 | opt, unknown = parser.parse_known_args() 553 | if opt.name and opt.resume: 554 | raise ValueError( 555 | "-n/--name and -r/--resume cannot be specified both." 556 | "If you want to resume training in a new log folder, " 557 | "use -n/--name in combination with --resume_from_checkpoint" 558 | ) 559 | if opt.resume: 560 | if not os.path.exists(opt.resume): 561 | raise ValueError("Cannot find {}".format(opt.resume)) 562 | if os.path.isfile(opt.resume): 563 | paths = opt.resume.split("/") 564 | # idx = len(paths)-paths[::-1].index("logs")+1 565 | # logdir = "/".join(paths[:idx]) 566 | logdir = "/".join(paths[:-2]) 567 | ckpt = opt.resume 568 | else: 569 | assert os.path.isdir(opt.resume), opt.resume 570 | logdir = opt.resume.rstrip("/") 571 | ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") 572 | 573 | opt.resume_from_checkpoint = ckpt 574 | base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml"))) 575 | opt.base = base_configs + opt.base 576 | _tmp = logdir.split("/") 577 | nowname = _tmp[-1] 578 | else: 579 | if opt.name: 580 | name = "_" + opt.name 581 | elif opt.base: 582 | cfg_fname = os.path.split(opt.base[0])[-1] 583 | cfg_name = os.path.splitext(cfg_fname)[0] 584 | name = "_" + cfg_name 585 | else: 586 | name = "" 587 | nowname = now + name + opt.postfix 588 | # nowname = name + opt.postfix 589 | logdir = os.path.join(opt.logdir, nowname) 590 | 591 | ckptdir = os.path.join(logdir, "checkpoints") 592 | cfgdir = os.path.join(logdir, "configs") 593 | seed_everything(opt.seed) 594 | ##################################################################################### 595 | try: 596 | # init and save configs 597 | configs = [OmegaConf.load(cfg) for cfg in opt.base] 598 | cli = OmegaConf.from_dotlist(unknown) 599 | config = OmegaConf.merge(*configs, cli) 600 | lightning_config = config.pop("lightning", OmegaConf.create()) 601 | # merge trainer cli with config 602 | trainer_config = lightning_config.get("trainer", OmegaConf.create()) 603 | # default to ddp 604 | trainer_config["accelerator"] = "ddp" 605 | for k in nondefault_trainer_args(opt): 606 | trainer_config[k] = getattr(opt, k) 607 | if not ("gpus" in trainer_config): 608 | del trainer_config["accelerator"] 609 | cpu = True 610 | else: 611 | gpuinfo = trainer_config["gpus"] 612 | print(f"Running on GPUs {gpuinfo}") 613 | cpu = False 614 | trainer_opt = argparse.Namespace(**trainer_config) 615 | lightning_config.trainer = trainer_config 616 | 617 | # model 618 | config.data.params.train.params.caption = opt.caption 619 | config.data.params.train.params.reg_caption = opt.reg_caption 620 | config.data.params.train.params.datapath = opt.datapath 621 | config.data.params.train.params.reg_datapath = opt.reg_datapath 622 | config.data.params.train.params.modifier = opt.modifier_token 623 | config.data.params.train.params.mask_path = opt.mask_path 624 | config.data.params.train.params.mask_path2 = opt.mask_path2 625 | 626 | if opt.batch_size is not None: 627 | config.data.params.batch_size = opt.batch_size 628 | if opt.modifier_token is not None: 629 | config.model.params.cond_stage_config.params.modifier_token = opt.modifier_token 630 | if opt.num_imgs is not None: 631 | config.model.params.cond_stage_config.params.num_imgs = opt.num_imgs 632 | if opt.repeat > 0: 633 | config.data.params.train.params.repeat = opt.repeat 634 | 635 | if opt.resume_from_checkpoint_custom: 636 | config.model.params.ckpt_path = None 637 | if opt.freeze_model is not None: 638 | config.model.params.freeze_model = opt.freeze_model 639 | 640 | model = instantiate_from_config(config.model) 641 | if opt.resume_from_checkpoint_custom: 642 | st = torch.load(opt.resume_from_checkpoint_custom, map_location='cpu')["state_dict"] 643 | token_weights = st["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"] 644 | del st["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"] 645 | model.load_state_dict(st, strict=False) 646 | model.cond_stage_model.transformer.text_model.embeddings.token_embedding.weight.data[:token_weights.shape[0]] = token_weights 647 | if opt.delta_ckpt is not None: 648 | st = torch.load(opt.delta_ckpt) 649 | embed = None 650 | if 'embed' in st: 651 | embed = st['embed'].reshape(-1, 768) 652 | if 'state_dict' in st: 653 | st = st['state_dict'] 654 | print("restroting from delta model from previous version") 655 | st1 = model.state_dict() 656 | for each in st1.keys(): 657 | if each in st.keys(): 658 | print("found common", each) 659 | model.load_state_dict(st, strict=False) 660 | if embed is not None: 661 | print("restoring embedding") 662 | model.cond_stage_model.transformer.text_model.embeddings.token_embedding.weight.data[token_weights.shape[0]: token_weights.shape[0] + embed.shape[0]] = embed 663 | 664 | # trainer and callbacks 665 | trainer_kwargs = dict() 666 | 667 | # default logger configs 668 | default_logger_cfgs = { 669 | "wandb": { 670 | "target": "pytorch_lightning.loggers.WandbLogger", 671 | "params": { 672 | "name": nowname, 673 | "save_dir": logdir, 674 | "offline": opt.debug, 675 | "id": nowname, 676 | } 677 | }, 678 | "testtube": { 679 | "target": "pytorch_lightning.loggers.TestTubeLogger", 680 | "params": { 681 | "name": "testtube", 682 | "save_dir": logdir, 683 | } 684 | }, 685 | } 686 | default_logger_cfg = default_logger_cfgs["testtube"] 687 | if "logger" in lightning_config: 688 | logger_cfg = lightning_config.logger 689 | else: 690 | logger_cfg = OmegaConf.create() 691 | logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) 692 | trainer_kwargs["logger"] = instantiate_from_config(logger_cfg) 693 | 694 | # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to 695 | # specify which metric is used to determine best models 696 | default_modelckpt_cfg = { 697 | "target": "pytorch_lightning.callbacks.ModelCheckpoint", 698 | "params": { 699 | "dirpath": ckptdir, 700 | "filename": "{epoch:06}", 701 | "verbose": True, 702 | "save_last": True, 703 | } 704 | } 705 | if hasattr(model, "monitor"): 706 | print(f"Monitoring {model.monitor} as checkpoint metric.") 707 | default_modelckpt_cfg["params"]["monitor"] = model.monitor 708 | default_modelckpt_cfg["params"]["save_top_k"] = -1 709 | default_modelckpt_cfg["params"]["every_n_epochs"] = 1 710 | 711 | if "modelcheckpoint" in lightning_config: 712 | modelckpt_cfg = lightning_config.modelcheckpoint 713 | else: 714 | modelckpt_cfg = OmegaConf.create() 715 | modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) 716 | print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}") 717 | if version.parse(pl.__version__) < version.parse('1.4.0'): 718 | trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg) 719 | 720 | # add callback which sets up log directory 721 | default_callbacks_cfg = { 722 | "setup_callback": { 723 | "target": "train.SetupCallback", 724 | "params": { 725 | "resume": opt.resume, 726 | "now": now, 727 | "logdir": logdir, 728 | "ckptdir": ckptdir, 729 | "cfgdir": cfgdir, 730 | "config": config, 731 | "lightning_config": lightning_config, 732 | } 733 | }, 734 | "image_logger": { 735 | "target": "train.ImageLogger", 736 | "params": { 737 | "batch_frequency": 750, 738 | "max_images": 4, 739 | "clamp": True 740 | } 741 | }, 742 | "learning_rate_logger": { 743 | "target": "train.LearningRateMonitor", 744 | "params": { 745 | "logging_interval": "step", 746 | # "log_momentum": True 747 | } 748 | }, 749 | "cuda_callback": { 750 | "target": "train.CUDACallback" 751 | }, 752 | } 753 | if version.parse(pl.__version__) >= version.parse('1.4.0'): 754 | default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg}) 755 | 756 | if "callbacks" in lightning_config: 757 | callbacks_cfg = lightning_config.callbacks 758 | else: 759 | callbacks_cfg = OmegaConf.create() 760 | 761 | if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg: 762 | print( 763 | 'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.') 764 | default_metrics_over_trainsteps_ckpt_dict = { 765 | 'metrics_over_trainsteps_checkpoint': 766 | {"target": 'pytorch_lightning.callbacks.ModelCheckpoint', 767 | 'params': { 768 | "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'), 769 | "filename": "{epoch:06}-{step:09}", 770 | "verbose": True, 771 | 'save_top_k': -1, 772 | 'every_n_train_steps': 50, 773 | 'save_weights_only': True 774 | } 775 | } 776 | } 777 | default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict) 778 | 779 | callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) 780 | if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'): 781 | callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint 782 | elif 'ignore_keys_callback' in callbacks_cfg: 783 | del callbacks_cfg['ignore_keys_callback'] 784 | 785 | trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] 786 | 787 | trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) 788 | trainer.logdir = logdir 789 | 790 | # data 791 | data = instantiate_from_config(config.data) 792 | # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html 793 | # calling these ourselves should not be necessary but it is. 794 | # lightning still takes care of proper multiprocessing though 795 | data.prepare_data() 796 | data.setup() 797 | print("#### Data #####") 798 | for k in data.datasets: 799 | print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}") 800 | 801 | # configure learning rate 802 | bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate 803 | if not cpu: 804 | ngpu = len(lightning_config.trainer.gpus.strip(",").split(',')) 805 | else: 806 | ngpu = 1 807 | if 'accumulate_grad_batches' in lightning_config.trainer: 808 | accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches 809 | else: 810 | accumulate_grad_batches = 1 811 | print(f"accumulate_grad_batches = {accumulate_grad_batches}") 812 | lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches 813 | if opt.scale_lr: 814 | model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr 815 | print( 816 | "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format( 817 | model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr)) 818 | else: 819 | model.learning_rate = base_lr 820 | print("++++ NOT USING LR SCALING ++++") 821 | print(f"Setting learning rate to {model.learning_rate:.2e}") 822 | 823 | # allow checkpointing via USR1 824 | def melk(*args, **kwargs): 825 | # run all checkpoint hooks 826 | if trainer.global_rank == 0: 827 | print("Summoning checkpoint.") 828 | ckpt_path = os.path.join(ckptdir, "last.ckpt") 829 | trainer.save_checkpoint(ckpt_path) 830 | 831 | def divein(*args, **kwargs): 832 | if trainer.global_rank == 0: 833 | import pudb 834 | pudb.set_trace() 835 | 836 | import signal 837 | 838 | signal.signal(signal.SIGUSR1, melk) 839 | signal.signal(signal.SIGUSR2, divein) 840 | 841 | # run 842 | if opt.train: 843 | try: 844 | trainer.fit(model, data) 845 | except Exception: 846 | melk() 847 | raise 848 | if not opt.no_test and not trainer.interrupted: 849 | trainer.test(model, data) 850 | except Exception: 851 | if opt.debug and trainer.global_rank == 0: 852 | try: 853 | import pudb as debugger 854 | except ImportError: 855 | import pdb as debugger 856 | debugger.post_mortem() 857 | raise 858 | finally: 859 | # move newly created debug project to debug_runs 860 | if opt.debug and not opt.resume and trainer.global_rank == 0: 861 | dst, name = os.path.split(logdir) 862 | dst = os.path.join(dst, "debug_runs", name) 863 | os.makedirs(os.path.split(dst)[0], exist_ok=True) 864 | os.rename(logdir, dst) 865 | if trainer.global_rank == 0: 866 | print(trainer.profiler.summary()) 867 | --------------------------------------------------------------------------------