├── README.md ├── data_selection.py ├── evaluation ├── argument.py ├── main.py ├── model_utils │ ├── models.py │ └── utils.py └── validation │ ├── main.py │ └── utils.py ├── imgs ├── DELT_framework.png ├── diversity_comparison.png ├── results.png └── samples.png ├── recover ├── models.py ├── recover.py └── utils.py └── scripts ├── conv4_imagenet1k_synthesis.sh ├── conv4_imagenet1k_validation.sh ├── resnet18_imagenet1k_synthesis.sh ├── resnet18_imagenet1k_validation.sh └── select_medium_tiny_rn18_ep50_ipc50.sh /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # DELT: A Simple Diversity-driven EarlyLate Training for Dataset Distillation (CVPR 2025) 4 | [[arXiv](https://arxiv.org/abs/2411.19946)]   [[Distilled Dataset](https://drive.google.com/file/d/1Rr_ik94FNte75yc4GiKtv927qdBwKskr/view?usp=sharing)] 5 | 6 | 7 |
8 | 9 | Official Pytorch implementation to DELT which outperforms SOTA top 1-acc by +1.3% by increasing diversity per class by +5% while reducing time by up to 39.3%. 10 | 11 |
12 | 13 |
14 | 15 | # Contents 16 | 17 | - [Abstract](#abstract) 18 | - [📸 Visualization](#-visualization) 19 | - [Usage](#usage) 20 | - [🗃 Dataset Format](#-dataset-format) 21 | - [🤖 Squeeze](#-squeeze) 22 | - [🧲 Ranking and Selection](#-ranking-and-selection) 23 | - [♻️ Recover](#️-recover) 24 | - [🧪 Evaluation](#-evaluation) 25 | - [📈 Results](#-results) 26 | - [📜 Citation](#-citation) 27 | # Abstract 28 | 29 | Recent advances in dataset distillation have led to solutions in two main directions. The conventional *batch-to-batch* matching mechanism is ideal for small-scale datasets and includes bi-level optimization methods on models and syntheses, such as FRePo, RCIG, and RaT-BPTT, as well as other methods like distribution matching, gradient matching, and weight trajectory matching. Conversely, *batch-to-global* matching typifies decoupled methods, which are particularly advantageous for large-scale datasets. This approach has garnered substantial interest within the community, as seen in SRe$^2$L, G-VBSM, WMDD, and CDA. A primary challenge with the second approach is the lack of diversity among syntheses within each class since samples are optimized independently and the same global supervision signals are reused across different synthetic images. In this study, we propose a new **D**iversity-driven **E**`arly`**L**`ate` **T**raining (DELT) scheme to enhance the diversity of images in batch-to-global matching with less computation. Our approach is conceptually simple yet effective, it partitions predefined IPC samples into smaller subtasks and employs local optimizations to distill each subset into distributions from distinct phases, reducing the uniformity induced by the unified optimization process. These distilled images from the subtasks demonstrate effective generalization when applied to the entire task. We conduct extensive experiments on CIFAR, Tiny-ImageNet, ImageNet-1K, and its sub-datasets. Our approach outperforms the previous state-of-the-art by **1.3%** on average across different datasets and IPCs (images per class), increasing diversity per class by more than **5%** while reducing synthesis time by up to **39.3%**, enhancing the overall efficiency. 30 | # 📸 Visualization 31 | 32 |
33 | 34 |
35 | 36 | # Usage 37 | 38 | ## 🗃 Dataset Format 39 | 40 | The dataset used for recovery and evaluation should be compatbile with [`ImageFolder`](https://pytorch.org/vision/main/generated/torchvision.datasets.ImageFolder.html) class. Refer to the PyTorch documentation for further details. 41 | 42 | ## 🤖 Squeeze 43 | 44 | For teacher models, we follow both SRe$^2$l, CDA, and RDED, using [official torchvision classification code](https://github.com/pytorch/vision/tree/main/references/classification)—we use torchvision pre-trained model for ImageNet-1K, and the other teacher models could be found in RDED's models [here](https://drive.google.com/drive/folders/1HmrheO6MgX453a5UPJdxPHK4UTv-4aVt?usp=drive_link). 45 | 46 | ## 🧲 Ranking and Selection 47 | 48 | Our DELT uses initialization based on 3 main selection criteria: 49 | 1. `top` selects the easiest images, scoring the highest probability of the true class 50 | 2. `min` selects the hardest images, scoring the lowest probabilities of the true class 51 | 3. `medium` selects the images around the median scores of true class probabilities, used in DELT. 52 | 53 | We provide a sample script that in [`scripts`](scripts/) that selects the medium difficulty images from TinyImageNet dataset. You can run the script as below 54 | 55 | ```bash 56 | bash /path/to/scripts/select_medium_tiny_rn18_ep50_ipc50.sh 57 | ``` 58 | 59 | We overview some variables 60 | - `TRAIN_DIR`: the training directory from which we select the images 61 | - `OUTPUT_DIR`: the output directory where we store the selected images, compatible with [`ImageFolder`](https://pytorch.org/vision/main/generated/torchvision.datasets.ImageFolder.html) 62 | - `RANKER_PATH`: the path to the model used in ranking; if not provided, we use the pre-trained model from torchvision 63 | - `RANKING_FILE`: the path of the output csv file containing the scores of all the images, which is used in selection 64 | 65 | ## ♻️ Recover 66 | 67 | We provide some [`scripts`](scripts/) to automate the experimentation process. For instance to synthesize ImageNet-1K images using ResNet-18 model, you can use 68 | 69 | ```bash 70 | bash /path/to/scripts/resnet18_imagenet1k_synthesis.sh 71 | ``` 72 | 73 | You can use the script while changing the variables as appropriate to your experiment. We overview them below 74 | 75 | - `SYN_PATH`: the target directory where we synthesize data 76 | - `INIT_PATH`: the path to the initial ImageNet-1K images that will be used in initialization 77 | - `IPC`: the number of **I**mage **P**er **C**lass 78 | - `ITR`: total number of update iterations 79 | - `ROUND_ITR`: the number of update iterations in a single round 80 | 81 | ## 🧪 Evaluation 82 | 83 | We provide some evaluation [`scripts`](scripts/) to evaluate the synthesized data. For instance to evaluate the synthesized ImageNet-1K images using ResNet-18 model, you can use 84 | 85 | ```bash 86 | bash /path/to/scripts/resnet18_imagenet1k_validation.sh 87 | ``` 88 | 89 | # 📈 Results 90 | 91 | We compare our approach against different methods on different datasets as below 92 | 93 |
94 | 95 |
96 | 97 | We also visualize the inter-class average cosine similarity as an indication for the diversity (lower values are more diverse) 98 | 99 |
100 | 101 |
102 | 103 | # 📜 Citation 104 | 105 | ``` 106 | @misc{shen2024deltsimplediversitydrivenearlylate, 107 | title={DELT: A Simple Diversity-driven EarlyLate Training for Dataset Distillation}, 108 | author={Zhiqiang Shen and Ammar Sherif and Zeyuan Yin and Shitong Shao}, 109 | year={2024}, 110 | eprint={2411.19946}, 111 | archivePrefix={arXiv}, 112 | } 113 | ``` 114 | -------------------------------------------------------------------------------- /data_selection.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torchvision 3 | import os 4 | import torch 5 | import torchvision.transforms as transforms 6 | import torchvision.models as models 7 | import pandas as pd 8 | from tqdm import tqdm 9 | import torch.nn as nn 10 | import math 11 | import shutil 12 | 13 | class ImageFolder(torchvision.datasets.ImageFolder): 14 | def __init__(self, **kwargs): 15 | super(ImageFolder, self).__init__(**kwargs) 16 | self.image_paths = [] 17 | self.image_relative_paths = [] 18 | self.targets = [] 19 | self.samples = [] 20 | class_dirs = [] 21 | for c_name in os.listdir(self.root): 22 | if os.path.isdir(os.path.join(self.root,c_name)): 23 | class_dirs.append(c_name) 24 | class_dirs.sort() 25 | for c_indx in range(len(class_dirs)): 26 | dir_path = os.path.join(self.root, class_dirs[c_indx]) 27 | img_names = os.listdir(dir_path) 28 | for i in range(len(img_names)): 29 | self.image_paths.append(os.path.join(dir_path, img_names[i])) 30 | self.image_relative_paths.append(os.path.join(class_dirs[c_indx], img_names[i])) 31 | self.targets.append(c_indx) 32 | 33 | def __getitem__(self, index): 34 | sample = self.loader(self.image_paths[index]) 35 | sample = self.transform(sample) 36 | return sample, self.targets[index], self.image_relative_paths[index] 37 | 38 | def __len__(self): 39 | return len(self.targets) 40 | 41 | def get_args(): 42 | parser = argparse.ArgumentParser("Data selection for dataset distillation") 43 | """Data save flags""" 44 | parser.add_argument("--dataset", type=str, default="imagenet-1k") 45 | parser.add_argument("--data-path", type=str, help="location of training data") 46 | parser.add_argument("--output-path", type=str, default="./selected-data", help="location of the selected data") 47 | parser.add_argument("--ranker-path", type=str, default="", help="path to the ranker model") 48 | parser.add_argument("--ranker-arch", type=str, default="resnet18", help="for loading the model") 49 | 50 | parser.add_argument("--ranking-file", type=str, default="", help="csv file that includes the ranking of the data") 51 | 52 | parser.add_argument("--store-rank-file", action="store_true", default=False, help="store csv file that includes the ranking of the data for future use") 53 | parser.add_argument("--ipc", type=int, default=50, help="number of IPC to select") 54 | parser.add_argument("--selection-criteria", type=str, default="medium") 55 | 56 | parser.add_argument("--batch-size", type=int, default=200, help="number of images to load at the same time") 57 | parser.add_argument("--workers", type=int, default=16, help="number of workers in data loader") 58 | parser.add_argument('--gpu-device', default="-1", type=str) 59 | args = parser.parse_args() 60 | 61 | if args.gpu_device != "-1": 62 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_device 63 | print("set CUDA_VISIBLE_DEVICES to ", args.gpu_device) 64 | 65 | print(args) 66 | return args 67 | 68 | def get_model(args): 69 | if args.dataset == "imagenet-1k": 70 | # ImageNet 1K dataset 71 | if args.ranker_path: 72 | model = models.get_model(args.ranker_arch, weights=False, num_classes=1000) 73 | checkpoint = torch.load(args.ranker_path, map_location="cpu") 74 | model.load_state_dict(checkpoint["model"]) 75 | model = nn.DataParallel(model).cuda() 76 | model.eval() 77 | for p in model.parameters(): 78 | p.requires_grad = False 79 | return model 80 | elif not args.ranker_path and args.ranker_arch: 81 | model = models.get_model(args.ranker_arch, weights="DEFAULT") 82 | model = nn.DataParallel(model).cuda() 83 | model.eval() 84 | for p in model.parameters(): 85 | p.requires_grad = False 86 | return model 87 | elif not args.ranker_path: 88 | print(f"You must either provide a checkpoint path using --ranker-path, ") 89 | print("or provide an architecture name using --ranker-arch to load torchvision pretrained weights") 90 | return None 91 | elif args.dataset == "imagenet-100": 92 | assert args.ranker_path 93 | model = models.get_model(args.ranker_arch, weights=False, num_classes=100) 94 | checkpoint = torch.load(args.ranker_path, map_location="cpu") 95 | model.load_state_dict(checkpoint["model"]) 96 | model = nn.DataParallel(model).cuda() 97 | model.eval() 98 | for p in model.parameters(): 99 | p.requires_grad = False 100 | return model 101 | elif args.dataset == "tiny-imagenet": 102 | assert args.ranker_path 103 | model = models.get_model(args.ranker_arch, weights=False, num_classes=200) 104 | model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 105 | model.maxpool = nn.Identity() 106 | checkpoint = torch.load(args.ranker_path, map_location="cpu") 107 | model.load_state_dict(checkpoint["model"]) 108 | 109 | model = nn.DataParallel(model).cuda() 110 | model.eval() 111 | for p in model.parameters(): 112 | p.requires_grad = False 113 | return model 114 | elif args.dataset in ["image-woof", "image-nette"]: 115 | model = models.get_model(args.ranker_arch, weights=False, num_classes=10) 116 | checkpoint = torch.load(args.ranker_path, map_location="cpu") 117 | model.load_state_dict(checkpoint["model"]) 118 | model = nn.DataParallel(model).cuda() 119 | model.eval() 120 | for p in model.parameters(): 121 | p.requires_grad = False 122 | return model 123 | elif args.dataset == "cifar10": 124 | assert args.ranker_path 125 | model = models.get_model(args.ranker_arch, weights=False, num_classes=10) 126 | model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 127 | model.maxpool = nn.Identity() 128 | checkpoint = torch.load(args.ranker_path, map_location="cpu") 129 | model.load_state_dict(checkpoint["model"]) 130 | 131 | model = nn.DataParallel(model).cuda() 132 | model.eval() 133 | for p in model.parameters(): 134 | p.requires_grad = False 135 | return model 136 | elif args.dataset == "cifar100": 137 | assert args.ranker_path 138 | model = models.get_model(args.ranker_arch, weights=False, num_classes=100) 139 | model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 140 | model.maxpool = nn.Identity() 141 | checkpoint = torch.load(args.ranker_path, map_location="cpu") 142 | model.load_state_dict(checkpoint["model"]) 143 | 144 | model = nn.DataParallel(model).cuda() 145 | model.eval() 146 | for p in model.parameters(): 147 | p.requires_grad = False 148 | return model 149 | def get_dataloader(args): 150 | if args.dataset in ["imagenet-1k", "imagenet-100", "image-woof", "image-nette"]: 151 | normalize = transforms.Normalize( 152 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 153 | ) 154 | dataset = ImageFolder(root=args.data_path, transform=transforms.Compose( 155 | [ 156 | transforms.Resize(224 // 7 * 8, antialias=True), 157 | transforms.CenterCrop(224), 158 | transforms.ToTensor(), 159 | normalize, 160 | ] 161 | ),) 162 | return torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, 163 | num_workers= args.workers, shuffle=False) 164 | elif args.dataset == "tiny-imagenet": 165 | normalize = transforms.Normalize( 166 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 167 | ) 168 | dataset = ImageFolder(root=args.data_path, transform=transforms.Compose( 169 | [ 170 | transforms.Resize(64 // 7 * 8, antialias=True), 171 | transforms.CenterCrop(64), 172 | transforms.ToTensor(), 173 | normalize, 174 | ] 175 | ),) 176 | return torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, 177 | num_workers= args.workers, shuffle=False) 178 | elif args.dataset in ["cifar10", "cifar100"] : 179 | normalize = transforms.Normalize( 180 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 181 | ) 182 | dataset = ImageFolder(root=args.data_path, transform=transforms.Compose( 183 | [ 184 | transforms.Resize(32 // 7 * 8, antialias=True), 185 | transforms.CenterCrop(32), 186 | transforms.ToTensor(), 187 | normalize, 188 | ] 189 | ),) 190 | return torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, 191 | num_workers= args.workers, shuffle=False) 192 | 193 | def rank(args, model, data_loader): 194 | softmax = torch.nn.Softmax(dim=1) 195 | 196 | ranking = { 197 | 'score': [], 198 | 'img_path': [], 199 | } 200 | 201 | with torch.no_grad(): 202 | for images, labels, paths in tqdm(data_loader): 203 | images, labels = images.cuda(), labels.cuda() 204 | output = model(images) 205 | output = softmax(output) 206 | pr = torch.gather(output, dim=1, index=labels.unsqueeze(-1)).squeeze(-1) 207 | ranking['img_path'] += list(paths) 208 | ranking['score'] += pr.tolist() 209 | 210 | ranking_df = pd.DataFrame(ranking) 211 | ranking_df["class"] = ranking_df['img_path'].apply(lambda path: path.split('/')[0]) 212 | # sort the values 213 | ranking_df = ranking_df.sort_values(['class', 'score'], ascending=[True, False]) 214 | if args.store_rank_file: 215 | dir = "/".join(args.ranking_file.split("/")[:-1]) 216 | if not os.path.exists(dir): 217 | os.makedirs(dir) 218 | ranking_df.to_csv(args.ranking_file, index=False) 219 | return ranking_df 220 | 221 | def store_top(ranked_df, class_name, root_dir, target_dir, ipc): 222 | target_dir = os.path.join(target_dir, f"{class_name}") 223 | if os.path.exists(target_dir): 224 | shutil.rmtree(target_dir) 225 | 226 | os.makedirs(target_dir) 227 | df = ranked_df[ranked_df['class'] == class_name].sort_values(['score'], ascending=[False]) 228 | for img_indx in range(ipc): 229 | img_path = os.path.join(root_dir, df.iloc[img_indx]['img_path']) 230 | file_name= df.iloc[img_indx]['img_path'].split("/")[-1] 231 | new_path = os.path.join(target_dir, f"{img_indx:03}_{file_name}") 232 | shutil.copy(img_path, new_path) 233 | 234 | def store_min(ranked_df, class_name, root_dir, target_dir, ipc): 235 | target_dir = os.path.join(target_dir, f"{class_name}") 236 | if os.path.exists(target_dir): 237 | shutil.rmtree(target_dir) 238 | 239 | os.makedirs(target_dir) 240 | df = ranked_df[ranked_df['class'] == class_name].sort_values(['score'], ascending=[True]) 241 | for img_indx in range(ipc): 242 | img_path = os.path.join(root_dir, df.iloc[img_indx]['img_path']) 243 | file_name= df.iloc[img_indx]['img_path'].split("/")[-1] 244 | new_path = os.path.join(target_dir, f"{img_indx:03}_{file_name}") 245 | shutil.copy(img_path, new_path) 246 | 247 | def store_medium(ranked_df, class_name, root_dir, target_dir, ipc): 248 | target_dir = os.path.join(target_dir, f"{class_name}") 249 | if os.path.exists(target_dir): 250 | shutil.rmtree(target_dir) 251 | 252 | os.makedirs(target_dir) 253 | df = ranked_df[ranked_df['class'] == class_name].sort_values(['score'], ascending=[True]) 254 | mid_indx = math.ceil(len(df)/2) 255 | pos_neg = -1 256 | for indx in range(ipc): 257 | img_indx = mid_indx + int(pos_neg*(indx+1)/2) 258 | pos_neg = pos_neg*-1 259 | img_path = os.path.join(root_dir, df.iloc[img_indx]['img_path']) 260 | file_name= df.iloc[img_indx]['img_path'].split("/")[-1] 261 | new_path = os.path.join(target_dir, f"{indx:03}_{file_name}") 262 | shutil.copy(img_path, new_path) 263 | 264 | def store_imgs(ranked_df, class_name, root_dir, target_dir, ipc): 265 | target_dir = os.path.join(target_dir, f"{class_name}") 266 | if os.path.exists(target_dir): 267 | shutil.rmtree(target_dir) 268 | 269 | os.makedirs(target_dir) 270 | df = ranked_df[ranked_df['class'] == class_name].sort_values(['score'], ascending=[False]) 271 | for img_indx in range(ipc): 272 | img_path = os.path.join(root_dir, df.iloc[img_indx]['img_path']) 273 | file_name= df.iloc[img_indx]['img_path'].split("/")[-1] 274 | new_path = os.path.join(target_dir, f"{img_indx:03}_{file_name}") 275 | shutil.copy(img_path, new_path) 276 | 277 | if __name__ == "__main__": 278 | args = get_args() 279 | if not args.store_rank_file and args.ranking_file: 280 | ranking_df = pd.read_csv(args.ranking_file) 281 | if "Unnamed: 0" in ranking_df.columns.values.tolist(): 282 | del ranking_df['Unnamed: 0'] 283 | else: 284 | model = get_model(args) 285 | loader = get_dataloader(args) 286 | if model is None or loader is None: 287 | exit() 288 | ranking_df = rank(args, model, loader) 289 | 290 | if args.selection_criteria =="medium": 291 | store = store_medium 292 | elif args.selection_criteria =="top": 293 | store = store_top 294 | elif args.selection_criteria =="min": 295 | store = store_min 296 | else: 297 | print("Unknown selection crtieria") 298 | exit() 299 | 300 | for class_name in tqdm(ranking_df['class'].unique()): 301 | store(ranking_df, class_name, root_dir=args.data_path, 302 | target_dir=args.output_path, ipc = args.ipc) -------------------------------------------------------------------------------- /evaluation/argument.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser("Evaluation scheme for DELT") 4 | """Architecture and Dataset""" 5 | parser.add_argument( 6 | "--arch-name", 7 | type=str, 8 | default="resnet18", 9 | help="arch name from pretrained torchvision models", 10 | ) 11 | parser.add_argument( 12 | "--subset", 13 | type=str, 14 | default="imagenet-1k", 15 | ) 16 | parser.add_argument( 17 | "--nclass", 18 | type=int, 19 | default=1000, 20 | help="number of synthesized classes", 21 | ) 22 | parser.add_argument( 23 | "--ipc", 24 | type=int, 25 | default=50, 26 | help="number of synthesized images per class", 27 | ) 28 | parser.add_argument( 29 | "--input-size", 30 | default=224, 31 | type=int, 32 | metavar="S", 33 | ) 34 | """Evaluation arguments""" 35 | parser.add_argument("--re-batch-size", default=0, type=int, metavar="N") 36 | parser.add_argument( 37 | "--re-accum-steps", 38 | type=int, 39 | default=1, 40 | help="gradient accumulation steps for small gpu memory", 41 | ) 42 | parser.add_argument( 43 | "--mix-type", 44 | default="cutmix", 45 | type=str, 46 | choices=["mixup", "cutmix", None], 47 | help="mixup or cutmix or None", 48 | ) 49 | parser.add_argument( 50 | "--stud-name", 51 | type=str, 52 | default="resnet18", 53 | help="arch name from torchvision models", 54 | ) 55 | parser.add_argument( 56 | "--val-ipc", 57 | type=int, 58 | default=30, 59 | ) 60 | parser.add_argument( 61 | "--workers", 62 | default=4, 63 | type=int, 64 | metavar="N", 65 | help="number of data loading workers (default: 4)", 66 | ) 67 | parser.add_argument( 68 | "--classes", 69 | type=list, 70 | help="number of classes for synthesis", 71 | ) 72 | parser.add_argument( 73 | "--temperature", 74 | type=float, 75 | help="temperature for distillation loss", 76 | ) 77 | parser.add_argument( 78 | "--val-dir", 79 | type=str, 80 | default="/workspace/data/val/", 81 | help="path to validation dataset", 82 | ) 83 | parser.add_argument( 84 | "--min-scale-crops", type=float, default=0.08, help="argument in RandomResizedCrop" 85 | ) 86 | parser.add_argument( 87 | "--max-scale-crops", type=float, default=1, help="argument in RandomResizedCrop" 88 | ) 89 | parser.add_argument("--re-epochs", default=300, type=int) 90 | parser.add_argument( 91 | "--syn-data-path", 92 | type=str, 93 | default="/workspace/data/DELT_IN_1K/ipc50_cr5_mipc300", 94 | help="path to synthetic data", 95 | ) 96 | parser.add_argument( 97 | "--seed", default=42, type=int, help="seed for initializing training. " 98 | ) 99 | parser.add_argument( 100 | "--mixup", 101 | type=float, 102 | default=0.8, 103 | help="mixup alpha, mixup enabled if > 0. (default: 0.8)", 104 | ) 105 | parser.add_argument( 106 | "--cutmix", 107 | type=float, 108 | default=1.0, 109 | help="cutmix alpha, cutmix enabled if > 0. (default: 1.0)", 110 | ) 111 | parser.add_argument("--cos", default=True, help="cosine lr scheduler") 112 | parser.add_argument("--use-rand-augment", default=False, action="store_true", help="use Timm's RandAugment") 113 | parser.add_argument('--rand-augment-config', type=str, default="rand-m6-n2-mstd1.0", help='RandAugment configuration') 114 | 115 | # sgd 116 | parser.add_argument("--sgd", default=False, action="store_true", help="sgd optimizer") 117 | parser.add_argument( 118 | "-lr", 119 | "--learning-rate", 120 | type=float, 121 | default=0.1, 122 | help="sgd init learning rate", 123 | ) 124 | parser.add_argument("--momentum", type=float, default=0.9, help="sgd momentum") 125 | parser.add_argument("--weight-decay", type=float, default=1e-4, help="sgd weight decay") 126 | 127 | # adamw 128 | parser.add_argument("--adamw-lr", type=float, default=0, help="adamw learning rate") 129 | parser.add_argument( 130 | "--adamw-weight-decay", type=float, default=0.01, help="adamw weight decay" 131 | ) 132 | parser.add_argument( 133 | "--exp-name", 134 | type=str, 135 | help="name of the experiment, subfolder under syn_data_path", 136 | ) 137 | parser.add_argument('--wandb-api-key', type=str, 138 | default=None, help='wandb api key') 139 | parser.add_argument('--wandb-project', type=str, 140 | default='imagnet_1k_distillation', help='wandb project name') 141 | parser.add_argument('--gpu-device', default="-1", type=str) 142 | args = parser.parse_args() 143 | 144 | # set up dataset settings 145 | # set smaller val_ipc only for quick validation 146 | if args.subset in [ 147 | "imagenet-a", 148 | "imagenet-b", 149 | "imagenet-c", 150 | "imagenet-d", 151 | "imagenet-e", 152 | "imagenet-birds", 153 | "imagenet-fruits", 154 | "imagenet-cats", 155 | "imagenet-10", 156 | ]: 157 | args.nclass = 10 158 | args.classes = range(args.nclass) 159 | args.val_ipc = 50 160 | args.input_size = 224 161 | 162 | elif args.subset == "imagenet-nette": 163 | args.nclass = 10 164 | args.classes = range(args.nclass) 165 | args.val_ipc = 50 166 | args.input_size = 224 167 | if args.arch_name in ["conv5", "conv6"] or args.stud_name in ["conv5", "conv6"]: 168 | args.input_size = 128 169 | 170 | elif args.subset == "imagenet-woof": 171 | args.nclass = 10 172 | args.classes = range(args.nclass) 173 | args.val_ipc = 50 174 | args.input_size = 224 175 | if args.arch_name in ["conv5", "conv6"] or args.stud_name in ["conv5", "conv6"]: 176 | args.input_size = 128 177 | 178 | elif args.subset == "imagenet-100": 179 | args.nclass = 100 180 | args.classes = range(args.nclass) 181 | args.val_ipc = 50 182 | args.input_size = 224 183 | if args.arch_name in ["conv5", "conv6"] or args.stud_name in ["conv5", "conv6"]: 184 | args.input_size = 128 185 | 186 | elif args.subset == "imagenet-1k": 187 | args.nclass = 1000 188 | args.classes = range(args.nclass) 189 | args.val_ipc = 50 190 | args.input_size = 224 191 | if "conv" in args.arch_name or "conv" in args.stud_name: 192 | args.input_size = 64 193 | 194 | elif args.subset == "cifar10": 195 | args.nclass = 10 196 | args.classes = range(args.nclass) 197 | args.val_ipc = 1000 198 | args.input_size = 32 199 | 200 | elif args.subset == "cifar100": 201 | args.nclass = 100 202 | args.classes = range(args.nclass) 203 | args.val_ipc = 100 204 | args.input_size = 32 205 | 206 | elif args.subset == "tinyimagenet": 207 | args.nclass = 200 208 | args.classes = range(args.nclass) 209 | args.val_ipc = 50 210 | args.input_size = 64 211 | 212 | args.nclass = len(args.classes) 213 | 214 | # set up batch size 215 | if args.re_batch_size == 0: 216 | if args.ipc == 100: 217 | args.re_batch_size = 100 218 | args.workers = 8 219 | elif args.ipc == 50: 220 | args.re_batch_size = 100 221 | args.workers = 4 222 | elif args.ipc in [30, 40]: 223 | args.re_batch_size = 100 224 | args.workers = 4 225 | elif args.ipc in [10, 20]: 226 | args.re_batch_size = 50 227 | args.workers = 4 228 | elif args.ipc == 1: 229 | args.re_batch_size = 10 230 | args.workers = 4 231 | 232 | if args.nclass == 10: 233 | args.re_batch_size *= 1 234 | if args.nclass == 100: 235 | args.re_batch_size *= 2 236 | if args.nclass == 1000: 237 | args.re_batch_size *= 2 238 | if args.subset in ["imagenet-1k", "imagenet-100"] and args.stud_name == "resnet101": 239 | args.re_batch_size = args.re_batch_size//2 240 | # ! tinyimagenet 241 | if args.subset == "tinyimagenet": 242 | args.re_batch_size = 100 243 | 244 | # reset batch size below ipc * nclass 245 | if args.re_batch_size > args.ipc * args.nclass: 246 | args.re_batch_size = int(args.ipc * args.nclass) 247 | 248 | # reset batch size with re_accum_steps 249 | if args.re_accum_steps != 1: 250 | args.re_batch_size = int(args.re_batch_size / args.re_accum_steps) 251 | 252 | # temperature 253 | if args.mix_type == "mixup": 254 | args.temperature = 4 255 | elif args.mix_type == "cutmix": 256 | args.temperature = 20 257 | 258 | # adamw learning rate 259 | if args.stud_name == "vgg11": 260 | args.adamw_lr = 0.0005 261 | elif args.stud_name == "conv3": 262 | args.adamw_lr = 0.001 263 | elif args.stud_name == "conv4": 264 | args.adamw_lr = 0.001 265 | elif args.stud_name == "conv5": 266 | args.adamw_lr = 0.001 267 | elif args.stud_name == "conv6": 268 | args.adamw_lr = 0.001 269 | elif args.stud_name == "resnet18": 270 | args.adamw_lr = 0.001 271 | elif args.stud_name == "resnet18_modified": 272 | args.adamw_lr = 0.001 273 | elif args.stud_name == "efficientnet_b0": 274 | args.adamw_lr = 0.002 275 | elif args.stud_name in ["mobilenet_v2", "mobilenet_v2_modified"]: 276 | args.adamw_lr = 0.0025 277 | elif args.stud_name == "alexnet": 278 | args.adamw_lr = 0.0001 279 | elif args.stud_name == "resnet50": 280 | args.adamw_lr = 0.001 281 | elif args.stud_name == "resnet101": 282 | args.adamw_lr = 0.001 283 | elif args.stud_name == "resnet101_modified": 284 | args.adamw_lr = 0.001 285 | elif args.stud_name == "vit_b_16": 286 | args.adamw_lr = 0.0001 287 | elif args.stud_name == "swin_v2_t": 288 | args.adamw_lr = 0.0001 289 | 290 | # special experiment 291 | if ( 292 | args.subset == "cifar100" 293 | and args.arch_name == "conv3" 294 | and args.stud_name == "conv3" 295 | ): 296 | args.re_batch_size = 25 297 | args.adamw_lr = 0.002 298 | 299 | print(f"args: {args}") 300 | -------------------------------------------------------------------------------- /evaluation/main.py: -------------------------------------------------------------------------------- 1 | """This code is adopted from RDED here: https://github.com/LINs-lab/RDED/tree/main""" 2 | from argument import args 3 | from validation.main import ( 4 | main as valid_main, 5 | ) # The relabel and validation are combined here for fast experiment 6 | import warnings 7 | 8 | warnings.filterwarnings("ignore") 9 | 10 | if __name__ == "__main__": 11 | valid_main(args) 12 | -------------------------------------------------------------------------------- /evaluation/model_utils/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | # Conv-3 model 6 | class ConvNet(nn.Module): 7 | def __init__( 8 | self, 9 | num_classes, 10 | net_norm="batch", 11 | net_depth=3, 12 | net_width=128, 13 | channel=3, 14 | net_act="relu", 15 | net_pooling="avgpooling", 16 | im_size=(32, 32), 17 | ): 18 | # print(f"Define Convnet (depth {net_depth}, width {net_width}, norm {net_norm})") 19 | super(ConvNet, self).__init__() 20 | if net_act == "sigmoid": 21 | self.net_act = nn.Sigmoid() 22 | elif net_act == "relu": 23 | self.net_act = nn.ReLU() 24 | elif net_act == "leakyrelu": 25 | self.net_act = nn.LeakyReLU(negative_slope=0.01) 26 | else: 27 | exit("unknown activation function: %s" % net_act) 28 | 29 | if net_pooling == "maxpooling": 30 | self.net_pooling = nn.MaxPool2d(kernel_size=2, stride=2) 31 | elif net_pooling == "avgpooling": 32 | self.net_pooling = nn.AvgPool2d(kernel_size=2, stride=2) 33 | elif net_pooling == "none": 34 | self.net_pooling = None 35 | else: 36 | exit("unknown net_pooling: %s" % net_pooling) 37 | 38 | self.depth = net_depth 39 | self.net_norm = net_norm 40 | 41 | self.layers, shape_feat = self._make_layers( 42 | channel, net_width, net_depth, net_norm, net_pooling, im_size 43 | ) 44 | num_feat = shape_feat[0] * shape_feat[1] * shape_feat[2] 45 | self.classifier = nn.Linear(num_feat, num_classes) 46 | 47 | def forward(self, x, return_features=False): 48 | for d in range(self.depth): 49 | x = self.layers["conv"][d](x) 50 | if len(self.layers["norm"]) > 0: 51 | x = self.layers["norm"][d](x) 52 | x = self.layers["act"][d](x) 53 | if len(self.layers["pool"]) > 0: 54 | x = self.layers["pool"][d](x) 55 | 56 | # x = nn.functional.avg_pool2d(x, x.shape[-1]) 57 | out = x.view(x.shape[0], -1) 58 | logit = self.classifier(out) 59 | 60 | if return_features: 61 | return logit, out 62 | else: 63 | return logit 64 | 65 | def get_feature( 66 | self, x, idx_from, idx_to=-1, return_prob=False, return_logit=False 67 | ): 68 | if idx_to == -1: 69 | idx_to = idx_from 70 | features = [] 71 | 72 | for d in range(self.depth): 73 | x = self.layers["conv"][d](x) 74 | if self.net_norm: 75 | x = self.layers["norm"][d](x) 76 | x = self.layers["act"][d](x) 77 | if self.net_pooling: 78 | x = self.layers["pool"][d](x) 79 | features.append(x) 80 | if idx_to < len(features): 81 | return features[idx_from : idx_to + 1] 82 | 83 | if return_prob: 84 | out = x.view(x.size(0), -1) 85 | logit = self.classifier(out) 86 | prob = torch.softmax(logit, dim=-1) 87 | return features, prob 88 | elif return_logit: 89 | out = x.view(x.size(0), -1) 90 | logit = self.classifier(out) 91 | return features, logit 92 | else: 93 | return features[idx_from : idx_to + 1] 94 | 95 | def _get_normlayer(self, net_norm, shape_feat): 96 | # shape_feat = (c * h * w) 97 | if net_norm == "batch": 98 | norm = nn.BatchNorm2d(shape_feat[0], affine=True) 99 | elif net_norm == "layer": 100 | norm = nn.LayerNorm(shape_feat, elementwise_affine=True) 101 | elif net_norm == "instance": 102 | norm = nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True) 103 | elif net_norm == "group": 104 | norm = nn.GroupNorm(4, shape_feat[0], affine=True) 105 | elif net_norm == "none": 106 | norm = None 107 | else: 108 | norm = None 109 | exit("unknown net_norm: %s" % net_norm) 110 | return norm 111 | 112 | def _make_layers( 113 | self, channel, net_width, net_depth, net_norm, net_pooling, im_size 114 | ): 115 | layers = {"conv": [], "norm": [], "act": [], "pool": []} 116 | 117 | in_channels = channel 118 | if im_size[0] == 28: 119 | im_size = (32, 32) 120 | shape_feat = [in_channels, im_size[0], im_size[1]] 121 | 122 | for d in range(net_depth): 123 | layers["conv"] += [ 124 | nn.Conv2d( 125 | in_channels, 126 | net_width, 127 | kernel_size=3, 128 | padding=3 if channel == 1 and d == 0 else 1, 129 | ) 130 | ] 131 | shape_feat[0] = net_width 132 | if net_norm != "none": 133 | layers["norm"] += [self._get_normlayer(net_norm, shape_feat)] 134 | layers["act"] += [self.net_act] 135 | in_channels = net_width 136 | if net_pooling != "none": 137 | layers["pool"] += [self.net_pooling] 138 | shape_feat[1] //= 2 139 | shape_feat[2] //= 2 140 | 141 | layers["conv"] = nn.ModuleList(layers["conv"]) 142 | layers["norm"] = nn.ModuleList(layers["norm"]) 143 | layers["act"] = nn.ModuleList(layers["act"]) 144 | layers["pool"] = nn.ModuleList(layers["pool"]) 145 | layers = nn.ModuleDict(layers) 146 | 147 | return layers, shape_feat 148 | 149 | """https://github.com/megvii-research/mdistiller/blob/a08d46f10d6102bd6e3f258ca5ac880b020ea259/mdistiller/models/cifar/mv2_tinyimagenet.py""" 150 | class LinearBottleNeck(nn.Module): 151 | 152 | def __init__(self, in_channels, out_channels, stride, t=6, class_num=100): 153 | super().__init__() 154 | 155 | self.residual = nn.Sequential( 156 | nn.Conv2d(in_channels, in_channels * t, 1), 157 | nn.BatchNorm2d(in_channels * t), 158 | nn.ReLU6(inplace=True), 159 | 160 | nn.Conv2d(in_channels * t, in_channels * t, 3, stride=stride, padding=1, groups=in_channels * t), 161 | nn.BatchNorm2d(in_channels * t), 162 | nn.ReLU6(inplace=True), 163 | 164 | nn.Conv2d(in_channels * t, out_channels, 1), 165 | nn.BatchNorm2d(out_channels) 166 | ) 167 | 168 | self.stride = stride 169 | self.in_channels = in_channels 170 | self.out_channels = out_channels 171 | 172 | def forward(self, x): 173 | 174 | residual = self.residual(x) 175 | 176 | if self.stride == 1 and self.in_channels == self.out_channels: 177 | residual += x 178 | 179 | return residual 180 | 181 | class MobileNetV2(nn.Module): 182 | 183 | def __init__(self, num_classes=100): 184 | super().__init__() 185 | 186 | self.pre = nn.Sequential( 187 | nn.Conv2d(3, 32, 1, padding=1), 188 | nn.BatchNorm2d(32), 189 | nn.ReLU6(inplace=True) 190 | ) 191 | 192 | self.stage1 = LinearBottleNeck(32, 16, 1, 1) 193 | self.stage2 = self._make_stage(2, 16, 24, 2, 6) 194 | self.stage3 = self._make_stage(3, 24, 32, 2, 6) 195 | self.stage4 = self._make_stage(4, 32, 64, 2, 6) 196 | self.stage5 = self._make_stage(3, 64, 96, 1, 6) 197 | self.stage6 = self._make_stage(3, 96, 160, 1, 6) 198 | self.stage7 = LinearBottleNeck(160, 320, 1, 6) 199 | 200 | self.conv1 = nn.Sequential( 201 | nn.Conv2d(320, 1280, 1), 202 | nn.BatchNorm2d(1280), 203 | nn.ReLU6(inplace=True) 204 | ) 205 | 206 | self.conv2 = nn.Conv2d(1280, num_classes, 1) 207 | 208 | def forward(self, x): 209 | x = self.pre(x) 210 | x = self.stage1(x) 211 | x = self.stage2(x) 212 | x = self.stage3(x) 213 | x = self.stage4(x) 214 | x = self.stage5(x) 215 | x = self.stage6(x) 216 | x = self.stage7(x) 217 | x = self.conv1(x) 218 | x = F.adaptive_avg_pool2d(x, 1) 219 | x = self.conv2(x) 220 | x = x.view(x.size(0), -1) 221 | 222 | return x 223 | 224 | def _make_stage(self, repeat, in_channels, out_channels, stride, t): 225 | 226 | layers = [] 227 | layers.append(LinearBottleNeck(in_channels, out_channels, stride, t)) 228 | 229 | while repeat - 1: 230 | layers.append(LinearBottleNeck(out_channels, out_channels, 1, t)) 231 | repeat -= 1 232 | 233 | return nn.Sequential(*layers) 234 | 235 | def mobilenetv2_tinyimagenet(): 236 | return MobileNetV2(num_classes = 200) 237 | 238 | def mobilenetv2_cifar10(): 239 | return MobileNetV2(num_classes = 10) -------------------------------------------------------------------------------- /evaluation/model_utils/utils.py: -------------------------------------------------------------------------------- 1 | """this code is modified from the RDED repo: https://github.com/LINs-lab/RDED/tree/main""" 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from argument import args as sys_args 5 | import torch 6 | import numpy as np 7 | import torchvision.transforms as transforms 8 | 9 | import torchvision.models as thmodels 10 | from torchvision.models._api import WeightsEnum 11 | from torch.hub import load_state_dict_from_url 12 | from model_utils.models import ConvNet, mobilenetv2_tinyimagenet, mobilenetv2_cifar10 13 | 14 | import math 15 | 16 | # use 0 to pad "other three picture" 17 | def pad(input_tensor, target_height, target_width=None): 18 | if target_width is None: 19 | target_width = target_height 20 | vertical_padding = target_height - input_tensor.size(2) 21 | horizontal_padding = target_width - input_tensor.size(3) 22 | 23 | top_padding = vertical_padding // 2 24 | bottom_padding = vertical_padding - top_padding 25 | left_padding = horizontal_padding // 2 26 | right_padding = horizontal_padding - left_padding 27 | 28 | padded_tensor = F.pad( 29 | input_tensor, (left_padding, right_padding, top_padding, bottom_padding) 30 | ) 31 | 32 | return padded_tensor 33 | 34 | 35 | def batched_forward(model, tensor, batch_size): 36 | total_samples = tensor.size(0) 37 | 38 | all_outputs = [] 39 | 40 | model.eval() 41 | 42 | with torch.no_grad(): 43 | for i in range(0, total_samples, batch_size): 44 | batch_data = tensor[i : min(i + batch_size, total_samples)] 45 | 46 | output = model(batch_data) 47 | 48 | all_outputs.append(output) 49 | 50 | final_output = torch.cat(all_outputs, dim=0) 51 | 52 | return final_output 53 | 54 | 55 | class MultiRandomCrop(torch.nn.Module): 56 | def __init__(self, num_crop=5, size=224, factor=2): 57 | super().__init__() 58 | self.num_crop = num_crop 59 | self.size = size 60 | self.factor = factor 61 | 62 | def forward(self, image): 63 | cropper = transforms.RandomResizedCrop( 64 | self.size // self.factor, 65 | ratio=(1, 1), 66 | antialias=True, 67 | ) 68 | patches = [] 69 | for _ in range(self.num_crop): 70 | patches.append(cropper(image)) 71 | return torch.stack(patches, 0) 72 | 73 | def __repr__(self) -> str: 74 | detail = f"(num_crop={self.num_crop}, size={self.size})" 75 | return f"{self.__class__.__name__}{detail}" 76 | 77 | 78 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 79 | 80 | denormalize = transforms.Compose( 81 | [ 82 | transforms.Normalize( 83 | mean=[0.0, 0.0, 0.0], std=[1 / 0.229, 1 / 0.224, 1 / 0.225] 84 | ), 85 | transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1.0, 1.0, 1.0]), 86 | ] 87 | ) 88 | 89 | 90 | def get_state_dict(self, *args, **kwargs): 91 | kwargs.pop("check_hash") 92 | return load_state_dict_from_url(self.url, *args, **kwargs) 93 | 94 | 95 | WeightsEnum.get_state_dict = get_state_dict 96 | 97 | 98 | def cross_entropy(y_pre, y): 99 | y_pre = F.softmax(y_pre, dim=1) 100 | return (-torch.log(y_pre.gather(1, y.view(-1, 1))))[:, 0] 101 | 102 | 103 | def selector(n, model, images, labels, size, m=5): 104 | with torch.no_grad(): 105 | # [mipc, m, 3, 224, 224] 106 | images = images.cuda() 107 | s = images.shape 108 | 109 | # [mipc * m, 3, 224, 224] 110 | images = images.permute(1, 0, 2, 3, 4) 111 | images = images.reshape(s[0] * s[1], s[2], s[3], s[4]) 112 | 113 | # [mipc * m, 1] 114 | labels = labels.repeat(m).cuda() 115 | 116 | # [mipc * m, n_class] 117 | batch_size = s[0] # Change it for small GPU memory 118 | preds = batched_forward(model, pad(images, size).cuda(), batch_size) 119 | 120 | # [mipc * m] 121 | dist = cross_entropy(preds, labels) 122 | 123 | # [m, mipc] 124 | dist = dist.reshape(m, s[0]) 125 | 126 | n_patch_per_image = math.ceil(n/dist.shape[-1]) 127 | 128 | # [mipc] 129 | # [n_patch_per_image, mipc] 130 | index = torch.argsort(dist, dim=0)[:n_patch_per_image,:] 131 | # print(f"index_1 shape: {index_1.shape}") 132 | # print(f"index_1 val: {index_1[0,0]}") 133 | # index = torch.argmin(dist, 0) 134 | # print(f"index shape: {index.shape}") 135 | # print(f"index val: {index[0]}") 136 | # dist_1 = dist[index_1, torch.arange(s[0])] 137 | # print(f"dist_1: {dist_1.shape}") 138 | dist = dist[index, torch.arange(s[0])] 139 | 140 | # [mipc*n_patch_per_image, 3, 224, 224] 141 | sa = images.shape 142 | images = images.reshape(m, s[0], sa[1], sa[2], sa[3]) 143 | images = images[index, torch.arange(s[0])] 144 | images = images.reshape(n_patch_per_image*s[0], sa[1], sa[2], sa[3]) 145 | dist = dist.reshape(n_patch_per_image*s[0]) 146 | 147 | indices = torch.argsort(dist, descending=False)[:n] 148 | torch.cuda.empty_cache() 149 | images = images[indices] 150 | # shuffle 151 | indexes = torch.randperm(images.shape[0]) 152 | return images[indexes].detach() 153 | 154 | 155 | def mix_images(input_img, out_size, factor, n): 156 | s = out_size // factor 157 | remained = out_size % factor 158 | k = 0 159 | mixed_images = torch.zeros( 160 | (n, 3, out_size, out_size), 161 | requires_grad=False, 162 | dtype=torch.float, 163 | ) 164 | h_loc = 0 165 | for i in range(factor): 166 | h_r = s + 1 if i < remained else s 167 | w_loc = 0 168 | for j in range(factor): 169 | w_r = s + 1 if j < remained else s 170 | # print(f"{k * n} : {(k + 1) * n}") 171 | 172 | img_part = F.interpolate( 173 | input_img.data[k * n : (k + 1) * n], size=(h_r, w_r) 174 | ) 175 | # print(f"shape: {img_part.shape}") 176 | mixed_images.data[ 177 | 0:n, 178 | :, 179 | h_loc : h_loc + h_r, 180 | w_loc : w_loc + w_r, 181 | ] = img_part 182 | w_loc += w_r 183 | k += 1 184 | h_loc += h_r 185 | return mixed_images 186 | 187 | 188 | def load_model(model_name="resnet18", dataset="cifar10", pretrained=True, classes=[]): 189 | def get_model(model_name="resnet18"): 190 | if "conv" in model_name: 191 | if dataset in ["cifar10", "cifar100"]: 192 | size = 32 193 | elif dataset in ["tinyimagenet", "imagenet-1k"]: 194 | size = 64 195 | elif dataset in ["imagenet-nette", "imagenet-woof", "imagenet-100"]: 196 | size = 128 197 | else: 198 | raise Exception("Unrecognized dataset") 199 | 200 | nclass = len(classes) 201 | 202 | model = ConvNet( 203 | num_classes=nclass, 204 | net_norm="batch", 205 | net_act="relu", 206 | net_pooling="avgpooling", 207 | net_depth=int(model_name[-1]), 208 | net_width=128, 209 | channel=3, 210 | im_size=(size, size), 211 | ) 212 | elif model_name == "resnet18_modified": 213 | model = thmodels.__dict__["resnet18"](pretrained=False) 214 | model.conv1 = nn.Conv2d( 215 | 3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False 216 | ) 217 | model.maxpool = nn.Identity() 218 | elif model_name == "resnet101_modified": 219 | model = thmodels.__dict__["resnet101"](pretrained=False) 220 | model.conv1 = nn.Conv2d( 221 | 3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False 222 | ) 223 | model.maxpool = nn.Identity() 224 | else: 225 | model = thmodels.__dict__[model_name](pretrained=False) 226 | 227 | return model 228 | 229 | def pruning_classifier(model=None, classes=[]): 230 | try: 231 | model_named_parameters = [name for name, x in model.named_parameters()] 232 | for name, x in model.named_parameters(): 233 | if ( 234 | name == model_named_parameters[-1] 235 | or name == model_named_parameters[-2] 236 | ): 237 | x.data = x[classes] 238 | except: 239 | print("ERROR in changing the number of classes.") 240 | 241 | return model 242 | 243 | if not pretrained and dataset == "tinyimagenet" and model_name == "mobilenet_v2_modified": 244 | return mobilenetv2_tinyimagenet() 245 | elif not pretrained and dataset == "cifar10" and model_name == "mobilenet_v2_modified": 246 | return mobilenetv2_cifar10() 247 | # "imagenet-100" "imagenet-10" "imagenet-first" "imagenet-nette" "imagenet-woof" 248 | model = get_model(model_name) 249 | model = pruning_classifier(model, classes) 250 | if pretrained: 251 | if dataset in [ 252 | "imagenet-100", 253 | "imagenet-10", 254 | "imagenet-nette", 255 | "imagenet-woof", 256 | "tinyimagenet", 257 | "cifar10", 258 | "cifar100", 259 | ]: 260 | checkpoint = torch.load( 261 | f"/workspace/save/{dataset}_{model_name}.pth", map_location="cpu" 262 | ) 263 | model.load_state_dict(checkpoint["model"]) 264 | elif dataset in ["imagenet-1k"]: 265 | if model_name == "efficientNet-b0": 266 | # Specifically, for loading the pre-trained EfficientNet model, the following modifications are made 267 | from torchvision.models._api import WeightsEnum 268 | from torch.hub import load_state_dict_from_url 269 | 270 | def get_state_dict(self, *args, **kwargs): 271 | kwargs.pop("check_hash") 272 | return load_state_dict_from_url(self.url, *args, **kwargs) 273 | 274 | WeightsEnum.get_state_dict = get_state_dict 275 | elif "conv" in model_name: 276 | checkpoint = torch.load( 277 | f"/workspace/save/{dataset}_{model_name}.pth", map_location="cpu" 278 | ) 279 | model.load_state_dict(checkpoint["model"]) 280 | else: 281 | model = thmodels.__dict__[model_name](pretrained=True) 282 | elif not pretrained and model_name == "mobilenet_v2": 283 | return thmodels.get_model("mobilenet_v2", weights=False, num_classes=len(classes)) 284 | return model 285 | -------------------------------------------------------------------------------- /evaluation/validation/main.py: -------------------------------------------------------------------------------- 1 | """this code is modified from the RDED repo: https://github.com/LINs-lab/RDED/tree/main""" 2 | import os 3 | import random 4 | 5 | import wandb 6 | 7 | import torch 8 | import torch.nn.parallel 9 | import torch.backends.cudnn as cudnn 10 | import torch.optim 11 | import torch.utils.data 12 | import torch.utils.data.distributed 13 | import torchvision.transforms as transforms 14 | from torch.optim.lr_scheduler import LambdaLR 15 | import math 16 | import time 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | 20 | from model_utils.utils import load_model 21 | from validation.utils import ( 22 | ImageFolder, 23 | mix_aug, 24 | AverageMeter, 25 | accuracy, 26 | get_parameters, 27 | ) 28 | 29 | from timm.data.auto_augment import rand_augment_transform 30 | 31 | sharing_strategy = "file_system" 32 | torch.multiprocessing.set_sharing_strategy(sharing_strategy) 33 | 34 | 35 | def set_worker_sharing_strategy(worker_id: int) -> None: 36 | torch.multiprocessing.set_sharing_strategy(sharing_strategy) 37 | 38 | 39 | def main(args): 40 | if args.seed is not None: 41 | random.seed(args.seed) 42 | torch.manual_seed(args.seed) 43 | main_worker(args) 44 | wandb.finish() 45 | 46 | 47 | def main_worker(args): 48 | if args.gpu_device != "-1": 49 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_device 50 | print("set CUDA_VISIBLE_DEVICES to ", args.gpu_device) 51 | wandb.login(key=args.wandb_api_key) 52 | wandb.init(project=args.wandb_project, name=args.exp_name) 53 | 54 | print("=> using pytorch pre-trained teacher model '{}'".format(args.arch_name)) 55 | teacher_model = load_model( 56 | model_name=args.arch_name, 57 | dataset=args.subset, 58 | pretrained=True, 59 | classes=args.classes, 60 | ) 61 | 62 | student_model = load_model( 63 | model_name=args.stud_name, 64 | dataset=args.subset, 65 | pretrained=False, 66 | classes=args.classes, 67 | ) 68 | teacher_model = torch.nn.DataParallel(teacher_model).cuda() 69 | student_model = torch.nn.DataParallel(student_model).cuda() 70 | 71 | teacher_model.eval() 72 | student_model.train() 73 | 74 | # freeze all layers 75 | for param in teacher_model.parameters(): 76 | param.requires_grad = False 77 | 78 | cudnn.benchmark = True 79 | 80 | # optimizer 81 | if args.sgd: 82 | optimizer = torch.optim.SGD( 83 | get_parameters(student_model), 84 | lr=args.learning_rate, 85 | momentum=args.momentum, 86 | weight_decay=args.weight_decay, 87 | ) 88 | else: 89 | optimizer = torch.optim.AdamW( 90 | get_parameters(student_model), 91 | lr=args.adamw_lr, 92 | betas=[0.9, 0.999], 93 | weight_decay=args.adamw_weight_decay, 94 | ) 95 | 96 | # lr scheduler 97 | if args.cos == True: 98 | scheduler = LambdaLR( 99 | optimizer, 100 | lambda step: 0.5 * (1.0 + math.cos(math.pi * step / args.re_epochs / 2)) 101 | if step <= args.re_epochs 102 | else 0, 103 | last_epoch=-1, 104 | ) 105 | else: 106 | scheduler = LambdaLR( 107 | optimizer, 108 | lambda step: (1.0 - step / args.re_epochs) if step <= args.re_epochs else 0, 109 | last_epoch=-1, 110 | ) 111 | 112 | print("process data from {}".format(args.syn_data_path)) 113 | normalize = transforms.Normalize( 114 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 115 | ) 116 | 117 | augment = [] 118 | 119 | if args.use_rand_augment: 120 | tfm = rand_augment_transform( 121 | config_str=args.rand_augment_config, 122 | hparams={'translate_const': 117, 'img_mean': (124, 116, 104)} 123 | ) 124 | 125 | augment.append(tfm) 126 | 127 | augment.append(transforms.ToTensor()) 128 | 129 | if args.random_erasing_p != 0.0: 130 | augment.append(transforms.RandomErasing(p=args.random_erasing_p)) 131 | 132 | augment.append( 133 | transforms.RandomResizedCrop( 134 | size=args.input_size, 135 | scale=(args.min_scale_crops, args.max_scale_crops), 136 | antialias=True, 137 | ) 138 | ) 139 | augment.append(transforms.RandomHorizontalFlip()) 140 | augment.append(normalize) 141 | 142 | train_dataset = ImageFolder( 143 | classes=range(args.nclass), 144 | ipc=args.ipc, 145 | mem=False, 146 | shuffle=True, 147 | root=args.syn_data_path, 148 | transform=transforms.Compose(augment), 149 | ) 150 | 151 | train_loader = torch.utils.data.DataLoader( 152 | train_dataset, 153 | batch_size=args.re_batch_size, 154 | shuffle=True, 155 | num_workers=args.workers, 156 | pin_memory=True, 157 | worker_init_fn=set_worker_sharing_strategy, 158 | ) 159 | 160 | val_loader = torch.utils.data.DataLoader( 161 | ImageFolder( 162 | classes=args.classes, 163 | ipc=args.val_ipc, 164 | mem=False, 165 | root=args.val_dir, 166 | transform=transforms.Compose( 167 | [ 168 | transforms.Resize(args.input_size // 7 * 8, antialias=True), 169 | transforms.CenterCrop(args.input_size), 170 | transforms.ToTensor(), 171 | normalize, 172 | ] 173 | ), 174 | ), 175 | batch_size=args.re_batch_size, 176 | shuffle=False, 177 | num_workers=args.workers, 178 | pin_memory=True, 179 | worker_init_fn=set_worker_sharing_strategy, 180 | ) 181 | print("load data successfully") 182 | 183 | best_acc1 = 0 184 | best_epoch = 0 185 | args.optimizer = optimizer 186 | args.scheduler = scheduler 187 | args.train_loader = train_loader 188 | args.val_loader = val_loader 189 | 190 | for epoch in range(args.re_epochs): 191 | global wandb_metrics 192 | wandb_metrics = {} 193 | train(epoch, train_loader, teacher_model, student_model, args) 194 | 195 | if epoch % 10 == 0 or epoch == args.re_epochs - 1: 196 | top1 = validate(student_model, args, epoch) 197 | else: 198 | top1 = 0 199 | wandb.log(wandb_metrics) 200 | scheduler.step() 201 | if top1 > best_acc1: 202 | best_acc1 = max(top1, best_acc1) 203 | best_epoch = epoch 204 | 205 | print(f"Train Finish! Best accuracy is {best_acc1}@{best_epoch}") 206 | 207 | 208 | def train(epoch, train_loader, teacher_model, student_model, args): 209 | """Generate soft labels and train""" 210 | objs = AverageMeter() 211 | top1 = AverageMeter() 212 | top5 = AverageMeter() 213 | 214 | optimizer = args.optimizer 215 | loss_function_kl = nn.KLDivLoss(reduction="batchmean") 216 | teacher_model.eval() 217 | student_model.train() 218 | t1 = time.time() 219 | for batch_idx, (images, labels) in enumerate(train_loader): 220 | with torch.no_grad(): 221 | images = images.cuda() 222 | labels = labels.cuda() 223 | 224 | mix_images, _, _, _ = mix_aug(images, args) 225 | 226 | pred_label = student_model(images) 227 | 228 | soft_mix_label = teacher_model(mix_images) 229 | soft_mix_label = F.softmax(soft_mix_label / args.temperature, dim=1) 230 | 231 | if batch_idx % args.re_accum_steps == 0: 232 | optimizer.zero_grad() 233 | 234 | prec1, prec5 = accuracy(pred_label, labels, topk=(1, 5)) 235 | 236 | pred_mix_label = student_model(mix_images) 237 | 238 | soft_pred_mix_label = F.log_softmax(pred_mix_label / args.temperature, dim=1) 239 | loss = loss_function_kl(soft_pred_mix_label, soft_mix_label) 240 | 241 | loss = loss / args.re_accum_steps 242 | 243 | loss.backward() 244 | if batch_idx % args.re_accum_steps == (args.re_accum_steps - 1): 245 | optimizer.step() 246 | 247 | n = images.size(0) 248 | objs.update(loss.item(), n) 249 | top1.update(prec1.item(), n) 250 | top5.update(prec5.item(), n) 251 | 252 | scheduler = args.scheduler 253 | metrics = { 254 | "train/Top1": top1.avg, 255 | "train/Top5": top5.avg, 256 | "train/loss": objs.avg, 257 | "train/lr": scheduler.get_last_lr()[0], 258 | "train/epoch": epoch,} 259 | wandb_metrics.update(metrics) 260 | 261 | printInfo = ( 262 | "TRAIN Iter {}: loss = {:.6f},\t".format(epoch, objs.avg) 263 | + "Top-1 err = {:.6f},\t".format(100 - top1.avg) 264 | + "Top-5 err = {:.6f},\t".format(100 - top5.avg) 265 | + "train_time = {:.6f}".format((time.time() - t1)) 266 | ) 267 | print(printInfo) 268 | t1 = time.time() 269 | 270 | 271 | def validate(model, args, epoch=None): 272 | objs = AverageMeter() 273 | top1 = AverageMeter() 274 | top5 = AverageMeter() 275 | loss_function = nn.CrossEntropyLoss() 276 | 277 | model.eval() 278 | t1 = time.time() 279 | with torch.no_grad(): 280 | for data, target in args.val_loader: 281 | target = target.type(torch.LongTensor) 282 | data, target = data.cuda(), target.cuda() 283 | 284 | output = model(data) 285 | loss = loss_function(output, target) 286 | 287 | prec1, prec5 = accuracy(output, target, topk=(1, 5)) 288 | n = data.size(0) 289 | objs.update(loss.item(), n) 290 | top1.update(prec1.item(), n) 291 | top5.update(prec5.item(), n) 292 | 293 | logInfo = ( 294 | "TEST:\nIter {}: loss = {:.6f},\t".format(epoch, objs.avg) 295 | + "Top-1 err = {:.6f},\t".format(100 - top1.avg) 296 | + "Top-5 err = {:.6f},\t".format(100 - top5.avg) 297 | + "val_time = {:.6f}".format(time.time() - t1) 298 | ) 299 | print(logInfo) 300 | metrics = { 301 | 'val/top1': top1.avg, 302 | 'val/top5': top5.avg, 303 | 'val/loss': objs.avg, 304 | 'val/epoch': epoch, 305 | } 306 | wandb_metrics.update(metrics) 307 | return top1.avg 308 | 309 | 310 | if __name__ == "__main__": 311 | pass 312 | # main(args) 313 | -------------------------------------------------------------------------------- /evaluation/validation/utils.py: -------------------------------------------------------------------------------- 1 | """this code is modified from the RDED repo: https://github.com/LINs-lab/RDED/tree/main""" 2 | import torch 3 | import numpy as np 4 | import os 5 | import torch.distributed 6 | import torchvision 7 | from torchvision.transforms import functional as t_F 8 | import torch.nn.functional as F 9 | import random 10 | 11 | 12 | # keep top k largest values, and smooth others 13 | def keep_top_k(p, k, n_classes=1000): # p is the softmax on label output 14 | if k == n_classes: 15 | return p 16 | 17 | values, indices = p.topk(k, dim=1) 18 | 19 | mask_topk = torch.zeros_like(p) 20 | mask_topk.scatter_(-1, indices, 1.0) 21 | top_p = mask_topk * p 22 | 23 | minor_value = (1 - torch.sum(values, dim=1)) / (n_classes - k) 24 | minor_value = minor_value.unsqueeze(1).expand(p.shape) 25 | mask_smooth = torch.ones_like(p) 26 | mask_smooth.scatter_(-1, indices, 0) 27 | smooth_p = mask_smooth * minor_value 28 | 29 | topk_smooth_p = top_p + smooth_p 30 | assert np.isclose( 31 | topk_smooth_p.sum().item(), p.shape[0] 32 | ), f"{topk_smooth_p.sum().item()} not close to {p.shape[0]}" 33 | return topk_smooth_p 34 | 35 | 36 | class AverageMeter(object): 37 | def __init__(self): 38 | self.reset() 39 | 40 | def reset(self): 41 | self.avg = 0 42 | self.sum = 0 43 | self.cnt = 0 44 | self.val = 0 45 | 46 | def update(self, val, n=1): 47 | self.val = val 48 | self.sum += val * n 49 | self.cnt += n 50 | self.avg = self.sum / self.cnt 51 | 52 | 53 | def accuracy(output, target, topk=(1,)): 54 | maxk = max(topk) 55 | batch_size = target.size(0) 56 | 57 | _, pred = output.topk(maxk, 1, True, True) 58 | pred = pred.t() 59 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 60 | 61 | res = [] 62 | for k in topk: 63 | correct_k = correct[:k].reshape(-1).float().sum(0) 64 | res.append(correct_k.mul_(100.0 / batch_size)) 65 | return res 66 | 67 | 68 | def get_parameters(model): 69 | group_no_weight_decay = [] 70 | group_weight_decay = [] 71 | for pname, p in model.named_parameters(): 72 | if pname.find("weight") >= 0 and len(p.size()) > 1: 73 | # print('include ', pname, p.size()) 74 | group_weight_decay.append(p) 75 | else: 76 | # print('not include ', pname, p.size()) 77 | group_no_weight_decay.append(p) 78 | assert len(list(model.parameters())) == len(group_weight_decay) + len( 79 | group_no_weight_decay 80 | ) 81 | groups = [ 82 | dict(params=group_weight_decay), 83 | dict(params=group_no_weight_decay, weight_decay=0.0), 84 | ] 85 | return groups 86 | 87 | 88 | class ImageFolder(torchvision.datasets.ImageFolder): 89 | def __init__(self, classes, ipc, mem=False, shuffle=False, **kwargs): 90 | super(ImageFolder, self).__init__(**kwargs) 91 | self.mem = mem 92 | self.image_paths = [] 93 | self.targets = [] 94 | self.samples = [] 95 | dirlist = [] 96 | for name in os.listdir(self.root): 97 | if os.path.isdir(os.path.join(self.root,name)): 98 | dirlist.append(name) 99 | dirlist.sort() 100 | # print(f"Num of dirs: {len(dirlist)}") 101 | # print(f"Num of classes: {len(classes)}") 102 | for c in range(len(classes)): 103 | # print(self.root) 104 | # print(dirlist) 105 | dir_path = os.path.join(self.root, dirlist[c]) 106 | # print("\n\n\n") 107 | 108 | # self.root + "/" + str(classes[c]).zfill(5) 109 | # print(dir_path) 110 | file_ls = os.listdir(dir_path) 111 | # exit() 112 | if shuffle: 113 | random.shuffle(file_ls) 114 | # print(len(file_ls)) 115 | # print(f"IPC: {ipc}") 116 | for i in range(ipc): 117 | self.image_paths.append(dir_path + "/" + file_ls[i]) 118 | self.targets.append(c) 119 | if self.mem: 120 | self.samples.append(self.loader(dir_path + "/" + file_ls[i])) 121 | 122 | def __getitem__(self, index): 123 | if self.mem: 124 | sample = self.samples[index] 125 | else: 126 | sample = self.loader(self.image_paths[index]) 127 | sample = self.transform(sample) 128 | return sample, self.targets[index] 129 | 130 | def __len__(self): 131 | return len(self.targets) 132 | 133 | 134 | def rand_bbox(size, lam): 135 | W = size[2] 136 | H = size[3] 137 | cut_rat = np.sqrt(1.0 - lam) 138 | cut_w = int(W * cut_rat) 139 | cut_h = int(H * cut_rat) 140 | 141 | # uniform 142 | cx = np.random.randint(W) 143 | cy = np.random.randint(H) 144 | 145 | bbx1 = np.clip(cx - cut_w // 2, 0, W) 146 | bby1 = np.clip(cy - cut_h // 2, 0, H) 147 | bbx2 = np.clip(cx + cut_w // 2, 0, W) 148 | bby2 = np.clip(cy + cut_h // 2, 0, H) 149 | 150 | return bbx1, bby1, bbx2, bby2 151 | 152 | 153 | def cutmix(images, args, rand_index=None, lam=None, bbox=None): 154 | rand_index = torch.randperm(images.size()[0]).cuda() 155 | lam = np.random.beta(args.cutmix, args.cutmix) 156 | bbx1, bby1, bbx2, bby2 = rand_bbox(images.size(), lam) 157 | 158 | images[:, :, bbx1:bbx2, bby1:bby2] = images[rand_index, :, bbx1:bbx2, bby1:bby2] 159 | return images, rand_index.cpu(), lam, [bbx1, bby1, bbx2, bby2] 160 | 161 | 162 | def mixup(images, args, rand_index=None, lam=None): 163 | rand_index = torch.randperm(images.size()[0]).cuda() 164 | lam = np.random.beta(args.mixup, args.mixup) 165 | 166 | mixed_images = lam * images + (1 - lam) * images[rand_index] 167 | return mixed_images, rand_index.cpu(), lam, None 168 | 169 | 170 | def mix_aug(images, args, rand_index=None, lam=None, bbox=None): 171 | if args.mix_type == "mixup": 172 | return mixup(images, args, rand_index, lam) 173 | elif args.mix_type == "cutmix": 174 | return cutmix(images, args, rand_index, lam, bbox) 175 | else: 176 | return images, None, None, None 177 | 178 | 179 | class ShufflePatches(torch.nn.Module): 180 | def shuffle_weight(self, img, factor): 181 | h, w = img.shape[1:] 182 | th, tw = h // factor, w // factor 183 | patches = [] 184 | for i in range(factor): 185 | i = i * tw 186 | if i != factor - 1: 187 | patches.append(img[..., i : i + tw]) 188 | else: 189 | patches.append(img[..., i:]) 190 | random.shuffle(patches) 191 | img = torch.cat(patches, -1) 192 | return img 193 | 194 | def __init__(self, factor): 195 | super().__init__() 196 | self.factor = factor 197 | 198 | def forward(self, img): 199 | img = self.shuffle_weight(img, self.factor) 200 | img = img.permute(0, 2, 1) 201 | img = self.shuffle_weight(img, self.factor) 202 | img = img.permute(0, 2, 1) 203 | return img 204 | -------------------------------------------------------------------------------- /imgs/DELT_framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VILA-Lab/DELT/1a45beb6381856081be6d48997c848dca50a375a/imgs/DELT_framework.png -------------------------------------------------------------------------------- /imgs/diversity_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VILA-Lab/DELT/1a45beb6381856081be6d48997c848dca50a375a/imgs/diversity_comparison.png -------------------------------------------------------------------------------- /imgs/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VILA-Lab/DELT/1a45beb6381856081be6d48997c848dca50a375a/imgs/results.png -------------------------------------------------------------------------------- /imgs/samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VILA-Lab/DELT/1a45beb6381856081be6d48997c848dca50a375a/imgs/samples.png -------------------------------------------------------------------------------- /recover/models.py: -------------------------------------------------------------------------------- 1 | """Adopted from https://github.com/LINs-lab/RDED/blob/main/synthesize/models.py to load the custom conv models""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.models as models 6 | 7 | # Conv-3 model 8 | class ConvNet(nn.Module): 9 | def __init__( 10 | self, 11 | num_classes, 12 | net_norm="batch", 13 | net_depth=3, 14 | net_width=128, 15 | channel=3, 16 | net_act="relu", 17 | net_pooling="avgpooling", 18 | im_size=(32, 32), 19 | ): 20 | # print(f"Define Convnet (depth {net_depth}, width {net_width}, norm {net_norm})") 21 | super(ConvNet, self).__init__() 22 | if net_act == "sigmoid": 23 | self.net_act = nn.Sigmoid() 24 | elif net_act == "relu": 25 | self.net_act = nn.ReLU() 26 | elif net_act == "leakyrelu": 27 | self.net_act = nn.LeakyReLU(negative_slope=0.01) 28 | else: 29 | exit("unknown activation function: %s" % net_act) 30 | 31 | if net_pooling == "maxpooling": 32 | self.net_pooling = nn.MaxPool2d(kernel_size=2, stride=2) 33 | elif net_pooling == "avgpooling": 34 | self.net_pooling = nn.AvgPool2d(kernel_size=2, stride=2) 35 | elif net_pooling == "none": 36 | self.net_pooling = None 37 | else: 38 | exit("unknown net_pooling: %s" % net_pooling) 39 | 40 | self.depth = net_depth 41 | self.net_norm = net_norm 42 | 43 | self.layers, shape_feat = self._make_layers( 44 | channel, net_width, net_depth, net_norm, net_pooling, im_size 45 | ) 46 | num_feat = shape_feat[0] * shape_feat[1] * shape_feat[2] 47 | self.classifier = nn.Linear(num_feat, num_classes) 48 | 49 | def forward(self, x, return_features=False): 50 | for d in range(self.depth): 51 | x = self.layers["conv"][d](x) 52 | if len(self.layers["norm"]) > 0: 53 | x = self.layers["norm"][d](x) 54 | x = self.layers["act"][d](x) 55 | if len(self.layers["pool"]) > 0: 56 | x = self.layers["pool"][d](x) 57 | 58 | # x = nn.functional.avg_pool2d(x, x.shape[-1]) 59 | out = x.view(x.shape[0], -1) 60 | logit = self.classifier(out) 61 | 62 | if return_features: 63 | return logit, out 64 | else: 65 | return logit 66 | 67 | def get_feature( 68 | self, x, idx_from, idx_to=-1, return_prob=False, return_logit=False 69 | ): 70 | if idx_to == -1: 71 | idx_to = idx_from 72 | features = [] 73 | 74 | for d in range(self.depth): 75 | x = self.layers["conv"][d](x) 76 | if self.net_norm: 77 | x = self.layers["norm"][d](x) 78 | x = self.layers["act"][d](x) 79 | if self.net_pooling: 80 | x = self.layers["pool"][d](x) 81 | features.append(x) 82 | if idx_to < len(features): 83 | return features[idx_from : idx_to + 1] 84 | 85 | if return_prob: 86 | out = x.view(x.size(0), -1) 87 | logit = self.classifier(out) 88 | prob = torch.softmax(logit, dim=-1) 89 | return features, prob 90 | elif return_logit: 91 | out = x.view(x.size(0), -1) 92 | logit = self.classifier(out) 93 | return features, logit 94 | else: 95 | return features[idx_from : idx_to + 1] 96 | 97 | def _get_normlayer(self, net_norm, shape_feat): 98 | # shape_feat = (c * h * w) 99 | if net_norm == "batch": 100 | norm = nn.BatchNorm2d(shape_feat[0], affine=True) 101 | elif net_norm == "layer": 102 | norm = nn.LayerNorm(shape_feat, elementwise_affine=True) 103 | elif net_norm == "instance": 104 | norm = nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True) 105 | elif net_norm == "group": 106 | norm = nn.GroupNorm(4, shape_feat[0], affine=True) 107 | elif net_norm == "none": 108 | norm = None 109 | else: 110 | norm = None 111 | exit("unknown net_norm: %s" % net_norm) 112 | return norm 113 | 114 | def _make_layers( 115 | self, channel, net_width, net_depth, net_norm, net_pooling, im_size 116 | ): 117 | layers = {"conv": [], "norm": [], "act": [], "pool": []} 118 | 119 | in_channels = channel 120 | if im_size[0] == 28: 121 | im_size = (32, 32) 122 | shape_feat = [in_channels, im_size[0], im_size[1]] 123 | 124 | for d in range(net_depth): 125 | layers["conv"] += [ 126 | nn.Conv2d( 127 | in_channels, 128 | net_width, 129 | kernel_size=3, 130 | padding=3 if channel == 1 and d == 0 else 1, 131 | ) 132 | ] 133 | shape_feat[0] = net_width 134 | if net_norm != "none": 135 | layers["norm"] += [self._get_normlayer(net_norm, shape_feat)] 136 | layers["act"] += [self.net_act] 137 | in_channels = net_width 138 | if net_pooling != "none": 139 | layers["pool"] += [self.net_pooling] 140 | shape_feat[1] //= 2 141 | shape_feat[2] //= 2 142 | 143 | layers["conv"] = nn.ModuleList(layers["conv"]) 144 | layers["norm"] = nn.ModuleList(layers["norm"]) 145 | layers["act"] = nn.ModuleList(layers["act"]) 146 | layers["pool"] = nn.ModuleList(layers["pool"]) 147 | layers = nn.ModuleDict(layers) 148 | 149 | return layers, shape_feat 150 | 151 | def load_model(model_name="resnet18", dataset="cifar10", pretrained=True, classes=[]): 152 | def get_model(model_name="resnet18"): 153 | if "conv" in model_name: 154 | if dataset in ["cifar10", "cifar100"]: 155 | size = 32 156 | elif dataset in ["tinyimagenet", "imagenet-1k"]: 157 | size = 64 158 | elif dataset in ["imagenet-nette", "imagenet-woof", "imagenet-100"]: 159 | size = 128 160 | else: 161 | raise Exception("Unrecognized dataset") 162 | 163 | nclass = len(classes) 164 | 165 | model = ConvNet( 166 | num_classes=nclass, 167 | net_norm="batch", 168 | net_act="relu", 169 | net_pooling="avgpooling", 170 | net_depth=int(model_name[-1]), 171 | net_width=128, 172 | channel=3, 173 | im_size=(size, size), 174 | ) 175 | elif model_name == "resnet18_modified": 176 | model = models.__dict__["resnet18"](pretrained=False) 177 | model.conv1 = nn.Conv2d( 178 | 3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False 179 | ) 180 | model.maxpool = nn.Identity() 181 | elif model_name == "resnet101_modified": 182 | model = models.__dict__["resnet101"](pretrained=False) 183 | model.conv1 = nn.Conv2d( 184 | 3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False 185 | ) 186 | model.maxpool = nn.Identity() 187 | else: 188 | model = models.__dict__[model_name](pretrained=False) 189 | 190 | return model 191 | 192 | def pruning_classifier(model=None, classes=[]): 193 | try: 194 | model_named_parameters = [name for name, x in model.named_parameters()] 195 | for name, x in model.named_parameters(): 196 | if ( 197 | name == model_named_parameters[-1] 198 | or name == model_named_parameters[-2] 199 | ): 200 | x.data = x[classes] 201 | except: 202 | print("ERROR in changing the number of classes.") 203 | 204 | return model 205 | 206 | # "imagenet-100" "imagenet-10" "imagenet-first" "imagenet-nette" "imagenet-woof" 207 | model = get_model(model_name) 208 | model = pruning_classifier(model, classes) 209 | if pretrained: 210 | if dataset in [ 211 | "imagenet-100", 212 | "imagenet-10", 213 | "imagenet-nette", 214 | "imagenet-woof", 215 | "tinyimagenet", 216 | "cifar10", 217 | "cifar100", 218 | ]: 219 | checkpoint = torch.load( 220 | f"/workspace/save/{dataset}_{model_name}.pth", map_location="cpu" 221 | ) 222 | model.load_state_dict(checkpoint["model"]) 223 | elif dataset in ["imagenet-1k"]: 224 | if model_name == "efficientNet-b0": 225 | # Specifically, for loading the pre-trained EfficientNet model, the following modifications are made 226 | from torchvision.models._api import WeightsEnum 227 | from torch.hub import load_state_dict_from_url 228 | 229 | def get_state_dict(self, *args, **kwargs): 230 | kwargs.pop("check_hash") 231 | return load_state_dict_from_url(self.url, *args, **kwargs) 232 | 233 | WeightsEnum.get_state_dict = get_state_dict 234 | elif "conv" in model_name: 235 | checkpoint = torch.load( 236 | f"/workspace/save/{dataset}_{model_name}.pth", map_location="cpu" 237 | ) 238 | model.load_state_dict(checkpoint["model"]) 239 | else: 240 | model = models.__dict__[model_name](pretrained=True) 241 | return model -------------------------------------------------------------------------------- /recover/recover.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import os 4 | import random 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from utils import BNFeatureHook, clip, lr_cosine_policy, ImageFolder 10 | from models import load_model # Load the custom conv models 11 | 12 | from torch.utils.data import DataLoader 13 | 14 | import torch.nn as nn 15 | import torch.optim as optim 16 | import torchvision.models as models 17 | from PIL import Image 18 | from torchvision import transforms 19 | 20 | 21 | from tqdm import tqdm 22 | 23 | 24 | def get_round_imgs_indx(args, class_iterations, targets, img_indxer): 25 | # =================================================================================== 26 | # first, get the maximum iterations for each class 27 | # =================================================================================== 28 | max_itr = class_iterations[targets.cpu().unique()].max().item() 29 | # initialize the lists that include 30 | slicing_list = [[] for i in range(max_itr//args.round_iterations)] 31 | for class_id in targets.unique(): 32 | init_img_indx = -1 if class_id.item() not in img_indxer else img_indxer[class_id.item()] 33 | # if we have a single 34 | if (targets==class_id).sum() == 1: 35 | img_index_in_batch = [(targets==class_id).nonzero().squeeze().item()] 36 | else: 37 | img_index_in_batch = (targets==class_id).nonzero().squeeze().tolist() 38 | img_indices = (np.array(img_index_in_batch) - img_index_in_batch[0] \ 39 | + init_img_indx + 1).tolist() 40 | # get the class rounds 41 | class_rounds = class_iterations[class_id.item()]//args.round_iterations 42 | 43 | round_imgs_indx = np.ones(class_rounds, dtype=int)*int(args.ipc//class_rounds) 44 | round_imgs_indx[:args.ipc%class_rounds] += 1 45 | round_imgs_indx = np.cumsum(round_imgs_indx) 46 | # Loop over every image in the batch 47 | for img_indx, batch_indx in zip(img_indices, img_index_in_batch): 48 | # get the rounds that include this image 49 | start_round = (img_indx < round_imgs_indx).nonzero()[0][0] 50 | for round_indx in range(start_round, len(round_imgs_indx)): 51 | slicing_list[round_indx].append(batch_indx) 52 | return slicing_list 53 | 54 | def get_images(args, inputs, targets, model_teacher, loss_r_feature_layers, 55 | class_iterations, img_indxer): 56 | save_every = 100 57 | 58 | best_cost = 1e4 59 | 60 | mean, std = get_mean_std(args) 61 | # ======================================================================================= 62 | # Prepare the labels and the inputs 63 | targets = targets.to('cuda') 64 | data_type = torch.float 65 | inputs = inputs.type(data_type) 66 | inputs = inputs.to('cuda') 67 | 68 | # --------------------------------------------------------------------------------------- 69 | # store the initial images to be used in the skip connection if needed 70 | # --------------------------------------------------------------------------------------- 71 | total_iterations = args.iteration 72 | skip_images = None 73 | inputs.requires_grad = True 74 | if args.use_early_late or (args.min_iterations != args.iteration): 75 | slicing_list = get_round_imgs_indx(args, class_iterations, targets, img_indxer) 76 | # =================================================================================== 77 | # Finally, check if there is an empty round 78 | # =================================================================================== 79 | round_indx = 0 80 | list_size = len(slicing_list) 81 | while round_indx < list_size: 82 | if len(slicing_list[round_indx]) == 0: # an empty round 83 | del slicing_list[round_indx] # delete the round 84 | round_indx -= 1 # reduce the number of rounds 85 | list_size -= 1 86 | total_iterations -= args.round_iterations # reduce the number of iterations 87 | round_indx += 1 88 | 89 | iterations_per_layer = total_iterations 90 | lim_0, lim_1 = args.jitter, args.jitter 91 | 92 | optimizer = optim.Adam([inputs], lr=args.lr, betas=[0.5, 0.9], eps=1e-8) 93 | lr_scheduler = lr_cosine_policy(args.lr, 0, iterations_per_layer) # 0 - do not use warmup 94 | 95 | criterion = nn.CrossEntropyLoss() 96 | criterion = criterion.cuda() 97 | round_indx= 0 98 | for iteration in range(iterations_per_layer): 99 | # =================================================================================== 100 | # Identify the round index for the Early-Late Training scheme 101 | # =================================================================================== 102 | if args.round_iterations != 0 and (args.use_early_late or \ 103 | (args.min_iterations != args.iteration)): 104 | round_indx = iteration//args.round_iterations 105 | # ----------------------------------------------------------------------------------- 106 | 107 | # learning rate scheduling: reset the scheduling per round 108 | if args.round_lr_reset and args.round_iterations != 0: 109 | lr_scheduler(optimizer, iteration%args.round_iterations, iteration%args.round_iterations) 110 | else: 111 | lr_scheduler(optimizer, iteration, iteration) 112 | lr_scheduler(optimizer, iteration, iteration) 113 | # =================================================================================== 114 | # strategy: start with whole image with mix crop of 1, then lower to 0.08 115 | # easy to hard 116 | min_crop = 0.08 117 | max_crop = 1.0 118 | if iteration < args.milestone * iterations_per_layer: 119 | if args.easy2hard_mode == "step": 120 | min_crop = 1.0 121 | elif args.easy2hard_mode == "linear": 122 | # min_crop linear decreasing: 1.0 -> 0.08 123 | min_crop = 0.08 + (1.0 - 0.08) * (1 - iteration / (args.milestone * iterations_per_layer)) 124 | elif args.easy2hard_mode == "cosine": 125 | # min_crop cosine decreasing: 1.0 -> 0.08 126 | min_crop = 0.08 + (1.0 - 0.08) * (1 + np.cos(np.pi * iteration / (args.milestone * iterations_per_layer))) / 2 127 | 128 | aug_function = transforms.Compose( 129 | [ 130 | # transforms.RandomResizedCrop(224, scale=(0.08, 1.0)), 131 | transforms.RandomResizedCrop(args.input_size, scale=(min_crop, max_crop)), 132 | transforms.RandomHorizontalFlip(), 133 | ] 134 | ) 135 | if args.round_iterations != 0 and (args.use_early_late or (args.min_iterations != args.iteration)): 136 | inputs_jit = aug_function(inputs[slicing_list[round_indx]]) 137 | else: 138 | inputs_jit = aug_function(inputs) 139 | # apply random jitter offsets 140 | off1 = random.randint(0, lim_0) 141 | off2 = random.randint(0, lim_1) 142 | inputs_jit = torch.roll(inputs_jit, shifts=(off1, off2), dims=(2, 3)) 143 | 144 | # forward pass 145 | optimizer.zero_grad() 146 | outputs = model_teacher(inputs_jit) 147 | 148 | # R_cross classification loss 149 | if args.round_iterations != 0 and (args.use_early_late or (args.min_iterations != args.iteration)): 150 | loss_ce = criterion(outputs, targets[slicing_list[round_indx]]) 151 | else: 152 | loss_ce = criterion(outputs, targets) 153 | 154 | # R_feature loss 155 | rescale = [args.first_bn_multiplier] + [1.0 for _ in range(len(loss_r_feature_layers) - 1)] 156 | loss_r_bn_feature = sum([mod.r_feature * rescale[idx] for (idx, mod) in enumerate(loss_r_feature_layers)]) 157 | 158 | # combining losses 159 | loss_aux = args.r_bn * loss_r_bn_feature 160 | 161 | loss = loss_ce + loss_aux 162 | 163 | # =================================================================================== 164 | # do image update 165 | # =================================================================================== 166 | loss.backward() 167 | # =================================================================================== 168 | optimizer.step() 169 | 170 | # clip color outlayers 171 | inputs.data = clip(inputs.data, mean=mean, std= std) 172 | 173 | if best_cost > loss.item() or iteration == 1: 174 | best_inputs = inputs.data.clone() 175 | if args.store_best_images: 176 | best_inputs = inputs.data.clone() # using multicrop, save the last one 177 | # add the denormalizer 178 | denormalize = transforms.Compose( 179 | [ 180 | transforms.Normalize( 181 | mean=[0.0, 0.0, 0.0], std= 1/std 182 | ), 183 | transforms.Normalize(mean= -mean, std=[1.0, 1.0, 1.0]), 184 | ] 185 | ) 186 | best_inputs = denormalize(best_inputs) 187 | save_images(args, best_inputs, targets, round_indx) 188 | 189 | # to reduce memory consumption by states of the optimizer we deallocate memory 190 | optimizer.state = collections.defaultdict(dict) 191 | torch.cuda.empty_cache() 192 | 193 | 194 | def save_images(args, images, targets, round_indx=0, indx_list=None): 195 | for id in range(images.shape[0]): 196 | if targets.ndimension() == 1: 197 | class_id = targets[id].item() 198 | else: 199 | class_id = targets[id].argmax().item() 200 | 201 | if not os.path.exists(args.syn_data_path): 202 | os.mkdir(args.syn_data_path) 203 | 204 | # save into separate folders 205 | dir_path = "{}/new{:03d}".format(args.syn_data_path, class_id) 206 | place_to_store = dir_path + "/class{:03d}_id{:03d}.jpg".format(class_id, get_img_indx(class_id)) 207 | 208 | if not os.path.exists(dir_path): 209 | os.makedirs(dir_path) 210 | 211 | image_np = images[id].data.cpu().numpy().transpose((1, 2, 0)) 212 | pil_image = Image.fromarray((image_np * 255).astype(np.uint8)) 213 | pil_image.save(place_to_store) 214 | 215 | def get_mean_std(args): 216 | if args.dataset in ["imagenet-1k", "imagenet-100", "imagenet-woof", "imagenet-nette"]: 217 | mean = np.array([0.485, 0.456, 0.406]) 218 | std = np.array([0.229, 0.224, 0.225]) 219 | elif args.dataset == "tinyimagenet": 220 | mean = np.array([0.4802, 0.4481, 0.3975]) 221 | std = np.array([0.2302, 0.2265, 0.2262]) 222 | elif args.dataset in ["cifar10", "cifar100"]: 223 | mean = np.array([0.4914, 0.4822, 0.4465]) 224 | std = np.array([0.2023, 0.1994, 0.2010]) 225 | return mean, std 226 | 227 | def main_syn(args): 228 | if "conv" in args.arch_name: 229 | model_teacher = load_model(model_name=args.arch_name, 230 | dataset=args.dataset, 231 | pretrained=True, 232 | classes=range(args.num_classes)) 233 | elif args.arch_path: 234 | model_teacher = models.get_model(args.arch_name, weights=False, num_classes=args.num_classes) 235 | if args.dataset in ["cifar10", "cifar100", "tinyimagenet"]: 236 | model_teacher.conv1 = nn.Conv2d( 237 | 3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False 238 | ) 239 | model_teacher.maxpool = nn.Identity() 240 | checkpoint = torch.load(args.arch_path, map_location="cpu") 241 | model_teacher.load_state_dict(checkpoint["model"]) 242 | else: 243 | model_teacher = models.get_model(args.arch_name, weights="DEFAULT") 244 | model_teacher = nn.DataParallel(model_teacher).cuda() 245 | model_teacher.eval() 246 | for p in model_teacher.parameters(): 247 | p.requires_grad = False 248 | 249 | # load the init_set 250 | # ======================================================================================= 251 | # do some augmentation to the initialized images 252 | transforms_list = [] 253 | mean, std = get_mean_std(args) 254 | transforms_list += [ transforms.Resize(args.input_size // 7 * 8, antialias=True), 255 | transforms.CenterCrop(args.input_size), 256 | transforms.ToTensor(), 257 | transforms.Normalize(mean= mean, 258 | std= std) 259 | ] 260 | init_data = ImageFolder( 261 | ipc = args.ipc, 262 | shuffle = False, 263 | root=args.init_data_path, 264 | transform=transforms.Compose(transforms_list) 265 | ) 266 | 267 | data_loader = DataLoader(init_data, batch_size=args.batch_size, 268 | shuffle=False, num_workers=args.workers, 269 | pin_memory=True) 270 | # Loop over the images 271 | global img_indxer 272 | img_indxer = {} 273 | 274 | # ======================================================================================= 275 | # define the number of gradient iterations for each class 276 | # ======================================================================================= 277 | class_iterations = torch.ones((args.num_classes,),dtype=int)*args.iteration 278 | if args.round_iterations != 0: 279 | min_rounds = args.min_iterations//args.round_iterations 280 | max_rounds = args.iteration//args.round_iterations 281 | # randomly select the number of rounds for each class 282 | class_iterations = torch.randint(min_rounds, max_rounds+1, (args.num_classes,)) 283 | # get the number of iterations by multiplying the round number * iteration per round 284 | class_iterations = class_iterations*args.round_iterations 285 | print(f"Class Iterations: {class_iterations}") 286 | 287 | loss_r_feature_layers = [] 288 | for module in model_teacher.modules(): 289 | if isinstance(module, nn.BatchNorm2d): 290 | loss_r_feature_layers.append(BNFeatureHook(module)) 291 | 292 | 293 | for images, labels in tqdm(data_loader): 294 | get_images(args, images, labels, model_teacher, loss_r_feature_layers, 295 | class_iterations, img_indxer) 296 | 297 | 298 | def get_img_indx(class_id): 299 | global img_indxer 300 | if class_id not in img_indxer: 301 | img_indxer[class_id] = 0 302 | else: 303 | img_indxer[class_id] += 1 304 | return img_indxer[class_id] 305 | 306 | def parse_args(): 307 | parser = argparse.ArgumentParser("DELT: Early-Late Recovery scheme for different datasets") 308 | """Data save flags""" 309 | parser.add_argument("--exp-name", type=str, default="test", help="name of the experiment, subfolder under syn_data_path") 310 | parser.add_argument("--init-data-path", type=str, default="/workspace/data/RDED_IN_1K/ipc50_cr5_mipc300", help="location of initialization data") 311 | parser.add_argument("--syn-data-path", type=str, default="./syn-data", help="where to store synthetic data") 312 | parser.add_argument("--store-best-images", action="store_true", help="whether to store best images") 313 | parser.add_argument("--ipc", type=int, default=50, help="number of IPC to use") 314 | parser.add_argument("--dataset", type=str, default="imagenet-1k", help="dataset to use") 315 | parser.add_argument("--input-size", type=int, default=224, help="image input size") 316 | parser.add_argument("--num-classes", type=int, default=1000, help="number of classes") 317 | 318 | """Early-Late Training flags""" 319 | parser.add_argument("--use-early-late", action="store_true", default=False, help="use a incremental learn") 320 | parser.add_argument("--round-iterations", type=int, default=0, help="number of iterations in a single round") 321 | parser.add_argument("--min-iterations", type=int, default=-1, help="minimum num of iterations to optimize the synthetic data of a specific class") 322 | parser.add_argument("--round-lr-reset", action="store_true", default=False, help="reset the lr per round") 323 | """Optimization related flags""" 324 | parser.add_argument("--batch-size", type=int, default=100, help="number of images to optimize at the same time") 325 | parser.add_argument('-j', '--workers', default=16, type=int, help='number of data loading workers') 326 | parser.add_argument("--iteration", type=int, default=1000, help="num of iterations to optimize the synthetic data") 327 | parser.add_argument("--lr", type=float, default=0.1, help="learning rate for optimization") 328 | parser.add_argument("--jitter", default=32, type=int, help="random shift on the synthetic data") 329 | parser.add_argument("--r-bn", type=float, default=0.05, help="coefficient for BN feature distribution regularization") 330 | parser.add_argument("--first-bn-multiplier", type=float, default=10.0, help="additional multiplier on first bn layer of R_bn") 331 | """Model related flags""" 332 | parser.add_argument("--arch-name", type=str, default="resnet18", help="arch name from pretrained torchvision models") 333 | parser.add_argument("--arch-path", type=str, default="", help="path to the teacher model") 334 | parser.add_argument("--easy2hard-mode", default="cosine", type=str, choices=["step", "linear", "cosine"]) 335 | parser.add_argument("--milestone", default=0, type=float) 336 | parser.add_argument('--gpu-device', default="-1", type=str) 337 | args = parser.parse_args() 338 | 339 | assert args.milestone >= 0 and args.milestone <= 1 340 | # assert args.batch_size%args.ipc == 0 and args.batch_size != 0 341 | 342 | if args.gpu_device != "-1": 343 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_device 344 | print("set CUDA_VISIBLE_DEVICES to ", args.gpu_device) 345 | 346 | # ======================================================================================= 347 | # Do some initializations 348 | # ======================================================================================= 349 | if args.dataset == "imagenet-1k": 350 | args.num_classes = 1000 351 | elif args.dataset == "tinyimagenet": 352 | args.num_classes = 200 353 | elif args.dataset in ["imagenet-100", "cifar100"]: 354 | args.num_classes = 100 355 | elif args.dataset in ["imagenet-woof", "imagenet-nette", "cifar10"]: 356 | args.num_classes = 10 357 | 358 | if args.dataset in ["imagenet-1k", "imagenet-100", "imagenet-woof", "imagenet-nette"]: 359 | args.input_size = 224 360 | elif args.dataset == "tinyimagenet": 361 | args.input_size = 64 362 | elif args.dataset in ["cifar10", "cifar100"]: 363 | args.input_size = 32 364 | 365 | if "conv" in args.arch_name: 366 | if args.dataset in ["imagenet-100", "imagenet-woof", "imagenet-nette"]: 367 | args.input_size = 128 368 | elif args.dataset in ["tinyimagenet", "imagenet-1k"]: 369 | args.input_size = 64 370 | elif args.dataset in ["cifar10", "cifar100"]: 371 | args.input_size = 32 372 | 373 | if args.min_iterations == -1: 374 | args.min_iterations = args.iteration 375 | else: 376 | args.min_iterations = max(args.min_iterations, args.round_iterations) 377 | 378 | args.syn_data_path = os.path.join(args.syn_data_path, args.exp_name) 379 | print(args) 380 | return args 381 | 382 | 383 | if __name__ == "__main__": 384 | args = parse_args() 385 | 386 | if not os.path.exists(args.syn_data_path): 387 | os.makedirs(args.syn_data_path) 388 | 389 | main_syn(args) 390 | print("Done.") 391 | -------------------------------------------------------------------------------- /recover/utils.py: -------------------------------------------------------------------------------- 1 | """Modifying the utils from the original CDA here https://github.com/VILA-Lab/SRe2L/blob/main/CDA/utils.py""" 2 | import numpy as np 3 | import torch 4 | import os 5 | import torchvision 6 | import random 7 | 8 | def clip(image_tensor, mean = np.array([0.485, 0.456, 0.406]), 9 | std = np.array([0.229, 0.224, 0.225])): 10 | """ 11 | adjust the input based on mean and variance 12 | """ 13 | # mean = np.array([0.485, 0.456, 0.406]) 14 | # std = np.array([0.229, 0.224, 0.225]) 15 | for c in range(3): 16 | m, s = mean[c], std[c] 17 | image_tensor[:, c] = torch.clamp(image_tensor[:, c], -m / s, (1 - m) / s) 18 | return image_tensor 19 | 20 | 21 | def tiny_clip(image_tensor): 22 | """ 23 | adjust the input based on mean and variance, using tiny-imagenet normalization 24 | """ 25 | mean = np.array([0.4802, 0.4481, 0.3975]) 26 | std = np.array([0.2302, 0.2265, 0.2262]) 27 | 28 | for c in range(3): 29 | m, s = mean[c], std[c] 30 | image_tensor[:, c] = torch.clamp(image_tensor[:, c], -m / s, (1 - m) / s) 31 | return image_tensor 32 | 33 | 34 | def denormalize(image_tensor, mean = np.array([0.485, 0.456, 0.406]), 35 | std = np.array([0.229, 0.224, 0.225])): 36 | """ 37 | convert floats back to input 38 | """ 39 | # mean = np.array([0.485, 0.456, 0.406]) 40 | # std = np.array([0.229, 0.224, 0.225]) 41 | 42 | for c in range(3): 43 | m, s = mean[c], std[c] 44 | image_tensor[:, c] = torch.clamp(image_tensor[:, c] * s + m, 0, 1) 45 | 46 | return image_tensor 47 | 48 | 49 | def tiny_denormalize(image_tensor): 50 | """ 51 | convert floats back to input, using tiny-imagenet normalization 52 | """ 53 | mean = np.array([0.4802, 0.4481, 0.3975]) 54 | std = np.array([0.2302, 0.2265, 0.2262]) 55 | 56 | for c in range(3): 57 | m, s = mean[c], std[c] 58 | image_tensor[:, c] = torch.clamp(image_tensor[:, c] * s + m, 0, 1) 59 | 60 | return image_tensor 61 | 62 | 63 | def lr_policy(lr_fn): 64 | def _alr(optimizer, iteration, epoch): 65 | lr = lr_fn(iteration, epoch) 66 | for param_group in optimizer.param_groups: 67 | param_group["lr"] = lr 68 | 69 | return _alr 70 | 71 | 72 | def lr_cosine_policy(base_lr, warmup_length, epochs): 73 | def _lr_fn(iteration, epoch): 74 | if epoch < warmup_length: 75 | lr = base_lr * (epoch + 1) / warmup_length 76 | else: 77 | e = epoch - warmup_length 78 | es = epochs - warmup_length 79 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr 80 | return lr 81 | 82 | return lr_policy(_lr_fn) 83 | 84 | 85 | class ViT_BNFeatureHook: 86 | def __init__(self, module): 87 | self.hook = module.register_forward_hook(self.hook_fn) 88 | 89 | def hook_fn(self, module, input, output): 90 | B, N, C = input[0].shape 91 | mean = torch.mean(input[0], dim=[0, 1]) 92 | var = torch.var(input[0], dim=[0, 1], unbiased=False) 93 | r_feature = torch.norm(module.running_var.data - var, 2) + torch.norm(module.running_mean.data - mean, 2) 94 | self.r_feature = r_feature 95 | 96 | def close(self): 97 | self.hook.remove() 98 | 99 | 100 | class BNFeatureHook: 101 | def __init__(self, module): 102 | self.hook = module.register_forward_hook(self.hook_fn) 103 | 104 | def hook_fn(self, module, input, output): 105 | nch = input[0].shape[1] 106 | mean = input[0].mean([0, 2, 3]) 107 | var = input[0].permute(1, 0, 2, 3).contiguous().reshape([nch, -1]).var(1, unbiased=False) 108 | r_feature = torch.norm(module.running_var.data - var, 2) + torch.norm(module.running_mean.data - mean, 2) 109 | self.r_feature = r_feature 110 | 111 | def close(self): 112 | self.hook.remove() 113 | 114 | 115 | # modified from Alibaba-ImageNet21K/src_files/models/utils/factory.py 116 | def load_model_weights(model, model_path): 117 | state = torch.load(model_path, map_location="cpu") 118 | 119 | Flag = False 120 | if "state_dict" in state: 121 | # resume from a model trained with nn.DataParallel 122 | state = state["state_dict"] 123 | Flag = True 124 | 125 | for key in model.state_dict(): 126 | if "num_batches_tracked" in key: 127 | continue 128 | p = model.state_dict()[key] 129 | 130 | if Flag: 131 | key = "module." + key 132 | 133 | if key in state: 134 | ip = state[key] 135 | # if key in state['state_dict']: 136 | # ip = state['state_dict'][key] 137 | if p.shape == ip.shape: 138 | p.data.copy_(ip.data) # Copy the data of parameters 139 | else: 140 | print("could not load layer: {}, mismatch shape {} ,{}".format(key, (p.shape), (ip.shape))) 141 | else: 142 | print("could not load layer: {}, not in checkpoint".format(key)) 143 | return model 144 | 145 | class ImageFolder(torchvision.datasets.ImageFolder): 146 | def __init__(self, ipc = 50, mem=False, shuffle=False, **kwargs): 147 | super(ImageFolder, self).__init__(**kwargs) 148 | self.mem = mem 149 | self.image_paths = [] 150 | self.targets = [] 151 | self.samples = [] 152 | dirlist = [] 153 | for name in os.listdir(self.root): 154 | if os.path.isdir(os.path.join(self.root,name)): 155 | dirlist.append(name) 156 | dirlist.sort() 157 | for c in range(len(dirlist)): 158 | dir_path = os.path.join(self.root, dirlist[c]) 159 | file_ls = os.listdir(dir_path) 160 | file_ls.sort() 161 | if shuffle: 162 | random.shuffle(file_ls) 163 | for i in range(ipc): 164 | self.image_paths.append(dir_path + "/" + file_ls[i]) 165 | # print(self.image_paths) 166 | # exit() 167 | self.targets.append(c) 168 | if self.mem: 169 | self.samples.append(self.loader(dir_path + "/" + file_ls[i])) 170 | 171 | def __getitem__(self, index): 172 | if self.mem: 173 | sample = self.samples[index] 174 | else: 175 | sample = self.loader(self.image_paths[index]) 176 | sample = self.transform(sample) 177 | return sample, self.targets[index] 178 | 179 | def __len__(self): 180 | return len(self.targets) -------------------------------------------------------------------------------- /scripts/conv4_imagenet1k_synthesis.sh: -------------------------------------------------------------------------------- 1 | GPU=0 2 | 3 | # ========================================================================================= 4 | 5 | SYN_PATH="/path/to/synthesized/conv_imagenet_1k/" 6 | INIT_PATH="/path/to/initialization/medium_prob" 7 | IPC=50 8 | ITR=4000 9 | ROUND_ITR=500 10 | EXP_NAME="IPC$((IPC))_4K_500_medium" 11 | 12 | echo "$EXP_NAME" 13 | 14 | python /path/to/DELT/recover/recover.py \ 15 | --init-data-path "$INIT_PATH" \ 16 | --syn-data-path "$SYN_PATH" \ 17 | --arch-name "conv4" \ 18 | --arch-path "/path/to/model/imagenet-1k_conv4.pth" \ 19 | --dataset "imagenet-1k" \ 20 | --exp-name "$EXP_NAME" \ 21 | --use-early-late \ 22 | --round-iterations $((ROUND_ITR)) \ 23 | --batch-size 100 \ 24 | --lr 0.25 \ 25 | --r-bn 0.01 \ 26 | --gpu-device $((GPU)) \ 27 | --iteration $((ITR)) \ 28 | --jitter 0\ 29 | --easy2hard-mode "cosine" --milestone 1 \ 30 | --ipc $((IPC)) --store-best-images 31 | 32 | echo "Synthesis -> DONE" 33 | 34 | # ========================================================================================= 35 | 36 | IPC=10 37 | ITR=4000 38 | ROUND_ITR=500 39 | EXP_NAME="IPC$((IPC))_4K_500_medium" 40 | 41 | echo "$EXP_NAME" 42 | 43 | python /path/to/DELT/recover/recover.py \ 44 | --init-data-path "$INIT_PATH" \ 45 | --syn-data-path "$SYN_PATH" \ 46 | --arch-name "conv4" \ 47 | --arch-path "/path/to/model/imagenet-1k_conv4.pth" \ 48 | --dataset "imagenet-1k" \ 49 | --exp-name "$EXP_NAME" \ 50 | --use-early-late \ 51 | --round-iterations $((ROUND_ITR)) \ 52 | --batch-size 100 \ 53 | --lr 0.25 \ 54 | --r-bn 0.01 \ 55 | --gpu-device $((GPU)) \ 56 | --iteration $((ITR)) \ 57 | --jitter 0\ 58 | --easy2hard-mode "cosine" --milestone 1 \ 59 | --ipc $((IPC)) --store-best-images 60 | 61 | echo "Synthesis -> DONE" 62 | 63 | # ========================================================================================= 64 | 65 | IPC=1 66 | ITR=4000 67 | EXP_NAME="IPC$((IPC))_4K_medium" 68 | 69 | echo "$EXP_NAME" 70 | 71 | python /path/to/DELT/recover/recover.py \ 72 | --init-data-path "$INIT_PATH" \ 73 | --syn-data-path "$SYN_PATH" \ 74 | --arch-name "conv4" \ 75 | --arch-path "/path/to/model/imagenet-1k_conv4.pth" \ 76 | --dataset "imagenet-1k" \ 77 | --exp-name "$EXP_NAME" \ 78 | --batch-size 100 \ 79 | --lr 0.25 \ 80 | --r-bn 0.01 \ 81 | --gpu-device $((GPU)) \ 82 | --iteration $((ITR)) \ 83 | --jitter 0\ 84 | --easy2hard-mode "cosine" --milestone 1 \ 85 | --ipc $((IPC)) --store-best-images 86 | 87 | echo "Synthesis -> DONE" 88 | 89 | # ========================================================================================= 90 | 91 | IPC=1 92 | ITR=3000 93 | EXP_NAME="IPC$((IPC))_3K_medium" 94 | 95 | echo "$EXP_NAME" 96 | 97 | python /path/to/DELT/recover/recover.py \ 98 | --init-data-path "$INIT_PATH" \ 99 | --syn-data-path "$SYN_PATH" \ 100 | --arch-name "conv4" \ 101 | --arch-path "/path/to/model/imagenet-1k_conv4.pth" \ 102 | --dataset "imagenet-1k" \ 103 | --exp-name "$EXP_NAME" \ 104 | --batch-size 100 \ 105 | --lr 0.25 \ 106 | --r-bn 0.01 \ 107 | --gpu-device $((GPU)) \ 108 | --iteration $((ITR)) \ 109 | --jitter 0\ 110 | --easy2hard-mode "cosine" --milestone 1 \ 111 | --ipc $((IPC)) --store-best-images 112 | 113 | echo "Synthesis -> DONE" 114 | 115 | # ========================================================================================= 116 | 117 | IPC=1 118 | ITR=2000 119 | EXP_NAME="IPC$((IPC))_2K_medium" 120 | 121 | echo "$EXP_NAME" 122 | 123 | python /path/to/DELT/recover/recover.py \ 124 | --init-data-path "$INIT_PATH" \ 125 | --syn-data-path "$SYN_PATH" \ 126 | --arch-name "conv4" \ 127 | --arch-path "/path/to/model/imagenet-1k_conv4.pth" \ 128 | --dataset "imagenet-1k" \ 129 | --exp-name "$EXP_NAME" \ 130 | --batch-size 100 \ 131 | --lr 0.25 \ 132 | --r-bn 0.01 \ 133 | --gpu-device $((GPU)) \ 134 | --iteration $((ITR)) \ 135 | --jitter 0\ 136 | --easy2hard-mode "cosine" --milestone 1 \ 137 | --ipc $((IPC)) --store-best-images 138 | 139 | echo "Synthesis -> DONE" 140 | 141 | # ========================================================================================= 142 | 143 | IPC=1 144 | ITR=1000 145 | EXP_NAME="IPC$((IPC))_1K_medium" 146 | 147 | echo "$EXP_NAME" 148 | 149 | python /path/to/DELT/recover/recover.py \ 150 | --init-data-path "$INIT_PATH" \ 151 | --syn-data-path "$SYN_PATH" \ 152 | --arch-name "conv4" \ 153 | --arch-path "/path/to/model/imagenet-1k_conv4.pth" \ 154 | --dataset "imagenet-1k" \ 155 | --exp-name "$EXP_NAME" \ 156 | --batch-size 100 \ 157 | --lr 0.25 \ 158 | --r-bn 0.01 \ 159 | --gpu-device $((GPU)) \ 160 | --iteration $((ITR)) \ 161 | --jitter 0\ 162 | --easy2hard-mode "cosine" --milestone 1 \ 163 | --ipc $((IPC)) --store-best-images 164 | 165 | echo "Synthesis -> DONE" 166 | 167 | # ========================================================================================= 168 | -------------------------------------------------------------------------------- /scripts/conv4_imagenet1k_validation.sh: -------------------------------------------------------------------------------- 1 | GPU=0 2 | 3 | # ========================================================================================= 4 | 5 | SYN_PATH="/path/to/synthesized/imagenet_1k_conv/" 6 | PROJECT_NAME="Imagenet1K-Seeds" 7 | IPC=50 8 | EXP_NAME="IPC$((IPC))_4K_500_medium" 9 | WANDB_API_KEY="write your api here" 10 | 11 | echo "$EXP_NAME" 12 | SEED=3407 13 | 14 | RAND_AUG="rand-m6-n2-mstd1.0" 15 | # VAL_NAME="$RAND_AUG" 16 | VAL_NAME="IPC$((IPC)) 4K_500 Medium Conv4 S$((SEED))" 17 | 18 | wandb enabled 19 | wandb online 20 | python /path/to/DELT/evaluation/main.py \ 21 | --wandb-project "$PROJECT_NAME" \ 22 | --wandb-api-key "$WANDB_API_KEY" \ 23 | --val-dir "/path/to/val" \ 24 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 25 | --exp-name "$VAL_NAME" \ 26 | --subset "imagenet-1k" \ 27 | --arch-name "conv4" \ 28 | --use-rand-augment \ 29 | --rand-augment-config "$RAND_AUG" \ 30 | --random-erasing-p 0.0 \ 31 | --min-scale-crops 0.25 \ 32 | --max-scale-crops 1.0 \ 33 | --ipc $((IPC)) \ 34 | --val-ipc 50 \ 35 | --stud-name "conv4" \ 36 | --re-epochs 300 \ 37 | --gpu-device $((GPU)) \ 38 | --seed $((SEED)) 39 | 40 | # ========================================================================================= 41 | 42 | 43 | SEED=4663 44 | VAL_NAME="IPC$((IPC)) 4K_500 Medium Conv4 S$((SEED))" 45 | 46 | wandb enabled 47 | wandb online 48 | python /path/to/DELT/evaluation/main.py \ 49 | --wandb-project "$PROJECT_NAME" \ 50 | --wandb-api-key "$WANDB_API_KEY" \ 51 | --val-dir "/path/to/val" \ 52 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 53 | --exp-name "$VAL_NAME" \ 54 | --subset "imagenet-1k" \ 55 | --arch-name "conv4" \ 56 | --use-rand-augment \ 57 | --rand-augment-config "$RAND_AUG" \ 58 | --random-erasing-p 0.0 \ 59 | --min-scale-crops 0.25 \ 60 | --max-scale-crops 1.0 \ 61 | --ipc $((IPC)) \ 62 | --val-ipc 50 \ 63 | --stud-name "conv4" \ 64 | --re-epochs 300 \ 65 | --gpu-device $((GPU)) \ 66 | --seed $((SEED)) 67 | 68 | # ========================================================================================= 69 | 70 | SEED=2897 71 | VAL_NAME="IPC$((IPC)) 4K_500 Medium Conv4 S$((SEED))" 72 | 73 | wandb enabled 74 | wandb online 75 | python /path/to/DELT/evaluation/main.py \ 76 | --wandb-project "$PROJECT_NAME" \ 77 | --wandb-api-key "$WANDB_API_KEY" \ 78 | --val-dir "/path/to/val" \ 79 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 80 | --exp-name "$VAL_NAME" \ 81 | --subset "imagenet-1k" \ 82 | --arch-name "conv4" \ 83 | --use-rand-augment \ 84 | --rand-augment-config "$RAND_AUG" \ 85 | --random-erasing-p 0.0 \ 86 | --min-scale-crops 0.25 \ 87 | --max-scale-crops 1.0 \ 88 | --ipc $((IPC)) \ 89 | --val-ipc 50 \ 90 | --stud-name "conv4" \ 91 | --re-epochs 300 \ 92 | --gpu-device $((GPU)) \ 93 | --seed $((SEED)) 94 | 95 | # ========================================================================================= 96 | 97 | IPC=10 98 | EXP_NAME="IPC$((IPC))_4K_500_medium" 99 | 100 | echo "$EXP_NAME" 101 | SEED=3407 102 | 103 | 104 | # VAL_NAME="$RAND_AUG" 105 | VAL_NAME="IPC$((IPC)) 4K_500 Medium Conv4 S$((SEED))" 106 | 107 | wandb enabled 108 | wandb online 109 | python /path/to/DELT/evaluation/main.py \ 110 | --wandb-project "$PROJECT_NAME" \ 111 | --wandb-api-key "$WANDB_API_KEY" \ 112 | --val-dir "/path/to/val" \ 113 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 114 | --exp-name "$VAL_NAME" \ 115 | --subset "imagenet-1k" \ 116 | --arch-name "conv4" \ 117 | --use-rand-augment \ 118 | --rand-augment-config "$RAND_AUG" \ 119 | --random-erasing-p 0.0 \ 120 | --min-scale-crops 0.25 \ 121 | --max-scale-crops 1.0 \ 122 | --ipc $((IPC)) \ 123 | --val-ipc 50 \ 124 | --stud-name "conv4" \ 125 | --re-epochs 300 \ 126 | --gpu-device $((GPU)) \ 127 | --seed $((SEED)) 128 | 129 | # ========================================================================================= 130 | 131 | 132 | SEED=4663 133 | VAL_NAME="IPC$((IPC)) 4K_500 Medium Conv4 S$((SEED))" 134 | 135 | wandb enabled 136 | wandb online 137 | python /path/to/DELT/evaluation/main.py \ 138 | --wandb-project "$PROJECT_NAME" \ 139 | --wandb-api-key "$WANDB_API_KEY" \ 140 | --val-dir "/path/to/val" \ 141 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 142 | --exp-name "$VAL_NAME" \ 143 | --subset "imagenet-1k" \ 144 | --arch-name "conv4" \ 145 | --use-rand-augment \ 146 | --rand-augment-config "$RAND_AUG" \ 147 | --random-erasing-p 0.0 \ 148 | --min-scale-crops 0.25 \ 149 | --max-scale-crops 1.0 \ 150 | --ipc $((IPC)) \ 151 | --val-ipc 50 \ 152 | --stud-name "conv4" \ 153 | --re-epochs 300 \ 154 | --gpu-device $((GPU)) \ 155 | --seed $((SEED)) 156 | 157 | # ========================================================================================= 158 | 159 | SEED=2897 160 | VAL_NAME="IPC$((IPC)) 4K_500 Medium Conv4 S$((SEED))" 161 | 162 | wandb enabled 163 | wandb online 164 | python /path/to/DELT/evaluation/main.py \ 165 | --wandb-project "$PROJECT_NAME" \ 166 | --wandb-api-key "$WANDB_API_KEY" \ 167 | --val-dir "/path/to/val" \ 168 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 169 | --exp-name "$VAL_NAME" \ 170 | --subset "imagenet-1k" \ 171 | --arch-name "conv4" \ 172 | --use-rand-augment \ 173 | --rand-augment-config "$RAND_AUG" \ 174 | --random-erasing-p 0.0 \ 175 | --min-scale-crops 0.25 \ 176 | --max-scale-crops 1.0 \ 177 | --ipc $((IPC)) \ 178 | --val-ipc 50 \ 179 | --stud-name "conv4" \ 180 | --re-epochs 300 \ 181 | --gpu-device $((GPU)) \ 182 | --seed $((SEED)) 183 | 184 | # ========================================================================================= 185 | 186 | IPC=1 187 | EXP_NAME="IPC$((IPC))_4K_medium" 188 | 189 | echo "$EXP_NAME" 190 | SEED=3407 191 | 192 | 193 | # VAL_NAME="$RAND_AUG" 194 | VAL_NAME="IPC$((IPC)) 4K Medium Conv4 S$((SEED))" 195 | 196 | wandb enabled 197 | wandb online 198 | python /path/to/DELT/evaluation/main.py \ 199 | --wandb-project "$PROJECT_NAME" \ 200 | --wandb-api-key "$WANDB_API_KEY" \ 201 | --val-dir "/path/to/val" \ 202 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 203 | --exp-name "$VAL_NAME" \ 204 | --subset "imagenet-1k" \ 205 | --arch-name "conv4" \ 206 | --use-rand-augment \ 207 | --rand-augment-config "$RAND_AUG" \ 208 | --random-erasing-p 0.0 \ 209 | --min-scale-crops 0.25 \ 210 | --max-scale-crops 1.0 \ 211 | --ipc $((IPC)) \ 212 | --val-ipc 50 \ 213 | --stud-name "conv4" \ 214 | --re-epochs 300 \ 215 | --gpu-device $((GPU)) \ 216 | --seed $((SEED)) 217 | 218 | # ========================================================================================= 219 | 220 | 221 | SEED=4663 222 | VAL_NAME="IPC$((IPC)) 4K Medium Conv4 S$((SEED))" 223 | 224 | wandb enabled 225 | wandb online 226 | python /path/to/DELT/evaluation/main.py \ 227 | --wandb-project "$PROJECT_NAME" \ 228 | --wandb-api-key "$WANDB_API_KEY" \ 229 | --val-dir "/path/to/val" \ 230 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 231 | --exp-name "$VAL_NAME" \ 232 | --subset "imagenet-1k" \ 233 | --arch-name "conv4" \ 234 | --use-rand-augment \ 235 | --rand-augment-config "$RAND_AUG" \ 236 | --random-erasing-p 0.0 \ 237 | --min-scale-crops 0.25 \ 238 | --max-scale-crops 1.0 \ 239 | --ipc $((IPC)) \ 240 | --val-ipc 50 \ 241 | --stud-name "conv4" \ 242 | --re-epochs 300 \ 243 | --gpu-device $((GPU)) \ 244 | --seed $((SEED)) 245 | 246 | # ========================================================================================= 247 | 248 | SEED=2897 249 | VAL_NAME="IPC$((IPC)) 4K Medium Conv4 S$((SEED))" 250 | 251 | wandb enabled 252 | wandb online 253 | python /path/to/DELT/evaluation/main.py \ 254 | --wandb-project "$PROJECT_NAME" \ 255 | --wandb-api-key "$WANDB_API_KEY" \ 256 | --val-dir "/path/to/val" \ 257 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 258 | --exp-name "$VAL_NAME" \ 259 | --subset "imagenet-1k" \ 260 | --arch-name "conv4" \ 261 | --use-rand-augment \ 262 | --rand-augment-config "$RAND_AUG" \ 263 | --random-erasing-p 0.0 \ 264 | --min-scale-crops 0.25 \ 265 | --max-scale-crops 1.0 \ 266 | --ipc $((IPC)) \ 267 | --val-ipc 50 \ 268 | --stud-name "conv4" \ 269 | --re-epochs 300 \ 270 | --gpu-device $((GPU)) \ 271 | --seed $((SEED)) 272 | 273 | # ========================================================================================= 274 | 275 | IPC=1 276 | EXP_NAME="IPC$((IPC))_3K_medium" 277 | 278 | echo "$EXP_NAME" 279 | SEED=3407 280 | 281 | 282 | # VAL_NAME="$RAND_AUG" 283 | VAL_NAME="IPC$((IPC)) 3K Medium Conv4 S$((SEED))" 284 | 285 | wandb enabled 286 | wandb online 287 | python /path/to/DELT/evaluation/main.py \ 288 | --wandb-project "$PROJECT_NAME" \ 289 | --wandb-api-key "$WANDB_API_KEY" \ 290 | --val-dir "/path/to/val" \ 291 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 292 | --exp-name "$VAL_NAME" \ 293 | --subset "imagenet-1k" \ 294 | --arch-name "conv4" \ 295 | --use-rand-augment \ 296 | --rand-augment-config "$RAND_AUG" \ 297 | --random-erasing-p 0.0 \ 298 | --min-scale-crops 0.25 \ 299 | --max-scale-crops 1.0 \ 300 | --ipc $((IPC)) \ 301 | --val-ipc 50 \ 302 | --stud-name "conv4" \ 303 | --re-epochs 300 \ 304 | --gpu-device $((GPU)) \ 305 | --seed $((SEED)) 306 | 307 | # ========================================================================================= 308 | 309 | 310 | SEED=4663 311 | VAL_NAME="IPC$((IPC)) 3K Medium Conv4 S$((SEED))" 312 | 313 | wandb enabled 314 | wandb online 315 | python /path/to/DELT/evaluation/main.py \ 316 | --wandb-project "$PROJECT_NAME" \ 317 | --wandb-api-key "$WANDB_API_KEY" \ 318 | --val-dir "/path/to/val" \ 319 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 320 | --exp-name "$VAL_NAME" \ 321 | --subset "imagenet-1k" \ 322 | --arch-name "conv4" \ 323 | --use-rand-augment \ 324 | --rand-augment-config "$RAND_AUG" \ 325 | --random-erasing-p 0.0 \ 326 | --min-scale-crops 0.25 \ 327 | --max-scale-crops 1.0 \ 328 | --ipc $((IPC)) \ 329 | --val-ipc 50 \ 330 | --stud-name "conv4" \ 331 | --re-epochs 300 \ 332 | --gpu-device $((GPU)) \ 333 | --seed $((SEED)) 334 | 335 | # ========================================================================================= 336 | 337 | SEED=2897 338 | VAL_NAME="IPC$((IPC)) 3K Medium Conv4 S$((SEED))" 339 | 340 | wandb enabled 341 | wandb online 342 | python /path/to/DELT/evaluation/main.py \ 343 | --wandb-project "$PROJECT_NAME" \ 344 | --wandb-api-key "$WANDB_API_KEY" \ 345 | --val-dir "/path/to/val" \ 346 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 347 | --exp-name "$VAL_NAME" \ 348 | --subset "imagenet-1k" \ 349 | --arch-name "conv4" \ 350 | --use-rand-augment \ 351 | --rand-augment-config "$RAND_AUG" \ 352 | --random-erasing-p 0.0 \ 353 | --min-scale-crops 0.25 \ 354 | --max-scale-crops 1.0 \ 355 | --ipc $((IPC)) \ 356 | --val-ipc 50 \ 357 | --stud-name "conv4" \ 358 | --re-epochs 300 \ 359 | --gpu-device $((GPU)) \ 360 | --seed $((SEED)) 361 | 362 | # ========================================================================================= 363 | 364 | IPC=1 365 | EXP_NAME="IPC$((IPC))_2K_medium" 366 | 367 | echo "$EXP_NAME" 368 | SEED=3407 369 | 370 | 371 | # VAL_NAME="$RAND_AUG" 372 | VAL_NAME="IPC$((IPC)) 2K Medium Conv4 S$((SEED))" 373 | 374 | wandb enabled 375 | wandb online 376 | python /path/to/DELT/evaluation/main.py \ 377 | --wandb-project "$PROJECT_NAME" \ 378 | --wandb-api-key "$WANDB_API_KEY" \ 379 | --val-dir "/path/to/val" \ 380 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 381 | --exp-name "$VAL_NAME" \ 382 | --subset "imagenet-1k" \ 383 | --arch-name "conv4" \ 384 | --use-rand-augment \ 385 | --rand-augment-config "$RAND_AUG" \ 386 | --random-erasing-p 0.0 \ 387 | --min-scale-crops 0.25 \ 388 | --max-scale-crops 1.0 \ 389 | --ipc $((IPC)) \ 390 | --val-ipc 50 \ 391 | --stud-name "conv4" \ 392 | --re-epochs 300 \ 393 | --gpu-device $((GPU)) \ 394 | --seed $((SEED)) 395 | 396 | # ========================================================================================= 397 | 398 | 399 | SEED=4663 400 | VAL_NAME="IPC$((IPC)) 2K Medium Conv4 S$((SEED))" 401 | 402 | wandb enabled 403 | wandb online 404 | python /path/to/DELT/evaluation/main.py \ 405 | --wandb-project "$PROJECT_NAME" \ 406 | --wandb-api-key "$WANDB_API_KEY" \ 407 | --val-dir "/path/to/val" \ 408 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 409 | --exp-name "$VAL_NAME" \ 410 | --subset "imagenet-1k" \ 411 | --arch-name "conv4" \ 412 | --use-rand-augment \ 413 | --rand-augment-config "$RAND_AUG" \ 414 | --random-erasing-p 0.0 \ 415 | --min-scale-crops 0.25 \ 416 | --max-scale-crops 1.0 \ 417 | --ipc $((IPC)) \ 418 | --val-ipc 50 \ 419 | --stud-name "conv4" \ 420 | --re-epochs 300 \ 421 | --gpu-device $((GPU)) \ 422 | --seed $((SEED)) 423 | 424 | # ========================================================================================= 425 | 426 | SEED=2897 427 | VAL_NAME="IPC$((IPC)) 2K Medium Conv4 S$((SEED))" 428 | 429 | wandb enabled 430 | wandb online 431 | python /path/to/DELT/evaluation/main.py \ 432 | --wandb-project "$PROJECT_NAME" \ 433 | --wandb-api-key "$WANDB_API_KEY" \ 434 | --val-dir "/path/to/val" \ 435 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 436 | --exp-name "$VAL_NAME" \ 437 | --subset "imagenet-1k" \ 438 | --arch-name "conv4" \ 439 | --use-rand-augment \ 440 | --rand-augment-config "$RAND_AUG" \ 441 | --random-erasing-p 0.0 \ 442 | --min-scale-crops 0.25 \ 443 | --max-scale-crops 1.0 \ 444 | --ipc $((IPC)) \ 445 | --val-ipc 50 \ 446 | --stud-name "conv4" \ 447 | --re-epochs 300 \ 448 | --gpu-device $((GPU)) \ 449 | --seed $((SEED)) 450 | 451 | # ========================================================================================= 452 | 453 | IPC=1 454 | EXP_NAME="IPC$((IPC))_1K_medium" 455 | 456 | echo "$EXP_NAME" 457 | SEED=3407 458 | 459 | 460 | # VAL_NAME="$RAND_AUG" 461 | VAL_NAME="IPC$((IPC)) 1K Medium Conv4 S$((SEED))" 462 | 463 | wandb enabled 464 | wandb online 465 | python /path/to/DELT/evaluation/main.py \ 466 | --wandb-project "$PROJECT_NAME" \ 467 | --wandb-api-key "$WANDB_API_KEY" \ 468 | --val-dir "/path/to/val" \ 469 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 470 | --exp-name "$VAL_NAME" \ 471 | --subset "imagenet-1k" \ 472 | --arch-name "conv4" \ 473 | --use-rand-augment \ 474 | --rand-augment-config "$RAND_AUG" \ 475 | --random-erasing-p 0.0 \ 476 | --min-scale-crops 0.25 \ 477 | --max-scale-crops 1.0 \ 478 | --ipc $((IPC)) \ 479 | --val-ipc 50 \ 480 | --stud-name "conv4" \ 481 | --re-epochs 300 \ 482 | --gpu-device $((GPU)) \ 483 | --seed $((SEED)) 484 | 485 | # ========================================================================================= 486 | 487 | 488 | SEED=4663 489 | VAL_NAME="IPC$((IPC)) 1K Medium Conv4 S$((SEED))" 490 | 491 | wandb enabled 492 | wandb online 493 | python /path/to/DELT/evaluation/main.py \ 494 | --wandb-project "$PROJECT_NAME" \ 495 | --wandb-api-key "$WANDB_API_KEY" \ 496 | --val-dir "/path/to/val" \ 497 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 498 | --exp-name "$VAL_NAME" \ 499 | --subset "imagenet-1k" \ 500 | --arch-name "conv4" \ 501 | --use-rand-augment \ 502 | --rand-augment-config "$RAND_AUG" \ 503 | --random-erasing-p 0.0 \ 504 | --min-scale-crops 0.25 \ 505 | --max-scale-crops 1.0 \ 506 | --ipc $((IPC)) \ 507 | --val-ipc 50 \ 508 | --stud-name "conv4" \ 509 | --re-epochs 300 \ 510 | --gpu-device $((GPU)) \ 511 | --seed $((SEED)) 512 | 513 | # ========================================================================================= 514 | 515 | SEED=2897 516 | VAL_NAME="IPC$((IPC)) 1K Medium Conv4 S$((SEED))" 517 | 518 | wandb enabled 519 | wandb online 520 | python /path/to/DELT/evaluation/main.py \ 521 | --wandb-project "$PROJECT_NAME" \ 522 | --wandb-api-key "$WANDB_API_KEY" \ 523 | --val-dir "/path/to/val" \ 524 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 525 | --exp-name "$VAL_NAME" \ 526 | --subset "imagenet-1k" \ 527 | --arch-name "conv4" \ 528 | --use-rand-augment \ 529 | --rand-augment-config "$RAND_AUG" \ 530 | --random-erasing-p 0.0 \ 531 | --min-scale-crops 0.25 \ 532 | --max-scale-crops 1.0 \ 533 | --ipc $((IPC)) \ 534 | --val-ipc 50 \ 535 | --stud-name "conv4" \ 536 | --re-epochs 300 \ 537 | --gpu-device $((GPU)) \ 538 | --seed $((SEED)) 539 | 540 | # ========================================================================================= 541 | -------------------------------------------------------------------------------- /scripts/resnet18_imagenet1k_synthesis.sh: -------------------------------------------------------------------------------- 1 | GPU=0 2 | 3 | # ========================================================================================= 4 | 5 | SYN_PATH="/path/to/synthesized/imagenet_1k/" 6 | INIT_PATH="/path/to/initialization/medium_prob" 7 | IPC=50 8 | ITR=4000 9 | ROUND_ITR=500 10 | EXP_NAME="IPC$((IPC))_4K_500_medium" 11 | 12 | echo "$EXP_NAME" 13 | 14 | python /path/to/DELT/recover/recover.py \ 15 | --init-data-path "$INIT_PATH" \ 16 | --syn-data-path "$SYN_PATH" \ 17 | --arch-name "resnet18" \ 18 | --dataset "imagenet-1k" \ 19 | --exp-name "$EXP_NAME" \ 20 | --use-early-late \ 21 | --round-iterations $((ROUND_ITR)) \ 22 | --batch-size 100 \ 23 | --lr 0.25 \ 24 | --r-bn 0.01 \ 25 | --gpu-device $((GPU)) \ 26 | --iteration $((ITR)) \ 27 | --jitter 0\ 28 | --easy2hard-mode "cosine" --milestone 1 \ 29 | --ipc $((IPC)) --store-best-images 30 | 31 | echo "Synthesis -> DONE" 32 | 33 | # ========================================================================================= 34 | 35 | IPC=10 36 | ITR=4000 37 | ROUND_ITR=500 38 | EXP_NAME="IPC$((IPC))_4K_500_medium" 39 | 40 | echo "$EXP_NAME" 41 | 42 | python /path/to/DELT/recover/recover.py \ 43 | --init-data-path "$INIT_PATH" \ 44 | --syn-data-path "$SYN_PATH" \ 45 | --arch-name "resnet18" \ 46 | --dataset "imagenet-1k" \ 47 | --exp-name "$EXP_NAME" \ 48 | --use-early-late \ 49 | --round-iterations $((ROUND_ITR)) \ 50 | --batch-size 100 \ 51 | --lr 0.25 \ 52 | --r-bn 0.01 \ 53 | --gpu-device $((GPU)) \ 54 | --iteration $((ITR)) \ 55 | --jitter 0\ 56 | --easy2hard-mode "cosine" --milestone 1 \ 57 | --ipc $((IPC)) --store-best-images 58 | 59 | echo "Synthesis -> DONE" 60 | 61 | # ========================================================================================= 62 | 63 | IPC=1 64 | ITR=4000 65 | EXP_NAME="IPC$((IPC))_4K_medium" 66 | 67 | echo "$EXP_NAME" 68 | 69 | python /path/to/DELT/recover/recover.py \ 70 | --init-data-path "$INIT_PATH" \ 71 | --syn-data-path "$SYN_PATH" \ 72 | --arch-name "resnet18" \ 73 | --dataset "imagenet-1k" \ 74 | --exp-name "$EXP_NAME" \ 75 | --batch-size 100 \ 76 | --lr 0.25 \ 77 | --r-bn 0.01 \ 78 | --gpu-device $((GPU)) \ 79 | --iteration $((ITR)) \ 80 | --jitter 0\ 81 | --easy2hard-mode "cosine" --milestone 1 \ 82 | --ipc $((IPC)) --store-best-images 83 | 84 | echo "Synthesis -> DONE" 85 | 86 | # ========================================================================================= 87 | 88 | IPC=1 89 | ITR=3000 90 | EXP_NAME="IPC$((IPC))_3K_medium" 91 | 92 | echo "$EXP_NAME" 93 | 94 | python /path/to/DELT/recover/recover.py \ 95 | --init-data-path "$INIT_PATH" \ 96 | --syn-data-path "$SYN_PATH" \ 97 | --arch-name "resnet18" \ 98 | --dataset "imagenet-1k" \ 99 | --exp-name "$EXP_NAME" \ 100 | --batch-size 100 \ 101 | --lr 0.25 \ 102 | --r-bn 0.01 \ 103 | --gpu-device $((GPU)) \ 104 | --iteration $((ITR)) \ 105 | --jitter 0\ 106 | --easy2hard-mode "cosine" --milestone 1 \ 107 | --ipc $((IPC)) --store-best-images 108 | 109 | echo "Synthesis -> DONE" 110 | 111 | # ========================================================================================= 112 | 113 | IPC=1 114 | ITR=2000 115 | EXP_NAME="IPC$((IPC))_2K_medium" 116 | 117 | echo "$EXP_NAME" 118 | 119 | python /path/to/DELT/recover/recover.py \ 120 | --init-data-path "$INIT_PATH" \ 121 | --syn-data-path "$SYN_PATH" \ 122 | --arch-name "resnet18" \ 123 | --dataset "imagenet-1k" \ 124 | --exp-name "$EXP_NAME" \ 125 | --batch-size 100 \ 126 | --lr 0.25 \ 127 | --r-bn 0.01 \ 128 | --gpu-device $((GPU)) \ 129 | --iteration $((ITR)) \ 130 | --jitter 0\ 131 | --easy2hard-mode "cosine" --milestone 1 \ 132 | --ipc $((IPC)) --store-best-images 133 | 134 | echo "Synthesis -> DONE" 135 | 136 | # ========================================================================================= 137 | 138 | IPC=1 139 | ITR=1000 140 | EXP_NAME="IPC$((IPC))_1K_medium" 141 | 142 | echo "$EXP_NAME" 143 | 144 | python /path/to/DELT/recover/recover.py \ 145 | --init-data-path "$INIT_PATH" \ 146 | --syn-data-path "$SYN_PATH" \ 147 | --arch-name "resnet18" \ 148 | --dataset "imagenet-1k" \ 149 | --exp-name "$EXP_NAME" \ 150 | --batch-size 100 \ 151 | --lr 0.25 \ 152 | --r-bn 0.01 \ 153 | --gpu-device $((GPU)) \ 154 | --iteration $((ITR)) \ 155 | --jitter 0\ 156 | --easy2hard-mode "cosine" --milestone 1 \ 157 | --ipc $((IPC)) --store-best-images 158 | 159 | echo "Synthesis -> DONE" 160 | 161 | # ========================================================================================= 162 | -------------------------------------------------------------------------------- /scripts/resnet18_imagenet1k_validation.sh: -------------------------------------------------------------------------------- 1 | GPU=0 2 | 3 | # ========================================================================================= 4 | 5 | SYN_PATH="/path/to/synthesized/imagenet_1k/" 6 | PROJECT_NAME="Imagenet1K-Seeds" 7 | IPC=50 8 | EXP_NAME="IPC$((IPC))_4K_500_medium" 9 | WANDB_API_KEY="write your api here" 10 | 11 | echo "$EXP_NAME" 12 | SEED=3407 13 | 14 | RAND_AUG="rand-m6-n2-mstd1.0" 15 | # VAL_NAME="$RAND_AUG" 16 | VAL_NAME="IPC$((IPC)) 4K_500 Medium R18 S$((SEED))" 17 | 18 | wandb enabled 19 | wandb online 20 | python /path/to/DELT/evaluation/main.py \ 21 | --wandb-project "$PROJECT_NAME" \ 22 | --wandb-api-key "$WANDB_API_KEY" \ 23 | --val-dir "/path/to/val" \ 24 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 25 | --exp-name "$VAL_NAME" \ 26 | --subset "imagenet-1k" \ 27 | --arch-name "resnet18" \ 28 | --use-rand-augment \ 29 | --rand-augment-config "$RAND_AUG" \ 30 | --random-erasing-p 0.0 \ 31 | --min-scale-crops 0.25 \ 32 | --max-scale-crops 1.0 \ 33 | --ipc $((IPC)) \ 34 | --val-ipc 50 \ 35 | --stud-name "resnet18" \ 36 | --re-epochs 300 \ 37 | --gpu-device $((GPU)) \ 38 | --seed $((SEED)) 39 | 40 | # ========================================================================================= 41 | 42 | 43 | SEED=4663 44 | VAL_NAME="IPC$((IPC)) 4K_500 Medium R18 S$((SEED))" 45 | 46 | wandb enabled 47 | wandb online 48 | python /path/to/DELT/evaluation/main.py \ 49 | --wandb-project "$PROJECT_NAME" \ 50 | --wandb-api-key "$WANDB_API_KEY" \ 51 | --val-dir "/path/to/val" \ 52 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 53 | --exp-name "$VAL_NAME" \ 54 | --subset "imagenet-1k" \ 55 | --arch-name "resnet18" \ 56 | --use-rand-augment \ 57 | --rand-augment-config "$RAND_AUG" \ 58 | --random-erasing-p 0.0 \ 59 | --min-scale-crops 0.25 \ 60 | --max-scale-crops 1.0 \ 61 | --ipc $((IPC)) \ 62 | --val-ipc 50 \ 63 | --stud-name "resnet18" \ 64 | --re-epochs 300 \ 65 | --gpu-device $((GPU)) \ 66 | --seed $((SEED)) 67 | 68 | # ========================================================================================= 69 | 70 | SEED=2897 71 | VAL_NAME="IPC$((IPC)) 4K_500 Medium R18 S$((SEED))" 72 | 73 | wandb enabled 74 | wandb online 75 | python /path/to/DELT/evaluation/main.py \ 76 | --wandb-project "$PROJECT_NAME" \ 77 | --wandb-api-key "$WANDB_API_KEY" \ 78 | --val-dir "/path/to/val" \ 79 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 80 | --exp-name "$VAL_NAME" \ 81 | --subset "imagenet-1k" \ 82 | --arch-name "resnet18" \ 83 | --use-rand-augment \ 84 | --rand-augment-config "$RAND_AUG" \ 85 | --random-erasing-p 0.0 \ 86 | --min-scale-crops 0.25 \ 87 | --max-scale-crops 1.0 \ 88 | --ipc $((IPC)) \ 89 | --val-ipc 50 \ 90 | --stud-name "resnet18" \ 91 | --re-epochs 300 \ 92 | --gpu-device $((GPU)) \ 93 | --seed $((SEED)) 94 | 95 | # ========================================================================================= 96 | 97 | IPC=10 98 | EXP_NAME="IPC$((IPC))_4K_500_medium" 99 | 100 | echo "$EXP_NAME" 101 | SEED=3407 102 | 103 | 104 | # VAL_NAME="$RAND_AUG" 105 | VAL_NAME="IPC$((IPC)) 4K_500 Medium R18 S$((SEED))" 106 | 107 | wandb enabled 108 | wandb online 109 | python /path/to/DELT/evaluation/main.py \ 110 | --wandb-project "$PROJECT_NAME" \ 111 | --wandb-api-key "$WANDB_API_KEY" \ 112 | --val-dir "/path/to/val" \ 113 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 114 | --exp-name "$VAL_NAME" \ 115 | --subset "imagenet-1k" \ 116 | --arch-name "resnet18" \ 117 | --use-rand-augment \ 118 | --rand-augment-config "$RAND_AUG" \ 119 | --random-erasing-p 0.0 \ 120 | --min-scale-crops 0.25 \ 121 | --max-scale-crops 1.0 \ 122 | --ipc $((IPC)) \ 123 | --val-ipc 50 \ 124 | --stud-name "resnet18" \ 125 | --re-epochs 300 \ 126 | --gpu-device $((GPU)) \ 127 | --seed $((SEED)) 128 | 129 | # ========================================================================================= 130 | 131 | 132 | SEED=4663 133 | VAL_NAME="IPC$((IPC)) 4K_500 Medium R18 S$((SEED))" 134 | 135 | wandb enabled 136 | wandb online 137 | python /path/to/DELT/evaluation/main.py \ 138 | --wandb-project "$PROJECT_NAME" \ 139 | --wandb-api-key "$WANDB_API_KEY" \ 140 | --val-dir "/path/to/val" \ 141 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 142 | --exp-name "$VAL_NAME" \ 143 | --subset "imagenet-1k" \ 144 | --arch-name "resnet18" \ 145 | --use-rand-augment \ 146 | --rand-augment-config "$RAND_AUG" \ 147 | --random-erasing-p 0.0 \ 148 | --min-scale-crops 0.25 \ 149 | --max-scale-crops 1.0 \ 150 | --ipc $((IPC)) \ 151 | --val-ipc 50 \ 152 | --stud-name "resnet18" \ 153 | --re-epochs 300 \ 154 | --gpu-device $((GPU)) \ 155 | --seed $((SEED)) 156 | 157 | # ========================================================================================= 158 | 159 | SEED=2897 160 | VAL_NAME="IPC$((IPC)) 4K_500 Medium R18 S$((SEED))" 161 | 162 | wandb enabled 163 | wandb online 164 | python /path/to/DELT/evaluation/main.py \ 165 | --wandb-project "$PROJECT_NAME" \ 166 | --wandb-api-key "$WANDB_API_KEY" \ 167 | --val-dir "/path/to/val" \ 168 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 169 | --exp-name "$VAL_NAME" \ 170 | --subset "imagenet-1k" \ 171 | --arch-name "resnet18" \ 172 | --use-rand-augment \ 173 | --rand-augment-config "$RAND_AUG" \ 174 | --random-erasing-p 0.0 \ 175 | --min-scale-crops 0.25 \ 176 | --max-scale-crops 1.0 \ 177 | --ipc $((IPC)) \ 178 | --val-ipc 50 \ 179 | --stud-name "resnet18" \ 180 | --re-epochs 300 \ 181 | --gpu-device $((GPU)) \ 182 | --seed $((SEED)) 183 | 184 | # ========================================================================================= 185 | 186 | IPC=1 187 | EXP_NAME="IPC$((IPC))_4K_medium" 188 | 189 | echo "$EXP_NAME" 190 | SEED=3407 191 | 192 | 193 | # VAL_NAME="$RAND_AUG" 194 | VAL_NAME="IPC$((IPC)) 4K Medium R18 S$((SEED))" 195 | 196 | wandb enabled 197 | wandb online 198 | python /path/to/DELT/evaluation/main.py \ 199 | --wandb-project "$PROJECT_NAME" \ 200 | --wandb-api-key "$WANDB_API_KEY" \ 201 | --val-dir "/path/to/val" \ 202 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 203 | --exp-name "$VAL_NAME" \ 204 | --subset "imagenet-1k" \ 205 | --arch-name "resnet18" \ 206 | --use-rand-augment \ 207 | --rand-augment-config "$RAND_AUG" \ 208 | --random-erasing-p 0.0 \ 209 | --min-scale-crops 0.25 \ 210 | --max-scale-crops 1.0 \ 211 | --ipc $((IPC)) \ 212 | --val-ipc 50 \ 213 | --stud-name "resnet18" \ 214 | --re-epochs 300 \ 215 | --gpu-device $((GPU)) \ 216 | --seed $((SEED)) 217 | 218 | # ========================================================================================= 219 | 220 | 221 | SEED=4663 222 | VAL_NAME="IPC$((IPC)) 4K Medium R18 S$((SEED))" 223 | 224 | wandb enabled 225 | wandb online 226 | python /path/to/DELT/evaluation/main.py \ 227 | --wandb-project "$PROJECT_NAME" \ 228 | --wandb-api-key "$WANDB_API_KEY" \ 229 | --val-dir "/path/to/val" \ 230 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 231 | --exp-name "$VAL_NAME" \ 232 | --subset "imagenet-1k" \ 233 | --arch-name "resnet18" \ 234 | --use-rand-augment \ 235 | --rand-augment-config "$RAND_AUG" \ 236 | --random-erasing-p 0.0 \ 237 | --min-scale-crops 0.25 \ 238 | --max-scale-crops 1.0 \ 239 | --ipc $((IPC)) \ 240 | --val-ipc 50 \ 241 | --stud-name "resnet18" \ 242 | --re-epochs 300 \ 243 | --gpu-device $((GPU)) \ 244 | --seed $((SEED)) 245 | 246 | # ========================================================================================= 247 | 248 | SEED=2897 249 | VAL_NAME="IPC$((IPC)) 4K Medium R18 S$((SEED))" 250 | 251 | wandb enabled 252 | wandb online 253 | python /path/to/DELT/evaluation/main.py \ 254 | --wandb-project "$PROJECT_NAME" \ 255 | --wandb-api-key "$WANDB_API_KEY" \ 256 | --val-dir "/path/to/val" \ 257 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 258 | --exp-name "$VAL_NAME" \ 259 | --subset "imagenet-1k" \ 260 | --arch-name "resnet18" \ 261 | --use-rand-augment \ 262 | --rand-augment-config "$RAND_AUG" \ 263 | --random-erasing-p 0.0 \ 264 | --min-scale-crops 0.25 \ 265 | --max-scale-crops 1.0 \ 266 | --ipc $((IPC)) \ 267 | --val-ipc 50 \ 268 | --stud-name "resnet18" \ 269 | --re-epochs 300 \ 270 | --gpu-device $((GPU)) \ 271 | --seed $((SEED)) 272 | 273 | # ========================================================================================= 274 | 275 | IPC=1 276 | EXP_NAME="IPC$((IPC))_3K_medium" 277 | 278 | echo "$EXP_NAME" 279 | SEED=3407 280 | 281 | 282 | # VAL_NAME="$RAND_AUG" 283 | VAL_NAME="IPC$((IPC)) 3K Medium R18 S$((SEED))" 284 | 285 | wandb enabled 286 | wandb online 287 | python /path/to/DELT/evaluation/main.py \ 288 | --wandb-project "$PROJECT_NAME" \ 289 | --wandb-api-key "$WANDB_API_KEY" \ 290 | --val-dir "/path/to/val" \ 291 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 292 | --exp-name "$VAL_NAME" \ 293 | --subset "imagenet-1k" \ 294 | --arch-name "resnet18" \ 295 | --use-rand-augment \ 296 | --rand-augment-config "$RAND_AUG" \ 297 | --random-erasing-p 0.0 \ 298 | --min-scale-crops 0.25 \ 299 | --max-scale-crops 1.0 \ 300 | --ipc $((IPC)) \ 301 | --val-ipc 50 \ 302 | --stud-name "resnet18" \ 303 | --re-epochs 300 \ 304 | --gpu-device $((GPU)) \ 305 | --seed $((SEED)) 306 | 307 | # ========================================================================================= 308 | 309 | 310 | SEED=4663 311 | VAL_NAME="IPC$((IPC)) 3K Medium R18 S$((SEED))" 312 | 313 | wandb enabled 314 | wandb online 315 | python /path/to/DELT/evaluation/main.py \ 316 | --wandb-project "$PROJECT_NAME" \ 317 | --wandb-api-key "$WANDB_API_KEY" \ 318 | --val-dir "/path/to/val" \ 319 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 320 | --exp-name "$VAL_NAME" \ 321 | --subset "imagenet-1k" \ 322 | --arch-name "resnet18" \ 323 | --use-rand-augment \ 324 | --rand-augment-config "$RAND_AUG" \ 325 | --random-erasing-p 0.0 \ 326 | --min-scale-crops 0.25 \ 327 | --max-scale-crops 1.0 \ 328 | --ipc $((IPC)) \ 329 | --val-ipc 50 \ 330 | --stud-name "resnet18" \ 331 | --re-epochs 300 \ 332 | --gpu-device $((GPU)) \ 333 | --seed $((SEED)) 334 | 335 | # ========================================================================================= 336 | 337 | SEED=2897 338 | VAL_NAME="IPC$((IPC)) 3K Medium R18 S$((SEED))" 339 | 340 | wandb enabled 341 | wandb online 342 | python /path/to/DELT/evaluation/main.py \ 343 | --wandb-project "$PROJECT_NAME" \ 344 | --wandb-api-key "$WANDB_API_KEY" \ 345 | --val-dir "/path/to/val" \ 346 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 347 | --exp-name "$VAL_NAME" \ 348 | --subset "imagenet-1k" \ 349 | --arch-name "resnet18" \ 350 | --use-rand-augment \ 351 | --rand-augment-config "$RAND_AUG" \ 352 | --random-erasing-p 0.0 \ 353 | --min-scale-crops 0.25 \ 354 | --max-scale-crops 1.0 \ 355 | --ipc $((IPC)) \ 356 | --val-ipc 50 \ 357 | --stud-name "resnet18" \ 358 | --re-epochs 300 \ 359 | --gpu-device $((GPU)) \ 360 | --seed $((SEED)) 361 | 362 | # ========================================================================================= 363 | 364 | IPC=1 365 | EXP_NAME="IPC$((IPC))_2K_medium" 366 | 367 | echo "$EXP_NAME" 368 | SEED=3407 369 | 370 | 371 | # VAL_NAME="$RAND_AUG" 372 | VAL_NAME="IPC$((IPC)) 2K Medium R18 S$((SEED))" 373 | 374 | wandb enabled 375 | wandb online 376 | python /path/to/DELT/evaluation/main.py \ 377 | --wandb-project "$PROJECT_NAME" \ 378 | --wandb-api-key "$WANDB_API_KEY" \ 379 | --val-dir "/path/to/val" \ 380 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 381 | --exp-name "$VAL_NAME" \ 382 | --subset "imagenet-1k" \ 383 | --arch-name "resnet18" \ 384 | --use-rand-augment \ 385 | --rand-augment-config "$RAND_AUG" \ 386 | --random-erasing-p 0.0 \ 387 | --min-scale-crops 0.25 \ 388 | --max-scale-crops 1.0 \ 389 | --ipc $((IPC)) \ 390 | --val-ipc 50 \ 391 | --stud-name "resnet18" \ 392 | --re-epochs 300 \ 393 | --gpu-device $((GPU)) \ 394 | --seed $((SEED)) 395 | 396 | # ========================================================================================= 397 | 398 | 399 | SEED=4663 400 | VAL_NAME="IPC$((IPC)) 2K Medium R18 S$((SEED))" 401 | 402 | wandb enabled 403 | wandb online 404 | python /path/to/DELT/evaluation/main.py \ 405 | --wandb-project "$PROJECT_NAME" \ 406 | --wandb-api-key "$WANDB_API_KEY" \ 407 | --val-dir "/path/to/val" \ 408 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 409 | --exp-name "$VAL_NAME" \ 410 | --subset "imagenet-1k" \ 411 | --arch-name "resnet18" \ 412 | --use-rand-augment \ 413 | --rand-augment-config "$RAND_AUG" \ 414 | --random-erasing-p 0.0 \ 415 | --min-scale-crops 0.25 \ 416 | --max-scale-crops 1.0 \ 417 | --ipc $((IPC)) \ 418 | --val-ipc 50 \ 419 | --stud-name "resnet18" \ 420 | --re-epochs 300 \ 421 | --gpu-device $((GPU)) \ 422 | --seed $((SEED)) 423 | 424 | # ========================================================================================= 425 | 426 | SEED=2897 427 | VAL_NAME="IPC$((IPC)) 2K Medium R18 S$((SEED))" 428 | 429 | wandb enabled 430 | wandb online 431 | python /path/to/DELT/evaluation/main.py \ 432 | --wandb-project "$PROJECT_NAME" \ 433 | --wandb-api-key "$WANDB_API_KEY" \ 434 | --val-dir "/path/to/val" \ 435 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 436 | --exp-name "$VAL_NAME" \ 437 | --subset "imagenet-1k" \ 438 | --arch-name "resnet18" \ 439 | --use-rand-augment \ 440 | --rand-augment-config "$RAND_AUG" \ 441 | --random-erasing-p 0.0 \ 442 | --min-scale-crops 0.25 \ 443 | --max-scale-crops 1.0 \ 444 | --ipc $((IPC)) \ 445 | --val-ipc 50 \ 446 | --stud-name "resnet18" \ 447 | --re-epochs 300 \ 448 | --gpu-device $((GPU)) \ 449 | --seed $((SEED)) 450 | 451 | # ========================================================================================= 452 | 453 | IPC=1 454 | EXP_NAME="IPC$((IPC))_1K_medium" 455 | 456 | echo "$EXP_NAME" 457 | SEED=3407 458 | 459 | 460 | # VAL_NAME="$RAND_AUG" 461 | VAL_NAME="IPC$((IPC)) 1K Medium R18 S$((SEED))" 462 | 463 | wandb enabled 464 | wandb online 465 | python /path/to/DELT/evaluation/main.py \ 466 | --wandb-project "$PROJECT_NAME" \ 467 | --wandb-api-key "$WANDB_API_KEY" \ 468 | --val-dir "/path/to/val" \ 469 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 470 | --exp-name "$VAL_NAME" \ 471 | --subset "imagenet-1k" \ 472 | --arch-name "resnet18" \ 473 | --use-rand-augment \ 474 | --rand-augment-config "$RAND_AUG" \ 475 | --random-erasing-p 0.0 \ 476 | --min-scale-crops 0.25 \ 477 | --max-scale-crops 1.0 \ 478 | --ipc $((IPC)) \ 479 | --val-ipc 50 \ 480 | --stud-name "resnet18" \ 481 | --re-epochs 300 \ 482 | --gpu-device $((GPU)) \ 483 | --seed $((SEED)) 484 | 485 | # ========================================================================================= 486 | 487 | 488 | SEED=4663 489 | VAL_NAME="IPC$((IPC)) 1K Medium R18 S$((SEED))" 490 | 491 | wandb enabled 492 | wandb online 493 | python /path/to/DELT/evaluation/main.py \ 494 | --wandb-project "$PROJECT_NAME" \ 495 | --wandb-api-key "$WANDB_API_KEY" \ 496 | --val-dir "/path/to/val" \ 497 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 498 | --exp-name "$VAL_NAME" \ 499 | --subset "imagenet-1k" \ 500 | --arch-name "resnet18" \ 501 | --use-rand-augment \ 502 | --rand-augment-config "$RAND_AUG" \ 503 | --random-erasing-p 0.0 \ 504 | --min-scale-crops 0.25 \ 505 | --max-scale-crops 1.0 \ 506 | --ipc $((IPC)) \ 507 | --val-ipc 50 \ 508 | --stud-name "resnet18" \ 509 | --re-epochs 300 \ 510 | --gpu-device $((GPU)) \ 511 | --seed $((SEED)) 512 | 513 | # ========================================================================================= 514 | 515 | SEED=2897 516 | VAL_NAME="IPC$((IPC)) 1K Medium R18 S$((SEED))" 517 | 518 | wandb enabled 519 | wandb online 520 | python /path/to/DELT/evaluation/main.py \ 521 | --wandb-project "$PROJECT_NAME" \ 522 | --wandb-api-key "$WANDB_API_KEY" \ 523 | --val-dir "/path/to/val" \ 524 | --syn-data-path "$SYN_PATH$EXP_NAME" \ 525 | --exp-name "$VAL_NAME" \ 526 | --subset "imagenet-1k" \ 527 | --arch-name "resnet18" \ 528 | --use-rand-augment \ 529 | --rand-augment-config "$RAND_AUG" \ 530 | --random-erasing-p 0.0 \ 531 | --min-scale-crops 0.25 \ 532 | --max-scale-crops 1.0 \ 533 | --ipc $((IPC)) \ 534 | --val-ipc 50 \ 535 | --stud-name "resnet18" \ 536 | --re-epochs 300 \ 537 | --gpu-device $((GPU)) \ 538 | --seed $((SEED)) 539 | 540 | # ========================================================================================= 541 | -------------------------------------------------------------------------------- /scripts/select_medium_tiny_rn18_ep50_ipc50.sh: -------------------------------------------------------------------------------- 1 | TRAIN_DIR="/path/to/tiny-imagenet/train" 2 | OUTPUT_DIR="/path/to/ranked/tiny_imagenet_medium" 3 | RANKER_PATH="/path/to/model/tinyimagenet_resnet18_modified.pth" 4 | RANKING_FILE="/path/to/rankings_csv/tiny_imagenet.csv" 5 | # Download the model 6 | curl "https://drive.usercontent.google.com/download?id={1h_Enp0_FlgxCED-oriPuyYbmonYwxIi9}&confirm=xxx" -o "$RANKER_PATH" 7 | 8 | python /path/to/data_selection.py \ 9 | --dataset "tiny-imagenet" \ 10 | --data-path "$TRAIN_DIR" \ 11 | --output-path "$OUTPUT_DIR" \ 12 | --ranker-path "$RANKER_PATH" \ 13 | --store-rank-file \ 14 | --ranker-arch "resnet18" \ 15 | --ranking-file "$RANKING_FILE" \ 16 | --selection-criteria "medium" \ 17 | --ipc 50 \ 18 | --gpu-device 0 19 | --------------------------------------------------------------------------------