├── .gitignore ├── Flickr30k └── ann_file │ ├── flickr30k_test.json │ ├── flickr30k_train.json │ └── flickr30k_val.json ├── README.md ├── __pycache__ ├── epoch.cpython-39.pyc ├── networks.cpython-39.pyc └── utils.cpython-39.pyc ├── buffer.py ├── data ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ ├── coco_dataset.cpython-39.pyc │ └── flickr30k_dataset.cpython-39.pyc ├── cifar_dataset.py ├── coco_dataset.py └── flickr30k_dataset.py ├── distill.py ├── epoch.py ├── images ├── clip_0.png ├── clip_2000.png ├── coco_syn_0.png ├── coco_syn_2000.png ├── logo.png ├── loss.png ├── more_vis.png ├── pipeline.png ├── synthetic_images_0.png ├── synthetic_images_2000.png ├── table.png ├── teaser.png ├── text_noise_0.png ├── text_noise_1.png └── visualization.png ├── index.html ├── model.py ├── networks.py ├── reparam_module.py ├── requirements.yaml ├── style.css ├── transform ├── .DS_Store ├── __pycache__ │ └── randaugment.cpython-39.pyc └── randaugment.py ├── utils.py └── website └── poster.pdf /.gitignore: -------------------------------------------------------------------------------- 1 | **/logged_files 2 | **/wandb 3 | *.png 4 | **/flickr30k_images 5 | **/output 6 | **/buffers 7 | **/data/cifar-10-batches-py 8 | **/data/cifar-10-python.tar.gz 9 | *.npz 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Vision-Language Dataset Distillation 2 | 3 | ### [Project Page](https://princetonvisualai.github.io/multimodal_dataset_distillation/) | [Paper](https://arxiv.org/abs/2308.07545) 4 | TMLR, 2024 5 | 6 | [Xindi Wu](https://xindiwu.github.io/), [Byron Zhang](https://www.linkedin.com/in/byron-zhang), [Zhiwei Deng](https://lucas2012.github.io/), [Olga Russakovsky](https://www.cs.princeton.edu/~olgarus/) 7 | 8 | ![Teaser image](images/teaser.png) 9 | This codebase is for the paper [Vision-Language Dataset Distillation](https://arxiv.org/abs/2308.07545). Please visit our [Project Page](https://princetonvisualai.github.io/multimodal_dataset_distillation/) for detailed results. 10 | 11 | 12 | Dataset distillation methods offer the promise of reducing a large-scale dataset down to a significantly smaller set of (potentially synthetic) training examples, which preserve sufficient information for training a new model from scratch. So far dataset distillation methods have been developed for image classification. However, with the rise in capabilities of vision-language models, and especially given the scale of datasets necessary to train these models, the time is ripe to expand dataset distillation methods beyond image classification. 13 | 14 | ![pipeline](images/pipeline.png) 15 | 16 | In this work, we take the first steps towards this goal by expanding on the idea of trajectory matching to create a distillation method for vision-language datasets. We demonstrate significant improvements on the challenging Flickr30K and COCO retrieval benchmarks: for example, on Flickr30K the best coreset selection method which selects 1000 image-text pairs for training is able to achieve only 5.6% image-to-text retrieval accuracy (i.e., recall@1); in contrast, our dataset distillation approach almost doubles that to 9.9% with just 100 (an order of magnitude fewer) training pairs. 17 | 18 | ![Visualization](images/visualization.png) 19 | ## Getting Started 20 | [Adapted from [mtt-distillaion](https://github.com/GeorgeCazenavette/mtt-distillation) by [George Cazenavette](https://georgecazenavette.github.io) et al.] 21 | First, download our repo: 22 | ```bash 23 | git clone https://github.com/princetonvisualai/multimodal_dataset_distillation.git 24 | cd multimodal_dataset_distillation 25 | ``` 26 | 27 | For an express instillation, we include ```.yaml``` files. 28 | 29 | You need a RTX 30XX GPU (or newer), and run 30 | 31 | ```bash 32 | conda env create -f requirements.yaml 33 | ``` 34 | 35 | You can then activate your conda environment with 36 | ```bash 37 | conda activate vl-distill 38 | ``` 39 | 40 | ## Datasets and Annotations 41 | Please download images for the Flickr30K and COCO datasets, create separate directories for the annotations, linked below. 42 | 43 | Flickr30K: [[Train]](https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json)[[Val]](https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json)[[Test]](https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json)[[Images]](https://www.kaggle.com/datasets/hsankesara/flickr-image-dataset) 44 | 45 | COCO: [[Train]](https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json)[[Val]](https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json)[[Test]](https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json)[[Images]](https://cocodataset.org/#download) 46 | 47 | ## Training Expert Trajectories 48 | The following command generates 20 expert trajectories using NFNet image encoder and BERT text encoder. Traing is done on the Flickr30K dataset, by simultaneously finetuning a pre-trained NFNet model while training a projection layer over the a frozen pre-trained BERT. 49 | ```bash 50 | python buffer.py --dataset=flickr --train_epochs=10 --num_experts=20 --buffer_path={path_to_buffer} --image_encoder=nfnet --text_encoder=bert --image_size=224 --image_root={path_to_image_directory} --ann_root={path_to_annotation_directory} 51 | ``` 52 | 53 | ## Bi-Trajectories Guided Co-Distillation 54 | The following command distills 100 synthetic samples for the Flickr30K dataset given the expert trajectories. 55 | ```bash 56 | python distill.py --dataset=flickr --syn_steps=8 --expert_epochs=1 --max_start_epoch=2 --lr_img=1000 --lr_txt=1000 --lr_lr=1e-02 --buffer_path={path_do_buffer} --num_queries 100 --image_encoder=nfnet --text_encoder=bert --draw True --image_root={path_to_image_directory} --ann_root={path_to_annotation_directory} --save_dir={path_to_saved_distilled_data} 57 | ``` 58 | 59 | ## Acknowledgements 60 | This material is based upon work supported by the National Science Foundation under Grant No. 2107048. Any opinions, findings, and conclusions or recommendations expressed in this material are those of the author(s) and do not necessarily reflect the views of the National Science Foundation. We thank many people from Princeton Visual AI lab (Allison Chen, Jihoon Chung, Tyler Zhu, Ye Zhu, William Yang and Kaiqu Liang) and Princeton NLP group (Carlos E. Jimenez, John Yang), Tiffany Ling and George Cazenavette for their helpful feedback on this work. 61 | 62 | ## Citation 63 | ``` 64 | @article{wu2023multimodal, 65 | title={Multimodal Dataset Distillation for Image-Text Retrieval}, 66 | author={Wu, Xindi and Zhang, Byron and Deng, Zhiwei and Russakovsky, Olga}, 67 | journal={arXiv preprint arXiv:2308.07545}, 68 | year={2023} 69 | } 70 | ``` 71 | 72 | -------------------------------------------------------------------------------- /__pycache__/epoch.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/__pycache__/epoch.cpython-39.pyc -------------------------------------------------------------------------------- /__pycache__/networks.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/__pycache__/networks.cpython-39.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /buffer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from epoch import epoch, epoch_test, itm_eval 5 | import wandb 6 | import warnings 7 | import datetime 8 | from data import get_dataset_flickr, textprocess 9 | from networks import CLIPModel_full 10 | from utils import load_or_process_file 11 | 12 | warnings.filterwarnings("ignore", category=DeprecationWarning) 13 | 14 | def main(args): 15 | #wandb.init(mode="disabled") 16 | wandb.init(project='DatasetDistillation', entity='dataset_distillation', config=args, name=args.name) 17 | 18 | 19 | args.dsa = True if args.dsa == 'True' else False 20 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 21 | args.distributed = torch.cuda.device_count() > 1 22 | 23 | 24 | # print('\n================== Exp %d ==================\n '%exp) 25 | print('Hyper-parameters: \n', args.__dict__) 26 | 27 | save_dir = os.path.join(args.buffer_path, args.dataset) 28 | if args.dataset in ["CIFAR10", "CIFAR100"] and not args.zca: 29 | save_dir += "_NO_ZCA" 30 | save_dir = os.path.join(save_dir, args.image_encoder, args.text_encoder) 31 | if not os.path.exists(save_dir): 32 | os.makedirs(save_dir) 33 | ''' organize the datasets ''' 34 | trainloader, testloader, train_dataset, test_dataset = get_dataset_flickr(args) 35 | data = load_or_process_file('text', textprocess, args, testloader) 36 | bert_test_embed = torch.from_numpy(data['bert_test_embed']).cpu() 37 | 38 | 39 | img_trajectories = [] 40 | txt_trajectories = [] 41 | 42 | for it in range(0, args.num_experts): 43 | 44 | ''' Train synthetic data ''' 45 | 46 | teacher_net = CLIPModel_full(args) 47 | img_teacher_net = teacher_net.image_encoder.to(args.device) 48 | txt_teacher_net = teacher_net.text_projection.to(args.device) 49 | if args.text_trainable: 50 | txt_teacher_net = teacher_net.text_encoder.to(args.device) 51 | if args.distributed: 52 | img_teacher_net = torch.nn.DataParallel(img_teacher_net) 53 | txt_teacher_net = torch.nn.DataParallel(txt_teacher_net) 54 | img_teacher_net.train() 55 | txt_teacher_net.train() 56 | lr_img = args.lr_teacher_img 57 | lr_txt = args.lr_teacher_txt 58 | 59 | teacher_optim_img = torch.optim.SGD(img_teacher_net.parameters(), lr=lr_img, momentum=args.mom, weight_decay=args.l2) 60 | teacher_optim_txt = torch.optim.SGD(txt_teacher_net.parameters(), lr=lr_txt, momentum=args.mom, weight_decay=args.l2) 61 | teacher_optim_img.zero_grad() 62 | teacher_optim_txt.zero_grad() 63 | 64 | img_timestamps = [] 65 | txt_timestamps = [] 66 | 67 | img_timestamps.append([p.detach().cpu() for p in img_teacher_net.parameters()]) 68 | txt_timestamps.append([p.detach().cpu() for p in txt_teacher_net.parameters()]) 69 | 70 | lr_schedule = [args.train_epochs // 2 + 1] 71 | 72 | for e in range(args.train_epochs): 73 | train_loss, train_acc = epoch(e, trainloader, teacher_net, teacher_optim_img, teacher_optim_txt, args) 74 | score_val_i2t, score_val_t2i = epoch_test(testloader, teacher_net, args.device, bert_test_embed) 75 | val_result = itm_eval(score_val_i2t, score_val_t2i, testloader.dataset.txt2img, testloader.dataset.img2txt) 76 | 77 | 78 | wandb.log({"train_loss": train_loss}) 79 | wandb.log({"train_acc": train_acc}) 80 | wandb.log({"txt_r1": val_result['txt_r1']}) 81 | wandb.log({"txt_r5": val_result['txt_r5']}) 82 | wandb.log({"txt_r10": val_result['txt_r10']}) 83 | wandb.log({"txt_r_mean": val_result['txt_r_mean']}) 84 | wandb.log({"img_r1": val_result['img_r1']}) 85 | wandb.log({"img_r5": val_result['img_r5']}) 86 | wandb.log({"img_r10": val_result['img_r10']}) 87 | wandb.log({"img_r_mean": val_result['img_r_mean']}) 88 | wandb.log({"r_mean": val_result['r_mean']}) 89 | 90 | print("Itr: {}\tEpoch: {}\tTrain Acc: {}\tImg R@1: {}\tR@5: {}\tR@10: {}\tR@Mean: {}\tTxt R@1: {}\tR@5: {}\tR@10: {}\tR@Mean: {}".format( 91 | it, e, train_acc, 92 | val_result['img_r1'], val_result['img_r5'], val_result['img_r10'], val_result['img_r_mean'], 93 | val_result['txt_r1'], val_result['txt_r5'], val_result['txt_r10'], val_result['txt_r_mean'])) 94 | img_timestamps.append([p.detach().cpu() for p in img_teacher_net.parameters()]) 95 | txt_timestamps.append([p.detach().cpu() for p in txt_teacher_net.parameters()]) 96 | 97 | if e in lr_schedule and args.decay: 98 | lr *= 0.1 99 | teacher_optim_img = torch.optim.SGD(img_teacher_net.parameters(), lr=lr, momentum=args.mom, weight_decay=args.l2) 100 | teacher_optim_txt = torch.optim.SGD(txt_teacher_net.parameters(), lr=lr, momentum=args.mom, weight_decay=args.l2) 101 | teacher_optim_img.zero_grad() 102 | teacher_optim_txt.zero_grad() 103 | 104 | img_trajectories.append(img_timestamps) 105 | txt_trajectories.append(txt_timestamps) 106 | n = 0 107 | while os.path.exists(os.path.join(save_dir, "img_replay_buffer_{}.pt".format(n))): 108 | n += 1 109 | print("Saving {}".format(os.path.join(save_dir, "img_replay_buffer_{}.pt".format(n)))) 110 | torch.save(img_trajectories, os.path.join(save_dir, "img_replay_buffer_{}.pt".format(n))) 111 | print("Saving {}".format(os.path.join(save_dir, "txt_replay_buffer_{}.pt".format(n)))) 112 | torch.save(txt_trajectories, os.path.join(save_dir, "txt_replay_buffer_{}.pt".format(n))) 113 | 114 | img_trajectories = [] 115 | txt_trajectories = [] 116 | 117 | 118 | if __name__ == '__main__': 119 | parser = argparse.ArgumentParser(description='Parameter Processing') 120 | parser.add_argument('--dataset', type=str, default='flickr', choices=['flickr', 'coco'], help='dataset') 121 | parser.add_argument('--num_experts', type=int, default=100, help='training iterations') 122 | parser.add_argument('--lr_teacher_img', type=float, default=0.1, help='learning rate for updating network parameters') 123 | parser.add_argument('--lr_teacher_txt', type=float, default=0.1, help='learning rate for updating network parameters') 124 | parser.add_argument('--batch_train', type=int, default=128, help='batch size for training networks') 125 | parser.add_argument('--dsa', type=str, default='True', choices=['True', 'False'], 126 | help='whether to use differentiable Siamese augmentation.') 127 | parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', 128 | help='differentiable Siamese augmentation strategy') 129 | parser.add_argument('--data_path', type=str, default='./data/Flickr30k/', help='dataset path') 130 | parser.add_argument('--buffer_path', type=str, default='./buffers', help='buffer path') 131 | parser.add_argument('--train_epochs', type=int, default=50) 132 | parser.add_argument('--zca', action='store_true') 133 | parser.add_argument('--decay', action='store_true') 134 | parser.add_argument('--mom', type=float, default=0, help='momentum') 135 | parser.add_argument('--l2', type=float, default=0, help='l2 regularization') 136 | parser.add_argument('--save_interval', type=int, default=10) 137 | current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") 138 | parser.add_argument('--name', type=str, default=current_time, help='name of wandb run') 139 | parser.add_argument('--text_pretrained', type=bool, default=True, help='text_pretrained') 140 | parser.add_argument('--image_pretrained', type=bool, default=True, help='image_pretrained') 141 | parser.add_argument('--text_trainable', type=bool, default=False, help='text_trainable') 142 | parser.add_argument('--image_trainable', type=bool, default=True, help='image_trainable') 143 | parser.add_argument('--batch_size_train', type=int, default=128, help='batch_size_train') 144 | parser.add_argument('--batch_size_test', type=int, default=128, help='batch_size_test') 145 | parser.add_argument('--image_root', type=str, default='./Flickr30k/flickr-image-dataset/flickr30k-images/', help='location of image root') 146 | parser.add_argument('--ann_root', type=str, default='./Flickr30k/ann_file/', help='location of ann root') 147 | parser.add_argument('--image_size', type=int, default=224, help='image_size') 148 | parser.add_argument('--k_test', type=int, default=128, help='k_test') 149 | parser.add_argument('--load_npy', type=bool, default=False, help='load_npy') 150 | parser.add_argument('--image_encoder', type=str, default='resnet50', choices=['nfnet', 'resnet18_gn', 'vit_tiny', 'nf_resnet50', 'nf_regnet'], help='image encoder') 151 | parser.add_argument('--text_encoder', type=str, default='bert', choices=['bert', 'clip'], help='text encoder') 152 | parser.add_argument('--margin', default=0.2, type=float, 153 | help='Rank loss margin.') 154 | parser.add_argument('--measure', default='cosine', 155 | help='Similarity measure used (cosine|order)') 156 | parser.add_argument('--max_violation', action='store_true', 157 | help='Use max instead of sum in the rank loss.') 158 | parser.add_argument('--only_has_image_projection', type=bool, default=False, help='None') 159 | parser.add_argument('--grounding', type=bool, default=False, help='None') 160 | parser.add_argument('--distill', type=bool, default=False, help='if distill') 161 | args = parser.parse_args() 162 | 163 | main(args) 164 | 165 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from transform.randaugment import RandomAugment 3 | from torchvision.transforms.functional import InterpolationMode 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from torch.utils.data import Dataset 7 | from torchvision.datasets.utils import download_url 8 | import json 9 | from PIL import Image 10 | import os 11 | from torchvision import transforms as T 12 | from networks import CLIPModel_full 13 | from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval 14 | from data.coco_dataset import coco_train, coco_caption_eval, coco_retrieval_eval 15 | import numpy as np 16 | from tqdm import tqdm 17 | @torch.no_grad() 18 | def textprocess(args, testloader): 19 | net = CLIPModel_full(args).to('cuda') 20 | net.eval() 21 | texts = testloader.dataset.text 22 | if args.dataset in ['flickr', 'coco']: 23 | if args.dataset == 'flickr': 24 | bert_test_embed = net.text_encoder(texts) 25 | elif args.dataset == 'coco': 26 | bert_test_embed = torch.cat((net.text_encoder(texts[:10000]), net.text_encoder(texts[10000:20000]), net.text_encoder(texts[20000:])), dim=0) 27 | 28 | bert_test_embed_np = bert_test_embed.cpu().numpy() 29 | np.savez(f'{args.dataset}_{args.text_encoder}_text_embed.npz', bert_test_embed=bert_test_embed_np) 30 | else: 31 | raise NotImplementedError 32 | return 33 | 34 | @torch.no_grad() 35 | def textprocess_train(args, texts): 36 | net = CLIPModel_full(args).to('cuda') 37 | net.eval() 38 | chunk_size = 2000 39 | chunks = [] 40 | for i in tqdm(range(0, len(texts), chunk_size)): 41 | chunk = net.text_encoder(texts[i:i + chunk_size]).cpu() 42 | chunks.append(chunk) 43 | del chunk 44 | torch.cuda.empty_cache() # free up memory 45 | bert_test_embed = torch.cat(chunks, dim=0) 46 | 47 | print('bert_test_embed.shape: ', bert_test_embed.shape) 48 | bert_test_embed_np = bert_test_embed.numpy() 49 | if args.dataset in ['flickr', 'coco']: 50 | np.savez(f'{args.dataset}_{args.text_encoder}_train_text_embed.npz', bert_test_embed=bert_test_embed_np) 51 | else: 52 | raise NotImplementedError 53 | return 54 | 55 | 56 | def create_dataset(args, min_scale=0.5): 57 | normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 58 | transform_train = transforms.Compose([ 59 | transforms.RandomResizedCrop(args.image_size,scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC), 60 | transforms.RandomHorizontalFlip(), 61 | RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize', 62 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), 63 | transforms.ToTensor(), 64 | normalize, 65 | ]) 66 | transform_test = transforms.Compose([ 67 | transforms.Resize((args.image_size, args.image_size),interpolation=InterpolationMode.BICUBIC), 68 | transforms.ToTensor(), 69 | normalize, 70 | ]) 71 | if args.dataset=='flickr': 72 | train_dataset = flickr30k_train(transform_train, args.image_root, args.ann_root) 73 | val_dataset = flickr30k_retrieval_eval(transform_test, args.image_root, args.ann_root, 'val') 74 | test_dataset = flickr30k_retrieval_eval(transform_test, args.image_root, args.ann_root, 'test') 75 | return train_dataset, val_dataset, test_dataset 76 | 77 | elif args.dataset=='coco': 78 | train_dataset = coco_train(transform_train, args.image_root, args.ann_root) 79 | val_dataset = coco_retrieval_eval(transform_test, args.image_root, args.ann_root, 'val') 80 | test_dataset = coco_retrieval_eval(transform_test, args.image_root, args.ann_root, 'test') 81 | return train_dataset, val_dataset, test_dataset 82 | else: 83 | raise NotImplementedError 84 | return train_dataset, val_dataset, test_dataset 85 | 86 | def create_sampler(datasets, shuffles, num_tasks, global_rank): 87 | samplers = [] 88 | for dataset,shuffle in zip(datasets,shuffles): 89 | sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle) 90 | samplers.append(sampler) 91 | return samplers 92 | 93 | 94 | def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): 95 | loaders = [] 96 | for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns): 97 | if is_train: 98 | shuffle = (sampler is None) 99 | drop_last = True 100 | else: 101 | shuffle = False 102 | drop_last = False 103 | loader = DataLoader( 104 | dataset, 105 | batch_size=bs, 106 | num_workers=n_worker, 107 | pin_memory=True, 108 | sampler=sampler, 109 | shuffle=shuffle, 110 | collate_fn=collate_fn, 111 | drop_last=drop_last, 112 | ) 113 | loaders.append(loader) 114 | return loaders 115 | 116 | 117 | def get_dataset_flickr(args): 118 | print("Creating retrieval dataset") 119 | train_dataset, val_dataset, test_dataset = create_dataset(args) 120 | 121 | samplers = [None, None, None] 122 | train_shuffle = True 123 | train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers, 124 | batch_size=[args.batch_size_train]+[args.batch_size_test]*2, 125 | num_workers=[4,4,4], 126 | is_trains=[train_shuffle, False, False], 127 | collate_fns=[None,None,None]) 128 | return train_loader, test_loader, train_dataset, test_dataset 129 | 130 | -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/data/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /data/__pycache__/coco_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/data/__pycache__/coco_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /data/__pycache__/flickr30k_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/data/__pycache__/flickr30k_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /data/cifar_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from typing import Any, Tuple 4 | from torch.utils.data import Dataset 5 | from PIL import Image 6 | import numpy as np 7 | from torchvision.datasets import CIFAR10 8 | from collections import defaultdict 9 | 10 | 11 | CLASSES = [ 12 | "airplane", 13 | "automobile", 14 | "bird", 15 | "cat", 16 | "deer", 17 | "dog", 18 | "frog", 19 | "horse", 20 | "ship", 21 | "truck", 22 | ] 23 | 24 | PROMPTS_1 = ["This is a {}"] 25 | 26 | PROMPTS_5 = [ 27 | "a photo of a {}", 28 | "a blurry image of a {}", 29 | "a photo of the {}", 30 | "a pixelated photo of a {}", 31 | "a picture of a {}", 32 | ] 33 | 34 | PROMPTS = [ 35 | 'a photo of a {}', 36 | 'a blurry photo of a {}', 37 | 'a low contrast photo of a {}', 38 | 'a high contrast photo of a {}', 39 | 'a bad photo of a {}', 40 | 'a good photo of a {}',s 41 | 'a photo of a small {}', 42 | 'a photo of a big {}', 43 | 'a photo of the {}', 44 | 'a blurry photo of the {}', 45 | 'a low contrast photo of the {}', 46 | 'a high contrast photo of the {}', 47 | 'a bad photo of the {}', 48 | 'a good photo of the {}', 49 | 'a photo of the small {}', 50 | 'a photo of the big {}', 51 | ] 52 | 53 | 54 | class cifar10_train(CIFAR10): 55 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False, num_prompts=1): 56 | super(cifar10_train, self).__init__(root, train=train, transform=transform, target_transform=target_transform, download=download) 57 | if num_prompts == 1: 58 | self.prompts = PROMPTS_1 59 | elif num_prompts == 5: 60 | self.prompts = PROMPTS_5 61 | else: 62 | self.prompts = PROMPTS 63 | self.captions = [prompt.format(cls) for cls in CLASSES for prompt in self.prompts] 64 | self.captions_to_label = {cap: i for i, cap in enumerate(self.captions)} 65 | self.annotations = [] 66 | for i in range(len(self.data)): 67 | cls_name = CLASSES[self.targets[i]] 68 | for prompt in self.prompts: 69 | caption = prompt.format(cls_name) 70 | self.annotations.append({"img_id": i, "caption_id": self.captions_to_label[caption]}) 71 | if num_prompts == 1: 72 | self.annotations = self.annotations * 5 73 | 74 | def __len__(self): 75 | return len(self.annotations) 76 | 77 | def __getitem__(self, index): 78 | ann = self.annotations[index] 79 | img_id = ann['img_id'] 80 | img = self.transform(self.data[img_id]) 81 | caption = self.captions[ann['caption_id']] 82 | return img, caption, img_id 83 | 84 | def fetch_distill_images(self, ipc): 85 | """ 86 | Randomly fetch `x` number of images from each class using numpy and return as a tensor. 87 | """ 88 | class_indices = defaultdict(list) 89 | 90 | for idx, label in enumerate(self.targets): 91 | class_indices[label].append(idx) 92 | 93 | # Randomly sample x indices for each class using numpy 94 | sampled_indices = [np.random.choice(indices, ipc, replace=False) for indices in class_indices.values()] 95 | sampled_indices = [idx for class_indices in sampled_indices for idx in class_indices] 96 | 97 | # Fetch images and labels using the selected indices 98 | images = torch.stack([self.transform(self.data[i]) for i in sampled_indices]) 99 | labels = [self.targets[i] for i in sampled_indices] 100 | 101 | captions = [] 102 | for label in labels: 103 | cls_name = CLASSES[label] 104 | prompt = np.random.choice(self.prompts) 105 | random_caption = prompt.format(cls_name) 106 | captions.append(random_caption) 107 | 108 | return images, captions 109 | 110 | def get_all_captions(self): 111 | return self.captions 112 | 113 | 114 | class cifar10_retrieval_eval(cifar10_train): 115 | def __init__(self, root, train=False, transform=None, target_transform=None, download=False, num_prompts=1): 116 | """ 117 | image_root (string): Root directory of images (e.g. coco/images/) 118 | ann_root (string): directory to store the annotation file 119 | split (string): val or test 120 | """ 121 | super(cifar10_retrieval_eval, self).__init__(root, train=train, transform=transform, target_transform=target_transform, download=download, num_prompts=num_prompts) 122 | self.text = self.captions 123 | self.txt2img = {} 124 | self.img2txt = defaultdict(list) 125 | 126 | for ann in self.annotations: 127 | img_id = ann['img_id'] 128 | caption_id = ann['caption_id'] 129 | self.img2txt[img_id].append(caption_id) 130 | self.txt2img[caption_id] = img_id 131 | 132 | def __len__(self): 133 | return len(self.data) 134 | 135 | def __getitem__(self, index): 136 | image = self.transform(self.data[index]) 137 | return image, index 138 | -------------------------------------------------------------------------------- /data/coco_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from torch.utils.data import Dataset 4 | from torchvision.datasets.utils import download_url 5 | from PIL import Image 6 | import re 7 | 8 | def pre_caption(caption,max_words=50): 9 | caption = re.sub( 10 | r"([.!\"()*#:;~])", 11 | ' ', 12 | caption.lower(), 13 | ) 14 | caption = re.sub( 15 | r"\s{2,}", 16 | ' ', 17 | caption, 18 | ) 19 | caption = caption.rstrip('\n') 20 | caption = caption.strip(' ') 21 | 22 | #truncate caption 23 | caption_words = caption.split(' ') 24 | if len(caption_words)>max_words: 25 | caption = ' '.join(caption_words[:max_words]) 26 | 27 | return caption 28 | 29 | class coco_train(Dataset): 30 | def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''): 31 | ''' 32 | image_root (string): Root directory of images (e.g. coco/images/) 33 | ann_root (string): directory to store the annotation file 34 | ''' 35 | url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json' 36 | filename = 'coco_karpathy_train.json' 37 | 38 | download_url(url,ann_root) 39 | 40 | self.annotation = json.load(open(os.path.join(ann_root,filename),'r')) 41 | self.transform = transform 42 | self.image_root = image_root 43 | self.max_words = max_words 44 | self.prompt = prompt 45 | 46 | self.img_ids = {} 47 | n = 0 48 | for ann in self.annotation: 49 | img_id = ann['image_id'] 50 | if img_id not in self.img_ids.keys(): 51 | self.img_ids[img_id] = n 52 | n += 1 53 | 54 | def __len__(self): 55 | return len(self.annotation) 56 | 57 | def __getitem__(self, index): 58 | 59 | ann = self.annotation[index] 60 | 61 | image_path = os.path.join(self.image_root,ann['image']) 62 | image = Image.open(image_path).convert('RGB') 63 | image = self.transform(image) 64 | 65 | caption = self.prompt+pre_caption(ann['caption'], self.max_words) 66 | 67 | return image, caption, self.img_ids[ann['image_id']] 68 | 69 | def get_all_captions(self): 70 | captions = [] 71 | for ann in self.annotation: 72 | caption = self.prompt + pre_caption(ann['caption'], self.max_words) 73 | captions.append(caption) 74 | return captions 75 | 76 | 77 | class coco_caption_eval(Dataset): 78 | def __init__(self, transform, image_root, ann_root, split): 79 | ''' 80 | image_root (string): Root directory of images (e.g. coco/images/) 81 | ann_root (string): directory to store the annotation file 82 | split (string): val or test 83 | ''' 84 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json', 85 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'} 86 | filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'} 87 | 88 | download_url(urls[split],ann_root) 89 | 90 | self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) 91 | self.transform = transform 92 | self.image_root = image_root 93 | 94 | def __len__(self): 95 | return len(self.annotation) 96 | 97 | def __getitem__(self, index): 98 | 99 | ann = self.annotation[index] 100 | 101 | image_path = os.path.join(self.image_root,ann['image']) 102 | image = Image.open(image_path).convert('RGB') 103 | image = self.transform(image) 104 | 105 | img_id = ann['image'].split('/')[-1].strip('.jpg').split('_')[-1] 106 | 107 | return image, int(img_id) 108 | 109 | 110 | class coco_retrieval_eval(Dataset): 111 | def __init__(self, transform, image_root, ann_root, split, max_words=30): 112 | ''' 113 | image_root (string): Root directory of images (e.g. coco/images/) 114 | ann_root (string): directory to store the annotation file 115 | split (string): val or test 116 | ''' 117 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json', 118 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'} 119 | filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'} 120 | 121 | download_url(urls[split],ann_root) 122 | 123 | self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) 124 | self.transform = transform 125 | self.image_root = image_root 126 | 127 | self.text = [] 128 | self.image = [] 129 | self.txt2img = {} 130 | self.img2txt = {} 131 | 132 | txt_id = 0 133 | for img_id, ann in enumerate(self.annotation): 134 | self.image.append(ann['image']) 135 | self.img2txt[img_id] = [] 136 | for i, caption in enumerate(ann['caption']): 137 | self.text.append(pre_caption(caption,max_words)) 138 | self.img2txt[img_id].append(txt_id) 139 | self.txt2img[txt_id] = img_id 140 | txt_id += 1 141 | 142 | def __len__(self): 143 | return len(self.annotation) 144 | 145 | def __getitem__(self, index): 146 | 147 | image_path = os.path.join(self.image_root, self.annotation[index]['image']) 148 | image = Image.open(image_path).convert('RGB') 149 | image = self.transform(image) 150 | 151 | return image, index -------------------------------------------------------------------------------- /data/flickr30k_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from tqdm import tqdm 4 | import numpy as np 5 | import yaml 6 | from transformers import BertTokenizer, BertModel 7 | from torchvision import transforms as T 8 | from torchvision.transforms.functional import InterpolationMode 9 | from torch.utils.data import Dataset 10 | from torchvision.datasets.utils import download_url 11 | import argparse 12 | import re 13 | import json 14 | from PIL import Image 15 | 16 | def pre_caption(caption,max_words=50): 17 | caption = re.sub( 18 | r"([.!\"()*#:;~])", 19 | ' ', 20 | caption.lower(), 21 | ) 22 | caption = re.sub( 23 | r"\s{2,}", 24 | ' ', 25 | caption, 26 | ) 27 | caption = caption.rstrip('\n') 28 | caption = caption.strip(' ') 29 | 30 | #truncate caption 31 | caption_words = caption.split(' ') 32 | if len(caption_words)>max_words: 33 | caption = ' '.join(caption_words[:max_words]) 34 | 35 | return caption 36 | 37 | 38 | class flickr30k_train(Dataset): 39 | def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''): 40 | ''' 41 | image_root (string): Root directory of images (e.g. flickr30k/) 42 | ann_root (string): directory to store the annotation file 43 | ''' 44 | url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json' 45 | filename = 'flickr30k_train.json' 46 | 47 | download_url(url,ann_root) 48 | 49 | self.annotation = json.load(open(os.path.join(ann_root,filename),'r')) 50 | self.transform = transform 51 | self.image_root = image_root 52 | self.max_words = max_words 53 | self.prompt = prompt 54 | 55 | self.img_ids = {} 56 | n = 0 57 | for ann in self.annotation: 58 | img_id = ann['image_id'] 59 | if img_id not in self.img_ids.keys(): 60 | self.img_ids[img_id] = n 61 | n += 1 62 | 63 | def __len__(self): 64 | return len(self.annotation) 65 | 66 | def __getitem__(self, index): 67 | 68 | ann = self.annotation[index] 69 | 70 | image_path = os.path.join(self.image_root,ann['image']) 71 | image = Image.open(image_path).convert('RGB') 72 | image = self.transform(image) 73 | 74 | caption = self.prompt+pre_caption(ann['caption'], self.max_words) 75 | 76 | return image, caption, self.img_ids[ann['image_id']] 77 | 78 | def get_all_captions(self): 79 | captions = [] 80 | for ann in self.annotation: 81 | caption = self.prompt + pre_caption(ann['caption'], self.max_words) 82 | captions.append(caption) 83 | return captions 84 | 85 | 86 | 87 | class flickr30k_retrieval_eval(Dataset): 88 | def __init__(self, transform, image_root, ann_root, split, max_words=30): 89 | ''' 90 | image_root (string): Root directory of images (e.g. flickr30k/) 91 | ann_root (string): directory to store the annotation file 92 | split (string): val or test 93 | ''' 94 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json', 95 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'} 96 | filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'} 97 | 98 | download_url(urls[split],ann_root) 99 | 100 | self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) 101 | self.transform = transform 102 | self.image_root = image_root 103 | self.max_words = max_words 104 | 105 | self.text = [] 106 | self.image = [] 107 | self.txt2img = {} 108 | self.img2txt = {} 109 | 110 | txt_id = 0 111 | for img_id, ann in enumerate(self.annotation): 112 | self.image.append(ann['image']) 113 | self.img2txt[img_id] = [] 114 | for i, caption in enumerate(ann['caption']): 115 | self.text.append(pre_caption(caption,max_words)) 116 | self.img2txt[img_id].append(txt_id) 117 | self.txt2img[txt_id] = img_id 118 | txt_id += 1 119 | 120 | def __len__(self): 121 | return len(self.annotation) 122 | 123 | def __getitem__(self, index): 124 | image_path = os.path.join(self.image_root, self.annotation[index]['image']) 125 | image = Image.open(image_path).convert('RGB') 126 | image = self.transform(image) 127 | 128 | return image, index 129 | 130 | 131 | -------------------------------------------------------------------------------- /distill.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import datetime 4 | import os 5 | import random 6 | import sys 7 | import warnings 8 | 9 | import numpy as np 10 | import pandas as pd 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torchvision.utils 15 | from sklearn.metrics.pairwise import cosine_similarity 16 | from tqdm import tqdm 17 | import math 18 | 19 | from transformers import BertTokenizer, BertConfig, BertModel 20 | import wandb 21 | 22 | from data import get_dataset_flickr, textprocess, textprocess_train 23 | from epoch import evaluate_synset, epoch, epoch_test, itm_eval 24 | from networks import CLIPModel_full, TextEncoder 25 | from reparam_module import ReparamModule 26 | from utils import DiffAugment, ParamDiffAug, TensorDataset, get_dataset, get_network, get_eval_pool, get_time, load_or_process_file 27 | 28 | 29 | def shuffle_files(img_expert_files, txt_expert_files): 30 | # Check if both lists have the same length and if the lists are not empty 31 | assert len(img_expert_files) == len(txt_expert_files), "Number of image files and text files does not match" 32 | assert len(img_expert_files) != 0, "No files to shuffle" 33 | shuffled_indices = np.random.permutation(len(img_expert_files)) 34 | 35 | # Apply the shuffled indices to both lists 36 | img_expert_files = np.take(img_expert_files, shuffled_indices) 37 | txt_expert_files = np.take(txt_expert_files, shuffled_indices) 38 | print(f"img_expert_files: {img_expert_files}") 39 | print(f"txt_expert_files: {txt_expert_files}") 40 | return img_expert_files, txt_expert_files 41 | 42 | def nearest_neighbor(sentences, query_embeddings, database_embeddings): 43 | """Find the nearest neighbors for a batch of embeddings. 44 | 45 | Args: 46 | sentences: The original sentences from which the embeddings were computed. 47 | query_embeddings: A batch of embeddings for which to find the nearest neighbors. 48 | database_embeddings: All pre-computed embeddings. 49 | 50 | Returns: 51 | A list of the most similar sentences for each embedding in the batch. 52 | """ 53 | nearest_neighbors = [] 54 | 55 | for query in query_embeddings: 56 | similarities = cosine_similarity(query.reshape(1, -1), database_embeddings) 57 | 58 | most_similar_index = np.argmax(similarities) 59 | 60 | nearest_neighbors.append(sentences[most_similar_index]) 61 | 62 | return nearest_neighbors 63 | 64 | 65 | def get_images_texts(n, dataset): 66 | """Get random n images and corresponding texts from the dataset. 67 | 68 | Args: 69 | n: Number of images and texts to retrieve. 70 | dataset: The dataset containing image-text pairs. 71 | 72 | Returns: 73 | A tuple containing two elements: 74 | - A tensor of randomly selected images. 75 | - A tensor of the corresponding texts, encoded as floats. 76 | """ 77 | # Generate n unique random indices 78 | idx_shuffle = np.random.permutation(len(dataset))[:n] 79 | 80 | # Initialize the text encoder 81 | text_encoder = TextEncoder(args) 82 | 83 | image_syn = torch.stack([dataset[i][0] for i in idx_shuffle]) 84 | text_syn = text_encoder([dataset[i][1] for i in idx_shuffle], device="cpu") 85 | 86 | return image_syn, text_syn.float() 87 | 88 | 89 | def main(args): 90 | ''' organize the real train dataset ''' 91 | trainloader, testloader, train_dataset, test_dataset = get_dataset_flickr(args) 92 | 93 | train_sentences = train_dataset.get_all_captions() 94 | 95 | data = load_or_process_file('text', textprocess, args, testloader) 96 | train_caption = load_or_process_file('train_text', textprocess_train, args, train_sentences) 97 | 98 | bert_test_embed = torch.from_numpy(data['bert_test_embed']).cpu() 99 | print("The shape of bert_test_embed: {}".format(bert_test_embed.shape)) 100 | train_caption_embed = torch.from_numpy(train_caption['bert_test_embed']).cpu() 101 | print("The shape of train_caption_embed: {}".format(train_caption_embed.shape)) 102 | 103 | logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 104 | if args.zca and args.texture: 105 | raise AssertionError("Cannot use zca and texture together") 106 | 107 | if args.texture and args.pix_init == "real": 108 | print("WARNING: Using texture with real initialization will take a very long time to smooth out the boundaries between images.") 109 | 110 | if args.max_experts is not None and args.max_files is not None: 111 | args.total_experts = args.max_experts * args.max_files 112 | 113 | print("CUDNN STATUS: {}".format(torch.backends.cudnn.enabled)) 114 | 115 | args.dsa = True if args.dsa == 'True' else False 116 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 117 | 118 | eval_it_pool = np.arange(0, args.Iteration + 1, args.eval_it).tolist() 119 | 120 | if args.dsa: 121 | args.dc_aug_param = None 122 | 123 | # wandb.init(mode="disabled") 124 | wandb.init(project='DatasetDistillation', entity='dataset_distillation', config=args, name=args.name) 125 | 126 | args.dsa_param = ParamDiffAug() 127 | zca_trans = args.zca_trans if args.zca else None 128 | args.zca_trans = zca_trans 129 | args.distributed = torch.cuda.device_count() > 1 130 | 131 | print('Hyper-parameters: \n', args.__dict__) 132 | syn_lr_img = torch.tensor(args.lr_teacher_img).to(args.device) 133 | syn_lr_txt = torch.tensor(args.lr_teacher_txt).to(args.device) 134 | 135 | ''' initialize the synthetic data ''' 136 | image_syn, text_syn = get_images_texts(args.num_queries, train_dataset) 137 | 138 | if args.pix_init == 'noise': 139 | mean = torch.tensor([-0.0626, -0.0221, 0.0680]) 140 | std = torch.tensor([1.0451, 1.0752, 1.0539]) 141 | image_syn = torch.randn([args.num_queries, 3, 224, 224]) 142 | for c in range(3): 143 | image_syn[:, c] = image_syn[:, c] * std[c] + mean[c] 144 | print('Initialized synthetic image from random noise') 145 | 146 | if args.txt_init == 'noise': 147 | text_syn = torch.normal(mean=-0.0094, std=0.5253, size=(args.num_queries, 768)) 148 | print('Initialized synthetic text from random noise') 149 | 150 | 151 | ''' training ''' 152 | image_syn = image_syn.detach().to(args.device).requires_grad_(True) 153 | optimizer_img = torch.optim.SGD([image_syn], lr=args.lr_img, momentum=0.5) 154 | optimizer_img.zero_grad() 155 | 156 | syn_lr_img = syn_lr_img.to(args.device).requires_grad_(True) 157 | syn_lr_txt = syn_lr_txt.to(args.device).requires_grad_(True) 158 | optimizer_lr = torch.optim.SGD([syn_lr_img, syn_lr_txt], lr=args.lr_lr, momentum=0.5) 159 | 160 | text_syn = text_syn.detach().to(args.device).requires_grad_(True) 161 | optimizer_txt = torch.optim.SGD([text_syn], lr=args.lr_txt, momentum=0.5) 162 | optimizer_txt.zero_grad() 163 | sentence_list = nearest_neighbor(train_sentences, text_syn.detach().cpu(), train_caption_embed) 164 | if args.draw: 165 | wandb.log({"original_sentence_list": wandb.Html('
'.join(sentence_list))}) 166 | wandb.log({"original_synthetic_images": wandb.Image(torch.nan_to_num(image_syn.detach().cpu()))}) 167 | 168 | criterion = nn.CrossEntropyLoss().to(args.device) 169 | print('%s training begins'%get_time()) 170 | 171 | expert_dir = os.path.join(args.buffer_path, args.dataset) 172 | expert_dir = args.buffer_path 173 | print("Expert Dir: {}".format(expert_dir)) 174 | 175 | 176 | img_expert_files = [] 177 | txt_expert_files = [] 178 | n = 0 179 | while os.path.exists(os.path.join(expert_dir, "img_replay_buffer_{}.pt".format(n))): 180 | img_expert_files.append(os.path.join(expert_dir, "img_replay_buffer_{}.pt".format(n))) 181 | txt_expert_files.append(os.path.join(expert_dir, "txt_replay_buffer_{}.pt".format(n))) 182 | n += 1 183 | if n == 0: 184 | raise AssertionError("No buffers detected at {}".format(expert_dir)) 185 | 186 | img_expert_files, txt_expert_files = shuffle_files(img_expert_files, txt_expert_files) 187 | 188 | file_idx = 0 189 | expert_idx = 0 190 | print("loading file {}".format(img_expert_files[file_idx])) 191 | print("loading file {}".format(txt_expert_files[file_idx])) 192 | 193 | img_buffer = torch.load(img_expert_files[file_idx]) 194 | txt_buffer = torch.load(txt_expert_files[file_idx]) 195 | 196 | for it in tqdm(range(args.Iteration + 1)): 197 | save_this_it = True 198 | 199 | wandb.log({"Progress": it}, step=it) 200 | ''' Evaluate synthetic data ''' 201 | if it in eval_it_pool: 202 | print('-------------------------\nEvaluation\nimage_model_train = %s, text_model_train = %s, iteration = %d'%(args.image_encoder, args.text_encoder, it)) 203 | if args.dsa: 204 | print('DSA augmentation strategy: \n', args.dsa_strategy) 205 | print('DSA augmentation parameters: \n', args.dsa_param.__dict__) 206 | else: 207 | print('DC augmentation parameters: \n', args.dc_aug_param) 208 | 209 | accs_train = [] 210 | img_r1s = [] 211 | img_r5s = [] 212 | img_r10s = [] 213 | img_r_means = [] 214 | 215 | txt_r1s = [] 216 | txt_r5s = [] 217 | txt_r10s = [] 218 | txt_r_means = [] 219 | 220 | r_means = [] 221 | for it_eval in range(args.num_eval): 222 | net_eval = CLIPModel_full(args, eval_stage=args.transfer) 223 | 224 | with torch.no_grad(): 225 | image_save = image_syn 226 | text_save = text_syn 227 | image_syn_eval, text_syn_eval = copy.deepcopy(image_save.detach()), copy.deepcopy(text_save.detach()) # avoid any unaware modification 228 | 229 | args.lr_net = syn_lr_img.item() 230 | print(image_syn_eval.shape) 231 | _, acc_train, val_result = evaluate_synset(it_eval, net_eval, image_syn_eval, text_syn_eval, testloader, args, bert_test_embed) 232 | print('Evaluate_%02d: Img R@1 = %.4f, Img R@5 = %.4f, Img R@10 = %.4f, Img R@Mean = %.4f, Txt R@1 = %.4f, Txt R@5 = %.4f, Txt R@10 = %.4f, Txt R@Mean = %.4f, R@Mean = %.4f' % 233 | (it_eval, 234 | val_result['img_r1'], val_result['img_r5'], val_result['img_r10'], val_result['img_r_mean'], 235 | val_result['txt_r1'], val_result['txt_r5'], val_result['txt_r10'], val_result['txt_r_mean'], 236 | val_result['r_mean'])) 237 | 238 | img_r1s.append(val_result['img_r1']) 239 | img_r5s.append(val_result['img_r5']) 240 | img_r10s.append(val_result['img_r10']) 241 | img_r_means.append(val_result['img_r_mean']) 242 | 243 | txt_r1s.append(val_result['txt_r1']) 244 | txt_r5s.append(val_result['txt_r5']) 245 | txt_r10s.append(val_result['txt_r10']) 246 | txt_r_means.append(val_result['txt_r_mean']) 247 | r_means.append(val_result['r_mean']) 248 | 249 | if not args.std: 250 | wandb.log({"txt_r1": val_result['txt_r1']}) 251 | wandb.log({"txt_r5": val_result['txt_r5']}) 252 | wandb.log({"txt_r10": val_result['txt_r10']}) 253 | wandb.log({"txt_r_mean": val_result['txt_r_mean']}) 254 | wandb.log({"img_r1": val_result['img_r1']}) 255 | wandb.log({"img_r5": val_result['img_r5']}) 256 | wandb.log({"img_r10": val_result['img_r10']}) 257 | wandb.log({"img_r_mean": val_result['img_r_mean']}) 258 | wandb.log({"r_mean": val_result['r_mean']}) 259 | if args.std: 260 | img_r1_mean, img_r1_std = np.mean(img_r1s), np.std(img_r1s) 261 | img_r5_mean, img_r5_std = np.mean(img_r5s), np.std(img_r5s) 262 | img_r10_mean, img_r10_std = np.mean(img_r10s), np.std(img_r10s) 263 | img_r_mean_mean, img_r_mean_std = np.mean(img_r_means), np.std(img_r_means) 264 | 265 | txt_r1_mean, txt_r1_std = np.mean(txt_r1s), np.std(txt_r1s) 266 | txt_r5_mean, txt_r5_std = np.mean(txt_r5s), np.std(txt_r5s) 267 | txt_r10_mean, txt_r10_std = np.mean(txt_r10s), np.std(txt_r10s) 268 | txt_r_mean_mean, txt_r_mean_std = np.mean(txt_r_means), np.std(txt_r_means) 269 | r_mean_mean, r_mean_std = np.mean(r_means), np.std(r_means) 270 | 271 | wandb.log({'Mean/txt_r1': txt_r1_mean, 'Std/txt_r1': txt_r1_std}) 272 | wandb.log({'Mean/txt_r5': txt_r5_mean, 'Std/txt_r5': txt_r5_std}) 273 | wandb.log({'Mean/txt_r10': txt_r10_mean, 'Std/txt_r10': txt_r10_std}) 274 | wandb.log({'Mean/txt_r_mean': txt_r_mean_mean, 'Std/txt_r_mean': txt_r_mean_std}) 275 | wandb.log({'Mean/img_r1': img_r1_mean, 'Std/img_r1': img_r1_std}) 276 | wandb.log({'Mean/img_r5': img_r5_mean, 'Std/img_r5': img_r5_std}) 277 | wandb.log({'Mean/img_r10': img_r10_mean, 'Std/img_r10': img_r10_std}) 278 | wandb.log({'Mean/img_r_mean': img_r_mean_mean, 'Std/img_r_mean': img_r_mean_std}) 279 | wandb.log({'Mean/r_mean': r_mean_mean, 'Std/r_mean': r_mean_std}) 280 | 281 | if it in eval_it_pool and (save_this_it or it % 1000 == 0): 282 | if args.draw: 283 | with torch.no_grad(): 284 | image_save = image_syn_eval.cuda() 285 | text_save = text_syn_eval.cuda() 286 | save_dir = os.path.join(".", "logged_files", args.dataset, wandb.run.name) 287 | print("Saving to {}".format(save_dir)) 288 | 289 | if not os.path.exists(save_dir): 290 | os.makedirs(save_dir) 291 | 292 | #torch.save(image_save, os.path.join(save_dir, "images_{}.pt".format(it))) 293 | #torch.save(text_save, os.path.join(save_dir, "labels_{}.pt".format(it))) 294 | 295 | #torch.save(image_save, os.path.join(save_dir, "images_best.pt".format(it))) 296 | #torch.save(text_save, os.path.join(save_dir, "labels_best.pt".format(it))) 297 | 298 | wandb.log({"Pixels": wandb.Histogram(torch.nan_to_num(image_syn.detach().cpu()))}, step=it) # Move tensor to CPU before converting to NumPy 299 | 300 | if args.ipc < 50 or args.force_save: 301 | upsampled = image_save[:90] 302 | if args.dataset != "ImageNet": 303 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=2) 304 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3) 305 | grid = torchvision.utils.make_grid(upsampled, nrow=10, normalize=True, scale_each=True) 306 | sentence_list = nearest_neighbor(train_sentences, text_syn.cpu(), train_caption_embed) 307 | sentence_list = sentence_list[:90] 308 | torchvision.utils.save_image(grid, os.path.join(save_dir, "synthetic_images_{}.png".format(it))) 309 | 310 | with open(os.path.join(save_dir, "synthetic_sentences_{}.txt".format(it)), "w") as file: 311 | file.write('\n'.join(sentence_list)) 312 | wandb.log({"Synthetic_Images": wandb.Image(torch.nan_to_num(grid.detach().cpu()))}, step=it) 313 | wandb.log({'Synthetic_Pixels': wandb.Histogram(torch.nan_to_num(image_save.detach().cpu()))}, step=it) 314 | wandb.log({"Synthetic_Sentences": wandb.Html('
'.join(sentence_list))}, step=it) 315 | print("finish saving images") 316 | 317 | for clip_val in [2.5]: 318 | std = torch.std(image_save) 319 | mean = torch.mean(image_save) 320 | upsampled = torch.clip(image_save, min=mean-clip_val*std, max=mean+clip_val*std).cpu() # Move to CPU 321 | if args.dataset != "ImageNet": 322 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=2) 323 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3) 324 | grid = torchvision.utils.make_grid(upsampled[:90], nrow=10, normalize=True, scale_each=True) 325 | wandb.log({"Clipped_Synthetic_Images/std_{}".format(clip_val): wandb.Image(torch.nan_to_num(grid))}, step=it) 326 | torchvision.utils.save_image(grid, os.path.join(save_dir, "clipped_synthetic_images_{}_std_{}.png".format(it, clip_val))) 327 | 328 | 329 | if args.zca: 330 | image_save = image_save.to(args.device) 331 | image_save = args.zca_trans.inverse_transform(image_save.cpu()) # Move to CPU for ZCA transformation 332 | torch.save(image_save, os.path.join(save_dir, "images_zca_{}.pt".format(it))) 333 | 334 | upsampled = image_save 335 | if args.dataset != "ImageNet": 336 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=2) 337 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3) 338 | grid = torchvision.utils.make_grid(upsampled, nrow=10, normalize=True, scale_each=True) 339 | wandb.log({"Reconstructed_Images": wandb.Image(torch.nan_to_num(grid))}, step=it) # Log GPU tensor directly 340 | wandb.log({'Reconstructed_Pixels': wandb.Histogram(torch.nan_to_num(image_save.detach().cpu()))}, step=it) 341 | 342 | for clip_val in [2.5]: 343 | std = torch.std(image_save) 344 | mean = torch.mean(image_save) 345 | upsampled = torch.clip(image_save, min=mean - clip_val * std, max=mean + clip_val * std) 346 | if args.dataset != "ImageNet": 347 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=2) 348 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3) 349 | grid = torchvision.utils.make_grid(upsampled, nrow=10, normalize=True, scale_each=True) 350 | wandb.log({"Clipped_Reconstructed_Images/std_{}".format(clip_val): wandb.Image( 351 | torch.nan_to_num(grid.detach().cpu()))}, step=it) 352 | 353 | wandb.log({"Synthetic_LR_Image": syn_lr_img.detach().cpu()}, step=it) 354 | wandb.log({"Synthetic_LR_Text": syn_lr_txt.detach().cpu()}, step=it) 355 | 356 | torch.cuda.empty_cache() 357 | student_net = CLIPModel_full(args) 358 | img_student_net = ReparamModule(student_net.image_encoder.to('cpu')).to('cuda') 359 | txt_student_net = ReparamModule(student_net.text_projection.to('cpu')).to('cuda') 360 | 361 | if args.distributed: 362 | img_student_net = torch.nn.DataParallel(img_student_net) 363 | txt_student_net = torch.nn.DataParallel(txt_student_net) 364 | 365 | img_student_net.train() 366 | txt_student_net.train() 367 | img_num_params = sum([np.prod(p.size()) for p in (img_student_net.parameters())]) 368 | txt_num_params = sum([np.prod(p.size()) for p in (txt_student_net.parameters())]) 369 | 370 | 371 | img_expert_trajectory = img_buffer[expert_idx] 372 | txt_expert_trajectory = txt_buffer[expert_idx] 373 | expert_idx += 1 374 | if expert_idx == len(img_buffer): 375 | expert_idx = 0 376 | file_idx += 1 377 | if file_idx == len(img_expert_files): 378 | file_idx = 0 379 | img_expert_files, txt_expert_files = shuffle_files(img_expert_files, txt_expert_files) 380 | print("loading file {}".format(img_expert_files[file_idx])) 381 | print("loading file {}".format(txt_expert_files[file_idx])) 382 | if args.max_files != 1: 383 | del img_buffer 384 | del txt_buffer 385 | img_buffer = torch.load(img_expert_files[file_idx]) 386 | txt_buffer = torch.load(txt_expert_files[file_idx]) 387 | 388 | start_epoch = np.random.randint(0, args.max_start_epoch) 389 | img_starting_params = img_expert_trajectory[start_epoch] 390 | txt_starting_params = txt_expert_trajectory[start_epoch] 391 | 392 | img_target_params = img_expert_trajectory[start_epoch+args.expert_epochs] 393 | txt_target_params = txt_expert_trajectory[start_epoch+args.expert_epochs] 394 | 395 | img_target_params = torch.cat([p.data.to(args.device).reshape(-1) for p in img_target_params], 0) 396 | txt_target_params = torch.cat([p.data.to(args.device).reshape(-1) for p in txt_target_params], 0) 397 | 398 | img_student_params = [torch.cat([p.data.to(args.device).reshape(-1) for p in img_starting_params], 0).requires_grad_(True)] 399 | txt_student_params = [torch.cat([p.data.to(args.device).reshape(-1) for p in txt_starting_params], 0).requires_grad_(True)] 400 | 401 | img_starting_params = torch.cat([p.data.to(args.device).reshape(-1) for p in img_starting_params], 0) 402 | txt_starting_params = torch.cat([p.data.to(args.device).reshape(-1) for p in txt_starting_params], 0) 403 | syn_images = image_syn 404 | syn_texts = text_syn 405 | 406 | img_param_loss_list = [] 407 | txt_param_loss_list = [] 408 | 409 | img_param_dist_list = [] 410 | txt_param_dist_list = [] 411 | 412 | indices_chunks = [] 413 | for step in range(args.syn_steps): 414 | indices = torch.randperm(len(syn_images)) 415 | these_indices = indices[:args.mini_batch_size] 416 | #these_indices = indices 417 | x = syn_images[these_indices] 418 | this_y = syn_texts[these_indices] 419 | if args.distributed: 420 | img_forward_params = img_student_params[-1].unsqueeze(0).expand(torch.cuda.device_count(), -1) 421 | txt_forward_params = txt_student_params[-1].unsqueeze(0).expand(torch.cuda.device_count(), -1) 422 | else: 423 | img_forward_params = img_student_params[-1] 424 | txt_forward_params = txt_student_params[-1] 425 | 426 | x = img_student_net(x, flat_param=img_forward_params) 427 | x = x / x.norm(dim=1, keepdim=True) 428 | this_y = txt_student_net(this_y, flat_param=txt_forward_params) 429 | this_y = this_y / this_y.norm(dim=1, keepdim=True) 430 | image_logits = logit_scale * x.float() @ this_y.float().t() 431 | ground_truth = torch.arange(len(image_logits)).type_as(image_logits).long() 432 | contrastive_loss = (F.cross_entropy(image_logits, ground_truth) + F.cross_entropy(image_logits.t(), ground_truth))/2 433 | 434 | img_grad = torch.autograd.grad(contrastive_loss, img_student_params[-1], create_graph=True)[0] 435 | txt_grad = torch.autograd.grad(contrastive_loss, txt_student_params[-1], create_graph=True)[0] 436 | print(contrastive_loss) 437 | img_student_params.append(img_student_params[-1] - syn_lr_img * img_grad) 438 | txt_student_params.append(txt_student_params[-1] - syn_lr_txt * txt_grad) 439 | img_param_loss = torch.tensor(0.0).to(args.device) 440 | img_param_dist = torch.tensor(0.0).to(args.device) 441 | txt_param_loss = torch.tensor(0.0).to(args.device) 442 | txt_param_dist = torch.tensor(0.0).to(args.device) 443 | 444 | 445 | img_param_loss += torch.nn.functional.mse_loss(img_student_params[-1], img_target_params, reduction="sum") 446 | img_param_dist += torch.nn.functional.mse_loss(img_starting_params, img_target_params, reduction="sum") 447 | txt_param_loss += torch.nn.functional.mse_loss(txt_student_params[-1], txt_target_params, reduction="sum") 448 | txt_param_dist += torch.nn.functional.mse_loss(txt_starting_params, txt_target_params, reduction="sum") 449 | 450 | img_param_loss_list.append(img_param_loss) 451 | img_param_dist_list.append(img_param_dist) 452 | txt_param_loss_list.append(txt_param_loss) 453 | txt_param_dist_list.append(txt_param_dist) 454 | 455 | 456 | img_param_loss /= img_param_dist 457 | txt_param_loss /= txt_param_dist 458 | grand_loss = img_param_loss + txt_param_loss 459 | 460 | if math.isnan(img_param_loss): 461 | break 462 | print("img_param_loss: {}".format(img_param_loss)) 463 | print("txt_param_loss: {}".format(txt_param_loss)) 464 | 465 | optimizer_lr.zero_grad() 466 | optimizer_img.zero_grad() 467 | optimizer_txt.zero_grad() 468 | 469 | grand_loss.backward() 470 | # clip_value = 0.5 471 | 472 | #torch.nn.utils.clip_grad_norm_([image_syn], clip_value) 473 | #torch.nn.utils.clip_grad_norm_([text_syn], clip_value) 474 | #torch.nn.utils.clip_grad_norm_([syn_lr_img], clip_value) 475 | #torch.nn.utils.clip_grad_norm_([syn_lr_txt], clip_value) 476 | print("syn_lr_img: {}".format(syn_lr_img.grad)) 477 | print("syn_lr_txt: {}".format(syn_lr_txt.grad)) 478 | wandb.log({"syn_lr_img": syn_lr_img.grad.detach().cpu()}, step=it) 479 | wandb.log({"syn_lr_txt": syn_lr_txt.grad.detach().cpu()}, step=it) 480 | 481 | optimizer_lr.step() 482 | optimizer_img.step() 483 | optimizer_txt.step() 484 | 485 | wandb.log({"Grand_Loss": grand_loss.detach().cpu(), 486 | "Start_Epoch": start_epoch}) 487 | 488 | for _ in img_student_params: 489 | del _ 490 | for _ in txt_student_params: 491 | del _ 492 | 493 | if it%10 == 0: 494 | print('%s iter = %04d, loss = %.4f' % (get_time(), it, grand_loss.item())) 495 | 496 | wandb.finish() 497 | 498 | 499 | if __name__ == '__main__': 500 | parser = argparse.ArgumentParser(description='Parameter Processing') 501 | 502 | parser.add_argument('--dataset', type=str, default='flickr30k', help='dataset') 503 | parser.add_argument('--ipc', type=int, default=1, help='image(s) per class') 504 | 505 | parser.add_argument('--eval_mode', type=str, default='S', 506 | help='eval_mode, check utils.py for more info') 507 | 508 | parser.add_argument('--num_eval', type=int, default=5, help='how many networks to evaluate on') 509 | 510 | parser.add_argument('--eval_it', type=int, default=50, help='how often to evaluate') 511 | 512 | parser.add_argument('--epoch_eval_train', type=int, default=50, help='epochs to train a model with synthetic data') 513 | parser.add_argument('--Iteration', type=int, default=50000, help='how many distillation steps to perform') 514 | 515 | parser.add_argument('--lr_img', type=float, default=1000, help='learning rate for updating synthetic images') 516 | parser.add_argument('--lr_txt', type=float, default=1000, help='learning rate for updating synthetic texts') 517 | parser.add_argument('--lr_lr', type=float, default=1e-03, help='learning rate for updating... learning rate') 518 | parser.add_argument('--lr_teacher_img', type=float, default=0.1, help='learning rate for updating network parameters') 519 | parser.add_argument('--lr_teacher_txt', type=float, default=0.1, help='learning rate for updating network parameters') 520 | 521 | parser.add_argument('--batch_train', type=int, default=64, help='batch size for training networks') 522 | 523 | parser.add_argument('--pix_init', type=str, default='real', choices=["noise", "real"], 524 | help='noise/real: initialize synthetic images from random noise or randomly sampled real images.') 525 | parser.add_argument('--txt_init', type=str, default='real', choices=["noise", "real"], 526 | help='noise/real: initialize synthetic texts from random noise or randomly sampled real images.') 527 | 528 | parser.add_argument('--dsa', type=str, default='True', choices=['True', 'False'], 529 | help='whether to use differentiable Siamese augmentation.') 530 | 531 | parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', 532 | help='differentiable Siamese augmentation strategy') 533 | 534 | parser.add_argument('--data_path', type=str, default='./data/Flickr30k/', help='dataset path') 535 | parser.add_argument('--buffer_path', type=str, default='./buffers', help='buffer path') 536 | 537 | parser.add_argument('--expert_epochs', type=int, default=3, help='how many expert epochs the target params are') 538 | parser.add_argument('--syn_steps', type=int, default=20, help='how many steps to take on synthetic data') 539 | parser.add_argument('--max_start_epoch', type=int, default=25, help='max epoch we can start at') 540 | 541 | parser.add_argument('--zca', action='store_true', help="do ZCA whitening") 542 | 543 | parser.add_argument('--load_all', action='store_true', help="only use if you can fit all expert trajectories into RAM") 544 | 545 | parser.add_argument('--no_aug', type=bool, default=False, help='this turns off diff aug during distillation') 546 | 547 | parser.add_argument('--texture', action='store_true', help="will distill textures instead") 548 | parser.add_argument('--canvas_size', type=int, default=2, help='size of synthetic canvas') 549 | parser.add_argument('--canvas_samples', type=int, default=1, help='number of canvas samples per iteration') 550 | 551 | 552 | parser.add_argument('--max_files', type=int, default=None, help='number of expert files to read (leave as None unless doing ablations)') 553 | parser.add_argument('--max_experts', type=int, default=None, help='number of experts to read per file (leave as None unless doing ablations)') 554 | 555 | parser.add_argument('--force_save', action='store_true', help='this will save images for 50ipc') 556 | current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") 557 | parser.add_argument('--name', type=str, default=current_time, help='name of wandb run') 558 | parser.add_argument('--num_queries', type=int, default=100, help='number of queries') 559 | parser.add_argument('--mini_batch_size', type=int, default=100, help='number of queries') 560 | parser.add_argument('--basis', type=bool, default=False, help='whether use basis or not') 561 | parser.add_argument('--n_basis', type=int, default=64, help='n_basis') 562 | parser.add_argument('--recursive', type=bool, default=False, help='whether use basis or not') 563 | parser.add_argument('--load_npy', type=bool, default=False, help='load_npy') 564 | parser.add_argument('--image_size', type=int, default=224, help='image_size') 565 | parser.add_argument('--image_root', type=str, default='./Flickr30k/flickr-image-dataset/flickr30k-images/', help='location of image root') 566 | parser.add_argument('--ann_root', type=str, default='./Flickr30k/ann_file/', help='location of ann root') 567 | parser.add_argument('--batch_size_train', type=int, default=128, help='batch_size_train') 568 | parser.add_argument('--batch_size_test', type=int, default=128, help='batch_size_test') 569 | parser.add_argument('--image_encoder', type=str, default='nfnet', choices=['clip', 'nfnet', 'vit', 'nf_resnet50'], help='image encoder') 570 | parser.add_argument('--text_encoder', type=str, default='bert', choices=['bert', 'clip'], help='text encoder') 571 | parser.add_argument('--text_pretrained', type=bool, default=True, help='text_pretrained') 572 | parser.add_argument('--image_pretrained', type=bool, default=True, help='image_pretrained') 573 | parser.add_argument('--text_trainable', type=bool, default=False, help='text_trainable') 574 | parser.add_argument('--image_trainable', type=bool, default=True, help='image_trainable') 575 | parser.add_argument('--only_has_image_projection', type=bool, default=False, help='None') 576 | parser.add_argument('--distill', type=bool, default=True, help='whether distill') 577 | parser.add_argument('--optimize', type=str, default='reparam', choices=['reparam', 'ift'], help='matching_train') 578 | parser.add_argument('--image_only', type=bool, default=False, help='None') 579 | parser.add_argument('--text_only', type=bool, default=False, help='None') 580 | parser.add_argument('--draw', type=bool, default=True, help='None') 581 | parser.add_argument('--transfer', type=bool, default=False, help='transfer cross architecture') 582 | parser.add_argument('--std', type=bool, default=False, help='standard deviation') 583 | args = parser.parse_args() 584 | 585 | main(args) -------------------------------------------------------------------------------- /epoch.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * part of the code (i.e. def epoch_test() and itm_eval()) is from: https://github.com/salesforce/BLIP/blob/main/train_retrieval.py#L69 3 | * Copyright (c) 2022, salesforce.com, inc. 4 | * All rights reserved. 5 | * SPDX-License-Identifier: BSD-3-Clause 6 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 7 | * By Junnan Li 8 | ''' 9 | import numpy as np 10 | import torch 11 | import time 12 | import datetime 13 | import torch 14 | import torch.nn.functional as F 15 | from tqdm import tqdm 16 | import torch.nn as nn 17 | from utils import * 18 | 19 | 20 | def epoch(e, dataloader, net, optimizer_img, optimizer_txt, args): 21 | """ 22 | Perform a training epoch on the given dataloader. 23 | 24 | Args: 25 | dataloader (torch.utils.data.DataLoader): The dataloader for iterating over the dataset. 26 | net: The model. 27 | optimizer_img: The optimizer for image parameters. 28 | optimizer_txt: The optimizer for text parameters. 29 | args (object): The arguments specifying the training configuration. 30 | 31 | Returns: 32 | Tuple of average loss and average accuracy. 33 | """ 34 | net = net.to(args.device) 35 | net.train() 36 | loss_avg, acc_avg, num_exp = 0, 0, 0 37 | 38 | for i, data in tqdm(enumerate(dataloader)): 39 | if args.distill: 40 | image, caption = data[:2] 41 | else: 42 | image, caption, index = data[:3] 43 | 44 | image = image.to(args.device) 45 | n_b = image.shape[0] 46 | 47 | loss, acc = net(image, caption, e) 48 | 49 | loss_avg += loss.item() * n_b 50 | acc_avg += acc 51 | num_exp += n_b 52 | 53 | optimizer_img.zero_grad() 54 | optimizer_txt.zero_grad() 55 | loss.backward() 56 | optimizer_img.step() 57 | optimizer_txt.step() 58 | 59 | loss_avg /= num_exp 60 | acc_avg /= num_exp 61 | 62 | return loss_avg, acc_avg 63 | 64 | 65 | 66 | 67 | @torch.no_grad() 68 | def epoch_test(dataloader, model, device, bert_test_embed): 69 | model.eval() 70 | logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 71 | metric_logger = MetricLogger(delimiter=" ") 72 | header = 'Evaluation:' 73 | print('Computing features for evaluation...') 74 | start_time = time.time() 75 | 76 | 77 | txt_embed = model.text_projection(bert_test_embed.float().to('cuda')) 78 | text_embeds = txt_embed / txt_embed.norm(dim=1, keepdim=True) #torch.Size([5000, 768]) 79 | text_embeds = text_embeds.to(device) 80 | 81 | image_embeds = [] 82 | for image, img_id in dataloader: 83 | image_feat = model.image_encoder(image.to(device)) 84 | im_embed = image_feat / image_feat.norm(dim=1, keepdim=True) 85 | image_embeds.append(im_embed) 86 | image_embeds = torch.cat(image_embeds,dim=0) 87 | use_image_projection = False 88 | if use_image_projection: 89 | im_embed = model.image_projection(image_embeds.float()) 90 | image_embeds = im_embed / im_embed.norm(dim=1, keepdim=True) 91 | else: 92 | image_embeds = image_embeds / image_embeds.norm(dim=1, keepdim=True) 93 | 94 | sims_matrix = logit_scale.exp() * image_embeds @ text_embeds.t() 95 | score_matrix_i2t = torch.full((len(image_embeds),len(text_embeds)),-100.0).to(device) #torch.Size([1000, 5000]) 96 | #for i, sims in enumerate(metric_logger.log_every(sims_matrix[0:sims_matrix.size(0) + 1], 50, header)): 97 | for i, sims in enumerate(sims_matrix[0:sims_matrix.size(0) + 1]): 98 | topk_sim, topk_idx = sims.topk(k=128, dim=0) 99 | score_matrix_i2t[i,topk_idx] = topk_sim #i:0-999, topk_idx:0-4999, find top k (k=128) similar text for each image 100 | 101 | sims_matrix = sims_matrix.t() 102 | score_matrix_t2i = torch.full((len(text_embeds),len(image_embeds)),-100.0).to(device) 103 | for i,sims in enumerate(sims_matrix[0:sims_matrix.size(0) + 1]): 104 | topk_sim, topk_idx = sims.topk(k=128, dim=0) 105 | score_matrix_t2i[i,topk_idx] = topk_sim 106 | 107 | total_time = time.time() - start_time 108 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 109 | print('Evaluation time {}'.format(total_time_str)) 110 | 111 | return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy() 112 | 113 | 114 | @torch.no_grad() 115 | def itm_eval(scores_i2t, scores_t2i, txt2img, img2txt): 116 | 117 | #Images->Text 118 | ranks = np.zeros(scores_i2t.shape[0]) 119 | print("TR: ", len(ranks)) 120 | for index, score in enumerate(scores_i2t): 121 | inds = np.argsort(score)[::-1] 122 | # Score 123 | rank = 1e20 124 | for i in img2txt[index]: 125 | tmp = np.where(inds == i)[0][0] 126 | if tmp < rank: 127 | rank = tmp 128 | ranks[index] = rank 129 | 130 | # Compute metrics 131 | tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 132 | tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 133 | tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 134 | 135 | #Text->Images 136 | ranks = np.zeros(scores_t2i.shape[0]) 137 | print("IR: ", len(ranks)) 138 | 139 | for index,score in enumerate(scores_t2i): 140 | inds = np.argsort(score)[::-1] 141 | ranks[index] = np.where(inds == txt2img[index])[0][0] 142 | 143 | # Compute metrics 144 | ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 145 | ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 146 | ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 147 | 148 | tr_mean = (tr1 + tr5 + tr10) / 3 149 | ir_mean = (ir1 + ir5 + ir10) / 3 150 | r_mean = (tr_mean + ir_mean) / 2 151 | 152 | eval_result = {'txt_r1': tr1, 153 | 'txt_r5': tr5, 154 | 'txt_r10': tr10, 155 | 'txt_r_mean': tr_mean, 156 | 'img_r1': ir1, 157 | 'img_r5': ir5, 158 | 'img_r10': ir10, 159 | 'img_r_mean': ir_mean, 160 | 'r_mean': r_mean} 161 | return eval_result 162 | 163 | 164 | def evaluate_synset(it_eval, net, images_train, labels_train, testloader, args, bert_test_embed, return_loss=False): 165 | 166 | net = net.to(args.device) 167 | images_train = images_train.to(args.device) 168 | labels_train = labels_train.to(args.device) 169 | lr = float(args.lr_net) 170 | Epoch = int(args.epoch_eval_train) 171 | lr_schedule = [Epoch//2+1] 172 | optimizer_img = torch.optim.SGD(net.image_encoder.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005) 173 | optimizer_txt = torch.optim.SGD(net.text_projection.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005) 174 | 175 | dst_train = TensorDataset(images_train, labels_train) 176 | trainloader = torch.utils.data.DataLoader(dst_train, batch_size=args.batch_train, shuffle=True, num_workers=0) 177 | 178 | start = time.time() 179 | acc_train_list = [] 180 | loss_train_list = [] 181 | 182 | for ep in tqdm(range(Epoch+1)): 183 | loss_train, acc_train = epoch(ep, trainloader, net, optimizer_img, optimizer_txt, args) 184 | acc_train_list.append(acc_train) 185 | loss_train_list.append(loss_train) 186 | if ep == Epoch: 187 | with torch.no_grad(): 188 | score_val_i2t, score_val_t2i = epoch_test(testloader, net, args.device, bert_test_embed) 189 | val_result = itm_eval(score_val_i2t, score_val_t2i, testloader.dataset.txt2img, testloader.dataset.img2txt) 190 | lr *= 0.1 191 | optimizer_img = torch.optim.SGD(net.image_encoder.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005) 192 | optimizer_txt = torch.optim.SGD(net.text_projection.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005) 193 | 194 | time_train = time.time() - start 195 | return net, acc_train_list, val_result 196 | -------------------------------------------------------------------------------- /images/clip_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/images/clip_0.png -------------------------------------------------------------------------------- /images/clip_2000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/images/clip_2000.png -------------------------------------------------------------------------------- /images/coco_syn_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/images/coco_syn_0.png -------------------------------------------------------------------------------- /images/coco_syn_2000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/images/coco_syn_2000.png -------------------------------------------------------------------------------- /images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/images/logo.png -------------------------------------------------------------------------------- /images/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/images/loss.png -------------------------------------------------------------------------------- /images/more_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/images/more_vis.png -------------------------------------------------------------------------------- /images/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/images/pipeline.png -------------------------------------------------------------------------------- /images/synthetic_images_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/images/synthetic_images_0.png -------------------------------------------------------------------------------- /images/synthetic_images_2000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/images/synthetic_images_2000.png -------------------------------------------------------------------------------- /images/table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/images/table.png -------------------------------------------------------------------------------- /images/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/images/teaser.png -------------------------------------------------------------------------------- /images/text_noise_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/images/text_noise_0.png -------------------------------------------------------------------------------- /images/text_noise_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/images/text_noise_1.png -------------------------------------------------------------------------------- /images/visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/images/visualization.png -------------------------------------------------------------------------------- /index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 |
14 |
15 | 16 |
17 |
18 | 19 | 20 | Vision-Language Dataset Distillation 21 | 22 | 23 | 24 | 25 |
26 |
27 | Vision-Language Dataset Distillation 28 | 29 |
30 | 31 |

