├── .gitignore
├── README.md
├── accelerate_config.yaml
├── assets
├── ade20_decomp.png
├── ade20k_training_daam.png
├── art_decomposition.png
├── demo.gif
├── imagenet_decomposition.png
├── supp_ade20k_decomposition.png
├── supp_external_composition.png
└── teaser.png
├── compose_models.py
├── create_accelerate_config.py
├── daam_ddim_visualize.py
├── daam_visualize_generation.py
├── datasets.py
├── eval.py
├── inference.py
├── main.py
├── requirements.txt
└── scripts
├── evaluate_classification.sh
├── sample.sh
├── train_ADE20K.sh
├── train_art.sh
└── train_imagenet.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Unsupervised Compositional Concepts Discovery
(ICCV 2023)
2 |
3 |
4 |
5 |
6 | https://github.com/nanlliu/Unsupervised-Compositional-Concepts-Discovery/assets/45443761/ca2504a2-2186-4edd-b37c-2b4e9c503a1b
7 | > Text-to-image generative models have enabled high-resolution image synthesis across different domains, but require users to specify the content they wish to generate. In this paper, we consider the inverse problem -- given a collection of different images, can we discover the generative concepts that represent each image? We present an unsupervised approach to discover generative concepts from a collection of images, disentangling different art styles in paintings, objects, and lighting from kitchen scenes, and discovering image classes given ImageNet images. We show how such generative concepts can accurately represent the content of images, be recombined and composed to generate new artistic and hybrid images, and be further used as a representation for downstream classification tasks.
8 |
9 | [Unsupervised Compositional Concepts Discovery with Text-to-Image Generative Models](https://energy-based-model.github.io/unsupervised-concept-discovery)
10 |
11 | [Nan Liu](https://nanliu.io) 1*,
12 | [Yilun Du](https://yilundu.github.io) 2*,
13 | [Shuang Li](https://people.csail.mit.edu/lishuang) 2*,
14 | [Joshua B. Tenenbaum](https://mitibmwatsonailab.mit.edu/people/joshua-tenenbaum/) 2,
15 | [Antonio Torralba](https://groups.csail.mit.edu/vision/torralbalab/) 2
16 |
17 | * Equal Contribution
18 |
19 | 1UIUC, 2MIT CSAIL
20 |
21 | ICCV 2023
22 |
23 |
24 |
25 | ## Setup
26 |
27 | Run following to create a conda environment, and activate it:
28 |
29 | conda create --name decomp python=3.8
30 | conda activate decomp
31 |
32 | Install Pytorch 1.11, the version we have tested on:
33 |
34 | pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
35 |
36 | Next, install the required packages:
37 |
38 | pip install -r requirements.txt
39 |
40 |
41 | ## Training
42 |
43 | After downloading images (e.g., ImageNet class folders) into the `repo directory`, you first specify some data arguments:
44 |
45 | model_path="stabilityai/stable-diffusion-2-1-base"
46 | train_data_dir="imagenet,imagenet,imagenet,imagenet,imagenet"
47 | placeholder_tokens=",,,,"
48 | class_folder_names="n09288635,n02085620,n02481823,n04204347,n03788195"
49 | learnable_property="object,object,object,object,object"
50 | output_dir=$OUTPUT_DIR
51 |
52 | Then you can train a model on ImageNet $S_1$ by running:
53 |
54 | DEVICE=$CUDA_VISIBLE_DEVICES
55 | python create_accelerate_config.py --gpu_id "${DEVICE}"
56 | accelerate launch --config_file accelerate_config.yaml main.py \
57 | --pretrained_model_name_or_path "${model_path}" \
58 | --train_data_dir ${train_data_dir} \
59 | --placeholder_tokens ${placeholder_tokens} \
60 | --resolution=512 --class_folder_names ${class_folder_names} \
61 | --train_batch_size=2 --gradient_accumulation_steps=8 --repeats 1 \
62 | --learning_rate=5.0e-03 --scale_lr --lr_scheduler="constant" --max_train_steps 3000 \
63 | --lr_warmup_steps=0 --output_dir ${output_dir} \
64 | --learnable_property "${learnable_property}" \
65 | --checkpointing_steps 1000 --mse_coeff 1 --seed 0 \
66 | --add_weight_per_score \
67 | --use_conj_score --init_weight 5 \
68 | --validation_step 1000 \
69 | --num_iters_per_image 120 --num_images_per_class 5
70 |
71 | See more training details about training in ```scripts``` folder.
72 |
73 | Training dataset for artistic paintings can be found [here](https://www.dropbox.com/sh/g91atmeuy2ihkgn/AAATmXz6zI9H4fLCnsEG-CRka?dl=0).
74 |
75 | ## Inference
76 | 
77 |
78 | Once model is trained, we can sample from each concept using following command:
79 |
80 | python inference.py --model_path ${output_dir} --prompts "a photo of " "a photo of " "a photo of " "a photo of " "a photo of " --num_images 64 --bsz 8
81 |
82 | ## Visualization
83 |
84 | You need to download required packages from [DAAM](https://github.com/castorini/daam).
85 |
86 | DDIM + DAAM (Training Data) | DAAM (Generated Data)
87 | :-------------------------:|:-------------------------:
88 |  | 
89 |
90 | Once concepts are learned, we can generate cross-attention heap maps between training image and learned concepts using diffusion attentive attribution maps ([DAAM](https://github.com/castorini/daam)) for visualization:
91 |
92 | python daam_ddim_visualize.py --model_path ${output_dir} --image_path $IMG_PATH --prompt "a photo of " --keyword "" --scale 7.5 --num_images 1 --seed 0
93 |
94 | Generate images along with heap map visualizations associated with each learned concept using DAAM:
95 |
96 | python daam_visualize_generation.py --model_path ${output_dir} --prompt "a photo of " --keyword "" --scale 7.5 --num_images 1 --seed 0
97 |
98 | ## Evaluation
99 |
100 | After we generate 64 images per concept, we can run following command to evaluate classification accuracy and KL divergence using pre-trained ResNet-50 and CLIP with specified threshold values:
101 |
102 | # CLIP
103 | python eval.py --model_path ${output_dir} --evaluation_metric clip --class_names "geyser" "chihuahua" "chimpanzee" "shopping cart" "mosque" --logit_threshold 0.3
104 | # ResNet-50
105 | python eval.py --model_path ${output_dir} --evaluation_metric resnet --class_names "geyser" "chihuahua" "chimpanzee" "shopping cart" "mosque" --logit_threshold 10
106 |
107 |
108 | ## Unsupervised Decomposition
109 |
110 | ### ImageNet Objects
111 | Our proposed method can discover different object categories from a set of unlabeled images.
112 | 
113 |
114 | ### Scene Concepts
115 | We demonstrate our method can decompose kitchen scenes into multiple sets of factors.
116 |
117 | 
118 |
119 | ### Art Concepts
120 | Our method allows unsupervised concept decomposition from just a few paintings.
121 |
122 | 
123 |
124 | ## Concept Composition
125 |
126 | Discovered concepts can be further combined with existing knowledge (i.e., texts) of the pretrained generative models.
127 |
128 | 
129 |
130 | ## Todo
131 |
132 | - [ ] Add support to other available models such as deepfloyd and StableDiffusionXL.
133 |
134 |
135 |
136 | ## Citation
137 | @InProceedings{Liu_2023_ICCV,
138 | author = {Liu, Nan and Du, Yilun and Li, Shuang and Tenenbaum, Joshua B. and Torralba, Antonio},
139 | title = {Unsupervised Compositional Concepts Discovery with Text-to-Image Generative Models},
140 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
141 | month = {October},
142 | year = {2023},
143 | pages = {2085-2095}
144 | }
145 |
146 | Feel free to let us know if we are missing any relevant citations.
147 |
--------------------------------------------------------------------------------
/accelerate_config.yaml:
--------------------------------------------------------------------------------
1 | command_file: null
2 | commands: null
3 | compute_environment: LOCAL_MACHINE
4 | deepspeed_config: {}
5 | distributed_type: 'NO'
6 | downcast_bf16: 'no'
7 | dynamo_backend: 'NO'
8 | fsdp_config: {}
9 | gpu_ids: '6'
10 | machine_rank: 0
11 | main_process_ip: null
12 | main_process_port: null
13 | main_training_function: main
14 | megatron_lm_config: {}
15 | mixed_precision: 'no'
16 | num_machines: 1
17 | num_processes: 1
18 | rdzv_backend: static
19 | same_network: true
20 | tpu_name: null
21 | tpu_zone: null
22 |
--------------------------------------------------------------------------------
/assets/ade20_decomp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nanlliu/Unsupervised-Compositional-Concepts-Discovery/03f557e747a1608b3ffc77abe1ef691baf2ea054/assets/ade20_decomp.png
--------------------------------------------------------------------------------
/assets/ade20k_training_daam.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nanlliu/Unsupervised-Compositional-Concepts-Discovery/03f557e747a1608b3ffc77abe1ef691baf2ea054/assets/ade20k_training_daam.png
--------------------------------------------------------------------------------
/assets/art_decomposition.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nanlliu/Unsupervised-Compositional-Concepts-Discovery/03f557e747a1608b3ffc77abe1ef691baf2ea054/assets/art_decomposition.png
--------------------------------------------------------------------------------
/assets/demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nanlliu/Unsupervised-Compositional-Concepts-Discovery/03f557e747a1608b3ffc77abe1ef691baf2ea054/assets/demo.gif
--------------------------------------------------------------------------------
/assets/imagenet_decomposition.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nanlliu/Unsupervised-Compositional-Concepts-Discovery/03f557e747a1608b3ffc77abe1ef691baf2ea054/assets/imagenet_decomposition.png
--------------------------------------------------------------------------------
/assets/supp_ade20k_decomposition.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nanlliu/Unsupervised-Compositional-Concepts-Discovery/03f557e747a1608b3ffc77abe1ef691baf2ea054/assets/supp_ade20k_decomposition.png
--------------------------------------------------------------------------------
/assets/supp_external_composition.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nanlliu/Unsupervised-Compositional-Concepts-Discovery/03f557e747a1608b3ffc77abe1ef691baf2ea054/assets/supp_external_composition.png
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nanlliu/Unsupervised-Compositional-Concepts-Discovery/03f557e747a1608b3ffc77abe1ef691baf2ea054/assets/teaser.png
--------------------------------------------------------------------------------
/compose_models.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import inspect
4 | import argparse
5 |
6 | from tqdm import tqdm
7 | from PIL import Image
8 |
9 | from diffusers import (
10 | AutoencoderKL,
11 | DDIMScheduler,
12 | UNet2DConditionModel,
13 | )
14 |
15 | from transformers import CLIPTextModel, CLIPTokenizer
16 | from typing import List, Optional, Tuple, Union
17 |
18 |
19 | def randn_tensor(
20 | shape: Union[Tuple, List],
21 | generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
22 | device: Optional["torch.device"] = None,
23 | dtype: Optional["torch.dtype"] = None,
24 | layout: Optional["torch.layout"] = None,
25 | ):
26 | """This is a helper function that allows to create random tensors on the desired `device` with the desired `dtype`. When
27 | passing a list of generators one can seed each batched size individually. If CPU generators are passed the tensor
28 | will always be created on CPU.
29 | """
30 | # device on which tensor is created defaults to device
31 | rand_device = device
32 | batch_size = shape[0]
33 |
34 | layout = layout or torch.strided
35 | device = device or torch.device("cpu")
36 |
37 | if generator is not None:
38 | gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
39 | if gen_device_type != device.type and gen_device_type == "cpu":
40 | rand_device = "cpu"
41 | # if device != "mps":
42 | # logger.info(
43 | # f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
44 | # f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
45 | # f" slighly speed up this function by passing a generator that was created on the {device} device."
46 | # )
47 | elif gen_device_type != device.type and gen_device_type == "cuda":
48 | raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
49 |
50 | if isinstance(generator, list):
51 | shape = (1,) + shape[1:]
52 | latents = [
53 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout)
54 | for i in range(batch_size)
55 | ]
56 | latents = torch.cat(latents, dim=0).to(device)
57 | else:
58 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
59 |
60 | return latents
61 |
62 |
63 | def get_batched_text_embeddings(tokenizer, text_encoder, prompt, batch_size):
64 | device = text_encoder.device
65 | text_inputs = tokenizer(
66 | prompt,
67 | padding="max_length",
68 | max_length=tokenizer.model_max_length,
69 | truncation=True,
70 | return_tensors="pt",
71 | )
72 | text_input_ids = text_inputs.input_ids
73 | text_embeddings = text_encoder(text_input_ids.to(device))[0]
74 | bs_embed, seq_len, _ = text_embeddings.shape
75 | text_embeddings = text_embeddings.repeat(1, batch_size, 1).view(bs_embed * batch_size, seq_len, -1)
76 | return text_embeddings
77 |
78 |
79 | def prepare_latents(vae_scale_factor, init_noise_sigma, batch_size,
80 | num_channels_latents, height, width, dtype, device, generator, latents=None):
81 | shape = (batch_size, num_channels_latents, height // vae_scale_factor, width // vae_scale_factor)
82 | if isinstance(generator, list) and len(generator) != batch_size:
83 | raise ValueError(
84 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
85 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
86 | )
87 |
88 | if latents is None:
89 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
90 | else:
91 | latents = latents.to(device)
92 |
93 | # scale the initial noise by the standard deviation required by the scheduler
94 | latents = latents * init_noise_sigma
95 | return latents
96 |
97 |
98 | def prepare_extra_step_kwargs(generator, scheduler, eta):
99 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
100 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
101 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
102 | # and should be between [0, 1]
103 |
104 | accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys())
105 | extra_step_kwargs = {}
106 | if accepts_eta:
107 | extra_step_kwargs["eta"] = eta
108 |
109 | # check if the scheduler accepts generator
110 | accepts_generator = "generator" in set(inspect.signature(scheduler.step).parameters.keys())
111 | if accepts_generator:
112 | extra_step_kwargs["generator"] = generator
113 | return extra_step_kwargs
114 |
115 |
116 | def decode_latents(vae, latents):
117 | latents = 1 / 0.18215 * latents
118 | image = vae.decode(latents).sample
119 | image = (image / 2 + 0.5).clamp(0, 1)
120 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
121 | image = image.cpu().permute(0, 2, 3, 1).float().numpy()
122 | return image
123 |
124 |
125 | def numpy_to_pil(images):
126 | """
127 | Convert a numpy image or a batch of images to a PIL image.
128 | """
129 | if images.ndim == 3:
130 | images = images[None, ...]
131 | images = (images * 255).round().astype("uint8")
132 | pil_images = [Image.fromarray(image) for image in images]
133 |
134 | return pil_images
135 |
136 |
137 | def main():
138 | parser = argparse.ArgumentParser()
139 | parser.add_argument("--model_paths", type=str, nargs="+")
140 | parser.add_argument("--prompts", type=str, nargs="+")
141 | parser.add_argument("--bsz", type=int, default=1)
142 | parser.add_argument("--num_images", type=int, default=1)
143 | parser.add_argument("--steps", type=int, default=50)
144 | parser.add_argument("--scales", type=float, nargs="+")
145 | parser.add_argument("--eta", type=float, default=0)
146 | parser.add_argument("--seed", type=int, default=0)
147 | parser.add_argument("--folder", type=str)
148 | args = parser.parse_args()
149 |
150 | # load models
151 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
152 | generator = torch.Generator(device).manual_seed(args.seed)
153 |
154 | tokenizers, text_encoders, vaes, unets = [], [], [], []
155 | noise_scheduler = DDIMScheduler.from_pretrained(args.model_paths[0], subfolder="scheduler")
156 |
157 | for model_path in args.model_paths:
158 | tokenizers.append(CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer"))
159 | text_encoders.append(CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder").to(device))
160 | vaes.append(AutoencoderKL.from_pretrained(model_path, subfolder="vae").to(device))
161 | unets.append(UNet2DConditionModel.from_pretrained(model_path, subfolder="unet").to(device))
162 | print(f"finished loading from {model_path}")
163 |
164 | # sampling
165 | batch_size = args.bsz
166 | num_batches = args.num_images // args.bsz
167 | steps = args.steps
168 | scales = args.scales
169 | eta = args.eta
170 | vae_scale_factor = 2 ** (len(vaes[0].config.block_out_channels) - 1)
171 | init_noise_sigma = noise_scheduler.init_noise_sigma
172 | num_channels_latents = unets[0].in_channels
173 | height = unets[0].config.sample_size * vae_scale_factor
174 | width = unets[0].config.sample_size * vae_scale_factor
175 | image_folder = args.folder
176 | os.makedirs(image_folder, exist_ok=True)
177 |
178 | with torch.no_grad():
179 | for batch_num in range(num_batches):
180 | # 1. set the noise scheduler
181 | noise_scheduler.set_timesteps(args.steps, device=device)
182 | timesteps = noise_scheduler.timesteps
183 | # 2. initialize the latents
184 | latents = prepare_latents(
185 | vae_scale_factor,
186 | init_noise_sigma,
187 | batch_size,
188 | num_channels_latents,
189 | height,
190 | width,
191 | text_encoders[0].dtype,
192 | device,
193 | generator,
194 | latents=None
195 | )
196 |
197 | frames = []
198 | num_warmup_steps = len(timesteps) - steps * noise_scheduler.order
199 | with tqdm(total=steps) as progress_bar:
200 | for i, t in enumerate(timesteps):
201 | latent_model_input = torch.cat([latents] * 2)
202 | latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t)
203 |
204 | uncond_scores, cond_scores = [], []
205 | for prompt, tokenizer, text_encoder, unet, vae in zip(args.prompts, tokenizers, text_encoders,
206 | unets, vaes):
207 | # 3. get the input embeddings
208 | text_embeddings = get_batched_text_embeddings(tokenizer, text_encoder, prompt, batch_size)
209 | null_embeddings = get_batched_text_embeddings(tokenizer, text_encoder, "", batch_size)
210 | input_embeddings = torch.cat((null_embeddings, text_embeddings), dim=0)
211 |
212 | # predict the noise residual
213 | noise_pred = unet(latent_model_input, t, encoder_hidden_states=input_embeddings).sample
214 | uncond_pred_noise, cond_pred_noise = noise_pred.chunk(2)
215 | # save predicted scores
216 | uncond_scores.append(uncond_pred_noise)
217 | cond_scores.append(cond_pred_noise)
218 |
219 | # apply compositional score
220 | composed_noise_pred = sum(uncond_scores) / len(uncond_scores) + sum(
221 | scale * (cond_score - uncond_score) for scale, cond_score, uncond_score in
222 | zip(scales, cond_scores, uncond_scores))
223 |
224 | # compute the previous noisy sample x_t -> x_t-1
225 | extra_step_kwargs = prepare_extra_step_kwargs(generator, noise_scheduler, eta)
226 | latents = noise_scheduler.step(composed_noise_pred, t, latents, **extra_step_kwargs).prev_sample
227 |
228 | # save intermediate results
229 | decoded_images = decode_latents(vae, latents)
230 | frames.append(decoded_images)
231 |
232 | images = decode_latents(vae, latents)
233 | images = numpy_to_pil(images)
234 |
235 | # call the callback, if provided
236 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % noise_scheduler.order == 0):
237 | progress_bar.update()
238 |
239 | images = decode_latents(vae, latents)
240 | images = numpy_to_pil(images)
241 |
242 | for j, img in enumerate(images):
243 | img_path = os.path.join(image_folder,
244 | f"{args.prompts}_{batch_num * batch_size + j}_{scales}_{args.seed}.png")
245 | img.save(img_path)
246 |
247 |
248 | if __name__ == "__main__":
249 | main()
250 |
--------------------------------------------------------------------------------
/create_accelerate_config.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import argparse
3 |
4 | parser = argparse.ArgumentParser()
5 | parser.add_argument("--gpu_id", type=str, required=True)
6 | parser.add_argument("--distributed", action="store_true", default=False)
7 | args = parser.parse_args()
8 |
9 | dict_file = {
10 | 'command_file': None,
11 | 'commands': None,
12 | 'compute_environment': 'LOCAL_MACHINE',
13 | 'deepspeed_config': {},
14 | 'distributed_type': 'NO',
15 | 'downcast_bf16': 'no',
16 | 'dynamo_backend': 'NO',
17 | 'fsdp_config': {},
18 | 'gpu_ids': args.gpu_id,
19 | 'machine_rank': 0,
20 | 'main_process_ip': None,
21 | 'main_process_port': None,
22 | 'main_training_function': 'main',
23 | 'megatron_lm_config': {},
24 | 'mixed_precision': 'no',
25 | 'num_machines': 1,
26 | 'num_processes': 1,
27 | 'rdzv_backend': 'static',
28 | 'same_network': True,
29 | 'tpu_name': None,
30 | 'tpu_zone': None,
31 | }
32 |
33 | dist_dict_file = {
34 | 'command_file': None,
35 | 'commands': None,
36 | 'compute_environment': 'LOCAL_MACHINE',
37 | 'deepspeed_config': {},
38 | 'distributed_type': 'MULTI_GPU',
39 | 'downcast_bf16': 'no',
40 | 'dynamo_backend': 'NO',
41 | 'fsdp_config': {},
42 | 'gpu_ids': args.gpu_id,
43 | 'machine_rank': 0,
44 | 'main_process_ip': None,
45 | 'main_process_port': None,
46 | 'main_training_function': 'main',
47 | 'megatron_lm_config': {},
48 | 'mixed_precision': 'fp16',
49 | 'num_machines': 1,
50 | 'num_processes': len(args.gpu_id.split(',')),
51 | 'rdzv_backend': 'static',
52 | 'same_network': True,
53 | 'tpu_name': None,
54 | 'tpu_zone': None,
55 | 'use_cpu': False,
56 | }
57 |
58 | if args.distributed:
59 | with open('accelerate_config.yaml', 'w') as file:
60 | documents = yaml.dump(dist_dict_file, file)
61 | else:
62 | with open('accelerate_config.yaml', 'w') as file:
63 | documents = yaml.dump(dict_file, file)
64 |
--------------------------------------------------------------------------------
/daam_ddim_visualize.py:
--------------------------------------------------------------------------------
1 | """https://github.com/cccntu/efficient-prompt-to-prompt/blob/main/ddim-inversion.ipynb"""
2 | from functools import partial
3 | from diffusers import StableDiffusionPipeline
4 |
5 | from PIL import Image
6 | from torchvision import transforms
7 |
8 | from typing import Callable, List, Optional, Union
9 |
10 | import torch
11 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
12 |
13 | from diffusers.models import AutoencoderKL, UNet2DConditionModel
14 | from diffusers.pipeline_utils import DiffusionPipeline
15 | from diffusers.pipelines.stable_diffusion.safety_checker import \
16 | StableDiffusionSafetyChecker
17 | from diffusers.schedulers import DDIMScheduler,PNDMScheduler, LMSDiscreteScheduler
18 | from diffusers.utils import logging
19 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
20 | import argparse
21 | from daam import trace, set_seed
22 |
23 |
24 | def backward_ddim(x_t, alpha_t, alpha_tm1, eps_xt):
25 | """ from noise to image"""
26 | return (
27 | alpha_tm1**0.5
28 | * (
29 | (alpha_t**-0.5 - alpha_tm1**-0.5) * x_t
30 | + ((1 / alpha_tm1 - 1) ** 0.5 - (1 / alpha_t - 1) ** 0.5) * eps_xt
31 | )
32 | + x_t
33 | )
34 |
35 |
36 | def forward_ddim(x_t, alpha_t, alpha_tp1, eps_xt):
37 | """ from image to noise, it's the same as backward_ddim"""
38 | return backward_ddim(x_t, alpha_t, alpha_tp1, eps_xt)
39 |
40 |
41 | class DDIMPipeline(DiffusionPipeline):
42 | def __init__(
43 | self,
44 | vae: AutoencoderKL,
45 | text_encoder: CLIPTextModel,
46 | tokenizer: CLIPTokenizer,
47 | unet: UNet2DConditionModel,
48 | scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
49 | safety_checker: StableDiffusionSafetyChecker = None,
50 | feature_extractor: CLIPFeatureExtractor = None,
51 | ):
52 | super().__init__()
53 |
54 | self.register_modules(
55 | vae=vae,
56 | text_encoder=text_encoder,
57 | tokenizer=tokenizer,
58 | unet=unet,
59 | scheduler=scheduler,
60 | safety_checker=safety_checker,
61 | feature_extractor=feature_extractor,
62 | )
63 | self.vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
64 | self.forward_diffusion = partial(self.backward_diffusion, reverse_process=True)
65 |
66 | def run_safety_checker(self, image, device, dtype):
67 | if self.safety_checker is not None:
68 | safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
69 | image, has_nsfw_concept = self.safety_checker(
70 | images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
71 | )
72 | else:
73 | has_nsfw_concept = None
74 | return image, has_nsfw_concept
75 |
76 | @torch.inference_mode()
77 | def get_text_embedding(self, prompt):
78 | text_input_ids = self.tokenizer(
79 | prompt,
80 | padding="max_length",
81 | truncation=True,
82 | max_length=self.tokenizer.model_max_length,
83 | return_tensors="pt",
84 | ).input_ids
85 | text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
86 | return text_embeddings
87 |
88 | @torch.inference_mode()
89 | def get_image_latents(self, image, sample=True, rng_generator=None):
90 | encoding_dist = self.vae.encode(image).latent_dist
91 | if sample:
92 | encoding = encoding_dist.sample(generator=rng_generator)
93 | else:
94 | encoding = encoding_dist.mode()
95 | latents = encoding * 0.18215
96 | return latents
97 |
98 | @torch.inference_mode()
99 | def backward_diffusion(
100 | self,
101 | use_old_emb_i=25,
102 | prompt=None,
103 | text_embeddings=None,
104 | old_text_embeddings=None,
105 | new_text_embeddings=None,
106 | latents: Optional[torch.FloatTensor] = None,
107 | num_inference_steps: int = 50,
108 | guidance_scale: float = 7.5,
109 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
110 | callback_steps: Optional[int] = 1,
111 | reverse_process: True = False,
112 | **kwargs,
113 | ):
114 | """ Generate image from text prompt and latents
115 | """
116 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
117 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
118 | # corresponds to doing no classifier free guidance.
119 | do_classifier_free_guidance = guidance_scale > 1.0
120 | if text_embeddings is None:
121 | text_embeddings = self._encode_prompt(prompt, device=self.device,
122 | num_images_per_prompt=1,
123 | do_classifier_free_guidance=do_classifier_free_guidance)
124 | # set timesteps
125 | self.scheduler.set_timesteps(num_inference_steps)
126 | # Some schedulers like PNDM have timesteps as arrays
127 | # It's more optimized to move all timesteps to correct device beforehand
128 | timesteps_tensor = self.scheduler.timesteps.to(self.device)
129 | # scale the initial noise by the standard deviation required by the scheduler
130 | latents = latents * self.scheduler.init_noise_sigma
131 |
132 | if old_text_embeddings is not None and new_text_embeddings is not None:
133 | prompt_to_prompt = True
134 | else:
135 | prompt_to_prompt = False
136 |
137 | for i, t in enumerate(
138 | self.progress_bar(timesteps_tensor if not reverse_process else reversed(timesteps_tensor))):
139 | if prompt_to_prompt:
140 | if i < use_old_emb_i:
141 | text_embeddings = old_text_embeddings
142 | else:
143 | text_embeddings = new_text_embeddings
144 |
145 | # expand the latents if we are doing classifier free guidance
146 | latent_model_input = (
147 | torch.cat([latents] * 2) if do_classifier_free_guidance else latents
148 | )
149 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
150 |
151 | # predict the noise residual
152 | noise_pred = self.unet(
153 | latent_model_input, t, encoder_hidden_states=text_embeddings
154 | ).sample
155 |
156 | # perform guidance
157 | if do_classifier_free_guidance:
158 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
159 | noise_pred = noise_pred_uncond + guidance_scale * (
160 | noise_pred_text - noise_pred_uncond
161 | )
162 |
163 | prev_timestep = (
164 | t
165 | - self.scheduler.config.num_train_timesteps
166 | // self.scheduler.num_inference_steps
167 | )
168 | # call the callback, if provided
169 | if callback is not None and i % callback_steps == 0:
170 | callback(i, t, latents)
171 |
172 | # ddim
173 | alpha_prod_t = self.scheduler.alphas_cumprod[t]
174 | alpha_prod_t_prev = (
175 | self.scheduler.alphas_cumprod[prev_timestep]
176 | if prev_timestep >= 0
177 | else self.scheduler.final_alpha_cumprod
178 | )
179 | if reverse_process:
180 | alpha_prod_t, alpha_prod_t_prev = alpha_prod_t_prev, alpha_prod_t
181 | latents = backward_ddim(
182 | x_t=latents,
183 | alpha_t=alpha_prod_t,
184 | alpha_tm1=alpha_prod_t_prev,
185 | eps_xt=noise_pred,
186 | )
187 | return latents
188 |
189 | @torch.inference_mode()
190 | def decode_image(self, latents: torch.FloatTensor, **kwargs) -> List["PIL_IMAGE"]:
191 | scaled_latents = 1 / 0.18215 * latents
192 | image = [
193 | self.vae.decode(scaled_latents[i: i + 1]).sample for i in range(len(latents))
194 | ]
195 | image = torch.cat(image, dim=0)
196 | return image
197 |
198 | @torch.inference_mode()
199 | def torch_to_numpy(self, image) -> List["PIL_IMAGE"]:
200 | image = (image / 2 + 0.5).clamp(0, 1)
201 | image = image.cpu().permute(0, 2, 3, 1).numpy()
202 | return image
203 |
204 | def _encode_prompt(
205 | self,
206 | prompt,
207 | device,
208 | num_images_per_prompt,
209 | do_classifier_free_guidance,
210 | negative_prompt=None,
211 | prompt_embeds: Optional[torch.FloatTensor] = None,
212 | negative_prompt_embeds: Optional[torch.FloatTensor] = None,
213 | ):
214 | r"""
215 | Encodes the prompt into text encoder hidden states.
216 | Args:
217 | prompt (`str` or `List[str]`, *optional*):
218 | prompt to be encoded
219 | device: (`torch.device`):
220 | torch device
221 | num_images_per_prompt (`int`):
222 | number of images that should be generated per prompt
223 | do_classifier_free_guidance (`bool`):
224 | whether to use classifier free guidance or not
225 | negative_prompt (`str` or `List[str]`, *optional*):
226 | The prompt or prompts not to guide the image generation. If not defined, one has to pass
227 | `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
228 | Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
229 | prompt_embeds (`torch.FloatTensor`, *optional*):
230 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
231 | provided, text embeddings will be generated from `prompt` input argument.
232 | negative_prompt_embeds (`torch.FloatTensor`, *optional*):
233 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
234 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
235 | argument.
236 | """
237 | if prompt is not None and isinstance(prompt, str):
238 | batch_size = 1
239 | elif prompt is not None and isinstance(prompt, list):
240 | batch_size = len(prompt)
241 | else:
242 | batch_size = prompt_embeds.shape[0]
243 |
244 | if prompt_embeds is None:
245 | text_inputs = self.tokenizer(
246 | prompt,
247 | padding="max_length",
248 | max_length=self.tokenizer.model_max_length,
249 | truncation=True,
250 | return_tensors="pt",
251 | )
252 | text_input_ids = text_inputs.input_ids
253 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
254 |
255 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
256 | text_input_ids, untruncated_ids
257 | ):
258 | removed_text = self.tokenizer.batch_decode(
259 | untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]
260 | )
261 | logger.warning(
262 | "The following part of your input was truncated because CLIP can only handle sequences up to"
263 | f" {self.tokenizer.model_max_length} tokens: {removed_text}"
264 | )
265 |
266 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
267 | attention_mask = text_inputs.attention_mask.to(device)
268 | else:
269 | attention_mask = None
270 |
271 | prompt_embeds = self.text_encoder(
272 | text_input_ids.to(device),
273 | attention_mask=attention_mask,
274 | )
275 | prompt_embeds = prompt_embeds[0]
276 |
277 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
278 |
279 | bs_embed, seq_len, _ = prompt_embeds.shape
280 | # duplicate text embeddings for each generation per prompt, using mps friendly method
281 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
282 | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
283 |
284 | # get unconditional embeddings for classifier free guidance
285 | if do_classifier_free_guidance and negative_prompt_embeds is None:
286 | uncond_tokens: List[str]
287 | if negative_prompt is None:
288 | uncond_tokens = [""] * batch_size
289 | elif type(prompt) is not type(negative_prompt):
290 | raise TypeError(
291 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
292 | f" {type(prompt)}."
293 | )
294 | elif isinstance(negative_prompt, str):
295 | uncond_tokens = [negative_prompt]
296 | elif batch_size != len(negative_prompt):
297 | raise ValueError(
298 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
299 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
300 | " the batch size of `prompt`."
301 | )
302 | else:
303 | uncond_tokens = negative_prompt
304 |
305 | max_length = prompt_embeds.shape[1]
306 | uncond_input = self.tokenizer(
307 | uncond_tokens,
308 | padding="max_length",
309 | max_length=max_length,
310 | truncation=True,
311 | return_tensors="pt",
312 | )
313 |
314 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
315 | attention_mask = uncond_input.attention_mask.to(device)
316 | else:
317 | attention_mask = None
318 |
319 | negative_prompt_embeds = self.text_encoder(
320 | uncond_input.input_ids.to(device),
321 | attention_mask=attention_mask,
322 | )
323 | negative_prompt_embeds = negative_prompt_embeds[0]
324 |
325 | if do_classifier_free_guidance:
326 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
327 | seq_len = negative_prompt_embeds.shape[1]
328 |
329 | negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
330 |
331 | negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
332 | negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
333 |
334 | # For classifier free guidance, we need to do two forward passes.
335 | # Here we concatenate the unconditional and text embeddings into a single batch
336 | # to avoid doing two forward passes
337 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
338 |
339 | return prompt_embeds
340 |
341 |
342 | def load_img(path, target_size=512):
343 | """Load an image, resize and output -1..1"""
344 | image = Image.open(path).convert("RGB")
345 |
346 | tform = transforms.Compose(
347 | [
348 | transforms.Resize(target_size),
349 | transforms.CenterCrop(target_size),
350 | transforms.ToTensor(),
351 | ]
352 | )
353 | image = tform(image)
354 | return 2.0 * image - 1.0
355 |
356 |
357 | def latents_to_imgs(latents):
358 | x = pipe.decode_image(latents)
359 | x = pipe.torch_to_numpy(x)
360 | x = pipe.numpy_to_pil(x)
361 | return x
362 |
363 |
364 | if __name__ == '__main__':
365 | parser = argparse.ArgumentParser()
366 | parser.add_argument("--model_path", type=str)
367 | parser.add_argument("--image_path", type=str)
368 | parser.add_argument("--prompt", type=str)
369 | parser.add_argument("--keyword", type=str)
370 | parser.add_argument("--seed", type=int, default=0)
371 | parser.add_argument("--scale", type=float, default=1)
372 | args = parser.parse_args()
373 |
374 | pipe = StableDiffusionPipeline.from_pretrained(args.model_path)
375 | pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
376 | pipe = pipe.to("cuda")
377 |
378 | img = load_img(args.image_path).unsqueeze(0).to("cuda")
379 | pipe2 = DDIMPipeline(
380 | vae=pipe.vae,
381 | text_encoder=pipe.text_encoder,
382 | tokenizer=pipe.tokenizer,
383 | unet=pipe.unet,
384 | scheduler=pipe.scheduler,
385 | safety_checker=pipe.safety_checker
386 | )
387 | pipe = pipe2
388 |
389 | prompt = args.prompt
390 | text_embeddings = pipe.get_text_embedding(prompt)
391 | image_latents = pipe.get_image_latents(img, rng_generator=torch.Generator(device=pipe.device).manual_seed(args.seed))
392 | reversed_latents = pipe.forward_diffusion(
393 | latents=image_latents,
394 | text_embeddings=text_embeddings,
395 | guidance_scale=1,
396 | num_inference_steps=50,
397 | )
398 |
399 | with torch.cuda.amp.autocast(dtype=torch.float16), torch.no_grad():
400 | with trace(pipe) as tc:
401 | reconstructed_latents = pipe.backward_diffusion(
402 | latents=reversed_latents,
403 | prompt=prompt,
404 | guidance_scale=args.scale,
405 | num_inference_steps=50,
406 | )
407 | img = latents_to_imgs(reconstructed_latents)[0]
408 | img.save(f"{prompt}_{args.seed}_reconstruction.png")
409 | heat_map = tc.compute_global_heat_map()
410 | heat_map = heat_map.compute_word_heat_map(keyword)
411 | heat_map.plot_overlay(img, out_file=f"{keyword}_heatmap_{args.seed}.png", word=None)
412 | img.save(f"{prompt}_{args.seed}.png")
413 |
--------------------------------------------------------------------------------
/daam_visualize_generation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | import argparse
5 |
6 | from PIL import Image
7 | from daam import trace, set_seed
8 | from matplotlib import pyplot as plt
9 | from diffusers import StableDiffusionPipeline, DDIMScheduler
10 |
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument("--model_path", type=str)
14 | parser.add_argument("--prompt", type=str)
15 | parser.add_argument("--keyword", type=str)
16 | parser.add_argument("--scale", type=float, default=7.5)
17 | parser.add_argument("--num_images", type=int, default=1)
18 | args = parser.parse_args()
19 |
20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21 | pipe = StableDiffusionPipeline.from_pretrained(args.model_path).to(device)
22 | pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
23 | pipe.safety_checker = None
24 | pipe = pipe.to(device)
25 |
26 | prompt = args.prompt
27 | keyword = args.keyword
28 | gen = set_seed(0) # for reproducibility
29 |
30 |
31 | def numpy_to_pil(images):
32 | """
33 | Convert a numpy image or a batch of images to a PIL image.
34 | """
35 | if images.ndim == 3:
36 | images = images[None, ...]
37 | images = (images * 255).round().astype("uint8")
38 | pil_images = [Image.fromarray(image) for image in images]
39 |
40 | return pil_images
41 |
42 |
43 | def daam_visualize():
44 | with torch.cuda.amp.autocast(dtype=torch.float16), torch.no_grad():
45 | with trace(pipe) as tc:
46 | for i in range(args.num_images):
47 | out, _ = pipe(prompt, num_inference_steps=50, generator=gen, guidance_scale=args.scale)
48 | img = out.images[0]
49 | heat_map = tc.compute_global_heat_map()
50 | heat_map = heat_map.compute_word_heat_map(keyword)
51 | heat_map.plot_overlay(img, out_file=f"{keyword}_heatmap_{i}.png", word=None)
52 | img.save(f"{prompt}_{i}.png")
53 |
54 |
55 | if __name__ == '__main__':
56 | daam_visualize()
57 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import json
4 | import torch
5 | import numpy as np
6 | import os.path
7 | import pickle
8 |
9 | from PIL import Image
10 | from torch.utils.data import Dataset
11 | from torchvision import transforms
12 | from transformers import CLIPTokenizer
13 |
14 | from typing import Any, Callable, Optional, Tuple
15 | from torchvision.datasets.vision import VisionDataset
16 | from torchvision.datasets.utils import check_integrity, download_and_extract_archive
17 |
18 | PIL_INTERPOLATION = {
19 | "linear": Image.Resampling.BILINEAR,
20 | "bilinear": Image.Resampling.BILINEAR,
21 | "bicubic": Image.Resampling.BICUBIC,
22 | "lanczos": Image.Resampling.LANCZOS,
23 | "nearest": Image.Resampling.NEAREST,
24 | }
25 |
26 | imagenet_templates_small = [
27 | "a photo of a {}",
28 | "a rendering of a {}",
29 | "a cropped photo of the {}",
30 | "the photo of a {}",
31 | "a photo of a clean {}",
32 | "a photo of a dirty {}",
33 | "a dark photo of the {}",
34 | "a photo of my {}",
35 | "a photo of the cool {}",
36 | "a close-up photo of a {}",
37 | "a bright photo of the {}",
38 | "a cropped photo of a {}",
39 | "a photo of the {}",
40 | "a good photo of the {}",
41 | "a photo of one {}",
42 | "a close-up photo of the {}",
43 | "a rendition of the {}",
44 | "a photo of the clean {}",
45 | "a rendition of a {}",
46 | "a photo of a nice {}",
47 | "a good photo of a {}",
48 | "a photo of the nice {}",
49 | "a photo of the small {}",
50 | "a photo of the weird {}",
51 | "a photo of the large {}",
52 | "a photo of a cool {}",
53 | "a photo of a small {}",
54 | ]
55 |
56 | imagenet_style_templates_small = [
57 | "a painting in the style of {}",
58 | "a rendering in the style of {}",
59 | "a cropped painting in the style of {}",
60 | "the painting in the style of {}",
61 | "a clean painting in the style of {}",
62 | "a dirty painting in the style of {}",
63 | "a dark painting in the style of {}",
64 | "a picture in the style of {}",
65 | "a cool painting in the style of {}",
66 | "a close-up painting in the style of {}",
67 | "a bright painting in the style of {}",
68 | "a cropped painting in the style of {}",
69 | "a good painting in the style of {}",
70 | "a close-up painting in the style of {}",
71 | "a rendition in the style of {}",
72 | "a nice painting in the style of {}",
73 | "a small painting in the style of {}",
74 | "a weird painting in the style of {}",
75 | "a large painting in the style of {}",
76 | ]
77 |
78 |
79 | class ComposableDataset(Dataset):
80 | def __init__(
81 | self,
82 | data_root,
83 | tokenizer,
84 | size=512,
85 | repeats=100,
86 | interpolation="bicubic",
87 | flip_p=0.5,
88 | set="train",
89 | placeholder_tokens="",
90 | center_crop=False,
91 | num_images_per_class=-1,
92 | class_folder_names="",
93 | learnable_property="",
94 | ):
95 | self.data_root = [x.strip() for x in data_root.split(",")]
96 | self.class_folder_names = [x.strip() for x in class_folder_names.split(",")]
97 |
98 | self.tokenizer = tokenizer
99 | self.size = size
100 | self.placeholder_tokens = [x.strip() for x in placeholder_tokens.split(",")]
101 | self.placeholder_tokens_ids = tokenizer.convert_tokens_to_ids(self.placeholder_tokens)
102 |
103 | self.center_crop = center_crop
104 | self.flip_p = flip_p
105 |
106 | # use textual inversion template - assume objects
107 | self.learnable_property = (x.strip() for x in learnable_property.split(","))
108 | self.templates = [imagenet_templates_small if x == "object" else imagenet_style_templates_small
109 | for x in self.learnable_property]
110 | self.use_template = learnable_property != ""
111 |
112 | # combine all folders into a single folder
113 | self.image_paths, self.classes = [], []
114 | total_images = max(len(self.placeholder_tokens) * num_images_per_class,
115 | len(self.class_folder_names) * num_images_per_class)
116 |
117 | images_per_folder = total_images // len(self.class_folder_names)
118 |
119 | for class_id, class_name in enumerate(self.class_folder_names):
120 | folder = os.path.join(self.data_root[class_id], class_name)
121 | folder_image_paths = [os.path.join(folder, file_name) for file_name in os.listdir(folder)]
122 | # reduce the size of images from each category if specified
123 | if num_images_per_class != -1:
124 | train_image_path = folder_image_paths[:images_per_folder]
125 | test_image_path = folder_image_paths[images_per_folder:2 * images_per_folder]
126 | if set == "train":
127 | folder_image_paths = train_image_path
128 | else:
129 | folder_image_paths = test_image_path
130 |
131 | self.image_paths.extend(folder_image_paths)
132 | self.classes.extend([class_id] * len(folder_image_paths))
133 |
134 | # size is the total images from different folders
135 | self.num_images = len(self.image_paths)
136 | self._length = self.num_images
137 |
138 | print("placeholder_tokens: ", self.placeholder_tokens)
139 | print("placeholder_tokens_ids: ", self.placeholder_tokens_ids)
140 | print("the number of images in this dataset: ", self.num_images)
141 | print("the flag for using the template: ", self.use_template)
142 |
143 | if set == "train":
144 | self._length = self.num_images * repeats
145 |
146 | self.interpolation = {
147 | "linear": PIL_INTERPOLATION["linear"],
148 | "bilinear": PIL_INTERPOLATION["bilinear"],
149 | "bicubic": PIL_INTERPOLATION["bicubic"],
150 | "lanczos": PIL_INTERPOLATION["lanczos"],
151 | }[interpolation]
152 |
153 | self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
154 |
155 | def __len__(self):
156 | return self._length
157 |
158 | def __getitem__(self, i):
159 | idx = i % self.num_images
160 | example = dict()
161 | image = Image.open(self.image_paths[idx])
162 | # image.save(f"{self.placeholder_tokens[self.classes[idx]]}_{i}.png")
163 |
164 | if not image.mode == "RGB":
165 | image = image.convert("RGB")
166 |
167 | if self.use_template:
168 | text = [random.choice(self.templates[self.classes[idx]]).format(x) for x in self.placeholder_tokens]
169 | else:
170 | text = self.placeholder_tokens # use token itself as the caption (unsupervised)
171 |
172 | # encode all classes since we will use all of them to compute composed score
173 | example["input_ids"] = self.tokenizer(
174 | text,
175 | padding="max_length",
176 | truncation=True,
177 | max_length=self.tokenizer.model_max_length,
178 | return_tensors="pt",
179 | ).input_ids
180 |
181 | # default to score-sde preprocessing
182 | img = np.array(image).astype(np.uint8)
183 |
184 | if self.center_crop:
185 | crop = min(img.shape[0], img.shape[1])
186 | h, w, = (
187 | img.shape[0],
188 | img.shape[1],
189 | )
190 | img = img[(h - crop) // 2: (h + crop) // 2, (w - crop) // 2: (w + crop) // 2]
191 |
192 | image = Image.fromarray(img)
193 | image = image.resize((self.size, self.size), resample=self.interpolation)
194 |
195 | image = self.flip_transform(image)
196 | image = np.array(image).astype(np.uint8)
197 | image = (image / 127.5 - 1.0).astype(np.float32)
198 |
199 | example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
200 | example["gt_weight_id"] = idx
201 | example["image_path"] = self.image_paths[idx]
202 | example["image_index"] = idx
203 | return example
204 |
205 |
206 | class ClassificationDataset(Dataset):
207 | def __init__(
208 | self,
209 | image_path_map,
210 | learned_weights,
211 | image_path_to_class,
212 | tokenizer,
213 | encoder,
214 | size=512,
215 | repeats=100,
216 | interpolation="bicubic",
217 | flip_p=0.5,
218 | set="train",
219 | placeholder_tokens="",
220 | center_crop=False,
221 | learnable_property="",
222 | ):
223 | self.image_path_map = image_path_map
224 | self.learned_weights = learned_weights
225 | self.image_path_to_class = image_path_to_class
226 |
227 | self.tokenizer = tokenizer
228 | self.encoder = encoder
229 |
230 | self.size = size
231 | self.placeholder_tokens = [x.strip() for x in placeholder_tokens.split(",")]
232 | self.placeholder_tokens_ids = tokenizer.convert_tokens_to_ids(self.placeholder_tokens)
233 |
234 | self.center_crop = center_crop
235 | self.flip_p = flip_p
236 |
237 | # use textual inversion template - assume objects
238 | self.learnable_property = (x.strip() for x in learnable_property.split(","))
239 | self.templates = [imagenet_templates_small if x == "object" else imagenet_style_templates_small
240 | for x in self.learnable_property]
241 | self.use_template = learnable_property != ""
242 |
243 | # size is the total images from different folders
244 | self.num_images = len(self.image_path_map)
245 | self._length = self.num_images
246 |
247 | print("placeholder_tokens: ", self.placeholder_tokens)
248 | print("placeholder_tokens_ids: ", self.placeholder_tokens_ids)
249 | print("the number of images in this dataset: ", self.num_images)
250 | print("the flag for using the template: ", self.use_template)
251 |
252 | if set == "train":
253 | self._length = self.num_images * repeats
254 |
255 | self.interpolation = {
256 | "linear": PIL_INTERPOLATION["linear"],
257 | "bilinear": PIL_INTERPOLATION["bilinear"],
258 | "bicubic": PIL_INTERPOLATION["bicubic"],
259 | "lanczos": PIL_INTERPOLATION["lanczos"],
260 | }[interpolation]
261 |
262 | self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
263 |
264 | def __len__(self):
265 | return self._length
266 |
267 | def __getitem__(self, i):
268 | idx = i % self.num_images
269 | image = Image.open(self.image_path_map[idx])
270 |
271 | if not image.mode == "RGB":
272 | image = image.convert("RGB")
273 |
274 | if self.use_template:
275 | text = [random.choice(self.templates[self.classes[idx]]).format(x) for x in self.placeholder_tokens]
276 | else:
277 | text = self.placeholder_tokens # use token itself as the caption (unsupervised)
278 |
279 | # default to score-sde preprocessing
280 | img = np.array(image).astype(np.uint8)
281 |
282 | if self.center_crop:
283 | crop = min(img.shape[0], img.shape[1])
284 | h, w, = (
285 | img.shape[0],
286 | img.shape[1],
287 | )
288 | img = img[(h - crop) // 2: (h + crop) // 2, (w - crop) // 2: (w + crop) // 2]
289 |
290 | image = Image.fromarray(img)
291 | image = image.resize((self.size, self.size), resample=self.interpolation)
292 |
293 | image = self.flip_transform(image)
294 | image = np.array(image).astype(np.uint8)
295 | image = (image / 127.5 - 1.0).astype(np.float32)
296 |
297 | # encode all classes since we will use all of them to compute composed score
298 | input_ids = self.tokenizer(
299 | text,
300 | padding="max_length",
301 | truncation=True,
302 | max_length=self.tokenizer.model_max_length,
303 | return_tensors="pt",
304 | ).input_ids
305 |
306 | example = {}
307 | example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
308 | example["embeddings"] = self.encoder(input_ids)[0]
309 | example["weights"] = self.learned_weights[idx]
310 | example["class"] = self.image_path_to_class[idx]
311 | return example
312 |
313 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | """
2 | modified based upon https://github.com/rinongal/textual_inversion/blob/main/evaluation/clip_eval.py
3 | """
4 | import argparse
5 | import os
6 |
7 | import clip
8 | import torch
9 | from torchvision import transforms
10 | from resnet import resnet50
11 | from PIL import Image
12 |
13 | from typing import List
14 | from torchmetrics.functional import kl_divergence
15 |
16 |
17 | class CLIPEvalutor:
18 | def __init__(self, target_classes_names: List, device, clip_model='ViT-B/32'):
19 | self.device = device
20 | self.model, self.clip_preprocess = clip.load(clip_model, device=self.device)
21 | self.texts = [f"a photo of {class_name}" for class_name in target_classes_names]
22 |
23 | def tokenize(self, strings: list):
24 | return clip.tokenize(strings).to(self.device)
25 |
26 | @torch.no_grad()
27 | def encode_text(self, tokens: list) -> torch.Tensor:
28 | return self.model.encode_text(tokens)
29 |
30 | @torch.no_grad()
31 | def get_text_features(self, text, norm: bool = True) -> torch.Tensor:
32 | tokens = clip.tokenize(text).to(self.device)
33 | text_features = self.encode_text(tokens)
34 | if norm:
35 | text_features /= text_features.norm(dim=-1, keepdim=True)
36 | return text_features
37 |
38 | @torch.no_grad()
39 | def encode_images(self, image) -> torch.Tensor:
40 | images = self.clip_preprocess(image).to(self.device).unsqueeze(dim=0)
41 | return self.model.encode_image(images)
42 |
43 | def evaluate(self, img_folder, threshold=0.8):
44 | images_path = [os.path.join(img_folder, filename) for filename in os.listdir(img_folder)]
45 | with torch.no_grad():
46 | image_features = torch.cat([self.encode_images(Image.open(path)) for path in images_path], dim=0)
47 | image_features /= image_features.norm(dim=-1, keepdim=True)
48 | text_features = self.get_text_features(self.texts, norm=True)
49 | similarity = image_features @ text_features.T
50 |
51 | max_vals, indices = torch.max(similarity, dim=1)
52 | # bincount only computes frequency for non-negative values
53 | counts = torch.bincount(indices, minlength=text_features.shape[0])
54 | coverages = []
55 | # when computing the coverage, we don't threshold the values
56 | for i, caption in enumerate(self.texts):
57 | # print(f"class: {caption} | "
58 | # f"coverage: {counts[i] / image_features.shape[0] * 100}%")
59 | coverages.append(counts[i] / image_features.shape[0])
60 | # if the probability of predicted class is < threshold, count it as misclassified
61 | num_misclassified_examples = torch.sum(max_vals < threshold)
62 | # p = target distribution, q = modeled distribution
63 | p = torch.tensor([1 / len(coverages)] * len(coverages))
64 | q = torch.tensor(coverages)
65 | acc = (sum(counts) - num_misclassified_examples) / image_features.shape[0]
66 | kl_entropy = kl_divergence(p[None], q[None])
67 | print(f"{img_folder}, Avg Acc: {100 * acc.item():.2f}, KL entropy: {kl_entropy.item():.4f}")
68 |
69 |
70 | class ResNetEvaluator:
71 | def __init__(self, target_classes_names: List, device):
72 | # replace batch norm with identity function
73 | self.device = device
74 | self.model = resnet50(pretrained=True, progress=True).to(self.device)
75 | # disable batch norm
76 | self.model.eval()
77 | self.transform = transforms.Compose([
78 | transforms.CenterCrop(224),
79 | transforms.ToTensor(),
80 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
81 | std=[0.229, 0.224, 0.225]),
82 | ])
83 |
84 | with open('imagenet_classes.txt') as f:
85 | self.labels = [line.strip() for line in f.readlines()]
86 |
87 | # if given class names, logits will be selected based on the given class names for evaluations
88 | if target_classes_names:
89 | self.target_labels = []
90 | self.target_label_names = []
91 | for i, label_name in enumerate(self.labels):
92 | for target_class in target_classes_names:
93 | if target_class.lower() in label_name.lower():
94 | self.target_labels.append(i)
95 | self.target_label_names.append(label_name)
96 |
97 | assert len(self.target_labels) == len(target_classes_names), \
98 | "the number of found labels are not the same as the given ones"
99 | print(self.target_labels)
100 | print(self.target_label_names)
101 | self.target_labels = sorted(self.target_labels)
102 | else:
103 | self.target_labels = list(range(1000))
104 |
105 | def evaluate(self, img_folder, threshold=0.8):
106 | images_path = [os.path.join(img_folder, filename) for filename in os.listdir(img_folder)]
107 | images = torch.stack([self.transform(Image.open(path)) for path in images_path], dim=0).to(self.device)
108 | # predictions
109 | with torch.no_grad():
110 | pred = self.model(images)[:, self.target_labels]
111 | max_vals, indices = torch.max(pred, dim=1)
112 | # bincount only computes frequency for non-negative values
113 | counts = torch.bincount(indices, minlength=len(self.target_labels))
114 | coverages = []
115 | # when computing the coverage, we don't threshold the values
116 | for i, index in enumerate(self.target_labels):
117 | # print(f"class: {self.labels[index]} | "
118 | # f"coverage: {counts[i] / pred.shape[0] * 100}%")
119 | coverages.append((counts[i] + 1e-8) / pred.shape[0])
120 |
121 | # if the probability of predicted class is < threshold, count it as misclassified
122 | num_misclassified_examples = torch.sum(max_vals < threshold)
123 | # p = target distribution, q = modeled distribution
124 | p = torch.tensor([1 / len(coverages)] * len(coverages))
125 | q = torch.tensor(coverages)
126 | acc = (sum(counts) - num_misclassified_examples) / pred.shape[0]
127 | kl_entropy = kl_divergence(p[None], q[None])
128 | print(f"{img_folder}, Avg Acc: {100 * acc.item():.2f}, KL entropy: {kl_entropy.item():.4f}")
129 |
130 |
131 | if __name__ == '__main__':
132 | parser = argparse.ArgumentParser()
133 | parser.add_argument("--model_path", type=str)
134 | parser.add_argument("--seed", type=int, default=0)
135 | parser.add_argument("--scale", type=float, default=7.5)
136 | parser.add_argument("--steps", type=int, default=50)
137 | parser.add_argument("--samples", type=int, default=64)
138 | parser.add_argument("--model", type=str, choices=["textual_inversion", "ours"])
139 | parser.add_argument("--evaluation_metric", choices=["clip", 'resnet'])
140 | parser.add_argument("--class_names", type=str, nargs="+")
141 | parser.add_argument("--logit_threshold", type=float)
142 | args = parser.parse_args()
143 |
144 | image_folder = os.path.join(args.model_path, "samples")
145 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
146 | if args.evaluation_metric == "clip":
147 | # build a ground truth captions for clip score
148 | evaluator = CLIPEvalutor(args.class_names, device=device, clip_model='ViT-B/32')
149 | # create folder where generated images are saved
150 | save_img_folder = os.path.join(args.model_path, "samples")
151 | evaluator.evaluate(image_folder, threshold=args.logit_threshold)
152 | else:
153 | evaluator = ResNetEvaluator(args.class_names, device=device)
154 | evaluator.evaluate(image_folder, threshold=args.logit_threshold)
155 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 |
5 | from diffusers import StableDiffusionPipeline, DDIMScheduler
6 |
7 |
8 | if __name__ == '__main__':
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument("--model_path", type=str)
11 | parser.add_argument("--prompts", type=str, nargs="+")
12 | parser.add_argument("--seed", type=int, default=0)
13 | parser.add_argument("--weights", type=str, default="7.5")
14 | parser.add_argument("--num_images", type=int, default=1)
15 | parser.add_argument("--bsz", type=int, default=1)
16 | parser.add_argument("--scale", type=float, default=7.5)
17 | parser.add_argument("--folder_name", type=str, default="samples")
18 | args = parser.parse_args()
19 |
20 | model_id = args.model_path
21 | pipe = StableDiffusionPipeline.from_pretrained(model_id).to("cuda")
22 | pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
23 | pipe.safety_checker = None
24 |
25 | folder = os.path.join(args.model_path, args.folder_name)
26 | os.makedirs(folder, exist_ok=True)
27 |
28 | generator = torch.Generator("cuda").manual_seed(args.seed)
29 | prompts = args.prompts
30 | weights = args.weights
31 |
32 | if prompts:
33 | batch_size = args.bsz
34 | num_batches = args.num_images // batch_size
35 | for prompt in prompts:
36 | for i in range(num_batches):
37 | image_list = pipe(prompt, num_inference_steps=50, guidance_scale=args.scale,
38 | generator=generator, num_images_per_prompt=batch_size)
39 | images = image_list.images
40 | for j, img in enumerate(images):
41 | img.save(os.path.join(folder, f"{prompt}_{weights}_{args.seed}_{i * batch_size + j}.png"))
42 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import itertools
3 | import math
4 | import os
5 | import random
6 | import json
7 |
8 | from pathlib import Path
9 | from typing import Optional
10 |
11 | import numpy as np
12 | import torch
13 | import torch.nn.functional
14 | import torch.nn.functional as F
15 | import torch.utils.checkpoint
16 | from torch.utils.data import Dataset
17 |
18 | import PIL
19 | from accelerate import Accelerator
20 | from accelerate.logging import get_logger
21 | from accelerate.utils import set_seed
22 | from diffusers.optimization import get_scheduler
23 | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
24 | from diffusers.utils.import_utils import is_xformers_available
25 | from huggingface_hub import HfFolder, Repository, whoami
26 |
27 | from diffusers import (
28 | AutoencoderKL,
29 | DDPMScheduler,
30 | DiffusionPipeline,
31 | DDIMScheduler,
32 | StableDiffusionPipeline,
33 | UNet2DConditionModel,
34 | )
35 |
36 | # TODO: remove and import from diffusers.utils when the new version of diffusers is released
37 | from packaging import version
38 | from PIL import Image
39 | from torchvision import transforms
40 | from tqdm.auto import tqdm
41 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
42 |
43 | from datasets import ComposableDataset
44 |
45 | if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
46 | PIL_INTERPOLATION = {
47 | "linear": PIL.Image.Resampling.BILINEAR,
48 | "bilinear": PIL.Image.Resampling.BILINEAR,
49 | "bicubic": PIL.Image.Resampling.BICUBIC,
50 | "lanczos": PIL.Image.Resampling.LANCZOS,
51 | "nearest": PIL.Image.Resampling.NEAREST,
52 | }
53 | else:
54 | PIL_INTERPOLATION = {
55 | "linear": PIL.Image.LINEAR,
56 | "bilinear": PIL.Image.BILINEAR,
57 | "bicubic": PIL.Image.BICUBIC,
58 | "lanczos": PIL.Image.LANCZOS,
59 | "nearest": PIL.Image.NEAREST,
60 | }
61 | # ------------------------------------------------------------------------------
62 |
63 |
64 | logger = get_logger(__name__)
65 |
66 |
67 | def save_progress(text_encoder, placeholder_token_id, accelerator, args):
68 | logger.info("Saving embeddings")
69 | learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
70 | learned_embeds_dict = {args.placeholder_tokens: learned_embeds.detach().cpu()}
71 | if args.test:
72 | embed_path = os.path.join(args.output_dir, "test_learned_embeds.bin")
73 | else:
74 | embed_path = os.path.join(args.output_dir, "learned_embeds.bin")
75 | torch.save(learned_embeds_dict, embed_path)
76 |
77 |
78 | def save_weights(weights, args):
79 | logger.info("Saving embeddings")
80 | learned_weights_dict = {"weights": weights.detach().cpu()}
81 | if args.test:
82 | weight_path = os.path.join(args.output_dir, "test_weights.bin")
83 | else:
84 | weight_path = os.path.join(args.output_dir, "weights.bin")
85 | torch.save(learned_weights_dict, weight_path)
86 |
87 |
88 | def parse_args():
89 | parser = argparse.ArgumentParser(description="Simple example of a training script.")
90 | parser.add_argument(
91 | "--save_steps",
92 | type=int,
93 | default=500,
94 | help="Save learned_embeds.bin every X updates steps.",
95 | )
96 | parser.add_argument(
97 | "--pretrained_model_name_or_path",
98 | type=str,
99 | default=None,
100 | required=True,
101 | help="Path to pretrained model or model identifier from huggingface.co/models.",
102 | )
103 | parser.add_argument(
104 | "--tokenizer_name",
105 | type=str,
106 | default=None,
107 | help="Pretrained tokenizer name or path if not the same as model_name",
108 | )
109 | parser.add_argument(
110 | "--train_data_dir", type=str, required=True,
111 | help="A list of folders containing the training data for each token provided."
112 | )
113 | parser.add_argument(
114 | "--placeholder_tokens",
115 | type=str,
116 | required=True,
117 | help="A list of tokens to use as placeholders for all the concepts, separated by comma",
118 | )
119 | parser.add_argument(
120 | "--initializer_tokens", type=str, default="",
121 | help="A list of tokens to use as initializer words, separated by comma"
122 | )
123 | parser.add_argument("--learnable_property", type=str, default="",
124 | help="a list of properties for all the tokens needed to be learned, separated by comma")
125 | parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
126 | parser.add_argument(
127 | "--output_dir",
128 | type=str,
129 | default="checkpoints",
130 | help="The output directory where the model predictions and checkpoints will be written.",
131 | )
132 | parser.add_argument(
133 | "--resume_dir",
134 | type=str,
135 | default="",
136 | help="The output directory where the model predictions and checkpoints will be written.",
137 | )
138 | parser.add_argument("--softmax_weights", action="store_true", default=False)
139 | parser.add_argument("--reuse_weights", action="store_true", default=False)
140 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
141 | parser.add_argument(
142 | "--resolution",
143 | type=int,
144 | default=512,
145 | help=(
146 | "The resolution for input images, all the images in the train/validation dataset will be resized to this"
147 | " resolution"
148 | ),
149 | )
150 | parser.add_argument(
151 | "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
152 | )
153 | parser.add_argument(
154 | "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
155 | )
156 | parser.add_argument("--num_train_epochs", type=int, default=100)
157 | parser.add_argument(
158 | "--max_train_steps",
159 | type=int,
160 | default=None,
161 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
162 | )
163 | parser.add_argument(
164 | "--gradient_accumulation_steps",
165 | type=int,
166 | default=1,
167 | help="Number of updates steps to accumulate before performing a backward/update pass.",
168 | )
169 | parser.add_argument(
170 | "--gradient_checkpointing",
171 | action="store_true",
172 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
173 | )
174 | parser.add_argument(
175 | "--learning_rate",
176 | type=float,
177 | default=1e-4,
178 | help="Initial learning rate (after the potential warmup period) to use.",
179 | )
180 | parser.add_argument(
181 | "--scale_lr",
182 | action="store_true",
183 | default=True,
184 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
185 | )
186 | parser.add_argument(
187 | "--lr_scheduler",
188 | type=str,
189 | default="constant",
190 | help=(
191 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
192 | ' "constant", "constant_with_warmup"]'
193 | ),
194 | )
195 | parser.add_argument(
196 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
197 | )
198 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
199 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
200 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
201 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
202 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
203 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
204 | parser.add_argument(
205 | "--hub_model_id",
206 | type=str,
207 | default=None,
208 | help="The name of the repository to keep in sync with the local `output_dir`.",
209 | )
210 | parser.add_argument(
211 | "--logging_dir",
212 | type=str,
213 | default="logs",
214 | help=(
215 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
216 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
217 | ),
218 | )
219 | parser.add_argument(
220 | "--mixed_precision",
221 | type=str,
222 | default="no",
223 | choices=["no", "fp16", "bf16"],
224 | help=(
225 | "Whether to use mixed precision. Choose"
226 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
227 | "and an Nvidia Ampere GPU."
228 | ),
229 | )
230 | parser.add_argument(
231 | "--allow_tf32",
232 | action="store_true",
233 | help=(
234 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
235 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
236 | ),
237 | )
238 | parser.add_argument(
239 | "--report_to",
240 | type=str,
241 | default="tensorboard",
242 | help=(
243 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
244 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
245 | ),
246 | )
247 | parser.add_argument(
248 | "--resume_from_checkpoint",
249 | type=str,
250 | default=None,
251 | help=(
252 | "Whether training should be resumed from a previous checkpoint. Use a path saved by"
253 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
254 | ),
255 | )
256 | parser.add_argument(
257 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
258 | )
259 | parser.add_argument(
260 | "--use_composed_score", action="store_true", default=False,
261 | help="whether to use composed score for textual inversion."
262 | )
263 | parser.add_argument(
264 | "--use_orthogonal_loss", action="store_true", default=False,
265 | help="should be enabled to get a better performance when using composed scores to invert text."
266 | )
267 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
268 | parser.add_argument(
269 | "--checkpointing_steps",
270 | type=int,
271 | default=500,
272 | help=(
273 | "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
274 | " training using `--resume_from_checkpoint`."
275 | ),
276 | )
277 |
278 | parser.add_argument("--data", type=str, default="imagenet")
279 | parser.add_argument("--class_folder_names", type=str,
280 | help="a list of imagenet data folders for each class, seperate by comma")
281 |
282 | parser.add_argument("--add_weight_per_score", action="store_true", default=False)
283 | parser.add_argument("--freeze_weights", action="store_true", default=False)
284 | parser.add_argument("--init_weight", type=float, default=1)
285 | parser.add_argument("--use_conj_score", action="store_true", default=False)
286 |
287 | parser.add_argument("--orthogonal_coeff", type=float, default=0.1)
288 | parser.add_argument("--squared_orthogonal_loss", action="store_true", default=False)
289 | parser.add_argument("--mse_coeff", type=float, default=1)
290 | parser.add_argument("--num_images_per_class", type=int, default=-1, help="-1 means all images considered")
291 |
292 | parser.add_argument("--weighted_sampling", action="store_true", default=False)
293 | parser.add_argument("--flip_weights", action="store_true", default=False)
294 |
295 | parser.add_argument("--text_loss", action="store_true", default=False)
296 | parser.add_argument("--text_angle_loss", action="store_true", default=False)
297 | parser.add_argument("--text_repulsion_loss", action="store_true", default=False)
298 | parser.add_argument("--text_repulsion_similarity_loss", action="store_true", default=False)
299 | parser.add_argument("--text_repulsion_coeff", type=float, default=0)
300 |
301 | parser.add_argument("--euclidean_dist_loss", action="store_true", default=False)
302 | parser.add_argument("--euclidean_dist_coeff", type=float, default=0)
303 |
304 | parser.add_argument("--use_similarity", action="store_true", default=False,
305 | help="Dot product between scores as the orthogonal loss")
306 | parser.add_argument("--use_euclidean_mhe", action="store_true", default=False,
307 | help="Minimum Hyperspherical Energy as the orthogonal loss")
308 | parser.add_argument("--log_mhe", action="store_true", default=False)
309 | parser.add_argument("--use_acos_mhe", action="store_true", default=False)
310 | parser.add_argument("--normalize_score", action="store_true", default=False)
311 | parser.add_argument("--use_weighted_score", action="store_true", default=False)
312 |
313 | parser.add_argument("--use_l2_norm_regularization", action="store_true", default=False)
314 | parser.add_argument("--l2_norm_coeff", type=float, default=0)
315 |
316 | parser.add_argument("--normalize_word", action="store_true", default=False)
317 | parser.add_argument("--num_iters_per_image", type=int, default=50)
318 | parser.add_argument("--hsic_loss", action="store_true", default=False)
319 | parser.add_argument("--test", action="store_true", default=False,
320 | help="enable this by only optimizing weights using existing models.")
321 |
322 | parser.add_argument(
323 | "--validation_step",
324 | type=int,
325 | default=100,
326 | help=(
327 | "Run validation every X steps. Validation consists of running the prompt"
328 | " `args.validation_prompt` multiple times: `args.num_validation_images`"
329 | " and logging the images."
330 | ),
331 | )
332 |
333 | parser.add_argument(
334 | "--revision",
335 | type=str,
336 | default=None,
337 | required=False,
338 | help="Revision of pretrained model identifier from huggingface.co/models.",
339 | )
340 |
341 |
342 | args = parser.parse_args()
343 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
344 | if env_local_rank != -1 and env_local_rank != args.local_rank:
345 | args.local_rank = env_local_rank
346 |
347 | if args.train_data_dir is None:
348 | raise ValueError("You must specify a train data directory.")
349 |
350 | return args
351 |
352 |
353 | def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
354 | if token is None:
355 | token = HfFolder.get_token()
356 | if organization is None:
357 | username = whoami(token)["name"]
358 | return f"{username}/{model_id}"
359 | else:
360 | return f"{organization}/{model_id}"
361 |
362 |
363 | def freeze_params(params):
364 | for param in params:
365 | param.requires_grad = False
366 |
367 |
368 | def main():
369 | args = parse_args()
370 | logging_dir = os.path.join(args.output_dir, args.logging_dir)
371 |
372 | accelerator = Accelerator(
373 | gradient_accumulation_steps=args.gradient_accumulation_steps,
374 | mixed_precision=args.mixed_precision,
375 | log_with="tensorboard",
376 | logging_dir=logging_dir,
377 | )
378 |
379 | # If passed along, set the training seed now.
380 | if args.seed is not None:
381 | set_seed(args.seed)
382 |
383 | # Handle the repository creation
384 | if accelerator.is_main_process:
385 | if args.push_to_hub:
386 | if args.hub_model_id is None:
387 | repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
388 | else:
389 | repo_name = args.hub_model_id
390 | repo = Repository(args.output_dir, clone_from=repo_name)
391 |
392 | with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
393 | if "step_*" not in gitignore:
394 | gitignore.write("step_*\n")
395 | if "epoch_*" not in gitignore:
396 | gitignore.write("epoch_*\n")
397 | elif args.output_dir is not None:
398 | os.makedirs(args.output_dir, exist_ok=True)
399 |
400 | if args.resume_from_checkpoint:
401 | args.pretrained_model_name_or_path = args.resume_dir
402 | print(f"resume everything from {args.pretrained_model_name_or_path}")
403 |
404 | # Load the tokenizer and add the placeholder token as a additional special token
405 | if args.tokenizer_name:
406 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
407 | elif args.pretrained_model_name_or_path:
408 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
409 |
410 | # Add the placeholder token in tokenizer
411 | placeholder_tokens = [x.strip() for x in args.placeholder_tokens.split(",")]
412 | num_added_tokens = tokenizer.add_tokens(placeholder_tokens)
413 |
414 | if num_added_tokens != 0 and num_added_tokens != len(placeholder_tokens):
415 | raise ValueError(
416 | f"The tokenizer already contains at least one of the tokens in {placeholder_tokens}. "
417 | f"Please pass a different placeholder_token` that is not already in the tokenizer."
418 | )
419 |
420 | # Convert the initializer_token, placeholder_token to ids
421 | if args.initializer_tokens != "":
422 | initializer_tokens = [x.strip() for x in args.initializer_tokens.split(",")]
423 | else:
424 | initializer_tokens = []
425 |
426 | if len(initializer_tokens) == 0:
427 | if args.resume_from_checkpoint:
428 | logger.info("* Resume the embeddings of placeholder tokens *")
429 | print("* Resume the embeddings of placeholder tokens *")
430 | else:
431 | logger.info("* Initialize the newly added placeholder token with the random embeddings *")
432 | print("* Initialize the newly added placeholder token with the random embeddings *")
433 | token_ids = tokenizer.convert_tokens_to_ids(placeholder_tokens)
434 | else:
435 | logger.info("* Initialize the newly added placeholder token with the embeddings of the initializer token *")
436 | print("* Initialize the newly added placeholder token with the embeddings of the initializer token *")
437 | token_ids = tokenizer.encode(initializer_tokens, add_special_tokens=False)
438 | # Check if initializer_token is a single token or a sequence of tokens
439 | if len(token_ids) > len(initializer_tokens):
440 | raise ValueError("The initializer token must be a single token.")
441 |
442 | initializer_token_ids = token_ids
443 | placeholder_token_ids = tokenizer.convert_tokens_to_ids(placeholder_tokens)
444 |
445 | # Load models and create wrapper for stable diffusion
446 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
447 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
448 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
449 |
450 | # Resize the token embeddings as we are adding new special tokens to the tokenizer
451 | text_encoder.resize_token_embeddings(len(tokenizer))
452 |
453 | # Initialise the newly added placeholder token with the embeddings of the initializer token
454 | token_embeds = text_encoder.get_input_embeddings().weight.data
455 | token_embeds[placeholder_token_ids] = token_embeds[initializer_token_ids]
456 | if args.normalize_word:
457 | token_embeds[placeholder_token_ids] = F.normalize(token_embeds[placeholder_token_ids], dim=1, p=2)
458 |
459 | # Freeze vae and unet
460 | freeze_params(vae.parameters())
461 | freeze_params(unet.parameters())
462 | # Freeze all parameters except for the token embeddings in text encoder
463 | params_to_freeze = itertools.chain(
464 | text_encoder.text_model.encoder.parameters(),
465 | text_encoder.text_model.final_layer_norm.parameters(),
466 | text_encoder.text_model.embeddings.position_embedding.parameters(),
467 | text_encoder.get_input_embeddings().parameters() if args.test else []
468 | )
469 | freeze_params(params_to_freeze)
470 |
471 | if args.gradient_checkpointing:
472 | # Keep unet in train mode if we are using gradient checkpointing to save memory.
473 | # The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.
474 | unet.train()
475 | text_encoder.gradient_checkpointing_enable()
476 | unet.enable_gradient_checkpointing()
477 |
478 | if args.enable_xformers_memory_efficient_attention:
479 | if is_xformers_available():
480 | unet.enable_xformers_memory_efficient_attention()
481 | else:
482 | raise ValueError("xformers is not available. Make sure it is installed correctly")
483 |
484 | # Enable TF32 for faster training on Ampere GPUs,
485 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
486 | if args.allow_tf32:
487 | torch.backends.cuda.matmul.allow_tf32 = True
488 |
489 | if args.scale_lr:
490 | args.learning_rate = (
491 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
492 | )
493 |
494 | train_dataset = ComposableDataset(
495 | data_root=args.train_data_dir,
496 | tokenizer=tokenizer,
497 | size=args.resolution,
498 | repeats=args.repeats,
499 | center_crop=args.center_crop,
500 | placeholder_tokens=args.placeholder_tokens,
501 | num_images_per_class=args.num_images_per_class,
502 | class_folder_names=args.class_folder_names,
503 | learnable_property=args.learnable_property,
504 | set="train" if not args.test else "val",
505 | )
506 |
507 | if args.add_weight_per_score:
508 | # Add a learnable weight for each token
509 | if args.resume_from_checkpoint and args.reuse_weights:
510 | weight_path = os.path.join(args.resume_dir, "weights.bin")
511 | concept_weights = torch.load(weight_path)["weights"]
512 | concept_weights.requires_grad = not args.freeze_weights
513 | concept_weights = torch.nn.Parameter(concept_weights, requires_grad=not args.freeze_weights)
514 | print('reusing the weights...')
515 | else:
516 | num_tokens = len(placeholder_token_ids)
517 | # create weight matrix NxMx1x1x1 where D is the number of images and M is the number of classes
518 | concept_weights = torch.tensor([args.init_weight] * num_tokens).reshape(1, -1, 1, 1, 1).float()
519 | if args.softmax_weights:
520 | concept_weights = F.softmax(concept_weights, dim=1)
521 | concept_weights = concept_weights.repeat(train_dataset.num_images, 1, 1, 1, 1)
522 | concept_weights.requires_grad = not args.freeze_weights
523 | concept_weights = torch.nn.Parameter(concept_weights, requires_grad=not args.freeze_weights)
524 |
525 | # Initialize the optimizer
526 | optimizer = torch.optim.AdamW(
527 | itertools.chain(
528 | text_encoder.get_input_embeddings().parameters() if not args.test else [],
529 | [concept_weights] if args.add_weight_per_score else []
530 | ),
531 | lr=args.learning_rate,
532 | betas=(args.adam_beta1, args.adam_beta2),
533 | weight_decay=args.adam_weight_decay,
534 | eps=args.adam_epsilon,
535 | )
536 |
537 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)
538 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
539 |
540 | # Scheduler and math around the number of training steps.
541 | overrode_max_train_steps = False
542 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
543 | if args.max_train_steps is None:
544 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
545 | overrode_max_train_steps = True
546 |
547 | lr_scheduler = get_scheduler(
548 | args.lr_scheduler,
549 | optimizer=optimizer,
550 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
551 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
552 | )
553 |
554 | text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
555 | text_encoder, optimizer, train_dataloader, lr_scheduler
556 | )
557 |
558 | weight_dtype = torch.float32
559 | if accelerator.mixed_precision == "fp16":
560 | weight_dtype = torch.float16
561 | elif accelerator.mixed_precision == "bf16":
562 | weight_dtype = torch.bfloat16
563 |
564 | # Move vae and unet to device
565 | vae.to(accelerator.device, dtype=weight_dtype)
566 | unet.to(accelerator.device, dtype=weight_dtype)
567 |
568 | args.max_train_steps = train_dataset.num_images * args.num_iters_per_image
569 | # We need to recalculate our total training steps as the size of the training dataloader may have changed.
570 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
571 | if overrode_max_train_steps:
572 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
573 | # Afterwards we recalculate our number of training epochs
574 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
575 |
576 | # We need to initialize the trackers we use, and also store our configuration.
577 | # The trackers initializes automatically on the main process.
578 | if accelerator.is_main_process:
579 | accelerator.init_trackers("checkpoints", config=vars(args))
580 |
581 | # Train!
582 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
583 | print(f'total_batch_size: {total_batch_size}')
584 |
585 | logger.info("***** Running training *****")
586 | logger.info(f" Num examples = {len(train_dataset)}")
587 | logger.info(f" Num Epochs = {args.num_train_epochs}")
588 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
589 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
590 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
591 | logger.info(f" Total optimization steps = {args.max_train_steps}")
592 | global_step = 0
593 | first_epoch = 0
594 |
595 | # Potentially load in the weights and states from a previous save
596 | if args.resume_from_checkpoint:
597 | if args.resume_from_checkpoint != "latest":
598 | path = os.path.basename(args.resume_from_checkpoint)
599 | else:
600 | # Get the most recent checkpoint
601 | dirs = os.listdir(args.resume_dir)
602 | dirs = [d for d in dirs if d.startswith("checkpoint")]
603 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
604 | path = dirs[-1] if len(dirs) > 0 else None
605 |
606 | if path is None:
607 | accelerator.print(
608 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
609 | )
610 | args.resume_from_checkpoint = None
611 | args.max_train_steps = train_dataset.num_images * args.num_iters_per_image
612 | else:
613 | if not args.test:
614 | accelerator.print(f"Resuming from checkpoint {path}")
615 | accelerator.load_state(os.path.join(args.resume_dir, path))
616 |
617 | global_step = int(path.split("-")[1])
618 |
619 | resume_global_step = global_step * args.gradient_accumulation_steps
620 | first_epoch = global_step // num_update_steps_per_epoch
621 | resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
622 | # update the number of iterations
623 | args.max_train_steps = global_step + train_dataset.num_images * args.num_iters_per_image
624 | # Afterwards we recalculate our number of training epochs
625 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
626 |
627 | # Only show the progress bar once on each machine.
628 | progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
629 | progress_bar.set_description("Steps")
630 |
631 | # keep original embeddings as reference
632 | orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()
633 |
634 | # iterate through the data and save dataset info
635 | dataset_info = {}
636 | for step, batch in enumerate(train_dataloader):
637 | image_path = batch["image_path"]
638 | image_idx = batch["image_index"]
639 | for i in range(len(image_path)):
640 | dataset_info[image_idx[i].item()] = image_path[i]
641 |
642 | if args.test:
643 | path = os.path.join(args.output_dir, "test_dataset_info.json")
644 | else:
645 | path = os.path.join(args.output_dir, "dataset_info.json")
646 |
647 | with open(path, "w") as f:
648 | json.dump(dataset_info, f)
649 |
650 | for epoch in range(first_epoch, args.num_train_epochs):
651 | text_encoder.train()
652 | for step, batch in enumerate(train_dataloader):
653 | # # Skip steps until we reach the resumed step
654 | if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
655 | if step % args.gradient_accumulation_steps == 0:
656 | progress_bar.update(1)
657 | continue
658 |
659 | mse_loss, orthogonal_loss, repulsion_loss, word_norm_loss, euclidean_dist_loss = 0, 0, 0, 0, 0
660 | with accelerator.accumulate(text_encoder):
661 | # image shape: Bx3xHxW
662 | # input_ids shape: BxMxD where M is the number of classes, D is the text dims
663 | pixel_value, input_ids = batch["pixel_values"], batch["input_ids"]
664 | weight_id = batch["gt_weight_id"]
665 | # split input ids into a list of BxD
666 | input_ids_list = [y.squeeze(dim=1) for y in input_ids.chunk(chunks=input_ids.shape[1], dim=1)]
667 |
668 | if args.use_composed_score:
669 | noise, uncond_noise_pred, noise_preds = None, None, []
670 | for input_ids in input_ids_list:
671 | # Convert images to latent space
672 | latents = vae.encode(pixel_value).latent_dist.sample().detach()
673 | latents = latents * 0.18215
674 |
675 | # Sample noise that we'll add to the latents
676 | if noise is None:
677 | noise = torch.randn(latents.shape).to(latents.device)
678 | bsz = latents.shape[0]
679 |
680 | # Sample a random timestep for each image
681 | if args.weighted_sampling:
682 | weights = torch.arange(1, noise_scheduler.config.num_train_timesteps + 1).float()
683 | if args.flip_weights:
684 | weights = weights.flip(dims=(0,))
685 | timesteps = torch.multinomial(weights, bsz).to(latents.device)
686 | else:
687 | timesteps = torch.randint(
688 | 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
689 | ).long()
690 |
691 | # Add noise to the latents according to the noise magnitude at each timestep
692 | # (this is the forward diffusion process)
693 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
694 |
695 | # Get the text embedding for conditioning
696 | encoder_hidden_states = text_encoder(input_ids)[0]
697 | # Predict the noise residual
698 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
699 | noise_preds.append(noise_pred)
700 |
701 | if uncond_noise_pred is None and args.use_conj_score:
702 | # precompute the unconditional text hidden states
703 | uncond_text_ids = tokenizer(
704 | "",
705 | padding="max_length",
706 | truncation=True,
707 | max_length=tokenizer.model_max_length,
708 | return_tensors="pt",
709 | ).input_ids.to(latents.device)
710 | B = noisy_latents.shape[0]
711 | uncond_encoder_hidden_states = text_encoder(uncond_text_ids)[0].repeat(B, 1, 1)
712 | uncond_noise_pred = unet(noisy_latents, timesteps, uncond_encoder_hidden_states).sample
713 |
714 | noise_preds_stack = torch.stack(noise_preds, dim=1) # BxMx4x64x64
715 | elif args.use_conj_score:
716 | # latents
717 | latents = vae.encode(pixel_value).latent_dist.sample().detach()
718 | latents = latents * 0.18215
719 | bsz = latents.shape[0]
720 | noise = torch.randn_like(latents)
721 |
722 | timesteps = torch.randint(
723 | 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
724 | ).long()
725 | # Add noise to the latents according to the noise magnitude at each timestep
726 | # (this is the forward diffusion process)
727 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
728 |
729 | weights = concept_weights[weight_id]
730 |
731 | cond_scores = []
732 | for input_ids in input_ids_list:
733 | encoder_hidden_state = text_encoder(input_ids)[0].to(dtype=weight_dtype)
734 | cond_scores.append(unet(noisy_latents, timesteps, encoder_hidden_state).sample)
735 | cond_scores = torch.stack(cond_scores, dim=1)
736 | uncond_text_ids = tokenizer(
737 | "",
738 | padding="max_length",
739 | truncation=True,
740 | max_length=tokenizer.model_max_length,
741 | return_tensors="pt",
742 | ).input_ids.to(latents.device)
743 | uncond_encoder_hidden_states = text_encoder(uncond_text_ids)[0].repeat(bsz, 1, 1)
744 | uncond_score = unet(noisy_latents, timesteps, uncond_encoder_hidden_states).sample
745 |
746 | # compute initial compositional score
747 | composed_score = uncond_score + torch.sum(weights.to(latents.device) * (cond_scores - uncond_score[:, None]), dim=1)
748 | # encoder_hidden_states = torch.stack(encoder_hidden_states_list, dim=1)
749 | # gt_encoder_states = (encoder_hidden_states * weights.to(latents.device)).sum(dim=1)
750 | mse_loss = args.mse_coeff * F.mse_loss(noise, composed_score.float(), reduction="mean")
751 |
752 | # orthogonal loss
753 | if args.use_orthogonal_loss:
754 | if args.use_similarity:
755 | B, M, C, H, W = cond_scores.shape
756 | ortho_scores_view = cond_scores.view(B, M, -1)
757 | prod_matrix = torch.bmm(ortho_scores_view, ortho_scores_view.transpose(2, 1)) / (C * H * W)
758 | # only compute the upper triangular matrices (exclude the diagonal)
759 | r, c = torch.triu_indices(M, M, offset=1)
760 | orthogonal_loss = args.orthogonal_coeff * (prod_matrix[:, r, c] ** 2).sum().sqrt()
761 | elif args.use_euclidean_mhe:
762 | B, M, C, H, W = cond_scores.shape
763 | ortho_scores_view = cond_scores.view(B, M, -1)
764 | batch_pair_wise_l2_dist = torch.cdist(ortho_scores_view, ortho_scores_view, p=2.0)
765 | # only compute the upper triangular matrices (exclude the diagonal)
766 | energy_matrix = torch.triu(batch_pair_wise_l2_dist, diagonal=1)
767 | energy_matrix = energy_matrix[energy_matrix != 0]
768 | if args.log_mhe:
769 | orthogonal_loss = torch.log(1 / energy_matrix).mean()
770 | else:
771 | orthogonal_loss = (1 / energy_matrix).mean()
772 | orthogonal_loss *= args.orthogonal_coeff
773 |
774 | if args.text_loss:
775 | word_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
776 | placeholder_token_ids]
777 | if args.text_repulsion_loss:
778 | if args.use_l2_norm_regularization:
779 | # word embeds are unnormalized.
780 | word_norm_loss = args.l2_norm_coeff * torch.norm(word_embeds, dim=1).mean(dim=0)
781 | if args.normalize_score:
782 | word_embeds = F.normalize(word_embeds, p=2, dim=1)
783 | word_dist_matrix = F.pdist(word_embeds, p=2)
784 | repulsion_loss = args.text_repulsion_coeff * torch.log(1 / word_dist_matrix).mean()
785 | elif args.text_repulsion_similarity_loss:
786 | num_words = word_embeds.shape[0]
787 | similarity = word_embeds @ word_embeds.T
788 | similarity = similarity[torch.triu_indices(num_words, num_words, offset=1).unbind()] ** 2.
789 | repulsion_loss = args.text_repulsion_coeff * similarity.sum().sqrt()
790 |
791 | if args.use_composed_score:
792 | # extract the corresponding weights for the batch of images
793 | if args.add_weight_per_score:
794 | weights = concept_weights[weight_id]
795 | if args.softmax_weights:
796 | weights = F.softmax(concept_weights[weight_id], dim=1)
797 | weighted_scores = noise_preds_stack * weights.to(latents.device)
798 | else:
799 | weighted_scores = noise_preds_stack / noise_preds_stack.shape[1]
800 |
801 | if args.use_conj_score:
802 | uncond_noise_pred = uncond_noise_pred[:, None]
803 | score = uncond_noise_pred + weights.to(latents.device) * (noise_preds_stack - uncond_noise_pred)
804 | composed_score = score.sum(dim=1)
805 | else:
806 | composed_score = weighted_scores.sum(dim=1)
807 |
808 | # TODO: MSE between classifier free score and noise doesn't make sense??
809 | mse_loss = args.mse_coeff * F.mse_loss(composed_score, noise, reduction="mean")
810 | # compute sum of pair wise dot product as the orthogonal loss
811 |
812 | if args.use_orthogonal_loss:
813 | if args.use_weighted_score:
814 | ortho_scores = weighted_scores
815 | else:
816 | ortho_scores = noise_preds_stack
817 | # assume number of classes: B > 1
818 | B, M, C, H, W = ortho_scores.shape
819 | ortho_scores_view = ortho_scores.view(B, M, -1)
820 | if args.normalize_score:
821 | ortho_scores_view = F.normalize(ortho_scores_view, p=2, dim=2)
822 |
823 | if args.use_similarity:
824 | prod_matrix = torch.bmm(ortho_scores_view, ortho_scores_view.transpose(2, 1))
825 | # only compute the upper triangular matrices (exclude the diagonal)
826 | num_pairs = math.factorial(M) / (math.factorial(2) * math.factorial(M - 2))
827 | ortho_matrix = (torch.triu(prod_matrix, diagonal=1) / (B * C * H * W))
828 | orthogonal_loss = args.orthogonal_coeff * ortho_matrix.sum() / num_pairs
829 | # r, c = torch.triu_indices(M, M, offset=1).unbind()
830 | # orthogonal_loss = args.orthogonal_coeff * (prod_matrix[:, r, c]).mean()
831 | elif args.use_euclidean_mhe:
832 | # Minimum Hyperspherical energy: norm
833 | # scale * sum_i^{N}sum_j^{N}_{i!=j} log(||w_i - w_j|| ** -
834 | batch_pair_wise_l2_dist = torch.cdist(ortho_scores_view, ortho_scores_view, p=2.0)
835 | # only compute the upper triangular matrices (exclude the diagonal)
836 | energy_matrix = torch.triu(batch_pair_wise_l2_dist, diagonal=1)
837 | energy_matrix = energy_matrix[energy_matrix != 0]
838 | orthogonal_loss = torch.log(1 / energy_matrix).mean()
839 | orthogonal_loss *= args.orthogonal_coeff
840 | elif args.use_acos_mhe:
841 | prod_matrix_1 = torch.bmm(ortho_scores_view, ortho_scores_view.transpose(2, 1))
842 | energy_matrix = torch.triu(prod_matrix_1, diagonal=1)
843 | energy_matrix = energy_matrix[energy_matrix != 0][..., None]
844 | energy_matrix = torch.acos(energy_matrix)
845 | orthogonal_loss = torch.log(1 / energy_matrix).sum(dim=energy_matrix.shape[1:]).mean(dim=0)
846 | orthogonal_loss *= args.orthogonal_coeff
847 | else:
848 | raise NotImplementedError
849 |
850 | if args.text_loss:
851 | word_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
852 | placeholder_token_ids]
853 | if args.text_repulsion_loss:
854 | if args.use_l2_norm_regularization:
855 | # word embeds are unnormalized.
856 | word_norm_loss = args.l2_norm_coeff * torch.norm(word_embeds, dim=1).mean(dim=0)
857 | if args.normalize_score:
858 | word_embeds = F.normalize(word_embeds, p=2, dim=1)
859 | word_dist_matrix = F.pdist(word_embeds, p=2)
860 | repulsion_loss = args.text_repulsion_coeff * torch.log(1 / word_dist_matrix).mean()
861 | elif args.text_repulsion_similarity_loss:
862 | if args.use_l2_norm_regularization:
863 | # word embeds are unnormalized.
864 | word_norm_loss = args.l2_norm_coeff * torch.norm(word_embeds, dim=1).mean(dim=0)
865 | N = word_embeds.shape[0]
866 | similarity = word_embeds @ word_embeds.T
867 | similarity = similarity[torch.triu_indices(N, N, offset=1).unbind()] ** 2.
868 | repulsion_loss = args.text_repulsion_coeff * similarity.mean()
869 |
870 | if args.euclidean_dist_loss:
871 | word_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
872 | placeholder_token_ids]
873 | euclidean_dist_loss = args.euclidean_dist_coeff * (1 / F.pdist(word_embeds, p=2)).mean()
874 |
875 | loss = mse_loss + orthogonal_loss + repulsion_loss + word_norm_loss + euclidean_dist_loss
876 |
877 | accelerator.backward(loss)
878 | optimizer.step()
879 | lr_scheduler.step()
880 | optimizer.zero_grad()
881 |
882 | # Let's make sure we don't update any embedding weights besides the newly added token
883 | if not args.test:
884 | index_no_updates = torch.ones(len(tokenizer), dtype=torch.bool)
885 | index_no_updates[placeholder_token_ids] = False
886 | if accelerator.num_processes > 1:
887 | grads = text_encoder.module.get_input_embeddings().weight.grad
888 | else:
889 | grads = text_encoder.get_input_embeddings().weight.grad
890 | # optimize all newly added tokens
891 | grads.data[index_no_updates, :] = grads.data[index_no_updates, :].fill_(0)
892 |
893 | with torch.no_grad():
894 | accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
895 | index_no_updates
896 | ] = orig_embeds_params[index_no_updates]
897 |
898 | # Checks if the accelerator has performed an optimization step behind the scenes
899 | if accelerator.sync_gradients:
900 | progress_bar.update(1)
901 | global_step += 1
902 | if global_step % args.save_steps == 0:
903 | save_progress(text_encoder, initializer_token_ids, accelerator, args)
904 | if args.add_weight_per_score:
905 | save_weights(concept_weights, args)
906 |
907 | if global_step % args.checkpointing_steps == 0:
908 | if accelerator.is_main_process:
909 | if not args.test:
910 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
911 | accelerator.save_state(save_path)
912 | logger.info(f"Saved state to {save_path}")
913 |
914 | logs = {"loss": loss.detach().item(),
915 | "lr": lr_scheduler.get_last_lr()[0],
916 | "mse_loss": mse_loss.item(),
917 | "ortho_loss": orthogonal_loss.item() if args.use_orthogonal_loss else 0,
918 | "word_repulsion_loss": repulsion_loss.item() if args.text_loss else 0,
919 | "euclidean_dist_loss": euclidean_dist_loss.item() if args.euclidean_dist_loss else 0,
920 | "word_norm_regularization": word_norm_loss.item() if args.use_l2_norm_regularization else 0,
921 | }
922 | progress_bar.set_postfix(**logs)
923 | accelerator.log(logs, step=global_step)
924 |
925 | if global_step >= args.max_train_steps:
926 | break
927 |
928 | if accelerator.sync_gradients and global_step % args.validation_step == 0:
929 | folder = os.path.join(args.output_dir, f'generated_samples_{global_step}')
930 | os.makedirs(folder, exist_ok=True)
931 | logger.info(
932 | f"Running validation..."
933 | )
934 | # create pipeline (note: unet and vae are loaded again in float32)
935 | pipeline = DiffusionPipeline.from_pretrained(
936 | args.pretrained_model_name_or_path,
937 | text_encoder=accelerator.unwrap_model(text_encoder),
938 | tokenizer=tokenizer,
939 | unet=unet,
940 | vae=vae,
941 | revision=args.revision,
942 | torch_dtype=weight_dtype,
943 | )
944 | pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
945 | pipeline = pipeline.to(accelerator.device)
946 | pipeline.set_progress_bar_config(disable=True)
947 |
948 | # run inference
949 | generator = (
950 | None if args.seed is None else
951 | torch.Generator(device=accelerator.device).manual_seed(args.seed)
952 | )
953 | images = []
954 | prompts = []
955 | if args.learnable_property != "":
956 | properties = [x.strip() for x in args.learnable_property.split(",")]
957 | else:
958 | properties = []
959 |
960 | if properties:
961 | for p, placeholder in zip(properties, placeholder_tokens):
962 | if p == "object":
963 | prompts.append(f"a photo of {placeholder}")
964 | else:
965 | prompts.append(f"a painting in the style of {placeholder}")
966 | else:
967 | for placeholder in placeholder_tokens:
968 | prompts.append(f"{placeholder}")
969 |
970 | for prompt in prompts:
971 | image_list = pipeline(prompt, guidance_scale=7.5,
972 | num_inference_steps=50, generator=generator)
973 | image = image_list.images[0]
974 | image.save(os.path.join(folder, f'{prompt}.png'))
975 | images.append(image)
976 |
977 | for tracker in accelerator.trackers:
978 | if tracker.name == "tensorboard":
979 | np_images = np.stack([np.asarray(img) for img in images])
980 | tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
981 | del pipeline
982 | torch.cuda.empty_cache()
983 |
984 | accelerator.wait_for_everyone()
985 |
986 | # Create the pipeline using the trained modules and save it.
987 | if accelerator.is_main_process and global_step % args.checkpointing_steps == 0 and not args.test:
988 | pipeline = StableDiffusionPipeline.from_pretrained(
989 | args.pretrained_model_name_or_path,
990 | text_encoder=accelerator.unwrap_model(text_encoder),
991 | vae=vae,
992 | unet=unet,
993 | tokenizer=tokenizer,
994 | )
995 | pipeline.save_pretrained(args.output_dir)
996 | # Also save the newly trained embeddings
997 | save_progress(text_encoder, initializer_token_ids, accelerator, args)
998 | if args.add_weight_per_score:
999 | save_weights(concept_weights, args)
1000 |
1001 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1002 | accelerator.save_state(save_path)
1003 | logger.info(f"Saved state to {save_path}")
1004 |
1005 | if args.push_to_hub:
1006 | repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
1007 |
1008 | accelerator.end_training()
1009 |
1010 |
1011 | if __name__ == "__main__":
1012 | main()
1013 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.15.0
2 | transformers==4.26.0
3 | diffusers==0.10.2
4 | ftfy
5 | regex
6 | tdqm
7 | tensorboard
8 | modelcards
9 | git+https://github.com/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch
10 | git+https://github.com/openai/CLIP.git
--------------------------------------------------------------------------------
/scripts/evaluate_classification.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | clip_logit=0.3
4 | resnet_logit=10
5 |
6 | python eval.py --model_path output/5_class_compositional_score_1_seed_0/ --evaluation_metric clip --class_names "geyser" "chihuahua" "chimpanzee" "shopping cart" "mosque" --logit_threshold $clip_logit
7 | # python eval.py --model_path output/5_class_compositional_score_2_seed_0/ --evaluation_metric clip --class_names "guinea pig" "warplane" "castle" "llama" "volcano" --logit_threshold $clip_logit
8 | # python eval.py --model_path output/5_class_compositional_score_3_seed_0/ --evaluation_metric clip --class_names "convertible" "starfish" "studio couch" "african elephant" "teddy" --logit_threshold $clip_logit
9 | # python eval.py --model_path output/5_class_compositional_score_4_seed_0/ --evaluation_metric clip --class_names "koala" "ice bear" "zebra" "tiger" "panda" --logit_threshold $clip_logit
10 |
11 | python eval.py --model_path output/5_class_compositional_score_1_seed_0/ --evaluation_metric resnet --class_names "geyser" "chihuahua" "chimpanzee" "shopping cart" "mosque" --logit_threshold $resnet_logit
12 | # python eval.py --model_path output/5_class_compositional_score_2_seed_0/ --evaluation_metric resnet --class_names "guinea pig" "warplane" "castle" "llama" "volcano" --logit_threshold $resnet_logit
13 | # python eval.py --model_path output/5_class_compositional_score_3_seed_0/ --evaluation_metric resnet --class_names "convertible" "starfish" "studio couch" "african elephant" "teddy" --logit_threshold $resnet_logit
14 | # python eval.py --model_path output/5_class_compositional_score_4_seed_0/ --evaluation_metric resnet --class_names "koala" "ice bear" "zebra" "tiger, Panthera tigris" "giant panda" --logit_threshold $resnet_logit
--------------------------------------------------------------------------------
/scripts/sample.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python inference.py --model_path output/5_class_compositional_score_1/ --prompts "a photo of " "a photo of " "a photo of " "a photo of " "a photo of " --num_images 64 --bsz 8
--------------------------------------------------------------------------------
/scripts/train_ADE20K.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | model_path="stabilityai/stable-diffusion-2-1-base"
4 | train_data_dir="ADE20K/home_or_hotel"
5 | placeholder_tokens=",,,,"
6 | class_folder_names="kitchen"
7 | output_dir="output/ADE20K_compositional_kitchen_5_concepts"
8 | learnable_property="object,object,object,object,object"
9 |
10 |
11 | DEVICE=$CUDA_VISIBLE_DEVICES
12 | python create_accelerate_config.py --gpu_id "${DEVICE}"
13 | accelerate launch --config_file accelerate_config.yaml main.py \
14 | --pretrained_model_name_or_path "${model_path}" \
15 | --train_data_dir ${train_data_dir} \
16 | --placeholder_tokens ${placeholder_tokens} \
17 | --resolution=512 --class_folder_names ${class_folder_names} \
18 | --train_batch_size=2 --gradient_accumulation_steps=8 --repeats 1 \
19 | --learning_rate=5.0e-03 --scale_lr --lr_scheduler="constant" --max_train_steps 3000 \
20 | --lr_warmup_steps=0 --output_dir ${output_dir} \
21 | --learnable_property ${learnable_property} \
22 | --data "imagenet" --checkpointing_steps 500 \
23 | --mse_coeff 1 --seed 1 \
24 | --add_weight_per_score \
25 | --use_conj_score --init_weight 0.2 \
26 | --validation_step 500 \
27 | --num_iters_per_image 100 --num_images_per_class 5 \
--------------------------------------------------------------------------------
/scripts/train_art.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | model_path="stabilityai/stable-diffusion-2-1-base"
4 | train_data_dir="style"
5 | placeholder_tokens=",,,,"
6 | class_folder_names="van_gogh"
7 | learnable_property=""
8 | output_dir="output/van_gogh"
9 |
10 | model_path="stabilityai/stable-diffusion-2-1-base"
11 | train_data_dir="style"
12 | placeholder_tokens=",,,,"
13 | class_folder_names="claude_monet_paintings"
14 | learnable_property=""
15 | output_dir="output/claude_monet_paintings"
16 |
17 | model_path="stabilityai/stable-diffusion-2-1-base"
18 | train_data_dir="style"
19 | placeholder_tokens=",,,,"
20 | class_folder_names="picasso"
21 | learnable_property=""
22 | output_dir="output/picasso_paintings"
23 |
24 |
25 | DEVICE=$CUDA_VISIBLE_DEVICES
26 | python create_accelerate_config.py --gpu_id "${DEVICE}"
27 | accelerate launch --config_file accelerate_config.yaml main.py \
28 | --pretrained_model_name_or_path "${model_path}" \
29 | --train_data_dir ${train_data_dir} \
30 | --placeholder_tokens ${placeholder_tokens} \
31 | --resolution=512 --class_folder_names ${class_folder_names} \
32 | --train_batch_size=2 --gradient_accumulation_steps=8 --repeats 1 \
33 | --learning_rate=5.0e-03 --scale_lr --lr_scheduler="constant" --max_train_steps 3000 \
34 | --lr_warmup_steps=0 --output_dir ${output_dir} \
35 | --learnable_property "${learnable_property}" \
36 | --data "imagenet" --checkpointing_steps 500 --mse_coeff 1 --seed 1 \
37 | --add_weight_per_score \
38 | --use_conj_score --init_weight 0.2 \
39 | --validation_step 500 \
40 | --num_iters_per_image 1000 --num_images_per_class -1
--------------------------------------------------------------------------------
/scripts/train_imagenet.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | model_path="stabilityai/stable-diffusion-2-1-base"
4 | train_data_dir="imagenet,imagenet,imagenet,imagenet,imagenet"
5 | placeholder_tokens=",,,,"
6 | class_folder_names="n09288635,n02085620,n02481823,n04204347,n03788195"
7 | learnable_property="object,object,object,object,object"
8 | output_dir="output/5_class_compositional_score_s1"
9 |
10 | model_path="stabilityai/stable-diffusion-2-1-base"
11 | train_data_dir="imagenet,imagenet,imagenet,imagenet,imagenet"
12 | placeholder_tokens=",,,,"
13 | class_folder_names="n02364673,n04552348,n02980441,n02437616,n09472597"
14 | output_dir="output/5_class_compositional_score_s2"
15 | learnable_property="object,object,object,object,object"
16 |
17 | model_path="stabilityai/stable-diffusion-2-1-base"
18 | train_data_dir="imagenet,imagenet,imagenet,imagenet,imagenet"
19 | placeholder_tokens=",,,,"
20 | class_folder_names="n03100240,n02317335,n04344873,n02504458,n04399382"
21 | output_dir="output/5_class_compositional_score_s3"
22 | learnable_property="object,object,object,object,object"
23 |
24 | model_path="stabilityai/stable-diffusion-2-1-base"
25 | train_data_dir="../improved_composable_diffusion/imagenet,../improved_composable_diffusion/imagenet,../improved_composable_diffusion/imagenet,../improved_composable_diffusion/imagenet,../improved_composable_diffusion/imagenet"
26 | placeholder_tokens=",,,,"
27 | class_folder_names="n01882714,n02134084,n02391049,n02129604,n02510455"
28 | learnable_property="object,object,object,object,object"
29 | output_dir="output/5_class_compositional_score_s4"
30 |
31 |
32 | DEVICE=$CUDA_VISIBLE_DEVICES
33 | python create_accelerate_config.py --gpu_id "${DEVICE}"
34 | accelerate launch --config_file accelerate_config.yaml main.py \
35 | --pretrained_model_name_or_path "${model_path}" \
36 | --train_data_dir ${train_data_dir} \
37 | --placeholder_tokens ${placeholder_tokens} \
38 | --resolution=512 --class_folder_names ${class_folder_names} \
39 | --train_batch_size=2 --gradient_accumulation_steps=8 --repeats 1 \
40 | --learning_rate=5.0e-03 --scale_lr --lr_scheduler="constant" --max_train_steps 3000 \
41 | --lr_warmup_steps=0 --output_dir ${output_dir} \
42 | --learnable_property "${learnable_property}" \
43 | --checkpointing_steps 1000 --mse_coeff 1 --seed 4 \
44 | --add_weight_per_score \
45 | --use_conj_score --init_weight 5 \
46 | --validation_step 1000 \
47 | --num_iters_per_image 120 --num_images_per_class 5
--------------------------------------------------------------------------------