├── .gitignore
├── LICENSE
├── README.md
├── __main__.py
├── data
├── VQA
│ ├── Annotations
│ │ └── prompt_meta.json
│ └── Meta
│ │ ├── meta_test.json
│ │ ├── meta_train.json
│ │ ├── meta_traintest.json
│ │ └── ques_ans_count.json
├── __init__.py
├── dataloader
│ ├── __init__.py
│ ├── clip_vqa.py
│ └── coco_tasks.py
└── helper
│ ├── PythonHelperTools
│ ├── __init__.py
│ ├── generate_metadata.py
│ ├── prompt_meta.txt
│ ├── questions.txt
│ ├── unique_question_train.txt
│ ├── unique_question_val.txt
│ └── vqaTools
│ │ ├── __init__.py
│ │ └── vqa.py
│ ├── QuestionTypes
│ ├── abstract_v002_question_types.txt
│ └── mscoco_question_types.txt
│ └── __init__.py
├── environment.yml
├── features
├── image_features.py
├── ques_features.py
└── text_features.py
├── model
├── custom_hnet.py
├── hyperclip.py
└── latent_diffuser.py
├── scripts
├── __init__.py
├── hyperclip_classification_test.py
├── precompute_adaptation.py
├── precompute_image_features.py
├── precompute_ques_features.py
├── precompute_text_features.py
├── train_few_shot.py
├── train_hyperclip.py
├── train_latent_diffusion.py
└── train_vae.py
├── setup.py
├── training
├── conditional_model_learn.py
├── hyperclip_learn.py
├── latent_diffusion_learn.py
├── maml_learn.py
├── store_few_shot_latent.py
└── vae_learn.py
└── utils
├── __init__.py
├── build_opt.py
├── clip_utils.py
├── diffusion_utils.py
├── init_utils.py
├── misc_utils.py
├── train_utils.py
└── wandb_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | .ipynb_checkpoints
3 | *.npy
4 | *.pyc
5 | .idea
6 | .vscode
7 | wandb/
8 | *.egg-info/
9 | *.swp
10 |
11 | data/VQA/Features/
12 | data/VQA/Images/val2014/
13 | data/VQA/Images/train2014/
14 | data/VQA/Questions/
15 | data/VQA/Annotations/
16 | data/helper/Results/
17 | data/VQA/Features/train2014/
18 | data/VQA/Features/val2014/
19 | evaluation/ckp_adapter/
20 | evaluation/results/old/
21 | evaluation/ckp_reptile/
22 | evaluation/results/
23 | evaluation/precompute_adaptation/
24 | evaluation/diffusion/
25 | evaluation/hyperclip/
26 | evaluation/hypergan/
27 | evaluation/maml/
28 | evaluation/few_shot/
29 | evaluation/vae/
30 | scripts/scratch/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Elvis Nava and Seijin Kobayashi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Meta-Learning via Classifier(-free) Diffusion Guidance
2 |
3 | [arxiv](https://arxiv.org/abs/2210.08942) | [BibTeX](#citation)
4 |
5 | **Meta-Learning via Classifier(-free) Diffusion Guidance**
6 | [Elvis Nava](https://github.com/elvisnava)\*,
7 | [Seijin Kobayashi](https://github.com/seijin-kobayashi)\*,
8 | Yifei Yin,
9 | Robert K. Katzschmann,
10 | Benjamin F. Grewe
11 | \* equal contribution
12 |
13 | # Installation
14 |
15 | The `hyperclip` conda environment can be created with the following commands:
16 | ```
17 | conda env create -f environment.yml
18 | conda activate hyperclip
19 | pip install git+https://github.com/openai/CLIP.git
20 | conda install pytorch cudatoolkit=11.3 -c pytorch
21 | pip install -e .
22 | ```
23 |
24 | To setup Weights and Biases run
25 | ```
26 | wandb login
27 | ```
28 | and paste your W&B API key.
29 |
30 | # Meta-VQA Dataset
31 |
32 | To re-compute the Meta-VQA dataset, first download the [original VQA v2 dataset](https://visualqa.org/download.html) and place it in the `data/VQA/` folder, and then run (while in the `hyperclip` environment):
33 | ```
34 | python scripts/precompute_image_features.py
35 | python scripts/precompute_ques_features.py
36 | python scripts/precompute_text_features.py
37 | ```
38 | to re-generate the pre-computed CLIP embeddings for images, task questions and answers.
39 |
40 | # Experiment scripts
41 |
42 | To train multitask/MAML baselines or an unconditional Hypernetwork generative model (to later use as basis for conditional generation), use the script:
43 | ```
44 | python scripts/train_few_shot.py [...]
45 | ```
46 |
47 | To train a number of our models, we first need to prepare a precomputed "dataset" of fine-tuned networks/hnet latents/vae latents. We can do so with the script:
48 | ```
49 | python scripts/precompute_adaptation.py (--few_shot_checkpoint | --vae_checkpoint ) [...]
50 | ```
51 |
52 | In order to train the unconditional VAE hypernetwork (alternative to the previous HNET as basis for conditional generation methods), use the script:
53 | ```
54 | python scripts/train_vae.py --precompute_checkpoint [...]
55 | ```
56 |
57 | To train the HyperCLIP encoder (either from precomputed VAE/HNET fine-tunings, a VAE, or an HNET), use the script:
58 | ```
59 | python scripts/train_hyperclip.py (--precompute_checkpoint | --vae_checkpoint | --few_shot_checkpoint ) [...]
60 | ```
61 |
62 | To train a hypernetwork latent diffusion model (HyperLDM), use the script:
63 | ```
64 | python scripts/train_latent_diffusion.py (--precompute_checkpoint | --vae_checkpoint | --few_shot_checkpoint ) [...]
65 | ```
66 |
67 |
68 | # Citation
69 | ```
70 | @misc{nava_meta-learning_2022,
71 | title = {Meta-{Learning} via {Classifier}(-free) {Diffusion} {Guidance}},
72 | url = {http://arxiv.org/abs/2210.08942},
73 | doi = {10.48550/arXiv.2210.08942},
74 | publisher = {arXiv},
75 | author = {Nava, Elvis and Kobayashi, Seijin and Yin, Yifei and Katzschmann, Robert K. and Grewe, Benjamin F.},
76 | month = oct,
77 | year = {2022},
78 | note = {arXiv:2210.08942 [cs]},
79 | keywords = {Machine Learning (cs.LG), FOS: Computer and information sciences},
80 | copyright = {arXiv.org perpetual, non-exclusive license}
81 | }
82 | ```
83 |
--------------------------------------------------------------------------------
/__main__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elvisnava/hyperclip/8574d3d36fbe1bb3311c3cbb214f07fd73ca0a05/__main__.py
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/data/dataloader/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elvisnava/hyperclip/8574d3d36fbe1bb3311c3cbb214f07fd73ca0a05/data/dataloader/__init__.py
--------------------------------------------------------------------------------
/data/dataloader/clip_vqa.py:
--------------------------------------------------------------------------------
1 | from __future__ import division, print_function
2 |
3 | import math
4 |
5 | import numpy as np
6 | import torch
7 | from torch.utils.data import Dataset
8 |
9 |
10 | class CLIP_VQA(Dataset):
11 |
12 | def __init__(self, meta_data, dataSubType, task , image_features, text_features, ques_emb, n_shot=None, n_shot_seed=42,
13 | data_idx=None):
14 | """
15 | Args:
16 | meta_data(string): Path to the meta learning data file
17 | dataSubType(string): train/test/traintest
18 | task(string): question
19 | image_features(dict): image features dict loaded from features.image_features
20 | text_features(dict): text features dict loaded from features.text_features
21 | ques_emb(dict): question feature dict loaded from features.ques_features
22 | n_shot(int): number of examples per class
23 | n_shot_seed(int): random seed for n_shot sampling
24 | """
25 | self.dataSubType = dataSubType
26 | self.task = task
27 | self.answers = meta_data[self.task]["answers"]
28 | self.image_features = image_features
29 | self.text_features = text_features
30 | self.ques_emb = ques_emb
31 | if data_idx is not None:
32 | self.data = meta_data[self.task]["train"] + meta_data[self.task]["test"]
33 | self.data = [self.data[i] for i in data_idx]
34 | elif dataSubType in ["train","test"]:
35 | self.data = meta_data[self.task][self.dataSubType]
36 | elif dataSubType == "traintest":
37 | self.data = meta_data[self.task]["train"] + meta_data[self.task]["test"]
38 | elif dataSubType == "random":
39 | self.data = meta_data[self.task]["train"] + meta_data[self.task]["test"]
40 | frac = torch.rand(())/3.+2./3.
41 | self.data = [self.data[i] for i in np.random.permutation(len(self.data))[:math.ceil(len(self.data)*frac)]]
42 | elif dataSubType == "random50":
43 | self.data = meta_data[self.task]["train"] + meta_data[self.task]["test"]
44 | frac = 0.5
45 | self.data = [self.data[i] for i in np.random.permutation(len(self.data))[:math.ceil(len(self.data)*frac)]]
46 | else:
47 | raise ValueError
48 |
49 | if n_shot is not None and n_shot != "full":
50 | all_answers = np.array([a for [_,a] in self.data])
51 | classes_idx = [np.arange(len(self.data))[np.array(all_answers) == self.answers[i]] for i in range(len(self.answers))]
52 | classes_idx = [np.random.RandomState(seed=n_shot_seed).permutation(i)[:n_shot] for i in classes_idx]
53 | self.data = [self.data[i] for i in np.concatenate(classes_idx)]
54 |
55 | def __len__(self):
56 | return len(self.data)
57 |
58 | def __getitem__(self, idx):
59 | if torch.is_tensor(idx):
60 | idx = idx.tolist()
61 |
62 | [image_name, answer] = self.data[idx]
63 | image_features = self.image_features[image_name]
64 | text_features = self.text_features[self.task]
65 | ques_emb = self.ques_emb[self.task]
66 |
67 |
68 | sample = {'ques_emb': ques_emb, 'image_features': image_features, 'text_features': text_features, "label": self.answers.index(answer)}
69 |
70 | return sample
71 |
--------------------------------------------------------------------------------
/data/dataloader/coco_tasks.py:
--------------------------------------------------------------------------------
1 | from __future__ import division, print_function
2 |
3 | import json
4 | import os
5 | import random
6 |
7 | import clip
8 | import numpy as np
9 | import torch
10 | import tqdm
11 | from features.image_features import load_image_features
12 | from slugify import slugify
13 | from torch.utils.data import Dataset
14 |
15 | base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
16 | device = "cuda" if torch.cuda.is_available() else "cpu"
17 |
18 | def compute_coco_tasks():
19 | with open(base_path + "/data/VQA/Meta/meta_train.json") as file:
20 | train_data = json.load(file)
21 | with open(base_path + "/data/VQA/Meta/meta_test.json") as file:
22 | test_data = json.load(file)
23 |
24 | data_types = ['val2014', 'train2014']
25 | coco_data = {}
26 | for dt in data_types:
27 | with open(base_path + f"/data/VQA/Annotations/annotations/instances_{dt}.json") as file:
28 | tmp = json.load(file)
29 | if coco_data == {}:
30 | coco_data = tmp
31 | else:
32 | coco_data['images'] += tmp['images']
33 | coco_data['annotations'] += tmp['annotations']
34 |
35 | imgs_test = []
36 | for task in test_data:
37 | for dataSubType in ["train", "test"]:
38 | for [image_name, _] in test_data[task][dataSubType]:
39 | imgs_test.append(image_name)
40 |
41 | cat_id_to_name = {cat['id']: cat['name'] for cat in coco_data['categories']}
42 | imgs_id_to_name = {i['id']: i['file_name'] for i in coco_data['images']}
43 |
44 | vanilla_coco_categories = {}
45 | for ann in tqdm.tqdm(coco_data["annotations"]):
46 | if ann["area"] > 20000 and not ann["iscrowd"]:
47 | cat = cat_id_to_name[ann["category_id"]]
48 | img_name = imgs_id_to_name[ann['image_id']].split(".")[0]
49 | if img_name not in imgs_test:
50 | if cat not in vanilla_coco_categories.keys():
51 | vanilla_coco_categories[cat] = []
52 | if img_name not in vanilla_coco_categories[cat]:
53 | vanilla_coco_categories[cat].append(img_name)
54 |
55 | np.save(base_path+"/data/Attributes/vanilla_coco_categories.npy", vanilla_coco_categories)
56 |
57 |
58 | def compute_coco_answer_features(model='ViT-L/14@336px'):
59 | print("Load CLIP {} model".format(model))
60 | clip_model, _ = clip.load(model, device=torch.device("cpu"), jit=False) # load cpu version to get float32 model
61 | clip_model = clip_model.to(device)
62 |
63 | categories = np.load(base_path+"/data/Attributes/vanilla_coco_categories.npy", allow_pickle=True).item()
64 |
65 | categories_features = {}
66 | for cat in tqdm.tqdm(categories.keys()):
67 | prompt = f"A picture of a {cat}"
68 | prompt = clip.tokenize(prompt).to(device)
69 |
70 | with torch.no_grad():
71 | text_feature = clip_model.encode_text(prompt)
72 |
73 | text_feature = text_feature.float().cpu().numpy()
74 |
75 | categories_features[cat] = text_feature
76 |
77 | np.save(base_path+f'/data/Attributes/coco_answer_features_{slugify(model)}.npy',categories_features)
78 |
79 | def load_coco_answer_features(model='ViT-L/14@336px'):
80 | coco_answer_features = np.load(base_path+f'/data/Attributes/coco_answer_features_{slugify(model)}.npy', allow_pickle=True).item()
81 | for cat in coco_answer_features.keys():
82 | coco_answer_features[cat] = torch.from_numpy(coco_answer_features[cat]).to(device)
83 |
84 | return coco_answer_features
85 |
86 | def filter_categories(categories, min_size=5):
87 | filtered_categories = {}
88 | for cat in categories.keys():
89 | if len(categories[cat]) >= min_size:
90 | filtered_categories[cat] = categories[cat]
91 | return filtered_categories
92 |
93 | class COCO_Tasks(Dataset):
94 |
95 | def __init__(self, categories, dataSubType, image_features, coco_answer_features, n_way=5, train_size=5, test_size=5, task_seed=42):
96 | self.categories = categories
97 | self.image_features = image_features
98 | self.coco_answer_features = coco_answer_features
99 | self.n_way = n_way
100 | self.train_size = train_size
101 | self.test_size = test_size
102 | self.task_seed = task_seed
103 |
104 | random.seed(self.task_seed)
105 | self.answers = random.sample(list(self.categories.keys()), n_way)
106 | self.text_features = torch.cat([self.coco_answer_features[answer] for answer in self.answers], dim=0)
107 |
108 | train_data = []
109 | test_data = []
110 | for answer in self.answers:
111 | data_for_answer = random.sample(self.categories[answer], self.train_size+self.test_size)
112 | for i, image_name in enumerate(data_for_answer):
113 | if i < self.train_size:
114 | train_data.append([image_name, answer])
115 | else:
116 | test_data.append([image_name, answer])
117 |
118 | if dataSubType == "train":
119 | self.data = train_data
120 | elif dataSubType == "test":
121 | self.data = test_data
122 | elif dataSubType == "traintest":
123 | self.data = train_data + test_data
124 |
125 | random.shuffle(self.data)
126 |
127 | def __len__(self):
128 | return len(self.data)
129 |
130 | def __getitem__(self, idx):
131 | if torch.is_tensor(idx):
132 | idx = idx.tolist()
133 |
134 | [image_name, answer] = self.data[idx]
135 | image_features = self.image_features[image_name]
136 | text_features = self.text_features
137 |
138 |
139 | sample = {'ques_emb': [0.], 'image_features': image_features, 'text_features': text_features, "label": self.answers.index(answer)}
140 |
141 | return sample
142 |
143 | if __name__ == "__main__":
144 | #compute_coco_tasks()
145 | #compute_coco_answer_features()
146 | coco_categories = np.load(base_path+"/data/Attributes/vanilla_coco_categories.npy", allow_pickle=True).item()
147 | coco_answer_features = load_coco_answer_features(model='ViT-L/14@336px')
148 | image_features = load_image_features(model='ViT-L/14@336px')
149 | dataset = COCO_Tasks(coco_categories, "train", image_features, coco_answer_features, n_way=5)
150 | for sample in dataset:
151 | print(sample)
152 |
153 |
--------------------------------------------------------------------------------
/data/helper/PythonHelperTools/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/data/helper/PythonHelperTools/generate_metadata.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import random
4 |
5 | from vqaTools.vqa import VQA
6 |
7 | dataDir ='../../VQA'
8 | versionType ='v2_' # this should be '' when using VQA v2.0 dataset
9 | taskType ='OpenEnded' # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0
10 | dataType ='mscoco' # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0.
11 | dataSubType ='train2014'
12 | annFile ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType)
13 | quesFile ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType)
14 | imgDir = '%s/Images/%s/%s/' %(dataDir, dataType, dataSubType)
15 |
16 | # initialize VQA api for QA annotations
17 | vqa_train=VQA(annFile, quesFile)
18 |
19 | unique_questions_train = {}
20 |
21 | for question_id in vqa_train.qa.keys():
22 | if vqa_train.qa[question_id]["answer_type"] in ["number","other"]:
23 | ques = vqa_train.qqa[question_id]['question']
24 | if ques not in unique_questions_train.keys():
25 | unique_questions_train[ques] = {"count":1,"ids":[question_id]}
26 | else:
27 | unique_questions_train[ques]["count"] +=1
28 | unique_questions_train[ques]["ids"].append(question_id)
29 |
30 | dataSubType ='val2014'
31 | annFile ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType)
32 | quesFile ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType)
33 | imgDir = '%s/Images/%s/%s/' %(dataDir, dataType, dataSubType)
34 |
35 | # initialize VQA api for QA annotations
36 | vqa_val=VQA(annFile, quesFile)
37 |
38 | unique_questions_val = {}
39 |
40 | for question_id in vqa_val.qa.keys():
41 | if vqa_val.qa[question_id]["answer_type"] in ["number","other"]:
42 | ques = vqa_val.qqa[question_id]['question']
43 | if ques not in unique_questions_val.keys():
44 | unique_questions_val[ques] = {"count":1,"ids":[question_id]}
45 | else:
46 | unique_questions_val[ques]["count"] +=1
47 | unique_questions_val[ques]["ids"].append(question_id)
48 |
49 | unique_questions_trainval = unique_questions_train
50 | for i, item in unique_questions_val.items():
51 | if i in unique_questions_train:
52 | unique_questions_trainval[i]['count'] += item['count']
53 | unique_questions_trainval[i]['ids'] += item['ids']
54 | else:
55 | unique_questions_trainval[i] = {'count':item['count'], 'ids': item['ids']}
56 |
57 | vqa_trainval_qa = {**vqa_train.qa, **vqa_val.qa}
58 | vqa_trainval_qqa = {**vqa_train.qqa, **vqa_val.qqa}
59 |
60 | unique_questions_trainval_selected ={}
61 | for key, item in unique_questions_trainval.items():
62 | image_ans = {}
63 | image_list = []
64 | for i in item["ids"]:
65 | a = vqa_trainval_qa[i]['multiple_choice_answer'] # multiple_choice_answer: most frequent ground-truth answer.
66 | # answers[a] = answers.get(a,0) + 1
67 | image_id = vqa_trainval_qa[i]['image_id']
68 | if i in vqa_train.qa:
69 | image_name = 'COCO_train2014_'+ str(image_id).zfill(12)
70 | elif i in vqa_val.qa:
71 | image_name = 'COCO_val2014_'+ str(image_id).zfill(12)
72 |
73 | if image_name not in image_list: # Avoding same image appear for the same task
74 | image_list.append(image_name)
75 | if a not in image_ans:
76 | image_ans[a] = [(image_name,i)]
77 | else:
78 | image_ans[a].append((image_name,i))
79 |
80 | # delete the answer that only appears once
81 | image_ans = {key:val for key, val in image_ans.items() if len(val) > 1}
82 | answers = {key:len(val) for key, val in image_ans.items()}
83 | unique_questions_trainval_selected[key] = {"count":sum(answers.values()), "answers_count": len(answers), "answers": answers ,"image_ans":image_ans}
84 |
85 | ques_ans_count = {}
86 | ques_image_ans = {}
87 |
88 | count = 0
89 | for key, item in unique_questions_trainval_selected.items():
90 | cond1 = item["count"] > 20 # question appear at least 20 times
91 | cond2 = item["answers_count"] > 1 # question contains multiple answers
92 | cond3 = "or" not in key.split() # question not in "choose from" form
93 | if cond1 and cond2 and cond3:
94 | ques_ans_count[key] = item["answers"]
95 | ques_image_ans[key] = item["image_ans"]
96 | count += item["count"]
97 |
98 | color = 0
99 | count = 0
100 | for ques in ques_ans_count.keys():
101 | if 'color' in ques: # question about color
102 | color += 1
103 | if 'How many' in ques: # question about counting
104 | count += 1
105 |
106 | meta_traintest = {}
107 |
108 | for ques, ans in ques_image_ans.items():
109 | train = []
110 | test = []
111 | for a, data in ans.items():
112 | split = round(len(data) * 0.7)
113 | shuffled_data = data.copy()
114 | random.Random(2021).shuffle(shuffled_data)
115 | train += [(i[0],a) for i in shuffled_data[:split]]
116 | test += [(i[0],a) for i in shuffled_data[split:]]
117 |
118 | meta_traintest[ques] = {"train" : train, "test": test, "answers": list(ans.keys())}
119 |
120 | tasks = list(meta_traintest.keys())
121 |
122 | meta_train_tasks = []
123 | meta_test_tasks = []
124 | split = round(len(tasks) * 0.7)
125 | shuffled_data = tasks.copy()
126 | random.Random(2021).shuffle(shuffled_data)
127 | meta_train_tasks += [i for i in shuffled_data[:split]]
128 | meta_test_tasks += [i for i in shuffled_data[split:]]
129 |
130 | meta_test = {}
131 |
132 | for task in meta_test_tasks:
133 | meta_test[task] = meta_traintest[task]
134 |
135 | with open(os.path.join(dataDir,"Meta/meta_test.json"),"w") as file:
136 | json.dump(meta_test,file)
137 |
138 | meta_train = {}
139 |
140 | for task in meta_train_tasks:
141 | meta_train[task] = meta_traintest[task]
142 |
143 | with open(os.path.join(dataDir,"Meta/meta_train.json"),"w") as file:
144 | json.dump(meta_train,file)
145 |
--------------------------------------------------------------------------------
/data/helper/PythonHelperTools/vqaTools/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'aagrawal'
2 |
--------------------------------------------------------------------------------
/data/helper/PythonHelperTools/vqaTools/vqa.py:
--------------------------------------------------------------------------------
1 | __author__ = 'aagrawal'
2 | __version__ = '0.9'
3 |
4 | # Interface for accessing the VQA dataset.
5 |
6 | # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
7 | # (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).
8 |
9 | # The following functions are defined:
10 | # VQA - VQA class that loads VQA annotation file and prepares data structures.
11 | # getQuesIds - Get question ids that satisfy given filter conditions.
12 | # getImgIds - Get image ids that satisfy given filter conditions.
13 | # loadQA - Load questions and answers with the specified question ids.
14 | # showQA - Display the specified questions and answers.
15 | # loadRes - Load result file and create result object.
16 |
17 | # Help on each function can be accessed by: "help(COCO.function)"
18 |
19 | import json
20 | import datetime
21 | import copy
22 |
23 | class VQA:
24 | def __init__(self, annotation_file=None, question_file=None):
25 | """
26 | Constructor of VQA helper class for reading and visualizing questions and answers.
27 | :param annotation_file (str): location of VQA annotation file
28 | :return:
29 | """
30 | # load dataset
31 | self.dataset = {}
32 | self.questions = {}
33 | self.qa = {}
34 | self.qqa = {}
35 | self.imgToQA = {}
36 | if not annotation_file == None and not question_file == None:
37 | print('loading VQA annotations and questions into memory...')
38 | time_t = datetime.datetime.utcnow()
39 | dataset = json.load(open(annotation_file, 'r'))
40 | questions = json.load(open(question_file, 'r'))
41 | print(datetime.datetime.utcnow() - time_t)
42 | self.dataset = dataset
43 | self.questions = questions
44 | self.createIndex()
45 |
46 | def createIndex(self):
47 | # create index
48 | print('creating index...')
49 | imgToQA = {ann['image_id']: [] for ann in self.dataset['annotations']}
50 | qa = {ann['question_id']: [] for ann in self.dataset['annotations']}
51 | qqa = {ann['question_id']: [] for ann in self.dataset['annotations']}
52 | for ann in self.dataset['annotations']:
53 | imgToQA[ann['image_id']] += [ann]
54 | qa[ann['question_id']] = ann
55 | for ques in self.questions['questions']:
56 | qqa[ques['question_id']] = ques
57 | print('index created!')
58 |
59 | # create class members
60 | self.qa = qa
61 | self.qqa = qqa
62 | self.imgToQA = imgToQA
63 |
64 | def info(self):
65 | """
66 | Print information about the VQA annotation file.
67 | :return:
68 | """
69 | for key, value in self.datset['info'].items():
70 | print('%s: %s'%(key, value))
71 |
72 | def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
73 | """
74 | Get question ids that satisfy given filter conditions. default skips that filter
75 | :param imgIds (int array) : get question ids for given imgs
76 | quesTypes (str array) : get question ids for given question types
77 | ansTypes (str array) : get question ids for given answer types
78 | :return: ids (int array) : integer array of question ids
79 | """
80 | imgIds = imgIds if type(imgIds) == list else [imgIds]
81 | quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
82 | ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
83 |
84 | if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
85 | anns = self.dataset['annotations']
86 | else:
87 | if not len(imgIds) == 0:
88 | anns = sum([self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA],[])
89 | else:
90 | anns = self.dataset['annotations']
91 | anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
92 | anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
93 | ids = [ann['question_id'] for ann in anns]
94 | return ids
95 |
96 | def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
97 | """
98 | Get image ids that satisfy given filter conditions. default skips that filter
99 | :param quesIds (int array) : get image ids for given question ids
100 | quesTypes (str array) : get image ids for given question types
101 | ansTypes (str array) : get image ids for given answer types
102 | :return: ids (int array) : integer array of image ids
103 | """
104 | quesIds = quesIds if type(quesIds) == list else [quesIds]
105 | quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
106 | ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
107 |
108 | if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
109 | anns = self.dataset['annotations']
110 | else:
111 | if not len(quesIds) == 0:
112 | anns = sum([self.qa[quesId] for quesId in quesIds if quesId in self.qa],[])
113 | else:
114 | anns = self.dataset['annotations']
115 | anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
116 | anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
117 | ids = [ann['image_id'] for ann in anns]
118 | return ids
119 |
120 | def loadQA(self, ids=[]):
121 | """
122 | Load questions and answers with the specified question ids.
123 | :param ids (int array) : integer ids specifying question ids
124 | :return: qa (object array) : loaded qa objects
125 | """
126 | if type(ids) == list:
127 | return [self.qa[id] for id in ids]
128 | elif type(ids) == int:
129 | return [self.qa[ids]]
130 |
131 | def showQA(self, anns):
132 | """
133 | Display the specified annotations.
134 | :param anns (array of object): annotations to display
135 | :return: None
136 | """
137 | if len(anns) == 0:
138 | return 0
139 | for ann in anns:
140 | quesId = ann['question_id']
141 | print("Question: %s" %(self.qqa[quesId]['question']))
142 | for ans in ann['answers']:
143 | print("Answer %d: %s" %(ans['answer_id'], ans['answer']))
144 |
145 | def loadRes(self, resFile, quesFile):
146 | """
147 | Load result file and return a result object.
148 | :param resFile (str) : file name of result file
149 | :return: res (obj) : result api object
150 | """
151 | res = VQA()
152 | res.questions = json.load(open(quesFile))
153 | res.dataset['info'] = copy.deepcopy(self.questions['info'])
154 | res.dataset['task_type'] = copy.deepcopy(self.questions['task_type'])
155 | res.dataset['data_type'] = copy.deepcopy(self.questions['data_type'])
156 | res.dataset['data_subtype'] = copy.deepcopy(self.questions['data_subtype'])
157 | res.dataset['license'] = copy.deepcopy(self.questions['license'])
158 |
159 | print('Loading and preparing results... ')
160 | time_t = datetime.datetime.utcnow()
161 | anns = json.load(open(resFile))
162 | assert type(anns) == list, 'results is not an array of objects'
163 | annsQuesIds = [ann['question_id'] for ann in anns]
164 | # assert set(annsQuesIds) == set(self.getQuesIds()), \
165 | # 'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.'
166 | for ann in anns:
167 | quesId = ann['question_id']
168 | if res.dataset['task_type'] == 'Multiple Choice':
169 | assert ann['answer'] in self.qqa[quesId]['multiple_choices'], 'predicted answer is not one of the multiple choices'
170 | qaAnn = self.qa[quesId]
171 | ann['image_id'] = qaAnn['image_id']
172 | ann['question_type'] = qaAnn['question_type']
173 | ann['answer_type'] = qaAnn['answer_type']
174 | print('DONE (t=%0.2fs)'%((datetime.datetime.utcnow() - time_t).total_seconds()))
175 |
176 | res.dataset['annotations'] = anns
177 | res.createIndex()
178 | return res
--------------------------------------------------------------------------------
/data/helper/QuestionTypes/abstract_v002_question_types.txt:
--------------------------------------------------------------------------------
1 | how many
2 | what color is the
3 | is the
4 | where is the
5 | what
6 | what is
7 | are the
8 | what is the
9 | is there a
10 | does the
11 | is the woman
12 | is the man
13 | what is on the
14 | is it
15 | is the girl
16 | is the boy
17 | is the dog
18 | are they
19 | who is
20 | what kind of
21 | what color are the
22 | what is in the
23 | what is the man
24 | is there
25 | what is the woman
26 | what are the
27 | what is the boy
28 | are there
29 | what is the girl
30 | is this
31 | how
32 | which
33 | how many people are
34 | is the cat
35 | why is the
36 | are
37 | will the
38 | what type of
39 | what is the dog
40 | do
41 | is she
42 | does
43 | do the
44 | is
45 | is the baby
46 | are there any
47 | is the lady
48 | can
49 | what animal is
50 | where are the
51 | is the sun
52 | what are they
53 | did the
54 | what is the cat
55 | what is the lady
56 | how many clouds are
57 | is that
58 | is the little girl
59 | is he
60 | are these
61 | how many trees are
62 | how many pillows
63 | are the people
64 | why
65 | is the young
66 | how many windows are
67 | is this a
68 | what is the little
69 | is the tv
70 | how many animals are
71 | who
72 | how many pictures
73 | how many plants are
74 | how many birds are
75 | what color is
76 | what is the baby
77 | is anyone
78 | what color
79 | how many bushes
80 | is the old man
81 | none of the above
82 |
--------------------------------------------------------------------------------
/data/helper/QuestionTypes/mscoco_question_types.txt:
--------------------------------------------------------------------------------
1 | how many
2 | is the
3 | what
4 | what color is the
5 | what is the
6 | is this
7 | is this a
8 | what is
9 | are the
10 | what kind of
11 | is there a
12 | what type of
13 | is it
14 | what are the
15 | where is the
16 | is there
17 | does the
18 | what color are the
19 | are these
20 | are there
21 | which
22 | is
23 | what is the man
24 | is the man
25 | are
26 | how
27 | does this
28 | what is on the
29 | what does the
30 | how many people are
31 | what is in the
32 | what is this
33 | do
34 | what are
35 | are they
36 | what time
37 | what sport is
38 | are there any
39 | is he
40 | what color is
41 | why
42 | where are the
43 | what color
44 | who is
45 | what animal is
46 | is the woman
47 | is this an
48 | do you
49 | how many people are in
50 | what room is
51 | has
52 | is this person
53 | what is the woman
54 | can you
55 | why is the
56 | is the person
57 | what is the color of the
58 | what is the person
59 | could
60 | was
61 | is that a
62 | what number is
63 | what is the name
64 | what brand
65 | none of the above
66 |
--------------------------------------------------------------------------------
/data/helper/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: hyperclip
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - conda-forge
6 | - defaults
7 | dependencies:
8 | - _libgcc_mutex=0.1=main
9 | - _openmp_mutex=5.1=1_gnu
10 | - argon2-cffi=21.3.0=pyhd3eb1b0_0
11 | - argon2-cffi-bindings=21.2.0=py39h7f8727e_0
12 | - asttokens=2.0.5=pyhd3eb1b0_0
13 | - attrs=21.4.0=pyhd3eb1b0_0
14 | - backcall=0.2.0=pyhd3eb1b0_0
15 | - beautifulsoup4=4.11.1=py39h06a4308_0
16 | - blas=1.0=mkl
17 | - bleach=4.1.0=pyhd3eb1b0_0
18 | - blosc=1.21.0=h8c45485_0
19 | - bottleneck=1.3.4=py39hce1f21e_0
20 | - brotli=1.0.9=he6710b0_2
21 | - brotlipy=0.7.0=py39h27cfd23_1003
22 | - brunsli=0.1=h2531618_0
23 | - bzip2=1.0.8=h7b6447c_0
24 | - c-ares=1.18.1=h7f8727e_0
25 | - ca-certificates=2022.6.15=ha878542_0
26 | - certifi=2022.6.15=pyhd8ed1ab_1
27 | - cffi=1.15.0=py39hd667e15_1
28 | - cfitsio=3.470=hf0d0db6_6
29 | - charls=2.2.0=h2531618_0
30 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
31 | - cloudpickle=2.0.0=pyhd3eb1b0_0
32 | - cryptography=36.0.0=py39h9ce1e76_0
33 | - cudatoolkit=11.3.1=ha36c431_9
34 | - cycler=0.11.0=pyhd3eb1b0_0
35 | - cytoolz=0.11.0=py39h27cfd23_0
36 | - dask-core=2022.2.1=pyhd3eb1b0_0
37 | - dbus=1.13.18=hb2f20db_0
38 | - debugpy=1.5.1=py39h295c915_0
39 | - decorator=5.1.1=pyhd3eb1b0_0
40 | - defusedxml=0.7.1=pyhd3eb1b0_0
41 | - einops=0.4.1=pyhd8ed1ab_0
42 | - entrypoints=0.4=py39h06a4308_0
43 | - executing=0.8.3=pyhd3eb1b0_0
44 | - expat=2.4.4=h295c915_0
45 | - ffmpeg=4.2.2=h20bf706_0
46 | - fontconfig=2.13.1=h6c09931_0
47 | - fonttools=4.25.0=pyhd3eb1b0_0
48 | - freetype=2.11.0=h70c0345_0
49 | - fsspec=2022.2.0=pyhd3eb1b0_0
50 | - giflib=5.2.1=h7b6447c_0
51 | - glib=2.69.1=h4ff587b_1
52 | - gmp=6.2.1=h2531618_2
53 | - gnutls=3.6.15=he1e5248_0
54 | - gst-plugins-base=1.14.0=h8213a91_2
55 | - gstreamer=1.14.0=h28cd5cc_2
56 | - icu=58.2=he6710b0_3
57 | - idna=3.3=pyhd3eb1b0_0
58 | - imagecodecs=2021.8.26=py39h4cda21f_0
59 | - imageio=2.9.0=pyhd3eb1b0_0
60 | - intel-openmp=2021.4.0=h06a4308_3561
61 | - ipykernel=6.9.1=py39h06a4308_0
62 | - ipython=8.3.0=py39h06a4308_0
63 | - ipython_genutils=0.2.0=pyhd3eb1b0_1
64 | - ipywidgets=7.6.5=pyhd3eb1b0_1
65 | - jedi=0.18.1=py39h06a4308_1
66 | - jinja2=3.0.3=pyhd3eb1b0_0
67 | - joblib=1.1.0=pyhd3eb1b0_0
68 | - jpeg=9e=h7f8727e_0
69 | - jsonschema=4.4.0=py39h06a4308_0
70 | - jupyter_client=7.2.2=py39h06a4308_0
71 | - jupyter_core=4.10.0=py39h06a4308_0
72 | - jupyterlab_pygments=0.1.2=py_0
73 | - jupyterlab_widgets=1.0.0=pyhd3eb1b0_1
74 | - jxrlib=1.1=h7b6447c_2
75 | - kiwisolver=1.3.2=py39h295c915_0
76 | - krb5=1.19.2=hac12032_0
77 | - lame=3.100=h7b6447c_0
78 | - lcms2=2.12=h3be6417_0
79 | - ld_impl_linux-64=2.38=h1181459_1
80 | - lerc=3.0=h295c915_0
81 | - libaec=1.0.4=he6710b0_1
82 | - libcurl=7.82.0=h0b77cf5_0
83 | - libdeflate=1.8=h7f8727e_5
84 | - libedit=3.1.20210910=h7f8727e_0
85 | - libev=4.33=h7f8727e_1
86 | - libffi=3.3=he6710b0_2
87 | - libgcc-ng=11.2.0=h1234567_1
88 | - libgfortran-ng=7.5.0=ha8ba4b0_17
89 | - libgfortran4=7.5.0=ha8ba4b0_17
90 | - libgomp=11.2.0=h1234567_1
91 | - libidn2=2.3.2=h7f8727e_0
92 | - libnghttp2=1.46.0=hce63b2e_0
93 | - libopus=1.3.1=h7b6447c_0
94 | - libpng=1.6.37=hbc83047_0
95 | - libsodium=1.0.18=h7b6447c_0
96 | - libssh2=1.10.0=h8f2d780_0
97 | - libstdcxx-ng=11.2.0=h1234567_1
98 | - libtasn1=4.16.0=h27cfd23_0
99 | - libtiff=4.2.0=h85742a9_0
100 | - libunistring=0.9.10=h27cfd23_0
101 | - libuuid=1.0.3=h7f8727e_2
102 | - libuv=1.40.0=h7b6447c_0
103 | - libvpx=1.7.0=h439df22_0
104 | - libwebp=1.2.2=h55f646e_0
105 | - libwebp-base=1.2.2=h7f8727e_0
106 | - libxcb=1.14=h7b6447c_0
107 | - libxml2=2.9.12=h03d6c58_0
108 | - libzopfli=1.0.3=he6710b0_0
109 | - locket=0.2.1=py39h06a4308_2
110 | - lz4-c=1.9.3=h295c915_1
111 | - markupsafe=2.1.1=py39h7f8727e_0
112 | - matplotlib=3.5.1=py39hf3d152e_0
113 | - matplotlib-base=3.5.1=py39ha18d171_1
114 | - matplotlib-inline=0.1.2=pyhd3eb1b0_2
115 | - mistune=0.8.4=py39h27cfd23_1000
116 | - mkl=2021.4.0=h06a4308_640
117 | - mkl-service=2.4.0=py39h7f8727e_0
118 | - mkl_fft=1.3.1=py39hd3c417c_0
119 | - mkl_random=1.2.2=py39h51133e4_0
120 | - munkres=1.1.4=py_0
121 | - nbclient=0.5.13=py39h06a4308_0
122 | - nbconvert=6.4.4=py39h06a4308_0
123 | - nbformat=5.3.0=py39h06a4308_0
124 | - ncurses=6.3=h7f8727e_2
125 | - nest-asyncio=1.5.5=py39h06a4308_0
126 | - nettle=3.7.3=hbbd107a_1
127 | - networkx=2.7.1=pyhd3eb1b0_0
128 | - notebook=6.4.11=py39h06a4308_0
129 | - numexpr=2.8.1=py39h6abb31d_0
130 | - numpy=1.21.5=py39he7a7128_1
131 | - numpy-base=1.21.5=py39hf524024_1
132 | - openh264=2.1.1=h4ff587b_0
133 | - openjpeg=2.4.0=h3ad879b_0
134 | - openssl=1.1.1q=h7f8727e_0
135 | - packaging=21.3=pyhd3eb1b0_0
136 | - pandas=1.4.2=py39h295c915_0
137 | - pandocfilters=1.5.0=pyhd3eb1b0_0
138 | - parso=0.8.3=pyhd3eb1b0_0
139 | - partd=1.2.0=pyhd3eb1b0_1
140 | - pcre=8.45=h295c915_0
141 | - pexpect=4.8.0=pyhd3eb1b0_3
142 | - pickleshare=0.7.5=pyhd3eb1b0_1003
143 | - pillow=9.0.1=py39h22f2fdc_0
144 | - pip=21.2.4=py39h06a4308_0
145 | - prometheus_client=0.13.1=pyhd3eb1b0_0
146 | - prompt-toolkit=3.0.20=pyhd3eb1b0_0
147 | - ptyprocess=0.7.0=pyhd3eb1b0_2
148 | - pure_eval=0.2.2=pyhd3eb1b0_0
149 | - pycparser=2.21=pyhd3eb1b0_0
150 | - pygments=2.11.2=pyhd3eb1b0_0
151 | - pyopenssl=22.0.0=pyhd3eb1b0_0
152 | - pyparsing=3.0.4=pyhd3eb1b0_0
153 | - pyqt=5.9.2=py39h2531618_6
154 | - pyrsistent=0.18.0=py39heee7806_0
155 | - pysocks=1.7.1=py39h06a4308_0
156 | - python=3.9.12=h12debd9_1
157 | - python-dateutil=2.8.2=pyhd3eb1b0_0
158 | - python-fastjsonschema=2.15.1=pyhd3eb1b0_0
159 | - python-slugify=6.1.2=pyhd8ed1ab_0
160 | - python_abi=3.9=2_cp39
161 | - pytorch=1.12.0=py3.9_cuda11.3_cudnn8.3.2_0
162 | - pytorch-mutex=1.0=cuda
163 | - pytz=2021.3=pyhd3eb1b0_0
164 | - pywavelets=1.3.0=py39h7f8727e_0
165 | - pyyaml=6.0=py39h7f8727e_1
166 | - pyzmq=22.3.0=py39h295c915_2
167 | - qt=5.9.7=h5867ecd_1
168 | - readline=8.1.2=h7f8727e_1
169 | - requests=2.27.1=pyhd3eb1b0_0
170 | - scikit-image=0.19.2=py39h51133e4_0
171 | - scikit-learn=1.0.2=py39h51133e4_1
172 | - scipy=1.7.3=py39hc147768_0
173 | - send2trash=1.8.0=pyhd3eb1b0_1
174 | - setuptools=61.2.0=py39h06a4308_0
175 | - sip=4.19.13=py39h295c915_0
176 | - six=1.16.0=pyhd3eb1b0_1
177 | - snappy=1.1.9=h295c915_0
178 | - soupsieve=2.3.1=pyhd3eb1b0_0
179 | - sqlite=3.38.3=hc218d9a_0
180 | - stack_data=0.2.0=pyhd3eb1b0_0
181 | - terminado=0.13.1=py39h06a4308_0
182 | - testpath=0.5.0=pyhd3eb1b0_0
183 | - text-unidecode=1.3=pyhd3eb1b0_0
184 | - threadpoolctl=2.2.0=pyh0d69192_0
185 | - tifffile=2021.7.2=pyhd3eb1b0_2
186 | - tk=8.6.12=h1ccaba5_0
187 | - toolz=0.11.2=pyhd3eb1b0_0
188 | - tornado=6.1=py39h27cfd23_0
189 | - traitlets=5.1.1=pyhd3eb1b0_0
190 | - typing-extensions=4.1.1=hd3eb1b0_0
191 | - typing_extensions=4.1.1=pyh06a4308_0
192 | - tzdata=2022a=hda174b7_0
193 | - unidecode=1.2.0=pyhd3eb1b0_0
194 | - urllib3=1.26.9=py39h06a4308_0
195 | - wcwidth=0.2.5=pyhd3eb1b0_0
196 | - webencodings=0.5.1=py39h06a4308_1
197 | - wheel=0.37.1=pyhd3eb1b0_0
198 | - widgetsnbextension=3.5.2=py39h06a4308_0
199 | - x264=1!157.20191217=h7b6447c_0
200 | - xz=5.2.5=h7f8727e_1
201 | - yacs=0.1.8=pyhd8ed1ab_0
202 | - yaml=0.2.5=h7b6447c_0
203 | - zeromq=4.3.4=h2531618_0
204 | - zfp=0.5.5=h295c915_6
205 | - zlib=1.2.12=h7f8727e_2
206 | - zstd=1.4.9=haebb681_0
207 | - pip:
208 | - blobfile==1.3.3
209 | - click==8.1.2
210 | - docker-pycreds==0.4.0
211 | - filelock==3.8.0
212 | - ftfy==6.1.1
213 | - gitdb==4.0.9
214 | - gitpython==3.1.27
215 | - h5py==3.7.0
216 | - ordered-set==4.1.0
217 | - pathtools==0.1.2
218 | - promise==2.3
219 | - protobuf==3.19.4
220 | - psutil==5.9.0
221 | - pycryptodomex==3.15.0
222 | - regex==2022.4.24
223 | - sentry-sdk==1.5.10
224 | - setproctitle==1.2.3
225 | - shortuuid==1.0.8
226 | - smmap==5.0.0
227 | - torchmeta==1.8.0
228 | - torchvision==0.10.1
229 | - tqdm==4.64.0
230 | - wandb==0.12.15
231 | - xmltodict==0.12.0
232 |
--------------------------------------------------------------------------------
/features/image_features.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import clip
4 | import numpy as np
5 | import torch
6 | import tqdm
7 | from PIL import Image
8 | from slugify import slugify
9 | from torchvision.transforms import Compose, Normalize, Resize, ToTensor
10 | from utils import clip_utils
11 |
12 | try:
13 | from torchvision.transforms import InterpolationMode
14 | BICUBIC = InterpolationMode.BICUBIC
15 | except ImportError:
16 | BICUBIC = Image.BICUBIC
17 |
18 | base_path = os.path.dirname(os.path.dirname(__file__))
19 | device = "cuda" if torch.cuda.is_available() else "cpu"
20 |
21 | def _convert_image_to_rgb(image):
22 | return image.convert("RGB")
23 |
24 | def compute_image_features(model):
25 |
26 | print("Load CLIP {} model".format(model))
27 | clip_model, _ = clip.load(model, device=torch.device("cpu"), jit=False) # load cpu version to get float32 model
28 | clip_model = clip_model.to(device)
29 | image_resolution = clip_utils.image_resolution[model]
30 | preprocess = Compose([Resize((image_resolution,image_resolution), interpolation=BICUBIC),
31 | _convert_image_to_rgb,
32 | ToTensor(),
33 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
34 | ])
35 |
36 | img_features = {}
37 |
38 | for data_subtype in ['train2014', 'val2014']:
39 |
40 | img_dir = base_path + f'/data/VQA/Images/{data_subtype}/'
41 |
42 | for filename in tqdm.tqdm(os.listdir(img_dir)):
43 |
44 | img_name = filename.replace(".jpg", "")
45 | image = preprocess(Image.open(img_dir + filename)).unsqueeze(0).to(device)
46 |
47 | with torch.no_grad():
48 | features = clip_model.encode_image(image)
49 |
50 | img_features[img_name] = features.float().cpu().numpy()
51 |
52 | np.save(base_path+f'/data/VQA/Features/ImageFeatures/image_features_{slugify(model)}.npy',img_features)
53 |
54 | def load_image_features(model):
55 | img_features = np.load(base_path + f"/data/VQA/Features/ImageFeatures/image_features_{slugify(model)}.npy", allow_pickle=True).item()
56 |
57 | for img_name in img_features.keys():
58 | img_features[img_name] = torch.from_numpy(img_features[img_name]).to(device)
59 |
60 | return img_features
61 |
--------------------------------------------------------------------------------
/features/ques_features.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | import clip
5 | import numpy as np
6 | import torch
7 | import tqdm
8 | from slugify import slugify
9 |
10 | base_path = os.path.dirname(os.path.dirname(__file__))
11 | device = "cuda" if torch.cuda.is_available() else "cpu"
12 |
13 | def compute_ques_features(model):
14 | print("Load CLIP {} model".format(model))
15 | clip_model, _ = clip.load(model, device=torch.device("cpu"), jit=False) # load cpu version to get float32 model
16 | clip_model = clip_model.to(device)
17 |
18 | with open(base_path+"/data/VQA/Meta/ques_ans_count.json") as file:
19 | qa = json.load(file)
20 |
21 | ques_features = {}
22 | for ques in tqdm.tqdm(qa.keys()):
23 | text = clip.tokenize(ques).to(device)
24 |
25 | with torch.no_grad():
26 | features = clip_model.encode_text(text)
27 | # text_features = text_features / text_features.norm(dim=-1, keepdim=True)
28 |
29 | features = features.float().cpu().numpy() #np.squeeze(features.float().cpu().numpy())
30 |
31 | ques_features[ques] = features
32 |
33 | # with open(base_path+"/data/VQA/Features/TextFeatures/text_features.json", "w") as outfile:
34 | # json.dump(text_features_val, outfile)
35 | np.save(base_path+f'/data/VQA/Features/QuesFeatures/ques_features_{slugify(model)}.npy',ques_features)
36 |
37 | def load_ques_features(model):
38 | ques_features = np.load(base_path + f"/data/VQA/Features/QuesFeatures/ques_features_{slugify(model)}.npy", allow_pickle=True).item()
39 |
40 | for ques in ques_features.keys():
41 | ques_features[ques] = torch.from_numpy(ques_features[ques]).to(device)
42 |
43 | return ques_features
44 |
--------------------------------------------------------------------------------
/features/text_features.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | import clip
5 | import numpy as np
6 | import torch
7 | import tqdm
8 | from slugify import slugify
9 |
10 | base_path = os.path.dirname(os.path.dirname(__file__))
11 | device = "cuda" if torch.cuda.is_available() else "cpu"
12 |
13 | def compute_text_features(model):
14 | print("Load CLIP {} model".format(model))
15 | clip_model, _ = clip.load(model, device=torch.device("cpu"), jit=False) # load cpu version to get float32 model
16 | clip_model = clip_model.to(device)
17 |
18 | with open(base_path+"/data/VQA/Meta/ques_ans_count.json") as file:
19 | qa = json.load(file)
20 |
21 | with open(base_path+"/data/VQA/Annotations/prompt_meta.json") as file:
22 | prompt_meta = json.load(file)
23 |
24 | text_features_meta = {}
25 | for ques in tqdm.tqdm(qa.keys()):
26 | temp = prompt_meta[ques]
27 | answers = list(qa[ques].keys())
28 | prompts = [temp.format(a.replace("_", " ")) for a in answers]
29 | prompts = torch.cat([clip.tokenize(p) for p in prompts])
30 | prompts = prompts.to(device)
31 |
32 | with torch.no_grad():
33 | text_features = clip_model.encode_text(prompts)
34 | # text_features = text_features / text_features.norm(dim=-1, keepdim=True)
35 |
36 | text_features = text_features.float().cpu().numpy()
37 |
38 | text_features_meta[ques] = text_features
39 |
40 | # with open(base_path+"/data/VQA/Features/TextFeatures/text_features.json", "w") as outfile:
41 | # json.dump(text_features_val, outfile)
42 | np.save(base_path+f'/data/VQA/Features/TextFeatures/text_features_{slugify(model)}.npy',text_features_meta)
43 |
44 | def load_text_features(model):
45 | text_features = np.load(base_path + f"/data/VQA/Features/TextFeatures/text_features_{slugify(model)}.npy", allow_pickle=True).item()
46 |
47 | for ques in text_features.keys():
48 | text_features[ques] = torch.from_numpy(text_features[ques]).to(device)
49 |
50 | return text_features
51 |
--------------------------------------------------------------------------------
/model/custom_hnet.py:
--------------------------------------------------------------------------------
1 | import math
2 | from collections import OrderedDict
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | from torch.nn import functional as F
8 | from torchmeta.modules import MetaLinear, MetaModule, MetaSequential
9 | from utils import clip_utils
10 |
11 |
12 | def get_from_dict_or_default(module, dict, key):
13 | if key in dict:
14 | return dict[key]
15 | return getattr(module, key)
16 |
17 | class Mlp(MetaModule):
18 | def __init__(self, input_dim, hidden_dims, output_dim, nonlin='relu'):
19 | super().__init__()
20 | self.input_dim=input_dim
21 | self.hidden_dims=hidden_dims
22 | self.output_dim=output_dim
23 |
24 | linear = []
25 | prev_layer_dim = self.input_dim
26 | for dim in hidden_dims:
27 | linear.append(MetaLinear(prev_layer_dim, dim))
28 | torch.nn.init.kaiming_normal_(linear[-1].weight, nonlinearity="relu")
29 | linear[-1].bias.data *= 0
30 | if nonlin == 'relu':
31 | linear.append(nn.ReLU())
32 | elif nonlin == 'tanh':
33 | linear.append(nn.Tanh())
34 | elif nonlin == 'softplus':
35 | linear.append(nn.Softplus())
36 | else:
37 | assert False
38 |
39 | prev_layer_dim = dim
40 |
41 | linear.append(MetaLinear(prev_layer_dim, output_dim))
42 | torch.nn.init.kaiming_normal_(linear[-1].weight, nonlinearity="linear")
43 | linear[-1].bias.data *= 0
44 | self.mlp = MetaSequential(*linear)
45 |
46 | def forward(self, uncond_input, params = None):
47 | if len(uncond_input.shape) == 1:
48 | uncond_input = uncond_input.unsqueeze(0)
49 | return self.mlp(uncond_input, params = self.get_subdict(params, "mlp"))
50 |
51 | class HyperGenerator(Mlp):
52 | def __init__(self, mnet, e_dim, hidden_dims, normalize=False):
53 | super().__init__(e_dim, hidden_dims, mnet.get_parameter_vector().shape[0])
54 | self.normalize=normalize
55 | if self.normalize:
56 | self.register_buffer("feature_mean", torch.zeros(self.output_dim))
57 | self.register_buffer("feature_std", torch.ones(self.output_dim))
58 |
59 | def set_stats(self, mean, std):
60 | self.feature_mean.data = mean.detach()
61 | self.feature_std.data = std.detach()
62 |
63 | def forward(self, uncond_input, params = None):
64 | res = super().forward(uncond_input, params)
65 | if self.normalize:
66 | res = res * self.feature_std
67 | res = res + self.feature_mean
68 | return res
69 |
70 |
71 | class HyperEncoder(Mlp):
72 | def __init__(self, mnet, e_dim, hidden_dims, normalize=False):
73 | super().__init__(mnet.get_parameter_vector().shape[0], hidden_dims, 2*e_dim)
74 | self.normalize=normalize
75 | if self.normalize:
76 | self.register_buffer("feature_mean", torch.zeros(self.input_dim))
77 | self.register_buffer("feature_std", torch.ones(self.input_dim))
78 |
79 | def set_stats(self, mean, std):
80 | self.feature_mean.data = mean.detach()
81 | self.feature_std.data = std.detach()
82 |
83 | def forward(self, uncond_input, params = None):
84 | if self.normalize:
85 | uncond_input = uncond_input - self.feature_mean
86 | uncond_input = uncond_input / self.feature_std
87 | return super().forward(uncond_input, params)
88 |
89 |
90 | class HyperDiscriminator(Mlp):
91 | def __init__(self, mnet, hidden_dims):
92 | super().__init__(mnet.get_parameter_vector().shape[0], hidden_dims, 1)
93 |
94 | class CLIPAdapter(MetaModule):
95 | def __init__(self, e_dim, hidden_layers, use_bias, no_weights=False, straight_through=False, ignore_passed_weights=False):
96 | super().__init__()
97 | assert len(hidden_layers) == 1, "Architecture supports a single hidden layer."
98 | hidden_size = hidden_layers[0]
99 | self.e_dim=e_dim
100 | self.hidden_size=hidden_size
101 | self.use_bias=use_bias
102 | self.no_weights=no_weights
103 | self.no_weight=no_weights
104 | self.straight_through=straight_through
105 | self.ignore_passed_weights = ignore_passed_weights
106 |
107 | if no_weights:
108 | self.register_buffer("W1", torch.randn(hidden_size, e_dim).requires_grad_())
109 | self.register_buffer("b1", torch.randn(hidden_size).requires_grad_() if use_bias else None)
110 | self.register_buffer("W2", torch.randn(e_dim, hidden_size).requires_grad_())
111 | self.register_buffer("b2", torch.randn(e_dim).requires_grad_() if use_bias else None)
112 | else:
113 | norm_W1=math.sqrt(self.e_dim) if self.straight_through else 1
114 | norm_W2=math.sqrt(self.hidden_size) if self.straight_through else 1
115 |
116 | self.W1=torch.nn.Parameter((torch.randn(hidden_size, e_dim)/norm_W1).requires_grad_())
117 | self.b1=torch.nn.Parameter(torch.randn(hidden_size).requires_grad_()) if use_bias else None
118 | self.W2=torch.nn.Parameter((torch.randn(e_dim, hidden_size)/norm_W2).requires_grad_())
119 | self.b2=torch.nn.Parameter(torch.randn(e_dim).requires_grad_()) if use_bias else None
120 |
121 | def get_parameter_vector(self, params=None):
122 | if params is not None:
123 | return torch.cat([params["W1"].flatten(), params["W2"].flatten()] + [params["b1"], params["b2"]] if self.use_bias else [])
124 | return torch.cat([self.W1.flatten(), self.W2.flatten()] + [self.b1, self.b2] if self.use_bias else [])
125 |
126 | def get_gradient_vector(self):
127 | return torch.cat([self.W1.grad.flatten(), self.W2.grad.flatten()] + [self.b1.grad, self.b2.grad] if self.use_bias else [])
128 |
129 | def load_from_vector(self, vector):
130 |
131 | def get_last_elements(v, shape):
132 | numel = np.prod(shape)
133 | res = v[-numel:]
134 | v = v[:-numel]
135 | return res.reshape(*shape), v
136 |
137 | params = OrderedDict()
138 | vector = vector.flatten()
139 | if self.use_bias:
140 | params["b2"], vector = get_last_elements(vector, [self.e_dim])
141 | params["b1"], vector = get_last_elements(vector, [self.hidden_size])
142 | params["W2"], vector = get_last_elements(vector, [self.e_dim, self.hidden_size])
143 | params["W1"], vector = get_last_elements(vector, [self.hidden_size, self.e_dim])
144 | assert len(vector) == 0
145 |
146 | return params
147 |
148 | def forward(self, image_features, text_features, weights=None, params=None):
149 | param_dict = OrderedDict(params) if params is not None else OrderedDict()
150 |
151 | if weights is not None and not self.ignore_passed_weights:
152 | param_dict.update(self.load_from_vector(weights))
153 |
154 | W1 = get_from_dict_or_default(self, param_dict, "W1")
155 | W2 = get_from_dict_or_default(self, param_dict, "W2")
156 | b1 = get_from_dict_or_default(self, param_dict, "b1")
157 | b2 = get_from_dict_or_default(self, param_dict, "b2")
158 |
159 | normalized_W1 = W1
160 | normalized_W2 = W2
161 | if not self.straight_through:
162 | normalized_W1 = W1/math.sqrt(self.e_dim)
163 | normalized_W2 = W2/math.sqrt(self.hidden_size)
164 |
165 | identity = image_features
166 | out = F.linear(image_features, normalized_W1, b1)
167 | out = F.relu(out)
168 | out = F.linear(out, normalized_W2, b2)
169 | out += identity
170 | adapted_image_features = out / out.norm(dim=-1, keepdim=True)
171 |
172 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
173 | text_features = text_features.to(torch.float32)
174 | logits = 100.0 * adapted_image_features @ torch.transpose(text_features, 1, 2)
175 | logits = torch.squeeze(logits,1)
176 |
177 | return logits
178 |
179 | class EmbeddingModule(torch.nn.Module):
180 | def __init__(self, input_dim):
181 | super().__init__()
182 | self.embedding = torch.nn.Parameter(torch.randn([1, input_dim]))
183 |
184 | def reset(self, embedding=None):
185 | self.embedding.data *= 0
186 | if embedding is None:
187 | embedding = torch.randn_like(self.embedding)
188 | self.embedding.data += embedding
189 |
190 |
191 | def forward(self, params=None):
192 | param_dict = OrderedDict(params) if params is not None else OrderedDict()
193 | embedding = get_from_dict_or_default(self, param_dict, "embedding")
194 |
195 | return embedding
196 |
197 | class MetaModel(MetaModule):
198 | def _init(self, mnet=None, hnet=None, enet=None, alpha=1):
199 | super().__init__()
200 | self.mnet = mnet
201 | self.hnet = hnet
202 | self.enet = enet
203 | self.alpha=alpha
204 | if self.hnet is not None:
205 | if self.enet is not None:
206 | self.meta_params = list(self.hnet.parameters())+list(self.enet.parameters())
207 | self.inner_module = self.enet
208 | else:
209 | self.meta_params = list(self.hnet.parameters())
210 | self.inner_module = self.hnet
211 | else:
212 | self.meta_params = list(self.mnet.parameters())
213 | self.inner_module = self.mnet
214 |
215 | self.inner_params= self._get_inner_params()
216 |
217 | def __init__(self, mnet=None, hnet=None, enet=None, inner_param=None, mainnet_use_bias=None, mainnet_hidden_dim=None, hypernet_hidden_dim=None, embedding_dim=None, straight_through=False, config=None):
218 | super().__init__()
219 | if "alpha" not in config:
220 | alpha=1
221 | else:
222 | alpha = config["alpha"]
223 |
224 | if mnet is not None or hnet is not None or enet is not None:
225 | self._init(mnet, hnet, enet, alpha)
226 | if inner_param == "enet":
227 | self.mnet = CLIPAdapter(clip_utils.embedding_size[config["clip_model"]],
228 | mainnet_hidden_dim, use_bias=mainnet_use_bias, no_weights=True, straight_through=straight_through)
229 | self.hnet = HyperGenerator(self.mnet, e_dim=embedding_dim,
230 | hidden_dims=hypernet_hidden_dim, )
231 | self.enet = EmbeddingModule(embedding_dim)
232 | self.meta_params = list(self.hnet.parameters())+list(self.enet.parameters())
233 | self.inner_module = self.enet
234 |
235 | elif inner_param == "hnet":
236 | self.mnet = CLIPAdapter(clip_utils.embedding_size[config["clip_model"]],
237 | mainnet_hidden_dim, use_bias=mainnet_use_bias, no_weights=True, straight_through=straight_through)
238 | self.hnet = HyperGenerator(self.mnet, e_dim=clip_utils.embedding_size[config["clip_model"]],
239 | hidden_dims=hypernet_hidden_dim, )
240 | self.enet = None
241 | self.meta_params = list(self.hnet.parameters())
242 | self.inner_module = self.hnet
243 |
244 | elif inner_param == "mnet":
245 | self.mnet = CLIPAdapter(clip_utils.embedding_size[config["clip_model"]],
246 | mainnet_hidden_dim, use_bias=mainnet_use_bias, no_weights=False, straight_through=straight_through)
247 | self.hnet = None
248 | self.enet = None
249 | self.meta_params = list(self.mnet.parameters())
250 | self.inner_module = self.mnet
251 | self.alpha=alpha
252 | self.inner_params= self._get_inner_params()
253 |
254 | def _get_inner_params(self):
255 | params = OrderedDict()
256 | for (name, param) in self.named_parameters():
257 | if any([id(param) == id(b) for b in self.inner_module.parameters()]):
258 | params[name] = param
259 | return params
260 |
261 | def get_inner_params(self):
262 | return self.inner_params
263 |
264 | def _forward(self, sample_image_features, sample_text_features, sample_ques_emb, params=None):
265 | if self.hnet is not None:
266 | if self.enet is None:
267 | weights = self.hnet.forward(uncond_input=sample_ques_emb, params=self.get_subdict(params, "hnet"))
268 | else:
269 | weights = self.hnet.forward(uncond_input=self.enet(params=self.get_subdict(params, "enet")))
270 | similarity = self.mnet(sample_image_features, sample_text_features, weights=weights)
271 | else:
272 | similarity = self.mnet(sample_image_features, sample_text_features, params=self.get_subdict(params, "mnet"))
273 |
274 | return similarity
275 |
276 | def forward(self, *args, params=None, **kwargs):
277 | if self.alpha == 1:
278 | return self._forward( *args, params=params, **kwargs)
279 | init_output = self._forward( *args, **kwargs)
280 | adapted_output = self._forward( *args, params = params, **kwargs)
281 | return self.alpha * (adapted_output - init_output) + init_output
282 |
283 | def get_mainnet_weights(self, ques_emb = None, params=None):
284 | if self.hnet is not None:
285 | if self.enet is None:
286 | return self.hnet.forward(uncond_input=ques_emb, params=self.get_subdict(params, "hnet"))
287 | else:
288 | return self.hnet.forward(uncond_input=self.enet(params=self.get_subdict(params, "enet")))
289 | else:
290 | return self.mnet.get_parameter_vector(params=self.get_subdict(params, "mnet"))
291 |
--------------------------------------------------------------------------------
/model/hyperclip.py:
--------------------------------------------------------------------------------
1 | # Credit: modification of code from https://github.com/AndreyGuzhov/AudioCLIP
2 |
3 | from typing import List, Optional, Tuple, Union
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | from clip.model import CLIP
8 |
9 | from model.custom_hnet import Mlp
10 |
11 | ClipFeatures = Tuple[
12 | Optional[torch.Tensor], # text
13 | Optional[torch.Tensor], # image
14 | Optional[torch.Tensor] # hyper
15 | ]
16 |
17 |
18 | ClipLogits = Tuple[
19 | Optional[torch.Tensor], # hyper x image
20 | Optional[torch.Tensor], # hyper x text
21 | Optional[torch.Tensor] # image x text
22 | ]
23 |
24 |
25 | ClipOutput = Tuple[
26 | Tuple[ClipFeatures, ClipLogits],
27 | Optional[torch.Tensor] # loss
28 | ]
29 |
30 | class HyperCLIP(CLIP):
31 |
32 | def __init__(self,
33 | embed_dim: int = 1024,
34 | # vision
35 | image_resolution: int = 224,
36 | vision_layers: Union[Tuple[int, int, int, int], int] = (3, 4, 6, 3),
37 | vision_width: int = 64,
38 | vision_patch_size: Optional[int] = None,
39 | # text
40 | context_length: int = 77,
41 | vocab_size: int = 49408,
42 | transformer_width: int = 512,
43 | transformer_heads: int = 8,
44 | transformer_layers: int = 12,
45 | # hyper
46 | hyper_model: str = "mlp", # choose between "mlp" and "embedder_hypernet"
47 | mainnet_param_count: int = 1024*256*2,
48 | hyper_hidden_dims: List[int] = [512],
49 | # pretrained model
50 | pretrained_it_location: Optional[str] = None,
51 | pretrained_hyper_location: Optional[str] = None):
52 |
53 | super(HyperCLIP, self).__init__(
54 | embed_dim=embed_dim,
55 | image_resolution=image_resolution,
56 | vision_layers=vision_layers,
57 | vision_width=vision_width,
58 | vision_patch_size=vision_patch_size,
59 | context_length=context_length,
60 | vocab_size=vocab_size,
61 | transformer_width=transformer_width,
62 | transformer_heads=transformer_heads,
63 | transformer_layers=transformer_layers
64 | )
65 |
66 | self.embed_dim = embed_dim
67 | self.hyper_model = hyper_model
68 |
69 | self.pretrained_it_location = pretrained_it_location
70 | self.pretrained_hyper_location = pretrained_hyper_location
71 |
72 | self.logit_scale_hi = torch.nn.Parameter(torch.log(torch.ones([]) * 100))
73 | self.logit_scale_ht = torch.nn.Parameter(torch.log(torch.ones([]) * 100))
74 |
75 | if pretrained_it_location is not None:
76 | self.load_state_dict(torch.jit.load(self.pretrained_it_location, map_location='cpu').state_dict(), strict=False)
77 | print('Image & Text weights loaded')
78 |
79 | if self.hyper_model == "mlp":
80 | self.hyper = Mlp(mainnet_param_count, hyper_hidden_dims, embed_dim, 'softplus')
81 | else:
82 | raise ValueError(f"Unsupported hyper model {self.hyper_model}")
83 |
84 | if pretrained_hyper_location is not None:
85 | self.hyper.load_state_dict(torch.load(self.pretrained_hyper_location, map_location='cpu'), strict=False)
86 | print('Hyper weights loaded')
87 |
88 | @property
89 | def device(self):
90 | return self.visual.conv1.weight.device
91 |
92 | def encode_hyper(self, weights: torch.Tensor) -> torch.Tensor:
93 | return self.hyper(weights.to(self.device))
94 |
95 | def forward(self,
96 | hyper: Optional[torch.Tensor] = None,
97 | image: Optional[torch.Tensor] = None,
98 | text: Optional[Union[List[List[str]],torch.Tensor]] = None,
99 | batch_indices: Optional[torch.Tensor] = None,
100 | # precomputed embeddings
101 | precomputed_it_embs: bool = False) -> ClipOutput:
102 |
103 | hyper_features = None
104 | image_features = None
105 | text_features = None
106 | sample_weights = None
107 |
108 | if hyper is not None:
109 | hyper_features = self.encode_hyper(hyper)
110 | hyper_features = hyper_features / hyper_features.norm(dim=-1, keepdim=True)
111 |
112 | if image is not None:
113 | if precomputed_it_embs:
114 | image_features = image
115 | else:
116 | image_features = self.encode_image(image)
117 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
118 |
119 | if text is not None:
120 | if precomputed_it_embs:
121 | text_features = text
122 | else:
123 | if batch_indices is None:
124 | batch_indices = torch.arange(len(text), dtype=torch.int64, device=self.device)
125 | text_features = self.encode_text(text, '{}', batch_indices)
126 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
127 |
128 | if hasattr(self, 'class_weights') and hasattr(self, 'label_to_class_idx'):
129 | sample_weights = torch.stack([
130 | sum(self.class_weights[self.label_to_class_idx[label]] for label in entities)
131 | for idx, entities in enumerate(text) if idx in batch_indices
132 | ])
133 |
134 | features: ClipFeatures = (hyper_features, image_features, text_features)
135 |
136 | logit_scale_hi = torch.clamp(self.logit_scale_hi.exp(), min=1.0, max=100.0)
137 | logit_scale_ht = torch.clamp(self.logit_scale_ht.exp(), min=1.0, max=100.0)
138 | logit_scale_it = torch.clamp(self.logit_scale.exp(), min=1.0, max=100.0)
139 |
140 | logits_hyper_image = None
141 | logits_hyper_text = None
142 | logits_image_text = None
143 |
144 | if (hyper_features is not None) and (image_features is not None):
145 | logits_hyper_image = logit_scale_hi * hyper_features @ image_features.T
146 |
147 | if (hyper_features is not None) and (text_features is not None):
148 | logits_hyper_text = logit_scale_ht * hyper_features @ text_features.T
149 |
150 | if (image_features is not None) and (text_features is not None):
151 | logits_image_text = logit_scale_it * image_features @ text_features.T
152 |
153 | logits: ClipLogits = (logits_hyper_image, logits_hyper_text, logits_image_text)
154 |
155 | loss = self.loss_fn(logits, sample_weights)
156 |
157 | return (features, logits), loss
158 |
159 | def loss_fn(self, logits: ClipLogits, sample_weights: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]:
160 | logits_hyper_image, logits_hyper_text, logits_image_text = logits
161 |
162 | if logits_hyper_image is not None:
163 | batch_size = logits_hyper_image.shape[0]
164 | elif logits_hyper_text is not None:
165 | batch_size = logits_hyper_text.shape[0]
166 | elif logits_image_text is not None:
167 | batch_size = logits_image_text.shape[0]
168 | else:
169 | return None
170 |
171 | reference = torch.arange(
172 | batch_size,
173 | dtype=torch.int64,
174 | device=self.device
175 | )
176 |
177 | loss = torch.tensor(0.0, dtype=self.dtype, device=self.device)
178 |
179 | num_modalities: int = 0
180 | scale = torch.tensor(1.0, dtype=self.dtype, device=self.device)
181 |
182 | if logits_hyper_image is not None:
183 | loss_hi = F.cross_entropy(
184 | logits_hyper_image, reference, weight=sample_weights
185 | ) + F.cross_entropy(
186 | logits_hyper_image.transpose(-1, -2), reference, weight=sample_weights
187 | )
188 | loss = loss + loss_hi
189 | num_modalities += 1
190 |
191 | if logits_hyper_text is not None:
192 | loss_ht = F.cross_entropy(
193 | logits_hyper_text, reference, weight=sample_weights
194 | ) + F.cross_entropy(
195 | logits_hyper_text.transpose(-1, -2), reference, weight=sample_weights
196 | )
197 | loss = loss + loss_ht
198 | num_modalities += 1
199 |
200 | if logits_image_text is not None:
201 | loss_it = F.cross_entropy(
202 | logits_image_text, reference, weight=sample_weights
203 | ) + F.cross_entropy(
204 | logits_image_text.transpose(-1, -2), reference, weight=sample_weights
205 | )
206 | loss = loss + loss_it
207 | num_modalities += 1
208 |
209 | for idx in range(num_modalities):
210 | scale = scale * (idx + 1)
211 |
212 | return loss / scale
213 |
214 | @property
215 | def loss_fn_name(self) -> str:
216 | return 'Cross Entropy'
217 |
218 |
219 | def build_hyperclip_from_classic_clip(state_dict: Union[dict, str],
220 | hyper_model: str = "mlp",
221 | mainnet_param_count: int = 2014*256*2,
222 | hyper_hidden_dims: List[int] = [512],
223 | pretrained_it_location: Optional[str] = None,
224 | pretrained_hyper_location: Optional[str] = None):
225 |
226 | if isinstance(state_dict, str):
227 | state_dict = torch.jit.load(state_dict, map_location='cpu').state_dict()
228 |
229 | vit = "visual.proj" in state_dict
230 |
231 | if vit:
232 | vision_width = state_dict["visual.conv1.weight"].shape[0]
233 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
234 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
235 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
236 | image_resolution = vision_patch_size * grid_size
237 | else:
238 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
239 | vision_layers = tuple(counts)
240 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
241 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
242 | vision_patch_size = None
243 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
244 | image_resolution = output_width * 32
245 |
246 | embed_dim = state_dict["text_projection"].shape[1]
247 | context_length = state_dict["positional_embedding"].shape[0]
248 | vocab_size = state_dict["token_embedding.weight"].shape[0]
249 | transformer_width = state_dict["ln_final.weight"].shape[0]
250 | transformer_heads = transformer_width // 64
251 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
252 |
253 | model = HyperCLIP(
254 | embed_dim,
255 | image_resolution, vision_layers, vision_width, vision_patch_size,
256 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers,
257 | hyper_model, mainnet_param_count, hyper_hidden_dims, pretrained_it_location, pretrained_hyper_location
258 | )
259 |
260 | return model
261 |
--------------------------------------------------------------------------------
/model/latent_diffuser.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from utils.diffusion_utils import timestep_embedding
4 |
5 |
6 | class FeedForwardBlock(nn.Module):
7 | def __init__(self, input_dim, output_dim):
8 | super().__init__()
9 | self.input_dim = input_dim
10 | self.output_dim = output_dim
11 | self.layers = nn.Sequential(nn.Linear(input_dim, output_dim),
12 | nn.LayerNorm(output_dim),
13 | nn.ReLU())
14 |
15 | def forward(self, x):
16 | return self.layers(x)
17 |
18 |
19 | class LatentDiffuser(nn.Module):
20 | def __init__(self, x_dim=128, cond_emb_dim=768, timestep_emb_dim=100, hidden_dims=[128,128]):
21 | super().__init__()
22 | self.x_dim = x_dim
23 | self.cond_emb_dim = cond_emb_dim
24 | self.timestep_emb_dim = timestep_emb_dim
25 | self.hidden_dims = hidden_dims
26 | layers = [FeedForwardBlock(x_dim + cond_emb_dim + timestep_emb_dim, hidden_dims[0])]
27 | for i, h in enumerate(hidden_dims[1:]):
28 | layers += [FeedForwardBlock(hidden_dims[i], h)]
29 | layers += [nn.Linear(hidden_dims[-1], x_dim)]
30 | self.layers = nn.ModuleList(layers)
31 |
32 | def forward(self, x, t, cond_emb):
33 | if type(t) == int:
34 | t = torch.tensor(t, device = x.get_device()).repeat((x.shape[0], 1))
35 | # right now we only support basic concat of cond_emb and timestep_emb to input x
36 | timestep_emb = timestep_embedding(t.flatten(), self.timestep_emb_dim)
37 | eps = torch.cat((x, cond_emb, timestep_emb), 1)
38 | for layer in self.layers:
39 | eps = layer(eps)
40 | return eps
41 |
42 | class SELayer(nn.Module):
43 | def __init__(self, input_dim, reduction=8):
44 | super().__init__()
45 | self.input_dim = input_dim
46 | self.reduction = reduction
47 | self.fc = nn.Sequential(
48 | nn.Linear(input_dim, input_dim // reduction, bias=False),
49 | nn.ReLU(),
50 | nn.Linear(input_dim // reduction, input_dim, bias=False),
51 | nn.Sigmoid()
52 | )
53 |
54 | def forward(self, x):
55 | return x * self.fc(x)
56 |
57 | class LatentDiffuserV2(nn.Module):
58 | def __init__(self, x_dim=128, cond_emb_dim=768, timestep_emb_dim=100, hidden_dims=[128,128], se_reduction=8):
59 | super().__init__()
60 | self.x_dim = x_dim
61 | self.cond_emb_dim = cond_emb_dim
62 | self.timestep_emb_dim = timestep_emb_dim
63 | self.hidden_dims = hidden_dims
64 | self.se_reduction = se_reduction
65 | layers = [FeedForwardBlock(x_dim + cond_emb_dim + timestep_emb_dim, hidden_dims[0])]
66 | res_maps = [nn.Linear(x_dim, hidden_dims[0])]
67 | se_layers = [SELayer(x_dim)]
68 | for i, h in enumerate(hidden_dims[1:]):
69 | layers += [FeedForwardBlock(hidden_dims[i] + cond_emb_dim + timestep_emb_dim, h)]
70 | res_maps += [nn.Linear(hidden_dims[i], h)]
71 | se_layers += [SELayer(hidden_dims[i], reduction=self.se_reduction)]
72 | self.layers = nn.ModuleList(layers)
73 | self.res_maps = nn.ModuleList(res_maps)
74 | self.final_layer = nn.Linear(hidden_dims[-1], x_dim)
75 |
76 | def forward(self, x, t, cond_emb):
77 | if type(t) == int:
78 | t = torch.tensor(t, device = x.get_device()).repeat((x.shape[0], 1))
79 | # right now we only support basic concat of cond_emb and timestep_emb to input x
80 | timestep_emb = timestep_embedding(t.flatten(), self.timestep_emb_dim)
81 | eps_in = x
82 | for i in range(len(self.layers)):
83 | eps = self.layers[i](torch.cat((eps_in, cond_emb, timestep_emb), 1))
84 | eps_in = eps + self.res_maps[i](eps_in)
85 | eps = self.final_layer(eps)
86 | return eps
87 |
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elvisnava/hyperclip/8574d3d36fbe1bb3311c3cbb214f07fd73ca0a05/scripts/__init__.py
--------------------------------------------------------------------------------
/scripts/hyperclip_classification_test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | from functools import partial
5 |
6 | import numpy as np
7 | import torch
8 | import wandb
9 | from features.image_features import load_image_features
10 | from features.ques_features import load_ques_features
11 | from features.text_features import load_text_features
12 | from model.hyperclip import build_hyperclip_from_classic_clip
13 | from tqdm import tqdm
14 | from training.store_few_shot_latent import StoreFewShotLatent
15 | from training.vae_learn import sample_from_enc
16 | from utils import clip_utils
17 | from utils.build_opt import build_optimizer
18 | from utils.init_utils import (load_metamodel_from_checkpoint,
19 | load_vae_and_metamodel_from_checkpoint)
20 |
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument('--wandb_mode', type=str, default='online', help='Set to "disabled" to disable Weights & Biases logging')
23 | parser.add_argument('--run_id', type=str, default='srl_ethz/hyperclip-scripts/irwxxdeh', help='The full "//" identifier of the run to load')
24 | parser.add_argument('--task_batch_size', type=int, default=10, help='Size of randomly sampled tasks for task classification')
25 |
26 | base_path = os.path.dirname(os.path.dirname(__file__))
27 |
28 | default_config = {
29 | "inner_epochs": '50',
30 | "inner_optimizer": "sgd",
31 | "inner_learning_rate": 0.1,
32 |
33 | "vae_stochastic_init": False,
34 | }
35 |
36 | torch.manual_seed(42)
37 | rng = np.random.RandomState(42)
38 | np.random.seed(42)
39 |
40 | def main(args):
41 |
42 | cfg = default_config
43 | cfg.update({k: v for (k, v) in vars(args).items() if v is not None})
44 | api = wandb.Api()
45 | loaded_run = api.run(args.run_id)
46 | cfg.update({k: v for (k, v) in loaded_run.config.items() if v is not None and k not in cfg})
47 | wandb.init(project="hyperclip-classification", entity="srl_ethz", name=loaded_run.name, config=cfg, mode=args.wandb_mode)
48 | config = wandb.config
49 |
50 | device = "cuda" if torch.cuda.is_available() else "cpu"
51 |
52 | with open(base_path + "/data/VQA/Meta/meta_train.json") as file:
53 | train_data = json.load(file)
54 | with open(base_path + "/data/VQA/Meta/meta_test.json") as file:
55 | test_data = json.load(file)
56 |
57 | train_tasks = list(train_data.keys())
58 | test_tasks = list(test_data.keys())
59 |
60 | hyperclip_path = base_path + "/evaluation/hyperclip/hyperclip_"+str(loaded_run.name)+".pth"
61 |
62 | hnet_gen, hnet_enc = None, None
63 | if "vae_checkpoint" in config and config["vae_checkpoint"] is not None:
64 | meta_module, hnet_gen, hnet_enc, precomputed_latent, precomputed_latent_train_eval, precomputed_latent_val_eval = \
65 | load_vae_and_metamodel_from_checkpoint(config, device)
66 | else:
67 | meta_module, precomputed_latent, precomputed_latent_train_eval, precomputed_latent_val_eval = \
68 | load_metamodel_from_checkpoint(config, device)
69 |
70 | # pre-computed features
71 | image_features = load_image_features(config["clip_model"])
72 | text_features = load_text_features(config["clip_model"])
73 | ques_emb = load_ques_features(config["clip_model"])
74 |
75 | hyperclip = build_hyperclip_from_classic_clip(
76 | os.path.expanduser(clip_utils.cached_location[config["clip_model"]]),
77 | hyper_model=config["hyperclip_model"],
78 | mainnet_param_count=meta_module.mnet.get_parameter_vector().shape[0],
79 | hyper_hidden_dims=[] if config["hyperclip_hidden_dim"] == "" else [int(i) for i in config["hyperclip_hidden_dim"].split(",")],
80 | pretrained_it_location=os.path.expanduser(clip_utils.cached_location[config["clip_model"]]),
81 | pretrained_hyper_location=None).to(device)
82 |
83 | hyperclip.hyper.load_state_dict(torch.load(hyperclip_path), strict=False)
84 | hyperclip.hyper.eval()
85 |
86 | unguided_inner_optim = partial(build_optimizer, config=config, loop="inner")
87 | few_shot_saver = StoreFewShotLatent( meta_module,
88 | image_features=image_features,
89 | text_features=text_features,
90 | ques_emb=ques_emb,
91 | config=config,
92 | device=device, compute_hessian=False,
93 | reset_normal_embedding=config["vae_stochastic_init"] and "vae_checkpoint" in config)
94 |
95 | _, _, _, train_tasks_optimized_params = few_shot_saver.run_epoch(train_data, config["inner_epochs"], unguided_inner_optim,
96 | batch_size=len(list(train_data.keys())),
97 | train=False, skip_cond=True,
98 | train_subtype = config["train_subtype"], val_subtype=config["val_subtype"],
99 | output_mean_std=True,
100 | debug=True)
101 |
102 |
103 | hyperclip_ques_batch = torch.zeros(config["task_batch_size"], clip_utils.embedding_size[config["clip_model"]], dtype=torch.float32).to(device)
104 | hyperclip_net_batch = torch.zeros(config["task_batch_size"], hyperclip.hyper.input_dim, dtype=torch.float32).to(device)
105 |
106 | train_correct = 0
107 | tot_loss_train = 0
108 |
109 | for i, task in enumerate(tqdm(train_tasks)):
110 |
111 | prob_other_sample = np.ones(len(train_tasks)) * (1.0 / (len(train_tasks)-1))
112 | prob_other_sample[i] = 0.0
113 | other_task_ids = np.random.choice(len(train_tasks), size=config["task_batch_size"]-1, replace=False, p=prob_other_sample)
114 |
115 | hyperclip_ques_batch[0] = ques_emb[task]
116 |
117 | opt_latent = train_tasks_optimized_params[i]
118 | opt_weights = get_weights(meta_module, hnet_gen, hnet_enc, opt_latent, config)
119 | hyperclip_net_batch[0] = opt_weights
120 |
121 | for batch_i, other_id in enumerate(other_task_ids):
122 | hyperclip_ques_batch[batch_i+1] = ques_emb[train_tasks[other_id]]
123 | opt_latent = train_tasks_optimized_params[other_id]
124 | opt_weights = get_weights(meta_module, hnet_gen, hnet_enc, opt_latent, config)
125 | hyperclip_net_batch[batch_i+1] = opt_weights
126 |
127 | with torch.no_grad():
128 | (features, logits), hyperclip_loss = hyperclip.forward(hyper = hyperclip_net_batch, text = hyperclip_ques_batch, precomputed_it_embs = True)
129 | tot_loss_train += hyperclip_loss.detach().cpu().numpy()
130 |
131 | _, logits_hyper_ques, _ = logits
132 |
133 | logits_hyper_ques = logits_hyper_ques[0] #consider the classification of the true task at position 0 against randomly sampled other tasks
134 | _, pred_index = logits_hyper_ques.topk(1)
135 | if pred_index == 0:
136 | train_correct += 1
137 |
138 | train_accuracy = train_correct / len(train_tasks)
139 | print(f"Trainset hyperclip accuracy: {train_accuracy}")
140 | train_avg_loss = tot_loss_train / len(train_tasks)
141 | print(f"Trainset hyperclip average loss: {train_avg_loss}")
142 | wandb.log({"hyperclip_trainset_accuracy": train_accuracy, "hyperclip_trainset_avg_loss": train_avg_loss})
143 |
144 |
145 | _, _, _, test_tasks_optimized_params = few_shot_saver.run_epoch(test_data, config["inner_epochs"], unguided_inner_optim,
146 | batch_size=len(list(test_data.keys())),
147 | train=False, skip_cond=True,
148 | train_subtype = config["train_subtype"], val_subtype=config["train_subtype"],
149 | output_mean_std=True,
150 | debug=True)
151 |
152 | test_correct = 0
153 | tot_loss_test = 0
154 |
155 | for i, task in enumerate(tqdm(test_tasks)):
156 |
157 | prob_other_sample = np.ones(len(test_tasks)) * (1.0 / (len(test_tasks)-1))
158 | prob_other_sample[i] = 0.0
159 | other_task_ids = np.random.choice(len(test_tasks), size=config["task_batch_size"]-1, replace=False, p=prob_other_sample)
160 |
161 | hyperclip_ques_batch[0] = ques_emb[task]
162 | opt_latent = test_tasks_optimized_params[i]
163 | opt_weights = get_weights(meta_module, hnet_gen, hnet_enc, opt_latent, config)
164 | hyperclip_net_batch[0] = opt_weights
165 | for batch_i, other_id in enumerate(other_task_ids):
166 | hyperclip_ques_batch[batch_i+1] = ques_emb[test_tasks[other_id]]
167 | opt_latent = test_tasks_optimized_params[other_id]
168 | opt_weights = get_weights(meta_module, hnet_gen, hnet_enc, opt_latent, config)
169 | hyperclip_net_batch[batch_i+1] = opt_weights
170 |
171 | with torch.no_grad():
172 | (features, logits), hyperclip_loss = hyperclip.forward(hyper = hyperclip_net_batch, text = hyperclip_ques_batch, precomputed_it_embs = True)
173 | tot_loss_test += hyperclip_loss.detach().cpu().numpy()
174 |
175 | _, logits_hyper_ques, _ = logits
176 |
177 | logits_hyper_ques = logits_hyper_ques[0] #consider the classification of the true task at position 0 against randomly sampled other tasks
178 | _, pred_index = logits_hyper_ques.topk(1)
179 | if pred_index == 0:
180 | test_correct += 1
181 |
182 | test_accuracy = test_correct / len(test_tasks)
183 | print(f"Testset hyperclip accuracy: {test_accuracy}")
184 | test_avg_loss = tot_loss_test / len(test_tasks)
185 | print(f"Testset hyperclip average loss: {test_avg_loss}")
186 | wandb.log({"hyperclip_testset_accuracy": test_accuracy, "hyperclip_testset_avg_loss": test_avg_loss})
187 |
188 | wandb.finish()
189 |
190 |
191 | def get_weights(meta_module, hnet_gen, hnet_enc, opt_latent, config):
192 | if "vae_checkpoint" in config and config["vae_checkpoint"] is not None and hnet_gen is not None:
193 | return hnet_gen(sample_from_enc(hnet_enc, get_weights_from_metamodule(meta_module, opt_latent))).detach()
194 | return get_weights_from_metamodule(meta_module, opt_latent)
195 |
196 | def get_weights_from_metamodule(meta_module, opt_latent):
197 | inner_params = {k: v.clone().detach().requires_grad_() for (k,v) in meta_module.get_inner_params().items()}
198 | inner_params["enet.embedding"] = opt_latent.clone().detach().requires_grad_()
199 | return meta_module.get_mainnet_weights(params = inner_params).detach()
200 |
201 | if __name__ == "__main__":
202 | args = parser.parse_args()
203 | main(args)
204 |
--------------------------------------------------------------------------------
/scripts/precompute_adaptation.py:
--------------------------------------------------------------------------------
1 | # Takes a hypernetwork and learns the hyperclip model on weight generated by embedding adaptation
2 |
3 | import argparse
4 | import json
5 | import os
6 | from functools import partial
7 |
8 | import numpy as np
9 | import torch
10 | import wandb
11 | from data.dataloader.coco_tasks import (filter_categories,
12 | load_coco_answer_features)
13 | from features.image_features import load_image_features
14 | from features.ques_features import load_ques_features
15 | from features.text_features import load_text_features
16 | from training.store_few_shot_latent import StoreFewShotLatent
17 | from utils.build_opt import build_optimizer
18 | from utils.init_utils import load_metamodel_from_checkpoint
19 | from utils.misc_utils import str2bool
20 |
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument('--wandb_mode', type=str, default='online',
23 | help='Set to "disabled" to disable Weights & Biases logging')
24 |
25 | parser.add_argument('--epochs', type=int)
26 |
27 | parser.add_argument('--inner_epochs', type=str)
28 | parser.add_argument('--inner_optimizer', type=str)
29 | parser.add_argument('--inner_learning_rate', type=float)
30 |
31 | parser.add_argument('--compute_hessian', type=str2bool)
32 |
33 | parser.add_argument('--few_shot_checkpoint', type=str, help='required')
34 | parser.add_argument('--vae_checkpoint', type=str, help='required')
35 | parser.add_argument('--vae_stochastic_init', type=str2bool)
36 |
37 | parser.add_argument('--use_extended_coco', type=str2bool)
38 | parser.add_argument('--extend_coco_size', type=int)
39 | parser.add_argument('--extend_coco_frac_train', type=float)
40 |
41 | parser.add_argument('--data_subtype', type=str)
42 |
43 |
44 | base_path = os.path.dirname(os.path.dirname(__file__))
45 |
46 | default_config = {
47 | "epochs": 100,
48 | "inner_epochs": '10',
49 | "inner_optimizer": "sgd",
50 | "inner_learning_rate": 0.1,
51 | "compute_hessian": False,
52 | "vae_stochastic_init": False,
53 |
54 | "clip_model": "ViT-L/14@336px",
55 |
56 | "use_extended_coco": False,
57 | "extend_coco_size": 10 * 870,
58 | "extend_coco_frac_train": 0.5,
59 | "data_subtype": "random",
60 | }
61 |
62 | torch.manual_seed(42)
63 | rng = np.random.RandomState(42)
64 | np.random.seed(42)
65 |
66 | def main(args):
67 | cfg = default_config
68 | cfg.update({k: v for (k, v) in vars(args).items() if v is not None})
69 | print(cfg)
70 | wandb.init(project='precompute_adaptation', entity="srl_ethz", config=cfg,
71 | mode=args.wandb_mode)
72 | config = wandb.config
73 |
74 | log_file = base_path + "/evaluation/precompute_adaptation/" + str(wandb.run.name) + ".pth"
75 | log_file_train_eval = base_path + "/evaluation/precompute_adaptation/" + str(wandb.run.name) + "_train_eval.pth"
76 | log_file_val_eval = base_path + "/evaluation/precompute_adaptation/" + str(wandb.run.name) + "_val_eval.pth"
77 |
78 | # Load the model
79 | device = "cuda" if torch.cuda.is_available() else "cpu"
80 |
81 | with open(base_path + "/data/VQA/Meta/meta_train.json") as file:
82 | train_data = json.load(file)
83 | with open(base_path + "/data/VQA/Meta/meta_test.json") as file:
84 | test_data = json.load(file)
85 |
86 | reset_normal_embedding = config["vae_stochastic_init"] and "vae_checkpoint" in config
87 |
88 | meta_module, _, _, _ = load_metamodel_from_checkpoint(config, device)
89 |
90 | # pre-computed features
91 | image_features = load_image_features(config["clip_model"])
92 | text_features = load_text_features(config["clip_model"])
93 | ques_emb = load_ques_features(config["clip_model"])
94 |
95 | coco_categories = None
96 | coco_answer_features = None
97 | if config["use_extended_coco"]:
98 | coco_categories = np.load(base_path+"/data/Attributes/vanilla_coco_categories.npy", allow_pickle=True).item()
99 | coco_categories = filter_categories(coco_categories, 10)
100 | # right now it's hardcoded, it's train_size + test_size for the extended coco sampled datasets
101 | coco_answer_features = load_coco_answer_features(config["clip_model"])
102 |
103 | unguided_inner_optim = partial(build_optimizer, config=config, loop="inner")
104 | few_shot_saver = StoreFewShotLatent( meta_module,
105 | image_features=image_features,
106 | text_features=text_features,
107 | ques_emb=ques_emb,
108 | config=config,
109 | device=device, compute_hessian=config["compute_hessian"],
110 | reset_normal_embedding=reset_normal_embedding,
111 | coco_categories=coco_categories, coco_answer_features=coco_answer_features,
112 | extend_coco_size=config["extend_coco_size"])
113 |
114 | print("Computing metric")
115 | from utils.train_utils import log_metric
116 | log_dict = few_shot_saver.run_epoch(train_data, config["inner_epochs"], unguided_inner_optim,
117 | batch_size=len(list(train_data.keys())),train=True,
118 | log_file=log_file_train_eval)
119 | log_metric(log_dict, "eval_train/")
120 | log_dict = few_shot_saver.run_epoch(test_data, config["inner_epochs"], unguided_inner_optim,
121 | batch_size=len(list(test_data.keys())),train=True,
122 | log_file=log_file_val_eval)
123 | log_metric(log_dict, "eval_val/")
124 |
125 | for meta_epoch in range(config["epochs"]):
126 | few_shot_saver.run_epoch(train_data, config["inner_epochs"], unguided_inner_optim,
127 | batch_size=len(list(train_data.keys())),
128 | extend_coco=config["use_extended_coco"],
129 | extend_coco_frac_train=config["extend_coco_frac_train"],
130 | train=True, train_subtype = config["data_subtype"], val_subtype=config["data_subtype"], debug=True, log_file=log_file)
131 |
132 |
133 | wandb.finish()
134 |
135 |
136 | if __name__ == "__main__":
137 | args = parser.parse_args()
138 | main(args)
139 |
--------------------------------------------------------------------------------
/scripts/precompute_image_features.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from features.image_features import compute_image_features
4 |
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument('--model', type=str, default='ViT-L/14@336px', help='CLIP visual encoder')
7 |
8 |
9 | if __name__ == "__main__":
10 | args = parser.parse_args()
11 | compute_image_features(args.model)
12 |
--------------------------------------------------------------------------------
/scripts/precompute_ques_features.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from features.ques_features import compute_ques_features
4 |
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument('--model', type=str, default='ViT-L/14@336px', help='CLIP visual encoder')
7 |
8 |
9 | if __name__ == "__main__":
10 | args = parser.parse_args()
11 | compute_ques_features(args.model)
12 |
--------------------------------------------------------------------------------
/scripts/precompute_text_features.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from features.text_features import compute_text_features
4 |
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument('--model', type=str, default='ViT-L/14@336px', help='CLIP visual encoder')
7 |
8 |
9 | if __name__ == "__main__":
10 | args = parser.parse_args()
11 | compute_text_features(args.model)
12 |
--------------------------------------------------------------------------------
/scripts/train_few_shot.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 | import numpy as np
6 | import torch
7 | import wandb
8 | from data.dataloader.coco_tasks import (filter_categories,
9 | load_coco_answer_features)
10 | from features.image_features import load_image_features
11 | from features.ques_features import load_ques_features
12 | from features.text_features import load_text_features
13 | from model.custom_hnet import MetaModel
14 | from training.maml_learn import MAML
15 | from utils import clip_utils
16 | from utils.build_opt import build_optimizer
17 | from utils.misc_utils import str2bool
18 | from utils.train_utils import log_metric, n_shot_trials_run
19 |
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument('--wandb_mode', type=str, default='online',
22 | help='Set to "disabled" to disable Weights & Biases logging')
23 |
24 | parser.add_argument('--meta_epochs', type=int)
25 | parser.add_argument('--meta_batch_size', type=int)
26 |
27 | parser.add_argument('--inner_epochs', type=str)
28 |
29 | parser.add_argument('--inner_learning_rate', type=float)
30 | parser.add_argument('--eval_inner_epochs', type=str)
31 | parser.add_argument('--second_order', type=str2bool)
32 | parser.add_argument('--train_subtype', type=str)
33 | parser.add_argument('--val_subtype', type=str)
34 | parser.add_argument('--meta_optimizer', type=str)
35 | parser.add_argument('--meta_learning_rate', type=float)
36 | parser.add_argument('--meta_grad_clip', type=float)
37 |
38 | parser.add_argument('--alpha', type=float)
39 |
40 | # meta_module
41 | parser.add_argument('--inner_param', type=str)
42 | parser.add_argument('--hypernet_hidden_dim', type=str)
43 | parser.add_argument('--straight_through', type=str2bool)
44 | parser.add_argument('--embedding_dim', type=int)
45 |
46 | parser.add_argument('--keep_tasks_frac', type=float)
47 |
48 | parser.add_argument('--load_checkpoint', type=str)
49 | parser.add_argument('--val_epoch_interval', type=int)
50 |
51 | parser.add_argument('--use_extended_coco', type=str2bool)
52 | parser.add_argument('--extend_coco_size', type=int)
53 | parser.add_argument('--extend_coco_frac_train', type=float)
54 |
55 | parser.add_argument('--use_clip_embedding_init', type=str2bool)
56 | parser.add_argument('--save_checkpoint', type=str2bool)
57 |
58 | parser.add_argument('--seed', type=int, default=42)
59 | parser.add_argument('--checkpoint', type=str)
60 | parser.add_argument('--eval', type=str2bool, default=False)
61 |
62 | parser.add_argument('--n_shot_trials_maxN', type=int)
63 |
64 | base_path = os.path.dirname(os.path.dirname(__file__))
65 |
66 | args = parser.parse_args()
67 |
68 | torch.manual_seed(args.seed)
69 | rng = np.random.RandomState(args.seed)
70 | np.random.seed(args.seed)
71 |
72 | default_config = {
73 | # meta_module
74 | "use_clip_embedding_init": False,
75 | "inner_param": "enet",
76 | "hypernet_hidden_dim": "128,128,128",
77 | "straight_through": False,
78 | "mainnet_use_bias": True,
79 | "mainnet_hidden_dim": [256],
80 | "embedding_dim": 128,
81 |
82 | "alpha": 1,
83 |
84 | "clip_model": "ViT-L/14@336px",
85 |
86 | "inner_epochs": "10",
87 | "inner_learning_rate": 0.1,
88 | "train_subtype": "test",
89 | "val_subtype": "train",
90 |
91 | "meta_epochs": 1000,
92 | "meta_batch_size": 32,
93 | "second_order": False,
94 | "meta_grad_clip": 10,
95 |
96 | "eval_inner_epochs": '',
97 | "meta_optimizer": "adam",
98 | "meta_learning_rate": 0.001,
99 |
100 | "val_epoch_interval": 25,
101 | "save_checkpoint": False,
102 | "load_checkpoint": "",
103 | "keep_tasks_frac": 1,
104 |
105 | "use_extended_coco": False,
106 | "extend_coco_size": 10 * 870,
107 | "extend_coco_frac_train": 0.5
108 | }
109 |
110 |
111 | def main():
112 | cfg=default_config
113 | cfg.update({k:v for (k,v) in vars(args).items() if v is not None})
114 |
115 | wandb.init(project="train_few_shot", entity="srl_ethz", config=cfg,
116 | mode=args.wandb_mode)
117 | config = wandb.config
118 |
119 | # Load the model
120 | device = "cuda" if torch.cuda.is_available() else "cpu"
121 |
122 | with open(base_path + "/data/VQA/Meta/meta_train.json") as file:
123 | train_data = json.load(file)
124 | with open(base_path + "/data/VQA/Meta/meta_test.json") as file:
125 | test_data = json.load(file)
126 |
127 | # pre-computed features
128 | image_features = load_image_features(config["clip_model"])
129 | text_features = load_text_features(config["clip_model"])
130 | ques_emb = load_ques_features(config["clip_model"])
131 |
132 | coco_categories = None
133 | coco_answer_features = None
134 | if config["use_extended_coco"]:
135 | coco_categories = np.load(base_path+"/data/Attributes/vanilla_coco_categories.npy", allow_pickle=True).item()
136 | coco_categories = filter_categories(coco_categories, 10) # right now it's hardcoded, it's train_size + test_size for the extended coco sampled datasets
137 | coco_answer_features = load_coco_answer_features(config["clip_model"])
138 |
139 | meta_module = MetaModel(
140 | inner_param=config["inner_param"],
141 | mainnet_use_bias=config["mainnet_use_bias"],
142 | mainnet_hidden_dim=config["mainnet_hidden_dim"],
143 | hypernet_hidden_dim=[] if config["hypernet_hidden_dim"]=="" else [int(i) for i in config["hypernet_hidden_dim"].split(",")],
144 | embedding_dim=config["embedding_dim"] if not config["use_clip_embedding_init"] else clip_utils.embedding_size[config["clip_model"]],
145 | straight_through=config["straight_through"],
146 | config=config).to(device)
147 |
148 | if "checkpoint" in config:
149 | loaded_model_path = config["checkpoint"]
150 | meta_module.load_state_dict(torch.load(loaded_model_path), strict=False)
151 |
152 | meta_optimizer = build_optimizer(meta_module.meta_params, config, loop="meta")
153 | meta_trainer = MAML(meta_module, meta_optimizer, image_features, text_features, ques_emb, config, coco_categories, coco_answer_features, extend_coco_size=config["extend_coco_size"])
154 |
155 | if "n_shot_trials_maxN" in config and config["n_shot_trials_maxN"] is not None:
156 | n_shot_trials = []
157 |
158 | best_val_acc=[0]*(1+len(config["eval_inner_epochs"].split(",")))
159 | best_val_epoch=[0]*(1+len(config["eval_inner_epochs"].split(",")))
160 |
161 | for meta_epoch in range(config["meta_epochs"]):
162 |
163 | if not config["eval"]:
164 | meta_trainer.run_epoch(train_data, config["inner_epochs"], config["inner_learning_rate"], meta_batch_size=config["meta_batch_size"],
165 | train=True, second_order=config["second_order"],
166 | train_subtype = config["train_subtype"], val_subtype=config["val_subtype"], debug=True, keep_tasks_frac=config["keep_tasks_frac"],
167 | extend_coco=config["use_extended_coco"], extend_coco_frac_train=config["extend_coco_frac_train"], device=device)
168 |
169 | if config["eval"] or meta_epoch % config["val_epoch_interval"] == 0 or meta_epoch == config["meta_epochs"]-1:
170 | log_dict = meta_trainer.run_epoch(train_data, config["inner_epochs"], config["inner_learning_rate"], keep_tasks_frac=config["keep_tasks_frac"], device=device, epoch=meta_epoch)
171 | log_metric(log_dict, "eval_train/")
172 | log_dict = meta_trainer.run_epoch(test_data, config["inner_epochs"], config["inner_learning_rate"], device=device, epoch=meta_epoch)
173 |
174 | if best_val_acc[0] < log_dict["query_accuracy_end"]:
175 | best_val_acc[0] = log_dict["query_accuracy_end"]
176 | best_val_epoch[0] = meta_epoch
177 | log_dict["best_accuracy"] = best_val_acc[0]
178 | log_dict["best_epoch"] = best_val_epoch[0]
179 | log_metric(log_dict, "eval_val/")
180 |
181 | if log_dict["query_accuracy_end"] < 0.3:
182 | print("Stopping training")
183 | return
184 |
185 | if "n_shot_trials_maxN" in config and config["n_shot_trials_maxN"] is not None:
186 | n_shot_trial_dict = {"epoch": meta_epoch}
187 |
188 | if config["eval_inner_epochs"] != '':
189 | for idx, inner_epochs in enumerate(config["eval_inner_epochs"].split(",")):
190 | # log_dict = meta_trainer.run_epoch(train_data, int(inner_epochs), config["inner_learning_rate"], keep_tasks_frac=config["keep_tasks_frac"], device=device)
191 | # log_dict["epoch"]=meta_epoch
192 | # log_metric(log_dict, "eval_train_{}step/".format(inner_epochs))
193 |
194 | log_dict = meta_trainer.run_epoch(test_data, int(inner_epochs), config["inner_learning_rate"], device=device, epoch=meta_epoch)
195 |
196 | if best_val_acc[idx+1] < log_dict["query_accuracy_end"]:
197 | best_val_acc[idx+1] = log_dict["query_accuracy_end"]
198 | best_val_epoch[idx+1] = meta_epoch
199 | log_dict["best_accuracy"] = best_val_acc[idx+1]
200 | log_dict["best_epoch"] = best_val_epoch[idx+1]
201 |
202 | log_metric(log_dict, "eval_val_{}step/".format(inner_epochs))
203 |
204 | n_shot_trials_run(meta_trainer, n_shot_trial_dict, config, f"eval_val_{inner_epochs}step/", test_data, int(inner_epochs), config["inner_learning_rate"], device=device, epoch=meta_epoch)
205 |
206 | if "n_shot_trials_maxN" in config and config["n_shot_trials_maxN"] is not None:
207 | n_shot_trials += [n_shot_trial_dict]
208 |
209 | if config["eval"]:
210 | return
211 |
212 | if config["save_checkpoint"] and (meta_epoch+1) % 20 == 0:
213 | model_output_path_checkpoint = base_path + "/evaluation/few_shot/meta_module" + str(wandb.run.name) + "_" + str( meta_epoch) + ".pth"
214 | torch.save(meta_module.state_dict(), model_output_path_checkpoint)
215 | print(f"Checkpoint for meta-epoch {meta_epoch} saved!")
216 |
217 | if "n_shot_trials_maxN" in config and config["n_shot_trials_maxN"] is not None:
218 | model_output_path_checkpoint = base_path + "/evaluation/few_shot/few_shot_" + str(
219 | wandb.run.name) + "_n_shot.npy"
220 | np.save(model_output_path_checkpoint, n_shot_trials)
221 |
222 | model_output_path_checkpoint = base_path + "/evaluation/few_shot/few_shot_" + str(wandb.run.name) + ".pth"
223 | torch.save(meta_module.state_dict(), model_output_path_checkpoint)
224 |
225 | wandb.finish()
226 |
227 |
228 | if __name__ == "__main__":
229 | main()
230 |
--------------------------------------------------------------------------------
/scripts/train_hyperclip.py:
--------------------------------------------------------------------------------
1 | # Takes a hypernetwork and learns the hyperclip model on weight generated by embedding adaptation
2 |
3 |
4 | import argparse
5 | import json
6 | import os
7 | from functools import partial
8 |
9 | import numpy as np
10 | import torch
11 | import torch.optim as optim
12 | import wandb
13 | from features.image_features import load_image_features
14 | from features.ques_features import load_ques_features
15 | from features.text_features import load_text_features
16 | from model.hyperclip import build_hyperclip_from_classic_clip
17 | from training.hyperclip_learn import HyperclipTraining
18 | from utils import clip_utils
19 | from utils.build_opt import build_optimizer
20 | from utils.init_utils import (load_metamodel_from_checkpoint,
21 | load_vae_and_metamodel_from_checkpoint)
22 | from utils.misc_utils import str2bool
23 | from utils.train_utils import log_metric
24 |
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument('--wandb_mode', type=str, default='online',
27 | help='Set to "disabled" to disable Weights & Biases logging')
28 |
29 | parser.add_argument('--inner_epochs', type=str)
30 |
31 | parser.add_argument('--few_shot_checkpoint', type=str, help='required')
32 | parser.add_argument('--vae_checkpoint', type=str, help='required')
33 | parser.add_argument('--precompute_checkpoint', type=str, help='required')
34 |
35 | parser.add_argument('--hyperclip_epochs', type=int)
36 | parser.add_argument('--hyperclip_batch_size', type=int)
37 | parser.add_argument('--val_epoch_interval', type=int)
38 |
39 |
40 | parser.add_argument('--hyperclip_optimizer', type=str)
41 | parser.add_argument('--hyperclip_learning_rate', type=float)
42 | parser.add_argument('--eval_inner_epochs', type=str)
43 | parser.add_argument('--guidance_optimizer', type=str)
44 | parser.add_argument('--guidance_learning_rate', type=float)
45 |
46 | parser.add_argument('--guidance_scheduler', type=str)
47 |
48 | # hyperclip
49 | parser.add_argument('--hyperclip_hidden_dim', type=str)
50 | parser.add_argument('--hyperclip_model', type=str)
51 |
52 | parser.add_argument('--guidance_init_l2_weight', type=str)
53 | parser.add_argument('--train_on_vae', type=str2bool)
54 |
55 | parser.add_argument('--checkpoint', type=str)
56 | parser.add_argument('--eval', type=str2bool)
57 | parser.add_argument('--langevin_eps', type=str)
58 |
59 | parser.add_argument('--normalize', type=str2bool)
60 |
61 |
62 | base_path = os.path.dirname(os.path.dirname(__file__))
63 |
64 |
65 | default_config = {
66 | "inner_epochs": '50',
67 | "inner_optimizer": "sgd",
68 | "inner_learning_rate": 0.1,
69 | "train_subtype": "random",
70 | "val_subtype": "random",
71 |
72 | "clip_model": "ViT-L/14@336px",
73 | "hyperclip_model": "mlp",
74 | "hyperclip_hidden_dim": "128,128",
75 |
76 | "hyperclip_epochs": 1000,
77 | "hyperclip_batch_size": 32,
78 |
79 | "hyperclip_optimizer": "adam",
80 | "hyperclip_learning_rate": 0.001,
81 | "hyperclip_weight_decay": 0,
82 | "hyperclip_momentum": 0.9,
83 | "hyperclip_sgd_nesterov": True,
84 |
85 | "eval_inner_epochs": '50',
86 | "guidance_optimizer": "adam",
87 | "guidance_learning_rate": 0.001,
88 | "guidance_momentum": 0.9,
89 | "guidance_sgd_nesterov": True,
90 | "guidance_init_l2_weight": "0",
91 |
92 | "guidance_scheduler": "none",
93 |
94 | "val_epoch_interval": 100,
95 | "save_checkpoint": False,
96 | "train_on_vae":False,
97 |
98 | "eval": False,
99 | "normalize":False,
100 | "langevin_eps": "0"
101 | }
102 |
103 | torch.manual_seed(42)
104 | rng = np.random.RandomState(42)
105 | np.random.seed(42)
106 |
107 | def main(args):
108 | cfg = default_config
109 | cfg.update({k: v for (k, v) in vars(args).items() if v is not None})
110 | print(cfg)
111 | wandb.init(project='train_hyperclip', entity="srl_ethz", config=cfg,
112 | mode=args.wandb_mode)
113 | config = wandb.config
114 |
115 | # Load the model
116 | device = "cuda" if torch.cuda.is_available() else "cpu"
117 |
118 | with open(base_path + "/data/VQA/Meta/meta_train.json") as file:
119 | train_data = json.load(file)
120 | with open(base_path + "/data/VQA/Meta/meta_test.json") as file:
121 | test_data = json.load(file)
122 |
123 | hnet_gen, hnet_enc = None, None
124 | if "vae_checkpoint" in config and config["vae_checkpoint"] is not None:
125 | meta_module, hnet_gen, hnet_enc, precomputed_latent, precomputed_latent_train_eval, precomputed_latent_val_eval = \
126 | load_vae_and_metamodel_from_checkpoint(config, device)
127 | else:
128 | meta_module, precomputed_latent, precomputed_latent_train_eval, precomputed_latent_val_eval = \
129 | load_metamodel_from_checkpoint(config, device)
130 |
131 | # pre-computed features
132 | image_features = load_image_features(config["clip_model"])
133 | text_features = load_text_features(config["clip_model"])
134 | ques_emb = load_ques_features(config["clip_model"])
135 |
136 | hyperclip = build_hyperclip_from_classic_clip(
137 | os.path.expanduser(clip_utils.cached_location[config["clip_model"]]),
138 | hyper_model=config["hyperclip_model"],
139 | mainnet_param_count=meta_module.mnet.get_parameter_vector().shape[0],
140 | hyper_hidden_dims=[] if config["hyperclip_hidden_dim"] == "" else [int(i) for i in config["hyperclip_hidden_dim"].split(",")],
141 | pretrained_it_location=os.path.expanduser(clip_utils.cached_location[config["clip_model"]]),
142 | pretrained_hyper_location=None).to(device)
143 |
144 | if "checkpoint" in config:
145 | api = wandb.Api()
146 | loaded_run = api.run(config["checkpoint"])
147 | loaded_model_path = base_path + "/evaluation/hyperclip/hyperclip_" + str(loaded_run.name) + ".pth"
148 | hyperclip.hyper.load_state_dict(torch.load(loaded_model_path), strict=False)
149 |
150 |
151 | hyperclip.hyper.train()
152 | hyperclip_optimizer = build_optimizer(hyperclip.hyper.parameters(), config, loop="hyperclip")
153 | hyperclip_training = HyperclipTraining(meta_module, hnet_gen, hnet_enc,
154 | hyperclip, hyperclip_optimizer, image_features, text_features,
155 | ques_emb, config, device, train_on_vae=config["train_on_vae"])
156 |
157 | unguided_inner_optim =partial(build_optimizer, config=config, loop="inner")
158 | guided_inner_optim =partial(build_optimizer, config=config, loop="guidance")
159 |
160 |
161 | # Get STD and MEAN
162 | if precomputed_latent is not None:
163 | sampled_precomputed_latent = {k:v[0:1].to(device) for (k,v) in precomputed_latent.items()}
164 | else:
165 | sampled_precomputed_latent=None
166 |
167 | _, optimized_params_mean, optimized_params_std, _ = hyperclip_training.run_epoch(train_data, config["inner_epochs"], unguided_inner_optim,
168 | batch_size=config["hyperclip_batch_size"], precomputed_latent=sampled_precomputed_latent,
169 | output_mean_std=True, train_subtype = config["train_subtype"], val_subtype=config["val_subtype"], skip_cond=True)
170 |
171 | optimized_params_mean, optimized_params_std=optimized_params_mean[0], optimized_params_std[0].clip(0.01)
172 |
173 | if config["normalize"]:
174 | hyperclip_training.set_stats(optimized_params_mean, optimized_params_std)
175 | wandb.log({"mean":optimized_params_mean, "std":optimized_params_std})
176 |
177 | for meta_epoch in range(config["hyperclip_epochs"]):
178 | if precomputed_latent is not None:
179 | curr_sample_epoch = np.random.randint(0, precomputed_latent["clip_embedding"].shape[0])
180 | curr_sample_batch_perm = np.random.permutation(precomputed_latent["clip_embedding"].shape[1])
181 | sampled_precomputed_latent = {k:v[curr_sample_epoch:curr_sample_epoch+1, curr_sample_batch_perm].to(device) for (k,v) in precomputed_latent.items()}
182 | else:
183 | sampled_precomputed_latent=None
184 |
185 | if not config["eval"]:
186 | hyperclip_training.run_epoch(train_data, config["inner_epochs"], unguided_inner_optim,
187 | batch_size=config["hyperclip_batch_size"], precomputed_latent=sampled_precomputed_latent,
188 | train=True, train_subtype = config["train_subtype"], val_subtype=config["val_subtype"], debug=True)
189 |
190 | if config["eval"] or (meta_epoch + 1) % (config["val_epoch_interval"]//10+1) == 0:
191 |
192 | log_dict = hyperclip_training.run_epoch(train_data, config["inner_epochs"], unguided_inner_optim, precomputed_latent=precomputed_latent_train_eval, epoch=meta_epoch)
193 | log_metric(log_dict, "eval_train/")
194 | log_dict = hyperclip_training.run_epoch(test_data, config["inner_epochs"], unguided_inner_optim, precomputed_latent=precomputed_latent_val_eval, epoch=meta_epoch)
195 | log_metric(log_dict, "eval_val/")
196 |
197 | if config["eval"] or (meta_epoch + 1) % config["val_epoch_interval"] == 0:
198 | if config["eval_inner_epochs"] != '':
199 | def eval(inner_epochs, guidance_init_l2_weight, guidance_scheduler, langevin_eps, train=False):
200 | guidance_scheduler_fn=None
201 | if guidance_scheduler == "cos":
202 | guidance_scheduler_fn = partial(optim.lr_scheduler.CosineAnnealingLR, T_max=inner_epochs, eta_min=0)
203 |
204 | if train:
205 | log_dict = hyperclip_training.run_epoch(train_data, inner_epochs, guided_inner_optim,
206 | guided_inner=True, use_vae=True, init_guidance_at="pre-trained", skip_cond=True, guidance_init_l2_weight=guidance_init_l2_weight, langevin_eps=langevin_eps, guidance_scheduler_fn=guidance_scheduler_fn, epoch=meta_epoch)
207 | log_metric(log_dict, "guided_eval_train_{}step_{}l2_{}_{}/".format(inner_epochs, guidance_init_l2_weight, guidance_scheduler, langevin_eps))
208 |
209 | log_dict = hyperclip_training.run_epoch(test_data, inner_epochs, guided_inner_optim, guided_inner=True, use_vae=True, init_guidance_at="pre-trained",
210 | skip_cond=True, guidance_init_l2_weight=guidance_init_l2_weight, langevin_eps=langevin_eps, guidance_scheduler_fn=guidance_scheduler_fn, epoch=meta_epoch)
211 |
212 | log_metric(log_dict, "guided_eval_val_{}step_{}l2_{}_{}/".format(inner_epochs, guidance_init_l2_weight, guidance_scheduler, langevin_eps))
213 |
214 | for idx, inner_epochs in enumerate(config["eval_inner_epochs"].split(",")):
215 | for _, guidance_init_l2_weight in enumerate(config["guidance_init_l2_weight"].split(",")):
216 | for _, guidance_scheduler in enumerate(config["guidance_scheduler"].split(",")):
217 | for _, langevin_eps in enumerate(config["langevin_eps"].split(",")):
218 | eval(int(inner_epochs), float(guidance_init_l2_weight), guidance_scheduler, float(langevin_eps), train=(meta_epoch + 1) % 1000 == 0)
219 |
220 | if config["eval"]:
221 | return
222 |
223 | if config["save_checkpoint"]:
224 | hyperclip_output_path_checkpoint = base_path + "/evaluation/hyperclip/hyperclip_" + str(
225 | wandb.run.name) + "_" + str(
226 | meta_epoch) + ".pth"
227 | torch.save(hyperclip.hyper.state_dict(), hyperclip_output_path_checkpoint)
228 | print(f"Checkpoint for meta-epoch {meta_epoch} saved!")
229 |
230 | hyperclip_output_path_checkpoint = base_path + "/evaluation/hyperclip/hyperclip_" + str(
231 | wandb.run.name)+ ".pth"
232 | torch.save(hyperclip.hyper.state_dict(), hyperclip_output_path_checkpoint)
233 |
234 | wandb.finish()
235 |
236 |
237 | if __name__ == "__main__":
238 | args = parser.parse_args()
239 | main(args)
240 |
--------------------------------------------------------------------------------
/scripts/train_vae.py:
--------------------------------------------------------------------------------
1 | # Takes a base model and learns a generative hypernetwork using adapted weights as data.
2 |
3 | import argparse
4 | import json
5 | import os
6 | from functools import partial
7 |
8 | import numpy as np
9 | import torch
10 | import wandb
11 | from features.image_features import load_image_features
12 | from features.ques_features import load_ques_features
13 | from features.text_features import load_text_features
14 | from model.custom_hnet import HyperEncoder, HyperGenerator
15 | from training.vae_learn import VAETraining
16 | from utils.build_opt import build_optimizer
17 | from utils.init_utils import load_metamodel_from_checkpoint
18 | from utils.misc_utils import str2bool
19 | from utils.train_utils import log_metric
20 |
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument('--wandb_mode', type=str, default='online',
23 | help='Set to "disabled" to disable Weights & Biases logging')
24 |
25 | parser.add_argument('--few_shot_checkpoint', type=str, help='required')
26 | parser.add_argument('--precompute_checkpoint', type=str, help='required')
27 |
28 | parser.add_argument('--vae_epochs', type=int)
29 | parser.add_argument('--vae_batch_size', type=int)
30 | parser.add_argument('--vae_noise_dim', type=int)
31 |
32 | parser.add_argument('--vae_optimizer', type=str)
33 | parser.add_argument('--vae_learning_rate', type=float)
34 | parser.add_argument('--kld_weight', type=float)
35 |
36 | parser.add_argument('--eval_inner_epochs', type=str)
37 | parser.add_argument('--generated_optimizer', type=str)
38 | parser.add_argument('--generated_learning_rate', type=float)
39 |
40 | # hyperclip
41 | parser.add_argument('--vae_hidden_dim', type=str)
42 | parser.add_argument('--val_epoch_interval', type=int)
43 | base_path = os.path.dirname(os.path.dirname(__file__))
44 |
45 | parser.add_argument('--normalize', type=str2bool)
46 | parser.add_argument('--grad_clip', type=float)
47 | parser.add_argument('--save_checkpoint', type=str2bool)
48 | parser.add_argument('--seed', type=int, default=42)
49 |
50 | args = parser.parse_args()
51 | torch.manual_seed(args.seed)
52 | rng = np.random.RandomState(args.seed)
53 | np.random.seed(args.seed)
54 | default_config = {
55 | "inner_epochs": 50,
56 | "inner_optimizer": "sgd",
57 | "inner_learning_rate": 0.1,
58 | "train_subtype": "random",
59 | "val_subtype": "random",
60 |
61 | "vae_hidden_dim": "128,128",
62 | "vae_noise_dim": 128,
63 |
64 | "vae_epochs": 10000,
65 | "vae_batch_size": 32,
66 | "kld_weight":1,
67 | "vae_optimizer": "adam",
68 | "vae_learning_rate": 0.0001,
69 |
70 | "eval_inner_epochs": '50',
71 | "generated_optimizer": "adam",
72 | "generated_learning_rate": 0.001,
73 | "generated_momentum": 0.9,
74 | "generated_sgd_nesterov": True,
75 |
76 | "val_epoch_interval": 2000,
77 | "clip_model": "ViT-L/14@336px",
78 | "save_checkpoint": False,
79 | "normalize": False,
80 | "grad_clip": -1,
81 | }
82 |
83 | def main(args):
84 | cfg = default_config
85 | cfg.update({k: v for (k, v) in vars(args).items() if v is not None})
86 | print(cfg)
87 | wandb.init(project="train_vae", entity="srl_ethz", config=cfg,
88 | mode=args.wandb_mode)
89 | config = wandb.config
90 |
91 | # Load the model
92 | device = "cuda" if torch.cuda.is_available() else "cpu"
93 |
94 | with open(base_path + "/data/VQA/Meta/meta_train.json") as file:
95 | train_data = json.load(file)
96 | with open(base_path + "/data/VQA/Meta/meta_test.json") as file:
97 | test_data = json.load(file)
98 |
99 | meta_module, precomputed_latent, precomputed_latent_train_eval, precomputed_latent_val_eval = \
100 | load_metamodel_from_checkpoint(config, device)
101 |
102 | # pre-computed features
103 | image_features = load_image_features(config["clip_model"])
104 | text_features = load_text_features(config["clip_model"])
105 | ques_emb = load_ques_features(config["clip_model"])
106 |
107 | hidden_dims = [int(h) for h in config["vae_hidden_dim"].split(",")]
108 | hnet_enc = HyperEncoder(meta_module.mnet, e_dim=config["vae_noise_dim"],
109 | hidden_dims=hidden_dims, normalize="normalize" in config and config["normalize"]).to(device)
110 | hidden_dims.reverse()
111 | hnet_gen = HyperGenerator(meta_module.mnet, e_dim=config["vae_noise_dim"],
112 | hidden_dims=hidden_dims, normalize="normalize" in config and config["normalize"]).to(device)
113 |
114 | optimizer_gen = build_optimizer(hnet_gen.parameters(), config, loop="vae")
115 | optimizer_enc = build_optimizer(hnet_enc.parameters(), config, loop="vae")
116 |
117 | vae_trainer = VAETraining(meta_module, hnet_gen, hnet_enc, optimizer_gen, optimizer_enc, image_features,
118 | text_features, ques_emb, config, device)
119 |
120 | inner_optim =partial(build_optimizer, config=config, loop="inner")
121 | generated_inner_optim =partial(build_optimizer, config=config, loop="generated")
122 |
123 | # Get STD and MEAN
124 | if precomputed_latent is not None:
125 | sampled_precomputed_latent = {k:v[0:1].to(device) for (k,v) in precomputed_latent.items()}
126 | else:
127 | sampled_precomputed_latent=None
128 |
129 | _, optimized_params_mean, optimized_params_std, _ = \
130 | vae_trainer.run_epoch(train_data, config["inner_epochs"], inner_optim, precomputed_latent=sampled_precomputed_latent,
131 | batch_size=config["vae_batch_size"],
132 | output_mean_std=True, device=device, train_subtype = config["train_subtype"], val_subtype=config["val_subtype"], debug=True)
133 |
134 | optimized_params_mean, optimized_params_std=optimized_params_mean[0], optimized_params_std[0]
135 | wandb.log({"mean":optimized_params_mean, "std":optimized_params_std, "std_avr": optimized_params_std.mean().item()})
136 |
137 | if config["normalize"]:
138 | optimized_params_std=optimized_params_std.clip(optimized_params_std.mean().item(),optimized_params_std.mean().item())
139 | hnet_gen.set_stats(optimized_params_mean, optimized_params_std)
140 | hnet_enc.set_stats(optimized_params_mean, optimized_params_std)
141 | vae_trainer.set_stats(optimized_params_mean, optimized_params_std)
142 |
143 | for meta_epoch in range(config["vae_epochs"]):
144 | if precomputed_latent is not None:
145 | curr_sample_epoch = np.random.randint(0, precomputed_latent["clip_embedding"].shape[0])
146 | curr_sample_batch_perm = np.random.permutation(precomputed_latent["clip_embedding"].shape[1])
147 | sampled_precomputed_latent = {k:v[curr_sample_epoch:curr_sample_epoch+1, curr_sample_batch_perm].to(device) for (k,v) in precomputed_latent.items()}
148 | else:
149 | sampled_precomputed_latent=None
150 |
151 | vae_trainer.run_epoch(train_data, config["inner_epochs"], inner_optim, precomputed_latent=sampled_precomputed_latent,
152 | batch_size=config["vae_batch_size"],
153 | train=True, device=device, train_subtype = config["train_subtype"], val_subtype=config["val_subtype"], debug=True)
154 |
155 | if (meta_epoch+1) % (config["val_epoch_interval"]//10+1) == 0:
156 | log_dict = vae_trainer.run_epoch(train_data, config["inner_epochs"], inner_optim,
157 | precomputed_latent=precomputed_latent_train_eval, device=device, epoch=meta_epoch)
158 | log_metric(log_dict, "eval_train/")
159 | log_dict = vae_trainer.run_epoch(test_data, config["inner_epochs"], inner_optim,
160 | precomputed_latent=precomputed_latent_val_eval, device=device, epoch=meta_epoch)
161 | log_metric(log_dict, "eval_val/")
162 |
163 | log_dict = vae_trainer.run_epoch(train_data, config["inner_epochs"], inner_optim, reconstructed=True,
164 | precomputed_latent=precomputed_latent_train_eval, device=device, epoch=meta_epoch)
165 | log_metric(log_dict, "reconstr_eval_train/")
166 | log_dict = vae_trainer.run_epoch(test_data, config["inner_epochs"], inner_optim, reconstructed=True,
167 | precomputed_latent=precomputed_latent_val_eval, device=device, epoch=meta_epoch)
168 | log_metric(log_dict, "reconstr_eval_val/")
169 |
170 | if meta_epoch % config["val_epoch_interval"] == config["val_epoch_interval"]-1:
171 |
172 | if config["eval_inner_epochs"] != '':
173 | for idx, inner_epochs in enumerate(config["eval_inner_epochs"].split(",")):
174 | log_dict = vae_trainer.run_epoch(train_data, int(inner_epochs), generated_inner_optim, device=device, generated=True, epoch=meta_epoch)
175 | log_metric(log_dict, "generated_eval_train_{}step/".format(inner_epochs))
176 |
177 | log_dict = vae_trainer.run_epoch(test_data, int(inner_epochs), generated_inner_optim, device=device, generated=True, epoch=meta_epoch)
178 | log_metric(log_dict, "generated_eval_val_{}step/".format(inner_epochs))
179 |
180 | if config["save_checkpoint"]:
181 | hnet_gen_output_path_checkpoint = base_path + "/evaluation/vae/hnet_gen_" + str(
182 | wandb.run.name) + ".pth"
183 | hnet_enc_output_path_checkpoint = base_path + "/evaluation/vae/hnet_enc_" + str(
184 | wandb.run.name) + ".pth"
185 | torch.save(hnet_gen.state_dict(), hnet_gen_output_path_checkpoint)
186 | torch.save(hnet_enc.state_dict(), hnet_enc_output_path_checkpoint)
187 | print(f"Checkpoint for meta-epoch {meta_epoch} saved!")
188 |
189 | hnet_gen_output_path_checkpoint = base_path + "/evaluation/vae/hnet_gen_" + str(
190 | wandb.run.name) + ".pth"
191 | hnet_enc_output_path_checkpoint = base_path + "/evaluation/vae/hnet_enc_" + str(
192 | wandb.run.name) + ".pth"
193 | torch.save(hnet_gen.state_dict(), hnet_gen_output_path_checkpoint)
194 | torch.save(hnet_enc.state_dict(), hnet_enc_output_path_checkpoint)
195 |
196 | wandb.finish()
197 |
198 |
199 | if __name__ == "__main__":
200 | main(args)
201 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(name='hyperclip', version='0.1', packages=find_packages())
4 |
--------------------------------------------------------------------------------
/training/conditional_model_learn.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from collections import OrderedDict
3 |
4 | import numpy as np
5 | import torch
6 | import wandb
7 | from data.dataloader.clip_vqa import CLIP_VQA
8 | from data.dataloader.coco_tasks import COCO_Tasks
9 | from model.custom_hnet import CLIPAdapter, EmbeddingModule, MetaModel
10 | from torch.nn import functional as F
11 | from torch.utils.data import DataLoader
12 | from tqdm import tqdm
13 | from utils import clip_utils
14 | from utils.misc_utils import append_dict, mean_dict
15 | from utils.train_utils import get_pred, log_metric, test_accuracy
16 |
17 | from training.vae_learn import sample_from_enc
18 |
19 |
20 | class ConditionalModelTraining():
21 | # Train a "conditional" model from question embeddings and finetuned network features (weights, latents)
22 | def __init__(self, meta_module, cond_model, optimizer, image_features, text_features, ques_emb, config, device, net_feature_dim,
23 | hnet_gen, hnet_enc, train_on_vae=False, compute_hessian=False,
24 | coco_categories=None, coco_answer_features=None, extend_coco_size=10 * 870):
25 | self.meta_module = meta_module
26 | self.cond_model = cond_model
27 | self.optimizer = optimizer
28 | self.image_features = image_features
29 | self.text_features = text_features
30 | self.ques_emb = ques_emb
31 | self.config = config
32 | self.device = device
33 | self.net_feature_dim = net_feature_dim
34 | self.compute_hessian = compute_hessian
35 | self.hnet_gen=hnet_gen
36 | self.hnet_enc=hnet_enc
37 | self.base_module=meta_module
38 | self.train_on_vae=train_on_vae
39 |
40 | self.coco_categories=coco_categories
41 | self.coco_answer_features=coco_answer_features
42 | self.extend_coco_size=extend_coco_size
43 |
44 | def reset(self, batch_size):
45 | self.sub_batch_i = 0
46 | self.ques_batch = torch.zeros(batch_size, clip_utils.embedding_size[self.config["clip_model"]], dtype=torch.float32).to(self.device)
47 | self.net_batch = torch.zeros(batch_size, self.net_feature_dim, dtype=torch.float32).to(self.device)
48 | self.task_batch = torch.zeros(batch_size, 1, dtype=torch.int).to(self.device)
49 | self.coco_batch = torch.zeros(batch_size, 1, dtype=torch.bool).to(self.device)
50 |
51 | self.hessian_batch = None
52 | if self.compute_hessian:
53 | self.hessian_batch = torch.zeros(batch_size, self.net_feature_dim, self.net_feature_dim, dtype=torch.float32).to(self.device)
54 |
55 | def reset_task(self, use_vae):
56 | if self.hnet_gen is not None and use_vae:
57 | embedding = sample_from_enc(self.hnet_enc, self.base_module.get_mainnet_weights(params = None).detach()).detach()
58 |
59 | enet = EmbeddingModule(self.hnet_gen.input_dim).to(self.device)
60 |
61 | enet.reset(embedding=embedding)
62 | mnet = CLIPAdapter(e_dim=self.meta_module.mnet.e_dim,
63 | hidden_layers=[self.meta_module.mnet.hidden_size],
64 | use_bias=self.meta_module.mnet.use_bias,
65 | straight_through=self.meta_module.mnet.straight_through,
66 | no_weights=True).to(self.device)
67 |
68 | self.meta_module = MetaModel(mnet=mnet, hnet=self.hnet_gen, enet=enet, config=self.config)
69 | else:
70 | self.meta_module = self.base_module
71 |
72 |
73 | def run_iter(self, task_idx, train_dataloader, net_features, train=True, clip_embedding=None, embed_hessian=None, coco=False, **kwargs):
74 | log_dict=dict()
75 | self.ques_batch[self.sub_batch_i] = clip_embedding if clip_embedding is not None else iter(train_dataloader).next()["ques_emb"][0]
76 | self.net_batch[self.sub_batch_i] = net_features
77 | self.task_batch[self.sub_batch_i] = task_idx
78 | self.coco_batch[self.sub_batch_i] = coco
79 | if self.compute_hessian:
80 | self.hessian_batch[self.sub_batch_i] = embed_hessian
81 |
82 | if self.sub_batch_i == self.ques_batch.shape[0]-1: # Train/test one batch of hyperclip once the batch is filled
83 | self.train_step(log_dict, **kwargs) if train else self.test_step(log_dict, **kwargs)
84 |
85 | self.sub_batch_i = (self.sub_batch_i+1) % self.ques_batch.shape[0]
86 | return log_dict
87 |
88 | def train_step(self, log_dict, **kwargs):
89 | raise NotImplementedError("Use a subclass of ConditionalModelTraining")
90 |
91 | def test_step(self, log_dict, **kwargs):
92 | raise NotImplementedError("Use a subclass of ConditionalModelTraining")
93 |
94 | def run_epoch(self, data, inner_epochs, inner_optim_fct, batch_size=32,
95 | train_subtype="train", val_subtype="test", guided_inner=False,
96 | precomputed_latent=None,
97 | train=False, skip_cond=False, output_mean_std=False, tasks_idxs=None,
98 | n_shot_training=None, opt_latents_for_n_shot=None,
99 | debug=False, use_vae=False, keep_tasks_frac=1, extend_coco=False,
100 | extend_coco_frac_train=0.5, num_ensemble=1,# frac of tasks to replace with extended coco
101 | epoch=0, **kwargs):
102 | tasks = list(data.keys())
103 | if tasks_idxs is None:
104 | tasks_idxs = np.arange(len(tasks))
105 |
106 | if batch_size > len(tasks)*keep_tasks_frac:
107 | batch_size = int(len(tasks)*keep_tasks_frac)
108 | print("Warning: batch size too big, decreasing to {}".format(batch_size))
109 |
110 | if train:
111 | self.reset(batch_size)
112 | else:
113 | self.reset(len(data.keys()))
114 |
115 | if inner_epochs is not None:
116 | inner_epochs_range = [inner_epochs] if type(inner_epochs) == int else [int(i) for i in inner_epochs.split(",")]
117 | else:
118 | inner_epochs_range = None
119 | log_dict = dict()
120 |
121 | enable_coco = [False] * len(tasks)
122 | if precomputed_latent is None or "task_idx" not in precomputed_latent:
123 | shuffled_train_tasks = [tasks_idxs[idx] for idx in torch.randperm(len(tasks_idxs))]
124 | shuffled_coco_tasks = torch.randperm(self.extend_coco_size)
125 | shuffle_for_extended_coco_replace = torch.randperm(len(tasks))
126 | if extend_coco:
127 | enable_coco=[shuffle_for_extended_coco_replace[i] < extend_coco_frac_train * len(tasks) for i in range(len(tasks))]
128 | else:
129 | shuffled_train_tasks=precomputed_latent["task_idx"][0].long()
130 | shuffled_coco_tasks=precomputed_latent["task_idx"][0].long()
131 | if "coco" in precomputed_latent:
132 | enable_coco=precomputed_latent["coco"][0]
133 |
134 | if output_mean_std:
135 | all_tasks_optimized_params = torch.zeros((len(tasks), self.net_feature_dim)).to(self.device)
136 | n_shot_seed = np.random.randint(0, 1000000)
137 | n_corr_guesses_support_start = 0
138 | n_corr_guesses_support_end = 0
139 | n_corr_guesses_query_start = 0
140 | n_corr_guesses_query_end = 0
141 | n_tot_samples_support = 0
142 | n_tot_samples_query = 0
143 |
144 | for inner_train_iter in tqdm(range(len(shuffled_train_tasks))):
145 | self.reset_task(use_vae)
146 | curr_log_dict = dict()
147 | if enable_coco[inner_train_iter]:
148 | task_idx = shuffled_coco_tasks[inner_train_iter]
149 | train_dataset = COCO_Tasks(categories=self.coco_categories,
150 | dataSubType=train_subtype,
151 | image_features=self.image_features,
152 | coco_answer_features=self.coco_answer_features,
153 | task_seed=task_idx)
154 | test_dataset = COCO_Tasks(categories=self.coco_categories,
155 | dataSubType=val_subtype,
156 | image_features=self.image_features,
157 | coco_answer_features=self.coco_answer_features,
158 | task_seed=task_idx)
159 | else:
160 | task_idx = shuffled_train_tasks[inner_train_iter]
161 | if task_idx > len(tasks)*keep_tasks_frac:
162 | continue
163 | train_dataset = CLIP_VQA(meta_data=data,
164 | dataSubType=train_subtype,
165 | task=tasks[task_idx],
166 | image_features=self.image_features,
167 | text_features=self.text_features,
168 | ques_emb=self.ques_emb,
169 | n_shot=n_shot_training)
170 | test_dataset = CLIP_VQA(meta_data=data,
171 | dataSubType=val_subtype,
172 | task=tasks[task_idx],
173 | image_features=self.image_features,
174 | text_features=self.text_features,
175 | ques_emb=self.ques_emb)
176 |
177 | train_dataloader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
178 | test_dataloader = DataLoader(test_dataset, batch_size=1024, shuffle=False)
179 |
180 | # Inner loop
181 | if inner_epochs_range is not None:
182 | inner_epochs_sampled = inner_epochs_range[0] if len(inner_epochs_range) == 1 else \
183 | np.random.randint(inner_epochs_range[0], inner_epochs_range[1]+1)
184 | else:
185 | inner_epochs_sampled = None
186 |
187 | embed_hessian=None
188 |
189 | init_inner_params = self.meta_module.get_inner_params()
190 | # Make sure to clone and detach to not optimize the actual initialization.
191 | inner_params = {k: v.clone().detach().requires_grad_() for (k,v) in self.meta_module.get_inner_params().items()}
192 | if opt_latents_for_n_shot is not None:
193 | assert self.meta_module.mnet.no_weight, "Not implemented when meta_module is a mainnet."
194 | inner_params["enet.embedding"] = opt_latents_for_n_shot[task_idx].clone().detach().requires_grad_()
195 |
196 | train_start_acc, train_start_loss = test_accuracy(self.meta_module, train_dataloader, params=inner_params)
197 | val_start_acc, val_start_loss = test_accuracy(self.meta_module, test_dataloader, params=inner_params)
198 |
199 | inner_params_list=None
200 | if guided_inner:
201 | if num_ensemble>1:
202 | inner_params_list=[]
203 | for _ in range(num_ensemble):
204 | inner_params = {k: v.clone().detach().requires_grad_() for (k,v) in self.meta_module.get_inner_params().items()}
205 | inner_params_list.append(inner_params)
206 | curr_log_dict.update(self.guided_inner(train_dataloader, inner_params, init_inner_params, inner_optim_fct,
207 | inner_train_iter, inner_epochs_sampled, batch_size, debug=inner_train_iter==1, **kwargs))
208 | else:
209 | curr_log_dict.update(self.guided_inner(train_dataloader, inner_params, init_inner_params, inner_optim_fct,
210 | inner_train_iter, inner_epochs_sampled, batch_size, debug=inner_train_iter==1, **kwargs))
211 | else:
212 | if precomputed_latent is None or n_shot_training is not None:
213 | inner_optimizer = inner_optim_fct(list(inner_params.values()))
214 | for _ in range(inner_epochs_sampled):
215 | outputs, labels = get_pred(self.meta_module, train_dataloader, params=inner_params)
216 | inner_loss = F.cross_entropy(outputs, labels)
217 | if debug and (self.sub_batch_i+1) % batch_size == 0:
218 | wandb.log({"debug_inner_loss": inner_loss.item()})
219 | inner_optimizer.zero_grad()
220 | inner_loss.backward()
221 | inner_optimizer.step()
222 |
223 | if self.compute_hessian:
224 | embed_hessian = self.compute_feature_hessian(train_dataloader, self.meta_module, inner_params["enet.embedding"]).squeeze()
225 |
226 | else:
227 | # If i use the precomputed latents, I "simulate" the finetuning inner loop after it's supposed
228 | # to happen, to not break metrics and mean/std calculations (super ugly, to refactor)
229 | if self.meta_module.mnet.no_weight:
230 | inner_params["enet.embedding"] = precomputed_latent["embedding"][0, inner_train_iter]
231 | else:
232 | inner_params.update({"mnet."+k:v for (k,v) in self.meta_module.mnet.load_from_vector(precomputed_latent["w_vect"][0, inner_train_iter]).items()})
233 |
234 | if self.compute_hessian:
235 | embed_hessian = precomputed_latent["hessian"][0, inner_train_iter]
236 |
237 | # Train set accuracy
238 | if inner_params_list is not None:
239 | train_end_acc, train_end_loss = test_accuracy(self.meta_module, train_dataloader, params_list=inner_params_list)
240 | val_end_acc, val_end_loss = test_accuracy(self.meta_module, test_dataloader, params_list=inner_params_list)
241 |
242 | avr_inner_params = OrderedDict()
243 | avr_inner_params.update({k: torch.stack([p[k] for p in inner_params_list]).mean(0).detach() for k in inner_params.keys()})
244 |
245 | train_end_acc_avr, train_end_loss_avr = test_accuracy(self.meta_module, train_dataloader, params=avr_inner_params)
246 | val_end_acc_avr, val_end_loss_avr = test_accuracy(self.meta_module, test_dataloader, params=avr_inner_params)
247 |
248 | curr_log_dict["support_loss_end_avr"] = train_end_loss_avr
249 | curr_log_dict["query_loss_end_avr"] = val_end_loss_avr
250 |
251 | curr_log_dict["support_accuracy_end_avr"] = train_end_acc_avr
252 | curr_log_dict["query_accuracy_end_avr"] = val_end_acc_avr
253 |
254 | else:
255 | train_end_acc, train_end_loss = test_accuracy(self.meta_module, train_dataloader, params=inner_params)
256 | val_end_acc, val_end_loss = test_accuracy(self.meta_module, test_dataloader, params=inner_params)
257 |
258 | n_tot_samples_support += len(train_dataset)
259 | n_tot_samples_query += len(test_dataset)
260 | curr_log_dict["query_accuracy_start"] = val_start_acc
261 | n_corr_guesses_query_start += val_start_acc * len(test_dataset)
262 | curr_log_dict["query_accuracy_end"] = val_end_acc
263 | n_corr_guesses_query_end += val_end_acc * len(test_dataset)
264 | curr_log_dict["support_accuracy_start"] = train_start_acc
265 | n_corr_guesses_support_start += train_start_acc * len(train_dataset)
266 | curr_log_dict["support_accuracy_end"] = train_end_acc
267 | n_corr_guesses_support_end += train_end_acc * len(train_dataset)
268 | curr_log_dict["query_loss_start"] = val_start_loss
269 | curr_log_dict["query_loss_end"] = val_end_loss
270 | curr_log_dict["support_loss_start"] = train_start_loss
271 | curr_log_dict["support_loss_end"] = train_end_loss
272 |
273 | # Actually run cond model training/eval
274 | if not skip_cond:
275 | cond_dict = self.run_iter(task_idx, train_dataloader,
276 | self.feed_net_feature(inner_params = inner_params),embed_hessian=embed_hessian,
277 | train=train, coco=enable_coco[inner_train_iter],
278 | clip_embedding=precomputed_latent["clip_embedding"][0, inner_train_iter], **kwargs)
279 | curr_log_dict.update(cond_dict)
280 |
281 | if output_mean_std:
282 | all_tasks_optimized_params[task_idx] = self.feed_net_feature(inner_params = inner_params)
283 |
284 | append_dict(log_dict, curr_log_dict)
285 |
286 | if debug and self.sub_batch_i % batch_size == 0:
287 | output_dict = mean_dict(log_dict)
288 | output_dict["query_accuracy_start_flatten"] = n_corr_guesses_query_start / n_tot_samples_query
289 | output_dict["query_accuracy_end_flatten"] = n_corr_guesses_query_end / n_tot_samples_query
290 | output_dict["support_accuracy_start_flatten"] = n_corr_guesses_support_start / n_tot_samples_support
291 | output_dict["support_accuracy_end_flatten"] = n_corr_guesses_support_end / n_tot_samples_support
292 | log_metric(output_dict, prefix="debug_")
293 | log_dict = dict()
294 |
295 | output_dict = mean_dict(log_dict)
296 | output_dict["query_accuracy_start_flatten"] = n_corr_guesses_query_start / n_tot_samples_query
297 | output_dict["query_accuracy_end_flatten"] = n_corr_guesses_query_end / n_tot_samples_query
298 | output_dict["support_accuracy_start_flatten"] = n_corr_guesses_support_start / n_tot_samples_support
299 | output_dict["support_accuracy_end_flatten"] = n_corr_guesses_support_end / n_tot_samples_support
300 | output_dict["epoch"] = epoch
301 |
302 | if not output_mean_std:
303 | return output_dict
304 | else:
305 | optimized_params_mean = torch.mean(all_tasks_optimized_params, dim=0, keepdim=True)
306 | optimized_params_std = torch.std(all_tasks_optimized_params, dim=0, keepdim=True)
307 | return output_dict, optimized_params_mean, optimized_params_std, all_tasks_optimized_params
308 |
309 | def guided_inner(self, train_dataloader, inner_params, init_inner_params, inner_optim_fct,
310 | inner_train_iter, inner_epochs, batch_size, debug, **kwargs):
311 | raise NotImplementedError("Use a subclass of ConditionalModelTraining")
312 |
313 | def feed_net_feature(self, **kwargs):
314 | raise NotImplementedError("Use a subclass of ConditionalModelTraining")
315 |
316 | def compute_feature_hessian(self, train_dataloader, meta_module, embedding):
317 | """ Only supports when the embedding is the net feature. Would generalize this if needed. """
318 | def get_nll(embedding):
319 | curr_params = OrderedDict()
320 | curr_params["enet.embedding"] = embedding
321 | outputs, labels = get_pred(meta_module, train_dataloader, params=curr_params)
322 | inner_loss = F.cross_entropy(outputs, labels)
323 | return inner_loss
324 |
325 | return torch.autograd.functional.hessian(get_nll, embedding)
326 |
327 |
--------------------------------------------------------------------------------
/training/hyperclip_learn.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import wandb
5 | from torch.nn import functional as F
6 | from utils.train_utils import get_pred, test_accuracy
7 |
8 | from training.conditional_model_learn import ConditionalModelTraining
9 | from training.vae_learn import sample_from_enc
10 |
11 |
12 | class HyperclipTraining(ConditionalModelTraining):
13 | def __init__(self, meta_module, hnet_gen, hnet_enc, hyperclip, optimizer,
14 | image_features, text_features, ques_emb, config, device, train_on_vae=False, **kwargs):
15 | net_feature_dim = hyperclip.hyper.input_dim
16 | super().__init__(meta_module, hyperclip, optimizer, image_features, text_features, ques_emb,
17 | config, device, net_feature_dim, hnet_gen, hnet_enc, train_on_vae=train_on_vae, **kwargs)
18 | self.feature_mean=None
19 | self.feature_std=None
20 |
21 | def train_step(self, log_dict):
22 | self.optimizer.zero_grad()
23 |
24 | net_batch = self.net_batch
25 | if self.feature_mean is not None:
26 | net_batch = net_batch - self.feature_mean
27 | net_batch = net_batch / self.feature_std
28 |
29 | _, hyperclip_loss = self.cond_model.forward(hyper = net_batch, text = self.ques_batch, precomputed_it_embs = True)
30 | log_dict.update({"hyperclip_loss": hyperclip_loss.detach().cpu().numpy()})
31 | hyperclip_loss.backward()
32 | self.optimizer.step()
33 |
34 | def test_step(self, log_dict):
35 | with torch.no_grad():
36 | self.cond_model.eval()
37 |
38 | net_batch = self.net_batch
39 | if self.feature_mean is not None:
40 | net_batch = net_batch-self.feature_mean
41 | net_batch = net_batch / self.feature_std
42 |
43 | _, hyperclip_loss = self.cond_model.forward(hyper = net_batch,
44 | text = self.ques_batch, precomputed_it_embs = True)
45 | log_dict.update({"hyperclip_val_loss": hyperclip_loss.detach().cpu().numpy()})
46 | self.cond_model.train()
47 |
48 |
49 | def set_stats(self, mean, std):
50 | self.feature_mean= mean
51 | self.feature_std=std
52 |
53 |
54 | def get_compute_grad(self, train_dataloader):
55 | def compute_grad(inner_params):
56 | with torch.enable_grad():
57 | x = inner_params["enet.embedding"]
58 | outputs, labels = get_pred(self.meta_module, train_dataloader, params=inner_params)
59 | inner_loss = F.cross_entropy(outputs, labels)
60 | grad =torch.autograd.grad(inner_loss, x)
61 | return grad[0]
62 | return compute_grad
63 |
64 |
65 | def guided_inner(self, train_dataloader, inner_params, init_inner_params, inner_optim_fct,
66 | inner_train_iter, inner_epochs, batch_size, debug, langevin_eps=0, guidance_scheduler_fn=None, guidance_init_l2_weight=0, **kwargs):
67 |
68 | # if eval:
69 | compute_grad_fn = self.get_compute_grad(train_dataloader)
70 |
71 | inner_optimizer = inner_optim_fct(list(inner_params.values()))
72 | if guidance_scheduler_fn is not None:
73 | guidance_scheduler = guidance_scheduler_fn(inner_optimizer)
74 | else:
75 | guidance_scheduler = None
76 |
77 |
78 | log_dict=dict()
79 | log_dict["cos_sim"]=[]
80 | log_dict["acc"]=[]
81 | log_dict["loss"]=[]
82 |
83 | for _ in range(inner_epochs):
84 | task_ques_emb = next(iter(train_dataloader))["ques_emb"][0]
85 | weights = self.meta_module.get_mainnet_weights(ques_emb=task_ques_emb, params=inner_params)
86 |
87 | if self.feature_mean is not None:
88 | weights = weights - self.feature_mean
89 | weights = weights / self.feature_std
90 |
91 | hyperclip = self.cond_model
92 |
93 | task_weight_emb = hyperclip.encode_hyper(weights)
94 |
95 | norm_task_ques_emb = task_ques_emb / task_ques_emb.norm(dim=-1, keepdim=True)
96 | norm_task_weight_emb = task_weight_emb / task_weight_emb.norm(dim=-1, keepdim=True)
97 |
98 | inner_product_embs_loss = - norm_task_weight_emb @ norm_task_ques_emb.T
99 |
100 | init_l2_loss = torch.stack(
101 | [(ip - p).pow(2).sum() for (ip, p)
102 | in zip(init_inner_params.values(), inner_params.values())]).sum() / 2
103 | inner_loss = inner_product_embs_loss + init_l2_loss* guidance_init_l2_weight
104 |
105 | inner_optimizer.zero_grad()
106 | inner_loss.backward()
107 |
108 | if True: #eval:
109 | alignment = torch.nn.CosineSimilarity(dim=0)(inner_params["enet.embedding"].grad.view(-1), compute_grad_fn(inner_params).view(-1))
110 | log_dict["cos_sim"].append(alignment.item())
111 | a, l = test_accuracy(self.meta_module, train_dataloader, params=inner_params)
112 | log_dict["acc"].append(a)
113 | log_dict["loss"].append(l)
114 |
115 | if debug:
116 | wandb.log({"debug_inner_guidance_loss": inner_product_embs_loss.item(),
117 | "debug_inner_l2_loss": init_l2_loss.item(),
118 | "inner_lr": inner_optimizer.param_groups[0]['lr'],
119 | "gradnorm": torch.stack([p.grad.pow(2).sum() for p in inner_params.values() if p.grad is not None]).sum().item()})
120 |
121 |
122 | if langevin_eps>0:
123 | for p in inner_params.values():
124 | p.grad.data += langevin_eps*torch.randn_like(p.grad.data)/math.sqrt(inner_optimizer.param_groups[0]['lr'])
125 |
126 | inner_optimizer.step()
127 | if guidance_scheduler is not None:
128 | guidance_scheduler.step()
129 |
130 | return log_dict
131 |
132 |
133 | def feed_net_feature(self, inner_params=None):
134 | if self.train_on_vae and self.hnet_gen is not None:
135 | return self.hnet_gen(sample_from_enc(self.hnet_enc, self.base_module.get_mainnet_weights(params = inner_params).detach())).detach()
136 | return self.base_module.get_mainnet_weights(params = inner_params).detach()
137 |
--------------------------------------------------------------------------------
/training/latent_diffusion_learn.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | import wandb
5 | from torch.nn import functional as F
6 | from utils.train_utils import get_pred
7 |
8 | from training.conditional_model_learn import ConditionalModelTraining
9 | from training.vae_learn import sample_from_enc
10 |
11 |
12 | class LatentDiffusionTraining(ConditionalModelTraining):
13 | def __init__(self,
14 | meta_module,
15 | diffusion_model,
16 | optimizer,
17 | n_timestep,
18 | beta,
19 | ema,
20 | net_feature_dim,
21 | train_data_for_mean_std,
22 | inner_opt_for_mean_std,
23 | image_features,
24 | text_features,
25 | ques_emb,
26 | config,
27 | device,
28 | hnet_gen, hnet_enc, train_on_vae=False,
29 | compute_hessian=False,
30 | compute_latents_mean_std = True,
31 | v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
32 | class_free_guidance=True,
33 | class_free_training_uncond_rate=0.2,
34 | precomputed_latent_for_mean_std=None,
35 | **kwargs
36 | ):
37 |
38 | super().__init__(meta_module, diffusion_model, optimizer, image_features, text_features,
39 | ques_emb, config, device, net_feature_dim, hnet_gen, hnet_enc, train_on_vae, compute_hessian, **kwargs)
40 | self.n_timestep = n_timestep
41 | self.beta = torch.tensor(beta).to(self.device)
42 | self.ema = ema
43 |
44 | self.alpha = (1 - self.beta)
45 | self.alpha_cumprod = self.alpha.log().cumsum(0).exp().float()
46 | self.alpha_cumprod_prev = torch.cat((torch.ones(1, device = self.device), self.alpha_cumprod[:-1]))
47 | self.v_posterior = v_posterior
48 | self.sigma = ((1 - self.v_posterior) * (self.beta * (1 - self.alpha_cumprod_prev) / (1 - self.alpha_cumprod)) + self.v_posterior * self.beta).sqrt().float()
49 | self.beta, self.alpha = self.beta.float(), self.alpha.float()
50 |
51 | self.class_free_guidance = class_free_guidance
52 | self.class_free_training_uncond_rate = class_free_training_uncond_rate
53 |
54 | if compute_latents_mean_std:
55 | _, self.latents_mean, self.latents_std, _ = self.run_epoch(train_data_for_mean_std, config["inner_epochs"], inner_opt_for_mean_std,
56 | batch_size=config["diffusion_batch_size"],
57 | train=True, train_subtype = config["train_subtype"], val_subtype=config["val_subtype"],
58 | precomputed_latent=precomputed_latent_for_mean_std,
59 | skip_cond=True, output_mean_std=True, keep_tasks_frac=config["diffusion_keep_tasks_frac"])
60 | else:
61 | self.latents_mean, self.latents_std = torch.zeros((1,self.net_feature_dim)).to(self.device), torch.ones((1,self.net_feature_dim)).to(self.device)
62 |
63 | wandb.log({"latents_mean": self.latents_mean.detach().cpu().numpy()})
64 | wandb.log({"latents_std": self.latents_std.detach().cpu().numpy()})
65 |
66 | def forward_diffusion(self, latents_batch, ques_batch, metric=None):
67 | norm_latents_batch = (latents_batch - self.latents_mean) / self.latents_std
68 | t = torch.randint(self.n_timestep, (norm_latents_batch.size(0),) + (1,) * (norm_latents_batch.dim() - 1), device = self.device)
69 | eps = torch.randn_like(norm_latents_batch)
70 | norm_latents_batch_t = torch.sqrt(self.alpha_cumprod[t]) * norm_latents_batch + torch.sqrt(1 - self.alpha_cumprod[t]) * eps
71 | output = self.cond_model(norm_latents_batch_t, t, ques_batch)
72 | if metric is not None:
73 | loss = (((eps - output).unsqueeze(1) @ metric @ (eps - output).unsqueeze(-1)) / metric.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1)).mean()
74 | else:
75 | loss = (eps - output).pow(2).mean()
76 | return loss
77 |
78 | def train_step(self, log_dict):
79 | self.optimizer.zero_grad()
80 |
81 | ques_batch = torch.clone(self.ques_batch)
82 | if self.class_free_guidance:
83 | perm = torch.randperm(ques_batch.shape[0])
84 | idx = perm[:int(perm.shape[0]*self.class_free_training_uncond_rate)]
85 | ques_batch[idx] = torch.zeros(ques_batch.shape[1]).to(self.device)
86 |
87 | loss = self.forward_diffusion(self.net_batch, ques_batch, self.hessian_batch)
88 |
89 | log_dict.update({"diffusion_loss": loss.detach().cpu().numpy()})
90 | loss.backward()
91 | self.optimizer.step()
92 |
93 | if self.ema is not None: self.ema.step()
94 |
95 | def test_step(self, log_dict):
96 | with torch.no_grad():
97 | self.cond_model.eval()
98 |
99 | ques_batch = torch.clone(self.ques_batch)
100 | if self.class_free_guidance:
101 | perm = torch.randperm(ques_batch.shape[0])
102 | idx = perm[:int(perm.shape[0]*self.class_free_training_uncond_rate)]
103 | ques_batch[idx] = torch.zeros(ques_batch.shape[1]).to(self.device)
104 |
105 | loss = self.forward_diffusion(self.net_batch, ques_batch, self.hessian_batch)
106 |
107 | log_dict.update({"diffusion_val_loss": loss.detach().cpu().numpy()})
108 | self.cond_model.train()
109 |
110 | def get_compute_grad(self, train_dataloader):
111 | def compute_grad(x):
112 | with torch.enable_grad():
113 | inner_params = OrderedDict()
114 | x = x.detach().requires_grad_()
115 | inner_params["enet.embedding"] = x * self.latents_std + self.latents_mean
116 | outputs, labels = get_pred(self.meta_module, train_dataloader, params=inner_params)
117 | inner_loss = F.cross_entropy(outputs, labels)
118 | grad =torch.autograd.grad(inner_loss, x)
119 | return grad[0]
120 | return compute_grad
121 |
122 | def guided_inner(self, train_dataloader, inner_params, init_inner_params, inner_optim_fct,
123 | inner_train_iter, inner_epochs, batch_size, debug, class_guidance_gamma=None,
124 | init_guidance_at="random", guidance_start_from_t_frac=1, fast_sampling_factor=None, few_shot_guidance=False, few_shot_gamma=1, **kwargs):
125 | #IMPORTANT: x_accuracy_start printed in wandb is not the one for the random initialized embedding (to which the diffusion is applied to)
126 | # but the one for the initial emb initialization we are NOT using (except if we start from pre-trained !)
127 | if few_shot_guidance:
128 | compute_grad_fn = self.get_compute_grad(train_dataloader)
129 | else:
130 | compute_grad_fn=None
131 | if class_guidance_gamma is None:
132 | class_guidance_gamma = 1.
133 | if init_guidance_at=="random":
134 | latent_start_point = None
135 | elif init_guidance_at=="pre-trained":
136 | latent_start_point = init_inner_params["enet.embedding"].clone().detach()
137 | sampled_emb = self.generate(next(iter(train_dataloader))["ques_emb"][0],
138 | self.class_free_guidance,
139 | class_guidance_gamma,
140 | latent_start_point=latent_start_point,
141 | start_from_t_frac=guidance_start_from_t_frac,
142 | fast_sampling_factor=fast_sampling_factor,
143 | compute_grad_fn=compute_grad_fn,
144 | few_shot_gamma=few_shot_gamma)
145 | inner_params["enet.embedding"] = sampled_emb # refactor to be more general
146 | return dict()
147 |
148 |
149 | def generate(self, cond_emb, class_free_guidance, gamma, latent_start_point=None, start_from_t_frac=1., fast_sampling_factor=None, compute_grad_fn=None, few_shot_gamma=1):
150 |
151 | with torch.no_grad():
152 |
153 | alpha_cumprod = self.alpha_cumprod
154 | sigma = self.sigma
155 | if fast_sampling_factor is not None and fast_sampling_factor > 1:
156 | alpha_cumprod = self.alpha_cumprod[::fast_sampling_factor]
157 | alpha_cumprod_prev = torch.cat((torch.ones(1, device = self.device), alpha_cumprod[:-1]))
158 | beta = 1 - (alpha_cumprod / alpha_cumprod_prev)
159 | sigma = ((1 - self.v_posterior) * (beta * (1 - alpha_cumprod_prev) / (1 - alpha_cumprod)) + self.v_posterior * beta).sqrt().float()
160 |
161 | if latent_start_point is None:
162 | x = torch.randn(self.latents_mean.shape, device = self.device)
163 | else:
164 | x = (latent_start_point - self.latents_mean) / self.latents_std
165 |
166 | for t in range(int(sigma.shape[0]*start_from_t_frac)-1, -1, -1):
167 | t_for_model = t
168 | if fast_sampling_factor is not None and fast_sampling_factor > 1:
169 | t_for_model = t * fast_sampling_factor
170 | if not class_free_guidance or gamma==1.:
171 | output = self.cond_model(x, t_for_model, cond_emb)
172 | else:
173 | output = (1 - gamma) * self.cond_model(x, t_for_model, torch.zeros_like(cond_emb)) + gamma * self.cond_model(x, t_for_model, cond_emb)
174 |
175 | if compute_grad_fn is not None:
176 | grad = compute_grad_fn(x)
177 | grad = grad / grad.norm()*output.norm()*few_shot_gamma
178 | output = output + grad
179 |
180 | z = torch.zeros_like(x) if t == 0 else torch.randn_like(x)
181 | x = 1/torch.sqrt(self.alpha[t]) * (x - (1-self.alpha[t]) / torch.sqrt(1-alpha_cumprod[t]) * output) + sigma[t] * z
182 |
183 | x = x * self.latents_std + self.latents_mean
184 |
185 | return x
186 |
187 | def feed_net_feature(self, inner_params=None):
188 | if self.train_on_vae and self.hnet_gen is not None:
189 | return sample_from_enc(self.hnet_enc, self.meta_module.get_mainnet_weights(params = inner_params).detach()).detach()
190 | # works if inner_params contains only one element, the hnet/vae embedding
191 | return inner_params["enet.embedding"].detach()
192 |
193 |
--------------------------------------------------------------------------------
/training/maml_learn.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import math
3 | from collections import OrderedDict
4 |
5 | import numpy as np
6 | import torch
7 | import wandb
8 | from data.dataloader.clip_vqa import CLIP_VQA
9 | from data.dataloader.coco_tasks import COCO_Tasks
10 | from torch.nn import functional as F
11 | from torch.utils.data import DataLoader
12 | from tqdm import tqdm
13 | from utils.misc_utils import append_dict, mean_dict
14 | from utils.train_utils import get_pred, log_metric, test_accuracy
15 |
16 |
17 | class GradientBuffer():
18 | def __init__(self, param_list):
19 | self.param_list=param_list
20 | self.reset_buffer()
21 |
22 | def reset_buffer(self):
23 | self.grad_list=[torch.zeros_like(param) for param in self.param_list]
24 | self.num = 0
25 |
26 | def accumulate(self):
27 | for param, grad in zip(self.param_list, self.grad_list):
28 | if param.grad is not None:
29 | grad += param.grad.data
30 | self.num += 1
31 |
32 | def unload(self):
33 | for param, grad in zip(self.param_list, self.grad_list):
34 | if param.grad is not None:
35 | param.grad.data += grad/self.num
36 | self.reset_buffer()
37 |
38 | class MAML():
39 | def __init__(self, meta_module, meta_optimizer, image_features, text_features, ques_emb, config, coco_categories=None, coco_answer_features=None,
40 | extend_coco_size=10 * 870, # size of virtual extended coco dataset
41 | ):
42 | self.meta_module=meta_module
43 | self.meta_optimizer=meta_optimizer
44 | self.image_features=image_features
45 | self.text_features=text_features
46 | self.ques_emb=ques_emb
47 | self.config=config
48 | self.coco_categories=coco_categories
49 | self.coco_answer_features=coco_answer_features
50 | self.extend_coco_size=extend_coco_size
51 | self.reset_coco()
52 |
53 | def reset(self, batch_size, device):
54 | self.buffer = GradientBuffer(self.meta_module.meta_params)
55 | self.batch_size=batch_size
56 | self.iter=0
57 |
58 | def reset_coco(self):
59 | self.coco_iter = 0
60 | self.shuffled_coco_tasks = torch.randperm(self.extend_coco_size)
61 |
62 | def train(self, test_dataloader, inner_params):
63 | # Validation loss
64 | log_dict=dict()
65 | self.meta_module.zero_grad()
66 | outputs, labels = get_pred(self.meta_module, test_dataloader, params=inner_params)
67 | loss=F.cross_entropy(outputs, labels)
68 | loss.backward()
69 | self.buffer.accumulate()
70 |
71 | if (self.iter % self.batch_size == 0):
72 | self.buffer.unload()
73 | self.meta_optimizer.step()
74 | if self.config["meta_grad_clip"] > 0:
75 | torch.nn.utils.clip_grad_norm_(self.meta_module.parameters(), self.config["meta_grad_clip"])
76 |
77 | def get_gradnorm(module):
78 | return np.sqrt(np.sum([p.grad.pow(2).sum().item() for p in module.parameters() if p.grad is not None])) if module is not None else -1
79 |
80 | log_dict["gradnorm_mnet"] = get_gradnorm(self.meta_module.mnet)
81 | log_dict["gradnorm_hnet"] = get_gradnorm(self.meta_module.hnet)
82 | log_dict["gradnorm_enet"] = get_gradnorm(self.meta_module.enet)
83 | self.iter = self.iter+1
84 |
85 | return log_dict
86 |
87 | def run_epoch(self, data, inner_epochs, inner_lr, meta_batch_size=1, train=False, second_order=False,
88 | train_subtype="train", val_subtype="test", keep_tasks_frac=1., extend_coco=False,
89 | extend_coco_frac_train=0.5, # frac of tasks to replace with extended coco
90 | debug=False, device=None, filter_tasks_by_max_k=None, filter_tasks_answers=None, n_shot_training=None, epoch=0):
91 |
92 | tasks = list(data.keys())
93 |
94 | if inner_epochs is not None:
95 | inner_epochs_range = [inner_epochs] if type(inner_epochs) == int else [int(i) for i in inner_epochs.split(",")]
96 | else:
97 | inner_epochs_range = None
98 |
99 | if train:
100 | self.reset(meta_batch_size, None)
101 | if extend_coco and self.coco_iter >= self.extend_coco_size:
102 | self.reset_coco()
103 |
104 | log_dict = dict()
105 |
106 | tasks_idxs = np.arange(len(tasks))
107 | if filter_tasks_by_max_k is not None:
108 | if not filter_tasks_answers:
109 | tasks_idxs = [i for i, t in enumerate(tasks) if np.min(np.unique([a for [_,a] in data[t][train_subtype]], return_counts=True)[1]) >= filter_tasks_by_max_k]
110 | else:
111 | data = copy.deepcopy(data)
112 | for t in tasks:
113 | ans, count = np.unique([a for [_,a] in data[t][train_subtype]], return_counts=True)
114 | filtered_ans = ans[count >= filter_tasks_by_max_k]
115 | data[t][train_subtype] = [d for d in data[t][train_subtype] if d[1] in filtered_ans]
116 | data[t][val_subtype] = [d for d in data[t][val_subtype] if d[1] in filtered_ans]
117 | tasks_idxs = [i for i, t in enumerate(tasks) if len(np.unique([a for [_,a] in data[t][train_subtype]])) >= 2]
118 |
119 | tasks = list(data.keys())
120 | if meta_batch_size > len(tasks_idxs)*keep_tasks_frac:
121 | meta_batch_size = int(len(tasks_idxs)*keep_tasks_frac)
122 | print("Warning: batch size too big, decreasing to {}".format(meta_batch_size))
123 |
124 | shuffled_train_tasks = [tasks_idxs[idx] for idx in torch.randperm(len(tasks_idxs))]
125 | shuffle_for_extended_coco_replace = torch.randperm(len(tasks_idxs))
126 |
127 | for inner_train_iter in tqdm(range(len(tasks_idxs))):
128 | curr_log_dict = dict()
129 | enable_coco = extend_coco and shuffle_for_extended_coco_replace[inner_train_iter] < extend_coco_frac_train * len(tasks_idxs)
130 | if enable_coco:
131 | train_dataset = COCO_Tasks(categories=self.coco_categories,
132 | dataSubType=train_subtype,
133 | image_features=self.image_features,
134 | coco_answer_features=self.coco_answer_features,
135 | task_seed=self.shuffled_coco_tasks[self.coco_iter])
136 | test_dataset = COCO_Tasks(categories=self.coco_categories,
137 | dataSubType=val_subtype,
138 | image_features=self.image_features,
139 | coco_answer_features=self.coco_answer_features,
140 | task_seed=self.shuffled_coco_tasks[self.coco_iter])
141 | else:
142 | task_idx = shuffled_train_tasks[inner_train_iter]
143 | if task_idx > len(tasks)*keep_tasks_frac:
144 | continue
145 |
146 | train_idx=None
147 | val_idx=None
148 | if val_subtype == "random":
149 | num_data = len(data[tasks[task_idx]]["train"] + data[tasks[task_idx]]["test"])
150 | randfrac = 2/3
151 | randperm = np.random.permutation(num_data)
152 | val_idx = randperm[:math.floor(num_data*randfrac)]
153 | train_idx = randperm[math.floor(num_data*randfrac):]
154 |
155 | train_dataset = CLIP_VQA(meta_data=data,
156 | dataSubType=train_subtype,
157 | task=tasks[task_idx],
158 | image_features=self.image_features,
159 | text_features=self.text_features,
160 | ques_emb=self.ques_emb,
161 | data_idx=train_idx,
162 | n_shot=n_shot_training)
163 | test_dataset = CLIP_VQA(meta_data=data,
164 | dataSubType=val_subtype,
165 | task=tasks[task_idx],
166 | image_features=self.image_features,
167 | text_features=self.text_features,
168 | ques_emb=self.ques_emb,
169 | data_idx=val_idx
170 | )
171 |
172 | train_dataloader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
173 | test_dataloader = DataLoader(test_dataset, batch_size=1024, shuffle=False)
174 |
175 | inner_params = self.meta_module.get_inner_params()
176 |
177 | task_embedding = iter(train_dataloader).next()["ques_emb"][0].to(device).detach()
178 | if self.config["use_clip_embedding_init"]:
179 | assert self.meta_module.hnet is not None
180 | inner_params["enet.embedding"] = task_embedding.requires_grad_()
181 |
182 | train_start_acc, train_start_loss = test_accuracy(self.meta_module, train_dataloader, params=inner_params)
183 | val_start_acc, val_start_loss = test_accuracy(self.meta_module, test_dataloader, params=inner_params)
184 |
185 | if inner_epochs_range is not None:
186 | inner_epochs_sampled = inner_epochs_range[0] if len(inner_epochs_range) == 1 else \
187 | np.random.randint(inner_epochs_range[0], inner_epochs_range[1]+1)
188 | else:
189 | inner_epochs_sampled = None
190 |
191 | # Inner loop
192 | for _ in range(inner_epochs_sampled):
193 | outputs, labels = get_pred(self.meta_module, train_dataloader, params=inner_params)
194 | inner_loss = F.cross_entropy(outputs, labels)
195 |
196 | if debug and inner_train_iter % meta_batch_size == 0:
197 | wandb.log({"debug_inner_loss": inner_loss.item()})
198 | grads = torch.autograd.grad(inner_loss, inner_params.values(), retain_graph=True,
199 | create_graph=True if train and second_order else False)
200 | params_next = OrderedDict()
201 | for (name, param), grad in zip(list(inner_params.items()), grads):
202 | params_next[name] = param - inner_lr * grad
203 | inner_params = params_next
204 |
205 | # Train set accuracy
206 | train_end_acc, train_end_loss = test_accuracy(self.meta_module, train_dataloader, params=inner_params)
207 | val_end_acc, val_end_loss = test_accuracy(self.meta_module, test_dataloader, params=inner_params)
208 |
209 | curr_log_dict["query_accuracy_start"] = val_start_acc
210 | curr_log_dict["query_accuracy_end"] = val_end_acc
211 | curr_log_dict["support_accuracy_start"] = train_start_acc
212 | curr_log_dict["support_accuracy_end"] = train_end_acc
213 | curr_log_dict["query_loss_start"] = val_start_loss
214 | curr_log_dict["query_loss_end"] = val_end_loss
215 | curr_log_dict["support_loss_start"] = train_start_loss
216 | curr_log_dict["support_loss_end"] = train_end_loss
217 |
218 | if train:
219 | curr_log_dict.update(self.train(test_dataloader, inner_params))
220 |
221 | append_dict(log_dict, curr_log_dict)
222 |
223 | if debug and inner_train_iter % meta_batch_size == 0:
224 | log_metric(mean_dict(log_dict), prefix = "debug_")
225 | log_dict=dict()
226 |
227 | if enable_coco:
228 | self.coco_iter += 1
229 |
230 | output_dict = mean_dict(log_dict)
231 | output_dict["epoch"] = epoch
232 |
233 | return output_dict
234 |
--------------------------------------------------------------------------------
/training/store_few_shot_latent.py:
--------------------------------------------------------------------------------
1 | from os import path
2 |
3 | import numpy as np
4 | import torch
5 |
6 | from training.conditional_model_learn import ConditionalModelTraining
7 |
8 |
9 | class StoreFewShotLatent(ConditionalModelTraining):
10 | def __init__(self, meta_module, image_features, text_features, ques_emb, config, device, compute_hessian=False,
11 | reset_normal_embedding=False, **kwargs):
12 | assert hasattr(meta_module, "hnet")
13 | feature_dim = np.sum([v.numel() for v in meta_module.get_inner_params().values()])
14 | super().__init__(meta_module, None, None, image_features, text_features, ques_emb, config, device, feature_dim,
15 | None, None, compute_hessian=compute_hessian, **kwargs)
16 | self.reset_normal_embedding=reset_normal_embedding
17 |
18 | def reset_task(self, guided_inner):
19 | if self.reset_normal_embedding:
20 | self.meta_module.enet.reset()
21 |
22 | def train_step(self, log_dict, log_file=None):
23 | clip_embedding=[]
24 | embedding=[]
25 | w_vect=[]
26 | task_idx=[]
27 | hessian=[]
28 | coco=[]
29 |
30 | if path.exists(log_file):
31 | matrix_dict = torch.load(log_file)
32 | clip_embedding.append(matrix_dict["clip_embedding"])
33 | coco.append(matrix_dict["coco"])
34 |
35 | if self.base_module.mnet.no_weight:
36 | embedding.append(matrix_dict["embedding"])
37 | else:
38 | w_vect.append(matrix_dict["w_vect"])
39 |
40 | task_idx.append(matrix_dict["task_idx"])
41 | if self.compute_hessian:
42 | hessian.append(matrix_dict["hessian"])
43 |
44 | clip_embedding.append(self.ques_batch.unsqueeze(0).cpu())
45 | coco.append(self.coco_batch.unsqueeze(0).cpu())
46 | if self.base_module.mnet.no_weight:
47 | embedding.append(self.net_batch.unsqueeze(0).cpu())
48 | else:
49 | w_vect.append(self.net_batch.unsqueeze(0).cpu())
50 |
51 | task_idx.append(self.task_batch.unsqueeze(0).cpu())
52 | if self.compute_hessian:
53 | hessian.append(self.hessian_batch.unsqueeze(0).cpu())
54 |
55 | matrix_dict=dict()
56 | matrix_dict["clip_embedding"]=torch.cat(clip_embedding, dim=0)
57 | matrix_dict["coco"]=torch.cat(coco, dim=0)
58 |
59 | if self.base_module.mnet.no_weight:
60 | matrix_dict["embedding"]=torch.cat(embedding, dim=0)
61 | else:
62 | matrix_dict["w_vect"]=torch.cat(w_vect, dim=0)
63 |
64 | matrix_dict["task_idx"]=torch.cat(task_idx, dim=0)
65 | if self.compute_hessian:
66 | matrix_dict["hessian"]=torch.cat(hessian, dim=0)
67 |
68 | print("Saving new tensor")
69 | torch.save(matrix_dict, log_file)
70 |
71 | def test_step(self, log_dict):
72 | pass
73 |
74 | def guided_inner(self, *args, **kwargs):
75 | pass
76 |
77 | def feed_net_feature(self, inner_params):
78 | if self.base_module.mnet.no_weight:
79 | return inner_params["enet.embedding"].detach()
80 | return self.base_module.get_mainnet_weights(params = inner_params).detach()
81 |
--------------------------------------------------------------------------------
/training/vae_learn.py:
--------------------------------------------------------------------------------
1 | import math
2 | from collections import OrderedDict
3 |
4 | import torch
5 | import wandb
6 | from data.dataloader.clip_vqa import CLIP_VQA
7 | from model.custom_hnet import CLIPAdapter, EmbeddingModule, MetaModel
8 | from torch.nn import functional as F
9 | from torch.utils.data import DataLoader
10 | from tqdm import tqdm
11 | from utils.misc_utils import append_dict, mean_dict
12 | from utils.train_utils import get_pred, log_metric, test_accuracy
13 |
14 |
15 | def sample_from_enc(encoder, real_imgs):
16 | vect = encoder(real_imgs)
17 | enc_dim = vect.shape[1] // 2
18 | mu, logvar = vect[:, :enc_dim], vect[:, enc_dim:] - 1
19 | std = torch.exp(0.5 * logvar)
20 | eps = torch.randn_like(std)
21 | encoding = eps * std + mu
22 | return encoding
23 |
24 |
25 | class VAETraining():
26 | def __init__(self, meta_module, hnet_gen, hnet_enc, optimizer_gen, optimizer_enc, image_features, text_features,
27 | ques_emb, config, device):
28 | self.meta_module = meta_module
29 | self.hnet_gen = hnet_gen
30 | self.hnet_enc = hnet_enc
31 | self.optimizer_gen = optimizer_gen
32 | self.optimizer_enc = optimizer_enc
33 |
34 | self.image_features = image_features
35 | self.text_features = text_features
36 | self.ques_emb = ques_emb
37 | self.config = config
38 | self.device=device
39 | self.feature_mean=None
40 | self.feature_std=None
41 |
42 | def set_stats(self, mean, std):
43 | self.feature_mean= mean
44 | self.feature_std=std
45 |
46 | def reset(self, batch_size, device):
47 | self.vae_sub_batch_i = 0
48 | self.weight_batch = torch.zeros(batch_size, self.hnet_enc.input_dim, dtype=torch.float32).to(device)
49 |
50 | def train(self, model_weights):
51 | log_dict=dict()
52 |
53 | self.weight_batch[self.vae_sub_batch_i] = model_weights
54 |
55 | if self.vae_sub_batch_i == self.weight_batch.shape[0]-1:
56 | real_imgs = self.weight_batch
57 | encoder=self.hnet_enc
58 | generator=self.hnet_gen
59 | kld_weight=self.config["kld_weight"]
60 |
61 | vect = encoder(real_imgs)
62 | enc_dim = vect.shape[1] // 2
63 | mu, logvar = vect[:, :enc_dim], vect[:, enc_dim:] - 1
64 | std = torch.exp(0.5 * logvar)
65 | eps = torch.randn_like(std)
66 | factor = math.sqrt(real_imgs.shape[-1])
67 |
68 | encoding = eps * std + mu
69 | generated_imgs = generator(encoding)
70 |
71 | recons = (generated_imgs - real_imgs)
72 |
73 | if self.feature_std is not None:
74 | recons = recons / self.feature_std
75 |
76 | recons_loss = recons.pow(2).mean(0).sum() / factor
77 |
78 | kld_loss = kld_weight * torch.mean(-0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp(), dim=1),
79 | dim=0) / factor
80 |
81 | loss = (recons_loss + kld_loss)
82 |
83 |
84 | self.optimizer_gen.zero_grad()
85 | self.optimizer_enc.zero_grad()
86 | loss.backward()
87 | log_dict = {"loss": loss.item(), "kld_loss": kld_loss.item(), "recons_loss": recons_loss.item(),
88 | "encoding_norm": encoding.pow(2).mean().sqrt().item(),
89 | "norm_fake": generated_imgs.norm(dim=1).mean(dim=0).item(),
90 | "norm_real": real_imgs.norm(dim=1).mean(dim=0).item(),
91 | "grad_norm_enc": torch.stack([p.grad.pow(2).sum() for p in encoder.parameters() if p.grad is not None]).sum().sqrt().item(),
92 | "grad_norm_gen": torch.stack([p.grad.pow(2).sum() for p in generator.parameters() if p.grad is not None]).sum().sqrt().item()}
93 |
94 |
95 | if self.config["grad_clip"] > 0:
96 | torch.nn.utils.clip_grad_norm_(encoder.parameters(), max_norm=self.config["grad_clip"])
97 | torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=self.config["grad_clip"])
98 |
99 | log_dict.update({"grad_norm_enc_clipped": torch.stack([p.grad.pow(2).sum() for p in encoder.parameters() if p.grad is not None]).sum().sqrt().item(),
100 | "grad_norm_gen_clipped": torch.stack([p.grad.pow(2).sum() for p in generator.parameters() if p.grad is not None]).sum().sqrt().item()})
101 |
102 | self.optimizer_gen.step()
103 | self.optimizer_enc.step()
104 |
105 | self.vae_sub_batch_i = (self.vae_sub_batch_i + 1) % self.weight_batch.shape[0]
106 | return log_dict
107 |
108 | def run_epoch(self, data, inner_epochs, inner_optim_fct, batch_size=32, device=None,
109 | train_subtype="train", val_subtype="test",
110 | train=False, generated=False, reconstructed=False, debug=False, precomputed_latent=None, output_mean_std=False, epoch=0):
111 | if train:
112 | self.reset(batch_size, device)
113 |
114 | log_dict = dict()
115 | tasks = list(data.keys())
116 | if precomputed_latent is None or "task_idx" not in precomputed_latent:
117 | shuffled_train_tasks = torch.randperm(len(tasks))
118 | else:
119 | shuffled_train_tasks=precomputed_latent["task_idx"][0]
120 |
121 | if output_mean_std:
122 | all_tasks_optimized_params = torch.zeros((len(tasks), self.hnet_enc.input_dim)).to(self.device)
123 |
124 | for inner_train_iter in tqdm(range(len(tasks))):
125 | curr_log_dict = dict()
126 | task_idx = shuffled_train_tasks[inner_train_iter]
127 | train_dataset = CLIP_VQA(meta_data=data,
128 | dataSubType=train_subtype,
129 | task=tasks[task_idx],
130 | image_features=self.image_features,
131 | text_features=self.text_features,
132 | ques_emb=self.ques_emb)
133 | test_dataset = CLIP_VQA(meta_data=data,
134 | dataSubType=val_subtype,
135 | task=tasks[task_idx],
136 | image_features=self.image_features,
137 | text_features=self.text_features,
138 | ques_emb=self.ques_emb)
139 |
140 | train_dataloader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
141 | test_dataloader = DataLoader(test_dataset, batch_size=1024, shuffle=False)
142 |
143 | # task_embedding = iter(train_dataloader).next()["ques_emb"][0]
144 | if generated or reconstructed:
145 | embedding = sample_from_enc(self.hnet_enc, self.meta_module.get_mainnet_weights().detach()).detach()
146 | enet = EmbeddingModule(self.hnet_gen.input_dim).to(device)
147 | enet.reset(embedding=embedding)
148 |
149 | mnet = CLIPAdapter(e_dim=self.meta_module.mnet.e_dim,
150 | hidden_layers=[self.meta_module.mnet.hidden_size],
151 | use_bias=self.meta_module.mnet.use_bias,
152 | straight_through=self.meta_module.mnet.straight_through,
153 | no_weights=True).to(device)
154 | meta_module=MetaModel(mnet=mnet, hnet=self.hnet_gen, enet=enet, config=self.config)
155 | elif False: #reconstructed:
156 | mnet = CLIPAdapter(e_dim=self.meta_module.mnet.e_dim,
157 | hidden_layers=[self.meta_module.mnet.hidden_size],
158 | use_bias=self.meta_module.mnet.use_bias,
159 | straight_through=self.meta_module.mnet.straight_through,
160 | no_weights=True).to(device)
161 | meta_module=MetaModel(mnet=mnet, hnet=self.hnet_gen, enet=None, config=self.config)
162 | else:
163 | meta_module = self.meta_module
164 |
165 | # Make sure to clone and detach to not optimize the actual initialization.
166 | inner_params = {k: v.clone().detach().requires_grad_() for (k,v) in meta_module.get_inner_params().items()}
167 |
168 | if reconstructed:
169 | tmp_inner_params = OrderedDict()
170 | tmp_inner_params["enet.embedding"] = sample_from_enc(self.hnet_enc, self.meta_module.get_mainnet_weights()).detach()
171 | inner_params=tmp_inner_params
172 | #task_embedding = sample_from_enc(self.hnet_enc, self.meta_module.get_mainnet_weights(params=inner_params)).detach()
173 |
174 | train_start_acc, train_start_loss = test_accuracy(meta_module, train_dataloader,
175 | params=inner_params) #, embedding = task_embedding)
176 | val_start_acc, val_start_loss = test_accuracy(meta_module, test_dataloader, params=inner_params) #, embedding = task_embedding)
177 |
178 | if precomputed_latent is None or generated:
179 | # Inner loop
180 | inner_optimizer = inner_optim_fct(list(inner_params.values()))
181 | for _ in range(inner_epochs):
182 | outputs, labels = get_pred(meta_module, train_dataloader, params=inner_params)
183 | inner_loss = F.cross_entropy(outputs, labels)
184 | if debug and inner_train_iter % batch_size == 0:
185 | wandb.log({"debug_inner_loss": inner_loss.item()})
186 | inner_optimizer.zero_grad()
187 | inner_loss.backward()
188 | inner_optimizer.step()
189 | else:
190 | if self.meta_module.mnet.no_weight:
191 | inner_params["enet.embedding"] = precomputed_latent["embedding"][0, inner_train_iter]
192 | else:
193 | inner_params.update({"mnet."+k:v for (k,v) in self.meta_module.mnet.load_from_vector(precomputed_latent["w_vect"][0, inner_train_iter]).items()})
194 |
195 | if reconstructed:
196 | tmp_inner_params = OrderedDict()
197 | tmp_inner_params["enet.embedding"] = sample_from_enc(self.hnet_enc, self.meta_module.get_mainnet_weights(params=inner_params)).detach()
198 | inner_params=tmp_inner_params
199 | #task_embedding = sample_from_enc(self.hnet_enc, self.meta_module.get_mainnet_weights(params=inner_params)).detach()
200 |
201 | # Train set accuracy
202 | train_end_acc, train_end_loss = test_accuracy(meta_module, train_dataloader, params=inner_params) #, embedding = task_embedding)
203 | val_end_acc, val_end_loss = test_accuracy(meta_module, test_dataloader, params=inner_params) #, embedding = task_embedding)
204 |
205 | curr_log_dict["query_accuracy_start"] = val_start_acc
206 | curr_log_dict["query_accuracy_end"] = val_end_acc
207 | curr_log_dict["support_accuracy_start"] = train_start_acc
208 | curr_log_dict["support_accuracy_end"] = train_end_acc
209 | curr_log_dict["query_loss_start"] = val_start_loss
210 | curr_log_dict["query_loss_end"] = val_end_loss
211 | curr_log_dict["support_loss_start"] = train_start_loss
212 | curr_log_dict["support_loss_end"] = train_end_loss
213 |
214 | if train:
215 | vae_dict = self.train(self.meta_module.get_mainnet_weights(params=inner_params).detach())
216 | curr_log_dict.update(vae_dict)
217 |
218 | append_dict(log_dict, curr_log_dict)
219 |
220 | if debug and inner_train_iter % batch_size == 0:
221 | log_metric(mean_dict(log_dict), prefix="debug_")
222 | log_dict = dict()
223 |
224 | if output_mean_std:
225 | all_tasks_optimized_params[inner_train_iter] = self.meta_module.get_mainnet_weights(params=inner_params).detach()
226 |
227 | output_dict = mean_dict(log_dict)
228 | output_dict["epoch"] = epoch
229 |
230 | if not output_mean_std:
231 | return output_dict
232 | else:
233 | optimized_params_mean = torch.mean(all_tasks_optimized_params, dim=0, keepdim=True)
234 | optimized_params_std = torch.std(all_tasks_optimized_params, dim=0, keepdim=True)
235 | return output_dict, optimized_params_mean, optimized_params_std, all_tasks_optimized_params
236 |
237 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elvisnava/hyperclip/8574d3d36fbe1bb3311c3cbb214f07fd73ca0a05/utils/__init__.py
--------------------------------------------------------------------------------
/utils/build_opt.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | supported_optims = ["adam", "amsgrad", "sgd", "rmsprop", "adamw"]
4 |
5 |
6 | def build_optimizer(parameters, config, loop="inner"):
7 | # loop should be either "inner" or "hyperclip"
8 | in_str = f"{loop}_"
9 | optim = config[in_str+"optimizer"]
10 | lr = config[in_str+"learning_rate"]
11 | weight_decay = config.get(in_str+"weight_decay", 0)
12 | adam_beta1 = config.get(in_str+"adam_beta1", 0.9)
13 | adam_beta2 = config.get(in_str+"adam_beta2", 0.999)
14 | momentum = config.get(in_str+"momentum", 0)
15 | sgd_dampening = config.get(in_str+"sgd_dampening", False)
16 | sgd_nesterov = config.get(in_str+"sgd_nesterov", False)
17 | rmsprop_alpha = config.get(in_str+"rmsprop_alpha", 0.99)
18 |
19 | if optim not in supported_optims:
20 | raise ValueError("Unsupported optim: {}. Must be one of {}".format(optim, supported_optims))
21 |
22 | if optim == "adam":
23 | optimizer = torch.optim.Adam(
24 | parameters,
25 | lr=lr,
26 | weight_decay=weight_decay,
27 | betas=(adam_beta1, adam_beta2),
28 | )
29 |
30 | elif optim == "amsgrad":
31 | optimizer = torch.optim.Adam(
32 | parameters,
33 | lr=lr,
34 | weight_decay=weight_decay,
35 | betas=(adam_beta1, adam_beta2),
36 | amsgrad=True,
37 | )
38 |
39 | elif optim == "sgd":
40 | optimizer = torch.optim.SGD(
41 | parameters,
42 | lr=lr,
43 | momentum=momentum,
44 | weight_decay=weight_decay,
45 | dampening=sgd_dampening,
46 | nesterov=sgd_nesterov,
47 | )
48 |
49 | elif optim == "rmsprop":
50 | optimizer = torch.optim.RMSprop(
51 | parameters,
52 | lr=lr,
53 | momentum=momentum,
54 | weight_decay=weight_decay,
55 | alpha=rmsprop_alpha,
56 | )
57 |
58 | elif optim == "adamw":
59 | optimizer = torch.optim.AdamW(
60 | parameters,
61 | lr=lr,
62 | weight_decay=weight_decay,
63 | betas=(adam_beta1, adam_beta2),
64 | )
65 |
66 | return optimizer
67 |
--------------------------------------------------------------------------------
/utils/clip_utils.py:
--------------------------------------------------------------------------------
1 |
2 | embedding_size = {
3 |
4 | 'RN50': 1024,
5 | 'RN101': 512,
6 | 'RN50x4': 640,
7 | 'RN50x16': 768,
8 | 'RN50x64': 1024,
9 | 'ViT-B/32': 512,
10 | 'ViT-B/16': 512,
11 | 'ViT-L/14': 768,
12 | 'ViT-L/14@336px': 768
13 |
14 | }
15 |
16 | image_resolution = {
17 |
18 | 'RN50': 224,
19 | 'RN101': 224,
20 | 'RN50x4': 288,
21 | 'RN50x16': 384,
22 | 'RN50x64': 448,
23 | 'ViT-B/32': 224,
24 | 'ViT-B/16': 224,
25 | 'ViT-L/14': 224,
26 | 'ViT-L/14@336px': 336
27 |
28 | }
29 |
30 | cached_location = {
31 |
32 | 'RN50': "~/.cache/clip/RN50.pt",
33 | 'RN101': "~/.cache/clip/RN101.pt",
34 | 'RN50x4': "~/.cache/clip/RN50x4.pt",
35 | 'RN50x16': "~/.cache/clip/RN50x16.pt",
36 | 'RN50x64': "~/.cache/clip/RN50x64.pt",
37 | 'ViT-B/32': "~/.cache/clip/ViT-B-32.pt",
38 | 'ViT-B/16': "~/.cache/clip/ViT-B-16.pt",
39 | 'ViT-L/14': "~/.cache/clip/ViT-L-14.pt",
40 | 'ViT-L/14@336px': "~/.cache/clip/ViT-L-14-336px.pt"
41 |
42 | }
--------------------------------------------------------------------------------
/utils/diffusion_utils.py:
--------------------------------------------------------------------------------
1 | # adopted from
2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3 | # and
4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5 | # and
6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7 | # through
8 | # https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/diffusionmodules/util.py
9 |
10 | import math
11 |
12 | import numpy as np
13 | import torch
14 | from einops import repeat
15 |
16 |
17 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
18 | if schedule == "linear":
19 | betas = (
20 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
21 | )
22 |
23 | elif schedule == "cosine":
24 | timesteps = (
25 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
26 | )
27 | alpha_cums = timesteps / (1 + cosine_s) * np.pi / 2
28 | alpha_cums = torch.cos(alpha_cums).pow(2)
29 | alpha_cums = alpha_cums / alpha_cums[0]
30 | betas = 1 - alpha_cums[1:] / alpha_cums[:-1]
31 | betas = np.clip(betas, a_min=0, a_max=0.999)
32 |
33 | elif schedule == "sqrt_linear":
34 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
35 | elif schedule == "sqrt":
36 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
37 | elif schedule == "quad":
38 | betas = (
39 | torch.linspace(linear_start ** 0.25, linear_end ** 0.25, n_timestep, dtype=torch.float64) ** 4
40 | )
41 | else:
42 | raise ValueError(f"schedule '{schedule}' unknown.")
43 | return betas.numpy()
44 |
45 |
46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
47 | if ddim_discr_method == 'uniform':
48 | c = num_ddpm_timesteps // num_ddim_timesteps
49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
50 | elif ddim_discr_method == 'quad':
51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
52 | else:
53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
54 |
55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps
56 | # add one to get the final alpha values right (the ones from first scale to data during sampling)
57 | steps_out = ddim_timesteps + 1
58 | if verbose:
59 | print(f'Selected timesteps for ddim sampler: {steps_out}')
60 | return steps_out
61 |
62 |
63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
64 | # select alphas for computing the variance schedule
65 | alphas = alphacums[ddim_timesteps]
66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
67 |
68 | # according the the formula provided in https://arxiv.org/abs/2010.02502
69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
70 | if verbose:
71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
72 | print(f'For the chosen value of eta, which is {eta}, '
73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
74 | return sigmas, alphas, alphas_prev
75 |
76 |
77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
78 | """
79 | Create a beta schedule that discretizes the given alpha_t_bar function,
80 | which defines the cumulative product of (1-beta) over time from t = [0,1].
81 | :param num_diffusion_timesteps: the number of betas to produce.
82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
83 | produces the cumulative product of (1-beta) up to that
84 | part of the diffusion process.
85 | :param max_beta: the maximum beta to use; use values lower than 1 to
86 | prevent singularities.
87 | """
88 | betas = []
89 | for i in range(num_diffusion_timesteps):
90 | t1 = i / num_diffusion_timesteps
91 | t2 = (i + 1) / num_diffusion_timesteps
92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
93 | return np.array(betas)
94 |
95 |
96 | def extract_into_tensor(a, t, x_shape):
97 | b, *_ = t.shape
98 | out = a.gather(-1, t)
99 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
100 |
101 |
102 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
103 | """
104 | Create sinusoidal timestep embeddings.
105 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
106 | These may be fractional.
107 | :param dim: the dimension of the output.
108 | :param max_period: controls the minimum frequency of the embeddings.
109 | :return: an [N x dim] Tensor of positional embeddings.
110 | """
111 | if not repeat_only:
112 | half = dim // 2
113 | freqs = torch.exp(
114 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
115 | ).to(device=timesteps.device)
116 | args = timesteps[:, None].float() * freqs[None]
117 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
118 | if dim % 2:
119 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
120 | else:
121 | embedding = repeat(timesteps, 'b -> b d', d=dim)
122 | return embedding
123 |
124 | class EMA:
125 | def __init__(self, model, decay):
126 | self.model = model
127 | self.decay = decay
128 | self.mem = { }
129 | with torch.no_grad():
130 | for p in model.parameters():
131 | self.mem[p] = p.clone()
132 |
133 | def step(self):
134 | with torch.no_grad():
135 | for p in self.model.parameters():
136 | self.mem[p].copy_(self.decay * self.mem[p] + (1 - self.decay) * p)
137 |
138 | def copy_to_model(self):
139 | with torch.no_grad():
140 | for p in self.model.parameters():
141 | p.copy_(self.mem[p])
--------------------------------------------------------------------------------
/utils/init_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import wandb
5 | from model.custom_hnet import (CLIPAdapter, EmbeddingModule, HyperEncoder,
6 | HyperGenerator, MetaModel)
7 | from training.vae_learn import sample_from_enc
8 |
9 | base_path = os.path.dirname(os.path.dirname(__file__))
10 |
11 | def load_few_shot_to_metamodel(run, device):
12 | loaded_config = run.config
13 | meta_module = MetaModel(
14 | inner_param=loaded_config["inner_param"],
15 | mainnet_use_bias=loaded_config["mainnet_use_bias"],
16 | mainnet_hidden_dim=loaded_config["mainnet_hidden_dim"],
17 | hypernet_hidden_dim=[] if loaded_config["hypernet_hidden_dim"]=="" else [int(i) for i in loaded_config["hypernet_hidden_dim"].split(",")],
18 | embedding_dim=loaded_config["embedding_dim"],
19 | straight_through=loaded_config["straight_through"],
20 | config=loaded_config).to(device)
21 | loaded_model_path = base_path + "/evaluation/few_shot/meta_module_" + str(run.name) + ".pth"
22 | meta_module.load_state_dict(torch.load(loaded_model_path), strict=False)
23 | return meta_module
24 |
25 | def load_vae(run, device):
26 | config=run.config
27 |
28 | api = wandb.Api()
29 | if "precompute_checkpoint" in config:
30 | print("Loading from precomputed run {}".format(config["precompute_checkpoint"]))
31 | precomp_run = api.run(config["precompute_checkpoint"])
32 | precomp_config = precomp_run.config
33 | config.update({"few_shot_checkpoint": precomp_config["few_shot_checkpoint"]}, allow_val_change=True)
34 |
35 | loaded_run = api.run(config["few_shot_checkpoint"])
36 | tmp_meta_module = load_few_shot_to_metamodel(loaded_run, device)
37 |
38 | hidden_dims = [int(h) for h in config["vae_hidden_dim"].split(",")]
39 | hnet_enc = HyperEncoder(tmp_meta_module.mnet, e_dim=config["vae_noise_dim"],
40 | hidden_dims=hidden_dims, normalize="normalize" in config and config["normalize"]).to(device)
41 | hidden_dims.reverse()
42 | hnet_gen = HyperGenerator(tmp_meta_module.mnet, e_dim=config["vae_noise_dim"],
43 | hidden_dims=hidden_dims, normalize="normalize" in config and config["normalize"]).to(device)
44 |
45 | loaded_model_path = base_path + "/evaluation/vae/hnet_gen_" + str(run.name) + ".pth"
46 | hnet_gen.load_state_dict(torch.load(loaded_model_path), strict=False)
47 | loaded_model_path = base_path + "/evaluation/vae/hnet_enc_" + str(run.name) + ".pth"
48 | hnet_enc.load_state_dict(torch.load(loaded_model_path), strict=False)
49 |
50 | return hnet_gen, hnet_enc, tmp_meta_module
51 |
52 | def load_vae_to_metamodel(run, device):
53 | hnet_gen, hnet_enc, tmp_meta_module = load_vae(run, device)
54 |
55 | embedding = sample_from_enc(hnet_enc, tmp_meta_module.get_mainnet_weights(params = None).detach()).detach()
56 | enet = EmbeddingModule(hnet_gen.input_dim).to(device)
57 | enet.reset(embedding=embedding)
58 |
59 | mnet = CLIPAdapter(e_dim=tmp_meta_module.mnet.e_dim,
60 | hidden_layers=[tmp_meta_module.mnet.hidden_size],
61 | use_bias=tmp_meta_module.mnet.use_bias,
62 | straight_through=tmp_meta_module.mnet.straight_through,
63 | no_weights=True).to(device)
64 | meta_module = MetaModel(mnet=mnet, hnet=hnet_gen, enet=enet, config=run.config)
65 |
66 | return meta_module
67 |
68 | def load_metamodel_from_checkpoint(config, device):
69 | precomputed_latent = None
70 | precomputed_latent_train_eval = None
71 | precomputed_latent_val_eval = None
72 | api = wandb.Api()
73 | if "precompute_checkpoint" in config:
74 | print("Loading from precomputed run {}".format(config["precompute_checkpoint"]))
75 | precomp_run = api.run(config["precompute_checkpoint"])
76 | precomp_config = precomp_run.config
77 | config.update({"few_shot_checkpoint": None if "few_shot_checkpoint" not in precomp_config else precomp_config["few_shot_checkpoint"]}, allow_val_change=True)
78 | config.update({"vae_checkpoint": None if "vae_checkpoint" not in precomp_config else precomp_config["vae_checkpoint"]}, allow_val_change=True)
79 |
80 | precomputed_file = base_path + "/evaluation/precompute_adaptation/" + str(precomp_run.name) + ".pth"
81 | precomputed_latent = torch.load(precomputed_file) #{k:v.to(device) for (k,v) in torch.load(precomputed_file).items()}
82 |
83 | try:
84 | precomputed_file_train_eval = base_path + "/evaluation/precompute_adaptation/" + str(precomp_run.name) + "_train_eval.pth"
85 | precomputed_file_val_eval = base_path + "/evaluation/precompute_adaptation/" + str(precomp_run.name) + "_val_eval.pth"
86 | precomputed_latent_train_eval = {k:v.to(device) for (k,v) in torch.load(precomputed_file_train_eval).items()}
87 | precomputed_latent_val_eval = {k:v.to(device) for (k,v) in torch.load(precomputed_file_val_eval).items()}
88 | except:
89 | print("Did not find eval latent")
90 |
91 | if "few_shot_checkpoint" in config and config["few_shot_checkpoint"] is not None:
92 | print("Creating MetaModule from FewShot trained model")
93 | loaded_run = api.run(config["few_shot_checkpoint"])
94 | meta_module = load_few_shot_to_metamodel(loaded_run, device)
95 | elif "vae_checkpoint" in config and config["vae_checkpoint"] is not None:
96 | print("Creating MetaModule from VAE model")
97 | loaded_run = api.run(config["vae_checkpoint"])
98 | meta_module = load_vae_to_metamodel(loaded_run, device)
99 | else:
100 | return NotImplementedError("something's wrong")
101 |
102 | return meta_module, precomputed_latent, precomputed_latent_train_eval, precomputed_latent_val_eval
103 |
104 | def load_vae_and_metamodel_from_checkpoint(config, device):
105 | api = wandb.Api()
106 | assert "vae_checkpoint" in config and config["vae_checkpoint"] is not None
107 | loaded_run = api.run(config["vae_checkpoint"])
108 | meta_module, precomputed_latent, precomputed_latent_train_eval, precomputed_latent_val_eval = \
109 | load_metamodel_from_checkpoint(loaded_run.config, device)
110 | hnet_gen, hnet_enc, _ = load_vae(loaded_run, device)
111 |
112 | return meta_module, hnet_gen, hnet_enc, precomputed_latent, precomputed_latent_train_eval, precomputed_latent_val_eval
113 |
--------------------------------------------------------------------------------
/utils/misc_utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import numpy as np
4 |
5 |
6 | def str2bool(v):
7 | if isinstance(v, bool):
8 | return v
9 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
10 | return True
11 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
12 | return False
13 | else:
14 | raise argparse.ArgumentTypeError('Boolean value expected.')
15 |
16 | def append_dict(dictionary, new_dict):
17 | for key in new_dict:
18 | if key not in dictionary:
19 | dictionary[key]=[]
20 | dictionary[key].append(new_dict[key])
21 |
22 | def mean_dict(dictionary):
23 | out_dict = dict()
24 | for key in dictionary:
25 | if isinstance(dictionary[key], list):
26 | if isinstance(dictionary[key][0], list):
27 | out_dict[key] = [np.mean([dictionary[key][j][i] for j in range(len(dictionary[key]))]) for i in range(len(dictionary[key][0]))]
28 | else:
29 | out_dict[key] = np.mean(np.array(dictionary[key]))
30 | else:
31 | out_dict[key] = dictionary[key]
32 | return out_dict
33 |
--------------------------------------------------------------------------------
/utils/train_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import wandb
3 | from sklearn.metrics import accuracy_score
4 | from torch.nn import functional as F
5 | import copy
6 | from utils.misc_utils import append_dict
7 | import numpy as np
8 |
9 | def get_pred( meta_module, dataloader, params=None, embedding=None):
10 | y_pred = []
11 | y_true = []
12 | for sample in dataloader:
13 | sample_image_features = sample["image_features"]
14 | sample_text_features = sample["text_features"]
15 | if embedding is None:
16 | embedding = sample["ques_emb"][0]
17 | labels = sample["label"].to(sample_image_features.device)
18 | similarity = meta_module(sample_image_features, sample_text_features, embedding, params=params)
19 | y_pred.append(similarity)
20 | y_true.append(labels)
21 |
22 | return torch.cat(y_pred), torch.cat(y_true)
23 |
24 | def test_accuracy( meta_module, dataloader, params=None, embedding=None, params_list=None):
25 | # Validation inner-loop testing
26 | meta_module.eval()
27 |
28 | with torch.no_grad():
29 | if params_list is not None:
30 | output_list = []
31 | for p in params_list:
32 | output, y_true = get_pred(meta_module, dataloader, params=p, embedding=embedding)
33 | output_list.append(output)
34 | output = torch.stack(output_list).mean(0)
35 | else:
36 | output, y_true = get_pred(meta_module, dataloader, params=params, embedding=embedding)
37 | _, y_pred = output.topk(1)
38 | loss = F.cross_entropy(output, y_true)
39 |
40 | acc = accuracy_score(y_true.cpu().numpy(), y_pred.cpu().numpy())
41 | meta_module.train()
42 | return acc, loss.item()
43 |
44 | def log_metric(log_dict, prefix=""):
45 | prefixed_dict = dict()
46 | for key in log_dict:
47 | if not isinstance(log_dict[key], list):
48 | prefixed_dict[prefix+key] = log_dict[key]
49 |
50 | wandb.log(prefixed_dict)
51 |
52 | for key in log_dict:
53 | if isinstance(log_dict[key], list):
54 | for i in range(len(log_dict[key])):
55 | wandb.log({prefix+key: log_dict[key][i]})
56 |
57 | def n_shot_trials_run(model_training, n_shot_trial_dict, config, log_name, data, *args, **kwargs):
58 | if "n_shot_trials_maxN" in config and config["n_shot_trials_maxN"] is not None:
59 | filtered_data, tasks_idx = filter_data_by_max_n_shot(data, config["n_shot_trials_maxN"])
60 | for n_shot in range(1, config["n_shot_trials_maxN"]+1):
61 | n_shot_trial_dict[f"{log_name}n_shot_{n_shot}/"] = {}
62 | log_dict = model_training.run_epoch(filtered_data, *args, n_shot_training=n_shot, tasks_idxs=tasks_idx, **kwargs)
63 | append_dict(n_shot_trial_dict[f"{log_name}n_shot_{n_shot}/"], log_dict)
64 | # full few shot
65 | log_dict = model_training.run_epoch(data, *args, n_shot_training="full", **kwargs)
66 | log_metric(log_dict, f"{log_name}few_shot/")
67 |
68 | def filter_data_by_fraction(data, keep_tasks_frac):
69 | task_idx = np.range(len(list(data.keys()))*keep_tasks_frac)
70 | return data, task_idx
71 |
72 | def filter_data_by_max_n_shot(data, max_k):
73 | data = copy.deepcopy(data)
74 | tasks_idx =[]
75 | for i,t in enumerate(list(data.keys())):
76 | ans, count = np.unique([a for [_, a] in data[t]["train"]], return_counts=True)
77 | filtered_ans = ans[count >= max_k]
78 | data[t]["train"] = [d for d in data[t]["train"] if d[1] in filtered_ans]
79 | data[t]["test"] = [d for d in data[t]["test"] if d[1] in filtered_ans]
80 | if len(ans)>=2:
81 | tasks_idx.append(i)
82 | return data, tasks_idx
83 |
84 | def update_best_val(best_val_dict, meta_epoch, log_dict):
85 | if best_val_dict["best_val_accuracy"] < log_dict["query_accuracy_end"]:
86 | best_val_dict["best_val_accuracy"] = log_dict["query_accuracy_end"]
87 | best_val_dict["best_val_epoch"] = meta_epoch
88 |
89 | def run_evals(model_training, train_data, test_data, n_shot_trial_dict, config, log_name, *args, skip_train=False, n_shot_from_opt_latent=False, best_val_dict=None, precomputed_latent_train=None, precomputed_latent_val=None, guided_inner=False, use_vae=False, **kwargs):
90 | # Eval on train data
91 | if not skip_train:
92 | log_dict = model_training.run_epoch(train_data, *args, guided_inner=guided_inner, keep_tasks_frac=config["diffusion_keep_tasks_frac"], precomputed_latent=precomputed_latent_train, use_vae=use_vae if precomputed_latent_train is None else False, **kwargs)
93 | log_metric(log_dict, log_name.format("train"))
94 |
95 | # Eval on test data
96 | log_dict, _, _, opt_latents = model_training.run_epoch(test_data, *args, guided_inner=guided_inner, precomputed_latent=precomputed_latent_val, use_vae=use_vae if precomputed_latent_val is None else False, output_mean_std=True, **kwargs)
97 | log_metric(log_dict, log_name.format("val"))
98 | if best_val_dict is not None:
99 | update_best_val(best_val_dict, kwargs["epoch"], log_dict)
100 | log_metric(best_val_dict)
101 |
102 | # N-shot eval on test data
103 | n_shot_trials_run(model_training, n_shot_trial_dict, config, log_name.format("val"), test_data, *args, opt_latents_for_n_shot=opt_latents if n_shot_from_opt_latent else None, use_vae=use_vae, **kwargs)
--------------------------------------------------------------------------------
/utils/wandb_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import wandb
4 |
5 | base_path = os.path.dirname(os.path.dirname(__file__))
6 |
7 | def populate_wandb_table(wandb_val_table, test_data, test_tasks, val_task_id, y_pred, y_true):
8 | ans = test_data[test_tasks[val_task_id]]["answers"]
9 | for i in range(len(y_pred)):
10 | image_name = test_data[test_tasks[val_task_id]]["test"][i][0]
11 | image_folder = image_name.split("_")[1]
12 | wandb_val_table.add_data(
13 | f"{val_task_id}_{i}",
14 | test_tasks[val_task_id],
15 | wandb.Image(base_path + f'/data/VQA/Images/{image_folder}/{image_name}.jpg'),
16 | ans[y_pred[i]],
17 | ans[y_true[i]],
18 | int(y_true[i]==y_pred[i]))
19 |
--------------------------------------------------------------------------------