32 | 33 | 34 | 37 | 40 | 43 | 47 |
35 | Xindi Wu1 36 | 38 | Byron Zhang1 39 | 41 | Zhiwei Deng2 42 | 44 | Olga Russakovsky1 45 | 46 |
48 | 49 | 50 | 51 | 54 | 57 | 58 |
52 | Princeton University1 53 | 55 | Google Research2 56 |
59 | 60 | 63 | 66 | 69 | 72 | 75 | 78 | 79 | 82 |
61 | [Arxiv] 62 | 64 | [OpenReview] 65 | 67 | [Code] 68 | 70 | [Video] 71 | 73 | [Slides] 74 | 76 | [Poster] 77 |
80 | TMLR 2024 81 |
83 | 84 | 85 |
86 | 87 | 88 | 89 | 90 | 91 |
92 |

Abstract

93 |
94 | 95 | 96 | 97 |
98 | 99 |

100 | Dataset distillation methods offer the promise of reducing a large-scale dataset down to a significantly smaller set of (potentially synthetic) training examples, which preserve sufficient information for training a new model from scratch. So far dataset distillation methods have been developed for image classification. However, with the rise in capabilities of vision-language models, and especially given the scale of datasets necessary to train these models, the time is ripe to expand dataset distillation methods beyond image classification. In this work, we take the first steps towards this goal by expanding on the idea of trajectory matching to create a distillation method for vision-language datasets. The key challenge is that vision-language datasets do not have a set of discrete classes. To overcome this, our proposed vision-and-language dataset distillation method jointly distill the images and their corresponding language descriptions in a contrastive formulation. Since there are no existing baselines, we compare our approach to three coreset selection methods (strategic subsampling of the training dataset), which we adapt to the vision-language setting. We demonstrate significant improvements on the challenging Flickr30K and COCO retrieval benchmarks: for example, on Flickr30K the best coreset selection method which selects 1000 image-text pairs for training is able to achieve only 5.6% image-to-text retrieval accuracy (i.e., recall@1); in contrast, our dataset distillation approach almost doubles that to 9.9% with just 100 (an order of magnitude fewer) training pairs. 101 |

