├── triplets.csv ├── dataset └── data │ ├── 00000.jpg │ ├── 00001.JPG │ └── 00002.jpg ├── .gitmodules ├── config ├── analogy_params.yaml └── parameter_estimation.yaml ├── .gitignore ├── requirements.txt ├── LICENSE ├── process_new_data.py ├── visualize_tokens.py ├── precompute_noises_and_conditionings.py ├── estimate_input_noise.py ├── README.md ├── analogy_creator.py ├── estimate_CLIP_features.py ├── do_analogies.py ├── utils.py ├── DiffusionImageAnalogies.ipynb ├── ddim_invertor.py └── modified_clip_transformers.py /triplets.csv: -------------------------------------------------------------------------------- 1 | 00000 00001 00002 -------------------------------------------------------------------------------- /dataset/data/00000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/subrtadel/DIA/HEAD/dataset/data/00000.jpg -------------------------------------------------------------------------------- /dataset/data/00001.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/subrtadel/DIA/HEAD/dataset/data/00001.JPG -------------------------------------------------------------------------------- /dataset/data/00002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/subrtadel/DIA/HEAD/dataset/data/00002.jpg -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "stable-diffusion"] 2 | path = stable-diffusion 3 | url = https://github.com/CompVis/stable-diffusion.git 4 | -------------------------------------------------------------------------------- /config/analogy_params.yaml: -------------------------------------------------------------------------------- 1 | add_orig_row: True 2 | guidance_scales: [1., 2., 3., 5., 7., 9., 12.] 3 | analogy_strength: [0, 1.5, 20] # start, end, number of samples -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | results/ 2 | dataset/data 3 | 4 | __pycache__/ 5 | 6 | 7 | .DS_Store 8 | 9 | *.jpg 10 | *.png 11 | *.jpeg 12 | *.JPG 13 | *.JPEG 14 | 15 | *.pdf 16 | 17 | 18 | estimate_input_noise_tmp_from_prompt.py -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations==0.4.3 2 | diffusers 3 | clip 4 | pudb==2019.2 5 | invisible-watermark 6 | imageio==2.9.0 7 | imageio-ffmpeg==0.4.2 8 | pytorch-lightning==1.4.2 9 | omegaconf==2.1.1 10 | test-tube>=0.7.5 11 | streamlit>=0.73.1 12 | einops==0.3.0 13 | torch-fidelity==0.3.0 14 | transformers==4.19.2 15 | torchmetrics==0.6.0 16 | kornia==0.6 17 | Pillow==9.3.0 18 | matplotlib==3.6.2 19 | ipython==8.7.0 20 | imageio==2.9.0 21 | tqdm==4.62.3 22 | -------------------------------------------------------------------------------- /config/parameter_estimation.yaml: -------------------------------------------------------------------------------- 1 | sufficient_loss: 0.0008 2 | ddim_steps: 10 3 | ddim_eta: 0. 4 | noise_optimization: 5 | opt_iters: 10 6 | log_every: 1 7 | lr: 0.1 8 | batch_size: 8 9 | uncond_guidance_scale: 1. 10 | 11 | conditioning_optimization: 12 | opt_iters: 10 13 | log_every: 1 14 | lr: 0.01 15 | N_tokens: 10 16 | batch_size: 8 17 | uncond_guidance_scale: 1. 18 | fixed_timesteps: True # non-deterministic results with False 19 | 20 | uncond_guidance_scale: 1. 21 | path2save_prefix: './results/experiments/inversion/' 22 | shape: [4,64,64] 23 | device: 'cuda:0' 24 | batch_size: 1 25 | f: 8 26 | save_reconstruction: True -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Adéla Šubrtová and contributors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /process_new_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | path2raw = './dataset/raw_data/' 5 | path2clean = './dataset/data/' 6 | 7 | os.makedirs(path2clean, exist_ok=True) 8 | 9 | def separate_filenames(path): 10 | all_raw_files = os.listdir(path) 11 | raw_files_to_rename = [fn for fn in all_raw_files if re.match('^[0-9]{5}\.', fn) is None] 12 | raw_files_okay = [fn for fn in all_raw_files if re.match('^[0-9]{5}\.', fn) is not None] 13 | return raw_files_to_rename, raw_files_okay 14 | 15 | def determine_file_count(raw_files_okay): 16 | okay_file_names_numbers = [int(rfn.split('.')[0]) for rfn in raw_files_okay] 17 | old_file_count = 0 18 | file_count = 0 19 | if len(okay_file_names_numbers) != 0: 20 | old_file_count = max(okay_file_names_numbers) 21 | file_count = 1 + old_file_count 22 | return file_count 23 | 24 | raw_files_to_rename, raw_files_okay = separate_filenames(path2raw) 25 | 26 | file_count = determine_file_count(raw_files_okay) 27 | 28 | for fn in raw_files_to_rename: 29 | # rename new files to predefined format 30 | suffix = fn.split('.')[-1] 31 | new_file_name = f'{file_count:05d}.{suffix}' 32 | os.rename(os.path.join(path2raw,fn), os.path.join(path2raw, new_file_name)) 33 | file_count += 1 34 | 35 | # pad images 36 | os.system(f'convert {os.path.join(path2raw, new_file_name)} -virtual-pixel black -set option:distort:viewport "%[fx:max(w,h)]x%[fx:max(w,h)]-%[fx:max((h-w)/2,0)]-%[fx:max((w-h)/2,0)]" -filter point -distort SRT 0 +repage {os.path.join(path2clean, new_file_name)}') 37 | # resize 38 | os.system(f'convert {os.path.join(path2clean, new_file_name)} -resize 512x512 {os.path.join(path2clean, new_file_name)}') 39 | 40 | 41 | print(f'Done. {len(raw_files_to_rename)} new files were added.') -------------------------------------------------------------------------------- /visualize_tokens.py: -------------------------------------------------------------------------------- 1 | 2 | from omegaconf import OmegaConf 3 | import numpy as np 4 | import os 5 | 6 | 7 | from ldm.models.diffusion.ddim import DDIMSampler 8 | import utils 9 | 10 | 11 | def fetch_cond_matrix(file_id, ddim_sampler, config): 12 | cond_out = utils.load_estimated_cond(file_id, token_subfolder=token_subfolder) 13 | cond_out = ddim_sampler.model.cond_stage_model.transformer(inputs_embeds = cond_out.unsqueeze(0))['last_hidden_state'] 14 | return cond_out.to(config.device) 15 | 16 | 17 | path_to_data = './dataset/data/' 18 | experiment_root = './results' 19 | 20 | config = OmegaConf.load('./config/parameter_estimation.yaml') 21 | 22 | 23 | print('Loading model...') 24 | model = utils.prepare_default_model() 25 | print('Model loaded') 26 | 27 | 28 | 29 | token_subfolder = 'tokens' 30 | subfolder = 'noise' 31 | 32 | 33 | ddim_sampler = DDIMSampler(model) 34 | 35 | # analogy_creator = AnalogyCreator(config, ddim_sampler, subfolder, token_subfolder, os.path.join(experiment_root, 'analogy_results', out_subfolder)) 36 | file_id_A = '00009' 37 | file_id_A_prime = '00008' 38 | file_id_B = '00010' 39 | 40 | 41 | 42 | cA = fetch_cond_matrix(file_id_A, ddim_sampler, config) 43 | cAprime = fetch_cond_matrix(file_id_A_prime, ddim_sampler, config) 44 | cB = fetch_cond_matrix(file_id_B, ddim_sampler, config) 45 | 46 | 47 | os.makedirs(os.path.join(experiment_root,'token_visualization'),exist_ok=True) 48 | def gen_random_samples(cond, file_id): 49 | tokens_,_ = ddim_sampler.sample( 50 | config.ddim_steps, 51 | 8, 52 | config.shape, 53 | conditioning = cond.expand(8,-1,-1), 54 | eta=config.ddim_eta, 55 | unconditional_guidance_scale=1., 56 | unconditional_conditioning=ddim_sampler.model.get_learned_conditioning(['']).expand(8,-1,-1), 57 | ) 58 | utils.save_latent_as_image( 59 | ddim_sampler.model, 60 | tokens_, 61 | os.path.join(experiment_root,'token_visualization',f'{file_id}.jpg') 62 | ) 63 | 64 | gen_random_samples(cA, file_id_A) 65 | gen_random_samples(cAprime, file_id_A_prime) 66 | gen_random_samples(cB, file_id_B) 67 | -------------------------------------------------------------------------------- /precompute_noises_and_conditionings.py: -------------------------------------------------------------------------------- 1 | import os 2 | import utils 3 | from argparse import ArgumentParser 4 | 5 | 6 | 7 | parser = ArgumentParser() 8 | parser.add_argument('--config', dest='config', type=str, default='./config/parameter_estimation.yaml', 9 | help='path to config file') 10 | 11 | parser.add_argument('--inversion_subfolder', dest='subfolder', type=str, default = 'noise', 12 | help='inversion subfolder name') 13 | 14 | parser.add_argument('--token_subfolder', dest='token_subfolder', type=str, default = 'tokens', 15 | help='Token inversion subfolder name') 16 | 17 | parser.add_argument('--triplet_file', dest='triplet_file', type=str, default='triplets.csv', 18 | help='file with triplets') 19 | 20 | 21 | parser.add_argument('--data_path', dest='data_path', type=str, default = './dataset/data/', 22 | help='root path to data') 23 | 24 | args = parser.parse_args() 25 | args = vars(args) 26 | 27 | 28 | 29 | with open(args['triplet_file'], 'r') as f: 30 | file_lines = f.readlines() 31 | 32 | 33 | conditioning_inversion_names = [] 34 | noise_inversion_names = [] 35 | for line in file_lines: 36 | clean_line = line.replace('\n','').split(' ') 37 | A_name = utils.file_id2im_path(clean_line[0]) 38 | Aprime_name = utils.file_id2im_path(clean_line[1]) 39 | B_name = utils.file_id2im_path(clean_line[2]) 40 | 41 | conditioning_inversion_names.extend([A_name ,Aprime_name, B_name]) 42 | noise_inversion_names.append(B_name) 43 | 44 | 45 | with open('tmp_clip_inversion.txt','w') as f: 46 | for fn in set(conditioning_inversion_names): 47 | f.write(f'{fn}\n') 48 | 49 | with open('tmp_noise_inversion.txt','w') as f: 50 | for fn in set(noise_inversion_names): 51 | f.write(f'{fn}\n') 52 | 53 | 54 | os.system(f'python estimate_CLIP_features.py --config {args["config"]} --subfolder {args["token_subfolder"]} --input_img tmp_clip_inversion.txt --data_path {args["data_path"]}') 55 | 56 | 57 | os.system(f'python estimate_input_noise.py --config {args["config"]} --input_img tmp_noise_inversion.txt --token_subfolder {args["token_subfolder"]} --subfolder {args["subfolder"]} --data_path {args["data_path"]}') 58 | 59 | 60 | os.remove('tmp_clip_inversion.txt') 61 | os.remove('tmp_noise_inversion.txt') -------------------------------------------------------------------------------- /estimate_input_noise.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from omegaconf import OmegaConf 3 | from PIL import Image 4 | import pickle as pkl 5 | import torch 6 | import os 7 | 8 | from ddim_invertor import DDIMInvertor 9 | import utils 10 | 11 | 12 | parser = ArgumentParser() 13 | parser.add_argument('--config', dest='config', type=str, default='./config/parameter_estimation.yaml', 14 | help='path to config file') 15 | 16 | parser.add_argument('--input_img', dest='input_img', type=str, default = None, 17 | help='path to image or text files with image names') 18 | 19 | parser.add_argument('--subfolder', dest='subfolder', type=str, default = 'noise', 20 | help='subfolder name') 21 | 22 | parser.add_argument('--token_subfolder', dest='token_subfolder', type=str, default = 'tokens', 23 | help='token subfolder name') 24 | 25 | 26 | parser.add_argument('--data_path', dest='data_path', type=str, default = './dataset/data/', 27 | help='root path to data') 28 | 29 | args = parser.parse_args() 30 | args = vars(args) 31 | 32 | assert os.path.isfile(args['input_img']), '--input_img is not a file' 33 | 34 | if args['input_img'].endswith('.txt'): 35 | with open(args['input_img'], 'r') as f: 36 | file_lines = f.readlines() 37 | clean_file_lines = [os.path.join(args['data_path'],x.replace('\n','')) for x in file_lines] 38 | elif args['input_img'].endswith(('.png','.jpeg','.jpg', 'JPG', 'JPEG')): 39 | clean_file_lines = [args['input_img']] 40 | 41 | config = OmegaConf.load(f"{args['config']}") 42 | config.args = args 43 | config.token_subfolder = args['token_subfolder'] 44 | 45 | 46 | print('Loading model...') 47 | model = utils.prepare_default_model() 48 | model = model.to(config.device) 49 | print('Model loaded') 50 | 51 | 52 | invertor = DDIMInvertor(config, model) 53 | 54 | 55 | for file_name in clean_file_lines: 56 | print(f'Processing file: {file_name}') 57 | if not os.path.exists(file_name): 58 | print(f'Path {file_name} does not exist. Skipping') 59 | continue 60 | # load & prepare image 61 | file_id = utils.extract_file_id_from_path(file_name) 62 | export_path = os.path.join(config.path2save_prefix, file_id, args['subfolder']) 63 | if os.path.exists(os.path.join(export_path, 'results.pkl')): 64 | print(f'The inversion for {file_id} seems to be done already. Skipping...') 65 | continue 66 | os.makedirs(export_path, exist_ok=True) 67 | 68 | 69 | print('Performing inversion...') 70 | outputs = invertor.perform_inversion(file_name, cond = None, init_noise_init = None, loss_weights= {'latents': 1. , 'pixels':1.} ) 71 | 72 | 73 | outputs['token_subfolder'] = args['token_subfolder'] 74 | 75 | if config.save_reconstruction: 76 | img = utils.load_pil(file_name) 77 | img.save(os.path.join(export_path,f'target.png')) 78 | 79 | rec_img_torch = utils.latent2img(model, outputs['reconstruction']) 80 | rec_img_pil = utils.torch2pil(rec_img_torch)[0] 81 | rec_img_pil.save(os.path.join(export_path, 'reconstruction.png')) 82 | 83 | print(f'Saving results to {export_path}') 84 | utils.save_results2pickle(export_path, outputs) 85 | 86 | 87 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Diffusion Image Analogies 2 |
3 | Adéla Šubrtová, 4 | Michal Lukáč, 5 | Jan Čech, 6 | David Futschik, 7 | Eli Shechtman, 8 | Daniel Sýkora, 9 |
10 | 11 | ![DIA_Teaser](https://github.com/subrtadel/DIA/assets/129282989/4e5ab11d-851a-4d9a-a6f8-d3769e994e33) 12 | This is the official repository for the Diffusion Image Analogies paper published at the SIGGRAPH 2023 Conference Proceedings. 13 | 14 | *** 15 | 16 | ## Installation 17 | 18 | 1. Clone the repo 19 | ```sh 20 | git clone --recurse-submodules https://github.com/subrtadel/DIA.git 21 | cd ./DIA 22 | ``` 23 | 2. Create environment 24 | ``` 25 | conda create -n dia_env 26 | conda activate dia_env 27 | conda install python=3.8.5 pip=20.3 cudatoolkit=11.3 pytorch=1.11.0 torchvision=0.12.0 numpy=1.19.2 -c pytorch -c nvidia -c conda-forge -c defaults 28 | ``` 29 | 3. Install packages 30 | ```sh 31 | pip install -r requirements.txt 32 | cd ./stable-diffusion/ 33 | pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 34 | pip install -e . 35 | ``` 36 | 4. Download the [sd-v1-4.ckpt model](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original) and put it into correct folder 37 | ``` 38 | mkdir -p ./models/ldm/stable-diffusion-v1/ 39 | 40 | ``` 41 | 5. Install [Image Magick](https://imagemagick.org). 42 | 43 |

(back to top)

44 | 45 | *** 46 | 47 | 48 | ## Usage 49 | 50 | 1. Upload images into `./dataset/raw_data/` folder. 51 |
52 | 53 | 2. Run `process_new_data.py`. The images are assigned `file_id`s in a `%05d` format. 54 |
55 | 56 | 3. Define the triplets in a `.csv` file. Refer to the images by their `file_id`. 57 | Example file is `triplets.csv`. First column specifies the `A` input, second the `A'` and the third `B` input. Either with of without filename suffixes is fine. 58 |
59 | 60 | 4. Run the `precompute_noises_and_conditionings.py` script. This may take a while. 61 | ```python precompute_noises_and_conditionings.py 62 | python precompute_noises_and_conditionings.py \ 63 | --config ./config/parameter_estimation.yaml \ 64 | --inversion_subfolder noise \ 65 | --token_subfolder tokens \ 66 | --triplet_file triplets.csv \ 67 | --data_path ./dataset/data/ 68 | ``` 69 |
70 | 71 | 5. Check the `./config/analogy_params.yaml`. 72 |
73 | 74 | 6. Run the `do_analogies.py` script. 75 | ```python do_analogies.py 76 | python do_analogies.py \ 77 | --config ./config/parameter_estimation.yaml \ 78 | --inversion_subfolder noise \ 79 | --token_subfolder tokens \ 80 | --output_subfolder analogies \ 81 | --triplet_file triplets.csv \ 82 | --data_path ./dataset/data/ 83 | ``` 84 | 85 | 86 | 87 | *** 88 | 89 | ## BibTeX 90 | 91 | @inproceedings{Subrtova2023DIA, 92 | title = {Diffusion Image Analogies}, 93 | author = {A. \v{S}ubrtov\'{a} and M. Luk\'{a}\v{c} and J. \v{C}ech and D. Futschik and E. Shechtman and D. S\'{y}kora}, 94 | booktitle = {ACM SIGGRAPH 2023 Conference Proceedings}, 95 | year = {2023} 96 | } 97 | -------------------------------------------------------------------------------- /analogy_creator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | 5 | from ddim_invertor import DDIMInvertor 6 | import utils 7 | 8 | class AnalogyCreator(): 9 | def __init__(self, config, ddim_sampler, inversion_subfolder, token_subfolder, output_root, data_path) -> None: 10 | self.config = config 11 | self.ddim_sampler = ddim_sampler 12 | self.subfolder = inversion_subfolder 13 | self.token_subfolder = token_subfolder 14 | self.output_root = output_root 15 | self.uc = self.ddim_sampler.model.get_learned_conditioning(['']) 16 | self.data_path = data_path 17 | 18 | 19 | def fetch_cond_matrix(self, file_id): 20 | cond_out = utils.load_estimated_cond(file_id, token_subfolder=self.token_subfolder) 21 | cond_out = self.ddim_sampler.model.cond_stage_model.transformer(inputs_embeds = cond_out.unsqueeze(0))['last_hidden_state'] 22 | return cond_out.to(self.config.device) 23 | 24 | def __load_B_noise(self, imB): 25 | fileid_B = utils.extract_file_id_from_path(imB) 26 | _,_,_, resB = utils.load_inversion_result_dict(fileid_B, self.subfolder, return_result_dict=True) 27 | return resB['estimated_input_noise'] 28 | 29 | def make_analogy(self, triplet, steps, uc_scales, analogy_func = None, **analogy_func_kwargs): 30 | print(f'Make analogy inputs: {triplet}, {uc_scales}, {steps}') 31 | triplet_code = '_'.join([utils.extract_file_id_from_path(t) for t in triplet]) 32 | noise = self.__load_B_noise(triplet[-1]) 33 | 34 | cA = self.fetch_cond_matrix(utils.extract_file_id_from_path(triplet[0])) 35 | cAprime = self.fetch_cond_matrix(utils.extract_file_id_from_path(triplet[1])) 36 | 37 | cB = self.fetch_cond_matrix(utils.extract_file_id_from_path(triplet[2])) 38 | self.make_analogy_from_args(triplet_code, cA, cAprime, cB, noise, steps, uc_scales, analogy_func, **analogy_func_kwargs) 39 | 40 | 41 | def make_analogy_from_args(self, triplet_code, cA, cAprime, cB, noise, steps, uc_scales, analogy_func = None, **analogy_func_kwargs): 42 | 43 | os.makedirs(os.path.join(self.output_root, triplet_code), exist_ok=True) 44 | os.makedirs(os.path.join(self.output_root,'grids'), exist_ok=True) 45 | if analogy_func is None: 46 | analogy_func = lambda cA, cAprime, cB, st: cB + st * (cAprime - cA) 47 | 48 | rows = [] 49 | for sc in uc_scales: 50 | cols = [] 51 | for st in steps: 52 | analogy_res,_ = self.ddim_sampler.sample( 53 | self.config.ddim_steps, 54 | 1, 55 | self.config.shape, 56 | conditioning = analogy_func(cA, cAprime, cB, st, **analogy_func_kwargs), 57 | eta=self.config.ddim_eta, 58 | x_T=noise, 59 | unconditional_guidance_scale=sc, 60 | unconditional_conditioning=self.uc, 61 | ) 62 | img = utils.save_latent_as_image( 63 | self.ddim_sampler.model, 64 | analogy_res, 65 | os.path.join(self.output_root, triplet_code, f'analogy_sc={sc}_shift_strength={st}.jpg'), 66 | return_pil=True 67 | ) 68 | cols.append(np.array(img)) 69 | 70 | rows.append(np.concatenate(cols, axis = 1)) 71 | grid = np.concatenate(rows, axis = 0) 72 | Image.fromarray(grid).save(os.path.join(self.output_root, 73 | 'grids', 74 | f'{triplet_code}_analogy_grid.jpg' )) 75 | 76 | -------------------------------------------------------------------------------- /estimate_CLIP_features.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from omegaconf import OmegaConf 3 | from PIL import Image 4 | import pickle as pkl 5 | import numpy as np 6 | import sys 7 | import os 8 | sys.path.append('./stable-diffusion/') 9 | 10 | from ddim_invertor import DDIMInvertor 11 | import utils 12 | 13 | 14 | 15 | parser = ArgumentParser() 16 | parser.add_argument('--config', dest='config', type=str, default='./config/parameter_estimation.yaml', 17 | help='path to config file') 18 | parser.add_argument('--input_img', dest='input_img', type=str, required = True, 19 | help='path to image or text files with image names') 20 | 21 | parser.add_argument('--subfolder', dest='subfolder', type=str, default = 'tokens', 22 | help='subfolder name') 23 | 24 | 25 | parser.add_argument('--data_path', dest='data_path', type=str, default = './dataset/data/', 26 | help='root path to data') 27 | 28 | parser.add_argument('--regenerate_tokens', dest='regenerate', action='store_true', 29 | help='Will regenerate images with random noise and the output conditioning') 30 | 31 | args = parser.parse_args() 32 | args = vars(args) 33 | 34 | assert os.path.isfile(args['input_img']), '--input_img is not a file' 35 | 36 | if args['input_img'].endswith('.txt'): 37 | with open(args['input_img'], 'r') as f: 38 | file_lines = f.readlines() 39 | clean_file_lines = [os.path.join(args['data_path'],x.replace('\n','')) for x in file_lines] 40 | elif args['input_img'].endswith(('.png','.jpeg','.jpg')): 41 | clean_file_lines = [args['input_img']] 42 | 43 | config = OmegaConf.load(f"{args['config']}") 44 | config.args = args 45 | 46 | 47 | print('Loading model...') 48 | model = utils.prepare_default_model() 49 | model = model.to(config.device) 50 | print('Model loaded') 51 | 52 | 53 | invertor = DDIMInvertor(config, model) 54 | 55 | for file_path in clean_file_lines: 56 | if not os.path.exists(file_path): 57 | print(f'Path {file_path} does not exist. Skipping') 58 | continue 59 | 60 | file_id = utils.extract_file_id_from_path(file_path) 61 | if os.path.exists(os.path.join(config.path2save_prefix, file_id, args['subfolder'],'results.pkl')): 62 | print(f'Inversion for file_id {file_id} is already done... Skipping') 63 | continue 64 | 65 | output = invertor.perform_cond_inversion_individual_timesteps(file_path, None, optimize_tokens=True) 66 | 67 | 68 | export_path = os.path.join(config.path2save_prefix, file_id, args['subfolder']) 69 | 70 | print(f'Saving results to {export_path}') 71 | utils.save_results2pickle(export_path, output) 72 | # os.makedirs(export_path, exist_ok=True) 73 | 74 | # with open(os.path.join(export_path, 'results.pkl') ,'wb') as f: 75 | # pkl.dump(output, f) 76 | 77 | if args["regenerate"]: 78 | c_ = model.cond_stage_model.transformer(inputs_embeds = output['estimated_conditioning'].unsqueeze(0))['last_hidden_state'] 79 | res, _ = invertor.ddim_sampler.sample(config.ddim_steps, 80 | config.conditioning_optimization.batch_size, 81 | config.shape, 82 | conditioning=c_.expand(config.conditioning_optimization.batch_size, -1,-1), 83 | eta=0., 84 | unconditional_guidance_scale=5., 85 | unconditional_conditioning=invertor.uc.expand(config.conditioning_optimization.batch_size, -1,-1)) 86 | 87 | 88 | 89 | img = utils.save_latent_as_image(model, res, os.path.join(export_path,'token_regeneration.png'),return_pil=True) 90 | orig = np.array(Image.open(file_path).convert("RGB")) 91 | row = np.concatenate((orig,np.zeros((orig.shape[0], 20,3)),np.array(img)), axis = 1).astype(np.uint8) 92 | Image.fromarray(row).save(os.path.join(export_path,'token_regeneration_with_ref.png')) 93 | 94 | del output 95 | 96 | -------------------------------------------------------------------------------- /do_analogies.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import os 3 | from omegaconf import OmegaConf 4 | 5 | from ldm.models.diffusion.ddim import DDIMSampler 6 | from analogy_creator import AnalogyCreator 7 | import utils 8 | from PIL import Image 9 | import pickle as pkl 10 | import torch 11 | import numpy as np 12 | 13 | 14 | 15 | parser = ArgumentParser() 16 | parser.add_argument('--config', dest='config', type=str, default='./config/parameter_estimation.yaml', 17 | help='path to config file') 18 | 19 | parser.add_argument('--inversion_subfolder', dest='subfolder', type=str, default = 'noise', 20 | help='inversion subfolder name') 21 | 22 | parser.add_argument('--token_subfolder', dest='token_subfolder', type=str, default = 'tokens', 23 | help='Token inversion subfolder name') 24 | 25 | parser.add_argument('--output_subfolder', dest='out_subfolder', type=str, default = 'analogies', 26 | help='Output subfolder name') 27 | 28 | parser.add_argument('--triplet_file', dest='triplet_file', type=str, 29 | help='file with image paths') 30 | 31 | 32 | parser.add_argument('--data_path', dest='data_path', type=str, default = './dataset/data/', 33 | help='root path to data') 34 | 35 | args = parser.parse_args() 36 | args = vars(args) 37 | 38 | 39 | with open(args['triplet_file'], 'r') as f: 40 | file_lines = f.readlines() 41 | 42 | 43 | clean_file_triplets = [] 44 | for line in file_lines: 45 | clean_line = line.replace('\n','').split(' ') 46 | pathA = utils.file_id2im_path(clean_line[0], data_path=args['data_path'],absolute=True) 47 | pathAprime = utils.file_id2im_path(clean_line[1], data_path=args['data_path'],absolute=True) 48 | pathB = utils.file_id2im_path(clean_line[2], data_path=args['data_path'],absolute=True) 49 | clean_file_triplets.append((pathA, pathAprime, pathB)) 50 | 51 | 52 | config = OmegaConf.load(f"{args['config']}") 53 | config.args = args 54 | # log_config = OmegaConf.load('./config/logging_config.yaml') 55 | experiment_root = './results/' 56 | 57 | 58 | subfolder = args['subfolder'] 59 | token_subfolder = args['token_subfolder'] 60 | 61 | print('Loading model...') 62 | model = utils.prepare_default_model() 63 | model = model.to(config.device) 64 | print('Model loaded') 65 | 66 | ddim_sampler = DDIMSampler(model) 67 | 68 | export_path = os.path.join(experiment_root, 'analogy_results', args['out_subfolder']) 69 | 70 | analogy_creator = AnalogyCreator(config, ddim_sampler, subfolder, token_subfolder, export_path, data_path= args['data_path']) 71 | 72 | 73 | analogy_config = OmegaConf.load(f"./config/analogy_params.yaml") 74 | add_orig_row = analogy_config.add_orig_row 75 | scales = analogy_config.guidance_scales 76 | steps = np.linspace(*analogy_config.analogy_strength) 77 | 78 | analogy_func = lambda cA, cAprime, cB, st: cB + st * (cAprime - cA) 79 | 80 | 81 | for triplet in clean_file_triplets: 82 | print(f'Processing triplet: {triplet}') 83 | if os.path.exists(os.path.join(export_path, utils.tuple2triplet_name(triplet))): 84 | print(f'This ({utils.tuple2triplet_name(triplet)}) analogy is precomputed... Skipping') 85 | continue 86 | 87 | if not utils.check_inversion_done(os.path.join(args['data_path'], triplet[2]), subfolder): 88 | print(f'Inversion for image {triplet[2]} not found... Skipping') 89 | print(f'p at : {os.path.join(args["data_path"], triplet[2], subfolder)}') 90 | continue 91 | if (not utils.check_inversion_done(os.path.join(args['data_path'], triplet[0]), token_subfolder) or \ 92 | not utils.check_inversion_done(os.path.join(args['data_path'], triplet[1]), token_subfolder) or \ 93 | not utils.check_inversion_done(os.path.join(args['data_path'], triplet[2]), subfolder)): 94 | print(f'Inversion not found... Skipping') 95 | continue 96 | 97 | analogy_creator.make_analogy(triplet, steps, scales, analogy_func = analogy_func) 98 | 99 | if add_orig_row: 100 | triplet_code = '_'.join([utils.extract_file_id_from_path(t) for t in triplet]) 101 | grid = np.array(Image.open(os.path.join(export_path, 102 | 'grids', 103 | f'{triplet_code}_analogy_grid.jpg' ))) 104 | first_row = utils.join_images(triplet, out_PIL=False) 105 | n_pad = grid.shape[1] - first_row.shape[1] 106 | pad = np.zeros((first_row.shape[0], n_pad, 3)) 107 | first_row = np.concatenate((first_row, pad), axis = 1) 108 | final_grid = np.uint8(np.concatenate((first_row, grid), axis = 0)) 109 | Image.fromarray(final_grid).save(os.path.join(export_path,'grids',f'{triplet_code}_analogy_grid.jpg')) 110 | 111 | 112 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning import seed_everything 2 | from omegaconf import OmegaConf 3 | import numpy as np 4 | import pickle as pkl 5 | import torch 6 | import fnmatch 7 | import PIL 8 | import gc 9 | import os 10 | 11 | 12 | from ldm.util import instantiate_from_config 13 | from modified_clip_transformers import ModifiedCLIPTextModel 14 | import importlib 15 | 16 | # importlib.import_module("/home/subrtade/analogies/DiffusionImageAnalogies/stable-diffusion") 17 | 18 | ######################################################################## Model prep 19 | 20 | 21 | 22 | # taken from the stable-diffusion project script txt2img.py 23 | def load_model_from_config(config, ckpt, verbose=False, device='cuda'): 24 | print(f"Loading model from {ckpt}") 25 | pl_sd = torch.load(ckpt, map_location="cpu") 26 | if "global_step" in pl_sd: 27 | print(f"Global Step: {pl_sd['global_step']}") 28 | sd = pl_sd["state_dict"] 29 | model = instantiate_from_config(config.model) 30 | m, u = model.load_state_dict(sd, strict=False) 31 | if len(m) > 0 and verbose: 32 | print("missing keys:") 33 | print(m) 34 | if len(u) > 0 and verbose: 35 | print("unexpected keys:") 36 | print(u) 37 | 38 | model.to(device) 39 | model.eval() 40 | return model 41 | 42 | 43 | def prepare_default_model(default_seed = 42): 44 | default_config_path = "./stable-diffusion/configs/stable-diffusion/v1-inference.yaml" 45 | default_ckpt_path = "./stable-diffusion/models/ldm/stable-diffusion-v1/model.ckpt" 46 | 47 | seed_everything(default_seed) 48 | 49 | config = OmegaConf.load(f"{default_config_path}") 50 | model = load_model_from_config(config, default_ckpt_path, True) 51 | 52 | del model.cond_stage_model.transformer 53 | print(f'GC COLLECT RETURN VALUE: {gc.collect()}') 54 | 55 | 56 | model.cond_stage_model.transformer = ModifiedCLIPTextModel.from_pretrained('openai/clip-vit-large-patch14').to(model.device) 57 | 58 | 59 | return model 60 | 61 | 62 | ####################################################################### 63 | # Data prep 64 | 65 | def extract_file_id_from_path(file_name): 66 | return os.path.basename(file_name).split('.')[0] 67 | 68 | def load_all_image_names(path = './dataset/data/', suffixes = ['jpg', 'jpeg', 'png', 'JPG', 'JPEG']): 69 | all_image_files = os.listdir(path) 70 | extracted_files = [] 71 | for suf in suffixes: 72 | extracted_files += fnmatch.filter(all_image_files , f'*.{suf}') 73 | 74 | return extracted_files 75 | 76 | 77 | def file_id2im_path(file_id, data_path = './dataset/data', absolute=False): 78 | if not file_id.endswith(('png', 'jpg', 'jpeg', 'JPG', 'JPEG')): 79 | image_names = load_all_image_names(path=data_path) 80 | image_name = fnmatch.filter(image_names, f'{file_id}.*')[0] 81 | else: 82 | image_name = file_id 83 | if absolute: 84 | return os.path.join(data_path, image_name) 85 | return image_name 86 | 87 | 88 | def extract_triplet_from_tuple(tuple_): 89 | fids = [extract_file_id_from_path(pth) for pth in tuple_] 90 | return fids 91 | 92 | def tuple2triplet_name(triplet_tuple): 93 | file_ids = extract_triplet_from_tuple(triplet_tuple) 94 | return '_'.join(file_ids) 95 | 96 | 97 | def join_images(list_of_image_paths, dim=1, path_prefix = '',out_PIL = True): 98 | """Given list of image paths, the function puts the images side by side in 'dim'. 99 | 100 | Args: 101 | list_of_image_paths (list): List that contains paths to the images 102 | dim (int, optional): In which dimension are the images joined. Defaults to 1. 103 | path_prefix (str, optional): Path to the images. Defaults to ''. 104 | out_PIL (bool, optional): The output is PIL image if True, otherwise the ouptut is np.ndarray. Defaults to True. 105 | 106 | Returns: 107 | _type_: _description_ 108 | """ 109 | imgs = [] 110 | for im_name in list_of_image_paths: 111 | img = np.array(PIL.Image.open(os.path.join(path_prefix,im_name))) 112 | if len(img.shape) == 2: 113 | img = np.stack((img,img,img), axis = -1) 114 | if img.shape[-1] == 4: 115 | img = img[:,:,:3] 116 | imgs.append(img) 117 | return join_array_of_np_images(imgs, dim, out_PIL) 118 | 119 | def join_array_of_np_images(array_of_imgs, dim = 1, out_PIL = True): 120 | if out_PIL: 121 | return PIL.Image.fromarray(np.concatenate(array_of_imgs, axis = dim)) 122 | return np.concatenate(array_of_imgs, axis = dim) 123 | 124 | 125 | 126 | 127 | def img2latent(model, img_torch): 128 | return model.get_first_stage_encoding(model.encode_first_stage(img_torch.to(model.first_stage_model.device))) 129 | 130 | def latent2img(model, latent): 131 | images = model.decode_first_stage(latent.to(model.first_stage_model.device)) 132 | return images 133 | 134 | def load_pil(img): 135 | return PIL.Image.open(img).convert('RGB') 136 | 137 | 138 | def pil2torch(pilimg, to_range = True, device = 'cuda:0'): 139 | w, h = pilimg.size 140 | w, h = w - w%32, h - h%32 141 | pilimg.resize((w,h), resample=PIL.Image.LANCZOS) 142 | im_np = np.array(pilimg).astype(np.float32) / 255. 143 | im_np = im_np[np.newaxis].transpose((0, 3, 1, 2)) 144 | im_torch = torch.from_numpy(im_np).to(device) 145 | if to_range: 146 | im_torch = 2*im_torch - 1 147 | return im_torch 148 | 149 | def pil2torch_batch(list_of_ims, to_range=True, device= 'cuda:0'): 150 | batch = [] 151 | for b in range(len(list_of_ims)): 152 | batch.append(pil2torch(list_of_ims[b], to_range, device)) 153 | return torch.cat(batch, dim=0) 154 | 155 | def torch2pil(images, from_range = True): 156 | if from_range: 157 | images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) 158 | pil_images = [] 159 | b = images.shape[0] 160 | for i in range(b): 161 | img_np = np.array(images[i].detach().cpu()) 162 | img_np = np.uint8(img_np.transpose((1,2,0)) * 255) 163 | img_pil = PIL.Image.fromarray(img_np) 164 | pil_images.append(img_pil) 165 | return pil_images 166 | 167 | 168 | def save_latent_as_image(model, latent, path, return_pil=False): 169 | """Generates the output image from given latent and saves it to path. 170 | 171 | Args: 172 | model (_type_): stable diffusion. 173 | latent (_type_): Latent of the image. 174 | path (_type_): Path to save the image. 175 | """ 176 | rec_img_torch = latent2img(model, latent) 177 | rec_img_pil = torch2pil(torch.cat(list(rec_img_torch), dim = -1).unsqueeze(0))[0] 178 | rec_img_pil.save(path) 179 | if return_pil: 180 | return rec_img_pil 181 | 182 | ####################################################################### 183 | # Optimization utils 184 | 185 | def pixel_space_loss(model, latent1, real_image, loss_fn): 186 | """computes loss loss_fn in pixel space 187 | 188 | Args: 189 | model (_type_): stable diffusion model 190 | latent1 (_type_): latent of the generated image 191 | real_image (_type_): target image 192 | loss_fn (_type_): torch functional loss 193 | 194 | Returns: 195 | _type_: Value of loss_fn between real_image and the image generated from latent1. 196 | """ 197 | image1 = model.differentiable_decode_first_stage(latent1.to(model.first_stage_model.device)) 198 | return loss_fn(image1, real_image.to(image1.device)) 199 | 200 | 201 | ####################################################################### 202 | # Results manipulation 203 | 204 | def load_estimated_cond(file_id, token_subfolder = 'tokens', inversion_path_root = './results/experiments/inversion/' ): 205 | if not os.path.exists(os.path.join(inversion_path_root, f'{file_id}/{token_subfolder}/results.pkl')): 206 | return None 207 | with open(os.path.join(inversion_path_root,f'{file_id}/{token_subfolder}/results.pkl'),'rb') as f: 208 | results = pkl.load(f) 209 | return results['estimated_conditioning'] 210 | 211 | 212 | 213 | def load_inversion_result_dict(file_id, subfolder, return_result_dict = False, inversion_root_folder='./results/experiments/inversion/'): 214 | """Loads the results of inversion for given file_id and experiment. 215 | 216 | Args: 217 | file_id (str (ex. 000001)): File id of the inverted image. 218 | subfolder (str): Name of the inversion experiment. 219 | return_result_dict (bool, optional): If yes returns the whole result dict. Defaults to False. 220 | 221 | Returns: 222 | _type_: collection of (noise, conditioning matrix, unconditional guidance scale, [result dict]) 223 | """ 224 | assert os.path.exists(os.path.join(inversion_root_folder, file_id, subfolder,'results.pkl')) , f'This ({file_id}/{subfolder}) experiment does not exist.' 225 | 226 | with open(os.path.join(inversion_root_folder, file_id, subfolder,'results.pkl'), 'rb') as f: 227 | results = pkl.load(f) 228 | 229 | noise = results['estimated_input_noise'] if 'estimated_input_noise' in results.keys() else None 230 | cond = results['estimated_conditioning'] if 'estimated_conditioning' in results.keys() else None 231 | cond_scale = results['guidance_scale'] if 'guidance_scale' in results.keys() else None 232 | 233 | output = (noise, cond, cond_scale) 234 | if return_result_dict: 235 | output = (*output, results) 236 | return output 237 | 238 | 239 | 240 | def check_inversion_done(path_to_image_or_file_id, subfolder, inversion_root_folder = "./results/experiments/inversion/"): 241 | if path_to_image_or_file_id.endswith(('.jpg','.png','.jpeg', 'JPG', 'JPEG')): 242 | file_id = extract_file_id_from_path(path_to_image_or_file_id) 243 | else: 244 | file_id = path_to_image_or_file_id 245 | print(f'Checking: {os.path.join(inversion_root_folder, file_id, subfolder,"results.pkl")}') 246 | return os.path.exists(os.path.join(inversion_root_folder, file_id, subfolder,'results.pkl')) 247 | 248 | ####################################################################### 249 | # Others 250 | 251 | 252 | def save_results2pickle(path2save, results): 253 | os.makedirs(path2save, exist_ok=True) 254 | with open(os.path.join(path2save, 'results.pkl') ,'wb') as f: 255 | pkl.dump(results, f) 256 | 257 | 258 | 259 | def check_and_run_inversion(model, file_id, subfolder, config, tokens = True): 260 | if not check_inversion_done(file_id, subfolder): 261 | from ddim_invertor import DDIMInvertor 262 | invertor = DDIMInvertor(config, model) 263 | if tokens: 264 | output = invertor.perform_cond_inversion_individual_timesteps(file_id2im_path(file_id), None, optimize_tokens=True) 265 | else: 266 | output = invertor.perform_inversion(file_id2im_path(file_id), None, init_noise_init = None, loss_weights= {'latents': 1. , 'pixels':1.} ) 267 | 268 | export_path = os.path.join(config.path2save_prefix, file_id, subfolder) 269 | save_results2pickle(export_path, output) 270 | print(f'Inversion done') -------------------------------------------------------------------------------- /DiffusionImageAnalogies.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "fd6ef73c", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stderr", 11 | "output_type": "stream", 12 | "text": [ 13 | "/home/subrtade/.local/lib/python3.10/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: libtorch_cuda_cu.so: cannot open shared object file: No such file or directory\n", 14 | " warn(f\"Failed to load image Python extension: {e}\")\n" 15 | ] 16 | } 17 | ], 18 | "source": [ 19 | "from omegaconf import OmegaConf\n", 20 | "import numpy as np\n", 21 | "import os\n", 22 | "\n", 23 | "\n", 24 | "from ldm.models.diffusion.ddim import DDIMSampler\n", 25 | "from analogy_creator import AnalogyCreator\n", 26 | "import utils" 27 | ] 28 | }, 29 | { 30 | "attachments": {}, 31 | "cell_type": "markdown", 32 | "id": "69519f0c", 33 | "metadata": {}, 34 | "source": [ 35 | "# Presets\n" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 2, 41 | "id": "44029400", 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "path_to_data = './dataset/data/'\n", 46 | "out_subfolder = 'notebook_analogies'\n", 47 | "experiment_root = './results'\n", 48 | "visualize_tokens = True\n", 49 | "# guidance scales\n", 50 | "scales = [1.,2., 3., 5., 7., 9., 12.]\n", 51 | "# analogy strength step\n", 52 | "steps = np.linspace(0, 3, 20)" 53 | ] 54 | }, 55 | { 56 | "attachments": {}, 57 | "cell_type": "markdown", 58 | "id": "65fc40a3", 59 | "metadata": {}, 60 | "source": [ 61 | "# Initialization" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "id": "bb5faf5b", 68 | "metadata": {}, 69 | "outputs": [ 70 | { 71 | "name": "stderr", 72 | "output_type": "stream", 73 | "text": [ 74 | "Global seed set to 42\n" 75 | ] 76 | }, 77 | { 78 | "name": "stdout", 79 | "output_type": "stream", 80 | "text": [ 81 | "Loading model...\n", 82 | "Loading model from stable-diffusion/models/ldm/stable-diffusion-v1/model.ckpt\n", 83 | "Global Step: 194366\n", 84 | "LatentDiffusion: Running in eps-prediction mode\n" 85 | ] 86 | } 87 | ], 88 | "source": [ 89 | "config = OmegaConf.load('./config/parameter_estimation.yaml')\n", 90 | "\n", 91 | "\n", 92 | "print('Loading model...')\n", 93 | "model = utils.prepare_default_model()\n", 94 | "print('Model loaded')\n", 95 | "\n", 96 | "\n", 97 | "export_path = os.path.join(experiment_root, 'analogy_results', out_subfolder)\n", 98 | "os.makedirs(export_path, exist_ok = True)\n", 99 | "\n", 100 | "token_subfolder = 'tokens_dia_test'\n", 101 | "subfolder = 'noise_dia_test'\n", 102 | "\n", 103 | "\n", 104 | "ddim_sampler = DDIMSampler(model)\n", 105 | "\n", 106 | "analogy_creator = AnalogyCreator(config, ddim_sampler, subfolder, token_subfolder, os.path.join(log_config.experiment_root, 'analogy_results', out_subfolder))\n" 107 | ] 108 | }, 109 | { 110 | "attachments": {}, 111 | "cell_type": "markdown", 112 | "id": "e4a7a064", 113 | "metadata": {}, 114 | "source": [ 115 | "# Dataset" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "id": "7738b083", 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "import ipyplot\n", 126 | "\n", 127 | "im_path = lambda x: os.path.join(path_to_data, x)\n", 128 | "\n", 129 | "\n", 130 | "all_images = sorted(utils.load_all_image_names())\n", 131 | "labels = [utils.extract_file_id_from_path(x) for x in all_images] \n", 132 | "\n", 133 | "\n", 134 | "images_list = [im_path(x) for x in all_images]\n", 135 | "ipyplot.plot_images(images_list, labels = labels, max_images=200, img_width=100)" 136 | ] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "id": "5627667a", 141 | "metadata": {}, 142 | "source": [ 143 | "### Triplet Selection" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "id": "c1035cde", 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "file_id_A = '000203'\n", 154 | "file_id_A_prime = '000204'\n", 155 | "file_id_B = '000226'\n", 156 | "\n" 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "id": "0c40cee5", 162 | "metadata": {}, 163 | "source": [ 164 | "# Analogies" 165 | ] 166 | }, 167 | { 168 | "cell_type": "markdown", 169 | "id": "5c6e0b3f", 170 | "metadata": {}, 171 | "source": [ 172 | "### Check for inverted CLIP features and noise\n", 173 | "The parameters will be estimated, if missing." 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "id": "89d6aec6", 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [ 183 | "\n", 184 | "# check for CLIP features\n", 185 | "utils.check_and_run_inversion(model, file_id_A, token_subfolder, config, tokens=True)\n", 186 | "\n", 187 | "utils.check_and_run_inversion(model,file_id_A_prime, token_subfolder, config, tokens=True)\n", 188 | "\n", 189 | "utils.check_and_run_inversion(model,file_id_B, token_subfolder, config, tokens=True)\n", 190 | "\n", 191 | "\n", 192 | "# check for noise of image B\n", 193 | "config.token_subfolder = token_subfolder\n", 194 | "utils.check_and_run_inversion(model, file_id_B, subfolder, config, tokens=False)" 195 | ] 196 | }, 197 | { 198 | "cell_type": "markdown", 199 | "id": "0a49d5fb", 200 | "metadata": {}, 201 | "source": [ 202 | "### Performing analogies\n", 203 | "Generating grid of analogies for given triplet, list of guidance scales \\sigma_i and analogy strengths \\lambda_j." 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": null, 209 | "id": "ca17cd4f", 210 | "metadata": { 211 | "scrolled": true 212 | }, 213 | "outputs": [], 214 | "source": [ 215 | "\n", 216 | "\n", 217 | "\n", 218 | "print(f'Processing triplet: ({file_id_A}, {file_id_A_prime}, {file_id_B})')\n", 219 | "\n", 220 | "triplet = (utils.file_id2im_path(file_id_A), utils.file_id2im_path(file_id_A_prime), utils.file_id2im_path(file_id_B))\n", 221 | "triplet_code = f'{file_id_A}_{file_id_A_prime}_{file_id_B}'\n", 222 | "\n", 223 | "# cA = model.get_learned_conditioning('manual prompt for A')\n", 224 | "# cAprime = model.get_learned_conditioning(\"manual prompt for A'\")\n", 225 | "\n", 226 | "cA = analogy_creator.fetch_cond_matrix(file_id_A)\n", 227 | "cAprime = analogy_creator.fetch_cond_matrix(file_id_A_prime)\n", 228 | "\n", 229 | "noiseB,_,_ = utils.load_inversion_result_dict(file_id_B, subfolder, return_result_dict=False)\n", 230 | "\n", 231 | "cB = analogy_creator.fetch_cond_matrix(file_id_B)\n", 232 | "\n", 233 | "\n", 234 | "analogy_creator.make_analogy_from_args(triplet_code, cA, cAprime, cB, noiseB, steps, scales)\n", 235 | "\n" 236 | ] 237 | }, 238 | { 239 | "cell_type": "markdown", 240 | "id": "1eec85b4", 241 | "metadata": {}, 242 | "source": [ 243 | "### Analogy Results\n", 244 | "Generated grid of analogies is shown. Each row corresponds to the guidance scale \\sigma_i, each column corresponds to analogy strength \\lambda_j." 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": null, 250 | "id": "c51a5c97", 251 | "metadata": {}, 252 | "outputs": [], 253 | "source": [ 254 | "from IPython.display import Image\n", 255 | "Image(filename=os.path.join(export_path,\n", 256 | " 'grids',\n", 257 | " f'{triplet_code}_analogy_grid.jpg' )) " 258 | ] 259 | }, 260 | { 261 | "cell_type": "markdown", 262 | "id": "ae0a069f", 263 | "metadata": {}, 264 | "source": [ 265 | "#### Visualize particular result for given \\sigma_i and \\lambda_j." 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": null, 271 | "id": "765c51aa", 272 | "metadata": {}, 273 | "outputs": [], 274 | "source": [ 275 | "row = 3 # [0-len(scales)]\n", 276 | "col = 10 # [0-len(steps)]\n", 277 | "current_scale = scales[row]\n", 278 | "current_step = steps[col]\n", 279 | "print(f'scale: {current_scale} | step: {current_step}')\n", 280 | "Image(filename=os.path.join(export_path,\n", 281 | " triplet_code,\n", 282 | " f'analogy_sc={current_scale}_shift_strength={current_step}.jpg' )) " 283 | ] 284 | }, 285 | { 286 | "cell_type": "markdown", 287 | "id": "adc55236", 288 | "metadata": {}, 289 | "source": [ 290 | "### Token visualization\n", 291 | "Estimated CLIP features are paired with random noise images (and transformed via the reverse diffusion process) to visualize the captured concepts. " 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": null, 297 | "id": "0ea2cb29", 298 | "metadata": {}, 299 | "outputs": [], 300 | "source": [ 301 | "\n", 302 | "if visualize_tokens:\n", 303 | " from torchvision.utils import save_image\n", 304 | " os.makedirs(os.path.join(experiment_root,'token_visualization'),exist_ok=True)\n", 305 | " def gen_random_samples(cond, file_id):\n", 306 | " tokens_,_ = analogy_creator.ddim_sampler.sample(\n", 307 | " analogy_creator.config.ddim_steps,\n", 308 | " 8,\n", 309 | " analogy_creator.config.shape,\n", 310 | " conditioning = cond.expand(8,-1,-1),\n", 311 | " eta=analogy_creator.config.ddim_eta,\n", 312 | " unconditional_guidance_scale=sc,\n", 313 | " unconditional_conditioning=analogy_creator.uc.expand(8,-1,-1),\n", 314 | " )\n", 315 | " utils.save_latent_as_image(\n", 316 | " analogy_creator.ddim_sampler.model, \n", 317 | " tokens_,\n", 318 | " os.path.join(experiment_root,'token_visualization',f'{file_id}.jpg')\n", 319 | " )\n", 320 | "\n", 321 | " from IPython.display import Image\n", 322 | " gen_random_samples(cA, file_id_A)\n", 323 | " Image(filename=os.path.join(experiment_root,'token_visualization',f'{file_id_A}.jpg')) \n", 324 | " gen_random_samples(cAprime, file_id_A_prime)\n", 325 | " Image(filename=os.path.join(experiment_root,'token_visualization',f'{file_id_A_prime}.jpg')) \n", 326 | " gen_random_samples(cB, file_id_B)\n", 327 | " Image(filename=os.path.join(experiment_root,'token_visualization',f'{file_id_B}.jpg')) \n", 328 | "\n" 329 | ] 330 | } 331 | ], 332 | "metadata": { 333 | "kernelspec": { 334 | "display_name": "analogies", 335 | "language": "python", 336 | "name": "analogies" 337 | }, 338 | "language_info": { 339 | "codemirror_mode": { 340 | "name": "ipython", 341 | "version": 3 342 | }, 343 | "file_extension": ".py", 344 | "mimetype": "text/x-python", 345 | "name": "python", 346 | "nbconvert_exporter": "python", 347 | "pygments_lexer": "ipython3", 348 | "version": "3.10.4" 349 | } 350 | }, 351 | "nbformat": 4, 352 | "nbformat_minor": 5 353 | } 354 | -------------------------------------------------------------------------------- /ddim_invertor.py: -------------------------------------------------------------------------------- 1 | from transformers import CLIPTokenizer, CLIPModel, CLIPProcessor 2 | from ldm.modules.diffusionmodules.util import noise_like 3 | from ldm.models.diffusion.ddim import DDIMSampler 4 | import torchvision.transforms as T 5 | from tqdm import tqdm 6 | import numpy as np 7 | import torch 8 | import gc 9 | 10 | from modified_clip_transformers import ModifiedCLIPTextModel 11 | import utils 12 | 13 | 14 | class DDIMInvertor(): 15 | def __init__(self, config, model, tokenizer=None) -> None: 16 | self.config = config 17 | self.ddim_sampler = DDIMSampler(model) 18 | self.ddim_sampler.make_schedule(self.config.ddim_steps, ddim_eta=self.config.ddim_eta, verbose=False) 19 | self.uc = self.ddim_sampler.model.get_learned_conditioning(['']) 20 | self.tokenizer = tokenizer 21 | 22 | 23 | def __sample_differentiable(self, cond, shape, 24 | x_T=None, ddim_use_original_steps=False, 25 | timesteps=None, unconditional_guidance_scale=1., 26 | unconditional_conditioning=None, 27 | ): 28 | b = cond.shape[0] 29 | if x_T is None: 30 | img = torch.randn(shape, device=self.config.device) 31 | else: 32 | img = x_T 33 | if timesteps is None: 34 | timesteps = self.ddim_sampler.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_sampler.ddim_timesteps 35 | elif timesteps is not None and not ddim_use_original_steps: 36 | subset_end = int(min(timesteps / self.ddim_sampler.ddim_timesteps.shape[0], 1) * self.ddim_sampler.ddim_timesteps.shape[0]) - 1 37 | timesteps = self.ddim_sampler.ddim_timesteps[:subset_end] 38 | 39 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) 40 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 41 | 42 | for i, step in enumerate(time_range): 43 | index = total_steps - i - 1 44 | ts = torch.full((b,), step, device=self.config.device, dtype=torch.long) 45 | 46 | 47 | outs = self.__step_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 48 | unconditional_guidance_scale=unconditional_guidance_scale, 49 | unconditional_conditioning=unconditional_conditioning) 50 | img, pred_x0 = outs 51 | return img 52 | 53 | 54 | def __step_ddim(self, x, c, t, index, use_original_steps=False, 55 | unconditional_guidance_scale=1., unconditional_conditioning=None): 56 | b, *_, device = *x.shape, x.device 57 | 58 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 59 | e_t = self.ddim_sampler.model.apply_model(x, t, c) 60 | else: 61 | x_in = torch.cat([x] * 2) 62 | t_in = torch.cat([t] * 2) 63 | c_in = torch.cat([unconditional_conditioning, c]) 64 | e_t_uncond, e_t = self.ddim_sampler.model.apply_model(x_in, t_in, c_in).chunk(2) 65 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 66 | 67 | 68 | alphas = self.ddim_sampler.model.alphas_cumprod if use_original_steps else self.ddim_sampler.ddim_alphas 69 | alphas_prev = self.ddim_sampler.model.alphas_cumprod_prev if use_original_steps else self.ddim_sampler.ddim_alphas_prev 70 | sqrt_one_minus_alphas = self.ddim_sampler.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sampler.ddim_sqrt_one_minus_alphas 71 | sigmas = self.ddim_sampler.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sampler.ddim_sigmas 72 | # select parameters corresponding to the currently considered timestep 73 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 74 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 75 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 76 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) 77 | 78 | # current prediction for x_0 79 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 80 | 81 | # direction pointing to x_t 82 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 83 | noise = sigma_t * noise_like(x.shape, device, False) 84 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 85 | return x_prev, pred_x0 86 | 87 | 88 | def perform_inversion(self, image, cond, init_noise_init = None, loss_weights = {'latents': 1. , 'pixels':1.} ): 89 | if cond is None: 90 | with torch.no_grad(): 91 | cond_out = utils.load_estimated_cond(utils.extract_file_id_from_path(image), token_subfolder=self.config.token_subfolder) 92 | assert cond_out is not None, 'Token inversion was not found...' 93 | cond = self.__tokens2conditioning(cond_out) 94 | 95 | target_img = utils.load_pil(image) 96 | target_img = target_img.resize((self.config.shape[-2] * self.config.f, self.config.shape[-1] * self.config.f)) 97 | target_img = utils.pil2torch(target_img) 98 | target_latent = utils.img2latent(self.ddim_sampler.model, target_img) 99 | target_latent = target_latent.to(self.ddim_sampler.model.device) 100 | target_img = target_img.to(self.ddim_sampler.model.device) 101 | 102 | 103 | if init_noise_init is None: 104 | alpha_t = torch.tensor([self.ddim_sampler.ddim_alphas[-1]]).cuda() 105 | init_noise = torch.sqrt(alpha_t) * target_latent + torch.sqrt(1. - alpha_t) * torch.randn_like(target_latent).to(target_latent.device) 106 | 107 | uc_scale = self.config.noise_optimization.uncond_guidance_scale 108 | 109 | init_noise.requires_grad = True 110 | 111 | lbfgs = torch.optim.LBFGS(params = [init_noise], lr = self.config.noise_optimization.lr) 112 | loss_fn = torch.nn.functional.mse_loss 113 | 114 | shape = [self.config.noise_optimization.batch_size, * self.config.shape] 115 | 116 | 117 | progress = {'loss':[]} 118 | progress['noise'] = [] 119 | 120 | pbar = tqdm(range(self.config.noise_optimization.opt_iters)) 121 | for i in pbar: 122 | def closure_(): 123 | lbfgs.zero_grad() 124 | x0_prediction = self.__sample_differentiable(cond, shape, 125 | x_T=init_noise, unconditional_guidance_scale=uc_scale, 126 | unconditional_conditioning= self.uc) 127 | 128 | loss = loss_weights['latents'] * loss_fn(x0_prediction, target_latent, reduction = 'mean') 129 | if loss_weights['pixels'] != 0: 130 | loss += loss_weights['pixels'] * utils.pixel_space_loss(self.ddim_sampler.model, x0_prediction, target_img, loss_fn) 131 | loss.backward() 132 | return loss.detach().item() 133 | 134 | 135 | x0_prediction = self.__sample_differentiable(cond, shape, 136 | x_T=init_noise, unconditional_guidance_scale=uc_scale, 137 | unconditional_conditioning= self.uc) 138 | 139 | loss = loss_weights['latents'] * loss_fn(x0_prediction, target_latent, reduction = 'mean') 140 | if loss_weights['pixels'] != 0: 141 | loss += loss_weights['pixels'] * utils.pixel_space_loss(self.ddim_sampler.model, x0_prediction, target_img, loss_fn) 142 | 143 | if i % self.config.noise_optimization.log_every == 0: 144 | progress['loss'].append(loss.item()) 145 | progress['noise'].append(init_noise.detach().cpu()) 146 | 147 | pbar.set_postfix({'loss': loss.cpu().item()}) 148 | 149 | if loss.item() < self.config.sufficient_loss: 150 | print(f'Ending computation with {loss.item()} done {i} steps.') 151 | break 152 | 153 | lbfgs.zero_grad() 154 | loss.backward() 155 | lbfgs.step(closure_) 156 | 157 | outputs = { 158 | 'estimated_input_noise': init_noise.detach(), 159 | 'estimated_conditioning': cond , 160 | 'initial_noise': init_noise_init, 161 | 'target_image_latent': target_latent, 162 | 'path2img': image, 163 | 'config_dict': self.config, 164 | 'reconstruction': x0_prediction.detach(), 165 | 'progress': progress, 166 | 'guidance_scale': uc_scale , 167 | } 168 | 169 | return outputs 170 | 171 | 172 | # taken from stable diffusion 173 | def add_noise(self, x0, noise, timestep_indices, ddim_use_original_steps=False): 174 | device= x0.device 175 | 176 | alphas_cumprod = self.ddim_sampler.model.ddim_alphas if ddim_use_original_steps else self.ddim_sampler.ddim_alphas 177 | sqrt_one_minus_alphas = self.ddim_sampler.model.ddim_sqrt_one_minus_alphas if ddim_use_original_steps else self.ddim_sampler.ddim_sqrt_one_minus_alphas 178 | sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).to(device) 179 | sqrt_one_minus_alphas = sqrt_one_minus_alphas.to(device) 180 | 181 | 182 | timestep_indices = timestep_indices.to(device) 183 | noise = noise.to(device) 184 | 185 | sqrt_at = torch.index_select(sqrt_alphas_cumprod, 0, timestep_indices).view(-1, 1, 1, 1).to(device) 186 | sqrt_one_minus_at = torch.index_select(sqrt_one_minus_alphas, 0, timestep_indices).view(-1, 1, 1, 1).to(device) 187 | 188 | noisy_samples = sqrt_at * x0.expand_as(noise) + sqrt_one_minus_at * noise 189 | return noisy_samples 190 | 191 | 192 | def ___prepare_batch_for_im(self, image): 193 | target_img = utils.load_pil(image) 194 | target_img = target_img.resize((self.config.shape[-2] * self.config.f, self.config.shape[-1] * self.config.f)) 195 | # # create batch 196 | hflipper = T.RandomHorizontalFlip(p=1) 197 | resize_cropper = T.RandomResizedCrop(size=(512, 512), scale = (0.85, 0.99),ratio=(1,1)) 198 | resized_crops = [resize_cropper(target_img) for _ in range(max(0, self.config.conditioning_optimization.batch_size - 2))] 199 | if self.config.conditioning_optimization.batch_size == 1: 200 | transformed_imgs = [target_img] 201 | else: 202 | transformed_imgs = [target_img, hflipper(target_img), *resized_crops] 203 | 204 | target_img = utils.pil2torch_batch(transformed_imgs) 205 | target_latent = utils.img2latent(self.ddim_sampler.model, target_img) 206 | return target_img, target_latent 207 | 208 | def __load_tokenizer_and_text_model(self, init_caption, tokenizer = None): 209 | version = 'openai/clip-vit-large-patch14' 210 | if tokenizer is None: 211 | tokenizer = CLIPTokenizer.from_pretrained(version) 212 | self.tokenizer = tokenizer 213 | 214 | if init_caption is None: 215 | return tokenizer, None 216 | batch_encoding = tokenizer(init_caption, truncation=True, max_length=77, return_length=True, 217 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 218 | 219 | 220 | embeddings = self.ddim_sampler.model.cond_stage_model.transformer.get_input_embeddings().weight.data[batch_encoding['input_ids'][0]] 221 | text_tokens = embeddings.clone() 222 | text_tokens.requires_grad = True 223 | 224 | return tokenizer, text_tokens 225 | 226 | def __tokens2conditioning(self, tokens): 227 | conditioning = self.ddim_sampler.model.cond_stage_model.transformer(inputs_embeds = tokens.unsqueeze(0))['last_hidden_state'] 228 | return conditioning 229 | 230 | 231 | 232 | def perform_cond_inversion_individual_timesteps(self, image_path, cond_init , optimize_tokens = True): 233 | self.config['optimize_tokens'] = optimize_tokens 234 | with torch.no_grad(): 235 | _, target_latent = self.___prepare_batch_for_im(image_path) 236 | timesteps = torch.tensor(self.ddim_sampler.ddim_timesteps) 237 | 238 | if optimize_tokens: 239 | tokenizer, text_tokens = self.__load_tokenizer_and_text_model('', tokenizer = self.tokenizer) 240 | if cond_init is not None: 241 | text_tokens = cond_init.squeeze(0) 242 | 243 | prompt_repre = text_tokens.detach().clone() 244 | 245 | grad_mask = torch.zeros_like(prompt_repre) 246 | grad_mask[:self.config.conditioning_optimization.N_tokens,:] = 1. 247 | grad_mask = grad_mask.to(self.ddim_sampler.model.device) 248 | fetch_cond_init = lambda x: self.ddim_sampler.model.cond_stage_model.transformer(inputs_embeds = x.unsqueeze(0))['last_hidden_state'] 249 | prompt_repre.requires_grad = True 250 | 251 | uc_scale = self.config.conditioning_optimization.uncond_guidance_scale 252 | 253 | adam = torch.optim.AdamW(params = [prompt_repre], lr = self.config.conditioning_optimization.lr) 254 | loss_fn = torch.nn.functional.mse_loss 255 | 256 | 257 | progress = {'loss':[], 'indices':[]} 258 | progress['cond'] = [] 259 | 260 | timestep_indices = torch.randperm(self.config.conditioning_optimization.batch_size).view(-1).long() 261 | print(f'Selected timesteps: {timestep_indices}') 262 | 263 | 264 | pbar = tqdm(range(self.config.conditioning_optimization.opt_iters)) 265 | for i in pbar: 266 | 267 | noise_ = torch.randn_like(target_latent) 268 | 269 | if not self.config.conditioning_optimization.fixed_timesteps: 270 | timestep_indices = torch.randint(low=0, high=self.config.ddim_steps, size=(self.config.conditioning_optimization.batch_size,1) ).view(-1) 271 | 272 | noisy_samples = self.add_noise(target_latent, noise_, timestep_indices, ddim_use_original_steps=False) 273 | 274 | steps_in = torch.index_select(timesteps, 0, timestep_indices).to(self.config.device) 275 | cond_init = fetch_cond_init(prompt_repre) 276 | 277 | 278 | noise_prediction = self.ddim_sampler.model.apply_model(noisy_samples, steps_in, cond_init.expand(self.config.conditioning_optimization.batch_size, -1 , -1)) 279 | 280 | loss = loss_fn(noise_prediction, noise_, reduction = 'none').mean((1,2,3)).mean() 281 | 282 | 283 | if i % self.config.conditioning_optimization.log_every == 0: 284 | progress['indices'].append(timestep_indices) 285 | progress['loss'].append(loss.item()) 286 | progress['cond'].append(prompt_repre.detach().cpu()) 287 | 288 | 289 | pbar.set_postfix({'loss': loss.cpu().item(), 'indices':timestep_indices}) 290 | 291 | if loss.item() < self.config.sufficient_loss: 292 | print(f'Ending computation with {loss.item()} done {i} steps.') 293 | break 294 | 295 | adam.zero_grad() 296 | loss.backward() 297 | prompt_repre.grad *= grad_mask 298 | adam.step() 299 | 300 | outputs = { 301 | 'estimated_conditioning': prompt_repre.detach(), 302 | 'target_image_latent': target_latent, 303 | 'config_dict': self.config, 304 | 'optimize_tokens': optimize_tokens, 305 | 'progress': progress, 306 | 'guidance_scale': uc_scale , 307 | } 308 | 309 | return outputs 310 | 311 | 312 | -------------------------------------------------------------------------------- /modified_clip_transformers.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | import torch 3 | import torch.nn as nn 4 | from transformers.models.clip.configuration_clip import CLIPTextConfig, CLIPConfig, CLIPVisionConfig 5 | from transformers.modeling_outputs import BaseModelOutputWithPooling 6 | from transformers.models.clip.modeling_clip import _expand_mask, CLIPTextEmbeddings, CLIPEncoder, CLIPPreTrainedModel, CLIPOutput,CLIPTextTransformer, CLIPVisionTransformer 7 | 8 | 9 | # code taken from hugging face transformers library 10 | # https://github.com/huggingface/transformers/blob/bd469c40659ce76c81f69c7726759d249b4aef49/src/transformers/models/clip/modeling_clip.py 11 | # modified lines are marked 12 | 13 | class ModifiedCLIPTextTransformer(CLIPTextTransformer): 14 | def __init__(self, config: CLIPTextConfig): 15 | super().__init__(config) 16 | 17 | def forward( 18 | self, 19 | input_ids: Optional[torch.Tensor] = None, 20 | attention_mask: Optional[torch.Tensor] = None, 21 | position_ids: Optional[torch.Tensor] = None, 22 | output_attentions: Optional[bool] = None, 23 | output_hidden_states: Optional[bool] = None, 24 | inputs_embeds: Optional[torch.Tensor] = None, 25 | return_dict: Optional[bool] = None, 26 | ) -> Union[Tuple, BaseModelOutputWithPooling]: 27 | r""" 28 | Returns: 29 | """ 30 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 31 | output_hidden_states = ( 32 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 33 | ) 34 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 35 | 36 | # MODIFIED 37 | if input_ids is None and inputs_embeds is None: 38 | raise ValueError("You have to specify either input_ids or inputs_embeds") 39 | 40 | if input_ids is None: 41 | input_shape = inputs_embeds.size() 42 | bsz, seq_len, dim = input_shape 43 | else: 44 | input_shape = input_ids.size() 45 | bsz, seq_len = input_shape 46 | input_ids = input_ids.view(-1, input_shape[-1]) 47 | 48 | 49 | hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds ) 50 | ########## 51 | 52 | 53 | # CLIP's text model uses causal mask, prepare it here. 54 | # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 55 | causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to( 56 | hidden_states.device 57 | ) 58 | # expand attention_mask 59 | if attention_mask is not None: 60 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 61 | attention_mask = _expand_mask(attention_mask, hidden_states.dtype) 62 | 63 | encoder_outputs = self.encoder( 64 | inputs_embeds=hidden_states, 65 | attention_mask=attention_mask, 66 | causal_attention_mask=causal_attention_mask, 67 | output_attentions=output_attentions, 68 | output_hidden_states=output_hidden_states, 69 | return_dict=return_dict, 70 | ) 71 | 72 | last_hidden_state = encoder_outputs[0] 73 | last_hidden_state = self.final_layer_norm(last_hidden_state) 74 | 75 | # text_embeds.shape = [batch_size, sequence_length, transformer.width] 76 | # take features from the eot embedding (eot_token is the highest number in each sequence) 77 | # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 78 | pooled_output = None 79 | 80 | if not return_dict: 81 | return (last_hidden_state, pooled_output) + encoder_outputs[1:] 82 | 83 | return BaseModelOutputWithPooling( 84 | last_hidden_state=last_hidden_state, 85 | pooler_output=pooled_output, 86 | hidden_states=encoder_outputs.hidden_states, 87 | attentions=encoder_outputs.attentions, 88 | ) 89 | 90 | def _build_causal_attention_mask(self, bsz, seq_len, dtype): 91 | # lazily create causal attention mask, with full attention between the vision tokens 92 | # pytorch uses additive attention mask; fill with -inf 93 | mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype) 94 | mask.fill_(torch.tensor(torch.finfo(dtype).min)) 95 | mask.triu_(1) # zero out the lower diagonal 96 | mask = mask.unsqueeze(1) # expand mask 97 | return mask 98 | 99 | 100 | 101 | class ModifiedCLIPTextModel(CLIPPreTrainedModel): 102 | config_class = CLIPTextConfig 103 | 104 | _no_split_modules = ["CLIPEncoderLayer"] 105 | 106 | def __init__(self, config: CLIPTextConfig): 107 | super().__init__(config) 108 | # MODIFIED 109 | self.text_model = ModifiedCLIPTextTransformer(config) 110 | ########## 111 | # Initialize weights and apply final processing 112 | self.post_init() 113 | 114 | def get_input_embeddings(self) -> nn.Module: 115 | return self.text_model.embeddings.token_embedding 116 | 117 | def set_input_embeddings(self, value): 118 | self.text_model.embeddings.token_embedding = value 119 | 120 | 121 | def forward( 122 | self, 123 | input_ids: Optional[torch.Tensor] = None, 124 | attention_mask: Optional[torch.Tensor] = None, 125 | position_ids: Optional[torch.Tensor] = None, 126 | output_attentions: Optional[bool] = None, 127 | output_hidden_states: Optional[bool] = None, 128 | inputs_embeds: Optional[torch.Tensor] = None, 129 | return_dict: Optional[bool] = None, 130 | ) -> Union[Tuple, BaseModelOutputWithPooling]: 131 | r""" 132 | Returns: 133 | Examples: 134 | ```python 135 | >>> from transformers import CLIPTokenizer, CLIPTextModel 136 | >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") 137 | >>> tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") 138 | >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") 139 | >>> outputs = model(**inputs) 140 | >>> last_hidden_state = outputs.last_hidden_state 141 | >>> pooled_output = outputs.pooler_output # pooled (EOS token) states 142 | ```""" 143 | 144 | 145 | # MODIFIED 146 | return self.text_model( 147 | input_ids=input_ids, 148 | attention_mask=attention_mask, 149 | position_ids=position_ids, 150 | output_attentions=output_attentions, 151 | output_hidden_states=output_hidden_states, 152 | return_dict=return_dict, 153 | inputs_embeds=inputs_embeds, 154 | ) 155 | ########## 156 | 157 | 158 | 159 | 160 | 161 | 162 | # class ModifiedCLIPModel(CLIPPreTrainedModel): 163 | # config_class = CLIPConfig 164 | 165 | # def __init__(self, config: CLIPConfig): 166 | # super().__init__(config) 167 | 168 | # if not isinstance(config.text_config, CLIPTextConfig): 169 | # raise ValueError( 170 | # "config.text_config is expected to be of type CLIPTextConfig but is of type" 171 | # f" {type(config.text_config)}." 172 | # ) 173 | 174 | # if not isinstance(config.vision_config, CLIPVisionConfig): 175 | # raise ValueError( 176 | # "config.vision_config is expected to be of type CLIPVisionConfig but is of type" 177 | # f" {type(config.vision_config)}." 178 | # ) 179 | 180 | # text_config = config.text_config 181 | # vision_config = config.vision_config 182 | 183 | # self.projection_dim = config.projection_dim 184 | # self.text_embed_dim = text_config.hidden_size 185 | # self.vision_embed_dim = vision_config.hidden_size 186 | 187 | # self.text_model = ModifiedCLIPTextTransformer(text_config) 188 | # self.vision_model = CLIPVisionTransformer(vision_config) 189 | 190 | # self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False) 191 | # self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False) 192 | # self.logit_scale = nn.Parameter(torch.ones([]) * self.config.logit_scale_init_value) 193 | 194 | # # Initialize weights and apply final processing 195 | # self.post_init() 196 | 197 | # def get_text_features( 198 | # self, 199 | # input_ids: Optional[torch.Tensor] = None, 200 | # attention_mask: Optional[torch.Tensor] = None, 201 | # position_ids: Optional[torch.Tensor] = None, 202 | # output_attentions: Optional[bool] = None, 203 | # output_hidden_states: Optional[bool] = None, 204 | # return_dict: Optional[bool] = None, 205 | # ) -> torch.FloatTensor: 206 | # r""" 207 | # Returns: 208 | # text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by 209 | # applying the projection layer to the pooled output of [`CLIPTextModel`]. 210 | # Examples: 211 | # ```python 212 | # >>> from transformers import CLIPTokenizer, CLIPModel 213 | # >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") 214 | # >>> tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") 215 | # >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") 216 | # >>> text_features = model.get_text_features(**inputs) 217 | # ```""" 218 | # # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. 219 | # output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 220 | # output_hidden_states = ( 221 | # output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 222 | # ) 223 | # return_dict = return_dict if return_dict is not None else self.config.use_return_dict 224 | 225 | # text_outputs = self.text_model( 226 | # input_ids=input_ids, 227 | # attention_mask=attention_mask, 228 | # position_ids=position_ids, 229 | # output_attentions=output_attentions, 230 | # output_hidden_states=output_hidden_states, 231 | # return_dict=return_dict, 232 | # ) 233 | 234 | # # TODO: fail pooled output je None 235 | # pooled_output = text_outputs[1] 236 | # text_features = self.text_projection(pooled_output) 237 | 238 | # return text_features 239 | 240 | # def get_image_features( 241 | # self, 242 | # pixel_values: Optional[torch.FloatTensor] = None, 243 | # output_attentions: Optional[bool] = None, 244 | # output_hidden_states: Optional[bool] = None, 245 | # return_dict: Optional[bool] = None, 246 | # ) -> torch.FloatTensor: 247 | # r""" 248 | # Returns: 249 | # image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by 250 | # applying the projection layer to the pooled output of [`CLIPVisionModel`]. 251 | # Examples: 252 | # ```python 253 | # >>> from PIL import Image 254 | # >>> import requests 255 | # >>> from transformers import CLIPProcessor, CLIPModel 256 | # >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") 257 | # >>> processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") 258 | # >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" 259 | # >>> image = Image.open(requests.get(url, stream=True).raw) 260 | # >>> inputs = processor(images=image, return_tensors="pt") 261 | # >>> image_features = model.get_image_features(**inputs) 262 | # ```""" 263 | # # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. 264 | # output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 265 | # output_hidden_states = ( 266 | # output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 267 | # ) 268 | # return_dict = return_dict if return_dict is not None else self.config.use_return_dict 269 | 270 | # vision_outputs = self.vision_model( 271 | # pixel_values=pixel_values, 272 | # output_attentions=output_attentions, 273 | # output_hidden_states=output_hidden_states, 274 | # return_dict=return_dict, 275 | # ) 276 | 277 | # pooled_output = vision_outputs[1] # pooled_output 278 | # image_features = self.visual_projection(pooled_output) 279 | 280 | # return image_features 281 | 282 | # # MODIFIED 283 | 284 | # def forward( 285 | # self, 286 | # input_ids: Optional[torch.LongTensor] = None, 287 | # pixel_values: Optional[torch.FloatTensor] = None, 288 | # attention_mask: Optional[torch.Tensor] = None, 289 | # position_ids: Optional[torch.LongTensor] = None, 290 | # return_loss: Optional[bool] = None, 291 | # output_attentions: Optional[bool] = None, 292 | # output_hidden_states: Optional[bool] = None, 293 | # return_dict: Optional[bool] = None, 294 | # inputs_embeds: Optional[torch.FloatTensor] = None, 295 | # ) -> Union[Tuple, CLIPOutput]: 296 | # r""" 297 | # Returns: 298 | # Examples: 299 | # ```python 300 | # >>> from PIL import Image 301 | # >>> import requests 302 | # >>> from transformers import CLIPProcessor, CLIPModel 303 | # >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") 304 | # >>> processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") 305 | # >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" 306 | # >>> image = Image.open(requests.get(url, stream=True).raw) 307 | # >>> inputs = processor( 308 | # ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True 309 | # ... ) 310 | # >>> outputs = model(**inputs) 311 | # >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score 312 | # >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities 313 | # ```""" 314 | # # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. 315 | # output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 316 | # output_hidden_states = ( 317 | # output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 318 | # ) 319 | # return_dict = return_dict if return_dict is not None else self.config.use_return_dict 320 | 321 | # vision_outputs = self.vision_model( 322 | # pixel_values=pixel_values, 323 | # output_attentions=output_attentions, 324 | # output_hidden_states=output_hidden_states, 325 | # return_dict=return_dict, 326 | # ) 327 | 328 | # text_outputs = self.text_model( 329 | # input_ids=input_ids, 330 | # attention_mask=attention_mask, 331 | # position_ids=position_ids, 332 | # output_attentions=output_attentions, 333 | # output_hidden_states=output_hidden_states, 334 | # return_dict=return_dict, 335 | # inputs_embeds=inputs_embeds, 336 | 337 | # ) 338 | 339 | # image_embeds = vision_outputs[1] 340 | 341 | # image_embeds = self.visual_projection(image_embeds) 342 | 343 | # text_embeds = text_outputs[0].squeeze(0) 344 | # text_embeds = self.text_projection(text_embeds) 345 | 346 | # # normalized features 347 | # image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) 348 | # text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) 349 | 350 | 351 | # # cosine similarity as logits 352 | # logit_scale = self.logit_scale.exp() 353 | # logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale 354 | # print(logits_per_text.shape) 355 | # logits_per_image = logits_per_text.t() 356 | 357 | # loss = None 358 | # if return_loss: 359 | # loss = clip_loss(logits_per_text) 360 | 361 | # if not return_dict: 362 | # output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) 363 | # return ((loss,) + output) if loss is not None else output 364 | 365 | # return CLIPOutput( 366 | # loss=loss, 367 | # logits_per_image=logits_per_image, 368 | # logits_per_text=logits_per_text, 369 | # text_embeds=text_embeds, 370 | # image_embeds=image_embeds, 371 | # text_model_output=text_outputs, 372 | # vision_model_output=vision_outputs, 373 | # ) 374 | # ########## --------------------------------------------------------------------------------