├── .gitignore
├── LICENSE
├── README.md
├── buffer.py
├── data
├── Flickr30k_ann
│ ├── coco_karpathy_test.json
│ ├── coco_karpathy_train.json
│ ├── coco_karpathy_val.json
│ ├── flickr30k_test.json
│ ├── flickr30k_train.json
│ └── flickr30k_val.json
├── __init__.py
├── cifar_dataset.py
├── coco_dataset.py
├── flickr30k_dataset.py
└── randaugment.py
├── distill_tesla_lors.py
├── evaluate_only.py
├── images
└── method.png
├── requirements.txt
├── sh
├── buffer_coco.sh
├── buffer_flickr.sh
├── distill_coco_lors_100.sh
├── distill_coco_lors_200.sh
├── distill_coco_lors_500.sh
├── distill_flickr_lors_100.sh
├── distill_flickr_lors_200.sh
└── distill_flickr_lors_500.sh
└── src
├── epoch.py
├── model.py
├── networks.py
├── reparam_module.py
├── similarity_mining.py
├── utils.py
└── vl_distill_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | /buffer
2 | distill_utils/checkpoints
3 | distill_utils/data
4 | tmp
5 | logged_files
6 | wandb
7 |
8 | # just simply ignore these files
9 | __pycache__/
10 | .DS_Store
11 | .ipynb_checkpoints
12 | .vscode
13 | *.pyc
14 | *.npz
15 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2024, Yue Xu
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | 1. Redistributions of source code must retain the above copyright notice, this
9 | list of conditions and the following disclaimer.
10 |
11 | 2. Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | 3. Neither the name of the copyright holder nor the names of its
16 | contributors may be used to endorse or promote products derived from
17 | this software without specific prior written permission.
18 |
19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LoRS: Low-Rank Similarity Mining
2 |
3 | ### [Paper](http://arxiv.org/abs/2406.03793)
4 |
5 | This repo contains code of our ICML'24 work **LoRS**: **Lo**w-**R**ank **S**imilarity Mining for Multimodal Dataset Distillation. LoRS propose to learn the similarity matrix during distilling the image and text. The simple and plug-and-play method yields significant performance gain. Please check our [paper](http://arxiv.org/abs/2406.03793) for more analysis.
6 |
7 |
8 | 
9 |
10 |
11 | ## Getting Started
12 |
13 | **Requirements**: please see `requirements.txt`.
14 |
15 | **Pretrained model checkpoints**: you may manually download checkpoint of BERT, NFNet (from TIMM) and put them here:
16 |
17 | ```
18 | distill_utils/checkpoints/
19 | ├── bert-base-uncased/
20 | │ ├── config.json
21 | │ ├── LICENSE.txt
22 | │ ├── model.onnx
23 | │ ├── pytorch_model.bin
24 | │ ├── vocab.txt
25 | │ └── ......
26 | └── nfnet_l0_ra2-45c6688d.pth
27 | ```
28 |
29 | **Datasets**: please download Flickr30K: [[Train]](https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json)[[Val]](https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json)[[Test]](https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json)[[Images]](https://www.kaggle.com/datasets/hsankesara/flickr-image-dataset) and COCO: [[Train]](https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json)[[Val]](https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json)[[Test]](https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json)[[Images]](https://cocodataset.org/#download) datasets, and put them here:
30 |
31 | ```
32 | ./distill_utils/data/
33 | ├── Flickr30k/
34 | │ ├── flickr30k-images/
35 | │ │ ├── 1234.jpg
36 | │ │ └── ......
37 | │ ├── results_20130124.token
38 | │ └── readme.txt
39 | └── COCO/
40 | ├── train2014/
41 | ├── val2014/
42 | └── test2014/
43 | ```
44 |
45 |
46 |
47 | **Training Expert Buffer**: *e.g.* run `sh sh/buffer_flickr.sh`. The expert training takes days. You could manually split the `num_experts` and run multiple processes.
48 |
49 | **Distill with LoRS**: *e.g.* run `sh sh/distill_flickr_lors_100.sh`. The distillation could be run on one single RTX 3090/4090 thanks to [TESLA](https://github.com/justincui03/tesla).
50 |
51 |
52 |
53 | ## Citation
54 |
55 | If you find our work useful and inspiring, please cite our paper:
56 | ```
57 | @article{xu2024lors,
58 | title={Low-Rank Similarity Mining for Multimodal Dataset Distillation},
59 | author={Xu, Yue and Lin, Zhilin and Qiu, Yusong and Lu, Cewu and Li, Yong-Lu},
60 | journal={arXiv e-prints},
61 | pages={arXiv--2406},
62 | year={2024}
63 | }
64 | ```
65 |
66 |
67 |
68 |
69 | ## Acknowledgement
70 |
71 | We following the setting and code of [VL-Distill](https://github.com/princetonvisualai/multimodal_dataset_distillation) and re-implement the algorithm with [TESLA](https://github.com/justincui03/tesla). We deeply appreciate their valuable contribution!
72 |
--------------------------------------------------------------------------------
/buffer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torch
4 | import torch.nn as nn
5 | from tqdm import tqdm, trange
6 | from src.vl_distill_utils import load_or_process_file
7 | from src.epoch import epoch, epoch_test, itm_eval
8 | import copy
9 | import wandb
10 | import warnings
11 | import datetime
12 | from data import get_dataset_flickr, textprocess, textprocess_train
13 | from src.networks import CLIPModel_full
14 | import numpy as np
15 |
16 | warnings.filterwarnings("ignore", category=DeprecationWarning)
17 |
18 | def main(args):
19 | if args.disabled_wandb == True:
20 | wandb.init(mode="disabled")
21 | else:
22 | no_aug_suffix = "_NoAug" if args.no_aug else ""
23 | wandb.init(project='LoRS-Buffer', config=args,
24 | name=f"{args.dataset}_{args.image_encoder}_{args.text_encoder}_{args.loss_type}{no_aug_suffix}")
25 |
26 | args.dsa = True if args.dsa == 'True' else False
27 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
28 | args.distributed = torch.cuda.device_count() > 1
29 |
30 |
31 | print('Hyper-parameters: \n', args.__dict__)
32 |
33 | save_dir = os.path.join(args.buffer_path, args.dataset)
34 | if args.dataset in ["CIFAR10", "CIFAR100"] and not args.zca:
35 | save_dir += "_NO_ZCA"
36 | save_dir = os.path.join(save_dir, args.image_encoder+"_"+args.text_encoder, args.loss_type+no_aug_suffix)
37 | if not os.path.exists(save_dir):
38 | os.makedirs(save_dir, exist_ok=True)
39 | ''' organize the datasets '''
40 | trainloader, testloader, train_dataset, test_dataset = get_dataset_flickr(args)
41 |
42 | train_sentences = train_dataset.get_all_captions()
43 | _ = load_or_process_file('text', textprocess, args, testloader)
44 | _ = load_or_process_file('train_text', textprocess_train, args, train_sentences)
45 |
46 | data = np.load(f'{args.dataset}_{args.text_encoder}_text_embed.npz')
47 | bert_test_embed_loaded = data['bert_test_embed']
48 | bert_test_embed = torch.from_numpy(bert_test_embed_loaded).cpu()
49 |
50 |
51 | img_trajectories = []
52 | txt_trajectories = []
53 |
54 | for it in range(0, args.num_experts):
55 |
56 | ''' Train synthetic data '''
57 |
58 | teacher_net = CLIPModel_full(args).to(args.device)
59 | img_teacher_net = teacher_net.image_encoder.to(args.device)
60 | txt_teacher_net = teacher_net.text_projection.to(args.device)
61 |
62 | if args.text_trainable:
63 | txt_teacher_net = teacher_net.text_encoder.to(args.device)
64 | if args.distributed:
65 | raise NotImplementedError()
66 | img_teacher_net = torch.nn.DataParallel(img_teacher_net)
67 | txt_teacher_net = torch.nn.DataParallel(txt_teacher_net)
68 |
69 | img_teacher_net.train()
70 | txt_teacher_net.train()
71 |
72 | teacher_optim = torch.optim.SGD([
73 | {'params': img_teacher_net.parameters(), 'lr': args.lr_teacher_img},
74 | {'params': txt_teacher_net.parameters(), 'lr': args.lr_teacher_txt},
75 | ], lr=0, momentum=args.mom, weight_decay=args.l2)
76 | teacher_optim.zero_grad()
77 | lr_schedule = [args.train_epochs // 2 + 1] if args.decay else []
78 | teacher_optim_scheduler = torch.optim.lr_scheduler.MultiStepLR(
79 | teacher_optim, milestones=lr_schedule, gamma=0.1)
80 |
81 |
82 | img_timestamps = []
83 | txt_timestamps = []
84 |
85 | img_timestamps.append([p.detach().cpu() for p in img_teacher_net.parameters()])
86 | txt_timestamps.append([p.detach().cpu() for p in txt_teacher_net.parameters()])
87 |
88 |
89 | for e in trange(args.train_epochs):
90 | train_loss, train_acc = epoch(e, trainloader, teacher_net, teacher_optim, args)
91 |
92 | if (e+1) % args.eval_freq == 0:
93 |
94 | score_val_i2t, score_val_t2i = epoch_test(testloader, teacher_net, args.device, bert_test_embed)
95 | val_result = itm_eval(score_val_i2t, score_val_t2i, testloader.dataset.txt2img, testloader.dataset.img2txt)
96 |
97 | wandb.log({
98 | "Loss/train_loss": train_loss,
99 | "Loss/train_acc": train_acc,
100 | "Results/txt_r1": val_result['txt_r1'],
101 | "Results/txt_r5": val_result['txt_r5'],
102 | "Results/txt_r10": val_result['txt_r10'],
103 | # "txt_r_mean": val_result['txt_r_mean'],
104 | "Results/img_r1": val_result['img_r1'],
105 | "Results/img_r5": val_result['img_r5'],
106 | "Results/img_r10": val_result['img_r10'],
107 | # "img_r_mean": val_result['img_r_mean'],
108 | "Results/r_mean": val_result['r_mean'],
109 | })
110 |
111 | print("Itr: {} Epoch={} Train Acc={} | Img R@1={} R@5={} R@10={} | Txt R@1={} R@5={} R@10={} | R@Mean={}".format(
112 | it, e, train_acc,
113 | val_result['img_r1'], val_result['img_r5'], val_result['img_r10'],
114 | val_result['txt_r1'], val_result['txt_r5'], val_result['txt_r10'], val_result['r_mean']))
115 |
116 |
117 | img_timestamps.append([p.detach().cpu() for p in img_teacher_net.parameters()])
118 | txt_timestamps.append([p.detach().cpu() for p in txt_teacher_net.parameters()])
119 |
120 | teacher_optim_scheduler.step()
121 |
122 |
123 | if not args.skip_save:
124 | img_trajectories.append(img_timestamps)
125 | txt_trajectories.append(txt_timestamps)
126 | n = 0
127 | while os.path.exists(os.path.join(save_dir, "img_replay_buffer_{}.pt".format(n))):
128 | n += 1
129 | print("Saving {}".format(os.path.join(save_dir, "img_replay_buffer_{}.pt".format(n))))
130 | torch.save(img_trajectories, os.path.join(save_dir, "img_replay_buffer_{}.pt".format(n)))
131 | print("Saving {}".format(os.path.join(save_dir, "txt_replay_buffer_{}.pt".format(n))))
132 | torch.save(txt_trajectories, os.path.join(save_dir, "txt_replay_buffer_{}.pt".format(n)))
133 |
134 | img_trajectories = []
135 | txt_trajectories = []
136 |
137 |
138 | def make_buffer_parser():
139 | parser = argparse.ArgumentParser(description='Parameter Processing')
140 | parser.add_argument('--dataset', type=str, default='flickr', choices=['flickr', 'coco'], help='dataset')
141 | parser.add_argument('--num_experts', type=int, default=100, help='training iterations')
142 | parser.add_argument('--lr_teacher_img', type=float, default=0.1, help='learning rate for updating network parameters')
143 | parser.add_argument('--lr_teacher_txt', type=float, default=0.1, help='learning rate for updating network parameters')
144 | parser.add_argument('--batch_train', type=int, default=128, help='batch size for training networks')
145 | parser.add_argument('--dsa', type=str, default='True', choices=['True', 'False'],
146 | help='whether to use differentiable Siamese augmentation.')
147 | parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate',
148 | help='differentiable Siamese augmentation strategy')
149 | parser.add_argument('--data_path', type=str, default='./data/Flickr30k/', help='dataset path')
150 | parser.add_argument('--buffer_path', type=str, default='./buffers', help='buffer path')
151 | parser.add_argument('--train_epochs', type=int, default=50)
152 | parser.add_argument('--zca', action='store_true')
153 | parser.add_argument('--decay', action='store_true')
154 | parser.add_argument('--mom', type=float, default=0, help='momentum')
155 | parser.add_argument('--l2', type=float, default=0, help='l2 regularization')
156 | parser.add_argument('--save_interval', type=int, default=10)
157 | current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
158 | parser.add_argument('--name', type=str, default=current_time, help='name of wandb run')
159 | parser.add_argument('--text_pretrained', type=bool, default=True, help='text_pretrained')
160 | parser.add_argument('--image_pretrained', type=bool, default=True, help='image_pretrained')
161 | parser.add_argument('--text_trainable', type=bool, default=False, help='text_trainable')
162 | parser.add_argument('--image_trainable', type=bool, default=True, help='image_trainable')
163 | parser.add_argument('--batch_size_train', type=int, default=128, help='batch_size_train')
164 | parser.add_argument('--batch_size_test', type=int, default=128, help='batch_size_test')
165 | parser.add_argument('--image_root', type=str, default='distill_utils/data/Flickr30k/', help='location of image root')
166 | parser.add_argument('--ann_root', type=str, default='./data/Flickr30k_ann/', help='location of ann root')
167 | parser.add_argument('--image_size', type=int, default=224, help='image_size')
168 | parser.add_argument('--k_test', type=int, default=128, help='k_test')
169 | parser.add_argument('--load_npy', type=bool, default=False, help='load_npy')
170 | parser.add_argument('--image_encoder', type=str, default='resnet50', help='image encoder')
171 | #, choices=['nfnet', 'resnet18_gn', 'vit_tiny', 'nf_resnet50', 'nf_regnet'])
172 | parser.add_argument('--text_encoder', type=str, default='bert', choices=['bert', 'clip', 'distilbert','gpt1'], help='text encoder')
173 | parser.add_argument('--margin', default=0.2, type=float,
174 | help='Rank loss margin.')
175 | parser.add_argument('--measure', default='cosine',
176 | help='Similarity measure used (cosine|order)')
177 | parser.add_argument('--max_violation', action='store_true',
178 | help='Use max instead of sum in the rank loss.')
179 | parser.add_argument('--only_has_image_projection', type=bool, default=False, help='None')
180 | parser.add_argument('--grounding', type=bool, default=False, help='None')
181 |
182 | parser.add_argument('--distill', type=bool, default=False, help='whether distill')
183 | parser.add_argument('--loss_type', type=str, default="InfoNCE")
184 |
185 | parser.add_argument('--eval_freq', type=int, default=5, help='eval_freq')
186 | parser.add_argument('--no_aug', action='store_true', help='no_aug')
187 | parser.add_argument('--skip_save', action='store_true', help='skip save buffer')
188 | parser.add_argument('--disabled_wandb', type=bool, default=False)
189 | return parser
190 |
191 |
192 | if __name__ == '__main__':
193 | parser = make_buffer_parser()
194 | args = parser.parse_args()
195 |
196 | args.image_root = {
197 | 'flickr': "distill_utils/data/Flickr30k/",
198 | 'coco': "distill_utils/data/COCO/",
199 | }[args.dataset]
200 |
201 | main(args)
202 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 | from data.randaugment import RandomAugment
3 | from torchvision.transforms.functional import InterpolationMode
4 | import torch
5 | from torch.utils.data import DataLoader
6 | from torch.utils.data import Dataset
7 | from torchvision.datasets.utils import download_url
8 | import json
9 | from PIL import Image
10 | import os
11 | from torchvision import transforms as T
12 | from src.networks import CLIPModel_full
13 | from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval, pre_caption
14 | from data.coco_dataset import coco_train, coco_caption_eval, coco_retrieval_eval
15 | import numpy as np
16 | from tqdm import tqdm
17 | @torch.no_grad()
18 | def textprocess(args, testloader):
19 | net = CLIPModel_full(args).to('cuda')
20 | net.eval()
21 | texts = testloader.dataset.text
22 | if args.dataset in ['flickr', 'coco']:
23 | if args.dataset == 'flickr':
24 | bert_test_embed = net.text_encoder(texts)
25 | elif args.dataset == 'coco':
26 | bert_test_embed = torch.cat((net.text_encoder(texts[:10000]), net.text_encoder(texts[10000:20000]), net.text_encoder(texts[20000:])), dim=0)
27 |
28 | bert_test_embed_np = bert_test_embed.cpu().numpy()
29 | np.savez(f'{args.dataset}_{args.text_encoder}_text_embed.npz', bert_test_embed=bert_test_embed_np)
30 | else:
31 | raise NotImplementedError
32 | return
33 |
34 | @torch.no_grad()
35 | def textprocess_train(args, texts):
36 | net = CLIPModel_full(args).to('cuda')
37 | net.eval()
38 | chunk_size = 2000
39 | chunks = []
40 | for i in tqdm(range(0, len(texts), chunk_size)):
41 | chunk = net.text_encoder(texts[i:i + chunk_size]).cpu()
42 | chunks.append(chunk)
43 | del chunk
44 | torch.cuda.empty_cache() # free up memory
45 | bert_test_embed = torch.cat(chunks, dim=0)
46 |
47 | print('bert_test_embed.shape: ', bert_test_embed.shape)
48 | bert_test_embed_np = bert_test_embed.numpy()
49 | if args.dataset in ['flickr', 'coco']:
50 | np.savez(f'{args.dataset}_{args.text_encoder}_train_text_embed.npz', bert_test_embed=bert_test_embed_np)
51 | else:
52 | raise NotImplementedError
53 | return
54 |
55 |
56 | def create_dataset(args, min_scale=0.5):
57 | normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
58 | transform_train = transforms.Compose([
59 | transforms.RandomResizedCrop(args.image_size,scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC),
60 | transforms.RandomHorizontalFlip(),
61 | RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize',
62 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
63 | transforms.ToTensor(),
64 | normalize,
65 | ])
66 | transform_test = transforms.Compose([
67 | transforms.Resize((args.image_size, args.image_size),interpolation=InterpolationMode.BICUBIC),
68 | transforms.ToTensor(),
69 | normalize,
70 | ])
71 |
72 | if args.no_aug:
73 | transform_train = transform_test # no augmentation
74 |
75 | if args.dataset=='flickr':
76 | train_dataset = flickr30k_train(transform_train, args.image_root, args.ann_root)
77 | val_dataset = flickr30k_retrieval_eval(transform_test, args.image_root, args.ann_root, 'val')
78 | test_dataset = flickr30k_retrieval_eval(transform_test, args.image_root, args.ann_root, 'test')
79 | return train_dataset, val_dataset, test_dataset
80 |
81 | elif args.dataset=='coco':
82 | train_dataset = coco_train(transform_train, args.image_root, args.ann_root)
83 | val_dataset = coco_retrieval_eval(transform_test, args.image_root, args.ann_root, 'val')
84 | test_dataset = coco_retrieval_eval(transform_test, args.image_root, args.ann_root, 'test')
85 | return train_dataset, val_dataset, test_dataset
86 |
87 | else:
88 | raise NotImplementedError
89 |
90 | def create_sampler(datasets, shuffles, num_tasks, global_rank):
91 | samplers = []
92 | for dataset,shuffle in zip(datasets,shuffles):
93 | sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle)
94 | samplers.append(sampler)
95 | return samplers
96 |
97 |
98 | def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
99 | loaders = []
100 | for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns):
101 | if is_train:
102 | shuffle = (sampler is None)
103 | drop_last = True
104 | else:
105 | shuffle = False
106 | drop_last = False
107 | loader = DataLoader(
108 | dataset,
109 | batch_size=bs,
110 | num_workers=n_worker,
111 | pin_memory=True,
112 | sampler=sampler,
113 | shuffle=shuffle,
114 | collate_fn=collate_fn,
115 | drop_last=drop_last,
116 | )
117 | loaders.append(loader)
118 | return loaders
119 |
120 |
121 | def get_dataset_flickr(args):
122 | print("Creating retrieval dataset")
123 | train_dataset, val_dataset, test_dataset = create_dataset(args)
124 |
125 | samplers = [None, None, None]
126 | train_shuffle = True
127 | train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers,
128 | batch_size=[args.batch_size_train]+[args.batch_size_test]*2,
129 | num_workers=[4,4,4],
130 | is_trains=[train_shuffle, False, False],
131 | collate_fns=[None,None,None])
132 |
133 | return train_loader, test_loader, train_dataset, test_dataset
134 |
135 |
--------------------------------------------------------------------------------
/data/cifar_dataset.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from typing import Any, Tuple
4 | from torch.utils.data import Dataset
5 | from PIL import Image
6 | import numpy as np
7 | from torchvision.datasets import CIFAR10
8 | from collections import defaultdict
9 |
10 |
11 | CLASSES = [
12 | "airplane",
13 | "automobile",
14 | "bird",
15 | "cat",
16 | "deer",
17 | "dog",
18 | "frog",
19 | "horse",
20 | "ship",
21 | "truck",
22 | ]
23 |
24 | PROMPTS_1 = ["This is a {}"]
25 |
26 | PROMPTS_5 = [
27 | "a photo of a {}",
28 | "a blurry image of a {}",
29 | "a photo of the {}",
30 | "a pixelated photo of a {}",
31 | "a picture of a {}",
32 | ]
33 |
34 | PROMPTS = [
35 | 'a photo of a {}',
36 | 'a blurry photo of a {}',
37 | 'a low contrast photo of a {}',
38 | 'a high contrast photo of a {}',
39 | 'a bad photo of a {}',
40 | 'a good photo of a {}',s
41 | 'a photo of a small {}',
42 | 'a photo of a big {}',
43 | 'a photo of the {}',
44 | 'a blurry photo of the {}',
45 | 'a low contrast photo of the {}',
46 | 'a high contrast photo of the {}',
47 | 'a bad photo of the {}',
48 | 'a good photo of the {}',
49 | 'a photo of the small {}',
50 | 'a photo of the big {}',
51 | ]
52 |
53 |
54 | class cifar10_train(CIFAR10):
55 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False, num_prompts=1):
56 | super(cifar10_train, self).__init__(root, train=train, transform=transform, target_transform=target_transform, download=download)
57 | if num_prompts == 1:
58 | self.prompts = PROMPTS_1
59 | elif num_prompts == 5:
60 | self.prompts = PROMPTS_5
61 | else:
62 | self.prompts = PROMPTS
63 | self.captions = [prompt.format(cls) for cls in CLASSES for prompt in self.prompts]
64 | self.captions_to_label = {cap: i for i, cap in enumerate(self.captions)}
65 | self.annotations = []
66 | for i in range(len(self.data)):
67 | cls_name = CLASSES[self.targets[i]]
68 | for prompt in self.prompts:
69 | caption = prompt.format(cls_name)
70 | self.annotations.append({"img_id": i, "caption_id": self.captions_to_label[caption]})
71 | if num_prompts == 1:
72 | self.annotations = self.annotations * 5
73 |
74 | def __len__(self):
75 | return len(self.annotations)
76 |
77 | def __getitem__(self, index):
78 | ann = self.annotations[index]
79 | img_id = ann['img_id']
80 | img = self.transform(self.data[img_id])
81 | caption = self.captions[ann['caption_id']]
82 | return img, caption, img_id
83 |
84 | def fetch_distill_images(self, ipc):
85 | """
86 | Randomly fetch `x` number of images from each class using numpy and return as a tensor.
87 | """
88 | class_indices = defaultdict(list)
89 |
90 | for idx, label in enumerate(self.targets):
91 | class_indices[label].append(idx)
92 |
93 | # Randomly sample x indices for each class using numpy
94 | sampled_indices = [np.random.choice(indices, ipc, replace=False) for indices in class_indices.values()]
95 | sampled_indices = [idx for class_indices in sampled_indices for idx in class_indices]
96 |
97 | # Fetch images and labels using the selected indices
98 | images = torch.stack([self.transform(self.data[i]) for i in sampled_indices])
99 | labels = [self.targets[i] for i in sampled_indices]
100 |
101 | captions = []
102 | for label in labels:
103 | cls_name = CLASSES[label]
104 | prompt = np.random.choice(self.prompts)
105 | random_caption = prompt.format(cls_name)
106 | captions.append(random_caption)
107 |
108 | return images, captions
109 |
110 | def get_all_captions(self):
111 | return self.captions
112 |
113 |
114 | class cifar10_retrieval_eval(cifar10_train):
115 | def __init__(self, root, train=False, transform=None, target_transform=None, download=False, num_prompts=1):
116 | """
117 | image_root (string): Root directory of images (e.g. coco/images/)
118 | ann_root (string): directory to store the annotation file
119 | split (string): val or test
120 | """
121 | super(cifar10_retrieval_eval, self).__init__(root, train=train, transform=transform, target_transform=target_transform, download=download, num_prompts=num_prompts)
122 | self.text = self.captions
123 | self.txt2img = {}
124 | self.img2txt = defaultdict(list)
125 |
126 | for ann in self.annotations:
127 | img_id = ann['img_id']
128 | caption_id = ann['caption_id']
129 | self.img2txt[img_id].append(caption_id)
130 | self.txt2img[caption_id] = img_id
131 |
132 | def __len__(self):
133 | return len(self.data)
134 |
135 | def __getitem__(self, index):
136 | image = self.transform(self.data[index])
137 | return image, index
138 |
--------------------------------------------------------------------------------
/data/coco_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from torch.utils.data import Dataset
4 | from torchvision.datasets.utils import download_url
5 | from PIL import Image
6 | import re
7 |
8 | def pre_caption(caption,max_words=50):
9 | caption = re.sub(
10 | r"([.!\"()*#:;~])",
11 | ' ',
12 | caption.lower(),
13 | )
14 | caption = re.sub(
15 | r"\s{2,}",
16 | ' ',
17 | caption,
18 | )
19 | caption = caption.rstrip('\n')
20 | caption = caption.strip(' ')
21 |
22 | #truncate caption
23 | caption_words = caption.split(' ')
24 | if len(caption_words)>max_words:
25 | caption = ' '.join(caption_words[:max_words])
26 |
27 | return caption
28 |
29 | class coco_train(Dataset):
30 | def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):
31 | '''
32 | image_root (string): Root directory of images (e.g. coco/images/)
33 | ann_root (string): directory to store the annotation file
34 | '''
35 | url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json'
36 | filename = 'coco_karpathy_train.json'
37 |
38 | download_url(url,ann_root)
39 |
40 | self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
41 | self.transform = transform
42 | self.image_root = image_root
43 | self.max_words = max_words
44 | self.prompt = prompt
45 |
46 | self.img_ids = {}
47 | n = 0
48 | for ann in self.annotation:
49 | img_id = ann['image_id']
50 | if img_id not in self.img_ids.keys():
51 | self.img_ids[img_id] = n
52 | n += 1
53 |
54 | def __len__(self):
55 | return len(self.annotation)
56 |
57 | def __getitem__(self, index):
58 |
59 | ann = self.annotation[index]
60 |
61 | image_path = os.path.join(self.image_root,ann['image'])
62 | image = Image.open(image_path).convert('RGB')
63 | image = self.transform(image)
64 |
65 | caption = self.prompt+pre_caption(ann['caption'], self.max_words)
66 |
67 | return image, caption, self.img_ids[ann['image_id']]
68 |
69 | def get_all_captions(self):
70 | captions = []
71 | for ann in self.annotation:
72 | caption = self.prompt + pre_caption(ann['caption'], self.max_words)
73 | captions.append(caption)
74 | return captions
75 |
76 |
77 | class coco_caption_eval(Dataset):
78 | def __init__(self, transform, image_root, ann_root, split):
79 | '''
80 | image_root (string): Root directory of images (e.g. coco/images/)
81 | ann_root (string): directory to store the annotation file
82 | split (string): val or test
83 | '''
84 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
85 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
86 | filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
87 |
88 | download_url(urls[split],ann_root)
89 |
90 | self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
91 | self.transform = transform
92 | self.image_root = image_root
93 |
94 | def __len__(self):
95 | return len(self.annotation)
96 |
97 | def __getitem__(self, index):
98 |
99 | ann = self.annotation[index]
100 |
101 | image_path = os.path.join(self.image_root,ann['image'])
102 | image = Image.open(image_path).convert('RGB')
103 | image = self.transform(image)
104 |
105 | img_id = ann['image'].split('/')[-1].strip('.jpg').split('_')[-1]
106 |
107 | return image, int(img_id)
108 |
109 |
110 | class coco_retrieval_eval(Dataset):
111 | def __init__(self, transform, image_root, ann_root, split, max_words=30):
112 | '''
113 | image_root (string): Root directory of images (e.g. coco/images/)
114 | ann_root (string): directory to store the annotation file
115 | split (string): val or test
116 | '''
117 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
118 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
119 | filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
120 |
121 | download_url(urls[split],ann_root)
122 |
123 | self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
124 | self.transform = transform
125 | self.image_root = image_root
126 |
127 | self.text = []
128 | self.image = []
129 | self.txt2img = {}
130 | self.img2txt = {}
131 |
132 | txt_id = 0
133 | for img_id, ann in enumerate(self.annotation):
134 | self.image.append(ann['image'])
135 | self.img2txt[img_id] = []
136 | for i, caption in enumerate(ann['caption']):
137 | self.text.append(pre_caption(caption,max_words))
138 | self.img2txt[img_id].append(txt_id)
139 | self.txt2img[txt_id] = img_id
140 | txt_id += 1
141 |
142 | def __len__(self):
143 | return len(self.annotation)
144 |
145 | def __getitem__(self, index):
146 |
147 | image_path = os.path.join(self.image_root, self.annotation[index]['image'])
148 | image = Image.open(image_path).convert('RGB')
149 | image = self.transform(image)
150 |
151 | return image, index
--------------------------------------------------------------------------------
/data/flickr30k_dataset.py:
--------------------------------------------------------------------------------
1 | from functools import lru_cache
2 | import os
3 | import torch
4 | from tqdm import tqdm
5 | import numpy as np
6 | import yaml
7 | from transformers import BertTokenizer, BertModel
8 | from torchvision import transforms as T
9 | from torchvision.transforms.functional import InterpolationMode
10 | from torch.utils.data import Dataset
11 | from torchvision.datasets.utils import download_url
12 | import argparse
13 | import re
14 | import json
15 | from PIL import Image
16 |
17 | def pre_caption(caption,max_words=50):
18 | caption = re.sub(
19 | r"([.!\"()*#:;~])",
20 | ' ',
21 | caption.lower(),
22 | )
23 | caption = re.sub(
24 | r"\s{2,}",
25 | ' ',
26 | caption,
27 | )
28 | caption = caption.rstrip('\n')
29 | caption = caption.strip(' ')
30 |
31 | #truncate caption
32 | caption_words = caption.split(' ')
33 | if len(caption_words)>max_words:
34 | caption = ' '.join(caption_words[:max_words])
35 |
36 | return caption
37 |
38 |
39 | class flickr30k_train(Dataset):
40 | def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):
41 | '''
42 | image_root (string): Root directory of images (e.g. flickr30k/)
43 | ann_root (string): directory to store the annotation file
44 | '''
45 | url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json'
46 | filename = 'flickr30k_train.json'
47 |
48 | ####################
49 | # download_url(url,ann_root)
50 | ####################
51 |
52 | self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
53 | self.transform = transform
54 | self.image_root = image_root
55 | self.max_words = max_words
56 | self.prompt = prompt
57 |
58 | self.img_ids = {}
59 | n = 0
60 | for ann in self.annotation:
61 | img_id = ann['image_id']
62 | if img_id not in self.img_ids.keys():
63 | self.img_ids[img_id] = n
64 | n += 1
65 |
66 | def __len__(self):
67 | return len(self.annotation)
68 |
69 | @lru_cache(maxsize=100)
70 | def read_image(self, image_path):
71 | image = Image.open(image_path).convert('RGB')
72 | image = self.transform(image)
73 | return image
74 |
75 | def __getitem__(self, index):
76 |
77 | ann = self.annotation[index]
78 |
79 | image_path = os.path.join(self.image_root,ann['image'])
80 | image = self.read_image(image_path)
81 | # image = Image.open(image_path).convert('RGB')
82 | # image = self.transform(image)
83 |
84 | caption = self.prompt+pre_caption(ann['caption'], self.max_words)
85 |
86 | return image, caption, self.img_ids[ann['image_id']]
87 |
88 | def get_all_captions(self):
89 | captions = []
90 | for ann in self.annotation:
91 | caption = self.prompt + pre_caption(ann['caption'], self.max_words)
92 | captions.append(caption)
93 | return captions
94 |
95 |
96 |
97 | class flickr30k_retrieval_eval(Dataset):
98 | def __init__(self, transform, image_root, ann_root, split, max_words=30):
99 | '''
100 | image_root (string): Root directory of images (e.g. flickr30k/)
101 | ann_root (string): directory to store the annotation file
102 | split (string): val or test
103 | '''
104 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json',
105 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'}
106 | filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'}
107 |
108 | #######################
109 | # download_url(urls[split],ann_root)
110 | #######################
111 |
112 | self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
113 | self.transform = transform
114 | self.image_root = image_root
115 | self.max_words = max_words
116 |
117 | self.text = []
118 | self.image = []
119 | self.txt2img = {}
120 | self.img2txt = {}
121 |
122 | txt_id = 0
123 | for img_id, ann in enumerate(self.annotation):
124 | self.image.append(ann['image'])
125 | self.img2txt[img_id] = []
126 | for i, caption in enumerate(ann['caption']):
127 | self.text.append(pre_caption(caption,max_words))
128 | self.img2txt[img_id].append(txt_id)
129 | self.txt2img[txt_id] = img_id
130 | txt_id += 1
131 |
132 | def __len__(self):
133 | return len(self.annotation)
134 |
135 | def __getitem__(self, index):
136 | image_path = os.path.join(self.image_root, self.annotation[index]['image'])
137 | image = Image.open(image_path).convert('RGB')
138 | image = self.transform(image)
139 |
140 | return image, index
141 |
142 |
143 |
--------------------------------------------------------------------------------
/data/randaugment.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 |
5 | ## aug functions
6 | def identity_func(img):
7 | return img
8 |
9 |
10 | def autocontrast_func(img, cutoff=0):
11 | '''
12 | same output as PIL.ImageOps.autocontrast
13 | '''
14 | n_bins = 256
15 |
16 | def tune_channel(ch):
17 | n = ch.size
18 | cut = cutoff * n // 100
19 | if cut == 0:
20 | high, low = ch.max(), ch.min()
21 | else:
22 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
23 | low = np.argwhere(np.cumsum(hist) > cut)
24 | low = 0 if low.shape[0] == 0 else low[0]
25 | high = np.argwhere(np.cumsum(hist[::-1]) > cut)
26 | high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
27 | if high <= low:
28 | table = np.arange(n_bins)
29 | else:
30 | ##### modified by lzl ###############
31 | scale = (n_bins - 1) / (high - low)
32 | low = np.int8(low)
33 | offset = -low * scale
34 | ################################
35 | table = np.arange(n_bins) * scale + offset
36 | table[table < 0] = 0
37 | table[table > n_bins - 1] = n_bins - 1
38 | table = table.clip(0, 255).astype(np.uint8)
39 | return table[ch]
40 |
41 | channels = [tune_channel(ch) for ch in cv2.split(img)]
42 | out = cv2.merge(channels)
43 | return out
44 |
45 |
46 | def equalize_func(img):
47 | '''
48 | same output as PIL.ImageOps.equalize
49 | PIL's implementation is different from cv2.equalize
50 | '''
51 | n_bins = 256
52 |
53 | def tune_channel(ch):
54 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
55 | non_zero_hist = hist[hist != 0].reshape(-1)
56 | step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
57 | if step == 0: return ch
58 | n = np.empty_like(hist)
59 | n[0] = step // 2
60 | n[1:] = hist[:-1]
61 | table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
62 | return table[ch]
63 |
64 | channels = [tune_channel(ch) for ch in cv2.split(img)]
65 | out = cv2.merge(channels)
66 | return out
67 |
68 |
69 | def rotate_func(img, degree, fill=(0, 0, 0)):
70 | '''
71 | like PIL, rotate by degree, not radians
72 | '''
73 | H, W = img.shape[0], img.shape[1]
74 | center = W / 2, H / 2
75 | M = cv2.getRotationMatrix2D(center, degree, 1)
76 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
77 | return out
78 |
79 |
80 | def solarize_func(img, thresh=128):
81 | '''
82 | same output as PIL.ImageOps.posterize
83 | '''
84 | table = np.array([el if el < thresh else 255 - el for el in range(256)])
85 | table = table.clip(0, 255).astype(np.uint8)
86 | out = table[img]
87 | return out
88 |
89 |
90 | def color_func(img, factor):
91 | '''
92 | same output as PIL.ImageEnhance.Color
93 | '''
94 | ## implementation according to PIL definition, quite slow
95 | # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
96 | # out = blend(degenerate, img, factor)
97 | # M = (
98 | # np.eye(3) * factor
99 | # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
100 | # )[np.newaxis, np.newaxis, :]
101 | M = (
102 | np.float32([
103 | [0.886, -0.114, -0.114],
104 | [-0.587, 0.413, -0.587],
105 | [-0.299, -0.299, 0.701]]) * factor
106 | + np.float32([[0.114], [0.587], [0.299]])
107 | )
108 | out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
109 | return out
110 |
111 |
112 | def contrast_func(img, factor):
113 | """
114 | same output as PIL.ImageEnhance.Contrast
115 | """
116 | mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
117 | table = np.array([(
118 | el - mean) * factor + mean
119 | for el in range(256)
120 | ]).clip(0, 255).astype(np.uint8)
121 | out = table[img]
122 | return out
123 |
124 |
125 | def brightness_func(img, factor):
126 | '''
127 | same output as PIL.ImageEnhance.Contrast
128 | '''
129 | table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
130 | out = table[img]
131 | return out
132 |
133 |
134 | def sharpness_func(img, factor):
135 | '''
136 | The differences the this result and PIL are all on the 4 boundaries, the center
137 | areas are same
138 | '''
139 | kernel = np.ones((3, 3), dtype=np.float32)
140 | kernel[1][1] = 5
141 | kernel /= 13
142 | degenerate = cv2.filter2D(img, -1, kernel)
143 | if factor == 0.0:
144 | out = degenerate
145 | elif factor == 1.0:
146 | out = img
147 | else:
148 | out = img.astype(np.float32)
149 | degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
150 | out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
151 | out = out.astype(np.uint8)
152 | return out
153 |
154 |
155 | def shear_x_func(img, factor, fill=(0, 0, 0)):
156 | H, W = img.shape[0], img.shape[1]
157 | M = np.float32([[1, factor, 0], [0, 1, 0]])
158 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
159 | return out
160 |
161 |
162 | def translate_x_func(img, offset, fill=(0, 0, 0)):
163 | '''
164 | same output as PIL.Image.transform
165 | '''
166 | H, W = img.shape[0], img.shape[1]
167 | M = np.float32([[1, 0, -offset], [0, 1, 0]])
168 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
169 | return out
170 |
171 |
172 | def translate_y_func(img, offset, fill=(0, 0, 0)):
173 | '''
174 | same output as PIL.Image.transform
175 | '''
176 | H, W = img.shape[0], img.shape[1]
177 | M = np.float32([[1, 0, 0], [0, 1, -offset]])
178 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
179 | return out
180 |
181 |
182 | def posterize_func(img, bits):
183 | '''
184 | same output as PIL.ImageOps.posterize
185 | '''
186 | out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
187 | return out
188 |
189 |
190 | def shear_y_func(img, factor, fill=(0, 0, 0)):
191 | H, W = img.shape[0], img.shape[1]
192 | M = np.float32([[1, 0, 0], [factor, 1, 0]])
193 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
194 | return out
195 |
196 |
197 | def cutout_func(img, pad_size, replace=(0, 0, 0)):
198 | replace = np.array(replace, dtype=np.uint8)
199 | H, W = img.shape[0], img.shape[1]
200 | rh, rw = np.random.random(2)
201 | pad_size = pad_size // 2
202 | ch, cw = int(rh * H), int(rw * W)
203 | x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
204 | y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
205 | out = img.copy()
206 | out[x1:x2, y1:y2, :] = replace
207 | return out
208 |
209 |
210 | ### level to args
211 | def enhance_level_to_args(MAX_LEVEL):
212 | def level_to_args(level):
213 | return ((level / MAX_LEVEL) * 1.8 + 0.1,)
214 | return level_to_args
215 |
216 |
217 | def shear_level_to_args(MAX_LEVEL, replace_value):
218 | def level_to_args(level):
219 | level = (level / MAX_LEVEL) * 0.3
220 | if np.random.random() > 0.5: level = -level
221 | return (level, replace_value)
222 |
223 | return level_to_args
224 |
225 |
226 | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
227 | def level_to_args(level):
228 | level = (level / MAX_LEVEL) * float(translate_const)
229 | if np.random.random() > 0.5: level = -level
230 | return (level, replace_value)
231 |
232 | return level_to_args
233 |
234 |
235 | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
236 | def level_to_args(level):
237 | level = int((level / MAX_LEVEL) * cutout_const)
238 | return (level, replace_value)
239 |
240 | return level_to_args
241 |
242 |
243 | def solarize_level_to_args(MAX_LEVEL):
244 | def level_to_args(level):
245 | level = int((level / MAX_LEVEL) * 256)
246 | return (level, )
247 | return level_to_args
248 |
249 |
250 | def none_level_to_args(level):
251 | return ()
252 |
253 |
254 | def posterize_level_to_args(MAX_LEVEL):
255 | def level_to_args(level):
256 | level = int((level / MAX_LEVEL) * 4)
257 | return (level, )
258 | return level_to_args
259 |
260 |
261 | def rotate_level_to_args(MAX_LEVEL, replace_value):
262 | def level_to_args(level):
263 | level = (level / MAX_LEVEL) * 30
264 | if np.random.random() < 0.5:
265 | level = -level
266 | return (level, replace_value)
267 |
268 | return level_to_args
269 |
270 |
271 | func_dict = {
272 | 'Identity': identity_func,
273 | 'AutoContrast': autocontrast_func,
274 | 'Equalize': equalize_func,
275 | 'Rotate': rotate_func,
276 | 'Solarize': solarize_func,
277 | 'Color': color_func,
278 | 'Contrast': contrast_func,
279 | 'Brightness': brightness_func,
280 | 'Sharpness': sharpness_func,
281 | 'ShearX': shear_x_func,
282 | 'TranslateX': translate_x_func,
283 | 'TranslateY': translate_y_func,
284 | 'Posterize': posterize_func,
285 | 'ShearY': shear_y_func,
286 | }
287 |
288 | translate_const = 10
289 | MAX_LEVEL = 10
290 | replace_value = (128, 128, 128)
291 | arg_dict = {
292 | 'Identity': none_level_to_args,
293 | 'AutoContrast': none_level_to_args,
294 | 'Equalize': none_level_to_args,
295 | 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
296 | 'Solarize': solarize_level_to_args(MAX_LEVEL),
297 | 'Color': enhance_level_to_args(MAX_LEVEL),
298 | 'Contrast': enhance_level_to_args(MAX_LEVEL),
299 | 'Brightness': enhance_level_to_args(MAX_LEVEL),
300 | 'Sharpness': enhance_level_to_args(MAX_LEVEL),
301 | 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
302 | 'TranslateX': translate_level_to_args(
303 | translate_const, MAX_LEVEL, replace_value
304 | ),
305 | 'TranslateY': translate_level_to_args(
306 | translate_const, MAX_LEVEL, replace_value
307 | ),
308 | 'Posterize': posterize_level_to_args(MAX_LEVEL),
309 | 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
310 | }
311 |
312 |
313 | class RandomAugment(object):
314 |
315 | def __init__(self, N=2, M=10, isPIL=False, augs=[]):
316 | self.N = N
317 | self.M = M
318 | self.isPIL = isPIL
319 | if augs:
320 | self.augs = augs
321 | else:
322 | self.augs = list(arg_dict.keys())
323 |
324 | def get_random_ops(self):
325 | sampled_ops = np.random.choice(self.augs, self.N)
326 | return [(op, 0.5, self.M) for op in sampled_ops]
327 |
328 | def __call__(self, img):
329 | if self.isPIL:
330 | img = np.array(img)
331 | ops = self.get_random_ops()
332 | for name, prob, level in ops:
333 | if np.random.random() > prob:
334 | continue
335 | args = arg_dict[name](level)
336 | img = func_dict[name](img, *args)
337 | return img
338 |
339 |
340 | if __name__ == '__main__':
341 | a = RandomAugment()
342 | img = np.random.randn(32, 32, 3)
343 | a(img)
--------------------------------------------------------------------------------
/distill_tesla_lors.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | import glob
3 | import os
4 | import argparse
5 | import re
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | from tqdm import tqdm
11 | import torchvision.utils
12 | import math
13 |
14 | import wandb
15 | import copy
16 | import datetime
17 |
18 | from data import get_dataset_flickr, textprocess, textprocess_train
19 | from src.epoch import evaluate_synset_with_similarity
20 | from src.networks import CLIPModel_full, MultilabelContrastiveLoss
21 | from src.reparam_module import ReparamModule
22 | from src.utils import ParamDiffAug, get_time
23 | from src.similarity_mining import LowRankSimilarityGenerator, FullSimilarityGenerator
24 | from src.vl_distill_utils import shuffle_files, nearest_neighbor, get_images_texts, load_or_process_file
25 |
26 |
27 |
28 | def make_timestamp(prefix: str="", suffix: str="") -> str:
29 | tmstamp = '{:%m%d_%H%M%S}'.format(datetime.datetime.now())
30 | return prefix + tmstamp + suffix
31 |
32 |
33 |
34 |
35 | def main(args):
36 | ''' organize the real train dataset '''
37 | trainloader, testloader, train_dataset, test_dataset = get_dataset_flickr(args)
38 |
39 | train_sentences = train_dataset.get_all_captions()
40 |
41 | data = load_or_process_file('text', textprocess, args, testloader)
42 | train_caption = load_or_process_file('train_text', textprocess_train, args, train_sentences)
43 |
44 | bert_test_embed = torch.from_numpy(data['bert_test_embed']).cpu()
45 | print("The shape of bert_test_embed: {}".format(bert_test_embed.shape))
46 | train_caption_embed = torch.from_numpy(train_caption['bert_test_embed']).cpu()
47 | print("The shape of train_caption_embed: {}".format(train_caption_embed.shape))
48 |
49 | logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / args.temperature))
50 |
51 |
52 | if args.zca and args.texture:
53 | raise AssertionError("Cannot use zca and texture together")
54 |
55 | if args.texture and args.pix_init == "real":
56 | print("WARNING: Using texture with real initialization will take a very long time to smooth out the boundaries between images.")
57 |
58 | if args.max_experts is not None and args.max_files is not None:
59 | args.total_experts = args.max_experts * args.max_files
60 |
61 | print("CUDNN STATUS: {}".format(torch.backends.cudnn.enabled))
62 |
63 | if args.dsa == True:
64 | print("unfortunately, this repo did not support DSA")
65 | raise AssertionError("DSA is not supported in this repo")
66 | args.dsa = True if args.dsa == 'True' else False
67 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
68 |
69 | if args.eval_it>0:
70 | eval_it_pool = np.arange(0, args.Iteration + 1, args.eval_it).tolist()
71 | else:
72 | eval_it_pool = []
73 |
74 | if args.dsa:
75 | args.dc_aug_param = None
76 |
77 | if args.disabled_wandb:
78 | wandb.init(mode = 'disabled')
79 | else:
80 | wandb.init(project='LoRS', config=args, name=args.name+"_"+make_timestamp())
81 |
82 | args.dsa_param = ParamDiffAug()
83 | zca_trans = args.zca_trans if args.zca else None
84 | args.zca_trans = zca_trans
85 | args.distributed = torch.cuda.device_count() > 1
86 |
87 | syn_lr_img = torch.tensor(args.lr_teacher_img).to(args.device).requires_grad_(True)
88 | syn_lr_txt = torch.tensor(args.lr_teacher_txt).to(args.device).requires_grad_(True)
89 |
90 | ''' initialize the synthetic data '''
91 | image_syn, text_syn = get_images_texts(args.num_queries, train_dataset, args)
92 |
93 | if args.sim_type == 'lowrank':
94 | sim_generator = LowRankSimilarityGenerator(
95 | args.num_queries, args.sim_rank, args.alpha)
96 | elif args.sim_type == 'full':
97 | sim_generator = FullSimilarityGenerator(args.num_queries)
98 | else:
99 | raise AssertionError("Invalid similarity type: {}".format(args.sim_type))
100 | sim_generator = sim_generator.to(args.device)
101 |
102 |
103 | contrastive_criterion = MultilabelContrastiveLoss(args.loss_type)
104 |
105 | if args.pix_init == 'noise':
106 | mean = torch.tensor([-0.0626, -0.0221, 0.0680])
107 | std = torch.tensor([1.0451, 1.0752, 1.0539])
108 | image_syn = torch.randn([args.num_queries, 3, 224, 224])
109 | for c in range(3):
110 | image_syn[:, c] = image_syn[:, c] * std[c] + mean[c]
111 | print('Initialized synthetic image from random noise')
112 |
113 | if args.txt_init == 'noise':
114 | text_syn = torch.normal(mean=-0.0094, std=0.5253, size=(args.num_queries, 768))
115 | print('Initialized synthetic text from random noise')
116 |
117 | start_it = 0
118 | if args.resume_from is not None:
119 | ckpt = torch.load(args.resume_from)
120 | image_syn = ckpt["image"].to(args.device).requires_grad_(True)
121 | text_syn = ckpt["text"].to(args.device).requires_grad_(True)
122 | if "similarity_params" in ckpt:
123 | sim_generator.load_params([x.to(args.device) for x in ckpt["similarity_params"]])
124 | else:
125 | print("WARNING: no similarity matrix in the checkpoint")
126 | syn_lr_img = ckpt["syn_lr_img"].to(args.device).requires_grad_(True)
127 | syn_lr_txt = ckpt["syn_lr_txt"].to(args.device).requires_grad_(True)
128 |
129 | re_res = re.findall(r"distilled_(\d+).pt", args.resume_from)
130 | if len(re_res) == 1:
131 | start_it = int(re_res[0])
132 | else:
133 | start_it = 0
134 |
135 |
136 | ''' training '''
137 | image_syn = image_syn.detach().to(args.device).requires_grad_(True)
138 | text_syn = text_syn.detach().to(args.device).requires_grad_(True)
139 |
140 | optimizer = torch.optim.SGD([
141 | {'params': [image_syn], 'lr': args.lr_img, "momentum": args.momentum_syn},
142 | {'params': [text_syn], 'lr': args.lr_txt, "momentum": args.momentum_syn},
143 | {'params': [syn_lr_img, syn_lr_txt], 'lr': args.lr_lr, "momentum": args.momentum_lr},
144 | {'params': sim_generator.get_indexed_parameters(), 'lr': args.lr_sim, "momentum": args.momentum_sim},
145 | ], lr=0)
146 | optimizer.zero_grad()
147 |
148 | if args.draw:
149 | sentence_list = nearest_neighbor(train_sentences, text_syn.detach().cpu(), train_caption_embed)
150 | wandb.log({"original_sentence_list": wandb.Html('
'.join(sentence_list))}, step=0)
151 | wandb.log({"original_synthetic_images": wandb.Image(torch.nan_to_num(image_syn.detach().cpu()))}, step=0)
152 |
153 |
154 | expert_dir = os.path.join(args.buffer_path, args.dataset)
155 | expert_dir = args.buffer_path
156 | print("Expert Dir: {}".format(expert_dir))
157 |
158 |
159 | img_expert_files = list(glob.glob(os.path.join(expert_dir, "img_replay_buffer_*.pt")))
160 | txt_expert_files = list(glob.glob(os.path.join(expert_dir, "txt_replay_buffer_*.pt")))
161 | if len(txt_expert_files) != len(img_expert_files) or len(txt_expert_files) == 0:
162 | raise AssertionError("No buffers / Error buffers detected at {}".format(expert_dir))
163 |
164 | img_expert_files, txt_expert_files = shuffle_files(img_expert_files, txt_expert_files)
165 |
166 | file_idx = 0
167 | expert_idx = 0
168 |
169 | img_buffer = torch.load(img_expert_files[file_idx])
170 | txt_buffer = torch.load(txt_expert_files[file_idx])
171 |
172 | for it in tqdm(range(start_it, args.Iteration + 1)):
173 | save_this_it = True
174 |
175 | wandb.log({"Progress": it}, step=it)
176 | ''' Evaluate synthetic data '''
177 | if it in eval_it_pool:
178 | print('Evaluation\nimage_model_train = %s, text_model_train = %s, iteration = %d'%(args.image_encoder, args.text_encoder, it))
179 |
180 | multi_eval_aggr_result = defaultdict(list) # aggregated results of multiple evaluations
181 |
182 | # r_means = []
183 | for it_eval in range(args.num_eval):
184 | net_eval = CLIPModel_full(args, eval_stage=args.transfer)
185 |
186 | with torch.no_grad():
187 | image_save = image_syn
188 | text_save = text_syn
189 | image_syn_eval, text_syn_eval = copy.deepcopy(image_save.detach()), copy.deepcopy(text_save.detach()) # avoid any unaware modification
190 |
191 | lr_img, lr_txt = copy.deepcopy(syn_lr_img.detach().item()), copy.deepcopy(syn_lr_txt.detach().item())
192 |
193 | sim_params = sim_generator.get_indexed_parameters()
194 | similarity_syn_eval = copy.deepcopy(sim_generator.generate_with_param(sim_params).detach()) # avoid any unaware modification
195 |
196 | _, _, best_val_result = evaluate_synset_with_similarity(
197 | it_eval, net_eval, image_syn_eval, text_syn_eval, lr_img, lr_txt,
198 | similarity_syn_eval, testloader, args, bert_test_embed)
199 |
200 | for k, v in best_val_result.items():
201 | multi_eval_aggr_result[k].append(v)
202 |
203 | if not args.std:
204 | wandb.log({
205 | k: v
206 | for k, v in best_val_result.items()
207 | if k not in ["img_r_mean", "txt_r_mean"]
208 | }, step=it)
209 | # logged img_r1, img_r5, img_r10, txt_r1, txt_r5, txt_r10, r_mean
210 |
211 |
212 | if args.std:
213 | for key, values in multi_eval_aggr_result.items():
214 | if key in ["img_r_mean", "txt_r_mean"]:
215 | continue
216 | wandb.log({
217 | "Mean/{}".format(key): np.mean(values),
218 | "Std/{}".format(key): np.std(values)
219 | }, step=it)
220 |
221 |
222 | if it in eval_it_pool and (save_this_it or it % 1000 == 0):
223 | with torch.no_grad():
224 | save_dir = os.path.join(".", "logged_files", args.dataset, wandb.run.name)
225 | print("Saving to {}".format(save_dir))
226 | if not os.path.exists(save_dir):
227 | os.makedirs(save_dir)
228 |
229 | image_save = image_syn.detach().cpu()
230 | text_save = text_syn.detach().cpu()
231 | sim_params = sim_generator.get_indexed_parameters()
232 | sim_mat = sim_generator.generate_with_param(sim_params)
233 |
234 | torch.save({
235 | "image": image_save,
236 | "text": text_save,
237 | "similarity_params": [x.detach().cpu() for x in sim_params],
238 | "similarity_mat": sim_mat.detach().cpu(),
239 | "syn_lr_img": syn_lr_img.detach().cpu(),
240 | "syn_lr_txt": syn_lr_txt.detach().cpu(),
241 | }, os.path.join(save_dir, "distilled_{}.pt".format(it)) )
242 |
243 |
244 | if args.draw:
245 | wandb.log({"Pixels": wandb.Histogram(torch.nan_to_num(image_syn.detach().cpu()))}, step=it) # Move tensor to CPU before converting to NumPy
246 |
247 | if args.ipc < 50 or args.force_save:
248 | upsampled = image_save[:90]
249 | if args.dataset != "ImageNet":
250 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=2)
251 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3)
252 | grid = torchvision.utils.make_grid(upsampled, nrow=10, normalize=True, scale_each=True)
253 | sentence_list = nearest_neighbor(train_sentences, text_syn.cpu(), train_caption_embed)
254 | sentence_list = sentence_list[:90]
255 | torchvision.utils.save_image(grid, os.path.join(save_dir, "synthetic_images_{}.png".format(it)))
256 |
257 | with open(os.path.join(save_dir, "synthetic_sentences_{}.txt".format(it)), "w") as file:
258 | file.write('\n'.join(sentence_list))
259 | wandb.log({"Synthetic_Images": wandb.Image(torch.nan_to_num(grid.detach().cpu()))}, step=it)
260 | wandb.log({'Synthetic_Pixels': wandb.Histogram(torch.nan_to_num(image_save.detach().cpu()))}, step=it)
261 | wandb.log({"Synthetic_Sentences": wandb.Html('
'.join(sentence_list))}, step=it)
262 | print("finish saving images")
263 |
264 | for clip_val in [2.5]:
265 | std = torch.std(image_save)
266 | mean = torch.mean(image_save)
267 | upsampled = torch.clip(image_save, min=mean-clip_val*std, max=mean+clip_val*std).cpu() # Move to CPU
268 | if args.dataset != "ImageNet":
269 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=2)
270 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3)
271 | grid = torchvision.utils.make_grid(upsampled[:90], nrow=10, normalize=True, scale_each=True)
272 | wandb.log({"Clipped_Synthetic_Images/std_{}".format(clip_val): wandb.Image(torch.nan_to_num(grid))}, step=it)
273 | torchvision.utils.save_image(grid, os.path.join(save_dir, "clipped_synthetic_images_{}_std_{}.png".format(it, clip_val)))
274 |
275 |
276 | if args.zca:
277 | raise AssertionError("we do not use ZCA transformation")
278 |
279 | wandb.log({"Synthetic_LR/Image": syn_lr_img.detach().cpu()}, step=it)
280 | wandb.log({"Synthetic_LR/Text": syn_lr_txt.detach().cpu()}, step=it)
281 |
282 | torch.cuda.empty_cache()
283 | student_net = CLIPModel_full(args, temperature=args.temperature)
284 | img_student_net = ReparamModule(student_net.image_encoder.to('cpu')).to('cuda')
285 | txt_student_net = ReparamModule(student_net.text_projection.to('cpu')).to('cuda')
286 |
287 | if args.distributed:
288 | img_student_net = torch.nn.DataParallel(img_student_net)
289 | txt_student_net = torch.nn.DataParallel(txt_student_net)
290 |
291 | img_student_net.train()
292 | txt_student_net.train()
293 |
294 |
295 | img_expert_trajectory = img_buffer[expert_idx]
296 | txt_expert_trajectory = txt_buffer[expert_idx]
297 | expert_idx += 1
298 | if expert_idx == len(img_buffer):
299 | expert_idx = 0
300 | file_idx += 1
301 | if file_idx == len(img_expert_files):
302 | file_idx = 0
303 | img_expert_files, txt_expert_files = shuffle_files(img_expert_files, txt_expert_files)
304 | if args.max_files != 1:
305 | del img_buffer
306 | del txt_buffer
307 | img_buffer = torch.load(img_expert_files[file_idx])
308 | txt_buffer = torch.load(txt_expert_files[file_idx])
309 |
310 | start_epoch = np.random.randint(0, args.max_start_epoch)
311 | img_starting_params = img_expert_trajectory[start_epoch]
312 | txt_starting_params = txt_expert_trajectory[start_epoch]
313 |
314 | img_target_params = img_expert_trajectory[start_epoch+args.expert_epochs]
315 | txt_target_params = txt_expert_trajectory[start_epoch+args.expert_epochs]
316 |
317 | img_target_params = torch.cat([p.data.to(args.device).reshape(-1) for p in img_target_params], 0)
318 | txt_target_params = torch.cat([p.data.to(args.device).reshape(-1) for p in txt_target_params], 0)
319 |
320 | img_student_params = [torch.cat([p.data.to(args.device).reshape(-1) for p in img_starting_params], 0).requires_grad_(True)]
321 | txt_student_params = [torch.cat([p.data.to(args.device).reshape(-1) for p in txt_starting_params], 0).requires_grad_(True)]
322 |
323 | img_starting_params = torch.cat([p.data.to(args.device).reshape(-1) for p in img_starting_params], 0)
324 | txt_starting_params = torch.cat([p.data.to(args.device).reshape(-1) for p in txt_starting_params], 0)
325 | syn_images = image_syn
326 | syn_texts = text_syn
327 |
328 | x_list = []
329 | y_list = []
330 | sim_list = [] # the parameters of parameterized similarity matrix
331 | sim_param_list = [] # the parameters of parameterized similarity matrix
332 | grad_sum_img = torch.zeros(img_student_params[-1].shape).to(args.device)
333 | grad_sum_txt = torch.zeros(txt_student_params[-1].shape).to(args.device)
334 |
335 | syn_image_gradients = torch.zeros(syn_images.shape).to(args.device)
336 | syn_txt_gradients = torch.zeros(syn_texts.shape).to(args.device)
337 | syn_sim_param_gradients = [torch.zeros(x.shape).to(args.device) for x in sim_generator.get_indexed_parameters()]
338 |
339 | indices_chunks_copy = []
340 | indices = torch.randperm(len(syn_images))
341 | index = 0
342 | for _ in range(args.syn_steps):
343 | if args.mini_batch_size + index > len(syn_images):
344 | indices = torch.randperm(len(syn_images))
345 | index = 0
346 | these_indices = indices[index : index + args.mini_batch_size]
347 | index += args.mini_batch_size
348 |
349 | indices_chunks_copy.append(these_indices)
350 |
351 | x = syn_images[these_indices]
352 | this_y = syn_texts[these_indices]
353 |
354 | x_list.append(x.clone())
355 | y_list.append(this_y.clone())
356 |
357 | this_sim_param = sim_generator.get_indexed_parameters(these_indices)
358 | this_sim = sim_generator.generate_with_param(this_sim_param)
359 | if args.distributed:
360 | img_forward_params = img_student_params[-1].unsqueeze(0).expand(torch.cuda.device_count(), -1)
361 | txt_forward_params = txt_student_params[-1].unsqueeze(0).expand(torch.cuda.device_count(), -1)
362 | else:
363 | img_forward_params = img_student_params[-1]
364 | txt_forward_params = txt_student_params[-1]
365 |
366 | x = img_student_net(x, flat_param=img_forward_params)
367 | x = x / x.norm(dim=1, keepdim=True)
368 | this_y = txt_student_net(this_y, flat_param=txt_forward_params)
369 | this_y = this_y / this_y.norm(dim=1, keepdim=True)
370 | image_logits = logit_scale.exp() * x.float() @ this_y.float().t()
371 |
372 | sim_list.append(this_sim)
373 | sim_param_list.append(this_sim_param)
374 |
375 | contrastive_loss = contrastive_criterion(image_logits, this_sim)
376 |
377 | img_grad = torch.autograd.grad(contrastive_loss, img_student_params[-1], create_graph=True)[0]
378 | txt_grad = torch.autograd.grad(contrastive_loss, txt_student_params[-1], create_graph=True)[0]
379 |
380 |
381 | img_detached_grad = img_grad.detach().clone()
382 | img_student_params.append(img_student_params[-1] - syn_lr_img.item() * img_detached_grad)
383 | grad_sum_img += img_detached_grad
384 |
385 | txt_detached_grad = txt_grad.detach().clone()
386 | txt_student_params.append(txt_student_params[-1] - syn_lr_txt.item() * txt_detached_grad)
387 | grad_sum_txt += txt_detached_grad
388 |
389 |
390 | del img_grad
391 | del txt_grad
392 |
393 | img_param_dist = torch.tensor(0.0).to(args.device)
394 | txt_param_dist = torch.tensor(0.0).to(args.device)
395 |
396 | img_param_dist += torch.nn.functional.mse_loss(img_starting_params, img_target_params, reduction="sum")
397 | txt_param_dist += torch.nn.functional.mse_loss(txt_starting_params, txt_target_params, reduction="sum")
398 |
399 |
400 | # compute gradients invoving 2 gradients
401 | for i in range(args.syn_steps):
402 | x = img_student_net(x_list[i], flat_param=img_student_params[i])
403 | x = x / x.norm(dim=1, keepdim=True)
404 | this_y = txt_student_net(y_list[i], flat_param=txt_student_params[i])
405 | this_y = this_y / this_y.norm(dim=1, keepdim=True)
406 | image_logits = logit_scale.exp() * x.float() @ this_y.float().t()
407 | loss_i = contrastive_criterion(image_logits, sim_list[i])
408 |
409 |
410 | img_single_term = syn_lr_img.item() * (img_target_params - img_starting_params)
411 | img_square_term = (syn_lr_img.item() ** 2) * grad_sum_img
412 | txt_single_term = syn_lr_txt.item() * (txt_target_params - txt_starting_params)
413 | txt_square_term = (syn_lr_txt.item() ** 2) * grad_sum_txt
414 |
415 |
416 | img_grad_i = torch.autograd.grad(loss_i, img_student_params[i], create_graph=True, retain_graph=True)[0]
417 | txt_grad_i = torch.autograd.grad(loss_i, txt_student_params[i], create_graph=True, retain_graph=True)[0]
418 |
419 |
420 | img_syn_real_dist = (img_single_term + img_square_term) @ img_grad_i
421 | txt_syn_real_dist = (txt_single_term + txt_square_term) @ txt_grad_i
422 |
423 | if args.merge_loss_branches:
424 | # Following traditional MTT loss, equivalent to adding some weights on the two branches
425 | grand_loss_i = (img_syn_real_dist + txt_syn_real_dist) / ( img_param_dist + txt_param_dist)
426 | else:
427 | # Original loss in vl-distill
428 | grand_loss_i = img_syn_real_dist / img_param_dist + txt_syn_real_dist / txt_param_dist
429 |
430 | multiple_gradients_in_one_time = torch.autograd.grad(
431 | 2 * grand_loss_i,
432 | [x_list[i], y_list[i]] + sim_param_list[i]
433 | )
434 |
435 | img_gradients = multiple_gradients_in_one_time[0]
436 | txt_gradients = multiple_gradients_in_one_time[1]
437 | sim_param_gradients = multiple_gradients_in_one_time[2:]
438 |
439 |
440 | with torch.no_grad():
441 | ids = indices_chunks_copy[i]
442 | syn_image_gradients[ids] += img_gradients
443 | syn_txt_gradients[ids] += txt_gradients
444 |
445 | assert len(sim_param_gradients) == len(syn_sim_param_gradients), f"{len(sim_param_gradients)}, {len(syn_sim_param_gradients)}"
446 | for g_idx in range(len(sim_param_gradients)):
447 | if args.sim_type == 'full':
448 | syn_sim_param_gradients[g_idx][ids[:,None], ids] += sim_param_gradients[g_idx]
449 | # !!! gradient will be lost if using xxxx[ids, :][:, ids]
450 | else:
451 | syn_sim_param_gradients[g_idx][ids, ...] += sim_param_gradients[g_idx]
452 |
453 |
454 | # ---------end of computing input image gradients and learning rates--------------
455 |
456 | syn_images.grad = syn_image_gradients
457 | syn_texts.grad = syn_txt_gradients
458 |
459 | for g_idx, param in enumerate(sim_generator.get_indexed_parameters()):
460 | param.grad = syn_sim_param_gradients[g_idx]
461 |
462 |
463 | img_grand_loss = img_starting_params - syn_lr_img * grad_sum_img - img_target_params
464 | txt_grand_loss = txt_starting_params - syn_lr_txt * grad_sum_txt - txt_target_params
465 | img_grand_loss = img_grand_loss.dot(img_grand_loss) / img_param_dist
466 | txt_grand_loss = txt_grand_loss.dot(txt_grand_loss) / txt_param_dist
467 | grand_loss = img_grand_loss + txt_grand_loss
468 |
469 | img_lr_grad = torch.autograd.grad(img_grand_loss, syn_lr_img)[0]
470 | syn_lr_img.grad = img_lr_grad
471 | txt_lr_grad = torch.autograd.grad(txt_grand_loss, syn_lr_txt)[0]
472 | syn_lr_txt.grad = txt_lr_grad
473 |
474 | if math.isnan(img_grand_loss):
475 | break
476 |
477 | wandb.log({
478 | "Loss/grand": grand_loss.detach().cpu(),
479 | # "Start_Epoch": start_epoch,
480 | "Loss/img_grand": img_grand_loss.detach().cpu(),
481 | "Loss/txt_grand": txt_grand_loss.detach().cpu(),
482 | }, step=it)
483 |
484 |
485 | wandb.log({"Synthetic_LR/grad_syn_lr_img": syn_lr_img.grad.detach().cpu()}, step=it)
486 | wandb.log({"Synthetic_LR/grad_syn_lr_txt": syn_lr_txt.grad.detach().cpu()}, step=it)
487 |
488 | optimizer.step() # no need zero_grad: grad is not computed by autograd; it is directly override by our computation
489 |
490 | if args.clamp_lr:
491 | syn_lr_img.data = torch.clamp(syn_lr_img.data, min=args.clamp_lr) #
492 | syn_lr_txt.data = torch.clamp(syn_lr_txt.data, min=args.clamp_lr)
493 |
494 | for _ in img_student_params:
495 | del _
496 | for _ in txt_student_params:
497 | del _
498 |
499 | if it%10 == 0:
500 | print('%s iter = %04d, loss = %.4f' % (get_time(), it, grand_loss.item()))
501 |
502 |
503 | wandb.finish()
504 |
505 |
506 | if __name__ == '__main__':
507 | parser = argparse.ArgumentParser(description='Parameter Processing')
508 |
509 | parser.add_argument('--dataset', type=str, default='flickr30k', help='dataset')
510 |
511 | parser.add_argument('--eval_mode', type=str, default='S',
512 | help='eval_mode, check utils.py for more info')
513 |
514 | parser.add_argument('--num_eval', type=int, default=5, help='how many networks to evaluate on')
515 |
516 | parser.add_argument('--eval_it', type=int, default=50, help='how often to evaluate')
517 |
518 | parser.add_argument('--epoch_eval_train', type=int, default=100, help='epochs to train a model with synthetic data')
519 | parser.add_argument('--Iteration', type=int, default=3000, help='how many distillation steps to perform')
520 |
521 | parser.add_argument('--lr_img', type=float, default=1000, help='learning rate for updating synthetic images')
522 | parser.add_argument('--lr_txt', type=float, default=1000, help='learning rate for updating synthetic texts')
523 | parser.add_argument('--lr_lr', type=float, default=1e-03, help='learning rate for updating... learning rate')
524 | parser.add_argument('--lr_teacher_img', type=float, default=0.1, help='learning rate for updating network parameters')
525 | parser.add_argument('--lr_teacher_txt', type=float, default=0.1, help='learning rate for updating network parameters')
526 |
527 | parser.add_argument('--loss_type', type=str)
528 |
529 | parser.add_argument('--batch_train', type=int, default=128, help='batch size for training networks')
530 |
531 | parser.add_argument('--pix_init', type=str, default='real', choices=["noise", "real"],
532 | help='noise/real: initialize synthetic images from random noise or randomly sampled real images.')
533 | parser.add_argument('--txt_init', type=str, default='real', choices=["noise", "real"],
534 | help='noise/real: initialize synthetic texts from random noise or randomly sampled real images.')
535 |
536 | parser.add_argument('--dsa', type=str, default='True', choices=['True', 'False'],
537 | help='whether to use differentiable Siamese augmentation.')
538 |
539 | parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate',
540 | help='differentiable Siamese augmentation strategy')
541 |
542 | parser.add_argument('--data_path', type=str, default='./data/Flickr30k/', help='dataset path')
543 | parser.add_argument('--buffer_path', type=str, default='./buffers', help='buffer path')
544 |
545 | parser.add_argument('--expert_epochs', type=int, default=3, help='how many expert epochs the target params are')
546 | parser.add_argument('--syn_steps', type=int, default=20, help='how many steps to take on synthetic data')
547 | parser.add_argument('--max_start_epoch', type=int, default=25, help='max epoch we can start at')
548 |
549 | parser.add_argument('--zca', action='store_true', help="do ZCA whitening")
550 |
551 | parser.add_argument('--load_all', action='store_true', help="only use if you can fit all expert trajectories into RAM")
552 |
553 | parser.add_argument('--no_aug', action="store_true", default=False, help='this turns off diff aug during distillation')
554 |
555 | parser.add_argument('--texture', action='store_true', help="will distill textures instead")
556 | parser.add_argument('--canvas_size', type=int, default=2, help='size of synthetic canvas')
557 | parser.add_argument('--canvas_samples', type=int, default=1, help='number of canvas samples per iteration')
558 |
559 |
560 | parser.add_argument('--max_files', type=int, default=None, help='number of expert files to read (leave as None unless doing ablations)')
561 | parser.add_argument('--max_experts', type=int, default=None, help='number of experts to read per file (leave as None unless doing ablations)')
562 |
563 | parser.add_argument('--force_save', action='store_true', help='this will save images for 50ipc')
564 | current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
565 | parser.add_argument('--name', type=str, default=current_time, help='name of wandb run')
566 | parser.add_argument('--num_queries', type=int, default=100, help='number of queries')
567 | parser.add_argument('--mini_batch_size', type=int, default=100, help='number of queries')
568 | parser.add_argument('--basis', type=bool, default=False, help='whether use basis or not')
569 | parser.add_argument('--n_basis', type=int, default=64, help='n_basis')
570 | parser.add_argument('--recursive', type=bool, default=False, help='whether use basis or not')
571 | parser.add_argument('--load_npy', type=bool, default=False, help='load_npy')
572 | parser.add_argument('--image_size', type=int, default=224, help='image_size')
573 | parser.add_argument('--image_root', type=str, default='distill_utils/data/Flickr30k/', help='location of image root')
574 | parser.add_argument('--ann_root', type=str, default='./data/Flickr30k_ann/', help='location of ann root')
575 | parser.add_argument('--batch_size_train', type=int, default=128, help='batch_size_train')
576 | parser.add_argument('--batch_size_test', type=int, default=128, help='batch_size_test')
577 | parser.add_argument('--image_encoder', type=str, default='nfnet', help='image encoder') # , choices=['clip', 'nfnet', 'vit', 'nf_resnet50', "nf_regnet"]
578 | parser.add_argument('--text_encoder', type=str, default='bert', choices=['bert', 'clip', 'distilbert'], help='text encoder')
579 | parser.add_argument('--text_pretrained', type=bool, default=True, help='text_pretrained')
580 | parser.add_argument('--image_pretrained', type=bool, default=True, help='image_pretrained')
581 | parser.add_argument('--text_trainable', type=bool, default=False, help='text_trainable')
582 | parser.add_argument('--image_trainable', type=bool, default=True, help='image_trainable')
583 | parser.add_argument('--only_has_image_projection', type=bool, default=False, help='None')
584 | parser.add_argument('--distill', type=bool, default=True, help='whether distill')
585 | parser.add_argument('--optimize', type=str, default='reparam', choices=['reparam', 'ift'], help='matching_train')
586 | parser.add_argument('--image_only', type=bool, default=False, help='None')
587 | parser.add_argument('--text_only', type=bool, default=False, help='None')
588 | parser.add_argument('--draw', type=bool, default=False, help='None')
589 | parser.add_argument('--transfer', type=bool, default=False, help='transfer cross architecture')
590 | parser.add_argument('--std', type=bool, default=True, help='standard deviation')
591 | parser.add_argument('--disabled_wandb', type=bool, default=False, help='disable wandb')
592 | parser.add_argument('--test_with_norm', type=bool, default=False, help='')
593 |
594 | parser.add_argument('--clamp_lr', type=float, default=None, help='')
595 |
596 |
597 | # Arguments below are for LoRS
598 |
599 | parser.add_argument('--resume_from', default=None, type=str)
600 |
601 | parser.add_argument('--sim_type', type=str, default="full", choices=["full", "lowrank"], help='similarity matrix type')
602 | parser.add_argument('--sim_rank', type=int, default=10, help='similarity matrix rank')
603 | parser.add_argument('--alpha', type=float, default=0.1, help='alpha in LoRA')
604 | parser.add_argument('--lr_sim', type=float, default=1e-03, help='learning rate for updating similarity mat learning rate')
605 | parser.add_argument('--temperature', type=float, default=0.07, help="temperature of CLIP model")
606 |
607 | parser.add_argument('--momentum_lr', type=float, default=0.5)
608 | parser.add_argument('--momentum_syn', type=float, default=0.5)
609 | parser.add_argument('--momentum_sim', type=float, default=0.5)
610 | parser.add_argument('--merge_loss_branches', action="store_true", default=False)
611 |
612 | args = parser.parse_args()
613 |
614 | main(args)
--------------------------------------------------------------------------------
/evaluate_only.py:
--------------------------------------------------------------------------------
1 | """Evaliuate the distilled data, could be used for cross-arch exp
2 | Example:
3 |
4 | python evaluate_only.py --dataset=flickr --num_eval 1 \
5 | --ckpt_path tmp/flickr_500_distilled.pt --loss_type WBCE \
6 | --image_encoder=nf_resnet50 --text_encoder=bert --batch_train 64
7 | """
8 |
9 | from collections import defaultdict
10 | import os
11 | import re
12 |
13 | import numpy as np
14 | import torch
15 |
16 | import copy
17 | import argparse
18 | import datetime
19 |
20 | from data import get_dataset_flickr, textprocess, textprocess_train
21 | from src.reparam_module import ReparamModule
22 | from src.epoch import evaluate_synset_with_similarity
23 | from src.networks import CLIPModel_full
24 |
25 | from src.vl_distill_utils import load_or_process_file
26 |
27 |
28 | def formatting_result_head():
29 | return "Img R@1 | Img R@5 | Img R@10 | Txt R@1 | Txt R@5 | Txt R@10 | Mean"
30 |
31 |
32 | def formatting_result_content(val_result):
33 | return "{img_r1:9.2f} | {img_r5:9.2f} | {img_r10:9.2f} | {txt_r1:9.2f} | {txt_r5:9.2f} | {txt_r10:9.2f} | {r_mean:9.2f}".format(
34 | **val_result
35 | )
36 |
37 | def formatting_result_content_clean(val_result):
38 | return "{img_r1} {img_r5} {img_r10} {txt_r1} {txt_r5} {txt_r10} {r_mean}".format(
39 | **val_result
40 | )
41 |
42 | def formatting_result_all(val_result):
43 | return "Image R@1={img_r1} R@5={img_r5} R@10={img_r10} | Text R@1={txt_r1} R@5={txt_r5} R@10={txt_r10} | Mean={r_mean}".format(
44 | **val_result
45 | )
46 |
47 |
48 |
49 | def main(args):
50 | ''' organize the real train dataset '''
51 | trainloader, testloader, train_dataset, test_dataset = get_dataset_flickr(args)
52 |
53 | train_sentences = train_dataset.get_all_captions()
54 |
55 | data = load_or_process_file('text', textprocess, args, testloader)
56 | train_caption = load_or_process_file('train_text', textprocess_train, args, train_sentences)
57 |
58 | bert_test_embed = torch.from_numpy(data['bert_test_embed']).cpu()
59 | print("The shape of bert_test_embed: {}".format(bert_test_embed.shape))
60 | train_caption_embed = torch.from_numpy(train_caption['bert_test_embed']).cpu()
61 | print("The shape of train_caption_embed: {}".format(train_caption_embed.shape))
62 |
63 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
64 |
65 |
66 | if not os.path.exists(args.ckpt_path):
67 | args.ckpt_path = "logged_files/"+args.ckpt_path
68 | if not os.path.exists(args.ckpt_path):
69 | raise ValueError(f"{args.ckpt_path} does not exist")
70 |
71 |
72 | print("Load from", args.ckpt_path)
73 | ckpt = torch.load(args.ckpt_path)
74 | image_syn = ckpt["image"].to(args.device)
75 | text_syn = ckpt["text"].to(args.device)
76 | sim_mat = ckpt["similarity_mat"].to(args.device)
77 | syn_lr_img = ckpt.get("syn_lr_img") if args.syn_lr_img is None else args.syn_lr_img
78 | syn_lr_txt = ckpt.get("syn_lr_txt") if args.syn_lr_txt is None else args.syn_lr_txt
79 | print(syn_lr_img, syn_lr_txt)
80 |
81 | if args.clip_similarity:
82 | # ablation - use similarity matrix from pretrained CLIP
83 | buffer_dir = os.path.join("buffer", args.dataset, args.image_encoder+"_"+args.text_encoder, "InfoNCE")
84 | img_starting_params = torch.load(os.path.join(buffer_dir, "img_replay_buffer_0.pt"))[0][-1]
85 | txt_starting_params = torch.load(os.path.join(buffer_dir, "txt_replay_buffer_0.pt"))[0][-1]
86 | img_forward_params = torch.cat([p.data.to(args.device).reshape(-1) for p in img_starting_params], 0)
87 | txt_forward_params = torch.cat([p.data.to(args.device).reshape(-1) for p in txt_starting_params], 0)
88 |
89 | clip_model = CLIPModel_full(args, eval_stage=args.transfer)
90 | img_student_net = ReparamModule(clip_model.image_encoder.to('cpu')).to('cuda')
91 | txt_student_net = ReparamModule(clip_model.text_projection.to('cpu')).to('cuda')
92 |
93 |
94 | img_feat = img_student_net(image_syn, flat_param=img_forward_params)
95 | img_feat = img_feat / img_feat.norm(dim=1, keepdim=True)
96 | txt_feat = txt_student_net(text_syn, flat_param=txt_forward_params)
97 | txt_feat = txt_feat / txt_feat.norm(dim=1, keepdim=True)
98 | sim_mat = torch.sigmoid(img_feat.float() @ txt_feat.float().t() / args.temperature)
99 |
100 | print('Evaluation\nimage_model_train = %s, text_model_train = %s, iteration = ?'%(args.image_encoder, args.text_encoder))
101 |
102 | multi_eval_aggr_result = defaultdict(list) # aggregated results of multiple evaluations
103 |
104 | for it_eval in range(args.num_eval):
105 | net_eval = CLIPModel_full(args, eval_stage=args.transfer)
106 |
107 | image_syn_eval, text_syn_eval = copy.deepcopy(image_syn), copy.deepcopy(text_syn) # avoid any unaware modification
108 | similarity_syn_eval = copy.deepcopy(sim_mat) # avoid any unaware modification
109 |
110 | _, _, best_val_result = evaluate_synset_with_similarity(
111 | it_eval, net_eval, image_syn_eval, text_syn_eval, syn_lr_img, syn_lr_txt,
112 | similarity_syn_eval, testloader, args, bert_test_embed, mom=args.mom, l2=args.l2)
113 |
114 | for k, v in best_val_result.items():
115 | multi_eval_aggr_result[k].append(v)
116 |
117 |
118 | if not args.std:
119 | formatting_result_content(best_val_result)
120 | # formatting_result_content_clean(best_val_result)
121 | # logged img_r1, img_r5, img_r10, txt_r1, txt_r5, txt_r10, r_mean
122 |
123 | print(formatting_result_head())
124 | if args.std:
125 | mean_results = {k: np.mean(v) for k, v in multi_eval_aggr_result.items()}
126 | std_results = {k: np.std(v) for k, v in multi_eval_aggr_result.items()}
127 |
128 | print(formatting_result_content(mean_results))
129 | print(formatting_result_content(std_results))
130 | print(formatting_result_content_clean({k: "%.2f$\\pm$%.2f"%(mean_results[k],std_results[k]) for k in std_results}))
131 |
132 | print(args.image_encoder)
133 |
134 |
135 |
136 | if __name__ == '__main__':
137 | parser = argparse.ArgumentParser(description='Parameter Processing')
138 |
139 | parser.add_argument('--dataset', type=str, default='flickr30k', help='dataset')
140 |
141 | parser.add_argument('--eval_mode', type=str, default='S',
142 | help='eval_mode, check utils.py for more info')
143 |
144 | parser.add_argument('--num_eval', type=int, default=5, help='how many networks to evaluate on')
145 |
146 | parser.add_argument('--eval_it', type=int, default=50, help='how often to evaluate')
147 |
148 | parser.add_argument('--epoch_eval_train', type=int, default=100, help='epochs to train a model with synthetic data')
149 | parser.add_argument('--Iteration', type=int, default=3000, help='how many distillation steps to perform')
150 |
151 | parser.add_argument('--lr_img', type=float, default=1000, help='learning rate for updating synthetic images')
152 | parser.add_argument('--lr_txt', type=float, default=1000, help='learning rate for updating synthetic texts')
153 | parser.add_argument('--lr_lr', type=float, default=1e-03, help='learning rate for updating... learning rate')
154 | parser.add_argument('--lr_teacher_img', type=float, default=0.1, help='learning rate for updating network parameters')
155 | parser.add_argument('--lr_teacher_txt', type=float, default=0.1, help='learning rate for updating network parameters')
156 |
157 | parser.add_argument('--loss_type', type=str)
158 |
159 | parser.add_argument('--batch_train', type=int, default=128, help='batch size for training networks')
160 |
161 | parser.add_argument('--pix_init', type=str, default='real', choices=["noise", "real"],
162 | help='noise/real: initialize synthetic images from random noise or randomly sampled real images.')
163 | parser.add_argument('--txt_init', type=str, default='real', choices=["noise", "real"],
164 | help='noise/real: initialize synthetic texts from random noise or randomly sampled real images.')
165 |
166 | parser.add_argument('--dsa', type=str, default='True', choices=['True', 'False'],
167 | help='whether to use differentiable Siamese augmentation.')
168 |
169 | parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate',
170 | help='differentiable Siamese augmentation strategy')
171 |
172 | parser.add_argument('--data_path', type=str, default='./data/Flickr30k/', help='dataset path')
173 | parser.add_argument('--buffer_path', type=str, default='./buffers', help='buffer path')
174 |
175 | parser.add_argument('--expert_epochs', type=int, default=3, help='how many expert epochs the target params are')
176 | parser.add_argument('--syn_steps', type=int, default=20, help='how many steps to take on synthetic data')
177 | parser.add_argument('--max_start_epoch', type=int, default=25, help='max epoch we can start at')
178 |
179 | parser.add_argument('--zca', action='store_true', help="do ZCA whitening")
180 |
181 | parser.add_argument('--load_all', action='store_true', help="only use if you can fit all expert trajectories into RAM")
182 |
183 | parser.add_argument('--no_aug', action="store_true", default=False, help='this turns off diff aug during distillation')
184 |
185 | parser.add_argument('--texture', action='store_true', help="will distill textures instead")
186 | parser.add_argument('--canvas_size', type=int, default=2, help='size of synthetic canvas')
187 | parser.add_argument('--canvas_samples', type=int, default=1, help='number of canvas samples per iteration')
188 |
189 |
190 | parser.add_argument('--max_files', type=int, default=None, help='number of expert files to read (leave as None unless doing ablations)')
191 | parser.add_argument('--max_experts', type=int, default=None, help='number of experts to read per file (leave as None unless doing ablations)')
192 |
193 | parser.add_argument('--force_save', action='store_true', help='this will save images for 50ipc')
194 | current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
195 | parser.add_argument('--name', type=str, default=current_time, help='name of wandb run')
196 | parser.add_argument('--num_queries', type=int, default=100, help='number of queries')
197 | parser.add_argument('--mini_batch_size', type=int, default=100, help='number of queries')
198 | parser.add_argument('--basis', type=bool, default=False, help='whether use basis or not')
199 | parser.add_argument('--n_basis', type=int, default=64, help='n_basis')
200 | parser.add_argument('--recursive', type=bool, default=False, help='whether use basis or not')
201 | parser.add_argument('--load_npy', type=bool, default=False, help='load_npy')
202 | parser.add_argument('--image_size', type=int, default=224, help='image_size')
203 | parser.add_argument('--image_root', type=str, default='distill_utils/data/Flickr30k/', help='location of image root')
204 | parser.add_argument('--ann_root', type=str, default='./data/Flickr30k_ann/', help='location of ann root')
205 | parser.add_argument('--batch_size_train', type=int, default=128, help='batch_size_train')
206 | parser.add_argument('--batch_size_test', type=int, default=128, help='batch_size_test')
207 | parser.add_argument('--image_encoder', type=str, default='nfnet', help='image encoder') # , choices=['clip', 'nfnet', 'vit', 'nf_resnet50', "nf_regnet"]
208 | parser.add_argument('--text_encoder', type=str, default='bert', choices=['bert', 'clip', 'distilbert'], help='text encoder')
209 | parser.add_argument('--text_pretrained', type=bool, default=True, help='text_pretrained')
210 | parser.add_argument('--image_pretrained', type=bool, default=True, help='image_pretrained')
211 | parser.add_argument('--text_trainable', type=bool, default=False, help='text_trainable')
212 | parser.add_argument('--image_trainable', type=bool, default=True, help='image_trainable')
213 | parser.add_argument('--only_has_image_projection', type=bool, default=False, help='None')
214 | parser.add_argument('--distill', type=bool, default=True, help='whether distill')
215 | parser.add_argument('--optimize', type=str, default='reparam', choices=['reparam', 'ift'], help='matching_train')
216 | parser.add_argument('--image_only', type=bool, default=False, help='None')
217 | parser.add_argument('--text_only', type=bool, default=False, help='None')
218 | parser.add_argument('--draw', type=bool, default=False, help='None')
219 | parser.add_argument('--transfer', type=bool, default=False, help='transfer cross architecture')
220 | parser.add_argument('--std', type=bool, default=True, help='standard deviation')
221 | parser.add_argument('--disabled_wandb', type=bool, default=False, help='disable wandb')
222 | parser.add_argument('--test_with_norm', type=bool, default=False, help='')
223 |
224 | parser.add_argument('--clamp_lr', type=float, default=None, help='')
225 |
226 |
227 | # Arguments below are for LoRS
228 |
229 | parser.add_argument('--resume_from', default=None, type=str)
230 |
231 | parser.add_argument('--sim_type', type=str, default="full", choices=["full", "lowrank"], help='similarity matrix type')
232 | parser.add_argument('--sim_rank', type=int, default=10, help='similarity matrix rank')
233 | parser.add_argument('--alpha', type=float, default=0.1, help='alpha in LoRA')
234 | parser.add_argument('--lr_sim', type=float, default=1e-03, help='learning rate for updating similarity mat learning rate')
235 | parser.add_argument('--temperature', type=float, default=0.07, help="temperature of CLIP model")
236 |
237 | parser.add_argument('--momentum_lr', type=float, default=0.5)
238 | parser.add_argument('--momentum_syn', type=float, default=0.5)
239 | parser.add_argument('--momentum_sim', type=float, default=0.5)
240 | parser.add_argument('--merge_loss_branches', action="store_true", default=False)
241 |
242 | # Arguments below are for evaluation
243 |
244 | parser.add_argument('--ckpt_path', type=str)
245 | parser.add_argument('--syn_lr_img', type=float, default=None)
246 | parser.add_argument('--syn_lr_txt', type=float, default=None)
247 | parser.add_argument('--mom', type=float, default=0.9)
248 | parser.add_argument('--l2', type=float, default=0.0005)
249 | parser.add_argument('--clip_similarity', action="store_true", default=False)
250 | args = parser.parse_args()
251 |
252 | main(args)
--------------------------------------------------------------------------------
/images/method.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/silicx/LoRS_Distill/8a712afdb30ed483199bd0f00cc73755e0817fe8/images/method.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | timm==0.9.8
4 | git+https://github.com/openai/CLIP.git
5 | kornia==0.6.3
6 | scikit-learn==1.1.2
7 | scipy==1.10.1
8 | matplotlib==3.6.0
9 | transformers
10 | wandb
11 | opencv-python
12 | transformers
13 | tqdm
--------------------------------------------------------------------------------
/sh/buffer_coco.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=$1 python buffer.py --dataset=coco --train_epochs=10 --num_experts=20 --buffer_path='buffer' --image_encoder=nfnet --text_encoder=bert --image_size=224
--------------------------------------------------------------------------------
/sh/buffer_flickr.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=$1 python buffer.py --dataset=flickr --train_epochs=10 --num_experts=20 --buffer_path='buffer' --image_encoder=nfnet --text_encoder=bert --image_size=224
--------------------------------------------------------------------------------
/sh/distill_coco_lors_100.sh:
--------------------------------------------------------------------------------
1 | FILE_NAME=$(basename -- "$0")
2 | EXP_NAME="${FILE_NAME%.*}"
3 |
4 | export CUDA_VISIBLE_DEVICES=$1;
5 | python distill_tesla_lors.py --dataset=coco \
6 | --buffer_path='buffer/coco/nfnet_bert/InfoNCE' --image_encoder=nfnet --text_encoder=bert \
7 | --image_root='distill_utils/data/COCO/' --merge_loss_branches \
8 | --syn_steps=8 --expert_epochs=1 --max_start_epoch=2 \
9 | --lr_img=1000 --lr_txt=1000 --lr_lr=1e-2 \
10 | --lr_teacher_img 0.1 --lr_teacher_txt 0.1 \
11 | --lr_sim=5.0 --sim_type lowrank --sim_rank 10 --alpha 1.0 \
12 | --num_queries 99 --mini_batch_size=20 \
13 | --loss_type WBCE --name ${EXP_NAME}
14 |
--------------------------------------------------------------------------------
/sh/distill_coco_lors_200.sh:
--------------------------------------------------------------------------------
1 | FILE_NAME=$(basename -- "$0")
2 | EXP_NAME="${FILE_NAME%.*}"
3 |
4 | export CUDA_VISIBLE_DEVICES=$1;
5 | python distill_tesla_lors.py --dataset=coco \
6 | --buffer_path='buffer/coco/nfnet_bert/InfoNCE' --image_encoder=nfnet --text_encoder=bert \
7 | --image_root='distill_utils/data/COCO/' --merge_loss_branches \
8 | --syn_steps=8 --expert_epochs=1 --max_start_epoch=2 \
9 | --lr_img=1000 --lr_txt=1000 --lr_lr=1e-2 \
10 | --lr_teacher_img 0.1 --lr_teacher_txt 0.1 \
11 | --lr_sim=50 --sim_type lowrank --sim_rank 20 --alpha 1.0 \
12 | --num_queries 199 --mini_batch_size=20 \
13 | --loss_type WBCE --name ${EXP_NAME} --Iteration 2000
--------------------------------------------------------------------------------
/sh/distill_coco_lors_500.sh:
--------------------------------------------------------------------------------
1 | FILE_NAME=$(basename -- "$0")
2 | EXP_NAME="${FILE_NAME%.*}"
3 |
4 | export CUDA_VISIBLE_DEVICES=$1;
5 | python distill_tesla_lors.py --dataset=coco \
6 | --buffer_path='buffer/coco/nfnet_bert/InfoNCE' --image_encoder=nfnet --text_encoder=bert \
7 | --image_root='distill_utils/data/COCO/' --merge_loss_branches \
8 | --syn_steps=8 --expert_epochs=1 --max_start_epoch=2 \
9 | --lr_img=5000 --lr_txt=5000 --lr_lr=1e-2 \
10 | --lr_teacher_img 0.1 --lr_teacher_txt 0.1 \
11 | --lr_sim=500 --sim_type lowrank --sim_rank 40 --alpha 1.0 \
12 | --num_queries 499 --mini_batch_size=30 \
13 | --temperature 0.1 --no_aug \
14 | --loss_type WBCE --name ${EXP_NAME}
--------------------------------------------------------------------------------
/sh/distill_flickr_lors_100.sh:
--------------------------------------------------------------------------------
1 | FILE_NAME=$(basename -- "$0")
2 | EXP_NAME="${FILE_NAME%.*}"
3 |
4 | export CUDA_VISIBLE_DEVICES=$1;
5 | python distill_tesla_lors.py --dataset=flickr \
6 | --buffer_path='buffer/flickr/nfnet_bert/InfoNCE' --image_encoder=nfnet --text_encoder=bert \
7 | --syn_steps=8 --expert_epochs=1 --max_start_epoch=2 \
8 | --lr_img=100 --lr_txt=100 --lr_lr=1e-2 \
9 | --lr_teacher_img 0.1 --lr_teacher_txt 0.1 \
10 | --lr_sim=10.0 --sim_type lowrank --sim_rank 10 --alpha 3 \
11 | --num_queries 99 --mini_batch_size=20 \
12 | --loss_type WBCE --name ${EXP_NAME}
--------------------------------------------------------------------------------
/sh/distill_flickr_lors_200.sh:
--------------------------------------------------------------------------------
1 | FILE_NAME=$(basename -- "$0")
2 | EXP_NAME="${FILE_NAME%.*}"
3 |
4 | export CUDA_VISIBLE_DEVICES=$1;
5 | python distill_tesla_lors.py --dataset=flickr \
6 | --buffer_path='buffer/flickr/nfnet_bert/InfoNCE' --image_encoder=nfnet --text_encoder=bert \
7 | --syn_steps=8 --expert_epochs=1 --max_start_epoch=2 \
8 | --lr_img=1000 --lr_txt=1000 --lr_lr=1e-2 \
9 | --lr_teacher_img 0.1 --lr_teacher_txt 0.1 \
10 | --lr_sim=10.0 --sim_type lowrank --sim_rank 5 --alpha 1.0 \
11 | --num_queries 199 --mini_batch_size=20 \
12 | --loss_type WBCE --name ${EXP_NAME}
--------------------------------------------------------------------------------
/sh/distill_flickr_lors_500.sh:
--------------------------------------------------------------------------------
1 | FILE_NAME=$(basename -- "$0")
2 | EXP_NAME="${FILE_NAME%.*}"
3 |
4 | export CUDA_VISIBLE_DEVICES=$1;
5 | python distill_tesla_lors.py --dataset=flickr \
6 | --buffer_path='buffer/flickr/nfnet_bert/InfoNCE' --image_encoder=nfnet --text_encoder=bert \
7 | --syn_steps=8 --expert_epochs=1 --max_start_epoch=3 \
8 | --lr_img=1000 --lr_txt=1000 --lr_lr=1e-2 \
9 | --lr_teacher_img 0.1 --lr_teacher_txt 0.1 \
10 | --lr_sim=100 --sim_type lowrank --sim_rank 20 --alpha 0.01 \
11 | --num_queries 499 --mini_batch_size=40 \
12 | --loss_type WBCE --name ${EXP_NAME} \
13 | --eval_it 300
--------------------------------------------------------------------------------
/src/epoch.py:
--------------------------------------------------------------------------------
1 | '''
2 | * part of the code (i.e. def epoch_test() and itm_eval()) is from: https://github.com/salesforce/BLIP/blob/main/train_retrieval.py#L69
3 | * Copyright (c) 2022, salesforce.com, inc.
4 | * All rights reserved.
5 | * SPDX-License-Identifier: BSD-3-Clause
6 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
7 | * By Junnan Li
8 | '''
9 | from math import ceil
10 | import numpy as np
11 | import torch
12 | import time
13 | import datetime
14 | import torch
15 | import torch.nn.functional as F
16 | from tqdm import tqdm
17 | import torch.nn as nn
18 | from src.utils import *
19 |
20 |
21 |
22 |
23 |
24 | def epoch(e, dataloader, net, optimizer, args):
25 | """
26 | Perform a training epoch on the given dataloader.
27 |
28 | Args:
29 | dataloader (torch.utils.data.DataLoader): The dataloader for iterating over the dataset.
30 | net: The model.
31 | optimizer_img: The optimizer for image parameters.
32 | optimizer_txt: The optimizer for text parameters.
33 | args (object): The arguments specifying the training configuration.
34 |
35 | Returns:
36 | Tuple of average loss and average accuracy.
37 | """
38 | net = net.to(args.device)
39 | net.train()
40 | loss_avg, acc_avg, num_exp = 0, 0, 0
41 |
42 | require_similarity = isinstance(dataloader, (SimilarityDataloader, SimilarityDataLoaderWrapper))
43 |
44 | for i, data in (enumerate(dataloader)):
45 | if args.distill:
46 | if require_similarity:
47 | image, caption, similarity = data
48 | else:
49 | image, caption = data[:2]
50 | else:
51 | if require_similarity:
52 | image, caption, index, similarity = data
53 | else:
54 | image, caption, index = data[:3]
55 |
56 | image = image.to(args.device)
57 | n_b = image.shape[0]
58 |
59 | if require_similarity:
60 | loss, acc = net(image, caption, e, similarity)
61 | else:
62 | loss, acc = net(image, caption, e)
63 |
64 | loss_avg += loss.item() * n_b
65 | acc_avg += acc
66 | num_exp += n_b
67 |
68 | optimizer.zero_grad()
69 | loss.backward()
70 | optimizer.step()
71 |
72 | loss_avg /= num_exp
73 | acc_avg /= num_exp
74 |
75 | return loss_avg, acc_avg
76 |
77 |
78 |
79 |
80 |
81 | @torch.no_grad()
82 | def epoch_test(dataloader, model, device, bert_test_embed):
83 | model.eval()
84 | logit_scale = model.logit_scale.detach()
85 | start_time = time.time()
86 |
87 |
88 | txt_embed = model.text_projection(bert_test_embed.float().to('cuda'))
89 | text_embeds = txt_embed / txt_embed.norm(dim=1, keepdim=True) #torch.Size([5000, 768])
90 | text_embeds = text_embeds.to(device)
91 |
92 | image_embeds = []
93 | for image, img_id in dataloader:
94 | image_feat = model.image_encoder(image.to(device))
95 | im_embed = image_feat / image_feat.norm(dim=1, keepdim=True)
96 | image_embeds.append(im_embed)
97 | image_embeds = torch.cat(image_embeds,dim=0)
98 | use_image_projection = False
99 | if use_image_projection:
100 | im_embed = model.image_projection(image_embeds.float())
101 | image_embeds = im_embed / im_embed.norm(dim=1, keepdim=True)
102 | else:
103 | image_embeds = image_embeds / image_embeds.norm(dim=1, keepdim=True)
104 |
105 | sims_matrix = logit_scale.exp() * image_embeds @ text_embeds.t()
106 | score_matrix_i2t = torch.full((len(image_embeds),len(text_embeds)),-100.0).to(device) #torch.Size([1000, 5000])
107 | for i, sims in enumerate(sims_matrix[0:sims_matrix.size(0) + 1]):
108 | topk_sim, topk_idx = sims.topk(k=128, dim=0)
109 | score_matrix_i2t[i,topk_idx] = topk_sim #i:0-999, topk_idx:0-4999, find top k (k=128) similar text for each image
110 |
111 | sims_matrix = sims_matrix.t()
112 | score_matrix_t2i = torch.full((len(text_embeds),len(image_embeds)),-100.0).to(device)
113 | for i,sims in enumerate(sims_matrix[0:sims_matrix.size(0) + 1]):
114 | topk_sim, topk_idx = sims.topk(k=128, dim=0)
115 | score_matrix_t2i[i,topk_idx] = topk_sim
116 |
117 | total_time = time.time() - start_time
118 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
119 | print('Evaluation time {}'.format(total_time_str))
120 |
121 | return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
122 |
123 |
124 | @torch.no_grad()
125 | def itm_eval(scores_i2t, scores_t2i, txt2img, img2txt):
126 |
127 | #Images->Text
128 | ranks = np.zeros(scores_i2t.shape[0])
129 | # print("TR: ", len(ranks))
130 | for index, score in enumerate(scores_i2t):
131 | inds = np.argsort(score)[::-1]
132 | # Score
133 | rank = 1e20
134 | for i in img2txt[index]:
135 | tmp = np.where(inds == i)[0][0]
136 | if tmp < rank:
137 | rank = tmp
138 | ranks[index] = rank
139 |
140 | # Compute metrics
141 | tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
142 | tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
143 | tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
144 |
145 | #Text->Images
146 | ranks = np.zeros(scores_t2i.shape[0])
147 | # print("IR: ", len(ranks))
148 |
149 | for index,score in enumerate(scores_t2i):
150 | inds = np.argsort(score)[::-1]
151 | ranks[index] = np.where(inds == txt2img[index])[0][0]
152 |
153 | # Compute metrics
154 | ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
155 | ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
156 | ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
157 |
158 | tr_mean = (tr1 + tr5 + tr10) / 3
159 | ir_mean = (ir1 + ir5 + ir10) / 3
160 | r_mean = (tr_mean + ir_mean) / 2
161 |
162 | eval_result = {'txt_r1': tr1,
163 | 'txt_r5': tr5,
164 | 'txt_r10': tr10,
165 | 'txt_r_mean': tr_mean,
166 | 'img_r1': ir1,
167 | 'img_r5': ir5,
168 | 'img_r10': ir10,
169 | 'img_r_mean': ir_mean,
170 | 'r_mean': r_mean}
171 | return eval_result
172 |
173 |
174 | def evaluate_synset(it_eval, net, images_train, labels_train, testloader, args, bert_test_embed, return_loss=False):
175 |
176 | net = net.to(args.device)
177 | images_train = images_train.to(args.device)
178 | labels_train = labels_train.to(args.device)
179 | Epoch = int(args.epoch_eval_train)
180 | optimizer = torch.optim.SGD([
181 | {'params': net.image_encoder.parameters(), 'lr': args.lr_teacher_img},
182 | {'params': net.text_projection.parameters(), 'lr': args.lr_teacher_txt},
183 | ], lr=0, momentum=0.9, weight_decay=0.0005)
184 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[Epoch//2+1], gamma=0.1)
185 |
186 | dst_train = TensorDataset(images_train, labels_train)
187 | trainloader = torch.utils.data.DataLoader(dst_train, batch_size=args.batch_train, shuffle=True, num_workers=0)
188 |
189 | start = time.time()
190 | acc_train_list = []
191 | loss_train_list = []
192 |
193 | eval_epochs = [Epoch]
194 | eval_freq = Epoch // 10
195 | eval_epochs = list(range(0, Epoch+1, eval_freq)) + [Epoch] # eval 10 times in total
196 |
197 | best_val_result = None
198 | for ep in tqdm(range(Epoch+1), ncols=60):
199 | loss_train, acc_train = epoch(ep, trainloader, net, optimizer, args)
200 | acc_train_list.append(acc_train)
201 | loss_train_list.append(loss_train)
202 | if ep in eval_epochs:
203 | with torch.no_grad():
204 | score_val_i2t, score_val_t2i = epoch_test(testloader, net, args.device, bert_test_embed)
205 | val_result = itm_eval(score_val_i2t, score_val_t2i, testloader.dataset.txt2img, testloader.dataset.img2txt)
206 |
207 | print("[Eval_{it:02d}] Ep{ep} | Image R@1={img_r1:.2f} R@5={img_r5:.2f} R@10={img_r10:.2f} | Text R@1={txt_r1:.2f} R@5={txt_r5:.2f} R@10={txt_r10:.2f} | Mean={r_mean:.2f}".format(
208 | it=it_eval, ep=ep, **val_result
209 | ))
210 | if best_val_result is None or val_result["r_mean"] > best_val_result["r_mean"]:
211 | best_val_result = val_result
212 |
213 | lr_scheduler.step()
214 |
215 | time_train = time.time() - start
216 | return net, acc_train_list, best_val_result
217 |
218 |
219 |
220 |
221 |
222 | class SimilarityDataLoaderWrapper:
223 | """make a regular dataloader looks like a SimilarityDataloader,
224 | by return N samples and a N*N identity similarity matrix together"""
225 |
226 | def __init__(self, dataloader):
227 | super().__init__()
228 | self.dataloader = dataloader
229 |
230 | def __iter__(self):
231 | for data in self.dataloader:
232 | bz = data[0].shape[0]
233 | yield data + [torch.eye(bz, dtype=torch.float32).to(data[0].device)]
234 |
235 |
236 | class SimilarityDataloader:
237 | def __init__(self, images_train, labels_train, similarity_train, batch_size, drop_last=False):
238 | """images_train, labels_train: N samples
239 | similarity_train: N*N similarity matrix, similarity_train[i,j] is the similarity between i-th and j-th samples"""
240 | super().__init__()
241 |
242 | self.images_train = images_train
243 | self.labels_train = labels_train
244 | self.similarity_train = similarity_train
245 | self.batch_size = batch_size
246 | assert not drop_last
247 |
248 | def __iter__(self):
249 | size = self.images_train.shape[0]
250 | num_batch = int(ceil(size / self.batch_size))
251 | indices = np.arange(size)
252 | for _ in range(num_batch):
253 | np.random.shuffle(indices)
254 | ids = indices[:self.batch_size]
255 | yield self.images_train[ids], self.labels_train[ids], self.similarity_train[ids[:,None], ids]
256 |
257 |
258 |
259 |
260 | def evaluate_synset_with_similarity(it_eval, net, images_train, labels_train, lr_img, lr_txt, similarity_train, testloader, args, bert_test_embed, mom=0.9, l2=0.0005):
261 | """Added for LoRS, see Algorithm 2 in the paper"""
262 |
263 | net = net.to(args.device)
264 | images_train = images_train.to(args.device)
265 | labels_train = labels_train.to(args.device)
266 | similarity_train = similarity_train.to(args.device)
267 | Epoch = int(args.epoch_eval_train)
268 | optimizer = torch.optim.SGD([
269 | {'params': net.image_encoder.parameters(), 'lr': lr_img},
270 | {'params': net.text_projection.parameters(), 'lr': lr_txt},
271 | ], lr=0, momentum=mom, weight_decay=l2)
272 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[Epoch//2+1], gamma=0.1)
273 |
274 | trainloader = SimilarityDataloader(images_train, labels_train, similarity_train, batch_size=args.batch_train)
275 |
276 | start = time.time()
277 | acc_train_list = []
278 | loss_train_list = []
279 |
280 | eval_freq = Epoch // 10
281 | eval_epochs = list(range(0, Epoch+1, eval_freq)) + [Epoch] # eval 10 times in total
282 |
283 | best_val_result = None
284 | for ep in tqdm(range(Epoch+1), ncols=60):
285 | loss_train, acc_train = epoch(ep, trainloader, net, optimizer, args)
286 | acc_train_list.append(acc_train)
287 | loss_train_list.append(loss_train)
288 | if ep in eval_epochs:
289 | with torch.no_grad():
290 | score_val_i2t, score_val_t2i = epoch_test(testloader, net, args.device, bert_test_embed)
291 | val_result = itm_eval(score_val_i2t, score_val_t2i, testloader.dataset.txt2img, testloader.dataset.img2txt)
292 |
293 | print("[Eval_{it:02d}] Ep{ep} | Image R@1={img_r1:.2f} R@5={img_r5:.2f} R@10={img_r10:.2f} | Text R@1={txt_r1:.2f} R@5={txt_r5:.2f} R@10={txt_r10:.2f} | Mean={r_mean:.2f}".format(
294 | it=it_eval, ep=ep, **val_result
295 | ))
296 | if best_val_result is None or val_result["r_mean"] > best_val_result["r_mean"]:
297 | best_val_result = val_result
298 |
299 | lr_scheduler.step()
300 |
301 | time_train = time.time() - start
302 | assert best_val_result is not None
303 | return net, acc_train_list, best_val_result
304 |
--------------------------------------------------------------------------------
/src/model.py:
--------------------------------------------------------------------------------
1 | # Sourced directly from OpenAI's CLIP repo
2 | from collections import OrderedDict
3 | from typing import Tuple, Union
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn.functional as F
8 | from torch import nn
9 |
10 |
11 | class Bottleneck(nn.Module):
12 | expansion = 4
13 |
14 | def __init__(self, inplanes, planes, stride=1):
15 | super().__init__()
16 |
17 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
18 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
19 | self.bn1 = nn.BatchNorm2d(planes)
20 |
21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22 | self.bn2 = nn.BatchNorm2d(planes)
23 |
24 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
25 |
26 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
27 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
28 |
29 | self.relu = nn.ReLU(inplace=True)
30 | self.downsample = None
31 | self.stride = stride
32 |
33 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
34 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
35 | self.downsample = nn.Sequential(OrderedDict([
36 | ("-1", nn.AvgPool2d(stride)),
37 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
38 | ("1", nn.BatchNorm2d(planes * self.expansion))
39 | ]))
40 |
41 | def forward(self, x: torch.Tensor):
42 | identity = x
43 |
44 | out = self.relu(self.bn1(self.conv1(x)))
45 | out = self.relu(self.bn2(self.conv2(out)))
46 | out = self.avgpool(out)
47 | out = self.bn3(self.conv3(out))
48 |
49 | if self.downsample is not None:
50 | identity = self.downsample(x)
51 |
52 | out += identity
53 | out = self.relu(out)
54 | return out
55 |
56 |
57 | class AttentionPool2d(nn.Module):
58 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
59 | super().__init__()
60 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
61 | self.k_proj = nn.Linear(embed_dim, embed_dim)
62 | self.q_proj = nn.Linear(embed_dim, embed_dim)
63 | self.v_proj = nn.Linear(embed_dim, embed_dim)
64 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
65 | self.num_heads = num_heads
66 |
67 | def forward(self, x):
68 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
69 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
70 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
71 | x, _ = F.multi_head_attention_forward(
72 | query=x, key=x, value=x,
73 | embed_dim_to_check=x.shape[-1],
74 | num_heads=self.num_heads,
75 | q_proj_weight=self.q_proj.weight,
76 | k_proj_weight=self.k_proj.weight,
77 | v_proj_weight=self.v_proj.weight,
78 | in_proj_weight=None,
79 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
80 | bias_k=None,
81 | bias_v=None,
82 | add_zero_attn=False,
83 | dropout_p=0,
84 | out_proj_weight=self.c_proj.weight,
85 | out_proj_bias=self.c_proj.bias,
86 | use_separate_proj_weight=True,
87 | training=self.training,
88 | need_weights=False
89 | )
90 |
91 | return x[0]
92 |
93 |
94 |
95 | class LayerNorm(nn.LayerNorm):
96 | """Subclass torch's LayerNorm to handle fp16."""
97 |
98 | def forward(self, x: torch.Tensor):
99 | orig_type = x.dtype
100 | ret = super().forward(x.type(torch.float32))
101 | return ret.type(orig_type)
102 |
103 |
104 | class QuickGELU(nn.Module):
105 | def forward(self, x: torch.Tensor):
106 | return x * torch.sigmoid(1.702 * x)
107 |
108 |
109 | class ResidualAttentionBlock(nn.Module):
110 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
111 | super().__init__()
112 |
113 | self.attn = nn.MultiheadAttention(d_model, n_head)
114 | self.ln_1 = LayerNorm(d_model)
115 | self.mlp = nn.Sequential(OrderedDict([
116 | ("c_fc", nn.Linear(d_model, d_model * 4)),
117 | ("gelu", QuickGELU()),
118 | ("c_proj", nn.Linear(d_model * 4, d_model))
119 | ]))
120 | self.ln_2 = LayerNorm(d_model)
121 | self.attn_mask = attn_mask
122 |
123 | def attention(self, x: torch.Tensor):
124 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
125 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
126 |
127 | def forward(self, x: torch.Tensor):
128 | x = x + self.attention(self.ln_1(x))
129 | x = x + self.mlp(self.ln_2(x))
130 | return x
131 |
132 |
133 |
134 | def convert_weights(model: nn.Module):
135 | """Convert applicable model parameters to fp16"""
136 |
137 | def _convert_weights_to_fp16(l):
138 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
139 | l.weight.data = l.weight.data.half()
140 | if l.bias is not None:
141 | l.bias.data = l.bias.data.half()
142 |
143 | if isinstance(l, nn.MultiheadAttention):
144 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
145 | tensor = getattr(l, attr)
146 | if tensor is not None:
147 | tensor.data = tensor.data.half()
148 |
149 | for name in ["text_projection", "proj"]:
150 | if hasattr(l, name):
151 | attr = getattr(l, name)
152 | if attr is not None:
153 | attr.data = attr.data.half()
154 |
155 | model.apply(_convert_weights_to_fp16)
156 |
157 |
158 | def build_model(state_dict: dict):
159 | vit = "visual.proj" in state_dict
160 |
161 | if vit:
162 | vision_width = state_dict["visual.conv1.weight"].shape[0]
163 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
164 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
165 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
166 | image_resolution = vision_patch_size * grid_size
167 | else:
168 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
169 | vision_layers = tuple(counts)
170 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
171 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
172 | vision_patch_size = None
173 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
174 | image_resolution = output_width * 32
175 |
176 | embed_dim = state_dict["text_projection"].shape[1]
177 | context_length = state_dict["positional_embedding"].shape[0]
178 | vocab_size = state_dict["token_embedding.weight"].shape[0]
179 | transformer_width = state_dict["ln_final.weight"].shape[0]
180 | transformer_heads = transformer_width // 64
181 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
182 |
183 | model = CLIP(
184 | embed_dim,
185 | image_resolution, vision_layers, vision_width, vision_patch_size,
186 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
187 | )
188 |
189 | for key in ["input_resolution", "context_length", "vocab_size"]:
190 | if key in state_dict:
191 | del state_dict[key]
192 |
193 | convert_weights(model)
194 | model.load_state_dict(state_dict)
195 | return model.eval()
--------------------------------------------------------------------------------
/src/networks.py:
--------------------------------------------------------------------------------
1 | # import pdb
2 | # import warnings
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch
6 | # from collections import OrderedDict
7 | # from typing import Tuple, Union
8 | import clip
9 | # from transformers import ViTConfig, ViTModel, AutoTokenizer, CLIPTextModel, CLIPTextConfig, CLIPProcessor, CLIPConfig
10 | import numpy as np
11 | from transformers import BertTokenizer, BertModel, DistilBertModel, DistilBertTokenizer, OpenAIGPTTokenizer, OpenAIGPTModel
12 | # from torchvision.models import resnet18, resnet
13 | from transformers.models.bert.modeling_bert import BertAttention, BertConfig
14 | import functools
15 | import copy
16 |
17 | from .similarity_mining import MultilabelContrastiveLoss
18 |
19 |
20 | # Caching text models, by Yue Xu
21 |
22 | @functools.lru_cache(maxsize=128)
23 | def get_bert_stuff():
24 | tokenizer = BertTokenizer.from_pretrained('./distill_utils/checkpoints/bert-base-uncased')
25 | BERT_model = BertModel.from_pretrained('./distill_utils/checkpoints/bert-base-uncased')
26 | return BERT_model, tokenizer
27 |
28 | @functools.lru_cache(maxsize=128)
29 | def get_distilbert_stuff():
30 | DistilBERT_model_name = "distilbert-base-uncased"
31 | DistilBERT_model = DistilBertModel.from_pretrained(DistilBERT_model_name, local_files_only=True)
32 | DistilBERT_tokenizer = DistilBertTokenizer.from_pretrained(DistilBERT_model_name, local_files_only=True)
33 | return DistilBERT_model, DistilBERT_tokenizer
34 |
35 | @functools.lru_cache(maxsize=128)
36 | def get_gpt1_stuff():
37 | model_name = './distill_utils/checkpoints/openai-gpt'
38 | model = OpenAIGPTModel.from_pretrained(model_name)
39 | tokenizer = OpenAIGPTTokenizer.from_pretrained(model_name)
40 | return model, tokenizer
41 |
42 |
43 | # Acknowledgement to
44 | # https://github.com/kuangliu/pytorch-cifar,
45 | # https://github.com/BIGBALLON/CIFAR-ZOO,
46 |
47 | # adapted from
48 | # https://github.com/VICO-UoE/DatasetCondensation
49 | # https://github.com/Zasder3/train-CLIP
50 |
51 |
52 | ''' MLP '''
53 | class MLP(nn.Module):
54 | def __init__(self, channel, num_classes):
55 | super(MLP, self).__init__()
56 | self.fc_1 = nn.Linear(28*28*1 if channel==1 else 32*32*3, 128)
57 | self.fc_2 = nn.Linear(128, 128)
58 | self.fc_3 = nn.Linear(128, num_classes)
59 |
60 | def forward(self, x):
61 | out = x.view(x.size(0), -1)
62 | out = F.relu(self.fc_1(out))
63 | out = F.relu(self.fc_2(out))
64 | out = self.fc_3(out)
65 | return out
66 |
67 |
68 |
69 | ''' ConvNet '''
70 | class ConvNet(nn.Module):
71 | def __init__(self, channel, num_classes, net_width=128, net_depth=4, net_act='relu', net_norm='instancenorm', net_pooling='avgpooling', im_size = (224,224)):
72 | super(ConvNet, self).__init__()
73 |
74 | self.features, shape_feat = self._make_layers(channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size)
75 | num_feat = shape_feat[0]*shape_feat[1]*shape_feat[2]
76 | self.classifier = nn.Linear(num_feat, num_classes)
77 |
78 | def forward(self, x):
79 | # print("MODEL DATA ON: ", x.get_device(), "MODEL PARAMS ON: ", self.classifier.weight.data.get_device())
80 | out = self.features(x)
81 | out = out.view(out.size(0), -1)
82 | out = self.classifier(out)
83 | return out
84 |
85 | def _get_activation(self, net_act):
86 | if net_act == 'sigmoid':
87 | return nn.Sigmoid()
88 | elif net_act == 'relu':
89 | return nn.ReLU(inplace=True)
90 | elif net_act == 'leakyrelu':
91 | return nn.LeakyReLU(negative_slope=0.01)
92 | else:
93 | exit('unknown activation function: %s'%net_act)
94 |
95 | def _get_pooling(self, net_pooling):
96 | if net_pooling == 'maxpooling':
97 | return nn.MaxPool2d(kernel_size=2, stride=2)
98 | elif net_pooling == 'avgpooling':
99 | return nn.AvgPool2d(kernel_size=2, stride=2)
100 | elif net_pooling == 'none':
101 | return None
102 | else:
103 | exit('unknown net_pooling: %s'%net_pooling)
104 |
105 | def _get_normlayer(self, net_norm, shape_feat):
106 | # shape_feat = (c*h*w)
107 | if net_norm == 'batchnorm':
108 | return nn.BatchNorm2d(shape_feat[0], affine=True)
109 | elif net_norm == 'layernorm':
110 | return nn.LayerNorm(shape_feat, elementwise_affine=True)
111 | elif net_norm == 'instancenorm':
112 | return nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True)
113 | elif net_norm == 'groupnorm':
114 | return nn.GroupNorm(4, shape_feat[0], affine=True)
115 | elif net_norm == 'none':
116 | return None
117 | else:
118 | exit('unknown net_norm: %s'%net_norm)
119 |
120 | def _make_layers(self, channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size):
121 | layers = []
122 | in_channels = channel
123 | if im_size[0] == 28:
124 | im_size = (32, 32)
125 | shape_feat = [in_channels, im_size[0], im_size[1]]
126 | for d in range(net_depth):
127 | layers += [nn.Conv2d(in_channels, net_width, kernel_size=3, padding=3 if channel == 1 and d == 0 else 1)]
128 | shape_feat[0] = net_width
129 | if net_norm != 'none':
130 | layers += [self._get_normlayer(net_norm, shape_feat)]
131 | layers += [self._get_activation(net_act)]
132 | in_channels = net_width
133 | if net_pooling != 'none':
134 | layers += [self._get_pooling(net_pooling)]
135 | shape_feat[1] //= 2
136 | shape_feat[2] //= 2
137 |
138 |
139 | return nn.Sequential(*layers), shape_feat
140 |
141 |
142 | ''' ConvNet '''
143 | class ConvNetGAP(nn.Module):
144 | def __init__(self, channel, num_classes, net_width, net_depth, net_act, net_norm, net_pooling, im_size = (32,32)):
145 | super(ConvNetGAP, self).__init__()
146 |
147 | self.features, shape_feat = self._make_layers(channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size)
148 | num_feat = shape_feat[0]*shape_feat[1]*shape_feat[2]
149 | # self.classifier = nn.Linear(num_feat, num_classes)
150 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
151 | self.classifier = nn.Linear(shape_feat[0], num_classes)
152 |
153 | def forward(self, x):
154 | out = self.features(x)
155 | out = self.avgpool(out)
156 | out = out.view(out.size(0), -1)
157 | out = self.classifier(out)
158 | return out
159 |
160 | def _get_activation(self, net_act):
161 | if net_act == 'sigmoid':
162 | return nn.Sigmoid()
163 | elif net_act == 'relu':
164 | return nn.ReLU(inplace=True)
165 | elif net_act == 'leakyrelu':
166 | return nn.LeakyReLU(negative_slope=0.01)
167 | else:
168 | exit('unknown activation function: %s'%net_act)
169 |
170 | def _get_pooling(self, net_pooling):
171 | if net_pooling == 'maxpooling':
172 | return nn.MaxPool2d(kernel_size=2, stride=2)
173 | elif net_pooling == 'avgpooling':
174 | return nn.AvgPool2d(kernel_size=2, stride=2)
175 | elif net_pooling == 'none':
176 | return None
177 | else:
178 | exit('unknown net_pooling: %s'%net_pooling)
179 |
180 | def _get_normlayer(self, net_norm, shape_feat):
181 | # shape_feat = (c*h*w)
182 | if net_norm == 'batchnorm':
183 | return nn.BatchNorm2d(shape_feat[0], affine=True)
184 | elif net_norm == 'layernorm':
185 | return nn.LayerNorm(shape_feat, elementwise_affine=True)
186 | elif net_norm == 'instancenorm':
187 | return nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True)
188 | elif net_norm == 'groupnorm':
189 | return nn.GroupNorm(4, shape_feat[0], affine=True)
190 | elif net_norm == 'none':
191 | return None
192 | else:
193 | exit('unknown net_norm: %s'%net_norm)
194 |
195 | def _make_layers(self, channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size):
196 | layers = []
197 | in_channels = channel
198 | if im_size[0] == 28:
199 | im_size = (32, 32)
200 | shape_feat = [in_channels, im_size[0], im_size[1]]
201 | for d in range(net_depth):
202 | layers += [nn.Conv2d(in_channels, net_width, kernel_size=3, padding=3 if channel == 1 and d == 0 else 1)]
203 | shape_feat[0] = net_width
204 | if net_norm != 'none':
205 | layers += [self._get_normlayer(net_norm, shape_feat)]
206 | layers += [self._get_activation(net_act)]
207 | in_channels = net_width
208 | if net_pooling != 'none':
209 | layers += [self._get_pooling(net_pooling)]
210 | shape_feat[1] //= 2
211 | shape_feat[2] //= 2
212 |
213 | return nn.Sequential(*layers), shape_feat
214 |
215 |
216 | ''' LeNet '''
217 | class LeNet(nn.Module):
218 | def __init__(self, channel, num_classes):
219 | super(LeNet, self).__init__()
220 | self.features = nn.Sequential(
221 | nn.Conv2d(channel, 6, kernel_size=5, padding=2 if channel==1 else 0),
222 | nn.ReLU(inplace=True),
223 | nn.MaxPool2d(kernel_size=2, stride=2),
224 | nn.Conv2d(6, 16, kernel_size=5),
225 | nn.ReLU(inplace=True),
226 | nn.MaxPool2d(kernel_size=2, stride=2),
227 | )
228 | self.fc_1 = nn.Linear(16 * 5 * 5, 120)
229 | self.fc_2 = nn.Linear(120, 84)
230 | self.fc_3 = nn.Linear(84, num_classes)
231 |
232 | def forward(self, x):
233 | x = self.features(x)
234 | x = x.view(x.size(0), -1)
235 | x = F.relu(self.fc_1(x))
236 | x = F.relu(self.fc_2(x))
237 | x = self.fc_3(x)
238 | return x
239 |
240 |
241 |
242 | ''' AlexNet '''
243 | class AlexNet(nn.Module):
244 | def __init__(self, channel, num_classes):
245 | super(AlexNet, self).__init__()
246 | self.features = nn.Sequential(
247 | nn.Conv2d(channel, 128, kernel_size=5, stride=1, padding=4 if channel==1 else 2),
248 | nn.ReLU(inplace=True),
249 | nn.MaxPool2d(kernel_size=2, stride=2),
250 | nn.Conv2d(128, 192, kernel_size=5, padding=2),
251 | nn.ReLU(inplace=True),
252 | nn.MaxPool2d(kernel_size=2, stride=2),
253 | nn.Conv2d(192, 256, kernel_size=3, padding=1),
254 | nn.ReLU(inplace=True),
255 | nn.Conv2d(256, 192, kernel_size=3, padding=1),
256 | nn.ReLU(inplace=True),
257 | nn.Conv2d(192, 192, kernel_size=3, padding=1),
258 | nn.ReLU(inplace=True),
259 | nn.MaxPool2d(kernel_size=2, stride=2),
260 | )
261 | self.fc = nn.Linear(192 * 4 * 4, num_classes)
262 |
263 | def forward(self, x):
264 | x = self.features(x)
265 | x = x.view(x.size(0), -1)
266 | x = self.fc(x)
267 | return x
268 |
269 |
270 |
271 | ''' VGG '''
272 | cfg_vgg = {
273 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
274 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
275 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
276 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
277 | }
278 | class VGG(nn.Module):
279 | def __init__(self, vgg_name, channel, num_classes, norm='instancenorm'):
280 | super(VGG, self).__init__()
281 | self.channel = channel
282 | self.features = self._make_layers(cfg_vgg[vgg_name], norm)
283 | self.classifier = nn.Linear(512 if vgg_name != 'VGGS' else 128, num_classes)
284 |
285 | def forward(self, x):
286 | x = self.features(x)
287 | x = x.view(x.size(0), -1)
288 | x = self.classifier(x)
289 | return x
290 |
291 | def _make_layers(self, cfg, norm):
292 | layers = []
293 | in_channels = self.channel
294 | for ic, x in enumerate(cfg):
295 | if x == 'M':
296 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
297 | else:
298 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=3 if self.channel==1 and ic==0 else 1),
299 | nn.GroupNorm(x, x, affine=True) if norm=='instancenorm' else nn.BatchNorm2d(x),
300 | nn.ReLU(inplace=True)]
301 | in_channels = x
302 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
303 | return nn.Sequential(*layers)
304 |
305 |
306 | def VGG11(channel, num_classes):
307 | return VGG('VGG11', channel, num_classes)
308 | def VGG11BN(channel, num_classes):
309 | return VGG('VGG11', channel, num_classes, norm='batchnorm')
310 | def VGG13(channel, num_classes):
311 | return VGG('VGG13', channel, num_classes)
312 | def VGG16(channel, num_classes):
313 | return VGG('VGG16', channel, num_classes)
314 | def VGG19(channel, num_classes):
315 | return VGG('VGG19', channel, num_classes)
316 |
317 |
318 | ''' ResNet_AP '''
319 | # The conv(stride=2) is replaced by conv(stride=1) + avgpool(kernel_size=2, stride=2)
320 |
321 | class BasicBlock_AP(nn.Module):
322 | expansion = 1
323 |
324 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'):
325 | super(BasicBlock_AP, self).__init__()
326 | self.norm = norm
327 | self.stride = stride
328 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False) # modification
329 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
330 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
331 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
332 |
333 | self.shortcut = nn.Sequential()
334 | if stride != 1 or in_planes != self.expansion * planes:
335 | self.shortcut = nn.Sequential(
336 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=1, bias=False),
337 | nn.AvgPool2d(kernel_size=2, stride=2), # modification
338 | nn.GroupNorm(self.expansion * planes, self.expansion * planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion * planes)
339 | )
340 |
341 | def forward(self, x):
342 | out = F.relu(self.bn1(self.conv1(x)))
343 | if self.stride != 1: # modification
344 | out = F.avg_pool2d(out, kernel_size=2, stride=2)
345 | out = self.bn2(self.conv2(out))
346 | out += self.shortcut(x)
347 | out = F.relu(out)
348 | return out
349 |
350 |
351 | class Bottleneck_AP(nn.Module):
352 | expansion = 4
353 |
354 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'):
355 | super(Bottleneck_AP, self).__init__()
356 | self.norm = norm
357 | self.stride = stride
358 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
359 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
360 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) # modification
361 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
362 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
363 | self.bn3 = nn.GroupNorm(self.expansion * planes, self.expansion * planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion * planes)
364 |
365 | self.shortcut = nn.Sequential()
366 | if stride != 1 or in_planes != self.expansion * planes:
367 | self.shortcut = nn.Sequential(
368 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=1, bias=False),
369 | nn.AvgPool2d(kernel_size=2, stride=2), # modification
370 | nn.GroupNorm(self.expansion * planes, self.expansion * planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion * planes)
371 | )
372 |
373 | def forward(self, x):
374 | out = F.relu(self.bn1(self.conv1(x)))
375 | out = F.relu(self.bn2(self.conv2(out)))
376 | if self.stride != 1: # modification
377 | out = F.avg_pool2d(out, kernel_size=2, stride=2)
378 | out = self.bn3(self.conv3(out))
379 | out += self.shortcut(x)
380 | out = F.relu(out)
381 | return out
382 |
383 |
384 | class ResNet_AP(nn.Module):
385 | def __init__(self, block, num_blocks, channel=3, num_classes=10, norm='instancenorm'):
386 | super(ResNet_AP, self).__init__()
387 | self.in_planes = 64
388 | self.norm = norm
389 |
390 | self.conv1 = nn.Conv2d(channel, 64, kernel_size=3, stride=1, padding=1, bias=False)
391 | self.bn1 = nn.GroupNorm(64, 64, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(64)
392 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
393 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
394 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
395 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
396 | self.classifier = nn.Linear(512 * block.expansion * 3 * 3 if channel==1 else 512 * block.expansion * 4 * 4, num_classes) # modification
397 |
398 | def _make_layer(self, block, planes, num_blocks, stride):
399 | strides = [stride] + [1] * (num_blocks - 1)
400 | layers = []
401 | for stride in strides:
402 | layers.append(block(self.in_planes, planes, stride, self.norm))
403 | self.in_planes = planes * block.expansion
404 | return nn.Sequential(*layers)
405 |
406 | def forward(self, x):
407 | out = F.relu(self.bn1(self.conv1(x)))
408 | out = self.layer1(out)
409 | out = self.layer2(out)
410 | out = self.layer3(out)
411 | out = self.layer4(out)
412 | out = F.avg_pool2d(out, kernel_size=1, stride=1) # modification
413 | out = out.view(out.size(0), -1)
414 | out = self.classifier(out)
415 | return out
416 |
417 |
418 | def ResNet18BN_AP(channel, num_classes):
419 | return ResNet_AP(BasicBlock_AP, [2,2,2,2], channel=channel, num_classes=num_classes, norm='batchnorm')
420 |
421 | def ResNet18_AP(channel, num_classes):
422 | return ResNet_AP(BasicBlock_AP, [2,2,2,2], channel=channel, num_classes=num_classes)
423 |
424 |
425 | ''' ResNet '''
426 |
427 | class BasicBlock(nn.Module):
428 | expansion = 1
429 |
430 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'):
431 | super(BasicBlock, self).__init__()
432 | self.norm = norm
433 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
434 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
435 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
436 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
437 |
438 | self.shortcut = nn.Sequential()
439 | if stride != 1 or in_planes != self.expansion*planes:
440 | self.shortcut = nn.Sequential(
441 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
442 | nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes)
443 | )
444 |
445 | def forward(self, x):
446 | out = F.relu(self.bn1(self.conv1(x)))
447 | out = self.bn2(self.conv2(out))
448 | out += self.shortcut(x)
449 | out = F.relu(out)
450 | return out
451 |
452 |
453 | class Bottleneck(nn.Module):
454 | expansion = 4
455 |
456 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'):
457 | super(Bottleneck, self).__init__()
458 | self.norm = norm
459 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
460 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
461 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
462 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
463 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
464 | self.bn3 = nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes)
465 |
466 | self.shortcut = nn.Sequential()
467 | if stride != 1 or in_planes != self.expansion*planes:
468 | self.shortcut = nn.Sequential(
469 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
470 | nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes)
471 | )
472 |
473 | def forward(self, x):
474 | out = F.relu(self.bn1(self.conv1(x)))
475 | out = F.relu(self.bn2(self.conv2(out)))
476 | out = self.bn3(self.conv3(out))
477 | out += self.shortcut(x)
478 | out = F.relu(out)
479 | return out
480 |
481 |
482 | class ResNetImageNet(nn.Module):
483 | def __init__(self, block, num_blocks, channel=3, num_classes=10, norm='instancenorm'):
484 | super(ResNetImageNet, self).__init__()
485 | self.in_planes = 64
486 | self.norm = norm
487 |
488 | self.conv1 = nn.Conv2d(channel, 64, kernel_size=7, stride=2, padding=3, bias=False)
489 | self.bn1 = nn.GroupNorm(64, 64, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(64)
490 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
491 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
492 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
493 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
494 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
495 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
496 | self.classifier = nn.Linear(512*block.expansion, num_classes)
497 |
498 | def _make_layer(self, block, planes, num_blocks, stride):
499 | strides = [stride] + [1]*(num_blocks-1)
500 | layers = []
501 | for stride in strides:
502 | layers.append(block(self.in_planes, planes, stride, self.norm))
503 | self.in_planes = planes * block.expansion
504 | return nn.Sequential(*layers)
505 |
506 | def forward(self, x):
507 | out = F.relu(self.bn1(self.conv1(x)))
508 | out = self.maxpool(out)
509 | out = self.layer1(out)
510 | out = self.layer2(out)
511 | out = self.layer3(out)
512 | out = self.layer4(out)
513 | # out = F.avg_pool2d(out, 4)
514 | # out = out.view(out.size(0), -1)
515 | out = self.avgpool(out)
516 | out = torch.flatten(out, 1)
517 | out = self.classifier(out)
518 | return out
519 |
520 |
521 | def ResNet18BN(channel, num_classes):
522 | return ResNet(BasicBlock, [2,2,2,2], channel=channel, num_classes=num_classes, norm='batchnorm')
523 |
524 | def ResNet18(channel, num_classes):
525 | return ResNet_gn(BasicBlock, [2,2,2,2], channel=channel, num_classes=num_classes)
526 |
527 | def ResNet34(channel, num_classes):
528 | return ResNet(BasicBlock, [3,4,6,3], channel=channel, num_classes=num_classes)
529 |
530 | def ResNet50(channel, num_classes):
531 | return ResNet(Bottleneck, [3,4,6,3], channel=channel, num_classes=num_classes)
532 |
533 | def ResNet101(channel, num_classes):
534 | return ResNet(Bottleneck, [3,4,23,3], channel=channel, num_classes=num_classes)
535 |
536 | def ResNet152(channel, num_classes):
537 | return ResNet(Bottleneck, [3,8,36,3], channel=channel, num_classes=num_classes)
538 |
539 | def ResNet18ImageNet(channel, num_classes):
540 | return ResNetImageNet(BasicBlock, [2,2,2,2], channel=channel, num_classes=num_classes)
541 |
542 | def ResNet6ImageNet(channel, num_classes):
543 | return ResNetImageNet(BasicBlock, [1,1,1,1], channel=channel, num_classes=num_classes)
544 |
545 | def resnet18_gn(pretrained=False, **kwargs):
546 | """Constructs a ResNet-18 model.
547 | """
548 | model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2])
549 | return _create_resnet('resnet18', pretrained, **dict(model_args, **kwargs))
550 |
551 |
552 | ## Sourced directly from OpenAI's CLIP repo
553 | class ModifiedResNet(nn.Module):
554 | """
555 | A ResNet class that is similar to torchvision's but contains the following changes:
556 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
557 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
558 | - The final pooling layer is a QKV attention instead of an average pool
559 | """
560 |
561 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
562 | super().__init__()
563 | self.output_dim = output_dim
564 | self.input_resolution = input_resolution
565 |
566 | # the 3-layer stem
567 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
568 | self.bn1 = nn.BatchNorm2d(width // 2)
569 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
570 | self.bn2 = nn.BatchNorm2d(width // 2)
571 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
572 | self.bn3 = nn.BatchNorm2d(width)
573 | self.avgpool = nn.AvgPool2d(2)
574 | self.relu = nn.ReLU(inplace=True)
575 |
576 | # residual layers
577 | self._inplanes = width # this is a *mutable* variable used during construction
578 | self.layer1 = self._make_layer(width, layers[0])
579 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
580 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
581 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
582 |
583 | embed_dim = width * 32 # the ResNet feature dimension
584 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
585 |
586 | def _make_layer(self, planes, blocks, stride=1):
587 | layers = [Bottleneck(self._inplanes, planes, stride)]
588 |
589 | self._inplanes = planes * Bottleneck.expansion
590 | for _ in range(1, blocks):
591 | layers.append(Bottleneck(self._inplanes, planes))
592 |
593 | return nn.Sequential(*layers)
594 |
595 | def forward(self, x):
596 | def stem(x):
597 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
598 | x = self.relu(bn(conv(x)))
599 | x = self.avgpool(x)
600 | return x
601 |
602 | x = x.type(self.conv1.weight.dtype)
603 | x = stem(x)
604 | x = self.layer1(x)
605 | x = self.layer2(x)
606 | x = self.layer3(x)
607 | x = self.layer4(x)
608 | x = self.attnpool(x)
609 |
610 | return x
611 |
612 |
613 | class AttentionPool2d(nn.Module):
614 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
615 | super().__init__()
616 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
617 | self.k_proj = nn.Linear(embed_dim, embed_dim)
618 | self.q_proj = nn.Linear(embed_dim, embed_dim)
619 | self.v_proj = nn.Linear(embed_dim, embed_dim)
620 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
621 | self.num_heads = num_heads
622 |
623 | def forward(self, x):
624 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
625 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
626 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
627 | x, _ = F.multi_head_attention_forward(
628 | query=x, key=x, value=x,
629 | embed_dim_to_check=x.shape[-1],
630 | num_heads=self.num_heads,
631 | q_proj_weight=self.q_proj.weight,
632 | k_proj_weight=self.k_proj.weight,
633 | v_proj_weight=self.v_proj.weight,
634 | in_proj_weight=None,
635 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
636 | bias_k=None,
637 | bias_v=None,
638 | add_zero_attn=False,
639 | dropout_p=0,
640 | out_proj_weight=self.c_proj.weight,
641 | out_proj_bias=self.c_proj.bias,
642 | use_separate_proj_weight=True,
643 | training=self.training,
644 | need_weights=False
645 | )
646 |
647 | return x[0]
648 |
649 | import timm
650 |
651 | class ProjectionHead(nn.Module):
652 | def __init__(
653 | self,
654 | embedding_dim,
655 | projection_dim=768,
656 | dropout=0.1
657 | ):
658 | super().__init__()
659 | self.projection = nn.Linear(embedding_dim, projection_dim)
660 | self.gelu = nn.GELU()
661 | self.fc = nn.Linear(projection_dim, projection_dim)
662 | self.dropout = nn.Dropout(dropout)
663 | self.layer_norm = nn.LayerNorm(projection_dim)
664 |
665 | def forward(self, x):
666 | projected = self.projection(x)
667 | x = self.gelu(projected)
668 | x = self.fc(x)
669 | x = self.dropout(x)
670 | x = x + projected
671 | x = self.layer_norm(x)
672 | return x
673 |
674 |
675 |
676 | @functools.lru_cache(maxsize=128)
677 | def load_from_timm(model_name, pretrained):
678 | if model_name == 'clip':
679 | raise NotImplementedError("it is unfair to use pretrained clip")
680 | # if pretrained:
681 | # model, preprocess = clip.load("ViT-B/32", device='cuda')
682 | # else:
683 | # configuration = ViTConfig()
684 | # model = ViTModel(configuration)
685 |
686 | elif model_name == 'nfnet':
687 | model = timm.create_model('nfnet_l0', pretrained=pretrained, num_classes=0, global_pool="avg",
688 | pretrained_cfg_overlay=dict(file='distill_utils/checkpoints/nfnet_l0_ra2-45c6688d.pth'),)
689 | elif model_name == 'vit':
690 | model = timm.create_model('vit_tiny_patch16_224', pretrained=True)
691 | elif model_name == 'nf_resnet50':
692 | model = timm.create_model('nf_resnet50', pretrained=True)
693 | elif model_name == 'nf_regnet':
694 | model = timm.create_model('nf_regnet_b1', pretrained=True)
695 | elif model_name=="efficientvit_m5":
696 | model = timm.create_model(model_name, num_classes=0, pretrained=True)
697 | else:
698 | model = timm.create_model(model_name, num_classes=0, pretrained=True, global_pool="avg")
699 |
700 |
701 | return model
702 |
703 |
704 |
705 |
706 | class ImageEncoder(nn.Module):
707 | """
708 | Encode images to a fixed size vector
709 | """
710 |
711 | def __init__(self, args):
712 | super().__init__()
713 | self.model_name = args.image_encoder
714 | self.pretrained = args.image_pretrained
715 | self.trainable = args.image_trainable
716 |
717 | self.model = copy.deepcopy(load_from_timm(self.model_name, self.pretrained))
718 | # use cached pretrained model
719 |
720 | for p in self.model.parameters():
721 | p.requires_grad = self.trainable
722 |
723 | def forward(self, x):
724 | if self.model_name == 'clip' and self.pretrained:
725 | return self.model.encode_image(x)
726 | else:
727 | return self.model(x)
728 |
729 | def gradient(self, x, y):
730 | # Compute the gradient of the mean squared error loss with respect to the weights
731 | loss = self.loss(x, y)
732 | grad = torch.autograd.grad(loss, self.parameters(), create_graph=True)
733 | return torch.cat([g.view(-1) for g in grad])
734 |
735 |
736 |
737 |
738 | class TextEncoder(nn.Module):
739 | def __init__(self, args):
740 | super().__init__()
741 | self.pretrained = args.text_pretrained
742 | self.trainable = args.text_trainable
743 | self.model_name = args.text_encoder
744 |
745 | if self.model_name == 'clip':
746 | self.model, preprocess = clip.load("ViT-B/32", device='cuda', download_root="distill_utils/checkpoints")
747 | elif self.model_name == 'bert':
748 | pt_model, self.tokenizer = get_bert_stuff()
749 | if args.text_pretrained:
750 | self.model = pt_model
751 | else:
752 | self.model = BertModel(BertConfig())
753 | self.model.init_weights()
754 | elif self.model_name == 'distilbert':
755 | self.model, self.tokenizer = get_distilbert_stuff()
756 | elif self.model_name == 'gpt1':
757 | self.model, self.tokenizer = get_gpt1_stuff()
758 | else:
759 | raise NotImplementedError(self.model_name)
760 |
761 | for p in self.model.parameters():
762 | p.requires_grad = self.trainable
763 |
764 | # we are using the CLS token hidden representation as the sentence's embedding
765 | self.target_token_idx = 0
766 |
767 | def forward(self, texts, device='cuda'):
768 | if self.model_name == 'clip':
769 | output = self.model.encode_text(clip.tokenize(texts).to('cuda'))
770 |
771 | elif self.model_name == 'bert':
772 | # Tokenize the input text
773 | encoding = self.tokenizer.batch_encode_plus(texts, return_tensors='pt', padding=True, truncation=True)
774 | input_ids = encoding['input_ids'].to(device)
775 | attention_mask = encoding['attention_mask'].to(device)
776 | output = self.model(input_ids, attention_mask=attention_mask).last_hidden_state[:, self.target_token_idx, :]
777 |
778 | elif self.model_name == 'distilbert':
779 | encoding = self.tokenizer.batch_encode_plus(texts, return_tensors='pt', padding=True, truncation=True)
780 | input_ids = encoding['input_ids'].to(device)
781 | attention_mask = encoding['attention_mask'].to(device)
782 | output = self.model(input_ids, attention_mask=attention_mask).last_hidden_state[:, self.target_token_idx, :]
783 |
784 | elif self.model_name == 'gpt1':
785 | self.tokenizer.pad_token = ' '
786 | encoding = self.tokenizer.batch_encode_plus(texts, return_tensors='pt', padding=True, truncation=True)
787 | input_ids = encoding['input_ids'].to(device)
788 | attention_mask = encoding['attention_mask'].to(device)
789 | output = self.model(input_ids, attention_mask=attention_mask).last_hidden_state[:, self.target_token_idx, :]
790 |
791 | return output
792 |
793 |
794 |
795 |
796 | class CLIPModel_full(nn.Module):
797 | def __init__(
798 | self,
799 | args,
800 | train_logit_scale=False,
801 | temperature=0.07,
802 | eval_stage=False
803 | ):
804 | super().__init__()
805 |
806 | if args.image_encoder == 'nfnet':
807 | if eval_stage:
808 | self.image_embedding = 1000 #2048
809 | else:
810 | self.image_embedding = 2304
811 | elif args.image_encoder == 'convnet':
812 | self.image_embedding = 768
813 | elif args.image_encoder == 'resnet18':
814 | self.image_embedding = 512
815 | elif args.image_encoder == 'convnext':
816 | self.image_embedding = 640
817 | else:
818 | self.image_embedding = 1000
819 |
820 | if args.text_encoder == 'clip':
821 | self.text_embedding = 512
822 | elif args.text_encoder == 'bert':
823 | self.text_embedding = 768
824 | elif args.text_encoder == 'distilbert':
825 | self.text_embedding = 768
826 | elif args.text_encoder == 'gpt1':
827 | self.text_embedding = 768
828 | else:
829 | raise NotImplementedError
830 |
831 | self.image_encoder = ImageEncoder(args)
832 |
833 | self.text_encoder = TextEncoder(args)
834 |
835 |
836 | if args.only_has_image_projection:
837 | self.image_projection = ProjectionHead(embedding_dim=self.image_embedding)
838 | self.text_projection = ProjectionHead(embedding_dim=self.text_embedding, projection_dim=self.image_embedding).to('cuda')
839 | if train_logit_scale:
840 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1.0 / temperature))
841 | else:
842 | self.logit_scale = torch.ones([]) * np.log(1.0 / temperature)
843 |
844 |
845 | self.args = args
846 | self.distill = args.distill
847 |
848 | self.multilabel_criterion = MultilabelContrastiveLoss(args.loss_type)
849 |
850 |
851 | def forward(self, image, caption, epoch, similarity=None):
852 | self.image_encoder = self.image_encoder.to('cuda')
853 | self.text_encoder = self.text_encoder.to('cuda')
854 |
855 | image_features = self.image_encoder(image)
856 | text_features = caption if self.distill else self.text_encoder(caption)
857 |
858 | use_image_project = False
859 | im_embed = image_features.float() if not use_image_project else self.image_projection(image_features.float())
860 |
861 | txt_embed = self.text_projection(text_features.float())
862 |
863 | combined_image_features = im_embed
864 | combined_text_features = txt_embed
865 | image_features = combined_image_features / combined_image_features.norm(dim=1, keepdim=True)
866 | text_features = combined_text_features / combined_text_features.norm(dim=1, keepdim=True)
867 |
868 | image_logits = self.logit_scale.exp() * image_features @ text_features.t()
869 |
870 | ground_truth = torch.arange(len(image_logits)).type_as(image_logits).long()
871 | acc_i = (torch.argmax(image_logits, 1) == ground_truth).sum().item()
872 | acc_t = (torch.argmax(image_logits, 0) == ground_truth).sum().item()
873 | acc = (acc_i + acc_t) / 2
874 |
875 |
876 | if similarity is None:
877 | loss = (F.cross_entropy(image_logits, ground_truth) + F.cross_entropy(image_logits.t(), ground_truth))/2
878 | else:
879 | loss = self.multilabel_criterion(image_logits, similarity)
880 |
881 | return loss, acc
--------------------------------------------------------------------------------
/src/reparam_module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import warnings
4 | import types
5 | from collections import namedtuple
6 | from contextlib import contextmanager
7 |
8 |
9 | class ReparamModule(nn.Module):
10 | def _get_module_from_name(self, mn):
11 | if mn == '':
12 | return self
13 | m = self
14 | for p in mn.split('.'):
15 | m = getattr(m, p)
16 | return m
17 |
18 | def __init__(self, module):
19 | super(ReparamModule, self).__init__()
20 | self.module = module
21 |
22 | param_infos = [] # (module name/path, param name)
23 | shared_param_memo = {}
24 | shared_param_infos = [] # (module name/path, param name, src module name/path, src param_name)
25 | params = []
26 | param_numels = []
27 | param_shapes = []
28 | for mn, m in self.named_modules():
29 | for n, p in m.named_parameters(recurse=False):
30 | if p is not None:
31 | if p in shared_param_memo:
32 | shared_mn, shared_n = shared_param_memo[p]
33 | shared_param_infos.append((mn, n, shared_mn, shared_n))
34 | else:
35 | shared_param_memo[p] = (mn, n)
36 | param_infos.append((mn, n))
37 | params.append(p.detach())
38 | param_numels.append(p.numel())
39 | param_shapes.append(p.size())
40 |
41 | assert len(set(p.dtype for p in params)) <= 1, \
42 | "expects all parameters in module to have same dtype"
43 |
44 | # store the info for unflatten
45 | self._param_infos = tuple(param_infos)
46 | self._shared_param_infos = tuple(shared_param_infos)
47 | self._param_numels = tuple(param_numels)
48 | self._param_shapes = tuple(param_shapes)
49 |
50 | # flatten
51 | flat_param = nn.Parameter(torch.cat([p.reshape(-1) for p in params], 0))
52 | self.register_parameter('flat_param', flat_param)
53 | self.param_numel = flat_param.numel()
54 | del params
55 | del shared_param_memo
56 |
57 | # deregister the names as parameters
58 | for mn, n in self._param_infos:
59 | delattr(self._get_module_from_name(mn), n)
60 | for mn, n, _, _ in self._shared_param_infos:
61 | delattr(self._get_module_from_name(mn), n)
62 |
63 | # register the views as plain attributes
64 | self._unflatten_param(self.flat_param)
65 |
66 | # now buffers
67 | # they are not reparametrized. just store info as (module, name, buffer)
68 | buffer_infos = []
69 | for mn, m in self.named_modules():
70 | for n, b in m.named_buffers(recurse=False):
71 | if b is not None:
72 | buffer_infos.append((mn, n, b))
73 |
74 | self._buffer_infos = tuple(buffer_infos)
75 | self._traced_self = None
76 |
77 | def trace(self, example_input, **trace_kwargs):
78 | assert self._traced_self is None, 'This ReparamModule is already traced'
79 |
80 | if isinstance(example_input, torch.Tensor):
81 | example_input = (example_input,)
82 | example_input = tuple(example_input)
83 | example_param = (self.flat_param.detach().clone(),)
84 | example_buffers = (tuple(b.detach().clone() for _, _, b in self._buffer_infos),)
85 |
86 | self._traced_self = torch.jit.trace_module(
87 | self,
88 | inputs=dict(
89 | _forward_with_param=example_param + example_input,
90 | _forward_with_param_and_buffers=example_param + example_buffers + example_input,
91 | ),
92 | **trace_kwargs,
93 | )
94 |
95 | # replace forwards with traced versions
96 | self._forward_with_param = self._traced_self._forward_with_param
97 | self._forward_with_param_and_buffers = self._traced_self._forward_with_param_and_buffers
98 | return self
99 |
100 | def clear_views(self):
101 | for mn, n in self._param_infos:
102 | setattr(self._get_module_from_name(mn), n, None) # This will set as plain attr
103 |
104 | def _apply(self, *args, **kwargs):
105 | if self._traced_self is not None:
106 | self._traced_self._apply(*args, **kwargs)
107 | return self
108 | return super(ReparamModule, self)._apply(*args, **kwargs)
109 |
110 | def _unflatten_param(self, flat_param):
111 | ps = (t.view(s) for (t, s) in zip(flat_param.split(self._param_numels), self._param_shapes))
112 | for (mn, n), p in zip(self._param_infos, ps):
113 | setattr(self._get_module_from_name(mn), n, p) # This will set as plain attr
114 | for (mn, n, shared_mn, shared_n) in self._shared_param_infos:
115 | setattr(self._get_module_from_name(mn), n, getattr(self._get_module_from_name(shared_mn), shared_n))
116 |
117 | @contextmanager
118 | def unflattened_param(self, flat_param):
119 | saved_views = [getattr(self._get_module_from_name(mn), n) for mn, n in self._param_infos]
120 | self._unflatten_param(flat_param)
121 | yield
122 | # Why not just `self._unflatten_param(self.flat_param)`?
123 | # 1. because of https://github.com/pytorch/pytorch/issues/17583
124 | # 2. slightly faster since it does not require reconstruct the split+view
125 | # graph
126 | for (mn, n), p in zip(self._param_infos, saved_views):
127 | setattr(self._get_module_from_name(mn), n, p)
128 | for (mn, n, shared_mn, shared_n) in self._shared_param_infos:
129 | setattr(self._get_module_from_name(mn), n, getattr(self._get_module_from_name(shared_mn), shared_n))
130 |
131 | @contextmanager
132 | def replaced_buffers(self, buffers):
133 | for (mn, n, _), new_b in zip(self._buffer_infos, buffers):
134 | setattr(self._get_module_from_name(mn), n, new_b)
135 | yield
136 | for mn, n, old_b in self._buffer_infos:
137 | setattr(self._get_module_from_name(mn), n, old_b)
138 |
139 | def _forward_with_param_and_buffers(self, flat_param, buffers, *inputs, **kwinputs):
140 | with self.unflattened_param(flat_param):
141 | with self.replaced_buffers(buffers):
142 | return self.module(*inputs, **kwinputs)
143 |
144 | def _forward_with_param(self, flat_param, *inputs, **kwinputs):
145 | with self.unflattened_param(flat_param):
146 | return self.module(*inputs, **kwinputs)
147 |
148 | def forward(self, *inputs, flat_param=None, buffers=None, **kwinputs):
149 | flat_param = torch.squeeze(flat_param)
150 | # print("PARAMS ON DEVICE: ", flat_param.get_device())
151 | # print("DATA ON DEVICE: ", inputs[0].get_device())
152 | # flat_param.to("cuda:{}".format(inputs[0].get_device()))
153 | # self.module.to("cuda:{}".format(inputs[0].get_device()))
154 | if flat_param is None:
155 | flat_param = self.flat_param
156 | if buffers is None:
157 | return self._forward_with_param(flat_param, *inputs, **kwinputs)
158 | else:
159 | return self._forward_with_param_and_buffers(flat_param, tuple(buffers), *inputs, **kwinputs)
--------------------------------------------------------------------------------
/src/similarity_mining.py:
--------------------------------------------------------------------------------
1 | """Similarity modules and loss modules for LoRS
2 | by Yue Xu"""
3 |
4 | from abc import ABC, abstractmethod
5 | from typing import List
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 |
11 |
12 | class BaseSimilarityGenerator(nn.Module, ABC):
13 | def __init__(self, dim):
14 | super().__init__()
15 | self.dim = dim
16 |
17 | @abstractmethod
18 | def generate_with_param(self, params: List):
19 | pass
20 |
21 | @abstractmethod
22 | def get_indexed_parameters(self, indices=None) -> List:
23 | pass
24 |
25 | def load_params(self, params):
26 | raise NotImplementedError("")
27 |
28 |
29 |
30 | class LowRankSimilarityGenerator(BaseSimilarityGenerator):
31 | """Generate a parameterized similarity matrix S = aI + L@R'
32 | where a is a vector, I is the identity matrix, and B is a matrix of learnable parameters."""
33 | def __init__(self, dim, rank, alpha=0.1):
34 | super().__init__(dim)
35 | self.rank = float(rank)
36 |
37 | self.alpha = alpha
38 |
39 | self.diag_weight = nn.Parameter(torch.ones(dim))
40 | self.left = nn.Parameter(torch.randn(dim, rank))
41 | self.right = nn.Parameter(torch.zeros(dim, rank))
42 |
43 |
44 | def generate_with_param(self, params: List):
45 | a, l, r = params
46 |
47 | diag = torch.diag(a)
48 | sim = diag + l @ r.t() * self.alpha / self.rank
49 | return sim
50 |
51 | def get_indexed_parameters(self, indices=None) -> List:
52 | if indices is None:
53 | a, l, r = self.diag_weight, self.left, self.right
54 | else:
55 | a, l, r = self.diag_weight[indices], self.left[indices], self.right[indices]
56 |
57 | return [a, l, r]
58 |
59 | def load_params(self, params):
60 | a, l, r = params
61 | self.diag_weight.data = a
62 | self.left.data = l
63 | self.right.data = r
64 |
65 |
66 |
67 | class FullSimilarityGenerator(BaseSimilarityGenerator):
68 | def __init__(self, dim):
69 | super().__init__(dim)
70 | self.sim_mat = nn.Parameter(torch.eye(dim))
71 |
72 | def generate_with_param(self, params):
73 | assert len(params) == 1
74 | return params[0]
75 |
76 | def get_indexed_parameters(self, indices=None) -> List:
77 | if indices is None:
78 | sim = self.sim_mat
79 | else:
80 | sim = self.sim_mat[indices[:,None], indices]
81 | return [sim]
82 |
83 | def load_params(self, params):
84 | s = params[0]
85 | self.sim_mat.data = s
86 |
87 |
88 |
89 |
90 |
91 | class MultilabelContrastiveLoss(nn.Module):
92 | """
93 | Cross entropy loss with soft target.
94 | """
95 |
96 | def __init__(self, loss_type=None):
97 | """
98 | Args:
99 | reduction (str): specifies reduction to apply to the output. It can be
100 | "mean" (default) or "none".
101 | """
102 | super().__init__()
103 | self.loss_type = loss_type
104 |
105 | self.kl_loss_func = nn.KLDivLoss(reduction="mean")
106 | self.bce_loss_func = nn.BCELoss(reduction="none")
107 |
108 | self.n_clusters = 4 # for KmeansBalanceBCE
109 |
110 |
111 |
112 | def __kl_criterion(self, logit, label):
113 | # batchsize = logit.shape[0]
114 | probs1 = F.log_softmax(logit, 1)
115 | probs2 = F.softmax(label * 10, 1)
116 | loss = self.kl_loss_func(probs1, probs2)
117 | return loss
118 |
119 | def __cwcl_criterion(self, logit, label):
120 | logprob = torch.log_softmax(logit, 1)
121 | loss_per = (label * logprob).sum(1) / (label.sum(1)+1e-6)
122 | loss = -loss_per.mean()
123 | return loss
124 |
125 | def __infonce_nonvonventional_criterion(self, logit, label):
126 | logprob = torch.log_softmax(logit, 1)
127 | loss_per = (label * logprob).sum(1)
128 | loss = -loss_per.mean()
129 | return loss
130 |
131 |
132 | def forward(self, logits, gt_matrix):
133 | gt_matrix = gt_matrix.to(logits.device)
134 |
135 | if self.loss_type == "KL":
136 | loss_i = self.__kl_criterion(logits, gt_matrix)
137 | loss_t = self.__kl_criterion(logits.t(), gt_matrix.t())
138 | return (loss_i + loss_t) * 0.5
139 | elif self.loss_type == "BCE":
140 | # print("BCE is called")
141 | probs1 = torch.sigmoid(logits)
142 | probs2 = gt_matrix # torch.sigmoid(gt_matrix)
143 | bce_loss = self.bce_loss_func(probs1, probs2)
144 | loss = bce_loss.mean()
145 | # loss = self.__general_cl_criterion(logits, gt_matrix, "BCE",
146 | # use_norm=False, use_negwgt=False)
147 | return loss
148 |
149 | elif self.loss_type in ["BalanceBCE", "WBCE"]:
150 | probs1 = torch.sigmoid(logits)
151 | probs2 = gt_matrix # torch.sigmoid(gt_matrix)
152 |
153 | loss_matrix = - probs2 * torch.log(probs1+1e-6) - (1-probs2) * torch.log(1-probs1+1e-6)
154 |
155 | pos_mask = (probs2>0.5).detach()
156 | neg_mask = ~pos_mask
157 |
158 | loss_pos = torch.where(pos_mask, loss_matrix, torch.tensor(0.0, device=probs1.device)).sum()
159 | loss_neg = torch.where(neg_mask, loss_matrix, torch.tensor(0.0, device=probs1.device)).sum()
160 |
161 | loss_pos /= (pos_mask.sum()+1e-6)
162 | loss_neg /= (neg_mask.sum()+1e-6)
163 |
164 | return (loss_pos+loss_neg)/2
165 |
166 |
167 | elif self.loss_type in ["NCE", "InfoNCE"]:
168 | loss_i = self.__infonce_nonvonventional_criterion(logits, gt_matrix)
169 | loss_t = self.__infonce_nonvonventional_criterion(logits.t(), gt_matrix.t())
170 | return (loss_i + loss_t) * 0.5
171 |
172 | elif self.loss_type == "MSE":
173 | probs = torch.sigmoid(logits)
174 | return F.mse_loss(probs, gt_matrix)
175 |
176 | elif self.loss_type == "CWCL":
177 | loss_i = self.__cwcl_criterion(logits, gt_matrix)
178 | loss_t = self.__cwcl_criterion(logits.t(), gt_matrix.t())
179 | return (loss_i + loss_t) * 0.5
180 |
181 | else:
182 | raise NotImplementedError(self.loss_type)
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | # adapted from
2 | # https://github.com/VICO-UoE/DatasetCondensation
3 | import time
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import os
9 | import kornia as K
10 | import tqdm
11 | from torch.utils.data import Dataset
12 | from torchvision import datasets, transforms
13 | from scipy.ndimage.interpolation import rotate as scipyrotate
14 | from src.networks import MLP, ConvNet, LeNet, AlexNet, VGG11BN, VGG11, ResNet18, ResNet18BN_AP, ResNet18_AP, ModifiedResNet, resnet18_gn
15 | import re
16 | import json
17 | import torch.distributed as dist
18 | from tqdm import tqdm
19 | from collections import defaultdict
20 |
21 | class Config:
22 | imagenette = [0, 217, 482, 491, 497, 566, 569, 571, 574, 701]
23 |
24 | # ["australian_terrier", "border_terrier", "samoyed", "beagle", "shih-tzu", "english_foxhound", "rhodesian_ridgeback", "dingo", "golden_retriever", "english_sheepdog"]
25 | imagewoof = [193, 182, 258, 162, 155, 167, 159, 273, 207, 229]
26 |
27 | # ["tabby_cat", "bengal_cat", "persian_cat", "siamese_cat", "egyptian_cat", "lion", "tiger", "jaguar", "snow_leopard", "lynx"]
28 | imagemeow = [281, 282, 283, 284, 285, 291, 292, 290, 289, 287]
29 |
30 | # ["peacock", "flamingo", "macaw", "pelican", "king_penguin", "bald_eagle", "toucan", "ostrich", "black_swan", "cockatoo"]
31 | imagesquawk = [84, 130, 88, 144, 145, 22, 96, 9, 100, 89]
32 |
33 | # ["pineapple", "banana", "strawberry", "orange", "lemon", "pomegranate", "fig", "bell_pepper", "cucumber", "green_apple"]
34 | imagefruit = [953, 954, 949, 950, 951, 957, 952, 945, 943, 948]
35 |
36 | # ["bee", "ladys slipper", "banana", "lemon", "corn", "school_bus", "honeycomb", "lion", "garden_spider", "goldfinch"]
37 | imageyellow = [309, 986, 954, 951, 987, 779, 599, 291, 72, 11]
38 |
39 | dict = {
40 | "imagenette" : imagenette,
41 | "imagewoof" : imagewoof,
42 | "imagefruit": imagefruit,
43 | "imageyellow": imageyellow,
44 | "imagemeow": imagemeow,
45 | "imagesquawk": imagesquawk,
46 | }
47 |
48 | config = Config()
49 |
50 | def get_dataset(args):
51 |
52 | if args.dataset == 'CIFAR10_clip':
53 | channel = 3
54 | im_size = (32, 32)
55 | num_classes = 768
56 | mean = [0.4914, 0.4822, 0.4465]
57 | std = [0.2023, 0.1994, 0.2010]
58 | if args.zca:
59 | transform = transforms.Compose([transforms.ToTensor()])
60 | else:
61 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
62 | dst_train = datasets.CIFAR10(args.data_path, train=True, download=True, transform=transform) # no augmentation
63 | dst_test = datasets.CIFAR10(args.data_path, train=False, download=True, transform=transform)
64 | class_names = dst_train.classes
65 | class_map = {x:x for x in range(num_classes)}
66 |
67 | else:
68 | exit('unknown dataset: %s'%args.dataset)
69 |
70 | if args.zca:
71 | images = []
72 | labels = []
73 | print("Train ZCA")
74 | for i in tqdm(range(len(dst_train))):
75 | im, lab = dst_train[i]
76 | images.append(im)
77 | labels.append(lab)
78 | images = torch.stack(images, dim=0).to(args.device)
79 | labels = torch.tensor(labels, dtype=torch.float, device="cpu")
80 | zca = K.enhance.ZCAWhitening(eps=0.1, compute_inv=True)
81 | zca.fit(images)
82 | zca_images = zca(images).to("cpu")
83 | dst_train = TensorDataset(zca_images, labels)
84 |
85 | images = []
86 | labels = []
87 | print("Test ZCA")
88 | for i in tqdm(range(len(dst_test))):
89 | im, lab = dst_test[i]
90 | images.append(im)
91 | labels.append(lab)
92 | images = torch.stack(images, dim=0).to(args.device)
93 | labels = torch.tensor(labels, dtype=torch.float, device="cpu")
94 |
95 | zca_images = zca(images).to("cpu")
96 | dst_test = TensorDataset(zca_images, labels)
97 |
98 | args.zca_trans = zca
99 |
100 |
101 | testloader = torch.utils.data.DataLoader(dst_test, batch_size=128, shuffle=False, num_workers=2)
102 |
103 | if "flickr" not in args.dataset:
104 | dst_train_label, dst_test_label = None, None
105 | return channel, im_size, num_classes, mean, std, dst_train, dst_test, testloader, dst_train_label, dst_test_label
106 |
107 |
108 |
109 | class TensorDataset(Dataset):
110 | def __init__(self, images, labels): # images: n x c x h x w tensor
111 | self.images = images.detach().float()
112 | self.labels = labels.detach()
113 |
114 | def __getitem__(self, index):
115 | return self.images[index], self.labels[index]
116 |
117 | def __len__(self):
118 | return self.images.shape[0]
119 |
120 |
121 |
122 | def get_default_convnet_setting():
123 | net_width, net_depth, net_act, net_norm, net_pooling = 128, 3, 'relu', 'instancenorm', 'avgpooling'
124 | return net_width, net_depth, net_act, net_norm, net_pooling
125 |
126 |
127 |
128 | def get_RN_network(model, vision_width, vision_layers, embed_dim, image_resolution):
129 | if model == 'RN50':
130 | vision_heads = vision_width * 32 // 64
131 | net = ModifiedResNet(layers=vision_layers,
132 | output_dim=embed_dim,
133 | heads=vision_heads,
134 | input_resolution=image_resolution,
135 | width=vision_width)
136 | if dist:
137 | gpu_num = torch.cuda.device_count()
138 | if gpu_num>0:
139 | device = 'cuda'
140 | if gpu_num>1:
141 | net = nn.DataParallel(net)
142 | else:
143 | device = 'cpu'
144 | net = net.to(device)
145 |
146 | return net
147 |
148 | def get_network(model, channel, num_classes, im_size=(32, 32), dist=True):
149 | torch.random.manual_seed(int(time.time() * 1000) % 100000)
150 | net_width, net_depth, net_act, net_norm, net_pooling = get_default_convnet_setting()
151 |
152 | if model == 'MLP':
153 | net = MLP(channel=channel, num_classes=num_classes)
154 | elif model == 'ConvNet':
155 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
156 | elif model == 'LeNet':
157 | net = LeNet(channel=channel, num_classes=num_classes)
158 | elif model == 'AlexNet':
159 | net = AlexNet(channel=channel, num_classes=num_classes)
160 | elif model == 'VGG11':
161 | net = VGG11( channel=channel, num_classes=num_classes)
162 | elif model == 'VGG11BN':
163 | net = VGG11BN(channel=channel, num_classes=num_classes)
164 | elif model == 'ResNet18':
165 | net = ResNet18(channel=channel, num_classes=num_classes)
166 | elif model == 'ResNet18BN_AP':
167 | net = ResNet18BN_AP(channel=channel, num_classes=num_classes)
168 | elif model == 'ResNet18_AP':
169 | net = ResNet18_AP(channel=channel, num_classes=num_classes)
170 |
171 | elif model == 'ConvNetD1':
172 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=1, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
173 | elif model == 'ConvNetD2':
174 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=2, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
175 | elif model == 'ConvNetD3':
176 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=3, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
177 | elif model == 'ConvNetD4':
178 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=4, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
179 | elif model == 'ConvNetD5':
180 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=5, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
181 | elif model == 'ConvNetD6':
182 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=6, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
183 | elif model == 'ConvNetD7':
184 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=7, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
185 | elif model == 'ConvNetD8':
186 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=8, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
187 |
188 |
189 | elif model == 'ConvNetW32':
190 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=32, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling)
191 | elif model == 'ConvNetW64':
192 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=64, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling)
193 | elif model == 'ConvNetW128':
194 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=128, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling)
195 | elif model == 'ConvNetW256':
196 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=256, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling)
197 | elif model == 'ConvNetW512':
198 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=512, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling)
199 | elif model == 'ConvNetW1024':
200 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=1024, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling)
201 |
202 | elif model == "ConvNetKIP":
203 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=1024, net_depth=net_depth, net_act=net_act,
204 | net_norm="none", net_pooling=net_pooling)
205 |
206 | elif model == 'ConvNetAS':
207 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='sigmoid', net_norm=net_norm, net_pooling=net_pooling)
208 | elif model == 'ConvNetAR':
209 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='relu', net_norm=net_norm, net_pooling=net_pooling)
210 | elif model == 'ConvNetAL':
211 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='leakyrelu', net_norm=net_norm, net_pooling=net_pooling)
212 |
213 | elif model == 'ConvNetNN':
214 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='none', net_pooling=net_pooling)
215 | elif model == 'ConvNetBN':
216 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='batchnorm', net_pooling=net_pooling)
217 | elif model == 'ConvNetLN':
218 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='layernorm', net_pooling=net_pooling)
219 | elif model == 'ConvNetIN':
220 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='instancenorm', net_pooling=net_pooling)
221 | elif model == 'ConvNetGN':
222 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='groupnorm', net_pooling=net_pooling)
223 |
224 | elif model == 'ConvNetNP':
225 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling='none')
226 | elif model == 'ConvNetMP':
227 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling='maxpooling')
228 | elif model == 'ConvNetAP':
229 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling='avgpooling')
230 |
231 |
232 | else:
233 | net = None
234 | exit('DC error: unknown model')
235 |
236 | if dist:
237 | gpu_num = torch.cuda.device_count()
238 | if gpu_num>0:
239 | device = 'cuda'
240 | if gpu_num>1:
241 | net = nn.DataParallel(net)
242 | else:
243 | device = 'cpu'
244 | net = net.to(device)
245 |
246 | return net
247 |
248 |
249 |
250 | def get_time():
251 | return str(time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime()))
252 |
253 |
254 |
255 | def augment(images, dc_aug_param, device):
256 | # This can be sped up in the future.
257 |
258 | if dc_aug_param != None and dc_aug_param['strategy'] != 'none':
259 | scale = dc_aug_param['scale']
260 | crop = dc_aug_param['crop']
261 | rotate = dc_aug_param['rotate']
262 | noise = dc_aug_param['noise']
263 | strategy = dc_aug_param['strategy']
264 |
265 | shape = images.shape
266 | mean = []
267 | for c in range(shape[1]):
268 | mean.append(float(torch.mean(images[:,c])))
269 |
270 | def cropfun(i):
271 | im_ = torch.zeros(shape[1],shape[2]+crop*2,shape[3]+crop*2, dtype=torch.float, device=device)
272 | for c in range(shape[1]):
273 | im_[c] = mean[c]
274 | im_[:, crop:crop+shape[2], crop:crop+shape[3]] = images[i]
275 | r, c = np.random.permutation(crop*2)[0], np.random.permutation(crop*2)[0]
276 | images[i] = im_[:, r:r+shape[2], c:c+shape[3]]
277 |
278 | def scalefun(i):
279 | h = int((np.random.uniform(1 - scale, 1 + scale)) * shape[2])
280 | w = int((np.random.uniform(1 - scale, 1 + scale)) * shape[2])
281 | tmp = F.interpolate(images[i:i + 1], [h, w], )[0]
282 | mhw = max(h, w, shape[2], shape[3])
283 | im_ = torch.zeros(shape[1], mhw, mhw, dtype=torch.float, device=device)
284 | r = int((mhw - h) / 2)
285 | c = int((mhw - w) / 2)
286 | im_[:, r:r + h, c:c + w] = tmp
287 | r = int((mhw - shape[2]) / 2)
288 | c = int((mhw - shape[3]) / 2)
289 | images[i] = im_[:, r:r + shape[2], c:c + shape[3]]
290 |
291 | def rotatefun(i):
292 | im_ = scipyrotate(images[i].cpu().data.numpy(), angle=np.random.randint(-rotate, rotate), axes=(-2, -1), cval=np.mean(mean))
293 | r = int((im_.shape[-2] - shape[-2]) / 2)
294 | c = int((im_.shape[-1] - shape[-1]) / 2)
295 | images[i] = torch.tensor(im_[:, r:r + shape[-2], c:c + shape[-1]], dtype=torch.float, device=device)
296 |
297 | def noisefun(i):
298 | images[i] = images[i] + noise * torch.randn(shape[1:], dtype=torch.float, device=device)
299 |
300 |
301 | augs = strategy.split('_')
302 |
303 | for i in range(shape[0]):
304 | choice = np.random.permutation(augs)[0] # randomly implement one augmentation
305 | if choice == 'crop':
306 | cropfun(i)
307 | elif choice == 'scale':
308 | scalefun(i)
309 | elif choice == 'rotate':
310 | rotatefun(i)
311 | elif choice == 'noise':
312 | noisefun(i)
313 |
314 | return images
315 |
316 |
317 |
318 | def get_daparam(dataset, model, model_eval, ipc):
319 | # We find that augmentation doesn't always benefit the performance.
320 | # So we do augmentation for some of the settings.
321 |
322 | dc_aug_param = dict()
323 | dc_aug_param['crop'] = 4
324 | dc_aug_param['scale'] = 0.2
325 | dc_aug_param['rotate'] = 45
326 | dc_aug_param['noise'] = 0.001
327 | dc_aug_param['strategy'] = 'none'
328 |
329 | if dataset == 'MNIST':
330 | dc_aug_param['strategy'] = 'crop_scale_rotate'
331 |
332 | if model_eval in ['ConvNetBN']: # Data augmentation makes model training with Batch Norm layer easier.
333 | dc_aug_param['strategy'] = 'crop_noise'
334 |
335 | return dc_aug_param
336 |
337 |
338 | def get_eval_pool(eval_mode, model, model_eval):
339 | if eval_mode == 'M': # multiple architectures
340 | # model_eval_pool = ['MLP', 'ConvNet', 'AlexNet', 'VGG11', 'ResNet18', 'LeNet']
341 | model_eval_pool = ['ConvNet', 'AlexNet', 'VGG11', 'ResNet18_AP', 'ResNet18']
342 | # model_eval_pool = ['MLP', 'ConvNet', 'AlexNet', 'VGG11', 'ResNet18']
343 | elif eval_mode == 'W': # ablation study on network width
344 | model_eval_pool = ['ConvNetW32', 'ConvNetW64', 'ConvNetW128', 'ConvNetW256']
345 | elif eval_mode == 'D': # ablation study on network depth
346 | model_eval_pool = ['ConvNetD1', 'ConvNetD2', 'ConvNetD3', 'ConvNetD4']
347 | elif eval_mode == 'A': # ablation study on network activation function
348 | model_eval_pool = ['ConvNetAS', 'ConvNetAR', 'ConvNetAL']
349 | elif eval_mode == 'P': # ablation study on network pooling layer
350 | model_eval_pool = ['ConvNetNP', 'ConvNetMP', 'ConvNetAP']
351 | elif eval_mode == 'N': # ablation study on network normalization layer
352 | model_eval_pool = ['ConvNetNN', 'ConvNetBN', 'ConvNetLN', 'ConvNetIN', 'ConvNetGN']
353 | elif eval_mode == 'S': # itself
354 | model_eval_pool = [model[:model.index('BN')]] if 'BN' in model else [model]
355 | elif eval_mode == 'C':
356 | model_eval_pool = [model, 'ConvNet']
357 | else:
358 | model_eval_pool = [model_eval]
359 | return model_eval_pool
360 |
361 |
362 | class ParamDiffAug():
363 | def __init__(self):
364 | self.aug_mode = 'S' #'multiple or single'
365 | self.prob_flip = 0.5
366 | self.ratio_scale = 1.2
367 | self.ratio_rotate = 15.0
368 | self.ratio_crop_pad = 0.125
369 | self.ratio_cutout = 0.5 # the size would be 0.5x0.5
370 | self.ratio_noise = 0.05
371 | self.brightness = 1.0
372 | self.saturation = 2.0
373 | self.contrast = 0.5
374 |
375 |
376 | def set_seed_DiffAug(param):
377 | if param.latestseed == -1:
378 | return
379 | else:
380 | torch.random.manual_seed(param.latestseed)
381 | param.latestseed += 1
382 |
383 |
384 | def DiffAugment(x, strategy='', seed = -1, param = None):
385 | if seed == -1:
386 | param.batchmode = False
387 | else:
388 | param.batchmode = True
389 |
390 | param.latestseed = seed
391 |
392 | if strategy == 'None' or strategy == 'none':
393 | return x
394 |
395 | if strategy:
396 | if param.aug_mode == 'M': # original
397 | for p in strategy.split('_'):
398 | for f in AUGMENT_FNS[p]:
399 | x = f(x, param)
400 | elif param.aug_mode == 'S':
401 | pbties = strategy.split('_')
402 | set_seed_DiffAug(param)
403 | p = pbties[torch.randint(0, len(pbties), size=(1,)).item()]
404 | for f in AUGMENT_FNS[p]:
405 | x = f(x, param)
406 | else:
407 | exit('Error ZH: unknown augmentation mode.')
408 | x = x.contiguous()
409 | return x
410 |
411 |
412 | # We implement the following differentiable augmentation strategies based on the code provided in https://github.com/mit-han-lab/data-efficient-gans.
413 | def rand_scale(x, param):
414 | # x>1, max scale
415 | # sx, sy: (0, +oo), 1: orignial size, 0.5: enlarge 2 times
416 | ratio = param.ratio_scale
417 | set_seed_DiffAug(param)
418 | sx = torch.rand(x.shape[0]) * (ratio - 1.0/ratio) + 1.0/ratio
419 | set_seed_DiffAug(param)
420 | sy = torch.rand(x.shape[0]) * (ratio - 1.0/ratio) + 1.0/ratio
421 | theta = [[[sx[i], 0, 0],
422 | [0, sy[i], 0],] for i in range(x.shape[0])]
423 | theta = torch.tensor(theta, dtype=torch.float)
424 | if param.batchmode: # batch-wise:
425 | theta[:] = theta[0]
426 | grid = F.affine_grid(theta, x.shape, align_corners=True).to(x.device)
427 | x = F.grid_sample(x, grid, align_corners=True)
428 | return x
429 |
430 |
431 | def rand_rotate(x, param): # [-180, 180], 90: anticlockwise 90 degree
432 | ratio = param.ratio_rotate
433 | set_seed_DiffAug(param)
434 | theta = (torch.rand(x.shape[0]) - 0.5) * 2 * ratio / 180 * float(np.pi)
435 | theta = [[[torch.cos(theta[i]), torch.sin(-theta[i]), 0],
436 | [torch.sin(theta[i]), torch.cos(theta[i]), 0],] for i in range(x.shape[0])]
437 | theta = torch.tensor(theta, dtype=torch.float)
438 | if param.batchmode: # batch-wise:
439 | theta[:] = theta[0]
440 | grid = F.affine_grid(theta, x.shape, align_corners=True).to(x.device)
441 | x = F.grid_sample(x, grid, align_corners=True)
442 | return x
443 |
444 |
445 | def rand_flip(x, param):
446 | prob = param.prob_flip
447 | set_seed_DiffAug(param)
448 | randf = torch.rand(x.size(0), 1, 1, 1, device=x.device)
449 | if param.batchmode: # batch-wise:
450 | randf[:] = randf[0]
451 | return torch.where(randf < prob, x.flip(3), x)
452 |
453 |
454 | def rand_brightness(x, param):
455 | ratio = param.brightness
456 | set_seed_DiffAug(param)
457 | randb = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device)
458 | if param.batchmode: # batch-wise:
459 | randb[:] = randb[0]
460 | x = x + (randb - 0.5)*ratio
461 | return x
462 |
463 |
464 | def rand_saturation(x, param):
465 | ratio = param.saturation
466 | x_mean = x.mean(dim=1, keepdim=True)
467 | set_seed_DiffAug(param)
468 | rands = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device)
469 | if param.batchmode: # batch-wise:
470 | rands[:] = rands[0]
471 | x = (x - x_mean) * (rands * ratio) + x_mean
472 | return x
473 |
474 |
475 | def rand_contrast(x, param):
476 | ratio = param.contrast
477 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
478 | set_seed_DiffAug(param)
479 | randc = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device)
480 | if param.batchmode: # batch-wise:
481 | randc[:] = randc[0]
482 | x = (x - x_mean) * (randc + ratio) + x_mean
483 | return x
484 |
485 |
486 | def rand_crop(x, param):
487 | # The image is padded on its surrounding and then cropped.
488 | ratio = param.ratio_crop_pad
489 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
490 | set_seed_DiffAug(param)
491 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
492 | set_seed_DiffAug(param)
493 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
494 | if param.batchmode: # batch-wise:
495 | translation_x[:] = translation_x[0]
496 | translation_y[:] = translation_y[0]
497 | grid_batch, grid_x, grid_y = torch.meshgrid(
498 | torch.arange(x.size(0), dtype=torch.long, device=x.device),
499 | torch.arange(x.size(2), dtype=torch.long, device=x.device),
500 | torch.arange(x.size(3), dtype=torch.long, device=x.device),
501 | )
502 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
503 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
504 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
505 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
506 | return x
507 |
508 |
509 | def rand_cutout(x, param):
510 | ratio = param.ratio_cutout
511 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
512 | set_seed_DiffAug(param)
513 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
514 | set_seed_DiffAug(param)
515 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
516 | if param.batchmode: # batch-wise:
517 | offset_x[:] = offset_x[0]
518 | offset_y[:] = offset_y[0]
519 | grid_batch, grid_x, grid_y = torch.meshgrid(
520 | torch.arange(x.size(0), dtype=torch.long, device=x.device),
521 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
522 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
523 | )
524 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
525 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
526 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
527 | mask[grid_batch, grid_x, grid_y] = 0
528 | x = x * mask.unsqueeze(1)
529 | return x
530 |
531 |
532 | AUGMENT_FNS = {
533 | 'color': [rand_brightness, rand_saturation, rand_contrast],
534 | 'crop': [rand_crop],
535 | 'cutout': [rand_cutout],
536 | 'flip': [rand_flip],
537 | 'scale': [rand_scale],
538 | 'rotate': [rand_rotate],
539 | }
540 |
541 |
542 | def pre_question(question,max_ques_words=50):
543 | question = re.sub(
544 | r"([.!\"()*#:;~])",
545 | '',
546 | question.lower(),
547 | )
548 | question = question.rstrip(' ')
549 |
550 | #truncate question
551 | question_words = question.split(' ')
552 | if len(question_words)>max_ques_words:
553 | question = ' '.join(question_words[:max_ques_words])
554 |
555 | return question
556 |
557 |
558 | def save_result(result, result_dir, filename, remove_duplicate=''):
559 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank()))
560 | final_result_file = os.path.join(result_dir, '%s.json'%filename)
561 |
562 | json.dump(result,open(result_file,'w'))
563 |
564 | dist.barrier()
565 |
566 | if utils.is_main_process():
567 | # combine results from all processes
568 | result = []
569 |
570 | for rank in range(utils.get_world_size()):
571 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank))
572 | res = json.load(open(result_file,'r'))
573 | result += res
574 |
575 | if remove_duplicate:
576 | result_new = []
577 | id_list = []
578 | for res in result:
579 | if res[remove_duplicate] not in id_list:
580 | id_list.append(res[remove_duplicate])
581 | result_new.append(res)
582 | result = result_new
583 |
584 | json.dump(result,open(final_result_file,'w'))
585 | print('result file saved to %s'%final_result_file)
586 |
587 | return final_result_file
588 |
589 |
590 |
591 | #### everything below is from https://github.com/salesforce/BLIP/blob/main/utils.py
592 | import math
593 |
594 | def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
595 | """Decay the learning rate"""
596 | lr = (init_lr - min_lr) * 0.5 * (1. + math.cos(math.pi * epoch / max_epoch)) + min_lr
597 | for param_group in optimizer.param_groups:
598 | param_group['lr'] = lr
599 |
600 | def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
601 | """Warmup the learning rate"""
602 | lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max_step)
603 | for param_group in optimizer.param_groups:
604 | param_group['lr'] = lr
605 |
606 | def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
607 | """Decay the learning rate"""
608 | lr = max(min_lr, init_lr * (decay_rate**epoch))
609 | for param_group in optimizer.param_groups:
610 | param_group['lr'] = lr
611 |
612 | import numpy as np
613 | import io
614 | import os
615 | import time
616 | from collections import defaultdict, deque
617 | import datetime
618 |
619 | import torch
620 | import torch.distributed as dist
621 |
622 |
623 | class MetricLogger(object):
624 | def __init__(self, delimiter="\t"):
625 | self.meters = defaultdict(SmoothedValue)
626 | self.delimiter = delimiter
627 |
628 | def update(self, **kwargs):
629 | for k, v in kwargs.items():
630 | if isinstance(v, torch.Tensor):
631 | v = v.item()
632 | assert isinstance(v, (float, int))
633 | self.meters[k].update(v)
634 |
635 | def __getattr__(self, attr):
636 | if attr in self.meters:
637 | return self.meters[attr]
638 | if attr in self.__dict__:
639 | return self.__dict__[attr]
640 | raise AttributeError("'{}' object has no attribute '{}'".format(
641 | type(self).__name__, attr))
642 |
643 | def __str__(self):
644 | loss_str = []
645 | for name, meter in self.meters.items():
646 | loss_str.append(
647 | "{}: {}".format(name, str(meter))
648 | )
649 | return self.delimiter.join(loss_str)
650 |
651 | def global_avg(self):
652 | loss_str = []
653 | for name, meter in self.meters.items():
654 | loss_str.append(
655 | "{}: {:.4f}".format(name, meter.global_avg)
656 | )
657 | return self.delimiter.join(loss_str)
658 |
659 | def synchronize_between_processes(self):
660 | for meter in self.meters.values():
661 | meter.synchronize_between_processes()
662 |
663 | def add_meter(self, name, meter):
664 | self.meters[name] = meter
665 |
666 | def log_every(self, iterable, print_freq, header=None):
667 | i = 0
668 | if not header:
669 | header = ''
670 | start_time = time.time()
671 | end = time.time()
672 | iter_time = SmoothedValue(fmt='{avg:.4f}')
673 | data_time = SmoothedValue(fmt='{avg:.4f}')
674 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
675 | log_msg = [
676 | header,
677 | '[{0' + space_fmt + '}/{1}]',
678 | 'eta: {eta}',
679 | '{meters}',
680 | 'time: {time}',
681 | 'data: {data}'
682 | ]
683 | if torch.cuda.is_available():
684 | log_msg.append('max mem: {memory:.0f}')
685 | log_msg = self.delimiter.join(log_msg)
686 | MB = 1024.0 * 1024.0
687 | for obj in iterable:
688 | data_time.update(time.time() - end)
689 | yield obj
690 | iter_time.update(time.time() - end)
691 | if i % print_freq == 0 or i == len(iterable) - 1:
692 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
693 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
694 | if torch.cuda.is_available():
695 | print(log_msg.format(
696 | i, len(iterable), eta=eta_string,
697 | meters=str(self),
698 | time=str(iter_time), data=str(data_time),
699 | memory=torch.cuda.max_memory_allocated() / MB))
700 | else:
701 | print(log_msg.format(
702 | i, len(iterable), eta=eta_string,
703 | meters=str(self),
704 | time=str(iter_time), data=str(data_time)))
705 | i += 1
706 | end = time.time()
707 | total_time = time.time() - start_time
708 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
709 | print('{} Total time: {} ({:.4f} s / it)'.format(
710 | header, total_time_str, total_time / len(iterable)))
711 |
712 |
713 |
714 | class SmoothedValue(object):
715 | """Track a series of values and provide access to smoothed values over a
716 | window or the global series average.
717 | """
718 |
719 | def __init__(self, window_size=20, fmt=None):
720 | if fmt is None:
721 | fmt = "{median:.4f} ({global_avg:.4f})"
722 | self.deque = deque(maxlen=window_size)
723 | self.total = 0.0
724 | self.count = 0
725 | self.fmt = fmt
726 |
727 | def update(self, value, n=1):
728 | self.deque.append(value)
729 | self.count += n
730 | self.total += value * n
731 |
732 | def synchronize_between_processes(self):
733 | """
734 | Warning: does not synchronize the deque!
735 | """
736 | if not is_dist_avail_and_initialized():
737 | return
738 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
739 | dist.barrier()
740 | dist.all_reduce(t)
741 | t = t.tolist()
742 | self.count = int(t[0])
743 | self.total = t[1]
744 |
745 | @property
746 | def median(self):
747 | d = torch.tensor(list(self.deque))
748 | return d.median().item()
749 |
750 | @property
751 | def avg(self):
752 | d = torch.tensor(list(self.deque), dtype=torch.float32)
753 | return d.mean().item()
754 |
755 | @property
756 | def global_avg(self):
757 | return self.total / self.count
758 |
759 | @property
760 | def max(self):
761 | return max(self.deque)
762 |
763 | @property
764 | def value(self):
765 | return self.deque[-1]
766 |
767 | def __str__(self):
768 | return self.fmt.format(
769 | median=self.median,
770 | avg=self.avg,
771 | global_avg=self.global_avg,
772 | max=self.max,
773 | value=self.value)
774 |
775 | class AttrDict(dict):
776 | def __init__(self, *args, **kwargs):
777 | super(AttrDict, self).__init__(*args, **kwargs)
778 | self.__dict__ = self
779 |
780 |
781 | def compute_acc(logits, label, reduction='mean'):
782 | ret = (torch.argmax(logits, dim=1) == label).float()
783 | if reduction == 'none':
784 | return ret.detach()
785 | elif reduction == 'mean':
786 | return ret.mean().item()
787 |
788 | def compute_n_params(model, return_str=True):
789 | tot = 0
790 | for p in model.parameters():
791 | w = 1
792 | for x in p.shape:
793 | w *= x
794 | tot += w
795 | if return_str:
796 | if tot >= 1e6:
797 | return '{:.1f}M'.format(tot / 1e6)
798 | else:
799 | return '{:.1f}K'.format(tot / 1e3)
800 | else:
801 | return tot
802 |
803 | def setup_for_distributed(is_master):
804 | """
805 | This function disables printing when not in master process
806 | """
807 | import builtins as __builtin__
808 | builtin_print = __builtin__.print
809 |
810 | def print(*args, **kwargs):
811 | force = kwargs.pop('force', False)
812 | if is_master or force:
813 | builtin_print(*args, **kwargs)
814 |
815 | __builtin__.print = print
816 |
817 |
818 | def is_dist_avail_and_initialized():
819 | if not dist.is_available():
820 | return False
821 | if not dist.is_initialized():
822 | return False
823 | return True
824 |
825 |
826 | def get_world_size():
827 | if not is_dist_avail_and_initialized():
828 | return 1
829 | return dist.get_world_size()
830 |
831 |
832 | def get_rank():
833 | if not is_dist_avail_and_initialized():
834 | return 0
835 | return dist.get_rank()
836 |
837 |
838 | def is_main_process():
839 | return get_rank() == 0
840 |
841 |
842 | def save_on_master(*args, **kwargs):
843 | if is_main_process():
844 | torch.save(*args, **kwargs)
845 |
846 |
847 | def init_distributed_mode(args):
848 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
849 | args.rank = int(os.environ["RANK"])
850 | args.world_size = int(os.environ['WORLD_SIZE'])
851 | args.gpu = int(os.environ['LOCAL_RANK'])
852 | elif 'SLURM_PROCID' in os.environ:
853 | args.rank = int(os.environ['SLURM_PROCID'])
854 | args.gpu = args.rank % torch.cuda.device_count()
855 | else:
856 | print('Not using distributed mode')
857 | args.distributed = False
858 | return
859 |
860 | args.distributed = True
861 |
862 | torch.cuda.set_device(args.gpu)
863 | args.dist_backend = 'nccl'
864 | print('| distributed init (rank {}, word {}): {}'.format(
865 | args.rank, args.world_size, args.dist_url), flush=True)
866 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
867 | world_size=args.world_size, rank=args.rank)
868 | torch.distributed.barrier()
869 | setup_for_distributed(args.rank == 0)
870 |
871 |
--------------------------------------------------------------------------------
/src/vl_distill_utils.py:
--------------------------------------------------------------------------------
1 | """Move some basic utils in distill.py in VL-Distill here"""
2 | import os
3 | import numpy as np
4 | import torch
5 | from sklearn.metrics.pairwise import cosine_similarity
6 | from src.networks import TextEncoder
7 |
8 | __all__ = [
9 | "shuffle_files",
10 | "nearest_neighbor",
11 | "get_images_texts",
12 | "load_or_process_file",
13 | ]
14 |
15 |
16 | def shuffle_files(img_expert_files, txt_expert_files):
17 | # Check if both lists have the same length and if the lists are not empty
18 | assert len(img_expert_files) == len(txt_expert_files), "Number of image files and text files does not match"
19 | assert len(img_expert_files) != 0, "No files to shuffle"
20 | shuffled_indices = np.random.permutation(len(img_expert_files))
21 |
22 | # Apply the shuffled indices to both lists
23 | img_expert_files = np.take(img_expert_files, shuffled_indices)
24 | txt_expert_files = np.take(txt_expert_files, shuffled_indices)
25 | return img_expert_files, txt_expert_files
26 |
27 | def nearest_neighbor(sentences, query_embeddings, database_embeddings):
28 | """Find the nearest neighbors for a batch of embeddings.
29 |
30 | Args:
31 | sentences: The original sentences from which the embeddings were computed.
32 | query_embeddings: A batch of embeddings for which to find the nearest neighbors.
33 | database_embeddings: All pre-computed embeddings.
34 |
35 | Returns:
36 | A list of the most similar sentences for each embedding in the batch.
37 | """
38 | nearest_neighbors = []
39 |
40 | for query in query_embeddings:
41 | similarities = cosine_similarity(query.reshape(1, -1), database_embeddings)
42 |
43 | most_similar_index = np.argmax(similarities)
44 |
45 | nearest_neighbors.append(sentences[most_similar_index])
46 |
47 | return nearest_neighbors
48 |
49 |
50 | def get_images_texts(n, dataset, args, i_have_indices=None):
51 | """Get random n images and corresponding texts from the dataset.
52 |
53 | Args:
54 | n: Number of images and texts to retrieve.
55 | dataset: The dataset containing image-text pairs.
56 |
57 | Returns:
58 | A tuple containing two elements:
59 | - A tensor of randomly selected images.
60 | - A tensor of the corresponding texts, encoded as floats.
61 | """
62 | # Generate n unique random indices
63 | if i_have_indices is not None:
64 | idx_shuffle = i_have_indices
65 | else:
66 | idx_shuffle = np.random.permutation(len(dataset))[:n]
67 |
68 | # Initialize the text encoder
69 | text_encoder = TextEncoder(args)
70 |
71 | image_syn = torch.stack([dataset[i][0] for i in idx_shuffle])
72 |
73 | text_syn = text_encoder([dataset[i][1] for i in idx_shuffle], device="cpu")
74 |
75 | return image_syn, text_syn.float()
76 |
77 |
78 | def load_or_process_file(file_type, process_func, args, data_source):
79 | """
80 | Load the processed file if it exists, otherwise process the data source and create the file.
81 |
82 | Args:
83 | file_type: The type of the file (e.g., 'train', 'test').
84 | process_func: The function to process the data source.
85 | args: The arguments required by the process function and to build the filename.
86 | data_source: The source data to be processed.
87 |
88 | Returns:
89 | The loaded data from the file.
90 | """
91 | filename = f'{args.dataset}_{args.text_encoder}_{file_type}_embed.npz'
92 |
93 |
94 | if not os.path.exists(filename):
95 | print(f'Creating {filename}')
96 | process_func(args, data_source)
97 | else:
98 | print(f'Loading {filename}')
99 |
100 | return np.load(filename)
101 |
102 | def get_LC_images_texts(n, dataset, args):
103 | """Get random n images and corresponding texts from the dataset.
104 |
105 | Args:
106 | n: Number of images and texts to retrieve.
107 | dataset: The dataset containing image-text pairs.
108 |
109 | Returns:
110 | A tuple containing two elements:
111 | - A tensor of randomly selected images.
112 | - A tensor of the corresponding texts, encoded as floats.
113 | """
114 | # Generate n unique random indices
115 | idx_shuffle = np.random.permutation(len(dataset))[:n]
116 |
117 | # Initialize the text encoder
118 | text_encoder = TextEncoder(args)
119 |
120 | image_syn = torch.stack([dataset[i][0] for i in idx_shuffle])
121 |
122 | text_syn = text_encoder([dataset[i][1] for i in idx_shuffle], device="cpu")
123 |
124 | return image_syn, text_syn.float()
125 |
126 |
--------------------------------------------------------------------------------