102 |
103 |
104 | 105 |

Bi-Trajectory-Guided Co-Distillation

106 | 107 |

108 | Dataset distillation traditionally focuses on classification tasks with distinct labels, creating compact distilled datasets for efficient learning. 109 | We've expanded this to a multimodal approach, distilling both vision and language data, emphasizing their interrelation. 110 | Unlike simple classification, our method captures complex connections between image and text data. 111 | It is worth noting that this would be impossible if we solely optimize a single modality, which is supported by our single-modality distillation results. 112 |

113 | The approach consists of two stages: 114 | 115 |

    116 |
  1. 117 | Obtaining the expert training trajectories \( \{\tau^*\} \), with each trajectory \( \tau^* = \{\theta^*_t\}_{t=0}^T \), by training multiple models for \( T \) epochs on the full dataset \( \mathbf{D} \). For our multimodal setting, the models are trained using bidirectional contrastive loss. 118 |
  2. 119 |
  3. 120 | Training a set of student models on the current distilled dataset \( \hat{\mathbf{D}} \) using the same bidirectional contrastive loss, and then updating the distilled dataset \( \hat{\mathbf{D}} \) based on the multimodal trajectory matching loss of the student models' parameters and the optimal \( \theta^* \). 121 |
  4. 122 |
123 |

