├── .gitignore ├── README.md ├── accelerate_config.yaml ├── assets ├── ade20_decomp.png ├── ade20k_training_daam.png ├── art_decomposition.png ├── demo.gif ├── imagenet_decomposition.png ├── supp_ade20k_decomposition.png ├── supp_external_composition.png └── teaser.png ├── compose_models.py ├── create_accelerate_config.py ├── daam_ddim_visualize.py ├── daam_visualize_generation.py ├── datasets.py ├── eval.py ├── inference.py ├── main.py ├── requirements.txt └── scripts ├── evaluate_classification.sh ├── sample.sh ├── train_ADE20K.sh ├── train_art.sh └── train_imagenet.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unsupervised Compositional Concepts Discovery
(ICCV 2023) 2 | 3 | 4 | 5 | 6 | https://github.com/nanlliu/Unsupervised-Compositional-Concepts-Discovery/assets/45443761/ca2504a2-2186-4edd-b37c-2b4e9c503a1b 7 | > Text-to-image generative models have enabled high-resolution image synthesis across different domains, but require users to specify the content they wish to generate. In this paper, we consider the inverse problem -- given a collection of different images, can we discover the generative concepts that represent each image? We present an unsupervised approach to discover generative concepts from a collection of images, disentangling different art styles in paintings, objects, and lighting from kitchen scenes, and discovering image classes given ImageNet images. We show how such generative concepts can accurately represent the content of images, be recombined and composed to generate new artistic and hybrid images, and be further used as a representation for downstream classification tasks. 8 | 9 | [Unsupervised Compositional Concepts Discovery with Text-to-Image Generative Models](https://energy-based-model.github.io/unsupervised-concept-discovery) 10 |
11 | [Nan Liu](https://nanliu.io) 1*, 12 | [Yilun Du](https://yilundu.github.io) 2*, 13 | [Shuang Li](https://people.csail.mit.edu/lishuang) 2*, 14 | [Joshua B. Tenenbaum](https://mitibmwatsonailab.mit.edu/people/joshua-tenenbaum/) 2, 15 | [Antonio Torralba](https://groups.csail.mit.edu/vision/torralbalab/) 2 16 |
17 | * Equal Contribution 18 |
19 | 1UIUC, 2MIT CSAIL 20 |
21 | ICCV 2023 22 |
23 | 24 | 25 | ## Setup 26 | 27 | Run following to create a conda environment, and activate it: 28 | 29 | conda create --name decomp python=3.8 30 | conda activate decomp 31 | 32 | Install Pytorch 1.11, the version we have tested on: 33 | 34 | pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113 35 | 36 | Next, install the required packages: 37 | 38 | pip install -r requirements.txt 39 | 40 | 41 | ## Training 42 | 43 | After downloading images (e.g., ImageNet class folders) into the `repo directory`, you first specify some data arguments: 44 | 45 | model_path="stabilityai/stable-diffusion-2-1-base" 46 | train_data_dir="imagenet,imagenet,imagenet,imagenet,imagenet" 47 | placeholder_tokens=",,,," 48 | class_folder_names="n09288635,n02085620,n02481823,n04204347,n03788195" 49 | learnable_property="object,object,object,object,object" 50 | output_dir=$OUTPUT_DIR 51 | 52 | Then you can train a model on ImageNet $S_1$ by running: 53 | 54 | DEVICE=$CUDA_VISIBLE_DEVICES 55 | python create_accelerate_config.py --gpu_id "${DEVICE}" 56 | accelerate launch --config_file accelerate_config.yaml main.py \ 57 | --pretrained_model_name_or_path "${model_path}" \ 58 | --train_data_dir ${train_data_dir} \ 59 | --placeholder_tokens ${placeholder_tokens} \ 60 | --resolution=512 --class_folder_names ${class_folder_names} \ 61 | --train_batch_size=2 --gradient_accumulation_steps=8 --repeats 1 \ 62 | --learning_rate=5.0e-03 --scale_lr --lr_scheduler="constant" --max_train_steps 3000 \ 63 | --lr_warmup_steps=0 --output_dir ${output_dir} \ 64 | --learnable_property "${learnable_property}" \ 65 | --checkpointing_steps 1000 --mse_coeff 1 --seed 0 \ 66 | --add_weight_per_score \ 67 | --use_conj_score --init_weight 5 \ 68 | --validation_step 1000 \ 69 | --num_iters_per_image 120 --num_images_per_class 5 70 | 71 | See more training details about training in ```scripts``` folder. 72 | 73 | Training dataset for artistic paintings can be found [here](https://www.dropbox.com/sh/g91atmeuy2ihkgn/AAATmXz6zI9H4fLCnsEG-CRka?dl=0). 74 | 75 | ## Inference 76 | ![teaser](assets/teaser.png) 77 | 78 | Once model is trained, we can sample from each concept using following command: 79 | 80 | python inference.py --model_path ${output_dir} --prompts "a photo of " "a photo of " "a photo of " "a photo of " "a photo of " --num_images 64 --bsz 8 81 | 82 | ## Visualization 83 | 84 | You need to download required packages from [DAAM](https://github.com/castorini/daam). 85 | 86 | DDIM + DAAM (Training Data) | DAAM (Generated Data) 87 | :-------------------------:|:-------------------------: 88 | ![decomposition_ade20k_training](assets/ade20k_training_daam.png) | ![decomposition_ade20k](assets/ade20_decomp.png) 89 | 90 | Once concepts are learned, we can generate cross-attention heap maps between training image and learned concepts using diffusion attentive attribution maps ([DAAM](https://github.com/castorini/daam)) for visualization: 91 | 92 | python daam_ddim_visualize.py --model_path ${output_dir} --image_path $IMG_PATH --prompt "a photo of " --keyword "" --scale 7.5 --num_images 1 --seed 0 93 | 94 | Generate images along with heap map visualizations associated with each learned concept using DAAM: 95 | 96 | python daam_visualize_generation.py --model_path ${output_dir} --prompt "a photo of " --keyword "" --scale 7.5 --num_images 1 --seed 0 97 | 98 | ## Evaluation 99 | 100 | After we generate 64 images per concept, we can run following command to evaluate classification accuracy and KL divergence using pre-trained ResNet-50 and CLIP with specified threshold values: 101 | 102 | # CLIP 103 | python eval.py --model_path ${output_dir} --evaluation_metric clip --class_names "geyser" "chihuahua" "chimpanzee" "shopping cart" "mosque" --logit_threshold 0.3 104 | # ResNet-50 105 | python eval.py --model_path ${output_dir} --evaluation_metric resnet --class_names "geyser" "chihuahua" "chimpanzee" "shopping cart" "mosque" --logit_threshold 10 106 | 107 | 108 | ## Unsupervised Decomposition 109 | 110 | ### ImageNet Objects 111 | Our proposed method can discover different object categories from a set of unlabeled images. 112 | ![Imagenet_decomposition](assets/imagenet_decomposition.png) 113 | 114 | ### Scene Concepts 115 | We demonstrate our method can decompose kitchen scenes into multiple sets of factors. 116 | 117 | ![](assets/supp_ade20k_decomposition.png) 118 | 119 | ### Art Concepts 120 | Our method allows unsupervised concept decomposition from just a few paintings. 121 | 122 | ![art decomposition](assets/art_decomposition.png) 123 | 124 | ## Concept Composition 125 | 126 | Discovered concepts can be further combined with existing knowledge (i.e., texts) of the pretrained generative models. 127 | 128 | ![concept composition](assets/supp_external_composition.png) 129 | 130 | ## Todo 131 | 132 | - [ ] Add support to other available models such as deepfloyd and StableDiffusionXL. 133 | 134 | 135 | 136 | ## Citation 137 | @InProceedings{Liu_2023_ICCV, 138 | author = {Liu, Nan and Du, Yilun and Li, Shuang and Tenenbaum, Joshua B. and Torralba, Antonio}, 139 | title = {Unsupervised Compositional Concepts Discovery with Text-to-Image Generative Models}, 140 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 141 | month = {October}, 142 | year = {2023}, 143 | pages = {2085-2095} 144 | } 145 | 146 | Feel free to let us know if we are missing any relevant citations. 147 | -------------------------------------------------------------------------------- /accelerate_config.yaml: -------------------------------------------------------------------------------- 1 | command_file: null 2 | commands: null 3 | compute_environment: LOCAL_MACHINE 4 | deepspeed_config: {} 5 | distributed_type: 'NO' 6 | downcast_bf16: 'no' 7 | dynamo_backend: 'NO' 8 | fsdp_config: {} 9 | gpu_ids: '6' 10 | machine_rank: 0 11 | main_process_ip: null 12 | main_process_port: null 13 | main_training_function: main 14 | megatron_lm_config: {} 15 | mixed_precision: 'no' 16 | num_machines: 1 17 | num_processes: 1 18 | rdzv_backend: static 19 | same_network: true 20 | tpu_name: null 21 | tpu_zone: null 22 | -------------------------------------------------------------------------------- /assets/ade20_decomp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nanlliu/Unsupervised-Compositional-Concepts-Discovery/03f557e747a1608b3ffc77abe1ef691baf2ea054/assets/ade20_decomp.png -------------------------------------------------------------------------------- /assets/ade20k_training_daam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nanlliu/Unsupervised-Compositional-Concepts-Discovery/03f557e747a1608b3ffc77abe1ef691baf2ea054/assets/ade20k_training_daam.png -------------------------------------------------------------------------------- /assets/art_decomposition.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nanlliu/Unsupervised-Compositional-Concepts-Discovery/03f557e747a1608b3ffc77abe1ef691baf2ea054/assets/art_decomposition.png -------------------------------------------------------------------------------- /assets/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nanlliu/Unsupervised-Compositional-Concepts-Discovery/03f557e747a1608b3ffc77abe1ef691baf2ea054/assets/demo.gif -------------------------------------------------------------------------------- /assets/imagenet_decomposition.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nanlliu/Unsupervised-Compositional-Concepts-Discovery/03f557e747a1608b3ffc77abe1ef691baf2ea054/assets/imagenet_decomposition.png -------------------------------------------------------------------------------- /assets/supp_ade20k_decomposition.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nanlliu/Unsupervised-Compositional-Concepts-Discovery/03f557e747a1608b3ffc77abe1ef691baf2ea054/assets/supp_ade20k_decomposition.png -------------------------------------------------------------------------------- /assets/supp_external_composition.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nanlliu/Unsupervised-Compositional-Concepts-Discovery/03f557e747a1608b3ffc77abe1ef691baf2ea054/assets/supp_external_composition.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nanlliu/Unsupervised-Compositional-Concepts-Discovery/03f557e747a1608b3ffc77abe1ef691baf2ea054/assets/teaser.png -------------------------------------------------------------------------------- /compose_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import inspect 4 | import argparse 5 | 6 | from tqdm import tqdm 7 | from PIL import Image 8 | 9 | from diffusers import ( 10 | AutoencoderKL, 11 | DDIMScheduler, 12 | UNet2DConditionModel, 13 | ) 14 | 15 | from transformers import CLIPTextModel, CLIPTokenizer 16 | from typing import List, Optional, Tuple, Union 17 | 18 | 19 | def randn_tensor( 20 | shape: Union[Tuple, List], 21 | generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, 22 | device: Optional["torch.device"] = None, 23 | dtype: Optional["torch.dtype"] = None, 24 | layout: Optional["torch.layout"] = None, 25 | ): 26 | """This is a helper function that allows to create random tensors on the desired `device` with the desired `dtype`. When 27 | passing a list of generators one can seed each batched size individually. If CPU generators are passed the tensor 28 | will always be created on CPU. 29 | """ 30 | # device on which tensor is created defaults to device 31 | rand_device = device 32 | batch_size = shape[0] 33 | 34 | layout = layout or torch.strided 35 | device = device or torch.device("cpu") 36 | 37 | if generator is not None: 38 | gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type 39 | if gen_device_type != device.type and gen_device_type == "cpu": 40 | rand_device = "cpu" 41 | # if device != "mps": 42 | # logger.info( 43 | # f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." 44 | # f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" 45 | # f" slighly speed up this function by passing a generator that was created on the {device} device." 46 | # ) 47 | elif gen_device_type != device.type and gen_device_type == "cuda": 48 | raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") 49 | 50 | if isinstance(generator, list): 51 | shape = (1,) + shape[1:] 52 | latents = [ 53 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) 54 | for i in range(batch_size) 55 | ] 56 | latents = torch.cat(latents, dim=0).to(device) 57 | else: 58 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) 59 | 60 | return latents 61 | 62 | 63 | def get_batched_text_embeddings(tokenizer, text_encoder, prompt, batch_size): 64 | device = text_encoder.device 65 | text_inputs = tokenizer( 66 | prompt, 67 | padding="max_length", 68 | max_length=tokenizer.model_max_length, 69 | truncation=True, 70 | return_tensors="pt", 71 | ) 72 | text_input_ids = text_inputs.input_ids 73 | text_embeddings = text_encoder(text_input_ids.to(device))[0] 74 | bs_embed, seq_len, _ = text_embeddings.shape 75 | text_embeddings = text_embeddings.repeat(1, batch_size, 1).view(bs_embed * batch_size, seq_len, -1) 76 | return text_embeddings 77 | 78 | 79 | def prepare_latents(vae_scale_factor, init_noise_sigma, batch_size, 80 | num_channels_latents, height, width, dtype, device, generator, latents=None): 81 | shape = (batch_size, num_channels_latents, height // vae_scale_factor, width // vae_scale_factor) 82 | if isinstance(generator, list) and len(generator) != batch_size: 83 | raise ValueError( 84 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 85 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 86 | ) 87 | 88 | if latents is None: 89 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 90 | else: 91 | latents = latents.to(device) 92 | 93 | # scale the initial noise by the standard deviation required by the scheduler 94 | latents = latents * init_noise_sigma 95 | return latents 96 | 97 | 98 | def prepare_extra_step_kwargs(generator, scheduler, eta): 99 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 100 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 101 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 102 | # and should be between [0, 1] 103 | 104 | accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys()) 105 | extra_step_kwargs = {} 106 | if accepts_eta: 107 | extra_step_kwargs["eta"] = eta 108 | 109 | # check if the scheduler accepts generator 110 | accepts_generator = "generator" in set(inspect.signature(scheduler.step).parameters.keys()) 111 | if accepts_generator: 112 | extra_step_kwargs["generator"] = generator 113 | return extra_step_kwargs 114 | 115 | 116 | def decode_latents(vae, latents): 117 | latents = 1 / 0.18215 * latents 118 | image = vae.decode(latents).sample 119 | image = (image / 2 + 0.5).clamp(0, 1) 120 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 121 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() 122 | return image 123 | 124 | 125 | def numpy_to_pil(images): 126 | """ 127 | Convert a numpy image or a batch of images to a PIL image. 128 | """ 129 | if images.ndim == 3: 130 | images = images[None, ...] 131 | images = (images * 255).round().astype("uint8") 132 | pil_images = [Image.fromarray(image) for image in images] 133 | 134 | return pil_images 135 | 136 | 137 | def main(): 138 | parser = argparse.ArgumentParser() 139 | parser.add_argument("--model_paths", type=str, nargs="+") 140 | parser.add_argument("--prompts", type=str, nargs="+") 141 | parser.add_argument("--bsz", type=int, default=1) 142 | parser.add_argument("--num_images", type=int, default=1) 143 | parser.add_argument("--steps", type=int, default=50) 144 | parser.add_argument("--scales", type=float, nargs="+") 145 | parser.add_argument("--eta", type=float, default=0) 146 | parser.add_argument("--seed", type=int, default=0) 147 | parser.add_argument("--folder", type=str) 148 | args = parser.parse_args() 149 | 150 | # load models 151 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 152 | generator = torch.Generator(device).manual_seed(args.seed) 153 | 154 | tokenizers, text_encoders, vaes, unets = [], [], [], [] 155 | noise_scheduler = DDIMScheduler.from_pretrained(args.model_paths[0], subfolder="scheduler") 156 | 157 | for model_path in args.model_paths: 158 | tokenizers.append(CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer")) 159 | text_encoders.append(CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder").to(device)) 160 | vaes.append(AutoencoderKL.from_pretrained(model_path, subfolder="vae").to(device)) 161 | unets.append(UNet2DConditionModel.from_pretrained(model_path, subfolder="unet").to(device)) 162 | print(f"finished loading from {model_path}") 163 | 164 | # sampling 165 | batch_size = args.bsz 166 | num_batches = args.num_images // args.bsz 167 | steps = args.steps 168 | scales = args.scales 169 | eta = args.eta 170 | vae_scale_factor = 2 ** (len(vaes[0].config.block_out_channels) - 1) 171 | init_noise_sigma = noise_scheduler.init_noise_sigma 172 | num_channels_latents = unets[0].in_channels 173 | height = unets[0].config.sample_size * vae_scale_factor 174 | width = unets[0].config.sample_size * vae_scale_factor 175 | image_folder = args.folder 176 | os.makedirs(image_folder, exist_ok=True) 177 | 178 | with torch.no_grad(): 179 | for batch_num in range(num_batches): 180 | # 1. set the noise scheduler 181 | noise_scheduler.set_timesteps(args.steps, device=device) 182 | timesteps = noise_scheduler.timesteps 183 | # 2. initialize the latents 184 | latents = prepare_latents( 185 | vae_scale_factor, 186 | init_noise_sigma, 187 | batch_size, 188 | num_channels_latents, 189 | height, 190 | width, 191 | text_encoders[0].dtype, 192 | device, 193 | generator, 194 | latents=None 195 | ) 196 | 197 | frames = [] 198 | num_warmup_steps = len(timesteps) - steps * noise_scheduler.order 199 | with tqdm(total=steps) as progress_bar: 200 | for i, t in enumerate(timesteps): 201 | latent_model_input = torch.cat([latents] * 2) 202 | latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t) 203 | 204 | uncond_scores, cond_scores = [], [] 205 | for prompt, tokenizer, text_encoder, unet, vae in zip(args.prompts, tokenizers, text_encoders, 206 | unets, vaes): 207 | # 3. get the input embeddings 208 | text_embeddings = get_batched_text_embeddings(tokenizer, text_encoder, prompt, batch_size) 209 | null_embeddings = get_batched_text_embeddings(tokenizer, text_encoder, "", batch_size) 210 | input_embeddings = torch.cat((null_embeddings, text_embeddings), dim=0) 211 | 212 | # predict the noise residual 213 | noise_pred = unet(latent_model_input, t, encoder_hidden_states=input_embeddings).sample 214 | uncond_pred_noise, cond_pred_noise = noise_pred.chunk(2) 215 | # save predicted scores 216 | uncond_scores.append(uncond_pred_noise) 217 | cond_scores.append(cond_pred_noise) 218 | 219 | # apply compositional score 220 | composed_noise_pred = sum(uncond_scores) / len(uncond_scores) + sum( 221 | scale * (cond_score - uncond_score) for scale, cond_score, uncond_score in 222 | zip(scales, cond_scores, uncond_scores)) 223 | 224 | # compute the previous noisy sample x_t -> x_t-1 225 | extra_step_kwargs = prepare_extra_step_kwargs(generator, noise_scheduler, eta) 226 | latents = noise_scheduler.step(composed_noise_pred, t, latents, **extra_step_kwargs).prev_sample 227 | 228 | # save intermediate results 229 | decoded_images = decode_latents(vae, latents) 230 | frames.append(decoded_images) 231 | 232 | images = decode_latents(vae, latents) 233 | images = numpy_to_pil(images) 234 | 235 | # call the callback, if provided 236 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % noise_scheduler.order == 0): 237 | progress_bar.update() 238 | 239 | images = decode_latents(vae, latents) 240 | images = numpy_to_pil(images) 241 | 242 | for j, img in enumerate(images): 243 | img_path = os.path.join(image_folder, 244 | f"{args.prompts}_{batch_num * batch_size + j}_{scales}_{args.seed}.png") 245 | img.save(img_path) 246 | 247 | 248 | if __name__ == "__main__": 249 | main() 250 | -------------------------------------------------------------------------------- /create_accelerate_config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument("--gpu_id", type=str, required=True) 6 | parser.add_argument("--distributed", action="store_true", default=False) 7 | args = parser.parse_args() 8 | 9 | dict_file = { 10 | 'command_file': None, 11 | 'commands': None, 12 | 'compute_environment': 'LOCAL_MACHINE', 13 | 'deepspeed_config': {}, 14 | 'distributed_type': 'NO', 15 | 'downcast_bf16': 'no', 16 | 'dynamo_backend': 'NO', 17 | 'fsdp_config': {}, 18 | 'gpu_ids': args.gpu_id, 19 | 'machine_rank': 0, 20 | 'main_process_ip': None, 21 | 'main_process_port': None, 22 | 'main_training_function': 'main', 23 | 'megatron_lm_config': {}, 24 | 'mixed_precision': 'no', 25 | 'num_machines': 1, 26 | 'num_processes': 1, 27 | 'rdzv_backend': 'static', 28 | 'same_network': True, 29 | 'tpu_name': None, 30 | 'tpu_zone': None, 31 | } 32 | 33 | dist_dict_file = { 34 | 'command_file': None, 35 | 'commands': None, 36 | 'compute_environment': 'LOCAL_MACHINE', 37 | 'deepspeed_config': {}, 38 | 'distributed_type': 'MULTI_GPU', 39 | 'downcast_bf16': 'no', 40 | 'dynamo_backend': 'NO', 41 | 'fsdp_config': {}, 42 | 'gpu_ids': args.gpu_id, 43 | 'machine_rank': 0, 44 | 'main_process_ip': None, 45 | 'main_process_port': None, 46 | 'main_training_function': 'main', 47 | 'megatron_lm_config': {}, 48 | 'mixed_precision': 'fp16', 49 | 'num_machines': 1, 50 | 'num_processes': len(args.gpu_id.split(',')), 51 | 'rdzv_backend': 'static', 52 | 'same_network': True, 53 | 'tpu_name': None, 54 | 'tpu_zone': None, 55 | 'use_cpu': False, 56 | } 57 | 58 | if args.distributed: 59 | with open('accelerate_config.yaml', 'w') as file: 60 | documents = yaml.dump(dist_dict_file, file) 61 | else: 62 | with open('accelerate_config.yaml', 'w') as file: 63 | documents = yaml.dump(dict_file, file) 64 | -------------------------------------------------------------------------------- /daam_ddim_visualize.py: -------------------------------------------------------------------------------- 1 | """https://github.com/cccntu/efficient-prompt-to-prompt/blob/main/ddim-inversion.ipynb""" 2 | from functools import partial 3 | from diffusers import StableDiffusionPipeline 4 | 5 | from PIL import Image 6 | from torchvision import transforms 7 | 8 | from typing import Callable, List, Optional, Union 9 | 10 | import torch 11 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer 12 | 13 | from diffusers.models import AutoencoderKL, UNet2DConditionModel 14 | from diffusers.pipeline_utils import DiffusionPipeline 15 | from diffusers.pipelines.stable_diffusion.safety_checker import \ 16 | StableDiffusionSafetyChecker 17 | from diffusers.schedulers import DDIMScheduler,PNDMScheduler, LMSDiscreteScheduler 18 | from diffusers.utils import logging 19 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 20 | import argparse 21 | from daam import trace, set_seed 22 | 23 | 24 | def backward_ddim(x_t, alpha_t, alpha_tm1, eps_xt): 25 | """ from noise to image""" 26 | return ( 27 | alpha_tm1**0.5 28 | * ( 29 | (alpha_t**-0.5 - alpha_tm1**-0.5) * x_t 30 | + ((1 / alpha_tm1 - 1) ** 0.5 - (1 / alpha_t - 1) ** 0.5) * eps_xt 31 | ) 32 | + x_t 33 | ) 34 | 35 | 36 | def forward_ddim(x_t, alpha_t, alpha_tp1, eps_xt): 37 | """ from image to noise, it's the same as backward_ddim""" 38 | return backward_ddim(x_t, alpha_t, alpha_tp1, eps_xt) 39 | 40 | 41 | class DDIMPipeline(DiffusionPipeline): 42 | def __init__( 43 | self, 44 | vae: AutoencoderKL, 45 | text_encoder: CLIPTextModel, 46 | tokenizer: CLIPTokenizer, 47 | unet: UNet2DConditionModel, 48 | scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], 49 | safety_checker: StableDiffusionSafetyChecker = None, 50 | feature_extractor: CLIPFeatureExtractor = None, 51 | ): 52 | super().__init__() 53 | 54 | self.register_modules( 55 | vae=vae, 56 | text_encoder=text_encoder, 57 | tokenizer=tokenizer, 58 | unet=unet, 59 | scheduler=scheduler, 60 | safety_checker=safety_checker, 61 | feature_extractor=feature_extractor, 62 | ) 63 | self.vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) 64 | self.forward_diffusion = partial(self.backward_diffusion, reverse_process=True) 65 | 66 | def run_safety_checker(self, image, device, dtype): 67 | if self.safety_checker is not None: 68 | safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) 69 | image, has_nsfw_concept = self.safety_checker( 70 | images=image, clip_input=safety_checker_input.pixel_values.to(dtype) 71 | ) 72 | else: 73 | has_nsfw_concept = None 74 | return image, has_nsfw_concept 75 | 76 | @torch.inference_mode() 77 | def get_text_embedding(self, prompt): 78 | text_input_ids = self.tokenizer( 79 | prompt, 80 | padding="max_length", 81 | truncation=True, 82 | max_length=self.tokenizer.model_max_length, 83 | return_tensors="pt", 84 | ).input_ids 85 | text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] 86 | return text_embeddings 87 | 88 | @torch.inference_mode() 89 | def get_image_latents(self, image, sample=True, rng_generator=None): 90 | encoding_dist = self.vae.encode(image).latent_dist 91 | if sample: 92 | encoding = encoding_dist.sample(generator=rng_generator) 93 | else: 94 | encoding = encoding_dist.mode() 95 | latents = encoding * 0.18215 96 | return latents 97 | 98 | @torch.inference_mode() 99 | def backward_diffusion( 100 | self, 101 | use_old_emb_i=25, 102 | prompt=None, 103 | text_embeddings=None, 104 | old_text_embeddings=None, 105 | new_text_embeddings=None, 106 | latents: Optional[torch.FloatTensor] = None, 107 | num_inference_steps: int = 50, 108 | guidance_scale: float = 7.5, 109 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 110 | callback_steps: Optional[int] = 1, 111 | reverse_process: True = False, 112 | **kwargs, 113 | ): 114 | """ Generate image from text prompt and latents 115 | """ 116 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 117 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 118 | # corresponds to doing no classifier free guidance. 119 | do_classifier_free_guidance = guidance_scale > 1.0 120 | if text_embeddings is None: 121 | text_embeddings = self._encode_prompt(prompt, device=self.device, 122 | num_images_per_prompt=1, 123 | do_classifier_free_guidance=do_classifier_free_guidance) 124 | # set timesteps 125 | self.scheduler.set_timesteps(num_inference_steps) 126 | # Some schedulers like PNDM have timesteps as arrays 127 | # It's more optimized to move all timesteps to correct device beforehand 128 | timesteps_tensor = self.scheduler.timesteps.to(self.device) 129 | # scale the initial noise by the standard deviation required by the scheduler 130 | latents = latents * self.scheduler.init_noise_sigma 131 | 132 | if old_text_embeddings is not None and new_text_embeddings is not None: 133 | prompt_to_prompt = True 134 | else: 135 | prompt_to_prompt = False 136 | 137 | for i, t in enumerate( 138 | self.progress_bar(timesteps_tensor if not reverse_process else reversed(timesteps_tensor))): 139 | if prompt_to_prompt: 140 | if i < use_old_emb_i: 141 | text_embeddings = old_text_embeddings 142 | else: 143 | text_embeddings = new_text_embeddings 144 | 145 | # expand the latents if we are doing classifier free guidance 146 | latent_model_input = ( 147 | torch.cat([latents] * 2) if do_classifier_free_guidance else latents 148 | ) 149 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 150 | 151 | # predict the noise residual 152 | noise_pred = self.unet( 153 | latent_model_input, t, encoder_hidden_states=text_embeddings 154 | ).sample 155 | 156 | # perform guidance 157 | if do_classifier_free_guidance: 158 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 159 | noise_pred = noise_pred_uncond + guidance_scale * ( 160 | noise_pred_text - noise_pred_uncond 161 | ) 162 | 163 | prev_timestep = ( 164 | t 165 | - self.scheduler.config.num_train_timesteps 166 | // self.scheduler.num_inference_steps 167 | ) 168 | # call the callback, if provided 169 | if callback is not None and i % callback_steps == 0: 170 | callback(i, t, latents) 171 | 172 | # ddim 173 | alpha_prod_t = self.scheduler.alphas_cumprod[t] 174 | alpha_prod_t_prev = ( 175 | self.scheduler.alphas_cumprod[prev_timestep] 176 | if prev_timestep >= 0 177 | else self.scheduler.final_alpha_cumprod 178 | ) 179 | if reverse_process: 180 | alpha_prod_t, alpha_prod_t_prev = alpha_prod_t_prev, alpha_prod_t 181 | latents = backward_ddim( 182 | x_t=latents, 183 | alpha_t=alpha_prod_t, 184 | alpha_tm1=alpha_prod_t_prev, 185 | eps_xt=noise_pred, 186 | ) 187 | return latents 188 | 189 | @torch.inference_mode() 190 | def decode_image(self, latents: torch.FloatTensor, **kwargs) -> List["PIL_IMAGE"]: 191 | scaled_latents = 1 / 0.18215 * latents 192 | image = [ 193 | self.vae.decode(scaled_latents[i: i + 1]).sample for i in range(len(latents)) 194 | ] 195 | image = torch.cat(image, dim=0) 196 | return image 197 | 198 | @torch.inference_mode() 199 | def torch_to_numpy(self, image) -> List["PIL_IMAGE"]: 200 | image = (image / 2 + 0.5).clamp(0, 1) 201 | image = image.cpu().permute(0, 2, 3, 1).numpy() 202 | return image 203 | 204 | def _encode_prompt( 205 | self, 206 | prompt, 207 | device, 208 | num_images_per_prompt, 209 | do_classifier_free_guidance, 210 | negative_prompt=None, 211 | prompt_embeds: Optional[torch.FloatTensor] = None, 212 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 213 | ): 214 | r""" 215 | Encodes the prompt into text encoder hidden states. 216 | Args: 217 | prompt (`str` or `List[str]`, *optional*): 218 | prompt to be encoded 219 | device: (`torch.device`): 220 | torch device 221 | num_images_per_prompt (`int`): 222 | number of images that should be generated per prompt 223 | do_classifier_free_guidance (`bool`): 224 | whether to use classifier free guidance or not 225 | negative_prompt (`str` or `List[str]`, *optional*): 226 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 227 | `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. 228 | Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). 229 | prompt_embeds (`torch.FloatTensor`, *optional*): 230 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 231 | provided, text embeddings will be generated from `prompt` input argument. 232 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 233 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 234 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 235 | argument. 236 | """ 237 | if prompt is not None and isinstance(prompt, str): 238 | batch_size = 1 239 | elif prompt is not None and isinstance(prompt, list): 240 | batch_size = len(prompt) 241 | else: 242 | batch_size = prompt_embeds.shape[0] 243 | 244 | if prompt_embeds is None: 245 | text_inputs = self.tokenizer( 246 | prompt, 247 | padding="max_length", 248 | max_length=self.tokenizer.model_max_length, 249 | truncation=True, 250 | return_tensors="pt", 251 | ) 252 | text_input_ids = text_inputs.input_ids 253 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 254 | 255 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( 256 | text_input_ids, untruncated_ids 257 | ): 258 | removed_text = self.tokenizer.batch_decode( 259 | untruncated_ids[:, self.tokenizer.model_max_length - 1: -1] 260 | ) 261 | logger.warning( 262 | "The following part of your input was truncated because CLIP can only handle sequences up to" 263 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 264 | ) 265 | 266 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 267 | attention_mask = text_inputs.attention_mask.to(device) 268 | else: 269 | attention_mask = None 270 | 271 | prompt_embeds = self.text_encoder( 272 | text_input_ids.to(device), 273 | attention_mask=attention_mask, 274 | ) 275 | prompt_embeds = prompt_embeds[0] 276 | 277 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 278 | 279 | bs_embed, seq_len, _ = prompt_embeds.shape 280 | # duplicate text embeddings for each generation per prompt, using mps friendly method 281 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 282 | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) 283 | 284 | # get unconditional embeddings for classifier free guidance 285 | if do_classifier_free_guidance and negative_prompt_embeds is None: 286 | uncond_tokens: List[str] 287 | if negative_prompt is None: 288 | uncond_tokens = [""] * batch_size 289 | elif type(prompt) is not type(negative_prompt): 290 | raise TypeError( 291 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 292 | f" {type(prompt)}." 293 | ) 294 | elif isinstance(negative_prompt, str): 295 | uncond_tokens = [negative_prompt] 296 | elif batch_size != len(negative_prompt): 297 | raise ValueError( 298 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 299 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 300 | " the batch size of `prompt`." 301 | ) 302 | else: 303 | uncond_tokens = negative_prompt 304 | 305 | max_length = prompt_embeds.shape[1] 306 | uncond_input = self.tokenizer( 307 | uncond_tokens, 308 | padding="max_length", 309 | max_length=max_length, 310 | truncation=True, 311 | return_tensors="pt", 312 | ) 313 | 314 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 315 | attention_mask = uncond_input.attention_mask.to(device) 316 | else: 317 | attention_mask = None 318 | 319 | negative_prompt_embeds = self.text_encoder( 320 | uncond_input.input_ids.to(device), 321 | attention_mask=attention_mask, 322 | ) 323 | negative_prompt_embeds = negative_prompt_embeds[0] 324 | 325 | if do_classifier_free_guidance: 326 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 327 | seq_len = negative_prompt_embeds.shape[1] 328 | 329 | negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 330 | 331 | negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) 332 | negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 333 | 334 | # For classifier free guidance, we need to do two forward passes. 335 | # Here we concatenate the unconditional and text embeddings into a single batch 336 | # to avoid doing two forward passes 337 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) 338 | 339 | return prompt_embeds 340 | 341 | 342 | def load_img(path, target_size=512): 343 | """Load an image, resize and output -1..1""" 344 | image = Image.open(path).convert("RGB") 345 | 346 | tform = transforms.Compose( 347 | [ 348 | transforms.Resize(target_size), 349 | transforms.CenterCrop(target_size), 350 | transforms.ToTensor(), 351 | ] 352 | ) 353 | image = tform(image) 354 | return 2.0 * image - 1.0 355 | 356 | 357 | def latents_to_imgs(latents): 358 | x = pipe.decode_image(latents) 359 | x = pipe.torch_to_numpy(x) 360 | x = pipe.numpy_to_pil(x) 361 | return x 362 | 363 | 364 | if __name__ == '__main__': 365 | parser = argparse.ArgumentParser() 366 | parser.add_argument("--model_path", type=str) 367 | parser.add_argument("--image_path", type=str) 368 | parser.add_argument("--prompt", type=str) 369 | parser.add_argument("--keyword", type=str) 370 | parser.add_argument("--seed", type=int, default=0) 371 | parser.add_argument("--scale", type=float, default=1) 372 | args = parser.parse_args() 373 | 374 | pipe = StableDiffusionPipeline.from_pretrained(args.model_path) 375 | pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) 376 | pipe = pipe.to("cuda") 377 | 378 | img = load_img(args.image_path).unsqueeze(0).to("cuda") 379 | pipe2 = DDIMPipeline( 380 | vae=pipe.vae, 381 | text_encoder=pipe.text_encoder, 382 | tokenizer=pipe.tokenizer, 383 | unet=pipe.unet, 384 | scheduler=pipe.scheduler, 385 | safety_checker=pipe.safety_checker 386 | ) 387 | pipe = pipe2 388 | 389 | prompt = args.prompt 390 | text_embeddings = pipe.get_text_embedding(prompt) 391 | image_latents = pipe.get_image_latents(img, rng_generator=torch.Generator(device=pipe.device).manual_seed(args.seed)) 392 | reversed_latents = pipe.forward_diffusion( 393 | latents=image_latents, 394 | text_embeddings=text_embeddings, 395 | guidance_scale=1, 396 | num_inference_steps=50, 397 | ) 398 | 399 | with torch.cuda.amp.autocast(dtype=torch.float16), torch.no_grad(): 400 | with trace(pipe) as tc: 401 | reconstructed_latents = pipe.backward_diffusion( 402 | latents=reversed_latents, 403 | prompt=prompt, 404 | guidance_scale=args.scale, 405 | num_inference_steps=50, 406 | ) 407 | img = latents_to_imgs(reconstructed_latents)[0] 408 | img.save(f"{prompt}_{args.seed}_reconstruction.png") 409 | heat_map = tc.compute_global_heat_map() 410 | heat_map = heat_map.compute_word_heat_map(keyword) 411 | heat_map.plot_overlay(img, out_file=f"{keyword}_heatmap_{args.seed}.png", word=None) 412 | img.save(f"{prompt}_{args.seed}.png") 413 | -------------------------------------------------------------------------------- /daam_visualize_generation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import argparse 5 | 6 | from PIL import Image 7 | from daam import trace, set_seed 8 | from matplotlib import pyplot as plt 9 | from diffusers import StableDiffusionPipeline, DDIMScheduler 10 | 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--model_path", type=str) 14 | parser.add_argument("--prompt", type=str) 15 | parser.add_argument("--keyword", type=str) 16 | parser.add_argument("--scale", type=float, default=7.5) 17 | parser.add_argument("--num_images", type=int, default=1) 18 | args = parser.parse_args() 19 | 20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 21 | pipe = StableDiffusionPipeline.from_pretrained(args.model_path).to(device) 22 | pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) 23 | pipe.safety_checker = None 24 | pipe = pipe.to(device) 25 | 26 | prompt = args.prompt 27 | keyword = args.keyword 28 | gen = set_seed(0) # for reproducibility 29 | 30 | 31 | def numpy_to_pil(images): 32 | """ 33 | Convert a numpy image or a batch of images to a PIL image. 34 | """ 35 | if images.ndim == 3: 36 | images = images[None, ...] 37 | images = (images * 255).round().astype("uint8") 38 | pil_images = [Image.fromarray(image) for image in images] 39 | 40 | return pil_images 41 | 42 | 43 | def daam_visualize(): 44 | with torch.cuda.amp.autocast(dtype=torch.float16), torch.no_grad(): 45 | with trace(pipe) as tc: 46 | for i in range(args.num_images): 47 | out, _ = pipe(prompt, num_inference_steps=50, generator=gen, guidance_scale=args.scale) 48 | img = out.images[0] 49 | heat_map = tc.compute_global_heat_map() 50 | heat_map = heat_map.compute_word_heat_map(keyword) 51 | heat_map.plot_overlay(img, out_file=f"{keyword}_heatmap_{i}.png", word=None) 52 | img.save(f"{prompt}_{i}.png") 53 | 54 | 55 | if __name__ == '__main__': 56 | daam_visualize() 57 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import json 4 | import torch 5 | import numpy as np 6 | import os.path 7 | import pickle 8 | 9 | from PIL import Image 10 | from torch.utils.data import Dataset 11 | from torchvision import transforms 12 | from transformers import CLIPTokenizer 13 | 14 | from typing import Any, Callable, Optional, Tuple 15 | from torchvision.datasets.vision import VisionDataset 16 | from torchvision.datasets.utils import check_integrity, download_and_extract_archive 17 | 18 | PIL_INTERPOLATION = { 19 | "linear": Image.Resampling.BILINEAR, 20 | "bilinear": Image.Resampling.BILINEAR, 21 | "bicubic": Image.Resampling.BICUBIC, 22 | "lanczos": Image.Resampling.LANCZOS, 23 | "nearest": Image.Resampling.NEAREST, 24 | } 25 | 26 | imagenet_templates_small = [ 27 | "a photo of a {}", 28 | "a rendering of a {}", 29 | "a cropped photo of the {}", 30 | "the photo of a {}", 31 | "a photo of a clean {}", 32 | "a photo of a dirty {}", 33 | "a dark photo of the {}", 34 | "a photo of my {}", 35 | "a photo of the cool {}", 36 | "a close-up photo of a {}", 37 | "a bright photo of the {}", 38 | "a cropped photo of a {}", 39 | "a photo of the {}", 40 | "a good photo of the {}", 41 | "a photo of one {}", 42 | "a close-up photo of the {}", 43 | "a rendition of the {}", 44 | "a photo of the clean {}", 45 | "a rendition of a {}", 46 | "a photo of a nice {}", 47 | "a good photo of a {}", 48 | "a photo of the nice {}", 49 | "a photo of the small {}", 50 | "a photo of the weird {}", 51 | "a photo of the large {}", 52 | "a photo of a cool {}", 53 | "a photo of a small {}", 54 | ] 55 | 56 | imagenet_style_templates_small = [ 57 | "a painting in the style of {}", 58 | "a rendering in the style of {}", 59 | "a cropped painting in the style of {}", 60 | "the painting in the style of {}", 61 | "a clean painting in the style of {}", 62 | "a dirty painting in the style of {}", 63 | "a dark painting in the style of {}", 64 | "a picture in the style of {}", 65 | "a cool painting in the style of {}", 66 | "a close-up painting in the style of {}", 67 | "a bright painting in the style of {}", 68 | "a cropped painting in the style of {}", 69 | "a good painting in the style of {}", 70 | "a close-up painting in the style of {}", 71 | "a rendition in the style of {}", 72 | "a nice painting in the style of {}", 73 | "a small painting in the style of {}", 74 | "a weird painting in the style of {}", 75 | "a large painting in the style of {}", 76 | ] 77 | 78 | 79 | class ComposableDataset(Dataset): 80 | def __init__( 81 | self, 82 | data_root, 83 | tokenizer, 84 | size=512, 85 | repeats=100, 86 | interpolation="bicubic", 87 | flip_p=0.5, 88 | set="train", 89 | placeholder_tokens="", 90 | center_crop=False, 91 | num_images_per_class=-1, 92 | class_folder_names="", 93 | learnable_property="", 94 | ): 95 | self.data_root = [x.strip() for x in data_root.split(",")] 96 | self.class_folder_names = [x.strip() for x in class_folder_names.split(",")] 97 | 98 | self.tokenizer = tokenizer 99 | self.size = size 100 | self.placeholder_tokens = [x.strip() for x in placeholder_tokens.split(",")] 101 | self.placeholder_tokens_ids = tokenizer.convert_tokens_to_ids(self.placeholder_tokens) 102 | 103 | self.center_crop = center_crop 104 | self.flip_p = flip_p 105 | 106 | # use textual inversion template - assume objects 107 | self.learnable_property = (x.strip() for x in learnable_property.split(",")) 108 | self.templates = [imagenet_templates_small if x == "object" else imagenet_style_templates_small 109 | for x in self.learnable_property] 110 | self.use_template = learnable_property != "" 111 | 112 | # combine all folders into a single folder 113 | self.image_paths, self.classes = [], [] 114 | total_images = max(len(self.placeholder_tokens) * num_images_per_class, 115 | len(self.class_folder_names) * num_images_per_class) 116 | 117 | images_per_folder = total_images // len(self.class_folder_names) 118 | 119 | for class_id, class_name in enumerate(self.class_folder_names): 120 | folder = os.path.join(self.data_root[class_id], class_name) 121 | folder_image_paths = [os.path.join(folder, file_name) for file_name in os.listdir(folder)] 122 | # reduce the size of images from each category if specified 123 | if num_images_per_class != -1: 124 | train_image_path = folder_image_paths[:images_per_folder] 125 | test_image_path = folder_image_paths[images_per_folder:2 * images_per_folder] 126 | if set == "train": 127 | folder_image_paths = train_image_path 128 | else: 129 | folder_image_paths = test_image_path 130 | 131 | self.image_paths.extend(folder_image_paths) 132 | self.classes.extend([class_id] * len(folder_image_paths)) 133 | 134 | # size is the total images from different folders 135 | self.num_images = len(self.image_paths) 136 | self._length = self.num_images 137 | 138 | print("placeholder_tokens: ", self.placeholder_tokens) 139 | print("placeholder_tokens_ids: ", self.placeholder_tokens_ids) 140 | print("the number of images in this dataset: ", self.num_images) 141 | print("the flag for using the template: ", self.use_template) 142 | 143 | if set == "train": 144 | self._length = self.num_images * repeats 145 | 146 | self.interpolation = { 147 | "linear": PIL_INTERPOLATION["linear"], 148 | "bilinear": PIL_INTERPOLATION["bilinear"], 149 | "bicubic": PIL_INTERPOLATION["bicubic"], 150 | "lanczos": PIL_INTERPOLATION["lanczos"], 151 | }[interpolation] 152 | 153 | self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p) 154 | 155 | def __len__(self): 156 | return self._length 157 | 158 | def __getitem__(self, i): 159 | idx = i % self.num_images 160 | example = dict() 161 | image = Image.open(self.image_paths[idx]) 162 | # image.save(f"{self.placeholder_tokens[self.classes[idx]]}_{i}.png") 163 | 164 | if not image.mode == "RGB": 165 | image = image.convert("RGB") 166 | 167 | if self.use_template: 168 | text = [random.choice(self.templates[self.classes[idx]]).format(x) for x in self.placeholder_tokens] 169 | else: 170 | text = self.placeholder_tokens # use token itself as the caption (unsupervised) 171 | 172 | # encode all classes since we will use all of them to compute composed score 173 | example["input_ids"] = self.tokenizer( 174 | text, 175 | padding="max_length", 176 | truncation=True, 177 | max_length=self.tokenizer.model_max_length, 178 | return_tensors="pt", 179 | ).input_ids 180 | 181 | # default to score-sde preprocessing 182 | img = np.array(image).astype(np.uint8) 183 | 184 | if self.center_crop: 185 | crop = min(img.shape[0], img.shape[1]) 186 | h, w, = ( 187 | img.shape[0], 188 | img.shape[1], 189 | ) 190 | img = img[(h - crop) // 2: (h + crop) // 2, (w - crop) // 2: (w + crop) // 2] 191 | 192 | image = Image.fromarray(img) 193 | image = image.resize((self.size, self.size), resample=self.interpolation) 194 | 195 | image = self.flip_transform(image) 196 | image = np.array(image).astype(np.uint8) 197 | image = (image / 127.5 - 1.0).astype(np.float32) 198 | 199 | example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) 200 | example["gt_weight_id"] = idx 201 | example["image_path"] = self.image_paths[idx] 202 | example["image_index"] = idx 203 | return example 204 | 205 | 206 | class ClassificationDataset(Dataset): 207 | def __init__( 208 | self, 209 | image_path_map, 210 | learned_weights, 211 | image_path_to_class, 212 | tokenizer, 213 | encoder, 214 | size=512, 215 | repeats=100, 216 | interpolation="bicubic", 217 | flip_p=0.5, 218 | set="train", 219 | placeholder_tokens="", 220 | center_crop=False, 221 | learnable_property="", 222 | ): 223 | self.image_path_map = image_path_map 224 | self.learned_weights = learned_weights 225 | self.image_path_to_class = image_path_to_class 226 | 227 | self.tokenizer = tokenizer 228 | self.encoder = encoder 229 | 230 | self.size = size 231 | self.placeholder_tokens = [x.strip() for x in placeholder_tokens.split(",")] 232 | self.placeholder_tokens_ids = tokenizer.convert_tokens_to_ids(self.placeholder_tokens) 233 | 234 | self.center_crop = center_crop 235 | self.flip_p = flip_p 236 | 237 | # use textual inversion template - assume objects 238 | self.learnable_property = (x.strip() for x in learnable_property.split(",")) 239 | self.templates = [imagenet_templates_small if x == "object" else imagenet_style_templates_small 240 | for x in self.learnable_property] 241 | self.use_template = learnable_property != "" 242 | 243 | # size is the total images from different folders 244 | self.num_images = len(self.image_path_map) 245 | self._length = self.num_images 246 | 247 | print("placeholder_tokens: ", self.placeholder_tokens) 248 | print("placeholder_tokens_ids: ", self.placeholder_tokens_ids) 249 | print("the number of images in this dataset: ", self.num_images) 250 | print("the flag for using the template: ", self.use_template) 251 | 252 | if set == "train": 253 | self._length = self.num_images * repeats 254 | 255 | self.interpolation = { 256 | "linear": PIL_INTERPOLATION["linear"], 257 | "bilinear": PIL_INTERPOLATION["bilinear"], 258 | "bicubic": PIL_INTERPOLATION["bicubic"], 259 | "lanczos": PIL_INTERPOLATION["lanczos"], 260 | }[interpolation] 261 | 262 | self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p) 263 | 264 | def __len__(self): 265 | return self._length 266 | 267 | def __getitem__(self, i): 268 | idx = i % self.num_images 269 | image = Image.open(self.image_path_map[idx]) 270 | 271 | if not image.mode == "RGB": 272 | image = image.convert("RGB") 273 | 274 | if self.use_template: 275 | text = [random.choice(self.templates[self.classes[idx]]).format(x) for x in self.placeholder_tokens] 276 | else: 277 | text = self.placeholder_tokens # use token itself as the caption (unsupervised) 278 | 279 | # default to score-sde preprocessing 280 | img = np.array(image).astype(np.uint8) 281 | 282 | if self.center_crop: 283 | crop = min(img.shape[0], img.shape[1]) 284 | h, w, = ( 285 | img.shape[0], 286 | img.shape[1], 287 | ) 288 | img = img[(h - crop) // 2: (h + crop) // 2, (w - crop) // 2: (w + crop) // 2] 289 | 290 | image = Image.fromarray(img) 291 | image = image.resize((self.size, self.size), resample=self.interpolation) 292 | 293 | image = self.flip_transform(image) 294 | image = np.array(image).astype(np.uint8) 295 | image = (image / 127.5 - 1.0).astype(np.float32) 296 | 297 | # encode all classes since we will use all of them to compute composed score 298 | input_ids = self.tokenizer( 299 | text, 300 | padding="max_length", 301 | truncation=True, 302 | max_length=self.tokenizer.model_max_length, 303 | return_tensors="pt", 304 | ).input_ids 305 | 306 | example = {} 307 | example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) 308 | example["embeddings"] = self.encoder(input_ids)[0] 309 | example["weights"] = self.learned_weights[idx] 310 | example["class"] = self.image_path_to_class[idx] 311 | return example 312 | 313 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | modified based upon https://github.com/rinongal/textual_inversion/blob/main/evaluation/clip_eval.py 3 | """ 4 | import argparse 5 | import os 6 | 7 | import clip 8 | import torch 9 | from torchvision import transforms 10 | from resnet import resnet50 11 | from PIL import Image 12 | 13 | from typing import List 14 | from torchmetrics.functional import kl_divergence 15 | 16 | 17 | class CLIPEvalutor: 18 | def __init__(self, target_classes_names: List, device, clip_model='ViT-B/32'): 19 | self.device = device 20 | self.model, self.clip_preprocess = clip.load(clip_model, device=self.device) 21 | self.texts = [f"a photo of {class_name}" for class_name in target_classes_names] 22 | 23 | def tokenize(self, strings: list): 24 | return clip.tokenize(strings).to(self.device) 25 | 26 | @torch.no_grad() 27 | def encode_text(self, tokens: list) -> torch.Tensor: 28 | return self.model.encode_text(tokens) 29 | 30 | @torch.no_grad() 31 | def get_text_features(self, text, norm: bool = True) -> torch.Tensor: 32 | tokens = clip.tokenize(text).to(self.device) 33 | text_features = self.encode_text(tokens) 34 | if norm: 35 | text_features /= text_features.norm(dim=-1, keepdim=True) 36 | return text_features 37 | 38 | @torch.no_grad() 39 | def encode_images(self, image) -> torch.Tensor: 40 | images = self.clip_preprocess(image).to(self.device).unsqueeze(dim=0) 41 | return self.model.encode_image(images) 42 | 43 | def evaluate(self, img_folder, threshold=0.8): 44 | images_path = [os.path.join(img_folder, filename) for filename in os.listdir(img_folder)] 45 | with torch.no_grad(): 46 | image_features = torch.cat([self.encode_images(Image.open(path)) for path in images_path], dim=0) 47 | image_features /= image_features.norm(dim=-1, keepdim=True) 48 | text_features = self.get_text_features(self.texts, norm=True) 49 | similarity = image_features @ text_features.T 50 | 51 | max_vals, indices = torch.max(similarity, dim=1) 52 | # bincount only computes frequency for non-negative values 53 | counts = torch.bincount(indices, minlength=text_features.shape[0]) 54 | coverages = [] 55 | # when computing the coverage, we don't threshold the values 56 | for i, caption in enumerate(self.texts): 57 | # print(f"class: {caption} | " 58 | # f"coverage: {counts[i] / image_features.shape[0] * 100}%") 59 | coverages.append(counts[i] / image_features.shape[0]) 60 | # if the probability of predicted class is < threshold, count it as misclassified 61 | num_misclassified_examples = torch.sum(max_vals < threshold) 62 | # p = target distribution, q = modeled distribution 63 | p = torch.tensor([1 / len(coverages)] * len(coverages)) 64 | q = torch.tensor(coverages) 65 | acc = (sum(counts) - num_misclassified_examples) / image_features.shape[0] 66 | kl_entropy = kl_divergence(p[None], q[None]) 67 | print(f"{img_folder}, Avg Acc: {100 * acc.item():.2f}, KL entropy: {kl_entropy.item():.4f}") 68 | 69 | 70 | class ResNetEvaluator: 71 | def __init__(self, target_classes_names: List, device): 72 | # replace batch norm with identity function 73 | self.device = device 74 | self.model = resnet50(pretrained=True, progress=True).to(self.device) 75 | # disable batch norm 76 | self.model.eval() 77 | self.transform = transforms.Compose([ 78 | transforms.CenterCrop(224), 79 | transforms.ToTensor(), 80 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 81 | std=[0.229, 0.224, 0.225]), 82 | ]) 83 | 84 | with open('imagenet_classes.txt') as f: 85 | self.labels = [line.strip() for line in f.readlines()] 86 | 87 | # if given class names, logits will be selected based on the given class names for evaluations 88 | if target_classes_names: 89 | self.target_labels = [] 90 | self.target_label_names = [] 91 | for i, label_name in enumerate(self.labels): 92 | for target_class in target_classes_names: 93 | if target_class.lower() in label_name.lower(): 94 | self.target_labels.append(i) 95 | self.target_label_names.append(label_name) 96 | 97 | assert len(self.target_labels) == len(target_classes_names), \ 98 | "the number of found labels are not the same as the given ones" 99 | print(self.target_labels) 100 | print(self.target_label_names) 101 | self.target_labels = sorted(self.target_labels) 102 | else: 103 | self.target_labels = list(range(1000)) 104 | 105 | def evaluate(self, img_folder, threshold=0.8): 106 | images_path = [os.path.join(img_folder, filename) for filename in os.listdir(img_folder)] 107 | images = torch.stack([self.transform(Image.open(path)) for path in images_path], dim=0).to(self.device) 108 | # predictions 109 | with torch.no_grad(): 110 | pred = self.model(images)[:, self.target_labels] 111 | max_vals, indices = torch.max(pred, dim=1) 112 | # bincount only computes frequency for non-negative values 113 | counts = torch.bincount(indices, minlength=len(self.target_labels)) 114 | coverages = [] 115 | # when computing the coverage, we don't threshold the values 116 | for i, index in enumerate(self.target_labels): 117 | # print(f"class: {self.labels[index]} | " 118 | # f"coverage: {counts[i] / pred.shape[0] * 100}%") 119 | coverages.append((counts[i] + 1e-8) / pred.shape[0]) 120 | 121 | # if the probability of predicted class is < threshold, count it as misclassified 122 | num_misclassified_examples = torch.sum(max_vals < threshold) 123 | # p = target distribution, q = modeled distribution 124 | p = torch.tensor([1 / len(coverages)] * len(coverages)) 125 | q = torch.tensor(coverages) 126 | acc = (sum(counts) - num_misclassified_examples) / pred.shape[0] 127 | kl_entropy = kl_divergence(p[None], q[None]) 128 | print(f"{img_folder}, Avg Acc: {100 * acc.item():.2f}, KL entropy: {kl_entropy.item():.4f}") 129 | 130 | 131 | if __name__ == '__main__': 132 | parser = argparse.ArgumentParser() 133 | parser.add_argument("--model_path", type=str) 134 | parser.add_argument("--seed", type=int, default=0) 135 | parser.add_argument("--scale", type=float, default=7.5) 136 | parser.add_argument("--steps", type=int, default=50) 137 | parser.add_argument("--samples", type=int, default=64) 138 | parser.add_argument("--model", type=str, choices=["textual_inversion", "ours"]) 139 | parser.add_argument("--evaluation_metric", choices=["clip", 'resnet']) 140 | parser.add_argument("--class_names", type=str, nargs="+") 141 | parser.add_argument("--logit_threshold", type=float) 142 | args = parser.parse_args() 143 | 144 | image_folder = os.path.join(args.model_path, "samples") 145 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 146 | if args.evaluation_metric == "clip": 147 | # build a ground truth captions for clip score 148 | evaluator = CLIPEvalutor(args.class_names, device=device, clip_model='ViT-B/32') 149 | # create folder where generated images are saved 150 | save_img_folder = os.path.join(args.model_path, "samples") 151 | evaluator.evaluate(image_folder, threshold=args.logit_threshold) 152 | else: 153 | evaluator = ResNetEvaluator(args.class_names, device=device) 154 | evaluator.evaluate(image_folder, threshold=args.logit_threshold) 155 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | 5 | from diffusers import StableDiffusionPipeline, DDIMScheduler 6 | 7 | 8 | if __name__ == '__main__': 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--model_path", type=str) 11 | parser.add_argument("--prompts", type=str, nargs="+") 12 | parser.add_argument("--seed", type=int, default=0) 13 | parser.add_argument("--weights", type=str, default="7.5") 14 | parser.add_argument("--num_images", type=int, default=1) 15 | parser.add_argument("--bsz", type=int, default=1) 16 | parser.add_argument("--scale", type=float, default=7.5) 17 | parser.add_argument("--folder_name", type=str, default="samples") 18 | args = parser.parse_args() 19 | 20 | model_id = args.model_path 21 | pipe = StableDiffusionPipeline.from_pretrained(model_id).to("cuda") 22 | pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) 23 | pipe.safety_checker = None 24 | 25 | folder = os.path.join(args.model_path, args.folder_name) 26 | os.makedirs(folder, exist_ok=True) 27 | 28 | generator = torch.Generator("cuda").manual_seed(args.seed) 29 | prompts = args.prompts 30 | weights = args.weights 31 | 32 | if prompts: 33 | batch_size = args.bsz 34 | num_batches = args.num_images // batch_size 35 | for prompt in prompts: 36 | for i in range(num_batches): 37 | image_list = pipe(prompt, num_inference_steps=50, guidance_scale=args.scale, 38 | generator=generator, num_images_per_prompt=batch_size) 39 | images = image_list.images 40 | for j, img in enumerate(images): 41 | img.save(os.path.join(folder, f"{prompt}_{weights}_{args.seed}_{i * batch_size + j}.png")) 42 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import math 4 | import os 5 | import random 6 | import json 7 | 8 | from pathlib import Path 9 | from typing import Optional 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn.functional 14 | import torch.nn.functional as F 15 | import torch.utils.checkpoint 16 | from torch.utils.data import Dataset 17 | 18 | import PIL 19 | from accelerate import Accelerator 20 | from accelerate.logging import get_logger 21 | from accelerate.utils import set_seed 22 | from diffusers.optimization import get_scheduler 23 | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker 24 | from diffusers.utils.import_utils import is_xformers_available 25 | from huggingface_hub import HfFolder, Repository, whoami 26 | 27 | from diffusers import ( 28 | AutoencoderKL, 29 | DDPMScheduler, 30 | DiffusionPipeline, 31 | DDIMScheduler, 32 | StableDiffusionPipeline, 33 | UNet2DConditionModel, 34 | ) 35 | 36 | # TODO: remove and import from diffusers.utils when the new version of diffusers is released 37 | from packaging import version 38 | from PIL import Image 39 | from torchvision import transforms 40 | from tqdm.auto import tqdm 41 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer 42 | 43 | from datasets import ComposableDataset 44 | 45 | if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): 46 | PIL_INTERPOLATION = { 47 | "linear": PIL.Image.Resampling.BILINEAR, 48 | "bilinear": PIL.Image.Resampling.BILINEAR, 49 | "bicubic": PIL.Image.Resampling.BICUBIC, 50 | "lanczos": PIL.Image.Resampling.LANCZOS, 51 | "nearest": PIL.Image.Resampling.NEAREST, 52 | } 53 | else: 54 | PIL_INTERPOLATION = { 55 | "linear": PIL.Image.LINEAR, 56 | "bilinear": PIL.Image.BILINEAR, 57 | "bicubic": PIL.Image.BICUBIC, 58 | "lanczos": PIL.Image.LANCZOS, 59 | "nearest": PIL.Image.NEAREST, 60 | } 61 | # ------------------------------------------------------------------------------ 62 | 63 | 64 | logger = get_logger(__name__) 65 | 66 | 67 | def save_progress(text_encoder, placeholder_token_id, accelerator, args): 68 | logger.info("Saving embeddings") 69 | learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id] 70 | learned_embeds_dict = {args.placeholder_tokens: learned_embeds.detach().cpu()} 71 | if args.test: 72 | embed_path = os.path.join(args.output_dir, "test_learned_embeds.bin") 73 | else: 74 | embed_path = os.path.join(args.output_dir, "learned_embeds.bin") 75 | torch.save(learned_embeds_dict, embed_path) 76 | 77 | 78 | def save_weights(weights, args): 79 | logger.info("Saving embeddings") 80 | learned_weights_dict = {"weights": weights.detach().cpu()} 81 | if args.test: 82 | weight_path = os.path.join(args.output_dir, "test_weights.bin") 83 | else: 84 | weight_path = os.path.join(args.output_dir, "weights.bin") 85 | torch.save(learned_weights_dict, weight_path) 86 | 87 | 88 | def parse_args(): 89 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 90 | parser.add_argument( 91 | "--save_steps", 92 | type=int, 93 | default=500, 94 | help="Save learned_embeds.bin every X updates steps.", 95 | ) 96 | parser.add_argument( 97 | "--pretrained_model_name_or_path", 98 | type=str, 99 | default=None, 100 | required=True, 101 | help="Path to pretrained model or model identifier from huggingface.co/models.", 102 | ) 103 | parser.add_argument( 104 | "--tokenizer_name", 105 | type=str, 106 | default=None, 107 | help="Pretrained tokenizer name or path if not the same as model_name", 108 | ) 109 | parser.add_argument( 110 | "--train_data_dir", type=str, required=True, 111 | help="A list of folders containing the training data for each token provided." 112 | ) 113 | parser.add_argument( 114 | "--placeholder_tokens", 115 | type=str, 116 | required=True, 117 | help="A list of tokens to use as placeholders for all the concepts, separated by comma", 118 | ) 119 | parser.add_argument( 120 | "--initializer_tokens", type=str, default="", 121 | help="A list of tokens to use as initializer words, separated by comma" 122 | ) 123 | parser.add_argument("--learnable_property", type=str, default="", 124 | help="a list of properties for all the tokens needed to be learned, separated by comma") 125 | parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.") 126 | parser.add_argument( 127 | "--output_dir", 128 | type=str, 129 | default="checkpoints", 130 | help="The output directory where the model predictions and checkpoints will be written.", 131 | ) 132 | parser.add_argument( 133 | "--resume_dir", 134 | type=str, 135 | default="", 136 | help="The output directory where the model predictions and checkpoints will be written.", 137 | ) 138 | parser.add_argument("--softmax_weights", action="store_true", default=False) 139 | parser.add_argument("--reuse_weights", action="store_true", default=False) 140 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 141 | parser.add_argument( 142 | "--resolution", 143 | type=int, 144 | default=512, 145 | help=( 146 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 147 | " resolution" 148 | ), 149 | ) 150 | parser.add_argument( 151 | "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" 152 | ) 153 | parser.add_argument( 154 | "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." 155 | ) 156 | parser.add_argument("--num_train_epochs", type=int, default=100) 157 | parser.add_argument( 158 | "--max_train_steps", 159 | type=int, 160 | default=None, 161 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 162 | ) 163 | parser.add_argument( 164 | "--gradient_accumulation_steps", 165 | type=int, 166 | default=1, 167 | help="Number of updates steps to accumulate before performing a backward/update pass.", 168 | ) 169 | parser.add_argument( 170 | "--gradient_checkpointing", 171 | action="store_true", 172 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 173 | ) 174 | parser.add_argument( 175 | "--learning_rate", 176 | type=float, 177 | default=1e-4, 178 | help="Initial learning rate (after the potential warmup period) to use.", 179 | ) 180 | parser.add_argument( 181 | "--scale_lr", 182 | action="store_true", 183 | default=True, 184 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 185 | ) 186 | parser.add_argument( 187 | "--lr_scheduler", 188 | type=str, 189 | default="constant", 190 | help=( 191 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 192 | ' "constant", "constant_with_warmup"]' 193 | ), 194 | ) 195 | parser.add_argument( 196 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 197 | ) 198 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 199 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 200 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 201 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 202 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 203 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 204 | parser.add_argument( 205 | "--hub_model_id", 206 | type=str, 207 | default=None, 208 | help="The name of the repository to keep in sync with the local `output_dir`.", 209 | ) 210 | parser.add_argument( 211 | "--logging_dir", 212 | type=str, 213 | default="logs", 214 | help=( 215 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 216 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 217 | ), 218 | ) 219 | parser.add_argument( 220 | "--mixed_precision", 221 | type=str, 222 | default="no", 223 | choices=["no", "fp16", "bf16"], 224 | help=( 225 | "Whether to use mixed precision. Choose" 226 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." 227 | "and an Nvidia Ampere GPU." 228 | ), 229 | ) 230 | parser.add_argument( 231 | "--allow_tf32", 232 | action="store_true", 233 | help=( 234 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 235 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 236 | ), 237 | ) 238 | parser.add_argument( 239 | "--report_to", 240 | type=str, 241 | default="tensorboard", 242 | help=( 243 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 244 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 245 | ), 246 | ) 247 | parser.add_argument( 248 | "--resume_from_checkpoint", 249 | type=str, 250 | default=None, 251 | help=( 252 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 253 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 254 | ), 255 | ) 256 | parser.add_argument( 257 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." 258 | ) 259 | parser.add_argument( 260 | "--use_composed_score", action="store_true", default=False, 261 | help="whether to use composed score for textual inversion." 262 | ) 263 | parser.add_argument( 264 | "--use_orthogonal_loss", action="store_true", default=False, 265 | help="should be enabled to get a better performance when using composed scores to invert text." 266 | ) 267 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 268 | parser.add_argument( 269 | "--checkpointing_steps", 270 | type=int, 271 | default=500, 272 | help=( 273 | "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" 274 | " training using `--resume_from_checkpoint`." 275 | ), 276 | ) 277 | 278 | parser.add_argument("--data", type=str, default="imagenet") 279 | parser.add_argument("--class_folder_names", type=str, 280 | help="a list of imagenet data folders for each class, seperate by comma") 281 | 282 | parser.add_argument("--add_weight_per_score", action="store_true", default=False) 283 | parser.add_argument("--freeze_weights", action="store_true", default=False) 284 | parser.add_argument("--init_weight", type=float, default=1) 285 | parser.add_argument("--use_conj_score", action="store_true", default=False) 286 | 287 | parser.add_argument("--orthogonal_coeff", type=float, default=0.1) 288 | parser.add_argument("--squared_orthogonal_loss", action="store_true", default=False) 289 | parser.add_argument("--mse_coeff", type=float, default=1) 290 | parser.add_argument("--num_images_per_class", type=int, default=-1, help="-1 means all images considered") 291 | 292 | parser.add_argument("--weighted_sampling", action="store_true", default=False) 293 | parser.add_argument("--flip_weights", action="store_true", default=False) 294 | 295 | parser.add_argument("--text_loss", action="store_true", default=False) 296 | parser.add_argument("--text_angle_loss", action="store_true", default=False) 297 | parser.add_argument("--text_repulsion_loss", action="store_true", default=False) 298 | parser.add_argument("--text_repulsion_similarity_loss", action="store_true", default=False) 299 | parser.add_argument("--text_repulsion_coeff", type=float, default=0) 300 | 301 | parser.add_argument("--euclidean_dist_loss", action="store_true", default=False) 302 | parser.add_argument("--euclidean_dist_coeff", type=float, default=0) 303 | 304 | parser.add_argument("--use_similarity", action="store_true", default=False, 305 | help="Dot product between scores as the orthogonal loss") 306 | parser.add_argument("--use_euclidean_mhe", action="store_true", default=False, 307 | help="Minimum Hyperspherical Energy as the orthogonal loss") 308 | parser.add_argument("--log_mhe", action="store_true", default=False) 309 | parser.add_argument("--use_acos_mhe", action="store_true", default=False) 310 | parser.add_argument("--normalize_score", action="store_true", default=False) 311 | parser.add_argument("--use_weighted_score", action="store_true", default=False) 312 | 313 | parser.add_argument("--use_l2_norm_regularization", action="store_true", default=False) 314 | parser.add_argument("--l2_norm_coeff", type=float, default=0) 315 | 316 | parser.add_argument("--normalize_word", action="store_true", default=False) 317 | parser.add_argument("--num_iters_per_image", type=int, default=50) 318 | parser.add_argument("--hsic_loss", action="store_true", default=False) 319 | parser.add_argument("--test", action="store_true", default=False, 320 | help="enable this by only optimizing weights using existing models.") 321 | 322 | parser.add_argument( 323 | "--validation_step", 324 | type=int, 325 | default=100, 326 | help=( 327 | "Run validation every X steps. Validation consists of running the prompt" 328 | " `args.validation_prompt` multiple times: `args.num_validation_images`" 329 | " and logging the images." 330 | ), 331 | ) 332 | 333 | parser.add_argument( 334 | "--revision", 335 | type=str, 336 | default=None, 337 | required=False, 338 | help="Revision of pretrained model identifier from huggingface.co/models.", 339 | ) 340 | 341 | 342 | args = parser.parse_args() 343 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 344 | if env_local_rank != -1 and env_local_rank != args.local_rank: 345 | args.local_rank = env_local_rank 346 | 347 | if args.train_data_dir is None: 348 | raise ValueError("You must specify a train data directory.") 349 | 350 | return args 351 | 352 | 353 | def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): 354 | if token is None: 355 | token = HfFolder.get_token() 356 | if organization is None: 357 | username = whoami(token)["name"] 358 | return f"{username}/{model_id}" 359 | else: 360 | return f"{organization}/{model_id}" 361 | 362 | 363 | def freeze_params(params): 364 | for param in params: 365 | param.requires_grad = False 366 | 367 | 368 | def main(): 369 | args = parse_args() 370 | logging_dir = os.path.join(args.output_dir, args.logging_dir) 371 | 372 | accelerator = Accelerator( 373 | gradient_accumulation_steps=args.gradient_accumulation_steps, 374 | mixed_precision=args.mixed_precision, 375 | log_with="tensorboard", 376 | logging_dir=logging_dir, 377 | ) 378 | 379 | # If passed along, set the training seed now. 380 | if args.seed is not None: 381 | set_seed(args.seed) 382 | 383 | # Handle the repository creation 384 | if accelerator.is_main_process: 385 | if args.push_to_hub: 386 | if args.hub_model_id is None: 387 | repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) 388 | else: 389 | repo_name = args.hub_model_id 390 | repo = Repository(args.output_dir, clone_from=repo_name) 391 | 392 | with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: 393 | if "step_*" not in gitignore: 394 | gitignore.write("step_*\n") 395 | if "epoch_*" not in gitignore: 396 | gitignore.write("epoch_*\n") 397 | elif args.output_dir is not None: 398 | os.makedirs(args.output_dir, exist_ok=True) 399 | 400 | if args.resume_from_checkpoint: 401 | args.pretrained_model_name_or_path = args.resume_dir 402 | print(f"resume everything from {args.pretrained_model_name_or_path}") 403 | 404 | # Load the tokenizer and add the placeholder token as a additional special token 405 | if args.tokenizer_name: 406 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) 407 | elif args.pretrained_model_name_or_path: 408 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") 409 | 410 | # Add the placeholder token in tokenizer 411 | placeholder_tokens = [x.strip() for x in args.placeholder_tokens.split(",")] 412 | num_added_tokens = tokenizer.add_tokens(placeholder_tokens) 413 | 414 | if num_added_tokens != 0 and num_added_tokens != len(placeholder_tokens): 415 | raise ValueError( 416 | f"The tokenizer already contains at least one of the tokens in {placeholder_tokens}. " 417 | f"Please pass a different placeholder_token` that is not already in the tokenizer." 418 | ) 419 | 420 | # Convert the initializer_token, placeholder_token to ids 421 | if args.initializer_tokens != "": 422 | initializer_tokens = [x.strip() for x in args.initializer_tokens.split(",")] 423 | else: 424 | initializer_tokens = [] 425 | 426 | if len(initializer_tokens) == 0: 427 | if args.resume_from_checkpoint: 428 | logger.info("* Resume the embeddings of placeholder tokens *") 429 | print("* Resume the embeddings of placeholder tokens *") 430 | else: 431 | logger.info("* Initialize the newly added placeholder token with the random embeddings *") 432 | print("* Initialize the newly added placeholder token with the random embeddings *") 433 | token_ids = tokenizer.convert_tokens_to_ids(placeholder_tokens) 434 | else: 435 | logger.info("* Initialize the newly added placeholder token with the embeddings of the initializer token *") 436 | print("* Initialize the newly added placeholder token with the embeddings of the initializer token *") 437 | token_ids = tokenizer.encode(initializer_tokens, add_special_tokens=False) 438 | # Check if initializer_token is a single token or a sequence of tokens 439 | if len(token_ids) > len(initializer_tokens): 440 | raise ValueError("The initializer token must be a single token.") 441 | 442 | initializer_token_ids = token_ids 443 | placeholder_token_ids = tokenizer.convert_tokens_to_ids(placeholder_tokens) 444 | 445 | # Load models and create wrapper for stable diffusion 446 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") 447 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") 448 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") 449 | 450 | # Resize the token embeddings as we are adding new special tokens to the tokenizer 451 | text_encoder.resize_token_embeddings(len(tokenizer)) 452 | 453 | # Initialise the newly added placeholder token with the embeddings of the initializer token 454 | token_embeds = text_encoder.get_input_embeddings().weight.data 455 | token_embeds[placeholder_token_ids] = token_embeds[initializer_token_ids] 456 | if args.normalize_word: 457 | token_embeds[placeholder_token_ids] = F.normalize(token_embeds[placeholder_token_ids], dim=1, p=2) 458 | 459 | # Freeze vae and unet 460 | freeze_params(vae.parameters()) 461 | freeze_params(unet.parameters()) 462 | # Freeze all parameters except for the token embeddings in text encoder 463 | params_to_freeze = itertools.chain( 464 | text_encoder.text_model.encoder.parameters(), 465 | text_encoder.text_model.final_layer_norm.parameters(), 466 | text_encoder.text_model.embeddings.position_embedding.parameters(), 467 | text_encoder.get_input_embeddings().parameters() if args.test else [] 468 | ) 469 | freeze_params(params_to_freeze) 470 | 471 | if args.gradient_checkpointing: 472 | # Keep unet in train mode if we are using gradient checkpointing to save memory. 473 | # The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode. 474 | unet.train() 475 | text_encoder.gradient_checkpointing_enable() 476 | unet.enable_gradient_checkpointing() 477 | 478 | if args.enable_xformers_memory_efficient_attention: 479 | if is_xformers_available(): 480 | unet.enable_xformers_memory_efficient_attention() 481 | else: 482 | raise ValueError("xformers is not available. Make sure it is installed correctly") 483 | 484 | # Enable TF32 for faster training on Ampere GPUs, 485 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 486 | if args.allow_tf32: 487 | torch.backends.cuda.matmul.allow_tf32 = True 488 | 489 | if args.scale_lr: 490 | args.learning_rate = ( 491 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 492 | ) 493 | 494 | train_dataset = ComposableDataset( 495 | data_root=args.train_data_dir, 496 | tokenizer=tokenizer, 497 | size=args.resolution, 498 | repeats=args.repeats, 499 | center_crop=args.center_crop, 500 | placeholder_tokens=args.placeholder_tokens, 501 | num_images_per_class=args.num_images_per_class, 502 | class_folder_names=args.class_folder_names, 503 | learnable_property=args.learnable_property, 504 | set="train" if not args.test else "val", 505 | ) 506 | 507 | if args.add_weight_per_score: 508 | # Add a learnable weight for each token 509 | if args.resume_from_checkpoint and args.reuse_weights: 510 | weight_path = os.path.join(args.resume_dir, "weights.bin") 511 | concept_weights = torch.load(weight_path)["weights"] 512 | concept_weights.requires_grad = not args.freeze_weights 513 | concept_weights = torch.nn.Parameter(concept_weights, requires_grad=not args.freeze_weights) 514 | print('reusing the weights...') 515 | else: 516 | num_tokens = len(placeholder_token_ids) 517 | # create weight matrix NxMx1x1x1 where D is the number of images and M is the number of classes 518 | concept_weights = torch.tensor([args.init_weight] * num_tokens).reshape(1, -1, 1, 1, 1).float() 519 | if args.softmax_weights: 520 | concept_weights = F.softmax(concept_weights, dim=1) 521 | concept_weights = concept_weights.repeat(train_dataset.num_images, 1, 1, 1, 1) 522 | concept_weights.requires_grad = not args.freeze_weights 523 | concept_weights = torch.nn.Parameter(concept_weights, requires_grad=not args.freeze_weights) 524 | 525 | # Initialize the optimizer 526 | optimizer = torch.optim.AdamW( 527 | itertools.chain( 528 | text_encoder.get_input_embeddings().parameters() if not args.test else [], 529 | [concept_weights] if args.add_weight_per_score else [] 530 | ), 531 | lr=args.learning_rate, 532 | betas=(args.adam_beta1, args.adam_beta2), 533 | weight_decay=args.adam_weight_decay, 534 | eps=args.adam_epsilon, 535 | ) 536 | 537 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True) 538 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") 539 | 540 | # Scheduler and math around the number of training steps. 541 | overrode_max_train_steps = False 542 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 543 | if args.max_train_steps is None: 544 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 545 | overrode_max_train_steps = True 546 | 547 | lr_scheduler = get_scheduler( 548 | args.lr_scheduler, 549 | optimizer=optimizer, 550 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 551 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 552 | ) 553 | 554 | text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 555 | text_encoder, optimizer, train_dataloader, lr_scheduler 556 | ) 557 | 558 | weight_dtype = torch.float32 559 | if accelerator.mixed_precision == "fp16": 560 | weight_dtype = torch.float16 561 | elif accelerator.mixed_precision == "bf16": 562 | weight_dtype = torch.bfloat16 563 | 564 | # Move vae and unet to device 565 | vae.to(accelerator.device, dtype=weight_dtype) 566 | unet.to(accelerator.device, dtype=weight_dtype) 567 | 568 | args.max_train_steps = train_dataset.num_images * args.num_iters_per_image 569 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 570 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 571 | if overrode_max_train_steps: 572 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 573 | # Afterwards we recalculate our number of training epochs 574 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 575 | 576 | # We need to initialize the trackers we use, and also store our configuration. 577 | # The trackers initializes automatically on the main process. 578 | if accelerator.is_main_process: 579 | accelerator.init_trackers("checkpoints", config=vars(args)) 580 | 581 | # Train! 582 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 583 | print(f'total_batch_size: {total_batch_size}') 584 | 585 | logger.info("***** Running training *****") 586 | logger.info(f" Num examples = {len(train_dataset)}") 587 | logger.info(f" Num Epochs = {args.num_train_epochs}") 588 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 589 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 590 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 591 | logger.info(f" Total optimization steps = {args.max_train_steps}") 592 | global_step = 0 593 | first_epoch = 0 594 | 595 | # Potentially load in the weights and states from a previous save 596 | if args.resume_from_checkpoint: 597 | if args.resume_from_checkpoint != "latest": 598 | path = os.path.basename(args.resume_from_checkpoint) 599 | else: 600 | # Get the most recent checkpoint 601 | dirs = os.listdir(args.resume_dir) 602 | dirs = [d for d in dirs if d.startswith("checkpoint")] 603 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 604 | path = dirs[-1] if len(dirs) > 0 else None 605 | 606 | if path is None: 607 | accelerator.print( 608 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 609 | ) 610 | args.resume_from_checkpoint = None 611 | args.max_train_steps = train_dataset.num_images * args.num_iters_per_image 612 | else: 613 | if not args.test: 614 | accelerator.print(f"Resuming from checkpoint {path}") 615 | accelerator.load_state(os.path.join(args.resume_dir, path)) 616 | 617 | global_step = int(path.split("-")[1]) 618 | 619 | resume_global_step = global_step * args.gradient_accumulation_steps 620 | first_epoch = global_step // num_update_steps_per_epoch 621 | resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) 622 | # update the number of iterations 623 | args.max_train_steps = global_step + train_dataset.num_images * args.num_iters_per_image 624 | # Afterwards we recalculate our number of training epochs 625 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 626 | 627 | # Only show the progress bar once on each machine. 628 | progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) 629 | progress_bar.set_description("Steps") 630 | 631 | # keep original embeddings as reference 632 | orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone() 633 | 634 | # iterate through the data and save dataset info 635 | dataset_info = {} 636 | for step, batch in enumerate(train_dataloader): 637 | image_path = batch["image_path"] 638 | image_idx = batch["image_index"] 639 | for i in range(len(image_path)): 640 | dataset_info[image_idx[i].item()] = image_path[i] 641 | 642 | if args.test: 643 | path = os.path.join(args.output_dir, "test_dataset_info.json") 644 | else: 645 | path = os.path.join(args.output_dir, "dataset_info.json") 646 | 647 | with open(path, "w") as f: 648 | json.dump(dataset_info, f) 649 | 650 | for epoch in range(first_epoch, args.num_train_epochs): 651 | text_encoder.train() 652 | for step, batch in enumerate(train_dataloader): 653 | # # Skip steps until we reach the resumed step 654 | if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: 655 | if step % args.gradient_accumulation_steps == 0: 656 | progress_bar.update(1) 657 | continue 658 | 659 | mse_loss, orthogonal_loss, repulsion_loss, word_norm_loss, euclidean_dist_loss = 0, 0, 0, 0, 0 660 | with accelerator.accumulate(text_encoder): 661 | # image shape: Bx3xHxW 662 | # input_ids shape: BxMxD where M is the number of classes, D is the text dims 663 | pixel_value, input_ids = batch["pixel_values"], batch["input_ids"] 664 | weight_id = batch["gt_weight_id"] 665 | # split input ids into a list of BxD 666 | input_ids_list = [y.squeeze(dim=1) for y in input_ids.chunk(chunks=input_ids.shape[1], dim=1)] 667 | 668 | if args.use_composed_score: 669 | noise, uncond_noise_pred, noise_preds = None, None, [] 670 | for input_ids in input_ids_list: 671 | # Convert images to latent space 672 | latents = vae.encode(pixel_value).latent_dist.sample().detach() 673 | latents = latents * 0.18215 674 | 675 | # Sample noise that we'll add to the latents 676 | if noise is None: 677 | noise = torch.randn(latents.shape).to(latents.device) 678 | bsz = latents.shape[0] 679 | 680 | # Sample a random timestep for each image 681 | if args.weighted_sampling: 682 | weights = torch.arange(1, noise_scheduler.config.num_train_timesteps + 1).float() 683 | if args.flip_weights: 684 | weights = weights.flip(dims=(0,)) 685 | timesteps = torch.multinomial(weights, bsz).to(latents.device) 686 | else: 687 | timesteps = torch.randint( 688 | 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device 689 | ).long() 690 | 691 | # Add noise to the latents according to the noise magnitude at each timestep 692 | # (this is the forward diffusion process) 693 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 694 | 695 | # Get the text embedding for conditioning 696 | encoder_hidden_states = text_encoder(input_ids)[0] 697 | # Predict the noise residual 698 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 699 | noise_preds.append(noise_pred) 700 | 701 | if uncond_noise_pred is None and args.use_conj_score: 702 | # precompute the unconditional text hidden states 703 | uncond_text_ids = tokenizer( 704 | "", 705 | padding="max_length", 706 | truncation=True, 707 | max_length=tokenizer.model_max_length, 708 | return_tensors="pt", 709 | ).input_ids.to(latents.device) 710 | B = noisy_latents.shape[0] 711 | uncond_encoder_hidden_states = text_encoder(uncond_text_ids)[0].repeat(B, 1, 1) 712 | uncond_noise_pred = unet(noisy_latents, timesteps, uncond_encoder_hidden_states).sample 713 | 714 | noise_preds_stack = torch.stack(noise_preds, dim=1) # BxMx4x64x64 715 | elif args.use_conj_score: 716 | # latents 717 | latents = vae.encode(pixel_value).latent_dist.sample().detach() 718 | latents = latents * 0.18215 719 | bsz = latents.shape[0] 720 | noise = torch.randn_like(latents) 721 | 722 | timesteps = torch.randint( 723 | 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device 724 | ).long() 725 | # Add noise to the latents according to the noise magnitude at each timestep 726 | # (this is the forward diffusion process) 727 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 728 | 729 | weights = concept_weights[weight_id] 730 | 731 | cond_scores = [] 732 | for input_ids in input_ids_list: 733 | encoder_hidden_state = text_encoder(input_ids)[0].to(dtype=weight_dtype) 734 | cond_scores.append(unet(noisy_latents, timesteps, encoder_hidden_state).sample) 735 | cond_scores = torch.stack(cond_scores, dim=1) 736 | uncond_text_ids = tokenizer( 737 | "", 738 | padding="max_length", 739 | truncation=True, 740 | max_length=tokenizer.model_max_length, 741 | return_tensors="pt", 742 | ).input_ids.to(latents.device) 743 | uncond_encoder_hidden_states = text_encoder(uncond_text_ids)[0].repeat(bsz, 1, 1) 744 | uncond_score = unet(noisy_latents, timesteps, uncond_encoder_hidden_states).sample 745 | 746 | # compute initial compositional score 747 | composed_score = uncond_score + torch.sum(weights.to(latents.device) * (cond_scores - uncond_score[:, None]), dim=1) 748 | # encoder_hidden_states = torch.stack(encoder_hidden_states_list, dim=1) 749 | # gt_encoder_states = (encoder_hidden_states * weights.to(latents.device)).sum(dim=1) 750 | mse_loss = args.mse_coeff * F.mse_loss(noise, composed_score.float(), reduction="mean") 751 | 752 | # orthogonal loss 753 | if args.use_orthogonal_loss: 754 | if args.use_similarity: 755 | B, M, C, H, W = cond_scores.shape 756 | ortho_scores_view = cond_scores.view(B, M, -1) 757 | prod_matrix = torch.bmm(ortho_scores_view, ortho_scores_view.transpose(2, 1)) / (C * H * W) 758 | # only compute the upper triangular matrices (exclude the diagonal) 759 | r, c = torch.triu_indices(M, M, offset=1) 760 | orthogonal_loss = args.orthogonal_coeff * (prod_matrix[:, r, c] ** 2).sum().sqrt() 761 | elif args.use_euclidean_mhe: 762 | B, M, C, H, W = cond_scores.shape 763 | ortho_scores_view = cond_scores.view(B, M, -1) 764 | batch_pair_wise_l2_dist = torch.cdist(ortho_scores_view, ortho_scores_view, p=2.0) 765 | # only compute the upper triangular matrices (exclude the diagonal) 766 | energy_matrix = torch.triu(batch_pair_wise_l2_dist, diagonal=1) 767 | energy_matrix = energy_matrix[energy_matrix != 0] 768 | if args.log_mhe: 769 | orthogonal_loss = torch.log(1 / energy_matrix).mean() 770 | else: 771 | orthogonal_loss = (1 / energy_matrix).mean() 772 | orthogonal_loss *= args.orthogonal_coeff 773 | 774 | if args.text_loss: 775 | word_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ 776 | placeholder_token_ids] 777 | if args.text_repulsion_loss: 778 | if args.use_l2_norm_regularization: 779 | # word embeds are unnormalized. 780 | word_norm_loss = args.l2_norm_coeff * torch.norm(word_embeds, dim=1).mean(dim=0) 781 | if args.normalize_score: 782 | word_embeds = F.normalize(word_embeds, p=2, dim=1) 783 | word_dist_matrix = F.pdist(word_embeds, p=2) 784 | repulsion_loss = args.text_repulsion_coeff * torch.log(1 / word_dist_matrix).mean() 785 | elif args.text_repulsion_similarity_loss: 786 | num_words = word_embeds.shape[0] 787 | similarity = word_embeds @ word_embeds.T 788 | similarity = similarity[torch.triu_indices(num_words, num_words, offset=1).unbind()] ** 2. 789 | repulsion_loss = args.text_repulsion_coeff * similarity.sum().sqrt() 790 | 791 | if args.use_composed_score: 792 | # extract the corresponding weights for the batch of images 793 | if args.add_weight_per_score: 794 | weights = concept_weights[weight_id] 795 | if args.softmax_weights: 796 | weights = F.softmax(concept_weights[weight_id], dim=1) 797 | weighted_scores = noise_preds_stack * weights.to(latents.device) 798 | else: 799 | weighted_scores = noise_preds_stack / noise_preds_stack.shape[1] 800 | 801 | if args.use_conj_score: 802 | uncond_noise_pred = uncond_noise_pred[:, None] 803 | score = uncond_noise_pred + weights.to(latents.device) * (noise_preds_stack - uncond_noise_pred) 804 | composed_score = score.sum(dim=1) 805 | else: 806 | composed_score = weighted_scores.sum(dim=1) 807 | 808 | # TODO: MSE between classifier free score and noise doesn't make sense?? 809 | mse_loss = args.mse_coeff * F.mse_loss(composed_score, noise, reduction="mean") 810 | # compute sum of pair wise dot product as the orthogonal loss 811 | 812 | if args.use_orthogonal_loss: 813 | if args.use_weighted_score: 814 | ortho_scores = weighted_scores 815 | else: 816 | ortho_scores = noise_preds_stack 817 | # assume number of classes: B > 1 818 | B, M, C, H, W = ortho_scores.shape 819 | ortho_scores_view = ortho_scores.view(B, M, -1) 820 | if args.normalize_score: 821 | ortho_scores_view = F.normalize(ortho_scores_view, p=2, dim=2) 822 | 823 | if args.use_similarity: 824 | prod_matrix = torch.bmm(ortho_scores_view, ortho_scores_view.transpose(2, 1)) 825 | # only compute the upper triangular matrices (exclude the diagonal) 826 | num_pairs = math.factorial(M) / (math.factorial(2) * math.factorial(M - 2)) 827 | ortho_matrix = (torch.triu(prod_matrix, diagonal=1) / (B * C * H * W)) 828 | orthogonal_loss = args.orthogonal_coeff * ortho_matrix.sum() / num_pairs 829 | # r, c = torch.triu_indices(M, M, offset=1).unbind() 830 | # orthogonal_loss = args.orthogonal_coeff * (prod_matrix[:, r, c]).mean() 831 | elif args.use_euclidean_mhe: 832 | # Minimum Hyperspherical energy: norm 833 | # scale * sum_i^{N}sum_j^{N}_{i!=j} log(||w_i - w_j|| ** - 834 | batch_pair_wise_l2_dist = torch.cdist(ortho_scores_view, ortho_scores_view, p=2.0) 835 | # only compute the upper triangular matrices (exclude the diagonal) 836 | energy_matrix = torch.triu(batch_pair_wise_l2_dist, diagonal=1) 837 | energy_matrix = energy_matrix[energy_matrix != 0] 838 | orthogonal_loss = torch.log(1 / energy_matrix).mean() 839 | orthogonal_loss *= args.orthogonal_coeff 840 | elif args.use_acos_mhe: 841 | prod_matrix_1 = torch.bmm(ortho_scores_view, ortho_scores_view.transpose(2, 1)) 842 | energy_matrix = torch.triu(prod_matrix_1, diagonal=1) 843 | energy_matrix = energy_matrix[energy_matrix != 0][..., None] 844 | energy_matrix = torch.acos(energy_matrix) 845 | orthogonal_loss = torch.log(1 / energy_matrix).sum(dim=energy_matrix.shape[1:]).mean(dim=0) 846 | orthogonal_loss *= args.orthogonal_coeff 847 | else: 848 | raise NotImplementedError 849 | 850 | if args.text_loss: 851 | word_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ 852 | placeholder_token_ids] 853 | if args.text_repulsion_loss: 854 | if args.use_l2_norm_regularization: 855 | # word embeds are unnormalized. 856 | word_norm_loss = args.l2_norm_coeff * torch.norm(word_embeds, dim=1).mean(dim=0) 857 | if args.normalize_score: 858 | word_embeds = F.normalize(word_embeds, p=2, dim=1) 859 | word_dist_matrix = F.pdist(word_embeds, p=2) 860 | repulsion_loss = args.text_repulsion_coeff * torch.log(1 / word_dist_matrix).mean() 861 | elif args.text_repulsion_similarity_loss: 862 | if args.use_l2_norm_regularization: 863 | # word embeds are unnormalized. 864 | word_norm_loss = args.l2_norm_coeff * torch.norm(word_embeds, dim=1).mean(dim=0) 865 | N = word_embeds.shape[0] 866 | similarity = word_embeds @ word_embeds.T 867 | similarity = similarity[torch.triu_indices(N, N, offset=1).unbind()] ** 2. 868 | repulsion_loss = args.text_repulsion_coeff * similarity.mean() 869 | 870 | if args.euclidean_dist_loss: 871 | word_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ 872 | placeholder_token_ids] 873 | euclidean_dist_loss = args.euclidean_dist_coeff * (1 / F.pdist(word_embeds, p=2)).mean() 874 | 875 | loss = mse_loss + orthogonal_loss + repulsion_loss + word_norm_loss + euclidean_dist_loss 876 | 877 | accelerator.backward(loss) 878 | optimizer.step() 879 | lr_scheduler.step() 880 | optimizer.zero_grad() 881 | 882 | # Let's make sure we don't update any embedding weights besides the newly added token 883 | if not args.test: 884 | index_no_updates = torch.ones(len(tokenizer), dtype=torch.bool) 885 | index_no_updates[placeholder_token_ids] = False 886 | if accelerator.num_processes > 1: 887 | grads = text_encoder.module.get_input_embeddings().weight.grad 888 | else: 889 | grads = text_encoder.get_input_embeddings().weight.grad 890 | # optimize all newly added tokens 891 | grads.data[index_no_updates, :] = grads.data[index_no_updates, :].fill_(0) 892 | 893 | with torch.no_grad(): 894 | accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ 895 | index_no_updates 896 | ] = orig_embeds_params[index_no_updates] 897 | 898 | # Checks if the accelerator has performed an optimization step behind the scenes 899 | if accelerator.sync_gradients: 900 | progress_bar.update(1) 901 | global_step += 1 902 | if global_step % args.save_steps == 0: 903 | save_progress(text_encoder, initializer_token_ids, accelerator, args) 904 | if args.add_weight_per_score: 905 | save_weights(concept_weights, args) 906 | 907 | if global_step % args.checkpointing_steps == 0: 908 | if accelerator.is_main_process: 909 | if not args.test: 910 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 911 | accelerator.save_state(save_path) 912 | logger.info(f"Saved state to {save_path}") 913 | 914 | logs = {"loss": loss.detach().item(), 915 | "lr": lr_scheduler.get_last_lr()[0], 916 | "mse_loss": mse_loss.item(), 917 | "ortho_loss": orthogonal_loss.item() if args.use_orthogonal_loss else 0, 918 | "word_repulsion_loss": repulsion_loss.item() if args.text_loss else 0, 919 | "euclidean_dist_loss": euclidean_dist_loss.item() if args.euclidean_dist_loss else 0, 920 | "word_norm_regularization": word_norm_loss.item() if args.use_l2_norm_regularization else 0, 921 | } 922 | progress_bar.set_postfix(**logs) 923 | accelerator.log(logs, step=global_step) 924 | 925 | if global_step >= args.max_train_steps: 926 | break 927 | 928 | if accelerator.sync_gradients and global_step % args.validation_step == 0: 929 | folder = os.path.join(args.output_dir, f'generated_samples_{global_step}') 930 | os.makedirs(folder, exist_ok=True) 931 | logger.info( 932 | f"Running validation..." 933 | ) 934 | # create pipeline (note: unet and vae are loaded again in float32) 935 | pipeline = DiffusionPipeline.from_pretrained( 936 | args.pretrained_model_name_or_path, 937 | text_encoder=accelerator.unwrap_model(text_encoder), 938 | tokenizer=tokenizer, 939 | unet=unet, 940 | vae=vae, 941 | revision=args.revision, 942 | torch_dtype=weight_dtype, 943 | ) 944 | pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) 945 | pipeline = pipeline.to(accelerator.device) 946 | pipeline.set_progress_bar_config(disable=True) 947 | 948 | # run inference 949 | generator = ( 950 | None if args.seed is None else 951 | torch.Generator(device=accelerator.device).manual_seed(args.seed) 952 | ) 953 | images = [] 954 | prompts = [] 955 | if args.learnable_property != "": 956 | properties = [x.strip() for x in args.learnable_property.split(",")] 957 | else: 958 | properties = [] 959 | 960 | if properties: 961 | for p, placeholder in zip(properties, placeholder_tokens): 962 | if p == "object": 963 | prompts.append(f"a photo of {placeholder}") 964 | else: 965 | prompts.append(f"a painting in the style of {placeholder}") 966 | else: 967 | for placeholder in placeholder_tokens: 968 | prompts.append(f"{placeholder}") 969 | 970 | for prompt in prompts: 971 | image_list = pipeline(prompt, guidance_scale=7.5, 972 | num_inference_steps=50, generator=generator) 973 | image = image_list.images[0] 974 | image.save(os.path.join(folder, f'{prompt}.png')) 975 | images.append(image) 976 | 977 | for tracker in accelerator.trackers: 978 | if tracker.name == "tensorboard": 979 | np_images = np.stack([np.asarray(img) for img in images]) 980 | tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") 981 | del pipeline 982 | torch.cuda.empty_cache() 983 | 984 | accelerator.wait_for_everyone() 985 | 986 | # Create the pipeline using the trained modules and save it. 987 | if accelerator.is_main_process and global_step % args.checkpointing_steps == 0 and not args.test: 988 | pipeline = StableDiffusionPipeline.from_pretrained( 989 | args.pretrained_model_name_or_path, 990 | text_encoder=accelerator.unwrap_model(text_encoder), 991 | vae=vae, 992 | unet=unet, 993 | tokenizer=tokenizer, 994 | ) 995 | pipeline.save_pretrained(args.output_dir) 996 | # Also save the newly trained embeddings 997 | save_progress(text_encoder, initializer_token_ids, accelerator, args) 998 | if args.add_weight_per_score: 999 | save_weights(concept_weights, args) 1000 | 1001 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 1002 | accelerator.save_state(save_path) 1003 | logger.info(f"Saved state to {save_path}") 1004 | 1005 | if args.push_to_hub: 1006 | repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) 1007 | 1008 | accelerator.end_training() 1009 | 1010 | 1011 | if __name__ == "__main__": 1012 | main() 1013 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.15.0 2 | transformers==4.26.0 3 | diffusers==0.10.2 4 | ftfy 5 | regex 6 | tdqm 7 | tensorboard 8 | modelcards 9 | git+https://github.com/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch 10 | git+https://github.com/openai/CLIP.git -------------------------------------------------------------------------------- /scripts/evaluate_classification.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | clip_logit=0.3 4 | resnet_logit=10 5 | 6 | python eval.py --model_path output/5_class_compositional_score_1_seed_0/ --evaluation_metric clip --class_names "geyser" "chihuahua" "chimpanzee" "shopping cart" "mosque" --logit_threshold $clip_logit 7 | # python eval.py --model_path output/5_class_compositional_score_2_seed_0/ --evaluation_metric clip --class_names "guinea pig" "warplane" "castle" "llama" "volcano" --logit_threshold $clip_logit 8 | # python eval.py --model_path output/5_class_compositional_score_3_seed_0/ --evaluation_metric clip --class_names "convertible" "starfish" "studio couch" "african elephant" "teddy" --logit_threshold $clip_logit 9 | # python eval.py --model_path output/5_class_compositional_score_4_seed_0/ --evaluation_metric clip --class_names "koala" "ice bear" "zebra" "tiger" "panda" --logit_threshold $clip_logit 10 | 11 | python eval.py --model_path output/5_class_compositional_score_1_seed_0/ --evaluation_metric resnet --class_names "geyser" "chihuahua" "chimpanzee" "shopping cart" "mosque" --logit_threshold $resnet_logit 12 | # python eval.py --model_path output/5_class_compositional_score_2_seed_0/ --evaluation_metric resnet --class_names "guinea pig" "warplane" "castle" "llama" "volcano" --logit_threshold $resnet_logit 13 | # python eval.py --model_path output/5_class_compositional_score_3_seed_0/ --evaluation_metric resnet --class_names "convertible" "starfish" "studio couch" "african elephant" "teddy" --logit_threshold $resnet_logit 14 | # python eval.py --model_path output/5_class_compositional_score_4_seed_0/ --evaluation_metric resnet --class_names "koala" "ice bear" "zebra" "tiger, Panthera tigris" "giant panda" --logit_threshold $resnet_logit -------------------------------------------------------------------------------- /scripts/sample.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python inference.py --model_path output/5_class_compositional_score_1/ --prompts "a photo of " "a photo of " "a photo of " "a photo of " "a photo of " --num_images 64 --bsz 8 -------------------------------------------------------------------------------- /scripts/train_ADE20K.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model_path="stabilityai/stable-diffusion-2-1-base" 4 | train_data_dir="ADE20K/home_or_hotel" 5 | placeholder_tokens=",,,," 6 | class_folder_names="kitchen" 7 | output_dir="output/ADE20K_compositional_kitchen_5_concepts" 8 | learnable_property="object,object,object,object,object" 9 | 10 | 11 | DEVICE=$CUDA_VISIBLE_DEVICES 12 | python create_accelerate_config.py --gpu_id "${DEVICE}" 13 | accelerate launch --config_file accelerate_config.yaml main.py \ 14 | --pretrained_model_name_or_path "${model_path}" \ 15 | --train_data_dir ${train_data_dir} \ 16 | --placeholder_tokens ${placeholder_tokens} \ 17 | --resolution=512 --class_folder_names ${class_folder_names} \ 18 | --train_batch_size=2 --gradient_accumulation_steps=8 --repeats 1 \ 19 | --learning_rate=5.0e-03 --scale_lr --lr_scheduler="constant" --max_train_steps 3000 \ 20 | --lr_warmup_steps=0 --output_dir ${output_dir} \ 21 | --learnable_property ${learnable_property} \ 22 | --data "imagenet" --checkpointing_steps 500 \ 23 | --mse_coeff 1 --seed 1 \ 24 | --add_weight_per_score \ 25 | --use_conj_score --init_weight 0.2 \ 26 | --validation_step 500 \ 27 | --num_iters_per_image 100 --num_images_per_class 5 \ -------------------------------------------------------------------------------- /scripts/train_art.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model_path="stabilityai/stable-diffusion-2-1-base" 4 | train_data_dir="style" 5 | placeholder_tokens=",,,," 6 | class_folder_names="van_gogh" 7 | learnable_property="" 8 | output_dir="output/van_gogh" 9 | 10 | model_path="stabilityai/stable-diffusion-2-1-base" 11 | train_data_dir="style" 12 | placeholder_tokens=",,,," 13 | class_folder_names="claude_monet_paintings" 14 | learnable_property="" 15 | output_dir="output/claude_monet_paintings" 16 | 17 | model_path="stabilityai/stable-diffusion-2-1-base" 18 | train_data_dir="style" 19 | placeholder_tokens=",,,," 20 | class_folder_names="picasso" 21 | learnable_property="" 22 | output_dir="output/picasso_paintings" 23 | 24 | 25 | DEVICE=$CUDA_VISIBLE_DEVICES 26 | python create_accelerate_config.py --gpu_id "${DEVICE}" 27 | accelerate launch --config_file accelerate_config.yaml main.py \ 28 | --pretrained_model_name_or_path "${model_path}" \ 29 | --train_data_dir ${train_data_dir} \ 30 | --placeholder_tokens ${placeholder_tokens} \ 31 | --resolution=512 --class_folder_names ${class_folder_names} \ 32 | --train_batch_size=2 --gradient_accumulation_steps=8 --repeats 1 \ 33 | --learning_rate=5.0e-03 --scale_lr --lr_scheduler="constant" --max_train_steps 3000 \ 34 | --lr_warmup_steps=0 --output_dir ${output_dir} \ 35 | --learnable_property "${learnable_property}" \ 36 | --data "imagenet" --checkpointing_steps 500 --mse_coeff 1 --seed 1 \ 37 | --add_weight_per_score \ 38 | --use_conj_score --init_weight 0.2 \ 39 | --validation_step 500 \ 40 | --num_iters_per_image 1000 --num_images_per_class -1 -------------------------------------------------------------------------------- /scripts/train_imagenet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model_path="stabilityai/stable-diffusion-2-1-base" 4 | train_data_dir="imagenet,imagenet,imagenet,imagenet,imagenet" 5 | placeholder_tokens=",,,," 6 | class_folder_names="n09288635,n02085620,n02481823,n04204347,n03788195" 7 | learnable_property="object,object,object,object,object" 8 | output_dir="output/5_class_compositional_score_s1" 9 | 10 | model_path="stabilityai/stable-diffusion-2-1-base" 11 | train_data_dir="imagenet,imagenet,imagenet,imagenet,imagenet" 12 | placeholder_tokens=",,,," 13 | class_folder_names="n02364673,n04552348,n02980441,n02437616,n09472597" 14 | output_dir="output/5_class_compositional_score_s2" 15 | learnable_property="object,object,object,object,object" 16 | 17 | model_path="stabilityai/stable-diffusion-2-1-base" 18 | train_data_dir="imagenet,imagenet,imagenet,imagenet,imagenet" 19 | placeholder_tokens=",,,," 20 | class_folder_names="n03100240,n02317335,n04344873,n02504458,n04399382" 21 | output_dir="output/5_class_compositional_score_s3" 22 | learnable_property="object,object,object,object,object" 23 | 24 | model_path="stabilityai/stable-diffusion-2-1-base" 25 | train_data_dir="../improved_composable_diffusion/imagenet,../improved_composable_diffusion/imagenet,../improved_composable_diffusion/imagenet,../improved_composable_diffusion/imagenet,../improved_composable_diffusion/imagenet" 26 | placeholder_tokens=",,,," 27 | class_folder_names="n01882714,n02134084,n02391049,n02129604,n02510455" 28 | learnable_property="object,object,object,object,object" 29 | output_dir="output/5_class_compositional_score_s4" 30 | 31 | 32 | DEVICE=$CUDA_VISIBLE_DEVICES 33 | python create_accelerate_config.py --gpu_id "${DEVICE}" 34 | accelerate launch --config_file accelerate_config.yaml main.py \ 35 | --pretrained_model_name_or_path "${model_path}" \ 36 | --train_data_dir ${train_data_dir} \ 37 | --placeholder_tokens ${placeholder_tokens} \ 38 | --resolution=512 --class_folder_names ${class_folder_names} \ 39 | --train_batch_size=2 --gradient_accumulation_steps=8 --repeats 1 \ 40 | --learning_rate=5.0e-03 --scale_lr --lr_scheduler="constant" --max_train_steps 3000 \ 41 | --lr_warmup_steps=0 --output_dir ${output_dir} \ 42 | --learnable_property "${learnable_property}" \ 43 | --checkpointing_steps 1000 --mse_coeff 1 --seed 4 \ 44 | --add_weight_per_score \ 45 | --use_conj_score --init_weight 5 \ 46 | --validation_step 1000 \ 47 | --num_iters_per_image 120 --num_images_per_class 5 --------------------------------------------------------------------------------