├── .gitignore
├── Flickr30k
└── ann_file
│ ├── flickr30k_test.json
│ ├── flickr30k_train.json
│ └── flickr30k_val.json
├── README.md
├── __pycache__
├── epoch.cpython-39.pyc
├── networks.cpython-39.pyc
└── utils.cpython-39.pyc
├── buffer.py
├── data
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-39.pyc
│ ├── coco_dataset.cpython-39.pyc
│ └── flickr30k_dataset.cpython-39.pyc
├── cifar_dataset.py
├── coco_dataset.py
└── flickr30k_dataset.py
├── distill.py
├── epoch.py
├── images
├── clip_0.png
├── clip_2000.png
├── coco_syn_0.png
├── coco_syn_2000.png
├── logo.png
├── loss.png
├── more_vis.png
├── pipeline.png
├── synthetic_images_0.png
├── synthetic_images_2000.png
├── table.png
├── teaser.png
├── text_noise_0.png
├── text_noise_1.png
└── visualization.png
├── index.html
├── model.py
├── networks.py
├── reparam_module.py
├── requirements.yaml
├── style.css
├── transform
├── .DS_Store
├── __pycache__
│ └── randaugment.cpython-39.pyc
└── randaugment.py
├── utils.py
└── website
└── poster.pdf
/.gitignore:
--------------------------------------------------------------------------------
1 | **/logged_files
2 | **/wandb
3 | *.png
4 | **/flickr30k_images
5 | **/output
6 | **/buffers
7 | **/data/cifar-10-batches-py
8 | **/data/cifar-10-python.tar.gz
9 | *.npz
10 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Vision-Language Dataset Distillation
2 |
3 | ### [Project Page](https://princetonvisualai.github.io/multimodal_dataset_distillation/) | [Paper](https://arxiv.org/abs/2308.07545)
4 | TMLR, 2024
5 |
6 | [Xindi Wu](https://xindiwu.github.io/), [Byron Zhang](https://www.linkedin.com/in/byron-zhang), [Zhiwei Deng](https://lucas2012.github.io/), [Olga Russakovsky](https://www.cs.princeton.edu/~olgarus/)
7 |
8 | 
9 | This codebase is for the paper [Vision-Language Dataset Distillation](https://arxiv.org/abs/2308.07545). Please visit our [Project Page](https://princetonvisualai.github.io/multimodal_dataset_distillation/) for detailed results.
10 |
11 |
12 | Dataset distillation methods offer the promise of reducing a large-scale dataset down to a significantly smaller set of (potentially synthetic) training examples, which preserve sufficient information for training a new model from scratch. So far dataset distillation methods have been developed for image classification. However, with the rise in capabilities of vision-language models, and especially given the scale of datasets necessary to train these models, the time is ripe to expand dataset distillation methods beyond image classification.
13 |
14 | 
15 |
16 | In this work, we take the first steps towards this goal by expanding on the idea of trajectory matching to create a distillation method for vision-language datasets. We demonstrate significant improvements on the challenging Flickr30K and COCO retrieval benchmarks: for example, on Flickr30K the best coreset selection method which selects 1000 image-text pairs for training is able to achieve only 5.6% image-to-text retrieval accuracy (i.e., recall@1); in contrast, our dataset distillation approach almost doubles that to 9.9% with just 100 (an order of magnitude fewer) training pairs.
17 |
18 | 
19 | ## Getting Started
20 | [Adapted from [mtt-distillaion](https://github.com/GeorgeCazenavette/mtt-distillation) by [George Cazenavette](https://georgecazenavette.github.io) et al.]
21 | First, download our repo:
22 | ```bash
23 | git clone https://github.com/princetonvisualai/multimodal_dataset_distillation.git
24 | cd multimodal_dataset_distillation
25 | ```
26 |
27 | For an express instillation, we include ```.yaml``` files.
28 |
29 | You need a RTX 30XX GPU (or newer), and run
30 |
31 | ```bash
32 | conda env create -f requirements.yaml
33 | ```
34 |
35 | You can then activate your conda environment with
36 | ```bash
37 | conda activate vl-distill
38 | ```
39 |
40 | ## Datasets and Annotations
41 | Please download images for the Flickr30K and COCO datasets, create separate directories for the annotations, linked below.
42 |
43 | Flickr30K: [[Train]](https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json)[[Val]](https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json)[[Test]](https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json)[[Images]](https://www.kaggle.com/datasets/hsankesara/flickr-image-dataset)
44 |
45 | COCO: [[Train]](https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json)[[Val]](https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json)[[Test]](https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json)[[Images]](https://cocodataset.org/#download)
46 |
47 | ## Training Expert Trajectories
48 | The following command generates 20 expert trajectories using NFNet image encoder and BERT text encoder. Traing is done on the Flickr30K dataset, by simultaneously finetuning a pre-trained NFNet model while training a projection layer over the a frozen pre-trained BERT.
49 | ```bash
50 | python buffer.py --dataset=flickr --train_epochs=10 --num_experts=20 --buffer_path={path_to_buffer} --image_encoder=nfnet --text_encoder=bert --image_size=224 --image_root={path_to_image_directory} --ann_root={path_to_annotation_directory}
51 | ```
52 |
53 | ## Bi-Trajectories Guided Co-Distillation
54 | The following command distills 100 synthetic samples for the Flickr30K dataset given the expert trajectories.
55 | ```bash
56 | python distill.py --dataset=flickr --syn_steps=8 --expert_epochs=1 --max_start_epoch=2 --lr_img=1000 --lr_txt=1000 --lr_lr=1e-02 --buffer_path={path_do_buffer} --num_queries 100 --image_encoder=nfnet --text_encoder=bert --draw True --image_root={path_to_image_directory} --ann_root={path_to_annotation_directory} --save_dir={path_to_saved_distilled_data}
57 | ```
58 |
59 | ## Acknowledgements
60 | This material is based upon work supported by the National Science Foundation under Grant No. 2107048. Any opinions, findings, and conclusions or recommendations expressed in this material are those of the author(s) and do not necessarily reflect the views of the National Science Foundation. We thank many people from Princeton Visual AI lab (Allison Chen, Jihoon Chung, Tyler Zhu, Ye Zhu, William Yang and Kaiqu Liang) and Princeton NLP group (Carlos E. Jimenez, John Yang), Tiffany Ling and George Cazenavette for their helpful feedback on this work.
61 |
62 | ## Citation
63 | ```
64 | @article{wu2023multimodal,
65 | title={Multimodal Dataset Distillation for Image-Text Retrieval},
66 | author={Wu, Xindi and Zhang, Byron and Deng, Zhiwei and Russakovsky, Olga},
67 | journal={arXiv preprint arXiv:2308.07545},
68 | year={2023}
69 | }
70 | ```
71 |
72 |
--------------------------------------------------------------------------------
/__pycache__/epoch.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/__pycache__/epoch.cpython-39.pyc
--------------------------------------------------------------------------------
/__pycache__/networks.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/__pycache__/networks.cpython-39.pyc
--------------------------------------------------------------------------------
/__pycache__/utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/__pycache__/utils.cpython-39.pyc
--------------------------------------------------------------------------------
/buffer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torch
4 | from epoch import epoch, epoch_test, itm_eval
5 | import wandb
6 | import warnings
7 | import datetime
8 | from data import get_dataset_flickr, textprocess
9 | from networks import CLIPModel_full
10 | from utils import load_or_process_file
11 |
12 | warnings.filterwarnings("ignore", category=DeprecationWarning)
13 |
14 | def main(args):
15 | #wandb.init(mode="disabled")
16 | wandb.init(project='DatasetDistillation', entity='dataset_distillation', config=args, name=args.name)
17 |
18 |
19 | args.dsa = True if args.dsa == 'True' else False
20 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
21 | args.distributed = torch.cuda.device_count() > 1
22 |
23 |
24 | # print('\n================== Exp %d ==================\n '%exp)
25 | print('Hyper-parameters: \n', args.__dict__)
26 |
27 | save_dir = os.path.join(args.buffer_path, args.dataset)
28 | if args.dataset in ["CIFAR10", "CIFAR100"] and not args.zca:
29 | save_dir += "_NO_ZCA"
30 | save_dir = os.path.join(save_dir, args.image_encoder, args.text_encoder)
31 | if not os.path.exists(save_dir):
32 | os.makedirs(save_dir)
33 | ''' organize the datasets '''
34 | trainloader, testloader, train_dataset, test_dataset = get_dataset_flickr(args)
35 | data = load_or_process_file('text', textprocess, args, testloader)
36 | bert_test_embed = torch.from_numpy(data['bert_test_embed']).cpu()
37 |
38 |
39 | img_trajectories = []
40 | txt_trajectories = []
41 |
42 | for it in range(0, args.num_experts):
43 |
44 | ''' Train synthetic data '''
45 |
46 | teacher_net = CLIPModel_full(args)
47 | img_teacher_net = teacher_net.image_encoder.to(args.device)
48 | txt_teacher_net = teacher_net.text_projection.to(args.device)
49 | if args.text_trainable:
50 | txt_teacher_net = teacher_net.text_encoder.to(args.device)
51 | if args.distributed:
52 | img_teacher_net = torch.nn.DataParallel(img_teacher_net)
53 | txt_teacher_net = torch.nn.DataParallel(txt_teacher_net)
54 | img_teacher_net.train()
55 | txt_teacher_net.train()
56 | lr_img = args.lr_teacher_img
57 | lr_txt = args.lr_teacher_txt
58 |
59 | teacher_optim_img = torch.optim.SGD(img_teacher_net.parameters(), lr=lr_img, momentum=args.mom, weight_decay=args.l2)
60 | teacher_optim_txt = torch.optim.SGD(txt_teacher_net.parameters(), lr=lr_txt, momentum=args.mom, weight_decay=args.l2)
61 | teacher_optim_img.zero_grad()
62 | teacher_optim_txt.zero_grad()
63 |
64 | img_timestamps = []
65 | txt_timestamps = []
66 |
67 | img_timestamps.append([p.detach().cpu() for p in img_teacher_net.parameters()])
68 | txt_timestamps.append([p.detach().cpu() for p in txt_teacher_net.parameters()])
69 |
70 | lr_schedule = [args.train_epochs // 2 + 1]
71 |
72 | for e in range(args.train_epochs):
73 | train_loss, train_acc = epoch(e, trainloader, teacher_net, teacher_optim_img, teacher_optim_txt, args)
74 | score_val_i2t, score_val_t2i = epoch_test(testloader, teacher_net, args.device, bert_test_embed)
75 | val_result = itm_eval(score_val_i2t, score_val_t2i, testloader.dataset.txt2img, testloader.dataset.img2txt)
76 |
77 |
78 | wandb.log({"train_loss": train_loss})
79 | wandb.log({"train_acc": train_acc})
80 | wandb.log({"txt_r1": val_result['txt_r1']})
81 | wandb.log({"txt_r5": val_result['txt_r5']})
82 | wandb.log({"txt_r10": val_result['txt_r10']})
83 | wandb.log({"txt_r_mean": val_result['txt_r_mean']})
84 | wandb.log({"img_r1": val_result['img_r1']})
85 | wandb.log({"img_r5": val_result['img_r5']})
86 | wandb.log({"img_r10": val_result['img_r10']})
87 | wandb.log({"img_r_mean": val_result['img_r_mean']})
88 | wandb.log({"r_mean": val_result['r_mean']})
89 |
90 | print("Itr: {}\tEpoch: {}\tTrain Acc: {}\tImg R@1: {}\tR@5: {}\tR@10: {}\tR@Mean: {}\tTxt R@1: {}\tR@5: {}\tR@10: {}\tR@Mean: {}".format(
91 | it, e, train_acc,
92 | val_result['img_r1'], val_result['img_r5'], val_result['img_r10'], val_result['img_r_mean'],
93 | val_result['txt_r1'], val_result['txt_r5'], val_result['txt_r10'], val_result['txt_r_mean']))
94 | img_timestamps.append([p.detach().cpu() for p in img_teacher_net.parameters()])
95 | txt_timestamps.append([p.detach().cpu() for p in txt_teacher_net.parameters()])
96 |
97 | if e in lr_schedule and args.decay:
98 | lr *= 0.1
99 | teacher_optim_img = torch.optim.SGD(img_teacher_net.parameters(), lr=lr, momentum=args.mom, weight_decay=args.l2)
100 | teacher_optim_txt = torch.optim.SGD(txt_teacher_net.parameters(), lr=lr, momentum=args.mom, weight_decay=args.l2)
101 | teacher_optim_img.zero_grad()
102 | teacher_optim_txt.zero_grad()
103 |
104 | img_trajectories.append(img_timestamps)
105 | txt_trajectories.append(txt_timestamps)
106 | n = 0
107 | while os.path.exists(os.path.join(save_dir, "img_replay_buffer_{}.pt".format(n))):
108 | n += 1
109 | print("Saving {}".format(os.path.join(save_dir, "img_replay_buffer_{}.pt".format(n))))
110 | torch.save(img_trajectories, os.path.join(save_dir, "img_replay_buffer_{}.pt".format(n)))
111 | print("Saving {}".format(os.path.join(save_dir, "txt_replay_buffer_{}.pt".format(n))))
112 | torch.save(txt_trajectories, os.path.join(save_dir, "txt_replay_buffer_{}.pt".format(n)))
113 |
114 | img_trajectories = []
115 | txt_trajectories = []
116 |
117 |
118 | if __name__ == '__main__':
119 | parser = argparse.ArgumentParser(description='Parameter Processing')
120 | parser.add_argument('--dataset', type=str, default='flickr', choices=['flickr', 'coco'], help='dataset')
121 | parser.add_argument('--num_experts', type=int, default=100, help='training iterations')
122 | parser.add_argument('--lr_teacher_img', type=float, default=0.1, help='learning rate for updating network parameters')
123 | parser.add_argument('--lr_teacher_txt', type=float, default=0.1, help='learning rate for updating network parameters')
124 | parser.add_argument('--batch_train', type=int, default=128, help='batch size for training networks')
125 | parser.add_argument('--dsa', type=str, default='True', choices=['True', 'False'],
126 | help='whether to use differentiable Siamese augmentation.')
127 | parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate',
128 | help='differentiable Siamese augmentation strategy')
129 | parser.add_argument('--data_path', type=str, default='./data/Flickr30k/', help='dataset path')
130 | parser.add_argument('--buffer_path', type=str, default='./buffers', help='buffer path')
131 | parser.add_argument('--train_epochs', type=int, default=50)
132 | parser.add_argument('--zca', action='store_true')
133 | parser.add_argument('--decay', action='store_true')
134 | parser.add_argument('--mom', type=float, default=0, help='momentum')
135 | parser.add_argument('--l2', type=float, default=0, help='l2 regularization')
136 | parser.add_argument('--save_interval', type=int, default=10)
137 | current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
138 | parser.add_argument('--name', type=str, default=current_time, help='name of wandb run')
139 | parser.add_argument('--text_pretrained', type=bool, default=True, help='text_pretrained')
140 | parser.add_argument('--image_pretrained', type=bool, default=True, help='image_pretrained')
141 | parser.add_argument('--text_trainable', type=bool, default=False, help='text_trainable')
142 | parser.add_argument('--image_trainable', type=bool, default=True, help='image_trainable')
143 | parser.add_argument('--batch_size_train', type=int, default=128, help='batch_size_train')
144 | parser.add_argument('--batch_size_test', type=int, default=128, help='batch_size_test')
145 | parser.add_argument('--image_root', type=str, default='./Flickr30k/flickr-image-dataset/flickr30k-images/', help='location of image root')
146 | parser.add_argument('--ann_root', type=str, default='./Flickr30k/ann_file/', help='location of ann root')
147 | parser.add_argument('--image_size', type=int, default=224, help='image_size')
148 | parser.add_argument('--k_test', type=int, default=128, help='k_test')
149 | parser.add_argument('--load_npy', type=bool, default=False, help='load_npy')
150 | parser.add_argument('--image_encoder', type=str, default='resnet50', choices=['nfnet', 'resnet18_gn', 'vit_tiny', 'nf_resnet50', 'nf_regnet'], help='image encoder')
151 | parser.add_argument('--text_encoder', type=str, default='bert', choices=['bert', 'clip'], help='text encoder')
152 | parser.add_argument('--margin', default=0.2, type=float,
153 | help='Rank loss margin.')
154 | parser.add_argument('--measure', default='cosine',
155 | help='Similarity measure used (cosine|order)')
156 | parser.add_argument('--max_violation', action='store_true',
157 | help='Use max instead of sum in the rank loss.')
158 | parser.add_argument('--only_has_image_projection', type=bool, default=False, help='None')
159 | parser.add_argument('--grounding', type=bool, default=False, help='None')
160 | parser.add_argument('--distill', type=bool, default=False, help='if distill')
161 | args = parser.parse_args()
162 |
163 | main(args)
164 |
165 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 | from transform.randaugment import RandomAugment
3 | from torchvision.transforms.functional import InterpolationMode
4 | import torch
5 | from torch.utils.data import DataLoader
6 | from torch.utils.data import Dataset
7 | from torchvision.datasets.utils import download_url
8 | import json
9 | from PIL import Image
10 | import os
11 | from torchvision import transforms as T
12 | from networks import CLIPModel_full
13 | from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval
14 | from data.coco_dataset import coco_train, coco_caption_eval, coco_retrieval_eval
15 | import numpy as np
16 | from tqdm import tqdm
17 | @torch.no_grad()
18 | def textprocess(args, testloader):
19 | net = CLIPModel_full(args).to('cuda')
20 | net.eval()
21 | texts = testloader.dataset.text
22 | if args.dataset in ['flickr', 'coco']:
23 | if args.dataset == 'flickr':
24 | bert_test_embed = net.text_encoder(texts)
25 | elif args.dataset == 'coco':
26 | bert_test_embed = torch.cat((net.text_encoder(texts[:10000]), net.text_encoder(texts[10000:20000]), net.text_encoder(texts[20000:])), dim=0)
27 |
28 | bert_test_embed_np = bert_test_embed.cpu().numpy()
29 | np.savez(f'{args.dataset}_{args.text_encoder}_text_embed.npz', bert_test_embed=bert_test_embed_np)
30 | else:
31 | raise NotImplementedError
32 | return
33 |
34 | @torch.no_grad()
35 | def textprocess_train(args, texts):
36 | net = CLIPModel_full(args).to('cuda')
37 | net.eval()
38 | chunk_size = 2000
39 | chunks = []
40 | for i in tqdm(range(0, len(texts), chunk_size)):
41 | chunk = net.text_encoder(texts[i:i + chunk_size]).cpu()
42 | chunks.append(chunk)
43 | del chunk
44 | torch.cuda.empty_cache() # free up memory
45 | bert_test_embed = torch.cat(chunks, dim=0)
46 |
47 | print('bert_test_embed.shape: ', bert_test_embed.shape)
48 | bert_test_embed_np = bert_test_embed.numpy()
49 | if args.dataset in ['flickr', 'coco']:
50 | np.savez(f'{args.dataset}_{args.text_encoder}_train_text_embed.npz', bert_test_embed=bert_test_embed_np)
51 | else:
52 | raise NotImplementedError
53 | return
54 |
55 |
56 | def create_dataset(args, min_scale=0.5):
57 | normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
58 | transform_train = transforms.Compose([
59 | transforms.RandomResizedCrop(args.image_size,scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC),
60 | transforms.RandomHorizontalFlip(),
61 | RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize',
62 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
63 | transforms.ToTensor(),
64 | normalize,
65 | ])
66 | transform_test = transforms.Compose([
67 | transforms.Resize((args.image_size, args.image_size),interpolation=InterpolationMode.BICUBIC),
68 | transforms.ToTensor(),
69 | normalize,
70 | ])
71 | if args.dataset=='flickr':
72 | train_dataset = flickr30k_train(transform_train, args.image_root, args.ann_root)
73 | val_dataset = flickr30k_retrieval_eval(transform_test, args.image_root, args.ann_root, 'val')
74 | test_dataset = flickr30k_retrieval_eval(transform_test, args.image_root, args.ann_root, 'test')
75 | return train_dataset, val_dataset, test_dataset
76 |
77 | elif args.dataset=='coco':
78 | train_dataset = coco_train(transform_train, args.image_root, args.ann_root)
79 | val_dataset = coco_retrieval_eval(transform_test, args.image_root, args.ann_root, 'val')
80 | test_dataset = coco_retrieval_eval(transform_test, args.image_root, args.ann_root, 'test')
81 | return train_dataset, val_dataset, test_dataset
82 | else:
83 | raise NotImplementedError
84 | return train_dataset, val_dataset, test_dataset
85 |
86 | def create_sampler(datasets, shuffles, num_tasks, global_rank):
87 | samplers = []
88 | for dataset,shuffle in zip(datasets,shuffles):
89 | sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle)
90 | samplers.append(sampler)
91 | return samplers
92 |
93 |
94 | def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
95 | loaders = []
96 | for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns):
97 | if is_train:
98 | shuffle = (sampler is None)
99 | drop_last = True
100 | else:
101 | shuffle = False
102 | drop_last = False
103 | loader = DataLoader(
104 | dataset,
105 | batch_size=bs,
106 | num_workers=n_worker,
107 | pin_memory=True,
108 | sampler=sampler,
109 | shuffle=shuffle,
110 | collate_fn=collate_fn,
111 | drop_last=drop_last,
112 | )
113 | loaders.append(loader)
114 | return loaders
115 |
116 |
117 | def get_dataset_flickr(args):
118 | print("Creating retrieval dataset")
119 | train_dataset, val_dataset, test_dataset = create_dataset(args)
120 |
121 | samplers = [None, None, None]
122 | train_shuffle = True
123 | train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers,
124 | batch_size=[args.batch_size_train]+[args.batch_size_test]*2,
125 | num_workers=[4,4,4],
126 | is_trains=[train_shuffle, False, False],
127 | collate_fns=[None,None,None])
128 | return train_loader, test_loader, train_dataset, test_dataset
129 |
130 |
--------------------------------------------------------------------------------
/data/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/data/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/data/__pycache__/coco_dataset.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/data/__pycache__/coco_dataset.cpython-39.pyc
--------------------------------------------------------------------------------
/data/__pycache__/flickr30k_dataset.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/data/__pycache__/flickr30k_dataset.cpython-39.pyc
--------------------------------------------------------------------------------
/data/cifar_dataset.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from typing import Any, Tuple
4 | from torch.utils.data import Dataset
5 | from PIL import Image
6 | import numpy as np
7 | from torchvision.datasets import CIFAR10
8 | from collections import defaultdict
9 |
10 |
11 | CLASSES = [
12 | "airplane",
13 | "automobile",
14 | "bird",
15 | "cat",
16 | "deer",
17 | "dog",
18 | "frog",
19 | "horse",
20 | "ship",
21 | "truck",
22 | ]
23 |
24 | PROMPTS_1 = ["This is a {}"]
25 |
26 | PROMPTS_5 = [
27 | "a photo of a {}",
28 | "a blurry image of a {}",
29 | "a photo of the {}",
30 | "a pixelated photo of a {}",
31 | "a picture of a {}",
32 | ]
33 |
34 | PROMPTS = [
35 | 'a photo of a {}',
36 | 'a blurry photo of a {}',
37 | 'a low contrast photo of a {}',
38 | 'a high contrast photo of a {}',
39 | 'a bad photo of a {}',
40 | 'a good photo of a {}',s
41 | 'a photo of a small {}',
42 | 'a photo of a big {}',
43 | 'a photo of the {}',
44 | 'a blurry photo of the {}',
45 | 'a low contrast photo of the {}',
46 | 'a high contrast photo of the {}',
47 | 'a bad photo of the {}',
48 | 'a good photo of the {}',
49 | 'a photo of the small {}',
50 | 'a photo of the big {}',
51 | ]
52 |
53 |
54 | class cifar10_train(CIFAR10):
55 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False, num_prompts=1):
56 | super(cifar10_train, self).__init__(root, train=train, transform=transform, target_transform=target_transform, download=download)
57 | if num_prompts == 1:
58 | self.prompts = PROMPTS_1
59 | elif num_prompts == 5:
60 | self.prompts = PROMPTS_5
61 | else:
62 | self.prompts = PROMPTS
63 | self.captions = [prompt.format(cls) for cls in CLASSES for prompt in self.prompts]
64 | self.captions_to_label = {cap: i for i, cap in enumerate(self.captions)}
65 | self.annotations = []
66 | for i in range(len(self.data)):
67 | cls_name = CLASSES[self.targets[i]]
68 | for prompt in self.prompts:
69 | caption = prompt.format(cls_name)
70 | self.annotations.append({"img_id": i, "caption_id": self.captions_to_label[caption]})
71 | if num_prompts == 1:
72 | self.annotations = self.annotations * 5
73 |
74 | def __len__(self):
75 | return len(self.annotations)
76 |
77 | def __getitem__(self, index):
78 | ann = self.annotations[index]
79 | img_id = ann['img_id']
80 | img = self.transform(self.data[img_id])
81 | caption = self.captions[ann['caption_id']]
82 | return img, caption, img_id
83 |
84 | def fetch_distill_images(self, ipc):
85 | """
86 | Randomly fetch `x` number of images from each class using numpy and return as a tensor.
87 | """
88 | class_indices = defaultdict(list)
89 |
90 | for idx, label in enumerate(self.targets):
91 | class_indices[label].append(idx)
92 |
93 | # Randomly sample x indices for each class using numpy
94 | sampled_indices = [np.random.choice(indices, ipc, replace=False) for indices in class_indices.values()]
95 | sampled_indices = [idx for class_indices in sampled_indices for idx in class_indices]
96 |
97 | # Fetch images and labels using the selected indices
98 | images = torch.stack([self.transform(self.data[i]) for i in sampled_indices])
99 | labels = [self.targets[i] for i in sampled_indices]
100 |
101 | captions = []
102 | for label in labels:
103 | cls_name = CLASSES[label]
104 | prompt = np.random.choice(self.prompts)
105 | random_caption = prompt.format(cls_name)
106 | captions.append(random_caption)
107 |
108 | return images, captions
109 |
110 | def get_all_captions(self):
111 | return self.captions
112 |
113 |
114 | class cifar10_retrieval_eval(cifar10_train):
115 | def __init__(self, root, train=False, transform=None, target_transform=None, download=False, num_prompts=1):
116 | """
117 | image_root (string): Root directory of images (e.g. coco/images/)
118 | ann_root (string): directory to store the annotation file
119 | split (string): val or test
120 | """
121 | super(cifar10_retrieval_eval, self).__init__(root, train=train, transform=transform, target_transform=target_transform, download=download, num_prompts=num_prompts)
122 | self.text = self.captions
123 | self.txt2img = {}
124 | self.img2txt = defaultdict(list)
125 |
126 | for ann in self.annotations:
127 | img_id = ann['img_id']
128 | caption_id = ann['caption_id']
129 | self.img2txt[img_id].append(caption_id)
130 | self.txt2img[caption_id] = img_id
131 |
132 | def __len__(self):
133 | return len(self.data)
134 |
135 | def __getitem__(self, index):
136 | image = self.transform(self.data[index])
137 | return image, index
138 |
--------------------------------------------------------------------------------
/data/coco_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from torch.utils.data import Dataset
4 | from torchvision.datasets.utils import download_url
5 | from PIL import Image
6 | import re
7 |
8 | def pre_caption(caption,max_words=50):
9 | caption = re.sub(
10 | r"([.!\"()*#:;~])",
11 | ' ',
12 | caption.lower(),
13 | )
14 | caption = re.sub(
15 | r"\s{2,}",
16 | ' ',
17 | caption,
18 | )
19 | caption = caption.rstrip('\n')
20 | caption = caption.strip(' ')
21 |
22 | #truncate caption
23 | caption_words = caption.split(' ')
24 | if len(caption_words)>max_words:
25 | caption = ' '.join(caption_words[:max_words])
26 |
27 | return caption
28 |
29 | class coco_train(Dataset):
30 | def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):
31 | '''
32 | image_root (string): Root directory of images (e.g. coco/images/)
33 | ann_root (string): directory to store the annotation file
34 | '''
35 | url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json'
36 | filename = 'coco_karpathy_train.json'
37 |
38 | download_url(url,ann_root)
39 |
40 | self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
41 | self.transform = transform
42 | self.image_root = image_root
43 | self.max_words = max_words
44 | self.prompt = prompt
45 |
46 | self.img_ids = {}
47 | n = 0
48 | for ann in self.annotation:
49 | img_id = ann['image_id']
50 | if img_id not in self.img_ids.keys():
51 | self.img_ids[img_id] = n
52 | n += 1
53 |
54 | def __len__(self):
55 | return len(self.annotation)
56 |
57 | def __getitem__(self, index):
58 |
59 | ann = self.annotation[index]
60 |
61 | image_path = os.path.join(self.image_root,ann['image'])
62 | image = Image.open(image_path).convert('RGB')
63 | image = self.transform(image)
64 |
65 | caption = self.prompt+pre_caption(ann['caption'], self.max_words)
66 |
67 | return image, caption, self.img_ids[ann['image_id']]
68 |
69 | def get_all_captions(self):
70 | captions = []
71 | for ann in self.annotation:
72 | caption = self.prompt + pre_caption(ann['caption'], self.max_words)
73 | captions.append(caption)
74 | return captions
75 |
76 |
77 | class coco_caption_eval(Dataset):
78 | def __init__(self, transform, image_root, ann_root, split):
79 | '''
80 | image_root (string): Root directory of images (e.g. coco/images/)
81 | ann_root (string): directory to store the annotation file
82 | split (string): val or test
83 | '''
84 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
85 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
86 | filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
87 |
88 | download_url(urls[split],ann_root)
89 |
90 | self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
91 | self.transform = transform
92 | self.image_root = image_root
93 |
94 | def __len__(self):
95 | return len(self.annotation)
96 |
97 | def __getitem__(self, index):
98 |
99 | ann = self.annotation[index]
100 |
101 | image_path = os.path.join(self.image_root,ann['image'])
102 | image = Image.open(image_path).convert('RGB')
103 | image = self.transform(image)
104 |
105 | img_id = ann['image'].split('/')[-1].strip('.jpg').split('_')[-1]
106 |
107 | return image, int(img_id)
108 |
109 |
110 | class coco_retrieval_eval(Dataset):
111 | def __init__(self, transform, image_root, ann_root, split, max_words=30):
112 | '''
113 | image_root (string): Root directory of images (e.g. coco/images/)
114 | ann_root (string): directory to store the annotation file
115 | split (string): val or test
116 | '''
117 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
118 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
119 | filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
120 |
121 | download_url(urls[split],ann_root)
122 |
123 | self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
124 | self.transform = transform
125 | self.image_root = image_root
126 |
127 | self.text = []
128 | self.image = []
129 | self.txt2img = {}
130 | self.img2txt = {}
131 |
132 | txt_id = 0
133 | for img_id, ann in enumerate(self.annotation):
134 | self.image.append(ann['image'])
135 | self.img2txt[img_id] = []
136 | for i, caption in enumerate(ann['caption']):
137 | self.text.append(pre_caption(caption,max_words))
138 | self.img2txt[img_id].append(txt_id)
139 | self.txt2img[txt_id] = img_id
140 | txt_id += 1
141 |
142 | def __len__(self):
143 | return len(self.annotation)
144 |
145 | def __getitem__(self, index):
146 |
147 | image_path = os.path.join(self.image_root, self.annotation[index]['image'])
148 | image = Image.open(image_path).convert('RGB')
149 | image = self.transform(image)
150 |
151 | return image, index
--------------------------------------------------------------------------------
/data/flickr30k_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from tqdm import tqdm
4 | import numpy as np
5 | import yaml
6 | from transformers import BertTokenizer, BertModel
7 | from torchvision import transforms as T
8 | from torchvision.transforms.functional import InterpolationMode
9 | from torch.utils.data import Dataset
10 | from torchvision.datasets.utils import download_url
11 | import argparse
12 | import re
13 | import json
14 | from PIL import Image
15 |
16 | def pre_caption(caption,max_words=50):
17 | caption = re.sub(
18 | r"([.!\"()*#:;~])",
19 | ' ',
20 | caption.lower(),
21 | )
22 | caption = re.sub(
23 | r"\s{2,}",
24 | ' ',
25 | caption,
26 | )
27 | caption = caption.rstrip('\n')
28 | caption = caption.strip(' ')
29 |
30 | #truncate caption
31 | caption_words = caption.split(' ')
32 | if len(caption_words)>max_words:
33 | caption = ' '.join(caption_words[:max_words])
34 |
35 | return caption
36 |
37 |
38 | class flickr30k_train(Dataset):
39 | def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):
40 | '''
41 | image_root (string): Root directory of images (e.g. flickr30k/)
42 | ann_root (string): directory to store the annotation file
43 | '''
44 | url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json'
45 | filename = 'flickr30k_train.json'
46 |
47 | download_url(url,ann_root)
48 |
49 | self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
50 | self.transform = transform
51 | self.image_root = image_root
52 | self.max_words = max_words
53 | self.prompt = prompt
54 |
55 | self.img_ids = {}
56 | n = 0
57 | for ann in self.annotation:
58 | img_id = ann['image_id']
59 | if img_id not in self.img_ids.keys():
60 | self.img_ids[img_id] = n
61 | n += 1
62 |
63 | def __len__(self):
64 | return len(self.annotation)
65 |
66 | def __getitem__(self, index):
67 |
68 | ann = self.annotation[index]
69 |
70 | image_path = os.path.join(self.image_root,ann['image'])
71 | image = Image.open(image_path).convert('RGB')
72 | image = self.transform(image)
73 |
74 | caption = self.prompt+pre_caption(ann['caption'], self.max_words)
75 |
76 | return image, caption, self.img_ids[ann['image_id']]
77 |
78 | def get_all_captions(self):
79 | captions = []
80 | for ann in self.annotation:
81 | caption = self.prompt + pre_caption(ann['caption'], self.max_words)
82 | captions.append(caption)
83 | return captions
84 |
85 |
86 |
87 | class flickr30k_retrieval_eval(Dataset):
88 | def __init__(self, transform, image_root, ann_root, split, max_words=30):
89 | '''
90 | image_root (string): Root directory of images (e.g. flickr30k/)
91 | ann_root (string): directory to store the annotation file
92 | split (string): val or test
93 | '''
94 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json',
95 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'}
96 | filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'}
97 |
98 | download_url(urls[split],ann_root)
99 |
100 | self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
101 | self.transform = transform
102 | self.image_root = image_root
103 | self.max_words = max_words
104 |
105 | self.text = []
106 | self.image = []
107 | self.txt2img = {}
108 | self.img2txt = {}
109 |
110 | txt_id = 0
111 | for img_id, ann in enumerate(self.annotation):
112 | self.image.append(ann['image'])
113 | self.img2txt[img_id] = []
114 | for i, caption in enumerate(ann['caption']):
115 | self.text.append(pre_caption(caption,max_words))
116 | self.img2txt[img_id].append(txt_id)
117 | self.txt2img[txt_id] = img_id
118 | txt_id += 1
119 |
120 | def __len__(self):
121 | return len(self.annotation)
122 |
123 | def __getitem__(self, index):
124 | image_path = os.path.join(self.image_root, self.annotation[index]['image'])
125 | image = Image.open(image_path).convert('RGB')
126 | image = self.transform(image)
127 |
128 | return image, index
129 |
130 |
131 |
--------------------------------------------------------------------------------
/distill.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import copy
3 | import datetime
4 | import os
5 | import random
6 | import sys
7 | import warnings
8 |
9 | import numpy as np
10 | import pandas as pd
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | import torchvision.utils
15 | from sklearn.metrics.pairwise import cosine_similarity
16 | from tqdm import tqdm
17 | import math
18 |
19 | from transformers import BertTokenizer, BertConfig, BertModel
20 | import wandb
21 |
22 | from data import get_dataset_flickr, textprocess, textprocess_train
23 | from epoch import evaluate_synset, epoch, epoch_test, itm_eval
24 | from networks import CLIPModel_full, TextEncoder
25 | from reparam_module import ReparamModule
26 | from utils import DiffAugment, ParamDiffAug, TensorDataset, get_dataset, get_network, get_eval_pool, get_time, load_or_process_file
27 |
28 |
29 | def shuffle_files(img_expert_files, txt_expert_files):
30 | # Check if both lists have the same length and if the lists are not empty
31 | assert len(img_expert_files) == len(txt_expert_files), "Number of image files and text files does not match"
32 | assert len(img_expert_files) != 0, "No files to shuffle"
33 | shuffled_indices = np.random.permutation(len(img_expert_files))
34 |
35 | # Apply the shuffled indices to both lists
36 | img_expert_files = np.take(img_expert_files, shuffled_indices)
37 | txt_expert_files = np.take(txt_expert_files, shuffled_indices)
38 | print(f"img_expert_files: {img_expert_files}")
39 | print(f"txt_expert_files: {txt_expert_files}")
40 | return img_expert_files, txt_expert_files
41 |
42 | def nearest_neighbor(sentences, query_embeddings, database_embeddings):
43 | """Find the nearest neighbors for a batch of embeddings.
44 |
45 | Args:
46 | sentences: The original sentences from which the embeddings were computed.
47 | query_embeddings: A batch of embeddings for which to find the nearest neighbors.
48 | database_embeddings: All pre-computed embeddings.
49 |
50 | Returns:
51 | A list of the most similar sentences for each embedding in the batch.
52 | """
53 | nearest_neighbors = []
54 |
55 | for query in query_embeddings:
56 | similarities = cosine_similarity(query.reshape(1, -1), database_embeddings)
57 |
58 | most_similar_index = np.argmax(similarities)
59 |
60 | nearest_neighbors.append(sentences[most_similar_index])
61 |
62 | return nearest_neighbors
63 |
64 |
65 | def get_images_texts(n, dataset):
66 | """Get random n images and corresponding texts from the dataset.
67 |
68 | Args:
69 | n: Number of images and texts to retrieve.
70 | dataset: The dataset containing image-text pairs.
71 |
72 | Returns:
73 | A tuple containing two elements:
74 | - A tensor of randomly selected images.
75 | - A tensor of the corresponding texts, encoded as floats.
76 | """
77 | # Generate n unique random indices
78 | idx_shuffle = np.random.permutation(len(dataset))[:n]
79 |
80 | # Initialize the text encoder
81 | text_encoder = TextEncoder(args)
82 |
83 | image_syn = torch.stack([dataset[i][0] for i in idx_shuffle])
84 | text_syn = text_encoder([dataset[i][1] for i in idx_shuffle], device="cpu")
85 |
86 | return image_syn, text_syn.float()
87 |
88 |
89 | def main(args):
90 | ''' organize the real train dataset '''
91 | trainloader, testloader, train_dataset, test_dataset = get_dataset_flickr(args)
92 |
93 | train_sentences = train_dataset.get_all_captions()
94 |
95 | data = load_or_process_file('text', textprocess, args, testloader)
96 | train_caption = load_or_process_file('train_text', textprocess_train, args, train_sentences)
97 |
98 | bert_test_embed = torch.from_numpy(data['bert_test_embed']).cpu()
99 | print("The shape of bert_test_embed: {}".format(bert_test_embed.shape))
100 | train_caption_embed = torch.from_numpy(train_caption['bert_test_embed']).cpu()
101 | print("The shape of train_caption_embed: {}".format(train_caption_embed.shape))
102 |
103 | logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
104 | if args.zca and args.texture:
105 | raise AssertionError("Cannot use zca and texture together")
106 |
107 | if args.texture and args.pix_init == "real":
108 | print("WARNING: Using texture with real initialization will take a very long time to smooth out the boundaries between images.")
109 |
110 | if args.max_experts is not None and args.max_files is not None:
111 | args.total_experts = args.max_experts * args.max_files
112 |
113 | print("CUDNN STATUS: {}".format(torch.backends.cudnn.enabled))
114 |
115 | args.dsa = True if args.dsa == 'True' else False
116 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
117 |
118 | eval_it_pool = np.arange(0, args.Iteration + 1, args.eval_it).tolist()
119 |
120 | if args.dsa:
121 | args.dc_aug_param = None
122 |
123 | # wandb.init(mode="disabled")
124 | wandb.init(project='DatasetDistillation', entity='dataset_distillation', config=args, name=args.name)
125 |
126 | args.dsa_param = ParamDiffAug()
127 | zca_trans = args.zca_trans if args.zca else None
128 | args.zca_trans = zca_trans
129 | args.distributed = torch.cuda.device_count() > 1
130 |
131 | print('Hyper-parameters: \n', args.__dict__)
132 | syn_lr_img = torch.tensor(args.lr_teacher_img).to(args.device)
133 | syn_lr_txt = torch.tensor(args.lr_teacher_txt).to(args.device)
134 |
135 | ''' initialize the synthetic data '''
136 | image_syn, text_syn = get_images_texts(args.num_queries, train_dataset)
137 |
138 | if args.pix_init == 'noise':
139 | mean = torch.tensor([-0.0626, -0.0221, 0.0680])
140 | std = torch.tensor([1.0451, 1.0752, 1.0539])
141 | image_syn = torch.randn([args.num_queries, 3, 224, 224])
142 | for c in range(3):
143 | image_syn[:, c] = image_syn[:, c] * std[c] + mean[c]
144 | print('Initialized synthetic image from random noise')
145 |
146 | if args.txt_init == 'noise':
147 | text_syn = torch.normal(mean=-0.0094, std=0.5253, size=(args.num_queries, 768))
148 | print('Initialized synthetic text from random noise')
149 |
150 |
151 | ''' training '''
152 | image_syn = image_syn.detach().to(args.device).requires_grad_(True)
153 | optimizer_img = torch.optim.SGD([image_syn], lr=args.lr_img, momentum=0.5)
154 | optimizer_img.zero_grad()
155 |
156 | syn_lr_img = syn_lr_img.to(args.device).requires_grad_(True)
157 | syn_lr_txt = syn_lr_txt.to(args.device).requires_grad_(True)
158 | optimizer_lr = torch.optim.SGD([syn_lr_img, syn_lr_txt], lr=args.lr_lr, momentum=0.5)
159 |
160 | text_syn = text_syn.detach().to(args.device).requires_grad_(True)
161 | optimizer_txt = torch.optim.SGD([text_syn], lr=args.lr_txt, momentum=0.5)
162 | optimizer_txt.zero_grad()
163 | sentence_list = nearest_neighbor(train_sentences, text_syn.detach().cpu(), train_caption_embed)
164 | if args.draw:
165 | wandb.log({"original_sentence_list": wandb.Html(' '.join(sentence_list))})
166 | wandb.log({"original_synthetic_images": wandb.Image(torch.nan_to_num(image_syn.detach().cpu()))})
167 |
168 | criterion = nn.CrossEntropyLoss().to(args.device)
169 | print('%s training begins'%get_time())
170 |
171 | expert_dir = os.path.join(args.buffer_path, args.dataset)
172 | expert_dir = args.buffer_path
173 | print("Expert Dir: {}".format(expert_dir))
174 |
175 |
176 | img_expert_files = []
177 | txt_expert_files = []
178 | n = 0
179 | while os.path.exists(os.path.join(expert_dir, "img_replay_buffer_{}.pt".format(n))):
180 | img_expert_files.append(os.path.join(expert_dir, "img_replay_buffer_{}.pt".format(n)))
181 | txt_expert_files.append(os.path.join(expert_dir, "txt_replay_buffer_{}.pt".format(n)))
182 | n += 1
183 | if n == 0:
184 | raise AssertionError("No buffers detected at {}".format(expert_dir))
185 |
186 | img_expert_files, txt_expert_files = shuffle_files(img_expert_files, txt_expert_files)
187 |
188 | file_idx = 0
189 | expert_idx = 0
190 | print("loading file {}".format(img_expert_files[file_idx]))
191 | print("loading file {}".format(txt_expert_files[file_idx]))
192 |
193 | img_buffer = torch.load(img_expert_files[file_idx])
194 | txt_buffer = torch.load(txt_expert_files[file_idx])
195 |
196 | for it in tqdm(range(args.Iteration + 1)):
197 | save_this_it = True
198 |
199 | wandb.log({"Progress": it}, step=it)
200 | ''' Evaluate synthetic data '''
201 | if it in eval_it_pool:
202 | print('-------------------------\nEvaluation\nimage_model_train = %s, text_model_train = %s, iteration = %d'%(args.image_encoder, args.text_encoder, it))
203 | if args.dsa:
204 | print('DSA augmentation strategy: \n', args.dsa_strategy)
205 | print('DSA augmentation parameters: \n', args.dsa_param.__dict__)
206 | else:
207 | print('DC augmentation parameters: \n', args.dc_aug_param)
208 |
209 | accs_train = []
210 | img_r1s = []
211 | img_r5s = []
212 | img_r10s = []
213 | img_r_means = []
214 |
215 | txt_r1s = []
216 | txt_r5s = []
217 | txt_r10s = []
218 | txt_r_means = []
219 |
220 | r_means = []
221 | for it_eval in range(args.num_eval):
222 | net_eval = CLIPModel_full(args, eval_stage=args.transfer)
223 |
224 | with torch.no_grad():
225 | image_save = image_syn
226 | text_save = text_syn
227 | image_syn_eval, text_syn_eval = copy.deepcopy(image_save.detach()), copy.deepcopy(text_save.detach()) # avoid any unaware modification
228 |
229 | args.lr_net = syn_lr_img.item()
230 | print(image_syn_eval.shape)
231 | _, acc_train, val_result = evaluate_synset(it_eval, net_eval, image_syn_eval, text_syn_eval, testloader, args, bert_test_embed)
232 | print('Evaluate_%02d: Img R@1 = %.4f, Img R@5 = %.4f, Img R@10 = %.4f, Img R@Mean = %.4f, Txt R@1 = %.4f, Txt R@5 = %.4f, Txt R@10 = %.4f, Txt R@Mean = %.4f, R@Mean = %.4f' %
233 | (it_eval,
234 | val_result['img_r1'], val_result['img_r5'], val_result['img_r10'], val_result['img_r_mean'],
235 | val_result['txt_r1'], val_result['txt_r5'], val_result['txt_r10'], val_result['txt_r_mean'],
236 | val_result['r_mean']))
237 |
238 | img_r1s.append(val_result['img_r1'])
239 | img_r5s.append(val_result['img_r5'])
240 | img_r10s.append(val_result['img_r10'])
241 | img_r_means.append(val_result['img_r_mean'])
242 |
243 | txt_r1s.append(val_result['txt_r1'])
244 | txt_r5s.append(val_result['txt_r5'])
245 | txt_r10s.append(val_result['txt_r10'])
246 | txt_r_means.append(val_result['txt_r_mean'])
247 | r_means.append(val_result['r_mean'])
248 |
249 | if not args.std:
250 | wandb.log({"txt_r1": val_result['txt_r1']})
251 | wandb.log({"txt_r5": val_result['txt_r5']})
252 | wandb.log({"txt_r10": val_result['txt_r10']})
253 | wandb.log({"txt_r_mean": val_result['txt_r_mean']})
254 | wandb.log({"img_r1": val_result['img_r1']})
255 | wandb.log({"img_r5": val_result['img_r5']})
256 | wandb.log({"img_r10": val_result['img_r10']})
257 | wandb.log({"img_r_mean": val_result['img_r_mean']})
258 | wandb.log({"r_mean": val_result['r_mean']})
259 | if args.std:
260 | img_r1_mean, img_r1_std = np.mean(img_r1s), np.std(img_r1s)
261 | img_r5_mean, img_r5_std = np.mean(img_r5s), np.std(img_r5s)
262 | img_r10_mean, img_r10_std = np.mean(img_r10s), np.std(img_r10s)
263 | img_r_mean_mean, img_r_mean_std = np.mean(img_r_means), np.std(img_r_means)
264 |
265 | txt_r1_mean, txt_r1_std = np.mean(txt_r1s), np.std(txt_r1s)
266 | txt_r5_mean, txt_r5_std = np.mean(txt_r5s), np.std(txt_r5s)
267 | txt_r10_mean, txt_r10_std = np.mean(txt_r10s), np.std(txt_r10s)
268 | txt_r_mean_mean, txt_r_mean_std = np.mean(txt_r_means), np.std(txt_r_means)
269 | r_mean_mean, r_mean_std = np.mean(r_means), np.std(r_means)
270 |
271 | wandb.log({'Mean/txt_r1': txt_r1_mean, 'Std/txt_r1': txt_r1_std})
272 | wandb.log({'Mean/txt_r5': txt_r5_mean, 'Std/txt_r5': txt_r5_std})
273 | wandb.log({'Mean/txt_r10': txt_r10_mean, 'Std/txt_r10': txt_r10_std})
274 | wandb.log({'Mean/txt_r_mean': txt_r_mean_mean, 'Std/txt_r_mean': txt_r_mean_std})
275 | wandb.log({'Mean/img_r1': img_r1_mean, 'Std/img_r1': img_r1_std})
276 | wandb.log({'Mean/img_r5': img_r5_mean, 'Std/img_r5': img_r5_std})
277 | wandb.log({'Mean/img_r10': img_r10_mean, 'Std/img_r10': img_r10_std})
278 | wandb.log({'Mean/img_r_mean': img_r_mean_mean, 'Std/img_r_mean': img_r_mean_std})
279 | wandb.log({'Mean/r_mean': r_mean_mean, 'Std/r_mean': r_mean_std})
280 |
281 | if it in eval_it_pool and (save_this_it or it % 1000 == 0):
282 | if args.draw:
283 | with torch.no_grad():
284 | image_save = image_syn_eval.cuda()
285 | text_save = text_syn_eval.cuda()
286 | save_dir = os.path.join(".", "logged_files", args.dataset, wandb.run.name)
287 | print("Saving to {}".format(save_dir))
288 |
289 | if not os.path.exists(save_dir):
290 | os.makedirs(save_dir)
291 |
292 | #torch.save(image_save, os.path.join(save_dir, "images_{}.pt".format(it)))
293 | #torch.save(text_save, os.path.join(save_dir, "labels_{}.pt".format(it)))
294 |
295 | #torch.save(image_save, os.path.join(save_dir, "images_best.pt".format(it)))
296 | #torch.save(text_save, os.path.join(save_dir, "labels_best.pt".format(it)))
297 |
298 | wandb.log({"Pixels": wandb.Histogram(torch.nan_to_num(image_syn.detach().cpu()))}, step=it) # Move tensor to CPU before converting to NumPy
299 |
300 | if args.ipc < 50 or args.force_save:
301 | upsampled = image_save[:90]
302 | if args.dataset != "ImageNet":
303 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=2)
304 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3)
305 | grid = torchvision.utils.make_grid(upsampled, nrow=10, normalize=True, scale_each=True)
306 | sentence_list = nearest_neighbor(train_sentences, text_syn.cpu(), train_caption_embed)
307 | sentence_list = sentence_list[:90]
308 | torchvision.utils.save_image(grid, os.path.join(save_dir, "synthetic_images_{}.png".format(it)))
309 |
310 | with open(os.path.join(save_dir, "synthetic_sentences_{}.txt".format(it)), "w") as file:
311 | file.write('\n'.join(sentence_list))
312 | wandb.log({"Synthetic_Images": wandb.Image(torch.nan_to_num(grid.detach().cpu()))}, step=it)
313 | wandb.log({'Synthetic_Pixels': wandb.Histogram(torch.nan_to_num(image_save.detach().cpu()))}, step=it)
314 | wandb.log({"Synthetic_Sentences": wandb.Html(' '.join(sentence_list))}, step=it)
315 | print("finish saving images")
316 |
317 | for clip_val in [2.5]:
318 | std = torch.std(image_save)
319 | mean = torch.mean(image_save)
320 | upsampled = torch.clip(image_save, min=mean-clip_val*std, max=mean+clip_val*std).cpu() # Move to CPU
321 | if args.dataset != "ImageNet":
322 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=2)
323 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3)
324 | grid = torchvision.utils.make_grid(upsampled[:90], nrow=10, normalize=True, scale_each=True)
325 | wandb.log({"Clipped_Synthetic_Images/std_{}".format(clip_val): wandb.Image(torch.nan_to_num(grid))}, step=it)
326 | torchvision.utils.save_image(grid, os.path.join(save_dir, "clipped_synthetic_images_{}_std_{}.png".format(it, clip_val)))
327 |
328 |
329 | if args.zca:
330 | image_save = image_save.to(args.device)
331 | image_save = args.zca_trans.inverse_transform(image_save.cpu()) # Move to CPU for ZCA transformation
332 | torch.save(image_save, os.path.join(save_dir, "images_zca_{}.pt".format(it)))
333 |
334 | upsampled = image_save
335 | if args.dataset != "ImageNet":
336 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=2)
337 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3)
338 | grid = torchvision.utils.make_grid(upsampled, nrow=10, normalize=True, scale_each=True)
339 | wandb.log({"Reconstructed_Images": wandb.Image(torch.nan_to_num(grid))}, step=it) # Log GPU tensor directly
340 | wandb.log({'Reconstructed_Pixels': wandb.Histogram(torch.nan_to_num(image_save.detach().cpu()))}, step=it)
341 |
342 | for clip_val in [2.5]:
343 | std = torch.std(image_save)
344 | mean = torch.mean(image_save)
345 | upsampled = torch.clip(image_save, min=mean - clip_val * std, max=mean + clip_val * std)
346 | if args.dataset != "ImageNet":
347 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=2)
348 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3)
349 | grid = torchvision.utils.make_grid(upsampled, nrow=10, normalize=True, scale_each=True)
350 | wandb.log({"Clipped_Reconstructed_Images/std_{}".format(clip_val): wandb.Image(
351 | torch.nan_to_num(grid.detach().cpu()))}, step=it)
352 |
353 | wandb.log({"Synthetic_LR_Image": syn_lr_img.detach().cpu()}, step=it)
354 | wandb.log({"Synthetic_LR_Text": syn_lr_txt.detach().cpu()}, step=it)
355 |
356 | torch.cuda.empty_cache()
357 | student_net = CLIPModel_full(args)
358 | img_student_net = ReparamModule(student_net.image_encoder.to('cpu')).to('cuda')
359 | txt_student_net = ReparamModule(student_net.text_projection.to('cpu')).to('cuda')
360 |
361 | if args.distributed:
362 | img_student_net = torch.nn.DataParallel(img_student_net)
363 | txt_student_net = torch.nn.DataParallel(txt_student_net)
364 |
365 | img_student_net.train()
366 | txt_student_net.train()
367 | img_num_params = sum([np.prod(p.size()) for p in (img_student_net.parameters())])
368 | txt_num_params = sum([np.prod(p.size()) for p in (txt_student_net.parameters())])
369 |
370 |
371 | img_expert_trajectory = img_buffer[expert_idx]
372 | txt_expert_trajectory = txt_buffer[expert_idx]
373 | expert_idx += 1
374 | if expert_idx == len(img_buffer):
375 | expert_idx = 0
376 | file_idx += 1
377 | if file_idx == len(img_expert_files):
378 | file_idx = 0
379 | img_expert_files, txt_expert_files = shuffle_files(img_expert_files, txt_expert_files)
380 | print("loading file {}".format(img_expert_files[file_idx]))
381 | print("loading file {}".format(txt_expert_files[file_idx]))
382 | if args.max_files != 1:
383 | del img_buffer
384 | del txt_buffer
385 | img_buffer = torch.load(img_expert_files[file_idx])
386 | txt_buffer = torch.load(txt_expert_files[file_idx])
387 |
388 | start_epoch = np.random.randint(0, args.max_start_epoch)
389 | img_starting_params = img_expert_trajectory[start_epoch]
390 | txt_starting_params = txt_expert_trajectory[start_epoch]
391 |
392 | img_target_params = img_expert_trajectory[start_epoch+args.expert_epochs]
393 | txt_target_params = txt_expert_trajectory[start_epoch+args.expert_epochs]
394 |
395 | img_target_params = torch.cat([p.data.to(args.device).reshape(-1) for p in img_target_params], 0)
396 | txt_target_params = torch.cat([p.data.to(args.device).reshape(-1) for p in txt_target_params], 0)
397 |
398 | img_student_params = [torch.cat([p.data.to(args.device).reshape(-1) for p in img_starting_params], 0).requires_grad_(True)]
399 | txt_student_params = [torch.cat([p.data.to(args.device).reshape(-1) for p in txt_starting_params], 0).requires_grad_(True)]
400 |
401 | img_starting_params = torch.cat([p.data.to(args.device).reshape(-1) for p in img_starting_params], 0)
402 | txt_starting_params = torch.cat([p.data.to(args.device).reshape(-1) for p in txt_starting_params], 0)
403 | syn_images = image_syn
404 | syn_texts = text_syn
405 |
406 | img_param_loss_list = []
407 | txt_param_loss_list = []
408 |
409 | img_param_dist_list = []
410 | txt_param_dist_list = []
411 |
412 | indices_chunks = []
413 | for step in range(args.syn_steps):
414 | indices = torch.randperm(len(syn_images))
415 | these_indices = indices[:args.mini_batch_size]
416 | #these_indices = indices
417 | x = syn_images[these_indices]
418 | this_y = syn_texts[these_indices]
419 | if args.distributed:
420 | img_forward_params = img_student_params[-1].unsqueeze(0).expand(torch.cuda.device_count(), -1)
421 | txt_forward_params = txt_student_params[-1].unsqueeze(0).expand(torch.cuda.device_count(), -1)
422 | else:
423 | img_forward_params = img_student_params[-1]
424 | txt_forward_params = txt_student_params[-1]
425 |
426 | x = img_student_net(x, flat_param=img_forward_params)
427 | x = x / x.norm(dim=1, keepdim=True)
428 | this_y = txt_student_net(this_y, flat_param=txt_forward_params)
429 | this_y = this_y / this_y.norm(dim=1, keepdim=True)
430 | image_logits = logit_scale * x.float() @ this_y.float().t()
431 | ground_truth = torch.arange(len(image_logits)).type_as(image_logits).long()
432 | contrastive_loss = (F.cross_entropy(image_logits, ground_truth) + F.cross_entropy(image_logits.t(), ground_truth))/2
433 |
434 | img_grad = torch.autograd.grad(contrastive_loss, img_student_params[-1], create_graph=True)[0]
435 | txt_grad = torch.autograd.grad(contrastive_loss, txt_student_params[-1], create_graph=True)[0]
436 | print(contrastive_loss)
437 | img_student_params.append(img_student_params[-1] - syn_lr_img * img_grad)
438 | txt_student_params.append(txt_student_params[-1] - syn_lr_txt * txt_grad)
439 | img_param_loss = torch.tensor(0.0).to(args.device)
440 | img_param_dist = torch.tensor(0.0).to(args.device)
441 | txt_param_loss = torch.tensor(0.0).to(args.device)
442 | txt_param_dist = torch.tensor(0.0).to(args.device)
443 |
444 |
445 | img_param_loss += torch.nn.functional.mse_loss(img_student_params[-1], img_target_params, reduction="sum")
446 | img_param_dist += torch.nn.functional.mse_loss(img_starting_params, img_target_params, reduction="sum")
447 | txt_param_loss += torch.nn.functional.mse_loss(txt_student_params[-1], txt_target_params, reduction="sum")
448 | txt_param_dist += torch.nn.functional.mse_loss(txt_starting_params, txt_target_params, reduction="sum")
449 |
450 | img_param_loss_list.append(img_param_loss)
451 | img_param_dist_list.append(img_param_dist)
452 | txt_param_loss_list.append(txt_param_loss)
453 | txt_param_dist_list.append(txt_param_dist)
454 |
455 |
456 | img_param_loss /= img_param_dist
457 | txt_param_loss /= txt_param_dist
458 | grand_loss = img_param_loss + txt_param_loss
459 |
460 | if math.isnan(img_param_loss):
461 | break
462 | print("img_param_loss: {}".format(img_param_loss))
463 | print("txt_param_loss: {}".format(txt_param_loss))
464 |
465 | optimizer_lr.zero_grad()
466 | optimizer_img.zero_grad()
467 | optimizer_txt.zero_grad()
468 |
469 | grand_loss.backward()
470 | # clip_value = 0.5
471 |
472 | #torch.nn.utils.clip_grad_norm_([image_syn], clip_value)
473 | #torch.nn.utils.clip_grad_norm_([text_syn], clip_value)
474 | #torch.nn.utils.clip_grad_norm_([syn_lr_img], clip_value)
475 | #torch.nn.utils.clip_grad_norm_([syn_lr_txt], clip_value)
476 | print("syn_lr_img: {}".format(syn_lr_img.grad))
477 | print("syn_lr_txt: {}".format(syn_lr_txt.grad))
478 | wandb.log({"syn_lr_img": syn_lr_img.grad.detach().cpu()}, step=it)
479 | wandb.log({"syn_lr_txt": syn_lr_txt.grad.detach().cpu()}, step=it)
480 |
481 | optimizer_lr.step()
482 | optimizer_img.step()
483 | optimizer_txt.step()
484 |
485 | wandb.log({"Grand_Loss": grand_loss.detach().cpu(),
486 | "Start_Epoch": start_epoch})
487 |
488 | for _ in img_student_params:
489 | del _
490 | for _ in txt_student_params:
491 | del _
492 |
493 | if it%10 == 0:
494 | print('%s iter = %04d, loss = %.4f' % (get_time(), it, grand_loss.item()))
495 |
496 | wandb.finish()
497 |
498 |
499 | if __name__ == '__main__':
500 | parser = argparse.ArgumentParser(description='Parameter Processing')
501 |
502 | parser.add_argument('--dataset', type=str, default='flickr30k', help='dataset')
503 | parser.add_argument('--ipc', type=int, default=1, help='image(s) per class')
504 |
505 | parser.add_argument('--eval_mode', type=str, default='S',
506 | help='eval_mode, check utils.py for more info')
507 |
508 | parser.add_argument('--num_eval', type=int, default=5, help='how many networks to evaluate on')
509 |
510 | parser.add_argument('--eval_it', type=int, default=50, help='how often to evaluate')
511 |
512 | parser.add_argument('--epoch_eval_train', type=int, default=50, help='epochs to train a model with synthetic data')
513 | parser.add_argument('--Iteration', type=int, default=50000, help='how many distillation steps to perform')
514 |
515 | parser.add_argument('--lr_img', type=float, default=1000, help='learning rate for updating synthetic images')
516 | parser.add_argument('--lr_txt', type=float, default=1000, help='learning rate for updating synthetic texts')
517 | parser.add_argument('--lr_lr', type=float, default=1e-03, help='learning rate for updating... learning rate')
518 | parser.add_argument('--lr_teacher_img', type=float, default=0.1, help='learning rate for updating network parameters')
519 | parser.add_argument('--lr_teacher_txt', type=float, default=0.1, help='learning rate for updating network parameters')
520 |
521 | parser.add_argument('--batch_train', type=int, default=64, help='batch size for training networks')
522 |
523 | parser.add_argument('--pix_init', type=str, default='real', choices=["noise", "real"],
524 | help='noise/real: initialize synthetic images from random noise or randomly sampled real images.')
525 | parser.add_argument('--txt_init', type=str, default='real', choices=["noise", "real"],
526 | help='noise/real: initialize synthetic texts from random noise or randomly sampled real images.')
527 |
528 | parser.add_argument('--dsa', type=str, default='True', choices=['True', 'False'],
529 | help='whether to use differentiable Siamese augmentation.')
530 |
531 | parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate',
532 | help='differentiable Siamese augmentation strategy')
533 |
534 | parser.add_argument('--data_path', type=str, default='./data/Flickr30k/', help='dataset path')
535 | parser.add_argument('--buffer_path', type=str, default='./buffers', help='buffer path')
536 |
537 | parser.add_argument('--expert_epochs', type=int, default=3, help='how many expert epochs the target params are')
538 | parser.add_argument('--syn_steps', type=int, default=20, help='how many steps to take on synthetic data')
539 | parser.add_argument('--max_start_epoch', type=int, default=25, help='max epoch we can start at')
540 |
541 | parser.add_argument('--zca', action='store_true', help="do ZCA whitening")
542 |
543 | parser.add_argument('--load_all', action='store_true', help="only use if you can fit all expert trajectories into RAM")
544 |
545 | parser.add_argument('--no_aug', type=bool, default=False, help='this turns off diff aug during distillation')
546 |
547 | parser.add_argument('--texture', action='store_true', help="will distill textures instead")
548 | parser.add_argument('--canvas_size', type=int, default=2, help='size of synthetic canvas')
549 | parser.add_argument('--canvas_samples', type=int, default=1, help='number of canvas samples per iteration')
550 |
551 |
552 | parser.add_argument('--max_files', type=int, default=None, help='number of expert files to read (leave as None unless doing ablations)')
553 | parser.add_argument('--max_experts', type=int, default=None, help='number of experts to read per file (leave as None unless doing ablations)')
554 |
555 | parser.add_argument('--force_save', action='store_true', help='this will save images for 50ipc')
556 | current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
557 | parser.add_argument('--name', type=str, default=current_time, help='name of wandb run')
558 | parser.add_argument('--num_queries', type=int, default=100, help='number of queries')
559 | parser.add_argument('--mini_batch_size', type=int, default=100, help='number of queries')
560 | parser.add_argument('--basis', type=bool, default=False, help='whether use basis or not')
561 | parser.add_argument('--n_basis', type=int, default=64, help='n_basis')
562 | parser.add_argument('--recursive', type=bool, default=False, help='whether use basis or not')
563 | parser.add_argument('--load_npy', type=bool, default=False, help='load_npy')
564 | parser.add_argument('--image_size', type=int, default=224, help='image_size')
565 | parser.add_argument('--image_root', type=str, default='./Flickr30k/flickr-image-dataset/flickr30k-images/', help='location of image root')
566 | parser.add_argument('--ann_root', type=str, default='./Flickr30k/ann_file/', help='location of ann root')
567 | parser.add_argument('--batch_size_train', type=int, default=128, help='batch_size_train')
568 | parser.add_argument('--batch_size_test', type=int, default=128, help='batch_size_test')
569 | parser.add_argument('--image_encoder', type=str, default='nfnet', choices=['clip', 'nfnet', 'vit', 'nf_resnet50'], help='image encoder')
570 | parser.add_argument('--text_encoder', type=str, default='bert', choices=['bert', 'clip'], help='text encoder')
571 | parser.add_argument('--text_pretrained', type=bool, default=True, help='text_pretrained')
572 | parser.add_argument('--image_pretrained', type=bool, default=True, help='image_pretrained')
573 | parser.add_argument('--text_trainable', type=bool, default=False, help='text_trainable')
574 | parser.add_argument('--image_trainable', type=bool, default=True, help='image_trainable')
575 | parser.add_argument('--only_has_image_projection', type=bool, default=False, help='None')
576 | parser.add_argument('--distill', type=bool, default=True, help='whether distill')
577 | parser.add_argument('--optimize', type=str, default='reparam', choices=['reparam', 'ift'], help='matching_train')
578 | parser.add_argument('--image_only', type=bool, default=False, help='None')
579 | parser.add_argument('--text_only', type=bool, default=False, help='None')
580 | parser.add_argument('--draw', type=bool, default=True, help='None')
581 | parser.add_argument('--transfer', type=bool, default=False, help='transfer cross architecture')
582 | parser.add_argument('--std', type=bool, default=False, help='standard deviation')
583 | args = parser.parse_args()
584 |
585 | main(args)
--------------------------------------------------------------------------------
/epoch.py:
--------------------------------------------------------------------------------
1 | '''
2 | * part of the code (i.e. def epoch_test() and itm_eval()) is from: https://github.com/salesforce/BLIP/blob/main/train_retrieval.py#L69
3 | * Copyright (c) 2022, salesforce.com, inc.
4 | * All rights reserved.
5 | * SPDX-License-Identifier: BSD-3-Clause
6 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
7 | * By Junnan Li
8 | '''
9 | import numpy as np
10 | import torch
11 | import time
12 | import datetime
13 | import torch
14 | import torch.nn.functional as F
15 | from tqdm import tqdm
16 | import torch.nn as nn
17 | from utils import *
18 |
19 |
20 | def epoch(e, dataloader, net, optimizer_img, optimizer_txt, args):
21 | """
22 | Perform a training epoch on the given dataloader.
23 |
24 | Args:
25 | dataloader (torch.utils.data.DataLoader): The dataloader for iterating over the dataset.
26 | net: The model.
27 | optimizer_img: The optimizer for image parameters.
28 | optimizer_txt: The optimizer for text parameters.
29 | args (object): The arguments specifying the training configuration.
30 |
31 | Returns:
32 | Tuple of average loss and average accuracy.
33 | """
34 | net = net.to(args.device)
35 | net.train()
36 | loss_avg, acc_avg, num_exp = 0, 0, 0
37 |
38 | for i, data in tqdm(enumerate(dataloader)):
39 | if args.distill:
40 | image, caption = data[:2]
41 | else:
42 | image, caption, index = data[:3]
43 |
44 | image = image.to(args.device)
45 | n_b = image.shape[0]
46 |
47 | loss, acc = net(image, caption, e)
48 |
49 | loss_avg += loss.item() * n_b
50 | acc_avg += acc
51 | num_exp += n_b
52 |
53 | optimizer_img.zero_grad()
54 | optimizer_txt.zero_grad()
55 | loss.backward()
56 | optimizer_img.step()
57 | optimizer_txt.step()
58 |
59 | loss_avg /= num_exp
60 | acc_avg /= num_exp
61 |
62 | return loss_avg, acc_avg
63 |
64 |
65 |
66 |
67 | @torch.no_grad()
68 | def epoch_test(dataloader, model, device, bert_test_embed):
69 | model.eval()
70 | logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
71 | metric_logger = MetricLogger(delimiter=" ")
72 | header = 'Evaluation:'
73 | print('Computing features for evaluation...')
74 | start_time = time.time()
75 |
76 |
77 | txt_embed = model.text_projection(bert_test_embed.float().to('cuda'))
78 | text_embeds = txt_embed / txt_embed.norm(dim=1, keepdim=True) #torch.Size([5000, 768])
79 | text_embeds = text_embeds.to(device)
80 |
81 | image_embeds = []
82 | for image, img_id in dataloader:
83 | image_feat = model.image_encoder(image.to(device))
84 | im_embed = image_feat / image_feat.norm(dim=1, keepdim=True)
85 | image_embeds.append(im_embed)
86 | image_embeds = torch.cat(image_embeds,dim=0)
87 | use_image_projection = False
88 | if use_image_projection:
89 | im_embed = model.image_projection(image_embeds.float())
90 | image_embeds = im_embed / im_embed.norm(dim=1, keepdim=True)
91 | else:
92 | image_embeds = image_embeds / image_embeds.norm(dim=1, keepdim=True)
93 |
94 | sims_matrix = logit_scale.exp() * image_embeds @ text_embeds.t()
95 | score_matrix_i2t = torch.full((len(image_embeds),len(text_embeds)),-100.0).to(device) #torch.Size([1000, 5000])
96 | #for i, sims in enumerate(metric_logger.log_every(sims_matrix[0:sims_matrix.size(0) + 1], 50, header)):
97 | for i, sims in enumerate(sims_matrix[0:sims_matrix.size(0) + 1]):
98 | topk_sim, topk_idx = sims.topk(k=128, dim=0)
99 | score_matrix_i2t[i,topk_idx] = topk_sim #i:0-999, topk_idx:0-4999, find top k (k=128) similar text for each image
100 |
101 | sims_matrix = sims_matrix.t()
102 | score_matrix_t2i = torch.full((len(text_embeds),len(image_embeds)),-100.0).to(device)
103 | for i,sims in enumerate(sims_matrix[0:sims_matrix.size(0) + 1]):
104 | topk_sim, topk_idx = sims.topk(k=128, dim=0)
105 | score_matrix_t2i[i,topk_idx] = topk_sim
106 |
107 | total_time = time.time() - start_time
108 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
109 | print('Evaluation time {}'.format(total_time_str))
110 |
111 | return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
112 |
113 |
114 | @torch.no_grad()
115 | def itm_eval(scores_i2t, scores_t2i, txt2img, img2txt):
116 |
117 | #Images->Text
118 | ranks = np.zeros(scores_i2t.shape[0])
119 | print("TR: ", len(ranks))
120 | for index, score in enumerate(scores_i2t):
121 | inds = np.argsort(score)[::-1]
122 | # Score
123 | rank = 1e20
124 | for i in img2txt[index]:
125 | tmp = np.where(inds == i)[0][0]
126 | if tmp < rank:
127 | rank = tmp
128 | ranks[index] = rank
129 |
130 | # Compute metrics
131 | tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
132 | tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
133 | tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
134 |
135 | #Text->Images
136 | ranks = np.zeros(scores_t2i.shape[0])
137 | print("IR: ", len(ranks))
138 |
139 | for index,score in enumerate(scores_t2i):
140 | inds = np.argsort(score)[::-1]
141 | ranks[index] = np.where(inds == txt2img[index])[0][0]
142 |
143 | # Compute metrics
144 | ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
145 | ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
146 | ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
147 |
148 | tr_mean = (tr1 + tr5 + tr10) / 3
149 | ir_mean = (ir1 + ir5 + ir10) / 3
150 | r_mean = (tr_mean + ir_mean) / 2
151 |
152 | eval_result = {'txt_r1': tr1,
153 | 'txt_r5': tr5,
154 | 'txt_r10': tr10,
155 | 'txt_r_mean': tr_mean,
156 | 'img_r1': ir1,
157 | 'img_r5': ir5,
158 | 'img_r10': ir10,
159 | 'img_r_mean': ir_mean,
160 | 'r_mean': r_mean}
161 | return eval_result
162 |
163 |
164 | def evaluate_synset(it_eval, net, images_train, labels_train, testloader, args, bert_test_embed, return_loss=False):
165 |
166 | net = net.to(args.device)
167 | images_train = images_train.to(args.device)
168 | labels_train = labels_train.to(args.device)
169 | lr = float(args.lr_net)
170 | Epoch = int(args.epoch_eval_train)
171 | lr_schedule = [Epoch//2+1]
172 | optimizer_img = torch.optim.SGD(net.image_encoder.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)
173 | optimizer_txt = torch.optim.SGD(net.text_projection.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)
174 |
175 | dst_train = TensorDataset(images_train, labels_train)
176 | trainloader = torch.utils.data.DataLoader(dst_train, batch_size=args.batch_train, shuffle=True, num_workers=0)
177 |
178 | start = time.time()
179 | acc_train_list = []
180 | loss_train_list = []
181 |
182 | for ep in tqdm(range(Epoch+1)):
183 | loss_train, acc_train = epoch(ep, trainloader, net, optimizer_img, optimizer_txt, args)
184 | acc_train_list.append(acc_train)
185 | loss_train_list.append(loss_train)
186 | if ep == Epoch:
187 | with torch.no_grad():
188 | score_val_i2t, score_val_t2i = epoch_test(testloader, net, args.device, bert_test_embed)
189 | val_result = itm_eval(score_val_i2t, score_val_t2i, testloader.dataset.txt2img, testloader.dataset.img2txt)
190 | lr *= 0.1
191 | optimizer_img = torch.optim.SGD(net.image_encoder.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)
192 | optimizer_txt = torch.optim.SGD(net.text_projection.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)
193 |
194 | time_train = time.time() - start
195 | return net, acc_train_list, val_result
196 |
--------------------------------------------------------------------------------
/images/clip_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/images/clip_0.png
--------------------------------------------------------------------------------
/images/clip_2000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/images/clip_2000.png
--------------------------------------------------------------------------------
/images/coco_syn_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/images/coco_syn_0.png
--------------------------------------------------------------------------------
/images/coco_syn_2000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/images/coco_syn_2000.png
--------------------------------------------------------------------------------
/images/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/images/logo.png
--------------------------------------------------------------------------------
/images/loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/images/loss.png
--------------------------------------------------------------------------------
/images/more_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/images/more_vis.png
--------------------------------------------------------------------------------
/images/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/images/pipeline.png
--------------------------------------------------------------------------------
/images/synthetic_images_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/images/synthetic_images_0.png
--------------------------------------------------------------------------------
/images/synthetic_images_2000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/images/synthetic_images_2000.png
--------------------------------------------------------------------------------
/images/table.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/images/table.png
--------------------------------------------------------------------------------
/images/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/images/teaser.png
--------------------------------------------------------------------------------
/images/text_noise_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/images/text_noise_0.png
--------------------------------------------------------------------------------
/images/text_noise_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/images/text_noise_1.png
--------------------------------------------------------------------------------
/images/visualization.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/images/visualization.png
--------------------------------------------------------------------------------
/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 | Vision-Language Dataset Distillation
21 |
22 |
23 |
24 |
25 |
26 |
27 | Vision-Language Dataset Distillation
28 |
29 |
30 |
31 |
32 |
48 |
49 |
50 |
51 |
52 | Princeton University1
53 |
54 |
55 | Google Research2
56 |
57 |
58 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 | Abstract
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 | Dataset distillation methods offer the promise of reducing a large-scale dataset down to a significantly smaller set of (potentially synthetic) training examples, which preserve sufficient information for training a new model from scratch. So far dataset distillation methods have been developed for image classification . However, with the rise in capabilities of vision-language models , and especially given the scale of datasets necessary to train these models, the time is ripe to expand dataset distillation methods beyond image classification. In this work, we take the first steps towards this goal by expanding on the idea of trajectory matching to create a distillation method for vision-language datasets. The key challenge is that vision-language datasets do not have a set of discrete classes. To overcome this, our proposed vision-and-language dataset distillation method jointly distill the images and their corresponding language descriptions in a contrastive formulation. Since there are no existing baselines, we compare our approach to three coreset selection methods (strategic subsampling of the training dataset), which we adapt to the vision-language setting. We demonstrate significant improvements on the challenging Flickr30K and COCO retrieval benchmarks: for example, on Flickr30K the best coreset selection method which selects 1000 image-text pairs for training is able to achieve only 5.6% image-to-text retrieval accuracy (i.e., recall@1); in contrast, our dataset distillation approach almost doubles that to 9.9% with just 100 (an order of magnitude fewer) training pairs.
101 |
102 |
103 |
104 |
105 | Bi-Trajectory-Guided Co-Distillation
106 |
107 |
108 | Dataset distillation traditionally focuses on classification tasks with distinct labels, creating compact distilled datasets for efficient learning.
109 | We've expanded this to a multimodal approach, distilling both vision and language data, emphasizing their interrelation.
110 | Unlike simple classification, our method captures complex connections between image and text data.
111 | It is worth noting that this would be impossible if we solely optimize a single modality, which is supported by our single-modality distillation results.
112 |
113 | The approach consists of two stages:
114 |
115 |
116 |
117 | Obtaining the expert training trajectories \( \{\tau^*\} \), with each trajectory \( \tau^* = \{\theta^*_t\}_{t=0}^T \), by training multiple models for \( T \) epochs on the full dataset \( \mathbf{D} \). For our multimodal setting, the models are trained using bidirectional contrastive loss .
118 |
119 |
120 | Training a set of student models on the current distilled dataset \( \hat{\mathbf{D}} \) using the same bidirectional contrastive loss, and then updating the distilled dataset \( \hat{\mathbf{D}} \) based on the multimodal trajectory matching loss of the student models' parameters and the optimal \( \theta^* \).
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 | Vision-Language Bi-Trajectory Matching
135 |
136 |
137 |
138 |
139 |
140 | Following the MTT formulation, we randomly sample \( M \) image-text pairs from \( \mathbf{D} \) to initialize the distilled dataset \( \mathbf{\hat{D}} \) (more details can be found elsewhere). We sample an expert trajectory (i.e., the trajectory of a model trained on the full dataset) \( \tau^* = \{\theta^*_t\}_{t=0}^T \) and a random starting epoch \( s \) to initialize \( \hat{\theta}_s = \theta^*_s \).
141 |
142 | We train the student model on the distilled dataset for \( \hat{R} \) steps to obtain \( \hat{\theta}_{s+\hat{R}} \). We then update the distilled dataset based on multimodal trajectory matching loss \( \ell_{trajectory} \) computed on the accumulated difference between student trajectory and expert trajectory:
143 |
144 | $$
145 | \ell_{trajectory} = \frac{\|\hat{\theta}_{img, s+\hat{R}} - \theta^*_{img, s+R}\|_2^2}{\|\theta^*_{img, s} - \theta^*_{img, s+R}\|_2^2} + \frac{\|\hat{\theta}_{txt, s+\hat{R}} - \theta^*_{txt, s+R}\|_2^2}{\|\theta^*_{txt, s} - \theta^*_{txt, s+R}\|_2^2}.
146 | $$
147 |
148 | We update the distilled dataset by back-propagating through multiple (\( \hat{R} \)) gradient descent updates to the \( \hat{\mathbf{D}} \), specifically, image pixel space and text embedding space. We initialize the continuous sentence embeddings using a pretrained BERT model and update the distilled text in the continuous embedding space. For the distilled image optimization, we directly update the pixel values of the distilled images.
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 | Results
163 |
164 |
165 | We compare our distillation method to four coreset selection methods: random selection of training examples, herding, k-center and forgetting. We consider different selected sizes (100, 200, 500, and 1000) and report the image-to-text (TR) and text-to-image (IR) retrieval performance on the Flickr30K dataset in Table A.
166 |
167 | We also provide ablation study on the selection of vision (Table B) and language (Table C) backbones. We introduce the Performance Recovery Ratio (PRR) to evaluate the effectiveness of dataset distillation. It quantifies the percentage of performance retained from the original data. The performance for various backbone combinations is shown in Table D.
168 |
169 |
170 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 | Visualization
185 | Left : The image and text pairs before the distillation. Right : The image and text pairs after 2000 distillation steps. Note that the texts visualized here are nearest sentence decodings in the training set corresponding to the distilled text embeddings.
186 |
187 |
188 | Here we include a number of visualizations of the data we distilled from the multimodal dataset (both Flickr30K and COCO) for a more intuitive understanding of the distilled set.
189 | We provide 50 distilled image-text paired examples including their visualization before the distillation process.
190 | Those experiments are conducted using 100 distilled pairs, with pretrained NFNet and BERT as backbones and the synthetic step is set to 8 during distillation.
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 | Conclusion
209 |
210 |
211 | In this work, we propose a multimodal dataset distillation method for the image-text retrieval task.
212 | By co-distilling both the vision and language modalities, we can progressively optimize and distill the most critical information.
213 | Our experiments show that co-distilling different modalities via trajectory matching holds promise.
214 | We hope that the insights we gathered can be a roadmap for future studies exploring more complex settings, and that our work lays the groundwork for future research aimed at understanding what is the minimum information required for a vision-language model to achieve comparable performance quickly, thereby building a better understanding of the compositionality of compact visual-linguistic knowledge.
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 | Paper
227 |
228 |
229 | Xindi Wu, Byron Zhang, Zhiwei Deng, Olga Russakovsky.
230 | Vision-Language Dataset Distillation
231 | TMLR 2024.
232 | [Arxiv]
233 | [OpenReview]
234 | [Code]
235 | [Video]
236 | [Slides]
237 | [Poster]
238 |
239 |
240 |
241 |
242 |
243 |
251 |
252 |
253 |
254 | Copy to Clipboard
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 | Acknowledgements
270 |
271 | This material is based upon work supported by the National Science Foundation under Grant No. 2107048. Any opinions, findings, and conclusions or recommendations expressed in this material are those of the author(s) and do not necessarily reflect the views of the National Science Foundation. We thank many people from Princeton Visual AI lab (Allison Chen, Jihoon Chung, Tyler Zhu, Ye Zhu, William Yang and Kaiqu Liang) and Princeton NLP group (Carlos E. Jimenez, John Yang), Tiffany Ling and George Cazenavette for their helpful feedback on this work.
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | # Sourced directly from OpenAI's CLIP repo
2 | from collections import OrderedDict
3 | from typing import Tuple, Union
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn.functional as F
8 | from torch import nn
9 |
10 |
11 | class Bottleneck(nn.Module):
12 | expansion = 4
13 |
14 | def __init__(self, inplanes, planes, stride=1):
15 | super().__init__()
16 |
17 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
18 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
19 | self.bn1 = nn.BatchNorm2d(planes)
20 |
21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22 | self.bn2 = nn.BatchNorm2d(planes)
23 |
24 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
25 |
26 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
27 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
28 |
29 | self.relu = nn.ReLU(inplace=True)
30 | self.downsample = None
31 | self.stride = stride
32 |
33 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
34 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
35 | self.downsample = nn.Sequential(OrderedDict([
36 | ("-1", nn.AvgPool2d(stride)),
37 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
38 | ("1", nn.BatchNorm2d(planes * self.expansion))
39 | ]))
40 |
41 | def forward(self, x: torch.Tensor):
42 | identity = x
43 |
44 | out = self.relu(self.bn1(self.conv1(x)))
45 | out = self.relu(self.bn2(self.conv2(out)))
46 | out = self.avgpool(out)
47 | out = self.bn3(self.conv3(out))
48 |
49 | if self.downsample is not None:
50 | identity = self.downsample(x)
51 |
52 | out += identity
53 | out = self.relu(out)
54 | return out
55 |
56 |
57 | class AttentionPool2d(nn.Module):
58 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
59 | super().__init__()
60 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
61 | self.k_proj = nn.Linear(embed_dim, embed_dim)
62 | self.q_proj = nn.Linear(embed_dim, embed_dim)
63 | self.v_proj = nn.Linear(embed_dim, embed_dim)
64 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
65 | self.num_heads = num_heads
66 |
67 | def forward(self, x):
68 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
69 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
70 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
71 | x, _ = F.multi_head_attention_forward(
72 | query=x, key=x, value=x,
73 | embed_dim_to_check=x.shape[-1],
74 | num_heads=self.num_heads,
75 | q_proj_weight=self.q_proj.weight,
76 | k_proj_weight=self.k_proj.weight,
77 | v_proj_weight=self.v_proj.weight,
78 | in_proj_weight=None,
79 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
80 | bias_k=None,
81 | bias_v=None,
82 | add_zero_attn=False,
83 | dropout_p=0,
84 | out_proj_weight=self.c_proj.weight,
85 | out_proj_bias=self.c_proj.bias,
86 | use_separate_proj_weight=True,
87 | training=self.training,
88 | need_weights=False
89 | )
90 |
91 | return x[0]
92 |
93 |
94 |
95 | class LayerNorm(nn.LayerNorm):
96 | """Subclass torch's LayerNorm to handle fp16."""
97 |
98 | def forward(self, x: torch.Tensor):
99 | orig_type = x.dtype
100 | ret = super().forward(x.type(torch.float32))
101 | return ret.type(orig_type)
102 |
103 |
104 | class QuickGELU(nn.Module):
105 | def forward(self, x: torch.Tensor):
106 | return x * torch.sigmoid(1.702 * x)
107 |
108 |
109 | class ResidualAttentionBlock(nn.Module):
110 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
111 | super().__init__()
112 |
113 | self.attn = nn.MultiheadAttention(d_model, n_head)
114 | self.ln_1 = LayerNorm(d_model)
115 | self.mlp = nn.Sequential(OrderedDict([
116 | ("c_fc", nn.Linear(d_model, d_model * 4)),
117 | ("gelu", QuickGELU()),
118 | ("c_proj", nn.Linear(d_model * 4, d_model))
119 | ]))
120 | self.ln_2 = LayerNorm(d_model)
121 | self.attn_mask = attn_mask
122 |
123 | def attention(self, x: torch.Tensor):
124 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
125 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
126 |
127 | def forward(self, x: torch.Tensor):
128 | x = x + self.attention(self.ln_1(x))
129 | x = x + self.mlp(self.ln_2(x))
130 | return x
131 |
132 |
133 |
134 | def convert_weights(model: nn.Module):
135 | """Convert applicable model parameters to fp16"""
136 |
137 | def _convert_weights_to_fp16(l):
138 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
139 | l.weight.data = l.weight.data.half()
140 | if l.bias is not None:
141 | l.bias.data = l.bias.data.half()
142 |
143 | if isinstance(l, nn.MultiheadAttention):
144 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
145 | tensor = getattr(l, attr)
146 | if tensor is not None:
147 | tensor.data = tensor.data.half()
148 |
149 | for name in ["text_projection", "proj"]:
150 | if hasattr(l, name):
151 | attr = getattr(l, name)
152 | if attr is not None:
153 | attr.data = attr.data.half()
154 |
155 | model.apply(_convert_weights_to_fp16)
156 |
157 |
158 | def build_model(state_dict: dict):
159 | vit = "visual.proj" in state_dict
160 |
161 | if vit:
162 | vision_width = state_dict["visual.conv1.weight"].shape[0]
163 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
164 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
165 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
166 | image_resolution = vision_patch_size * grid_size
167 | else:
168 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
169 | vision_layers = tuple(counts)
170 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
171 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
172 | vision_patch_size = None
173 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
174 | image_resolution = output_width * 32
175 |
176 | embed_dim = state_dict["text_projection"].shape[1]
177 | context_length = state_dict["positional_embedding"].shape[0]
178 | vocab_size = state_dict["token_embedding.weight"].shape[0]
179 | transformer_width = state_dict["ln_final.weight"].shape[0]
180 | transformer_heads = transformer_width // 64
181 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
182 |
183 | model = CLIP(
184 | embed_dim,
185 | image_resolution, vision_layers, vision_width, vision_patch_size,
186 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
187 | )
188 |
189 | for key in ["input_resolution", "context_length", "vocab_size"]:
190 | if key in state_dict:
191 | del state_dict[key]
192 |
193 | convert_weights(model)
194 | model.load_state_dict(state_dict)
195 | return model.eval()
--------------------------------------------------------------------------------
/networks.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch
4 | from collections import OrderedDict
5 | from typing import Tuple, Union
6 | import clip
7 | from transformers import ViTConfig, ViTModel, AutoTokenizer, CLIPTextModel, CLIPTextConfig, CLIPProcessor, CLIPConfig
8 | import numpy as np
9 | from transformers import BertTokenizer, BertModel
10 | from torchvision.models import resnet18, resnet
11 | from transformers.models.bert.modeling_bert import BertAttention, BertConfig
12 |
13 | tokenizer=BertTokenizer.from_pretrained('bert-base-uncased')
14 | BERT_model = BertModel.from_pretrained('bert-base-uncased')
15 |
16 |
17 | # Acknowledgement to
18 | # https://github.com/kuangliu/pytorch-cifar,
19 | # https://github.com/BIGBALLON/CIFAR-ZOO,
20 |
21 | # adapted from
22 | # https://github.com/VICO-UoE/DatasetCondensation
23 | # https://github.com/Zasder3/train-CLIP
24 |
25 |
26 | ''' MLP '''
27 | class MLP(nn.Module):
28 | def __init__(self, channel, num_classes):
29 | super(MLP, self).__init__()
30 | self.fc_1 = nn.Linear(28*28*1 if channel==1 else 32*32*3, 128)
31 | self.fc_2 = nn.Linear(128, 128)
32 | self.fc_3 = nn.Linear(128, num_classes)
33 |
34 | def forward(self, x):
35 | out = x.view(x.size(0), -1)
36 | out = F.relu(self.fc_1(out))
37 | out = F.relu(self.fc_2(out))
38 | out = self.fc_3(out)
39 | return out
40 |
41 |
42 |
43 | ''' ConvNet '''
44 | class ConvNet(nn.Module):
45 | def __init__(self, channel, num_classes, net_width=128, net_depth=4, net_act='relu', net_norm='instancenorm', net_pooling='avgpooling', im_size = (224,224)):
46 | super(ConvNet, self).__init__()
47 |
48 | self.features, shape_feat = self._make_layers(channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size)
49 | num_feat = shape_feat[0]*shape_feat[1]*shape_feat[2]
50 | self.classifier = nn.Linear(num_feat, num_classes)
51 |
52 | def forward(self, x):
53 | # print("MODEL DATA ON: ", x.get_device(), "MODEL PARAMS ON: ", self.classifier.weight.data.get_device())
54 | out = self.features(x)
55 | out = out.view(out.size(0), -1)
56 | out = self.classifier(out)
57 | return out
58 |
59 | def _get_activation(self, net_act):
60 | if net_act == 'sigmoid':
61 | return nn.Sigmoid()
62 | elif net_act == 'relu':
63 | return nn.ReLU(inplace=True)
64 | elif net_act == 'leakyrelu':
65 | return nn.LeakyReLU(negative_slope=0.01)
66 | else:
67 | exit('unknown activation function: %s'%net_act)
68 |
69 | def _get_pooling(self, net_pooling):
70 | if net_pooling == 'maxpooling':
71 | return nn.MaxPool2d(kernel_size=2, stride=2)
72 | elif net_pooling == 'avgpooling':
73 | return nn.AvgPool2d(kernel_size=2, stride=2)
74 | elif net_pooling == 'none':
75 | return None
76 | else:
77 | exit('unknown net_pooling: %s'%net_pooling)
78 |
79 | def _get_normlayer(self, net_norm, shape_feat):
80 | # shape_feat = (c*h*w)
81 | if net_norm == 'batchnorm':
82 | return nn.BatchNorm2d(shape_feat[0], affine=True)
83 | elif net_norm == 'layernorm':
84 | return nn.LayerNorm(shape_feat, elementwise_affine=True)
85 | elif net_norm == 'instancenorm':
86 | return nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True)
87 | elif net_norm == 'groupnorm':
88 | return nn.GroupNorm(4, shape_feat[0], affine=True)
89 | elif net_norm == 'none':
90 | return None
91 | else:
92 | exit('unknown net_norm: %s'%net_norm)
93 |
94 | def _make_layers(self, channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size):
95 | layers = []
96 | in_channels = channel
97 | if im_size[0] == 28:
98 | im_size = (32, 32)
99 | shape_feat = [in_channels, im_size[0], im_size[1]]
100 | for d in range(net_depth):
101 | layers += [nn.Conv2d(in_channels, net_width, kernel_size=3, padding=3 if channel == 1 and d == 0 else 1)]
102 | shape_feat[0] = net_width
103 | if net_norm != 'none':
104 | layers += [self._get_normlayer(net_norm, shape_feat)]
105 | layers += [self._get_activation(net_act)]
106 | in_channels = net_width
107 | if net_pooling != 'none':
108 | layers += [self._get_pooling(net_pooling)]
109 | shape_feat[1] //= 2
110 | shape_feat[2] //= 2
111 |
112 |
113 | return nn.Sequential(*layers), shape_feat
114 |
115 |
116 | ''' ConvNet '''
117 | class ConvNetGAP(nn.Module):
118 | def __init__(self, channel, num_classes, net_width, net_depth, net_act, net_norm, net_pooling, im_size = (32,32)):
119 | super(ConvNetGAP, self).__init__()
120 |
121 | self.features, shape_feat = self._make_layers(channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size)
122 | num_feat = shape_feat[0]*shape_feat[1]*shape_feat[2]
123 | # self.classifier = nn.Linear(num_feat, num_classes)
124 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
125 | self.classifier = nn.Linear(shape_feat[0], num_classes)
126 |
127 | def forward(self, x):
128 | out = self.features(x)
129 | out = self.avgpool(out)
130 | out = out.view(out.size(0), -1)
131 | out = self.classifier(out)
132 | return out
133 |
134 | def _get_activation(self, net_act):
135 | if net_act == 'sigmoid':
136 | return nn.Sigmoid()
137 | elif net_act == 'relu':
138 | return nn.ReLU(inplace=True)
139 | elif net_act == 'leakyrelu':
140 | return nn.LeakyReLU(negative_slope=0.01)
141 | else:
142 | exit('unknown activation function: %s'%net_act)
143 |
144 | def _get_pooling(self, net_pooling):
145 | if net_pooling == 'maxpooling':
146 | return nn.MaxPool2d(kernel_size=2, stride=2)
147 | elif net_pooling == 'avgpooling':
148 | return nn.AvgPool2d(kernel_size=2, stride=2)
149 | elif net_pooling == 'none':
150 | return None
151 | else:
152 | exit('unknown net_pooling: %s'%net_pooling)
153 |
154 | def _get_normlayer(self, net_norm, shape_feat):
155 | # shape_feat = (c*h*w)
156 | if net_norm == 'batchnorm':
157 | return nn.BatchNorm2d(shape_feat[0], affine=True)
158 | elif net_norm == 'layernorm':
159 | return nn.LayerNorm(shape_feat, elementwise_affine=True)
160 | elif net_norm == 'instancenorm':
161 | return nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True)
162 | elif net_norm == 'groupnorm':
163 | return nn.GroupNorm(4, shape_feat[0], affine=True)
164 | elif net_norm == 'none':
165 | return None
166 | else:
167 | exit('unknown net_norm: %s'%net_norm)
168 |
169 | def _make_layers(self, channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size):
170 | layers = []
171 | in_channels = channel
172 | if im_size[0] == 28:
173 | im_size = (32, 32)
174 | shape_feat = [in_channels, im_size[0], im_size[1]]
175 | for d in range(net_depth):
176 | layers += [nn.Conv2d(in_channels, net_width, kernel_size=3, padding=3 if channel == 1 and d == 0 else 1)]
177 | shape_feat[0] = net_width
178 | if net_norm != 'none':
179 | layers += [self._get_normlayer(net_norm, shape_feat)]
180 | layers += [self._get_activation(net_act)]
181 | in_channels = net_width
182 | if net_pooling != 'none':
183 | layers += [self._get_pooling(net_pooling)]
184 | shape_feat[1] //= 2
185 | shape_feat[2] //= 2
186 |
187 | return nn.Sequential(*layers), shape_feat
188 |
189 |
190 | ''' LeNet '''
191 | class LeNet(nn.Module):
192 | def __init__(self, channel, num_classes):
193 | super(LeNet, self).__init__()
194 | self.features = nn.Sequential(
195 | nn.Conv2d(channel, 6, kernel_size=5, padding=2 if channel==1 else 0),
196 | nn.ReLU(inplace=True),
197 | nn.MaxPool2d(kernel_size=2, stride=2),
198 | nn.Conv2d(6, 16, kernel_size=5),
199 | nn.ReLU(inplace=True),
200 | nn.MaxPool2d(kernel_size=2, stride=2),
201 | )
202 | self.fc_1 = nn.Linear(16 * 5 * 5, 120)
203 | self.fc_2 = nn.Linear(120, 84)
204 | self.fc_3 = nn.Linear(84, num_classes)
205 |
206 | def forward(self, x):
207 | x = self.features(x)
208 | x = x.view(x.size(0), -1)
209 | x = F.relu(self.fc_1(x))
210 | x = F.relu(self.fc_2(x))
211 | x = self.fc_3(x)
212 | return x
213 |
214 |
215 |
216 | ''' AlexNet '''
217 | class AlexNet(nn.Module):
218 | def __init__(self, channel, num_classes):
219 | super(AlexNet, self).__init__()
220 | self.features = nn.Sequential(
221 | nn.Conv2d(channel, 128, kernel_size=5, stride=1, padding=4 if channel==1 else 2),
222 | nn.ReLU(inplace=True),
223 | nn.MaxPool2d(kernel_size=2, stride=2),
224 | nn.Conv2d(128, 192, kernel_size=5, padding=2),
225 | nn.ReLU(inplace=True),
226 | nn.MaxPool2d(kernel_size=2, stride=2),
227 | nn.Conv2d(192, 256, kernel_size=3, padding=1),
228 | nn.ReLU(inplace=True),
229 | nn.Conv2d(256, 192, kernel_size=3, padding=1),
230 | nn.ReLU(inplace=True),
231 | nn.Conv2d(192, 192, kernel_size=3, padding=1),
232 | nn.ReLU(inplace=True),
233 | nn.MaxPool2d(kernel_size=2, stride=2),
234 | )
235 | self.fc = nn.Linear(192 * 4 * 4, num_classes)
236 |
237 | def forward(self, x):
238 | x = self.features(x)
239 | x = x.view(x.size(0), -1)
240 | x = self.fc(x)
241 | return x
242 |
243 |
244 |
245 | ''' VGG '''
246 | cfg_vgg = {
247 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
248 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
249 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
250 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
251 | }
252 | class VGG(nn.Module):
253 | def __init__(self, vgg_name, channel, num_classes, norm='instancenorm'):
254 | super(VGG, self).__init__()
255 | self.channel = channel
256 | self.features = self._make_layers(cfg_vgg[vgg_name], norm)
257 | self.classifier = nn.Linear(512 if vgg_name != 'VGGS' else 128, num_classes)
258 |
259 | def forward(self, x):
260 | x = self.features(x)
261 | x = x.view(x.size(0), -1)
262 | x = self.classifier(x)
263 | return x
264 |
265 | def _make_layers(self, cfg, norm):
266 | layers = []
267 | in_channels = self.channel
268 | for ic, x in enumerate(cfg):
269 | if x == 'M':
270 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
271 | else:
272 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=3 if self.channel==1 and ic==0 else 1),
273 | nn.GroupNorm(x, x, affine=True) if norm=='instancenorm' else nn.BatchNorm2d(x),
274 | nn.ReLU(inplace=True)]
275 | in_channels = x
276 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
277 | return nn.Sequential(*layers)
278 |
279 |
280 | def VGG11(channel, num_classes):
281 | return VGG('VGG11', channel, num_classes)
282 | def VGG11BN(channel, num_classes):
283 | return VGG('VGG11', channel, num_classes, norm='batchnorm')
284 | def VGG13(channel, num_classes):
285 | return VGG('VGG13', channel, num_classes)
286 | def VGG16(channel, num_classes):
287 | return VGG('VGG16', channel, num_classes)
288 | def VGG19(channel, num_classes):
289 | return VGG('VGG19', channel, num_classes)
290 |
291 |
292 | ''' ResNet_AP '''
293 | # The conv(stride=2) is replaced by conv(stride=1) + avgpool(kernel_size=2, stride=2)
294 |
295 | class BasicBlock_AP(nn.Module):
296 | expansion = 1
297 |
298 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'):
299 | super(BasicBlock_AP, self).__init__()
300 | self.norm = norm
301 | self.stride = stride
302 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False) # modification
303 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
304 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
305 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
306 |
307 | self.shortcut = nn.Sequential()
308 | if stride != 1 or in_planes != self.expansion * planes:
309 | self.shortcut = nn.Sequential(
310 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=1, bias=False),
311 | nn.AvgPool2d(kernel_size=2, stride=2), # modification
312 | nn.GroupNorm(self.expansion * planes, self.expansion * planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion * planes)
313 | )
314 |
315 | def forward(self, x):
316 | out = F.relu(self.bn1(self.conv1(x)))
317 | if self.stride != 1: # modification
318 | out = F.avg_pool2d(out, kernel_size=2, stride=2)
319 | out = self.bn2(self.conv2(out))
320 | out += self.shortcut(x)
321 | out = F.relu(out)
322 | return out
323 |
324 |
325 | class Bottleneck_AP(nn.Module):
326 | expansion = 4
327 |
328 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'):
329 | super(Bottleneck_AP, self).__init__()
330 | self.norm = norm
331 | self.stride = stride
332 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
333 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
334 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) # modification
335 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
336 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
337 | self.bn3 = nn.GroupNorm(self.expansion * planes, self.expansion * planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion * planes)
338 |
339 | self.shortcut = nn.Sequential()
340 | if stride != 1 or in_planes != self.expansion * planes:
341 | self.shortcut = nn.Sequential(
342 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=1, bias=False),
343 | nn.AvgPool2d(kernel_size=2, stride=2), # modification
344 | nn.GroupNorm(self.expansion * planes, self.expansion * planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion * planes)
345 | )
346 |
347 | def forward(self, x):
348 | out = F.relu(self.bn1(self.conv1(x)))
349 | out = F.relu(self.bn2(self.conv2(out)))
350 | if self.stride != 1: # modification
351 | out = F.avg_pool2d(out, kernel_size=2, stride=2)
352 | out = self.bn3(self.conv3(out))
353 | out += self.shortcut(x)
354 | out = F.relu(out)
355 | return out
356 |
357 |
358 | class ResNet_AP(nn.Module):
359 | def __init__(self, block, num_blocks, channel=3, num_classes=10, norm='instancenorm'):
360 | super(ResNet_AP, self).__init__()
361 | self.in_planes = 64
362 | self.norm = norm
363 |
364 | self.conv1 = nn.Conv2d(channel, 64, kernel_size=3, stride=1, padding=1, bias=False)
365 | self.bn1 = nn.GroupNorm(64, 64, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(64)
366 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
367 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
368 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
369 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
370 | self.classifier = nn.Linear(512 * block.expansion * 3 * 3 if channel==1 else 512 * block.expansion * 4 * 4, num_classes) # modification
371 |
372 | def _make_layer(self, block, planes, num_blocks, stride):
373 | strides = [stride] + [1] * (num_blocks - 1)
374 | layers = []
375 | for stride in strides:
376 | layers.append(block(self.in_planes, planes, stride, self.norm))
377 | self.in_planes = planes * block.expansion
378 | return nn.Sequential(*layers)
379 |
380 | def forward(self, x):
381 | out = F.relu(self.bn1(self.conv1(x)))
382 | out = self.layer1(out)
383 | out = self.layer2(out)
384 | out = self.layer3(out)
385 | out = self.layer4(out)
386 | out = F.avg_pool2d(out, kernel_size=1, stride=1) # modification
387 | out = out.view(out.size(0), -1)
388 | out = self.classifier(out)
389 | return out
390 |
391 |
392 | def ResNet18BN_AP(channel, num_classes):
393 | return ResNet_AP(BasicBlock_AP, [2,2,2,2], channel=channel, num_classes=num_classes, norm='batchnorm')
394 |
395 | def ResNet18_AP(channel, num_classes):
396 | return ResNet_AP(BasicBlock_AP, [2,2,2,2], channel=channel, num_classes=num_classes)
397 |
398 |
399 | ''' ResNet '''
400 |
401 | class BasicBlock(nn.Module):
402 | expansion = 1
403 |
404 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'):
405 | super(BasicBlock, self).__init__()
406 | self.norm = norm
407 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
408 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
409 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
410 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
411 |
412 | self.shortcut = nn.Sequential()
413 | if stride != 1 or in_planes != self.expansion*planes:
414 | self.shortcut = nn.Sequential(
415 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
416 | nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes)
417 | )
418 |
419 | def forward(self, x):
420 | out = F.relu(self.bn1(self.conv1(x)))
421 | out = self.bn2(self.conv2(out))
422 | out += self.shortcut(x)
423 | out = F.relu(out)
424 | return out
425 |
426 |
427 | class Bottleneck(nn.Module):
428 | expansion = 4
429 |
430 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'):
431 | super(Bottleneck, self).__init__()
432 | self.norm = norm
433 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
434 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
435 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
436 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
437 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
438 | self.bn3 = nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes)
439 |
440 | self.shortcut = nn.Sequential()
441 | if stride != 1 or in_planes != self.expansion*planes:
442 | self.shortcut = nn.Sequential(
443 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
444 | nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes)
445 | )
446 |
447 | def forward(self, x):
448 | out = F.relu(self.bn1(self.conv1(x)))
449 | out = F.relu(self.bn2(self.conv2(out)))
450 | out = self.bn3(self.conv3(out))
451 | out += self.shortcut(x)
452 | out = F.relu(out)
453 | return out
454 |
455 |
456 | class ResNetImageNet(nn.Module):
457 | def __init__(self, block, num_blocks, channel=3, num_classes=10, norm='instancenorm'):
458 | super(ResNetImageNet, self).__init__()
459 | self.in_planes = 64
460 | self.norm = norm
461 |
462 | self.conv1 = nn.Conv2d(channel, 64, kernel_size=7, stride=2, padding=3, bias=False)
463 | self.bn1 = nn.GroupNorm(64, 64, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(64)
464 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
465 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
466 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
467 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
468 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
469 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
470 | self.classifier = nn.Linear(512*block.expansion, num_classes)
471 |
472 | def _make_layer(self, block, planes, num_blocks, stride):
473 | strides = [stride] + [1]*(num_blocks-1)
474 | layers = []
475 | for stride in strides:
476 | layers.append(block(self.in_planes, planes, stride, self.norm))
477 | self.in_planes = planes * block.expansion
478 | return nn.Sequential(*layers)
479 |
480 | def forward(self, x):
481 | out = F.relu(self.bn1(self.conv1(x)))
482 | out = self.maxpool(out)
483 | out = self.layer1(out)
484 | out = self.layer2(out)
485 | out = self.layer3(out)
486 | out = self.layer4(out)
487 | # out = F.avg_pool2d(out, 4)
488 | # out = out.view(out.size(0), -1)
489 | out = self.avgpool(out)
490 | out = torch.flatten(out, 1)
491 | out = self.classifier(out)
492 | return out
493 |
494 |
495 | def ResNet18BN(channel, num_classes):
496 | return ResNet(BasicBlock, [2,2,2,2], channel=channel, num_classes=num_classes, norm='batchnorm')
497 |
498 | def ResNet18(channel, num_classes):
499 | return ResNet_gn(BasicBlock, [2,2,2,2], channel=channel, num_classes=num_classes)
500 |
501 | def ResNet34(channel, num_classes):
502 | return ResNet(BasicBlock, [3,4,6,3], channel=channel, num_classes=num_classes)
503 |
504 | def ResNet50(channel, num_classes):
505 | return ResNet(Bottleneck, [3,4,6,3], channel=channel, num_classes=num_classes)
506 |
507 | def ResNet101(channel, num_classes):
508 | return ResNet(Bottleneck, [3,4,23,3], channel=channel, num_classes=num_classes)
509 |
510 | def ResNet152(channel, num_classes):
511 | return ResNet(Bottleneck, [3,8,36,3], channel=channel, num_classes=num_classes)
512 |
513 | def ResNet18ImageNet(channel, num_classes):
514 | return ResNetImageNet(BasicBlock, [2,2,2,2], channel=channel, num_classes=num_classes)
515 |
516 | def ResNet6ImageNet(channel, num_classes):
517 | return ResNetImageNet(BasicBlock, [1,1,1,1], channel=channel, num_classes=num_classes)
518 |
519 | def resnet18_gn(pretrained=False, **kwargs):
520 | """Constructs a ResNet-18 model.
521 | """
522 | model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2])
523 | return _create_resnet('resnet18', pretrained, **dict(model_args, **kwargs))
524 |
525 |
526 | ## Sourced directly from OpenAI's CLIP repo
527 | class ModifiedResNet(nn.Module):
528 | """
529 | A ResNet class that is similar to torchvision's but contains the following changes:
530 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
531 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
532 | - The final pooling layer is a QKV attention instead of an average pool
533 | """
534 |
535 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
536 | super().__init__()
537 | self.output_dim = output_dim
538 | self.input_resolution = input_resolution
539 |
540 | # the 3-layer stem
541 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
542 | self.bn1 = nn.BatchNorm2d(width // 2)
543 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
544 | self.bn2 = nn.BatchNorm2d(width // 2)
545 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
546 | self.bn3 = nn.BatchNorm2d(width)
547 | self.avgpool = nn.AvgPool2d(2)
548 | self.relu = nn.ReLU(inplace=True)
549 |
550 | # residual layers
551 | self._inplanes = width # this is a *mutable* variable used during construction
552 | self.layer1 = self._make_layer(width, layers[0])
553 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
554 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
555 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
556 |
557 | embed_dim = width * 32 # the ResNet feature dimension
558 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
559 |
560 | def _make_layer(self, planes, blocks, stride=1):
561 | layers = [Bottleneck(self._inplanes, planes, stride)]
562 |
563 | self._inplanes = planes * Bottleneck.expansion
564 | for _ in range(1, blocks):
565 | layers.append(Bottleneck(self._inplanes, planes))
566 |
567 | return nn.Sequential(*layers)
568 |
569 | def forward(self, x):
570 | def stem(x):
571 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
572 | x = self.relu(bn(conv(x)))
573 | x = self.avgpool(x)
574 | return x
575 |
576 | x = x.type(self.conv1.weight.dtype)
577 | x = stem(x)
578 | x = self.layer1(x)
579 | x = self.layer2(x)
580 | x = self.layer3(x)
581 | x = self.layer4(x)
582 | x = self.attnpool(x)
583 |
584 | return x
585 |
586 |
587 | class AttentionPool2d(nn.Module):
588 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
589 | super().__init__()
590 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
591 | self.k_proj = nn.Linear(embed_dim, embed_dim)
592 | self.q_proj = nn.Linear(embed_dim, embed_dim)
593 | self.v_proj = nn.Linear(embed_dim, embed_dim)
594 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
595 | self.num_heads = num_heads
596 |
597 | def forward(self, x):
598 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
599 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
600 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
601 | x, _ = F.multi_head_attention_forward(
602 | query=x, key=x, value=x,
603 | embed_dim_to_check=x.shape[-1],
604 | num_heads=self.num_heads,
605 | q_proj_weight=self.q_proj.weight,
606 | k_proj_weight=self.k_proj.weight,
607 | v_proj_weight=self.v_proj.weight,
608 | in_proj_weight=None,
609 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
610 | bias_k=None,
611 | bias_v=None,
612 | add_zero_attn=False,
613 | dropout_p=0,
614 | out_proj_weight=self.c_proj.weight,
615 | out_proj_bias=self.c_proj.bias,
616 | use_separate_proj_weight=True,
617 | training=self.training,
618 | need_weights=False
619 | )
620 |
621 | return x[0]
622 |
623 | import timm
624 |
625 | class ProjectionHead(nn.Module):
626 | def __init__(
627 | self,
628 | embedding_dim,
629 | projection_dim=768,
630 | dropout=0.1
631 | ):
632 | super().__init__()
633 | self.projection = nn.Linear(embedding_dim, projection_dim)
634 | self.gelu = nn.GELU()
635 | self.fc = nn.Linear(projection_dim, projection_dim)
636 | self.dropout = nn.Dropout(dropout)
637 | self.layer_norm = nn.LayerNorm(projection_dim)
638 |
639 | def forward(self, x):
640 | projected = self.projection(x)
641 | x = self.gelu(projected)
642 | x = self.fc(x)
643 | x = self.dropout(x)
644 | x = x + projected
645 | x = self.layer_norm(x)
646 | return x
647 |
648 | class ImageEncoder(nn.Module):
649 | """
650 | Encode images to a fixed size vector
651 | """
652 |
653 | def __init__(self, args, eval_stage):
654 | super().__init__()
655 | self.model_name = args.image_encoder
656 | self.pretrained = args.image_pretrained
657 | self.trainable = args.image_trainable
658 |
659 | if self.model_name == 'clip':
660 | if self.pretrained:
661 | self.model, preprocess = clip.load("ViT-B/32", device='cuda')
662 | else:
663 | configuration = ViTConfig()
664 | self.model = ViTModel(configuration)
665 | elif self.model_name == 'nfnet':
666 | self.model = timm.create_model('nfnet_l0', pretrained=self.pretrained, num_classes=0, global_pool="avg")
667 | elif self.model_name == 'vit':
668 | self.model = timm.create_model('vit_tiny_patch16_224', pretrained=True)
669 | elif self.model_name == 'nf_resnet50':
670 | self.model = timm.create_model('nf_resnet50', pretrained=True)
671 | elif self.model_name == 'nf_regnet':
672 | self.model = timm.create_model('nf_regnet_b1', pretrained=True)
673 | else:
674 | self.model = timm.create_model(self.model_name, self.pretrained, num_classes=0, global_pool="avg")
675 | for p in self.model.parameters():
676 | p.requires_grad = self.trainable
677 |
678 | def forward(self, x):
679 | if self.model_name == 'clip' and self.pretrained:
680 | return self.model.encode_image(x)
681 | else:
682 | return self.model(x)
683 |
684 | def gradient(self, x, y):
685 | # Compute the gradient of the mean squared error loss with respect to the weights
686 | loss = self.loss(x, y)
687 | grad = torch.autograd.grad(loss, self.parameters(), create_graph=True)
688 | return torch.cat([g.view(-1) for g in grad])
689 |
690 |
691 |
692 |
693 | class TextEncoder(nn.Module):
694 | def __init__(self, args):
695 | super().__init__()
696 | self.pretrained = args.text_pretrained
697 | self.trainable = args.text_trainable
698 | self.model_name = args.text_encoder
699 | if self.model_name == 'clip':
700 | self.model, preprocess = clip.load("ViT-B/32", device='cuda')
701 | elif self.model_name == 'bert':
702 | if args.text_pretrained:
703 | self.model = BERT_model
704 | else:
705 | self.model = BertModel(BertConfig())
706 | self.model.init_weights()
707 | self.tokenizer = tokenizer
708 | else:
709 | raise NotImplementedError
710 |
711 | for p in self.model.parameters():
712 | p.requires_grad = self.trainable
713 |
714 | # we are using the CLS token hidden representation as the sentence's embedding
715 | self.target_token_idx = 0
716 |
717 | def forward(self, texts, device='cuda'):
718 | if self.model_name == 'clip':
719 | output = self.model.encode_text(clip.tokenize(texts).to('cuda'))
720 |
721 | elif self.model_name == 'bert':
722 | # Tokenize the input text
723 | encoding = self.tokenizer.batch_encode_plus(texts, return_tensors='pt', padding=True, truncation=True)
724 | input_ids = encoding['input_ids'].to(device)
725 | attention_mask = encoding['attention_mask'].to(device)
726 | output = self.model(input_ids, attention_mask=attention_mask).last_hidden_state[:, self.target_token_idx, :]
727 | return output
728 |
729 |
730 |
731 | class CLIPModel_full(nn.Module):
732 | def __init__(
733 | self,
734 | args,
735 | temperature=1.0,
736 | eval_stage=False
737 | ):
738 | super().__init__()
739 |
740 | if args.image_encoder == 'nfnet':
741 | if eval_stage:
742 | self.image_embedding = 1000#2048
743 | else:
744 | self.image_embedding = 2304
745 | elif args.image_encoder == 'convnet':
746 | self.image_embedding = 768
747 | elif args.image_encoder == 'resnet18':
748 | self.image_embedding = 512
749 | elif args.image_encoder == 'convnext':
750 | self.image_embedding = 640
751 | else:
752 | self.image_embedding = 1000
753 | if args.text_encoder == 'clip':
754 | self.text_embedding = 512
755 | elif args.text_encoder == 'bert':
756 | self.text_embedding = 768
757 | else:
758 | raise NotImplementedError
759 |
760 | self.image_encoder = ImageEncoder(args, eval_stage=eval_stage)
761 | self.text_encoder = TextEncoder(args)
762 |
763 | if args.only_has_image_projection:
764 | self.image_projection = ProjectionHead(embedding_dim=self.image_embedding)
765 | self.text_projection = ProjectionHead(embedding_dim=self.text_embedding, projection_dim=self.image_embedding).to('cuda')
766 | self.temperature = temperature
767 | #self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
768 | self.args = args
769 | self.distill = args.distill
770 |
771 | def forward(self, image, caption, epoch):
772 | self.image_encoder = self.image_encoder.to('cuda')
773 | self.text_encoder = self.text_encoder.to('cuda')
774 |
775 | image_features = self.image_encoder(image)
776 | text_features = caption if self.distill else self.text_encoder(caption)
777 |
778 | use_image_project = False
779 | im_embed = image_features.float() if not use_image_project else self.image_projection(image_features.float())
780 | txt_embed = self.text_projection(text_features.float())
781 |
782 | combined_image_features = im_embed
783 | combined_text_features = txt_embed
784 | image_features = combined_image_features / combined_image_features.norm(dim=1, keepdim=True)
785 | text_features = combined_text_features / combined_text_features.norm(dim=1, keepdim=True)
786 |
787 |
788 | image_logits = np.exp(np.log(1 / 0.07)) * image_features @ text_features.t()
789 | ground_truth = torch.arange(len(image_logits)).type_as(image_logits).long()
790 | loss = (F.cross_entropy(image_logits, ground_truth) + F.cross_entropy(image_logits.t(), ground_truth))/2
791 | acc_i = (torch.argmax(image_logits, 1) == ground_truth).sum().item()
792 | acc_t = (torch.argmax(image_logits, 0) == ground_truth).sum().item()
793 | acc = (acc_i + acc_t) / 2
794 | return loss, acc
--------------------------------------------------------------------------------
/reparam_module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import warnings
4 | import types
5 | from collections import namedtuple
6 | from contextlib import contextmanager
7 |
8 |
9 | class ReparamModule(nn.Module):
10 | def _get_module_from_name(self, mn):
11 | if mn == '':
12 | return self
13 | m = self
14 | for p in mn.split('.'):
15 | m = getattr(m, p)
16 | return m
17 |
18 | def __init__(self, module):
19 | super(ReparamModule, self).__init__()
20 | self.module = module
21 |
22 | param_infos = [] # (module name/path, param name)
23 | shared_param_memo = {}
24 | shared_param_infos = [] # (module name/path, param name, src module name/path, src param_name)
25 | params = []
26 | param_numels = []
27 | param_shapes = []
28 | for mn, m in self.named_modules():
29 | for n, p in m.named_parameters(recurse=False):
30 | if p is not None:
31 | if p in shared_param_memo:
32 | shared_mn, shared_n = shared_param_memo[p]
33 | shared_param_infos.append((mn, n, shared_mn, shared_n))
34 | else:
35 | shared_param_memo[p] = (mn, n)
36 | param_infos.append((mn, n))
37 | params.append(p.detach())
38 | param_numels.append(p.numel())
39 | param_shapes.append(p.size())
40 |
41 | assert len(set(p.dtype for p in params)) <= 1, \
42 | "expects all parameters in module to have same dtype"
43 |
44 | # store the info for unflatten
45 | self._param_infos = tuple(param_infos)
46 | self._shared_param_infos = tuple(shared_param_infos)
47 | self._param_numels = tuple(param_numels)
48 | self._param_shapes = tuple(param_shapes)
49 |
50 | # flatten
51 | flat_param = nn.Parameter(torch.cat([p.reshape(-1) for p in params], 0))
52 | self.register_parameter('flat_param', flat_param)
53 | self.param_numel = flat_param.numel()
54 | del params
55 | del shared_param_memo
56 |
57 | # deregister the names as parameters
58 | for mn, n in self._param_infos:
59 | delattr(self._get_module_from_name(mn), n)
60 | for mn, n, _, _ in self._shared_param_infos:
61 | delattr(self._get_module_from_name(mn), n)
62 |
63 | # register the views as plain attributes
64 | self._unflatten_param(self.flat_param)
65 |
66 | # now buffers
67 | # they are not reparametrized. just store info as (module, name, buffer)
68 | buffer_infos = []
69 | for mn, m in self.named_modules():
70 | for n, b in m.named_buffers(recurse=False):
71 | if b is not None:
72 | buffer_infos.append((mn, n, b))
73 |
74 | self._buffer_infos = tuple(buffer_infos)
75 | self._traced_self = None
76 |
77 | def trace(self, example_input, **trace_kwargs):
78 | assert self._traced_self is None, 'This ReparamModule is already traced'
79 |
80 | if isinstance(example_input, torch.Tensor):
81 | example_input = (example_input,)
82 | example_input = tuple(example_input)
83 | example_param = (self.flat_param.detach().clone(),)
84 | example_buffers = (tuple(b.detach().clone() for _, _, b in self._buffer_infos),)
85 |
86 | self._traced_self = torch.jit.trace_module(
87 | self,
88 | inputs=dict(
89 | _forward_with_param=example_param + example_input,
90 | _forward_with_param_and_buffers=example_param + example_buffers + example_input,
91 | ),
92 | **trace_kwargs,
93 | )
94 |
95 | # replace forwards with traced versions
96 | self._forward_with_param = self._traced_self._forward_with_param
97 | self._forward_with_param_and_buffers = self._traced_self._forward_with_param_and_buffers
98 | return self
99 |
100 | def clear_views(self):
101 | for mn, n in self._param_infos:
102 | setattr(self._get_module_from_name(mn), n, None) # This will set as plain attr
103 |
104 | def _apply(self, *args, **kwargs):
105 | if self._traced_self is not None:
106 | self._traced_self._apply(*args, **kwargs)
107 | return self
108 | return super(ReparamModule, self)._apply(*args, **kwargs)
109 |
110 | def _unflatten_param(self, flat_param):
111 | ps = (t.view(s) for (t, s) in zip(flat_param.split(self._param_numels), self._param_shapes))
112 | for (mn, n), p in zip(self._param_infos, ps):
113 | setattr(self._get_module_from_name(mn), n, p) # This will set as plain attr
114 | for (mn, n, shared_mn, shared_n) in self._shared_param_infos:
115 | setattr(self._get_module_from_name(mn), n, getattr(self._get_module_from_name(shared_mn), shared_n))
116 |
117 | @contextmanager
118 | def unflattened_param(self, flat_param):
119 | saved_views = [getattr(self._get_module_from_name(mn), n) for mn, n in self._param_infos]
120 | self._unflatten_param(flat_param)
121 | yield
122 | # Why not just `self._unflatten_param(self.flat_param)`?
123 | # 1. because of https://github.com/pytorch/pytorch/issues/17583
124 | # 2. slightly faster since it does not require reconstruct the split+view
125 | # graph
126 | for (mn, n), p in zip(self._param_infos, saved_views):
127 | setattr(self._get_module_from_name(mn), n, p)
128 | for (mn, n, shared_mn, shared_n) in self._shared_param_infos:
129 | setattr(self._get_module_from_name(mn), n, getattr(self._get_module_from_name(shared_mn), shared_n))
130 |
131 | @contextmanager
132 | def replaced_buffers(self, buffers):
133 | for (mn, n, _), new_b in zip(self._buffer_infos, buffers):
134 | setattr(self._get_module_from_name(mn), n, new_b)
135 | yield
136 | for mn, n, old_b in self._buffer_infos:
137 | setattr(self._get_module_from_name(mn), n, old_b)
138 |
139 | def _forward_with_param_and_buffers(self, flat_param, buffers, *inputs, **kwinputs):
140 | with self.unflattened_param(flat_param):
141 | with self.replaced_buffers(buffers):
142 | return self.module(*inputs, **kwinputs)
143 |
144 | def _forward_with_param(self, flat_param, *inputs, **kwinputs):
145 | with self.unflattened_param(flat_param):
146 | return self.module(*inputs, **kwinputs)
147 |
148 | def forward(self, *inputs, flat_param=None, buffers=None, **kwinputs):
149 | flat_param = torch.squeeze(flat_param)
150 | # print("PARAMS ON DEVICE: ", flat_param.get_device())
151 | # print("DATA ON DEVICE: ", inputs[0].get_device())
152 | # flat_param.to("cuda:{}".format(inputs[0].get_device()))
153 | # self.module.to("cuda:{}".format(inputs[0].get_device()))
154 | if flat_param is None:
155 | flat_param = self.flat_param
156 | if buffers is None:
157 | return self._forward_with_param(flat_param, *inputs, **kwinputs)
158 | else:
159 | return self._forward_with_param_and_buffers(flat_param, tuple(buffers), *inputs, **kwinputs)
--------------------------------------------------------------------------------
/requirements.yaml:
--------------------------------------------------------------------------------
1 | name: vl-distill
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1=main
8 | - _openmp_mutex=4.5=1_gnu
9 | - absl-py=1.2.0=pypi_0
10 | - accelerate=0.17.0=pypi_0
11 | - addict=2.4.0=pypi_0
12 | - aiohttp=3.8.2=pypi_0
13 | - aiosignal=1.2.0=pypi_0
14 | - antlr4-python3-runtime=4.9.3=pypi_0
15 | - anyio=3.6.1=pypi_0
16 | - argcomplete=2.0.0=pypi_0
17 | - argon2-cffi=21.3.0=pypi_0
18 | - argon2-cffi-bindings=21.2.0=pypi_0
19 | - asttokens=2.0.8=pypi_0
20 | - async-timeout=4.0.2=pypi_0
21 | - attrs=22.1.0=pypi_0
22 | - autopep8=1.7.0=pypi_0
23 | - babel=2.10.3=pypi_0
24 | - backcall=0.2.0=pypi_0
25 | - beautifulsoup4=4.11.1=pypi_0
26 | - blas=1.0=mkl
27 | - bleach=5.0.1=pypi_0
28 | - boto=2.49.0=pypi_0
29 | - braceexpand=0.1.7=pypi_0
30 | - brotlipy=0.7.0=py39h27cfd23_1003
31 | - bzip2=1.0.8=h7b6447c_0
32 | - ca-certificates=2021.10.8=ha878542_0
33 | - cachetools=5.2.0=pypi_0
34 | - certifi=2021.10.8=py39hf3d152e_1
35 | - cffi=1.15.0=py39hd667e15_1
36 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
37 | - chatgdb=1.0.3=pypi_0
38 | - click=7.1.2=pyh9f0ad1d_0
39 | - clip=1.0=pypi_0
40 | - cloudpickle=2.2.0=pypi_0
41 | - cmake=3.26.3=pypi_0
42 | - colorama=0.4.4=pyh9f0ad1d_0
43 | - configparser=5.2.0=pyhd8ed1ab_0
44 | - contourpy=1.0.5=pypi_0
45 | - crcmod=1.7=pypi_0
46 | - cryptography=36.0.0=py39h9ce1e76_0
47 | - cudatoolkit=11.3.1=h2bc3f7f_2
48 | - cycler=0.11.0=pypi_0
49 | - dataclasses=0.8=pyhc8e2a94_3
50 | - datasets=2.7.1=pypi_0
51 | - debugpy=1.6.3=pypi_0
52 | - decorator=5.1.1=pypi_0
53 | - decord=0.6.0=pypi_0
54 | - defusedxml=0.7.1=pypi_0
55 | - deprecation=2.1.0=pypi_0
56 | - diffusers=0.14.0=pypi_0
57 | - dill=0.3.6=pypi_0
58 | - distlib=0.3.6=pypi_0
59 | - dm-haiku=0.0.9=pypi_0
60 | - docker-pycreds=0.4.0=py_0
61 | - einops=0.5.0=pypi_0
62 | - entrypoints=0.4=pypi_0
63 | - executing=1.1.0=pypi_0
64 | - fairscale=0.4.12=pypi_0
65 | - fasteners=0.18=pypi_0
66 | - fastjsonschema=2.16.2=pypi_0
67 | - ffmpeg=4.3=hf484d3e_0
68 | - filelock=3.8.0=pypi_0
69 | - fonttools=4.37.3=pypi_0
70 | - freetype=2.11.0=h70c0345_0
71 | - frozenlist=1.3.1=pypi_0
72 | - fsspec=2022.8.2=pypi_0
73 | - ftfy=6.1.1=pypi_0
74 | - fvcore=0.1.5.post20220512=pypi_0
75 | - gcs-oauth2-boto-plugin=3.0=pypi_0
76 | - gdown=4.7.1=pypi_0
77 | - giflib=5.2.1=h7b6447c_0
78 | - git-filter-repo=2.38.0=pypi_0
79 | - gitdb=4.0.9=pyhd8ed1ab_0
80 | - gitpython=3.1.27=pyhd8ed1ab_0
81 | - gmp=6.2.1=h2531618_2
82 | - gnutls=3.6.15=he1e5248_0
83 | - google-apitools=0.5.32=pypi_0
84 | - google-auth=2.11.1=pypi_0
85 | - google-auth-oauthlib=0.4.6=pypi_0
86 | - google-reauth=0.1.1=pypi_0
87 | - gql=0.2.0=pypi_0
88 | - graphql-core=1.1=pypi_0
89 | - grpcio=1.48.1=pypi_0
90 | - gsutil=5.14=pypi_0
91 | - gym=0.26.2=pypi_0
92 | - gym-notices=0.0.8=pypi_0
93 | - httplib2=0.20.4=pypi_0
94 | - huggingface-hub=0.13.2=pypi_0
95 | - idna=3.3=pyhd3eb1b0_0
96 | - importlib-metadata=4.12.0=pypi_0
97 | - intel-openmp=2021.4.0=h06a4308_3561
98 | - iopath=0.1.10=pypi_0
99 | - ipdb=0.13.13=pypi_0
100 | - ipykernel=6.16.0=pypi_0
101 | - ipython=8.5.0=pypi_0
102 | - ipython-genutils=0.2.0=pypi_0
103 | - ipywidgets=8.0.2=pypi_0
104 | - isort=5.10.1=pypi_0
105 | - jedi=0.18.1=pypi_0
106 | - jinja2=3.1.2=pypi_0
107 | - jmp=0.0.4=pypi_0
108 | - joblib=1.2.0=pypi_0
109 | - jpeg=9d=h7f8727e_0
110 | - json5=0.9.10=pypi_0
111 | - jsonschema=4.16.0=pypi_0
112 | - jstyleson=0.0.2=pypi_0
113 | - jupyter-client=7.3.5=pypi_0
114 | - jupyter-core=4.11.1=pypi_0
115 | - jupyter-packaging=0.12.3=pypi_0
116 | - jupyter-server=1.19.0=pypi_0
117 | - jupyterlab=3.4.7=pypi_0
118 | - jupyterlab-pygments=0.2.2=pypi_0
119 | - jupyterlab-server=2.15.2=pypi_0
120 | - jupyterlab-widgets=3.0.3=pypi_0
121 | - kaggle=1.5.12=pypi_0
122 | - kiwisolver=1.4.4=pypi_0
123 | - kornia=0.6.3=pyhd8ed1ab_0
124 | - lame=3.100=h7b6447c_0
125 | - lcms2=2.12=h3be6417_0
126 | - ld_impl_linux-64=2.35.1=h7274673_9
127 | - learn2learn=0.1.7=pypi_0
128 | - libffi=3.3=he6710b0_2
129 | - libgcc-ng=9.3.0=h5101ec6_17
130 | - libgfortran-ng=7.5.0=ha8ba4b0_17
131 | - libgfortran4=7.5.0=ha8ba4b0_17
132 | - libgomp=9.3.0=h5101ec6_17
133 | - libiconv=1.15=h63c8f33_5
134 | - libidn2=2.3.2=h7f8727e_0
135 | - libpng=1.6.37=hbc83047_0
136 | - libprotobuf=3.15.8=h780b84a_0
137 | - libstdcxx-ng=9.3.0=hd4cf53a_17
138 | - libtasn1=4.16.0=h27cfd23_0
139 | - libtiff=4.2.0=h85742a9_0
140 | - libunistring=0.9.10=h27cfd23_0
141 | - libuv=1.40.0=h7b6447c_0
142 | - libwebp=1.2.2=h55f646e_0
143 | - libwebp-base=1.2.2=h7f8727e_0
144 | - lit=16.0.1=pypi_0
145 | - llvmlite=0.39.1=pypi_0
146 | - logger=1.4=pypi_0
147 | - lxml=4.9.1=pypi_0
148 | - lz4-c=1.9.3=h295c915_1
149 | - markdown=3.4.1=pypi_0
150 | - markupsafe=2.1.1=pypi_0
151 | - matplotlib=3.6.0=pypi_0
152 | - matplotlib-inline=0.1.6=pypi_0
153 | - mistune=2.0.4=pypi_0
154 | - mkl=2021.4.0=h06a4308_640
155 | - mkl-service=2.4.0=py39h7f8727e_0
156 | - mkl_fft=1.3.1=py39hd3c417c_0
157 | - mkl_random=1.2.2=py39h51133e4_0
158 | - monotonic=1.6=pypi_0
159 | - mpmath=1.3.0=pypi_0
160 | - mss=6.1.0=pypi_0
161 | - multidict=5.2.0=pypi_0
162 | - multiprocess=0.70.14=pypi_0
163 | - nbclassic=0.4.3=pypi_0
164 | - nbclient=0.6.8=pypi_0
165 | - nbconvert=7.0.0=pypi_0
166 | - nbformat=5.6.1=pypi_0
167 | - ncurses=6.3=h7f8727e_2
168 | - nest-asyncio=1.5.5=pypi_0
169 | - nettle=3.7.3=hbbd107a_1
170 | - networkx=2.8.8=pypi_0
171 | - nltk=3.8=pypi_0
172 | - notebook=6.4.12=pypi_0
173 | - notebook-shim=0.1.0=pypi_0
174 | - numba=0.56.4=pypi_0
175 | - numpy=1.22.4=pypi_0
176 | - nvidia-cublas-cu11=11.10.3.66=pypi_0
177 | - nvidia-cuda-cupti-cu11=11.7.101=pypi_0
178 | - nvidia-cuda-nvrtc-cu11=11.7.99=pypi_0
179 | - nvidia-cuda-runtime-cu11=11.7.99=pypi_0
180 | - nvidia-cudnn-cu11=8.5.0.96=pypi_0
181 | - nvidia-cufft-cu11=10.9.0.58=pypi_0
182 | - nvidia-curand-cu11=10.2.10.91=pypi_0
183 | - nvidia-cusolver-cu11=11.4.0.1=pypi_0
184 | - nvidia-cusparse-cu11=11.7.4.91=pypi_0
185 | - nvidia-ml-py3=7.352.0=pypi_0
186 | - nvidia-nccl-cu11=2.14.3=pypi_0
187 | - nvidia-nvtx-cu11=11.7.91=pypi_0
188 | - oauth2client=4.1.3=pypi_0
189 | - oauthlib=3.2.1=pypi_0
190 | - omegaconf=2.2.3=pypi_0
191 | - open3d=0.15.2=pypi_0
192 | - opencv-python=4.6.0.66=pypi_0
193 | - opendatasets=0.1.22=pypi_0
194 | - openh264=2.1.1=h4ff587b_0
195 | - openssl=1.1.1m=h7f8727e_0
196 | - openvino=2022.3.0=pypi_0
197 | - openvino-dev=2022.3.0=pypi_0
198 | - openvino-telemetry=2022.3.0=pypi_0
199 | - optree=0.9.0=pypi_0
200 | - packaging=21.3=pyhd8ed1ab_0
201 | - pandas=1.3.5=pypi_0
202 | - pandocfilters=1.5.0=pypi_0
203 | - parso=0.8.3=pypi_0
204 | - pathtools=0.1.2=py_1
205 | - pexpect=4.8.0=pypi_0
206 | - pickleshare=0.7.5=pypi_0
207 | - pillow=9.0.1=py39h22f2fdc_0
208 | - pip=21.2.4=py39h06a4308_0
209 | - pipenv=2022.9.21=pypi_0
210 | - platformdirs=2.5.2=pypi_0
211 | - plotext=5.2.8=pypi_0
212 | - portalocker=2.5.1=pypi_0
213 | - prefetch-generator=1.0.3=pypi_0
214 | - prometheus-client=0.14.1=pypi_0
215 | - promise=2.3=py39hf3d152e_5
216 | - prompt-toolkit=3.0.31=pypi_0
217 | - protobuf=3.15.8=py39he80948d_0
218 | - psutil=5.8.0=py39h27cfd23_1
219 | - ptyprocess=0.7.0=pypi_0
220 | - pure-eval=0.2.2=pypi_0
221 | - pyarrow=6.0.1=pypi_0
222 | - pyasn1=0.4.8=pypi_0
223 | - pyasn1-modules=0.2.8=pypi_0
224 | - pycocoevalcap=1.2=pypi_0
225 | - pycocotools=2.0.6=pypi_0
226 | - pycodestyle=2.9.1=pypi_0
227 | - pycparser=2.21=pyhd3eb1b0_0
228 | - pydeprecate=0.3.2=pypi_0
229 | - pygments=2.13.0=pypi_0
230 | - pyopenssl=22.0.0=pyhd3eb1b0_0
231 | - pyparsing=3.0.7=pyhd8ed1ab_0
232 | - pyquaternion=0.9.9=pypi_0
233 | - pyrsistent=0.18.1=pypi_0
234 | - pysocks=1.7.1=py39h06a4308_0
235 | - python=3.9.7=h12debd9_1
236 | - python-dateutil=2.8.2=pyhd8ed1ab_0
237 | - python-graphviz=0.20.1=pypi_0
238 | - python-slugify=6.1.2=pypi_0
239 | - python_abi=3.9=2_cp39
240 | - pytorch-lightning=1.7.6=pypi_0
241 | - pytorch-mutex=1.0=cuda
242 | - pytz=2022.2.1=pypi_0
243 | - pyu2f=0.1.5=pypi_0
244 | - pyyaml=5.4.1=py39h3811e60_0
245 | - pyzmq=24.0.1=pypi_0
246 | - qpth=0.0.15=pypi_0
247 | - readline=8.1.2=h7f8727e_1
248 | - regex=2022.9.13=pypi_0
249 | - requests=2.27.1=pyhd3eb1b0_0
250 | - requests-oauthlib=1.3.1=pypi_0
251 | - responses=0.18.0=pypi_0
252 | - retry-decorator=1.1.1=pypi_0
253 | - rsa=4.7.2=pypi_0
254 | - ruamel-yaml=0.17.21=pypi_0
255 | - ruamel-yaml-clib=0.2.7=pypi_0
256 | - scikit-learn=1.1.2=pypi_0
257 | - scipy=1.10.1=pypi_0
258 | - send2trash=1.8.0=pypi_0
259 | - sentence-transformers=2.2.2=pypi_0
260 | - sentencepiece=0.1.97=pypi_0
261 | - sentry-sdk=1.5.7=pyhd8ed1ab_0
262 | - setproctitle=1.2.2=py39h3811e60_0
263 | - setuptools=65.4.0=pypi_0
264 | - shortuuid=1.0.8=py39hf3d152e_0
265 | - six=1.16.0=pyhd3eb1b0_1
266 | - smmap=3.0.5=pyh44b312d_0
267 | - sniffio=1.3.0=pypi_0
268 | - soupsieve=2.3.2.post1=pypi_0
269 | - sqlite=3.38.0=hc218d9a_0
270 | - stack-data=0.5.1=pypi_0
271 | - subprocess32=3.5.4=pypi_0
272 | - sympy=1.11.1=pypi_0
273 | - tabulate=0.8.10=pypi_0
274 | - tensorboard=2.10.0=pypi_0
275 | - tensorboard-data-server=0.6.1=pypi_0
276 | - tensorboard-plugin-wit=1.8.1=pypi_0
277 | - termcolor=1.1.0=py_2
278 | - terminado=0.15.0=pypi_0
279 | - text-unidecode=1.3=pypi_0
280 | - texttable=1.6.7=pypi_0
281 | - threadpoolctl=3.1.0=pypi_0
282 | - timm=0.6.7=pypi_0
283 | - tinycss2=1.1.1=pypi_0
284 | - tk=8.6.11=h1ccaba5_0
285 | - tokenizers=0.12.1=pypi_0
286 | - toml=0.10.2=pypi_0
287 | - tomli=2.0.1=pypi_0
288 | - tomlkit=0.11.4=pypi_0
289 | - torch=2.0.0=pypi_0
290 | - torchaudio=0.11.0=py39_cu113
291 | - torchmetrics=0.9.3=pypi_0
292 | - torchopt=0.7.0=pypi_0
293 | - torchvision=0.15.1=pypi_0
294 | - tornado=6.2=pypi_0
295 | - tqdm=4.63.0=pyhd8ed1ab_0
296 | - traitlets=5.4.0=pypi_0
297 | - transformers=4.26.1=pypi_0
298 | - triton=2.0.0=pypi_0
299 | - typing-extensions=4.3.0=pypi_0
300 | - tzdata=2021e=hda174b7_0
301 | - urllib3=1.26.8=pyhd3eb1b0_0
302 | - virtualenv=20.16.5=pypi_0
303 | - virtualenv-clone=0.5.7=pypi_0
304 | - wandb=0.13.7=pypi_0
305 | - watchdog=2.2.1=pypi_0
306 | - wcwidth=0.2.5=pypi_0
307 | - webdataset=0.2.30=pypi_0
308 | - webencodings=0.5.1=pypi_0
309 | - websocket-client=1.4.1=pypi_0
310 | - werkzeug=2.2.2=pypi_0
311 | - wheel=0.37.1=pyhd3eb1b0_0
312 | - widgetsnbextension=4.0.3=pypi_0
313 | - xxhash=3.1.0=pypi_0
314 | - xz=5.2.5=h7b6447c_0
315 | - yacs=0.1.8=pypi_0
316 | - yaml=0.2.5=h516909a_0
317 | - yarl=1.8.1=pypi_0
318 | - yaspin=2.1.0=pyhd8ed1ab_0
319 | - zipp=3.8.1=pypi_0
320 | - zlib=1.2.11=h7f8727e_4
321 | - zstd=1.4.9=haebb681_0
322 |
--------------------------------------------------------------------------------
/style.css:
--------------------------------------------------------------------------------
1 | body {
2 | font-family: "HelveticaNeue-Light", "Helvetica Neue Light", "Helvetica Neue", Helvetica, Arial, "Lucida Grande", sans-serif;
3 | font-weight:300;
4 | font-size:18px;
5 | margin-left: auto;
6 | margin-right: auto;
7 | width: 1100px;
8 | }
9 |
10 | table {
11 | margin-left: auto;
12 | margin-right: auto;
13 | }
14 |
15 | tr {
16 | /*margin-left: auto;
17 | margin-right: auto;*/
18 | text-align: center;
19 | }
20 |
21 | h1 {
22 | font-weight:300;
23 | text-align: center;
24 | }
25 |
26 | h2 {
27 | text-align: center;
28 | font-family: "HelveticaNeue-Light", "Helvetica Neue Light", "Helvetica Neue", Helvetica, Arial, "Lucida Grande", sans-serif;
29 | }
30 |
31 | div {
32 | float: none;
33 | margin-left: auto;
34 | margin-right: auto;
35 | font-size: 0;
36 | }
37 |
38 | .disclaimerbox {
39 | background-color: #eee;
40 | border: 1px solid #eeeeee;
41 | border-radius: 10px ;
42 | -moz-border-radius: 10px ;
43 | -webkit-border-radius: 10px ;
44 | padding: 20px;
45 | }
46 |
47 | video.header-vid {
48 | height: 140px;
49 | border: 1px solid black;
50 | border-radius: 10px ;
51 | -moz-border-radius: 10px ;
52 | -webkit-border-radius: 10px ;
53 | }
54 |
55 | img.header-img {
56 | height: 140px;
57 | border: 1px solid black;
58 | border-radius: 10px ;
59 | -moz-border-radius: 10px ;
60 | -webkit-border-radius: 10px ;
61 | }
62 |
63 | img.rounded {
64 | border: 1px solid #eeeeee;
65 | border-radius: 10px ;
66 | -moz-border-radius: 10px ;
67 | -webkit-border-radius: 10px ;
68 | }
69 |
70 | a:link,a:visited
71 | {
72 | color: #1367a7;
73 | text-decoration: none;
74 | }
75 | a:hover {
76 | color: #208799;
77 | }
78 |
79 | td.dl-link {
80 | height: 160px;
81 | text-align: center;
82 | font-size: 22px;
83 | }
84 |
85 |
86 | .layered-paper-big { /* modified from: http://css-tricks.com/snippets/css/layered-paper/ */
87 | box-shadow:
88 | 0px 0px 1px 1px rgba(0,0,0,0.35), /* The top layer shadow */
89 | 5px 5px 0 0px #fff, /* The second layer */
90 | 5px 5px 1px 1px rgba(0,0,0,0.35), /* The second layer shadow */
91 | 10px 10px 0 0px #fff, /* The third layer */
92 | 10px 10px 1px 1px rgba(0,0,0,0.35), /* The third layer shadow */
93 | 15px 15px 0 0px #fff, /* The fourth layer */
94 | 15px 15px 1px 1px rgba(0,0,0,0.35), /* The fourth layer shadow */
95 | 20px 20px 0 0px #fff, /* The fifth layer */
96 | 20px 20px 1px 1px rgba(0,0,0,0.35), /* The fifth layer shadow */
97 | 25px 25px 0 0px #fff, /* The fifth layer */
98 | 25px 25px 1px 1px rgba(0,0,0,0.35); /* The fifth layer shadow */
99 | margin-left: 10px;
100 | margin-right: 45px;
101 | }
102 |
103 |
104 | .layered-paper { /* modified from: http://css-tricks.com/snippets/css/layered-paper/ */
105 | box-shadow:
106 | 0px 0px 1px 1px rgba(0,0,0,0.35), /* The top layer shadow */
107 | 5px 5px 0 0px #fff, /* The second layer */
108 | 5px 5px 1px 1px rgba(0,0,0,0.35), /* The second layer shadow */
109 | 10px 10px 0 0px #fff, /* The third layer */
110 | 10px 10px 1px 1px rgba(0,0,0,0.35); /* The third layer shadow */
111 | margin-top: 5px;
112 | margin-left: 10px;
113 | margin-right: 30px;
114 | margin-bottom: 5px;
115 | }
116 |
117 | .vert-cent {
118 | position: relative;
119 | top: 50%;
120 | transform: translateY(-50%);
121 | }
122 |
123 | pre {
124 | overflow-x: auto;
125 | text-align: left;
126 | border: 1px solid grey;
127 | border-radius: 3px;
128 | background: #eee;
129 | padding: 5px 5px 5px 10px;
130 | line-height:1.2
131 | }
132 |
133 | pre code {
134 | text-align: left;
135 | word-wrap: normal;
136 | white-space: pre-wrap;
137 | font-size:12px
138 | }
139 |
140 | hr {
141 | border: 0;
142 | height: 1.5px;
143 | background-image: linear-gradient(to right, rgba(0, 0, 0, 0), rgba(0, 0, 0, 0.75), rgba(0, 0, 0, 0));
144 | }
145 |
146 | .vertical {
147 | display: table-cell;
148 | vertical-align: middle;
149 | }
150 |
151 | /* Below are hover effect - by Zhiqiu */
152 |
153 | .slide-container {
154 | /*position: relative;*/
155 | width: 250px;
156 | overflow: hidden;
157 | }
158 |
159 |
160 | .slide-container img:hover, .slide-container img:active{
161 | /*position: relative;*/
162 | transition: ease-in-out;
163 | transform: translate3d(-250px, 0px, 0px);
164 | }
165 |
166 | .slideup-container {
167 | /*position: relative;*/
168 | height: 125px;
169 | overflow: hidden;
170 | }
171 |
172 | .slideup-container img:hover, .slideup-container img:active{
173 | /*position: relative;*/
174 | transition: ease-in-out;
175 | transform: translate3d(0px, -125px, 0px);
176 | }
177 |
178 | /*.slideup-container img:hover{
179 | position: relative;
180 | transform: translate3d(0px, -125px, 0px);
181 | }*/
182 |
183 | .teaser {
184 | /*position: absolute;*/
185 | width: 270px;
186 | height: 55px;
187 | }
188 |
189 | .teaser-answer {
190 | opacity:0;
191 | transition:1.2s;
192 | }
193 |
194 | .teaser .teaser-answer:hover, .teaser .teaser-answer:active {
195 | opacity:1;
196 | }
197 |
198 | /*.overlay {
199 | position: absolute;
200 | top: 0;
201 | bottom: 0;
202 | left: 0;
203 | right: 0;
204 | height: 100%;
205 | width: 100%;
206 | opacity: 0;
207 | transition: .5s ease;
208 | background-color: #008CBA;
209 | }
210 |
211 | .hover-container:hover .overlay {
212 | opacity: 1;
213 | }*/
214 |
215 | /*.overlay:hover .overlay {
216 | opacity: 0.5;
217 | transition: .5s ease;
218 | }*/
219 |
220 | /*.text {
221 | color: white;
222 | font-size: 20px;
223 | position: absolute;
224 | top: 50%;
225 | left: 50%;
226 | -webkit-transform: translate(-50%, -50%);
227 | -ms-transform: translate(-50%, -50%);
228 | transform: translate(-50%, -50%);
229 | text-align: center;
230 | }*/
231 |
--------------------------------------------------------------------------------
/transform/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/transform/.DS_Store
--------------------------------------------------------------------------------
/transform/__pycache__/randaugment.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/transform/__pycache__/randaugment.cpython-39.pyc
--------------------------------------------------------------------------------
/transform/randaugment.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 |
5 | ## aug functions
6 | def identity_func(img):
7 | return img
8 |
9 |
10 | def autocontrast_func(img, cutoff=0):
11 | '''
12 | same output as PIL.ImageOps.autocontrast
13 | '''
14 | n_bins = 256
15 |
16 | def tune_channel(ch):
17 | n = ch.size
18 | cut = cutoff * n // 100
19 | if cut == 0:
20 | high, low = ch.max(), ch.min()
21 | else:
22 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
23 | low = np.argwhere(np.cumsum(hist) > cut)
24 | low = 0 if low.shape[0] == 0 else low[0]
25 | high = np.argwhere(np.cumsum(hist[::-1]) > cut)
26 | high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
27 | if high <= low:
28 | table = np.arange(n_bins)
29 | else:
30 | scale = (n_bins - 1) / (high - low)
31 | offset = -low * scale
32 | table = np.arange(n_bins) * scale + offset
33 | table[table < 0] = 0
34 | table[table > n_bins - 1] = n_bins - 1
35 | table = table.clip(0, 255).astype(np.uint8)
36 | return table[ch]
37 |
38 | channels = [tune_channel(ch) for ch in cv2.split(img)]
39 | out = cv2.merge(channels)
40 | return out
41 |
42 |
43 | def equalize_func(img):
44 | '''
45 | same output as PIL.ImageOps.equalize
46 | PIL's implementation is different from cv2.equalize
47 | '''
48 | n_bins = 256
49 |
50 | def tune_channel(ch):
51 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
52 | non_zero_hist = hist[hist != 0].reshape(-1)
53 | step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
54 | if step == 0: return ch
55 | n = np.empty_like(hist)
56 | n[0] = step // 2
57 | n[1:] = hist[:-1]
58 | table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
59 | return table[ch]
60 |
61 | channels = [tune_channel(ch) for ch in cv2.split(img)]
62 | out = cv2.merge(channels)
63 | return out
64 |
65 |
66 | def rotate_func(img, degree, fill=(0, 0, 0)):
67 | '''
68 | like PIL, rotate by degree, not radians
69 | '''
70 | H, W = img.shape[0], img.shape[1]
71 | center = W / 2, H / 2
72 | M = cv2.getRotationMatrix2D(center, degree, 1)
73 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
74 | return out
75 |
76 |
77 | def solarize_func(img, thresh=128):
78 | '''
79 | same output as PIL.ImageOps.posterize
80 | '''
81 | table = np.array([el if el < thresh else 255 - el for el in range(256)])
82 | table = table.clip(0, 255).astype(np.uint8)
83 | out = table[img]
84 | return out
85 |
86 |
87 | def color_func(img, factor):
88 | '''
89 | same output as PIL.ImageEnhance.Color
90 | '''
91 | ## implementation according to PIL definition, quite slow
92 | # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
93 | # out = blend(degenerate, img, factor)
94 | # M = (
95 | # np.eye(3) * factor
96 | # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
97 | # )[np.newaxis, np.newaxis, :]
98 | M = (
99 | np.float32([
100 | [0.886, -0.114, -0.114],
101 | [-0.587, 0.413, -0.587],
102 | [-0.299, -0.299, 0.701]]) * factor
103 | + np.float32([[0.114], [0.587], [0.299]])
104 | )
105 | out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
106 | return out
107 |
108 |
109 | def contrast_func(img, factor):
110 | """
111 | same output as PIL.ImageEnhance.Contrast
112 | """
113 | mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
114 | table = np.array([(
115 | el - mean) * factor + mean
116 | for el in range(256)
117 | ]).clip(0, 255).astype(np.uint8)
118 | out = table[img]
119 | return out
120 |
121 |
122 | def brightness_func(img, factor):
123 | '''
124 | same output as PIL.ImageEnhance.Contrast
125 | '''
126 | table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
127 | out = table[img]
128 | return out
129 |
130 |
131 | def sharpness_func(img, factor):
132 | '''
133 | The differences the this result and PIL are all on the 4 boundaries, the center
134 | areas are same
135 | '''
136 | kernel = np.ones((3, 3), dtype=np.float32)
137 | kernel[1][1] = 5
138 | kernel /= 13
139 | degenerate = cv2.filter2D(img, -1, kernel)
140 | if factor == 0.0:
141 | out = degenerate
142 | elif factor == 1.0:
143 | out = img
144 | else:
145 | out = img.astype(np.float32)
146 | degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
147 | out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
148 | out = out.astype(np.uint8)
149 | return out
150 |
151 |
152 | def shear_x_func(img, factor, fill=(0, 0, 0)):
153 | H, W = img.shape[0], img.shape[1]
154 | M = np.float32([[1, factor, 0], [0, 1, 0]])
155 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
156 | return out
157 |
158 |
159 | def translate_x_func(img, offset, fill=(0, 0, 0)):
160 | '''
161 | same output as PIL.Image.transform
162 | '''
163 | H, W = img.shape[0], img.shape[1]
164 | M = np.float32([[1, 0, -offset], [0, 1, 0]])
165 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
166 | return out
167 |
168 |
169 | def translate_y_func(img, offset, fill=(0, 0, 0)):
170 | '''
171 | same output as PIL.Image.transform
172 | '''
173 | H, W = img.shape[0], img.shape[1]
174 | M = np.float32([[1, 0, 0], [0, 1, -offset]])
175 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
176 | return out
177 |
178 |
179 | def posterize_func(img, bits):
180 | '''
181 | same output as PIL.ImageOps.posterize
182 | '''
183 | out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
184 | return out
185 |
186 |
187 | def shear_y_func(img, factor, fill=(0, 0, 0)):
188 | H, W = img.shape[0], img.shape[1]
189 | M = np.float32([[1, 0, 0], [factor, 1, 0]])
190 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
191 | return out
192 |
193 |
194 | def cutout_func(img, pad_size, replace=(0, 0, 0)):
195 | replace = np.array(replace, dtype=np.uint8)
196 | H, W = img.shape[0], img.shape[1]
197 | rh, rw = np.random.random(2)
198 | pad_size = pad_size // 2
199 | ch, cw = int(rh * H), int(rw * W)
200 | x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
201 | y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
202 | out = img.copy()
203 | out[x1:x2, y1:y2, :] = replace
204 | return out
205 |
206 |
207 | ### level to args
208 | def enhance_level_to_args(MAX_LEVEL):
209 | def level_to_args(level):
210 | return ((level / MAX_LEVEL) * 1.8 + 0.1,)
211 | return level_to_args
212 |
213 |
214 | def shear_level_to_args(MAX_LEVEL, replace_value):
215 | def level_to_args(level):
216 | level = (level / MAX_LEVEL) * 0.3
217 | if np.random.random() > 0.5: level = -level
218 | return (level, replace_value)
219 |
220 | return level_to_args
221 |
222 |
223 | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
224 | def level_to_args(level):
225 | level = (level / MAX_LEVEL) * float(translate_const)
226 | if np.random.random() > 0.5: level = -level
227 | return (level, replace_value)
228 |
229 | return level_to_args
230 |
231 |
232 | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
233 | def level_to_args(level):
234 | level = int((level / MAX_LEVEL) * cutout_const)
235 | return (level, replace_value)
236 |
237 | return level_to_args
238 |
239 |
240 | def solarize_level_to_args(MAX_LEVEL):
241 | def level_to_args(level):
242 | level = int((level / MAX_LEVEL) * 256)
243 | return (level, )
244 | return level_to_args
245 |
246 |
247 | def none_level_to_args(level):
248 | return ()
249 |
250 |
251 | def posterize_level_to_args(MAX_LEVEL):
252 | def level_to_args(level):
253 | level = int((level / MAX_LEVEL) * 4)
254 | return (level, )
255 | return level_to_args
256 |
257 |
258 | def rotate_level_to_args(MAX_LEVEL, replace_value):
259 | def level_to_args(level):
260 | level = (level / MAX_LEVEL) * 30
261 | if np.random.random() < 0.5:
262 | level = -level
263 | return (level, replace_value)
264 |
265 | return level_to_args
266 |
267 |
268 | func_dict = {
269 | 'Identity': identity_func,
270 | 'AutoContrast': autocontrast_func,
271 | 'Equalize': equalize_func,
272 | 'Rotate': rotate_func,
273 | 'Solarize': solarize_func,
274 | 'Color': color_func,
275 | 'Contrast': contrast_func,
276 | 'Brightness': brightness_func,
277 | 'Sharpness': sharpness_func,
278 | 'ShearX': shear_x_func,
279 | 'TranslateX': translate_x_func,
280 | 'TranslateY': translate_y_func,
281 | 'Posterize': posterize_func,
282 | 'ShearY': shear_y_func,
283 | }
284 |
285 | translate_const = 10
286 | MAX_LEVEL = 10
287 | replace_value = (128, 128, 128)
288 | arg_dict = {
289 | 'Identity': none_level_to_args,
290 | 'AutoContrast': none_level_to_args,
291 | 'Equalize': none_level_to_args,
292 | 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
293 | 'Solarize': solarize_level_to_args(MAX_LEVEL),
294 | 'Color': enhance_level_to_args(MAX_LEVEL),
295 | 'Contrast': enhance_level_to_args(MAX_LEVEL),
296 | 'Brightness': enhance_level_to_args(MAX_LEVEL),
297 | 'Sharpness': enhance_level_to_args(MAX_LEVEL),
298 | 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
299 | 'TranslateX': translate_level_to_args(
300 | translate_const, MAX_LEVEL, replace_value
301 | ),
302 | 'TranslateY': translate_level_to_args(
303 | translate_const, MAX_LEVEL, replace_value
304 | ),
305 | 'Posterize': posterize_level_to_args(MAX_LEVEL),
306 | 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
307 | }
308 |
309 |
310 | class RandomAugment(object):
311 |
312 | def __init__(self, N=2, M=10, isPIL=False, augs=[]):
313 | self.N = N
314 | self.M = M
315 | self.isPIL = isPIL
316 | if augs:
317 | self.augs = augs
318 | else:
319 | self.augs = list(arg_dict.keys())
320 |
321 | def get_random_ops(self):
322 | sampled_ops = np.random.choice(self.augs, self.N)
323 | return [(op, 0.5, self.M) for op in sampled_ops]
324 |
325 | def __call__(self, img):
326 | if self.isPIL:
327 | img = np.array(img)
328 | ops = self.get_random_ops()
329 | for name, prob, level in ops:
330 | if np.random.random() > prob:
331 | continue
332 | args = arg_dict[name](level)
333 | img = func_dict[name](img, *args)
334 | return img
335 |
336 |
337 | if __name__ == '__main__':
338 | a = RandomAugment()
339 | img = np.random.randn(32, 32, 3)
340 | a(img)
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # adapted from
2 | # https://github.com/VICO-UoE/DatasetCondensation
3 | import time
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import os
9 | import kornia as K
10 | import tqdm
11 | from torch.utils.data import Dataset
12 | from torchvision import datasets, transforms
13 | from scipy.ndimage.interpolation import rotate as scipyrotate
14 | from networks import MLP, ConvNet, LeNet, AlexNet, VGG11BN, VGG11, ResNet18, ResNet18BN_AP, ResNet18_AP, ModifiedResNet, resnet18_gn
15 | import re
16 | import json
17 | import torch.distributed as dist
18 | from tqdm import tqdm
19 | from collections import defaultdict
20 |
21 | class Config:
22 | imagenette = [0, 217, 482, 491, 497, 566, 569, 571, 574, 701]
23 |
24 | # ["australian_terrier", "border_terrier", "samoyed", "beagle", "shih-tzu", "english_foxhound", "rhodesian_ridgeback", "dingo", "golden_retriever", "english_sheepdog"]
25 | imagewoof = [193, 182, 258, 162, 155, 167, 159, 273, 207, 229]
26 |
27 | # ["tabby_cat", "bengal_cat", "persian_cat", "siamese_cat", "egyptian_cat", "lion", "tiger", "jaguar", "snow_leopard", "lynx"]
28 | imagemeow = [281, 282, 283, 284, 285, 291, 292, 290, 289, 287]
29 |
30 | # ["peacock", "flamingo", "macaw", "pelican", "king_penguin", "bald_eagle", "toucan", "ostrich", "black_swan", "cockatoo"]
31 | imagesquawk = [84, 130, 88, 144, 145, 22, 96, 9, 100, 89]
32 |
33 | # ["pineapple", "banana", "strawberry", "orange", "lemon", "pomegranate", "fig", "bell_pepper", "cucumber", "green_apple"]
34 | imagefruit = [953, 954, 949, 950, 951, 957, 952, 945, 943, 948]
35 |
36 | # ["bee", "ladys slipper", "banana", "lemon", "corn", "school_bus", "honeycomb", "lion", "garden_spider", "goldfinch"]
37 | imageyellow = [309, 986, 954, 951, 987, 779, 599, 291, 72, 11]
38 |
39 | dict = {
40 | "imagenette" : imagenette,
41 | "imagewoof" : imagewoof,
42 | "imagefruit": imagefruit,
43 | "imageyellow": imageyellow,
44 | "imagemeow": imagemeow,
45 | "imagesquawk": imagesquawk,
46 | }
47 |
48 | config = Config()
49 |
50 | def get_dataset(args):
51 |
52 | if args.dataset == 'CIFAR10_clip':
53 | channel = 3
54 | im_size = (32, 32)
55 | num_classes = 768
56 | mean = [0.4914, 0.4822, 0.4465]
57 | std = [0.2023, 0.1994, 0.2010]
58 | if args.zca:
59 | transform = transforms.Compose([transforms.ToTensor()])
60 | else:
61 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
62 | dst_train = datasets.CIFAR10(args.data_path, train=True, download=True, transform=transform) # no augmentation
63 | dst_test = datasets.CIFAR10(args.data_path, train=False, download=True, transform=transform)
64 | class_names = dst_train.classes
65 | class_map = {x:x for x in range(num_classes)}
66 |
67 | else:
68 | exit('unknown dataset: %s'%args.dataset)
69 |
70 | if args.zca:
71 | images = []
72 | labels = []
73 | print("Train ZCA")
74 | for i in tqdm(range(len(dst_train))):
75 | im, lab = dst_train[i]
76 | images.append(im)
77 | labels.append(lab)
78 | images = torch.stack(images, dim=0).to(args.device)
79 | labels = torch.tensor(labels, dtype=torch.float, device="cpu")
80 | zca = K.enhance.ZCAWhitening(eps=0.1, compute_inv=True)
81 | zca.fit(images)
82 | zca_images = zca(images).to("cpu")
83 | dst_train = TensorDataset(zca_images, labels)
84 |
85 | images = []
86 | labels = []
87 | print("Test ZCA")
88 | for i in tqdm(range(len(dst_test))):
89 | im, lab = dst_test[i]
90 | images.append(im)
91 | labels.append(lab)
92 | images = torch.stack(images, dim=0).to(args.device)
93 | labels = torch.tensor(labels, dtype=torch.float, device="cpu")
94 |
95 | zca_images = zca(images).to("cpu")
96 | dst_test = TensorDataset(zca_images, labels)
97 |
98 | args.zca_trans = zca
99 |
100 |
101 | testloader = torch.utils.data.DataLoader(dst_test, batch_size=128, shuffle=False, num_workers=2)
102 |
103 | if "flickr" not in args.dataset:
104 | dst_train_label, dst_test_label = None, None
105 | return channel, im_size, num_classes, mean, std, dst_train, dst_test, testloader, dst_train_label, dst_test_label
106 |
107 |
108 |
109 | class TensorDataset(Dataset):
110 | def __init__(self, images, labels): # images: n x c x h x w tensor
111 | self.images = images.detach().float()
112 | self.labels = labels.detach()
113 |
114 | def __getitem__(self, index):
115 | return self.images[index], self.labels[index]
116 |
117 | def __len__(self):
118 | return self.images.shape[0]
119 |
120 |
121 |
122 | def get_default_convnet_setting():
123 | net_width, net_depth, net_act, net_norm, net_pooling = 128, 3, 'relu', 'instancenorm', 'avgpooling'
124 | return net_width, net_depth, net_act, net_norm, net_pooling
125 |
126 |
127 |
128 | def get_RN_network(model, vision_width, vision_layers, embed_dim, image_resolution):
129 | if model == 'RN50':
130 | vision_heads = vision_width * 32 // 64
131 | net = ModifiedResNet(layers=vision_layers,
132 | output_dim=embed_dim,
133 | heads=vision_heads,
134 | input_resolution=image_resolution,
135 | width=vision_width)
136 | if dist:
137 | gpu_num = torch.cuda.device_count()
138 | if gpu_num>0:
139 | device = 'cuda'
140 | if gpu_num>1:
141 | net = nn.DataParallel(net)
142 | else:
143 | device = 'cpu'
144 | net = net.to(device)
145 |
146 | return net
147 |
148 | def get_network(model, channel, num_classes, im_size=(32, 32), dist=True):
149 | torch.random.manual_seed(int(time.time() * 1000) % 100000)
150 | net_width, net_depth, net_act, net_norm, net_pooling = get_default_convnet_setting()
151 |
152 | if model == 'MLP':
153 | net = MLP(channel=channel, num_classes=num_classes)
154 | elif model == 'ConvNet':
155 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
156 | elif model == 'LeNet':
157 | net = LeNet(channel=channel, num_classes=num_classes)
158 | elif model == 'AlexNet':
159 | net = AlexNet(channel=channel, num_classes=num_classes)
160 | elif model == 'VGG11':
161 | net = VGG11( channel=channel, num_classes=num_classes)
162 | elif model == 'VGG11BN':
163 | net = VGG11BN(channel=channel, num_classes=num_classes)
164 | elif model == 'ResNet18':
165 | net = ResNet18(channel=channel, num_classes=num_classes)
166 | elif model == 'ResNet18BN_AP':
167 | net = ResNet18BN_AP(channel=channel, num_classes=num_classes)
168 | elif model == 'ResNet18_AP':
169 | net = ResNet18_AP(channel=channel, num_classes=num_classes)
170 |
171 | elif model == 'ConvNetD1':
172 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=1, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
173 | elif model == 'ConvNetD2':
174 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=2, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
175 | elif model == 'ConvNetD3':
176 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=3, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
177 | elif model == 'ConvNetD4':
178 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=4, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
179 | elif model == 'ConvNetD5':
180 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=5, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
181 | elif model == 'ConvNetD6':
182 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=6, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
183 | elif model == 'ConvNetD7':
184 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=7, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
185 | elif model == 'ConvNetD8':
186 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=8, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
187 |
188 |
189 | elif model == 'ConvNetW32':
190 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=32, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling)
191 | elif model == 'ConvNetW64':
192 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=64, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling)
193 | elif model == 'ConvNetW128':
194 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=128, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling)
195 | elif model == 'ConvNetW256':
196 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=256, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling)
197 | elif model == 'ConvNetW512':
198 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=512, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling)
199 | elif model == 'ConvNetW1024':
200 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=1024, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling)
201 |
202 | elif model == "ConvNetKIP":
203 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=1024, net_depth=net_depth, net_act=net_act,
204 | net_norm="none", net_pooling=net_pooling)
205 |
206 | elif model == 'ConvNetAS':
207 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='sigmoid', net_norm=net_norm, net_pooling=net_pooling)
208 | elif model == 'ConvNetAR':
209 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='relu', net_norm=net_norm, net_pooling=net_pooling)
210 | elif model == 'ConvNetAL':
211 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='leakyrelu', net_norm=net_norm, net_pooling=net_pooling)
212 |
213 | elif model == 'ConvNetNN':
214 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='none', net_pooling=net_pooling)
215 | elif model == 'ConvNetBN':
216 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='batchnorm', net_pooling=net_pooling)
217 | elif model == 'ConvNetLN':
218 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='layernorm', net_pooling=net_pooling)
219 | elif model == 'ConvNetIN':
220 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='instancenorm', net_pooling=net_pooling)
221 | elif model == 'ConvNetGN':
222 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='groupnorm', net_pooling=net_pooling)
223 |
224 | elif model == 'ConvNetNP':
225 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling='none')
226 | elif model == 'ConvNetMP':
227 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling='maxpooling')
228 | elif model == 'ConvNetAP':
229 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling='avgpooling')
230 |
231 |
232 | else:
233 | net = None
234 | exit('DC error: unknown model')
235 |
236 | if dist:
237 | gpu_num = torch.cuda.device_count()
238 | if gpu_num>0:
239 | device = 'cuda'
240 | if gpu_num>1:
241 | net = nn.DataParallel(net)
242 | else:
243 | device = 'cpu'
244 | net = net.to(device)
245 |
246 | return net
247 |
248 |
249 |
250 | def get_time():
251 | return str(time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime()))
252 |
253 |
254 |
255 | def augment(images, dc_aug_param, device):
256 | # This can be sped up in the future.
257 |
258 | if dc_aug_param != None and dc_aug_param['strategy'] != 'none':
259 | scale = dc_aug_param['scale']
260 | crop = dc_aug_param['crop']
261 | rotate = dc_aug_param['rotate']
262 | noise = dc_aug_param['noise']
263 | strategy = dc_aug_param['strategy']
264 |
265 | shape = images.shape
266 | mean = []
267 | for c in range(shape[1]):
268 | mean.append(float(torch.mean(images[:,c])))
269 |
270 | def cropfun(i):
271 | im_ = torch.zeros(shape[1],shape[2]+crop*2,shape[3]+crop*2, dtype=torch.float, device=device)
272 | for c in range(shape[1]):
273 | im_[c] = mean[c]
274 | im_[:, crop:crop+shape[2], crop:crop+shape[3]] = images[i]
275 | r, c = np.random.permutation(crop*2)[0], np.random.permutation(crop*2)[0]
276 | images[i] = im_[:, r:r+shape[2], c:c+shape[3]]
277 |
278 | def scalefun(i):
279 | h = int((np.random.uniform(1 - scale, 1 + scale)) * shape[2])
280 | w = int((np.random.uniform(1 - scale, 1 + scale)) * shape[2])
281 | tmp = F.interpolate(images[i:i + 1], [h, w], )[0]
282 | mhw = max(h, w, shape[2], shape[3])
283 | im_ = torch.zeros(shape[1], mhw, mhw, dtype=torch.float, device=device)
284 | r = int((mhw - h) / 2)
285 | c = int((mhw - w) / 2)
286 | im_[:, r:r + h, c:c + w] = tmp
287 | r = int((mhw - shape[2]) / 2)
288 | c = int((mhw - shape[3]) / 2)
289 | images[i] = im_[:, r:r + shape[2], c:c + shape[3]]
290 |
291 | def rotatefun(i):
292 | im_ = scipyrotate(images[i].cpu().data.numpy(), angle=np.random.randint(-rotate, rotate), axes=(-2, -1), cval=np.mean(mean))
293 | r = int((im_.shape[-2] - shape[-2]) / 2)
294 | c = int((im_.shape[-1] - shape[-1]) / 2)
295 | images[i] = torch.tensor(im_[:, r:r + shape[-2], c:c + shape[-1]], dtype=torch.float, device=device)
296 |
297 | def noisefun(i):
298 | images[i] = images[i] + noise * torch.randn(shape[1:], dtype=torch.float, device=device)
299 |
300 |
301 | augs = strategy.split('_')
302 |
303 | for i in range(shape[0]):
304 | choice = np.random.permutation(augs)[0] # randomly implement one augmentation
305 | if choice == 'crop':
306 | cropfun(i)
307 | elif choice == 'scale':
308 | scalefun(i)
309 | elif choice == 'rotate':
310 | rotatefun(i)
311 | elif choice == 'noise':
312 | noisefun(i)
313 |
314 | return images
315 |
316 |
317 |
318 | def get_daparam(dataset, model, model_eval, ipc):
319 | # We find that augmentation doesn't always benefit the performance.
320 | # So we do augmentation for some of the settings.
321 |
322 | dc_aug_param = dict()
323 | dc_aug_param['crop'] = 4
324 | dc_aug_param['scale'] = 0.2
325 | dc_aug_param['rotate'] = 45
326 | dc_aug_param['noise'] = 0.001
327 | dc_aug_param['strategy'] = 'none'
328 |
329 | if dataset == 'MNIST':
330 | dc_aug_param['strategy'] = 'crop_scale_rotate'
331 |
332 | if model_eval in ['ConvNetBN']: # Data augmentation makes model training with Batch Norm layer easier.
333 | dc_aug_param['strategy'] = 'crop_noise'
334 |
335 | return dc_aug_param
336 |
337 |
338 | def get_eval_pool(eval_mode, model, model_eval):
339 | if eval_mode == 'M': # multiple architectures
340 | # model_eval_pool = ['MLP', 'ConvNet', 'AlexNet', 'VGG11', 'ResNet18', 'LeNet']
341 | model_eval_pool = ['ConvNet', 'AlexNet', 'VGG11', 'ResNet18_AP', 'ResNet18']
342 | # model_eval_pool = ['MLP', 'ConvNet', 'AlexNet', 'VGG11', 'ResNet18']
343 | elif eval_mode == 'W': # ablation study on network width
344 | model_eval_pool = ['ConvNetW32', 'ConvNetW64', 'ConvNetW128', 'ConvNetW256']
345 | elif eval_mode == 'D': # ablation study on network depth
346 | model_eval_pool = ['ConvNetD1', 'ConvNetD2', 'ConvNetD3', 'ConvNetD4']
347 | elif eval_mode == 'A': # ablation study on network activation function
348 | model_eval_pool = ['ConvNetAS', 'ConvNetAR', 'ConvNetAL']
349 | elif eval_mode == 'P': # ablation study on network pooling layer
350 | model_eval_pool = ['ConvNetNP', 'ConvNetMP', 'ConvNetAP']
351 | elif eval_mode == 'N': # ablation study on network normalization layer
352 | model_eval_pool = ['ConvNetNN', 'ConvNetBN', 'ConvNetLN', 'ConvNetIN', 'ConvNetGN']
353 | elif eval_mode == 'S': # itself
354 | model_eval_pool = [model[:model.index('BN')]] if 'BN' in model else [model]
355 | elif eval_mode == 'C':
356 | model_eval_pool = [model, 'ConvNet']
357 | else:
358 | model_eval_pool = [model_eval]
359 | return model_eval_pool
360 |
361 |
362 | class ParamDiffAug():
363 | def __init__(self):
364 | self.aug_mode = 'S' #'multiple or single'
365 | self.prob_flip = 0.5
366 | self.ratio_scale = 1.2
367 | self.ratio_rotate = 15.0
368 | self.ratio_crop_pad = 0.125
369 | self.ratio_cutout = 0.5 # the size would be 0.5x0.5
370 | self.ratio_noise = 0.05
371 | self.brightness = 1.0
372 | self.saturation = 2.0
373 | self.contrast = 0.5
374 |
375 |
376 | def set_seed_DiffAug(param):
377 | if param.latestseed == -1:
378 | return
379 | else:
380 | torch.random.manual_seed(param.latestseed)
381 | param.latestseed += 1
382 |
383 |
384 | def DiffAugment(x, strategy='', seed = -1, param = None):
385 | if seed == -1:
386 | param.batchmode = False
387 | else:
388 | param.batchmode = True
389 |
390 | param.latestseed = seed
391 |
392 | if strategy == 'None' or strategy == 'none':
393 | return x
394 |
395 | if strategy:
396 | if param.aug_mode == 'M': # original
397 | for p in strategy.split('_'):
398 | for f in AUGMENT_FNS[p]:
399 | x = f(x, param)
400 | elif param.aug_mode == 'S':
401 | pbties = strategy.split('_')
402 | set_seed_DiffAug(param)
403 | p = pbties[torch.randint(0, len(pbties), size=(1,)).item()]
404 | for f in AUGMENT_FNS[p]:
405 | x = f(x, param)
406 | else:
407 | exit('Error ZH: unknown augmentation mode.')
408 | x = x.contiguous()
409 | return x
410 |
411 |
412 | # We implement the following differentiable augmentation strategies based on the code provided in https://github.com/mit-han-lab/data-efficient-gans.
413 | def rand_scale(x, param):
414 | # x>1, max scale
415 | # sx, sy: (0, +oo), 1: orignial size, 0.5: enlarge 2 times
416 | ratio = param.ratio_scale
417 | set_seed_DiffAug(param)
418 | sx = torch.rand(x.shape[0]) * (ratio - 1.0/ratio) + 1.0/ratio
419 | set_seed_DiffAug(param)
420 | sy = torch.rand(x.shape[0]) * (ratio - 1.0/ratio) + 1.0/ratio
421 | theta = [[[sx[i], 0, 0],
422 | [0, sy[i], 0],] for i in range(x.shape[0])]
423 | theta = torch.tensor(theta, dtype=torch.float)
424 | if param.batchmode: # batch-wise:
425 | theta[:] = theta[0]
426 | grid = F.affine_grid(theta, x.shape, align_corners=True).to(x.device)
427 | x = F.grid_sample(x, grid, align_corners=True)
428 | return x
429 |
430 |
431 | def rand_rotate(x, param): # [-180, 180], 90: anticlockwise 90 degree
432 | ratio = param.ratio_rotate
433 | set_seed_DiffAug(param)
434 | theta = (torch.rand(x.shape[0]) - 0.5) * 2 * ratio / 180 * float(np.pi)
435 | theta = [[[torch.cos(theta[i]), torch.sin(-theta[i]), 0],
436 | [torch.sin(theta[i]), torch.cos(theta[i]), 0],] for i in range(x.shape[0])]
437 | theta = torch.tensor(theta, dtype=torch.float)
438 | if param.batchmode: # batch-wise:
439 | theta[:] = theta[0]
440 | grid = F.affine_grid(theta, x.shape, align_corners=True).to(x.device)
441 | x = F.grid_sample(x, grid, align_corners=True)
442 | return x
443 |
444 |
445 | def rand_flip(x, param):
446 | prob = param.prob_flip
447 | set_seed_DiffAug(param)
448 | randf = torch.rand(x.size(0), 1, 1, 1, device=x.device)
449 | if param.batchmode: # batch-wise:
450 | randf[:] = randf[0]
451 | return torch.where(randf < prob, x.flip(3), x)
452 |
453 |
454 | def rand_brightness(x, param):
455 | ratio = param.brightness
456 | set_seed_DiffAug(param)
457 | randb = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device)
458 | if param.batchmode: # batch-wise:
459 | randb[:] = randb[0]
460 | x = x + (randb - 0.5)*ratio
461 | return x
462 |
463 |
464 | def rand_saturation(x, param):
465 | ratio = param.saturation
466 | x_mean = x.mean(dim=1, keepdim=True)
467 | set_seed_DiffAug(param)
468 | rands = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device)
469 | if param.batchmode: # batch-wise:
470 | rands[:] = rands[0]
471 | x = (x - x_mean) * (rands * ratio) + x_mean
472 | return x
473 |
474 |
475 | def rand_contrast(x, param):
476 | ratio = param.contrast
477 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
478 | set_seed_DiffAug(param)
479 | randc = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device)
480 | if param.batchmode: # batch-wise:
481 | randc[:] = randc[0]
482 | x = (x - x_mean) * (randc + ratio) + x_mean
483 | return x
484 |
485 |
486 | def rand_crop(x, param):
487 | # The image is padded on its surrounding and then cropped.
488 | ratio = param.ratio_crop_pad
489 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
490 | set_seed_DiffAug(param)
491 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
492 | set_seed_DiffAug(param)
493 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
494 | if param.batchmode: # batch-wise:
495 | translation_x[:] = translation_x[0]
496 | translation_y[:] = translation_y[0]
497 | grid_batch, grid_x, grid_y = torch.meshgrid(
498 | torch.arange(x.size(0), dtype=torch.long, device=x.device),
499 | torch.arange(x.size(2), dtype=torch.long, device=x.device),
500 | torch.arange(x.size(3), dtype=torch.long, device=x.device),
501 | )
502 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
503 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
504 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
505 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
506 | return x
507 |
508 |
509 | def rand_cutout(x, param):
510 | ratio = param.ratio_cutout
511 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
512 | set_seed_DiffAug(param)
513 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
514 | set_seed_DiffAug(param)
515 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
516 | if param.batchmode: # batch-wise:
517 | offset_x[:] = offset_x[0]
518 | offset_y[:] = offset_y[0]
519 | grid_batch, grid_x, grid_y = torch.meshgrid(
520 | torch.arange(x.size(0), dtype=torch.long, device=x.device),
521 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
522 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
523 | )
524 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
525 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
526 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
527 | mask[grid_batch, grid_x, grid_y] = 0
528 | x = x * mask.unsqueeze(1)
529 | return x
530 |
531 |
532 | AUGMENT_FNS = {
533 | 'color': [rand_brightness, rand_saturation, rand_contrast],
534 | 'crop': [rand_crop],
535 | 'cutout': [rand_cutout],
536 | 'flip': [rand_flip],
537 | 'scale': [rand_scale],
538 | 'rotate': [rand_rotate],
539 | }
540 |
541 |
542 | def pre_question(question,max_ques_words=50):
543 | question = re.sub(
544 | r"([.!\"()*#:;~])",
545 | '',
546 | question.lower(),
547 | )
548 | question = question.rstrip(' ')
549 |
550 | #truncate question
551 | question_words = question.split(' ')
552 | if len(question_words)>max_ques_words:
553 | question = ' '.join(question_words[:max_ques_words])
554 |
555 | return question
556 |
557 |
558 | def save_result(result, result_dir, filename, remove_duplicate=''):
559 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank()))
560 | final_result_file = os.path.join(result_dir, '%s.json'%filename)
561 |
562 | json.dump(result,open(result_file,'w'))
563 |
564 | dist.barrier()
565 |
566 | if utils.is_main_process():
567 | # combine results from all processes
568 | result = []
569 |
570 | for rank in range(utils.get_world_size()):
571 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank))
572 | res = json.load(open(result_file,'r'))
573 | result += res
574 |
575 | if remove_duplicate:
576 | result_new = []
577 | id_list = []
578 | for res in result:
579 | if res[remove_duplicate] not in id_list:
580 | id_list.append(res[remove_duplicate])
581 | result_new.append(res)
582 | result = result_new
583 |
584 | json.dump(result,open(final_result_file,'w'))
585 | print('result file saved to %s'%final_result_file)
586 |
587 | return final_result_file
588 |
589 |
590 |
591 | #### everything below is from https://github.com/salesforce/BLIP/blob/main/utils.py
592 | import math
593 |
594 | def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
595 | """Decay the learning rate"""
596 | lr = (init_lr - min_lr) * 0.5 * (1. + math.cos(math.pi * epoch / max_epoch)) + min_lr
597 | for param_group in optimizer.param_groups:
598 | param_group['lr'] = lr
599 |
600 | def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
601 | """Warmup the learning rate"""
602 | lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max_step)
603 | for param_group in optimizer.param_groups:
604 | param_group['lr'] = lr
605 |
606 | def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
607 | """Decay the learning rate"""
608 | lr = max(min_lr, init_lr * (decay_rate**epoch))
609 | for param_group in optimizer.param_groups:
610 | param_group['lr'] = lr
611 |
612 | import numpy as np
613 | import io
614 | import os
615 | import time
616 | from collections import defaultdict, deque
617 | import datetime
618 |
619 | import torch
620 | import torch.distributed as dist
621 |
622 |
623 | class MetricLogger(object):
624 | def __init__(self, delimiter="\t"):
625 | self.meters = defaultdict(SmoothedValue)
626 | self.delimiter = delimiter
627 |
628 | def update(self, **kwargs):
629 | for k, v in kwargs.items():
630 | if isinstance(v, torch.Tensor):
631 | v = v.item()
632 | assert isinstance(v, (float, int))
633 | self.meters[k].update(v)
634 |
635 | def __getattr__(self, attr):
636 | if attr in self.meters:
637 | return self.meters[attr]
638 | if attr in self.__dict__:
639 | return self.__dict__[attr]
640 | raise AttributeError("'{}' object has no attribute '{}'".format(
641 | type(self).__name__, attr))
642 |
643 | def __str__(self):
644 | loss_str = []
645 | for name, meter in self.meters.items():
646 | loss_str.append(
647 | "{}: {}".format(name, str(meter))
648 | )
649 | return self.delimiter.join(loss_str)
650 |
651 | def global_avg(self):
652 | loss_str = []
653 | for name, meter in self.meters.items():
654 | loss_str.append(
655 | "{}: {:.4f}".format(name, meter.global_avg)
656 | )
657 | return self.delimiter.join(loss_str)
658 |
659 | def synchronize_between_processes(self):
660 | for meter in self.meters.values():
661 | meter.synchronize_between_processes()
662 |
663 | def add_meter(self, name, meter):
664 | self.meters[name] = meter
665 |
666 | def log_every(self, iterable, print_freq, header=None):
667 | i = 0
668 | if not header:
669 | header = ''
670 | start_time = time.time()
671 | end = time.time()
672 | iter_time = SmoothedValue(fmt='{avg:.4f}')
673 | data_time = SmoothedValue(fmt='{avg:.4f}')
674 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
675 | log_msg = [
676 | header,
677 | '[{0' + space_fmt + '}/{1}]',
678 | 'eta: {eta}',
679 | '{meters}',
680 | 'time: {time}',
681 | 'data: {data}'
682 | ]
683 | if torch.cuda.is_available():
684 | log_msg.append('max mem: {memory:.0f}')
685 | log_msg = self.delimiter.join(log_msg)
686 | MB = 1024.0 * 1024.0
687 | for obj in iterable:
688 | data_time.update(time.time() - end)
689 | yield obj
690 | iter_time.update(time.time() - end)
691 | if i % print_freq == 0 or i == len(iterable) - 1:
692 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
693 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
694 | if torch.cuda.is_available():
695 | print(log_msg.format(
696 | i, len(iterable), eta=eta_string,
697 | meters=str(self),
698 | time=str(iter_time), data=str(data_time),
699 | memory=torch.cuda.max_memory_allocated() / MB))
700 | else:
701 | print(log_msg.format(
702 | i, len(iterable), eta=eta_string,
703 | meters=str(self),
704 | time=str(iter_time), data=str(data_time)))
705 | i += 1
706 | end = time.time()
707 | total_time = time.time() - start_time
708 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
709 | print('{} Total time: {} ({:.4f} s / it)'.format(
710 | header, total_time_str, total_time / len(iterable)))
711 |
712 |
713 |
714 | class SmoothedValue(object):
715 | """Track a series of values and provide access to smoothed values over a
716 | window or the global series average.
717 | """
718 |
719 | def __init__(self, window_size=20, fmt=None):
720 | if fmt is None:
721 | fmt = "{median:.4f} ({global_avg:.4f})"
722 | self.deque = deque(maxlen=window_size)
723 | self.total = 0.0
724 | self.count = 0
725 | self.fmt = fmt
726 |
727 | def update(self, value, n=1):
728 | self.deque.append(value)
729 | self.count += n
730 | self.total += value * n
731 |
732 | def synchronize_between_processes(self):
733 | """
734 | Warning: does not synchronize the deque!
735 | """
736 | if not is_dist_avail_and_initialized():
737 | return
738 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
739 | dist.barrier()
740 | dist.all_reduce(t)
741 | t = t.tolist()
742 | self.count = int(t[0])
743 | self.total = t[1]
744 |
745 | @property
746 | def median(self):
747 | d = torch.tensor(list(self.deque))
748 | return d.median().item()
749 |
750 | @property
751 | def avg(self):
752 | d = torch.tensor(list(self.deque), dtype=torch.float32)
753 | return d.mean().item()
754 |
755 | @property
756 | def global_avg(self):
757 | return self.total / self.count
758 |
759 | @property
760 | def max(self):
761 | return max(self.deque)
762 |
763 | @property
764 | def value(self):
765 | return self.deque[-1]
766 |
767 | def __str__(self):
768 | return self.fmt.format(
769 | median=self.median,
770 | avg=self.avg,
771 | global_avg=self.global_avg,
772 | max=self.max,
773 | value=self.value)
774 |
775 | class AttrDict(dict):
776 | def __init__(self, *args, **kwargs):
777 | super(AttrDict, self).__init__(*args, **kwargs)
778 | self.__dict__ = self
779 |
780 |
781 | def compute_acc(logits, label, reduction='mean'):
782 | ret = (torch.argmax(logits, dim=1) == label).float()
783 | if reduction == 'none':
784 | return ret.detach()
785 | elif reduction == 'mean':
786 | return ret.mean().item()
787 |
788 | def compute_n_params(model, return_str=True):
789 | tot = 0
790 | for p in model.parameters():
791 | w = 1
792 | for x in p.shape:
793 | w *= x
794 | tot += w
795 | if return_str:
796 | if tot >= 1e6:
797 | return '{:.1f}M'.format(tot / 1e6)
798 | else:
799 | return '{:.1f}K'.format(tot / 1e3)
800 | else:
801 | return tot
802 |
803 | def setup_for_distributed(is_master):
804 | """
805 | This function disables printing when not in master process
806 | """
807 | import builtins as __builtin__
808 | builtin_print = __builtin__.print
809 |
810 | def print(*args, **kwargs):
811 | force = kwargs.pop('force', False)
812 | if is_master or force:
813 | builtin_print(*args, **kwargs)
814 |
815 | __builtin__.print = print
816 |
817 |
818 | def is_dist_avail_and_initialized():
819 | if not dist.is_available():
820 | return False
821 | if not dist.is_initialized():
822 | return False
823 | return True
824 |
825 |
826 | def get_world_size():
827 | if not is_dist_avail_and_initialized():
828 | return 1
829 | return dist.get_world_size()
830 |
831 |
832 | def get_rank():
833 | if not is_dist_avail_and_initialized():
834 | return 0
835 | return dist.get_rank()
836 |
837 |
838 | def is_main_process():
839 | return get_rank() == 0
840 |
841 |
842 | def save_on_master(*args, **kwargs):
843 | if is_main_process():
844 | torch.save(*args, **kwargs)
845 |
846 |
847 | def init_distributed_mode(args):
848 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
849 | args.rank = int(os.environ["RANK"])
850 | args.world_size = int(os.environ['WORLD_SIZE'])
851 | args.gpu = int(os.environ['LOCAL_RANK'])
852 | elif 'SLURM_PROCID' in os.environ:
853 | args.rank = int(os.environ['SLURM_PROCID'])
854 | args.gpu = args.rank % torch.cuda.device_count()
855 | else:
856 | print('Not using distributed mode')
857 | args.distributed = False
858 | return
859 |
860 | args.distributed = True
861 |
862 | torch.cuda.set_device(args.gpu)
863 | args.dist_backend = 'nccl'
864 | print('| distributed init (rank {}, word {}): {}'.format(
865 | args.rank, args.world_size, args.dist_url), flush=True)
866 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
867 | world_size=args.world_size, rank=args.rank)
868 | torch.distributed.barrier()
869 | setup_for_distributed(args.rank == 0)
870 |
871 |
872 | def load_or_process_file(file_type, process_func, args, data_source):
873 | """
874 | Load the processed file if it exists, otherwise process the data source and create the file.
875 |
876 | Args:
877 | file_type: The type of the file (e.g., 'train', 'test').
878 | process_func: The function to process the data source.
879 | args: The arguments required by the process function and to build the filename.
880 | data_source: The source data to be processed.
881 |
882 | Returns:
883 | The loaded data from the file.
884 | """
885 | filename = f'{args.dataset}_{args.text_encoder}_{file_type}_embed.npz'
886 |
887 |
888 | if not os.path.exists(filename):
889 | print(f'Creating {filename}')
890 | process_func(args, data_source)
891 | else:
892 | print(f'Loading {filename}')
893 |
894 | return np.load(filename)
--------------------------------------------------------------------------------
/website/poster.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/website/poster.pdf
--------------------------------------------------------------------------------