124 | 125 | 126 |
127 |
128 | 129 | 130 |
131 | 132 | 133 | 136 | 137 | 138 | 151 | 154 | 155 |
134 |

Vision-Language Bi-Trajectory Matching

135 |
139 |

140 | Following the MTT formulation, we randomly sample \( M \) image-text pairs from \( \mathbf{D} \) to initialize the distilled dataset \( \mathbf{\hat{D}} \) (more details can be found elsewhere). We sample an expert trajectory (i.e., the trajectory of a model trained on the full dataset) \( \tau^* = \{\theta^*_t\}_{t=0}^T \) and a random starting epoch \( s \) to initialize \( \hat{\theta}_s = \theta^*_s \). 141 |

142 | We train the student model on the distilled dataset for \( \hat{R} \) steps to obtain \( \hat{\theta}_{s+\hat{R}} \). We then update the distilled dataset based on multimodal trajectory matching loss \( \ell_{trajectory} \) computed on the accumulated difference between student trajectory and expert trajectory: 143 |
144 | $$ 145 | \ell_{trajectory} = \frac{\|\hat{\theta}_{img, s+\hat{R}} - \theta^*_{img, s+R}\|_2^2}{\|\theta^*_{img, s} - \theta^*_{img, s+R}\|_2^2} + \frac{\|\hat{\theta}_{txt, s+\hat{R}} - \theta^*_{txt, s+R}\|_2^2}{\|\theta^*_{txt, s} - \theta^*_{txt, s+R}\|_2^2}. 146 | $$ 147 |
148 | We update the distilled dataset by back-propagating through multiple (\( \hat{R} \)) gradient descent updates to the \( \hat{\mathbf{D}} \), specifically, image pixel space and text embedding space. We initialize the continuous sentence embeddings using a pretrained BERT model and update the distilled text in the continuous embedding space. For the distilled image optimization, we directly update the pixel values of the distilled images. 149 |

150 |
152 | 153 |
156 | 157 | 158 | 159 |
160 | 161 |
162 |

Results

163 |
164 | 165 |

We compare our distillation method to four coreset selection methods: random selection of training examples, herding, k-center and forgetting. We consider different selected sizes (100, 200, 500, and 1000) and report the image-to-text (TR) and text-to-image (IR) retrieval performance on the Flickr30K dataset in Table A. 166 |

167 |

We also provide ablation study on the selection of vision (Table B) and language (Table C) backbones. We introduce the Performance Recovery Ratio (PRR) to evaluate the effectiveness of dataset distillation. It quantifies the percentage of performance retained from the original data. The performance for various backbone combinations is shown in Table D. 168 | 169 | 170 | 176 | 177 | 178 | 179 |

180 |
181 | 182 |
183 | 184 |

Visualization

185 |

Left: The image and text pairs before the distillation. Right: The image and text pairs after 2000 distillation steps. Note that the texts visualized here are nearest sentence decodings in the training set corresponding to the distilled text embeddings. 186 | 187 |

188 |

Here we include a number of visualizations of the data we distilled from the multimodal dataset (both Flickr30K and COCO) for a more intuitive understanding of the distilled set. 189 | We provide 50 distilled image-text paired examples including their visualization before the distillation process. 190 | Those experiments are conducted using 100 distilled pairs, with pretrained NFNet and BERT as backbones and the synthetic step is set to 8 during distillation. 191 | 192 | 193 |

194 | 195 | 196 |
197 | 198 | 199 | 200 | 201 |
202 | 203 |
204 | 205 |
206 | 207 |
208 |

Conclusion

209 |
210 |

211 | In this work, we propose a multimodal dataset distillation method for the image-text retrieval task. 212 | By co-distilling both the vision and language modalities, we can progressively optimize and distill the most critical information. 213 | Our experiments show that co-distilling different modalities via trajectory matching holds promise. 214 | We hope that the insights we gathered can be a roadmap for future studies exploring more complex settings, and that our work lays the groundwork for future research aimed at understanding what is the minimum information required for a vision-language model to achieve comparable performance quickly, thereby building a better understanding of the compositionality of compact visual-linguistic knowledge. 215 | 216 |
217 | 218 | 219 | 220 | 221 |

222 | 223 |
224 |
225 | 226 |

Paper

227 | 228 | 229 | 239 | 240 | 241 | 242 | 252 | 253 | 256 | 257 | 258 |
Xindi Wu, Byron Zhang, Zhiwei Deng, Olga Russakovsky.
230 | Vision-Language Dataset Distillation
231 | TMLR 2024.
232 | [Arxiv]     233 | [OpenReview]     234 | [Code]     235 | [Video]     236 | [Slides]     237 | [Poster] 238 |
243 | 251 |
254 | 255 |
259 |
260 |
261 | 262 | 263 | 264 | 265 | 266 | 274 | 275 |
267 | 268 |
269 |

Acknowledgements

270 |
271 | This material is based upon work supported by the National Science Foundation under Grant No. 2107048. Any opinions, findings, and conclusions or recommendations expressed in this material are those of the author(s) and do not necessarily reflect the views of the National Science Foundation. We thank many people from Princeton Visual AI lab (Allison Chen, Jihoon Chung, Tyler Zhu, Ye Zhu, William Yang and Kaiqu Liang) and Princeton NLP group (Carlos E. Jimenez, John Yang), Tiffany Ling and George Cazenavette for their helpful feedback on this work. 272 |
273 |
276 | 277 |

