├── triplets.csv
├── dataset
└── data
│ ├── 00000.jpg
│ ├── 00001.JPG
│ └── 00002.jpg
├── .gitmodules
├── config
├── analogy_params.yaml
└── parameter_estimation.yaml
├── .gitignore
├── requirements.txt
├── LICENSE
├── process_new_data.py
├── visualize_tokens.py
├── precompute_noises_and_conditionings.py
├── estimate_input_noise.py
├── README.md
├── analogy_creator.py
├── estimate_CLIP_features.py
├── do_analogies.py
├── utils.py
├── DiffusionImageAnalogies.ipynb
├── ddim_invertor.py
└── modified_clip_transformers.py
/triplets.csv:
--------------------------------------------------------------------------------
1 | 00000 00001 00002
--------------------------------------------------------------------------------
/dataset/data/00000.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/subrtadel/DIA/HEAD/dataset/data/00000.jpg
--------------------------------------------------------------------------------
/dataset/data/00001.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/subrtadel/DIA/HEAD/dataset/data/00001.JPG
--------------------------------------------------------------------------------
/dataset/data/00002.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/subrtadel/DIA/HEAD/dataset/data/00002.jpg
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "stable-diffusion"]
2 | path = stable-diffusion
3 | url = https://github.com/CompVis/stable-diffusion.git
4 |
--------------------------------------------------------------------------------
/config/analogy_params.yaml:
--------------------------------------------------------------------------------
1 | add_orig_row: True
2 | guidance_scales: [1., 2., 3., 5., 7., 9., 12.]
3 | analogy_strength: [0, 1.5, 20] # start, end, number of samples
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | results/
2 | dataset/data
3 |
4 | __pycache__/
5 |
6 |
7 | .DS_Store
8 |
9 | *.jpg
10 | *.png
11 | *.jpeg
12 | *.JPG
13 | *.JPEG
14 |
15 | *.pdf
16 |
17 |
18 | estimate_input_noise_tmp_from_prompt.py
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | albumentations==0.4.3
2 | diffusers
3 | clip
4 | pudb==2019.2
5 | invisible-watermark
6 | imageio==2.9.0
7 | imageio-ffmpeg==0.4.2
8 | pytorch-lightning==1.4.2
9 | omegaconf==2.1.1
10 | test-tube>=0.7.5
11 | streamlit>=0.73.1
12 | einops==0.3.0
13 | torch-fidelity==0.3.0
14 | transformers==4.19.2
15 | torchmetrics==0.6.0
16 | kornia==0.6
17 | Pillow==9.3.0
18 | matplotlib==3.6.2
19 | ipython==8.7.0
20 | imageio==2.9.0
21 | tqdm==4.62.3
22 |
--------------------------------------------------------------------------------
/config/parameter_estimation.yaml:
--------------------------------------------------------------------------------
1 | sufficient_loss: 0.0008
2 | ddim_steps: 10
3 | ddim_eta: 0.
4 | noise_optimization:
5 | opt_iters: 10
6 | log_every: 1
7 | lr: 0.1
8 | batch_size: 8
9 | uncond_guidance_scale: 1.
10 |
11 | conditioning_optimization:
12 | opt_iters: 10
13 | log_every: 1
14 | lr: 0.01
15 | N_tokens: 10
16 | batch_size: 8
17 | uncond_guidance_scale: 1.
18 | fixed_timesteps: True # non-deterministic results with False
19 |
20 | uncond_guidance_scale: 1.
21 | path2save_prefix: './results/experiments/inversion/'
22 | shape: [4,64,64]
23 | device: 'cuda:0'
24 | batch_size: 1
25 | f: 8
26 | save_reconstruction: True
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Adéla Šubrtová and contributors
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/process_new_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 |
4 | path2raw = './dataset/raw_data/'
5 | path2clean = './dataset/data/'
6 |
7 | os.makedirs(path2clean, exist_ok=True)
8 |
9 | def separate_filenames(path):
10 | all_raw_files = os.listdir(path)
11 | raw_files_to_rename = [fn for fn in all_raw_files if re.match('^[0-9]{5}\.', fn) is None]
12 | raw_files_okay = [fn for fn in all_raw_files if re.match('^[0-9]{5}\.', fn) is not None]
13 | return raw_files_to_rename, raw_files_okay
14 |
15 | def determine_file_count(raw_files_okay):
16 | okay_file_names_numbers = [int(rfn.split('.')[0]) for rfn in raw_files_okay]
17 | old_file_count = 0
18 | file_count = 0
19 | if len(okay_file_names_numbers) != 0:
20 | old_file_count = max(okay_file_names_numbers)
21 | file_count = 1 + old_file_count
22 | return file_count
23 |
24 | raw_files_to_rename, raw_files_okay = separate_filenames(path2raw)
25 |
26 | file_count = determine_file_count(raw_files_okay)
27 |
28 | for fn in raw_files_to_rename:
29 | # rename new files to predefined format
30 | suffix = fn.split('.')[-1]
31 | new_file_name = f'{file_count:05d}.{suffix}'
32 | os.rename(os.path.join(path2raw,fn), os.path.join(path2raw, new_file_name))
33 | file_count += 1
34 |
35 | # pad images
36 | os.system(f'convert {os.path.join(path2raw, new_file_name)} -virtual-pixel black -set option:distort:viewport "%[fx:max(w,h)]x%[fx:max(w,h)]-%[fx:max((h-w)/2,0)]-%[fx:max((w-h)/2,0)]" -filter point -distort SRT 0 +repage {os.path.join(path2clean, new_file_name)}')
37 | # resize
38 | os.system(f'convert {os.path.join(path2clean, new_file_name)} -resize 512x512 {os.path.join(path2clean, new_file_name)}')
39 |
40 |
41 | print(f'Done. {len(raw_files_to_rename)} new files were added.')
--------------------------------------------------------------------------------
/visualize_tokens.py:
--------------------------------------------------------------------------------
1 |
2 | from omegaconf import OmegaConf
3 | import numpy as np
4 | import os
5 |
6 |
7 | from ldm.models.diffusion.ddim import DDIMSampler
8 | import utils
9 |
10 |
11 | def fetch_cond_matrix(file_id, ddim_sampler, config):
12 | cond_out = utils.load_estimated_cond(file_id, token_subfolder=token_subfolder)
13 | cond_out = ddim_sampler.model.cond_stage_model.transformer(inputs_embeds = cond_out.unsqueeze(0))['last_hidden_state']
14 | return cond_out.to(config.device)
15 |
16 |
17 | path_to_data = './dataset/data/'
18 | experiment_root = './results'
19 |
20 | config = OmegaConf.load('./config/parameter_estimation.yaml')
21 |
22 |
23 | print('Loading model...')
24 | model = utils.prepare_default_model()
25 | print('Model loaded')
26 |
27 |
28 |
29 | token_subfolder = 'tokens'
30 | subfolder = 'noise'
31 |
32 |
33 | ddim_sampler = DDIMSampler(model)
34 |
35 | # analogy_creator = AnalogyCreator(config, ddim_sampler, subfolder, token_subfolder, os.path.join(experiment_root, 'analogy_results', out_subfolder))
36 | file_id_A = '00009'
37 | file_id_A_prime = '00008'
38 | file_id_B = '00010'
39 |
40 |
41 |
42 | cA = fetch_cond_matrix(file_id_A, ddim_sampler, config)
43 | cAprime = fetch_cond_matrix(file_id_A_prime, ddim_sampler, config)
44 | cB = fetch_cond_matrix(file_id_B, ddim_sampler, config)
45 |
46 |
47 | os.makedirs(os.path.join(experiment_root,'token_visualization'),exist_ok=True)
48 | def gen_random_samples(cond, file_id):
49 | tokens_,_ = ddim_sampler.sample(
50 | config.ddim_steps,
51 | 8,
52 | config.shape,
53 | conditioning = cond.expand(8,-1,-1),
54 | eta=config.ddim_eta,
55 | unconditional_guidance_scale=1.,
56 | unconditional_conditioning=ddim_sampler.model.get_learned_conditioning(['']).expand(8,-1,-1),
57 | )
58 | utils.save_latent_as_image(
59 | ddim_sampler.model,
60 | tokens_,
61 | os.path.join(experiment_root,'token_visualization',f'{file_id}.jpg')
62 | )
63 |
64 | gen_random_samples(cA, file_id_A)
65 | gen_random_samples(cAprime, file_id_A_prime)
66 | gen_random_samples(cB, file_id_B)
67 |
--------------------------------------------------------------------------------
/precompute_noises_and_conditionings.py:
--------------------------------------------------------------------------------
1 | import os
2 | import utils
3 | from argparse import ArgumentParser
4 |
5 |
6 |
7 | parser = ArgumentParser()
8 | parser.add_argument('--config', dest='config', type=str, default='./config/parameter_estimation.yaml',
9 | help='path to config file')
10 |
11 | parser.add_argument('--inversion_subfolder', dest='subfolder', type=str, default = 'noise',
12 | help='inversion subfolder name')
13 |
14 | parser.add_argument('--token_subfolder', dest='token_subfolder', type=str, default = 'tokens',
15 | help='Token inversion subfolder name')
16 |
17 | parser.add_argument('--triplet_file', dest='triplet_file', type=str, default='triplets.csv',
18 | help='file with triplets')
19 |
20 |
21 | parser.add_argument('--data_path', dest='data_path', type=str, default = './dataset/data/',
22 | help='root path to data')
23 |
24 | args = parser.parse_args()
25 | args = vars(args)
26 |
27 |
28 |
29 | with open(args['triplet_file'], 'r') as f:
30 | file_lines = f.readlines()
31 |
32 |
33 | conditioning_inversion_names = []
34 | noise_inversion_names = []
35 | for line in file_lines:
36 | clean_line = line.replace('\n','').split(' ')
37 | A_name = utils.file_id2im_path(clean_line[0])
38 | Aprime_name = utils.file_id2im_path(clean_line[1])
39 | B_name = utils.file_id2im_path(clean_line[2])
40 |
41 | conditioning_inversion_names.extend([A_name ,Aprime_name, B_name])
42 | noise_inversion_names.append(B_name)
43 |
44 |
45 | with open('tmp_clip_inversion.txt','w') as f:
46 | for fn in set(conditioning_inversion_names):
47 | f.write(f'{fn}\n')
48 |
49 | with open('tmp_noise_inversion.txt','w') as f:
50 | for fn in set(noise_inversion_names):
51 | f.write(f'{fn}\n')
52 |
53 |
54 | os.system(f'python estimate_CLIP_features.py --config {args["config"]} --subfolder {args["token_subfolder"]} --input_img tmp_clip_inversion.txt --data_path {args["data_path"]}')
55 |
56 |
57 | os.system(f'python estimate_input_noise.py --config {args["config"]} --input_img tmp_noise_inversion.txt --token_subfolder {args["token_subfolder"]} --subfolder {args["subfolder"]} --data_path {args["data_path"]}')
58 |
59 |
60 | os.remove('tmp_clip_inversion.txt')
61 | os.remove('tmp_noise_inversion.txt')
--------------------------------------------------------------------------------
/estimate_input_noise.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | from omegaconf import OmegaConf
3 | from PIL import Image
4 | import pickle as pkl
5 | import torch
6 | import os
7 |
8 | from ddim_invertor import DDIMInvertor
9 | import utils
10 |
11 |
12 | parser = ArgumentParser()
13 | parser.add_argument('--config', dest='config', type=str, default='./config/parameter_estimation.yaml',
14 | help='path to config file')
15 |
16 | parser.add_argument('--input_img', dest='input_img', type=str, default = None,
17 | help='path to image or text files with image names')
18 |
19 | parser.add_argument('--subfolder', dest='subfolder', type=str, default = 'noise',
20 | help='subfolder name')
21 |
22 | parser.add_argument('--token_subfolder', dest='token_subfolder', type=str, default = 'tokens',
23 | help='token subfolder name')
24 |
25 |
26 | parser.add_argument('--data_path', dest='data_path', type=str, default = './dataset/data/',
27 | help='root path to data')
28 |
29 | args = parser.parse_args()
30 | args = vars(args)
31 |
32 | assert os.path.isfile(args['input_img']), '--input_img is not a file'
33 |
34 | if args['input_img'].endswith('.txt'):
35 | with open(args['input_img'], 'r') as f:
36 | file_lines = f.readlines()
37 | clean_file_lines = [os.path.join(args['data_path'],x.replace('\n','')) for x in file_lines]
38 | elif args['input_img'].endswith(('.png','.jpeg','.jpg', 'JPG', 'JPEG')):
39 | clean_file_lines = [args['input_img']]
40 |
41 | config = OmegaConf.load(f"{args['config']}")
42 | config.args = args
43 | config.token_subfolder = args['token_subfolder']
44 |
45 |
46 | print('Loading model...')
47 | model = utils.prepare_default_model()
48 | model = model.to(config.device)
49 | print('Model loaded')
50 |
51 |
52 | invertor = DDIMInvertor(config, model)
53 |
54 |
55 | for file_name in clean_file_lines:
56 | print(f'Processing file: {file_name}')
57 | if not os.path.exists(file_name):
58 | print(f'Path {file_name} does not exist. Skipping')
59 | continue
60 | # load & prepare image
61 | file_id = utils.extract_file_id_from_path(file_name)
62 | export_path = os.path.join(config.path2save_prefix, file_id, args['subfolder'])
63 | if os.path.exists(os.path.join(export_path, 'results.pkl')):
64 | print(f'The inversion for {file_id} seems to be done already. Skipping...')
65 | continue
66 | os.makedirs(export_path, exist_ok=True)
67 |
68 |
69 | print('Performing inversion...')
70 | outputs = invertor.perform_inversion(file_name, cond = None, init_noise_init = None, loss_weights= {'latents': 1. , 'pixels':1.} )
71 |
72 |
73 | outputs['token_subfolder'] = args['token_subfolder']
74 |
75 | if config.save_reconstruction:
76 | img = utils.load_pil(file_name)
77 | img.save(os.path.join(export_path,f'target.png'))
78 |
79 | rec_img_torch = utils.latent2img(model, outputs['reconstruction'])
80 | rec_img_pil = utils.torch2pil(rec_img_torch)[0]
81 | rec_img_pil.save(os.path.join(export_path, 'reconstruction.png'))
82 |
83 | print(f'Saving results to {export_path}')
84 | utils.save_results2pickle(export_path, outputs)
85 |
86 |
87 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Diffusion Image Analogies
2 |
10 |
11 | 
12 | This is the official repository for the Diffusion Image Analogies paper published at the SIGGRAPH 2023 Conference Proceedings.
13 |
14 | ***
15 |
16 | ## Installation
17 |
18 | 1. Clone the repo
19 | ```sh
20 | git clone --recurse-submodules https://github.com/subrtadel/DIA.git
21 | cd ./DIA
22 | ```
23 | 2. Create environment
24 | ```
25 | conda create -n dia_env
26 | conda activate dia_env
27 | conda install python=3.8.5 pip=20.3 cudatoolkit=11.3 pytorch=1.11.0 torchvision=0.12.0 numpy=1.19.2 -c pytorch -c nvidia -c conda-forge -c defaults
28 | ```
29 | 3. Install packages
30 | ```sh
31 | pip install -r requirements.txt
32 | cd ./stable-diffusion/
33 | pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
34 | pip install -e .
35 | ```
36 | 4. Download the [sd-v1-4.ckpt model](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original) and put it into correct folder
37 | ```
38 | mkdir -p ./models/ldm/stable-diffusion-v1/
39 |
40 | ```
41 | 5. Install [Image Magick](https://imagemagick.org).
42 |
43 | (back to top)
44 |
45 | ***
46 |
47 |
48 | ## Usage
49 |
50 | 1. Upload images into `./dataset/raw_data/` folder.
51 |
52 |
53 | 2. Run `process_new_data.py`. The images are assigned `file_id`s in a `%05d` format.
54 |
55 |
56 | 3. Define the triplets in a `.csv` file. Refer to the images by their `file_id`.
57 | Example file is `triplets.csv`. First column specifies the `A` input, second the `A'` and the third `B` input. Either with of without filename suffixes is fine.
58 |
59 |
60 | 4. Run the `precompute_noises_and_conditionings.py` script. This may take a while.
61 | ```python precompute_noises_and_conditionings.py
62 | python precompute_noises_and_conditionings.py \
63 | --config ./config/parameter_estimation.yaml \
64 | --inversion_subfolder noise \
65 | --token_subfolder tokens \
66 | --triplet_file triplets.csv \
67 | --data_path ./dataset/data/
68 | ```
69 |
70 |
71 | 5. Check the `./config/analogy_params.yaml`.
72 |
73 |
74 | 6. Run the `do_analogies.py` script.
75 | ```python do_analogies.py
76 | python do_analogies.py \
77 | --config ./config/parameter_estimation.yaml \
78 | --inversion_subfolder noise \
79 | --token_subfolder tokens \
80 | --output_subfolder analogies \
81 | --triplet_file triplets.csv \
82 | --data_path ./dataset/data/
83 | ```
84 |
85 |
86 |
87 | ***
88 |
89 | ## BibTeX
90 |
91 | @inproceedings{Subrtova2023DIA,
92 | title = {Diffusion Image Analogies},
93 | author = {A. \v{S}ubrtov\'{a} and M. Luk\'{a}\v{c} and J. \v{C}ech and D. Futschik and E. Shechtman and D. S\'{y}kora},
94 | booktitle = {ACM SIGGRAPH 2023 Conference Proceedings},
95 | year = {2023}
96 | }
97 |
--------------------------------------------------------------------------------
/analogy_creator.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from PIL import Image
4 |
5 | from ddim_invertor import DDIMInvertor
6 | import utils
7 |
8 | class AnalogyCreator():
9 | def __init__(self, config, ddim_sampler, inversion_subfolder, token_subfolder, output_root, data_path) -> None:
10 | self.config = config
11 | self.ddim_sampler = ddim_sampler
12 | self.subfolder = inversion_subfolder
13 | self.token_subfolder = token_subfolder
14 | self.output_root = output_root
15 | self.uc = self.ddim_sampler.model.get_learned_conditioning([''])
16 | self.data_path = data_path
17 |
18 |
19 | def fetch_cond_matrix(self, file_id):
20 | cond_out = utils.load_estimated_cond(file_id, token_subfolder=self.token_subfolder)
21 | cond_out = self.ddim_sampler.model.cond_stage_model.transformer(inputs_embeds = cond_out.unsqueeze(0))['last_hidden_state']
22 | return cond_out.to(self.config.device)
23 |
24 | def __load_B_noise(self, imB):
25 | fileid_B = utils.extract_file_id_from_path(imB)
26 | _,_,_, resB = utils.load_inversion_result_dict(fileid_B, self.subfolder, return_result_dict=True)
27 | return resB['estimated_input_noise']
28 |
29 | def make_analogy(self, triplet, steps, uc_scales, analogy_func = None, **analogy_func_kwargs):
30 | print(f'Make analogy inputs: {triplet}, {uc_scales}, {steps}')
31 | triplet_code = '_'.join([utils.extract_file_id_from_path(t) for t in triplet])
32 | noise = self.__load_B_noise(triplet[-1])
33 |
34 | cA = self.fetch_cond_matrix(utils.extract_file_id_from_path(triplet[0]))
35 | cAprime = self.fetch_cond_matrix(utils.extract_file_id_from_path(triplet[1]))
36 |
37 | cB = self.fetch_cond_matrix(utils.extract_file_id_from_path(triplet[2]))
38 | self.make_analogy_from_args(triplet_code, cA, cAprime, cB, noise, steps, uc_scales, analogy_func, **analogy_func_kwargs)
39 |
40 |
41 | def make_analogy_from_args(self, triplet_code, cA, cAprime, cB, noise, steps, uc_scales, analogy_func = None, **analogy_func_kwargs):
42 |
43 | os.makedirs(os.path.join(self.output_root, triplet_code), exist_ok=True)
44 | os.makedirs(os.path.join(self.output_root,'grids'), exist_ok=True)
45 | if analogy_func is None:
46 | analogy_func = lambda cA, cAprime, cB, st: cB + st * (cAprime - cA)
47 |
48 | rows = []
49 | for sc in uc_scales:
50 | cols = []
51 | for st in steps:
52 | analogy_res,_ = self.ddim_sampler.sample(
53 | self.config.ddim_steps,
54 | 1,
55 | self.config.shape,
56 | conditioning = analogy_func(cA, cAprime, cB, st, **analogy_func_kwargs),
57 | eta=self.config.ddim_eta,
58 | x_T=noise,
59 | unconditional_guidance_scale=sc,
60 | unconditional_conditioning=self.uc,
61 | )
62 | img = utils.save_latent_as_image(
63 | self.ddim_sampler.model,
64 | analogy_res,
65 | os.path.join(self.output_root, triplet_code, f'analogy_sc={sc}_shift_strength={st}.jpg'),
66 | return_pil=True
67 | )
68 | cols.append(np.array(img))
69 |
70 | rows.append(np.concatenate(cols, axis = 1))
71 | grid = np.concatenate(rows, axis = 0)
72 | Image.fromarray(grid).save(os.path.join(self.output_root,
73 | 'grids',
74 | f'{triplet_code}_analogy_grid.jpg' ))
75 |
76 |
--------------------------------------------------------------------------------
/estimate_CLIP_features.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | from omegaconf import OmegaConf
3 | from PIL import Image
4 | import pickle as pkl
5 | import numpy as np
6 | import sys
7 | import os
8 | sys.path.append('./stable-diffusion/')
9 |
10 | from ddim_invertor import DDIMInvertor
11 | import utils
12 |
13 |
14 |
15 | parser = ArgumentParser()
16 | parser.add_argument('--config', dest='config', type=str, default='./config/parameter_estimation.yaml',
17 | help='path to config file')
18 | parser.add_argument('--input_img', dest='input_img', type=str, required = True,
19 | help='path to image or text files with image names')
20 |
21 | parser.add_argument('--subfolder', dest='subfolder', type=str, default = 'tokens',
22 | help='subfolder name')
23 |
24 |
25 | parser.add_argument('--data_path', dest='data_path', type=str, default = './dataset/data/',
26 | help='root path to data')
27 |
28 | parser.add_argument('--regenerate_tokens', dest='regenerate', action='store_true',
29 | help='Will regenerate images with random noise and the output conditioning')
30 |
31 | args = parser.parse_args()
32 | args = vars(args)
33 |
34 | assert os.path.isfile(args['input_img']), '--input_img is not a file'
35 |
36 | if args['input_img'].endswith('.txt'):
37 | with open(args['input_img'], 'r') as f:
38 | file_lines = f.readlines()
39 | clean_file_lines = [os.path.join(args['data_path'],x.replace('\n','')) for x in file_lines]
40 | elif args['input_img'].endswith(('.png','.jpeg','.jpg')):
41 | clean_file_lines = [args['input_img']]
42 |
43 | config = OmegaConf.load(f"{args['config']}")
44 | config.args = args
45 |
46 |
47 | print('Loading model...')
48 | model = utils.prepare_default_model()
49 | model = model.to(config.device)
50 | print('Model loaded')
51 |
52 |
53 | invertor = DDIMInvertor(config, model)
54 |
55 | for file_path in clean_file_lines:
56 | if not os.path.exists(file_path):
57 | print(f'Path {file_path} does not exist. Skipping')
58 | continue
59 |
60 | file_id = utils.extract_file_id_from_path(file_path)
61 | if os.path.exists(os.path.join(config.path2save_prefix, file_id, args['subfolder'],'results.pkl')):
62 | print(f'Inversion for file_id {file_id} is already done... Skipping')
63 | continue
64 |
65 | output = invertor.perform_cond_inversion_individual_timesteps(file_path, None, optimize_tokens=True)
66 |
67 |
68 | export_path = os.path.join(config.path2save_prefix, file_id, args['subfolder'])
69 |
70 | print(f'Saving results to {export_path}')
71 | utils.save_results2pickle(export_path, output)
72 | # os.makedirs(export_path, exist_ok=True)
73 |
74 | # with open(os.path.join(export_path, 'results.pkl') ,'wb') as f:
75 | # pkl.dump(output, f)
76 |
77 | if args["regenerate"]:
78 | c_ = model.cond_stage_model.transformer(inputs_embeds = output['estimated_conditioning'].unsqueeze(0))['last_hidden_state']
79 | res, _ = invertor.ddim_sampler.sample(config.ddim_steps,
80 | config.conditioning_optimization.batch_size,
81 | config.shape,
82 | conditioning=c_.expand(config.conditioning_optimization.batch_size, -1,-1),
83 | eta=0.,
84 | unconditional_guidance_scale=5.,
85 | unconditional_conditioning=invertor.uc.expand(config.conditioning_optimization.batch_size, -1,-1))
86 |
87 |
88 |
89 | img = utils.save_latent_as_image(model, res, os.path.join(export_path,'token_regeneration.png'),return_pil=True)
90 | orig = np.array(Image.open(file_path).convert("RGB"))
91 | row = np.concatenate((orig,np.zeros((orig.shape[0], 20,3)),np.array(img)), axis = 1).astype(np.uint8)
92 | Image.fromarray(row).save(os.path.join(export_path,'token_regeneration_with_ref.png'))
93 |
94 | del output
95 |
96 |
--------------------------------------------------------------------------------
/do_analogies.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | import os
3 | from omegaconf import OmegaConf
4 |
5 | from ldm.models.diffusion.ddim import DDIMSampler
6 | from analogy_creator import AnalogyCreator
7 | import utils
8 | from PIL import Image
9 | import pickle as pkl
10 | import torch
11 | import numpy as np
12 |
13 |
14 |
15 | parser = ArgumentParser()
16 | parser.add_argument('--config', dest='config', type=str, default='./config/parameter_estimation.yaml',
17 | help='path to config file')
18 |
19 | parser.add_argument('--inversion_subfolder', dest='subfolder', type=str, default = 'noise',
20 | help='inversion subfolder name')
21 |
22 | parser.add_argument('--token_subfolder', dest='token_subfolder', type=str, default = 'tokens',
23 | help='Token inversion subfolder name')
24 |
25 | parser.add_argument('--output_subfolder', dest='out_subfolder', type=str, default = 'analogies',
26 | help='Output subfolder name')
27 |
28 | parser.add_argument('--triplet_file', dest='triplet_file', type=str,
29 | help='file with image paths')
30 |
31 |
32 | parser.add_argument('--data_path', dest='data_path', type=str, default = './dataset/data/',
33 | help='root path to data')
34 |
35 | args = parser.parse_args()
36 | args = vars(args)
37 |
38 |
39 | with open(args['triplet_file'], 'r') as f:
40 | file_lines = f.readlines()
41 |
42 |
43 | clean_file_triplets = []
44 | for line in file_lines:
45 | clean_line = line.replace('\n','').split(' ')
46 | pathA = utils.file_id2im_path(clean_line[0], data_path=args['data_path'],absolute=True)
47 | pathAprime = utils.file_id2im_path(clean_line[1], data_path=args['data_path'],absolute=True)
48 | pathB = utils.file_id2im_path(clean_line[2], data_path=args['data_path'],absolute=True)
49 | clean_file_triplets.append((pathA, pathAprime, pathB))
50 |
51 |
52 | config = OmegaConf.load(f"{args['config']}")
53 | config.args = args
54 | # log_config = OmegaConf.load('./config/logging_config.yaml')
55 | experiment_root = './results/'
56 |
57 |
58 | subfolder = args['subfolder']
59 | token_subfolder = args['token_subfolder']
60 |
61 | print('Loading model...')
62 | model = utils.prepare_default_model()
63 | model = model.to(config.device)
64 | print('Model loaded')
65 |
66 | ddim_sampler = DDIMSampler(model)
67 |
68 | export_path = os.path.join(experiment_root, 'analogy_results', args['out_subfolder'])
69 |
70 | analogy_creator = AnalogyCreator(config, ddim_sampler, subfolder, token_subfolder, export_path, data_path= args['data_path'])
71 |
72 |
73 | analogy_config = OmegaConf.load(f"./config/analogy_params.yaml")
74 | add_orig_row = analogy_config.add_orig_row
75 | scales = analogy_config.guidance_scales
76 | steps = np.linspace(*analogy_config.analogy_strength)
77 |
78 | analogy_func = lambda cA, cAprime, cB, st: cB + st * (cAprime - cA)
79 |
80 |
81 | for triplet in clean_file_triplets:
82 | print(f'Processing triplet: {triplet}')
83 | if os.path.exists(os.path.join(export_path, utils.tuple2triplet_name(triplet))):
84 | print(f'This ({utils.tuple2triplet_name(triplet)}) analogy is precomputed... Skipping')
85 | continue
86 |
87 | if not utils.check_inversion_done(os.path.join(args['data_path'], triplet[2]), subfolder):
88 | print(f'Inversion for image {triplet[2]} not found... Skipping')
89 | print(f'p at : {os.path.join(args["data_path"], triplet[2], subfolder)}')
90 | continue
91 | if (not utils.check_inversion_done(os.path.join(args['data_path'], triplet[0]), token_subfolder) or \
92 | not utils.check_inversion_done(os.path.join(args['data_path'], triplet[1]), token_subfolder) or \
93 | not utils.check_inversion_done(os.path.join(args['data_path'], triplet[2]), subfolder)):
94 | print(f'Inversion not found... Skipping')
95 | continue
96 |
97 | analogy_creator.make_analogy(triplet, steps, scales, analogy_func = analogy_func)
98 |
99 | if add_orig_row:
100 | triplet_code = '_'.join([utils.extract_file_id_from_path(t) for t in triplet])
101 | grid = np.array(Image.open(os.path.join(export_path,
102 | 'grids',
103 | f'{triplet_code}_analogy_grid.jpg' )))
104 | first_row = utils.join_images(triplet, out_PIL=False)
105 | n_pad = grid.shape[1] - first_row.shape[1]
106 | pad = np.zeros((first_row.shape[0], n_pad, 3))
107 | first_row = np.concatenate((first_row, pad), axis = 1)
108 | final_grid = np.uint8(np.concatenate((first_row, grid), axis = 0))
109 | Image.fromarray(final_grid).save(os.path.join(export_path,'grids',f'{triplet_code}_analogy_grid.jpg'))
110 |
111 |
112 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from pytorch_lightning import seed_everything
2 | from omegaconf import OmegaConf
3 | import numpy as np
4 | import pickle as pkl
5 | import torch
6 | import fnmatch
7 | import PIL
8 | import gc
9 | import os
10 |
11 |
12 | from ldm.util import instantiate_from_config
13 | from modified_clip_transformers import ModifiedCLIPTextModel
14 | import importlib
15 |
16 | # importlib.import_module("/home/subrtade/analogies/DiffusionImageAnalogies/stable-diffusion")
17 |
18 | ######################################################################## Model prep
19 |
20 |
21 |
22 | # taken from the stable-diffusion project script txt2img.py
23 | def load_model_from_config(config, ckpt, verbose=False, device='cuda'):
24 | print(f"Loading model from {ckpt}")
25 | pl_sd = torch.load(ckpt, map_location="cpu")
26 | if "global_step" in pl_sd:
27 | print(f"Global Step: {pl_sd['global_step']}")
28 | sd = pl_sd["state_dict"]
29 | model = instantiate_from_config(config.model)
30 | m, u = model.load_state_dict(sd, strict=False)
31 | if len(m) > 0 and verbose:
32 | print("missing keys:")
33 | print(m)
34 | if len(u) > 0 and verbose:
35 | print("unexpected keys:")
36 | print(u)
37 |
38 | model.to(device)
39 | model.eval()
40 | return model
41 |
42 |
43 | def prepare_default_model(default_seed = 42):
44 | default_config_path = "./stable-diffusion/configs/stable-diffusion/v1-inference.yaml"
45 | default_ckpt_path = "./stable-diffusion/models/ldm/stable-diffusion-v1/model.ckpt"
46 |
47 | seed_everything(default_seed)
48 |
49 | config = OmegaConf.load(f"{default_config_path}")
50 | model = load_model_from_config(config, default_ckpt_path, True)
51 |
52 | del model.cond_stage_model.transformer
53 | print(f'GC COLLECT RETURN VALUE: {gc.collect()}')
54 |
55 |
56 | model.cond_stage_model.transformer = ModifiedCLIPTextModel.from_pretrained('openai/clip-vit-large-patch14').to(model.device)
57 |
58 |
59 | return model
60 |
61 |
62 | #######################################################################
63 | # Data prep
64 |
65 | def extract_file_id_from_path(file_name):
66 | return os.path.basename(file_name).split('.')[0]
67 |
68 | def load_all_image_names(path = './dataset/data/', suffixes = ['jpg', 'jpeg', 'png', 'JPG', 'JPEG']):
69 | all_image_files = os.listdir(path)
70 | extracted_files = []
71 | for suf in suffixes:
72 | extracted_files += fnmatch.filter(all_image_files , f'*.{suf}')
73 |
74 | return extracted_files
75 |
76 |
77 | def file_id2im_path(file_id, data_path = './dataset/data', absolute=False):
78 | if not file_id.endswith(('png', 'jpg', 'jpeg', 'JPG', 'JPEG')):
79 | image_names = load_all_image_names(path=data_path)
80 | image_name = fnmatch.filter(image_names, f'{file_id}.*')[0]
81 | else:
82 | image_name = file_id
83 | if absolute:
84 | return os.path.join(data_path, image_name)
85 | return image_name
86 |
87 |
88 | def extract_triplet_from_tuple(tuple_):
89 | fids = [extract_file_id_from_path(pth) for pth in tuple_]
90 | return fids
91 |
92 | def tuple2triplet_name(triplet_tuple):
93 | file_ids = extract_triplet_from_tuple(triplet_tuple)
94 | return '_'.join(file_ids)
95 |
96 |
97 | def join_images(list_of_image_paths, dim=1, path_prefix = '',out_PIL = True):
98 | """Given list of image paths, the function puts the images side by side in 'dim'.
99 |
100 | Args:
101 | list_of_image_paths (list): List that contains paths to the images
102 | dim (int, optional): In which dimension are the images joined. Defaults to 1.
103 | path_prefix (str, optional): Path to the images. Defaults to ''.
104 | out_PIL (bool, optional): The output is PIL image if True, otherwise the ouptut is np.ndarray. Defaults to True.
105 |
106 | Returns:
107 | _type_: _description_
108 | """
109 | imgs = []
110 | for im_name in list_of_image_paths:
111 | img = np.array(PIL.Image.open(os.path.join(path_prefix,im_name)))
112 | if len(img.shape) == 2:
113 | img = np.stack((img,img,img), axis = -1)
114 | if img.shape[-1] == 4:
115 | img = img[:,:,:3]
116 | imgs.append(img)
117 | return join_array_of_np_images(imgs, dim, out_PIL)
118 |
119 | def join_array_of_np_images(array_of_imgs, dim = 1, out_PIL = True):
120 | if out_PIL:
121 | return PIL.Image.fromarray(np.concatenate(array_of_imgs, axis = dim))
122 | return np.concatenate(array_of_imgs, axis = dim)
123 |
124 |
125 |
126 |
127 | def img2latent(model, img_torch):
128 | return model.get_first_stage_encoding(model.encode_first_stage(img_torch.to(model.first_stage_model.device)))
129 |
130 | def latent2img(model, latent):
131 | images = model.decode_first_stage(latent.to(model.first_stage_model.device))
132 | return images
133 |
134 | def load_pil(img):
135 | return PIL.Image.open(img).convert('RGB')
136 |
137 |
138 | def pil2torch(pilimg, to_range = True, device = 'cuda:0'):
139 | w, h = pilimg.size
140 | w, h = w - w%32, h - h%32
141 | pilimg.resize((w,h), resample=PIL.Image.LANCZOS)
142 | im_np = np.array(pilimg).astype(np.float32) / 255.
143 | im_np = im_np[np.newaxis].transpose((0, 3, 1, 2))
144 | im_torch = torch.from_numpy(im_np).to(device)
145 | if to_range:
146 | im_torch = 2*im_torch - 1
147 | return im_torch
148 |
149 | def pil2torch_batch(list_of_ims, to_range=True, device= 'cuda:0'):
150 | batch = []
151 | for b in range(len(list_of_ims)):
152 | batch.append(pil2torch(list_of_ims[b], to_range, device))
153 | return torch.cat(batch, dim=0)
154 |
155 | def torch2pil(images, from_range = True):
156 | if from_range:
157 | images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
158 | pil_images = []
159 | b = images.shape[0]
160 | for i in range(b):
161 | img_np = np.array(images[i].detach().cpu())
162 | img_np = np.uint8(img_np.transpose((1,2,0)) * 255)
163 | img_pil = PIL.Image.fromarray(img_np)
164 | pil_images.append(img_pil)
165 | return pil_images
166 |
167 |
168 | def save_latent_as_image(model, latent, path, return_pil=False):
169 | """Generates the output image from given latent and saves it to path.
170 |
171 | Args:
172 | model (_type_): stable diffusion.
173 | latent (_type_): Latent of the image.
174 | path (_type_): Path to save the image.
175 | """
176 | rec_img_torch = latent2img(model, latent)
177 | rec_img_pil = torch2pil(torch.cat(list(rec_img_torch), dim = -1).unsqueeze(0))[0]
178 | rec_img_pil.save(path)
179 | if return_pil:
180 | return rec_img_pil
181 |
182 | #######################################################################
183 | # Optimization utils
184 |
185 | def pixel_space_loss(model, latent1, real_image, loss_fn):
186 | """computes loss loss_fn in pixel space
187 |
188 | Args:
189 | model (_type_): stable diffusion model
190 | latent1 (_type_): latent of the generated image
191 | real_image (_type_): target image
192 | loss_fn (_type_): torch functional loss
193 |
194 | Returns:
195 | _type_: Value of loss_fn between real_image and the image generated from latent1.
196 | """
197 | image1 = model.differentiable_decode_first_stage(latent1.to(model.first_stage_model.device))
198 | return loss_fn(image1, real_image.to(image1.device))
199 |
200 |
201 | #######################################################################
202 | # Results manipulation
203 |
204 | def load_estimated_cond(file_id, token_subfolder = 'tokens', inversion_path_root = './results/experiments/inversion/' ):
205 | if not os.path.exists(os.path.join(inversion_path_root, f'{file_id}/{token_subfolder}/results.pkl')):
206 | return None
207 | with open(os.path.join(inversion_path_root,f'{file_id}/{token_subfolder}/results.pkl'),'rb') as f:
208 | results = pkl.load(f)
209 | return results['estimated_conditioning']
210 |
211 |
212 |
213 | def load_inversion_result_dict(file_id, subfolder, return_result_dict = False, inversion_root_folder='./results/experiments/inversion/'):
214 | """Loads the results of inversion for given file_id and experiment.
215 |
216 | Args:
217 | file_id (str (ex. 000001)): File id of the inverted image.
218 | subfolder (str): Name of the inversion experiment.
219 | return_result_dict (bool, optional): If yes returns the whole result dict. Defaults to False.
220 |
221 | Returns:
222 | _type_: collection of (noise, conditioning matrix, unconditional guidance scale, [result dict])
223 | """
224 | assert os.path.exists(os.path.join(inversion_root_folder, file_id, subfolder,'results.pkl')) , f'This ({file_id}/{subfolder}) experiment does not exist.'
225 |
226 | with open(os.path.join(inversion_root_folder, file_id, subfolder,'results.pkl'), 'rb') as f:
227 | results = pkl.load(f)
228 |
229 | noise = results['estimated_input_noise'] if 'estimated_input_noise' in results.keys() else None
230 | cond = results['estimated_conditioning'] if 'estimated_conditioning' in results.keys() else None
231 | cond_scale = results['guidance_scale'] if 'guidance_scale' in results.keys() else None
232 |
233 | output = (noise, cond, cond_scale)
234 | if return_result_dict:
235 | output = (*output, results)
236 | return output
237 |
238 |
239 |
240 | def check_inversion_done(path_to_image_or_file_id, subfolder, inversion_root_folder = "./results/experiments/inversion/"):
241 | if path_to_image_or_file_id.endswith(('.jpg','.png','.jpeg', 'JPG', 'JPEG')):
242 | file_id = extract_file_id_from_path(path_to_image_or_file_id)
243 | else:
244 | file_id = path_to_image_or_file_id
245 | print(f'Checking: {os.path.join(inversion_root_folder, file_id, subfolder,"results.pkl")}')
246 | return os.path.exists(os.path.join(inversion_root_folder, file_id, subfolder,'results.pkl'))
247 |
248 | #######################################################################
249 | # Others
250 |
251 |
252 | def save_results2pickle(path2save, results):
253 | os.makedirs(path2save, exist_ok=True)
254 | with open(os.path.join(path2save, 'results.pkl') ,'wb') as f:
255 | pkl.dump(results, f)
256 |
257 |
258 |
259 | def check_and_run_inversion(model, file_id, subfolder, config, tokens = True):
260 | if not check_inversion_done(file_id, subfolder):
261 | from ddim_invertor import DDIMInvertor
262 | invertor = DDIMInvertor(config, model)
263 | if tokens:
264 | output = invertor.perform_cond_inversion_individual_timesteps(file_id2im_path(file_id), None, optimize_tokens=True)
265 | else:
266 | output = invertor.perform_inversion(file_id2im_path(file_id), None, init_noise_init = None, loss_weights= {'latents': 1. , 'pixels':1.} )
267 |
268 | export_path = os.path.join(config.path2save_prefix, file_id, subfolder)
269 | save_results2pickle(export_path, output)
270 | print(f'Inversion done')
--------------------------------------------------------------------------------
/DiffusionImageAnalogies.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "fd6ef73c",
7 | "metadata": {},
8 | "outputs": [
9 | {
10 | "name": "stderr",
11 | "output_type": "stream",
12 | "text": [
13 | "/home/subrtade/.local/lib/python3.10/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: libtorch_cuda_cu.so: cannot open shared object file: No such file or directory\n",
14 | " warn(f\"Failed to load image Python extension: {e}\")\n"
15 | ]
16 | }
17 | ],
18 | "source": [
19 | "from omegaconf import OmegaConf\n",
20 | "import numpy as np\n",
21 | "import os\n",
22 | "\n",
23 | "\n",
24 | "from ldm.models.diffusion.ddim import DDIMSampler\n",
25 | "from analogy_creator import AnalogyCreator\n",
26 | "import utils"
27 | ]
28 | },
29 | {
30 | "attachments": {},
31 | "cell_type": "markdown",
32 | "id": "69519f0c",
33 | "metadata": {},
34 | "source": [
35 | "# Presets\n"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": 2,
41 | "id": "44029400",
42 | "metadata": {},
43 | "outputs": [],
44 | "source": [
45 | "path_to_data = './dataset/data/'\n",
46 | "out_subfolder = 'notebook_analogies'\n",
47 | "experiment_root = './results'\n",
48 | "visualize_tokens = True\n",
49 | "# guidance scales\n",
50 | "scales = [1.,2., 3., 5., 7., 9., 12.]\n",
51 | "# analogy strength step\n",
52 | "steps = np.linspace(0, 3, 20)"
53 | ]
54 | },
55 | {
56 | "attachments": {},
57 | "cell_type": "markdown",
58 | "id": "65fc40a3",
59 | "metadata": {},
60 | "source": [
61 | "# Initialization"
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "execution_count": null,
67 | "id": "bb5faf5b",
68 | "metadata": {},
69 | "outputs": [
70 | {
71 | "name": "stderr",
72 | "output_type": "stream",
73 | "text": [
74 | "Global seed set to 42\n"
75 | ]
76 | },
77 | {
78 | "name": "stdout",
79 | "output_type": "stream",
80 | "text": [
81 | "Loading model...\n",
82 | "Loading model from stable-diffusion/models/ldm/stable-diffusion-v1/model.ckpt\n",
83 | "Global Step: 194366\n",
84 | "LatentDiffusion: Running in eps-prediction mode\n"
85 | ]
86 | }
87 | ],
88 | "source": [
89 | "config = OmegaConf.load('./config/parameter_estimation.yaml')\n",
90 | "\n",
91 | "\n",
92 | "print('Loading model...')\n",
93 | "model = utils.prepare_default_model()\n",
94 | "print('Model loaded')\n",
95 | "\n",
96 | "\n",
97 | "export_path = os.path.join(experiment_root, 'analogy_results', out_subfolder)\n",
98 | "os.makedirs(export_path, exist_ok = True)\n",
99 | "\n",
100 | "token_subfolder = 'tokens_dia_test'\n",
101 | "subfolder = 'noise_dia_test'\n",
102 | "\n",
103 | "\n",
104 | "ddim_sampler = DDIMSampler(model)\n",
105 | "\n",
106 | "analogy_creator = AnalogyCreator(config, ddim_sampler, subfolder, token_subfolder, os.path.join(log_config.experiment_root, 'analogy_results', out_subfolder))\n"
107 | ]
108 | },
109 | {
110 | "attachments": {},
111 | "cell_type": "markdown",
112 | "id": "e4a7a064",
113 | "metadata": {},
114 | "source": [
115 | "# Dataset"
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": null,
121 | "id": "7738b083",
122 | "metadata": {},
123 | "outputs": [],
124 | "source": [
125 | "import ipyplot\n",
126 | "\n",
127 | "im_path = lambda x: os.path.join(path_to_data, x)\n",
128 | "\n",
129 | "\n",
130 | "all_images = sorted(utils.load_all_image_names())\n",
131 | "labels = [utils.extract_file_id_from_path(x) for x in all_images] \n",
132 | "\n",
133 | "\n",
134 | "images_list = [im_path(x) for x in all_images]\n",
135 | "ipyplot.plot_images(images_list, labels = labels, max_images=200, img_width=100)"
136 | ]
137 | },
138 | {
139 | "cell_type": "markdown",
140 | "id": "5627667a",
141 | "metadata": {},
142 | "source": [
143 | "### Triplet Selection"
144 | ]
145 | },
146 | {
147 | "cell_type": "code",
148 | "execution_count": null,
149 | "id": "c1035cde",
150 | "metadata": {},
151 | "outputs": [],
152 | "source": [
153 | "file_id_A = '000203'\n",
154 | "file_id_A_prime = '000204'\n",
155 | "file_id_B = '000226'\n",
156 | "\n"
157 | ]
158 | },
159 | {
160 | "cell_type": "markdown",
161 | "id": "0c40cee5",
162 | "metadata": {},
163 | "source": [
164 | "# Analogies"
165 | ]
166 | },
167 | {
168 | "cell_type": "markdown",
169 | "id": "5c6e0b3f",
170 | "metadata": {},
171 | "source": [
172 | "### Check for inverted CLIP features and noise\n",
173 | "The parameters will be estimated, if missing."
174 | ]
175 | },
176 | {
177 | "cell_type": "code",
178 | "execution_count": null,
179 | "id": "89d6aec6",
180 | "metadata": {},
181 | "outputs": [],
182 | "source": [
183 | "\n",
184 | "# check for CLIP features\n",
185 | "utils.check_and_run_inversion(model, file_id_A, token_subfolder, config, tokens=True)\n",
186 | "\n",
187 | "utils.check_and_run_inversion(model,file_id_A_prime, token_subfolder, config, tokens=True)\n",
188 | "\n",
189 | "utils.check_and_run_inversion(model,file_id_B, token_subfolder, config, tokens=True)\n",
190 | "\n",
191 | "\n",
192 | "# check for noise of image B\n",
193 | "config.token_subfolder = token_subfolder\n",
194 | "utils.check_and_run_inversion(model, file_id_B, subfolder, config, tokens=False)"
195 | ]
196 | },
197 | {
198 | "cell_type": "markdown",
199 | "id": "0a49d5fb",
200 | "metadata": {},
201 | "source": [
202 | "### Performing analogies\n",
203 | "Generating grid of analogies for given triplet, list of guidance scales \\sigma_i and analogy strengths \\lambda_j."
204 | ]
205 | },
206 | {
207 | "cell_type": "code",
208 | "execution_count": null,
209 | "id": "ca17cd4f",
210 | "metadata": {
211 | "scrolled": true
212 | },
213 | "outputs": [],
214 | "source": [
215 | "\n",
216 | "\n",
217 | "\n",
218 | "print(f'Processing triplet: ({file_id_A}, {file_id_A_prime}, {file_id_B})')\n",
219 | "\n",
220 | "triplet = (utils.file_id2im_path(file_id_A), utils.file_id2im_path(file_id_A_prime), utils.file_id2im_path(file_id_B))\n",
221 | "triplet_code = f'{file_id_A}_{file_id_A_prime}_{file_id_B}'\n",
222 | "\n",
223 | "# cA = model.get_learned_conditioning('manual prompt for A')\n",
224 | "# cAprime = model.get_learned_conditioning(\"manual prompt for A'\")\n",
225 | "\n",
226 | "cA = analogy_creator.fetch_cond_matrix(file_id_A)\n",
227 | "cAprime = analogy_creator.fetch_cond_matrix(file_id_A_prime)\n",
228 | "\n",
229 | "noiseB,_,_ = utils.load_inversion_result_dict(file_id_B, subfolder, return_result_dict=False)\n",
230 | "\n",
231 | "cB = analogy_creator.fetch_cond_matrix(file_id_B)\n",
232 | "\n",
233 | "\n",
234 | "analogy_creator.make_analogy_from_args(triplet_code, cA, cAprime, cB, noiseB, steps, scales)\n",
235 | "\n"
236 | ]
237 | },
238 | {
239 | "cell_type": "markdown",
240 | "id": "1eec85b4",
241 | "metadata": {},
242 | "source": [
243 | "### Analogy Results\n",
244 | "Generated grid of analogies is shown. Each row corresponds to the guidance scale \\sigma_i, each column corresponds to analogy strength \\lambda_j."
245 | ]
246 | },
247 | {
248 | "cell_type": "code",
249 | "execution_count": null,
250 | "id": "c51a5c97",
251 | "metadata": {},
252 | "outputs": [],
253 | "source": [
254 | "from IPython.display import Image\n",
255 | "Image(filename=os.path.join(export_path,\n",
256 | " 'grids',\n",
257 | " f'{triplet_code}_analogy_grid.jpg' )) "
258 | ]
259 | },
260 | {
261 | "cell_type": "markdown",
262 | "id": "ae0a069f",
263 | "metadata": {},
264 | "source": [
265 | "#### Visualize particular result for given \\sigma_i and \\lambda_j."
266 | ]
267 | },
268 | {
269 | "cell_type": "code",
270 | "execution_count": null,
271 | "id": "765c51aa",
272 | "metadata": {},
273 | "outputs": [],
274 | "source": [
275 | "row = 3 # [0-len(scales)]\n",
276 | "col = 10 # [0-len(steps)]\n",
277 | "current_scale = scales[row]\n",
278 | "current_step = steps[col]\n",
279 | "print(f'scale: {current_scale} | step: {current_step}')\n",
280 | "Image(filename=os.path.join(export_path,\n",
281 | " triplet_code,\n",
282 | " f'analogy_sc={current_scale}_shift_strength={current_step}.jpg' )) "
283 | ]
284 | },
285 | {
286 | "cell_type": "markdown",
287 | "id": "adc55236",
288 | "metadata": {},
289 | "source": [
290 | "### Token visualization\n",
291 | "Estimated CLIP features are paired with random noise images (and transformed via the reverse diffusion process) to visualize the captured concepts. "
292 | ]
293 | },
294 | {
295 | "cell_type": "code",
296 | "execution_count": null,
297 | "id": "0ea2cb29",
298 | "metadata": {},
299 | "outputs": [],
300 | "source": [
301 | "\n",
302 | "if visualize_tokens:\n",
303 | " from torchvision.utils import save_image\n",
304 | " os.makedirs(os.path.join(experiment_root,'token_visualization'),exist_ok=True)\n",
305 | " def gen_random_samples(cond, file_id):\n",
306 | " tokens_,_ = analogy_creator.ddim_sampler.sample(\n",
307 | " analogy_creator.config.ddim_steps,\n",
308 | " 8,\n",
309 | " analogy_creator.config.shape,\n",
310 | " conditioning = cond.expand(8,-1,-1),\n",
311 | " eta=analogy_creator.config.ddim_eta,\n",
312 | " unconditional_guidance_scale=sc,\n",
313 | " unconditional_conditioning=analogy_creator.uc.expand(8,-1,-1),\n",
314 | " )\n",
315 | " utils.save_latent_as_image(\n",
316 | " analogy_creator.ddim_sampler.model, \n",
317 | " tokens_,\n",
318 | " os.path.join(experiment_root,'token_visualization',f'{file_id}.jpg')\n",
319 | " )\n",
320 | "\n",
321 | " from IPython.display import Image\n",
322 | " gen_random_samples(cA, file_id_A)\n",
323 | " Image(filename=os.path.join(experiment_root,'token_visualization',f'{file_id_A}.jpg')) \n",
324 | " gen_random_samples(cAprime, file_id_A_prime)\n",
325 | " Image(filename=os.path.join(experiment_root,'token_visualization',f'{file_id_A_prime}.jpg')) \n",
326 | " gen_random_samples(cB, file_id_B)\n",
327 | " Image(filename=os.path.join(experiment_root,'token_visualization',f'{file_id_B}.jpg')) \n",
328 | "\n"
329 | ]
330 | }
331 | ],
332 | "metadata": {
333 | "kernelspec": {
334 | "display_name": "analogies",
335 | "language": "python",
336 | "name": "analogies"
337 | },
338 | "language_info": {
339 | "codemirror_mode": {
340 | "name": "ipython",
341 | "version": 3
342 | },
343 | "file_extension": ".py",
344 | "mimetype": "text/x-python",
345 | "name": "python",
346 | "nbconvert_exporter": "python",
347 | "pygments_lexer": "ipython3",
348 | "version": "3.10.4"
349 | }
350 | },
351 | "nbformat": 4,
352 | "nbformat_minor": 5
353 | }
354 |
--------------------------------------------------------------------------------
/ddim_invertor.py:
--------------------------------------------------------------------------------
1 | from transformers import CLIPTokenizer, CLIPModel, CLIPProcessor
2 | from ldm.modules.diffusionmodules.util import noise_like
3 | from ldm.models.diffusion.ddim import DDIMSampler
4 | import torchvision.transforms as T
5 | from tqdm import tqdm
6 | import numpy as np
7 | import torch
8 | import gc
9 |
10 | from modified_clip_transformers import ModifiedCLIPTextModel
11 | import utils
12 |
13 |
14 | class DDIMInvertor():
15 | def __init__(self, config, model, tokenizer=None) -> None:
16 | self.config = config
17 | self.ddim_sampler = DDIMSampler(model)
18 | self.ddim_sampler.make_schedule(self.config.ddim_steps, ddim_eta=self.config.ddim_eta, verbose=False)
19 | self.uc = self.ddim_sampler.model.get_learned_conditioning([''])
20 | self.tokenizer = tokenizer
21 |
22 |
23 | def __sample_differentiable(self, cond, shape,
24 | x_T=None, ddim_use_original_steps=False,
25 | timesteps=None, unconditional_guidance_scale=1.,
26 | unconditional_conditioning=None,
27 | ):
28 | b = cond.shape[0]
29 | if x_T is None:
30 | img = torch.randn(shape, device=self.config.device)
31 | else:
32 | img = x_T
33 | if timesteps is None:
34 | timesteps = self.ddim_sampler.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_sampler.ddim_timesteps
35 | elif timesteps is not None and not ddim_use_original_steps:
36 | subset_end = int(min(timesteps / self.ddim_sampler.ddim_timesteps.shape[0], 1) * self.ddim_sampler.ddim_timesteps.shape[0]) - 1
37 | timesteps = self.ddim_sampler.ddim_timesteps[:subset_end]
38 |
39 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
40 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
41 |
42 | for i, step in enumerate(time_range):
43 | index = total_steps - i - 1
44 | ts = torch.full((b,), step, device=self.config.device, dtype=torch.long)
45 |
46 |
47 | outs = self.__step_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
48 | unconditional_guidance_scale=unconditional_guidance_scale,
49 | unconditional_conditioning=unconditional_conditioning)
50 | img, pred_x0 = outs
51 | return img
52 |
53 |
54 | def __step_ddim(self, x, c, t, index, use_original_steps=False,
55 | unconditional_guidance_scale=1., unconditional_conditioning=None):
56 | b, *_, device = *x.shape, x.device
57 |
58 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
59 | e_t = self.ddim_sampler.model.apply_model(x, t, c)
60 | else:
61 | x_in = torch.cat([x] * 2)
62 | t_in = torch.cat([t] * 2)
63 | c_in = torch.cat([unconditional_conditioning, c])
64 | e_t_uncond, e_t = self.ddim_sampler.model.apply_model(x_in, t_in, c_in).chunk(2)
65 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
66 |
67 |
68 | alphas = self.ddim_sampler.model.alphas_cumprod if use_original_steps else self.ddim_sampler.ddim_alphas
69 | alphas_prev = self.ddim_sampler.model.alphas_cumprod_prev if use_original_steps else self.ddim_sampler.ddim_alphas_prev
70 | sqrt_one_minus_alphas = self.ddim_sampler.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sampler.ddim_sqrt_one_minus_alphas
71 | sigmas = self.ddim_sampler.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sampler.ddim_sigmas
72 | # select parameters corresponding to the currently considered timestep
73 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
74 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
75 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
76 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
77 |
78 | # current prediction for x_0
79 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
80 |
81 | # direction pointing to x_t
82 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
83 | noise = sigma_t * noise_like(x.shape, device, False)
84 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
85 | return x_prev, pred_x0
86 |
87 |
88 | def perform_inversion(self, image, cond, init_noise_init = None, loss_weights = {'latents': 1. , 'pixels':1.} ):
89 | if cond is None:
90 | with torch.no_grad():
91 | cond_out = utils.load_estimated_cond(utils.extract_file_id_from_path(image), token_subfolder=self.config.token_subfolder)
92 | assert cond_out is not None, 'Token inversion was not found...'
93 | cond = self.__tokens2conditioning(cond_out)
94 |
95 | target_img = utils.load_pil(image)
96 | target_img = target_img.resize((self.config.shape[-2] * self.config.f, self.config.shape[-1] * self.config.f))
97 | target_img = utils.pil2torch(target_img)
98 | target_latent = utils.img2latent(self.ddim_sampler.model, target_img)
99 | target_latent = target_latent.to(self.ddim_sampler.model.device)
100 | target_img = target_img.to(self.ddim_sampler.model.device)
101 |
102 |
103 | if init_noise_init is None:
104 | alpha_t = torch.tensor([self.ddim_sampler.ddim_alphas[-1]]).cuda()
105 | init_noise = torch.sqrt(alpha_t) * target_latent + torch.sqrt(1. - alpha_t) * torch.randn_like(target_latent).to(target_latent.device)
106 |
107 | uc_scale = self.config.noise_optimization.uncond_guidance_scale
108 |
109 | init_noise.requires_grad = True
110 |
111 | lbfgs = torch.optim.LBFGS(params = [init_noise], lr = self.config.noise_optimization.lr)
112 | loss_fn = torch.nn.functional.mse_loss
113 |
114 | shape = [self.config.noise_optimization.batch_size, * self.config.shape]
115 |
116 |
117 | progress = {'loss':[]}
118 | progress['noise'] = []
119 |
120 | pbar = tqdm(range(self.config.noise_optimization.opt_iters))
121 | for i in pbar:
122 | def closure_():
123 | lbfgs.zero_grad()
124 | x0_prediction = self.__sample_differentiable(cond, shape,
125 | x_T=init_noise, unconditional_guidance_scale=uc_scale,
126 | unconditional_conditioning= self.uc)
127 |
128 | loss = loss_weights['latents'] * loss_fn(x0_prediction, target_latent, reduction = 'mean')
129 | if loss_weights['pixels'] != 0:
130 | loss += loss_weights['pixels'] * utils.pixel_space_loss(self.ddim_sampler.model, x0_prediction, target_img, loss_fn)
131 | loss.backward()
132 | return loss.detach().item()
133 |
134 |
135 | x0_prediction = self.__sample_differentiable(cond, shape,
136 | x_T=init_noise, unconditional_guidance_scale=uc_scale,
137 | unconditional_conditioning= self.uc)
138 |
139 | loss = loss_weights['latents'] * loss_fn(x0_prediction, target_latent, reduction = 'mean')
140 | if loss_weights['pixels'] != 0:
141 | loss += loss_weights['pixels'] * utils.pixel_space_loss(self.ddim_sampler.model, x0_prediction, target_img, loss_fn)
142 |
143 | if i % self.config.noise_optimization.log_every == 0:
144 | progress['loss'].append(loss.item())
145 | progress['noise'].append(init_noise.detach().cpu())
146 |
147 | pbar.set_postfix({'loss': loss.cpu().item()})
148 |
149 | if loss.item() < self.config.sufficient_loss:
150 | print(f'Ending computation with {loss.item()} done {i} steps.')
151 | break
152 |
153 | lbfgs.zero_grad()
154 | loss.backward()
155 | lbfgs.step(closure_)
156 |
157 | outputs = {
158 | 'estimated_input_noise': init_noise.detach(),
159 | 'estimated_conditioning': cond ,
160 | 'initial_noise': init_noise_init,
161 | 'target_image_latent': target_latent,
162 | 'path2img': image,
163 | 'config_dict': self.config,
164 | 'reconstruction': x0_prediction.detach(),
165 | 'progress': progress,
166 | 'guidance_scale': uc_scale ,
167 | }
168 |
169 | return outputs
170 |
171 |
172 | # taken from stable diffusion
173 | def add_noise(self, x0, noise, timestep_indices, ddim_use_original_steps=False):
174 | device= x0.device
175 |
176 | alphas_cumprod = self.ddim_sampler.model.ddim_alphas if ddim_use_original_steps else self.ddim_sampler.ddim_alphas
177 | sqrt_one_minus_alphas = self.ddim_sampler.model.ddim_sqrt_one_minus_alphas if ddim_use_original_steps else self.ddim_sampler.ddim_sqrt_one_minus_alphas
178 | sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).to(device)
179 | sqrt_one_minus_alphas = sqrt_one_minus_alphas.to(device)
180 |
181 |
182 | timestep_indices = timestep_indices.to(device)
183 | noise = noise.to(device)
184 |
185 | sqrt_at = torch.index_select(sqrt_alphas_cumprod, 0, timestep_indices).view(-1, 1, 1, 1).to(device)
186 | sqrt_one_minus_at = torch.index_select(sqrt_one_minus_alphas, 0, timestep_indices).view(-1, 1, 1, 1).to(device)
187 |
188 | noisy_samples = sqrt_at * x0.expand_as(noise) + sqrt_one_minus_at * noise
189 | return noisy_samples
190 |
191 |
192 | def ___prepare_batch_for_im(self, image):
193 | target_img = utils.load_pil(image)
194 | target_img = target_img.resize((self.config.shape[-2] * self.config.f, self.config.shape[-1] * self.config.f))
195 | # # create batch
196 | hflipper = T.RandomHorizontalFlip(p=1)
197 | resize_cropper = T.RandomResizedCrop(size=(512, 512), scale = (0.85, 0.99),ratio=(1,1))
198 | resized_crops = [resize_cropper(target_img) for _ in range(max(0, self.config.conditioning_optimization.batch_size - 2))]
199 | if self.config.conditioning_optimization.batch_size == 1:
200 | transformed_imgs = [target_img]
201 | else:
202 | transformed_imgs = [target_img, hflipper(target_img), *resized_crops]
203 |
204 | target_img = utils.pil2torch_batch(transformed_imgs)
205 | target_latent = utils.img2latent(self.ddim_sampler.model, target_img)
206 | return target_img, target_latent
207 |
208 | def __load_tokenizer_and_text_model(self, init_caption, tokenizer = None):
209 | version = 'openai/clip-vit-large-patch14'
210 | if tokenizer is None:
211 | tokenizer = CLIPTokenizer.from_pretrained(version)
212 | self.tokenizer = tokenizer
213 |
214 | if init_caption is None:
215 | return tokenizer, None
216 | batch_encoding = tokenizer(init_caption, truncation=True, max_length=77, return_length=True,
217 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
218 |
219 |
220 | embeddings = self.ddim_sampler.model.cond_stage_model.transformer.get_input_embeddings().weight.data[batch_encoding['input_ids'][0]]
221 | text_tokens = embeddings.clone()
222 | text_tokens.requires_grad = True
223 |
224 | return tokenizer, text_tokens
225 |
226 | def __tokens2conditioning(self, tokens):
227 | conditioning = self.ddim_sampler.model.cond_stage_model.transformer(inputs_embeds = tokens.unsqueeze(0))['last_hidden_state']
228 | return conditioning
229 |
230 |
231 |
232 | def perform_cond_inversion_individual_timesteps(self, image_path, cond_init , optimize_tokens = True):
233 | self.config['optimize_tokens'] = optimize_tokens
234 | with torch.no_grad():
235 | _, target_latent = self.___prepare_batch_for_im(image_path)
236 | timesteps = torch.tensor(self.ddim_sampler.ddim_timesteps)
237 |
238 | if optimize_tokens:
239 | tokenizer, text_tokens = self.__load_tokenizer_and_text_model('', tokenizer = self.tokenizer)
240 | if cond_init is not None:
241 | text_tokens = cond_init.squeeze(0)
242 |
243 | prompt_repre = text_tokens.detach().clone()
244 |
245 | grad_mask = torch.zeros_like(prompt_repre)
246 | grad_mask[:self.config.conditioning_optimization.N_tokens,:] = 1.
247 | grad_mask = grad_mask.to(self.ddim_sampler.model.device)
248 | fetch_cond_init = lambda x: self.ddim_sampler.model.cond_stage_model.transformer(inputs_embeds = x.unsqueeze(0))['last_hidden_state']
249 | prompt_repre.requires_grad = True
250 |
251 | uc_scale = self.config.conditioning_optimization.uncond_guidance_scale
252 |
253 | adam = torch.optim.AdamW(params = [prompt_repre], lr = self.config.conditioning_optimization.lr)
254 | loss_fn = torch.nn.functional.mse_loss
255 |
256 |
257 | progress = {'loss':[], 'indices':[]}
258 | progress['cond'] = []
259 |
260 | timestep_indices = torch.randperm(self.config.conditioning_optimization.batch_size).view(-1).long()
261 | print(f'Selected timesteps: {timestep_indices}')
262 |
263 |
264 | pbar = tqdm(range(self.config.conditioning_optimization.opt_iters))
265 | for i in pbar:
266 |
267 | noise_ = torch.randn_like(target_latent)
268 |
269 | if not self.config.conditioning_optimization.fixed_timesteps:
270 | timestep_indices = torch.randint(low=0, high=self.config.ddim_steps, size=(self.config.conditioning_optimization.batch_size,1) ).view(-1)
271 |
272 | noisy_samples = self.add_noise(target_latent, noise_, timestep_indices, ddim_use_original_steps=False)
273 |
274 | steps_in = torch.index_select(timesteps, 0, timestep_indices).to(self.config.device)
275 | cond_init = fetch_cond_init(prompt_repre)
276 |
277 |
278 | noise_prediction = self.ddim_sampler.model.apply_model(noisy_samples, steps_in, cond_init.expand(self.config.conditioning_optimization.batch_size, -1 , -1))
279 |
280 | loss = loss_fn(noise_prediction, noise_, reduction = 'none').mean((1,2,3)).mean()
281 |
282 |
283 | if i % self.config.conditioning_optimization.log_every == 0:
284 | progress['indices'].append(timestep_indices)
285 | progress['loss'].append(loss.item())
286 | progress['cond'].append(prompt_repre.detach().cpu())
287 |
288 |
289 | pbar.set_postfix({'loss': loss.cpu().item(), 'indices':timestep_indices})
290 |
291 | if loss.item() < self.config.sufficient_loss:
292 | print(f'Ending computation with {loss.item()} done {i} steps.')
293 | break
294 |
295 | adam.zero_grad()
296 | loss.backward()
297 | prompt_repre.grad *= grad_mask
298 | adam.step()
299 |
300 | outputs = {
301 | 'estimated_conditioning': prompt_repre.detach(),
302 | 'target_image_latent': target_latent,
303 | 'config_dict': self.config,
304 | 'optimize_tokens': optimize_tokens,
305 | 'progress': progress,
306 | 'guidance_scale': uc_scale ,
307 | }
308 |
309 | return outputs
310 |
311 |
312 |
--------------------------------------------------------------------------------
/modified_clip_transformers.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple, Union
2 | import torch
3 | import torch.nn as nn
4 | from transformers.models.clip.configuration_clip import CLIPTextConfig, CLIPConfig, CLIPVisionConfig
5 | from transformers.modeling_outputs import BaseModelOutputWithPooling
6 | from transformers.models.clip.modeling_clip import _expand_mask, CLIPTextEmbeddings, CLIPEncoder, CLIPPreTrainedModel, CLIPOutput,CLIPTextTransformer, CLIPVisionTransformer
7 |
8 |
9 | # code taken from hugging face transformers library
10 | # https://github.com/huggingface/transformers/blob/bd469c40659ce76c81f69c7726759d249b4aef49/src/transformers/models/clip/modeling_clip.py
11 | # modified lines are marked
12 |
13 | class ModifiedCLIPTextTransformer(CLIPTextTransformer):
14 | def __init__(self, config: CLIPTextConfig):
15 | super().__init__(config)
16 |
17 | def forward(
18 | self,
19 | input_ids: Optional[torch.Tensor] = None,
20 | attention_mask: Optional[torch.Tensor] = None,
21 | position_ids: Optional[torch.Tensor] = None,
22 | output_attentions: Optional[bool] = None,
23 | output_hidden_states: Optional[bool] = None,
24 | inputs_embeds: Optional[torch.Tensor] = None,
25 | return_dict: Optional[bool] = None,
26 | ) -> Union[Tuple, BaseModelOutputWithPooling]:
27 | r"""
28 | Returns:
29 | """
30 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
31 | output_hidden_states = (
32 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
33 | )
34 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
35 |
36 | # MODIFIED
37 | if input_ids is None and inputs_embeds is None:
38 | raise ValueError("You have to specify either input_ids or inputs_embeds")
39 |
40 | if input_ids is None:
41 | input_shape = inputs_embeds.size()
42 | bsz, seq_len, dim = input_shape
43 | else:
44 | input_shape = input_ids.size()
45 | bsz, seq_len = input_shape
46 | input_ids = input_ids.view(-1, input_shape[-1])
47 |
48 |
49 | hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds )
50 | ##########
51 |
52 |
53 | # CLIP's text model uses causal mask, prepare it here.
54 | # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
55 | causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
56 | hidden_states.device
57 | )
58 | # expand attention_mask
59 | if attention_mask is not None:
60 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
61 | attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
62 |
63 | encoder_outputs = self.encoder(
64 | inputs_embeds=hidden_states,
65 | attention_mask=attention_mask,
66 | causal_attention_mask=causal_attention_mask,
67 | output_attentions=output_attentions,
68 | output_hidden_states=output_hidden_states,
69 | return_dict=return_dict,
70 | )
71 |
72 | last_hidden_state = encoder_outputs[0]
73 | last_hidden_state = self.final_layer_norm(last_hidden_state)
74 |
75 | # text_embeds.shape = [batch_size, sequence_length, transformer.width]
76 | # take features from the eot embedding (eot_token is the highest number in each sequence)
77 | # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
78 | pooled_output = None
79 |
80 | if not return_dict:
81 | return (last_hidden_state, pooled_output) + encoder_outputs[1:]
82 |
83 | return BaseModelOutputWithPooling(
84 | last_hidden_state=last_hidden_state,
85 | pooler_output=pooled_output,
86 | hidden_states=encoder_outputs.hidden_states,
87 | attentions=encoder_outputs.attentions,
88 | )
89 |
90 | def _build_causal_attention_mask(self, bsz, seq_len, dtype):
91 | # lazily create causal attention mask, with full attention between the vision tokens
92 | # pytorch uses additive attention mask; fill with -inf
93 | mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
94 | mask.fill_(torch.tensor(torch.finfo(dtype).min))
95 | mask.triu_(1) # zero out the lower diagonal
96 | mask = mask.unsqueeze(1) # expand mask
97 | return mask
98 |
99 |
100 |
101 | class ModifiedCLIPTextModel(CLIPPreTrainedModel):
102 | config_class = CLIPTextConfig
103 |
104 | _no_split_modules = ["CLIPEncoderLayer"]
105 |
106 | def __init__(self, config: CLIPTextConfig):
107 | super().__init__(config)
108 | # MODIFIED
109 | self.text_model = ModifiedCLIPTextTransformer(config)
110 | ##########
111 | # Initialize weights and apply final processing
112 | self.post_init()
113 |
114 | def get_input_embeddings(self) -> nn.Module:
115 | return self.text_model.embeddings.token_embedding
116 |
117 | def set_input_embeddings(self, value):
118 | self.text_model.embeddings.token_embedding = value
119 |
120 |
121 | def forward(
122 | self,
123 | input_ids: Optional[torch.Tensor] = None,
124 | attention_mask: Optional[torch.Tensor] = None,
125 | position_ids: Optional[torch.Tensor] = None,
126 | output_attentions: Optional[bool] = None,
127 | output_hidden_states: Optional[bool] = None,
128 | inputs_embeds: Optional[torch.Tensor] = None,
129 | return_dict: Optional[bool] = None,
130 | ) -> Union[Tuple, BaseModelOutputWithPooling]:
131 | r"""
132 | Returns:
133 | Examples:
134 | ```python
135 | >>> from transformers import CLIPTokenizer, CLIPTextModel
136 | >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
137 | >>> tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
138 | >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
139 | >>> outputs = model(**inputs)
140 | >>> last_hidden_state = outputs.last_hidden_state
141 | >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
142 | ```"""
143 |
144 |
145 | # MODIFIED
146 | return self.text_model(
147 | input_ids=input_ids,
148 | attention_mask=attention_mask,
149 | position_ids=position_ids,
150 | output_attentions=output_attentions,
151 | output_hidden_states=output_hidden_states,
152 | return_dict=return_dict,
153 | inputs_embeds=inputs_embeds,
154 | )
155 | ##########
156 |
157 |
158 |
159 |
160 |
161 |
162 | # class ModifiedCLIPModel(CLIPPreTrainedModel):
163 | # config_class = CLIPConfig
164 |
165 | # def __init__(self, config: CLIPConfig):
166 | # super().__init__(config)
167 |
168 | # if not isinstance(config.text_config, CLIPTextConfig):
169 | # raise ValueError(
170 | # "config.text_config is expected to be of type CLIPTextConfig but is of type"
171 | # f" {type(config.text_config)}."
172 | # )
173 |
174 | # if not isinstance(config.vision_config, CLIPVisionConfig):
175 | # raise ValueError(
176 | # "config.vision_config is expected to be of type CLIPVisionConfig but is of type"
177 | # f" {type(config.vision_config)}."
178 | # )
179 |
180 | # text_config = config.text_config
181 | # vision_config = config.vision_config
182 |
183 | # self.projection_dim = config.projection_dim
184 | # self.text_embed_dim = text_config.hidden_size
185 | # self.vision_embed_dim = vision_config.hidden_size
186 |
187 | # self.text_model = ModifiedCLIPTextTransformer(text_config)
188 | # self.vision_model = CLIPVisionTransformer(vision_config)
189 |
190 | # self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
191 | # self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
192 | # self.logit_scale = nn.Parameter(torch.ones([]) * self.config.logit_scale_init_value)
193 |
194 | # # Initialize weights and apply final processing
195 | # self.post_init()
196 |
197 | # def get_text_features(
198 | # self,
199 | # input_ids: Optional[torch.Tensor] = None,
200 | # attention_mask: Optional[torch.Tensor] = None,
201 | # position_ids: Optional[torch.Tensor] = None,
202 | # output_attentions: Optional[bool] = None,
203 | # output_hidden_states: Optional[bool] = None,
204 | # return_dict: Optional[bool] = None,
205 | # ) -> torch.FloatTensor:
206 | # r"""
207 | # Returns:
208 | # text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
209 | # applying the projection layer to the pooled output of [`CLIPTextModel`].
210 | # Examples:
211 | # ```python
212 | # >>> from transformers import CLIPTokenizer, CLIPModel
213 | # >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
214 | # >>> tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
215 | # >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
216 | # >>> text_features = model.get_text_features(**inputs)
217 | # ```"""
218 | # # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
219 | # output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
220 | # output_hidden_states = (
221 | # output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
222 | # )
223 | # return_dict = return_dict if return_dict is not None else self.config.use_return_dict
224 |
225 | # text_outputs = self.text_model(
226 | # input_ids=input_ids,
227 | # attention_mask=attention_mask,
228 | # position_ids=position_ids,
229 | # output_attentions=output_attentions,
230 | # output_hidden_states=output_hidden_states,
231 | # return_dict=return_dict,
232 | # )
233 |
234 | # # TODO: fail pooled output je None
235 | # pooled_output = text_outputs[1]
236 | # text_features = self.text_projection(pooled_output)
237 |
238 | # return text_features
239 |
240 | # def get_image_features(
241 | # self,
242 | # pixel_values: Optional[torch.FloatTensor] = None,
243 | # output_attentions: Optional[bool] = None,
244 | # output_hidden_states: Optional[bool] = None,
245 | # return_dict: Optional[bool] = None,
246 | # ) -> torch.FloatTensor:
247 | # r"""
248 | # Returns:
249 | # image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
250 | # applying the projection layer to the pooled output of [`CLIPVisionModel`].
251 | # Examples:
252 | # ```python
253 | # >>> from PIL import Image
254 | # >>> import requests
255 | # >>> from transformers import CLIPProcessor, CLIPModel
256 | # >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
257 | # >>> processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
258 | # >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
259 | # >>> image = Image.open(requests.get(url, stream=True).raw)
260 | # >>> inputs = processor(images=image, return_tensors="pt")
261 | # >>> image_features = model.get_image_features(**inputs)
262 | # ```"""
263 | # # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
264 | # output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
265 | # output_hidden_states = (
266 | # output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
267 | # )
268 | # return_dict = return_dict if return_dict is not None else self.config.use_return_dict
269 |
270 | # vision_outputs = self.vision_model(
271 | # pixel_values=pixel_values,
272 | # output_attentions=output_attentions,
273 | # output_hidden_states=output_hidden_states,
274 | # return_dict=return_dict,
275 | # )
276 |
277 | # pooled_output = vision_outputs[1] # pooled_output
278 | # image_features = self.visual_projection(pooled_output)
279 |
280 | # return image_features
281 |
282 | # # MODIFIED
283 |
284 | # def forward(
285 | # self,
286 | # input_ids: Optional[torch.LongTensor] = None,
287 | # pixel_values: Optional[torch.FloatTensor] = None,
288 | # attention_mask: Optional[torch.Tensor] = None,
289 | # position_ids: Optional[torch.LongTensor] = None,
290 | # return_loss: Optional[bool] = None,
291 | # output_attentions: Optional[bool] = None,
292 | # output_hidden_states: Optional[bool] = None,
293 | # return_dict: Optional[bool] = None,
294 | # inputs_embeds: Optional[torch.FloatTensor] = None,
295 | # ) -> Union[Tuple, CLIPOutput]:
296 | # r"""
297 | # Returns:
298 | # Examples:
299 | # ```python
300 | # >>> from PIL import Image
301 | # >>> import requests
302 | # >>> from transformers import CLIPProcessor, CLIPModel
303 | # >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
304 | # >>> processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
305 | # >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
306 | # >>> image = Image.open(requests.get(url, stream=True).raw)
307 | # >>> inputs = processor(
308 | # ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
309 | # ... )
310 | # >>> outputs = model(**inputs)
311 | # >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
312 | # >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
313 | # ```"""
314 | # # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
315 | # output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
316 | # output_hidden_states = (
317 | # output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
318 | # )
319 | # return_dict = return_dict if return_dict is not None else self.config.use_return_dict
320 |
321 | # vision_outputs = self.vision_model(
322 | # pixel_values=pixel_values,
323 | # output_attentions=output_attentions,
324 | # output_hidden_states=output_hidden_states,
325 | # return_dict=return_dict,
326 | # )
327 |
328 | # text_outputs = self.text_model(
329 | # input_ids=input_ids,
330 | # attention_mask=attention_mask,
331 | # position_ids=position_ids,
332 | # output_attentions=output_attentions,
333 | # output_hidden_states=output_hidden_states,
334 | # return_dict=return_dict,
335 | # inputs_embeds=inputs_embeds,
336 |
337 | # )
338 |
339 | # image_embeds = vision_outputs[1]
340 |
341 | # image_embeds = self.visual_projection(image_embeds)
342 |
343 | # text_embeds = text_outputs[0].squeeze(0)
344 | # text_embeds = self.text_projection(text_embeds)
345 |
346 | # # normalized features
347 | # image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
348 | # text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
349 |
350 |
351 | # # cosine similarity as logits
352 | # logit_scale = self.logit_scale.exp()
353 | # logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
354 | # print(logits_per_text.shape)
355 | # logits_per_image = logits_per_text.t()
356 |
357 | # loss = None
358 | # if return_loss:
359 | # loss = clip_loss(logits_per_text)
360 |
361 | # if not return_dict:
362 | # output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
363 | # return ((loss,) + output) if loss is not None else output
364 |
365 | # return CLIPOutput(
366 | # loss=loss,
367 | # logits_per_image=logits_per_image,
368 | # logits_per_text=logits_per_text,
369 | # text_embeds=text_embeds,
370 | # image_embeds=image_embeds,
371 | # text_model_output=text_outputs,
372 | # vision_model_output=vision_outputs,
373 | # )
374 | # ##########
--------------------------------------------------------------------------------