├── .gitignore ├── LICENSE ├── README.md ├── __main__.py ├── data ├── VQA │ ├── Annotations │ │ └── prompt_meta.json │ └── Meta │ │ ├── meta_test.json │ │ ├── meta_train.json │ │ ├── meta_traintest.json │ │ └── ques_ans_count.json ├── __init__.py ├── dataloader │ ├── __init__.py │ ├── clip_vqa.py │ └── coco_tasks.py └── helper │ ├── PythonHelperTools │ ├── __init__.py │ ├── generate_metadata.py │ ├── prompt_meta.txt │ ├── questions.txt │ ├── unique_question_train.txt │ ├── unique_question_val.txt │ └── vqaTools │ │ ├── __init__.py │ │ └── vqa.py │ ├── QuestionTypes │ ├── abstract_v002_question_types.txt │ └── mscoco_question_types.txt │ └── __init__.py ├── environment.yml ├── features ├── image_features.py ├── ques_features.py └── text_features.py ├── model ├── custom_hnet.py ├── hyperclip.py └── latent_diffuser.py ├── scripts ├── __init__.py ├── hyperclip_classification_test.py ├── precompute_adaptation.py ├── precompute_image_features.py ├── precompute_ques_features.py ├── precompute_text_features.py ├── train_few_shot.py ├── train_hyperclip.py ├── train_latent_diffusion.py └── train_vae.py ├── setup.py ├── training ├── conditional_model_learn.py ├── hyperclip_learn.py ├── latent_diffusion_learn.py ├── maml_learn.py ├── store_few_shot_latent.py └── vae_learn.py └── utils ├── __init__.py ├── build_opt.py ├── clip_utils.py ├── diffusion_utils.py ├── init_utils.py ├── misc_utils.py ├── train_utils.py └── wandb_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .ipynb_checkpoints 3 | *.npy 4 | *.pyc 5 | .idea 6 | .vscode 7 | wandb/ 8 | *.egg-info/ 9 | *.swp 10 | 11 | data/VQA/Features/ 12 | data/VQA/Images/val2014/ 13 | data/VQA/Images/train2014/ 14 | data/VQA/Questions/ 15 | data/VQA/Annotations/ 16 | data/helper/Results/ 17 | data/VQA/Features/train2014/ 18 | data/VQA/Features/val2014/ 19 | evaluation/ckp_adapter/ 20 | evaluation/results/old/ 21 | evaluation/ckp_reptile/ 22 | evaluation/results/ 23 | evaluation/precompute_adaptation/ 24 | evaluation/diffusion/ 25 | evaluation/hyperclip/ 26 | evaluation/hypergan/ 27 | evaluation/maml/ 28 | evaluation/few_shot/ 29 | evaluation/vae/ 30 | scripts/scratch/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Elvis Nava and Seijin Kobayashi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Meta-Learning via Classifier(-free) Diffusion Guidance 2 | 3 | [arxiv](https://arxiv.org/abs/2210.08942) | [BibTeX](#citation) 4 | 5 | **Meta-Learning via Classifier(-free) Diffusion Guidance**
6 | [Elvis Nava](https://github.com/elvisnava)\*, 7 | [Seijin Kobayashi](https://github.com/seijin-kobayashi)\*, 8 | Yifei Yin, 9 | Robert K. Katzschmann, 10 | Benjamin F. Grewe
11 | \* equal contribution 12 | 13 | # Installation 14 | 15 | The `hyperclip` conda environment can be created with the following commands: 16 | ``` 17 | conda env create -f environment.yml 18 | conda activate hyperclip 19 | pip install git+https://github.com/openai/CLIP.git 20 | conda install pytorch cudatoolkit=11.3 -c pytorch 21 | pip install -e . 22 | ``` 23 | 24 | To setup Weights and Biases run 25 | ``` 26 | wandb login 27 | ``` 28 | and paste your W&B API key. 29 | 30 | # Meta-VQA Dataset 31 | 32 | To re-compute the Meta-VQA dataset, first download the [original VQA v2 dataset](https://visualqa.org/download.html) and place it in the `data/VQA/` folder, and then run (while in the `hyperclip` environment): 33 | ``` 34 | python scripts/precompute_image_features.py 35 | python scripts/precompute_ques_features.py 36 | python scripts/precompute_text_features.py 37 | ``` 38 | to re-generate the pre-computed CLIP embeddings for images, task questions and answers. 39 | 40 | # Experiment scripts 41 | 42 | To train multitask/MAML baselines or an unconditional Hypernetwork generative model (to later use as basis for conditional generation), use the script: 43 | ``` 44 | python scripts/train_few_shot.py [...] 45 | ``` 46 | 47 | To train a number of our models, we first need to prepare a precomputed "dataset" of fine-tuned networks/hnet latents/vae latents. We can do so with the script: 48 | ``` 49 | python scripts/precompute_adaptation.py (--few_shot_checkpoint | --vae_checkpoint ) [...] 50 | ``` 51 | 52 | In order to train the unconditional VAE hypernetwork (alternative to the previous HNET as basis for conditional generation methods), use the script: 53 | ``` 54 | python scripts/train_vae.py --precompute_checkpoint [...] 55 | ``` 56 | 57 | To train the HyperCLIP encoder (either from precomputed VAE/HNET fine-tunings, a VAE, or an HNET), use the script: 58 | ``` 59 | python scripts/train_hyperclip.py (--precompute_checkpoint | --vae_checkpoint | --few_shot_checkpoint ) [...] 60 | ``` 61 | 62 | To train a hypernetwork latent diffusion model (HyperLDM), use the script: 63 | ``` 64 | python scripts/train_latent_diffusion.py (--precompute_checkpoint | --vae_checkpoint | --few_shot_checkpoint ) [...] 65 | ``` 66 | 67 | 68 | # Citation 69 | ``` 70 | @misc{nava_meta-learning_2022, 71 | title = {Meta-{Learning} via {Classifier}(-free) {Diffusion} {Guidance}}, 72 | url = {http://arxiv.org/abs/2210.08942}, 73 | doi = {10.48550/arXiv.2210.08942}, 74 | publisher = {arXiv}, 75 | author = {Nava, Elvis and Kobayashi, Seijin and Yin, Yifei and Katzschmann, Robert K. and Grewe, Benjamin F.}, 76 | month = oct, 77 | year = {2022}, 78 | note = {arXiv:2210.08942 [cs]}, 79 | keywords = {Machine Learning (cs.LG), FOS: Computer and information sciences}, 80 | copyright = {arXiv.org perpetual, non-exclusive license} 81 | } 82 | ``` 83 | -------------------------------------------------------------------------------- /__main__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elvisnava/hyperclip/8574d3d36fbe1bb3311c3cbb214f07fd73ca0a05/__main__.py -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/dataloader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elvisnava/hyperclip/8574d3d36fbe1bb3311c3cbb214f07fd73ca0a05/data/dataloader/__init__.py -------------------------------------------------------------------------------- /data/dataloader/clip_vqa.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function 2 | 3 | import math 4 | 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import Dataset 8 | 9 | 10 | class CLIP_VQA(Dataset): 11 | 12 | def __init__(self, meta_data, dataSubType, task , image_features, text_features, ques_emb, n_shot=None, n_shot_seed=42, 13 | data_idx=None): 14 | """ 15 | Args: 16 | meta_data(string): Path to the meta learning data file 17 | dataSubType(string): train/test/traintest 18 | task(string): question 19 | image_features(dict): image features dict loaded from features.image_features 20 | text_features(dict): text features dict loaded from features.text_features 21 | ques_emb(dict): question feature dict loaded from features.ques_features 22 | n_shot(int): number of examples per class 23 | n_shot_seed(int): random seed for n_shot sampling 24 | """ 25 | self.dataSubType = dataSubType 26 | self.task = task 27 | self.answers = meta_data[self.task]["answers"] 28 | self.image_features = image_features 29 | self.text_features = text_features 30 | self.ques_emb = ques_emb 31 | if data_idx is not None: 32 | self.data = meta_data[self.task]["train"] + meta_data[self.task]["test"] 33 | self.data = [self.data[i] for i in data_idx] 34 | elif dataSubType in ["train","test"]: 35 | self.data = meta_data[self.task][self.dataSubType] 36 | elif dataSubType == "traintest": 37 | self.data = meta_data[self.task]["train"] + meta_data[self.task]["test"] 38 | elif dataSubType == "random": 39 | self.data = meta_data[self.task]["train"] + meta_data[self.task]["test"] 40 | frac = torch.rand(())/3.+2./3. 41 | self.data = [self.data[i] for i in np.random.permutation(len(self.data))[:math.ceil(len(self.data)*frac)]] 42 | elif dataSubType == "random50": 43 | self.data = meta_data[self.task]["train"] + meta_data[self.task]["test"] 44 | frac = 0.5 45 | self.data = [self.data[i] for i in np.random.permutation(len(self.data))[:math.ceil(len(self.data)*frac)]] 46 | else: 47 | raise ValueError 48 | 49 | if n_shot is not None and n_shot != "full": 50 | all_answers = np.array([a for [_,a] in self.data]) 51 | classes_idx = [np.arange(len(self.data))[np.array(all_answers) == self.answers[i]] for i in range(len(self.answers))] 52 | classes_idx = [np.random.RandomState(seed=n_shot_seed).permutation(i)[:n_shot] for i in classes_idx] 53 | self.data = [self.data[i] for i in np.concatenate(classes_idx)] 54 | 55 | def __len__(self): 56 | return len(self.data) 57 | 58 | def __getitem__(self, idx): 59 | if torch.is_tensor(idx): 60 | idx = idx.tolist() 61 | 62 | [image_name, answer] = self.data[idx] 63 | image_features = self.image_features[image_name] 64 | text_features = self.text_features[self.task] 65 | ques_emb = self.ques_emb[self.task] 66 | 67 | 68 | sample = {'ques_emb': ques_emb, 'image_features': image_features, 'text_features': text_features, "label": self.answers.index(answer)} 69 | 70 | return sample 71 | -------------------------------------------------------------------------------- /data/dataloader/coco_tasks.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function 2 | 3 | import json 4 | import os 5 | import random 6 | 7 | import clip 8 | import numpy as np 9 | import torch 10 | import tqdm 11 | from features.image_features import load_image_features 12 | from slugify import slugify 13 | from torch.utils.data import Dataset 14 | 15 | base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) 16 | device = "cuda" if torch.cuda.is_available() else "cpu" 17 | 18 | def compute_coco_tasks(): 19 | with open(base_path + "/data/VQA/Meta/meta_train.json") as file: 20 | train_data = json.load(file) 21 | with open(base_path + "/data/VQA/Meta/meta_test.json") as file: 22 | test_data = json.load(file) 23 | 24 | data_types = ['val2014', 'train2014'] 25 | coco_data = {} 26 | for dt in data_types: 27 | with open(base_path + f"/data/VQA/Annotations/annotations/instances_{dt}.json") as file: 28 | tmp = json.load(file) 29 | if coco_data == {}: 30 | coco_data = tmp 31 | else: 32 | coco_data['images'] += tmp['images'] 33 | coco_data['annotations'] += tmp['annotations'] 34 | 35 | imgs_test = [] 36 | for task in test_data: 37 | for dataSubType in ["train", "test"]: 38 | for [image_name, _] in test_data[task][dataSubType]: 39 | imgs_test.append(image_name) 40 | 41 | cat_id_to_name = {cat['id']: cat['name'] for cat in coco_data['categories']} 42 | imgs_id_to_name = {i['id']: i['file_name'] for i in coco_data['images']} 43 | 44 | vanilla_coco_categories = {} 45 | for ann in tqdm.tqdm(coco_data["annotations"]): 46 | if ann["area"] > 20000 and not ann["iscrowd"]: 47 | cat = cat_id_to_name[ann["category_id"]] 48 | img_name = imgs_id_to_name[ann['image_id']].split(".")[0] 49 | if img_name not in imgs_test: 50 | if cat not in vanilla_coco_categories.keys(): 51 | vanilla_coco_categories[cat] = [] 52 | if img_name not in vanilla_coco_categories[cat]: 53 | vanilla_coco_categories[cat].append(img_name) 54 | 55 | np.save(base_path+"/data/Attributes/vanilla_coco_categories.npy", vanilla_coco_categories) 56 | 57 | 58 | def compute_coco_answer_features(model='ViT-L/14@336px'): 59 | print("Load CLIP {} model".format(model)) 60 | clip_model, _ = clip.load(model, device=torch.device("cpu"), jit=False) # load cpu version to get float32 model 61 | clip_model = clip_model.to(device) 62 | 63 | categories = np.load(base_path+"/data/Attributes/vanilla_coco_categories.npy", allow_pickle=True).item() 64 | 65 | categories_features = {} 66 | for cat in tqdm.tqdm(categories.keys()): 67 | prompt = f"A picture of a {cat}" 68 | prompt = clip.tokenize(prompt).to(device) 69 | 70 | with torch.no_grad(): 71 | text_feature = clip_model.encode_text(prompt) 72 | 73 | text_feature = text_feature.float().cpu().numpy() 74 | 75 | categories_features[cat] = text_feature 76 | 77 | np.save(base_path+f'/data/Attributes/coco_answer_features_{slugify(model)}.npy',categories_features) 78 | 79 | def load_coco_answer_features(model='ViT-L/14@336px'): 80 | coco_answer_features = np.load(base_path+f'/data/Attributes/coco_answer_features_{slugify(model)}.npy', allow_pickle=True).item() 81 | for cat in coco_answer_features.keys(): 82 | coco_answer_features[cat] = torch.from_numpy(coco_answer_features[cat]).to(device) 83 | 84 | return coco_answer_features 85 | 86 | def filter_categories(categories, min_size=5): 87 | filtered_categories = {} 88 | for cat in categories.keys(): 89 | if len(categories[cat]) >= min_size: 90 | filtered_categories[cat] = categories[cat] 91 | return filtered_categories 92 | 93 | class COCO_Tasks(Dataset): 94 | 95 | def __init__(self, categories, dataSubType, image_features, coco_answer_features, n_way=5, train_size=5, test_size=5, task_seed=42): 96 | self.categories = categories 97 | self.image_features = image_features 98 | self.coco_answer_features = coco_answer_features 99 | self.n_way = n_way 100 | self.train_size = train_size 101 | self.test_size = test_size 102 | self.task_seed = task_seed 103 | 104 | random.seed(self.task_seed) 105 | self.answers = random.sample(list(self.categories.keys()), n_way) 106 | self.text_features = torch.cat([self.coco_answer_features[answer] for answer in self.answers], dim=0) 107 | 108 | train_data = [] 109 | test_data = [] 110 | for answer in self.answers: 111 | data_for_answer = random.sample(self.categories[answer], self.train_size+self.test_size) 112 | for i, image_name in enumerate(data_for_answer): 113 | if i < self.train_size: 114 | train_data.append([image_name, answer]) 115 | else: 116 | test_data.append([image_name, answer]) 117 | 118 | if dataSubType == "train": 119 | self.data = train_data 120 | elif dataSubType == "test": 121 | self.data = test_data 122 | elif dataSubType == "traintest": 123 | self.data = train_data + test_data 124 | 125 | random.shuffle(self.data) 126 | 127 | def __len__(self): 128 | return len(self.data) 129 | 130 | def __getitem__(self, idx): 131 | if torch.is_tensor(idx): 132 | idx = idx.tolist() 133 | 134 | [image_name, answer] = self.data[idx] 135 | image_features = self.image_features[image_name] 136 | text_features = self.text_features 137 | 138 | 139 | sample = {'ques_emb': [0.], 'image_features': image_features, 'text_features': text_features, "label": self.answers.index(answer)} 140 | 141 | return sample 142 | 143 | if __name__ == "__main__": 144 | #compute_coco_tasks() 145 | #compute_coco_answer_features() 146 | coco_categories = np.load(base_path+"/data/Attributes/vanilla_coco_categories.npy", allow_pickle=True).item() 147 | coco_answer_features = load_coco_answer_features(model='ViT-L/14@336px') 148 | image_features = load_image_features(model='ViT-L/14@336px') 149 | dataset = COCO_Tasks(coco_categories, "train", image_features, coco_answer_features, n_way=5) 150 | for sample in dataset: 151 | print(sample) 152 | 153 | -------------------------------------------------------------------------------- /data/helper/PythonHelperTools/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/helper/PythonHelperTools/generate_metadata.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | 5 | from vqaTools.vqa import VQA 6 | 7 | dataDir ='../../VQA' 8 | versionType ='v2_' # this should be '' when using VQA v2.0 dataset 9 | taskType ='OpenEnded' # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0 10 | dataType ='mscoco' # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0. 11 | dataSubType ='train2014' 12 | annFile ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType) 13 | quesFile ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType) 14 | imgDir = '%s/Images/%s/%s/' %(dataDir, dataType, dataSubType) 15 | 16 | # initialize VQA api for QA annotations 17 | vqa_train=VQA(annFile, quesFile) 18 | 19 | unique_questions_train = {} 20 | 21 | for question_id in vqa_train.qa.keys(): 22 | if vqa_train.qa[question_id]["answer_type"] in ["number","other"]: 23 | ques = vqa_train.qqa[question_id]['question'] 24 | if ques not in unique_questions_train.keys(): 25 | unique_questions_train[ques] = {"count":1,"ids":[question_id]} 26 | else: 27 | unique_questions_train[ques]["count"] +=1 28 | unique_questions_train[ques]["ids"].append(question_id) 29 | 30 | dataSubType ='val2014' 31 | annFile ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType) 32 | quesFile ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType) 33 | imgDir = '%s/Images/%s/%s/' %(dataDir, dataType, dataSubType) 34 | 35 | # initialize VQA api for QA annotations 36 | vqa_val=VQA(annFile, quesFile) 37 | 38 | unique_questions_val = {} 39 | 40 | for question_id in vqa_val.qa.keys(): 41 | if vqa_val.qa[question_id]["answer_type"] in ["number","other"]: 42 | ques = vqa_val.qqa[question_id]['question'] 43 | if ques not in unique_questions_val.keys(): 44 | unique_questions_val[ques] = {"count":1,"ids":[question_id]} 45 | else: 46 | unique_questions_val[ques]["count"] +=1 47 | unique_questions_val[ques]["ids"].append(question_id) 48 | 49 | unique_questions_trainval = unique_questions_train 50 | for i, item in unique_questions_val.items(): 51 | if i in unique_questions_train: 52 | unique_questions_trainval[i]['count'] += item['count'] 53 | unique_questions_trainval[i]['ids'] += item['ids'] 54 | else: 55 | unique_questions_trainval[i] = {'count':item['count'], 'ids': item['ids']} 56 | 57 | vqa_trainval_qa = {**vqa_train.qa, **vqa_val.qa} 58 | vqa_trainval_qqa = {**vqa_train.qqa, **vqa_val.qqa} 59 | 60 | unique_questions_trainval_selected ={} 61 | for key, item in unique_questions_trainval.items(): 62 | image_ans = {} 63 | image_list = [] 64 | for i in item["ids"]: 65 | a = vqa_trainval_qa[i]['multiple_choice_answer'] # multiple_choice_answer: most frequent ground-truth answer. 66 | # answers[a] = answers.get(a,0) + 1 67 | image_id = vqa_trainval_qa[i]['image_id'] 68 | if i in vqa_train.qa: 69 | image_name = 'COCO_train2014_'+ str(image_id).zfill(12) 70 | elif i in vqa_val.qa: 71 | image_name = 'COCO_val2014_'+ str(image_id).zfill(12) 72 | 73 | if image_name not in image_list: # Avoding same image appear for the same task 74 | image_list.append(image_name) 75 | if a not in image_ans: 76 | image_ans[a] = [(image_name,i)] 77 | else: 78 | image_ans[a].append((image_name,i)) 79 | 80 | # delete the answer that only appears once 81 | image_ans = {key:val for key, val in image_ans.items() if len(val) > 1} 82 | answers = {key:len(val) for key, val in image_ans.items()} 83 | unique_questions_trainval_selected[key] = {"count":sum(answers.values()), "answers_count": len(answers), "answers": answers ,"image_ans":image_ans} 84 | 85 | ques_ans_count = {} 86 | ques_image_ans = {} 87 | 88 | count = 0 89 | for key, item in unique_questions_trainval_selected.items(): 90 | cond1 = item["count"] > 20 # question appear at least 20 times 91 | cond2 = item["answers_count"] > 1 # question contains multiple answers 92 | cond3 = "or" not in key.split() # question not in "choose from" form 93 | if cond1 and cond2 and cond3: 94 | ques_ans_count[key] = item["answers"] 95 | ques_image_ans[key] = item["image_ans"] 96 | count += item["count"] 97 | 98 | color = 0 99 | count = 0 100 | for ques in ques_ans_count.keys(): 101 | if 'color' in ques: # question about color 102 | color += 1 103 | if 'How many' in ques: # question about counting 104 | count += 1 105 | 106 | meta_traintest = {} 107 | 108 | for ques, ans in ques_image_ans.items(): 109 | train = [] 110 | test = [] 111 | for a, data in ans.items(): 112 | split = round(len(data) * 0.7) 113 | shuffled_data = data.copy() 114 | random.Random(2021).shuffle(shuffled_data) 115 | train += [(i[0],a) for i in shuffled_data[:split]] 116 | test += [(i[0],a) for i in shuffled_data[split:]] 117 | 118 | meta_traintest[ques] = {"train" : train, "test": test, "answers": list(ans.keys())} 119 | 120 | tasks = list(meta_traintest.keys()) 121 | 122 | meta_train_tasks = [] 123 | meta_test_tasks = [] 124 | split = round(len(tasks) * 0.7) 125 | shuffled_data = tasks.copy() 126 | random.Random(2021).shuffle(shuffled_data) 127 | meta_train_tasks += [i for i in shuffled_data[:split]] 128 | meta_test_tasks += [i for i in shuffled_data[split:]] 129 | 130 | meta_test = {} 131 | 132 | for task in meta_test_tasks: 133 | meta_test[task] = meta_traintest[task] 134 | 135 | with open(os.path.join(dataDir,"Meta/meta_test.json"),"w") as file: 136 | json.dump(meta_test,file) 137 | 138 | meta_train = {} 139 | 140 | for task in meta_train_tasks: 141 | meta_train[task] = meta_traintest[task] 142 | 143 | with open(os.path.join(dataDir,"Meta/meta_train.json"),"w") as file: 144 | json.dump(meta_train,file) 145 | -------------------------------------------------------------------------------- /data/helper/PythonHelperTools/vqaTools/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'aagrawal' 2 | -------------------------------------------------------------------------------- /data/helper/PythonHelperTools/vqaTools/vqa.py: -------------------------------------------------------------------------------- 1 | __author__ = 'aagrawal' 2 | __version__ = '0.9' 3 | 4 | # Interface for accessing the VQA dataset. 5 | 6 | # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: 7 | # (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py). 8 | 9 | # The following functions are defined: 10 | # VQA - VQA class that loads VQA annotation file and prepares data structures. 11 | # getQuesIds - Get question ids that satisfy given filter conditions. 12 | # getImgIds - Get image ids that satisfy given filter conditions. 13 | # loadQA - Load questions and answers with the specified question ids. 14 | # showQA - Display the specified questions and answers. 15 | # loadRes - Load result file and create result object. 16 | 17 | # Help on each function can be accessed by: "help(COCO.function)" 18 | 19 | import json 20 | import datetime 21 | import copy 22 | 23 | class VQA: 24 | def __init__(self, annotation_file=None, question_file=None): 25 | """ 26 | Constructor of VQA helper class for reading and visualizing questions and answers. 27 | :param annotation_file (str): location of VQA annotation file 28 | :return: 29 | """ 30 | # load dataset 31 | self.dataset = {} 32 | self.questions = {} 33 | self.qa = {} 34 | self.qqa = {} 35 | self.imgToQA = {} 36 | if not annotation_file == None and not question_file == None: 37 | print('loading VQA annotations and questions into memory...') 38 | time_t = datetime.datetime.utcnow() 39 | dataset = json.load(open(annotation_file, 'r')) 40 | questions = json.load(open(question_file, 'r')) 41 | print(datetime.datetime.utcnow() - time_t) 42 | self.dataset = dataset 43 | self.questions = questions 44 | self.createIndex() 45 | 46 | def createIndex(self): 47 | # create index 48 | print('creating index...') 49 | imgToQA = {ann['image_id']: [] for ann in self.dataset['annotations']} 50 | qa = {ann['question_id']: [] for ann in self.dataset['annotations']} 51 | qqa = {ann['question_id']: [] for ann in self.dataset['annotations']} 52 | for ann in self.dataset['annotations']: 53 | imgToQA[ann['image_id']] += [ann] 54 | qa[ann['question_id']] = ann 55 | for ques in self.questions['questions']: 56 | qqa[ques['question_id']] = ques 57 | print('index created!') 58 | 59 | # create class members 60 | self.qa = qa 61 | self.qqa = qqa 62 | self.imgToQA = imgToQA 63 | 64 | def info(self): 65 | """ 66 | Print information about the VQA annotation file. 67 | :return: 68 | """ 69 | for key, value in self.datset['info'].items(): 70 | print('%s: %s'%(key, value)) 71 | 72 | def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]): 73 | """ 74 | Get question ids that satisfy given filter conditions. default skips that filter 75 | :param imgIds (int array) : get question ids for given imgs 76 | quesTypes (str array) : get question ids for given question types 77 | ansTypes (str array) : get question ids for given answer types 78 | :return: ids (int array) : integer array of question ids 79 | """ 80 | imgIds = imgIds if type(imgIds) == list else [imgIds] 81 | quesTypes = quesTypes if type(quesTypes) == list else [quesTypes] 82 | ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] 83 | 84 | if len(imgIds) == len(quesTypes) == len(ansTypes) == 0: 85 | anns = self.dataset['annotations'] 86 | else: 87 | if not len(imgIds) == 0: 88 | anns = sum([self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA],[]) 89 | else: 90 | anns = self.dataset['annotations'] 91 | anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes] 92 | anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes] 93 | ids = [ann['question_id'] for ann in anns] 94 | return ids 95 | 96 | def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]): 97 | """ 98 | Get image ids that satisfy given filter conditions. default skips that filter 99 | :param quesIds (int array) : get image ids for given question ids 100 | quesTypes (str array) : get image ids for given question types 101 | ansTypes (str array) : get image ids for given answer types 102 | :return: ids (int array) : integer array of image ids 103 | """ 104 | quesIds = quesIds if type(quesIds) == list else [quesIds] 105 | quesTypes = quesTypes if type(quesTypes) == list else [quesTypes] 106 | ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] 107 | 108 | if len(quesIds) == len(quesTypes) == len(ansTypes) == 0: 109 | anns = self.dataset['annotations'] 110 | else: 111 | if not len(quesIds) == 0: 112 | anns = sum([self.qa[quesId] for quesId in quesIds if quesId in self.qa],[]) 113 | else: 114 | anns = self.dataset['annotations'] 115 | anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes] 116 | anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes] 117 | ids = [ann['image_id'] for ann in anns] 118 | return ids 119 | 120 | def loadQA(self, ids=[]): 121 | """ 122 | Load questions and answers with the specified question ids. 123 | :param ids (int array) : integer ids specifying question ids 124 | :return: qa (object array) : loaded qa objects 125 | """ 126 | if type(ids) == list: 127 | return [self.qa[id] for id in ids] 128 | elif type(ids) == int: 129 | return [self.qa[ids]] 130 | 131 | def showQA(self, anns): 132 | """ 133 | Display the specified annotations. 134 | :param anns (array of object): annotations to display 135 | :return: None 136 | """ 137 | if len(anns) == 0: 138 | return 0 139 | for ann in anns: 140 | quesId = ann['question_id'] 141 | print("Question: %s" %(self.qqa[quesId]['question'])) 142 | for ans in ann['answers']: 143 | print("Answer %d: %s" %(ans['answer_id'], ans['answer'])) 144 | 145 | def loadRes(self, resFile, quesFile): 146 | """ 147 | Load result file and return a result object. 148 | :param resFile (str) : file name of result file 149 | :return: res (obj) : result api object 150 | """ 151 | res = VQA() 152 | res.questions = json.load(open(quesFile)) 153 | res.dataset['info'] = copy.deepcopy(self.questions['info']) 154 | res.dataset['task_type'] = copy.deepcopy(self.questions['task_type']) 155 | res.dataset['data_type'] = copy.deepcopy(self.questions['data_type']) 156 | res.dataset['data_subtype'] = copy.deepcopy(self.questions['data_subtype']) 157 | res.dataset['license'] = copy.deepcopy(self.questions['license']) 158 | 159 | print('Loading and preparing results... ') 160 | time_t = datetime.datetime.utcnow() 161 | anns = json.load(open(resFile)) 162 | assert type(anns) == list, 'results is not an array of objects' 163 | annsQuesIds = [ann['question_id'] for ann in anns] 164 | # assert set(annsQuesIds) == set(self.getQuesIds()), \ 165 | # 'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.' 166 | for ann in anns: 167 | quesId = ann['question_id'] 168 | if res.dataset['task_type'] == 'Multiple Choice': 169 | assert ann['answer'] in self.qqa[quesId]['multiple_choices'], 'predicted answer is not one of the multiple choices' 170 | qaAnn = self.qa[quesId] 171 | ann['image_id'] = qaAnn['image_id'] 172 | ann['question_type'] = qaAnn['question_type'] 173 | ann['answer_type'] = qaAnn['answer_type'] 174 | print('DONE (t=%0.2fs)'%((datetime.datetime.utcnow() - time_t).total_seconds())) 175 | 176 | res.dataset['annotations'] = anns 177 | res.createIndex() 178 | return res -------------------------------------------------------------------------------- /data/helper/QuestionTypes/abstract_v002_question_types.txt: -------------------------------------------------------------------------------- 1 | how many 2 | what color is the 3 | is the 4 | where is the 5 | what 6 | what is 7 | are the 8 | what is the 9 | is there a 10 | does the 11 | is the woman 12 | is the man 13 | what is on the 14 | is it 15 | is the girl 16 | is the boy 17 | is the dog 18 | are they 19 | who is 20 | what kind of 21 | what color are the 22 | what is in the 23 | what is the man 24 | is there 25 | what is the woman 26 | what are the 27 | what is the boy 28 | are there 29 | what is the girl 30 | is this 31 | how 32 | which 33 | how many people are 34 | is the cat 35 | why is the 36 | are 37 | will the 38 | what type of 39 | what is the dog 40 | do 41 | is she 42 | does 43 | do the 44 | is 45 | is the baby 46 | are there any 47 | is the lady 48 | can 49 | what animal is 50 | where are the 51 | is the sun 52 | what are they 53 | did the 54 | what is the cat 55 | what is the lady 56 | how many clouds are 57 | is that 58 | is the little girl 59 | is he 60 | are these 61 | how many trees are 62 | how many pillows 63 | are the people 64 | why 65 | is the young 66 | how many windows are 67 | is this a 68 | what is the little 69 | is the tv 70 | how many animals are 71 | who 72 | how many pictures 73 | how many plants are 74 | how many birds are 75 | what color is 76 | what is the baby 77 | is anyone 78 | what color 79 | how many bushes 80 | is the old man 81 | none of the above 82 | -------------------------------------------------------------------------------- /data/helper/QuestionTypes/mscoco_question_types.txt: -------------------------------------------------------------------------------- 1 | how many 2 | is the 3 | what 4 | what color is the 5 | what is the 6 | is this 7 | is this a 8 | what is 9 | are the 10 | what kind of 11 | is there a 12 | what type of 13 | is it 14 | what are the 15 | where is the 16 | is there 17 | does the 18 | what color are the 19 | are these 20 | are there 21 | which 22 | is 23 | what is the man 24 | is the man 25 | are 26 | how 27 | does this 28 | what is on the 29 | what does the 30 | how many people are 31 | what is in the 32 | what is this 33 | do 34 | what are 35 | are they 36 | what time 37 | what sport is 38 | are there any 39 | is he 40 | what color is 41 | why 42 | where are the 43 | what color 44 | who is 45 | what animal is 46 | is the woman 47 | is this an 48 | do you 49 | how many people are in 50 | what room is 51 | has 52 | is this person 53 | what is the woman 54 | can you 55 | why is the 56 | is the person 57 | what is the color of the 58 | what is the person 59 | could 60 | was 61 | is that a 62 | what number is 63 | what is the name 64 | what brand 65 | none of the above 66 | -------------------------------------------------------------------------------- /data/helper/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: hyperclip 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _openmp_mutex=5.1=1_gnu 10 | - argon2-cffi=21.3.0=pyhd3eb1b0_0 11 | - argon2-cffi-bindings=21.2.0=py39h7f8727e_0 12 | - asttokens=2.0.5=pyhd3eb1b0_0 13 | - attrs=21.4.0=pyhd3eb1b0_0 14 | - backcall=0.2.0=pyhd3eb1b0_0 15 | - beautifulsoup4=4.11.1=py39h06a4308_0 16 | - blas=1.0=mkl 17 | - bleach=4.1.0=pyhd3eb1b0_0 18 | - blosc=1.21.0=h8c45485_0 19 | - bottleneck=1.3.4=py39hce1f21e_0 20 | - brotli=1.0.9=he6710b0_2 21 | - brotlipy=0.7.0=py39h27cfd23_1003 22 | - brunsli=0.1=h2531618_0 23 | - bzip2=1.0.8=h7b6447c_0 24 | - c-ares=1.18.1=h7f8727e_0 25 | - ca-certificates=2022.6.15=ha878542_0 26 | - certifi=2022.6.15=pyhd8ed1ab_1 27 | - cffi=1.15.0=py39hd667e15_1 28 | - cfitsio=3.470=hf0d0db6_6 29 | - charls=2.2.0=h2531618_0 30 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 31 | - cloudpickle=2.0.0=pyhd3eb1b0_0 32 | - cryptography=36.0.0=py39h9ce1e76_0 33 | - cudatoolkit=11.3.1=ha36c431_9 34 | - cycler=0.11.0=pyhd3eb1b0_0 35 | - cytoolz=0.11.0=py39h27cfd23_0 36 | - dask-core=2022.2.1=pyhd3eb1b0_0 37 | - dbus=1.13.18=hb2f20db_0 38 | - debugpy=1.5.1=py39h295c915_0 39 | - decorator=5.1.1=pyhd3eb1b0_0 40 | - defusedxml=0.7.1=pyhd3eb1b0_0 41 | - einops=0.4.1=pyhd8ed1ab_0 42 | - entrypoints=0.4=py39h06a4308_0 43 | - executing=0.8.3=pyhd3eb1b0_0 44 | - expat=2.4.4=h295c915_0 45 | - ffmpeg=4.2.2=h20bf706_0 46 | - fontconfig=2.13.1=h6c09931_0 47 | - fonttools=4.25.0=pyhd3eb1b0_0 48 | - freetype=2.11.0=h70c0345_0 49 | - fsspec=2022.2.0=pyhd3eb1b0_0 50 | - giflib=5.2.1=h7b6447c_0 51 | - glib=2.69.1=h4ff587b_1 52 | - gmp=6.2.1=h2531618_2 53 | - gnutls=3.6.15=he1e5248_0 54 | - gst-plugins-base=1.14.0=h8213a91_2 55 | - gstreamer=1.14.0=h28cd5cc_2 56 | - icu=58.2=he6710b0_3 57 | - idna=3.3=pyhd3eb1b0_0 58 | - imagecodecs=2021.8.26=py39h4cda21f_0 59 | - imageio=2.9.0=pyhd3eb1b0_0 60 | - intel-openmp=2021.4.0=h06a4308_3561 61 | - ipykernel=6.9.1=py39h06a4308_0 62 | - ipython=8.3.0=py39h06a4308_0 63 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 64 | - ipywidgets=7.6.5=pyhd3eb1b0_1 65 | - jedi=0.18.1=py39h06a4308_1 66 | - jinja2=3.0.3=pyhd3eb1b0_0 67 | - joblib=1.1.0=pyhd3eb1b0_0 68 | - jpeg=9e=h7f8727e_0 69 | - jsonschema=4.4.0=py39h06a4308_0 70 | - jupyter_client=7.2.2=py39h06a4308_0 71 | - jupyter_core=4.10.0=py39h06a4308_0 72 | - jupyterlab_pygments=0.1.2=py_0 73 | - jupyterlab_widgets=1.0.0=pyhd3eb1b0_1 74 | - jxrlib=1.1=h7b6447c_2 75 | - kiwisolver=1.3.2=py39h295c915_0 76 | - krb5=1.19.2=hac12032_0 77 | - lame=3.100=h7b6447c_0 78 | - lcms2=2.12=h3be6417_0 79 | - ld_impl_linux-64=2.38=h1181459_1 80 | - lerc=3.0=h295c915_0 81 | - libaec=1.0.4=he6710b0_1 82 | - libcurl=7.82.0=h0b77cf5_0 83 | - libdeflate=1.8=h7f8727e_5 84 | - libedit=3.1.20210910=h7f8727e_0 85 | - libev=4.33=h7f8727e_1 86 | - libffi=3.3=he6710b0_2 87 | - libgcc-ng=11.2.0=h1234567_1 88 | - libgfortran-ng=7.5.0=ha8ba4b0_17 89 | - libgfortran4=7.5.0=ha8ba4b0_17 90 | - libgomp=11.2.0=h1234567_1 91 | - libidn2=2.3.2=h7f8727e_0 92 | - libnghttp2=1.46.0=hce63b2e_0 93 | - libopus=1.3.1=h7b6447c_0 94 | - libpng=1.6.37=hbc83047_0 95 | - libsodium=1.0.18=h7b6447c_0 96 | - libssh2=1.10.0=h8f2d780_0 97 | - libstdcxx-ng=11.2.0=h1234567_1 98 | - libtasn1=4.16.0=h27cfd23_0 99 | - libtiff=4.2.0=h85742a9_0 100 | - libunistring=0.9.10=h27cfd23_0 101 | - libuuid=1.0.3=h7f8727e_2 102 | - libuv=1.40.0=h7b6447c_0 103 | - libvpx=1.7.0=h439df22_0 104 | - libwebp=1.2.2=h55f646e_0 105 | - libwebp-base=1.2.2=h7f8727e_0 106 | - libxcb=1.14=h7b6447c_0 107 | - libxml2=2.9.12=h03d6c58_0 108 | - libzopfli=1.0.3=he6710b0_0 109 | - locket=0.2.1=py39h06a4308_2 110 | - lz4-c=1.9.3=h295c915_1 111 | - markupsafe=2.1.1=py39h7f8727e_0 112 | - matplotlib=3.5.1=py39hf3d152e_0 113 | - matplotlib-base=3.5.1=py39ha18d171_1 114 | - matplotlib-inline=0.1.2=pyhd3eb1b0_2 115 | - mistune=0.8.4=py39h27cfd23_1000 116 | - mkl=2021.4.0=h06a4308_640 117 | - mkl-service=2.4.0=py39h7f8727e_0 118 | - mkl_fft=1.3.1=py39hd3c417c_0 119 | - mkl_random=1.2.2=py39h51133e4_0 120 | - munkres=1.1.4=py_0 121 | - nbclient=0.5.13=py39h06a4308_0 122 | - nbconvert=6.4.4=py39h06a4308_0 123 | - nbformat=5.3.0=py39h06a4308_0 124 | - ncurses=6.3=h7f8727e_2 125 | - nest-asyncio=1.5.5=py39h06a4308_0 126 | - nettle=3.7.3=hbbd107a_1 127 | - networkx=2.7.1=pyhd3eb1b0_0 128 | - notebook=6.4.11=py39h06a4308_0 129 | - numexpr=2.8.1=py39h6abb31d_0 130 | - numpy=1.21.5=py39he7a7128_1 131 | - numpy-base=1.21.5=py39hf524024_1 132 | - openh264=2.1.1=h4ff587b_0 133 | - openjpeg=2.4.0=h3ad879b_0 134 | - openssl=1.1.1q=h7f8727e_0 135 | - packaging=21.3=pyhd3eb1b0_0 136 | - pandas=1.4.2=py39h295c915_0 137 | - pandocfilters=1.5.0=pyhd3eb1b0_0 138 | - parso=0.8.3=pyhd3eb1b0_0 139 | - partd=1.2.0=pyhd3eb1b0_1 140 | - pcre=8.45=h295c915_0 141 | - pexpect=4.8.0=pyhd3eb1b0_3 142 | - pickleshare=0.7.5=pyhd3eb1b0_1003 143 | - pillow=9.0.1=py39h22f2fdc_0 144 | - pip=21.2.4=py39h06a4308_0 145 | - prometheus_client=0.13.1=pyhd3eb1b0_0 146 | - prompt-toolkit=3.0.20=pyhd3eb1b0_0 147 | - ptyprocess=0.7.0=pyhd3eb1b0_2 148 | - pure_eval=0.2.2=pyhd3eb1b0_0 149 | - pycparser=2.21=pyhd3eb1b0_0 150 | - pygments=2.11.2=pyhd3eb1b0_0 151 | - pyopenssl=22.0.0=pyhd3eb1b0_0 152 | - pyparsing=3.0.4=pyhd3eb1b0_0 153 | - pyqt=5.9.2=py39h2531618_6 154 | - pyrsistent=0.18.0=py39heee7806_0 155 | - pysocks=1.7.1=py39h06a4308_0 156 | - python=3.9.12=h12debd9_1 157 | - python-dateutil=2.8.2=pyhd3eb1b0_0 158 | - python-fastjsonschema=2.15.1=pyhd3eb1b0_0 159 | - python-slugify=6.1.2=pyhd8ed1ab_0 160 | - python_abi=3.9=2_cp39 161 | - pytorch=1.12.0=py3.9_cuda11.3_cudnn8.3.2_0 162 | - pytorch-mutex=1.0=cuda 163 | - pytz=2021.3=pyhd3eb1b0_0 164 | - pywavelets=1.3.0=py39h7f8727e_0 165 | - pyyaml=6.0=py39h7f8727e_1 166 | - pyzmq=22.3.0=py39h295c915_2 167 | - qt=5.9.7=h5867ecd_1 168 | - readline=8.1.2=h7f8727e_1 169 | - requests=2.27.1=pyhd3eb1b0_0 170 | - scikit-image=0.19.2=py39h51133e4_0 171 | - scikit-learn=1.0.2=py39h51133e4_1 172 | - scipy=1.7.3=py39hc147768_0 173 | - send2trash=1.8.0=pyhd3eb1b0_1 174 | - setuptools=61.2.0=py39h06a4308_0 175 | - sip=4.19.13=py39h295c915_0 176 | - six=1.16.0=pyhd3eb1b0_1 177 | - snappy=1.1.9=h295c915_0 178 | - soupsieve=2.3.1=pyhd3eb1b0_0 179 | - sqlite=3.38.3=hc218d9a_0 180 | - stack_data=0.2.0=pyhd3eb1b0_0 181 | - terminado=0.13.1=py39h06a4308_0 182 | - testpath=0.5.0=pyhd3eb1b0_0 183 | - text-unidecode=1.3=pyhd3eb1b0_0 184 | - threadpoolctl=2.2.0=pyh0d69192_0 185 | - tifffile=2021.7.2=pyhd3eb1b0_2 186 | - tk=8.6.12=h1ccaba5_0 187 | - toolz=0.11.2=pyhd3eb1b0_0 188 | - tornado=6.1=py39h27cfd23_0 189 | - traitlets=5.1.1=pyhd3eb1b0_0 190 | - typing-extensions=4.1.1=hd3eb1b0_0 191 | - typing_extensions=4.1.1=pyh06a4308_0 192 | - tzdata=2022a=hda174b7_0 193 | - unidecode=1.2.0=pyhd3eb1b0_0 194 | - urllib3=1.26.9=py39h06a4308_0 195 | - wcwidth=0.2.5=pyhd3eb1b0_0 196 | - webencodings=0.5.1=py39h06a4308_1 197 | - wheel=0.37.1=pyhd3eb1b0_0 198 | - widgetsnbextension=3.5.2=py39h06a4308_0 199 | - x264=1!157.20191217=h7b6447c_0 200 | - xz=5.2.5=h7f8727e_1 201 | - yacs=0.1.8=pyhd8ed1ab_0 202 | - yaml=0.2.5=h7b6447c_0 203 | - zeromq=4.3.4=h2531618_0 204 | - zfp=0.5.5=h295c915_6 205 | - zlib=1.2.12=h7f8727e_2 206 | - zstd=1.4.9=haebb681_0 207 | - pip: 208 | - blobfile==1.3.3 209 | - click==8.1.2 210 | - docker-pycreds==0.4.0 211 | - filelock==3.8.0 212 | - ftfy==6.1.1 213 | - gitdb==4.0.9 214 | - gitpython==3.1.27 215 | - h5py==3.7.0 216 | - ordered-set==4.1.0 217 | - pathtools==0.1.2 218 | - promise==2.3 219 | - protobuf==3.19.4 220 | - psutil==5.9.0 221 | - pycryptodomex==3.15.0 222 | - regex==2022.4.24 223 | - sentry-sdk==1.5.10 224 | - setproctitle==1.2.3 225 | - shortuuid==1.0.8 226 | - smmap==5.0.0 227 | - torchmeta==1.8.0 228 | - torchvision==0.10.1 229 | - tqdm==4.64.0 230 | - wandb==0.12.15 231 | - xmltodict==0.12.0 232 | -------------------------------------------------------------------------------- /features/image_features.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import clip 4 | import numpy as np 5 | import torch 6 | import tqdm 7 | from PIL import Image 8 | from slugify import slugify 9 | from torchvision.transforms import Compose, Normalize, Resize, ToTensor 10 | from utils import clip_utils 11 | 12 | try: 13 | from torchvision.transforms import InterpolationMode 14 | BICUBIC = InterpolationMode.BICUBIC 15 | except ImportError: 16 | BICUBIC = Image.BICUBIC 17 | 18 | base_path = os.path.dirname(os.path.dirname(__file__)) 19 | device = "cuda" if torch.cuda.is_available() else "cpu" 20 | 21 | def _convert_image_to_rgb(image): 22 | return image.convert("RGB") 23 | 24 | def compute_image_features(model): 25 | 26 | print("Load CLIP {} model".format(model)) 27 | clip_model, _ = clip.load(model, device=torch.device("cpu"), jit=False) # load cpu version to get float32 model 28 | clip_model = clip_model.to(device) 29 | image_resolution = clip_utils.image_resolution[model] 30 | preprocess = Compose([Resize((image_resolution,image_resolution), interpolation=BICUBIC), 31 | _convert_image_to_rgb, 32 | ToTensor(), 33 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 34 | ]) 35 | 36 | img_features = {} 37 | 38 | for data_subtype in ['train2014', 'val2014']: 39 | 40 | img_dir = base_path + f'/data/VQA/Images/{data_subtype}/' 41 | 42 | for filename in tqdm.tqdm(os.listdir(img_dir)): 43 | 44 | img_name = filename.replace(".jpg", "") 45 | image = preprocess(Image.open(img_dir + filename)).unsqueeze(0).to(device) 46 | 47 | with torch.no_grad(): 48 | features = clip_model.encode_image(image) 49 | 50 | img_features[img_name] = features.float().cpu().numpy() 51 | 52 | np.save(base_path+f'/data/VQA/Features/ImageFeatures/image_features_{slugify(model)}.npy',img_features) 53 | 54 | def load_image_features(model): 55 | img_features = np.load(base_path + f"/data/VQA/Features/ImageFeatures/image_features_{slugify(model)}.npy", allow_pickle=True).item() 56 | 57 | for img_name in img_features.keys(): 58 | img_features[img_name] = torch.from_numpy(img_features[img_name]).to(device) 59 | 60 | return img_features 61 | -------------------------------------------------------------------------------- /features/ques_features.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import clip 5 | import numpy as np 6 | import torch 7 | import tqdm 8 | from slugify import slugify 9 | 10 | base_path = os.path.dirname(os.path.dirname(__file__)) 11 | device = "cuda" if torch.cuda.is_available() else "cpu" 12 | 13 | def compute_ques_features(model): 14 | print("Load CLIP {} model".format(model)) 15 | clip_model, _ = clip.load(model, device=torch.device("cpu"), jit=False) # load cpu version to get float32 model 16 | clip_model = clip_model.to(device) 17 | 18 | with open(base_path+"/data/VQA/Meta/ques_ans_count.json") as file: 19 | qa = json.load(file) 20 | 21 | ques_features = {} 22 | for ques in tqdm.tqdm(qa.keys()): 23 | text = clip.tokenize(ques).to(device) 24 | 25 | with torch.no_grad(): 26 | features = clip_model.encode_text(text) 27 | # text_features = text_features / text_features.norm(dim=-1, keepdim=True) 28 | 29 | features = features.float().cpu().numpy() #np.squeeze(features.float().cpu().numpy()) 30 | 31 | ques_features[ques] = features 32 | 33 | # with open(base_path+"/data/VQA/Features/TextFeatures/text_features.json", "w") as outfile: 34 | # json.dump(text_features_val, outfile) 35 | np.save(base_path+f'/data/VQA/Features/QuesFeatures/ques_features_{slugify(model)}.npy',ques_features) 36 | 37 | def load_ques_features(model): 38 | ques_features = np.load(base_path + f"/data/VQA/Features/QuesFeatures/ques_features_{slugify(model)}.npy", allow_pickle=True).item() 39 | 40 | for ques in ques_features.keys(): 41 | ques_features[ques] = torch.from_numpy(ques_features[ques]).to(device) 42 | 43 | return ques_features 44 | -------------------------------------------------------------------------------- /features/text_features.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import clip 5 | import numpy as np 6 | import torch 7 | import tqdm 8 | from slugify import slugify 9 | 10 | base_path = os.path.dirname(os.path.dirname(__file__)) 11 | device = "cuda" if torch.cuda.is_available() else "cpu" 12 | 13 | def compute_text_features(model): 14 | print("Load CLIP {} model".format(model)) 15 | clip_model, _ = clip.load(model, device=torch.device("cpu"), jit=False) # load cpu version to get float32 model 16 | clip_model = clip_model.to(device) 17 | 18 | with open(base_path+"/data/VQA/Meta/ques_ans_count.json") as file: 19 | qa = json.load(file) 20 | 21 | with open(base_path+"/data/VQA/Annotations/prompt_meta.json") as file: 22 | prompt_meta = json.load(file) 23 | 24 | text_features_meta = {} 25 | for ques in tqdm.tqdm(qa.keys()): 26 | temp = prompt_meta[ques] 27 | answers = list(qa[ques].keys()) 28 | prompts = [temp.format(a.replace("_", " ")) for a in answers] 29 | prompts = torch.cat([clip.tokenize(p) for p in prompts]) 30 | prompts = prompts.to(device) 31 | 32 | with torch.no_grad(): 33 | text_features = clip_model.encode_text(prompts) 34 | # text_features = text_features / text_features.norm(dim=-1, keepdim=True) 35 | 36 | text_features = text_features.float().cpu().numpy() 37 | 38 | text_features_meta[ques] = text_features 39 | 40 | # with open(base_path+"/data/VQA/Features/TextFeatures/text_features.json", "w") as outfile: 41 | # json.dump(text_features_val, outfile) 42 | np.save(base_path+f'/data/VQA/Features/TextFeatures/text_features_{slugify(model)}.npy',text_features_meta) 43 | 44 | def load_text_features(model): 45 | text_features = np.load(base_path + f"/data/VQA/Features/TextFeatures/text_features_{slugify(model)}.npy", allow_pickle=True).item() 46 | 47 | for ques in text_features.keys(): 48 | text_features[ques] = torch.from_numpy(text_features[ques]).to(device) 49 | 50 | return text_features 51 | -------------------------------------------------------------------------------- /model/custom_hnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import OrderedDict 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn import functional as F 8 | from torchmeta.modules import MetaLinear, MetaModule, MetaSequential 9 | from utils import clip_utils 10 | 11 | 12 | def get_from_dict_or_default(module, dict, key): 13 | if key in dict: 14 | return dict[key] 15 | return getattr(module, key) 16 | 17 | class Mlp(MetaModule): 18 | def __init__(self, input_dim, hidden_dims, output_dim, nonlin='relu'): 19 | super().__init__() 20 | self.input_dim=input_dim 21 | self.hidden_dims=hidden_dims 22 | self.output_dim=output_dim 23 | 24 | linear = [] 25 | prev_layer_dim = self.input_dim 26 | for dim in hidden_dims: 27 | linear.append(MetaLinear(prev_layer_dim, dim)) 28 | torch.nn.init.kaiming_normal_(linear[-1].weight, nonlinearity="relu") 29 | linear[-1].bias.data *= 0 30 | if nonlin == 'relu': 31 | linear.append(nn.ReLU()) 32 | elif nonlin == 'tanh': 33 | linear.append(nn.Tanh()) 34 | elif nonlin == 'softplus': 35 | linear.append(nn.Softplus()) 36 | else: 37 | assert False 38 | 39 | prev_layer_dim = dim 40 | 41 | linear.append(MetaLinear(prev_layer_dim, output_dim)) 42 | torch.nn.init.kaiming_normal_(linear[-1].weight, nonlinearity="linear") 43 | linear[-1].bias.data *= 0 44 | self.mlp = MetaSequential(*linear) 45 | 46 | def forward(self, uncond_input, params = None): 47 | if len(uncond_input.shape) == 1: 48 | uncond_input = uncond_input.unsqueeze(0) 49 | return self.mlp(uncond_input, params = self.get_subdict(params, "mlp")) 50 | 51 | class HyperGenerator(Mlp): 52 | def __init__(self, mnet, e_dim, hidden_dims, normalize=False): 53 | super().__init__(e_dim, hidden_dims, mnet.get_parameter_vector().shape[0]) 54 | self.normalize=normalize 55 | if self.normalize: 56 | self.register_buffer("feature_mean", torch.zeros(self.output_dim)) 57 | self.register_buffer("feature_std", torch.ones(self.output_dim)) 58 | 59 | def set_stats(self, mean, std): 60 | self.feature_mean.data = mean.detach() 61 | self.feature_std.data = std.detach() 62 | 63 | def forward(self, uncond_input, params = None): 64 | res = super().forward(uncond_input, params) 65 | if self.normalize: 66 | res = res * self.feature_std 67 | res = res + self.feature_mean 68 | return res 69 | 70 | 71 | class HyperEncoder(Mlp): 72 | def __init__(self, mnet, e_dim, hidden_dims, normalize=False): 73 | super().__init__(mnet.get_parameter_vector().shape[0], hidden_dims, 2*e_dim) 74 | self.normalize=normalize 75 | if self.normalize: 76 | self.register_buffer("feature_mean", torch.zeros(self.input_dim)) 77 | self.register_buffer("feature_std", torch.ones(self.input_dim)) 78 | 79 | def set_stats(self, mean, std): 80 | self.feature_mean.data = mean.detach() 81 | self.feature_std.data = std.detach() 82 | 83 | def forward(self, uncond_input, params = None): 84 | if self.normalize: 85 | uncond_input = uncond_input - self.feature_mean 86 | uncond_input = uncond_input / self.feature_std 87 | return super().forward(uncond_input, params) 88 | 89 | 90 | class HyperDiscriminator(Mlp): 91 | def __init__(self, mnet, hidden_dims): 92 | super().__init__(mnet.get_parameter_vector().shape[0], hidden_dims, 1) 93 | 94 | class CLIPAdapter(MetaModule): 95 | def __init__(self, e_dim, hidden_layers, use_bias, no_weights=False, straight_through=False, ignore_passed_weights=False): 96 | super().__init__() 97 | assert len(hidden_layers) == 1, "Architecture supports a single hidden layer." 98 | hidden_size = hidden_layers[0] 99 | self.e_dim=e_dim 100 | self.hidden_size=hidden_size 101 | self.use_bias=use_bias 102 | self.no_weights=no_weights 103 | self.no_weight=no_weights 104 | self.straight_through=straight_through 105 | self.ignore_passed_weights = ignore_passed_weights 106 | 107 | if no_weights: 108 | self.register_buffer("W1", torch.randn(hidden_size, e_dim).requires_grad_()) 109 | self.register_buffer("b1", torch.randn(hidden_size).requires_grad_() if use_bias else None) 110 | self.register_buffer("W2", torch.randn(e_dim, hidden_size).requires_grad_()) 111 | self.register_buffer("b2", torch.randn(e_dim).requires_grad_() if use_bias else None) 112 | else: 113 | norm_W1=math.sqrt(self.e_dim) if self.straight_through else 1 114 | norm_W2=math.sqrt(self.hidden_size) if self.straight_through else 1 115 | 116 | self.W1=torch.nn.Parameter((torch.randn(hidden_size, e_dim)/norm_W1).requires_grad_()) 117 | self.b1=torch.nn.Parameter(torch.randn(hidden_size).requires_grad_()) if use_bias else None 118 | self.W2=torch.nn.Parameter((torch.randn(e_dim, hidden_size)/norm_W2).requires_grad_()) 119 | self.b2=torch.nn.Parameter(torch.randn(e_dim).requires_grad_()) if use_bias else None 120 | 121 | def get_parameter_vector(self, params=None): 122 | if params is not None: 123 | return torch.cat([params["W1"].flatten(), params["W2"].flatten()] + [params["b1"], params["b2"]] if self.use_bias else []) 124 | return torch.cat([self.W1.flatten(), self.W2.flatten()] + [self.b1, self.b2] if self.use_bias else []) 125 | 126 | def get_gradient_vector(self): 127 | return torch.cat([self.W1.grad.flatten(), self.W2.grad.flatten()] + [self.b1.grad, self.b2.grad] if self.use_bias else []) 128 | 129 | def load_from_vector(self, vector): 130 | 131 | def get_last_elements(v, shape): 132 | numel = np.prod(shape) 133 | res = v[-numel:] 134 | v = v[:-numel] 135 | return res.reshape(*shape), v 136 | 137 | params = OrderedDict() 138 | vector = vector.flatten() 139 | if self.use_bias: 140 | params["b2"], vector = get_last_elements(vector, [self.e_dim]) 141 | params["b1"], vector = get_last_elements(vector, [self.hidden_size]) 142 | params["W2"], vector = get_last_elements(vector, [self.e_dim, self.hidden_size]) 143 | params["W1"], vector = get_last_elements(vector, [self.hidden_size, self.e_dim]) 144 | assert len(vector) == 0 145 | 146 | return params 147 | 148 | def forward(self, image_features, text_features, weights=None, params=None): 149 | param_dict = OrderedDict(params) if params is not None else OrderedDict() 150 | 151 | if weights is not None and not self.ignore_passed_weights: 152 | param_dict.update(self.load_from_vector(weights)) 153 | 154 | W1 = get_from_dict_or_default(self, param_dict, "W1") 155 | W2 = get_from_dict_or_default(self, param_dict, "W2") 156 | b1 = get_from_dict_or_default(self, param_dict, "b1") 157 | b2 = get_from_dict_or_default(self, param_dict, "b2") 158 | 159 | normalized_W1 = W1 160 | normalized_W2 = W2 161 | if not self.straight_through: 162 | normalized_W1 = W1/math.sqrt(self.e_dim) 163 | normalized_W2 = W2/math.sqrt(self.hidden_size) 164 | 165 | identity = image_features 166 | out = F.linear(image_features, normalized_W1, b1) 167 | out = F.relu(out) 168 | out = F.linear(out, normalized_W2, b2) 169 | out += identity 170 | adapted_image_features = out / out.norm(dim=-1, keepdim=True) 171 | 172 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 173 | text_features = text_features.to(torch.float32) 174 | logits = 100.0 * adapted_image_features @ torch.transpose(text_features, 1, 2) 175 | logits = torch.squeeze(logits,1) 176 | 177 | return logits 178 | 179 | class EmbeddingModule(torch.nn.Module): 180 | def __init__(self, input_dim): 181 | super().__init__() 182 | self.embedding = torch.nn.Parameter(torch.randn([1, input_dim])) 183 | 184 | def reset(self, embedding=None): 185 | self.embedding.data *= 0 186 | if embedding is None: 187 | embedding = torch.randn_like(self.embedding) 188 | self.embedding.data += embedding 189 | 190 | 191 | def forward(self, params=None): 192 | param_dict = OrderedDict(params) if params is not None else OrderedDict() 193 | embedding = get_from_dict_or_default(self, param_dict, "embedding") 194 | 195 | return embedding 196 | 197 | class MetaModel(MetaModule): 198 | def _init(self, mnet=None, hnet=None, enet=None, alpha=1): 199 | super().__init__() 200 | self.mnet = mnet 201 | self.hnet = hnet 202 | self.enet = enet 203 | self.alpha=alpha 204 | if self.hnet is not None: 205 | if self.enet is not None: 206 | self.meta_params = list(self.hnet.parameters())+list(self.enet.parameters()) 207 | self.inner_module = self.enet 208 | else: 209 | self.meta_params = list(self.hnet.parameters()) 210 | self.inner_module = self.hnet 211 | else: 212 | self.meta_params = list(self.mnet.parameters()) 213 | self.inner_module = self.mnet 214 | 215 | self.inner_params= self._get_inner_params() 216 | 217 | def __init__(self, mnet=None, hnet=None, enet=None, inner_param=None, mainnet_use_bias=None, mainnet_hidden_dim=None, hypernet_hidden_dim=None, embedding_dim=None, straight_through=False, config=None): 218 | super().__init__() 219 | if "alpha" not in config: 220 | alpha=1 221 | else: 222 | alpha = config["alpha"] 223 | 224 | if mnet is not None or hnet is not None or enet is not None: 225 | self._init(mnet, hnet, enet, alpha) 226 | if inner_param == "enet": 227 | self.mnet = CLIPAdapter(clip_utils.embedding_size[config["clip_model"]], 228 | mainnet_hidden_dim, use_bias=mainnet_use_bias, no_weights=True, straight_through=straight_through) 229 | self.hnet = HyperGenerator(self.mnet, e_dim=embedding_dim, 230 | hidden_dims=hypernet_hidden_dim, ) 231 | self.enet = EmbeddingModule(embedding_dim) 232 | self.meta_params = list(self.hnet.parameters())+list(self.enet.parameters()) 233 | self.inner_module = self.enet 234 | 235 | elif inner_param == "hnet": 236 | self.mnet = CLIPAdapter(clip_utils.embedding_size[config["clip_model"]], 237 | mainnet_hidden_dim, use_bias=mainnet_use_bias, no_weights=True, straight_through=straight_through) 238 | self.hnet = HyperGenerator(self.mnet, e_dim=clip_utils.embedding_size[config["clip_model"]], 239 | hidden_dims=hypernet_hidden_dim, ) 240 | self.enet = None 241 | self.meta_params = list(self.hnet.parameters()) 242 | self.inner_module = self.hnet 243 | 244 | elif inner_param == "mnet": 245 | self.mnet = CLIPAdapter(clip_utils.embedding_size[config["clip_model"]], 246 | mainnet_hidden_dim, use_bias=mainnet_use_bias, no_weights=False, straight_through=straight_through) 247 | self.hnet = None 248 | self.enet = None 249 | self.meta_params = list(self.mnet.parameters()) 250 | self.inner_module = self.mnet 251 | self.alpha=alpha 252 | self.inner_params= self._get_inner_params() 253 | 254 | def _get_inner_params(self): 255 | params = OrderedDict() 256 | for (name, param) in self.named_parameters(): 257 | if any([id(param) == id(b) for b in self.inner_module.parameters()]): 258 | params[name] = param 259 | return params 260 | 261 | def get_inner_params(self): 262 | return self.inner_params 263 | 264 | def _forward(self, sample_image_features, sample_text_features, sample_ques_emb, params=None): 265 | if self.hnet is not None: 266 | if self.enet is None: 267 | weights = self.hnet.forward(uncond_input=sample_ques_emb, params=self.get_subdict(params, "hnet")) 268 | else: 269 | weights = self.hnet.forward(uncond_input=self.enet(params=self.get_subdict(params, "enet"))) 270 | similarity = self.mnet(sample_image_features, sample_text_features, weights=weights) 271 | else: 272 | similarity = self.mnet(sample_image_features, sample_text_features, params=self.get_subdict(params, "mnet")) 273 | 274 | return similarity 275 | 276 | def forward(self, *args, params=None, **kwargs): 277 | if self.alpha == 1: 278 | return self._forward( *args, params=params, **kwargs) 279 | init_output = self._forward( *args, **kwargs) 280 | adapted_output = self._forward( *args, params = params, **kwargs) 281 | return self.alpha * (adapted_output - init_output) + init_output 282 | 283 | def get_mainnet_weights(self, ques_emb = None, params=None): 284 | if self.hnet is not None: 285 | if self.enet is None: 286 | return self.hnet.forward(uncond_input=ques_emb, params=self.get_subdict(params, "hnet")) 287 | else: 288 | return self.hnet.forward(uncond_input=self.enet(params=self.get_subdict(params, "enet"))) 289 | else: 290 | return self.mnet.get_parameter_vector(params=self.get_subdict(params, "mnet")) 291 | -------------------------------------------------------------------------------- /model/hyperclip.py: -------------------------------------------------------------------------------- 1 | # Credit: modification of code from https://github.com/AndreyGuzhov/AudioCLIP 2 | 3 | from typing import List, Optional, Tuple, Union 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from clip.model import CLIP 8 | 9 | from model.custom_hnet import Mlp 10 | 11 | ClipFeatures = Tuple[ 12 | Optional[torch.Tensor], # text 13 | Optional[torch.Tensor], # image 14 | Optional[torch.Tensor] # hyper 15 | ] 16 | 17 | 18 | ClipLogits = Tuple[ 19 | Optional[torch.Tensor], # hyper x image 20 | Optional[torch.Tensor], # hyper x text 21 | Optional[torch.Tensor] # image x text 22 | ] 23 | 24 | 25 | ClipOutput = Tuple[ 26 | Tuple[ClipFeatures, ClipLogits], 27 | Optional[torch.Tensor] # loss 28 | ] 29 | 30 | class HyperCLIP(CLIP): 31 | 32 | def __init__(self, 33 | embed_dim: int = 1024, 34 | # vision 35 | image_resolution: int = 224, 36 | vision_layers: Union[Tuple[int, int, int, int], int] = (3, 4, 6, 3), 37 | vision_width: int = 64, 38 | vision_patch_size: Optional[int] = None, 39 | # text 40 | context_length: int = 77, 41 | vocab_size: int = 49408, 42 | transformer_width: int = 512, 43 | transformer_heads: int = 8, 44 | transformer_layers: int = 12, 45 | # hyper 46 | hyper_model: str = "mlp", # choose between "mlp" and "embedder_hypernet" 47 | mainnet_param_count: int = 1024*256*2, 48 | hyper_hidden_dims: List[int] = [512], 49 | # pretrained model 50 | pretrained_it_location: Optional[str] = None, 51 | pretrained_hyper_location: Optional[str] = None): 52 | 53 | super(HyperCLIP, self).__init__( 54 | embed_dim=embed_dim, 55 | image_resolution=image_resolution, 56 | vision_layers=vision_layers, 57 | vision_width=vision_width, 58 | vision_patch_size=vision_patch_size, 59 | context_length=context_length, 60 | vocab_size=vocab_size, 61 | transformer_width=transformer_width, 62 | transformer_heads=transformer_heads, 63 | transformer_layers=transformer_layers 64 | ) 65 | 66 | self.embed_dim = embed_dim 67 | self.hyper_model = hyper_model 68 | 69 | self.pretrained_it_location = pretrained_it_location 70 | self.pretrained_hyper_location = pretrained_hyper_location 71 | 72 | self.logit_scale_hi = torch.nn.Parameter(torch.log(torch.ones([]) * 100)) 73 | self.logit_scale_ht = torch.nn.Parameter(torch.log(torch.ones([]) * 100)) 74 | 75 | if pretrained_it_location is not None: 76 | self.load_state_dict(torch.jit.load(self.pretrained_it_location, map_location='cpu').state_dict(), strict=False) 77 | print('Image & Text weights loaded') 78 | 79 | if self.hyper_model == "mlp": 80 | self.hyper = Mlp(mainnet_param_count, hyper_hidden_dims, embed_dim, 'softplus') 81 | else: 82 | raise ValueError(f"Unsupported hyper model {self.hyper_model}") 83 | 84 | if pretrained_hyper_location is not None: 85 | self.hyper.load_state_dict(torch.load(self.pretrained_hyper_location, map_location='cpu'), strict=False) 86 | print('Hyper weights loaded') 87 | 88 | @property 89 | def device(self): 90 | return self.visual.conv1.weight.device 91 | 92 | def encode_hyper(self, weights: torch.Tensor) -> torch.Tensor: 93 | return self.hyper(weights.to(self.device)) 94 | 95 | def forward(self, 96 | hyper: Optional[torch.Tensor] = None, 97 | image: Optional[torch.Tensor] = None, 98 | text: Optional[Union[List[List[str]],torch.Tensor]] = None, 99 | batch_indices: Optional[torch.Tensor] = None, 100 | # precomputed embeddings 101 | precomputed_it_embs: bool = False) -> ClipOutput: 102 | 103 | hyper_features = None 104 | image_features = None 105 | text_features = None 106 | sample_weights = None 107 | 108 | if hyper is not None: 109 | hyper_features = self.encode_hyper(hyper) 110 | hyper_features = hyper_features / hyper_features.norm(dim=-1, keepdim=True) 111 | 112 | if image is not None: 113 | if precomputed_it_embs: 114 | image_features = image 115 | else: 116 | image_features = self.encode_image(image) 117 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 118 | 119 | if text is not None: 120 | if precomputed_it_embs: 121 | text_features = text 122 | else: 123 | if batch_indices is None: 124 | batch_indices = torch.arange(len(text), dtype=torch.int64, device=self.device) 125 | text_features = self.encode_text(text, '{}', batch_indices) 126 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 127 | 128 | if hasattr(self, 'class_weights') and hasattr(self, 'label_to_class_idx'): 129 | sample_weights = torch.stack([ 130 | sum(self.class_weights[self.label_to_class_idx[label]] for label in entities) 131 | for idx, entities in enumerate(text) if idx in batch_indices 132 | ]) 133 | 134 | features: ClipFeatures = (hyper_features, image_features, text_features) 135 | 136 | logit_scale_hi = torch.clamp(self.logit_scale_hi.exp(), min=1.0, max=100.0) 137 | logit_scale_ht = torch.clamp(self.logit_scale_ht.exp(), min=1.0, max=100.0) 138 | logit_scale_it = torch.clamp(self.logit_scale.exp(), min=1.0, max=100.0) 139 | 140 | logits_hyper_image = None 141 | logits_hyper_text = None 142 | logits_image_text = None 143 | 144 | if (hyper_features is not None) and (image_features is not None): 145 | logits_hyper_image = logit_scale_hi * hyper_features @ image_features.T 146 | 147 | if (hyper_features is not None) and (text_features is not None): 148 | logits_hyper_text = logit_scale_ht * hyper_features @ text_features.T 149 | 150 | if (image_features is not None) and (text_features is not None): 151 | logits_image_text = logit_scale_it * image_features @ text_features.T 152 | 153 | logits: ClipLogits = (logits_hyper_image, logits_hyper_text, logits_image_text) 154 | 155 | loss = self.loss_fn(logits, sample_weights) 156 | 157 | return (features, logits), loss 158 | 159 | def loss_fn(self, logits: ClipLogits, sample_weights: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]: 160 | logits_hyper_image, logits_hyper_text, logits_image_text = logits 161 | 162 | if logits_hyper_image is not None: 163 | batch_size = logits_hyper_image.shape[0] 164 | elif logits_hyper_text is not None: 165 | batch_size = logits_hyper_text.shape[0] 166 | elif logits_image_text is not None: 167 | batch_size = logits_image_text.shape[0] 168 | else: 169 | return None 170 | 171 | reference = torch.arange( 172 | batch_size, 173 | dtype=torch.int64, 174 | device=self.device 175 | ) 176 | 177 | loss = torch.tensor(0.0, dtype=self.dtype, device=self.device) 178 | 179 | num_modalities: int = 0 180 | scale = torch.tensor(1.0, dtype=self.dtype, device=self.device) 181 | 182 | if logits_hyper_image is not None: 183 | loss_hi = F.cross_entropy( 184 | logits_hyper_image, reference, weight=sample_weights 185 | ) + F.cross_entropy( 186 | logits_hyper_image.transpose(-1, -2), reference, weight=sample_weights 187 | ) 188 | loss = loss + loss_hi 189 | num_modalities += 1 190 | 191 | if logits_hyper_text is not None: 192 | loss_ht = F.cross_entropy( 193 | logits_hyper_text, reference, weight=sample_weights 194 | ) + F.cross_entropy( 195 | logits_hyper_text.transpose(-1, -2), reference, weight=sample_weights 196 | ) 197 | loss = loss + loss_ht 198 | num_modalities += 1 199 | 200 | if logits_image_text is not None: 201 | loss_it = F.cross_entropy( 202 | logits_image_text, reference, weight=sample_weights 203 | ) + F.cross_entropy( 204 | logits_image_text.transpose(-1, -2), reference, weight=sample_weights 205 | ) 206 | loss = loss + loss_it 207 | num_modalities += 1 208 | 209 | for idx in range(num_modalities): 210 | scale = scale * (idx + 1) 211 | 212 | return loss / scale 213 | 214 | @property 215 | def loss_fn_name(self) -> str: 216 | return 'Cross Entropy' 217 | 218 | 219 | def build_hyperclip_from_classic_clip(state_dict: Union[dict, str], 220 | hyper_model: str = "mlp", 221 | mainnet_param_count: int = 2014*256*2, 222 | hyper_hidden_dims: List[int] = [512], 223 | pretrained_it_location: Optional[str] = None, 224 | pretrained_hyper_location: Optional[str] = None): 225 | 226 | if isinstance(state_dict, str): 227 | state_dict = torch.jit.load(state_dict, map_location='cpu').state_dict() 228 | 229 | vit = "visual.proj" in state_dict 230 | 231 | if vit: 232 | vision_width = state_dict["visual.conv1.weight"].shape[0] 233 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 234 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 235 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 236 | image_resolution = vision_patch_size * grid_size 237 | else: 238 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 239 | vision_layers = tuple(counts) 240 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 241 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 242 | vision_patch_size = None 243 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 244 | image_resolution = output_width * 32 245 | 246 | embed_dim = state_dict["text_projection"].shape[1] 247 | context_length = state_dict["positional_embedding"].shape[0] 248 | vocab_size = state_dict["token_embedding.weight"].shape[0] 249 | transformer_width = state_dict["ln_final.weight"].shape[0] 250 | transformer_heads = transformer_width // 64 251 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 252 | 253 | model = HyperCLIP( 254 | embed_dim, 255 | image_resolution, vision_layers, vision_width, vision_patch_size, 256 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, 257 | hyper_model, mainnet_param_count, hyper_hidden_dims, pretrained_it_location, pretrained_hyper_location 258 | ) 259 | 260 | return model 261 | -------------------------------------------------------------------------------- /model/latent_diffuser.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from utils.diffusion_utils import timestep_embedding 4 | 5 | 6 | class FeedForwardBlock(nn.Module): 7 | def __init__(self, input_dim, output_dim): 8 | super().__init__() 9 | self.input_dim = input_dim 10 | self.output_dim = output_dim 11 | self.layers = nn.Sequential(nn.Linear(input_dim, output_dim), 12 | nn.LayerNorm(output_dim), 13 | nn.ReLU()) 14 | 15 | def forward(self, x): 16 | return self.layers(x) 17 | 18 | 19 | class LatentDiffuser(nn.Module): 20 | def __init__(self, x_dim=128, cond_emb_dim=768, timestep_emb_dim=100, hidden_dims=[128,128]): 21 | super().__init__() 22 | self.x_dim = x_dim 23 | self.cond_emb_dim = cond_emb_dim 24 | self.timestep_emb_dim = timestep_emb_dim 25 | self.hidden_dims = hidden_dims 26 | layers = [FeedForwardBlock(x_dim + cond_emb_dim + timestep_emb_dim, hidden_dims[0])] 27 | for i, h in enumerate(hidden_dims[1:]): 28 | layers += [FeedForwardBlock(hidden_dims[i], h)] 29 | layers += [nn.Linear(hidden_dims[-1], x_dim)] 30 | self.layers = nn.ModuleList(layers) 31 | 32 | def forward(self, x, t, cond_emb): 33 | if type(t) == int: 34 | t = torch.tensor(t, device = x.get_device()).repeat((x.shape[0], 1)) 35 | # right now we only support basic concat of cond_emb and timestep_emb to input x 36 | timestep_emb = timestep_embedding(t.flatten(), self.timestep_emb_dim) 37 | eps = torch.cat((x, cond_emb, timestep_emb), 1) 38 | for layer in self.layers: 39 | eps = layer(eps) 40 | return eps 41 | 42 | class SELayer(nn.Module): 43 | def __init__(self, input_dim, reduction=8): 44 | super().__init__() 45 | self.input_dim = input_dim 46 | self.reduction = reduction 47 | self.fc = nn.Sequential( 48 | nn.Linear(input_dim, input_dim // reduction, bias=False), 49 | nn.ReLU(), 50 | nn.Linear(input_dim // reduction, input_dim, bias=False), 51 | nn.Sigmoid() 52 | ) 53 | 54 | def forward(self, x): 55 | return x * self.fc(x) 56 | 57 | class LatentDiffuserV2(nn.Module): 58 | def __init__(self, x_dim=128, cond_emb_dim=768, timestep_emb_dim=100, hidden_dims=[128,128], se_reduction=8): 59 | super().__init__() 60 | self.x_dim = x_dim 61 | self.cond_emb_dim = cond_emb_dim 62 | self.timestep_emb_dim = timestep_emb_dim 63 | self.hidden_dims = hidden_dims 64 | self.se_reduction = se_reduction 65 | layers = [FeedForwardBlock(x_dim + cond_emb_dim + timestep_emb_dim, hidden_dims[0])] 66 | res_maps = [nn.Linear(x_dim, hidden_dims[0])] 67 | se_layers = [SELayer(x_dim)] 68 | for i, h in enumerate(hidden_dims[1:]): 69 | layers += [FeedForwardBlock(hidden_dims[i] + cond_emb_dim + timestep_emb_dim, h)] 70 | res_maps += [nn.Linear(hidden_dims[i], h)] 71 | se_layers += [SELayer(hidden_dims[i], reduction=self.se_reduction)] 72 | self.layers = nn.ModuleList(layers) 73 | self.res_maps = nn.ModuleList(res_maps) 74 | self.final_layer = nn.Linear(hidden_dims[-1], x_dim) 75 | 76 | def forward(self, x, t, cond_emb): 77 | if type(t) == int: 78 | t = torch.tensor(t, device = x.get_device()).repeat((x.shape[0], 1)) 79 | # right now we only support basic concat of cond_emb and timestep_emb to input x 80 | timestep_emb = timestep_embedding(t.flatten(), self.timestep_emb_dim) 81 | eps_in = x 82 | for i in range(len(self.layers)): 83 | eps = self.layers[i](torch.cat((eps_in, cond_emb, timestep_emb), 1)) 84 | eps_in = eps + self.res_maps[i](eps_in) 85 | eps = self.final_layer(eps) 86 | return eps 87 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elvisnava/hyperclip/8574d3d36fbe1bb3311c3cbb214f07fd73ca0a05/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/hyperclip_classification_test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from functools import partial 5 | 6 | import numpy as np 7 | import torch 8 | import wandb 9 | from features.image_features import load_image_features 10 | from features.ques_features import load_ques_features 11 | from features.text_features import load_text_features 12 | from model.hyperclip import build_hyperclip_from_classic_clip 13 | from tqdm import tqdm 14 | from training.store_few_shot_latent import StoreFewShotLatent 15 | from training.vae_learn import sample_from_enc 16 | from utils import clip_utils 17 | from utils.build_opt import build_optimizer 18 | from utils.init_utils import (load_metamodel_from_checkpoint, 19 | load_vae_and_metamodel_from_checkpoint) 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--wandb_mode', type=str, default='online', help='Set to "disabled" to disable Weights & Biases logging') 23 | parser.add_argument('--run_id', type=str, default='srl_ethz/hyperclip-scripts/irwxxdeh', help='The full "//" identifier of the run to load') 24 | parser.add_argument('--task_batch_size', type=int, default=10, help='Size of randomly sampled tasks for task classification') 25 | 26 | base_path = os.path.dirname(os.path.dirname(__file__)) 27 | 28 | default_config = { 29 | "inner_epochs": '50', 30 | "inner_optimizer": "sgd", 31 | "inner_learning_rate": 0.1, 32 | 33 | "vae_stochastic_init": False, 34 | } 35 | 36 | torch.manual_seed(42) 37 | rng = np.random.RandomState(42) 38 | np.random.seed(42) 39 | 40 | def main(args): 41 | 42 | cfg = default_config 43 | cfg.update({k: v for (k, v) in vars(args).items() if v is not None}) 44 | api = wandb.Api() 45 | loaded_run = api.run(args.run_id) 46 | cfg.update({k: v for (k, v) in loaded_run.config.items() if v is not None and k not in cfg}) 47 | wandb.init(project="hyperclip-classification", entity="srl_ethz", name=loaded_run.name, config=cfg, mode=args.wandb_mode) 48 | config = wandb.config 49 | 50 | device = "cuda" if torch.cuda.is_available() else "cpu" 51 | 52 | with open(base_path + "/data/VQA/Meta/meta_train.json") as file: 53 | train_data = json.load(file) 54 | with open(base_path + "/data/VQA/Meta/meta_test.json") as file: 55 | test_data = json.load(file) 56 | 57 | train_tasks = list(train_data.keys()) 58 | test_tasks = list(test_data.keys()) 59 | 60 | hyperclip_path = base_path + "/evaluation/hyperclip/hyperclip_"+str(loaded_run.name)+".pth" 61 | 62 | hnet_gen, hnet_enc = None, None 63 | if "vae_checkpoint" in config and config["vae_checkpoint"] is not None: 64 | meta_module, hnet_gen, hnet_enc, precomputed_latent, precomputed_latent_train_eval, precomputed_latent_val_eval = \ 65 | load_vae_and_metamodel_from_checkpoint(config, device) 66 | else: 67 | meta_module, precomputed_latent, precomputed_latent_train_eval, precomputed_latent_val_eval = \ 68 | load_metamodel_from_checkpoint(config, device) 69 | 70 | # pre-computed features 71 | image_features = load_image_features(config["clip_model"]) 72 | text_features = load_text_features(config["clip_model"]) 73 | ques_emb = load_ques_features(config["clip_model"]) 74 | 75 | hyperclip = build_hyperclip_from_classic_clip( 76 | os.path.expanduser(clip_utils.cached_location[config["clip_model"]]), 77 | hyper_model=config["hyperclip_model"], 78 | mainnet_param_count=meta_module.mnet.get_parameter_vector().shape[0], 79 | hyper_hidden_dims=[] if config["hyperclip_hidden_dim"] == "" else [int(i) for i in config["hyperclip_hidden_dim"].split(",")], 80 | pretrained_it_location=os.path.expanduser(clip_utils.cached_location[config["clip_model"]]), 81 | pretrained_hyper_location=None).to(device) 82 | 83 | hyperclip.hyper.load_state_dict(torch.load(hyperclip_path), strict=False) 84 | hyperclip.hyper.eval() 85 | 86 | unguided_inner_optim = partial(build_optimizer, config=config, loop="inner") 87 | few_shot_saver = StoreFewShotLatent( meta_module, 88 | image_features=image_features, 89 | text_features=text_features, 90 | ques_emb=ques_emb, 91 | config=config, 92 | device=device, compute_hessian=False, 93 | reset_normal_embedding=config["vae_stochastic_init"] and "vae_checkpoint" in config) 94 | 95 | _, _, _, train_tasks_optimized_params = few_shot_saver.run_epoch(train_data, config["inner_epochs"], unguided_inner_optim, 96 | batch_size=len(list(train_data.keys())), 97 | train=False, skip_cond=True, 98 | train_subtype = config["train_subtype"], val_subtype=config["val_subtype"], 99 | output_mean_std=True, 100 | debug=True) 101 | 102 | 103 | hyperclip_ques_batch = torch.zeros(config["task_batch_size"], clip_utils.embedding_size[config["clip_model"]], dtype=torch.float32).to(device) 104 | hyperclip_net_batch = torch.zeros(config["task_batch_size"], hyperclip.hyper.input_dim, dtype=torch.float32).to(device) 105 | 106 | train_correct = 0 107 | tot_loss_train = 0 108 | 109 | for i, task in enumerate(tqdm(train_tasks)): 110 | 111 | prob_other_sample = np.ones(len(train_tasks)) * (1.0 / (len(train_tasks)-1)) 112 | prob_other_sample[i] = 0.0 113 | other_task_ids = np.random.choice(len(train_tasks), size=config["task_batch_size"]-1, replace=False, p=prob_other_sample) 114 | 115 | hyperclip_ques_batch[0] = ques_emb[task] 116 | 117 | opt_latent = train_tasks_optimized_params[i] 118 | opt_weights = get_weights(meta_module, hnet_gen, hnet_enc, opt_latent, config) 119 | hyperclip_net_batch[0] = opt_weights 120 | 121 | for batch_i, other_id in enumerate(other_task_ids): 122 | hyperclip_ques_batch[batch_i+1] = ques_emb[train_tasks[other_id]] 123 | opt_latent = train_tasks_optimized_params[other_id] 124 | opt_weights = get_weights(meta_module, hnet_gen, hnet_enc, opt_latent, config) 125 | hyperclip_net_batch[batch_i+1] = opt_weights 126 | 127 | with torch.no_grad(): 128 | (features, logits), hyperclip_loss = hyperclip.forward(hyper = hyperclip_net_batch, text = hyperclip_ques_batch, precomputed_it_embs = True) 129 | tot_loss_train += hyperclip_loss.detach().cpu().numpy() 130 | 131 | _, logits_hyper_ques, _ = logits 132 | 133 | logits_hyper_ques = logits_hyper_ques[0] #consider the classification of the true task at position 0 against randomly sampled other tasks 134 | _, pred_index = logits_hyper_ques.topk(1) 135 | if pred_index == 0: 136 | train_correct += 1 137 | 138 | train_accuracy = train_correct / len(train_tasks) 139 | print(f"Trainset hyperclip accuracy: {train_accuracy}") 140 | train_avg_loss = tot_loss_train / len(train_tasks) 141 | print(f"Trainset hyperclip average loss: {train_avg_loss}") 142 | wandb.log({"hyperclip_trainset_accuracy": train_accuracy, "hyperclip_trainset_avg_loss": train_avg_loss}) 143 | 144 | 145 | _, _, _, test_tasks_optimized_params = few_shot_saver.run_epoch(test_data, config["inner_epochs"], unguided_inner_optim, 146 | batch_size=len(list(test_data.keys())), 147 | train=False, skip_cond=True, 148 | train_subtype = config["train_subtype"], val_subtype=config["train_subtype"], 149 | output_mean_std=True, 150 | debug=True) 151 | 152 | test_correct = 0 153 | tot_loss_test = 0 154 | 155 | for i, task in enumerate(tqdm(test_tasks)): 156 | 157 | prob_other_sample = np.ones(len(test_tasks)) * (1.0 / (len(test_tasks)-1)) 158 | prob_other_sample[i] = 0.0 159 | other_task_ids = np.random.choice(len(test_tasks), size=config["task_batch_size"]-1, replace=False, p=prob_other_sample) 160 | 161 | hyperclip_ques_batch[0] = ques_emb[task] 162 | opt_latent = test_tasks_optimized_params[i] 163 | opt_weights = get_weights(meta_module, hnet_gen, hnet_enc, opt_latent, config) 164 | hyperclip_net_batch[0] = opt_weights 165 | for batch_i, other_id in enumerate(other_task_ids): 166 | hyperclip_ques_batch[batch_i+1] = ques_emb[test_tasks[other_id]] 167 | opt_latent = test_tasks_optimized_params[other_id] 168 | opt_weights = get_weights(meta_module, hnet_gen, hnet_enc, opt_latent, config) 169 | hyperclip_net_batch[batch_i+1] = opt_weights 170 | 171 | with torch.no_grad(): 172 | (features, logits), hyperclip_loss = hyperclip.forward(hyper = hyperclip_net_batch, text = hyperclip_ques_batch, precomputed_it_embs = True) 173 | tot_loss_test += hyperclip_loss.detach().cpu().numpy() 174 | 175 | _, logits_hyper_ques, _ = logits 176 | 177 | logits_hyper_ques = logits_hyper_ques[0] #consider the classification of the true task at position 0 against randomly sampled other tasks 178 | _, pred_index = logits_hyper_ques.topk(1) 179 | if pred_index == 0: 180 | test_correct += 1 181 | 182 | test_accuracy = test_correct / len(test_tasks) 183 | print(f"Testset hyperclip accuracy: {test_accuracy}") 184 | test_avg_loss = tot_loss_test / len(test_tasks) 185 | print(f"Testset hyperclip average loss: {test_avg_loss}") 186 | wandb.log({"hyperclip_testset_accuracy": test_accuracy, "hyperclip_testset_avg_loss": test_avg_loss}) 187 | 188 | wandb.finish() 189 | 190 | 191 | def get_weights(meta_module, hnet_gen, hnet_enc, opt_latent, config): 192 | if "vae_checkpoint" in config and config["vae_checkpoint"] is not None and hnet_gen is not None: 193 | return hnet_gen(sample_from_enc(hnet_enc, get_weights_from_metamodule(meta_module, opt_latent))).detach() 194 | return get_weights_from_metamodule(meta_module, opt_latent) 195 | 196 | def get_weights_from_metamodule(meta_module, opt_latent): 197 | inner_params = {k: v.clone().detach().requires_grad_() for (k,v) in meta_module.get_inner_params().items()} 198 | inner_params["enet.embedding"] = opt_latent.clone().detach().requires_grad_() 199 | return meta_module.get_mainnet_weights(params = inner_params).detach() 200 | 201 | if __name__ == "__main__": 202 | args = parser.parse_args() 203 | main(args) 204 | -------------------------------------------------------------------------------- /scripts/precompute_adaptation.py: -------------------------------------------------------------------------------- 1 | # Takes a hypernetwork and learns the hyperclip model on weight generated by embedding adaptation 2 | 3 | import argparse 4 | import json 5 | import os 6 | from functools import partial 7 | 8 | import numpy as np 9 | import torch 10 | import wandb 11 | from data.dataloader.coco_tasks import (filter_categories, 12 | load_coco_answer_features) 13 | from features.image_features import load_image_features 14 | from features.ques_features import load_ques_features 15 | from features.text_features import load_text_features 16 | from training.store_few_shot_latent import StoreFewShotLatent 17 | from utils.build_opt import build_optimizer 18 | from utils.init_utils import load_metamodel_from_checkpoint 19 | from utils.misc_utils import str2bool 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--wandb_mode', type=str, default='online', 23 | help='Set to "disabled" to disable Weights & Biases logging') 24 | 25 | parser.add_argument('--epochs', type=int) 26 | 27 | parser.add_argument('--inner_epochs', type=str) 28 | parser.add_argument('--inner_optimizer', type=str) 29 | parser.add_argument('--inner_learning_rate', type=float) 30 | 31 | parser.add_argument('--compute_hessian', type=str2bool) 32 | 33 | parser.add_argument('--few_shot_checkpoint', type=str, help='required') 34 | parser.add_argument('--vae_checkpoint', type=str, help='required') 35 | parser.add_argument('--vae_stochastic_init', type=str2bool) 36 | 37 | parser.add_argument('--use_extended_coco', type=str2bool) 38 | parser.add_argument('--extend_coco_size', type=int) 39 | parser.add_argument('--extend_coco_frac_train', type=float) 40 | 41 | parser.add_argument('--data_subtype', type=str) 42 | 43 | 44 | base_path = os.path.dirname(os.path.dirname(__file__)) 45 | 46 | default_config = { 47 | "epochs": 100, 48 | "inner_epochs": '10', 49 | "inner_optimizer": "sgd", 50 | "inner_learning_rate": 0.1, 51 | "compute_hessian": False, 52 | "vae_stochastic_init": False, 53 | 54 | "clip_model": "ViT-L/14@336px", 55 | 56 | "use_extended_coco": False, 57 | "extend_coco_size": 10 * 870, 58 | "extend_coco_frac_train": 0.5, 59 | "data_subtype": "random", 60 | } 61 | 62 | torch.manual_seed(42) 63 | rng = np.random.RandomState(42) 64 | np.random.seed(42) 65 | 66 | def main(args): 67 | cfg = default_config 68 | cfg.update({k: v for (k, v) in vars(args).items() if v is not None}) 69 | print(cfg) 70 | wandb.init(project='precompute_adaptation', entity="srl_ethz", config=cfg, 71 | mode=args.wandb_mode) 72 | config = wandb.config 73 | 74 | log_file = base_path + "/evaluation/precompute_adaptation/" + str(wandb.run.name) + ".pth" 75 | log_file_train_eval = base_path + "/evaluation/precompute_adaptation/" + str(wandb.run.name) + "_train_eval.pth" 76 | log_file_val_eval = base_path + "/evaluation/precompute_adaptation/" + str(wandb.run.name) + "_val_eval.pth" 77 | 78 | # Load the model 79 | device = "cuda" if torch.cuda.is_available() else "cpu" 80 | 81 | with open(base_path + "/data/VQA/Meta/meta_train.json") as file: 82 | train_data = json.load(file) 83 | with open(base_path + "/data/VQA/Meta/meta_test.json") as file: 84 | test_data = json.load(file) 85 | 86 | reset_normal_embedding = config["vae_stochastic_init"] and "vae_checkpoint" in config 87 | 88 | meta_module, _, _, _ = load_metamodel_from_checkpoint(config, device) 89 | 90 | # pre-computed features 91 | image_features = load_image_features(config["clip_model"]) 92 | text_features = load_text_features(config["clip_model"]) 93 | ques_emb = load_ques_features(config["clip_model"]) 94 | 95 | coco_categories = None 96 | coco_answer_features = None 97 | if config["use_extended_coco"]: 98 | coco_categories = np.load(base_path+"/data/Attributes/vanilla_coco_categories.npy", allow_pickle=True).item() 99 | coco_categories = filter_categories(coco_categories, 10) 100 | # right now it's hardcoded, it's train_size + test_size for the extended coco sampled datasets 101 | coco_answer_features = load_coco_answer_features(config["clip_model"]) 102 | 103 | unguided_inner_optim = partial(build_optimizer, config=config, loop="inner") 104 | few_shot_saver = StoreFewShotLatent( meta_module, 105 | image_features=image_features, 106 | text_features=text_features, 107 | ques_emb=ques_emb, 108 | config=config, 109 | device=device, compute_hessian=config["compute_hessian"], 110 | reset_normal_embedding=reset_normal_embedding, 111 | coco_categories=coco_categories, coco_answer_features=coco_answer_features, 112 | extend_coco_size=config["extend_coco_size"]) 113 | 114 | print("Computing metric") 115 | from utils.train_utils import log_metric 116 | log_dict = few_shot_saver.run_epoch(train_data, config["inner_epochs"], unguided_inner_optim, 117 | batch_size=len(list(train_data.keys())),train=True, 118 | log_file=log_file_train_eval) 119 | log_metric(log_dict, "eval_train/") 120 | log_dict = few_shot_saver.run_epoch(test_data, config["inner_epochs"], unguided_inner_optim, 121 | batch_size=len(list(test_data.keys())),train=True, 122 | log_file=log_file_val_eval) 123 | log_metric(log_dict, "eval_val/") 124 | 125 | for meta_epoch in range(config["epochs"]): 126 | few_shot_saver.run_epoch(train_data, config["inner_epochs"], unguided_inner_optim, 127 | batch_size=len(list(train_data.keys())), 128 | extend_coco=config["use_extended_coco"], 129 | extend_coco_frac_train=config["extend_coco_frac_train"], 130 | train=True, train_subtype = config["data_subtype"], val_subtype=config["data_subtype"], debug=True, log_file=log_file) 131 | 132 | 133 | wandb.finish() 134 | 135 | 136 | if __name__ == "__main__": 137 | args = parser.parse_args() 138 | main(args) 139 | -------------------------------------------------------------------------------- /scripts/precompute_image_features.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from features.image_features import compute_image_features 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--model', type=str, default='ViT-L/14@336px', help='CLIP visual encoder') 7 | 8 | 9 | if __name__ == "__main__": 10 | args = parser.parse_args() 11 | compute_image_features(args.model) 12 | -------------------------------------------------------------------------------- /scripts/precompute_ques_features.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from features.ques_features import compute_ques_features 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--model', type=str, default='ViT-L/14@336px', help='CLIP visual encoder') 7 | 8 | 9 | if __name__ == "__main__": 10 | args = parser.parse_args() 11 | compute_ques_features(args.model) 12 | -------------------------------------------------------------------------------- /scripts/precompute_text_features.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from features.text_features import compute_text_features 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--model', type=str, default='ViT-L/14@336px', help='CLIP visual encoder') 7 | 8 | 9 | if __name__ == "__main__": 10 | args = parser.parse_args() 11 | compute_text_features(args.model) 12 | -------------------------------------------------------------------------------- /scripts/train_few_shot.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import numpy as np 6 | import torch 7 | import wandb 8 | from data.dataloader.coco_tasks import (filter_categories, 9 | load_coco_answer_features) 10 | from features.image_features import load_image_features 11 | from features.ques_features import load_ques_features 12 | from features.text_features import load_text_features 13 | from model.custom_hnet import MetaModel 14 | from training.maml_learn import MAML 15 | from utils import clip_utils 16 | from utils.build_opt import build_optimizer 17 | from utils.misc_utils import str2bool 18 | from utils.train_utils import log_metric, n_shot_trials_run 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--wandb_mode', type=str, default='online', 22 | help='Set to "disabled" to disable Weights & Biases logging') 23 | 24 | parser.add_argument('--meta_epochs', type=int) 25 | parser.add_argument('--meta_batch_size', type=int) 26 | 27 | parser.add_argument('--inner_epochs', type=str) 28 | 29 | parser.add_argument('--inner_learning_rate', type=float) 30 | parser.add_argument('--eval_inner_epochs', type=str) 31 | parser.add_argument('--second_order', type=str2bool) 32 | parser.add_argument('--train_subtype', type=str) 33 | parser.add_argument('--val_subtype', type=str) 34 | parser.add_argument('--meta_optimizer', type=str) 35 | parser.add_argument('--meta_learning_rate', type=float) 36 | parser.add_argument('--meta_grad_clip', type=float) 37 | 38 | parser.add_argument('--alpha', type=float) 39 | 40 | # meta_module 41 | parser.add_argument('--inner_param', type=str) 42 | parser.add_argument('--hypernet_hidden_dim', type=str) 43 | parser.add_argument('--straight_through', type=str2bool) 44 | parser.add_argument('--embedding_dim', type=int) 45 | 46 | parser.add_argument('--keep_tasks_frac', type=float) 47 | 48 | parser.add_argument('--load_checkpoint', type=str) 49 | parser.add_argument('--val_epoch_interval', type=int) 50 | 51 | parser.add_argument('--use_extended_coco', type=str2bool) 52 | parser.add_argument('--extend_coco_size', type=int) 53 | parser.add_argument('--extend_coco_frac_train', type=float) 54 | 55 | parser.add_argument('--use_clip_embedding_init', type=str2bool) 56 | parser.add_argument('--save_checkpoint', type=str2bool) 57 | 58 | parser.add_argument('--seed', type=int, default=42) 59 | parser.add_argument('--checkpoint', type=str) 60 | parser.add_argument('--eval', type=str2bool, default=False) 61 | 62 | parser.add_argument('--n_shot_trials_maxN', type=int) 63 | 64 | base_path = os.path.dirname(os.path.dirname(__file__)) 65 | 66 | args = parser.parse_args() 67 | 68 | torch.manual_seed(args.seed) 69 | rng = np.random.RandomState(args.seed) 70 | np.random.seed(args.seed) 71 | 72 | default_config = { 73 | # meta_module 74 | "use_clip_embedding_init": False, 75 | "inner_param": "enet", 76 | "hypernet_hidden_dim": "128,128,128", 77 | "straight_through": False, 78 | "mainnet_use_bias": True, 79 | "mainnet_hidden_dim": [256], 80 | "embedding_dim": 128, 81 | 82 | "alpha": 1, 83 | 84 | "clip_model": "ViT-L/14@336px", 85 | 86 | "inner_epochs": "10", 87 | "inner_learning_rate": 0.1, 88 | "train_subtype": "test", 89 | "val_subtype": "train", 90 | 91 | "meta_epochs": 1000, 92 | "meta_batch_size": 32, 93 | "second_order": False, 94 | "meta_grad_clip": 10, 95 | 96 | "eval_inner_epochs": '', 97 | "meta_optimizer": "adam", 98 | "meta_learning_rate": 0.001, 99 | 100 | "val_epoch_interval": 25, 101 | "save_checkpoint": False, 102 | "load_checkpoint": "", 103 | "keep_tasks_frac": 1, 104 | 105 | "use_extended_coco": False, 106 | "extend_coco_size": 10 * 870, 107 | "extend_coco_frac_train": 0.5 108 | } 109 | 110 | 111 | def main(): 112 | cfg=default_config 113 | cfg.update({k:v for (k,v) in vars(args).items() if v is not None}) 114 | 115 | wandb.init(project="train_few_shot", entity="srl_ethz", config=cfg, 116 | mode=args.wandb_mode) 117 | config = wandb.config 118 | 119 | # Load the model 120 | device = "cuda" if torch.cuda.is_available() else "cpu" 121 | 122 | with open(base_path + "/data/VQA/Meta/meta_train.json") as file: 123 | train_data = json.load(file) 124 | with open(base_path + "/data/VQA/Meta/meta_test.json") as file: 125 | test_data = json.load(file) 126 | 127 | # pre-computed features 128 | image_features = load_image_features(config["clip_model"]) 129 | text_features = load_text_features(config["clip_model"]) 130 | ques_emb = load_ques_features(config["clip_model"]) 131 | 132 | coco_categories = None 133 | coco_answer_features = None 134 | if config["use_extended_coco"]: 135 | coco_categories = np.load(base_path+"/data/Attributes/vanilla_coco_categories.npy", allow_pickle=True).item() 136 | coco_categories = filter_categories(coco_categories, 10) # right now it's hardcoded, it's train_size + test_size for the extended coco sampled datasets 137 | coco_answer_features = load_coco_answer_features(config["clip_model"]) 138 | 139 | meta_module = MetaModel( 140 | inner_param=config["inner_param"], 141 | mainnet_use_bias=config["mainnet_use_bias"], 142 | mainnet_hidden_dim=config["mainnet_hidden_dim"], 143 | hypernet_hidden_dim=[] if config["hypernet_hidden_dim"]=="" else [int(i) for i in config["hypernet_hidden_dim"].split(",")], 144 | embedding_dim=config["embedding_dim"] if not config["use_clip_embedding_init"] else clip_utils.embedding_size[config["clip_model"]], 145 | straight_through=config["straight_through"], 146 | config=config).to(device) 147 | 148 | if "checkpoint" in config: 149 | loaded_model_path = config["checkpoint"] 150 | meta_module.load_state_dict(torch.load(loaded_model_path), strict=False) 151 | 152 | meta_optimizer = build_optimizer(meta_module.meta_params, config, loop="meta") 153 | meta_trainer = MAML(meta_module, meta_optimizer, image_features, text_features, ques_emb, config, coco_categories, coco_answer_features, extend_coco_size=config["extend_coco_size"]) 154 | 155 | if "n_shot_trials_maxN" in config and config["n_shot_trials_maxN"] is not None: 156 | n_shot_trials = [] 157 | 158 | best_val_acc=[0]*(1+len(config["eval_inner_epochs"].split(","))) 159 | best_val_epoch=[0]*(1+len(config["eval_inner_epochs"].split(","))) 160 | 161 | for meta_epoch in range(config["meta_epochs"]): 162 | 163 | if not config["eval"]: 164 | meta_trainer.run_epoch(train_data, config["inner_epochs"], config["inner_learning_rate"], meta_batch_size=config["meta_batch_size"], 165 | train=True, second_order=config["second_order"], 166 | train_subtype = config["train_subtype"], val_subtype=config["val_subtype"], debug=True, keep_tasks_frac=config["keep_tasks_frac"], 167 | extend_coco=config["use_extended_coco"], extend_coco_frac_train=config["extend_coco_frac_train"], device=device) 168 | 169 | if config["eval"] or meta_epoch % config["val_epoch_interval"] == 0 or meta_epoch == config["meta_epochs"]-1: 170 | log_dict = meta_trainer.run_epoch(train_data, config["inner_epochs"], config["inner_learning_rate"], keep_tasks_frac=config["keep_tasks_frac"], device=device, epoch=meta_epoch) 171 | log_metric(log_dict, "eval_train/") 172 | log_dict = meta_trainer.run_epoch(test_data, config["inner_epochs"], config["inner_learning_rate"], device=device, epoch=meta_epoch) 173 | 174 | if best_val_acc[0] < log_dict["query_accuracy_end"]: 175 | best_val_acc[0] = log_dict["query_accuracy_end"] 176 | best_val_epoch[0] = meta_epoch 177 | log_dict["best_accuracy"] = best_val_acc[0] 178 | log_dict["best_epoch"] = best_val_epoch[0] 179 | log_metric(log_dict, "eval_val/") 180 | 181 | if log_dict["query_accuracy_end"] < 0.3: 182 | print("Stopping training") 183 | return 184 | 185 | if "n_shot_trials_maxN" in config and config["n_shot_trials_maxN"] is not None: 186 | n_shot_trial_dict = {"epoch": meta_epoch} 187 | 188 | if config["eval_inner_epochs"] != '': 189 | for idx, inner_epochs in enumerate(config["eval_inner_epochs"].split(",")): 190 | # log_dict = meta_trainer.run_epoch(train_data, int(inner_epochs), config["inner_learning_rate"], keep_tasks_frac=config["keep_tasks_frac"], device=device) 191 | # log_dict["epoch"]=meta_epoch 192 | # log_metric(log_dict, "eval_train_{}step/".format(inner_epochs)) 193 | 194 | log_dict = meta_trainer.run_epoch(test_data, int(inner_epochs), config["inner_learning_rate"], device=device, epoch=meta_epoch) 195 | 196 | if best_val_acc[idx+1] < log_dict["query_accuracy_end"]: 197 | best_val_acc[idx+1] = log_dict["query_accuracy_end"] 198 | best_val_epoch[idx+1] = meta_epoch 199 | log_dict["best_accuracy"] = best_val_acc[idx+1] 200 | log_dict["best_epoch"] = best_val_epoch[idx+1] 201 | 202 | log_metric(log_dict, "eval_val_{}step/".format(inner_epochs)) 203 | 204 | n_shot_trials_run(meta_trainer, n_shot_trial_dict, config, f"eval_val_{inner_epochs}step/", test_data, int(inner_epochs), config["inner_learning_rate"], device=device, epoch=meta_epoch) 205 | 206 | if "n_shot_trials_maxN" in config and config["n_shot_trials_maxN"] is not None: 207 | n_shot_trials += [n_shot_trial_dict] 208 | 209 | if config["eval"]: 210 | return 211 | 212 | if config["save_checkpoint"] and (meta_epoch+1) % 20 == 0: 213 | model_output_path_checkpoint = base_path + "/evaluation/few_shot/meta_module" + str(wandb.run.name) + "_" + str( meta_epoch) + ".pth" 214 | torch.save(meta_module.state_dict(), model_output_path_checkpoint) 215 | print(f"Checkpoint for meta-epoch {meta_epoch} saved!") 216 | 217 | if "n_shot_trials_maxN" in config and config["n_shot_trials_maxN"] is not None: 218 | model_output_path_checkpoint = base_path + "/evaluation/few_shot/few_shot_" + str( 219 | wandb.run.name) + "_n_shot.npy" 220 | np.save(model_output_path_checkpoint, n_shot_trials) 221 | 222 | model_output_path_checkpoint = base_path + "/evaluation/few_shot/few_shot_" + str(wandb.run.name) + ".pth" 223 | torch.save(meta_module.state_dict(), model_output_path_checkpoint) 224 | 225 | wandb.finish() 226 | 227 | 228 | if __name__ == "__main__": 229 | main() 230 | -------------------------------------------------------------------------------- /scripts/train_hyperclip.py: -------------------------------------------------------------------------------- 1 | # Takes a hypernetwork and learns the hyperclip model on weight generated by embedding adaptation 2 | 3 | 4 | import argparse 5 | import json 6 | import os 7 | from functools import partial 8 | 9 | import numpy as np 10 | import torch 11 | import torch.optim as optim 12 | import wandb 13 | from features.image_features import load_image_features 14 | from features.ques_features import load_ques_features 15 | from features.text_features import load_text_features 16 | from model.hyperclip import build_hyperclip_from_classic_clip 17 | from training.hyperclip_learn import HyperclipTraining 18 | from utils import clip_utils 19 | from utils.build_opt import build_optimizer 20 | from utils.init_utils import (load_metamodel_from_checkpoint, 21 | load_vae_and_metamodel_from_checkpoint) 22 | from utils.misc_utils import str2bool 23 | from utils.train_utils import log_metric 24 | 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--wandb_mode', type=str, default='online', 27 | help='Set to "disabled" to disable Weights & Biases logging') 28 | 29 | parser.add_argument('--inner_epochs', type=str) 30 | 31 | parser.add_argument('--few_shot_checkpoint', type=str, help='required') 32 | parser.add_argument('--vae_checkpoint', type=str, help='required') 33 | parser.add_argument('--precompute_checkpoint', type=str, help='required') 34 | 35 | parser.add_argument('--hyperclip_epochs', type=int) 36 | parser.add_argument('--hyperclip_batch_size', type=int) 37 | parser.add_argument('--val_epoch_interval', type=int) 38 | 39 | 40 | parser.add_argument('--hyperclip_optimizer', type=str) 41 | parser.add_argument('--hyperclip_learning_rate', type=float) 42 | parser.add_argument('--eval_inner_epochs', type=str) 43 | parser.add_argument('--guidance_optimizer', type=str) 44 | parser.add_argument('--guidance_learning_rate', type=float) 45 | 46 | parser.add_argument('--guidance_scheduler', type=str) 47 | 48 | # hyperclip 49 | parser.add_argument('--hyperclip_hidden_dim', type=str) 50 | parser.add_argument('--hyperclip_model', type=str) 51 | 52 | parser.add_argument('--guidance_init_l2_weight', type=str) 53 | parser.add_argument('--train_on_vae', type=str2bool) 54 | 55 | parser.add_argument('--checkpoint', type=str) 56 | parser.add_argument('--eval', type=str2bool) 57 | parser.add_argument('--langevin_eps', type=str) 58 | 59 | parser.add_argument('--normalize', type=str2bool) 60 | 61 | 62 | base_path = os.path.dirname(os.path.dirname(__file__)) 63 | 64 | 65 | default_config = { 66 | "inner_epochs": '50', 67 | "inner_optimizer": "sgd", 68 | "inner_learning_rate": 0.1, 69 | "train_subtype": "random", 70 | "val_subtype": "random", 71 | 72 | "clip_model": "ViT-L/14@336px", 73 | "hyperclip_model": "mlp", 74 | "hyperclip_hidden_dim": "128,128", 75 | 76 | "hyperclip_epochs": 1000, 77 | "hyperclip_batch_size": 32, 78 | 79 | "hyperclip_optimizer": "adam", 80 | "hyperclip_learning_rate": 0.001, 81 | "hyperclip_weight_decay": 0, 82 | "hyperclip_momentum": 0.9, 83 | "hyperclip_sgd_nesterov": True, 84 | 85 | "eval_inner_epochs": '50', 86 | "guidance_optimizer": "adam", 87 | "guidance_learning_rate": 0.001, 88 | "guidance_momentum": 0.9, 89 | "guidance_sgd_nesterov": True, 90 | "guidance_init_l2_weight": "0", 91 | 92 | "guidance_scheduler": "none", 93 | 94 | "val_epoch_interval": 100, 95 | "save_checkpoint": False, 96 | "train_on_vae":False, 97 | 98 | "eval": False, 99 | "normalize":False, 100 | "langevin_eps": "0" 101 | } 102 | 103 | torch.manual_seed(42) 104 | rng = np.random.RandomState(42) 105 | np.random.seed(42) 106 | 107 | def main(args): 108 | cfg = default_config 109 | cfg.update({k: v for (k, v) in vars(args).items() if v is not None}) 110 | print(cfg) 111 | wandb.init(project='train_hyperclip', entity="srl_ethz", config=cfg, 112 | mode=args.wandb_mode) 113 | config = wandb.config 114 | 115 | # Load the model 116 | device = "cuda" if torch.cuda.is_available() else "cpu" 117 | 118 | with open(base_path + "/data/VQA/Meta/meta_train.json") as file: 119 | train_data = json.load(file) 120 | with open(base_path + "/data/VQA/Meta/meta_test.json") as file: 121 | test_data = json.load(file) 122 | 123 | hnet_gen, hnet_enc = None, None 124 | if "vae_checkpoint" in config and config["vae_checkpoint"] is not None: 125 | meta_module, hnet_gen, hnet_enc, precomputed_latent, precomputed_latent_train_eval, precomputed_latent_val_eval = \ 126 | load_vae_and_metamodel_from_checkpoint(config, device) 127 | else: 128 | meta_module, precomputed_latent, precomputed_latent_train_eval, precomputed_latent_val_eval = \ 129 | load_metamodel_from_checkpoint(config, device) 130 | 131 | # pre-computed features 132 | image_features = load_image_features(config["clip_model"]) 133 | text_features = load_text_features(config["clip_model"]) 134 | ques_emb = load_ques_features(config["clip_model"]) 135 | 136 | hyperclip = build_hyperclip_from_classic_clip( 137 | os.path.expanduser(clip_utils.cached_location[config["clip_model"]]), 138 | hyper_model=config["hyperclip_model"], 139 | mainnet_param_count=meta_module.mnet.get_parameter_vector().shape[0], 140 | hyper_hidden_dims=[] if config["hyperclip_hidden_dim"] == "" else [int(i) for i in config["hyperclip_hidden_dim"].split(",")], 141 | pretrained_it_location=os.path.expanduser(clip_utils.cached_location[config["clip_model"]]), 142 | pretrained_hyper_location=None).to(device) 143 | 144 | if "checkpoint" in config: 145 | api = wandb.Api() 146 | loaded_run = api.run(config["checkpoint"]) 147 | loaded_model_path = base_path + "/evaluation/hyperclip/hyperclip_" + str(loaded_run.name) + ".pth" 148 | hyperclip.hyper.load_state_dict(torch.load(loaded_model_path), strict=False) 149 | 150 | 151 | hyperclip.hyper.train() 152 | hyperclip_optimizer = build_optimizer(hyperclip.hyper.parameters(), config, loop="hyperclip") 153 | hyperclip_training = HyperclipTraining(meta_module, hnet_gen, hnet_enc, 154 | hyperclip, hyperclip_optimizer, image_features, text_features, 155 | ques_emb, config, device, train_on_vae=config["train_on_vae"]) 156 | 157 | unguided_inner_optim =partial(build_optimizer, config=config, loop="inner") 158 | guided_inner_optim =partial(build_optimizer, config=config, loop="guidance") 159 | 160 | 161 | # Get STD and MEAN 162 | if precomputed_latent is not None: 163 | sampled_precomputed_latent = {k:v[0:1].to(device) for (k,v) in precomputed_latent.items()} 164 | else: 165 | sampled_precomputed_latent=None 166 | 167 | _, optimized_params_mean, optimized_params_std, _ = hyperclip_training.run_epoch(train_data, config["inner_epochs"], unguided_inner_optim, 168 | batch_size=config["hyperclip_batch_size"], precomputed_latent=sampled_precomputed_latent, 169 | output_mean_std=True, train_subtype = config["train_subtype"], val_subtype=config["val_subtype"], skip_cond=True) 170 | 171 | optimized_params_mean, optimized_params_std=optimized_params_mean[0], optimized_params_std[0].clip(0.01) 172 | 173 | if config["normalize"]: 174 | hyperclip_training.set_stats(optimized_params_mean, optimized_params_std) 175 | wandb.log({"mean":optimized_params_mean, "std":optimized_params_std}) 176 | 177 | for meta_epoch in range(config["hyperclip_epochs"]): 178 | if precomputed_latent is not None: 179 | curr_sample_epoch = np.random.randint(0, precomputed_latent["clip_embedding"].shape[0]) 180 | curr_sample_batch_perm = np.random.permutation(precomputed_latent["clip_embedding"].shape[1]) 181 | sampled_precomputed_latent = {k:v[curr_sample_epoch:curr_sample_epoch+1, curr_sample_batch_perm].to(device) for (k,v) in precomputed_latent.items()} 182 | else: 183 | sampled_precomputed_latent=None 184 | 185 | if not config["eval"]: 186 | hyperclip_training.run_epoch(train_data, config["inner_epochs"], unguided_inner_optim, 187 | batch_size=config["hyperclip_batch_size"], precomputed_latent=sampled_precomputed_latent, 188 | train=True, train_subtype = config["train_subtype"], val_subtype=config["val_subtype"], debug=True) 189 | 190 | if config["eval"] or (meta_epoch + 1) % (config["val_epoch_interval"]//10+1) == 0: 191 | 192 | log_dict = hyperclip_training.run_epoch(train_data, config["inner_epochs"], unguided_inner_optim, precomputed_latent=precomputed_latent_train_eval, epoch=meta_epoch) 193 | log_metric(log_dict, "eval_train/") 194 | log_dict = hyperclip_training.run_epoch(test_data, config["inner_epochs"], unguided_inner_optim, precomputed_latent=precomputed_latent_val_eval, epoch=meta_epoch) 195 | log_metric(log_dict, "eval_val/") 196 | 197 | if config["eval"] or (meta_epoch + 1) % config["val_epoch_interval"] == 0: 198 | if config["eval_inner_epochs"] != '': 199 | def eval(inner_epochs, guidance_init_l2_weight, guidance_scheduler, langevin_eps, train=False): 200 | guidance_scheduler_fn=None 201 | if guidance_scheduler == "cos": 202 | guidance_scheduler_fn = partial(optim.lr_scheduler.CosineAnnealingLR, T_max=inner_epochs, eta_min=0) 203 | 204 | if train: 205 | log_dict = hyperclip_training.run_epoch(train_data, inner_epochs, guided_inner_optim, 206 | guided_inner=True, use_vae=True, init_guidance_at="pre-trained", skip_cond=True, guidance_init_l2_weight=guidance_init_l2_weight, langevin_eps=langevin_eps, guidance_scheduler_fn=guidance_scheduler_fn, epoch=meta_epoch) 207 | log_metric(log_dict, "guided_eval_train_{}step_{}l2_{}_{}/".format(inner_epochs, guidance_init_l2_weight, guidance_scheduler, langevin_eps)) 208 | 209 | log_dict = hyperclip_training.run_epoch(test_data, inner_epochs, guided_inner_optim, guided_inner=True, use_vae=True, init_guidance_at="pre-trained", 210 | skip_cond=True, guidance_init_l2_weight=guidance_init_l2_weight, langevin_eps=langevin_eps, guidance_scheduler_fn=guidance_scheduler_fn, epoch=meta_epoch) 211 | 212 | log_metric(log_dict, "guided_eval_val_{}step_{}l2_{}_{}/".format(inner_epochs, guidance_init_l2_weight, guidance_scheduler, langevin_eps)) 213 | 214 | for idx, inner_epochs in enumerate(config["eval_inner_epochs"].split(",")): 215 | for _, guidance_init_l2_weight in enumerate(config["guidance_init_l2_weight"].split(",")): 216 | for _, guidance_scheduler in enumerate(config["guidance_scheduler"].split(",")): 217 | for _, langevin_eps in enumerate(config["langevin_eps"].split(",")): 218 | eval(int(inner_epochs), float(guidance_init_l2_weight), guidance_scheduler, float(langevin_eps), train=(meta_epoch + 1) % 1000 == 0) 219 | 220 | if config["eval"]: 221 | return 222 | 223 | if config["save_checkpoint"]: 224 | hyperclip_output_path_checkpoint = base_path + "/evaluation/hyperclip/hyperclip_" + str( 225 | wandb.run.name) + "_" + str( 226 | meta_epoch) + ".pth" 227 | torch.save(hyperclip.hyper.state_dict(), hyperclip_output_path_checkpoint) 228 | print(f"Checkpoint for meta-epoch {meta_epoch} saved!") 229 | 230 | hyperclip_output_path_checkpoint = base_path + "/evaluation/hyperclip/hyperclip_" + str( 231 | wandb.run.name)+ ".pth" 232 | torch.save(hyperclip.hyper.state_dict(), hyperclip_output_path_checkpoint) 233 | 234 | wandb.finish() 235 | 236 | 237 | if __name__ == "__main__": 238 | args = parser.parse_args() 239 | main(args) 240 | -------------------------------------------------------------------------------- /scripts/train_vae.py: -------------------------------------------------------------------------------- 1 | # Takes a base model and learns a generative hypernetwork using adapted weights as data. 2 | 3 | import argparse 4 | import json 5 | import os 6 | from functools import partial 7 | 8 | import numpy as np 9 | import torch 10 | import wandb 11 | from features.image_features import load_image_features 12 | from features.ques_features import load_ques_features 13 | from features.text_features import load_text_features 14 | from model.custom_hnet import HyperEncoder, HyperGenerator 15 | from training.vae_learn import VAETraining 16 | from utils.build_opt import build_optimizer 17 | from utils.init_utils import load_metamodel_from_checkpoint 18 | from utils.misc_utils import str2bool 19 | from utils.train_utils import log_metric 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--wandb_mode', type=str, default='online', 23 | help='Set to "disabled" to disable Weights & Biases logging') 24 | 25 | parser.add_argument('--few_shot_checkpoint', type=str, help='required') 26 | parser.add_argument('--precompute_checkpoint', type=str, help='required') 27 | 28 | parser.add_argument('--vae_epochs', type=int) 29 | parser.add_argument('--vae_batch_size', type=int) 30 | parser.add_argument('--vae_noise_dim', type=int) 31 | 32 | parser.add_argument('--vae_optimizer', type=str) 33 | parser.add_argument('--vae_learning_rate', type=float) 34 | parser.add_argument('--kld_weight', type=float) 35 | 36 | parser.add_argument('--eval_inner_epochs', type=str) 37 | parser.add_argument('--generated_optimizer', type=str) 38 | parser.add_argument('--generated_learning_rate', type=float) 39 | 40 | # hyperclip 41 | parser.add_argument('--vae_hidden_dim', type=str) 42 | parser.add_argument('--val_epoch_interval', type=int) 43 | base_path = os.path.dirname(os.path.dirname(__file__)) 44 | 45 | parser.add_argument('--normalize', type=str2bool) 46 | parser.add_argument('--grad_clip', type=float) 47 | parser.add_argument('--save_checkpoint', type=str2bool) 48 | parser.add_argument('--seed', type=int, default=42) 49 | 50 | args = parser.parse_args() 51 | torch.manual_seed(args.seed) 52 | rng = np.random.RandomState(args.seed) 53 | np.random.seed(args.seed) 54 | default_config = { 55 | "inner_epochs": 50, 56 | "inner_optimizer": "sgd", 57 | "inner_learning_rate": 0.1, 58 | "train_subtype": "random", 59 | "val_subtype": "random", 60 | 61 | "vae_hidden_dim": "128,128", 62 | "vae_noise_dim": 128, 63 | 64 | "vae_epochs": 10000, 65 | "vae_batch_size": 32, 66 | "kld_weight":1, 67 | "vae_optimizer": "adam", 68 | "vae_learning_rate": 0.0001, 69 | 70 | "eval_inner_epochs": '50', 71 | "generated_optimizer": "adam", 72 | "generated_learning_rate": 0.001, 73 | "generated_momentum": 0.9, 74 | "generated_sgd_nesterov": True, 75 | 76 | "val_epoch_interval": 2000, 77 | "clip_model": "ViT-L/14@336px", 78 | "save_checkpoint": False, 79 | "normalize": False, 80 | "grad_clip": -1, 81 | } 82 | 83 | def main(args): 84 | cfg = default_config 85 | cfg.update({k: v for (k, v) in vars(args).items() if v is not None}) 86 | print(cfg) 87 | wandb.init(project="train_vae", entity="srl_ethz", config=cfg, 88 | mode=args.wandb_mode) 89 | config = wandb.config 90 | 91 | # Load the model 92 | device = "cuda" if torch.cuda.is_available() else "cpu" 93 | 94 | with open(base_path + "/data/VQA/Meta/meta_train.json") as file: 95 | train_data = json.load(file) 96 | with open(base_path + "/data/VQA/Meta/meta_test.json") as file: 97 | test_data = json.load(file) 98 | 99 | meta_module, precomputed_latent, precomputed_latent_train_eval, precomputed_latent_val_eval = \ 100 | load_metamodel_from_checkpoint(config, device) 101 | 102 | # pre-computed features 103 | image_features = load_image_features(config["clip_model"]) 104 | text_features = load_text_features(config["clip_model"]) 105 | ques_emb = load_ques_features(config["clip_model"]) 106 | 107 | hidden_dims = [int(h) for h in config["vae_hidden_dim"].split(",")] 108 | hnet_enc = HyperEncoder(meta_module.mnet, e_dim=config["vae_noise_dim"], 109 | hidden_dims=hidden_dims, normalize="normalize" in config and config["normalize"]).to(device) 110 | hidden_dims.reverse() 111 | hnet_gen = HyperGenerator(meta_module.mnet, e_dim=config["vae_noise_dim"], 112 | hidden_dims=hidden_dims, normalize="normalize" in config and config["normalize"]).to(device) 113 | 114 | optimizer_gen = build_optimizer(hnet_gen.parameters(), config, loop="vae") 115 | optimizer_enc = build_optimizer(hnet_enc.parameters(), config, loop="vae") 116 | 117 | vae_trainer = VAETraining(meta_module, hnet_gen, hnet_enc, optimizer_gen, optimizer_enc, image_features, 118 | text_features, ques_emb, config, device) 119 | 120 | inner_optim =partial(build_optimizer, config=config, loop="inner") 121 | generated_inner_optim =partial(build_optimizer, config=config, loop="generated") 122 | 123 | # Get STD and MEAN 124 | if precomputed_latent is not None: 125 | sampled_precomputed_latent = {k:v[0:1].to(device) for (k,v) in precomputed_latent.items()} 126 | else: 127 | sampled_precomputed_latent=None 128 | 129 | _, optimized_params_mean, optimized_params_std, _ = \ 130 | vae_trainer.run_epoch(train_data, config["inner_epochs"], inner_optim, precomputed_latent=sampled_precomputed_latent, 131 | batch_size=config["vae_batch_size"], 132 | output_mean_std=True, device=device, train_subtype = config["train_subtype"], val_subtype=config["val_subtype"], debug=True) 133 | 134 | optimized_params_mean, optimized_params_std=optimized_params_mean[0], optimized_params_std[0] 135 | wandb.log({"mean":optimized_params_mean, "std":optimized_params_std, "std_avr": optimized_params_std.mean().item()}) 136 | 137 | if config["normalize"]: 138 | optimized_params_std=optimized_params_std.clip(optimized_params_std.mean().item(),optimized_params_std.mean().item()) 139 | hnet_gen.set_stats(optimized_params_mean, optimized_params_std) 140 | hnet_enc.set_stats(optimized_params_mean, optimized_params_std) 141 | vae_trainer.set_stats(optimized_params_mean, optimized_params_std) 142 | 143 | for meta_epoch in range(config["vae_epochs"]): 144 | if precomputed_latent is not None: 145 | curr_sample_epoch = np.random.randint(0, precomputed_latent["clip_embedding"].shape[0]) 146 | curr_sample_batch_perm = np.random.permutation(precomputed_latent["clip_embedding"].shape[1]) 147 | sampled_precomputed_latent = {k:v[curr_sample_epoch:curr_sample_epoch+1, curr_sample_batch_perm].to(device) for (k,v) in precomputed_latent.items()} 148 | else: 149 | sampled_precomputed_latent=None 150 | 151 | vae_trainer.run_epoch(train_data, config["inner_epochs"], inner_optim, precomputed_latent=sampled_precomputed_latent, 152 | batch_size=config["vae_batch_size"], 153 | train=True, device=device, train_subtype = config["train_subtype"], val_subtype=config["val_subtype"], debug=True) 154 | 155 | if (meta_epoch+1) % (config["val_epoch_interval"]//10+1) == 0: 156 | log_dict = vae_trainer.run_epoch(train_data, config["inner_epochs"], inner_optim, 157 | precomputed_latent=precomputed_latent_train_eval, device=device, epoch=meta_epoch) 158 | log_metric(log_dict, "eval_train/") 159 | log_dict = vae_trainer.run_epoch(test_data, config["inner_epochs"], inner_optim, 160 | precomputed_latent=precomputed_latent_val_eval, device=device, epoch=meta_epoch) 161 | log_metric(log_dict, "eval_val/") 162 | 163 | log_dict = vae_trainer.run_epoch(train_data, config["inner_epochs"], inner_optim, reconstructed=True, 164 | precomputed_latent=precomputed_latent_train_eval, device=device, epoch=meta_epoch) 165 | log_metric(log_dict, "reconstr_eval_train/") 166 | log_dict = vae_trainer.run_epoch(test_data, config["inner_epochs"], inner_optim, reconstructed=True, 167 | precomputed_latent=precomputed_latent_val_eval, device=device, epoch=meta_epoch) 168 | log_metric(log_dict, "reconstr_eval_val/") 169 | 170 | if meta_epoch % config["val_epoch_interval"] == config["val_epoch_interval"]-1: 171 | 172 | if config["eval_inner_epochs"] != '': 173 | for idx, inner_epochs in enumerate(config["eval_inner_epochs"].split(",")): 174 | log_dict = vae_trainer.run_epoch(train_data, int(inner_epochs), generated_inner_optim, device=device, generated=True, epoch=meta_epoch) 175 | log_metric(log_dict, "generated_eval_train_{}step/".format(inner_epochs)) 176 | 177 | log_dict = vae_trainer.run_epoch(test_data, int(inner_epochs), generated_inner_optim, device=device, generated=True, epoch=meta_epoch) 178 | log_metric(log_dict, "generated_eval_val_{}step/".format(inner_epochs)) 179 | 180 | if config["save_checkpoint"]: 181 | hnet_gen_output_path_checkpoint = base_path + "/evaluation/vae/hnet_gen_" + str( 182 | wandb.run.name) + ".pth" 183 | hnet_enc_output_path_checkpoint = base_path + "/evaluation/vae/hnet_enc_" + str( 184 | wandb.run.name) + ".pth" 185 | torch.save(hnet_gen.state_dict(), hnet_gen_output_path_checkpoint) 186 | torch.save(hnet_enc.state_dict(), hnet_enc_output_path_checkpoint) 187 | print(f"Checkpoint for meta-epoch {meta_epoch} saved!") 188 | 189 | hnet_gen_output_path_checkpoint = base_path + "/evaluation/vae/hnet_gen_" + str( 190 | wandb.run.name) + ".pth" 191 | hnet_enc_output_path_checkpoint = base_path + "/evaluation/vae/hnet_enc_" + str( 192 | wandb.run.name) + ".pth" 193 | torch.save(hnet_gen.state_dict(), hnet_gen_output_path_checkpoint) 194 | torch.save(hnet_enc.state_dict(), hnet_enc_output_path_checkpoint) 195 | 196 | wandb.finish() 197 | 198 | 199 | if __name__ == "__main__": 200 | main(args) 201 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(name='hyperclip', version='0.1', packages=find_packages()) 4 | -------------------------------------------------------------------------------- /training/conditional_model_learn.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from collections import OrderedDict 3 | 4 | import numpy as np 5 | import torch 6 | import wandb 7 | from data.dataloader.clip_vqa import CLIP_VQA 8 | from data.dataloader.coco_tasks import COCO_Tasks 9 | from model.custom_hnet import CLIPAdapter, EmbeddingModule, MetaModel 10 | from torch.nn import functional as F 11 | from torch.utils.data import DataLoader 12 | from tqdm import tqdm 13 | from utils import clip_utils 14 | from utils.misc_utils import append_dict, mean_dict 15 | from utils.train_utils import get_pred, log_metric, test_accuracy 16 | 17 | from training.vae_learn import sample_from_enc 18 | 19 | 20 | class ConditionalModelTraining(): 21 | # Train a "conditional" model from question embeddings and finetuned network features (weights, latents) 22 | def __init__(self, meta_module, cond_model, optimizer, image_features, text_features, ques_emb, config, device, net_feature_dim, 23 | hnet_gen, hnet_enc, train_on_vae=False, compute_hessian=False, 24 | coco_categories=None, coco_answer_features=None, extend_coco_size=10 * 870): 25 | self.meta_module = meta_module 26 | self.cond_model = cond_model 27 | self.optimizer = optimizer 28 | self.image_features = image_features 29 | self.text_features = text_features 30 | self.ques_emb = ques_emb 31 | self.config = config 32 | self.device = device 33 | self.net_feature_dim = net_feature_dim 34 | self.compute_hessian = compute_hessian 35 | self.hnet_gen=hnet_gen 36 | self.hnet_enc=hnet_enc 37 | self.base_module=meta_module 38 | self.train_on_vae=train_on_vae 39 | 40 | self.coco_categories=coco_categories 41 | self.coco_answer_features=coco_answer_features 42 | self.extend_coco_size=extend_coco_size 43 | 44 | def reset(self, batch_size): 45 | self.sub_batch_i = 0 46 | self.ques_batch = torch.zeros(batch_size, clip_utils.embedding_size[self.config["clip_model"]], dtype=torch.float32).to(self.device) 47 | self.net_batch = torch.zeros(batch_size, self.net_feature_dim, dtype=torch.float32).to(self.device) 48 | self.task_batch = torch.zeros(batch_size, 1, dtype=torch.int).to(self.device) 49 | self.coco_batch = torch.zeros(batch_size, 1, dtype=torch.bool).to(self.device) 50 | 51 | self.hessian_batch = None 52 | if self.compute_hessian: 53 | self.hessian_batch = torch.zeros(batch_size, self.net_feature_dim, self.net_feature_dim, dtype=torch.float32).to(self.device) 54 | 55 | def reset_task(self, use_vae): 56 | if self.hnet_gen is not None and use_vae: 57 | embedding = sample_from_enc(self.hnet_enc, self.base_module.get_mainnet_weights(params = None).detach()).detach() 58 | 59 | enet = EmbeddingModule(self.hnet_gen.input_dim).to(self.device) 60 | 61 | enet.reset(embedding=embedding) 62 | mnet = CLIPAdapter(e_dim=self.meta_module.mnet.e_dim, 63 | hidden_layers=[self.meta_module.mnet.hidden_size], 64 | use_bias=self.meta_module.mnet.use_bias, 65 | straight_through=self.meta_module.mnet.straight_through, 66 | no_weights=True).to(self.device) 67 | 68 | self.meta_module = MetaModel(mnet=mnet, hnet=self.hnet_gen, enet=enet, config=self.config) 69 | else: 70 | self.meta_module = self.base_module 71 | 72 | 73 | def run_iter(self, task_idx, train_dataloader, net_features, train=True, clip_embedding=None, embed_hessian=None, coco=False, **kwargs): 74 | log_dict=dict() 75 | self.ques_batch[self.sub_batch_i] = clip_embedding if clip_embedding is not None else iter(train_dataloader).next()["ques_emb"][0] 76 | self.net_batch[self.sub_batch_i] = net_features 77 | self.task_batch[self.sub_batch_i] = task_idx 78 | self.coco_batch[self.sub_batch_i] = coco 79 | if self.compute_hessian: 80 | self.hessian_batch[self.sub_batch_i] = embed_hessian 81 | 82 | if self.sub_batch_i == self.ques_batch.shape[0]-1: # Train/test one batch of hyperclip once the batch is filled 83 | self.train_step(log_dict, **kwargs) if train else self.test_step(log_dict, **kwargs) 84 | 85 | self.sub_batch_i = (self.sub_batch_i+1) % self.ques_batch.shape[0] 86 | return log_dict 87 | 88 | def train_step(self, log_dict, **kwargs): 89 | raise NotImplementedError("Use a subclass of ConditionalModelTraining") 90 | 91 | def test_step(self, log_dict, **kwargs): 92 | raise NotImplementedError("Use a subclass of ConditionalModelTraining") 93 | 94 | def run_epoch(self, data, inner_epochs, inner_optim_fct, batch_size=32, 95 | train_subtype="train", val_subtype="test", guided_inner=False, 96 | precomputed_latent=None, 97 | train=False, skip_cond=False, output_mean_std=False, tasks_idxs=None, 98 | n_shot_training=None, opt_latents_for_n_shot=None, 99 | debug=False, use_vae=False, keep_tasks_frac=1, extend_coco=False, 100 | extend_coco_frac_train=0.5, num_ensemble=1,# frac of tasks to replace with extended coco 101 | epoch=0, **kwargs): 102 | tasks = list(data.keys()) 103 | if tasks_idxs is None: 104 | tasks_idxs = np.arange(len(tasks)) 105 | 106 | if batch_size > len(tasks)*keep_tasks_frac: 107 | batch_size = int(len(tasks)*keep_tasks_frac) 108 | print("Warning: batch size too big, decreasing to {}".format(batch_size)) 109 | 110 | if train: 111 | self.reset(batch_size) 112 | else: 113 | self.reset(len(data.keys())) 114 | 115 | if inner_epochs is not None: 116 | inner_epochs_range = [inner_epochs] if type(inner_epochs) == int else [int(i) for i in inner_epochs.split(",")] 117 | else: 118 | inner_epochs_range = None 119 | log_dict = dict() 120 | 121 | enable_coco = [False] * len(tasks) 122 | if precomputed_latent is None or "task_idx" not in precomputed_latent: 123 | shuffled_train_tasks = [tasks_idxs[idx] for idx in torch.randperm(len(tasks_idxs))] 124 | shuffled_coco_tasks = torch.randperm(self.extend_coco_size) 125 | shuffle_for_extended_coco_replace = torch.randperm(len(tasks)) 126 | if extend_coco: 127 | enable_coco=[shuffle_for_extended_coco_replace[i] < extend_coco_frac_train * len(tasks) for i in range(len(tasks))] 128 | else: 129 | shuffled_train_tasks=precomputed_latent["task_idx"][0].long() 130 | shuffled_coco_tasks=precomputed_latent["task_idx"][0].long() 131 | if "coco" in precomputed_latent: 132 | enable_coco=precomputed_latent["coco"][0] 133 | 134 | if output_mean_std: 135 | all_tasks_optimized_params = torch.zeros((len(tasks), self.net_feature_dim)).to(self.device) 136 | n_shot_seed = np.random.randint(0, 1000000) 137 | n_corr_guesses_support_start = 0 138 | n_corr_guesses_support_end = 0 139 | n_corr_guesses_query_start = 0 140 | n_corr_guesses_query_end = 0 141 | n_tot_samples_support = 0 142 | n_tot_samples_query = 0 143 | 144 | for inner_train_iter in tqdm(range(len(shuffled_train_tasks))): 145 | self.reset_task(use_vae) 146 | curr_log_dict = dict() 147 | if enable_coco[inner_train_iter]: 148 | task_idx = shuffled_coco_tasks[inner_train_iter] 149 | train_dataset = COCO_Tasks(categories=self.coco_categories, 150 | dataSubType=train_subtype, 151 | image_features=self.image_features, 152 | coco_answer_features=self.coco_answer_features, 153 | task_seed=task_idx) 154 | test_dataset = COCO_Tasks(categories=self.coco_categories, 155 | dataSubType=val_subtype, 156 | image_features=self.image_features, 157 | coco_answer_features=self.coco_answer_features, 158 | task_seed=task_idx) 159 | else: 160 | task_idx = shuffled_train_tasks[inner_train_iter] 161 | if task_idx > len(tasks)*keep_tasks_frac: 162 | continue 163 | train_dataset = CLIP_VQA(meta_data=data, 164 | dataSubType=train_subtype, 165 | task=tasks[task_idx], 166 | image_features=self.image_features, 167 | text_features=self.text_features, 168 | ques_emb=self.ques_emb, 169 | n_shot=n_shot_training) 170 | test_dataset = CLIP_VQA(meta_data=data, 171 | dataSubType=val_subtype, 172 | task=tasks[task_idx], 173 | image_features=self.image_features, 174 | text_features=self.text_features, 175 | ques_emb=self.ques_emb) 176 | 177 | train_dataloader = DataLoader(train_dataset, batch_size=1024, shuffle=True) 178 | test_dataloader = DataLoader(test_dataset, batch_size=1024, shuffle=False) 179 | 180 | # Inner loop 181 | if inner_epochs_range is not None: 182 | inner_epochs_sampled = inner_epochs_range[0] if len(inner_epochs_range) == 1 else \ 183 | np.random.randint(inner_epochs_range[0], inner_epochs_range[1]+1) 184 | else: 185 | inner_epochs_sampled = None 186 | 187 | embed_hessian=None 188 | 189 | init_inner_params = self.meta_module.get_inner_params() 190 | # Make sure to clone and detach to not optimize the actual initialization. 191 | inner_params = {k: v.clone().detach().requires_grad_() for (k,v) in self.meta_module.get_inner_params().items()} 192 | if opt_latents_for_n_shot is not None: 193 | assert self.meta_module.mnet.no_weight, "Not implemented when meta_module is a mainnet." 194 | inner_params["enet.embedding"] = opt_latents_for_n_shot[task_idx].clone().detach().requires_grad_() 195 | 196 | train_start_acc, train_start_loss = test_accuracy(self.meta_module, train_dataloader, params=inner_params) 197 | val_start_acc, val_start_loss = test_accuracy(self.meta_module, test_dataloader, params=inner_params) 198 | 199 | inner_params_list=None 200 | if guided_inner: 201 | if num_ensemble>1: 202 | inner_params_list=[] 203 | for _ in range(num_ensemble): 204 | inner_params = {k: v.clone().detach().requires_grad_() for (k,v) in self.meta_module.get_inner_params().items()} 205 | inner_params_list.append(inner_params) 206 | curr_log_dict.update(self.guided_inner(train_dataloader, inner_params, init_inner_params, inner_optim_fct, 207 | inner_train_iter, inner_epochs_sampled, batch_size, debug=inner_train_iter==1, **kwargs)) 208 | else: 209 | curr_log_dict.update(self.guided_inner(train_dataloader, inner_params, init_inner_params, inner_optim_fct, 210 | inner_train_iter, inner_epochs_sampled, batch_size, debug=inner_train_iter==1, **kwargs)) 211 | else: 212 | if precomputed_latent is None or n_shot_training is not None: 213 | inner_optimizer = inner_optim_fct(list(inner_params.values())) 214 | for _ in range(inner_epochs_sampled): 215 | outputs, labels = get_pred(self.meta_module, train_dataloader, params=inner_params) 216 | inner_loss = F.cross_entropy(outputs, labels) 217 | if debug and (self.sub_batch_i+1) % batch_size == 0: 218 | wandb.log({"debug_inner_loss": inner_loss.item()}) 219 | inner_optimizer.zero_grad() 220 | inner_loss.backward() 221 | inner_optimizer.step() 222 | 223 | if self.compute_hessian: 224 | embed_hessian = self.compute_feature_hessian(train_dataloader, self.meta_module, inner_params["enet.embedding"]).squeeze() 225 | 226 | else: 227 | # If i use the precomputed latents, I "simulate" the finetuning inner loop after it's supposed 228 | # to happen, to not break metrics and mean/std calculations (super ugly, to refactor) 229 | if self.meta_module.mnet.no_weight: 230 | inner_params["enet.embedding"] = precomputed_latent["embedding"][0, inner_train_iter] 231 | else: 232 | inner_params.update({"mnet."+k:v for (k,v) in self.meta_module.mnet.load_from_vector(precomputed_latent["w_vect"][0, inner_train_iter]).items()}) 233 | 234 | if self.compute_hessian: 235 | embed_hessian = precomputed_latent["hessian"][0, inner_train_iter] 236 | 237 | # Train set accuracy 238 | if inner_params_list is not None: 239 | train_end_acc, train_end_loss = test_accuracy(self.meta_module, train_dataloader, params_list=inner_params_list) 240 | val_end_acc, val_end_loss = test_accuracy(self.meta_module, test_dataloader, params_list=inner_params_list) 241 | 242 | avr_inner_params = OrderedDict() 243 | avr_inner_params.update({k: torch.stack([p[k] for p in inner_params_list]).mean(0).detach() for k in inner_params.keys()}) 244 | 245 | train_end_acc_avr, train_end_loss_avr = test_accuracy(self.meta_module, train_dataloader, params=avr_inner_params) 246 | val_end_acc_avr, val_end_loss_avr = test_accuracy(self.meta_module, test_dataloader, params=avr_inner_params) 247 | 248 | curr_log_dict["support_loss_end_avr"] = train_end_loss_avr 249 | curr_log_dict["query_loss_end_avr"] = val_end_loss_avr 250 | 251 | curr_log_dict["support_accuracy_end_avr"] = train_end_acc_avr 252 | curr_log_dict["query_accuracy_end_avr"] = val_end_acc_avr 253 | 254 | else: 255 | train_end_acc, train_end_loss = test_accuracy(self.meta_module, train_dataloader, params=inner_params) 256 | val_end_acc, val_end_loss = test_accuracy(self.meta_module, test_dataloader, params=inner_params) 257 | 258 | n_tot_samples_support += len(train_dataset) 259 | n_tot_samples_query += len(test_dataset) 260 | curr_log_dict["query_accuracy_start"] = val_start_acc 261 | n_corr_guesses_query_start += val_start_acc * len(test_dataset) 262 | curr_log_dict["query_accuracy_end"] = val_end_acc 263 | n_corr_guesses_query_end += val_end_acc * len(test_dataset) 264 | curr_log_dict["support_accuracy_start"] = train_start_acc 265 | n_corr_guesses_support_start += train_start_acc * len(train_dataset) 266 | curr_log_dict["support_accuracy_end"] = train_end_acc 267 | n_corr_guesses_support_end += train_end_acc * len(train_dataset) 268 | curr_log_dict["query_loss_start"] = val_start_loss 269 | curr_log_dict["query_loss_end"] = val_end_loss 270 | curr_log_dict["support_loss_start"] = train_start_loss 271 | curr_log_dict["support_loss_end"] = train_end_loss 272 | 273 | # Actually run cond model training/eval 274 | if not skip_cond: 275 | cond_dict = self.run_iter(task_idx, train_dataloader, 276 | self.feed_net_feature(inner_params = inner_params),embed_hessian=embed_hessian, 277 | train=train, coco=enable_coco[inner_train_iter], 278 | clip_embedding=precomputed_latent["clip_embedding"][0, inner_train_iter], **kwargs) 279 | curr_log_dict.update(cond_dict) 280 | 281 | if output_mean_std: 282 | all_tasks_optimized_params[task_idx] = self.feed_net_feature(inner_params = inner_params) 283 | 284 | append_dict(log_dict, curr_log_dict) 285 | 286 | if debug and self.sub_batch_i % batch_size == 0: 287 | output_dict = mean_dict(log_dict) 288 | output_dict["query_accuracy_start_flatten"] = n_corr_guesses_query_start / n_tot_samples_query 289 | output_dict["query_accuracy_end_flatten"] = n_corr_guesses_query_end / n_tot_samples_query 290 | output_dict["support_accuracy_start_flatten"] = n_corr_guesses_support_start / n_tot_samples_support 291 | output_dict["support_accuracy_end_flatten"] = n_corr_guesses_support_end / n_tot_samples_support 292 | log_metric(output_dict, prefix="debug_") 293 | log_dict = dict() 294 | 295 | output_dict = mean_dict(log_dict) 296 | output_dict["query_accuracy_start_flatten"] = n_corr_guesses_query_start / n_tot_samples_query 297 | output_dict["query_accuracy_end_flatten"] = n_corr_guesses_query_end / n_tot_samples_query 298 | output_dict["support_accuracy_start_flatten"] = n_corr_guesses_support_start / n_tot_samples_support 299 | output_dict["support_accuracy_end_flatten"] = n_corr_guesses_support_end / n_tot_samples_support 300 | output_dict["epoch"] = epoch 301 | 302 | if not output_mean_std: 303 | return output_dict 304 | else: 305 | optimized_params_mean = torch.mean(all_tasks_optimized_params, dim=0, keepdim=True) 306 | optimized_params_std = torch.std(all_tasks_optimized_params, dim=0, keepdim=True) 307 | return output_dict, optimized_params_mean, optimized_params_std, all_tasks_optimized_params 308 | 309 | def guided_inner(self, train_dataloader, inner_params, init_inner_params, inner_optim_fct, 310 | inner_train_iter, inner_epochs, batch_size, debug, **kwargs): 311 | raise NotImplementedError("Use a subclass of ConditionalModelTraining") 312 | 313 | def feed_net_feature(self, **kwargs): 314 | raise NotImplementedError("Use a subclass of ConditionalModelTraining") 315 | 316 | def compute_feature_hessian(self, train_dataloader, meta_module, embedding): 317 | """ Only supports when the embedding is the net feature. Would generalize this if needed. """ 318 | def get_nll(embedding): 319 | curr_params = OrderedDict() 320 | curr_params["enet.embedding"] = embedding 321 | outputs, labels = get_pred(meta_module, train_dataloader, params=curr_params) 322 | inner_loss = F.cross_entropy(outputs, labels) 323 | return inner_loss 324 | 325 | return torch.autograd.functional.hessian(get_nll, embedding) 326 | 327 | -------------------------------------------------------------------------------- /training/hyperclip_learn.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import wandb 5 | from torch.nn import functional as F 6 | from utils.train_utils import get_pred, test_accuracy 7 | 8 | from training.conditional_model_learn import ConditionalModelTraining 9 | from training.vae_learn import sample_from_enc 10 | 11 | 12 | class HyperclipTraining(ConditionalModelTraining): 13 | def __init__(self, meta_module, hnet_gen, hnet_enc, hyperclip, optimizer, 14 | image_features, text_features, ques_emb, config, device, train_on_vae=False, **kwargs): 15 | net_feature_dim = hyperclip.hyper.input_dim 16 | super().__init__(meta_module, hyperclip, optimizer, image_features, text_features, ques_emb, 17 | config, device, net_feature_dim, hnet_gen, hnet_enc, train_on_vae=train_on_vae, **kwargs) 18 | self.feature_mean=None 19 | self.feature_std=None 20 | 21 | def train_step(self, log_dict): 22 | self.optimizer.zero_grad() 23 | 24 | net_batch = self.net_batch 25 | if self.feature_mean is not None: 26 | net_batch = net_batch - self.feature_mean 27 | net_batch = net_batch / self.feature_std 28 | 29 | _, hyperclip_loss = self.cond_model.forward(hyper = net_batch, text = self.ques_batch, precomputed_it_embs = True) 30 | log_dict.update({"hyperclip_loss": hyperclip_loss.detach().cpu().numpy()}) 31 | hyperclip_loss.backward() 32 | self.optimizer.step() 33 | 34 | def test_step(self, log_dict): 35 | with torch.no_grad(): 36 | self.cond_model.eval() 37 | 38 | net_batch = self.net_batch 39 | if self.feature_mean is not None: 40 | net_batch = net_batch-self.feature_mean 41 | net_batch = net_batch / self.feature_std 42 | 43 | _, hyperclip_loss = self.cond_model.forward(hyper = net_batch, 44 | text = self.ques_batch, precomputed_it_embs = True) 45 | log_dict.update({"hyperclip_val_loss": hyperclip_loss.detach().cpu().numpy()}) 46 | self.cond_model.train() 47 | 48 | 49 | def set_stats(self, mean, std): 50 | self.feature_mean= mean 51 | self.feature_std=std 52 | 53 | 54 | def get_compute_grad(self, train_dataloader): 55 | def compute_grad(inner_params): 56 | with torch.enable_grad(): 57 | x = inner_params["enet.embedding"] 58 | outputs, labels = get_pred(self.meta_module, train_dataloader, params=inner_params) 59 | inner_loss = F.cross_entropy(outputs, labels) 60 | grad =torch.autograd.grad(inner_loss, x) 61 | return grad[0] 62 | return compute_grad 63 | 64 | 65 | def guided_inner(self, train_dataloader, inner_params, init_inner_params, inner_optim_fct, 66 | inner_train_iter, inner_epochs, batch_size, debug, langevin_eps=0, guidance_scheduler_fn=None, guidance_init_l2_weight=0, **kwargs): 67 | 68 | # if eval: 69 | compute_grad_fn = self.get_compute_grad(train_dataloader) 70 | 71 | inner_optimizer = inner_optim_fct(list(inner_params.values())) 72 | if guidance_scheduler_fn is not None: 73 | guidance_scheduler = guidance_scheduler_fn(inner_optimizer) 74 | else: 75 | guidance_scheduler = None 76 | 77 | 78 | log_dict=dict() 79 | log_dict["cos_sim"]=[] 80 | log_dict["acc"]=[] 81 | log_dict["loss"]=[] 82 | 83 | for _ in range(inner_epochs): 84 | task_ques_emb = next(iter(train_dataloader))["ques_emb"][0] 85 | weights = self.meta_module.get_mainnet_weights(ques_emb=task_ques_emb, params=inner_params) 86 | 87 | if self.feature_mean is not None: 88 | weights = weights - self.feature_mean 89 | weights = weights / self.feature_std 90 | 91 | hyperclip = self.cond_model 92 | 93 | task_weight_emb = hyperclip.encode_hyper(weights) 94 | 95 | norm_task_ques_emb = task_ques_emb / task_ques_emb.norm(dim=-1, keepdim=True) 96 | norm_task_weight_emb = task_weight_emb / task_weight_emb.norm(dim=-1, keepdim=True) 97 | 98 | inner_product_embs_loss = - norm_task_weight_emb @ norm_task_ques_emb.T 99 | 100 | init_l2_loss = torch.stack( 101 | [(ip - p).pow(2).sum() for (ip, p) 102 | in zip(init_inner_params.values(), inner_params.values())]).sum() / 2 103 | inner_loss = inner_product_embs_loss + init_l2_loss* guidance_init_l2_weight 104 | 105 | inner_optimizer.zero_grad() 106 | inner_loss.backward() 107 | 108 | if True: #eval: 109 | alignment = torch.nn.CosineSimilarity(dim=0)(inner_params["enet.embedding"].grad.view(-1), compute_grad_fn(inner_params).view(-1)) 110 | log_dict["cos_sim"].append(alignment.item()) 111 | a, l = test_accuracy(self.meta_module, train_dataloader, params=inner_params) 112 | log_dict["acc"].append(a) 113 | log_dict["loss"].append(l) 114 | 115 | if debug: 116 | wandb.log({"debug_inner_guidance_loss": inner_product_embs_loss.item(), 117 | "debug_inner_l2_loss": init_l2_loss.item(), 118 | "inner_lr": inner_optimizer.param_groups[0]['lr'], 119 | "gradnorm": torch.stack([p.grad.pow(2).sum() for p in inner_params.values() if p.grad is not None]).sum().item()}) 120 | 121 | 122 | if langevin_eps>0: 123 | for p in inner_params.values(): 124 | p.grad.data += langevin_eps*torch.randn_like(p.grad.data)/math.sqrt(inner_optimizer.param_groups[0]['lr']) 125 | 126 | inner_optimizer.step() 127 | if guidance_scheduler is not None: 128 | guidance_scheduler.step() 129 | 130 | return log_dict 131 | 132 | 133 | def feed_net_feature(self, inner_params=None): 134 | if self.train_on_vae and self.hnet_gen is not None: 135 | return self.hnet_gen(sample_from_enc(self.hnet_enc, self.base_module.get_mainnet_weights(params = inner_params).detach())).detach() 136 | return self.base_module.get_mainnet_weights(params = inner_params).detach() 137 | -------------------------------------------------------------------------------- /training/latent_diffusion_learn.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import wandb 5 | from torch.nn import functional as F 6 | from utils.train_utils import get_pred 7 | 8 | from training.conditional_model_learn import ConditionalModelTraining 9 | from training.vae_learn import sample_from_enc 10 | 11 | 12 | class LatentDiffusionTraining(ConditionalModelTraining): 13 | def __init__(self, 14 | meta_module, 15 | diffusion_model, 16 | optimizer, 17 | n_timestep, 18 | beta, 19 | ema, 20 | net_feature_dim, 21 | train_data_for_mean_std, 22 | inner_opt_for_mean_std, 23 | image_features, 24 | text_features, 25 | ques_emb, 26 | config, 27 | device, 28 | hnet_gen, hnet_enc, train_on_vae=False, 29 | compute_hessian=False, 30 | compute_latents_mean_std = True, 31 | v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta 32 | class_free_guidance=True, 33 | class_free_training_uncond_rate=0.2, 34 | precomputed_latent_for_mean_std=None, 35 | **kwargs 36 | ): 37 | 38 | super().__init__(meta_module, diffusion_model, optimizer, image_features, text_features, 39 | ques_emb, config, device, net_feature_dim, hnet_gen, hnet_enc, train_on_vae, compute_hessian, **kwargs) 40 | self.n_timestep = n_timestep 41 | self.beta = torch.tensor(beta).to(self.device) 42 | self.ema = ema 43 | 44 | self.alpha = (1 - self.beta) 45 | self.alpha_cumprod = self.alpha.log().cumsum(0).exp().float() 46 | self.alpha_cumprod_prev = torch.cat((torch.ones(1, device = self.device), self.alpha_cumprod[:-1])) 47 | self.v_posterior = v_posterior 48 | self.sigma = ((1 - self.v_posterior) * (self.beta * (1 - self.alpha_cumprod_prev) / (1 - self.alpha_cumprod)) + self.v_posterior * self.beta).sqrt().float() 49 | self.beta, self.alpha = self.beta.float(), self.alpha.float() 50 | 51 | self.class_free_guidance = class_free_guidance 52 | self.class_free_training_uncond_rate = class_free_training_uncond_rate 53 | 54 | if compute_latents_mean_std: 55 | _, self.latents_mean, self.latents_std, _ = self.run_epoch(train_data_for_mean_std, config["inner_epochs"], inner_opt_for_mean_std, 56 | batch_size=config["diffusion_batch_size"], 57 | train=True, train_subtype = config["train_subtype"], val_subtype=config["val_subtype"], 58 | precomputed_latent=precomputed_latent_for_mean_std, 59 | skip_cond=True, output_mean_std=True, keep_tasks_frac=config["diffusion_keep_tasks_frac"]) 60 | else: 61 | self.latents_mean, self.latents_std = torch.zeros((1,self.net_feature_dim)).to(self.device), torch.ones((1,self.net_feature_dim)).to(self.device) 62 | 63 | wandb.log({"latents_mean": self.latents_mean.detach().cpu().numpy()}) 64 | wandb.log({"latents_std": self.latents_std.detach().cpu().numpy()}) 65 | 66 | def forward_diffusion(self, latents_batch, ques_batch, metric=None): 67 | norm_latents_batch = (latents_batch - self.latents_mean) / self.latents_std 68 | t = torch.randint(self.n_timestep, (norm_latents_batch.size(0),) + (1,) * (norm_latents_batch.dim() - 1), device = self.device) 69 | eps = torch.randn_like(norm_latents_batch) 70 | norm_latents_batch_t = torch.sqrt(self.alpha_cumprod[t]) * norm_latents_batch + torch.sqrt(1 - self.alpha_cumprod[t]) * eps 71 | output = self.cond_model(norm_latents_batch_t, t, ques_batch) 72 | if metric is not None: 73 | loss = (((eps - output).unsqueeze(1) @ metric @ (eps - output).unsqueeze(-1)) / metric.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1)).mean() 74 | else: 75 | loss = (eps - output).pow(2).mean() 76 | return loss 77 | 78 | def train_step(self, log_dict): 79 | self.optimizer.zero_grad() 80 | 81 | ques_batch = torch.clone(self.ques_batch) 82 | if self.class_free_guidance: 83 | perm = torch.randperm(ques_batch.shape[0]) 84 | idx = perm[:int(perm.shape[0]*self.class_free_training_uncond_rate)] 85 | ques_batch[idx] = torch.zeros(ques_batch.shape[1]).to(self.device) 86 | 87 | loss = self.forward_diffusion(self.net_batch, ques_batch, self.hessian_batch) 88 | 89 | log_dict.update({"diffusion_loss": loss.detach().cpu().numpy()}) 90 | loss.backward() 91 | self.optimizer.step() 92 | 93 | if self.ema is not None: self.ema.step() 94 | 95 | def test_step(self, log_dict): 96 | with torch.no_grad(): 97 | self.cond_model.eval() 98 | 99 | ques_batch = torch.clone(self.ques_batch) 100 | if self.class_free_guidance: 101 | perm = torch.randperm(ques_batch.shape[0]) 102 | idx = perm[:int(perm.shape[0]*self.class_free_training_uncond_rate)] 103 | ques_batch[idx] = torch.zeros(ques_batch.shape[1]).to(self.device) 104 | 105 | loss = self.forward_diffusion(self.net_batch, ques_batch, self.hessian_batch) 106 | 107 | log_dict.update({"diffusion_val_loss": loss.detach().cpu().numpy()}) 108 | self.cond_model.train() 109 | 110 | def get_compute_grad(self, train_dataloader): 111 | def compute_grad(x): 112 | with torch.enable_grad(): 113 | inner_params = OrderedDict() 114 | x = x.detach().requires_grad_() 115 | inner_params["enet.embedding"] = x * self.latents_std + self.latents_mean 116 | outputs, labels = get_pred(self.meta_module, train_dataloader, params=inner_params) 117 | inner_loss = F.cross_entropy(outputs, labels) 118 | grad =torch.autograd.grad(inner_loss, x) 119 | return grad[0] 120 | return compute_grad 121 | 122 | def guided_inner(self, train_dataloader, inner_params, init_inner_params, inner_optim_fct, 123 | inner_train_iter, inner_epochs, batch_size, debug, class_guidance_gamma=None, 124 | init_guidance_at="random", guidance_start_from_t_frac=1, fast_sampling_factor=None, few_shot_guidance=False, few_shot_gamma=1, **kwargs): 125 | #IMPORTANT: x_accuracy_start printed in wandb is not the one for the random initialized embedding (to which the diffusion is applied to) 126 | # but the one for the initial emb initialization we are NOT using (except if we start from pre-trained !) 127 | if few_shot_guidance: 128 | compute_grad_fn = self.get_compute_grad(train_dataloader) 129 | else: 130 | compute_grad_fn=None 131 | if class_guidance_gamma is None: 132 | class_guidance_gamma = 1. 133 | if init_guidance_at=="random": 134 | latent_start_point = None 135 | elif init_guidance_at=="pre-trained": 136 | latent_start_point = init_inner_params["enet.embedding"].clone().detach() 137 | sampled_emb = self.generate(next(iter(train_dataloader))["ques_emb"][0], 138 | self.class_free_guidance, 139 | class_guidance_gamma, 140 | latent_start_point=latent_start_point, 141 | start_from_t_frac=guidance_start_from_t_frac, 142 | fast_sampling_factor=fast_sampling_factor, 143 | compute_grad_fn=compute_grad_fn, 144 | few_shot_gamma=few_shot_gamma) 145 | inner_params["enet.embedding"] = sampled_emb # refactor to be more general 146 | return dict() 147 | 148 | 149 | def generate(self, cond_emb, class_free_guidance, gamma, latent_start_point=None, start_from_t_frac=1., fast_sampling_factor=None, compute_grad_fn=None, few_shot_gamma=1): 150 | 151 | with torch.no_grad(): 152 | 153 | alpha_cumprod = self.alpha_cumprod 154 | sigma = self.sigma 155 | if fast_sampling_factor is not None and fast_sampling_factor > 1: 156 | alpha_cumprod = self.alpha_cumprod[::fast_sampling_factor] 157 | alpha_cumprod_prev = torch.cat((torch.ones(1, device = self.device), alpha_cumprod[:-1])) 158 | beta = 1 - (alpha_cumprod / alpha_cumprod_prev) 159 | sigma = ((1 - self.v_posterior) * (beta * (1 - alpha_cumprod_prev) / (1 - alpha_cumprod)) + self.v_posterior * beta).sqrt().float() 160 | 161 | if latent_start_point is None: 162 | x = torch.randn(self.latents_mean.shape, device = self.device) 163 | else: 164 | x = (latent_start_point - self.latents_mean) / self.latents_std 165 | 166 | for t in range(int(sigma.shape[0]*start_from_t_frac)-1, -1, -1): 167 | t_for_model = t 168 | if fast_sampling_factor is not None and fast_sampling_factor > 1: 169 | t_for_model = t * fast_sampling_factor 170 | if not class_free_guidance or gamma==1.: 171 | output = self.cond_model(x, t_for_model, cond_emb) 172 | else: 173 | output = (1 - gamma) * self.cond_model(x, t_for_model, torch.zeros_like(cond_emb)) + gamma * self.cond_model(x, t_for_model, cond_emb) 174 | 175 | if compute_grad_fn is not None: 176 | grad = compute_grad_fn(x) 177 | grad = grad / grad.norm()*output.norm()*few_shot_gamma 178 | output = output + grad 179 | 180 | z = torch.zeros_like(x) if t == 0 else torch.randn_like(x) 181 | x = 1/torch.sqrt(self.alpha[t]) * (x - (1-self.alpha[t]) / torch.sqrt(1-alpha_cumprod[t]) * output) + sigma[t] * z 182 | 183 | x = x * self.latents_std + self.latents_mean 184 | 185 | return x 186 | 187 | def feed_net_feature(self, inner_params=None): 188 | if self.train_on_vae and self.hnet_gen is not None: 189 | return sample_from_enc(self.hnet_enc, self.meta_module.get_mainnet_weights(params = inner_params).detach()).detach() 190 | # works if inner_params contains only one element, the hnet/vae embedding 191 | return inner_params["enet.embedding"].detach() 192 | 193 | -------------------------------------------------------------------------------- /training/maml_learn.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | from collections import OrderedDict 4 | 5 | import numpy as np 6 | import torch 7 | import wandb 8 | from data.dataloader.clip_vqa import CLIP_VQA 9 | from data.dataloader.coco_tasks import COCO_Tasks 10 | from torch.nn import functional as F 11 | from torch.utils.data import DataLoader 12 | from tqdm import tqdm 13 | from utils.misc_utils import append_dict, mean_dict 14 | from utils.train_utils import get_pred, log_metric, test_accuracy 15 | 16 | 17 | class GradientBuffer(): 18 | def __init__(self, param_list): 19 | self.param_list=param_list 20 | self.reset_buffer() 21 | 22 | def reset_buffer(self): 23 | self.grad_list=[torch.zeros_like(param) for param in self.param_list] 24 | self.num = 0 25 | 26 | def accumulate(self): 27 | for param, grad in zip(self.param_list, self.grad_list): 28 | if param.grad is not None: 29 | grad += param.grad.data 30 | self.num += 1 31 | 32 | def unload(self): 33 | for param, grad in zip(self.param_list, self.grad_list): 34 | if param.grad is not None: 35 | param.grad.data += grad/self.num 36 | self.reset_buffer() 37 | 38 | class MAML(): 39 | def __init__(self, meta_module, meta_optimizer, image_features, text_features, ques_emb, config, coco_categories=None, coco_answer_features=None, 40 | extend_coco_size=10 * 870, # size of virtual extended coco dataset 41 | ): 42 | self.meta_module=meta_module 43 | self.meta_optimizer=meta_optimizer 44 | self.image_features=image_features 45 | self.text_features=text_features 46 | self.ques_emb=ques_emb 47 | self.config=config 48 | self.coco_categories=coco_categories 49 | self.coco_answer_features=coco_answer_features 50 | self.extend_coco_size=extend_coco_size 51 | self.reset_coco() 52 | 53 | def reset(self, batch_size, device): 54 | self.buffer = GradientBuffer(self.meta_module.meta_params) 55 | self.batch_size=batch_size 56 | self.iter=0 57 | 58 | def reset_coco(self): 59 | self.coco_iter = 0 60 | self.shuffled_coco_tasks = torch.randperm(self.extend_coco_size) 61 | 62 | def train(self, test_dataloader, inner_params): 63 | # Validation loss 64 | log_dict=dict() 65 | self.meta_module.zero_grad() 66 | outputs, labels = get_pred(self.meta_module, test_dataloader, params=inner_params) 67 | loss=F.cross_entropy(outputs, labels) 68 | loss.backward() 69 | self.buffer.accumulate() 70 | 71 | if (self.iter % self.batch_size == 0): 72 | self.buffer.unload() 73 | self.meta_optimizer.step() 74 | if self.config["meta_grad_clip"] > 0: 75 | torch.nn.utils.clip_grad_norm_(self.meta_module.parameters(), self.config["meta_grad_clip"]) 76 | 77 | def get_gradnorm(module): 78 | return np.sqrt(np.sum([p.grad.pow(2).sum().item() for p in module.parameters() if p.grad is not None])) if module is not None else -1 79 | 80 | log_dict["gradnorm_mnet"] = get_gradnorm(self.meta_module.mnet) 81 | log_dict["gradnorm_hnet"] = get_gradnorm(self.meta_module.hnet) 82 | log_dict["gradnorm_enet"] = get_gradnorm(self.meta_module.enet) 83 | self.iter = self.iter+1 84 | 85 | return log_dict 86 | 87 | def run_epoch(self, data, inner_epochs, inner_lr, meta_batch_size=1, train=False, second_order=False, 88 | train_subtype="train", val_subtype="test", keep_tasks_frac=1., extend_coco=False, 89 | extend_coco_frac_train=0.5, # frac of tasks to replace with extended coco 90 | debug=False, device=None, filter_tasks_by_max_k=None, filter_tasks_answers=None, n_shot_training=None, epoch=0): 91 | 92 | tasks = list(data.keys()) 93 | 94 | if inner_epochs is not None: 95 | inner_epochs_range = [inner_epochs] if type(inner_epochs) == int else [int(i) for i in inner_epochs.split(",")] 96 | else: 97 | inner_epochs_range = None 98 | 99 | if train: 100 | self.reset(meta_batch_size, None) 101 | if extend_coco and self.coco_iter >= self.extend_coco_size: 102 | self.reset_coco() 103 | 104 | log_dict = dict() 105 | 106 | tasks_idxs = np.arange(len(tasks)) 107 | if filter_tasks_by_max_k is not None: 108 | if not filter_tasks_answers: 109 | tasks_idxs = [i for i, t in enumerate(tasks) if np.min(np.unique([a for [_,a] in data[t][train_subtype]], return_counts=True)[1]) >= filter_tasks_by_max_k] 110 | else: 111 | data = copy.deepcopy(data) 112 | for t in tasks: 113 | ans, count = np.unique([a for [_,a] in data[t][train_subtype]], return_counts=True) 114 | filtered_ans = ans[count >= filter_tasks_by_max_k] 115 | data[t][train_subtype] = [d for d in data[t][train_subtype] if d[1] in filtered_ans] 116 | data[t][val_subtype] = [d for d in data[t][val_subtype] if d[1] in filtered_ans] 117 | tasks_idxs = [i for i, t in enumerate(tasks) if len(np.unique([a for [_,a] in data[t][train_subtype]])) >= 2] 118 | 119 | tasks = list(data.keys()) 120 | if meta_batch_size > len(tasks_idxs)*keep_tasks_frac: 121 | meta_batch_size = int(len(tasks_idxs)*keep_tasks_frac) 122 | print("Warning: batch size too big, decreasing to {}".format(meta_batch_size)) 123 | 124 | shuffled_train_tasks = [tasks_idxs[idx] for idx in torch.randperm(len(tasks_idxs))] 125 | shuffle_for_extended_coco_replace = torch.randperm(len(tasks_idxs)) 126 | 127 | for inner_train_iter in tqdm(range(len(tasks_idxs))): 128 | curr_log_dict = dict() 129 | enable_coco = extend_coco and shuffle_for_extended_coco_replace[inner_train_iter] < extend_coco_frac_train * len(tasks_idxs) 130 | if enable_coco: 131 | train_dataset = COCO_Tasks(categories=self.coco_categories, 132 | dataSubType=train_subtype, 133 | image_features=self.image_features, 134 | coco_answer_features=self.coco_answer_features, 135 | task_seed=self.shuffled_coco_tasks[self.coco_iter]) 136 | test_dataset = COCO_Tasks(categories=self.coco_categories, 137 | dataSubType=val_subtype, 138 | image_features=self.image_features, 139 | coco_answer_features=self.coco_answer_features, 140 | task_seed=self.shuffled_coco_tasks[self.coco_iter]) 141 | else: 142 | task_idx = shuffled_train_tasks[inner_train_iter] 143 | if task_idx > len(tasks)*keep_tasks_frac: 144 | continue 145 | 146 | train_idx=None 147 | val_idx=None 148 | if val_subtype == "random": 149 | num_data = len(data[tasks[task_idx]]["train"] + data[tasks[task_idx]]["test"]) 150 | randfrac = 2/3 151 | randperm = np.random.permutation(num_data) 152 | val_idx = randperm[:math.floor(num_data*randfrac)] 153 | train_idx = randperm[math.floor(num_data*randfrac):] 154 | 155 | train_dataset = CLIP_VQA(meta_data=data, 156 | dataSubType=train_subtype, 157 | task=tasks[task_idx], 158 | image_features=self.image_features, 159 | text_features=self.text_features, 160 | ques_emb=self.ques_emb, 161 | data_idx=train_idx, 162 | n_shot=n_shot_training) 163 | test_dataset = CLIP_VQA(meta_data=data, 164 | dataSubType=val_subtype, 165 | task=tasks[task_idx], 166 | image_features=self.image_features, 167 | text_features=self.text_features, 168 | ques_emb=self.ques_emb, 169 | data_idx=val_idx 170 | ) 171 | 172 | train_dataloader = DataLoader(train_dataset, batch_size=1024, shuffle=True) 173 | test_dataloader = DataLoader(test_dataset, batch_size=1024, shuffle=False) 174 | 175 | inner_params = self.meta_module.get_inner_params() 176 | 177 | task_embedding = iter(train_dataloader).next()["ques_emb"][0].to(device).detach() 178 | if self.config["use_clip_embedding_init"]: 179 | assert self.meta_module.hnet is not None 180 | inner_params["enet.embedding"] = task_embedding.requires_grad_() 181 | 182 | train_start_acc, train_start_loss = test_accuracy(self.meta_module, train_dataloader, params=inner_params) 183 | val_start_acc, val_start_loss = test_accuracy(self.meta_module, test_dataloader, params=inner_params) 184 | 185 | if inner_epochs_range is not None: 186 | inner_epochs_sampled = inner_epochs_range[0] if len(inner_epochs_range) == 1 else \ 187 | np.random.randint(inner_epochs_range[0], inner_epochs_range[1]+1) 188 | else: 189 | inner_epochs_sampled = None 190 | 191 | # Inner loop 192 | for _ in range(inner_epochs_sampled): 193 | outputs, labels = get_pred(self.meta_module, train_dataloader, params=inner_params) 194 | inner_loss = F.cross_entropy(outputs, labels) 195 | 196 | if debug and inner_train_iter % meta_batch_size == 0: 197 | wandb.log({"debug_inner_loss": inner_loss.item()}) 198 | grads = torch.autograd.grad(inner_loss, inner_params.values(), retain_graph=True, 199 | create_graph=True if train and second_order else False) 200 | params_next = OrderedDict() 201 | for (name, param), grad in zip(list(inner_params.items()), grads): 202 | params_next[name] = param - inner_lr * grad 203 | inner_params = params_next 204 | 205 | # Train set accuracy 206 | train_end_acc, train_end_loss = test_accuracy(self.meta_module, train_dataloader, params=inner_params) 207 | val_end_acc, val_end_loss = test_accuracy(self.meta_module, test_dataloader, params=inner_params) 208 | 209 | curr_log_dict["query_accuracy_start"] = val_start_acc 210 | curr_log_dict["query_accuracy_end"] = val_end_acc 211 | curr_log_dict["support_accuracy_start"] = train_start_acc 212 | curr_log_dict["support_accuracy_end"] = train_end_acc 213 | curr_log_dict["query_loss_start"] = val_start_loss 214 | curr_log_dict["query_loss_end"] = val_end_loss 215 | curr_log_dict["support_loss_start"] = train_start_loss 216 | curr_log_dict["support_loss_end"] = train_end_loss 217 | 218 | if train: 219 | curr_log_dict.update(self.train(test_dataloader, inner_params)) 220 | 221 | append_dict(log_dict, curr_log_dict) 222 | 223 | if debug and inner_train_iter % meta_batch_size == 0: 224 | log_metric(mean_dict(log_dict), prefix = "debug_") 225 | log_dict=dict() 226 | 227 | if enable_coco: 228 | self.coco_iter += 1 229 | 230 | output_dict = mean_dict(log_dict) 231 | output_dict["epoch"] = epoch 232 | 233 | return output_dict 234 | -------------------------------------------------------------------------------- /training/store_few_shot_latent.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from training.conditional_model_learn import ConditionalModelTraining 7 | 8 | 9 | class StoreFewShotLatent(ConditionalModelTraining): 10 | def __init__(self, meta_module, image_features, text_features, ques_emb, config, device, compute_hessian=False, 11 | reset_normal_embedding=False, **kwargs): 12 | assert hasattr(meta_module, "hnet") 13 | feature_dim = np.sum([v.numel() for v in meta_module.get_inner_params().values()]) 14 | super().__init__(meta_module, None, None, image_features, text_features, ques_emb, config, device, feature_dim, 15 | None, None, compute_hessian=compute_hessian, **kwargs) 16 | self.reset_normal_embedding=reset_normal_embedding 17 | 18 | def reset_task(self, guided_inner): 19 | if self.reset_normal_embedding: 20 | self.meta_module.enet.reset() 21 | 22 | def train_step(self, log_dict, log_file=None): 23 | clip_embedding=[] 24 | embedding=[] 25 | w_vect=[] 26 | task_idx=[] 27 | hessian=[] 28 | coco=[] 29 | 30 | if path.exists(log_file): 31 | matrix_dict = torch.load(log_file) 32 | clip_embedding.append(matrix_dict["clip_embedding"]) 33 | coco.append(matrix_dict["coco"]) 34 | 35 | if self.base_module.mnet.no_weight: 36 | embedding.append(matrix_dict["embedding"]) 37 | else: 38 | w_vect.append(matrix_dict["w_vect"]) 39 | 40 | task_idx.append(matrix_dict["task_idx"]) 41 | if self.compute_hessian: 42 | hessian.append(matrix_dict["hessian"]) 43 | 44 | clip_embedding.append(self.ques_batch.unsqueeze(0).cpu()) 45 | coco.append(self.coco_batch.unsqueeze(0).cpu()) 46 | if self.base_module.mnet.no_weight: 47 | embedding.append(self.net_batch.unsqueeze(0).cpu()) 48 | else: 49 | w_vect.append(self.net_batch.unsqueeze(0).cpu()) 50 | 51 | task_idx.append(self.task_batch.unsqueeze(0).cpu()) 52 | if self.compute_hessian: 53 | hessian.append(self.hessian_batch.unsqueeze(0).cpu()) 54 | 55 | matrix_dict=dict() 56 | matrix_dict["clip_embedding"]=torch.cat(clip_embedding, dim=0) 57 | matrix_dict["coco"]=torch.cat(coco, dim=0) 58 | 59 | if self.base_module.mnet.no_weight: 60 | matrix_dict["embedding"]=torch.cat(embedding, dim=0) 61 | else: 62 | matrix_dict["w_vect"]=torch.cat(w_vect, dim=0) 63 | 64 | matrix_dict["task_idx"]=torch.cat(task_idx, dim=0) 65 | if self.compute_hessian: 66 | matrix_dict["hessian"]=torch.cat(hessian, dim=0) 67 | 68 | print("Saving new tensor") 69 | torch.save(matrix_dict, log_file) 70 | 71 | def test_step(self, log_dict): 72 | pass 73 | 74 | def guided_inner(self, *args, **kwargs): 75 | pass 76 | 77 | def feed_net_feature(self, inner_params): 78 | if self.base_module.mnet.no_weight: 79 | return inner_params["enet.embedding"].detach() 80 | return self.base_module.get_mainnet_weights(params = inner_params).detach() 81 | -------------------------------------------------------------------------------- /training/vae_learn.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import wandb 6 | from data.dataloader.clip_vqa import CLIP_VQA 7 | from model.custom_hnet import CLIPAdapter, EmbeddingModule, MetaModel 8 | from torch.nn import functional as F 9 | from torch.utils.data import DataLoader 10 | from tqdm import tqdm 11 | from utils.misc_utils import append_dict, mean_dict 12 | from utils.train_utils import get_pred, log_metric, test_accuracy 13 | 14 | 15 | def sample_from_enc(encoder, real_imgs): 16 | vect = encoder(real_imgs) 17 | enc_dim = vect.shape[1] // 2 18 | mu, logvar = vect[:, :enc_dim], vect[:, enc_dim:] - 1 19 | std = torch.exp(0.5 * logvar) 20 | eps = torch.randn_like(std) 21 | encoding = eps * std + mu 22 | return encoding 23 | 24 | 25 | class VAETraining(): 26 | def __init__(self, meta_module, hnet_gen, hnet_enc, optimizer_gen, optimizer_enc, image_features, text_features, 27 | ques_emb, config, device): 28 | self.meta_module = meta_module 29 | self.hnet_gen = hnet_gen 30 | self.hnet_enc = hnet_enc 31 | self.optimizer_gen = optimizer_gen 32 | self.optimizer_enc = optimizer_enc 33 | 34 | self.image_features = image_features 35 | self.text_features = text_features 36 | self.ques_emb = ques_emb 37 | self.config = config 38 | self.device=device 39 | self.feature_mean=None 40 | self.feature_std=None 41 | 42 | def set_stats(self, mean, std): 43 | self.feature_mean= mean 44 | self.feature_std=std 45 | 46 | def reset(self, batch_size, device): 47 | self.vae_sub_batch_i = 0 48 | self.weight_batch = torch.zeros(batch_size, self.hnet_enc.input_dim, dtype=torch.float32).to(device) 49 | 50 | def train(self, model_weights): 51 | log_dict=dict() 52 | 53 | self.weight_batch[self.vae_sub_batch_i] = model_weights 54 | 55 | if self.vae_sub_batch_i == self.weight_batch.shape[0]-1: 56 | real_imgs = self.weight_batch 57 | encoder=self.hnet_enc 58 | generator=self.hnet_gen 59 | kld_weight=self.config["kld_weight"] 60 | 61 | vect = encoder(real_imgs) 62 | enc_dim = vect.shape[1] // 2 63 | mu, logvar = vect[:, :enc_dim], vect[:, enc_dim:] - 1 64 | std = torch.exp(0.5 * logvar) 65 | eps = torch.randn_like(std) 66 | factor = math.sqrt(real_imgs.shape[-1]) 67 | 68 | encoding = eps * std + mu 69 | generated_imgs = generator(encoding) 70 | 71 | recons = (generated_imgs - real_imgs) 72 | 73 | if self.feature_std is not None: 74 | recons = recons / self.feature_std 75 | 76 | recons_loss = recons.pow(2).mean(0).sum() / factor 77 | 78 | kld_loss = kld_weight * torch.mean(-0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp(), dim=1), 79 | dim=0) / factor 80 | 81 | loss = (recons_loss + kld_loss) 82 | 83 | 84 | self.optimizer_gen.zero_grad() 85 | self.optimizer_enc.zero_grad() 86 | loss.backward() 87 | log_dict = {"loss": loss.item(), "kld_loss": kld_loss.item(), "recons_loss": recons_loss.item(), 88 | "encoding_norm": encoding.pow(2).mean().sqrt().item(), 89 | "norm_fake": generated_imgs.norm(dim=1).mean(dim=0).item(), 90 | "norm_real": real_imgs.norm(dim=1).mean(dim=0).item(), 91 | "grad_norm_enc": torch.stack([p.grad.pow(2).sum() for p in encoder.parameters() if p.grad is not None]).sum().sqrt().item(), 92 | "grad_norm_gen": torch.stack([p.grad.pow(2).sum() for p in generator.parameters() if p.grad is not None]).sum().sqrt().item()} 93 | 94 | 95 | if self.config["grad_clip"] > 0: 96 | torch.nn.utils.clip_grad_norm_(encoder.parameters(), max_norm=self.config["grad_clip"]) 97 | torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=self.config["grad_clip"]) 98 | 99 | log_dict.update({"grad_norm_enc_clipped": torch.stack([p.grad.pow(2).sum() for p in encoder.parameters() if p.grad is not None]).sum().sqrt().item(), 100 | "grad_norm_gen_clipped": torch.stack([p.grad.pow(2).sum() for p in generator.parameters() if p.grad is not None]).sum().sqrt().item()}) 101 | 102 | self.optimizer_gen.step() 103 | self.optimizer_enc.step() 104 | 105 | self.vae_sub_batch_i = (self.vae_sub_batch_i + 1) % self.weight_batch.shape[0] 106 | return log_dict 107 | 108 | def run_epoch(self, data, inner_epochs, inner_optim_fct, batch_size=32, device=None, 109 | train_subtype="train", val_subtype="test", 110 | train=False, generated=False, reconstructed=False, debug=False, precomputed_latent=None, output_mean_std=False, epoch=0): 111 | if train: 112 | self.reset(batch_size, device) 113 | 114 | log_dict = dict() 115 | tasks = list(data.keys()) 116 | if precomputed_latent is None or "task_idx" not in precomputed_latent: 117 | shuffled_train_tasks = torch.randperm(len(tasks)) 118 | else: 119 | shuffled_train_tasks=precomputed_latent["task_idx"][0] 120 | 121 | if output_mean_std: 122 | all_tasks_optimized_params = torch.zeros((len(tasks), self.hnet_enc.input_dim)).to(self.device) 123 | 124 | for inner_train_iter in tqdm(range(len(tasks))): 125 | curr_log_dict = dict() 126 | task_idx = shuffled_train_tasks[inner_train_iter] 127 | train_dataset = CLIP_VQA(meta_data=data, 128 | dataSubType=train_subtype, 129 | task=tasks[task_idx], 130 | image_features=self.image_features, 131 | text_features=self.text_features, 132 | ques_emb=self.ques_emb) 133 | test_dataset = CLIP_VQA(meta_data=data, 134 | dataSubType=val_subtype, 135 | task=tasks[task_idx], 136 | image_features=self.image_features, 137 | text_features=self.text_features, 138 | ques_emb=self.ques_emb) 139 | 140 | train_dataloader = DataLoader(train_dataset, batch_size=1024, shuffle=True) 141 | test_dataloader = DataLoader(test_dataset, batch_size=1024, shuffle=False) 142 | 143 | # task_embedding = iter(train_dataloader).next()["ques_emb"][0] 144 | if generated or reconstructed: 145 | embedding = sample_from_enc(self.hnet_enc, self.meta_module.get_mainnet_weights().detach()).detach() 146 | enet = EmbeddingModule(self.hnet_gen.input_dim).to(device) 147 | enet.reset(embedding=embedding) 148 | 149 | mnet = CLIPAdapter(e_dim=self.meta_module.mnet.e_dim, 150 | hidden_layers=[self.meta_module.mnet.hidden_size], 151 | use_bias=self.meta_module.mnet.use_bias, 152 | straight_through=self.meta_module.mnet.straight_through, 153 | no_weights=True).to(device) 154 | meta_module=MetaModel(mnet=mnet, hnet=self.hnet_gen, enet=enet, config=self.config) 155 | elif False: #reconstructed: 156 | mnet = CLIPAdapter(e_dim=self.meta_module.mnet.e_dim, 157 | hidden_layers=[self.meta_module.mnet.hidden_size], 158 | use_bias=self.meta_module.mnet.use_bias, 159 | straight_through=self.meta_module.mnet.straight_through, 160 | no_weights=True).to(device) 161 | meta_module=MetaModel(mnet=mnet, hnet=self.hnet_gen, enet=None, config=self.config) 162 | else: 163 | meta_module = self.meta_module 164 | 165 | # Make sure to clone and detach to not optimize the actual initialization. 166 | inner_params = {k: v.clone().detach().requires_grad_() for (k,v) in meta_module.get_inner_params().items()} 167 | 168 | if reconstructed: 169 | tmp_inner_params = OrderedDict() 170 | tmp_inner_params["enet.embedding"] = sample_from_enc(self.hnet_enc, self.meta_module.get_mainnet_weights()).detach() 171 | inner_params=tmp_inner_params 172 | #task_embedding = sample_from_enc(self.hnet_enc, self.meta_module.get_mainnet_weights(params=inner_params)).detach() 173 | 174 | train_start_acc, train_start_loss = test_accuracy(meta_module, train_dataloader, 175 | params=inner_params) #, embedding = task_embedding) 176 | val_start_acc, val_start_loss = test_accuracy(meta_module, test_dataloader, params=inner_params) #, embedding = task_embedding) 177 | 178 | if precomputed_latent is None or generated: 179 | # Inner loop 180 | inner_optimizer = inner_optim_fct(list(inner_params.values())) 181 | for _ in range(inner_epochs): 182 | outputs, labels = get_pred(meta_module, train_dataloader, params=inner_params) 183 | inner_loss = F.cross_entropy(outputs, labels) 184 | if debug and inner_train_iter % batch_size == 0: 185 | wandb.log({"debug_inner_loss": inner_loss.item()}) 186 | inner_optimizer.zero_grad() 187 | inner_loss.backward() 188 | inner_optimizer.step() 189 | else: 190 | if self.meta_module.mnet.no_weight: 191 | inner_params["enet.embedding"] = precomputed_latent["embedding"][0, inner_train_iter] 192 | else: 193 | inner_params.update({"mnet."+k:v for (k,v) in self.meta_module.mnet.load_from_vector(precomputed_latent["w_vect"][0, inner_train_iter]).items()}) 194 | 195 | if reconstructed: 196 | tmp_inner_params = OrderedDict() 197 | tmp_inner_params["enet.embedding"] = sample_from_enc(self.hnet_enc, self.meta_module.get_mainnet_weights(params=inner_params)).detach() 198 | inner_params=tmp_inner_params 199 | #task_embedding = sample_from_enc(self.hnet_enc, self.meta_module.get_mainnet_weights(params=inner_params)).detach() 200 | 201 | # Train set accuracy 202 | train_end_acc, train_end_loss = test_accuracy(meta_module, train_dataloader, params=inner_params) #, embedding = task_embedding) 203 | val_end_acc, val_end_loss = test_accuracy(meta_module, test_dataloader, params=inner_params) #, embedding = task_embedding) 204 | 205 | curr_log_dict["query_accuracy_start"] = val_start_acc 206 | curr_log_dict["query_accuracy_end"] = val_end_acc 207 | curr_log_dict["support_accuracy_start"] = train_start_acc 208 | curr_log_dict["support_accuracy_end"] = train_end_acc 209 | curr_log_dict["query_loss_start"] = val_start_loss 210 | curr_log_dict["query_loss_end"] = val_end_loss 211 | curr_log_dict["support_loss_start"] = train_start_loss 212 | curr_log_dict["support_loss_end"] = train_end_loss 213 | 214 | if train: 215 | vae_dict = self.train(self.meta_module.get_mainnet_weights(params=inner_params).detach()) 216 | curr_log_dict.update(vae_dict) 217 | 218 | append_dict(log_dict, curr_log_dict) 219 | 220 | if debug and inner_train_iter % batch_size == 0: 221 | log_metric(mean_dict(log_dict), prefix="debug_") 222 | log_dict = dict() 223 | 224 | if output_mean_std: 225 | all_tasks_optimized_params[inner_train_iter] = self.meta_module.get_mainnet_weights(params=inner_params).detach() 226 | 227 | output_dict = mean_dict(log_dict) 228 | output_dict["epoch"] = epoch 229 | 230 | if not output_mean_std: 231 | return output_dict 232 | else: 233 | optimized_params_mean = torch.mean(all_tasks_optimized_params, dim=0, keepdim=True) 234 | optimized_params_std = torch.std(all_tasks_optimized_params, dim=0, keepdim=True) 235 | return output_dict, optimized_params_mean, optimized_params_std, all_tasks_optimized_params 236 | 237 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elvisnava/hyperclip/8574d3d36fbe1bb3311c3cbb214f07fd73ca0a05/utils/__init__.py -------------------------------------------------------------------------------- /utils/build_opt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | supported_optims = ["adam", "amsgrad", "sgd", "rmsprop", "adamw"] 4 | 5 | 6 | def build_optimizer(parameters, config, loop="inner"): 7 | # loop should be either "inner" or "hyperclip" 8 | in_str = f"{loop}_" 9 | optim = config[in_str+"optimizer"] 10 | lr = config[in_str+"learning_rate"] 11 | weight_decay = config.get(in_str+"weight_decay", 0) 12 | adam_beta1 = config.get(in_str+"adam_beta1", 0.9) 13 | adam_beta2 = config.get(in_str+"adam_beta2", 0.999) 14 | momentum = config.get(in_str+"momentum", 0) 15 | sgd_dampening = config.get(in_str+"sgd_dampening", False) 16 | sgd_nesterov = config.get(in_str+"sgd_nesterov", False) 17 | rmsprop_alpha = config.get(in_str+"rmsprop_alpha", 0.99) 18 | 19 | if optim not in supported_optims: 20 | raise ValueError("Unsupported optim: {}. Must be one of {}".format(optim, supported_optims)) 21 | 22 | if optim == "adam": 23 | optimizer = torch.optim.Adam( 24 | parameters, 25 | lr=lr, 26 | weight_decay=weight_decay, 27 | betas=(adam_beta1, adam_beta2), 28 | ) 29 | 30 | elif optim == "amsgrad": 31 | optimizer = torch.optim.Adam( 32 | parameters, 33 | lr=lr, 34 | weight_decay=weight_decay, 35 | betas=(adam_beta1, adam_beta2), 36 | amsgrad=True, 37 | ) 38 | 39 | elif optim == "sgd": 40 | optimizer = torch.optim.SGD( 41 | parameters, 42 | lr=lr, 43 | momentum=momentum, 44 | weight_decay=weight_decay, 45 | dampening=sgd_dampening, 46 | nesterov=sgd_nesterov, 47 | ) 48 | 49 | elif optim == "rmsprop": 50 | optimizer = torch.optim.RMSprop( 51 | parameters, 52 | lr=lr, 53 | momentum=momentum, 54 | weight_decay=weight_decay, 55 | alpha=rmsprop_alpha, 56 | ) 57 | 58 | elif optim == "adamw": 59 | optimizer = torch.optim.AdamW( 60 | parameters, 61 | lr=lr, 62 | weight_decay=weight_decay, 63 | betas=(adam_beta1, adam_beta2), 64 | ) 65 | 66 | return optimizer 67 | -------------------------------------------------------------------------------- /utils/clip_utils.py: -------------------------------------------------------------------------------- 1 | 2 | embedding_size = { 3 | 4 | 'RN50': 1024, 5 | 'RN101': 512, 6 | 'RN50x4': 640, 7 | 'RN50x16': 768, 8 | 'RN50x64': 1024, 9 | 'ViT-B/32': 512, 10 | 'ViT-B/16': 512, 11 | 'ViT-L/14': 768, 12 | 'ViT-L/14@336px': 768 13 | 14 | } 15 | 16 | image_resolution = { 17 | 18 | 'RN50': 224, 19 | 'RN101': 224, 20 | 'RN50x4': 288, 21 | 'RN50x16': 384, 22 | 'RN50x64': 448, 23 | 'ViT-B/32': 224, 24 | 'ViT-B/16': 224, 25 | 'ViT-L/14': 224, 26 | 'ViT-L/14@336px': 336 27 | 28 | } 29 | 30 | cached_location = { 31 | 32 | 'RN50': "~/.cache/clip/RN50.pt", 33 | 'RN101': "~/.cache/clip/RN101.pt", 34 | 'RN50x4': "~/.cache/clip/RN50x4.pt", 35 | 'RN50x16': "~/.cache/clip/RN50x16.pt", 36 | 'RN50x64': "~/.cache/clip/RN50x64.pt", 37 | 'ViT-B/32': "~/.cache/clip/ViT-B-32.pt", 38 | 'ViT-B/16': "~/.cache/clip/ViT-B-16.pt", 39 | 'ViT-L/14': "~/.cache/clip/ViT-L-14.pt", 40 | 'ViT-L/14@336px': "~/.cache/clip/ViT-L-14-336px.pt" 41 | 42 | } -------------------------------------------------------------------------------- /utils/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # through 8 | # https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/diffusionmodules/util.py 9 | 10 | import math 11 | 12 | import numpy as np 13 | import torch 14 | from einops import repeat 15 | 16 | 17 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 18 | if schedule == "linear": 19 | betas = ( 20 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 21 | ) 22 | 23 | elif schedule == "cosine": 24 | timesteps = ( 25 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 26 | ) 27 | alpha_cums = timesteps / (1 + cosine_s) * np.pi / 2 28 | alpha_cums = torch.cos(alpha_cums).pow(2) 29 | alpha_cums = alpha_cums / alpha_cums[0] 30 | betas = 1 - alpha_cums[1:] / alpha_cums[:-1] 31 | betas = np.clip(betas, a_min=0, a_max=0.999) 32 | 33 | elif schedule == "sqrt_linear": 34 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 35 | elif schedule == "sqrt": 36 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 37 | elif schedule == "quad": 38 | betas = ( 39 | torch.linspace(linear_start ** 0.25, linear_end ** 0.25, n_timestep, dtype=torch.float64) ** 4 40 | ) 41 | else: 42 | raise ValueError(f"schedule '{schedule}' unknown.") 43 | return betas.numpy() 44 | 45 | 46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 47 | if ddim_discr_method == 'uniform': 48 | c = num_ddpm_timesteps // num_ddim_timesteps 49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 50 | elif ddim_discr_method == 'quad': 51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 52 | else: 53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 54 | 55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 56 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 57 | steps_out = ddim_timesteps + 1 58 | if verbose: 59 | print(f'Selected timesteps for ddim sampler: {steps_out}') 60 | return steps_out 61 | 62 | 63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 64 | # select alphas for computing the variance schedule 65 | alphas = alphacums[ddim_timesteps] 66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 67 | 68 | # according the the formula provided in https://arxiv.org/abs/2010.02502 69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 70 | if verbose: 71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 72 | print(f'For the chosen value of eta, which is {eta}, ' 73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 74 | return sigmas, alphas, alphas_prev 75 | 76 | 77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 78 | """ 79 | Create a beta schedule that discretizes the given alpha_t_bar function, 80 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 81 | :param num_diffusion_timesteps: the number of betas to produce. 82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 83 | produces the cumulative product of (1-beta) up to that 84 | part of the diffusion process. 85 | :param max_beta: the maximum beta to use; use values lower than 1 to 86 | prevent singularities. 87 | """ 88 | betas = [] 89 | for i in range(num_diffusion_timesteps): 90 | t1 = i / num_diffusion_timesteps 91 | t2 = (i + 1) / num_diffusion_timesteps 92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 93 | return np.array(betas) 94 | 95 | 96 | def extract_into_tensor(a, t, x_shape): 97 | b, *_ = t.shape 98 | out = a.gather(-1, t) 99 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 100 | 101 | 102 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 103 | """ 104 | Create sinusoidal timestep embeddings. 105 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 106 | These may be fractional. 107 | :param dim: the dimension of the output. 108 | :param max_period: controls the minimum frequency of the embeddings. 109 | :return: an [N x dim] Tensor of positional embeddings. 110 | """ 111 | if not repeat_only: 112 | half = dim // 2 113 | freqs = torch.exp( 114 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 115 | ).to(device=timesteps.device) 116 | args = timesteps[:, None].float() * freqs[None] 117 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 118 | if dim % 2: 119 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 120 | else: 121 | embedding = repeat(timesteps, 'b -> b d', d=dim) 122 | return embedding 123 | 124 | class EMA: 125 | def __init__(self, model, decay): 126 | self.model = model 127 | self.decay = decay 128 | self.mem = { } 129 | with torch.no_grad(): 130 | for p in model.parameters(): 131 | self.mem[p] = p.clone() 132 | 133 | def step(self): 134 | with torch.no_grad(): 135 | for p in self.model.parameters(): 136 | self.mem[p].copy_(self.decay * self.mem[p] + (1 - self.decay) * p) 137 | 138 | def copy_to_model(self): 139 | with torch.no_grad(): 140 | for p in self.model.parameters(): 141 | p.copy_(self.mem[p]) -------------------------------------------------------------------------------- /utils/init_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import wandb 5 | from model.custom_hnet import (CLIPAdapter, EmbeddingModule, HyperEncoder, 6 | HyperGenerator, MetaModel) 7 | from training.vae_learn import sample_from_enc 8 | 9 | base_path = os.path.dirname(os.path.dirname(__file__)) 10 | 11 | def load_few_shot_to_metamodel(run, device): 12 | loaded_config = run.config 13 | meta_module = MetaModel( 14 | inner_param=loaded_config["inner_param"], 15 | mainnet_use_bias=loaded_config["mainnet_use_bias"], 16 | mainnet_hidden_dim=loaded_config["mainnet_hidden_dim"], 17 | hypernet_hidden_dim=[] if loaded_config["hypernet_hidden_dim"]=="" else [int(i) for i in loaded_config["hypernet_hidden_dim"].split(",")], 18 | embedding_dim=loaded_config["embedding_dim"], 19 | straight_through=loaded_config["straight_through"], 20 | config=loaded_config).to(device) 21 | loaded_model_path = base_path + "/evaluation/few_shot/meta_module_" + str(run.name) + ".pth" 22 | meta_module.load_state_dict(torch.load(loaded_model_path), strict=False) 23 | return meta_module 24 | 25 | def load_vae(run, device): 26 | config=run.config 27 | 28 | api = wandb.Api() 29 | if "precompute_checkpoint" in config: 30 | print("Loading from precomputed run {}".format(config["precompute_checkpoint"])) 31 | precomp_run = api.run(config["precompute_checkpoint"]) 32 | precomp_config = precomp_run.config 33 | config.update({"few_shot_checkpoint": precomp_config["few_shot_checkpoint"]}, allow_val_change=True) 34 | 35 | loaded_run = api.run(config["few_shot_checkpoint"]) 36 | tmp_meta_module = load_few_shot_to_metamodel(loaded_run, device) 37 | 38 | hidden_dims = [int(h) for h in config["vae_hidden_dim"].split(",")] 39 | hnet_enc = HyperEncoder(tmp_meta_module.mnet, e_dim=config["vae_noise_dim"], 40 | hidden_dims=hidden_dims, normalize="normalize" in config and config["normalize"]).to(device) 41 | hidden_dims.reverse() 42 | hnet_gen = HyperGenerator(tmp_meta_module.mnet, e_dim=config["vae_noise_dim"], 43 | hidden_dims=hidden_dims, normalize="normalize" in config and config["normalize"]).to(device) 44 | 45 | loaded_model_path = base_path + "/evaluation/vae/hnet_gen_" + str(run.name) + ".pth" 46 | hnet_gen.load_state_dict(torch.load(loaded_model_path), strict=False) 47 | loaded_model_path = base_path + "/evaluation/vae/hnet_enc_" + str(run.name) + ".pth" 48 | hnet_enc.load_state_dict(torch.load(loaded_model_path), strict=False) 49 | 50 | return hnet_gen, hnet_enc, tmp_meta_module 51 | 52 | def load_vae_to_metamodel(run, device): 53 | hnet_gen, hnet_enc, tmp_meta_module = load_vae(run, device) 54 | 55 | embedding = sample_from_enc(hnet_enc, tmp_meta_module.get_mainnet_weights(params = None).detach()).detach() 56 | enet = EmbeddingModule(hnet_gen.input_dim).to(device) 57 | enet.reset(embedding=embedding) 58 | 59 | mnet = CLIPAdapter(e_dim=tmp_meta_module.mnet.e_dim, 60 | hidden_layers=[tmp_meta_module.mnet.hidden_size], 61 | use_bias=tmp_meta_module.mnet.use_bias, 62 | straight_through=tmp_meta_module.mnet.straight_through, 63 | no_weights=True).to(device) 64 | meta_module = MetaModel(mnet=mnet, hnet=hnet_gen, enet=enet, config=run.config) 65 | 66 | return meta_module 67 | 68 | def load_metamodel_from_checkpoint(config, device): 69 | precomputed_latent = None 70 | precomputed_latent_train_eval = None 71 | precomputed_latent_val_eval = None 72 | api = wandb.Api() 73 | if "precompute_checkpoint" in config: 74 | print("Loading from precomputed run {}".format(config["precompute_checkpoint"])) 75 | precomp_run = api.run(config["precompute_checkpoint"]) 76 | precomp_config = precomp_run.config 77 | config.update({"few_shot_checkpoint": None if "few_shot_checkpoint" not in precomp_config else precomp_config["few_shot_checkpoint"]}, allow_val_change=True) 78 | config.update({"vae_checkpoint": None if "vae_checkpoint" not in precomp_config else precomp_config["vae_checkpoint"]}, allow_val_change=True) 79 | 80 | precomputed_file = base_path + "/evaluation/precompute_adaptation/" + str(precomp_run.name) + ".pth" 81 | precomputed_latent = torch.load(precomputed_file) #{k:v.to(device) for (k,v) in torch.load(precomputed_file).items()} 82 | 83 | try: 84 | precomputed_file_train_eval = base_path + "/evaluation/precompute_adaptation/" + str(precomp_run.name) + "_train_eval.pth" 85 | precomputed_file_val_eval = base_path + "/evaluation/precompute_adaptation/" + str(precomp_run.name) + "_val_eval.pth" 86 | precomputed_latent_train_eval = {k:v.to(device) for (k,v) in torch.load(precomputed_file_train_eval).items()} 87 | precomputed_latent_val_eval = {k:v.to(device) for (k,v) in torch.load(precomputed_file_val_eval).items()} 88 | except: 89 | print("Did not find eval latent") 90 | 91 | if "few_shot_checkpoint" in config and config["few_shot_checkpoint"] is not None: 92 | print("Creating MetaModule from FewShot trained model") 93 | loaded_run = api.run(config["few_shot_checkpoint"]) 94 | meta_module = load_few_shot_to_metamodel(loaded_run, device) 95 | elif "vae_checkpoint" in config and config["vae_checkpoint"] is not None: 96 | print("Creating MetaModule from VAE model") 97 | loaded_run = api.run(config["vae_checkpoint"]) 98 | meta_module = load_vae_to_metamodel(loaded_run, device) 99 | else: 100 | return NotImplementedError("something's wrong") 101 | 102 | return meta_module, precomputed_latent, precomputed_latent_train_eval, precomputed_latent_val_eval 103 | 104 | def load_vae_and_metamodel_from_checkpoint(config, device): 105 | api = wandb.Api() 106 | assert "vae_checkpoint" in config and config["vae_checkpoint"] is not None 107 | loaded_run = api.run(config["vae_checkpoint"]) 108 | meta_module, precomputed_latent, precomputed_latent_train_eval, precomputed_latent_val_eval = \ 109 | load_metamodel_from_checkpoint(loaded_run.config, device) 110 | hnet_gen, hnet_enc, _ = load_vae(loaded_run, device) 111 | 112 | return meta_module, hnet_gen, hnet_enc, precomputed_latent, precomputed_latent_train_eval, precomputed_latent_val_eval 113 | -------------------------------------------------------------------------------- /utils/misc_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | 5 | 6 | def str2bool(v): 7 | if isinstance(v, bool): 8 | return v 9 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 10 | return True 11 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 12 | return False 13 | else: 14 | raise argparse.ArgumentTypeError('Boolean value expected.') 15 | 16 | def append_dict(dictionary, new_dict): 17 | for key in new_dict: 18 | if key not in dictionary: 19 | dictionary[key]=[] 20 | dictionary[key].append(new_dict[key]) 21 | 22 | def mean_dict(dictionary): 23 | out_dict = dict() 24 | for key in dictionary: 25 | if isinstance(dictionary[key], list): 26 | if isinstance(dictionary[key][0], list): 27 | out_dict[key] = [np.mean([dictionary[key][j][i] for j in range(len(dictionary[key]))]) for i in range(len(dictionary[key][0]))] 28 | else: 29 | out_dict[key] = np.mean(np.array(dictionary[key])) 30 | else: 31 | out_dict[key] = dictionary[key] 32 | return out_dict 33 | -------------------------------------------------------------------------------- /utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import wandb 3 | from sklearn.metrics import accuracy_score 4 | from torch.nn import functional as F 5 | import copy 6 | from utils.misc_utils import append_dict 7 | import numpy as np 8 | 9 | def get_pred( meta_module, dataloader, params=None, embedding=None): 10 | y_pred = [] 11 | y_true = [] 12 | for sample in dataloader: 13 | sample_image_features = sample["image_features"] 14 | sample_text_features = sample["text_features"] 15 | if embedding is None: 16 | embedding = sample["ques_emb"][0] 17 | labels = sample["label"].to(sample_image_features.device) 18 | similarity = meta_module(sample_image_features, sample_text_features, embedding, params=params) 19 | y_pred.append(similarity) 20 | y_true.append(labels) 21 | 22 | return torch.cat(y_pred), torch.cat(y_true) 23 | 24 | def test_accuracy( meta_module, dataloader, params=None, embedding=None, params_list=None): 25 | # Validation inner-loop testing 26 | meta_module.eval() 27 | 28 | with torch.no_grad(): 29 | if params_list is not None: 30 | output_list = [] 31 | for p in params_list: 32 | output, y_true = get_pred(meta_module, dataloader, params=p, embedding=embedding) 33 | output_list.append(output) 34 | output = torch.stack(output_list).mean(0) 35 | else: 36 | output, y_true = get_pred(meta_module, dataloader, params=params, embedding=embedding) 37 | _, y_pred = output.topk(1) 38 | loss = F.cross_entropy(output, y_true) 39 | 40 | acc = accuracy_score(y_true.cpu().numpy(), y_pred.cpu().numpy()) 41 | meta_module.train() 42 | return acc, loss.item() 43 | 44 | def log_metric(log_dict, prefix=""): 45 | prefixed_dict = dict() 46 | for key in log_dict: 47 | if not isinstance(log_dict[key], list): 48 | prefixed_dict[prefix+key] = log_dict[key] 49 | 50 | wandb.log(prefixed_dict) 51 | 52 | for key in log_dict: 53 | if isinstance(log_dict[key], list): 54 | for i in range(len(log_dict[key])): 55 | wandb.log({prefix+key: log_dict[key][i]}) 56 | 57 | def n_shot_trials_run(model_training, n_shot_trial_dict, config, log_name, data, *args, **kwargs): 58 | if "n_shot_trials_maxN" in config and config["n_shot_trials_maxN"] is not None: 59 | filtered_data, tasks_idx = filter_data_by_max_n_shot(data, config["n_shot_trials_maxN"]) 60 | for n_shot in range(1, config["n_shot_trials_maxN"]+1): 61 | n_shot_trial_dict[f"{log_name}n_shot_{n_shot}/"] = {} 62 | log_dict = model_training.run_epoch(filtered_data, *args, n_shot_training=n_shot, tasks_idxs=tasks_idx, **kwargs) 63 | append_dict(n_shot_trial_dict[f"{log_name}n_shot_{n_shot}/"], log_dict) 64 | # full few shot 65 | log_dict = model_training.run_epoch(data, *args, n_shot_training="full", **kwargs) 66 | log_metric(log_dict, f"{log_name}few_shot/") 67 | 68 | def filter_data_by_fraction(data, keep_tasks_frac): 69 | task_idx = np.range(len(list(data.keys()))*keep_tasks_frac) 70 | return data, task_idx 71 | 72 | def filter_data_by_max_n_shot(data, max_k): 73 | data = copy.deepcopy(data) 74 | tasks_idx =[] 75 | for i,t in enumerate(list(data.keys())): 76 | ans, count = np.unique([a for [_, a] in data[t]["train"]], return_counts=True) 77 | filtered_ans = ans[count >= max_k] 78 | data[t]["train"] = [d for d in data[t]["train"] if d[1] in filtered_ans] 79 | data[t]["test"] = [d for d in data[t]["test"] if d[1] in filtered_ans] 80 | if len(ans)>=2: 81 | tasks_idx.append(i) 82 | return data, tasks_idx 83 | 84 | def update_best_val(best_val_dict, meta_epoch, log_dict): 85 | if best_val_dict["best_val_accuracy"] < log_dict["query_accuracy_end"]: 86 | best_val_dict["best_val_accuracy"] = log_dict["query_accuracy_end"] 87 | best_val_dict["best_val_epoch"] = meta_epoch 88 | 89 | def run_evals(model_training, train_data, test_data, n_shot_trial_dict, config, log_name, *args, skip_train=False, n_shot_from_opt_latent=False, best_val_dict=None, precomputed_latent_train=None, precomputed_latent_val=None, guided_inner=False, use_vae=False, **kwargs): 90 | # Eval on train data 91 | if not skip_train: 92 | log_dict = model_training.run_epoch(train_data, *args, guided_inner=guided_inner, keep_tasks_frac=config["diffusion_keep_tasks_frac"], precomputed_latent=precomputed_latent_train, use_vae=use_vae if precomputed_latent_train is None else False, **kwargs) 93 | log_metric(log_dict, log_name.format("train")) 94 | 95 | # Eval on test data 96 | log_dict, _, _, opt_latents = model_training.run_epoch(test_data, *args, guided_inner=guided_inner, precomputed_latent=precomputed_latent_val, use_vae=use_vae if precomputed_latent_val is None else False, output_mean_std=True, **kwargs) 97 | log_metric(log_dict, log_name.format("val")) 98 | if best_val_dict is not None: 99 | update_best_val(best_val_dict, kwargs["epoch"], log_dict) 100 | log_metric(best_val_dict) 101 | 102 | # N-shot eval on test data 103 | n_shot_trials_run(model_training, n_shot_trial_dict, config, log_name.format("val"), test_data, *args, opt_latents_for_n_shot=opt_latents if n_shot_from_opt_latent else None, use_vae=use_vae, **kwargs) -------------------------------------------------------------------------------- /utils/wandb_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import wandb 4 | 5 | base_path = os.path.dirname(os.path.dirname(__file__)) 6 | 7 | def populate_wandb_table(wandb_val_table, test_data, test_tasks, val_task_id, y_pred, y_true): 8 | ans = test_data[test_tasks[val_task_id]]["answers"] 9 | for i in range(len(y_pred)): 10 | image_name = test_data[test_tasks[val_task_id]]["test"][i][0] 11 | image_folder = image_name.split("_")[1] 12 | wandb_val_table.add_data( 13 | f"{val_task_id}_{i}", 14 | test_tasks[val_task_id], 15 | wandb.Image(base_path + f'/data/VQA/Images/{image_folder}/{image_name}.jpg'), 16 | ans[y_pred[i]], 17 | ans[y_true[i]], 18 | int(y_true[i]==y_pred[i])) 19 | --------------------------------------------------------------------------------