278 | 279 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Sourced directly from OpenAI's CLIP repo 2 | from collections import OrderedDict 3 | from typing import Tuple, Union 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | 10 | 11 | class Bottleneck(nn.Module): 12 | expansion = 4 13 | 14 | def __init__(self, inplanes, planes, stride=1): 15 | super().__init__() 16 | 17 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 18 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | 24 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 25 | 26 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 27 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 28 | 29 | self.relu = nn.ReLU(inplace=True) 30 | self.downsample = None 31 | self.stride = stride 32 | 33 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 34 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 35 | self.downsample = nn.Sequential(OrderedDict([ 36 | ("-1", nn.AvgPool2d(stride)), 37 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 38 | ("1", nn.BatchNorm2d(planes * self.expansion)) 39 | ])) 40 | 41 | def forward(self, x: torch.Tensor): 42 | identity = x 43 | 44 | out = self.relu(self.bn1(self.conv1(x))) 45 | out = self.relu(self.bn2(self.conv2(out))) 46 | out = self.avgpool(out) 47 | out = self.bn3(self.conv3(out)) 48 | 49 | if self.downsample is not None: 50 | identity = self.downsample(x) 51 | 52 | out += identity 53 | out = self.relu(out) 54 | return out 55 | 56 | 57 | class AttentionPool2d(nn.Module): 58 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 59 | super().__init__() 60 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 61 | self.k_proj = nn.Linear(embed_dim, embed_dim) 62 | self.q_proj = nn.Linear(embed_dim, embed_dim) 63 | self.v_proj = nn.Linear(embed_dim, embed_dim) 64 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 65 | self.num_heads = num_heads 66 | 67 | def forward(self, x): 68 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 69 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 70 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 71 | x, _ = F.multi_head_attention_forward( 72 | query=x, key=x, value=x, 73 | embed_dim_to_check=x.shape[-1], 74 | num_heads=self.num_heads, 75 | q_proj_weight=self.q_proj.weight, 76 | k_proj_weight=self.k_proj.weight, 77 | v_proj_weight=self.v_proj.weight, 78 | in_proj_weight=None, 79 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 80 | bias_k=None, 81 | bias_v=None, 82 | add_zero_attn=False, 83 | dropout_p=0, 84 | out_proj_weight=self.c_proj.weight, 85 | out_proj_bias=self.c_proj.bias, 86 | use_separate_proj_weight=True, 87 | training=self.training, 88 | need_weights=False 89 | ) 90 | 91 | return x[0] 92 | 93 | 94 | 95 | class LayerNorm(nn.LayerNorm): 96 | """Subclass torch's LayerNorm to handle fp16.""" 97 | 98 | def forward(self, x: torch.Tensor): 99 | orig_type = x.dtype 100 | ret = super().forward(x.type(torch.float32)) 101 | return ret.type(orig_type) 102 | 103 | 104 | class QuickGELU(nn.Module): 105 | def forward(self, x: torch.Tensor): 106 | return x * torch.sigmoid(1.702 * x) 107 | 108 | 109 | class ResidualAttentionBlock(nn.Module): 110 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 111 | super().__init__() 112 | 113 | self.attn = nn.MultiheadAttention(d_model, n_head) 114 | self.ln_1 = LayerNorm(d_model) 115 | self.mlp = nn.Sequential(OrderedDict([ 116 | ("c_fc", nn.Linear(d_model, d_model * 4)), 117 | ("gelu", QuickGELU()), 118 | ("c_proj", nn.Linear(d_model * 4, d_model)) 119 | ])) 120 | self.ln_2 = LayerNorm(d_model) 121 | self.attn_mask = attn_mask 122 | 123 | def attention(self, x: torch.Tensor): 124 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 125 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 126 | 127 | def forward(self, x: torch.Tensor): 128 | x = x + self.attention(self.ln_1(x)) 129 | x = x + self.mlp(self.ln_2(x)) 130 | return x 131 | 132 | 133 | 134 | def convert_weights(model: nn.Module): 135 | """Convert applicable model parameters to fp16""" 136 | 137 | def _convert_weights_to_fp16(l): 138 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 139 | l.weight.data = l.weight.data.half() 140 | if l.bias is not None: 141 | l.bias.data = l.bias.data.half() 142 | 143 | if isinstance(l, nn.MultiheadAttention): 144 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 145 | tensor = getattr(l, attr) 146 | if tensor is not None: 147 | tensor.data = tensor.data.half() 148 | 149 | for name in ["text_projection", "proj"]: 150 | if hasattr(l, name): 151 | attr = getattr(l, name) 152 | if attr is not None: 153 | attr.data = attr.data.half() 154 | 155 | model.apply(_convert_weights_to_fp16) 156 | 157 | 158 | def build_model(state_dict: dict): 159 | vit = "visual.proj" in state_dict 160 | 161 | if vit: 162 | vision_width = state_dict["visual.conv1.weight"].shape[0] 163 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 164 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 165 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 166 | image_resolution = vision_patch_size * grid_size 167 | else: 168 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 169 | vision_layers = tuple(counts) 170 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 171 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 172 | vision_patch_size = None 173 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 174 | image_resolution = output_width * 32 175 | 176 | embed_dim = state_dict["text_projection"].shape[1] 177 | context_length = state_dict["positional_embedding"].shape[0] 178 | vocab_size = state_dict["token_embedding.weight"].shape[0] 179 | transformer_width = state_dict["ln_final.weight"].shape[0] 180 | transformer_heads = transformer_width // 64 181 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 182 | 183 | model = CLIP( 184 | embed_dim, 185 | image_resolution, vision_layers, vision_width, vision_patch_size, 186 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 187 | ) 188 | 189 | for key in ["input_resolution", "context_length", "vocab_size"]: 190 | if key in state_dict: 191 | del state_dict[key] 192 | 193 | convert_weights(model) 194 | model.load_state_dict(state_dict) 195 | return model.eval() -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | from collections import OrderedDict 5 | from typing import Tuple, Union 6 | import clip 7 | from transformers import ViTConfig, ViTModel, AutoTokenizer, CLIPTextModel, CLIPTextConfig, CLIPProcessor, CLIPConfig 8 | import numpy as np 9 | from transformers import BertTokenizer, BertModel 10 | from torchvision.models import resnet18, resnet 11 | from transformers.models.bert.modeling_bert import BertAttention, BertConfig 12 | 13 | tokenizer=BertTokenizer.from_pretrained('bert-base-uncased') 14 | BERT_model = BertModel.from_pretrained('bert-base-uncased') 15 | 16 | 17 | # Acknowledgement to 18 | # https://github.com/kuangliu/pytorch-cifar, 19 | # https://github.com/BIGBALLON/CIFAR-ZOO, 20 | 21 | # adapted from 22 | # https://github.com/VICO-UoE/DatasetCondensation 23 | # https://github.com/Zasder3/train-CLIP 24 | 25 | 26 | ''' MLP ''' 27 | class MLP(nn.Module): 28 | def __init__(self, channel, num_classes): 29 | super(MLP, self).__init__() 30 | self.fc_1 = nn.Linear(28*28*1 if channel==1 else 32*32*3, 128) 31 | self.fc_2 = nn.Linear(128, 128) 32 | self.fc_3 = nn.Linear(128, num_classes) 33 | 34 | def forward(self, x): 35 | out = x.view(x.size(0), -1) 36 | out = F.relu(self.fc_1(out)) 37 | out = F.relu(self.fc_2(out)) 38 | out = self.fc_3(out) 39 | return out 40 | 41 | 42 | 43 | ''' ConvNet ''' 44 | class ConvNet(nn.Module): 45 | def __init__(self, channel, num_classes, net_width=128, net_depth=4, net_act='relu', net_norm='instancenorm', net_pooling='avgpooling', im_size = (224,224)): 46 | super(ConvNet, self).__init__() 47 | 48 | self.features, shape_feat = self._make_layers(channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size) 49 | num_feat = shape_feat[0]*shape_feat[1]*shape_feat[2] 50 | self.classifier = nn.Linear(num_feat, num_classes) 51 | 52 | def forward(self, x): 53 | # print("MODEL DATA ON: ", x.get_device(), "MODEL PARAMS ON: ", self.classifier.weight.data.get_device()) 54 | out = self.features(x) 55 | out = out.view(out.size(0), -1) 56 | out = self.classifier(out) 57 | return out 58 | 59 | def _get_activation(self, net_act): 60 | if net_act == 'sigmoid': 61 | return nn.Sigmoid() 62 | elif net_act == 'relu': 63 | return nn.ReLU(inplace=True) 64 | elif net_act == 'leakyrelu': 65 | return nn.LeakyReLU(negative_slope=0.01) 66 | else: 67 | exit('unknown activation function: %s'%net_act) 68 | 69 | def _get_pooling(self, net_pooling): 70 | if net_pooling == 'maxpooling': 71 | return nn.MaxPool2d(kernel_size=2, stride=2) 72 | elif net_pooling == 'avgpooling': 73 | return nn.AvgPool2d(kernel_size=2, stride=2) 74 | elif net_pooling == 'none': 75 | return None 76 | else: 77 | exit('unknown net_pooling: %s'%net_pooling) 78 | 79 | def _get_normlayer(self, net_norm, shape_feat): 80 | # shape_feat = (c*h*w) 81 | if net_norm == 'batchnorm': 82 | return nn.BatchNorm2d(shape_feat[0], affine=True) 83 | elif net_norm == 'layernorm': 84 | return nn.LayerNorm(shape_feat, elementwise_affine=True) 85 | elif net_norm == 'instancenorm': 86 | return nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True) 87 | elif net_norm == 'groupnorm': 88 | return nn.GroupNorm(4, shape_feat[0], affine=True) 89 | elif net_norm == 'none': 90 | return None 91 | else: 92 | exit('unknown net_norm: %s'%net_norm) 93 | 94 | def _make_layers(self, channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size): 95 | layers = [] 96 | in_channels = channel 97 | if im_size[0] == 28: 98 | im_size = (32, 32) 99 | shape_feat = [in_channels, im_size[0], im_size[1]] 100 | for d in range(net_depth): 101 | layers += [nn.Conv2d(in_channels, net_width, kernel_size=3, padding=3 if channel == 1 and d == 0 else 1)] 102 | shape_feat[0] = net_width 103 | if net_norm != 'none': 104 | layers += [self._get_normlayer(net_norm, shape_feat)] 105 | layers += [self._get_activation(net_act)] 106 | in_channels = net_width 107 | if net_pooling != 'none': 108 | layers += [self._get_pooling(net_pooling)] 109 | shape_feat[1] //= 2 110 | shape_feat[2] //= 2 111 | 112 | 113 | return nn.Sequential(*layers), shape_feat 114 | 115 | 116 | ''' ConvNet ''' 117 | class ConvNetGAP(nn.Module): 118 | def __init__(self, channel, num_classes, net_width, net_depth, net_act, net_norm, net_pooling, im_size = (32,32)): 119 | super(ConvNetGAP, self).__init__() 120 | 121 | self.features, shape_feat = self._make_layers(channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size) 122 | num_feat = shape_feat[0]*shape_feat[1]*shape_feat[2] 123 | # self.classifier = nn.Linear(num_feat, num_classes) 124 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 125 | self.classifier = nn.Linear(shape_feat[0], num_classes) 126 | 127 | def forward(self, x): 128 | out = self.features(x) 129 | out = self.avgpool(out) 130 | out = out.view(out.size(0), -1) 131 | out = self.classifier(out) 132 | return out 133 | 134 | def _get_activation(self, net_act): 135 | if net_act == 'sigmoid': 136 | return nn.Sigmoid() 137 | elif net_act == 'relu': 138 | return nn.ReLU(inplace=True) 139 | elif net_act == 'leakyrelu': 140 | return nn.LeakyReLU(negative_slope=0.01) 141 | else: 142 | exit('unknown activation function: %s'%net_act) 143 | 144 | def _get_pooling(self, net_pooling): 145 | if net_pooling == 'maxpooling': 146 | return nn.MaxPool2d(kernel_size=2, stride=2) 147 | elif net_pooling == 'avgpooling': 148 | return nn.AvgPool2d(kernel_size=2, stride=2) 149 | elif net_pooling == 'none': 150 | return None 151 | else: 152 | exit('unknown net_pooling: %s'%net_pooling) 153 | 154 | def _get_normlayer(self, net_norm, shape_feat): 155 | # shape_feat = (c*h*w) 156 | if net_norm == 'batchnorm': 157 | return nn.BatchNorm2d(shape_feat[0], affine=True) 158 | elif net_norm == 'layernorm': 159 | return nn.LayerNorm(shape_feat, elementwise_affine=True) 160 | elif net_norm == 'instancenorm': 161 | return nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True) 162 | elif net_norm == 'groupnorm': 163 | return nn.GroupNorm(4, shape_feat[0], affine=True) 164 | elif net_norm == 'none': 165 | return None 166 | else: 167 | exit('unknown net_norm: %s'%net_norm) 168 | 169 | def _make_layers(self, channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size): 170 | layers = [] 171 | in_channels = channel 172 | if im_size[0] == 28: 173 | im_size = (32, 32) 174 | shape_feat = [in_channels, im_size[0], im_size[1]] 175 | for d in range(net_depth): 176 | layers += [nn.Conv2d(in_channels, net_width, kernel_size=3, padding=3 if channel == 1 and d == 0 else 1)] 177 | shape_feat[0] = net_width 178 | if net_norm != 'none': 179 | layers += [self._get_normlayer(net_norm, shape_feat)] 180 | layers += [self._get_activation(net_act)] 181 | in_channels = net_width 182 | if net_pooling != 'none': 183 | layers += [self._get_pooling(net_pooling)] 184 | shape_feat[1] //= 2 185 | shape_feat[2] //= 2 186 | 187 | return nn.Sequential(*layers), shape_feat 188 | 189 | 190 | ''' LeNet ''' 191 | class LeNet(nn.Module): 192 | def __init__(self, channel, num_classes): 193 | super(LeNet, self).__init__() 194 | self.features = nn.Sequential( 195 | nn.Conv2d(channel, 6, kernel_size=5, padding=2 if channel==1 else 0), 196 | nn.ReLU(inplace=True), 197 | nn.MaxPool2d(kernel_size=2, stride=2), 198 | nn.Conv2d(6, 16, kernel_size=5), 199 | nn.ReLU(inplace=True), 200 | nn.MaxPool2d(kernel_size=2, stride=2), 201 | ) 202 | self.fc_1 = nn.Linear(16 * 5 * 5, 120) 203 | self.fc_2 = nn.Linear(120, 84) 204 | self.fc_3 = nn.Linear(84, num_classes) 205 | 206 | def forward(self, x): 207 | x = self.features(x) 208 | x = x.view(x.size(0), -1) 209 | x = F.relu(self.fc_1(x)) 210 | x = F.relu(self.fc_2(x)) 211 | x = self.fc_3(x) 212 | return x 213 | 214 | 215 | 216 | ''' AlexNet ''' 217 | class AlexNet(nn.Module): 218 | def __init__(self, channel, num_classes): 219 | super(AlexNet, self).__init__() 220 | self.features = nn.Sequential( 221 | nn.Conv2d(channel, 128, kernel_size=5, stride=1, padding=4 if channel==1 else 2), 222 | nn.ReLU(inplace=True), 223 | nn.MaxPool2d(kernel_size=2, stride=2), 224 | nn.Conv2d(128, 192, kernel_size=5, padding=2), 225 | nn.ReLU(inplace=True), 226 | nn.MaxPool2d(kernel_size=2, stride=2), 227 | nn.Conv2d(192, 256, kernel_size=3, padding=1), 228 | nn.ReLU(inplace=True), 229 | nn.Conv2d(256, 192, kernel_size=3, padding=1), 230 | nn.ReLU(inplace=True), 231 | nn.Conv2d(192, 192, kernel_size=3, padding=1), 232 | nn.ReLU(inplace=True), 233 | nn.MaxPool2d(kernel_size=2, stride=2), 234 | ) 235 | self.fc = nn.Linear(192 * 4 * 4, num_classes) 236 | 237 | def forward(self, x): 238 | x = self.features(x) 239 | x = x.view(x.size(0), -1) 240 | x = self.fc(x) 241 | return x 242 | 243 | 244 | 245 | ''' VGG ''' 246 | cfg_vgg = { 247 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 248 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 249 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 250 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 251 | } 252 | class VGG(nn.Module): 253 | def __init__(self, vgg_name, channel, num_classes, norm='instancenorm'): 254 | super(VGG, self).__init__() 255 | self.channel = channel 256 | self.features = self._make_layers(cfg_vgg[vgg_name], norm) 257 | self.classifier = nn.Linear(512 if vgg_name != 'VGGS' else 128, num_classes) 258 | 259 | def forward(self, x): 260 | x = self.features(x) 261 | x = x.view(x.size(0), -1) 262 | x = self.classifier(x) 263 | return x 264 | 265 | def _make_layers(self, cfg, norm): 266 | layers = [] 267 | in_channels = self.channel 268 | for ic, x in enumerate(cfg): 269 | if x == 'M': 270 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 271 | else: 272 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=3 if self.channel==1 and ic==0 else 1), 273 | nn.GroupNorm(x, x, affine=True) if norm=='instancenorm' else nn.BatchNorm2d(x), 274 | nn.ReLU(inplace=True)] 275 | in_channels = x 276 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 277 | return nn.Sequential(*layers) 278 | 279 | 280 | def VGG11(channel, num_classes): 281 | return VGG('VGG11', channel, num_classes) 282 | def VGG11BN(channel, num_classes): 283 | return VGG('VGG11', channel, num_classes, norm='batchnorm') 284 | def VGG13(channel, num_classes): 285 | return VGG('VGG13', channel, num_classes) 286 | def VGG16(channel, num_classes): 287 | return VGG('VGG16', channel, num_classes) 288 | def VGG19(channel, num_classes): 289 | return VGG('VGG19', channel, num_classes) 290 | 291 | 292 | ''' ResNet_AP ''' 293 | # The conv(stride=2) is replaced by conv(stride=1) + avgpool(kernel_size=2, stride=2) 294 | 295 | class BasicBlock_AP(nn.Module): 296 | expansion = 1 297 | 298 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'): 299 | super(BasicBlock_AP, self).__init__() 300 | self.norm = norm 301 | self.stride = stride 302 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False) # modification 303 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 304 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 305 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 306 | 307 | self.shortcut = nn.Sequential() 308 | if stride != 1 or in_planes != self.expansion * planes: 309 | self.shortcut = nn.Sequential( 310 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=1, bias=False), 311 | nn.AvgPool2d(kernel_size=2, stride=2), # modification 312 | nn.GroupNorm(self.expansion * planes, self.expansion * planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion * planes) 313 | ) 314 | 315 | def forward(self, x): 316 | out = F.relu(self.bn1(self.conv1(x))) 317 | if self.stride != 1: # modification 318 | out = F.avg_pool2d(out, kernel_size=2, stride=2) 319 | out = self.bn2(self.conv2(out)) 320 | out += self.shortcut(x) 321 | out = F.relu(out) 322 | return out 323 | 324 | 325 | class Bottleneck_AP(nn.Module): 326 | expansion = 4 327 | 328 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'): 329 | super(Bottleneck_AP, self).__init__() 330 | self.norm = norm 331 | self.stride = stride 332 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 333 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 334 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) # modification 335 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 336 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 337 | self.bn3 = nn.GroupNorm(self.expansion * planes, self.expansion * planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion * planes) 338 | 339 | self.shortcut = nn.Sequential() 340 | if stride != 1 or in_planes != self.expansion * planes: 341 | self.shortcut = nn.Sequential( 342 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=1, bias=False), 343 | nn.AvgPool2d(kernel_size=2, stride=2), # modification 344 | nn.GroupNorm(self.expansion * planes, self.expansion * planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion * planes) 345 | ) 346 | 347 | def forward(self, x): 348 | out = F.relu(self.bn1(self.conv1(x))) 349 | out = F.relu(self.bn2(self.conv2(out))) 350 | if self.stride != 1: # modification 351 | out = F.avg_pool2d(out, kernel_size=2, stride=2) 352 | out = self.bn3(self.conv3(out)) 353 | out += self.shortcut(x) 354 | out = F.relu(out) 355 | return out 356 | 357 | 358 | class ResNet_AP(nn.Module): 359 | def __init__(self, block, num_blocks, channel=3, num_classes=10, norm='instancenorm'): 360 | super(ResNet_AP, self).__init__() 361 | self.in_planes = 64 362 | self.norm = norm 363 | 364 | self.conv1 = nn.Conv2d(channel, 64, kernel_size=3, stride=1, padding=1, bias=False) 365 | self.bn1 = nn.GroupNorm(64, 64, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(64) 366 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 367 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 368 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 369 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 370 | self.classifier = nn.Linear(512 * block.expansion * 3 * 3 if channel==1 else 512 * block.expansion * 4 * 4, num_classes) # modification 371 | 372 | def _make_layer(self, block, planes, num_blocks, stride): 373 | strides = [stride] + [1] * (num_blocks - 1) 374 | layers = [] 375 | for stride in strides: 376 | layers.append(block(self.in_planes, planes, stride, self.norm)) 377 | self.in_planes = planes * block.expansion 378 | return nn.Sequential(*layers) 379 | 380 | def forward(self, x): 381 | out = F.relu(self.bn1(self.conv1(x))) 382 | out = self.layer1(out) 383 | out = self.layer2(out) 384 | out = self.layer3(out) 385 | out = self.layer4(out) 386 | out = F.avg_pool2d(out, kernel_size=1, stride=1) # modification 387 | out = out.view(out.size(0), -1) 388 | out = self.classifier(out) 389 | return out 390 | 391 | 392 | def ResNet18BN_AP(channel, num_classes): 393 | return ResNet_AP(BasicBlock_AP, [2,2,2,2], channel=channel, num_classes=num_classes, norm='batchnorm') 394 | 395 | def ResNet18_AP(channel, num_classes): 396 | return ResNet_AP(BasicBlock_AP, [2,2,2,2], channel=channel, num_classes=num_classes) 397 | 398 | 399 | ''' ResNet ''' 400 | 401 | class BasicBlock(nn.Module): 402 | expansion = 1 403 | 404 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'): 405 | super(BasicBlock, self).__init__() 406 | self.norm = norm 407 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 408 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 409 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 410 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 411 | 412 | self.shortcut = nn.Sequential() 413 | if stride != 1 or in_planes != self.expansion*planes: 414 | self.shortcut = nn.Sequential( 415 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 416 | nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes) 417 | ) 418 | 419 | def forward(self, x): 420 | out = F.relu(self.bn1(self.conv1(x))) 421 | out = self.bn2(self.conv2(out)) 422 | out += self.shortcut(x) 423 | out = F.relu(out) 424 | return out 425 | 426 | 427 | class Bottleneck(nn.Module): 428 | expansion = 4 429 | 430 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'): 431 | super(Bottleneck, self).__init__() 432 | self.norm = norm 433 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 434 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 435 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 436 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 437 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 438 | self.bn3 = nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes) 439 | 440 | self.shortcut = nn.Sequential() 441 | if stride != 1 or in_planes != self.expansion*planes: 442 | self.shortcut = nn.Sequential( 443 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 444 | nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes) 445 | ) 446 | 447 | def forward(self, x): 448 | out = F.relu(self.bn1(self.conv1(x))) 449 | out = F.relu(self.bn2(self.conv2(out))) 450 | out = self.bn3(self.conv3(out)) 451 | out += self.shortcut(x) 452 | out = F.relu(out) 453 | return out 454 | 455 | 456 | class ResNetImageNet(nn.Module): 457 | def __init__(self, block, num_blocks, channel=3, num_classes=10, norm='instancenorm'): 458 | super(ResNetImageNet, self).__init__() 459 | self.in_planes = 64 460 | self.norm = norm 461 | 462 | self.conv1 = nn.Conv2d(channel, 64, kernel_size=7, stride=2, padding=3, bias=False) 463 | self.bn1 = nn.GroupNorm(64, 64, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(64) 464 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 465 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 466 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 467 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 468 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 469 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 470 | self.classifier = nn.Linear(512*block.expansion, num_classes) 471 | 472 | def _make_layer(self, block, planes, num_blocks, stride): 473 | strides = [stride] + [1]*(num_blocks-1) 474 | layers = [] 475 | for stride in strides: 476 | layers.append(block(self.in_planes, planes, stride, self.norm)) 477 | self.in_planes = planes * block.expansion 478 | return nn.Sequential(*layers) 479 | 480 | def forward(self, x): 481 | out = F.relu(self.bn1(self.conv1(x))) 482 | out = self.maxpool(out) 483 | out = self.layer1(out) 484 | out = self.layer2(out) 485 | out = self.layer3(out) 486 | out = self.layer4(out) 487 | # out = F.avg_pool2d(out, 4) 488 | # out = out.view(out.size(0), -1) 489 | out = self.avgpool(out) 490 | out = torch.flatten(out, 1) 491 | out = self.classifier(out) 492 | return out 493 | 494 | 495 | def ResNet18BN(channel, num_classes): 496 | return ResNet(BasicBlock, [2,2,2,2], channel=channel, num_classes=num_classes, norm='batchnorm') 497 | 498 | def ResNet18(channel, num_classes): 499 | return ResNet_gn(BasicBlock, [2,2,2,2], channel=channel, num_classes=num_classes) 500 | 501 | def ResNet34(channel, num_classes): 502 | return ResNet(BasicBlock, [3,4,6,3], channel=channel, num_classes=num_classes) 503 | 504 | def ResNet50(channel, num_classes): 505 | return ResNet(Bottleneck, [3,4,6,3], channel=channel, num_classes=num_classes) 506 | 507 | def ResNet101(channel, num_classes): 508 | return ResNet(Bottleneck, [3,4,23,3], channel=channel, num_classes=num_classes) 509 | 510 | def ResNet152(channel, num_classes): 511 | return ResNet(Bottleneck, [3,8,36,3], channel=channel, num_classes=num_classes) 512 | 513 | def ResNet18ImageNet(channel, num_classes): 514 | return ResNetImageNet(BasicBlock, [2,2,2,2], channel=channel, num_classes=num_classes) 515 | 516 | def ResNet6ImageNet(channel, num_classes): 517 | return ResNetImageNet(BasicBlock, [1,1,1,1], channel=channel, num_classes=num_classes) 518 | 519 | def resnet18_gn(pretrained=False, **kwargs): 520 | """Constructs a ResNet-18 model. 521 | """ 522 | model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2]) 523 | return _create_resnet('resnet18', pretrained, **dict(model_args, **kwargs)) 524 | 525 | 526 | ## Sourced directly from OpenAI's CLIP repo 527 | class ModifiedResNet(nn.Module): 528 | """ 529 | A ResNet class that is similar to torchvision's but contains the following changes: 530 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 531 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 532 | - The final pooling layer is a QKV attention instead of an average pool 533 | """ 534 | 535 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 536 | super().__init__() 537 | self.output_dim = output_dim 538 | self.input_resolution = input_resolution 539 | 540 | # the 3-layer stem 541 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 542 | self.bn1 = nn.BatchNorm2d(width // 2) 543 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 544 | self.bn2 = nn.BatchNorm2d(width // 2) 545 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 546 | self.bn3 = nn.BatchNorm2d(width) 547 | self.avgpool = nn.AvgPool2d(2) 548 | self.relu = nn.ReLU(inplace=True) 549 | 550 | # residual layers 551 | self._inplanes = width # this is a *mutable* variable used during construction 552 | self.layer1 = self._make_layer(width, layers[0]) 553 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 554 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 555 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 556 | 557 | embed_dim = width * 32 # the ResNet feature dimension 558 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 559 | 560 | def _make_layer(self, planes, blocks, stride=1): 561 | layers = [Bottleneck(self._inplanes, planes, stride)] 562 | 563 | self._inplanes = planes * Bottleneck.expansion 564 | for _ in range(1, blocks): 565 | layers.append(Bottleneck(self._inplanes, planes)) 566 | 567 | return nn.Sequential(*layers) 568 | 569 | def forward(self, x): 570 | def stem(x): 571 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 572 | x = self.relu(bn(conv(x))) 573 | x = self.avgpool(x) 574 | return x 575 | 576 | x = x.type(self.conv1.weight.dtype) 577 | x = stem(x) 578 | x = self.layer1(x) 579 | x = self.layer2(x) 580 | x = self.layer3(x) 581 | x = self.layer4(x) 582 | x = self.attnpool(x) 583 | 584 | return x 585 | 586 | 587 | class AttentionPool2d(nn.Module): 588 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 589 | super().__init__() 590 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 591 | self.k_proj = nn.Linear(embed_dim, embed_dim) 592 | self.q_proj = nn.Linear(embed_dim, embed_dim) 593 | self.v_proj = nn.Linear(embed_dim, embed_dim) 594 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 595 | self.num_heads = num_heads 596 | 597 | def forward(self, x): 598 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 599 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 600 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 601 | x, _ = F.multi_head_attention_forward( 602 | query=x, key=x, value=x, 603 | embed_dim_to_check=x.shape[-1], 604 | num_heads=self.num_heads, 605 | q_proj_weight=self.q_proj.weight, 606 | k_proj_weight=self.k_proj.weight, 607 | v_proj_weight=self.v_proj.weight, 608 | in_proj_weight=None, 609 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 610 | bias_k=None, 611 | bias_v=None, 612 | add_zero_attn=False, 613 | dropout_p=0, 614 | out_proj_weight=self.c_proj.weight, 615 | out_proj_bias=self.c_proj.bias, 616 | use_separate_proj_weight=True, 617 | training=self.training, 618 | need_weights=False 619 | ) 620 | 621 | return x[0] 622 | 623 | import timm 624 | 625 | class ProjectionHead(nn.Module): 626 | def __init__( 627 | self, 628 | embedding_dim, 629 | projection_dim=768, 630 | dropout=0.1 631 | ): 632 | super().__init__() 633 | self.projection = nn.Linear(embedding_dim, projection_dim) 634 | self.gelu = nn.GELU() 635 | self.fc = nn.Linear(projection_dim, projection_dim) 636 | self.dropout = nn.Dropout(dropout) 637 | self.layer_norm = nn.LayerNorm(projection_dim) 638 | 639 | def forward(self, x): 640 | projected = self.projection(x) 641 | x = self.gelu(projected) 642 | x = self.fc(x) 643 | x = self.dropout(x) 644 | x = x + projected 645 | x = self.layer_norm(x) 646 | return x 647 | 648 | class ImageEncoder(nn.Module): 649 | """ 650 | Encode images to a fixed size vector 651 | """ 652 | 653 | def __init__(self, args, eval_stage): 654 | super().__init__() 655 | self.model_name = args.image_encoder 656 | self.pretrained = args.image_pretrained 657 | self.trainable = args.image_trainable 658 | 659 | if self.model_name == 'clip': 660 | if self.pretrained: 661 | self.model, preprocess = clip.load("ViT-B/32", device='cuda') 662 | else: 663 | configuration = ViTConfig() 664 | self.model = ViTModel(configuration) 665 | elif self.model_name == 'nfnet': 666 | self.model = timm.create_model('nfnet_l0', pretrained=self.pretrained, num_classes=0, global_pool="avg") 667 | elif self.model_name == 'vit': 668 | self.model = timm.create_model('vit_tiny_patch16_224', pretrained=True) 669 | elif self.model_name == 'nf_resnet50': 670 | self.model = timm.create_model('nf_resnet50', pretrained=True) 671 | elif self.model_name == 'nf_regnet': 672 | self.model = timm.create_model('nf_regnet_b1', pretrained=True) 673 | else: 674 | self.model = timm.create_model(self.model_name, self.pretrained, num_classes=0, global_pool="avg") 675 | for p in self.model.parameters(): 676 | p.requires_grad = self.trainable 677 | 678 | def forward(self, x): 679 | if self.model_name == 'clip' and self.pretrained: 680 | return self.model.encode_image(x) 681 | else: 682 | return self.model(x) 683 | 684 | def gradient(self, x, y): 685 | # Compute the gradient of the mean squared error loss with respect to the weights 686 | loss = self.loss(x, y) 687 | grad = torch.autograd.grad(loss, self.parameters(), create_graph=True) 688 | return torch.cat([g.view(-1) for g in grad]) 689 | 690 | 691 | 692 | 693 | class TextEncoder(nn.Module): 694 | def __init__(self, args): 695 | super().__init__() 696 | self.pretrained = args.text_pretrained 697 | self.trainable = args.text_trainable 698 | self.model_name = args.text_encoder 699 | if self.model_name == 'clip': 700 | self.model, preprocess = clip.load("ViT-B/32", device='cuda') 701 | elif self.model_name == 'bert': 702 | if args.text_pretrained: 703 | self.model = BERT_model 704 | else: 705 | self.model = BertModel(BertConfig()) 706 | self.model.init_weights() 707 | self.tokenizer = tokenizer 708 | else: 709 | raise NotImplementedError 710 | 711 | for p in self.model.parameters(): 712 | p.requires_grad = self.trainable 713 | 714 | # we are using the CLS token hidden representation as the sentence's embedding 715 | self.target_token_idx = 0 716 | 717 | def forward(self, texts, device='cuda'): 718 | if self.model_name == 'clip': 719 | output = self.model.encode_text(clip.tokenize(texts).to('cuda')) 720 | 721 | elif self.model_name == 'bert': 722 | # Tokenize the input text 723 | encoding = self.tokenizer.batch_encode_plus(texts, return_tensors='pt', padding=True, truncation=True) 724 | input_ids = encoding['input_ids'].to(device) 725 | attention_mask = encoding['attention_mask'].to(device) 726 | output = self.model(input_ids, attention_mask=attention_mask).last_hidden_state[:, self.target_token_idx, :] 727 | return output 728 | 729 | 730 | 731 | class CLIPModel_full(nn.Module): 732 | def __init__( 733 | self, 734 | args, 735 | temperature=1.0, 736 | eval_stage=False 737 | ): 738 | super().__init__() 739 | 740 | if args.image_encoder == 'nfnet': 741 | if eval_stage: 742 | self.image_embedding = 1000#2048 743 | else: 744 | self.image_embedding = 2304 745 | elif args.image_encoder == 'convnet': 746 | self.image_embedding = 768 747 | elif args.image_encoder == 'resnet18': 748 | self.image_embedding = 512 749 | elif args.image_encoder == 'convnext': 750 | self.image_embedding = 640 751 | else: 752 | self.image_embedding = 1000 753 | if args.text_encoder == 'clip': 754 | self.text_embedding = 512 755 | elif args.text_encoder == 'bert': 756 | self.text_embedding = 768 757 | else: 758 | raise NotImplementedError 759 | 760 | self.image_encoder = ImageEncoder(args, eval_stage=eval_stage) 761 | self.text_encoder = TextEncoder(args) 762 | 763 | if args.only_has_image_projection: 764 | self.image_projection = ProjectionHead(embedding_dim=self.image_embedding) 765 | self.text_projection = ProjectionHead(embedding_dim=self.text_embedding, projection_dim=self.image_embedding).to('cuda') 766 | self.temperature = temperature 767 | #self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 768 | self.args = args 769 | self.distill = args.distill 770 | 771 | def forward(self, image, caption, epoch): 772 | self.image_encoder = self.image_encoder.to('cuda') 773 | self.text_encoder = self.text_encoder.to('cuda') 774 | 775 | image_features = self.image_encoder(image) 776 | text_features = caption if self.distill else self.text_encoder(caption) 777 | 778 | use_image_project = False 779 | im_embed = image_features.float() if not use_image_project else self.image_projection(image_features.float()) 780 | txt_embed = self.text_projection(text_features.float()) 781 | 782 | combined_image_features = im_embed 783 | combined_text_features = txt_embed 784 | image_features = combined_image_features / combined_image_features.norm(dim=1, keepdim=True) 785 | text_features = combined_text_features / combined_text_features.norm(dim=1, keepdim=True) 786 | 787 | 788 | image_logits = np.exp(np.log(1 / 0.07)) * image_features @ text_features.t() 789 | ground_truth = torch.arange(len(image_logits)).type_as(image_logits).long() 790 | loss = (F.cross_entropy(image_logits, ground_truth) + F.cross_entropy(image_logits.t(), ground_truth))/2 791 | acc_i = (torch.argmax(image_logits, 1) == ground_truth).sum().item() 792 | acc_t = (torch.argmax(image_logits, 0) == ground_truth).sum().item() 793 | acc = (acc_i + acc_t) / 2 794 | return loss, acc -------------------------------------------------------------------------------- /reparam_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import warnings 4 | import types 5 | from collections import namedtuple 6 | from contextlib import contextmanager 7 | 8 | 9 | class ReparamModule(nn.Module): 10 | def _get_module_from_name(self, mn): 11 | if mn == '': 12 | return self 13 | m = self 14 | for p in mn.split('.'): 15 | m = getattr(m, p) 16 | return m 17 | 18 | def __init__(self, module): 19 | super(ReparamModule, self).__init__() 20 | self.module = module 21 | 22 | param_infos = [] # (module name/path, param name) 23 | shared_param_memo = {} 24 | shared_param_infos = [] # (module name/path, param name, src module name/path, src param_name) 25 | params = [] 26 | param_numels = [] 27 | param_shapes = [] 28 | for mn, m in self.named_modules(): 29 | for n, p in m.named_parameters(recurse=False): 30 | if p is not None: 31 | if p in shared_param_memo: 32 | shared_mn, shared_n = shared_param_memo[p] 33 | shared_param_infos.append((mn, n, shared_mn, shared_n)) 34 | else: 35 | shared_param_memo[p] = (mn, n) 36 | param_infos.append((mn, n)) 37 | params.append(p.detach()) 38 | param_numels.append(p.numel()) 39 | param_shapes.append(p.size()) 40 | 41 | assert len(set(p.dtype for p in params)) <= 1, \ 42 | "expects all parameters in module to have same dtype" 43 | 44 | # store the info for unflatten 45 | self._param_infos = tuple(param_infos) 46 | self._shared_param_infos = tuple(shared_param_infos) 47 | self._param_numels = tuple(param_numels) 48 | self._param_shapes = tuple(param_shapes) 49 | 50 | # flatten 51 | flat_param = nn.Parameter(torch.cat([p.reshape(-1) for p in params], 0)) 52 | self.register_parameter('flat_param', flat_param) 53 | self.param_numel = flat_param.numel() 54 | del params 55 | del shared_param_memo 56 | 57 | # deregister the names as parameters 58 | for mn, n in self._param_infos: 59 | delattr(self._get_module_from_name(mn), n) 60 | for mn, n, _, _ in self._shared_param_infos: 61 | delattr(self._get_module_from_name(mn), n) 62 | 63 | # register the views as plain attributes 64 | self._unflatten_param(self.flat_param) 65 | 66 | # now buffers 67 | # they are not reparametrized. just store info as (module, name, buffer) 68 | buffer_infos = [] 69 | for mn, m in self.named_modules(): 70 | for n, b in m.named_buffers(recurse=False): 71 | if b is not None: 72 | buffer_infos.append((mn, n, b)) 73 | 74 | self._buffer_infos = tuple(buffer_infos) 75 | self._traced_self = None 76 | 77 | def trace(self, example_input, **trace_kwargs): 78 | assert self._traced_self is None, 'This ReparamModule is already traced' 79 | 80 | if isinstance(example_input, torch.Tensor): 81 | example_input = (example_input,) 82 | example_input = tuple(example_input) 83 | example_param = (self.flat_param.detach().clone(),) 84 | example_buffers = (tuple(b.detach().clone() for _, _, b in self._buffer_infos),) 85 | 86 | self._traced_self = torch.jit.trace_module( 87 | self, 88 | inputs=dict( 89 | _forward_with_param=example_param + example_input, 90 | _forward_with_param_and_buffers=example_param + example_buffers + example_input, 91 | ), 92 | **trace_kwargs, 93 | ) 94 | 95 | # replace forwards with traced versions 96 | self._forward_with_param = self._traced_self._forward_with_param 97 | self._forward_with_param_and_buffers = self._traced_self._forward_with_param_and_buffers 98 | return self 99 | 100 | def clear_views(self): 101 | for mn, n in self._param_infos: 102 | setattr(self._get_module_from_name(mn), n, None) # This will set as plain attr 103 | 104 | def _apply(self, *args, **kwargs): 105 | if self._traced_self is not None: 106 | self._traced_self._apply(*args, **kwargs) 107 | return self 108 | return super(ReparamModule, self)._apply(*args, **kwargs) 109 | 110 | def _unflatten_param(self, flat_param): 111 | ps = (t.view(s) for (t, s) in zip(flat_param.split(self._param_numels), self._param_shapes)) 112 | for (mn, n), p in zip(self._param_infos, ps): 113 | setattr(self._get_module_from_name(mn), n, p) # This will set as plain attr 114 | for (mn, n, shared_mn, shared_n) in self._shared_param_infos: 115 | setattr(self._get_module_from_name(mn), n, getattr(self._get_module_from_name(shared_mn), shared_n)) 116 | 117 | @contextmanager 118 | def unflattened_param(self, flat_param): 119 | saved_views = [getattr(self._get_module_from_name(mn), n) for mn, n in self._param_infos] 120 | self._unflatten_param(flat_param) 121 | yield 122 | # Why not just `self._unflatten_param(self.flat_param)`? 123 | # 1. because of https://github.com/pytorch/pytorch/issues/17583 124 | # 2. slightly faster since it does not require reconstruct the split+view 125 | # graph 126 | for (mn, n), p in zip(self._param_infos, saved_views): 127 | setattr(self._get_module_from_name(mn), n, p) 128 | for (mn, n, shared_mn, shared_n) in self._shared_param_infos: 129 | setattr(self._get_module_from_name(mn), n, getattr(self._get_module_from_name(shared_mn), shared_n)) 130 | 131 | @contextmanager 132 | def replaced_buffers(self, buffers): 133 | for (mn, n, _), new_b in zip(self._buffer_infos, buffers): 134 | setattr(self._get_module_from_name(mn), n, new_b) 135 | yield 136 | for mn, n, old_b in self._buffer_infos: 137 | setattr(self._get_module_from_name(mn), n, old_b) 138 | 139 | def _forward_with_param_and_buffers(self, flat_param, buffers, *inputs, **kwinputs): 140 | with self.unflattened_param(flat_param): 141 | with self.replaced_buffers(buffers): 142 | return self.module(*inputs, **kwinputs) 143 | 144 | def _forward_with_param(self, flat_param, *inputs, **kwinputs): 145 | with self.unflattened_param(flat_param): 146 | return self.module(*inputs, **kwinputs) 147 | 148 | def forward(self, *inputs, flat_param=None, buffers=None, **kwinputs): 149 | flat_param = torch.squeeze(flat_param) 150 | # print("PARAMS ON DEVICE: ", flat_param.get_device()) 151 | # print("DATA ON DEVICE: ", inputs[0].get_device()) 152 | # flat_param.to("cuda:{}".format(inputs[0].get_device())) 153 | # self.module.to("cuda:{}".format(inputs[0].get_device())) 154 | if flat_param is None: 155 | flat_param = self.flat_param 156 | if buffers is None: 157 | return self._forward_with_param(flat_param, *inputs, **kwinputs) 158 | else: 159 | return self._forward_with_param_and_buffers(flat_param, tuple(buffers), *inputs, **kwinputs) -------------------------------------------------------------------------------- /requirements.yaml: -------------------------------------------------------------------------------- 1 | name: vl-distill 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=4.5=1_gnu 9 | - absl-py=1.2.0=pypi_0 10 | - accelerate=0.17.0=pypi_0 11 | - addict=2.4.0=pypi_0 12 | - aiohttp=3.8.2=pypi_0 13 | - aiosignal=1.2.0=pypi_0 14 | - antlr4-python3-runtime=4.9.3=pypi_0 15 | - anyio=3.6.1=pypi_0 16 | - argcomplete=2.0.0=pypi_0 17 | - argon2-cffi=21.3.0=pypi_0 18 | - argon2-cffi-bindings=21.2.0=pypi_0 19 | - asttokens=2.0.8=pypi_0 20 | - async-timeout=4.0.2=pypi_0 21 | - attrs=22.1.0=pypi_0 22 | - autopep8=1.7.0=pypi_0 23 | - babel=2.10.3=pypi_0 24 | - backcall=0.2.0=pypi_0 25 | - beautifulsoup4=4.11.1=pypi_0 26 | - blas=1.0=mkl 27 | - bleach=5.0.1=pypi_0 28 | - boto=2.49.0=pypi_0 29 | - braceexpand=0.1.7=pypi_0 30 | - brotlipy=0.7.0=py39h27cfd23_1003 31 | - bzip2=1.0.8=h7b6447c_0 32 | - ca-certificates=2021.10.8=ha878542_0 33 | - cachetools=5.2.0=pypi_0 34 | - certifi=2021.10.8=py39hf3d152e_1 35 | - cffi=1.15.0=py39hd667e15_1 36 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 37 | - chatgdb=1.0.3=pypi_0 38 | - click=7.1.2=pyh9f0ad1d_0 39 | - clip=1.0=pypi_0 40 | - cloudpickle=2.2.0=pypi_0 41 | - cmake=3.26.3=pypi_0 42 | - colorama=0.4.4=pyh9f0ad1d_0 43 | - configparser=5.2.0=pyhd8ed1ab_0 44 | - contourpy=1.0.5=pypi_0 45 | - crcmod=1.7=pypi_0 46 | - cryptography=36.0.0=py39h9ce1e76_0 47 | - cudatoolkit=11.3.1=h2bc3f7f_2 48 | - cycler=0.11.0=pypi_0 49 | - dataclasses=0.8=pyhc8e2a94_3 50 | - datasets=2.7.1=pypi_0 51 | - debugpy=1.6.3=pypi_0 52 | - decorator=5.1.1=pypi_0 53 | - decord=0.6.0=pypi_0 54 | - defusedxml=0.7.1=pypi_0 55 | - deprecation=2.1.0=pypi_0 56 | - diffusers=0.14.0=pypi_0 57 | - dill=0.3.6=pypi_0 58 | - distlib=0.3.6=pypi_0 59 | - dm-haiku=0.0.9=pypi_0 60 | - docker-pycreds=0.4.0=py_0 61 | - einops=0.5.0=pypi_0 62 | - entrypoints=0.4=pypi_0 63 | - executing=1.1.0=pypi_0 64 | - fairscale=0.4.12=pypi_0 65 | - fasteners=0.18=pypi_0 66 | - fastjsonschema=2.16.2=pypi_0 67 | - ffmpeg=4.3=hf484d3e_0 68 | - filelock=3.8.0=pypi_0 69 | - fonttools=4.37.3=pypi_0 70 | - freetype=2.11.0=h70c0345_0 71 | - frozenlist=1.3.1=pypi_0 72 | - fsspec=2022.8.2=pypi_0 73 | - ftfy=6.1.1=pypi_0 74 | - fvcore=0.1.5.post20220512=pypi_0 75 | - gcs-oauth2-boto-plugin=3.0=pypi_0 76 | - gdown=4.7.1=pypi_0 77 | - giflib=5.2.1=h7b6447c_0 78 | - git-filter-repo=2.38.0=pypi_0 79 | - gitdb=4.0.9=pyhd8ed1ab_0 80 | - gitpython=3.1.27=pyhd8ed1ab_0 81 | - gmp=6.2.1=h2531618_2 82 | - gnutls=3.6.15=he1e5248_0 83 | - google-apitools=0.5.32=pypi_0 84 | - google-auth=2.11.1=pypi_0 85 | - google-auth-oauthlib=0.4.6=pypi_0 86 | - google-reauth=0.1.1=pypi_0 87 | - gql=0.2.0=pypi_0 88 | - graphql-core=1.1=pypi_0 89 | - grpcio=1.48.1=pypi_0 90 | - gsutil=5.14=pypi_0 91 | - gym=0.26.2=pypi_0 92 | - gym-notices=0.0.8=pypi_0 93 | - httplib2=0.20.4=pypi_0 94 | - huggingface-hub=0.13.2=pypi_0 95 | - idna=3.3=pyhd3eb1b0_0 96 | - importlib-metadata=4.12.0=pypi_0 97 | - intel-openmp=2021.4.0=h06a4308_3561 98 | - iopath=0.1.10=pypi_0 99 | - ipdb=0.13.13=pypi_0 100 | - ipykernel=6.16.0=pypi_0 101 | - ipython=8.5.0=pypi_0 102 | - ipython-genutils=0.2.0=pypi_0 103 | - ipywidgets=8.0.2=pypi_0 104 | - isort=5.10.1=pypi_0 105 | - jedi=0.18.1=pypi_0 106 | - jinja2=3.1.2=pypi_0 107 | - jmp=0.0.4=pypi_0 108 | - joblib=1.2.0=pypi_0 109 | - jpeg=9d=h7f8727e_0 110 | - json5=0.9.10=pypi_0 111 | - jsonschema=4.16.0=pypi_0 112 | - jstyleson=0.0.2=pypi_0 113 | - jupyter-client=7.3.5=pypi_0 114 | - jupyter-core=4.11.1=pypi_0 115 | - jupyter-packaging=0.12.3=pypi_0 116 | - jupyter-server=1.19.0=pypi_0 117 | - jupyterlab=3.4.7=pypi_0 118 | - jupyterlab-pygments=0.2.2=pypi_0 119 | - jupyterlab-server=2.15.2=pypi_0 120 | - jupyterlab-widgets=3.0.3=pypi_0 121 | - kaggle=1.5.12=pypi_0 122 | - kiwisolver=1.4.4=pypi_0 123 | - kornia=0.6.3=pyhd8ed1ab_0 124 | - lame=3.100=h7b6447c_0 125 | - lcms2=2.12=h3be6417_0 126 | - ld_impl_linux-64=2.35.1=h7274673_9 127 | - learn2learn=0.1.7=pypi_0 128 | - libffi=3.3=he6710b0_2 129 | - libgcc-ng=9.3.0=h5101ec6_17 130 | - libgfortran-ng=7.5.0=ha8ba4b0_17 131 | - libgfortran4=7.5.0=ha8ba4b0_17 132 | - libgomp=9.3.0=h5101ec6_17 133 | - libiconv=1.15=h63c8f33_5 134 | - libidn2=2.3.2=h7f8727e_0 135 | - libpng=1.6.37=hbc83047_0 136 | - libprotobuf=3.15.8=h780b84a_0 137 | - libstdcxx-ng=9.3.0=hd4cf53a_17 138 | - libtasn1=4.16.0=h27cfd23_0 139 | - libtiff=4.2.0=h85742a9_0 140 | - libunistring=0.9.10=h27cfd23_0 141 | - libuv=1.40.0=h7b6447c_0 142 | - libwebp=1.2.2=h55f646e_0 143 | - libwebp-base=1.2.2=h7f8727e_0 144 | - lit=16.0.1=pypi_0 145 | - llvmlite=0.39.1=pypi_0 146 | - logger=1.4=pypi_0 147 | - lxml=4.9.1=pypi_0 148 | - lz4-c=1.9.3=h295c915_1 149 | - markdown=3.4.1=pypi_0 150 | - markupsafe=2.1.1=pypi_0 151 | - matplotlib=3.6.0=pypi_0 152 | - matplotlib-inline=0.1.6=pypi_0 153 | - mistune=2.0.4=pypi_0 154 | - mkl=2021.4.0=h06a4308_640 155 | - mkl-service=2.4.0=py39h7f8727e_0 156 | - mkl_fft=1.3.1=py39hd3c417c_0 157 | - mkl_random=1.2.2=py39h51133e4_0 158 | - monotonic=1.6=pypi_0 159 | - mpmath=1.3.0=pypi_0 160 | - mss=6.1.0=pypi_0 161 | - multidict=5.2.0=pypi_0 162 | - multiprocess=0.70.14=pypi_0 163 | - nbclassic=0.4.3=pypi_0 164 | - nbclient=0.6.8=pypi_0 165 | - nbconvert=7.0.0=pypi_0 166 | - nbformat=5.6.1=pypi_0 167 | - ncurses=6.3=h7f8727e_2 168 | - nest-asyncio=1.5.5=pypi_0 169 | - nettle=3.7.3=hbbd107a_1 170 | - networkx=2.8.8=pypi_0 171 | - nltk=3.8=pypi_0 172 | - notebook=6.4.12=pypi_0 173 | - notebook-shim=0.1.0=pypi_0 174 | - numba=0.56.4=pypi_0 175 | - numpy=1.22.4=pypi_0 176 | - nvidia-cublas-cu11=11.10.3.66=pypi_0 177 | - nvidia-cuda-cupti-cu11=11.7.101=pypi_0 178 | - nvidia-cuda-nvrtc-cu11=11.7.99=pypi_0 179 | - nvidia-cuda-runtime-cu11=11.7.99=pypi_0 180 | - nvidia-cudnn-cu11=8.5.0.96=pypi_0 181 | - nvidia-cufft-cu11=10.9.0.58=pypi_0 182 | - nvidia-curand-cu11=10.2.10.91=pypi_0 183 | - nvidia-cusolver-cu11=11.4.0.1=pypi_0 184 | - nvidia-cusparse-cu11=11.7.4.91=pypi_0 185 | - nvidia-ml-py3=7.352.0=pypi_0 186 | - nvidia-nccl-cu11=2.14.3=pypi_0 187 | - nvidia-nvtx-cu11=11.7.91=pypi_0 188 | - oauth2client=4.1.3=pypi_0 189 | - oauthlib=3.2.1=pypi_0 190 | - omegaconf=2.2.3=pypi_0 191 | - open3d=0.15.2=pypi_0 192 | - opencv-python=4.6.0.66=pypi_0 193 | - opendatasets=0.1.22=pypi_0 194 | - openh264=2.1.1=h4ff587b_0 195 | - openssl=1.1.1m=h7f8727e_0 196 | - openvino=2022.3.0=pypi_0 197 | - openvino-dev=2022.3.0=pypi_0 198 | - openvino-telemetry=2022.3.0=pypi_0 199 | - optree=0.9.0=pypi_0 200 | - packaging=21.3=pyhd8ed1ab_0 201 | - pandas=1.3.5=pypi_0 202 | - pandocfilters=1.5.0=pypi_0 203 | - parso=0.8.3=pypi_0 204 | - pathtools=0.1.2=py_1 205 | - pexpect=4.8.0=pypi_0 206 | - pickleshare=0.7.5=pypi_0 207 | - pillow=9.0.1=py39h22f2fdc_0 208 | - pip=21.2.4=py39h06a4308_0 209 | - pipenv=2022.9.21=pypi_0 210 | - platformdirs=2.5.2=pypi_0 211 | - plotext=5.2.8=pypi_0 212 | - portalocker=2.5.1=pypi_0 213 | - prefetch-generator=1.0.3=pypi_0 214 | - prometheus-client=0.14.1=pypi_0 215 | - promise=2.3=py39hf3d152e_5 216 | - prompt-toolkit=3.0.31=pypi_0 217 | - protobuf=3.15.8=py39he80948d_0 218 | - psutil=5.8.0=py39h27cfd23_1 219 | - ptyprocess=0.7.0=pypi_0 220 | - pure-eval=0.2.2=pypi_0 221 | - pyarrow=6.0.1=pypi_0 222 | - pyasn1=0.4.8=pypi_0 223 | - pyasn1-modules=0.2.8=pypi_0 224 | - pycocoevalcap=1.2=pypi_0 225 | - pycocotools=2.0.6=pypi_0 226 | - pycodestyle=2.9.1=pypi_0 227 | - pycparser=2.21=pyhd3eb1b0_0 228 | - pydeprecate=0.3.2=pypi_0 229 | - pygments=2.13.0=pypi_0 230 | - pyopenssl=22.0.0=pyhd3eb1b0_0 231 | - pyparsing=3.0.7=pyhd8ed1ab_0 232 | - pyquaternion=0.9.9=pypi_0 233 | - pyrsistent=0.18.1=pypi_0 234 | - pysocks=1.7.1=py39h06a4308_0 235 | - python=3.9.7=h12debd9_1 236 | - python-dateutil=2.8.2=pyhd8ed1ab_0 237 | - python-graphviz=0.20.1=pypi_0 238 | - python-slugify=6.1.2=pypi_0 239 | - python_abi=3.9=2_cp39 240 | - pytorch-lightning=1.7.6=pypi_0 241 | - pytorch-mutex=1.0=cuda 242 | - pytz=2022.2.1=pypi_0 243 | - pyu2f=0.1.5=pypi_0 244 | - pyyaml=5.4.1=py39h3811e60_0 245 | - pyzmq=24.0.1=pypi_0 246 | - qpth=0.0.15=pypi_0 247 | - readline=8.1.2=h7f8727e_1 248 | - regex=2022.9.13=pypi_0 249 | - requests=2.27.1=pyhd3eb1b0_0 250 | - requests-oauthlib=1.3.1=pypi_0 251 | - responses=0.18.0=pypi_0 252 | - retry-decorator=1.1.1=pypi_0 253 | - rsa=4.7.2=pypi_0 254 | - ruamel-yaml=0.17.21=pypi_0 255 | - ruamel-yaml-clib=0.2.7=pypi_0 256 | - scikit-learn=1.1.2=pypi_0 257 | - scipy=1.10.1=pypi_0 258 | - send2trash=1.8.0=pypi_0 259 | - sentence-transformers=2.2.2=pypi_0 260 | - sentencepiece=0.1.97=pypi_0 261 | - sentry-sdk=1.5.7=pyhd8ed1ab_0 262 | - setproctitle=1.2.2=py39h3811e60_0 263 | - setuptools=65.4.0=pypi_0 264 | - shortuuid=1.0.8=py39hf3d152e_0 265 | - six=1.16.0=pyhd3eb1b0_1 266 | - smmap=3.0.5=pyh44b312d_0 267 | - sniffio=1.3.0=pypi_0 268 | - soupsieve=2.3.2.post1=pypi_0 269 | - sqlite=3.38.0=hc218d9a_0 270 | - stack-data=0.5.1=pypi_0 271 | - subprocess32=3.5.4=pypi_0 272 | - sympy=1.11.1=pypi_0 273 | - tabulate=0.8.10=pypi_0 274 | - tensorboard=2.10.0=pypi_0 275 | - tensorboard-data-server=0.6.1=pypi_0 276 | - tensorboard-plugin-wit=1.8.1=pypi_0 277 | - termcolor=1.1.0=py_2 278 | - terminado=0.15.0=pypi_0 279 | - text-unidecode=1.3=pypi_0 280 | - texttable=1.6.7=pypi_0 281 | - threadpoolctl=3.1.0=pypi_0 282 | - timm=0.6.7=pypi_0 283 | - tinycss2=1.1.1=pypi_0 284 | - tk=8.6.11=h1ccaba5_0 285 | - tokenizers=0.12.1=pypi_0 286 | - toml=0.10.2=pypi_0 287 | - tomli=2.0.1=pypi_0 288 | - tomlkit=0.11.4=pypi_0 289 | - torch=2.0.0=pypi_0 290 | - torchaudio=0.11.0=py39_cu113 291 | - torchmetrics=0.9.3=pypi_0 292 | - torchopt=0.7.0=pypi_0 293 | - torchvision=0.15.1=pypi_0 294 | - tornado=6.2=pypi_0 295 | - tqdm=4.63.0=pyhd8ed1ab_0 296 | - traitlets=5.4.0=pypi_0 297 | - transformers=4.26.1=pypi_0 298 | - triton=2.0.0=pypi_0 299 | - typing-extensions=4.3.0=pypi_0 300 | - tzdata=2021e=hda174b7_0 301 | - urllib3=1.26.8=pyhd3eb1b0_0 302 | - virtualenv=20.16.5=pypi_0 303 | - virtualenv-clone=0.5.7=pypi_0 304 | - wandb=0.13.7=pypi_0 305 | - watchdog=2.2.1=pypi_0 306 | - wcwidth=0.2.5=pypi_0 307 | - webdataset=0.2.30=pypi_0 308 | - webencodings=0.5.1=pypi_0 309 | - websocket-client=1.4.1=pypi_0 310 | - werkzeug=2.2.2=pypi_0 311 | - wheel=0.37.1=pyhd3eb1b0_0 312 | - widgetsnbextension=4.0.3=pypi_0 313 | - xxhash=3.1.0=pypi_0 314 | - xz=5.2.5=h7b6447c_0 315 | - yacs=0.1.8=pypi_0 316 | - yaml=0.2.5=h516909a_0 317 | - yarl=1.8.1=pypi_0 318 | - yaspin=2.1.0=pyhd8ed1ab_0 319 | - zipp=3.8.1=pypi_0 320 | - zlib=1.2.11=h7f8727e_4 321 | - zstd=1.4.9=haebb681_0 322 | -------------------------------------------------------------------------------- /style.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: "HelveticaNeue-Light", "Helvetica Neue Light", "Helvetica Neue", Helvetica, Arial, "Lucida Grande", sans-serif; 3 | font-weight:300; 4 | font-size:18px; 5 | margin-left: auto; 6 | margin-right: auto; 7 | width: 1100px; 8 | } 9 | 10 | table { 11 | margin-left: auto; 12 | margin-right: auto; 13 | } 14 | 15 | tr { 16 | /*margin-left: auto; 17 | margin-right: auto;*/ 18 | text-align: center; 19 | } 20 | 21 | h1 { 22 | font-weight:300; 23 | text-align: center; 24 | } 25 | 26 | h2 { 27 | text-align: center; 28 | font-family: "HelveticaNeue-Light", "Helvetica Neue Light", "Helvetica Neue", Helvetica, Arial, "Lucida Grande", sans-serif; 29 | } 30 | 31 | div { 32 | float: none; 33 | margin-left: auto; 34 | margin-right: auto; 35 | font-size: 0; 36 | } 37 | 38 | .disclaimerbox { 39 | background-color: #eee; 40 | border: 1px solid #eeeeee; 41 | border-radius: 10px ; 42 | -moz-border-radius: 10px ; 43 | -webkit-border-radius: 10px ; 44 | padding: 20px; 45 | } 46 | 47 | video.header-vid { 48 | height: 140px; 49 | border: 1px solid black; 50 | border-radius: 10px ; 51 | -moz-border-radius: 10px ; 52 | -webkit-border-radius: 10px ; 53 | } 54 | 55 | img.header-img { 56 | height: 140px; 57 | border: 1px solid black; 58 | border-radius: 10px ; 59 | -moz-border-radius: 10px ; 60 | -webkit-border-radius: 10px ; 61 | } 62 | 63 | img.rounded { 64 | border: 1px solid #eeeeee; 65 | border-radius: 10px ; 66 | -moz-border-radius: 10px ; 67 | -webkit-border-radius: 10px ; 68 | } 69 | 70 | a:link,a:visited 71 | { 72 | color: #1367a7; 73 | text-decoration: none; 74 | } 75 | a:hover { 76 | color: #208799; 77 | } 78 | 79 | td.dl-link { 80 | height: 160px; 81 | text-align: center; 82 | font-size: 22px; 83 | } 84 | 85 | 86 | .layered-paper-big { /* modified from: http://css-tricks.com/snippets/css/layered-paper/ */ 87 | box-shadow: 88 | 0px 0px 1px 1px rgba(0,0,0,0.35), /* The top layer shadow */ 89 | 5px 5px 0 0px #fff, /* The second layer */ 90 | 5px 5px 1px 1px rgba(0,0,0,0.35), /* The second layer shadow */ 91 | 10px 10px 0 0px #fff, /* The third layer */ 92 | 10px 10px 1px 1px rgba(0,0,0,0.35), /* The third layer shadow */ 93 | 15px 15px 0 0px #fff, /* The fourth layer */ 94 | 15px 15px 1px 1px rgba(0,0,0,0.35), /* The fourth layer shadow */ 95 | 20px 20px 0 0px #fff, /* The fifth layer */ 96 | 20px 20px 1px 1px rgba(0,0,0,0.35), /* The fifth layer shadow */ 97 | 25px 25px 0 0px #fff, /* The fifth layer */ 98 | 25px 25px 1px 1px rgba(0,0,0,0.35); /* The fifth layer shadow */ 99 | margin-left: 10px; 100 | margin-right: 45px; 101 | } 102 | 103 | 104 | .layered-paper { /* modified from: http://css-tricks.com/snippets/css/layered-paper/ */ 105 | box-shadow: 106 | 0px 0px 1px 1px rgba(0,0,0,0.35), /* The top layer shadow */ 107 | 5px 5px 0 0px #fff, /* The second layer */ 108 | 5px 5px 1px 1px rgba(0,0,0,0.35), /* The second layer shadow */ 109 | 10px 10px 0 0px #fff, /* The third layer */ 110 | 10px 10px 1px 1px rgba(0,0,0,0.35); /* The third layer shadow */ 111 | margin-top: 5px; 112 | margin-left: 10px; 113 | margin-right: 30px; 114 | margin-bottom: 5px; 115 | } 116 | 117 | .vert-cent { 118 | position: relative; 119 | top: 50%; 120 | transform: translateY(-50%); 121 | } 122 | 123 | pre { 124 | overflow-x: auto; 125 | text-align: left; 126 | border: 1px solid grey; 127 | border-radius: 3px; 128 | background: #eee; 129 | padding: 5px 5px 5px 10px; 130 | line-height:1.2 131 | } 132 | 133 | pre code { 134 | text-align: left; 135 | word-wrap: normal; 136 | white-space: pre-wrap; 137 | font-size:12px 138 | } 139 | 140 | hr { 141 | border: 0; 142 | height: 1.5px; 143 | background-image: linear-gradient(to right, rgba(0, 0, 0, 0), rgba(0, 0, 0, 0.75), rgba(0, 0, 0, 0)); 144 | } 145 | 146 | .vertical { 147 | display: table-cell; 148 | vertical-align: middle; 149 | } 150 | 151 | /* Below are hover effect - by Zhiqiu */ 152 | 153 | .slide-container { 154 | /*position: relative;*/ 155 | width: 250px; 156 | overflow: hidden; 157 | } 158 | 159 | 160 | .slide-container img:hover, .slide-container img:active{ 161 | /*position: relative;*/ 162 | transition: ease-in-out; 163 | transform: translate3d(-250px, 0px, 0px); 164 | } 165 | 166 | .slideup-container { 167 | /*position: relative;*/ 168 | height: 125px; 169 | overflow: hidden; 170 | } 171 | 172 | .slideup-container img:hover, .slideup-container img:active{ 173 | /*position: relative;*/ 174 | transition: ease-in-out; 175 | transform: translate3d(0px, -125px, 0px); 176 | } 177 | 178 | /*.slideup-container img:hover{ 179 | position: relative; 180 | transform: translate3d(0px, -125px, 0px); 181 | }*/ 182 | 183 | .teaser { 184 | /*position: absolute;*/ 185 | width: 270px; 186 | height: 55px; 187 | } 188 | 189 | .teaser-answer { 190 | opacity:0; 191 | transition:1.2s; 192 | } 193 | 194 | .teaser .teaser-answer:hover, .teaser .teaser-answer:active { 195 | opacity:1; 196 | } 197 | 198 | /*.overlay { 199 | position: absolute; 200 | top: 0; 201 | bottom: 0; 202 | left: 0; 203 | right: 0; 204 | height: 100%; 205 | width: 100%; 206 | opacity: 0; 207 | transition: .5s ease; 208 | background-color: #008CBA; 209 | } 210 | 211 | .hover-container:hover .overlay { 212 | opacity: 1; 213 | }*/ 214 | 215 | /*.overlay:hover .overlay { 216 | opacity: 0.5; 217 | transition: .5s ease; 218 | }*/ 219 | 220 | /*.text { 221 | color: white; 222 | font-size: 20px; 223 | position: absolute; 224 | top: 50%; 225 | left: 50%; 226 | -webkit-transform: translate(-50%, -50%); 227 | -ms-transform: translate(-50%, -50%); 228 | transform: translate(-50%, -50%); 229 | text-align: center; 230 | }*/ 231 | -------------------------------------------------------------------------------- /transform/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/transform/.DS_Store -------------------------------------------------------------------------------- /transform/__pycache__/randaugment.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/transform/__pycache__/randaugment.cpython-39.pyc -------------------------------------------------------------------------------- /transform/randaugment.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | ## aug functions 6 | def identity_func(img): 7 | return img 8 | 9 | 10 | def autocontrast_func(img, cutoff=0): 11 | ''' 12 | same output as PIL.ImageOps.autocontrast 13 | ''' 14 | n_bins = 256 15 | 16 | def tune_channel(ch): 17 | n = ch.size 18 | cut = cutoff * n // 100 19 | if cut == 0: 20 | high, low = ch.max(), ch.min() 21 | else: 22 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 23 | low = np.argwhere(np.cumsum(hist) > cut) 24 | low = 0 if low.shape[0] == 0 else low[0] 25 | high = np.argwhere(np.cumsum(hist[::-1]) > cut) 26 | high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0] 27 | if high <= low: 28 | table = np.arange(n_bins) 29 | else: 30 | scale = (n_bins - 1) / (high - low) 31 | offset = -low * scale 32 | table = np.arange(n_bins) * scale + offset 33 | table[table < 0] = 0 34 | table[table > n_bins - 1] = n_bins - 1 35 | table = table.clip(0, 255).astype(np.uint8) 36 | return table[ch] 37 | 38 | channels = [tune_channel(ch) for ch in cv2.split(img)] 39 | out = cv2.merge(channels) 40 | return out 41 | 42 | 43 | def equalize_func(img): 44 | ''' 45 | same output as PIL.ImageOps.equalize 46 | PIL's implementation is different from cv2.equalize 47 | ''' 48 | n_bins = 256 49 | 50 | def tune_channel(ch): 51 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 52 | non_zero_hist = hist[hist != 0].reshape(-1) 53 | step = np.sum(non_zero_hist[:-1]) // (n_bins - 1) 54 | if step == 0: return ch 55 | n = np.empty_like(hist) 56 | n[0] = step // 2 57 | n[1:] = hist[:-1] 58 | table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8) 59 | return table[ch] 60 | 61 | channels = [tune_channel(ch) for ch in cv2.split(img)] 62 | out = cv2.merge(channels) 63 | return out 64 | 65 | 66 | def rotate_func(img, degree, fill=(0, 0, 0)): 67 | ''' 68 | like PIL, rotate by degree, not radians 69 | ''' 70 | H, W = img.shape[0], img.shape[1] 71 | center = W / 2, H / 2 72 | M = cv2.getRotationMatrix2D(center, degree, 1) 73 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill) 74 | return out 75 | 76 | 77 | def solarize_func(img, thresh=128): 78 | ''' 79 | same output as PIL.ImageOps.posterize 80 | ''' 81 | table = np.array([el if el < thresh else 255 - el for el in range(256)]) 82 | table = table.clip(0, 255).astype(np.uint8) 83 | out = table[img] 84 | return out 85 | 86 | 87 | def color_func(img, factor): 88 | ''' 89 | same output as PIL.ImageEnhance.Color 90 | ''' 91 | ## implementation according to PIL definition, quite slow 92 | # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis] 93 | # out = blend(degenerate, img, factor) 94 | # M = ( 95 | # np.eye(3) * factor 96 | # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor) 97 | # )[np.newaxis, np.newaxis, :] 98 | M = ( 99 | np.float32([ 100 | [0.886, -0.114, -0.114], 101 | [-0.587, 0.413, -0.587], 102 | [-0.299, -0.299, 0.701]]) * factor 103 | + np.float32([[0.114], [0.587], [0.299]]) 104 | ) 105 | out = np.matmul(img, M).clip(0, 255).astype(np.uint8) 106 | return out 107 | 108 | 109 | def contrast_func(img, factor): 110 | """ 111 | same output as PIL.ImageEnhance.Contrast 112 | """ 113 | mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299])) 114 | table = np.array([( 115 | el - mean) * factor + mean 116 | for el in range(256) 117 | ]).clip(0, 255).astype(np.uint8) 118 | out = table[img] 119 | return out 120 | 121 | 122 | def brightness_func(img, factor): 123 | ''' 124 | same output as PIL.ImageEnhance.Contrast 125 | ''' 126 | table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8) 127 | out = table[img] 128 | return out 129 | 130 | 131 | def sharpness_func(img, factor): 132 | ''' 133 | The differences the this result and PIL are all on the 4 boundaries, the center 134 | areas are same 135 | ''' 136 | kernel = np.ones((3, 3), dtype=np.float32) 137 | kernel[1][1] = 5 138 | kernel /= 13 139 | degenerate = cv2.filter2D(img, -1, kernel) 140 | if factor == 0.0: 141 | out = degenerate 142 | elif factor == 1.0: 143 | out = img 144 | else: 145 | out = img.astype(np.float32) 146 | degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :] 147 | out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate) 148 | out = out.astype(np.uint8) 149 | return out 150 | 151 | 152 | def shear_x_func(img, factor, fill=(0, 0, 0)): 153 | H, W = img.shape[0], img.shape[1] 154 | M = np.float32([[1, factor, 0], [0, 1, 0]]) 155 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 156 | return out 157 | 158 | 159 | def translate_x_func(img, offset, fill=(0, 0, 0)): 160 | ''' 161 | same output as PIL.Image.transform 162 | ''' 163 | H, W = img.shape[0], img.shape[1] 164 | M = np.float32([[1, 0, -offset], [0, 1, 0]]) 165 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 166 | return out 167 | 168 | 169 | def translate_y_func(img, offset, fill=(0, 0, 0)): 170 | ''' 171 | same output as PIL.Image.transform 172 | ''' 173 | H, W = img.shape[0], img.shape[1] 174 | M = np.float32([[1, 0, 0], [0, 1, -offset]]) 175 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 176 | return out 177 | 178 | 179 | def posterize_func(img, bits): 180 | ''' 181 | same output as PIL.ImageOps.posterize 182 | ''' 183 | out = np.bitwise_and(img, np.uint8(255 << (8 - bits))) 184 | return out 185 | 186 | 187 | def shear_y_func(img, factor, fill=(0, 0, 0)): 188 | H, W = img.shape[0], img.shape[1] 189 | M = np.float32([[1, 0, 0], [factor, 1, 0]]) 190 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 191 | return out 192 | 193 | 194 | def cutout_func(img, pad_size, replace=(0, 0, 0)): 195 | replace = np.array(replace, dtype=np.uint8) 196 | H, W = img.shape[0], img.shape[1] 197 | rh, rw = np.random.random(2) 198 | pad_size = pad_size // 2 199 | ch, cw = int(rh * H), int(rw * W) 200 | x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H) 201 | y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W) 202 | out = img.copy() 203 | out[x1:x2, y1:y2, :] = replace 204 | return out 205 | 206 | 207 | ### level to args 208 | def enhance_level_to_args(MAX_LEVEL): 209 | def level_to_args(level): 210 | return ((level / MAX_LEVEL) * 1.8 + 0.1,) 211 | return level_to_args 212 | 213 | 214 | def shear_level_to_args(MAX_LEVEL, replace_value): 215 | def level_to_args(level): 216 | level = (level / MAX_LEVEL) * 0.3 217 | if np.random.random() > 0.5: level = -level 218 | return (level, replace_value) 219 | 220 | return level_to_args 221 | 222 | 223 | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value): 224 | def level_to_args(level): 225 | level = (level / MAX_LEVEL) * float(translate_const) 226 | if np.random.random() > 0.5: level = -level 227 | return (level, replace_value) 228 | 229 | return level_to_args 230 | 231 | 232 | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value): 233 | def level_to_args(level): 234 | level = int((level / MAX_LEVEL) * cutout_const) 235 | return (level, replace_value) 236 | 237 | return level_to_args 238 | 239 | 240 | def solarize_level_to_args(MAX_LEVEL): 241 | def level_to_args(level): 242 | level = int((level / MAX_LEVEL) * 256) 243 | return (level, ) 244 | return level_to_args 245 | 246 | 247 | def none_level_to_args(level): 248 | return () 249 | 250 | 251 | def posterize_level_to_args(MAX_LEVEL): 252 | def level_to_args(level): 253 | level = int((level / MAX_LEVEL) * 4) 254 | return (level, ) 255 | return level_to_args 256 | 257 | 258 | def rotate_level_to_args(MAX_LEVEL, replace_value): 259 | def level_to_args(level): 260 | level = (level / MAX_LEVEL) * 30 261 | if np.random.random() < 0.5: 262 | level = -level 263 | return (level, replace_value) 264 | 265 | return level_to_args 266 | 267 | 268 | func_dict = { 269 | 'Identity': identity_func, 270 | 'AutoContrast': autocontrast_func, 271 | 'Equalize': equalize_func, 272 | 'Rotate': rotate_func, 273 | 'Solarize': solarize_func, 274 | 'Color': color_func, 275 | 'Contrast': contrast_func, 276 | 'Brightness': brightness_func, 277 | 'Sharpness': sharpness_func, 278 | 'ShearX': shear_x_func, 279 | 'TranslateX': translate_x_func, 280 | 'TranslateY': translate_y_func, 281 | 'Posterize': posterize_func, 282 | 'ShearY': shear_y_func, 283 | } 284 | 285 | translate_const = 10 286 | MAX_LEVEL = 10 287 | replace_value = (128, 128, 128) 288 | arg_dict = { 289 | 'Identity': none_level_to_args, 290 | 'AutoContrast': none_level_to_args, 291 | 'Equalize': none_level_to_args, 292 | 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value), 293 | 'Solarize': solarize_level_to_args(MAX_LEVEL), 294 | 'Color': enhance_level_to_args(MAX_LEVEL), 295 | 'Contrast': enhance_level_to_args(MAX_LEVEL), 296 | 'Brightness': enhance_level_to_args(MAX_LEVEL), 297 | 'Sharpness': enhance_level_to_args(MAX_LEVEL), 298 | 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value), 299 | 'TranslateX': translate_level_to_args( 300 | translate_const, MAX_LEVEL, replace_value 301 | ), 302 | 'TranslateY': translate_level_to_args( 303 | translate_const, MAX_LEVEL, replace_value 304 | ), 305 | 'Posterize': posterize_level_to_args(MAX_LEVEL), 306 | 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value), 307 | } 308 | 309 | 310 | class RandomAugment(object): 311 | 312 | def __init__(self, N=2, M=10, isPIL=False, augs=[]): 313 | self.N = N 314 | self.M = M 315 | self.isPIL = isPIL 316 | if augs: 317 | self.augs = augs 318 | else: 319 | self.augs = list(arg_dict.keys()) 320 | 321 | def get_random_ops(self): 322 | sampled_ops = np.random.choice(self.augs, self.N) 323 | return [(op, 0.5, self.M) for op in sampled_ops] 324 | 325 | def __call__(self, img): 326 | if self.isPIL: 327 | img = np.array(img) 328 | ops = self.get_random_ops() 329 | for name, prob, level in ops: 330 | if np.random.random() > prob: 331 | continue 332 | args = arg_dict[name](level) 333 | img = func_dict[name](img, *args) 334 | return img 335 | 336 | 337 | if __name__ == '__main__': 338 | a = RandomAugment() 339 | img = np.random.randn(32, 32, 3) 340 | a(img) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # adapted from 2 | # https://github.com/VICO-UoE/DatasetCondensation 3 | import time 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import os 9 | import kornia as K 10 | import tqdm 11 | from torch.utils.data import Dataset 12 | from torchvision import datasets, transforms 13 | from scipy.ndimage.interpolation import rotate as scipyrotate 14 | from networks import MLP, ConvNet, LeNet, AlexNet, VGG11BN, VGG11, ResNet18, ResNet18BN_AP, ResNet18_AP, ModifiedResNet, resnet18_gn 15 | import re 16 | import json 17 | import torch.distributed as dist 18 | from tqdm import tqdm 19 | from collections import defaultdict 20 | 21 | class Config: 22 | imagenette = [0, 217, 482, 491, 497, 566, 569, 571, 574, 701] 23 | 24 | # ["australian_terrier", "border_terrier", "samoyed", "beagle", "shih-tzu", "english_foxhound", "rhodesian_ridgeback", "dingo", "golden_retriever", "english_sheepdog"] 25 | imagewoof = [193, 182, 258, 162, 155, 167, 159, 273, 207, 229] 26 | 27 | # ["tabby_cat", "bengal_cat", "persian_cat", "siamese_cat", "egyptian_cat", "lion", "tiger", "jaguar", "snow_leopard", "lynx"] 28 | imagemeow = [281, 282, 283, 284, 285, 291, 292, 290, 289, 287] 29 | 30 | # ["peacock", "flamingo", "macaw", "pelican", "king_penguin", "bald_eagle", "toucan", "ostrich", "black_swan", "cockatoo"] 31 | imagesquawk = [84, 130, 88, 144, 145, 22, 96, 9, 100, 89] 32 | 33 | # ["pineapple", "banana", "strawberry", "orange", "lemon", "pomegranate", "fig", "bell_pepper", "cucumber", "green_apple"] 34 | imagefruit = [953, 954, 949, 950, 951, 957, 952, 945, 943, 948] 35 | 36 | # ["bee", "ladys slipper", "banana", "lemon", "corn", "school_bus", "honeycomb", "lion", "garden_spider", "goldfinch"] 37 | imageyellow = [309, 986, 954, 951, 987, 779, 599, 291, 72, 11] 38 | 39 | dict = { 40 | "imagenette" : imagenette, 41 | "imagewoof" : imagewoof, 42 | "imagefruit": imagefruit, 43 | "imageyellow": imageyellow, 44 | "imagemeow": imagemeow, 45 | "imagesquawk": imagesquawk, 46 | } 47 | 48 | config = Config() 49 | 50 | def get_dataset(args): 51 | 52 | if args.dataset == 'CIFAR10_clip': 53 | channel = 3 54 | im_size = (32, 32) 55 | num_classes = 768 56 | mean = [0.4914, 0.4822, 0.4465] 57 | std = [0.2023, 0.1994, 0.2010] 58 | if args.zca: 59 | transform = transforms.Compose([transforms.ToTensor()]) 60 | else: 61 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) 62 | dst_train = datasets.CIFAR10(args.data_path, train=True, download=True, transform=transform) # no augmentation 63 | dst_test = datasets.CIFAR10(args.data_path, train=False, download=True, transform=transform) 64 | class_names = dst_train.classes 65 | class_map = {x:x for x in range(num_classes)} 66 | 67 | else: 68 | exit('unknown dataset: %s'%args.dataset) 69 | 70 | if args.zca: 71 | images = [] 72 | labels = [] 73 | print("Train ZCA") 74 | for i in tqdm(range(len(dst_train))): 75 | im, lab = dst_train[i] 76 | images.append(im) 77 | labels.append(lab) 78 | images = torch.stack(images, dim=0).to(args.device) 79 | labels = torch.tensor(labels, dtype=torch.float, device="cpu") 80 | zca = K.enhance.ZCAWhitening(eps=0.1, compute_inv=True) 81 | zca.fit(images) 82 | zca_images = zca(images).to("cpu") 83 | dst_train = TensorDataset(zca_images, labels) 84 | 85 | images = [] 86 | labels = [] 87 | print("Test ZCA") 88 | for i in tqdm(range(len(dst_test))): 89 | im, lab = dst_test[i] 90 | images.append(im) 91 | labels.append(lab) 92 | images = torch.stack(images, dim=0).to(args.device) 93 | labels = torch.tensor(labels, dtype=torch.float, device="cpu") 94 | 95 | zca_images = zca(images).to("cpu") 96 | dst_test = TensorDataset(zca_images, labels) 97 | 98 | args.zca_trans = zca 99 | 100 | 101 | testloader = torch.utils.data.DataLoader(dst_test, batch_size=128, shuffle=False, num_workers=2) 102 | 103 | if "flickr" not in args.dataset: 104 | dst_train_label, dst_test_label = None, None 105 | return channel, im_size, num_classes, mean, std, dst_train, dst_test, testloader, dst_train_label, dst_test_label 106 | 107 | 108 | 109 | class TensorDataset(Dataset): 110 | def __init__(self, images, labels): # images: n x c x h x w tensor 111 | self.images = images.detach().float() 112 | self.labels = labels.detach() 113 | 114 | def __getitem__(self, index): 115 | return self.images[index], self.labels[index] 116 | 117 | def __len__(self): 118 | return self.images.shape[0] 119 | 120 | 121 | 122 | def get_default_convnet_setting(): 123 | net_width, net_depth, net_act, net_norm, net_pooling = 128, 3, 'relu', 'instancenorm', 'avgpooling' 124 | return net_width, net_depth, net_act, net_norm, net_pooling 125 | 126 | 127 | 128 | def get_RN_network(model, vision_width, vision_layers, embed_dim, image_resolution): 129 | if model == 'RN50': 130 | vision_heads = vision_width * 32 // 64 131 | net = ModifiedResNet(layers=vision_layers, 132 | output_dim=embed_dim, 133 | heads=vision_heads, 134 | input_resolution=image_resolution, 135 | width=vision_width) 136 | if dist: 137 | gpu_num = torch.cuda.device_count() 138 | if gpu_num>0: 139 | device = 'cuda' 140 | if gpu_num>1: 141 | net = nn.DataParallel(net) 142 | else: 143 | device = 'cpu' 144 | net = net.to(device) 145 | 146 | return net 147 | 148 | def get_network(model, channel, num_classes, im_size=(32, 32), dist=True): 149 | torch.random.manual_seed(int(time.time() * 1000) % 100000) 150 | net_width, net_depth, net_act, net_norm, net_pooling = get_default_convnet_setting() 151 | 152 | if model == 'MLP': 153 | net = MLP(channel=channel, num_classes=num_classes) 154 | elif model == 'ConvNet': 155 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 156 | elif model == 'LeNet': 157 | net = LeNet(channel=channel, num_classes=num_classes) 158 | elif model == 'AlexNet': 159 | net = AlexNet(channel=channel, num_classes=num_classes) 160 | elif model == 'VGG11': 161 | net = VGG11( channel=channel, num_classes=num_classes) 162 | elif model == 'VGG11BN': 163 | net = VGG11BN(channel=channel, num_classes=num_classes) 164 | elif model == 'ResNet18': 165 | net = ResNet18(channel=channel, num_classes=num_classes) 166 | elif model == 'ResNet18BN_AP': 167 | net = ResNet18BN_AP(channel=channel, num_classes=num_classes) 168 | elif model == 'ResNet18_AP': 169 | net = ResNet18_AP(channel=channel, num_classes=num_classes) 170 | 171 | elif model == 'ConvNetD1': 172 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=1, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 173 | elif model == 'ConvNetD2': 174 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=2, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 175 | elif model == 'ConvNetD3': 176 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=3, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 177 | elif model == 'ConvNetD4': 178 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=4, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 179 | elif model == 'ConvNetD5': 180 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=5, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 181 | elif model == 'ConvNetD6': 182 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=6, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 183 | elif model == 'ConvNetD7': 184 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=7, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 185 | elif model == 'ConvNetD8': 186 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=8, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 187 | 188 | 189 | elif model == 'ConvNetW32': 190 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=32, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling) 191 | elif model == 'ConvNetW64': 192 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=64, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling) 193 | elif model == 'ConvNetW128': 194 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=128, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling) 195 | elif model == 'ConvNetW256': 196 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=256, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling) 197 | elif model == 'ConvNetW512': 198 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=512, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling) 199 | elif model == 'ConvNetW1024': 200 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=1024, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling) 201 | 202 | elif model == "ConvNetKIP": 203 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=1024, net_depth=net_depth, net_act=net_act, 204 | net_norm="none", net_pooling=net_pooling) 205 | 206 | elif model == 'ConvNetAS': 207 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='sigmoid', net_norm=net_norm, net_pooling=net_pooling) 208 | elif model == 'ConvNetAR': 209 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='relu', net_norm=net_norm, net_pooling=net_pooling) 210 | elif model == 'ConvNetAL': 211 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='leakyrelu', net_norm=net_norm, net_pooling=net_pooling) 212 | 213 | elif model == 'ConvNetNN': 214 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='none', net_pooling=net_pooling) 215 | elif model == 'ConvNetBN': 216 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='batchnorm', net_pooling=net_pooling) 217 | elif model == 'ConvNetLN': 218 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='layernorm', net_pooling=net_pooling) 219 | elif model == 'ConvNetIN': 220 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='instancenorm', net_pooling=net_pooling) 221 | elif model == 'ConvNetGN': 222 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='groupnorm', net_pooling=net_pooling) 223 | 224 | elif model == 'ConvNetNP': 225 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling='none') 226 | elif model == 'ConvNetMP': 227 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling='maxpooling') 228 | elif model == 'ConvNetAP': 229 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling='avgpooling') 230 | 231 | 232 | else: 233 | net = None 234 | exit('DC error: unknown model') 235 | 236 | if dist: 237 | gpu_num = torch.cuda.device_count() 238 | if gpu_num>0: 239 | device = 'cuda' 240 | if gpu_num>1: 241 | net = nn.DataParallel(net) 242 | else: 243 | device = 'cpu' 244 | net = net.to(device) 245 | 246 | return net 247 | 248 | 249 | 250 | def get_time(): 251 | return str(time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime())) 252 | 253 | 254 | 255 | def augment(images, dc_aug_param, device): 256 | # This can be sped up in the future. 257 | 258 | if dc_aug_param != None and dc_aug_param['strategy'] != 'none': 259 | scale = dc_aug_param['scale'] 260 | crop = dc_aug_param['crop'] 261 | rotate = dc_aug_param['rotate'] 262 | noise = dc_aug_param['noise'] 263 | strategy = dc_aug_param['strategy'] 264 | 265 | shape = images.shape 266 | mean = [] 267 | for c in range(shape[1]): 268 | mean.append(float(torch.mean(images[:,c]))) 269 | 270 | def cropfun(i): 271 | im_ = torch.zeros(shape[1],shape[2]+crop*2,shape[3]+crop*2, dtype=torch.float, device=device) 272 | for c in range(shape[1]): 273 | im_[c] = mean[c] 274 | im_[:, crop:crop+shape[2], crop:crop+shape[3]] = images[i] 275 | r, c = np.random.permutation(crop*2)[0], np.random.permutation(crop*2)[0] 276 | images[i] = im_[:, r:r+shape[2], c:c+shape[3]] 277 | 278 | def scalefun(i): 279 | h = int((np.random.uniform(1 - scale, 1 + scale)) * shape[2]) 280 | w = int((np.random.uniform(1 - scale, 1 + scale)) * shape[2]) 281 | tmp = F.interpolate(images[i:i + 1], [h, w], )[0] 282 | mhw = max(h, w, shape[2], shape[3]) 283 | im_ = torch.zeros(shape[1], mhw, mhw, dtype=torch.float, device=device) 284 | r = int((mhw - h) / 2) 285 | c = int((mhw - w) / 2) 286 | im_[:, r:r + h, c:c + w] = tmp 287 | r = int((mhw - shape[2]) / 2) 288 | c = int((mhw - shape[3]) / 2) 289 | images[i] = im_[:, r:r + shape[2], c:c + shape[3]] 290 | 291 | def rotatefun(i): 292 | im_ = scipyrotate(images[i].cpu().data.numpy(), angle=np.random.randint(-rotate, rotate), axes=(-2, -1), cval=np.mean(mean)) 293 | r = int((im_.shape[-2] - shape[-2]) / 2) 294 | c = int((im_.shape[-1] - shape[-1]) / 2) 295 | images[i] = torch.tensor(im_[:, r:r + shape[-2], c:c + shape[-1]], dtype=torch.float, device=device) 296 | 297 | def noisefun(i): 298 | images[i] = images[i] + noise * torch.randn(shape[1:], dtype=torch.float, device=device) 299 | 300 | 301 | augs = strategy.split('_') 302 | 303 | for i in range(shape[0]): 304 | choice = np.random.permutation(augs)[0] # randomly implement one augmentation 305 | if choice == 'crop': 306 | cropfun(i) 307 | elif choice == 'scale': 308 | scalefun(i) 309 | elif choice == 'rotate': 310 | rotatefun(i) 311 | elif choice == 'noise': 312 | noisefun(i) 313 | 314 | return images 315 | 316 | 317 | 318 | def get_daparam(dataset, model, model_eval, ipc): 319 | # We find that augmentation doesn't always benefit the performance. 320 | # So we do augmentation for some of the settings. 321 | 322 | dc_aug_param = dict() 323 | dc_aug_param['crop'] = 4 324 | dc_aug_param['scale'] = 0.2 325 | dc_aug_param['rotate'] = 45 326 | dc_aug_param['noise'] = 0.001 327 | dc_aug_param['strategy'] = 'none' 328 | 329 | if dataset == 'MNIST': 330 | dc_aug_param['strategy'] = 'crop_scale_rotate' 331 | 332 | if model_eval in ['ConvNetBN']: # Data augmentation makes model training with Batch Norm layer easier. 333 | dc_aug_param['strategy'] = 'crop_noise' 334 | 335 | return dc_aug_param 336 | 337 | 338 | def get_eval_pool(eval_mode, model, model_eval): 339 | if eval_mode == 'M': # multiple architectures 340 | # model_eval_pool = ['MLP', 'ConvNet', 'AlexNet', 'VGG11', 'ResNet18', 'LeNet'] 341 | model_eval_pool = ['ConvNet', 'AlexNet', 'VGG11', 'ResNet18_AP', 'ResNet18'] 342 | # model_eval_pool = ['MLP', 'ConvNet', 'AlexNet', 'VGG11', 'ResNet18'] 343 | elif eval_mode == 'W': # ablation study on network width 344 | model_eval_pool = ['ConvNetW32', 'ConvNetW64', 'ConvNetW128', 'ConvNetW256'] 345 | elif eval_mode == 'D': # ablation study on network depth 346 | model_eval_pool = ['ConvNetD1', 'ConvNetD2', 'ConvNetD3', 'ConvNetD4'] 347 | elif eval_mode == 'A': # ablation study on network activation function 348 | model_eval_pool = ['ConvNetAS', 'ConvNetAR', 'ConvNetAL'] 349 | elif eval_mode == 'P': # ablation study on network pooling layer 350 | model_eval_pool = ['ConvNetNP', 'ConvNetMP', 'ConvNetAP'] 351 | elif eval_mode == 'N': # ablation study on network normalization layer 352 | model_eval_pool = ['ConvNetNN', 'ConvNetBN', 'ConvNetLN', 'ConvNetIN', 'ConvNetGN'] 353 | elif eval_mode == 'S': # itself 354 | model_eval_pool = [model[:model.index('BN')]] if 'BN' in model else [model] 355 | elif eval_mode == 'C': 356 | model_eval_pool = [model, 'ConvNet'] 357 | else: 358 | model_eval_pool = [model_eval] 359 | return model_eval_pool 360 | 361 | 362 | class ParamDiffAug(): 363 | def __init__(self): 364 | self.aug_mode = 'S' #'multiple or single' 365 | self.prob_flip = 0.5 366 | self.ratio_scale = 1.2 367 | self.ratio_rotate = 15.0 368 | self.ratio_crop_pad = 0.125 369 | self.ratio_cutout = 0.5 # the size would be 0.5x0.5 370 | self.ratio_noise = 0.05 371 | self.brightness = 1.0 372 | self.saturation = 2.0 373 | self.contrast = 0.5 374 | 375 | 376 | def set_seed_DiffAug(param): 377 | if param.latestseed == -1: 378 | return 379 | else: 380 | torch.random.manual_seed(param.latestseed) 381 | param.latestseed += 1 382 | 383 | 384 | def DiffAugment(x, strategy='', seed = -1, param = None): 385 | if seed == -1: 386 | param.batchmode = False 387 | else: 388 | param.batchmode = True 389 | 390 | param.latestseed = seed 391 | 392 | if strategy == 'None' or strategy == 'none': 393 | return x 394 | 395 | if strategy: 396 | if param.aug_mode == 'M': # original 397 | for p in strategy.split('_'): 398 | for f in AUGMENT_FNS[p]: 399 | x = f(x, param) 400 | elif param.aug_mode == 'S': 401 | pbties = strategy.split('_') 402 | set_seed_DiffAug(param) 403 | p = pbties[torch.randint(0, len(pbties), size=(1,)).item()] 404 | for f in AUGMENT_FNS[p]: 405 | x = f(x, param) 406 | else: 407 | exit('Error ZH: unknown augmentation mode.') 408 | x = x.contiguous() 409 | return x 410 | 411 | 412 | # We implement the following differentiable augmentation strategies based on the code provided in https://github.com/mit-han-lab/data-efficient-gans. 413 | def rand_scale(x, param): 414 | # x>1, max scale 415 | # sx, sy: (0, +oo), 1: orignial size, 0.5: enlarge 2 times 416 | ratio = param.ratio_scale 417 | set_seed_DiffAug(param) 418 | sx = torch.rand(x.shape[0]) * (ratio - 1.0/ratio) + 1.0/ratio 419 | set_seed_DiffAug(param) 420 | sy = torch.rand(x.shape[0]) * (ratio - 1.0/ratio) + 1.0/ratio 421 | theta = [[[sx[i], 0, 0], 422 | [0, sy[i], 0],] for i in range(x.shape[0])] 423 | theta = torch.tensor(theta, dtype=torch.float) 424 | if param.batchmode: # batch-wise: 425 | theta[:] = theta[0] 426 | grid = F.affine_grid(theta, x.shape, align_corners=True).to(x.device) 427 | x = F.grid_sample(x, grid, align_corners=True) 428 | return x 429 | 430 | 431 | def rand_rotate(x, param): # [-180, 180], 90: anticlockwise 90 degree 432 | ratio = param.ratio_rotate 433 | set_seed_DiffAug(param) 434 | theta = (torch.rand(x.shape[0]) - 0.5) * 2 * ratio / 180 * float(np.pi) 435 | theta = [[[torch.cos(theta[i]), torch.sin(-theta[i]), 0], 436 | [torch.sin(theta[i]), torch.cos(theta[i]), 0],] for i in range(x.shape[0])] 437 | theta = torch.tensor(theta, dtype=torch.float) 438 | if param.batchmode: # batch-wise: 439 | theta[:] = theta[0] 440 | grid = F.affine_grid(theta, x.shape, align_corners=True).to(x.device) 441 | x = F.grid_sample(x, grid, align_corners=True) 442 | return x 443 | 444 | 445 | def rand_flip(x, param): 446 | prob = param.prob_flip 447 | set_seed_DiffAug(param) 448 | randf = torch.rand(x.size(0), 1, 1, 1, device=x.device) 449 | if param.batchmode: # batch-wise: 450 | randf[:] = randf[0] 451 | return torch.where(randf < prob, x.flip(3), x) 452 | 453 | 454 | def rand_brightness(x, param): 455 | ratio = param.brightness 456 | set_seed_DiffAug(param) 457 | randb = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) 458 | if param.batchmode: # batch-wise: 459 | randb[:] = randb[0] 460 | x = x + (randb - 0.5)*ratio 461 | return x 462 | 463 | 464 | def rand_saturation(x, param): 465 | ratio = param.saturation 466 | x_mean = x.mean(dim=1, keepdim=True) 467 | set_seed_DiffAug(param) 468 | rands = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) 469 | if param.batchmode: # batch-wise: 470 | rands[:] = rands[0] 471 | x = (x - x_mean) * (rands * ratio) + x_mean 472 | return x 473 | 474 | 475 | def rand_contrast(x, param): 476 | ratio = param.contrast 477 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True) 478 | set_seed_DiffAug(param) 479 | randc = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) 480 | if param.batchmode: # batch-wise: 481 | randc[:] = randc[0] 482 | x = (x - x_mean) * (randc + ratio) + x_mean 483 | return x 484 | 485 | 486 | def rand_crop(x, param): 487 | # The image is padded on its surrounding and then cropped. 488 | ratio = param.ratio_crop_pad 489 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 490 | set_seed_DiffAug(param) 491 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) 492 | set_seed_DiffAug(param) 493 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) 494 | if param.batchmode: # batch-wise: 495 | translation_x[:] = translation_x[0] 496 | translation_y[:] = translation_y[0] 497 | grid_batch, grid_x, grid_y = torch.meshgrid( 498 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 499 | torch.arange(x.size(2), dtype=torch.long, device=x.device), 500 | torch.arange(x.size(3), dtype=torch.long, device=x.device), 501 | ) 502 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) 503 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) 504 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) 505 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2) 506 | return x 507 | 508 | 509 | def rand_cutout(x, param): 510 | ratio = param.ratio_cutout 511 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 512 | set_seed_DiffAug(param) 513 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) 514 | set_seed_DiffAug(param) 515 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) 516 | if param.batchmode: # batch-wise: 517 | offset_x[:] = offset_x[0] 518 | offset_y[:] = offset_y[0] 519 | grid_batch, grid_x, grid_y = torch.meshgrid( 520 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 521 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device), 522 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device), 523 | ) 524 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) 525 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) 526 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) 527 | mask[grid_batch, grid_x, grid_y] = 0 528 | x = x * mask.unsqueeze(1) 529 | return x 530 | 531 | 532 | AUGMENT_FNS = { 533 | 'color': [rand_brightness, rand_saturation, rand_contrast], 534 | 'crop': [rand_crop], 535 | 'cutout': [rand_cutout], 536 | 'flip': [rand_flip], 537 | 'scale': [rand_scale], 538 | 'rotate': [rand_rotate], 539 | } 540 | 541 | 542 | def pre_question(question,max_ques_words=50): 543 | question = re.sub( 544 | r"([.!\"()*#:;~])", 545 | '', 546 | question.lower(), 547 | ) 548 | question = question.rstrip(' ') 549 | 550 | #truncate question 551 | question_words = question.split(' ') 552 | if len(question_words)>max_ques_words: 553 | question = ' '.join(question_words[:max_ques_words]) 554 | 555 | return question 556 | 557 | 558 | def save_result(result, result_dir, filename, remove_duplicate=''): 559 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank())) 560 | final_result_file = os.path.join(result_dir, '%s.json'%filename) 561 | 562 | json.dump(result,open(result_file,'w')) 563 | 564 | dist.barrier() 565 | 566 | if utils.is_main_process(): 567 | # combine results from all processes 568 | result = [] 569 | 570 | for rank in range(utils.get_world_size()): 571 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank)) 572 | res = json.load(open(result_file,'r')) 573 | result += res 574 | 575 | if remove_duplicate: 576 | result_new = [] 577 | id_list = [] 578 | for res in result: 579 | if res[remove_duplicate] not in id_list: 580 | id_list.append(res[remove_duplicate]) 581 | result_new.append(res) 582 | result = result_new 583 | 584 | json.dump(result,open(final_result_file,'w')) 585 | print('result file saved to %s'%final_result_file) 586 | 587 | return final_result_file 588 | 589 | 590 | 591 | #### everything below is from https://github.com/salesforce/BLIP/blob/main/utils.py 592 | import math 593 | 594 | def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr): 595 | """Decay the learning rate""" 596 | lr = (init_lr - min_lr) * 0.5 * (1. + math.cos(math.pi * epoch / max_epoch)) + min_lr 597 | for param_group in optimizer.param_groups: 598 | param_group['lr'] = lr 599 | 600 | def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr): 601 | """Warmup the learning rate""" 602 | lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max_step) 603 | for param_group in optimizer.param_groups: 604 | param_group['lr'] = lr 605 | 606 | def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate): 607 | """Decay the learning rate""" 608 | lr = max(min_lr, init_lr * (decay_rate**epoch)) 609 | for param_group in optimizer.param_groups: 610 | param_group['lr'] = lr 611 | 612 | import numpy as np 613 | import io 614 | import os 615 | import time 616 | from collections import defaultdict, deque 617 | import datetime 618 | 619 | import torch 620 | import torch.distributed as dist 621 | 622 | 623 | class MetricLogger(object): 624 | def __init__(self, delimiter="\t"): 625 | self.meters = defaultdict(SmoothedValue) 626 | self.delimiter = delimiter 627 | 628 | def update(self, **kwargs): 629 | for k, v in kwargs.items(): 630 | if isinstance(v, torch.Tensor): 631 | v = v.item() 632 | assert isinstance(v, (float, int)) 633 | self.meters[k].update(v) 634 | 635 | def __getattr__(self, attr): 636 | if attr in self.meters: 637 | return self.meters[attr] 638 | if attr in self.__dict__: 639 | return self.__dict__[attr] 640 | raise AttributeError("'{}' object has no attribute '{}'".format( 641 | type(self).__name__, attr)) 642 | 643 | def __str__(self): 644 | loss_str = [] 645 | for name, meter in self.meters.items(): 646 | loss_str.append( 647 | "{}: {}".format(name, str(meter)) 648 | ) 649 | return self.delimiter.join(loss_str) 650 | 651 | def global_avg(self): 652 | loss_str = [] 653 | for name, meter in self.meters.items(): 654 | loss_str.append( 655 | "{}: {:.4f}".format(name, meter.global_avg) 656 | ) 657 | return self.delimiter.join(loss_str) 658 | 659 | def synchronize_between_processes(self): 660 | for meter in self.meters.values(): 661 | meter.synchronize_between_processes() 662 | 663 | def add_meter(self, name, meter): 664 | self.meters[name] = meter 665 | 666 | def log_every(self, iterable, print_freq, header=None): 667 | i = 0 668 | if not header: 669 | header = '' 670 | start_time = time.time() 671 | end = time.time() 672 | iter_time = SmoothedValue(fmt='{avg:.4f}') 673 | data_time = SmoothedValue(fmt='{avg:.4f}') 674 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 675 | log_msg = [ 676 | header, 677 | '[{0' + space_fmt + '}/{1}]', 678 | 'eta: {eta}', 679 | '{meters}', 680 | 'time: {time}', 681 | 'data: {data}' 682 | ] 683 | if torch.cuda.is_available(): 684 | log_msg.append('max mem: {memory:.0f}') 685 | log_msg = self.delimiter.join(log_msg) 686 | MB = 1024.0 * 1024.0 687 | for obj in iterable: 688 | data_time.update(time.time() - end) 689 | yield obj 690 | iter_time.update(time.time() - end) 691 | if i % print_freq == 0 or i == len(iterable) - 1: 692 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 693 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 694 | if torch.cuda.is_available(): 695 | print(log_msg.format( 696 | i, len(iterable), eta=eta_string, 697 | meters=str(self), 698 | time=str(iter_time), data=str(data_time), 699 | memory=torch.cuda.max_memory_allocated() / MB)) 700 | else: 701 | print(log_msg.format( 702 | i, len(iterable), eta=eta_string, 703 | meters=str(self), 704 | time=str(iter_time), data=str(data_time))) 705 | i += 1 706 | end = time.time() 707 | total_time = time.time() - start_time 708 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 709 | print('{} Total time: {} ({:.4f} s / it)'.format( 710 | header, total_time_str, total_time / len(iterable))) 711 | 712 | 713 | 714 | class SmoothedValue(object): 715 | """Track a series of values and provide access to smoothed values over a 716 | window or the global series average. 717 | """ 718 | 719 | def __init__(self, window_size=20, fmt=None): 720 | if fmt is None: 721 | fmt = "{median:.4f} ({global_avg:.4f})" 722 | self.deque = deque(maxlen=window_size) 723 | self.total = 0.0 724 | self.count = 0 725 | self.fmt = fmt 726 | 727 | def update(self, value, n=1): 728 | self.deque.append(value) 729 | self.count += n 730 | self.total += value * n 731 | 732 | def synchronize_between_processes(self): 733 | """ 734 | Warning: does not synchronize the deque! 735 | """ 736 | if not is_dist_avail_and_initialized(): 737 | return 738 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 739 | dist.barrier() 740 | dist.all_reduce(t) 741 | t = t.tolist() 742 | self.count = int(t[0]) 743 | self.total = t[1] 744 | 745 | @property 746 | def median(self): 747 | d = torch.tensor(list(self.deque)) 748 | return d.median().item() 749 | 750 | @property 751 | def avg(self): 752 | d = torch.tensor(list(self.deque), dtype=torch.float32) 753 | return d.mean().item() 754 | 755 | @property 756 | def global_avg(self): 757 | return self.total / self.count 758 | 759 | @property 760 | def max(self): 761 | return max(self.deque) 762 | 763 | @property 764 | def value(self): 765 | return self.deque[-1] 766 | 767 | def __str__(self): 768 | return self.fmt.format( 769 | median=self.median, 770 | avg=self.avg, 771 | global_avg=self.global_avg, 772 | max=self.max, 773 | value=self.value) 774 | 775 | class AttrDict(dict): 776 | def __init__(self, *args, **kwargs): 777 | super(AttrDict, self).__init__(*args, **kwargs) 778 | self.__dict__ = self 779 | 780 | 781 | def compute_acc(logits, label, reduction='mean'): 782 | ret = (torch.argmax(logits, dim=1) == label).float() 783 | if reduction == 'none': 784 | return ret.detach() 785 | elif reduction == 'mean': 786 | return ret.mean().item() 787 | 788 | def compute_n_params(model, return_str=True): 789 | tot = 0 790 | for p in model.parameters(): 791 | w = 1 792 | for x in p.shape: 793 | w *= x 794 | tot += w 795 | if return_str: 796 | if tot >= 1e6: 797 | return '{:.1f}M'.format(tot / 1e6) 798 | else: 799 | return '{:.1f}K'.format(tot / 1e3) 800 | else: 801 | return tot 802 | 803 | def setup_for_distributed(is_master): 804 | """ 805 | This function disables printing when not in master process 806 | """ 807 | import builtins as __builtin__ 808 | builtin_print = __builtin__.print 809 | 810 | def print(*args, **kwargs): 811 | force = kwargs.pop('force', False) 812 | if is_master or force: 813 | builtin_print(*args, **kwargs) 814 | 815 | __builtin__.print = print 816 | 817 | 818 | def is_dist_avail_and_initialized(): 819 | if not dist.is_available(): 820 | return False 821 | if not dist.is_initialized(): 822 | return False 823 | return True 824 | 825 | 826 | def get_world_size(): 827 | if not is_dist_avail_and_initialized(): 828 | return 1 829 | return dist.get_world_size() 830 | 831 | 832 | def get_rank(): 833 | if not is_dist_avail_and_initialized(): 834 | return 0 835 | return dist.get_rank() 836 | 837 | 838 | def is_main_process(): 839 | return get_rank() == 0 840 | 841 | 842 | def save_on_master(*args, **kwargs): 843 | if is_main_process(): 844 | torch.save(*args, **kwargs) 845 | 846 | 847 | def init_distributed_mode(args): 848 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 849 | args.rank = int(os.environ["RANK"]) 850 | args.world_size = int(os.environ['WORLD_SIZE']) 851 | args.gpu = int(os.environ['LOCAL_RANK']) 852 | elif 'SLURM_PROCID' in os.environ: 853 | args.rank = int(os.environ['SLURM_PROCID']) 854 | args.gpu = args.rank % torch.cuda.device_count() 855 | else: 856 | print('Not using distributed mode') 857 | args.distributed = False 858 | return 859 | 860 | args.distributed = True 861 | 862 | torch.cuda.set_device(args.gpu) 863 | args.dist_backend = 'nccl' 864 | print('| distributed init (rank {}, word {}): {}'.format( 865 | args.rank, args.world_size, args.dist_url), flush=True) 866 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 867 | world_size=args.world_size, rank=args.rank) 868 | torch.distributed.barrier() 869 | setup_for_distributed(args.rank == 0) 870 | 871 | 872 | def load_or_process_file(file_type, process_func, args, data_source): 873 | """ 874 | Load the processed file if it exists, otherwise process the data source and create the file. 875 | 876 | Args: 877 | file_type: The type of the file (e.g., 'train', 'test'). 878 | process_func: The function to process the data source. 879 | args: The arguments required by the process function and to build the filename. 880 | data_source: The source data to be processed. 881 | 882 | Returns: 883 | The loaded data from the file. 884 | """ 885 | filename = f'{args.dataset}_{args.text_encoder}_{file_type}_embed.npz' 886 | 887 | 888 | if not os.path.exists(filename): 889 | print(f'Creating {filename}') 890 | process_func(args, data_source) 891 | else: 892 | print(f'Loading {filename}') 893 | 894 | return np.load(filename) -------------------------------------------------------------------------------- /website/poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princetonvisualai/multimodal_dataset_distillation/c696fd6eaa5435394a613376f7fc41b2478ba241/website/poster.pdf --------------------------------------------------------------------------------