├── .gitignore ├── LICENSE ├── README.md ├── buffer.py ├── data ├── Flickr30k_ann │ ├── coco_karpathy_test.json │ ├── coco_karpathy_train.json │ ├── coco_karpathy_val.json │ ├── flickr30k_test.json │ ├── flickr30k_train.json │ └── flickr30k_val.json ├── __init__.py ├── cifar_dataset.py ├── coco_dataset.py ├── flickr30k_dataset.py └── randaugment.py ├── distill_tesla_lors.py ├── evaluate_only.py ├── images └── method.png ├── requirements.txt ├── sh ├── buffer_coco.sh ├── buffer_flickr.sh ├── distill_coco_lors_100.sh ├── distill_coco_lors_200.sh ├── distill_coco_lors_500.sh ├── distill_flickr_lors_100.sh ├── distill_flickr_lors_200.sh └── distill_flickr_lors_500.sh └── src ├── epoch.py ├── model.py ├── networks.py ├── reparam_module.py ├── similarity_mining.py ├── utils.py └── vl_distill_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | /buffer 2 | distill_utils/checkpoints 3 | distill_utils/data 4 | tmp 5 | logged_files 6 | wandb 7 | 8 | # just simply ignore these files 9 | __pycache__/ 10 | .DS_Store 11 | .ipynb_checkpoints 12 | .vscode 13 | *.pyc 14 | *.npz 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2024, Yue Xu 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LoRS: Low-Rank Similarity Mining 2 | 3 | ### [Paper](http://arxiv.org/abs/2406.03793) 4 | 5 | This repo contains code of our ICML'24 work **LoRS**: **Lo**w-**R**ank **S**imilarity Mining for Multimodal Dataset Distillation. LoRS propose to learn the similarity matrix during distilling the image and text. The simple and plug-and-play method yields significant performance gain. Please check our [paper](http://arxiv.org/abs/2406.03793) for more analysis. 6 | 7 | 8 | ![Method](images/method.png) 9 | 10 | 11 | ## Getting Started 12 | 13 | **Requirements**: please see `requirements.txt`. 14 | 15 | **Pretrained model checkpoints**: you may manually download checkpoint of BERT, NFNet (from TIMM) and put them here: 16 | 17 | ``` 18 | distill_utils/checkpoints/ 19 | ├── bert-base-uncased/ 20 | │ ├── config.json 21 | │ ├── LICENSE.txt 22 | │ ├── model.onnx 23 | │ ├── pytorch_model.bin 24 | │ ├── vocab.txt 25 | │ └── ...... 26 | └── nfnet_l0_ra2-45c6688d.pth 27 | ``` 28 | 29 | **Datasets**: please download Flickr30K: [[Train]](https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json)[[Val]](https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json)[[Test]](https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json)[[Images]](https://www.kaggle.com/datasets/hsankesara/flickr-image-dataset) and COCO: [[Train]](https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json)[[Val]](https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json)[[Test]](https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json)[[Images]](https://cocodataset.org/#download) datasets, and put them here: 30 | 31 | ``` 32 | ./distill_utils/data/ 33 | ├── Flickr30k/ 34 | │ ├── flickr30k-images/ 35 | │ │ ├── 1234.jpg 36 | │ │ └── ...... 37 | │ ├── results_20130124.token 38 | │ └── readme.txt 39 | └── COCO/ 40 | ├── train2014/ 41 | ├── val2014/ 42 | └── test2014/ 43 | ``` 44 | 45 | 46 | 47 | **Training Expert Buffer**: *e.g.* run `sh sh/buffer_flickr.sh`. The expert training takes days. You could manually split the `num_experts` and run multiple processes. 48 | 49 | **Distill with LoRS**: *e.g.* run `sh sh/distill_flickr_lors_100.sh`. The distillation could be run on one single RTX 3090/4090 thanks to [TESLA](https://github.com/justincui03/tesla). 50 | 51 | 52 | 53 | ## Citation 54 | 55 | If you find our work useful and inspiring, please cite our paper: 56 | ``` 57 | @article{xu2024lors, 58 | title={Low-Rank Similarity Mining for Multimodal Dataset Distillation}, 59 | author={Xu, Yue and Lin, Zhilin and Qiu, Yusong and Lu, Cewu and Li, Yong-Lu}, 60 | journal={arXiv e-prints}, 61 | pages={arXiv--2406}, 62 | year={2024} 63 | } 64 | ``` 65 | 66 | 67 | 68 | 69 | ## Acknowledgement 70 | 71 | We following the setting and code of [VL-Distill](https://github.com/princetonvisualai/multimodal_dataset_distillation) and re-implement the algorithm with [TESLA](https://github.com/justincui03/tesla). We deeply appreciate their valuable contribution! 72 | -------------------------------------------------------------------------------- /buffer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | from tqdm import tqdm, trange 6 | from src.vl_distill_utils import load_or_process_file 7 | from src.epoch import epoch, epoch_test, itm_eval 8 | import copy 9 | import wandb 10 | import warnings 11 | import datetime 12 | from data import get_dataset_flickr, textprocess, textprocess_train 13 | from src.networks import CLIPModel_full 14 | import numpy as np 15 | 16 | warnings.filterwarnings("ignore", category=DeprecationWarning) 17 | 18 | def main(args): 19 | if args.disabled_wandb == True: 20 | wandb.init(mode="disabled") 21 | else: 22 | no_aug_suffix = "_NoAug" if args.no_aug else "" 23 | wandb.init(project='LoRS-Buffer', config=args, 24 | name=f"{args.dataset}_{args.image_encoder}_{args.text_encoder}_{args.loss_type}{no_aug_suffix}") 25 | 26 | args.dsa = True if args.dsa == 'True' else False 27 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 28 | args.distributed = torch.cuda.device_count() > 1 29 | 30 | 31 | print('Hyper-parameters: \n', args.__dict__) 32 | 33 | save_dir = os.path.join(args.buffer_path, args.dataset) 34 | if args.dataset in ["CIFAR10", "CIFAR100"] and not args.zca: 35 | save_dir += "_NO_ZCA" 36 | save_dir = os.path.join(save_dir, args.image_encoder+"_"+args.text_encoder, args.loss_type+no_aug_suffix) 37 | if not os.path.exists(save_dir): 38 | os.makedirs(save_dir, exist_ok=True) 39 | ''' organize the datasets ''' 40 | trainloader, testloader, train_dataset, test_dataset = get_dataset_flickr(args) 41 | 42 | train_sentences = train_dataset.get_all_captions() 43 | _ = load_or_process_file('text', textprocess, args, testloader) 44 | _ = load_or_process_file('train_text', textprocess_train, args, train_sentences) 45 | 46 | data = np.load(f'{args.dataset}_{args.text_encoder}_text_embed.npz') 47 | bert_test_embed_loaded = data['bert_test_embed'] 48 | bert_test_embed = torch.from_numpy(bert_test_embed_loaded).cpu() 49 | 50 | 51 | img_trajectories = [] 52 | txt_trajectories = [] 53 | 54 | for it in range(0, args.num_experts): 55 | 56 | ''' Train synthetic data ''' 57 | 58 | teacher_net = CLIPModel_full(args).to(args.device) 59 | img_teacher_net = teacher_net.image_encoder.to(args.device) 60 | txt_teacher_net = teacher_net.text_projection.to(args.device) 61 | 62 | if args.text_trainable: 63 | txt_teacher_net = teacher_net.text_encoder.to(args.device) 64 | if args.distributed: 65 | raise NotImplementedError() 66 | img_teacher_net = torch.nn.DataParallel(img_teacher_net) 67 | txt_teacher_net = torch.nn.DataParallel(txt_teacher_net) 68 | 69 | img_teacher_net.train() 70 | txt_teacher_net.train() 71 | 72 | teacher_optim = torch.optim.SGD([ 73 | {'params': img_teacher_net.parameters(), 'lr': args.lr_teacher_img}, 74 | {'params': txt_teacher_net.parameters(), 'lr': args.lr_teacher_txt}, 75 | ], lr=0, momentum=args.mom, weight_decay=args.l2) 76 | teacher_optim.zero_grad() 77 | lr_schedule = [args.train_epochs // 2 + 1] if args.decay else [] 78 | teacher_optim_scheduler = torch.optim.lr_scheduler.MultiStepLR( 79 | teacher_optim, milestones=lr_schedule, gamma=0.1) 80 | 81 | 82 | img_timestamps = [] 83 | txt_timestamps = [] 84 | 85 | img_timestamps.append([p.detach().cpu() for p in img_teacher_net.parameters()]) 86 | txt_timestamps.append([p.detach().cpu() for p in txt_teacher_net.parameters()]) 87 | 88 | 89 | for e in trange(args.train_epochs): 90 | train_loss, train_acc = epoch(e, trainloader, teacher_net, teacher_optim, args) 91 | 92 | if (e+1) % args.eval_freq == 0: 93 | 94 | score_val_i2t, score_val_t2i = epoch_test(testloader, teacher_net, args.device, bert_test_embed) 95 | val_result = itm_eval(score_val_i2t, score_val_t2i, testloader.dataset.txt2img, testloader.dataset.img2txt) 96 | 97 | wandb.log({ 98 | "Loss/train_loss": train_loss, 99 | "Loss/train_acc": train_acc, 100 | "Results/txt_r1": val_result['txt_r1'], 101 | "Results/txt_r5": val_result['txt_r5'], 102 | "Results/txt_r10": val_result['txt_r10'], 103 | # "txt_r_mean": val_result['txt_r_mean'], 104 | "Results/img_r1": val_result['img_r1'], 105 | "Results/img_r5": val_result['img_r5'], 106 | "Results/img_r10": val_result['img_r10'], 107 | # "img_r_mean": val_result['img_r_mean'], 108 | "Results/r_mean": val_result['r_mean'], 109 | }) 110 | 111 | print("Itr: {} Epoch={} Train Acc={} | Img R@1={} R@5={} R@10={} | Txt R@1={} R@5={} R@10={} | R@Mean={}".format( 112 | it, e, train_acc, 113 | val_result['img_r1'], val_result['img_r5'], val_result['img_r10'], 114 | val_result['txt_r1'], val_result['txt_r5'], val_result['txt_r10'], val_result['r_mean'])) 115 | 116 | 117 | img_timestamps.append([p.detach().cpu() for p in img_teacher_net.parameters()]) 118 | txt_timestamps.append([p.detach().cpu() for p in txt_teacher_net.parameters()]) 119 | 120 | teacher_optim_scheduler.step() 121 | 122 | 123 | if not args.skip_save: 124 | img_trajectories.append(img_timestamps) 125 | txt_trajectories.append(txt_timestamps) 126 | n = 0 127 | while os.path.exists(os.path.join(save_dir, "img_replay_buffer_{}.pt".format(n))): 128 | n += 1 129 | print("Saving {}".format(os.path.join(save_dir, "img_replay_buffer_{}.pt".format(n)))) 130 | torch.save(img_trajectories, os.path.join(save_dir, "img_replay_buffer_{}.pt".format(n))) 131 | print("Saving {}".format(os.path.join(save_dir, "txt_replay_buffer_{}.pt".format(n)))) 132 | torch.save(txt_trajectories, os.path.join(save_dir, "txt_replay_buffer_{}.pt".format(n))) 133 | 134 | img_trajectories = [] 135 | txt_trajectories = [] 136 | 137 | 138 | def make_buffer_parser(): 139 | parser = argparse.ArgumentParser(description='Parameter Processing') 140 | parser.add_argument('--dataset', type=str, default='flickr', choices=['flickr', 'coco'], help='dataset') 141 | parser.add_argument('--num_experts', type=int, default=100, help='training iterations') 142 | parser.add_argument('--lr_teacher_img', type=float, default=0.1, help='learning rate for updating network parameters') 143 | parser.add_argument('--lr_teacher_txt', type=float, default=0.1, help='learning rate for updating network parameters') 144 | parser.add_argument('--batch_train', type=int, default=128, help='batch size for training networks') 145 | parser.add_argument('--dsa', type=str, default='True', choices=['True', 'False'], 146 | help='whether to use differentiable Siamese augmentation.') 147 | parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', 148 | help='differentiable Siamese augmentation strategy') 149 | parser.add_argument('--data_path', type=str, default='./data/Flickr30k/', help='dataset path') 150 | parser.add_argument('--buffer_path', type=str, default='./buffers', help='buffer path') 151 | parser.add_argument('--train_epochs', type=int, default=50) 152 | parser.add_argument('--zca', action='store_true') 153 | parser.add_argument('--decay', action='store_true') 154 | parser.add_argument('--mom', type=float, default=0, help='momentum') 155 | parser.add_argument('--l2', type=float, default=0, help='l2 regularization') 156 | parser.add_argument('--save_interval', type=int, default=10) 157 | current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") 158 | parser.add_argument('--name', type=str, default=current_time, help='name of wandb run') 159 | parser.add_argument('--text_pretrained', type=bool, default=True, help='text_pretrained') 160 | parser.add_argument('--image_pretrained', type=bool, default=True, help='image_pretrained') 161 | parser.add_argument('--text_trainable', type=bool, default=False, help='text_trainable') 162 | parser.add_argument('--image_trainable', type=bool, default=True, help='image_trainable') 163 | parser.add_argument('--batch_size_train', type=int, default=128, help='batch_size_train') 164 | parser.add_argument('--batch_size_test', type=int, default=128, help='batch_size_test') 165 | parser.add_argument('--image_root', type=str, default='distill_utils/data/Flickr30k/', help='location of image root') 166 | parser.add_argument('--ann_root', type=str, default='./data/Flickr30k_ann/', help='location of ann root') 167 | parser.add_argument('--image_size', type=int, default=224, help='image_size') 168 | parser.add_argument('--k_test', type=int, default=128, help='k_test') 169 | parser.add_argument('--load_npy', type=bool, default=False, help='load_npy') 170 | parser.add_argument('--image_encoder', type=str, default='resnet50', help='image encoder') 171 | #, choices=['nfnet', 'resnet18_gn', 'vit_tiny', 'nf_resnet50', 'nf_regnet']) 172 | parser.add_argument('--text_encoder', type=str, default='bert', choices=['bert', 'clip', 'distilbert','gpt1'], help='text encoder') 173 | parser.add_argument('--margin', default=0.2, type=float, 174 | help='Rank loss margin.') 175 | parser.add_argument('--measure', default='cosine', 176 | help='Similarity measure used (cosine|order)') 177 | parser.add_argument('--max_violation', action='store_true', 178 | help='Use max instead of sum in the rank loss.') 179 | parser.add_argument('--only_has_image_projection', type=bool, default=False, help='None') 180 | parser.add_argument('--grounding', type=bool, default=False, help='None') 181 | 182 | parser.add_argument('--distill', type=bool, default=False, help='whether distill') 183 | parser.add_argument('--loss_type', type=str, default="InfoNCE") 184 | 185 | parser.add_argument('--eval_freq', type=int, default=5, help='eval_freq') 186 | parser.add_argument('--no_aug', action='store_true', help='no_aug') 187 | parser.add_argument('--skip_save', action='store_true', help='skip save buffer') 188 | parser.add_argument('--disabled_wandb', type=bool, default=False) 189 | return parser 190 | 191 | 192 | if __name__ == '__main__': 193 | parser = make_buffer_parser() 194 | args = parser.parse_args() 195 | 196 | args.image_root = { 197 | 'flickr': "distill_utils/data/Flickr30k/", 198 | 'coco': "distill_utils/data/COCO/", 199 | }[args.dataset] 200 | 201 | main(args) 202 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from data.randaugment import RandomAugment 3 | from torchvision.transforms.functional import InterpolationMode 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from torch.utils.data import Dataset 7 | from torchvision.datasets.utils import download_url 8 | import json 9 | from PIL import Image 10 | import os 11 | from torchvision import transforms as T 12 | from src.networks import CLIPModel_full 13 | from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval, pre_caption 14 | from data.coco_dataset import coco_train, coco_caption_eval, coco_retrieval_eval 15 | import numpy as np 16 | from tqdm import tqdm 17 | @torch.no_grad() 18 | def textprocess(args, testloader): 19 | net = CLIPModel_full(args).to('cuda') 20 | net.eval() 21 | texts = testloader.dataset.text 22 | if args.dataset in ['flickr', 'coco']: 23 | if args.dataset == 'flickr': 24 | bert_test_embed = net.text_encoder(texts) 25 | elif args.dataset == 'coco': 26 | bert_test_embed = torch.cat((net.text_encoder(texts[:10000]), net.text_encoder(texts[10000:20000]), net.text_encoder(texts[20000:])), dim=0) 27 | 28 | bert_test_embed_np = bert_test_embed.cpu().numpy() 29 | np.savez(f'{args.dataset}_{args.text_encoder}_text_embed.npz', bert_test_embed=bert_test_embed_np) 30 | else: 31 | raise NotImplementedError 32 | return 33 | 34 | @torch.no_grad() 35 | def textprocess_train(args, texts): 36 | net = CLIPModel_full(args).to('cuda') 37 | net.eval() 38 | chunk_size = 2000 39 | chunks = [] 40 | for i in tqdm(range(0, len(texts), chunk_size)): 41 | chunk = net.text_encoder(texts[i:i + chunk_size]).cpu() 42 | chunks.append(chunk) 43 | del chunk 44 | torch.cuda.empty_cache() # free up memory 45 | bert_test_embed = torch.cat(chunks, dim=0) 46 | 47 | print('bert_test_embed.shape: ', bert_test_embed.shape) 48 | bert_test_embed_np = bert_test_embed.numpy() 49 | if args.dataset in ['flickr', 'coco']: 50 | np.savez(f'{args.dataset}_{args.text_encoder}_train_text_embed.npz', bert_test_embed=bert_test_embed_np) 51 | else: 52 | raise NotImplementedError 53 | return 54 | 55 | 56 | def create_dataset(args, min_scale=0.5): 57 | normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 58 | transform_train = transforms.Compose([ 59 | transforms.RandomResizedCrop(args.image_size,scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC), 60 | transforms.RandomHorizontalFlip(), 61 | RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize', 62 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), 63 | transforms.ToTensor(), 64 | normalize, 65 | ]) 66 | transform_test = transforms.Compose([ 67 | transforms.Resize((args.image_size, args.image_size),interpolation=InterpolationMode.BICUBIC), 68 | transforms.ToTensor(), 69 | normalize, 70 | ]) 71 | 72 | if args.no_aug: 73 | transform_train = transform_test # no augmentation 74 | 75 | if args.dataset=='flickr': 76 | train_dataset = flickr30k_train(transform_train, args.image_root, args.ann_root) 77 | val_dataset = flickr30k_retrieval_eval(transform_test, args.image_root, args.ann_root, 'val') 78 | test_dataset = flickr30k_retrieval_eval(transform_test, args.image_root, args.ann_root, 'test') 79 | return train_dataset, val_dataset, test_dataset 80 | 81 | elif args.dataset=='coco': 82 | train_dataset = coco_train(transform_train, args.image_root, args.ann_root) 83 | val_dataset = coco_retrieval_eval(transform_test, args.image_root, args.ann_root, 'val') 84 | test_dataset = coco_retrieval_eval(transform_test, args.image_root, args.ann_root, 'test') 85 | return train_dataset, val_dataset, test_dataset 86 | 87 | else: 88 | raise NotImplementedError 89 | 90 | def create_sampler(datasets, shuffles, num_tasks, global_rank): 91 | samplers = [] 92 | for dataset,shuffle in zip(datasets,shuffles): 93 | sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle) 94 | samplers.append(sampler) 95 | return samplers 96 | 97 | 98 | def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): 99 | loaders = [] 100 | for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns): 101 | if is_train: 102 | shuffle = (sampler is None) 103 | drop_last = True 104 | else: 105 | shuffle = False 106 | drop_last = False 107 | loader = DataLoader( 108 | dataset, 109 | batch_size=bs, 110 | num_workers=n_worker, 111 | pin_memory=True, 112 | sampler=sampler, 113 | shuffle=shuffle, 114 | collate_fn=collate_fn, 115 | drop_last=drop_last, 116 | ) 117 | loaders.append(loader) 118 | return loaders 119 | 120 | 121 | def get_dataset_flickr(args): 122 | print("Creating retrieval dataset") 123 | train_dataset, val_dataset, test_dataset = create_dataset(args) 124 | 125 | samplers = [None, None, None] 126 | train_shuffle = True 127 | train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers, 128 | batch_size=[args.batch_size_train]+[args.batch_size_test]*2, 129 | num_workers=[4,4,4], 130 | is_trains=[train_shuffle, False, False], 131 | collate_fns=[None,None,None]) 132 | 133 | return train_loader, test_loader, train_dataset, test_dataset 134 | 135 | -------------------------------------------------------------------------------- /data/cifar_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from typing import Any, Tuple 4 | from torch.utils.data import Dataset 5 | from PIL import Image 6 | import numpy as np 7 | from torchvision.datasets import CIFAR10 8 | from collections import defaultdict 9 | 10 | 11 | CLASSES = [ 12 | "airplane", 13 | "automobile", 14 | "bird", 15 | "cat", 16 | "deer", 17 | "dog", 18 | "frog", 19 | "horse", 20 | "ship", 21 | "truck", 22 | ] 23 | 24 | PROMPTS_1 = ["This is a {}"] 25 | 26 | PROMPTS_5 = [ 27 | "a photo of a {}", 28 | "a blurry image of a {}", 29 | "a photo of the {}", 30 | "a pixelated photo of a {}", 31 | "a picture of a {}", 32 | ] 33 | 34 | PROMPTS = [ 35 | 'a photo of a {}', 36 | 'a blurry photo of a {}', 37 | 'a low contrast photo of a {}', 38 | 'a high contrast photo of a {}', 39 | 'a bad photo of a {}', 40 | 'a good photo of a {}',s 41 | 'a photo of a small {}', 42 | 'a photo of a big {}', 43 | 'a photo of the {}', 44 | 'a blurry photo of the {}', 45 | 'a low contrast photo of the {}', 46 | 'a high contrast photo of the {}', 47 | 'a bad photo of the {}', 48 | 'a good photo of the {}', 49 | 'a photo of the small {}', 50 | 'a photo of the big {}', 51 | ] 52 | 53 | 54 | class cifar10_train(CIFAR10): 55 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False, num_prompts=1): 56 | super(cifar10_train, self).__init__(root, train=train, transform=transform, target_transform=target_transform, download=download) 57 | if num_prompts == 1: 58 | self.prompts = PROMPTS_1 59 | elif num_prompts == 5: 60 | self.prompts = PROMPTS_5 61 | else: 62 | self.prompts = PROMPTS 63 | self.captions = [prompt.format(cls) for cls in CLASSES for prompt in self.prompts] 64 | self.captions_to_label = {cap: i for i, cap in enumerate(self.captions)} 65 | self.annotations = [] 66 | for i in range(len(self.data)): 67 | cls_name = CLASSES[self.targets[i]] 68 | for prompt in self.prompts: 69 | caption = prompt.format(cls_name) 70 | self.annotations.append({"img_id": i, "caption_id": self.captions_to_label[caption]}) 71 | if num_prompts == 1: 72 | self.annotations = self.annotations * 5 73 | 74 | def __len__(self): 75 | return len(self.annotations) 76 | 77 | def __getitem__(self, index): 78 | ann = self.annotations[index] 79 | img_id = ann['img_id'] 80 | img = self.transform(self.data[img_id]) 81 | caption = self.captions[ann['caption_id']] 82 | return img, caption, img_id 83 | 84 | def fetch_distill_images(self, ipc): 85 | """ 86 | Randomly fetch `x` number of images from each class using numpy and return as a tensor. 87 | """ 88 | class_indices = defaultdict(list) 89 | 90 | for idx, label in enumerate(self.targets): 91 | class_indices[label].append(idx) 92 | 93 | # Randomly sample x indices for each class using numpy 94 | sampled_indices = [np.random.choice(indices, ipc, replace=False) for indices in class_indices.values()] 95 | sampled_indices = [idx for class_indices in sampled_indices for idx in class_indices] 96 | 97 | # Fetch images and labels using the selected indices 98 | images = torch.stack([self.transform(self.data[i]) for i in sampled_indices]) 99 | labels = [self.targets[i] for i in sampled_indices] 100 | 101 | captions = [] 102 | for label in labels: 103 | cls_name = CLASSES[label] 104 | prompt = np.random.choice(self.prompts) 105 | random_caption = prompt.format(cls_name) 106 | captions.append(random_caption) 107 | 108 | return images, captions 109 | 110 | def get_all_captions(self): 111 | return self.captions 112 | 113 | 114 | class cifar10_retrieval_eval(cifar10_train): 115 | def __init__(self, root, train=False, transform=None, target_transform=None, download=False, num_prompts=1): 116 | """ 117 | image_root (string): Root directory of images (e.g. coco/images/) 118 | ann_root (string): directory to store the annotation file 119 | split (string): val or test 120 | """ 121 | super(cifar10_retrieval_eval, self).__init__(root, train=train, transform=transform, target_transform=target_transform, download=download, num_prompts=num_prompts) 122 | self.text = self.captions 123 | self.txt2img = {} 124 | self.img2txt = defaultdict(list) 125 | 126 | for ann in self.annotations: 127 | img_id = ann['img_id'] 128 | caption_id = ann['caption_id'] 129 | self.img2txt[img_id].append(caption_id) 130 | self.txt2img[caption_id] = img_id 131 | 132 | def __len__(self): 133 | return len(self.data) 134 | 135 | def __getitem__(self, index): 136 | image = self.transform(self.data[index]) 137 | return image, index 138 | -------------------------------------------------------------------------------- /data/coco_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from torch.utils.data import Dataset 4 | from torchvision.datasets.utils import download_url 5 | from PIL import Image 6 | import re 7 | 8 | def pre_caption(caption,max_words=50): 9 | caption = re.sub( 10 | r"([.!\"()*#:;~])", 11 | ' ', 12 | caption.lower(), 13 | ) 14 | caption = re.sub( 15 | r"\s{2,}", 16 | ' ', 17 | caption, 18 | ) 19 | caption = caption.rstrip('\n') 20 | caption = caption.strip(' ') 21 | 22 | #truncate caption 23 | caption_words = caption.split(' ') 24 | if len(caption_words)>max_words: 25 | caption = ' '.join(caption_words[:max_words]) 26 | 27 | return caption 28 | 29 | class coco_train(Dataset): 30 | def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''): 31 | ''' 32 | image_root (string): Root directory of images (e.g. coco/images/) 33 | ann_root (string): directory to store the annotation file 34 | ''' 35 | url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json' 36 | filename = 'coco_karpathy_train.json' 37 | 38 | download_url(url,ann_root) 39 | 40 | self.annotation = json.load(open(os.path.join(ann_root,filename),'r')) 41 | self.transform = transform 42 | self.image_root = image_root 43 | self.max_words = max_words 44 | self.prompt = prompt 45 | 46 | self.img_ids = {} 47 | n = 0 48 | for ann in self.annotation: 49 | img_id = ann['image_id'] 50 | if img_id not in self.img_ids.keys(): 51 | self.img_ids[img_id] = n 52 | n += 1 53 | 54 | def __len__(self): 55 | return len(self.annotation) 56 | 57 | def __getitem__(self, index): 58 | 59 | ann = self.annotation[index] 60 | 61 | image_path = os.path.join(self.image_root,ann['image']) 62 | image = Image.open(image_path).convert('RGB') 63 | image = self.transform(image) 64 | 65 | caption = self.prompt+pre_caption(ann['caption'], self.max_words) 66 | 67 | return image, caption, self.img_ids[ann['image_id']] 68 | 69 | def get_all_captions(self): 70 | captions = [] 71 | for ann in self.annotation: 72 | caption = self.prompt + pre_caption(ann['caption'], self.max_words) 73 | captions.append(caption) 74 | return captions 75 | 76 | 77 | class coco_caption_eval(Dataset): 78 | def __init__(self, transform, image_root, ann_root, split): 79 | ''' 80 | image_root (string): Root directory of images (e.g. coco/images/) 81 | ann_root (string): directory to store the annotation file 82 | split (string): val or test 83 | ''' 84 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json', 85 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'} 86 | filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'} 87 | 88 | download_url(urls[split],ann_root) 89 | 90 | self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) 91 | self.transform = transform 92 | self.image_root = image_root 93 | 94 | def __len__(self): 95 | return len(self.annotation) 96 | 97 | def __getitem__(self, index): 98 | 99 | ann = self.annotation[index] 100 | 101 | image_path = os.path.join(self.image_root,ann['image']) 102 | image = Image.open(image_path).convert('RGB') 103 | image = self.transform(image) 104 | 105 | img_id = ann['image'].split('/')[-1].strip('.jpg').split('_')[-1] 106 | 107 | return image, int(img_id) 108 | 109 | 110 | class coco_retrieval_eval(Dataset): 111 | def __init__(self, transform, image_root, ann_root, split, max_words=30): 112 | ''' 113 | image_root (string): Root directory of images (e.g. coco/images/) 114 | ann_root (string): directory to store the annotation file 115 | split (string): val or test 116 | ''' 117 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json', 118 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'} 119 | filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'} 120 | 121 | download_url(urls[split],ann_root) 122 | 123 | self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) 124 | self.transform = transform 125 | self.image_root = image_root 126 | 127 | self.text = [] 128 | self.image = [] 129 | self.txt2img = {} 130 | self.img2txt = {} 131 | 132 | txt_id = 0 133 | for img_id, ann in enumerate(self.annotation): 134 | self.image.append(ann['image']) 135 | self.img2txt[img_id] = [] 136 | for i, caption in enumerate(ann['caption']): 137 | self.text.append(pre_caption(caption,max_words)) 138 | self.img2txt[img_id].append(txt_id) 139 | self.txt2img[txt_id] = img_id 140 | txt_id += 1 141 | 142 | def __len__(self): 143 | return len(self.annotation) 144 | 145 | def __getitem__(self, index): 146 | 147 | image_path = os.path.join(self.image_root, self.annotation[index]['image']) 148 | image = Image.open(image_path).convert('RGB') 149 | image = self.transform(image) 150 | 151 | return image, index -------------------------------------------------------------------------------- /data/flickr30k_dataset.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | import os 3 | import torch 4 | from tqdm import tqdm 5 | import numpy as np 6 | import yaml 7 | from transformers import BertTokenizer, BertModel 8 | from torchvision import transforms as T 9 | from torchvision.transforms.functional import InterpolationMode 10 | from torch.utils.data import Dataset 11 | from torchvision.datasets.utils import download_url 12 | import argparse 13 | import re 14 | import json 15 | from PIL import Image 16 | 17 | def pre_caption(caption,max_words=50): 18 | caption = re.sub( 19 | r"([.!\"()*#:;~])", 20 | ' ', 21 | caption.lower(), 22 | ) 23 | caption = re.sub( 24 | r"\s{2,}", 25 | ' ', 26 | caption, 27 | ) 28 | caption = caption.rstrip('\n') 29 | caption = caption.strip(' ') 30 | 31 | #truncate caption 32 | caption_words = caption.split(' ') 33 | if len(caption_words)>max_words: 34 | caption = ' '.join(caption_words[:max_words]) 35 | 36 | return caption 37 | 38 | 39 | class flickr30k_train(Dataset): 40 | def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''): 41 | ''' 42 | image_root (string): Root directory of images (e.g. flickr30k/) 43 | ann_root (string): directory to store the annotation file 44 | ''' 45 | url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json' 46 | filename = 'flickr30k_train.json' 47 | 48 | #################### 49 | # download_url(url,ann_root) 50 | #################### 51 | 52 | self.annotation = json.load(open(os.path.join(ann_root,filename),'r')) 53 | self.transform = transform 54 | self.image_root = image_root 55 | self.max_words = max_words 56 | self.prompt = prompt 57 | 58 | self.img_ids = {} 59 | n = 0 60 | for ann in self.annotation: 61 | img_id = ann['image_id'] 62 | if img_id not in self.img_ids.keys(): 63 | self.img_ids[img_id] = n 64 | n += 1 65 | 66 | def __len__(self): 67 | return len(self.annotation) 68 | 69 | @lru_cache(maxsize=100) 70 | def read_image(self, image_path): 71 | image = Image.open(image_path).convert('RGB') 72 | image = self.transform(image) 73 | return image 74 | 75 | def __getitem__(self, index): 76 | 77 | ann = self.annotation[index] 78 | 79 | image_path = os.path.join(self.image_root,ann['image']) 80 | image = self.read_image(image_path) 81 | # image = Image.open(image_path).convert('RGB') 82 | # image = self.transform(image) 83 | 84 | caption = self.prompt+pre_caption(ann['caption'], self.max_words) 85 | 86 | return image, caption, self.img_ids[ann['image_id']] 87 | 88 | def get_all_captions(self): 89 | captions = [] 90 | for ann in self.annotation: 91 | caption = self.prompt + pre_caption(ann['caption'], self.max_words) 92 | captions.append(caption) 93 | return captions 94 | 95 | 96 | 97 | class flickr30k_retrieval_eval(Dataset): 98 | def __init__(self, transform, image_root, ann_root, split, max_words=30): 99 | ''' 100 | image_root (string): Root directory of images (e.g. flickr30k/) 101 | ann_root (string): directory to store the annotation file 102 | split (string): val or test 103 | ''' 104 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json', 105 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'} 106 | filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'} 107 | 108 | ####################### 109 | # download_url(urls[split],ann_root) 110 | ####################### 111 | 112 | self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) 113 | self.transform = transform 114 | self.image_root = image_root 115 | self.max_words = max_words 116 | 117 | self.text = [] 118 | self.image = [] 119 | self.txt2img = {} 120 | self.img2txt = {} 121 | 122 | txt_id = 0 123 | for img_id, ann in enumerate(self.annotation): 124 | self.image.append(ann['image']) 125 | self.img2txt[img_id] = [] 126 | for i, caption in enumerate(ann['caption']): 127 | self.text.append(pre_caption(caption,max_words)) 128 | self.img2txt[img_id].append(txt_id) 129 | self.txt2img[txt_id] = img_id 130 | txt_id += 1 131 | 132 | def __len__(self): 133 | return len(self.annotation) 134 | 135 | def __getitem__(self, index): 136 | image_path = os.path.join(self.image_root, self.annotation[index]['image']) 137 | image = Image.open(image_path).convert('RGB') 138 | image = self.transform(image) 139 | 140 | return image, index 141 | 142 | 143 | -------------------------------------------------------------------------------- /data/randaugment.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | ## aug functions 6 | def identity_func(img): 7 | return img 8 | 9 | 10 | def autocontrast_func(img, cutoff=0): 11 | ''' 12 | same output as PIL.ImageOps.autocontrast 13 | ''' 14 | n_bins = 256 15 | 16 | def tune_channel(ch): 17 | n = ch.size 18 | cut = cutoff * n // 100 19 | if cut == 0: 20 | high, low = ch.max(), ch.min() 21 | else: 22 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 23 | low = np.argwhere(np.cumsum(hist) > cut) 24 | low = 0 if low.shape[0] == 0 else low[0] 25 | high = np.argwhere(np.cumsum(hist[::-1]) > cut) 26 | high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0] 27 | if high <= low: 28 | table = np.arange(n_bins) 29 | else: 30 | ##### modified by lzl ############### 31 | scale = (n_bins - 1) / (high - low) 32 | low = np.int8(low) 33 | offset = -low * scale 34 | ################################ 35 | table = np.arange(n_bins) * scale + offset 36 | table[table < 0] = 0 37 | table[table > n_bins - 1] = n_bins - 1 38 | table = table.clip(0, 255).astype(np.uint8) 39 | return table[ch] 40 | 41 | channels = [tune_channel(ch) for ch in cv2.split(img)] 42 | out = cv2.merge(channels) 43 | return out 44 | 45 | 46 | def equalize_func(img): 47 | ''' 48 | same output as PIL.ImageOps.equalize 49 | PIL's implementation is different from cv2.equalize 50 | ''' 51 | n_bins = 256 52 | 53 | def tune_channel(ch): 54 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 55 | non_zero_hist = hist[hist != 0].reshape(-1) 56 | step = np.sum(non_zero_hist[:-1]) // (n_bins - 1) 57 | if step == 0: return ch 58 | n = np.empty_like(hist) 59 | n[0] = step // 2 60 | n[1:] = hist[:-1] 61 | table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8) 62 | return table[ch] 63 | 64 | channels = [tune_channel(ch) for ch in cv2.split(img)] 65 | out = cv2.merge(channels) 66 | return out 67 | 68 | 69 | def rotate_func(img, degree, fill=(0, 0, 0)): 70 | ''' 71 | like PIL, rotate by degree, not radians 72 | ''' 73 | H, W = img.shape[0], img.shape[1] 74 | center = W / 2, H / 2 75 | M = cv2.getRotationMatrix2D(center, degree, 1) 76 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill) 77 | return out 78 | 79 | 80 | def solarize_func(img, thresh=128): 81 | ''' 82 | same output as PIL.ImageOps.posterize 83 | ''' 84 | table = np.array([el if el < thresh else 255 - el for el in range(256)]) 85 | table = table.clip(0, 255).astype(np.uint8) 86 | out = table[img] 87 | return out 88 | 89 | 90 | def color_func(img, factor): 91 | ''' 92 | same output as PIL.ImageEnhance.Color 93 | ''' 94 | ## implementation according to PIL definition, quite slow 95 | # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis] 96 | # out = blend(degenerate, img, factor) 97 | # M = ( 98 | # np.eye(3) * factor 99 | # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor) 100 | # )[np.newaxis, np.newaxis, :] 101 | M = ( 102 | np.float32([ 103 | [0.886, -0.114, -0.114], 104 | [-0.587, 0.413, -0.587], 105 | [-0.299, -0.299, 0.701]]) * factor 106 | + np.float32([[0.114], [0.587], [0.299]]) 107 | ) 108 | out = np.matmul(img, M).clip(0, 255).astype(np.uint8) 109 | return out 110 | 111 | 112 | def contrast_func(img, factor): 113 | """ 114 | same output as PIL.ImageEnhance.Contrast 115 | """ 116 | mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299])) 117 | table = np.array([( 118 | el - mean) * factor + mean 119 | for el in range(256) 120 | ]).clip(0, 255).astype(np.uint8) 121 | out = table[img] 122 | return out 123 | 124 | 125 | def brightness_func(img, factor): 126 | ''' 127 | same output as PIL.ImageEnhance.Contrast 128 | ''' 129 | table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8) 130 | out = table[img] 131 | return out 132 | 133 | 134 | def sharpness_func(img, factor): 135 | ''' 136 | The differences the this result and PIL are all on the 4 boundaries, the center 137 | areas are same 138 | ''' 139 | kernel = np.ones((3, 3), dtype=np.float32) 140 | kernel[1][1] = 5 141 | kernel /= 13 142 | degenerate = cv2.filter2D(img, -1, kernel) 143 | if factor == 0.0: 144 | out = degenerate 145 | elif factor == 1.0: 146 | out = img 147 | else: 148 | out = img.astype(np.float32) 149 | degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :] 150 | out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate) 151 | out = out.astype(np.uint8) 152 | return out 153 | 154 | 155 | def shear_x_func(img, factor, fill=(0, 0, 0)): 156 | H, W = img.shape[0], img.shape[1] 157 | M = np.float32([[1, factor, 0], [0, 1, 0]]) 158 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 159 | return out 160 | 161 | 162 | def translate_x_func(img, offset, fill=(0, 0, 0)): 163 | ''' 164 | same output as PIL.Image.transform 165 | ''' 166 | H, W = img.shape[0], img.shape[1] 167 | M = np.float32([[1, 0, -offset], [0, 1, 0]]) 168 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 169 | return out 170 | 171 | 172 | def translate_y_func(img, offset, fill=(0, 0, 0)): 173 | ''' 174 | same output as PIL.Image.transform 175 | ''' 176 | H, W = img.shape[0], img.shape[1] 177 | M = np.float32([[1, 0, 0], [0, 1, -offset]]) 178 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 179 | return out 180 | 181 | 182 | def posterize_func(img, bits): 183 | ''' 184 | same output as PIL.ImageOps.posterize 185 | ''' 186 | out = np.bitwise_and(img, np.uint8(255 << (8 - bits))) 187 | return out 188 | 189 | 190 | def shear_y_func(img, factor, fill=(0, 0, 0)): 191 | H, W = img.shape[0], img.shape[1] 192 | M = np.float32([[1, 0, 0], [factor, 1, 0]]) 193 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 194 | return out 195 | 196 | 197 | def cutout_func(img, pad_size, replace=(0, 0, 0)): 198 | replace = np.array(replace, dtype=np.uint8) 199 | H, W = img.shape[0], img.shape[1] 200 | rh, rw = np.random.random(2) 201 | pad_size = pad_size // 2 202 | ch, cw = int(rh * H), int(rw * W) 203 | x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H) 204 | y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W) 205 | out = img.copy() 206 | out[x1:x2, y1:y2, :] = replace 207 | return out 208 | 209 | 210 | ### level to args 211 | def enhance_level_to_args(MAX_LEVEL): 212 | def level_to_args(level): 213 | return ((level / MAX_LEVEL) * 1.8 + 0.1,) 214 | return level_to_args 215 | 216 | 217 | def shear_level_to_args(MAX_LEVEL, replace_value): 218 | def level_to_args(level): 219 | level = (level / MAX_LEVEL) * 0.3 220 | if np.random.random() > 0.5: level = -level 221 | return (level, replace_value) 222 | 223 | return level_to_args 224 | 225 | 226 | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value): 227 | def level_to_args(level): 228 | level = (level / MAX_LEVEL) * float(translate_const) 229 | if np.random.random() > 0.5: level = -level 230 | return (level, replace_value) 231 | 232 | return level_to_args 233 | 234 | 235 | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value): 236 | def level_to_args(level): 237 | level = int((level / MAX_LEVEL) * cutout_const) 238 | return (level, replace_value) 239 | 240 | return level_to_args 241 | 242 | 243 | def solarize_level_to_args(MAX_LEVEL): 244 | def level_to_args(level): 245 | level = int((level / MAX_LEVEL) * 256) 246 | return (level, ) 247 | return level_to_args 248 | 249 | 250 | def none_level_to_args(level): 251 | return () 252 | 253 | 254 | def posterize_level_to_args(MAX_LEVEL): 255 | def level_to_args(level): 256 | level = int((level / MAX_LEVEL) * 4) 257 | return (level, ) 258 | return level_to_args 259 | 260 | 261 | def rotate_level_to_args(MAX_LEVEL, replace_value): 262 | def level_to_args(level): 263 | level = (level / MAX_LEVEL) * 30 264 | if np.random.random() < 0.5: 265 | level = -level 266 | return (level, replace_value) 267 | 268 | return level_to_args 269 | 270 | 271 | func_dict = { 272 | 'Identity': identity_func, 273 | 'AutoContrast': autocontrast_func, 274 | 'Equalize': equalize_func, 275 | 'Rotate': rotate_func, 276 | 'Solarize': solarize_func, 277 | 'Color': color_func, 278 | 'Contrast': contrast_func, 279 | 'Brightness': brightness_func, 280 | 'Sharpness': sharpness_func, 281 | 'ShearX': shear_x_func, 282 | 'TranslateX': translate_x_func, 283 | 'TranslateY': translate_y_func, 284 | 'Posterize': posterize_func, 285 | 'ShearY': shear_y_func, 286 | } 287 | 288 | translate_const = 10 289 | MAX_LEVEL = 10 290 | replace_value = (128, 128, 128) 291 | arg_dict = { 292 | 'Identity': none_level_to_args, 293 | 'AutoContrast': none_level_to_args, 294 | 'Equalize': none_level_to_args, 295 | 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value), 296 | 'Solarize': solarize_level_to_args(MAX_LEVEL), 297 | 'Color': enhance_level_to_args(MAX_LEVEL), 298 | 'Contrast': enhance_level_to_args(MAX_LEVEL), 299 | 'Brightness': enhance_level_to_args(MAX_LEVEL), 300 | 'Sharpness': enhance_level_to_args(MAX_LEVEL), 301 | 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value), 302 | 'TranslateX': translate_level_to_args( 303 | translate_const, MAX_LEVEL, replace_value 304 | ), 305 | 'TranslateY': translate_level_to_args( 306 | translate_const, MAX_LEVEL, replace_value 307 | ), 308 | 'Posterize': posterize_level_to_args(MAX_LEVEL), 309 | 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value), 310 | } 311 | 312 | 313 | class RandomAugment(object): 314 | 315 | def __init__(self, N=2, M=10, isPIL=False, augs=[]): 316 | self.N = N 317 | self.M = M 318 | self.isPIL = isPIL 319 | if augs: 320 | self.augs = augs 321 | else: 322 | self.augs = list(arg_dict.keys()) 323 | 324 | def get_random_ops(self): 325 | sampled_ops = np.random.choice(self.augs, self.N) 326 | return [(op, 0.5, self.M) for op in sampled_ops] 327 | 328 | def __call__(self, img): 329 | if self.isPIL: 330 | img = np.array(img) 331 | ops = self.get_random_ops() 332 | for name, prob, level in ops: 333 | if np.random.random() > prob: 334 | continue 335 | args = arg_dict[name](level) 336 | img = func_dict[name](img, *args) 337 | return img 338 | 339 | 340 | if __name__ == '__main__': 341 | a = RandomAugment() 342 | img = np.random.randn(32, 32, 3) 343 | a(img) -------------------------------------------------------------------------------- /distill_tesla_lors.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import glob 3 | import os 4 | import argparse 5 | import re 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | from tqdm import tqdm 11 | import torchvision.utils 12 | import math 13 | 14 | import wandb 15 | import copy 16 | import datetime 17 | 18 | from data import get_dataset_flickr, textprocess, textprocess_train 19 | from src.epoch import evaluate_synset_with_similarity 20 | from src.networks import CLIPModel_full, MultilabelContrastiveLoss 21 | from src.reparam_module import ReparamModule 22 | from src.utils import ParamDiffAug, get_time 23 | from src.similarity_mining import LowRankSimilarityGenerator, FullSimilarityGenerator 24 | from src.vl_distill_utils import shuffle_files, nearest_neighbor, get_images_texts, load_or_process_file 25 | 26 | 27 | 28 | def make_timestamp(prefix: str="", suffix: str="") -> str: 29 | tmstamp = '{:%m%d_%H%M%S}'.format(datetime.datetime.now()) 30 | return prefix + tmstamp + suffix 31 | 32 | 33 | 34 | 35 | def main(args): 36 | ''' organize the real train dataset ''' 37 | trainloader, testloader, train_dataset, test_dataset = get_dataset_flickr(args) 38 | 39 | train_sentences = train_dataset.get_all_captions() 40 | 41 | data = load_or_process_file('text', textprocess, args, testloader) 42 | train_caption = load_or_process_file('train_text', textprocess_train, args, train_sentences) 43 | 44 | bert_test_embed = torch.from_numpy(data['bert_test_embed']).cpu() 45 | print("The shape of bert_test_embed: {}".format(bert_test_embed.shape)) 46 | train_caption_embed = torch.from_numpy(train_caption['bert_test_embed']).cpu() 47 | print("The shape of train_caption_embed: {}".format(train_caption_embed.shape)) 48 | 49 | logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / args.temperature)) 50 | 51 | 52 | if args.zca and args.texture: 53 | raise AssertionError("Cannot use zca and texture together") 54 | 55 | if args.texture and args.pix_init == "real": 56 | print("WARNING: Using texture with real initialization will take a very long time to smooth out the boundaries between images.") 57 | 58 | if args.max_experts is not None and args.max_files is not None: 59 | args.total_experts = args.max_experts * args.max_files 60 | 61 | print("CUDNN STATUS: {}".format(torch.backends.cudnn.enabled)) 62 | 63 | if args.dsa == True: 64 | print("unfortunately, this repo did not support DSA") 65 | raise AssertionError("DSA is not supported in this repo") 66 | args.dsa = True if args.dsa == 'True' else False 67 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 68 | 69 | if args.eval_it>0: 70 | eval_it_pool = np.arange(0, args.Iteration + 1, args.eval_it).tolist() 71 | else: 72 | eval_it_pool = [] 73 | 74 | if args.dsa: 75 | args.dc_aug_param = None 76 | 77 | if args.disabled_wandb: 78 | wandb.init(mode = 'disabled') 79 | else: 80 | wandb.init(project='LoRS', config=args, name=args.name+"_"+make_timestamp()) 81 | 82 | args.dsa_param = ParamDiffAug() 83 | zca_trans = args.zca_trans if args.zca else None 84 | args.zca_trans = zca_trans 85 | args.distributed = torch.cuda.device_count() > 1 86 | 87 | syn_lr_img = torch.tensor(args.lr_teacher_img).to(args.device).requires_grad_(True) 88 | syn_lr_txt = torch.tensor(args.lr_teacher_txt).to(args.device).requires_grad_(True) 89 | 90 | ''' initialize the synthetic data ''' 91 | image_syn, text_syn = get_images_texts(args.num_queries, train_dataset, args) 92 | 93 | if args.sim_type == 'lowrank': 94 | sim_generator = LowRankSimilarityGenerator( 95 | args.num_queries, args.sim_rank, args.alpha) 96 | elif args.sim_type == 'full': 97 | sim_generator = FullSimilarityGenerator(args.num_queries) 98 | else: 99 | raise AssertionError("Invalid similarity type: {}".format(args.sim_type)) 100 | sim_generator = sim_generator.to(args.device) 101 | 102 | 103 | contrastive_criterion = MultilabelContrastiveLoss(args.loss_type) 104 | 105 | if args.pix_init == 'noise': 106 | mean = torch.tensor([-0.0626, -0.0221, 0.0680]) 107 | std = torch.tensor([1.0451, 1.0752, 1.0539]) 108 | image_syn = torch.randn([args.num_queries, 3, 224, 224]) 109 | for c in range(3): 110 | image_syn[:, c] = image_syn[:, c] * std[c] + mean[c] 111 | print('Initialized synthetic image from random noise') 112 | 113 | if args.txt_init == 'noise': 114 | text_syn = torch.normal(mean=-0.0094, std=0.5253, size=(args.num_queries, 768)) 115 | print('Initialized synthetic text from random noise') 116 | 117 | start_it = 0 118 | if args.resume_from is not None: 119 | ckpt = torch.load(args.resume_from) 120 | image_syn = ckpt["image"].to(args.device).requires_grad_(True) 121 | text_syn = ckpt["text"].to(args.device).requires_grad_(True) 122 | if "similarity_params" in ckpt: 123 | sim_generator.load_params([x.to(args.device) for x in ckpt["similarity_params"]]) 124 | else: 125 | print("WARNING: no similarity matrix in the checkpoint") 126 | syn_lr_img = ckpt["syn_lr_img"].to(args.device).requires_grad_(True) 127 | syn_lr_txt = ckpt["syn_lr_txt"].to(args.device).requires_grad_(True) 128 | 129 | re_res = re.findall(r"distilled_(\d+).pt", args.resume_from) 130 | if len(re_res) == 1: 131 | start_it = int(re_res[0]) 132 | else: 133 | start_it = 0 134 | 135 | 136 | ''' training ''' 137 | image_syn = image_syn.detach().to(args.device).requires_grad_(True) 138 | text_syn = text_syn.detach().to(args.device).requires_grad_(True) 139 | 140 | optimizer = torch.optim.SGD([ 141 | {'params': [image_syn], 'lr': args.lr_img, "momentum": args.momentum_syn}, 142 | {'params': [text_syn], 'lr': args.lr_txt, "momentum": args.momentum_syn}, 143 | {'params': [syn_lr_img, syn_lr_txt], 'lr': args.lr_lr, "momentum": args.momentum_lr}, 144 | {'params': sim_generator.get_indexed_parameters(), 'lr': args.lr_sim, "momentum": args.momentum_sim}, 145 | ], lr=0) 146 | optimizer.zero_grad() 147 | 148 | if args.draw: 149 | sentence_list = nearest_neighbor(train_sentences, text_syn.detach().cpu(), train_caption_embed) 150 | wandb.log({"original_sentence_list": wandb.Html('
'.join(sentence_list))}, step=0) 151 | wandb.log({"original_synthetic_images": wandb.Image(torch.nan_to_num(image_syn.detach().cpu()))}, step=0) 152 | 153 | 154 | expert_dir = os.path.join(args.buffer_path, args.dataset) 155 | expert_dir = args.buffer_path 156 | print("Expert Dir: {}".format(expert_dir)) 157 | 158 | 159 | img_expert_files = list(glob.glob(os.path.join(expert_dir, "img_replay_buffer_*.pt"))) 160 | txt_expert_files = list(glob.glob(os.path.join(expert_dir, "txt_replay_buffer_*.pt"))) 161 | if len(txt_expert_files) != len(img_expert_files) or len(txt_expert_files) == 0: 162 | raise AssertionError("No buffers / Error buffers detected at {}".format(expert_dir)) 163 | 164 | img_expert_files, txt_expert_files = shuffle_files(img_expert_files, txt_expert_files) 165 | 166 | file_idx = 0 167 | expert_idx = 0 168 | 169 | img_buffer = torch.load(img_expert_files[file_idx]) 170 | txt_buffer = torch.load(txt_expert_files[file_idx]) 171 | 172 | for it in tqdm(range(start_it, args.Iteration + 1)): 173 | save_this_it = True 174 | 175 | wandb.log({"Progress": it}, step=it) 176 | ''' Evaluate synthetic data ''' 177 | if it in eval_it_pool: 178 | print('Evaluation\nimage_model_train = %s, text_model_train = %s, iteration = %d'%(args.image_encoder, args.text_encoder, it)) 179 | 180 | multi_eval_aggr_result = defaultdict(list) # aggregated results of multiple evaluations 181 | 182 | # r_means = [] 183 | for it_eval in range(args.num_eval): 184 | net_eval = CLIPModel_full(args, eval_stage=args.transfer) 185 | 186 | with torch.no_grad(): 187 | image_save = image_syn 188 | text_save = text_syn 189 | image_syn_eval, text_syn_eval = copy.deepcopy(image_save.detach()), copy.deepcopy(text_save.detach()) # avoid any unaware modification 190 | 191 | lr_img, lr_txt = copy.deepcopy(syn_lr_img.detach().item()), copy.deepcopy(syn_lr_txt.detach().item()) 192 | 193 | sim_params = sim_generator.get_indexed_parameters() 194 | similarity_syn_eval = copy.deepcopy(sim_generator.generate_with_param(sim_params).detach()) # avoid any unaware modification 195 | 196 | _, _, best_val_result = evaluate_synset_with_similarity( 197 | it_eval, net_eval, image_syn_eval, text_syn_eval, lr_img, lr_txt, 198 | similarity_syn_eval, testloader, args, bert_test_embed) 199 | 200 | for k, v in best_val_result.items(): 201 | multi_eval_aggr_result[k].append(v) 202 | 203 | if not args.std: 204 | wandb.log({ 205 | k: v 206 | for k, v in best_val_result.items() 207 | if k not in ["img_r_mean", "txt_r_mean"] 208 | }, step=it) 209 | # logged img_r1, img_r5, img_r10, txt_r1, txt_r5, txt_r10, r_mean 210 | 211 | 212 | if args.std: 213 | for key, values in multi_eval_aggr_result.items(): 214 | if key in ["img_r_mean", "txt_r_mean"]: 215 | continue 216 | wandb.log({ 217 | "Mean/{}".format(key): np.mean(values), 218 | "Std/{}".format(key): np.std(values) 219 | }, step=it) 220 | 221 | 222 | if it in eval_it_pool and (save_this_it or it % 1000 == 0): 223 | with torch.no_grad(): 224 | save_dir = os.path.join(".", "logged_files", args.dataset, wandb.run.name) 225 | print("Saving to {}".format(save_dir)) 226 | if not os.path.exists(save_dir): 227 | os.makedirs(save_dir) 228 | 229 | image_save = image_syn.detach().cpu() 230 | text_save = text_syn.detach().cpu() 231 | sim_params = sim_generator.get_indexed_parameters() 232 | sim_mat = sim_generator.generate_with_param(sim_params) 233 | 234 | torch.save({ 235 | "image": image_save, 236 | "text": text_save, 237 | "similarity_params": [x.detach().cpu() for x in sim_params], 238 | "similarity_mat": sim_mat.detach().cpu(), 239 | "syn_lr_img": syn_lr_img.detach().cpu(), 240 | "syn_lr_txt": syn_lr_txt.detach().cpu(), 241 | }, os.path.join(save_dir, "distilled_{}.pt".format(it)) ) 242 | 243 | 244 | if args.draw: 245 | wandb.log({"Pixels": wandb.Histogram(torch.nan_to_num(image_syn.detach().cpu()))}, step=it) # Move tensor to CPU before converting to NumPy 246 | 247 | if args.ipc < 50 or args.force_save: 248 | upsampled = image_save[:90] 249 | if args.dataset != "ImageNet": 250 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=2) 251 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3) 252 | grid = torchvision.utils.make_grid(upsampled, nrow=10, normalize=True, scale_each=True) 253 | sentence_list = nearest_neighbor(train_sentences, text_syn.cpu(), train_caption_embed) 254 | sentence_list = sentence_list[:90] 255 | torchvision.utils.save_image(grid, os.path.join(save_dir, "synthetic_images_{}.png".format(it))) 256 | 257 | with open(os.path.join(save_dir, "synthetic_sentences_{}.txt".format(it)), "w") as file: 258 | file.write('\n'.join(sentence_list)) 259 | wandb.log({"Synthetic_Images": wandb.Image(torch.nan_to_num(grid.detach().cpu()))}, step=it) 260 | wandb.log({'Synthetic_Pixels': wandb.Histogram(torch.nan_to_num(image_save.detach().cpu()))}, step=it) 261 | wandb.log({"Synthetic_Sentences": wandb.Html('
'.join(sentence_list))}, step=it) 262 | print("finish saving images") 263 | 264 | for clip_val in [2.5]: 265 | std = torch.std(image_save) 266 | mean = torch.mean(image_save) 267 | upsampled = torch.clip(image_save, min=mean-clip_val*std, max=mean+clip_val*std).cpu() # Move to CPU 268 | if args.dataset != "ImageNet": 269 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=2) 270 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3) 271 | grid = torchvision.utils.make_grid(upsampled[:90], nrow=10, normalize=True, scale_each=True) 272 | wandb.log({"Clipped_Synthetic_Images/std_{}".format(clip_val): wandb.Image(torch.nan_to_num(grid))}, step=it) 273 | torchvision.utils.save_image(grid, os.path.join(save_dir, "clipped_synthetic_images_{}_std_{}.png".format(it, clip_val))) 274 | 275 | 276 | if args.zca: 277 | raise AssertionError("we do not use ZCA transformation") 278 | 279 | wandb.log({"Synthetic_LR/Image": syn_lr_img.detach().cpu()}, step=it) 280 | wandb.log({"Synthetic_LR/Text": syn_lr_txt.detach().cpu()}, step=it) 281 | 282 | torch.cuda.empty_cache() 283 | student_net = CLIPModel_full(args, temperature=args.temperature) 284 | img_student_net = ReparamModule(student_net.image_encoder.to('cpu')).to('cuda') 285 | txt_student_net = ReparamModule(student_net.text_projection.to('cpu')).to('cuda') 286 | 287 | if args.distributed: 288 | img_student_net = torch.nn.DataParallel(img_student_net) 289 | txt_student_net = torch.nn.DataParallel(txt_student_net) 290 | 291 | img_student_net.train() 292 | txt_student_net.train() 293 | 294 | 295 | img_expert_trajectory = img_buffer[expert_idx] 296 | txt_expert_trajectory = txt_buffer[expert_idx] 297 | expert_idx += 1 298 | if expert_idx == len(img_buffer): 299 | expert_idx = 0 300 | file_idx += 1 301 | if file_idx == len(img_expert_files): 302 | file_idx = 0 303 | img_expert_files, txt_expert_files = shuffle_files(img_expert_files, txt_expert_files) 304 | if args.max_files != 1: 305 | del img_buffer 306 | del txt_buffer 307 | img_buffer = torch.load(img_expert_files[file_idx]) 308 | txt_buffer = torch.load(txt_expert_files[file_idx]) 309 | 310 | start_epoch = np.random.randint(0, args.max_start_epoch) 311 | img_starting_params = img_expert_trajectory[start_epoch] 312 | txt_starting_params = txt_expert_trajectory[start_epoch] 313 | 314 | img_target_params = img_expert_trajectory[start_epoch+args.expert_epochs] 315 | txt_target_params = txt_expert_trajectory[start_epoch+args.expert_epochs] 316 | 317 | img_target_params = torch.cat([p.data.to(args.device).reshape(-1) for p in img_target_params], 0) 318 | txt_target_params = torch.cat([p.data.to(args.device).reshape(-1) for p in txt_target_params], 0) 319 | 320 | img_student_params = [torch.cat([p.data.to(args.device).reshape(-1) for p in img_starting_params], 0).requires_grad_(True)] 321 | txt_student_params = [torch.cat([p.data.to(args.device).reshape(-1) for p in txt_starting_params], 0).requires_grad_(True)] 322 | 323 | img_starting_params = torch.cat([p.data.to(args.device).reshape(-1) for p in img_starting_params], 0) 324 | txt_starting_params = torch.cat([p.data.to(args.device).reshape(-1) for p in txt_starting_params], 0) 325 | syn_images = image_syn 326 | syn_texts = text_syn 327 | 328 | x_list = [] 329 | y_list = [] 330 | sim_list = [] # the parameters of parameterized similarity matrix 331 | sim_param_list = [] # the parameters of parameterized similarity matrix 332 | grad_sum_img = torch.zeros(img_student_params[-1].shape).to(args.device) 333 | grad_sum_txt = torch.zeros(txt_student_params[-1].shape).to(args.device) 334 | 335 | syn_image_gradients = torch.zeros(syn_images.shape).to(args.device) 336 | syn_txt_gradients = torch.zeros(syn_texts.shape).to(args.device) 337 | syn_sim_param_gradients = [torch.zeros(x.shape).to(args.device) for x in sim_generator.get_indexed_parameters()] 338 | 339 | indices_chunks_copy = [] 340 | indices = torch.randperm(len(syn_images)) 341 | index = 0 342 | for _ in range(args.syn_steps): 343 | if args.mini_batch_size + index > len(syn_images): 344 | indices = torch.randperm(len(syn_images)) 345 | index = 0 346 | these_indices = indices[index : index + args.mini_batch_size] 347 | index += args.mini_batch_size 348 | 349 | indices_chunks_copy.append(these_indices) 350 | 351 | x = syn_images[these_indices] 352 | this_y = syn_texts[these_indices] 353 | 354 | x_list.append(x.clone()) 355 | y_list.append(this_y.clone()) 356 | 357 | this_sim_param = sim_generator.get_indexed_parameters(these_indices) 358 | this_sim = sim_generator.generate_with_param(this_sim_param) 359 | if args.distributed: 360 | img_forward_params = img_student_params[-1].unsqueeze(0).expand(torch.cuda.device_count(), -1) 361 | txt_forward_params = txt_student_params[-1].unsqueeze(0).expand(torch.cuda.device_count(), -1) 362 | else: 363 | img_forward_params = img_student_params[-1] 364 | txt_forward_params = txt_student_params[-1] 365 | 366 | x = img_student_net(x, flat_param=img_forward_params) 367 | x = x / x.norm(dim=1, keepdim=True) 368 | this_y = txt_student_net(this_y, flat_param=txt_forward_params) 369 | this_y = this_y / this_y.norm(dim=1, keepdim=True) 370 | image_logits = logit_scale.exp() * x.float() @ this_y.float().t() 371 | 372 | sim_list.append(this_sim) 373 | sim_param_list.append(this_sim_param) 374 | 375 | contrastive_loss = contrastive_criterion(image_logits, this_sim) 376 | 377 | img_grad = torch.autograd.grad(contrastive_loss, img_student_params[-1], create_graph=True)[0] 378 | txt_grad = torch.autograd.grad(contrastive_loss, txt_student_params[-1], create_graph=True)[0] 379 | 380 | 381 | img_detached_grad = img_grad.detach().clone() 382 | img_student_params.append(img_student_params[-1] - syn_lr_img.item() * img_detached_grad) 383 | grad_sum_img += img_detached_grad 384 | 385 | txt_detached_grad = txt_grad.detach().clone() 386 | txt_student_params.append(txt_student_params[-1] - syn_lr_txt.item() * txt_detached_grad) 387 | grad_sum_txt += txt_detached_grad 388 | 389 | 390 | del img_grad 391 | del txt_grad 392 | 393 | img_param_dist = torch.tensor(0.0).to(args.device) 394 | txt_param_dist = torch.tensor(0.0).to(args.device) 395 | 396 | img_param_dist += torch.nn.functional.mse_loss(img_starting_params, img_target_params, reduction="sum") 397 | txt_param_dist += torch.nn.functional.mse_loss(txt_starting_params, txt_target_params, reduction="sum") 398 | 399 | 400 | # compute gradients invoving 2 gradients 401 | for i in range(args.syn_steps): 402 | x = img_student_net(x_list[i], flat_param=img_student_params[i]) 403 | x = x / x.norm(dim=1, keepdim=True) 404 | this_y = txt_student_net(y_list[i], flat_param=txt_student_params[i]) 405 | this_y = this_y / this_y.norm(dim=1, keepdim=True) 406 | image_logits = logit_scale.exp() * x.float() @ this_y.float().t() 407 | loss_i = contrastive_criterion(image_logits, sim_list[i]) 408 | 409 | 410 | img_single_term = syn_lr_img.item() * (img_target_params - img_starting_params) 411 | img_square_term = (syn_lr_img.item() ** 2) * grad_sum_img 412 | txt_single_term = syn_lr_txt.item() * (txt_target_params - txt_starting_params) 413 | txt_square_term = (syn_lr_txt.item() ** 2) * grad_sum_txt 414 | 415 | 416 | img_grad_i = torch.autograd.grad(loss_i, img_student_params[i], create_graph=True, retain_graph=True)[0] 417 | txt_grad_i = torch.autograd.grad(loss_i, txt_student_params[i], create_graph=True, retain_graph=True)[0] 418 | 419 | 420 | img_syn_real_dist = (img_single_term + img_square_term) @ img_grad_i 421 | txt_syn_real_dist = (txt_single_term + txt_square_term) @ txt_grad_i 422 | 423 | if args.merge_loss_branches: 424 | # Following traditional MTT loss, equivalent to adding some weights on the two branches 425 | grand_loss_i = (img_syn_real_dist + txt_syn_real_dist) / ( img_param_dist + txt_param_dist) 426 | else: 427 | # Original loss in vl-distill 428 | grand_loss_i = img_syn_real_dist / img_param_dist + txt_syn_real_dist / txt_param_dist 429 | 430 | multiple_gradients_in_one_time = torch.autograd.grad( 431 | 2 * grand_loss_i, 432 | [x_list[i], y_list[i]] + sim_param_list[i] 433 | ) 434 | 435 | img_gradients = multiple_gradients_in_one_time[0] 436 | txt_gradients = multiple_gradients_in_one_time[1] 437 | sim_param_gradients = multiple_gradients_in_one_time[2:] 438 | 439 | 440 | with torch.no_grad(): 441 | ids = indices_chunks_copy[i] 442 | syn_image_gradients[ids] += img_gradients 443 | syn_txt_gradients[ids] += txt_gradients 444 | 445 | assert len(sim_param_gradients) == len(syn_sim_param_gradients), f"{len(sim_param_gradients)}, {len(syn_sim_param_gradients)}" 446 | for g_idx in range(len(sim_param_gradients)): 447 | if args.sim_type == 'full': 448 | syn_sim_param_gradients[g_idx][ids[:,None], ids] += sim_param_gradients[g_idx] 449 | # !!! gradient will be lost if using xxxx[ids, :][:, ids] 450 | else: 451 | syn_sim_param_gradients[g_idx][ids, ...] += sim_param_gradients[g_idx] 452 | 453 | 454 | # ---------end of computing input image gradients and learning rates-------------- 455 | 456 | syn_images.grad = syn_image_gradients 457 | syn_texts.grad = syn_txt_gradients 458 | 459 | for g_idx, param in enumerate(sim_generator.get_indexed_parameters()): 460 | param.grad = syn_sim_param_gradients[g_idx] 461 | 462 | 463 | img_grand_loss = img_starting_params - syn_lr_img * grad_sum_img - img_target_params 464 | txt_grand_loss = txt_starting_params - syn_lr_txt * grad_sum_txt - txt_target_params 465 | img_grand_loss = img_grand_loss.dot(img_grand_loss) / img_param_dist 466 | txt_grand_loss = txt_grand_loss.dot(txt_grand_loss) / txt_param_dist 467 | grand_loss = img_grand_loss + txt_grand_loss 468 | 469 | img_lr_grad = torch.autograd.grad(img_grand_loss, syn_lr_img)[0] 470 | syn_lr_img.grad = img_lr_grad 471 | txt_lr_grad = torch.autograd.grad(txt_grand_loss, syn_lr_txt)[0] 472 | syn_lr_txt.grad = txt_lr_grad 473 | 474 | if math.isnan(img_grand_loss): 475 | break 476 | 477 | wandb.log({ 478 | "Loss/grand": grand_loss.detach().cpu(), 479 | # "Start_Epoch": start_epoch, 480 | "Loss/img_grand": img_grand_loss.detach().cpu(), 481 | "Loss/txt_grand": txt_grand_loss.detach().cpu(), 482 | }, step=it) 483 | 484 | 485 | wandb.log({"Synthetic_LR/grad_syn_lr_img": syn_lr_img.grad.detach().cpu()}, step=it) 486 | wandb.log({"Synthetic_LR/grad_syn_lr_txt": syn_lr_txt.grad.detach().cpu()}, step=it) 487 | 488 | optimizer.step() # no need zero_grad: grad is not computed by autograd; it is directly override by our computation 489 | 490 | if args.clamp_lr: 491 | syn_lr_img.data = torch.clamp(syn_lr_img.data, min=args.clamp_lr) # 492 | syn_lr_txt.data = torch.clamp(syn_lr_txt.data, min=args.clamp_lr) 493 | 494 | for _ in img_student_params: 495 | del _ 496 | for _ in txt_student_params: 497 | del _ 498 | 499 | if it%10 == 0: 500 | print('%s iter = %04d, loss = %.4f' % (get_time(), it, grand_loss.item())) 501 | 502 | 503 | wandb.finish() 504 | 505 | 506 | if __name__ == '__main__': 507 | parser = argparse.ArgumentParser(description='Parameter Processing') 508 | 509 | parser.add_argument('--dataset', type=str, default='flickr30k', help='dataset') 510 | 511 | parser.add_argument('--eval_mode', type=str, default='S', 512 | help='eval_mode, check utils.py for more info') 513 | 514 | parser.add_argument('--num_eval', type=int, default=5, help='how many networks to evaluate on') 515 | 516 | parser.add_argument('--eval_it', type=int, default=50, help='how often to evaluate') 517 | 518 | parser.add_argument('--epoch_eval_train', type=int, default=100, help='epochs to train a model with synthetic data') 519 | parser.add_argument('--Iteration', type=int, default=3000, help='how many distillation steps to perform') 520 | 521 | parser.add_argument('--lr_img', type=float, default=1000, help='learning rate for updating synthetic images') 522 | parser.add_argument('--lr_txt', type=float, default=1000, help='learning rate for updating synthetic texts') 523 | parser.add_argument('--lr_lr', type=float, default=1e-03, help='learning rate for updating... learning rate') 524 | parser.add_argument('--lr_teacher_img', type=float, default=0.1, help='learning rate for updating network parameters') 525 | parser.add_argument('--lr_teacher_txt', type=float, default=0.1, help='learning rate for updating network parameters') 526 | 527 | parser.add_argument('--loss_type', type=str) 528 | 529 | parser.add_argument('--batch_train', type=int, default=128, help='batch size for training networks') 530 | 531 | parser.add_argument('--pix_init', type=str, default='real', choices=["noise", "real"], 532 | help='noise/real: initialize synthetic images from random noise or randomly sampled real images.') 533 | parser.add_argument('--txt_init', type=str, default='real', choices=["noise", "real"], 534 | help='noise/real: initialize synthetic texts from random noise or randomly sampled real images.') 535 | 536 | parser.add_argument('--dsa', type=str, default='True', choices=['True', 'False'], 537 | help='whether to use differentiable Siamese augmentation.') 538 | 539 | parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', 540 | help='differentiable Siamese augmentation strategy') 541 | 542 | parser.add_argument('--data_path', type=str, default='./data/Flickr30k/', help='dataset path') 543 | parser.add_argument('--buffer_path', type=str, default='./buffers', help='buffer path') 544 | 545 | parser.add_argument('--expert_epochs', type=int, default=3, help='how many expert epochs the target params are') 546 | parser.add_argument('--syn_steps', type=int, default=20, help='how many steps to take on synthetic data') 547 | parser.add_argument('--max_start_epoch', type=int, default=25, help='max epoch we can start at') 548 | 549 | parser.add_argument('--zca', action='store_true', help="do ZCA whitening") 550 | 551 | parser.add_argument('--load_all', action='store_true', help="only use if you can fit all expert trajectories into RAM") 552 | 553 | parser.add_argument('--no_aug', action="store_true", default=False, help='this turns off diff aug during distillation') 554 | 555 | parser.add_argument('--texture', action='store_true', help="will distill textures instead") 556 | parser.add_argument('--canvas_size', type=int, default=2, help='size of synthetic canvas') 557 | parser.add_argument('--canvas_samples', type=int, default=1, help='number of canvas samples per iteration') 558 | 559 | 560 | parser.add_argument('--max_files', type=int, default=None, help='number of expert files to read (leave as None unless doing ablations)') 561 | parser.add_argument('--max_experts', type=int, default=None, help='number of experts to read per file (leave as None unless doing ablations)') 562 | 563 | parser.add_argument('--force_save', action='store_true', help='this will save images for 50ipc') 564 | current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") 565 | parser.add_argument('--name', type=str, default=current_time, help='name of wandb run') 566 | parser.add_argument('--num_queries', type=int, default=100, help='number of queries') 567 | parser.add_argument('--mini_batch_size', type=int, default=100, help='number of queries') 568 | parser.add_argument('--basis', type=bool, default=False, help='whether use basis or not') 569 | parser.add_argument('--n_basis', type=int, default=64, help='n_basis') 570 | parser.add_argument('--recursive', type=bool, default=False, help='whether use basis or not') 571 | parser.add_argument('--load_npy', type=bool, default=False, help='load_npy') 572 | parser.add_argument('--image_size', type=int, default=224, help='image_size') 573 | parser.add_argument('--image_root', type=str, default='distill_utils/data/Flickr30k/', help='location of image root') 574 | parser.add_argument('--ann_root', type=str, default='./data/Flickr30k_ann/', help='location of ann root') 575 | parser.add_argument('--batch_size_train', type=int, default=128, help='batch_size_train') 576 | parser.add_argument('--batch_size_test', type=int, default=128, help='batch_size_test') 577 | parser.add_argument('--image_encoder', type=str, default='nfnet', help='image encoder') # , choices=['clip', 'nfnet', 'vit', 'nf_resnet50', "nf_regnet"] 578 | parser.add_argument('--text_encoder', type=str, default='bert', choices=['bert', 'clip', 'distilbert'], help='text encoder') 579 | parser.add_argument('--text_pretrained', type=bool, default=True, help='text_pretrained') 580 | parser.add_argument('--image_pretrained', type=bool, default=True, help='image_pretrained') 581 | parser.add_argument('--text_trainable', type=bool, default=False, help='text_trainable') 582 | parser.add_argument('--image_trainable', type=bool, default=True, help='image_trainable') 583 | parser.add_argument('--only_has_image_projection', type=bool, default=False, help='None') 584 | parser.add_argument('--distill', type=bool, default=True, help='whether distill') 585 | parser.add_argument('--optimize', type=str, default='reparam', choices=['reparam', 'ift'], help='matching_train') 586 | parser.add_argument('--image_only', type=bool, default=False, help='None') 587 | parser.add_argument('--text_only', type=bool, default=False, help='None') 588 | parser.add_argument('--draw', type=bool, default=False, help='None') 589 | parser.add_argument('--transfer', type=bool, default=False, help='transfer cross architecture') 590 | parser.add_argument('--std', type=bool, default=True, help='standard deviation') 591 | parser.add_argument('--disabled_wandb', type=bool, default=False, help='disable wandb') 592 | parser.add_argument('--test_with_norm', type=bool, default=False, help='') 593 | 594 | parser.add_argument('--clamp_lr', type=float, default=None, help='') 595 | 596 | 597 | # Arguments below are for LoRS 598 | 599 | parser.add_argument('--resume_from', default=None, type=str) 600 | 601 | parser.add_argument('--sim_type', type=str, default="full", choices=["full", "lowrank"], help='similarity matrix type') 602 | parser.add_argument('--sim_rank', type=int, default=10, help='similarity matrix rank') 603 | parser.add_argument('--alpha', type=float, default=0.1, help='alpha in LoRA') 604 | parser.add_argument('--lr_sim', type=float, default=1e-03, help='learning rate for updating similarity mat learning rate') 605 | parser.add_argument('--temperature', type=float, default=0.07, help="temperature of CLIP model") 606 | 607 | parser.add_argument('--momentum_lr', type=float, default=0.5) 608 | parser.add_argument('--momentum_syn', type=float, default=0.5) 609 | parser.add_argument('--momentum_sim', type=float, default=0.5) 610 | parser.add_argument('--merge_loss_branches', action="store_true", default=False) 611 | 612 | args = parser.parse_args() 613 | 614 | main(args) -------------------------------------------------------------------------------- /evaluate_only.py: -------------------------------------------------------------------------------- 1 | """Evaliuate the distilled data, could be used for cross-arch exp 2 | Example: 3 | 4 | python evaluate_only.py --dataset=flickr --num_eval 1 \ 5 | --ckpt_path tmp/flickr_500_distilled.pt --loss_type WBCE \ 6 | --image_encoder=nf_resnet50 --text_encoder=bert --batch_train 64 7 | """ 8 | 9 | from collections import defaultdict 10 | import os 11 | import re 12 | 13 | import numpy as np 14 | import torch 15 | 16 | import copy 17 | import argparse 18 | import datetime 19 | 20 | from data import get_dataset_flickr, textprocess, textprocess_train 21 | from src.reparam_module import ReparamModule 22 | from src.epoch import evaluate_synset_with_similarity 23 | from src.networks import CLIPModel_full 24 | 25 | from src.vl_distill_utils import load_or_process_file 26 | 27 | 28 | def formatting_result_head(): 29 | return "Img R@1 | Img R@5 | Img R@10 | Txt R@1 | Txt R@5 | Txt R@10 | Mean" 30 | 31 | 32 | def formatting_result_content(val_result): 33 | return "{img_r1:9.2f} | {img_r5:9.2f} | {img_r10:9.2f} | {txt_r1:9.2f} | {txt_r5:9.2f} | {txt_r10:9.2f} | {r_mean:9.2f}".format( 34 | **val_result 35 | ) 36 | 37 | def formatting_result_content_clean(val_result): 38 | return "{img_r1} {img_r5} {img_r10} {txt_r1} {txt_r5} {txt_r10} {r_mean}".format( 39 | **val_result 40 | ) 41 | 42 | def formatting_result_all(val_result): 43 | return "Image R@1={img_r1} R@5={img_r5} R@10={img_r10} | Text R@1={txt_r1} R@5={txt_r5} R@10={txt_r10} | Mean={r_mean}".format( 44 | **val_result 45 | ) 46 | 47 | 48 | 49 | def main(args): 50 | ''' organize the real train dataset ''' 51 | trainloader, testloader, train_dataset, test_dataset = get_dataset_flickr(args) 52 | 53 | train_sentences = train_dataset.get_all_captions() 54 | 55 | data = load_or_process_file('text', textprocess, args, testloader) 56 | train_caption = load_or_process_file('train_text', textprocess_train, args, train_sentences) 57 | 58 | bert_test_embed = torch.from_numpy(data['bert_test_embed']).cpu() 59 | print("The shape of bert_test_embed: {}".format(bert_test_embed.shape)) 60 | train_caption_embed = torch.from_numpy(train_caption['bert_test_embed']).cpu() 61 | print("The shape of train_caption_embed: {}".format(train_caption_embed.shape)) 62 | 63 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 64 | 65 | 66 | if not os.path.exists(args.ckpt_path): 67 | args.ckpt_path = "logged_files/"+args.ckpt_path 68 | if not os.path.exists(args.ckpt_path): 69 | raise ValueError(f"{args.ckpt_path} does not exist") 70 | 71 | 72 | print("Load from", args.ckpt_path) 73 | ckpt = torch.load(args.ckpt_path) 74 | image_syn = ckpt["image"].to(args.device) 75 | text_syn = ckpt["text"].to(args.device) 76 | sim_mat = ckpt["similarity_mat"].to(args.device) 77 | syn_lr_img = ckpt.get("syn_lr_img") if args.syn_lr_img is None else args.syn_lr_img 78 | syn_lr_txt = ckpt.get("syn_lr_txt") if args.syn_lr_txt is None else args.syn_lr_txt 79 | print(syn_lr_img, syn_lr_txt) 80 | 81 | if args.clip_similarity: 82 | # ablation - use similarity matrix from pretrained CLIP 83 | buffer_dir = os.path.join("buffer", args.dataset, args.image_encoder+"_"+args.text_encoder, "InfoNCE") 84 | img_starting_params = torch.load(os.path.join(buffer_dir, "img_replay_buffer_0.pt"))[0][-1] 85 | txt_starting_params = torch.load(os.path.join(buffer_dir, "txt_replay_buffer_0.pt"))[0][-1] 86 | img_forward_params = torch.cat([p.data.to(args.device).reshape(-1) for p in img_starting_params], 0) 87 | txt_forward_params = torch.cat([p.data.to(args.device).reshape(-1) for p in txt_starting_params], 0) 88 | 89 | clip_model = CLIPModel_full(args, eval_stage=args.transfer) 90 | img_student_net = ReparamModule(clip_model.image_encoder.to('cpu')).to('cuda') 91 | txt_student_net = ReparamModule(clip_model.text_projection.to('cpu')).to('cuda') 92 | 93 | 94 | img_feat = img_student_net(image_syn, flat_param=img_forward_params) 95 | img_feat = img_feat / img_feat.norm(dim=1, keepdim=True) 96 | txt_feat = txt_student_net(text_syn, flat_param=txt_forward_params) 97 | txt_feat = txt_feat / txt_feat.norm(dim=1, keepdim=True) 98 | sim_mat = torch.sigmoid(img_feat.float() @ txt_feat.float().t() / args.temperature) 99 | 100 | print('Evaluation\nimage_model_train = %s, text_model_train = %s, iteration = ?'%(args.image_encoder, args.text_encoder)) 101 | 102 | multi_eval_aggr_result = defaultdict(list) # aggregated results of multiple evaluations 103 | 104 | for it_eval in range(args.num_eval): 105 | net_eval = CLIPModel_full(args, eval_stage=args.transfer) 106 | 107 | image_syn_eval, text_syn_eval = copy.deepcopy(image_syn), copy.deepcopy(text_syn) # avoid any unaware modification 108 | similarity_syn_eval = copy.deepcopy(sim_mat) # avoid any unaware modification 109 | 110 | _, _, best_val_result = evaluate_synset_with_similarity( 111 | it_eval, net_eval, image_syn_eval, text_syn_eval, syn_lr_img, syn_lr_txt, 112 | similarity_syn_eval, testloader, args, bert_test_embed, mom=args.mom, l2=args.l2) 113 | 114 | for k, v in best_val_result.items(): 115 | multi_eval_aggr_result[k].append(v) 116 | 117 | 118 | if not args.std: 119 | formatting_result_content(best_val_result) 120 | # formatting_result_content_clean(best_val_result) 121 | # logged img_r1, img_r5, img_r10, txt_r1, txt_r5, txt_r10, r_mean 122 | 123 | print(formatting_result_head()) 124 | if args.std: 125 | mean_results = {k: np.mean(v) for k, v in multi_eval_aggr_result.items()} 126 | std_results = {k: np.std(v) for k, v in multi_eval_aggr_result.items()} 127 | 128 | print(formatting_result_content(mean_results)) 129 | print(formatting_result_content(std_results)) 130 | print(formatting_result_content_clean({k: "%.2f$\\pm$%.2f"%(mean_results[k],std_results[k]) for k in std_results})) 131 | 132 | print(args.image_encoder) 133 | 134 | 135 | 136 | if __name__ == '__main__': 137 | parser = argparse.ArgumentParser(description='Parameter Processing') 138 | 139 | parser.add_argument('--dataset', type=str, default='flickr30k', help='dataset') 140 | 141 | parser.add_argument('--eval_mode', type=str, default='S', 142 | help='eval_mode, check utils.py for more info') 143 | 144 | parser.add_argument('--num_eval', type=int, default=5, help='how many networks to evaluate on') 145 | 146 | parser.add_argument('--eval_it', type=int, default=50, help='how often to evaluate') 147 | 148 | parser.add_argument('--epoch_eval_train', type=int, default=100, help='epochs to train a model with synthetic data') 149 | parser.add_argument('--Iteration', type=int, default=3000, help='how many distillation steps to perform') 150 | 151 | parser.add_argument('--lr_img', type=float, default=1000, help='learning rate for updating synthetic images') 152 | parser.add_argument('--lr_txt', type=float, default=1000, help='learning rate for updating synthetic texts') 153 | parser.add_argument('--lr_lr', type=float, default=1e-03, help='learning rate for updating... learning rate') 154 | parser.add_argument('--lr_teacher_img', type=float, default=0.1, help='learning rate for updating network parameters') 155 | parser.add_argument('--lr_teacher_txt', type=float, default=0.1, help='learning rate for updating network parameters') 156 | 157 | parser.add_argument('--loss_type', type=str) 158 | 159 | parser.add_argument('--batch_train', type=int, default=128, help='batch size for training networks') 160 | 161 | parser.add_argument('--pix_init', type=str, default='real', choices=["noise", "real"], 162 | help='noise/real: initialize synthetic images from random noise or randomly sampled real images.') 163 | parser.add_argument('--txt_init', type=str, default='real', choices=["noise", "real"], 164 | help='noise/real: initialize synthetic texts from random noise or randomly sampled real images.') 165 | 166 | parser.add_argument('--dsa', type=str, default='True', choices=['True', 'False'], 167 | help='whether to use differentiable Siamese augmentation.') 168 | 169 | parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', 170 | help='differentiable Siamese augmentation strategy') 171 | 172 | parser.add_argument('--data_path', type=str, default='./data/Flickr30k/', help='dataset path') 173 | parser.add_argument('--buffer_path', type=str, default='./buffers', help='buffer path') 174 | 175 | parser.add_argument('--expert_epochs', type=int, default=3, help='how many expert epochs the target params are') 176 | parser.add_argument('--syn_steps', type=int, default=20, help='how many steps to take on synthetic data') 177 | parser.add_argument('--max_start_epoch', type=int, default=25, help='max epoch we can start at') 178 | 179 | parser.add_argument('--zca', action='store_true', help="do ZCA whitening") 180 | 181 | parser.add_argument('--load_all', action='store_true', help="only use if you can fit all expert trajectories into RAM") 182 | 183 | parser.add_argument('--no_aug', action="store_true", default=False, help='this turns off diff aug during distillation') 184 | 185 | parser.add_argument('--texture', action='store_true', help="will distill textures instead") 186 | parser.add_argument('--canvas_size', type=int, default=2, help='size of synthetic canvas') 187 | parser.add_argument('--canvas_samples', type=int, default=1, help='number of canvas samples per iteration') 188 | 189 | 190 | parser.add_argument('--max_files', type=int, default=None, help='number of expert files to read (leave as None unless doing ablations)') 191 | parser.add_argument('--max_experts', type=int, default=None, help='number of experts to read per file (leave as None unless doing ablations)') 192 | 193 | parser.add_argument('--force_save', action='store_true', help='this will save images for 50ipc') 194 | current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") 195 | parser.add_argument('--name', type=str, default=current_time, help='name of wandb run') 196 | parser.add_argument('--num_queries', type=int, default=100, help='number of queries') 197 | parser.add_argument('--mini_batch_size', type=int, default=100, help='number of queries') 198 | parser.add_argument('--basis', type=bool, default=False, help='whether use basis or not') 199 | parser.add_argument('--n_basis', type=int, default=64, help='n_basis') 200 | parser.add_argument('--recursive', type=bool, default=False, help='whether use basis or not') 201 | parser.add_argument('--load_npy', type=bool, default=False, help='load_npy') 202 | parser.add_argument('--image_size', type=int, default=224, help='image_size') 203 | parser.add_argument('--image_root', type=str, default='distill_utils/data/Flickr30k/', help='location of image root') 204 | parser.add_argument('--ann_root', type=str, default='./data/Flickr30k_ann/', help='location of ann root') 205 | parser.add_argument('--batch_size_train', type=int, default=128, help='batch_size_train') 206 | parser.add_argument('--batch_size_test', type=int, default=128, help='batch_size_test') 207 | parser.add_argument('--image_encoder', type=str, default='nfnet', help='image encoder') # , choices=['clip', 'nfnet', 'vit', 'nf_resnet50', "nf_regnet"] 208 | parser.add_argument('--text_encoder', type=str, default='bert', choices=['bert', 'clip', 'distilbert'], help='text encoder') 209 | parser.add_argument('--text_pretrained', type=bool, default=True, help='text_pretrained') 210 | parser.add_argument('--image_pretrained', type=bool, default=True, help='image_pretrained') 211 | parser.add_argument('--text_trainable', type=bool, default=False, help='text_trainable') 212 | parser.add_argument('--image_trainable', type=bool, default=True, help='image_trainable') 213 | parser.add_argument('--only_has_image_projection', type=bool, default=False, help='None') 214 | parser.add_argument('--distill', type=bool, default=True, help='whether distill') 215 | parser.add_argument('--optimize', type=str, default='reparam', choices=['reparam', 'ift'], help='matching_train') 216 | parser.add_argument('--image_only', type=bool, default=False, help='None') 217 | parser.add_argument('--text_only', type=bool, default=False, help='None') 218 | parser.add_argument('--draw', type=bool, default=False, help='None') 219 | parser.add_argument('--transfer', type=bool, default=False, help='transfer cross architecture') 220 | parser.add_argument('--std', type=bool, default=True, help='standard deviation') 221 | parser.add_argument('--disabled_wandb', type=bool, default=False, help='disable wandb') 222 | parser.add_argument('--test_with_norm', type=bool, default=False, help='') 223 | 224 | parser.add_argument('--clamp_lr', type=float, default=None, help='') 225 | 226 | 227 | # Arguments below are for LoRS 228 | 229 | parser.add_argument('--resume_from', default=None, type=str) 230 | 231 | parser.add_argument('--sim_type', type=str, default="full", choices=["full", "lowrank"], help='similarity matrix type') 232 | parser.add_argument('--sim_rank', type=int, default=10, help='similarity matrix rank') 233 | parser.add_argument('--alpha', type=float, default=0.1, help='alpha in LoRA') 234 | parser.add_argument('--lr_sim', type=float, default=1e-03, help='learning rate for updating similarity mat learning rate') 235 | parser.add_argument('--temperature', type=float, default=0.07, help="temperature of CLIP model") 236 | 237 | parser.add_argument('--momentum_lr', type=float, default=0.5) 238 | parser.add_argument('--momentum_syn', type=float, default=0.5) 239 | parser.add_argument('--momentum_sim', type=float, default=0.5) 240 | parser.add_argument('--merge_loss_branches', action="store_true", default=False) 241 | 242 | # Arguments below are for evaluation 243 | 244 | parser.add_argument('--ckpt_path', type=str) 245 | parser.add_argument('--syn_lr_img', type=float, default=None) 246 | parser.add_argument('--syn_lr_txt', type=float, default=None) 247 | parser.add_argument('--mom', type=float, default=0.9) 248 | parser.add_argument('--l2', type=float, default=0.0005) 249 | parser.add_argument('--clip_similarity', action="store_true", default=False) 250 | args = parser.parse_args() 251 | 252 | main(args) -------------------------------------------------------------------------------- /images/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/silicx/LoRS_Distill/8a712afdb30ed483199bd0f00cc73755e0817fe8/images/method.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | timm==0.9.8 4 | git+https://github.com/openai/CLIP.git 5 | kornia==0.6.3 6 | scikit-learn==1.1.2 7 | scipy==1.10.1 8 | matplotlib==3.6.0 9 | transformers 10 | wandb 11 | opencv-python 12 | transformers 13 | tqdm -------------------------------------------------------------------------------- /sh/buffer_coco.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 python buffer.py --dataset=coco --train_epochs=10 --num_experts=20 --buffer_path='buffer' --image_encoder=nfnet --text_encoder=bert --image_size=224 -------------------------------------------------------------------------------- /sh/buffer_flickr.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 python buffer.py --dataset=flickr --train_epochs=10 --num_experts=20 --buffer_path='buffer' --image_encoder=nfnet --text_encoder=bert --image_size=224 -------------------------------------------------------------------------------- /sh/distill_coco_lors_100.sh: -------------------------------------------------------------------------------- 1 | FILE_NAME=$(basename -- "$0") 2 | EXP_NAME="${FILE_NAME%.*}" 3 | 4 | export CUDA_VISIBLE_DEVICES=$1; 5 | python distill_tesla_lors.py --dataset=coco \ 6 | --buffer_path='buffer/coco/nfnet_bert/InfoNCE' --image_encoder=nfnet --text_encoder=bert \ 7 | --image_root='distill_utils/data/COCO/' --merge_loss_branches \ 8 | --syn_steps=8 --expert_epochs=1 --max_start_epoch=2 \ 9 | --lr_img=1000 --lr_txt=1000 --lr_lr=1e-2 \ 10 | --lr_teacher_img 0.1 --lr_teacher_txt 0.1 \ 11 | --lr_sim=5.0 --sim_type lowrank --sim_rank 10 --alpha 1.0 \ 12 | --num_queries 99 --mini_batch_size=20 \ 13 | --loss_type WBCE --name ${EXP_NAME} 14 | -------------------------------------------------------------------------------- /sh/distill_coco_lors_200.sh: -------------------------------------------------------------------------------- 1 | FILE_NAME=$(basename -- "$0") 2 | EXP_NAME="${FILE_NAME%.*}" 3 | 4 | export CUDA_VISIBLE_DEVICES=$1; 5 | python distill_tesla_lors.py --dataset=coco \ 6 | --buffer_path='buffer/coco/nfnet_bert/InfoNCE' --image_encoder=nfnet --text_encoder=bert \ 7 | --image_root='distill_utils/data/COCO/' --merge_loss_branches \ 8 | --syn_steps=8 --expert_epochs=1 --max_start_epoch=2 \ 9 | --lr_img=1000 --lr_txt=1000 --lr_lr=1e-2 \ 10 | --lr_teacher_img 0.1 --lr_teacher_txt 0.1 \ 11 | --lr_sim=50 --sim_type lowrank --sim_rank 20 --alpha 1.0 \ 12 | --num_queries 199 --mini_batch_size=20 \ 13 | --loss_type WBCE --name ${EXP_NAME} --Iteration 2000 -------------------------------------------------------------------------------- /sh/distill_coco_lors_500.sh: -------------------------------------------------------------------------------- 1 | FILE_NAME=$(basename -- "$0") 2 | EXP_NAME="${FILE_NAME%.*}" 3 | 4 | export CUDA_VISIBLE_DEVICES=$1; 5 | python distill_tesla_lors.py --dataset=coco \ 6 | --buffer_path='buffer/coco/nfnet_bert/InfoNCE' --image_encoder=nfnet --text_encoder=bert \ 7 | --image_root='distill_utils/data/COCO/' --merge_loss_branches \ 8 | --syn_steps=8 --expert_epochs=1 --max_start_epoch=2 \ 9 | --lr_img=5000 --lr_txt=5000 --lr_lr=1e-2 \ 10 | --lr_teacher_img 0.1 --lr_teacher_txt 0.1 \ 11 | --lr_sim=500 --sim_type lowrank --sim_rank 40 --alpha 1.0 \ 12 | --num_queries 499 --mini_batch_size=30 \ 13 | --temperature 0.1 --no_aug \ 14 | --loss_type WBCE --name ${EXP_NAME} -------------------------------------------------------------------------------- /sh/distill_flickr_lors_100.sh: -------------------------------------------------------------------------------- 1 | FILE_NAME=$(basename -- "$0") 2 | EXP_NAME="${FILE_NAME%.*}" 3 | 4 | export CUDA_VISIBLE_DEVICES=$1; 5 | python distill_tesla_lors.py --dataset=flickr \ 6 | --buffer_path='buffer/flickr/nfnet_bert/InfoNCE' --image_encoder=nfnet --text_encoder=bert \ 7 | --syn_steps=8 --expert_epochs=1 --max_start_epoch=2 \ 8 | --lr_img=100 --lr_txt=100 --lr_lr=1e-2 \ 9 | --lr_teacher_img 0.1 --lr_teacher_txt 0.1 \ 10 | --lr_sim=10.0 --sim_type lowrank --sim_rank 10 --alpha 3 \ 11 | --num_queries 99 --mini_batch_size=20 \ 12 | --loss_type WBCE --name ${EXP_NAME} -------------------------------------------------------------------------------- /sh/distill_flickr_lors_200.sh: -------------------------------------------------------------------------------- 1 | FILE_NAME=$(basename -- "$0") 2 | EXP_NAME="${FILE_NAME%.*}" 3 | 4 | export CUDA_VISIBLE_DEVICES=$1; 5 | python distill_tesla_lors.py --dataset=flickr \ 6 | --buffer_path='buffer/flickr/nfnet_bert/InfoNCE' --image_encoder=nfnet --text_encoder=bert \ 7 | --syn_steps=8 --expert_epochs=1 --max_start_epoch=2 \ 8 | --lr_img=1000 --lr_txt=1000 --lr_lr=1e-2 \ 9 | --lr_teacher_img 0.1 --lr_teacher_txt 0.1 \ 10 | --lr_sim=10.0 --sim_type lowrank --sim_rank 5 --alpha 1.0 \ 11 | --num_queries 199 --mini_batch_size=20 \ 12 | --loss_type WBCE --name ${EXP_NAME} -------------------------------------------------------------------------------- /sh/distill_flickr_lors_500.sh: -------------------------------------------------------------------------------- 1 | FILE_NAME=$(basename -- "$0") 2 | EXP_NAME="${FILE_NAME%.*}" 3 | 4 | export CUDA_VISIBLE_DEVICES=$1; 5 | python distill_tesla_lors.py --dataset=flickr \ 6 | --buffer_path='buffer/flickr/nfnet_bert/InfoNCE' --image_encoder=nfnet --text_encoder=bert \ 7 | --syn_steps=8 --expert_epochs=1 --max_start_epoch=3 \ 8 | --lr_img=1000 --lr_txt=1000 --lr_lr=1e-2 \ 9 | --lr_teacher_img 0.1 --lr_teacher_txt 0.1 \ 10 | --lr_sim=100 --sim_type lowrank --sim_rank 20 --alpha 0.01 \ 11 | --num_queries 499 --mini_batch_size=40 \ 12 | --loss_type WBCE --name ${EXP_NAME} \ 13 | --eval_it 300 -------------------------------------------------------------------------------- /src/epoch.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * part of the code (i.e. def epoch_test() and itm_eval()) is from: https://github.com/salesforce/BLIP/blob/main/train_retrieval.py#L69 3 | * Copyright (c) 2022, salesforce.com, inc. 4 | * All rights reserved. 5 | * SPDX-License-Identifier: BSD-3-Clause 6 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 7 | * By Junnan Li 8 | ''' 9 | from math import ceil 10 | import numpy as np 11 | import torch 12 | import time 13 | import datetime 14 | import torch 15 | import torch.nn.functional as F 16 | from tqdm import tqdm 17 | import torch.nn as nn 18 | from src.utils import * 19 | 20 | 21 | 22 | 23 | 24 | def epoch(e, dataloader, net, optimizer, args): 25 | """ 26 | Perform a training epoch on the given dataloader. 27 | 28 | Args: 29 | dataloader (torch.utils.data.DataLoader): The dataloader for iterating over the dataset. 30 | net: The model. 31 | optimizer_img: The optimizer for image parameters. 32 | optimizer_txt: The optimizer for text parameters. 33 | args (object): The arguments specifying the training configuration. 34 | 35 | Returns: 36 | Tuple of average loss and average accuracy. 37 | """ 38 | net = net.to(args.device) 39 | net.train() 40 | loss_avg, acc_avg, num_exp = 0, 0, 0 41 | 42 | require_similarity = isinstance(dataloader, (SimilarityDataloader, SimilarityDataLoaderWrapper)) 43 | 44 | for i, data in (enumerate(dataloader)): 45 | if args.distill: 46 | if require_similarity: 47 | image, caption, similarity = data 48 | else: 49 | image, caption = data[:2] 50 | else: 51 | if require_similarity: 52 | image, caption, index, similarity = data 53 | else: 54 | image, caption, index = data[:3] 55 | 56 | image = image.to(args.device) 57 | n_b = image.shape[0] 58 | 59 | if require_similarity: 60 | loss, acc = net(image, caption, e, similarity) 61 | else: 62 | loss, acc = net(image, caption, e) 63 | 64 | loss_avg += loss.item() * n_b 65 | acc_avg += acc 66 | num_exp += n_b 67 | 68 | optimizer.zero_grad() 69 | loss.backward() 70 | optimizer.step() 71 | 72 | loss_avg /= num_exp 73 | acc_avg /= num_exp 74 | 75 | return loss_avg, acc_avg 76 | 77 | 78 | 79 | 80 | 81 | @torch.no_grad() 82 | def epoch_test(dataloader, model, device, bert_test_embed): 83 | model.eval() 84 | logit_scale = model.logit_scale.detach() 85 | start_time = time.time() 86 | 87 | 88 | txt_embed = model.text_projection(bert_test_embed.float().to('cuda')) 89 | text_embeds = txt_embed / txt_embed.norm(dim=1, keepdim=True) #torch.Size([5000, 768]) 90 | text_embeds = text_embeds.to(device) 91 | 92 | image_embeds = [] 93 | for image, img_id in dataloader: 94 | image_feat = model.image_encoder(image.to(device)) 95 | im_embed = image_feat / image_feat.norm(dim=1, keepdim=True) 96 | image_embeds.append(im_embed) 97 | image_embeds = torch.cat(image_embeds,dim=0) 98 | use_image_projection = False 99 | if use_image_projection: 100 | im_embed = model.image_projection(image_embeds.float()) 101 | image_embeds = im_embed / im_embed.norm(dim=1, keepdim=True) 102 | else: 103 | image_embeds = image_embeds / image_embeds.norm(dim=1, keepdim=True) 104 | 105 | sims_matrix = logit_scale.exp() * image_embeds @ text_embeds.t() 106 | score_matrix_i2t = torch.full((len(image_embeds),len(text_embeds)),-100.0).to(device) #torch.Size([1000, 5000]) 107 | for i, sims in enumerate(sims_matrix[0:sims_matrix.size(0) + 1]): 108 | topk_sim, topk_idx = sims.topk(k=128, dim=0) 109 | score_matrix_i2t[i,topk_idx] = topk_sim #i:0-999, topk_idx:0-4999, find top k (k=128) similar text for each image 110 | 111 | sims_matrix = sims_matrix.t() 112 | score_matrix_t2i = torch.full((len(text_embeds),len(image_embeds)),-100.0).to(device) 113 | for i,sims in enumerate(sims_matrix[0:sims_matrix.size(0) + 1]): 114 | topk_sim, topk_idx = sims.topk(k=128, dim=0) 115 | score_matrix_t2i[i,topk_idx] = topk_sim 116 | 117 | total_time = time.time() - start_time 118 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 119 | print('Evaluation time {}'.format(total_time_str)) 120 | 121 | return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy() 122 | 123 | 124 | @torch.no_grad() 125 | def itm_eval(scores_i2t, scores_t2i, txt2img, img2txt): 126 | 127 | #Images->Text 128 | ranks = np.zeros(scores_i2t.shape[0]) 129 | # print("TR: ", len(ranks)) 130 | for index, score in enumerate(scores_i2t): 131 | inds = np.argsort(score)[::-1] 132 | # Score 133 | rank = 1e20 134 | for i in img2txt[index]: 135 | tmp = np.where(inds == i)[0][0] 136 | if tmp < rank: 137 | rank = tmp 138 | ranks[index] = rank 139 | 140 | # Compute metrics 141 | tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 142 | tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 143 | tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 144 | 145 | #Text->Images 146 | ranks = np.zeros(scores_t2i.shape[0]) 147 | # print("IR: ", len(ranks)) 148 | 149 | for index,score in enumerate(scores_t2i): 150 | inds = np.argsort(score)[::-1] 151 | ranks[index] = np.where(inds == txt2img[index])[0][0] 152 | 153 | # Compute metrics 154 | ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 155 | ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 156 | ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 157 | 158 | tr_mean = (tr1 + tr5 + tr10) / 3 159 | ir_mean = (ir1 + ir5 + ir10) / 3 160 | r_mean = (tr_mean + ir_mean) / 2 161 | 162 | eval_result = {'txt_r1': tr1, 163 | 'txt_r5': tr5, 164 | 'txt_r10': tr10, 165 | 'txt_r_mean': tr_mean, 166 | 'img_r1': ir1, 167 | 'img_r5': ir5, 168 | 'img_r10': ir10, 169 | 'img_r_mean': ir_mean, 170 | 'r_mean': r_mean} 171 | return eval_result 172 | 173 | 174 | def evaluate_synset(it_eval, net, images_train, labels_train, testloader, args, bert_test_embed, return_loss=False): 175 | 176 | net = net.to(args.device) 177 | images_train = images_train.to(args.device) 178 | labels_train = labels_train.to(args.device) 179 | Epoch = int(args.epoch_eval_train) 180 | optimizer = torch.optim.SGD([ 181 | {'params': net.image_encoder.parameters(), 'lr': args.lr_teacher_img}, 182 | {'params': net.text_projection.parameters(), 'lr': args.lr_teacher_txt}, 183 | ], lr=0, momentum=0.9, weight_decay=0.0005) 184 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[Epoch//2+1], gamma=0.1) 185 | 186 | dst_train = TensorDataset(images_train, labels_train) 187 | trainloader = torch.utils.data.DataLoader(dst_train, batch_size=args.batch_train, shuffle=True, num_workers=0) 188 | 189 | start = time.time() 190 | acc_train_list = [] 191 | loss_train_list = [] 192 | 193 | eval_epochs = [Epoch] 194 | eval_freq = Epoch // 10 195 | eval_epochs = list(range(0, Epoch+1, eval_freq)) + [Epoch] # eval 10 times in total 196 | 197 | best_val_result = None 198 | for ep in tqdm(range(Epoch+1), ncols=60): 199 | loss_train, acc_train = epoch(ep, trainloader, net, optimizer, args) 200 | acc_train_list.append(acc_train) 201 | loss_train_list.append(loss_train) 202 | if ep in eval_epochs: 203 | with torch.no_grad(): 204 | score_val_i2t, score_val_t2i = epoch_test(testloader, net, args.device, bert_test_embed) 205 | val_result = itm_eval(score_val_i2t, score_val_t2i, testloader.dataset.txt2img, testloader.dataset.img2txt) 206 | 207 | print("[Eval_{it:02d}] Ep{ep} | Image R@1={img_r1:.2f} R@5={img_r5:.2f} R@10={img_r10:.2f} | Text R@1={txt_r1:.2f} R@5={txt_r5:.2f} R@10={txt_r10:.2f} | Mean={r_mean:.2f}".format( 208 | it=it_eval, ep=ep, **val_result 209 | )) 210 | if best_val_result is None or val_result["r_mean"] > best_val_result["r_mean"]: 211 | best_val_result = val_result 212 | 213 | lr_scheduler.step() 214 | 215 | time_train = time.time() - start 216 | return net, acc_train_list, best_val_result 217 | 218 | 219 | 220 | 221 | 222 | class SimilarityDataLoaderWrapper: 223 | """make a regular dataloader looks like a SimilarityDataloader, 224 | by return N samples and a N*N identity similarity matrix together""" 225 | 226 | def __init__(self, dataloader): 227 | super().__init__() 228 | self.dataloader = dataloader 229 | 230 | def __iter__(self): 231 | for data in self.dataloader: 232 | bz = data[0].shape[0] 233 | yield data + [torch.eye(bz, dtype=torch.float32).to(data[0].device)] 234 | 235 | 236 | class SimilarityDataloader: 237 | def __init__(self, images_train, labels_train, similarity_train, batch_size, drop_last=False): 238 | """images_train, labels_train: N samples 239 | similarity_train: N*N similarity matrix, similarity_train[i,j] is the similarity between i-th and j-th samples""" 240 | super().__init__() 241 | 242 | self.images_train = images_train 243 | self.labels_train = labels_train 244 | self.similarity_train = similarity_train 245 | self.batch_size = batch_size 246 | assert not drop_last 247 | 248 | def __iter__(self): 249 | size = self.images_train.shape[0] 250 | num_batch = int(ceil(size / self.batch_size)) 251 | indices = np.arange(size) 252 | for _ in range(num_batch): 253 | np.random.shuffle(indices) 254 | ids = indices[:self.batch_size] 255 | yield self.images_train[ids], self.labels_train[ids], self.similarity_train[ids[:,None], ids] 256 | 257 | 258 | 259 | 260 | def evaluate_synset_with_similarity(it_eval, net, images_train, labels_train, lr_img, lr_txt, similarity_train, testloader, args, bert_test_embed, mom=0.9, l2=0.0005): 261 | """Added for LoRS, see Algorithm 2 in the paper""" 262 | 263 | net = net.to(args.device) 264 | images_train = images_train.to(args.device) 265 | labels_train = labels_train.to(args.device) 266 | similarity_train = similarity_train.to(args.device) 267 | Epoch = int(args.epoch_eval_train) 268 | optimizer = torch.optim.SGD([ 269 | {'params': net.image_encoder.parameters(), 'lr': lr_img}, 270 | {'params': net.text_projection.parameters(), 'lr': lr_txt}, 271 | ], lr=0, momentum=mom, weight_decay=l2) 272 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[Epoch//2+1], gamma=0.1) 273 | 274 | trainloader = SimilarityDataloader(images_train, labels_train, similarity_train, batch_size=args.batch_train) 275 | 276 | start = time.time() 277 | acc_train_list = [] 278 | loss_train_list = [] 279 | 280 | eval_freq = Epoch // 10 281 | eval_epochs = list(range(0, Epoch+1, eval_freq)) + [Epoch] # eval 10 times in total 282 | 283 | best_val_result = None 284 | for ep in tqdm(range(Epoch+1), ncols=60): 285 | loss_train, acc_train = epoch(ep, trainloader, net, optimizer, args) 286 | acc_train_list.append(acc_train) 287 | loss_train_list.append(loss_train) 288 | if ep in eval_epochs: 289 | with torch.no_grad(): 290 | score_val_i2t, score_val_t2i = epoch_test(testloader, net, args.device, bert_test_embed) 291 | val_result = itm_eval(score_val_i2t, score_val_t2i, testloader.dataset.txt2img, testloader.dataset.img2txt) 292 | 293 | print("[Eval_{it:02d}] Ep{ep} | Image R@1={img_r1:.2f} R@5={img_r5:.2f} R@10={img_r10:.2f} | Text R@1={txt_r1:.2f} R@5={txt_r5:.2f} R@10={txt_r10:.2f} | Mean={r_mean:.2f}".format( 294 | it=it_eval, ep=ep, **val_result 295 | )) 296 | if best_val_result is None or val_result["r_mean"] > best_val_result["r_mean"]: 297 | best_val_result = val_result 298 | 299 | lr_scheduler.step() 300 | 301 | time_train = time.time() - start 302 | assert best_val_result is not None 303 | return net, acc_train_list, best_val_result 304 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | # Sourced directly from OpenAI's CLIP repo 2 | from collections import OrderedDict 3 | from typing import Tuple, Union 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | 10 | 11 | class Bottleneck(nn.Module): 12 | expansion = 4 13 | 14 | def __init__(self, inplanes, planes, stride=1): 15 | super().__init__() 16 | 17 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 18 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | 24 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 25 | 26 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 27 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 28 | 29 | self.relu = nn.ReLU(inplace=True) 30 | self.downsample = None 31 | self.stride = stride 32 | 33 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 34 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 35 | self.downsample = nn.Sequential(OrderedDict([ 36 | ("-1", nn.AvgPool2d(stride)), 37 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 38 | ("1", nn.BatchNorm2d(planes * self.expansion)) 39 | ])) 40 | 41 | def forward(self, x: torch.Tensor): 42 | identity = x 43 | 44 | out = self.relu(self.bn1(self.conv1(x))) 45 | out = self.relu(self.bn2(self.conv2(out))) 46 | out = self.avgpool(out) 47 | out = self.bn3(self.conv3(out)) 48 | 49 | if self.downsample is not None: 50 | identity = self.downsample(x) 51 | 52 | out += identity 53 | out = self.relu(out) 54 | return out 55 | 56 | 57 | class AttentionPool2d(nn.Module): 58 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 59 | super().__init__() 60 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 61 | self.k_proj = nn.Linear(embed_dim, embed_dim) 62 | self.q_proj = nn.Linear(embed_dim, embed_dim) 63 | self.v_proj = nn.Linear(embed_dim, embed_dim) 64 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 65 | self.num_heads = num_heads 66 | 67 | def forward(self, x): 68 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 69 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 70 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 71 | x, _ = F.multi_head_attention_forward( 72 | query=x, key=x, value=x, 73 | embed_dim_to_check=x.shape[-1], 74 | num_heads=self.num_heads, 75 | q_proj_weight=self.q_proj.weight, 76 | k_proj_weight=self.k_proj.weight, 77 | v_proj_weight=self.v_proj.weight, 78 | in_proj_weight=None, 79 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 80 | bias_k=None, 81 | bias_v=None, 82 | add_zero_attn=False, 83 | dropout_p=0, 84 | out_proj_weight=self.c_proj.weight, 85 | out_proj_bias=self.c_proj.bias, 86 | use_separate_proj_weight=True, 87 | training=self.training, 88 | need_weights=False 89 | ) 90 | 91 | return x[0] 92 | 93 | 94 | 95 | class LayerNorm(nn.LayerNorm): 96 | """Subclass torch's LayerNorm to handle fp16.""" 97 | 98 | def forward(self, x: torch.Tensor): 99 | orig_type = x.dtype 100 | ret = super().forward(x.type(torch.float32)) 101 | return ret.type(orig_type) 102 | 103 | 104 | class QuickGELU(nn.Module): 105 | def forward(self, x: torch.Tensor): 106 | return x * torch.sigmoid(1.702 * x) 107 | 108 | 109 | class ResidualAttentionBlock(nn.Module): 110 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 111 | super().__init__() 112 | 113 | self.attn = nn.MultiheadAttention(d_model, n_head) 114 | self.ln_1 = LayerNorm(d_model) 115 | self.mlp = nn.Sequential(OrderedDict([ 116 | ("c_fc", nn.Linear(d_model, d_model * 4)), 117 | ("gelu", QuickGELU()), 118 | ("c_proj", nn.Linear(d_model * 4, d_model)) 119 | ])) 120 | self.ln_2 = LayerNorm(d_model) 121 | self.attn_mask = attn_mask 122 | 123 | def attention(self, x: torch.Tensor): 124 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 125 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 126 | 127 | def forward(self, x: torch.Tensor): 128 | x = x + self.attention(self.ln_1(x)) 129 | x = x + self.mlp(self.ln_2(x)) 130 | return x 131 | 132 | 133 | 134 | def convert_weights(model: nn.Module): 135 | """Convert applicable model parameters to fp16""" 136 | 137 | def _convert_weights_to_fp16(l): 138 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 139 | l.weight.data = l.weight.data.half() 140 | if l.bias is not None: 141 | l.bias.data = l.bias.data.half() 142 | 143 | if isinstance(l, nn.MultiheadAttention): 144 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 145 | tensor = getattr(l, attr) 146 | if tensor is not None: 147 | tensor.data = tensor.data.half() 148 | 149 | for name in ["text_projection", "proj"]: 150 | if hasattr(l, name): 151 | attr = getattr(l, name) 152 | if attr is not None: 153 | attr.data = attr.data.half() 154 | 155 | model.apply(_convert_weights_to_fp16) 156 | 157 | 158 | def build_model(state_dict: dict): 159 | vit = "visual.proj" in state_dict 160 | 161 | if vit: 162 | vision_width = state_dict["visual.conv1.weight"].shape[0] 163 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 164 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 165 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 166 | image_resolution = vision_patch_size * grid_size 167 | else: 168 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 169 | vision_layers = tuple(counts) 170 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 171 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 172 | vision_patch_size = None 173 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 174 | image_resolution = output_width * 32 175 | 176 | embed_dim = state_dict["text_projection"].shape[1] 177 | context_length = state_dict["positional_embedding"].shape[0] 178 | vocab_size = state_dict["token_embedding.weight"].shape[0] 179 | transformer_width = state_dict["ln_final.weight"].shape[0] 180 | transformer_heads = transformer_width // 64 181 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 182 | 183 | model = CLIP( 184 | embed_dim, 185 | image_resolution, vision_layers, vision_width, vision_patch_size, 186 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 187 | ) 188 | 189 | for key in ["input_resolution", "context_length", "vocab_size"]: 190 | if key in state_dict: 191 | del state_dict[key] 192 | 193 | convert_weights(model) 194 | model.load_state_dict(state_dict) 195 | return model.eval() -------------------------------------------------------------------------------- /src/networks.py: -------------------------------------------------------------------------------- 1 | # import pdb 2 | # import warnings 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch 6 | # from collections import OrderedDict 7 | # from typing import Tuple, Union 8 | import clip 9 | # from transformers import ViTConfig, ViTModel, AutoTokenizer, CLIPTextModel, CLIPTextConfig, CLIPProcessor, CLIPConfig 10 | import numpy as np 11 | from transformers import BertTokenizer, BertModel, DistilBertModel, DistilBertTokenizer, OpenAIGPTTokenizer, OpenAIGPTModel 12 | # from torchvision.models import resnet18, resnet 13 | from transformers.models.bert.modeling_bert import BertAttention, BertConfig 14 | import functools 15 | import copy 16 | 17 | from .similarity_mining import MultilabelContrastiveLoss 18 | 19 | 20 | # Caching text models, by Yue Xu 21 | 22 | @functools.lru_cache(maxsize=128) 23 | def get_bert_stuff(): 24 | tokenizer = BertTokenizer.from_pretrained('./distill_utils/checkpoints/bert-base-uncased') 25 | BERT_model = BertModel.from_pretrained('./distill_utils/checkpoints/bert-base-uncased') 26 | return BERT_model, tokenizer 27 | 28 | @functools.lru_cache(maxsize=128) 29 | def get_distilbert_stuff(): 30 | DistilBERT_model_name = "distilbert-base-uncased" 31 | DistilBERT_model = DistilBertModel.from_pretrained(DistilBERT_model_name, local_files_only=True) 32 | DistilBERT_tokenizer = DistilBertTokenizer.from_pretrained(DistilBERT_model_name, local_files_only=True) 33 | return DistilBERT_model, DistilBERT_tokenizer 34 | 35 | @functools.lru_cache(maxsize=128) 36 | def get_gpt1_stuff(): 37 | model_name = './distill_utils/checkpoints/openai-gpt' 38 | model = OpenAIGPTModel.from_pretrained(model_name) 39 | tokenizer = OpenAIGPTTokenizer.from_pretrained(model_name) 40 | return model, tokenizer 41 | 42 | 43 | # Acknowledgement to 44 | # https://github.com/kuangliu/pytorch-cifar, 45 | # https://github.com/BIGBALLON/CIFAR-ZOO, 46 | 47 | # adapted from 48 | # https://github.com/VICO-UoE/DatasetCondensation 49 | # https://github.com/Zasder3/train-CLIP 50 | 51 | 52 | ''' MLP ''' 53 | class MLP(nn.Module): 54 | def __init__(self, channel, num_classes): 55 | super(MLP, self).__init__() 56 | self.fc_1 = nn.Linear(28*28*1 if channel==1 else 32*32*3, 128) 57 | self.fc_2 = nn.Linear(128, 128) 58 | self.fc_3 = nn.Linear(128, num_classes) 59 | 60 | def forward(self, x): 61 | out = x.view(x.size(0), -1) 62 | out = F.relu(self.fc_1(out)) 63 | out = F.relu(self.fc_2(out)) 64 | out = self.fc_3(out) 65 | return out 66 | 67 | 68 | 69 | ''' ConvNet ''' 70 | class ConvNet(nn.Module): 71 | def __init__(self, channel, num_classes, net_width=128, net_depth=4, net_act='relu', net_norm='instancenorm', net_pooling='avgpooling', im_size = (224,224)): 72 | super(ConvNet, self).__init__() 73 | 74 | self.features, shape_feat = self._make_layers(channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size) 75 | num_feat = shape_feat[0]*shape_feat[1]*shape_feat[2] 76 | self.classifier = nn.Linear(num_feat, num_classes) 77 | 78 | def forward(self, x): 79 | # print("MODEL DATA ON: ", x.get_device(), "MODEL PARAMS ON: ", self.classifier.weight.data.get_device()) 80 | out = self.features(x) 81 | out = out.view(out.size(0), -1) 82 | out = self.classifier(out) 83 | return out 84 | 85 | def _get_activation(self, net_act): 86 | if net_act == 'sigmoid': 87 | return nn.Sigmoid() 88 | elif net_act == 'relu': 89 | return nn.ReLU(inplace=True) 90 | elif net_act == 'leakyrelu': 91 | return nn.LeakyReLU(negative_slope=0.01) 92 | else: 93 | exit('unknown activation function: %s'%net_act) 94 | 95 | def _get_pooling(self, net_pooling): 96 | if net_pooling == 'maxpooling': 97 | return nn.MaxPool2d(kernel_size=2, stride=2) 98 | elif net_pooling == 'avgpooling': 99 | return nn.AvgPool2d(kernel_size=2, stride=2) 100 | elif net_pooling == 'none': 101 | return None 102 | else: 103 | exit('unknown net_pooling: %s'%net_pooling) 104 | 105 | def _get_normlayer(self, net_norm, shape_feat): 106 | # shape_feat = (c*h*w) 107 | if net_norm == 'batchnorm': 108 | return nn.BatchNorm2d(shape_feat[0], affine=True) 109 | elif net_norm == 'layernorm': 110 | return nn.LayerNorm(shape_feat, elementwise_affine=True) 111 | elif net_norm == 'instancenorm': 112 | return nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True) 113 | elif net_norm == 'groupnorm': 114 | return nn.GroupNorm(4, shape_feat[0], affine=True) 115 | elif net_norm == 'none': 116 | return None 117 | else: 118 | exit('unknown net_norm: %s'%net_norm) 119 | 120 | def _make_layers(self, channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size): 121 | layers = [] 122 | in_channels = channel 123 | if im_size[0] == 28: 124 | im_size = (32, 32) 125 | shape_feat = [in_channels, im_size[0], im_size[1]] 126 | for d in range(net_depth): 127 | layers += [nn.Conv2d(in_channels, net_width, kernel_size=3, padding=3 if channel == 1 and d == 0 else 1)] 128 | shape_feat[0] = net_width 129 | if net_norm != 'none': 130 | layers += [self._get_normlayer(net_norm, shape_feat)] 131 | layers += [self._get_activation(net_act)] 132 | in_channels = net_width 133 | if net_pooling != 'none': 134 | layers += [self._get_pooling(net_pooling)] 135 | shape_feat[1] //= 2 136 | shape_feat[2] //= 2 137 | 138 | 139 | return nn.Sequential(*layers), shape_feat 140 | 141 | 142 | ''' ConvNet ''' 143 | class ConvNetGAP(nn.Module): 144 | def __init__(self, channel, num_classes, net_width, net_depth, net_act, net_norm, net_pooling, im_size = (32,32)): 145 | super(ConvNetGAP, self).__init__() 146 | 147 | self.features, shape_feat = self._make_layers(channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size) 148 | num_feat = shape_feat[0]*shape_feat[1]*shape_feat[2] 149 | # self.classifier = nn.Linear(num_feat, num_classes) 150 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 151 | self.classifier = nn.Linear(shape_feat[0], num_classes) 152 | 153 | def forward(self, x): 154 | out = self.features(x) 155 | out = self.avgpool(out) 156 | out = out.view(out.size(0), -1) 157 | out = self.classifier(out) 158 | return out 159 | 160 | def _get_activation(self, net_act): 161 | if net_act == 'sigmoid': 162 | return nn.Sigmoid() 163 | elif net_act == 'relu': 164 | return nn.ReLU(inplace=True) 165 | elif net_act == 'leakyrelu': 166 | return nn.LeakyReLU(negative_slope=0.01) 167 | else: 168 | exit('unknown activation function: %s'%net_act) 169 | 170 | def _get_pooling(self, net_pooling): 171 | if net_pooling == 'maxpooling': 172 | return nn.MaxPool2d(kernel_size=2, stride=2) 173 | elif net_pooling == 'avgpooling': 174 | return nn.AvgPool2d(kernel_size=2, stride=2) 175 | elif net_pooling == 'none': 176 | return None 177 | else: 178 | exit('unknown net_pooling: %s'%net_pooling) 179 | 180 | def _get_normlayer(self, net_norm, shape_feat): 181 | # shape_feat = (c*h*w) 182 | if net_norm == 'batchnorm': 183 | return nn.BatchNorm2d(shape_feat[0], affine=True) 184 | elif net_norm == 'layernorm': 185 | return nn.LayerNorm(shape_feat, elementwise_affine=True) 186 | elif net_norm == 'instancenorm': 187 | return nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True) 188 | elif net_norm == 'groupnorm': 189 | return nn.GroupNorm(4, shape_feat[0], affine=True) 190 | elif net_norm == 'none': 191 | return None 192 | else: 193 | exit('unknown net_norm: %s'%net_norm) 194 | 195 | def _make_layers(self, channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size): 196 | layers = [] 197 | in_channels = channel 198 | if im_size[0] == 28: 199 | im_size = (32, 32) 200 | shape_feat = [in_channels, im_size[0], im_size[1]] 201 | for d in range(net_depth): 202 | layers += [nn.Conv2d(in_channels, net_width, kernel_size=3, padding=3 if channel == 1 and d == 0 else 1)] 203 | shape_feat[0] = net_width 204 | if net_norm != 'none': 205 | layers += [self._get_normlayer(net_norm, shape_feat)] 206 | layers += [self._get_activation(net_act)] 207 | in_channels = net_width 208 | if net_pooling != 'none': 209 | layers += [self._get_pooling(net_pooling)] 210 | shape_feat[1] //= 2 211 | shape_feat[2] //= 2 212 | 213 | return nn.Sequential(*layers), shape_feat 214 | 215 | 216 | ''' LeNet ''' 217 | class LeNet(nn.Module): 218 | def __init__(self, channel, num_classes): 219 | super(LeNet, self).__init__() 220 | self.features = nn.Sequential( 221 | nn.Conv2d(channel, 6, kernel_size=5, padding=2 if channel==1 else 0), 222 | nn.ReLU(inplace=True), 223 | nn.MaxPool2d(kernel_size=2, stride=2), 224 | nn.Conv2d(6, 16, kernel_size=5), 225 | nn.ReLU(inplace=True), 226 | nn.MaxPool2d(kernel_size=2, stride=2), 227 | ) 228 | self.fc_1 = nn.Linear(16 * 5 * 5, 120) 229 | self.fc_2 = nn.Linear(120, 84) 230 | self.fc_3 = nn.Linear(84, num_classes) 231 | 232 | def forward(self, x): 233 | x = self.features(x) 234 | x = x.view(x.size(0), -1) 235 | x = F.relu(self.fc_1(x)) 236 | x = F.relu(self.fc_2(x)) 237 | x = self.fc_3(x) 238 | return x 239 | 240 | 241 | 242 | ''' AlexNet ''' 243 | class AlexNet(nn.Module): 244 | def __init__(self, channel, num_classes): 245 | super(AlexNet, self).__init__() 246 | self.features = nn.Sequential( 247 | nn.Conv2d(channel, 128, kernel_size=5, stride=1, padding=4 if channel==1 else 2), 248 | nn.ReLU(inplace=True), 249 | nn.MaxPool2d(kernel_size=2, stride=2), 250 | nn.Conv2d(128, 192, kernel_size=5, padding=2), 251 | nn.ReLU(inplace=True), 252 | nn.MaxPool2d(kernel_size=2, stride=2), 253 | nn.Conv2d(192, 256, kernel_size=3, padding=1), 254 | nn.ReLU(inplace=True), 255 | nn.Conv2d(256, 192, kernel_size=3, padding=1), 256 | nn.ReLU(inplace=True), 257 | nn.Conv2d(192, 192, kernel_size=3, padding=1), 258 | nn.ReLU(inplace=True), 259 | nn.MaxPool2d(kernel_size=2, stride=2), 260 | ) 261 | self.fc = nn.Linear(192 * 4 * 4, num_classes) 262 | 263 | def forward(self, x): 264 | x = self.features(x) 265 | x = x.view(x.size(0), -1) 266 | x = self.fc(x) 267 | return x 268 | 269 | 270 | 271 | ''' VGG ''' 272 | cfg_vgg = { 273 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 274 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 275 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 276 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 277 | } 278 | class VGG(nn.Module): 279 | def __init__(self, vgg_name, channel, num_classes, norm='instancenorm'): 280 | super(VGG, self).__init__() 281 | self.channel = channel 282 | self.features = self._make_layers(cfg_vgg[vgg_name], norm) 283 | self.classifier = nn.Linear(512 if vgg_name != 'VGGS' else 128, num_classes) 284 | 285 | def forward(self, x): 286 | x = self.features(x) 287 | x = x.view(x.size(0), -1) 288 | x = self.classifier(x) 289 | return x 290 | 291 | def _make_layers(self, cfg, norm): 292 | layers = [] 293 | in_channels = self.channel 294 | for ic, x in enumerate(cfg): 295 | if x == 'M': 296 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 297 | else: 298 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=3 if self.channel==1 and ic==0 else 1), 299 | nn.GroupNorm(x, x, affine=True) if norm=='instancenorm' else nn.BatchNorm2d(x), 300 | nn.ReLU(inplace=True)] 301 | in_channels = x 302 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 303 | return nn.Sequential(*layers) 304 | 305 | 306 | def VGG11(channel, num_classes): 307 | return VGG('VGG11', channel, num_classes) 308 | def VGG11BN(channel, num_classes): 309 | return VGG('VGG11', channel, num_classes, norm='batchnorm') 310 | def VGG13(channel, num_classes): 311 | return VGG('VGG13', channel, num_classes) 312 | def VGG16(channel, num_classes): 313 | return VGG('VGG16', channel, num_classes) 314 | def VGG19(channel, num_classes): 315 | return VGG('VGG19', channel, num_classes) 316 | 317 | 318 | ''' ResNet_AP ''' 319 | # The conv(stride=2) is replaced by conv(stride=1) + avgpool(kernel_size=2, stride=2) 320 | 321 | class BasicBlock_AP(nn.Module): 322 | expansion = 1 323 | 324 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'): 325 | super(BasicBlock_AP, self).__init__() 326 | self.norm = norm 327 | self.stride = stride 328 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False) # modification 329 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 330 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 331 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 332 | 333 | self.shortcut = nn.Sequential() 334 | if stride != 1 or in_planes != self.expansion * planes: 335 | self.shortcut = nn.Sequential( 336 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=1, bias=False), 337 | nn.AvgPool2d(kernel_size=2, stride=2), # modification 338 | nn.GroupNorm(self.expansion * planes, self.expansion * planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion * planes) 339 | ) 340 | 341 | def forward(self, x): 342 | out = F.relu(self.bn1(self.conv1(x))) 343 | if self.stride != 1: # modification 344 | out = F.avg_pool2d(out, kernel_size=2, stride=2) 345 | out = self.bn2(self.conv2(out)) 346 | out += self.shortcut(x) 347 | out = F.relu(out) 348 | return out 349 | 350 | 351 | class Bottleneck_AP(nn.Module): 352 | expansion = 4 353 | 354 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'): 355 | super(Bottleneck_AP, self).__init__() 356 | self.norm = norm 357 | self.stride = stride 358 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 359 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 360 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) # modification 361 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 362 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 363 | self.bn3 = nn.GroupNorm(self.expansion * planes, self.expansion * planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion * planes) 364 | 365 | self.shortcut = nn.Sequential() 366 | if stride != 1 or in_planes != self.expansion * planes: 367 | self.shortcut = nn.Sequential( 368 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=1, bias=False), 369 | nn.AvgPool2d(kernel_size=2, stride=2), # modification 370 | nn.GroupNorm(self.expansion * planes, self.expansion * planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion * planes) 371 | ) 372 | 373 | def forward(self, x): 374 | out = F.relu(self.bn1(self.conv1(x))) 375 | out = F.relu(self.bn2(self.conv2(out))) 376 | if self.stride != 1: # modification 377 | out = F.avg_pool2d(out, kernel_size=2, stride=2) 378 | out = self.bn3(self.conv3(out)) 379 | out += self.shortcut(x) 380 | out = F.relu(out) 381 | return out 382 | 383 | 384 | class ResNet_AP(nn.Module): 385 | def __init__(self, block, num_blocks, channel=3, num_classes=10, norm='instancenorm'): 386 | super(ResNet_AP, self).__init__() 387 | self.in_planes = 64 388 | self.norm = norm 389 | 390 | self.conv1 = nn.Conv2d(channel, 64, kernel_size=3, stride=1, padding=1, bias=False) 391 | self.bn1 = nn.GroupNorm(64, 64, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(64) 392 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 393 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 394 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 395 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 396 | self.classifier = nn.Linear(512 * block.expansion * 3 * 3 if channel==1 else 512 * block.expansion * 4 * 4, num_classes) # modification 397 | 398 | def _make_layer(self, block, planes, num_blocks, stride): 399 | strides = [stride] + [1] * (num_blocks - 1) 400 | layers = [] 401 | for stride in strides: 402 | layers.append(block(self.in_planes, planes, stride, self.norm)) 403 | self.in_planes = planes * block.expansion 404 | return nn.Sequential(*layers) 405 | 406 | def forward(self, x): 407 | out = F.relu(self.bn1(self.conv1(x))) 408 | out = self.layer1(out) 409 | out = self.layer2(out) 410 | out = self.layer3(out) 411 | out = self.layer4(out) 412 | out = F.avg_pool2d(out, kernel_size=1, stride=1) # modification 413 | out = out.view(out.size(0), -1) 414 | out = self.classifier(out) 415 | return out 416 | 417 | 418 | def ResNet18BN_AP(channel, num_classes): 419 | return ResNet_AP(BasicBlock_AP, [2,2,2,2], channel=channel, num_classes=num_classes, norm='batchnorm') 420 | 421 | def ResNet18_AP(channel, num_classes): 422 | return ResNet_AP(BasicBlock_AP, [2,2,2,2], channel=channel, num_classes=num_classes) 423 | 424 | 425 | ''' ResNet ''' 426 | 427 | class BasicBlock(nn.Module): 428 | expansion = 1 429 | 430 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'): 431 | super(BasicBlock, self).__init__() 432 | self.norm = norm 433 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 434 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 435 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 436 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 437 | 438 | self.shortcut = nn.Sequential() 439 | if stride != 1 or in_planes != self.expansion*planes: 440 | self.shortcut = nn.Sequential( 441 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 442 | nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes) 443 | ) 444 | 445 | def forward(self, x): 446 | out = F.relu(self.bn1(self.conv1(x))) 447 | out = self.bn2(self.conv2(out)) 448 | out += self.shortcut(x) 449 | out = F.relu(out) 450 | return out 451 | 452 | 453 | class Bottleneck(nn.Module): 454 | expansion = 4 455 | 456 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'): 457 | super(Bottleneck, self).__init__() 458 | self.norm = norm 459 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 460 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 461 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 462 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 463 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 464 | self.bn3 = nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes) 465 | 466 | self.shortcut = nn.Sequential() 467 | if stride != 1 or in_planes != self.expansion*planes: 468 | self.shortcut = nn.Sequential( 469 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 470 | nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes) 471 | ) 472 | 473 | def forward(self, x): 474 | out = F.relu(self.bn1(self.conv1(x))) 475 | out = F.relu(self.bn2(self.conv2(out))) 476 | out = self.bn3(self.conv3(out)) 477 | out += self.shortcut(x) 478 | out = F.relu(out) 479 | return out 480 | 481 | 482 | class ResNetImageNet(nn.Module): 483 | def __init__(self, block, num_blocks, channel=3, num_classes=10, norm='instancenorm'): 484 | super(ResNetImageNet, self).__init__() 485 | self.in_planes = 64 486 | self.norm = norm 487 | 488 | self.conv1 = nn.Conv2d(channel, 64, kernel_size=7, stride=2, padding=3, bias=False) 489 | self.bn1 = nn.GroupNorm(64, 64, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(64) 490 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 491 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 492 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 493 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 494 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 495 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 496 | self.classifier = nn.Linear(512*block.expansion, num_classes) 497 | 498 | def _make_layer(self, block, planes, num_blocks, stride): 499 | strides = [stride] + [1]*(num_blocks-1) 500 | layers = [] 501 | for stride in strides: 502 | layers.append(block(self.in_planes, planes, stride, self.norm)) 503 | self.in_planes = planes * block.expansion 504 | return nn.Sequential(*layers) 505 | 506 | def forward(self, x): 507 | out = F.relu(self.bn1(self.conv1(x))) 508 | out = self.maxpool(out) 509 | out = self.layer1(out) 510 | out = self.layer2(out) 511 | out = self.layer3(out) 512 | out = self.layer4(out) 513 | # out = F.avg_pool2d(out, 4) 514 | # out = out.view(out.size(0), -1) 515 | out = self.avgpool(out) 516 | out = torch.flatten(out, 1) 517 | out = self.classifier(out) 518 | return out 519 | 520 | 521 | def ResNet18BN(channel, num_classes): 522 | return ResNet(BasicBlock, [2,2,2,2], channel=channel, num_classes=num_classes, norm='batchnorm') 523 | 524 | def ResNet18(channel, num_classes): 525 | return ResNet_gn(BasicBlock, [2,2,2,2], channel=channel, num_classes=num_classes) 526 | 527 | def ResNet34(channel, num_classes): 528 | return ResNet(BasicBlock, [3,4,6,3], channel=channel, num_classes=num_classes) 529 | 530 | def ResNet50(channel, num_classes): 531 | return ResNet(Bottleneck, [3,4,6,3], channel=channel, num_classes=num_classes) 532 | 533 | def ResNet101(channel, num_classes): 534 | return ResNet(Bottleneck, [3,4,23,3], channel=channel, num_classes=num_classes) 535 | 536 | def ResNet152(channel, num_classes): 537 | return ResNet(Bottleneck, [3,8,36,3], channel=channel, num_classes=num_classes) 538 | 539 | def ResNet18ImageNet(channel, num_classes): 540 | return ResNetImageNet(BasicBlock, [2,2,2,2], channel=channel, num_classes=num_classes) 541 | 542 | def ResNet6ImageNet(channel, num_classes): 543 | return ResNetImageNet(BasicBlock, [1,1,1,1], channel=channel, num_classes=num_classes) 544 | 545 | def resnet18_gn(pretrained=False, **kwargs): 546 | """Constructs a ResNet-18 model. 547 | """ 548 | model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2]) 549 | return _create_resnet('resnet18', pretrained, **dict(model_args, **kwargs)) 550 | 551 | 552 | ## Sourced directly from OpenAI's CLIP repo 553 | class ModifiedResNet(nn.Module): 554 | """ 555 | A ResNet class that is similar to torchvision's but contains the following changes: 556 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 557 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 558 | - The final pooling layer is a QKV attention instead of an average pool 559 | """ 560 | 561 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 562 | super().__init__() 563 | self.output_dim = output_dim 564 | self.input_resolution = input_resolution 565 | 566 | # the 3-layer stem 567 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 568 | self.bn1 = nn.BatchNorm2d(width // 2) 569 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 570 | self.bn2 = nn.BatchNorm2d(width // 2) 571 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 572 | self.bn3 = nn.BatchNorm2d(width) 573 | self.avgpool = nn.AvgPool2d(2) 574 | self.relu = nn.ReLU(inplace=True) 575 | 576 | # residual layers 577 | self._inplanes = width # this is a *mutable* variable used during construction 578 | self.layer1 = self._make_layer(width, layers[0]) 579 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 580 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 581 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 582 | 583 | embed_dim = width * 32 # the ResNet feature dimension 584 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 585 | 586 | def _make_layer(self, planes, blocks, stride=1): 587 | layers = [Bottleneck(self._inplanes, planes, stride)] 588 | 589 | self._inplanes = planes * Bottleneck.expansion 590 | for _ in range(1, blocks): 591 | layers.append(Bottleneck(self._inplanes, planes)) 592 | 593 | return nn.Sequential(*layers) 594 | 595 | def forward(self, x): 596 | def stem(x): 597 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 598 | x = self.relu(bn(conv(x))) 599 | x = self.avgpool(x) 600 | return x 601 | 602 | x = x.type(self.conv1.weight.dtype) 603 | x = stem(x) 604 | x = self.layer1(x) 605 | x = self.layer2(x) 606 | x = self.layer3(x) 607 | x = self.layer4(x) 608 | x = self.attnpool(x) 609 | 610 | return x 611 | 612 | 613 | class AttentionPool2d(nn.Module): 614 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 615 | super().__init__() 616 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 617 | self.k_proj = nn.Linear(embed_dim, embed_dim) 618 | self.q_proj = nn.Linear(embed_dim, embed_dim) 619 | self.v_proj = nn.Linear(embed_dim, embed_dim) 620 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 621 | self.num_heads = num_heads 622 | 623 | def forward(self, x): 624 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 625 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 626 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 627 | x, _ = F.multi_head_attention_forward( 628 | query=x, key=x, value=x, 629 | embed_dim_to_check=x.shape[-1], 630 | num_heads=self.num_heads, 631 | q_proj_weight=self.q_proj.weight, 632 | k_proj_weight=self.k_proj.weight, 633 | v_proj_weight=self.v_proj.weight, 634 | in_proj_weight=None, 635 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 636 | bias_k=None, 637 | bias_v=None, 638 | add_zero_attn=False, 639 | dropout_p=0, 640 | out_proj_weight=self.c_proj.weight, 641 | out_proj_bias=self.c_proj.bias, 642 | use_separate_proj_weight=True, 643 | training=self.training, 644 | need_weights=False 645 | ) 646 | 647 | return x[0] 648 | 649 | import timm 650 | 651 | class ProjectionHead(nn.Module): 652 | def __init__( 653 | self, 654 | embedding_dim, 655 | projection_dim=768, 656 | dropout=0.1 657 | ): 658 | super().__init__() 659 | self.projection = nn.Linear(embedding_dim, projection_dim) 660 | self.gelu = nn.GELU() 661 | self.fc = nn.Linear(projection_dim, projection_dim) 662 | self.dropout = nn.Dropout(dropout) 663 | self.layer_norm = nn.LayerNorm(projection_dim) 664 | 665 | def forward(self, x): 666 | projected = self.projection(x) 667 | x = self.gelu(projected) 668 | x = self.fc(x) 669 | x = self.dropout(x) 670 | x = x + projected 671 | x = self.layer_norm(x) 672 | return x 673 | 674 | 675 | 676 | @functools.lru_cache(maxsize=128) 677 | def load_from_timm(model_name, pretrained): 678 | if model_name == 'clip': 679 | raise NotImplementedError("it is unfair to use pretrained clip") 680 | # if pretrained: 681 | # model, preprocess = clip.load("ViT-B/32", device='cuda') 682 | # else: 683 | # configuration = ViTConfig() 684 | # model = ViTModel(configuration) 685 | 686 | elif model_name == 'nfnet': 687 | model = timm.create_model('nfnet_l0', pretrained=pretrained, num_classes=0, global_pool="avg", 688 | pretrained_cfg_overlay=dict(file='distill_utils/checkpoints/nfnet_l0_ra2-45c6688d.pth'),) 689 | elif model_name == 'vit': 690 | model = timm.create_model('vit_tiny_patch16_224', pretrained=True) 691 | elif model_name == 'nf_resnet50': 692 | model = timm.create_model('nf_resnet50', pretrained=True) 693 | elif model_name == 'nf_regnet': 694 | model = timm.create_model('nf_regnet_b1', pretrained=True) 695 | elif model_name=="efficientvit_m5": 696 | model = timm.create_model(model_name, num_classes=0, pretrained=True) 697 | else: 698 | model = timm.create_model(model_name, num_classes=0, pretrained=True, global_pool="avg") 699 | 700 | 701 | return model 702 | 703 | 704 | 705 | 706 | class ImageEncoder(nn.Module): 707 | """ 708 | Encode images to a fixed size vector 709 | """ 710 | 711 | def __init__(self, args): 712 | super().__init__() 713 | self.model_name = args.image_encoder 714 | self.pretrained = args.image_pretrained 715 | self.trainable = args.image_trainable 716 | 717 | self.model = copy.deepcopy(load_from_timm(self.model_name, self.pretrained)) 718 | # use cached pretrained model 719 | 720 | for p in self.model.parameters(): 721 | p.requires_grad = self.trainable 722 | 723 | def forward(self, x): 724 | if self.model_name == 'clip' and self.pretrained: 725 | return self.model.encode_image(x) 726 | else: 727 | return self.model(x) 728 | 729 | def gradient(self, x, y): 730 | # Compute the gradient of the mean squared error loss with respect to the weights 731 | loss = self.loss(x, y) 732 | grad = torch.autograd.grad(loss, self.parameters(), create_graph=True) 733 | return torch.cat([g.view(-1) for g in grad]) 734 | 735 | 736 | 737 | 738 | class TextEncoder(nn.Module): 739 | def __init__(self, args): 740 | super().__init__() 741 | self.pretrained = args.text_pretrained 742 | self.trainable = args.text_trainable 743 | self.model_name = args.text_encoder 744 | 745 | if self.model_name == 'clip': 746 | self.model, preprocess = clip.load("ViT-B/32", device='cuda', download_root="distill_utils/checkpoints") 747 | elif self.model_name == 'bert': 748 | pt_model, self.tokenizer = get_bert_stuff() 749 | if args.text_pretrained: 750 | self.model = pt_model 751 | else: 752 | self.model = BertModel(BertConfig()) 753 | self.model.init_weights() 754 | elif self.model_name == 'distilbert': 755 | self.model, self.tokenizer = get_distilbert_stuff() 756 | elif self.model_name == 'gpt1': 757 | self.model, self.tokenizer = get_gpt1_stuff() 758 | else: 759 | raise NotImplementedError(self.model_name) 760 | 761 | for p in self.model.parameters(): 762 | p.requires_grad = self.trainable 763 | 764 | # we are using the CLS token hidden representation as the sentence's embedding 765 | self.target_token_idx = 0 766 | 767 | def forward(self, texts, device='cuda'): 768 | if self.model_name == 'clip': 769 | output = self.model.encode_text(clip.tokenize(texts).to('cuda')) 770 | 771 | elif self.model_name == 'bert': 772 | # Tokenize the input text 773 | encoding = self.tokenizer.batch_encode_plus(texts, return_tensors='pt', padding=True, truncation=True) 774 | input_ids = encoding['input_ids'].to(device) 775 | attention_mask = encoding['attention_mask'].to(device) 776 | output = self.model(input_ids, attention_mask=attention_mask).last_hidden_state[:, self.target_token_idx, :] 777 | 778 | elif self.model_name == 'distilbert': 779 | encoding = self.tokenizer.batch_encode_plus(texts, return_tensors='pt', padding=True, truncation=True) 780 | input_ids = encoding['input_ids'].to(device) 781 | attention_mask = encoding['attention_mask'].to(device) 782 | output = self.model(input_ids, attention_mask=attention_mask).last_hidden_state[:, self.target_token_idx, :] 783 | 784 | elif self.model_name == 'gpt1': 785 | self.tokenizer.pad_token = ' ' 786 | encoding = self.tokenizer.batch_encode_plus(texts, return_tensors='pt', padding=True, truncation=True) 787 | input_ids = encoding['input_ids'].to(device) 788 | attention_mask = encoding['attention_mask'].to(device) 789 | output = self.model(input_ids, attention_mask=attention_mask).last_hidden_state[:, self.target_token_idx, :] 790 | 791 | return output 792 | 793 | 794 | 795 | 796 | class CLIPModel_full(nn.Module): 797 | def __init__( 798 | self, 799 | args, 800 | train_logit_scale=False, 801 | temperature=0.07, 802 | eval_stage=False 803 | ): 804 | super().__init__() 805 | 806 | if args.image_encoder == 'nfnet': 807 | if eval_stage: 808 | self.image_embedding = 1000 #2048 809 | else: 810 | self.image_embedding = 2304 811 | elif args.image_encoder == 'convnet': 812 | self.image_embedding = 768 813 | elif args.image_encoder == 'resnet18': 814 | self.image_embedding = 512 815 | elif args.image_encoder == 'convnext': 816 | self.image_embedding = 640 817 | else: 818 | self.image_embedding = 1000 819 | 820 | if args.text_encoder == 'clip': 821 | self.text_embedding = 512 822 | elif args.text_encoder == 'bert': 823 | self.text_embedding = 768 824 | elif args.text_encoder == 'distilbert': 825 | self.text_embedding = 768 826 | elif args.text_encoder == 'gpt1': 827 | self.text_embedding = 768 828 | else: 829 | raise NotImplementedError 830 | 831 | self.image_encoder = ImageEncoder(args) 832 | 833 | self.text_encoder = TextEncoder(args) 834 | 835 | 836 | if args.only_has_image_projection: 837 | self.image_projection = ProjectionHead(embedding_dim=self.image_embedding) 838 | self.text_projection = ProjectionHead(embedding_dim=self.text_embedding, projection_dim=self.image_embedding).to('cuda') 839 | if train_logit_scale: 840 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1.0 / temperature)) 841 | else: 842 | self.logit_scale = torch.ones([]) * np.log(1.0 / temperature) 843 | 844 | 845 | self.args = args 846 | self.distill = args.distill 847 | 848 | self.multilabel_criterion = MultilabelContrastiveLoss(args.loss_type) 849 | 850 | 851 | def forward(self, image, caption, epoch, similarity=None): 852 | self.image_encoder = self.image_encoder.to('cuda') 853 | self.text_encoder = self.text_encoder.to('cuda') 854 | 855 | image_features = self.image_encoder(image) 856 | text_features = caption if self.distill else self.text_encoder(caption) 857 | 858 | use_image_project = False 859 | im_embed = image_features.float() if not use_image_project else self.image_projection(image_features.float()) 860 | 861 | txt_embed = self.text_projection(text_features.float()) 862 | 863 | combined_image_features = im_embed 864 | combined_text_features = txt_embed 865 | image_features = combined_image_features / combined_image_features.norm(dim=1, keepdim=True) 866 | text_features = combined_text_features / combined_text_features.norm(dim=1, keepdim=True) 867 | 868 | image_logits = self.logit_scale.exp() * image_features @ text_features.t() 869 | 870 | ground_truth = torch.arange(len(image_logits)).type_as(image_logits).long() 871 | acc_i = (torch.argmax(image_logits, 1) == ground_truth).sum().item() 872 | acc_t = (torch.argmax(image_logits, 0) == ground_truth).sum().item() 873 | acc = (acc_i + acc_t) / 2 874 | 875 | 876 | if similarity is None: 877 | loss = (F.cross_entropy(image_logits, ground_truth) + F.cross_entropy(image_logits.t(), ground_truth))/2 878 | else: 879 | loss = self.multilabel_criterion(image_logits, similarity) 880 | 881 | return loss, acc -------------------------------------------------------------------------------- /src/reparam_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import warnings 4 | import types 5 | from collections import namedtuple 6 | from contextlib import contextmanager 7 | 8 | 9 | class ReparamModule(nn.Module): 10 | def _get_module_from_name(self, mn): 11 | if mn == '': 12 | return self 13 | m = self 14 | for p in mn.split('.'): 15 | m = getattr(m, p) 16 | return m 17 | 18 | def __init__(self, module): 19 | super(ReparamModule, self).__init__() 20 | self.module = module 21 | 22 | param_infos = [] # (module name/path, param name) 23 | shared_param_memo = {} 24 | shared_param_infos = [] # (module name/path, param name, src module name/path, src param_name) 25 | params = [] 26 | param_numels = [] 27 | param_shapes = [] 28 | for mn, m in self.named_modules(): 29 | for n, p in m.named_parameters(recurse=False): 30 | if p is not None: 31 | if p in shared_param_memo: 32 | shared_mn, shared_n = shared_param_memo[p] 33 | shared_param_infos.append((mn, n, shared_mn, shared_n)) 34 | else: 35 | shared_param_memo[p] = (mn, n) 36 | param_infos.append((mn, n)) 37 | params.append(p.detach()) 38 | param_numels.append(p.numel()) 39 | param_shapes.append(p.size()) 40 | 41 | assert len(set(p.dtype for p in params)) <= 1, \ 42 | "expects all parameters in module to have same dtype" 43 | 44 | # store the info for unflatten 45 | self._param_infos = tuple(param_infos) 46 | self._shared_param_infos = tuple(shared_param_infos) 47 | self._param_numels = tuple(param_numels) 48 | self._param_shapes = tuple(param_shapes) 49 | 50 | # flatten 51 | flat_param = nn.Parameter(torch.cat([p.reshape(-1) for p in params], 0)) 52 | self.register_parameter('flat_param', flat_param) 53 | self.param_numel = flat_param.numel() 54 | del params 55 | del shared_param_memo 56 | 57 | # deregister the names as parameters 58 | for mn, n in self._param_infos: 59 | delattr(self._get_module_from_name(mn), n) 60 | for mn, n, _, _ in self._shared_param_infos: 61 | delattr(self._get_module_from_name(mn), n) 62 | 63 | # register the views as plain attributes 64 | self._unflatten_param(self.flat_param) 65 | 66 | # now buffers 67 | # they are not reparametrized. just store info as (module, name, buffer) 68 | buffer_infos = [] 69 | for mn, m in self.named_modules(): 70 | for n, b in m.named_buffers(recurse=False): 71 | if b is not None: 72 | buffer_infos.append((mn, n, b)) 73 | 74 | self._buffer_infos = tuple(buffer_infos) 75 | self._traced_self = None 76 | 77 | def trace(self, example_input, **trace_kwargs): 78 | assert self._traced_self is None, 'This ReparamModule is already traced' 79 | 80 | if isinstance(example_input, torch.Tensor): 81 | example_input = (example_input,) 82 | example_input = tuple(example_input) 83 | example_param = (self.flat_param.detach().clone(),) 84 | example_buffers = (tuple(b.detach().clone() for _, _, b in self._buffer_infos),) 85 | 86 | self._traced_self = torch.jit.trace_module( 87 | self, 88 | inputs=dict( 89 | _forward_with_param=example_param + example_input, 90 | _forward_with_param_and_buffers=example_param + example_buffers + example_input, 91 | ), 92 | **trace_kwargs, 93 | ) 94 | 95 | # replace forwards with traced versions 96 | self._forward_with_param = self._traced_self._forward_with_param 97 | self._forward_with_param_and_buffers = self._traced_self._forward_with_param_and_buffers 98 | return self 99 | 100 | def clear_views(self): 101 | for mn, n in self._param_infos: 102 | setattr(self._get_module_from_name(mn), n, None) # This will set as plain attr 103 | 104 | def _apply(self, *args, **kwargs): 105 | if self._traced_self is not None: 106 | self._traced_self._apply(*args, **kwargs) 107 | return self 108 | return super(ReparamModule, self)._apply(*args, **kwargs) 109 | 110 | def _unflatten_param(self, flat_param): 111 | ps = (t.view(s) for (t, s) in zip(flat_param.split(self._param_numels), self._param_shapes)) 112 | for (mn, n), p in zip(self._param_infos, ps): 113 | setattr(self._get_module_from_name(mn), n, p) # This will set as plain attr 114 | for (mn, n, shared_mn, shared_n) in self._shared_param_infos: 115 | setattr(self._get_module_from_name(mn), n, getattr(self._get_module_from_name(shared_mn), shared_n)) 116 | 117 | @contextmanager 118 | def unflattened_param(self, flat_param): 119 | saved_views = [getattr(self._get_module_from_name(mn), n) for mn, n in self._param_infos] 120 | self._unflatten_param(flat_param) 121 | yield 122 | # Why not just `self._unflatten_param(self.flat_param)`? 123 | # 1. because of https://github.com/pytorch/pytorch/issues/17583 124 | # 2. slightly faster since it does not require reconstruct the split+view 125 | # graph 126 | for (mn, n), p in zip(self._param_infos, saved_views): 127 | setattr(self._get_module_from_name(mn), n, p) 128 | for (mn, n, shared_mn, shared_n) in self._shared_param_infos: 129 | setattr(self._get_module_from_name(mn), n, getattr(self._get_module_from_name(shared_mn), shared_n)) 130 | 131 | @contextmanager 132 | def replaced_buffers(self, buffers): 133 | for (mn, n, _), new_b in zip(self._buffer_infos, buffers): 134 | setattr(self._get_module_from_name(mn), n, new_b) 135 | yield 136 | for mn, n, old_b in self._buffer_infos: 137 | setattr(self._get_module_from_name(mn), n, old_b) 138 | 139 | def _forward_with_param_and_buffers(self, flat_param, buffers, *inputs, **kwinputs): 140 | with self.unflattened_param(flat_param): 141 | with self.replaced_buffers(buffers): 142 | return self.module(*inputs, **kwinputs) 143 | 144 | def _forward_with_param(self, flat_param, *inputs, **kwinputs): 145 | with self.unflattened_param(flat_param): 146 | return self.module(*inputs, **kwinputs) 147 | 148 | def forward(self, *inputs, flat_param=None, buffers=None, **kwinputs): 149 | flat_param = torch.squeeze(flat_param) 150 | # print("PARAMS ON DEVICE: ", flat_param.get_device()) 151 | # print("DATA ON DEVICE: ", inputs[0].get_device()) 152 | # flat_param.to("cuda:{}".format(inputs[0].get_device())) 153 | # self.module.to("cuda:{}".format(inputs[0].get_device())) 154 | if flat_param is None: 155 | flat_param = self.flat_param 156 | if buffers is None: 157 | return self._forward_with_param(flat_param, *inputs, **kwinputs) 158 | else: 159 | return self._forward_with_param_and_buffers(flat_param, tuple(buffers), *inputs, **kwinputs) -------------------------------------------------------------------------------- /src/similarity_mining.py: -------------------------------------------------------------------------------- 1 | """Similarity modules and loss modules for LoRS 2 | by Yue Xu""" 3 | 4 | from abc import ABC, abstractmethod 5 | from typing import List 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | 12 | class BaseSimilarityGenerator(nn.Module, ABC): 13 | def __init__(self, dim): 14 | super().__init__() 15 | self.dim = dim 16 | 17 | @abstractmethod 18 | def generate_with_param(self, params: List): 19 | pass 20 | 21 | @abstractmethod 22 | def get_indexed_parameters(self, indices=None) -> List: 23 | pass 24 | 25 | def load_params(self, params): 26 | raise NotImplementedError("") 27 | 28 | 29 | 30 | class LowRankSimilarityGenerator(BaseSimilarityGenerator): 31 | """Generate a parameterized similarity matrix S = aI + L@R' 32 | where a is a vector, I is the identity matrix, and B is a matrix of learnable parameters.""" 33 | def __init__(self, dim, rank, alpha=0.1): 34 | super().__init__(dim) 35 | self.rank = float(rank) 36 | 37 | self.alpha = alpha 38 | 39 | self.diag_weight = nn.Parameter(torch.ones(dim)) 40 | self.left = nn.Parameter(torch.randn(dim, rank)) 41 | self.right = nn.Parameter(torch.zeros(dim, rank)) 42 | 43 | 44 | def generate_with_param(self, params: List): 45 | a, l, r = params 46 | 47 | diag = torch.diag(a) 48 | sim = diag + l @ r.t() * self.alpha / self.rank 49 | return sim 50 | 51 | def get_indexed_parameters(self, indices=None) -> List: 52 | if indices is None: 53 | a, l, r = self.diag_weight, self.left, self.right 54 | else: 55 | a, l, r = self.diag_weight[indices], self.left[indices], self.right[indices] 56 | 57 | return [a, l, r] 58 | 59 | def load_params(self, params): 60 | a, l, r = params 61 | self.diag_weight.data = a 62 | self.left.data = l 63 | self.right.data = r 64 | 65 | 66 | 67 | class FullSimilarityGenerator(BaseSimilarityGenerator): 68 | def __init__(self, dim): 69 | super().__init__(dim) 70 | self.sim_mat = nn.Parameter(torch.eye(dim)) 71 | 72 | def generate_with_param(self, params): 73 | assert len(params) == 1 74 | return params[0] 75 | 76 | def get_indexed_parameters(self, indices=None) -> List: 77 | if indices is None: 78 | sim = self.sim_mat 79 | else: 80 | sim = self.sim_mat[indices[:,None], indices] 81 | return [sim] 82 | 83 | def load_params(self, params): 84 | s = params[0] 85 | self.sim_mat.data = s 86 | 87 | 88 | 89 | 90 | 91 | class MultilabelContrastiveLoss(nn.Module): 92 | """ 93 | Cross entropy loss with soft target. 94 | """ 95 | 96 | def __init__(self, loss_type=None): 97 | """ 98 | Args: 99 | reduction (str): specifies reduction to apply to the output. It can be 100 | "mean" (default) or "none". 101 | """ 102 | super().__init__() 103 | self.loss_type = loss_type 104 | 105 | self.kl_loss_func = nn.KLDivLoss(reduction="mean") 106 | self.bce_loss_func = nn.BCELoss(reduction="none") 107 | 108 | self.n_clusters = 4 # for KmeansBalanceBCE 109 | 110 | 111 | 112 | def __kl_criterion(self, logit, label): 113 | # batchsize = logit.shape[0] 114 | probs1 = F.log_softmax(logit, 1) 115 | probs2 = F.softmax(label * 10, 1) 116 | loss = self.kl_loss_func(probs1, probs2) 117 | return loss 118 | 119 | def __cwcl_criterion(self, logit, label): 120 | logprob = torch.log_softmax(logit, 1) 121 | loss_per = (label * logprob).sum(1) / (label.sum(1)+1e-6) 122 | loss = -loss_per.mean() 123 | return loss 124 | 125 | def __infonce_nonvonventional_criterion(self, logit, label): 126 | logprob = torch.log_softmax(logit, 1) 127 | loss_per = (label * logprob).sum(1) 128 | loss = -loss_per.mean() 129 | return loss 130 | 131 | 132 | def forward(self, logits, gt_matrix): 133 | gt_matrix = gt_matrix.to(logits.device) 134 | 135 | if self.loss_type == "KL": 136 | loss_i = self.__kl_criterion(logits, gt_matrix) 137 | loss_t = self.__kl_criterion(logits.t(), gt_matrix.t()) 138 | return (loss_i + loss_t) * 0.5 139 | elif self.loss_type == "BCE": 140 | # print("BCE is called") 141 | probs1 = torch.sigmoid(logits) 142 | probs2 = gt_matrix # torch.sigmoid(gt_matrix) 143 | bce_loss = self.bce_loss_func(probs1, probs2) 144 | loss = bce_loss.mean() 145 | # loss = self.__general_cl_criterion(logits, gt_matrix, "BCE", 146 | # use_norm=False, use_negwgt=False) 147 | return loss 148 | 149 | elif self.loss_type in ["BalanceBCE", "WBCE"]: 150 | probs1 = torch.sigmoid(logits) 151 | probs2 = gt_matrix # torch.sigmoid(gt_matrix) 152 | 153 | loss_matrix = - probs2 * torch.log(probs1+1e-6) - (1-probs2) * torch.log(1-probs1+1e-6) 154 | 155 | pos_mask = (probs2>0.5).detach() 156 | neg_mask = ~pos_mask 157 | 158 | loss_pos = torch.where(pos_mask, loss_matrix, torch.tensor(0.0, device=probs1.device)).sum() 159 | loss_neg = torch.where(neg_mask, loss_matrix, torch.tensor(0.0, device=probs1.device)).sum() 160 | 161 | loss_pos /= (pos_mask.sum()+1e-6) 162 | loss_neg /= (neg_mask.sum()+1e-6) 163 | 164 | return (loss_pos+loss_neg)/2 165 | 166 | 167 | elif self.loss_type in ["NCE", "InfoNCE"]: 168 | loss_i = self.__infonce_nonvonventional_criterion(logits, gt_matrix) 169 | loss_t = self.__infonce_nonvonventional_criterion(logits.t(), gt_matrix.t()) 170 | return (loss_i + loss_t) * 0.5 171 | 172 | elif self.loss_type == "MSE": 173 | probs = torch.sigmoid(logits) 174 | return F.mse_loss(probs, gt_matrix) 175 | 176 | elif self.loss_type == "CWCL": 177 | loss_i = self.__cwcl_criterion(logits, gt_matrix) 178 | loss_t = self.__cwcl_criterion(logits.t(), gt_matrix.t()) 179 | return (loss_i + loss_t) * 0.5 180 | 181 | else: 182 | raise NotImplementedError(self.loss_type) -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | # adapted from 2 | # https://github.com/VICO-UoE/DatasetCondensation 3 | import time 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import os 9 | import kornia as K 10 | import tqdm 11 | from torch.utils.data import Dataset 12 | from torchvision import datasets, transforms 13 | from scipy.ndimage.interpolation import rotate as scipyrotate 14 | from src.networks import MLP, ConvNet, LeNet, AlexNet, VGG11BN, VGG11, ResNet18, ResNet18BN_AP, ResNet18_AP, ModifiedResNet, resnet18_gn 15 | import re 16 | import json 17 | import torch.distributed as dist 18 | from tqdm import tqdm 19 | from collections import defaultdict 20 | 21 | class Config: 22 | imagenette = [0, 217, 482, 491, 497, 566, 569, 571, 574, 701] 23 | 24 | # ["australian_terrier", "border_terrier", "samoyed", "beagle", "shih-tzu", "english_foxhound", "rhodesian_ridgeback", "dingo", "golden_retriever", "english_sheepdog"] 25 | imagewoof = [193, 182, 258, 162, 155, 167, 159, 273, 207, 229] 26 | 27 | # ["tabby_cat", "bengal_cat", "persian_cat", "siamese_cat", "egyptian_cat", "lion", "tiger", "jaguar", "snow_leopard", "lynx"] 28 | imagemeow = [281, 282, 283, 284, 285, 291, 292, 290, 289, 287] 29 | 30 | # ["peacock", "flamingo", "macaw", "pelican", "king_penguin", "bald_eagle", "toucan", "ostrich", "black_swan", "cockatoo"] 31 | imagesquawk = [84, 130, 88, 144, 145, 22, 96, 9, 100, 89] 32 | 33 | # ["pineapple", "banana", "strawberry", "orange", "lemon", "pomegranate", "fig", "bell_pepper", "cucumber", "green_apple"] 34 | imagefruit = [953, 954, 949, 950, 951, 957, 952, 945, 943, 948] 35 | 36 | # ["bee", "ladys slipper", "banana", "lemon", "corn", "school_bus", "honeycomb", "lion", "garden_spider", "goldfinch"] 37 | imageyellow = [309, 986, 954, 951, 987, 779, 599, 291, 72, 11] 38 | 39 | dict = { 40 | "imagenette" : imagenette, 41 | "imagewoof" : imagewoof, 42 | "imagefruit": imagefruit, 43 | "imageyellow": imageyellow, 44 | "imagemeow": imagemeow, 45 | "imagesquawk": imagesquawk, 46 | } 47 | 48 | config = Config() 49 | 50 | def get_dataset(args): 51 | 52 | if args.dataset == 'CIFAR10_clip': 53 | channel = 3 54 | im_size = (32, 32) 55 | num_classes = 768 56 | mean = [0.4914, 0.4822, 0.4465] 57 | std = [0.2023, 0.1994, 0.2010] 58 | if args.zca: 59 | transform = transforms.Compose([transforms.ToTensor()]) 60 | else: 61 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) 62 | dst_train = datasets.CIFAR10(args.data_path, train=True, download=True, transform=transform) # no augmentation 63 | dst_test = datasets.CIFAR10(args.data_path, train=False, download=True, transform=transform) 64 | class_names = dst_train.classes 65 | class_map = {x:x for x in range(num_classes)} 66 | 67 | else: 68 | exit('unknown dataset: %s'%args.dataset) 69 | 70 | if args.zca: 71 | images = [] 72 | labels = [] 73 | print("Train ZCA") 74 | for i in tqdm(range(len(dst_train))): 75 | im, lab = dst_train[i] 76 | images.append(im) 77 | labels.append(lab) 78 | images = torch.stack(images, dim=0).to(args.device) 79 | labels = torch.tensor(labels, dtype=torch.float, device="cpu") 80 | zca = K.enhance.ZCAWhitening(eps=0.1, compute_inv=True) 81 | zca.fit(images) 82 | zca_images = zca(images).to("cpu") 83 | dst_train = TensorDataset(zca_images, labels) 84 | 85 | images = [] 86 | labels = [] 87 | print("Test ZCA") 88 | for i in tqdm(range(len(dst_test))): 89 | im, lab = dst_test[i] 90 | images.append(im) 91 | labels.append(lab) 92 | images = torch.stack(images, dim=0).to(args.device) 93 | labels = torch.tensor(labels, dtype=torch.float, device="cpu") 94 | 95 | zca_images = zca(images).to("cpu") 96 | dst_test = TensorDataset(zca_images, labels) 97 | 98 | args.zca_trans = zca 99 | 100 | 101 | testloader = torch.utils.data.DataLoader(dst_test, batch_size=128, shuffle=False, num_workers=2) 102 | 103 | if "flickr" not in args.dataset: 104 | dst_train_label, dst_test_label = None, None 105 | return channel, im_size, num_classes, mean, std, dst_train, dst_test, testloader, dst_train_label, dst_test_label 106 | 107 | 108 | 109 | class TensorDataset(Dataset): 110 | def __init__(self, images, labels): # images: n x c x h x w tensor 111 | self.images = images.detach().float() 112 | self.labels = labels.detach() 113 | 114 | def __getitem__(self, index): 115 | return self.images[index], self.labels[index] 116 | 117 | def __len__(self): 118 | return self.images.shape[0] 119 | 120 | 121 | 122 | def get_default_convnet_setting(): 123 | net_width, net_depth, net_act, net_norm, net_pooling = 128, 3, 'relu', 'instancenorm', 'avgpooling' 124 | return net_width, net_depth, net_act, net_norm, net_pooling 125 | 126 | 127 | 128 | def get_RN_network(model, vision_width, vision_layers, embed_dim, image_resolution): 129 | if model == 'RN50': 130 | vision_heads = vision_width * 32 // 64 131 | net = ModifiedResNet(layers=vision_layers, 132 | output_dim=embed_dim, 133 | heads=vision_heads, 134 | input_resolution=image_resolution, 135 | width=vision_width) 136 | if dist: 137 | gpu_num = torch.cuda.device_count() 138 | if gpu_num>0: 139 | device = 'cuda' 140 | if gpu_num>1: 141 | net = nn.DataParallel(net) 142 | else: 143 | device = 'cpu' 144 | net = net.to(device) 145 | 146 | return net 147 | 148 | def get_network(model, channel, num_classes, im_size=(32, 32), dist=True): 149 | torch.random.manual_seed(int(time.time() * 1000) % 100000) 150 | net_width, net_depth, net_act, net_norm, net_pooling = get_default_convnet_setting() 151 | 152 | if model == 'MLP': 153 | net = MLP(channel=channel, num_classes=num_classes) 154 | elif model == 'ConvNet': 155 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 156 | elif model == 'LeNet': 157 | net = LeNet(channel=channel, num_classes=num_classes) 158 | elif model == 'AlexNet': 159 | net = AlexNet(channel=channel, num_classes=num_classes) 160 | elif model == 'VGG11': 161 | net = VGG11( channel=channel, num_classes=num_classes) 162 | elif model == 'VGG11BN': 163 | net = VGG11BN(channel=channel, num_classes=num_classes) 164 | elif model == 'ResNet18': 165 | net = ResNet18(channel=channel, num_classes=num_classes) 166 | elif model == 'ResNet18BN_AP': 167 | net = ResNet18BN_AP(channel=channel, num_classes=num_classes) 168 | elif model == 'ResNet18_AP': 169 | net = ResNet18_AP(channel=channel, num_classes=num_classes) 170 | 171 | elif model == 'ConvNetD1': 172 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=1, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 173 | elif model == 'ConvNetD2': 174 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=2, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 175 | elif model == 'ConvNetD3': 176 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=3, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 177 | elif model == 'ConvNetD4': 178 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=4, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 179 | elif model == 'ConvNetD5': 180 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=5, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 181 | elif model == 'ConvNetD6': 182 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=6, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 183 | elif model == 'ConvNetD7': 184 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=7, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 185 | elif model == 'ConvNetD8': 186 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=8, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 187 | 188 | 189 | elif model == 'ConvNetW32': 190 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=32, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling) 191 | elif model == 'ConvNetW64': 192 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=64, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling) 193 | elif model == 'ConvNetW128': 194 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=128, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling) 195 | elif model == 'ConvNetW256': 196 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=256, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling) 197 | elif model == 'ConvNetW512': 198 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=512, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling) 199 | elif model == 'ConvNetW1024': 200 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=1024, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling) 201 | 202 | elif model == "ConvNetKIP": 203 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=1024, net_depth=net_depth, net_act=net_act, 204 | net_norm="none", net_pooling=net_pooling) 205 | 206 | elif model == 'ConvNetAS': 207 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='sigmoid', net_norm=net_norm, net_pooling=net_pooling) 208 | elif model == 'ConvNetAR': 209 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='relu', net_norm=net_norm, net_pooling=net_pooling) 210 | elif model == 'ConvNetAL': 211 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='leakyrelu', net_norm=net_norm, net_pooling=net_pooling) 212 | 213 | elif model == 'ConvNetNN': 214 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='none', net_pooling=net_pooling) 215 | elif model == 'ConvNetBN': 216 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='batchnorm', net_pooling=net_pooling) 217 | elif model == 'ConvNetLN': 218 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='layernorm', net_pooling=net_pooling) 219 | elif model == 'ConvNetIN': 220 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='instancenorm', net_pooling=net_pooling) 221 | elif model == 'ConvNetGN': 222 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='groupnorm', net_pooling=net_pooling) 223 | 224 | elif model == 'ConvNetNP': 225 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling='none') 226 | elif model == 'ConvNetMP': 227 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling='maxpooling') 228 | elif model == 'ConvNetAP': 229 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling='avgpooling') 230 | 231 | 232 | else: 233 | net = None 234 | exit('DC error: unknown model') 235 | 236 | if dist: 237 | gpu_num = torch.cuda.device_count() 238 | if gpu_num>0: 239 | device = 'cuda' 240 | if gpu_num>1: 241 | net = nn.DataParallel(net) 242 | else: 243 | device = 'cpu' 244 | net = net.to(device) 245 | 246 | return net 247 | 248 | 249 | 250 | def get_time(): 251 | return str(time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime())) 252 | 253 | 254 | 255 | def augment(images, dc_aug_param, device): 256 | # This can be sped up in the future. 257 | 258 | if dc_aug_param != None and dc_aug_param['strategy'] != 'none': 259 | scale = dc_aug_param['scale'] 260 | crop = dc_aug_param['crop'] 261 | rotate = dc_aug_param['rotate'] 262 | noise = dc_aug_param['noise'] 263 | strategy = dc_aug_param['strategy'] 264 | 265 | shape = images.shape 266 | mean = [] 267 | for c in range(shape[1]): 268 | mean.append(float(torch.mean(images[:,c]))) 269 | 270 | def cropfun(i): 271 | im_ = torch.zeros(shape[1],shape[2]+crop*2,shape[3]+crop*2, dtype=torch.float, device=device) 272 | for c in range(shape[1]): 273 | im_[c] = mean[c] 274 | im_[:, crop:crop+shape[2], crop:crop+shape[3]] = images[i] 275 | r, c = np.random.permutation(crop*2)[0], np.random.permutation(crop*2)[0] 276 | images[i] = im_[:, r:r+shape[2], c:c+shape[3]] 277 | 278 | def scalefun(i): 279 | h = int((np.random.uniform(1 - scale, 1 + scale)) * shape[2]) 280 | w = int((np.random.uniform(1 - scale, 1 + scale)) * shape[2]) 281 | tmp = F.interpolate(images[i:i + 1], [h, w], )[0] 282 | mhw = max(h, w, shape[2], shape[3]) 283 | im_ = torch.zeros(shape[1], mhw, mhw, dtype=torch.float, device=device) 284 | r = int((mhw - h) / 2) 285 | c = int((mhw - w) / 2) 286 | im_[:, r:r + h, c:c + w] = tmp 287 | r = int((mhw - shape[2]) / 2) 288 | c = int((mhw - shape[3]) / 2) 289 | images[i] = im_[:, r:r + shape[2], c:c + shape[3]] 290 | 291 | def rotatefun(i): 292 | im_ = scipyrotate(images[i].cpu().data.numpy(), angle=np.random.randint(-rotate, rotate), axes=(-2, -1), cval=np.mean(mean)) 293 | r = int((im_.shape[-2] - shape[-2]) / 2) 294 | c = int((im_.shape[-1] - shape[-1]) / 2) 295 | images[i] = torch.tensor(im_[:, r:r + shape[-2], c:c + shape[-1]], dtype=torch.float, device=device) 296 | 297 | def noisefun(i): 298 | images[i] = images[i] + noise * torch.randn(shape[1:], dtype=torch.float, device=device) 299 | 300 | 301 | augs = strategy.split('_') 302 | 303 | for i in range(shape[0]): 304 | choice = np.random.permutation(augs)[0] # randomly implement one augmentation 305 | if choice == 'crop': 306 | cropfun(i) 307 | elif choice == 'scale': 308 | scalefun(i) 309 | elif choice == 'rotate': 310 | rotatefun(i) 311 | elif choice == 'noise': 312 | noisefun(i) 313 | 314 | return images 315 | 316 | 317 | 318 | def get_daparam(dataset, model, model_eval, ipc): 319 | # We find that augmentation doesn't always benefit the performance. 320 | # So we do augmentation for some of the settings. 321 | 322 | dc_aug_param = dict() 323 | dc_aug_param['crop'] = 4 324 | dc_aug_param['scale'] = 0.2 325 | dc_aug_param['rotate'] = 45 326 | dc_aug_param['noise'] = 0.001 327 | dc_aug_param['strategy'] = 'none' 328 | 329 | if dataset == 'MNIST': 330 | dc_aug_param['strategy'] = 'crop_scale_rotate' 331 | 332 | if model_eval in ['ConvNetBN']: # Data augmentation makes model training with Batch Norm layer easier. 333 | dc_aug_param['strategy'] = 'crop_noise' 334 | 335 | return dc_aug_param 336 | 337 | 338 | def get_eval_pool(eval_mode, model, model_eval): 339 | if eval_mode == 'M': # multiple architectures 340 | # model_eval_pool = ['MLP', 'ConvNet', 'AlexNet', 'VGG11', 'ResNet18', 'LeNet'] 341 | model_eval_pool = ['ConvNet', 'AlexNet', 'VGG11', 'ResNet18_AP', 'ResNet18'] 342 | # model_eval_pool = ['MLP', 'ConvNet', 'AlexNet', 'VGG11', 'ResNet18'] 343 | elif eval_mode == 'W': # ablation study on network width 344 | model_eval_pool = ['ConvNetW32', 'ConvNetW64', 'ConvNetW128', 'ConvNetW256'] 345 | elif eval_mode == 'D': # ablation study on network depth 346 | model_eval_pool = ['ConvNetD1', 'ConvNetD2', 'ConvNetD3', 'ConvNetD4'] 347 | elif eval_mode == 'A': # ablation study on network activation function 348 | model_eval_pool = ['ConvNetAS', 'ConvNetAR', 'ConvNetAL'] 349 | elif eval_mode == 'P': # ablation study on network pooling layer 350 | model_eval_pool = ['ConvNetNP', 'ConvNetMP', 'ConvNetAP'] 351 | elif eval_mode == 'N': # ablation study on network normalization layer 352 | model_eval_pool = ['ConvNetNN', 'ConvNetBN', 'ConvNetLN', 'ConvNetIN', 'ConvNetGN'] 353 | elif eval_mode == 'S': # itself 354 | model_eval_pool = [model[:model.index('BN')]] if 'BN' in model else [model] 355 | elif eval_mode == 'C': 356 | model_eval_pool = [model, 'ConvNet'] 357 | else: 358 | model_eval_pool = [model_eval] 359 | return model_eval_pool 360 | 361 | 362 | class ParamDiffAug(): 363 | def __init__(self): 364 | self.aug_mode = 'S' #'multiple or single' 365 | self.prob_flip = 0.5 366 | self.ratio_scale = 1.2 367 | self.ratio_rotate = 15.0 368 | self.ratio_crop_pad = 0.125 369 | self.ratio_cutout = 0.5 # the size would be 0.5x0.5 370 | self.ratio_noise = 0.05 371 | self.brightness = 1.0 372 | self.saturation = 2.0 373 | self.contrast = 0.5 374 | 375 | 376 | def set_seed_DiffAug(param): 377 | if param.latestseed == -1: 378 | return 379 | else: 380 | torch.random.manual_seed(param.latestseed) 381 | param.latestseed += 1 382 | 383 | 384 | def DiffAugment(x, strategy='', seed = -1, param = None): 385 | if seed == -1: 386 | param.batchmode = False 387 | else: 388 | param.batchmode = True 389 | 390 | param.latestseed = seed 391 | 392 | if strategy == 'None' or strategy == 'none': 393 | return x 394 | 395 | if strategy: 396 | if param.aug_mode == 'M': # original 397 | for p in strategy.split('_'): 398 | for f in AUGMENT_FNS[p]: 399 | x = f(x, param) 400 | elif param.aug_mode == 'S': 401 | pbties = strategy.split('_') 402 | set_seed_DiffAug(param) 403 | p = pbties[torch.randint(0, len(pbties), size=(1,)).item()] 404 | for f in AUGMENT_FNS[p]: 405 | x = f(x, param) 406 | else: 407 | exit('Error ZH: unknown augmentation mode.') 408 | x = x.contiguous() 409 | return x 410 | 411 | 412 | # We implement the following differentiable augmentation strategies based on the code provided in https://github.com/mit-han-lab/data-efficient-gans. 413 | def rand_scale(x, param): 414 | # x>1, max scale 415 | # sx, sy: (0, +oo), 1: orignial size, 0.5: enlarge 2 times 416 | ratio = param.ratio_scale 417 | set_seed_DiffAug(param) 418 | sx = torch.rand(x.shape[0]) * (ratio - 1.0/ratio) + 1.0/ratio 419 | set_seed_DiffAug(param) 420 | sy = torch.rand(x.shape[0]) * (ratio - 1.0/ratio) + 1.0/ratio 421 | theta = [[[sx[i], 0, 0], 422 | [0, sy[i], 0],] for i in range(x.shape[0])] 423 | theta = torch.tensor(theta, dtype=torch.float) 424 | if param.batchmode: # batch-wise: 425 | theta[:] = theta[0] 426 | grid = F.affine_grid(theta, x.shape, align_corners=True).to(x.device) 427 | x = F.grid_sample(x, grid, align_corners=True) 428 | return x 429 | 430 | 431 | def rand_rotate(x, param): # [-180, 180], 90: anticlockwise 90 degree 432 | ratio = param.ratio_rotate 433 | set_seed_DiffAug(param) 434 | theta = (torch.rand(x.shape[0]) - 0.5) * 2 * ratio / 180 * float(np.pi) 435 | theta = [[[torch.cos(theta[i]), torch.sin(-theta[i]), 0], 436 | [torch.sin(theta[i]), torch.cos(theta[i]), 0],] for i in range(x.shape[0])] 437 | theta = torch.tensor(theta, dtype=torch.float) 438 | if param.batchmode: # batch-wise: 439 | theta[:] = theta[0] 440 | grid = F.affine_grid(theta, x.shape, align_corners=True).to(x.device) 441 | x = F.grid_sample(x, grid, align_corners=True) 442 | return x 443 | 444 | 445 | def rand_flip(x, param): 446 | prob = param.prob_flip 447 | set_seed_DiffAug(param) 448 | randf = torch.rand(x.size(0), 1, 1, 1, device=x.device) 449 | if param.batchmode: # batch-wise: 450 | randf[:] = randf[0] 451 | return torch.where(randf < prob, x.flip(3), x) 452 | 453 | 454 | def rand_brightness(x, param): 455 | ratio = param.brightness 456 | set_seed_DiffAug(param) 457 | randb = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) 458 | if param.batchmode: # batch-wise: 459 | randb[:] = randb[0] 460 | x = x + (randb - 0.5)*ratio 461 | return x 462 | 463 | 464 | def rand_saturation(x, param): 465 | ratio = param.saturation 466 | x_mean = x.mean(dim=1, keepdim=True) 467 | set_seed_DiffAug(param) 468 | rands = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) 469 | if param.batchmode: # batch-wise: 470 | rands[:] = rands[0] 471 | x = (x - x_mean) * (rands * ratio) + x_mean 472 | return x 473 | 474 | 475 | def rand_contrast(x, param): 476 | ratio = param.contrast 477 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True) 478 | set_seed_DiffAug(param) 479 | randc = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) 480 | if param.batchmode: # batch-wise: 481 | randc[:] = randc[0] 482 | x = (x - x_mean) * (randc + ratio) + x_mean 483 | return x 484 | 485 | 486 | def rand_crop(x, param): 487 | # The image is padded on its surrounding and then cropped. 488 | ratio = param.ratio_crop_pad 489 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 490 | set_seed_DiffAug(param) 491 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) 492 | set_seed_DiffAug(param) 493 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) 494 | if param.batchmode: # batch-wise: 495 | translation_x[:] = translation_x[0] 496 | translation_y[:] = translation_y[0] 497 | grid_batch, grid_x, grid_y = torch.meshgrid( 498 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 499 | torch.arange(x.size(2), dtype=torch.long, device=x.device), 500 | torch.arange(x.size(3), dtype=torch.long, device=x.device), 501 | ) 502 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) 503 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) 504 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) 505 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2) 506 | return x 507 | 508 | 509 | def rand_cutout(x, param): 510 | ratio = param.ratio_cutout 511 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 512 | set_seed_DiffAug(param) 513 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) 514 | set_seed_DiffAug(param) 515 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) 516 | if param.batchmode: # batch-wise: 517 | offset_x[:] = offset_x[0] 518 | offset_y[:] = offset_y[0] 519 | grid_batch, grid_x, grid_y = torch.meshgrid( 520 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 521 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device), 522 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device), 523 | ) 524 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) 525 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) 526 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) 527 | mask[grid_batch, grid_x, grid_y] = 0 528 | x = x * mask.unsqueeze(1) 529 | return x 530 | 531 | 532 | AUGMENT_FNS = { 533 | 'color': [rand_brightness, rand_saturation, rand_contrast], 534 | 'crop': [rand_crop], 535 | 'cutout': [rand_cutout], 536 | 'flip': [rand_flip], 537 | 'scale': [rand_scale], 538 | 'rotate': [rand_rotate], 539 | } 540 | 541 | 542 | def pre_question(question,max_ques_words=50): 543 | question = re.sub( 544 | r"([.!\"()*#:;~])", 545 | '', 546 | question.lower(), 547 | ) 548 | question = question.rstrip(' ') 549 | 550 | #truncate question 551 | question_words = question.split(' ') 552 | if len(question_words)>max_ques_words: 553 | question = ' '.join(question_words[:max_ques_words]) 554 | 555 | return question 556 | 557 | 558 | def save_result(result, result_dir, filename, remove_duplicate=''): 559 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank())) 560 | final_result_file = os.path.join(result_dir, '%s.json'%filename) 561 | 562 | json.dump(result,open(result_file,'w')) 563 | 564 | dist.barrier() 565 | 566 | if utils.is_main_process(): 567 | # combine results from all processes 568 | result = [] 569 | 570 | for rank in range(utils.get_world_size()): 571 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank)) 572 | res = json.load(open(result_file,'r')) 573 | result += res 574 | 575 | if remove_duplicate: 576 | result_new = [] 577 | id_list = [] 578 | for res in result: 579 | if res[remove_duplicate] not in id_list: 580 | id_list.append(res[remove_duplicate]) 581 | result_new.append(res) 582 | result = result_new 583 | 584 | json.dump(result,open(final_result_file,'w')) 585 | print('result file saved to %s'%final_result_file) 586 | 587 | return final_result_file 588 | 589 | 590 | 591 | #### everything below is from https://github.com/salesforce/BLIP/blob/main/utils.py 592 | import math 593 | 594 | def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr): 595 | """Decay the learning rate""" 596 | lr = (init_lr - min_lr) * 0.5 * (1. + math.cos(math.pi * epoch / max_epoch)) + min_lr 597 | for param_group in optimizer.param_groups: 598 | param_group['lr'] = lr 599 | 600 | def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr): 601 | """Warmup the learning rate""" 602 | lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max_step) 603 | for param_group in optimizer.param_groups: 604 | param_group['lr'] = lr 605 | 606 | def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate): 607 | """Decay the learning rate""" 608 | lr = max(min_lr, init_lr * (decay_rate**epoch)) 609 | for param_group in optimizer.param_groups: 610 | param_group['lr'] = lr 611 | 612 | import numpy as np 613 | import io 614 | import os 615 | import time 616 | from collections import defaultdict, deque 617 | import datetime 618 | 619 | import torch 620 | import torch.distributed as dist 621 | 622 | 623 | class MetricLogger(object): 624 | def __init__(self, delimiter="\t"): 625 | self.meters = defaultdict(SmoothedValue) 626 | self.delimiter = delimiter 627 | 628 | def update(self, **kwargs): 629 | for k, v in kwargs.items(): 630 | if isinstance(v, torch.Tensor): 631 | v = v.item() 632 | assert isinstance(v, (float, int)) 633 | self.meters[k].update(v) 634 | 635 | def __getattr__(self, attr): 636 | if attr in self.meters: 637 | return self.meters[attr] 638 | if attr in self.__dict__: 639 | return self.__dict__[attr] 640 | raise AttributeError("'{}' object has no attribute '{}'".format( 641 | type(self).__name__, attr)) 642 | 643 | def __str__(self): 644 | loss_str = [] 645 | for name, meter in self.meters.items(): 646 | loss_str.append( 647 | "{}: {}".format(name, str(meter)) 648 | ) 649 | return self.delimiter.join(loss_str) 650 | 651 | def global_avg(self): 652 | loss_str = [] 653 | for name, meter in self.meters.items(): 654 | loss_str.append( 655 | "{}: {:.4f}".format(name, meter.global_avg) 656 | ) 657 | return self.delimiter.join(loss_str) 658 | 659 | def synchronize_between_processes(self): 660 | for meter in self.meters.values(): 661 | meter.synchronize_between_processes() 662 | 663 | def add_meter(self, name, meter): 664 | self.meters[name] = meter 665 | 666 | def log_every(self, iterable, print_freq, header=None): 667 | i = 0 668 | if not header: 669 | header = '' 670 | start_time = time.time() 671 | end = time.time() 672 | iter_time = SmoothedValue(fmt='{avg:.4f}') 673 | data_time = SmoothedValue(fmt='{avg:.4f}') 674 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 675 | log_msg = [ 676 | header, 677 | '[{0' + space_fmt + '}/{1}]', 678 | 'eta: {eta}', 679 | '{meters}', 680 | 'time: {time}', 681 | 'data: {data}' 682 | ] 683 | if torch.cuda.is_available(): 684 | log_msg.append('max mem: {memory:.0f}') 685 | log_msg = self.delimiter.join(log_msg) 686 | MB = 1024.0 * 1024.0 687 | for obj in iterable: 688 | data_time.update(time.time() - end) 689 | yield obj 690 | iter_time.update(time.time() - end) 691 | if i % print_freq == 0 or i == len(iterable) - 1: 692 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 693 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 694 | if torch.cuda.is_available(): 695 | print(log_msg.format( 696 | i, len(iterable), eta=eta_string, 697 | meters=str(self), 698 | time=str(iter_time), data=str(data_time), 699 | memory=torch.cuda.max_memory_allocated() / MB)) 700 | else: 701 | print(log_msg.format( 702 | i, len(iterable), eta=eta_string, 703 | meters=str(self), 704 | time=str(iter_time), data=str(data_time))) 705 | i += 1 706 | end = time.time() 707 | total_time = time.time() - start_time 708 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 709 | print('{} Total time: {} ({:.4f} s / it)'.format( 710 | header, total_time_str, total_time / len(iterable))) 711 | 712 | 713 | 714 | class SmoothedValue(object): 715 | """Track a series of values and provide access to smoothed values over a 716 | window or the global series average. 717 | """ 718 | 719 | def __init__(self, window_size=20, fmt=None): 720 | if fmt is None: 721 | fmt = "{median:.4f} ({global_avg:.4f})" 722 | self.deque = deque(maxlen=window_size) 723 | self.total = 0.0 724 | self.count = 0 725 | self.fmt = fmt 726 | 727 | def update(self, value, n=1): 728 | self.deque.append(value) 729 | self.count += n 730 | self.total += value * n 731 | 732 | def synchronize_between_processes(self): 733 | """ 734 | Warning: does not synchronize the deque! 735 | """ 736 | if not is_dist_avail_and_initialized(): 737 | return 738 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 739 | dist.barrier() 740 | dist.all_reduce(t) 741 | t = t.tolist() 742 | self.count = int(t[0]) 743 | self.total = t[1] 744 | 745 | @property 746 | def median(self): 747 | d = torch.tensor(list(self.deque)) 748 | return d.median().item() 749 | 750 | @property 751 | def avg(self): 752 | d = torch.tensor(list(self.deque), dtype=torch.float32) 753 | return d.mean().item() 754 | 755 | @property 756 | def global_avg(self): 757 | return self.total / self.count 758 | 759 | @property 760 | def max(self): 761 | return max(self.deque) 762 | 763 | @property 764 | def value(self): 765 | return self.deque[-1] 766 | 767 | def __str__(self): 768 | return self.fmt.format( 769 | median=self.median, 770 | avg=self.avg, 771 | global_avg=self.global_avg, 772 | max=self.max, 773 | value=self.value) 774 | 775 | class AttrDict(dict): 776 | def __init__(self, *args, **kwargs): 777 | super(AttrDict, self).__init__(*args, **kwargs) 778 | self.__dict__ = self 779 | 780 | 781 | def compute_acc(logits, label, reduction='mean'): 782 | ret = (torch.argmax(logits, dim=1) == label).float() 783 | if reduction == 'none': 784 | return ret.detach() 785 | elif reduction == 'mean': 786 | return ret.mean().item() 787 | 788 | def compute_n_params(model, return_str=True): 789 | tot = 0 790 | for p in model.parameters(): 791 | w = 1 792 | for x in p.shape: 793 | w *= x 794 | tot += w 795 | if return_str: 796 | if tot >= 1e6: 797 | return '{:.1f}M'.format(tot / 1e6) 798 | else: 799 | return '{:.1f}K'.format(tot / 1e3) 800 | else: 801 | return tot 802 | 803 | def setup_for_distributed(is_master): 804 | """ 805 | This function disables printing when not in master process 806 | """ 807 | import builtins as __builtin__ 808 | builtin_print = __builtin__.print 809 | 810 | def print(*args, **kwargs): 811 | force = kwargs.pop('force', False) 812 | if is_master or force: 813 | builtin_print(*args, **kwargs) 814 | 815 | __builtin__.print = print 816 | 817 | 818 | def is_dist_avail_and_initialized(): 819 | if not dist.is_available(): 820 | return False 821 | if not dist.is_initialized(): 822 | return False 823 | return True 824 | 825 | 826 | def get_world_size(): 827 | if not is_dist_avail_and_initialized(): 828 | return 1 829 | return dist.get_world_size() 830 | 831 | 832 | def get_rank(): 833 | if not is_dist_avail_and_initialized(): 834 | return 0 835 | return dist.get_rank() 836 | 837 | 838 | def is_main_process(): 839 | return get_rank() == 0 840 | 841 | 842 | def save_on_master(*args, **kwargs): 843 | if is_main_process(): 844 | torch.save(*args, **kwargs) 845 | 846 | 847 | def init_distributed_mode(args): 848 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 849 | args.rank = int(os.environ["RANK"]) 850 | args.world_size = int(os.environ['WORLD_SIZE']) 851 | args.gpu = int(os.environ['LOCAL_RANK']) 852 | elif 'SLURM_PROCID' in os.environ: 853 | args.rank = int(os.environ['SLURM_PROCID']) 854 | args.gpu = args.rank % torch.cuda.device_count() 855 | else: 856 | print('Not using distributed mode') 857 | args.distributed = False 858 | return 859 | 860 | args.distributed = True 861 | 862 | torch.cuda.set_device(args.gpu) 863 | args.dist_backend = 'nccl' 864 | print('| distributed init (rank {}, word {}): {}'.format( 865 | args.rank, args.world_size, args.dist_url), flush=True) 866 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 867 | world_size=args.world_size, rank=args.rank) 868 | torch.distributed.barrier() 869 | setup_for_distributed(args.rank == 0) 870 | 871 | -------------------------------------------------------------------------------- /src/vl_distill_utils.py: -------------------------------------------------------------------------------- 1 | """Move some basic utils in distill.py in VL-Distill here""" 2 | import os 3 | import numpy as np 4 | import torch 5 | from sklearn.metrics.pairwise import cosine_similarity 6 | from src.networks import TextEncoder 7 | 8 | __all__ = [ 9 | "shuffle_files", 10 | "nearest_neighbor", 11 | "get_images_texts", 12 | "load_or_process_file", 13 | ] 14 | 15 | 16 | def shuffle_files(img_expert_files, txt_expert_files): 17 | # Check if both lists have the same length and if the lists are not empty 18 | assert len(img_expert_files) == len(txt_expert_files), "Number of image files and text files does not match" 19 | assert len(img_expert_files) != 0, "No files to shuffle" 20 | shuffled_indices = np.random.permutation(len(img_expert_files)) 21 | 22 | # Apply the shuffled indices to both lists 23 | img_expert_files = np.take(img_expert_files, shuffled_indices) 24 | txt_expert_files = np.take(txt_expert_files, shuffled_indices) 25 | return img_expert_files, txt_expert_files 26 | 27 | def nearest_neighbor(sentences, query_embeddings, database_embeddings): 28 | """Find the nearest neighbors for a batch of embeddings. 29 | 30 | Args: 31 | sentences: The original sentences from which the embeddings were computed. 32 | query_embeddings: A batch of embeddings for which to find the nearest neighbors. 33 | database_embeddings: All pre-computed embeddings. 34 | 35 | Returns: 36 | A list of the most similar sentences for each embedding in the batch. 37 | """ 38 | nearest_neighbors = [] 39 | 40 | for query in query_embeddings: 41 | similarities = cosine_similarity(query.reshape(1, -1), database_embeddings) 42 | 43 | most_similar_index = np.argmax(similarities) 44 | 45 | nearest_neighbors.append(sentences[most_similar_index]) 46 | 47 | return nearest_neighbors 48 | 49 | 50 | def get_images_texts(n, dataset, args, i_have_indices=None): 51 | """Get random n images and corresponding texts from the dataset. 52 | 53 | Args: 54 | n: Number of images and texts to retrieve. 55 | dataset: The dataset containing image-text pairs. 56 | 57 | Returns: 58 | A tuple containing two elements: 59 | - A tensor of randomly selected images. 60 | - A tensor of the corresponding texts, encoded as floats. 61 | """ 62 | # Generate n unique random indices 63 | if i_have_indices is not None: 64 | idx_shuffle = i_have_indices 65 | else: 66 | idx_shuffle = np.random.permutation(len(dataset))[:n] 67 | 68 | # Initialize the text encoder 69 | text_encoder = TextEncoder(args) 70 | 71 | image_syn = torch.stack([dataset[i][0] for i in idx_shuffle]) 72 | 73 | text_syn = text_encoder([dataset[i][1] for i in idx_shuffle], device="cpu") 74 | 75 | return image_syn, text_syn.float() 76 | 77 | 78 | def load_or_process_file(file_type, process_func, args, data_source): 79 | """ 80 | Load the processed file if it exists, otherwise process the data source and create the file. 81 | 82 | Args: 83 | file_type: The type of the file (e.g., 'train', 'test'). 84 | process_func: The function to process the data source. 85 | args: The arguments required by the process function and to build the filename. 86 | data_source: The source data to be processed. 87 | 88 | Returns: 89 | The loaded data from the file. 90 | """ 91 | filename = f'{args.dataset}_{args.text_encoder}_{file_type}_embed.npz' 92 | 93 | 94 | if not os.path.exists(filename): 95 | print(f'Creating {filename}') 96 | process_func(args, data_source) 97 | else: 98 | print(f'Loading {filename}') 99 | 100 | return np.load(filename) 101 | 102 | def get_LC_images_texts(n, dataset, args): 103 | """Get random n images and corresponding texts from the dataset. 104 | 105 | Args: 106 | n: Number of images and texts to retrieve. 107 | dataset: The dataset containing image-text pairs. 108 | 109 | Returns: 110 | A tuple containing two elements: 111 | - A tensor of randomly selected images. 112 | - A tensor of the corresponding texts, encoded as floats. 113 | """ 114 | # Generate n unique random indices 115 | idx_shuffle = np.random.permutation(len(dataset))[:n] 116 | 117 | # Initialize the text encoder 118 | text_encoder = TextEncoder(args) 119 | 120 | image_syn = torch.stack([dataset[i][0] for i in idx_shuffle]) 121 | 122 | text_syn = text_encoder([dataset[i][1] for i in idx_shuffle], device="cpu") 123 | 124 | return image_syn, text_syn.float() 125 | 126 | --------------------------------------------------------------------------------