├── README.md
├── data_selection.py
├── evaluation
├── argument.py
├── main.py
├── model_utils
│ ├── models.py
│ └── utils.py
└── validation
│ ├── main.py
│ └── utils.py
├── imgs
├── DELT_framework.png
├── diversity_comparison.png
├── results.png
└── samples.png
├── recover
├── models.py
├── recover.py
└── utils.py
└── scripts
├── conv4_imagenet1k_synthesis.sh
├── conv4_imagenet1k_validation.sh
├── resnet18_imagenet1k_synthesis.sh
├── resnet18_imagenet1k_validation.sh
└── select_medium_tiny_rn18_ep50_ipc50.sh
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # DELT: A Simple Diversity-driven EarlyLate Training for Dataset Distillation (CVPR 2025)
4 | [[arXiv](https://arxiv.org/abs/2411.19946)] [[Distilled Dataset](https://drive.google.com/file/d/1Rr_ik94FNte75yc4GiKtv927qdBwKskr/view?usp=sharing)]
5 |
6 |
7 |
8 |
9 | Official Pytorch implementation to DELT which outperforms SOTA top 1-acc by +1.3% by increasing diversity per class by +5% while reducing time by up to 39.3%.
10 |
11 |
12 |

13 |
14 |
15 | # Contents
16 |
17 | - [Abstract](#abstract)
18 | - [📸 Visualization](#-visualization)
19 | - [Usage](#usage)
20 | - [🗃 Dataset Format](#-dataset-format)
21 | - [🤖 Squeeze](#-squeeze)
22 | - [🧲 Ranking and Selection](#-ranking-and-selection)
23 | - [♻️ Recover](#️-recover)
24 | - [🧪 Evaluation](#-evaluation)
25 | - [📈 Results](#-results)
26 | - [📜 Citation](#-citation)
27 | # Abstract
28 |
29 | Recent advances in dataset distillation have led to solutions in two main directions. The conventional *batch-to-batch* matching mechanism is ideal for small-scale datasets and includes bi-level optimization methods on models and syntheses, such as FRePo, RCIG, and RaT-BPTT, as well as other methods like distribution matching, gradient matching, and weight trajectory matching. Conversely, *batch-to-global* matching typifies decoupled methods, which are particularly advantageous for large-scale datasets. This approach has garnered substantial interest within the community, as seen in SRe$^2$L, G-VBSM, WMDD, and CDA. A primary challenge with the second approach is the lack of diversity among syntheses within each class since samples are optimized independently and the same global supervision signals are reused across different synthetic images. In this study, we propose a new **D**iversity-driven **E**`arly`**L**`ate` **T**raining (DELT) scheme to enhance the diversity of images in batch-to-global matching with less computation. Our approach is conceptually simple yet effective, it partitions predefined IPC samples into smaller subtasks and employs local optimizations to distill each subset into distributions from distinct phases, reducing the uniformity induced by the unified optimization process. These distilled images from the subtasks demonstrate effective generalization when applied to the entire task. We conduct extensive experiments on CIFAR, Tiny-ImageNet, ImageNet-1K, and its sub-datasets. Our approach outperforms the previous state-of-the-art by **1.3%** on average across different datasets and IPCs (images per class), increasing diversity per class by more than **5%** while reducing synthesis time by up to **39.3%**, enhancing the overall efficiency.
30 | # 📸 Visualization
31 |
32 |
33 |

34 |
35 |
36 | # Usage
37 |
38 | ## 🗃 Dataset Format
39 |
40 | The dataset used for recovery and evaluation should be compatbile with [`ImageFolder`](https://pytorch.org/vision/main/generated/torchvision.datasets.ImageFolder.html) class. Refer to the PyTorch documentation for further details.
41 |
42 | ## 🤖 Squeeze
43 |
44 | For teacher models, we follow both SRe$^2$l, CDA, and RDED, using [official torchvision classification code](https://github.com/pytorch/vision/tree/main/references/classification)—we use torchvision pre-trained model for ImageNet-1K, and the other teacher models could be found in RDED's models [here](https://drive.google.com/drive/folders/1HmrheO6MgX453a5UPJdxPHK4UTv-4aVt?usp=drive_link).
45 |
46 | ## 🧲 Ranking and Selection
47 |
48 | Our DELT uses initialization based on 3 main selection criteria:
49 | 1. `top` selects the easiest images, scoring the highest probability of the true class
50 | 2. `min` selects the hardest images, scoring the lowest probabilities of the true class
51 | 3. `medium` selects the images around the median scores of true class probabilities, used in DELT.
52 |
53 | We provide a sample script that in [`scripts`](scripts/) that selects the medium difficulty images from TinyImageNet dataset. You can run the script as below
54 |
55 | ```bash
56 | bash /path/to/scripts/select_medium_tiny_rn18_ep50_ipc50.sh
57 | ```
58 |
59 | We overview some variables
60 | - `TRAIN_DIR`: the training directory from which we select the images
61 | - `OUTPUT_DIR`: the output directory where we store the selected images, compatible with [`ImageFolder`](https://pytorch.org/vision/main/generated/torchvision.datasets.ImageFolder.html)
62 | - `RANKER_PATH`: the path to the model used in ranking; if not provided, we use the pre-trained model from torchvision
63 | - `RANKING_FILE`: the path of the output csv file containing the scores of all the images, which is used in selection
64 |
65 | ## ♻️ Recover
66 |
67 | We provide some [`scripts`](scripts/) to automate the experimentation process. For instance to synthesize ImageNet-1K images using ResNet-18 model, you can use
68 |
69 | ```bash
70 | bash /path/to/scripts/resnet18_imagenet1k_synthesis.sh
71 | ```
72 |
73 | You can use the script while changing the variables as appropriate to your experiment. We overview them below
74 |
75 | - `SYN_PATH`: the target directory where we synthesize data
76 | - `INIT_PATH`: the path to the initial ImageNet-1K images that will be used in initialization
77 | - `IPC`: the number of **I**mage **P**er **C**lass
78 | - `ITR`: total number of update iterations
79 | - `ROUND_ITR`: the number of update iterations in a single round
80 |
81 | ## 🧪 Evaluation
82 |
83 | We provide some evaluation [`scripts`](scripts/) to evaluate the synthesized data. For instance to evaluate the synthesized ImageNet-1K images using ResNet-18 model, you can use
84 |
85 | ```bash
86 | bash /path/to/scripts/resnet18_imagenet1k_validation.sh
87 | ```
88 |
89 | # 📈 Results
90 |
91 | We compare our approach against different methods on different datasets as below
92 |
93 |
94 |

95 |
96 |
97 | We also visualize the inter-class average cosine similarity as an indication for the diversity (lower values are more diverse)
98 |
99 |
100 |

101 |
102 |
103 | # 📜 Citation
104 |
105 | ```
106 | @misc{shen2024deltsimplediversitydrivenearlylate,
107 | title={DELT: A Simple Diversity-driven EarlyLate Training for Dataset Distillation},
108 | author={Zhiqiang Shen and Ammar Sherif and Zeyuan Yin and Shitong Shao},
109 | year={2024},
110 | eprint={2411.19946},
111 | archivePrefix={arXiv},
112 | }
113 | ```
114 |
--------------------------------------------------------------------------------
/data_selection.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torchvision
3 | import os
4 | import torch
5 | import torchvision.transforms as transforms
6 | import torchvision.models as models
7 | import pandas as pd
8 | from tqdm import tqdm
9 | import torch.nn as nn
10 | import math
11 | import shutil
12 |
13 | class ImageFolder(torchvision.datasets.ImageFolder):
14 | def __init__(self, **kwargs):
15 | super(ImageFolder, self).__init__(**kwargs)
16 | self.image_paths = []
17 | self.image_relative_paths = []
18 | self.targets = []
19 | self.samples = []
20 | class_dirs = []
21 | for c_name in os.listdir(self.root):
22 | if os.path.isdir(os.path.join(self.root,c_name)):
23 | class_dirs.append(c_name)
24 | class_dirs.sort()
25 | for c_indx in range(len(class_dirs)):
26 | dir_path = os.path.join(self.root, class_dirs[c_indx])
27 | img_names = os.listdir(dir_path)
28 | for i in range(len(img_names)):
29 | self.image_paths.append(os.path.join(dir_path, img_names[i]))
30 | self.image_relative_paths.append(os.path.join(class_dirs[c_indx], img_names[i]))
31 | self.targets.append(c_indx)
32 |
33 | def __getitem__(self, index):
34 | sample = self.loader(self.image_paths[index])
35 | sample = self.transform(sample)
36 | return sample, self.targets[index], self.image_relative_paths[index]
37 |
38 | def __len__(self):
39 | return len(self.targets)
40 |
41 | def get_args():
42 | parser = argparse.ArgumentParser("Data selection for dataset distillation")
43 | """Data save flags"""
44 | parser.add_argument("--dataset", type=str, default="imagenet-1k")
45 | parser.add_argument("--data-path", type=str, help="location of training data")
46 | parser.add_argument("--output-path", type=str, default="./selected-data", help="location of the selected data")
47 | parser.add_argument("--ranker-path", type=str, default="", help="path to the ranker model")
48 | parser.add_argument("--ranker-arch", type=str, default="resnet18", help="for loading the model")
49 |
50 | parser.add_argument("--ranking-file", type=str, default="", help="csv file that includes the ranking of the data")
51 |
52 | parser.add_argument("--store-rank-file", action="store_true", default=False, help="store csv file that includes the ranking of the data for future use")
53 | parser.add_argument("--ipc", type=int, default=50, help="number of IPC to select")
54 | parser.add_argument("--selection-criteria", type=str, default="medium")
55 |
56 | parser.add_argument("--batch-size", type=int, default=200, help="number of images to load at the same time")
57 | parser.add_argument("--workers", type=int, default=16, help="number of workers in data loader")
58 | parser.add_argument('--gpu-device', default="-1", type=str)
59 | args = parser.parse_args()
60 |
61 | if args.gpu_device != "-1":
62 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_device
63 | print("set CUDA_VISIBLE_DEVICES to ", args.gpu_device)
64 |
65 | print(args)
66 | return args
67 |
68 | def get_model(args):
69 | if args.dataset == "imagenet-1k":
70 | # ImageNet 1K dataset
71 | if args.ranker_path:
72 | model = models.get_model(args.ranker_arch, weights=False, num_classes=1000)
73 | checkpoint = torch.load(args.ranker_path, map_location="cpu")
74 | model.load_state_dict(checkpoint["model"])
75 | model = nn.DataParallel(model).cuda()
76 | model.eval()
77 | for p in model.parameters():
78 | p.requires_grad = False
79 | return model
80 | elif not args.ranker_path and args.ranker_arch:
81 | model = models.get_model(args.ranker_arch, weights="DEFAULT")
82 | model = nn.DataParallel(model).cuda()
83 | model.eval()
84 | for p in model.parameters():
85 | p.requires_grad = False
86 | return model
87 | elif not args.ranker_path:
88 | print(f"You must either provide a checkpoint path using --ranker-path, ")
89 | print("or provide an architecture name using --ranker-arch to load torchvision pretrained weights")
90 | return None
91 | elif args.dataset == "imagenet-100":
92 | assert args.ranker_path
93 | model = models.get_model(args.ranker_arch, weights=False, num_classes=100)
94 | checkpoint = torch.load(args.ranker_path, map_location="cpu")
95 | model.load_state_dict(checkpoint["model"])
96 | model = nn.DataParallel(model).cuda()
97 | model.eval()
98 | for p in model.parameters():
99 | p.requires_grad = False
100 | return model
101 | elif args.dataset == "tiny-imagenet":
102 | assert args.ranker_path
103 | model = models.get_model(args.ranker_arch, weights=False, num_classes=200)
104 | model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
105 | model.maxpool = nn.Identity()
106 | checkpoint = torch.load(args.ranker_path, map_location="cpu")
107 | model.load_state_dict(checkpoint["model"])
108 |
109 | model = nn.DataParallel(model).cuda()
110 | model.eval()
111 | for p in model.parameters():
112 | p.requires_grad = False
113 | return model
114 | elif args.dataset in ["image-woof", "image-nette"]:
115 | model = models.get_model(args.ranker_arch, weights=False, num_classes=10)
116 | checkpoint = torch.load(args.ranker_path, map_location="cpu")
117 | model.load_state_dict(checkpoint["model"])
118 | model = nn.DataParallel(model).cuda()
119 | model.eval()
120 | for p in model.parameters():
121 | p.requires_grad = False
122 | return model
123 | elif args.dataset == "cifar10":
124 | assert args.ranker_path
125 | model = models.get_model(args.ranker_arch, weights=False, num_classes=10)
126 | model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
127 | model.maxpool = nn.Identity()
128 | checkpoint = torch.load(args.ranker_path, map_location="cpu")
129 | model.load_state_dict(checkpoint["model"])
130 |
131 | model = nn.DataParallel(model).cuda()
132 | model.eval()
133 | for p in model.parameters():
134 | p.requires_grad = False
135 | return model
136 | elif args.dataset == "cifar100":
137 | assert args.ranker_path
138 | model = models.get_model(args.ranker_arch, weights=False, num_classes=100)
139 | model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
140 | model.maxpool = nn.Identity()
141 | checkpoint = torch.load(args.ranker_path, map_location="cpu")
142 | model.load_state_dict(checkpoint["model"])
143 |
144 | model = nn.DataParallel(model).cuda()
145 | model.eval()
146 | for p in model.parameters():
147 | p.requires_grad = False
148 | return model
149 | def get_dataloader(args):
150 | if args.dataset in ["imagenet-1k", "imagenet-100", "image-woof", "image-nette"]:
151 | normalize = transforms.Normalize(
152 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
153 | )
154 | dataset = ImageFolder(root=args.data_path, transform=transforms.Compose(
155 | [
156 | transforms.Resize(224 // 7 * 8, antialias=True),
157 | transforms.CenterCrop(224),
158 | transforms.ToTensor(),
159 | normalize,
160 | ]
161 | ),)
162 | return torch.utils.data.DataLoader(dataset, batch_size=args.batch_size,
163 | num_workers= args.workers, shuffle=False)
164 | elif args.dataset == "tiny-imagenet":
165 | normalize = transforms.Normalize(
166 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
167 | )
168 | dataset = ImageFolder(root=args.data_path, transform=transforms.Compose(
169 | [
170 | transforms.Resize(64 // 7 * 8, antialias=True),
171 | transforms.CenterCrop(64),
172 | transforms.ToTensor(),
173 | normalize,
174 | ]
175 | ),)
176 | return torch.utils.data.DataLoader(dataset, batch_size=args.batch_size,
177 | num_workers= args.workers, shuffle=False)
178 | elif args.dataset in ["cifar10", "cifar100"] :
179 | normalize = transforms.Normalize(
180 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
181 | )
182 | dataset = ImageFolder(root=args.data_path, transform=transforms.Compose(
183 | [
184 | transforms.Resize(32 // 7 * 8, antialias=True),
185 | transforms.CenterCrop(32),
186 | transforms.ToTensor(),
187 | normalize,
188 | ]
189 | ),)
190 | return torch.utils.data.DataLoader(dataset, batch_size=args.batch_size,
191 | num_workers= args.workers, shuffle=False)
192 |
193 | def rank(args, model, data_loader):
194 | softmax = torch.nn.Softmax(dim=1)
195 |
196 | ranking = {
197 | 'score': [],
198 | 'img_path': [],
199 | }
200 |
201 | with torch.no_grad():
202 | for images, labels, paths in tqdm(data_loader):
203 | images, labels = images.cuda(), labels.cuda()
204 | output = model(images)
205 | output = softmax(output)
206 | pr = torch.gather(output, dim=1, index=labels.unsqueeze(-1)).squeeze(-1)
207 | ranking['img_path'] += list(paths)
208 | ranking['score'] += pr.tolist()
209 |
210 | ranking_df = pd.DataFrame(ranking)
211 | ranking_df["class"] = ranking_df['img_path'].apply(lambda path: path.split('/')[0])
212 | # sort the values
213 | ranking_df = ranking_df.sort_values(['class', 'score'], ascending=[True, False])
214 | if args.store_rank_file:
215 | dir = "/".join(args.ranking_file.split("/")[:-1])
216 | if not os.path.exists(dir):
217 | os.makedirs(dir)
218 | ranking_df.to_csv(args.ranking_file, index=False)
219 | return ranking_df
220 |
221 | def store_top(ranked_df, class_name, root_dir, target_dir, ipc):
222 | target_dir = os.path.join(target_dir, f"{class_name}")
223 | if os.path.exists(target_dir):
224 | shutil.rmtree(target_dir)
225 |
226 | os.makedirs(target_dir)
227 | df = ranked_df[ranked_df['class'] == class_name].sort_values(['score'], ascending=[False])
228 | for img_indx in range(ipc):
229 | img_path = os.path.join(root_dir, df.iloc[img_indx]['img_path'])
230 | file_name= df.iloc[img_indx]['img_path'].split("/")[-1]
231 | new_path = os.path.join(target_dir, f"{img_indx:03}_{file_name}")
232 | shutil.copy(img_path, new_path)
233 |
234 | def store_min(ranked_df, class_name, root_dir, target_dir, ipc):
235 | target_dir = os.path.join(target_dir, f"{class_name}")
236 | if os.path.exists(target_dir):
237 | shutil.rmtree(target_dir)
238 |
239 | os.makedirs(target_dir)
240 | df = ranked_df[ranked_df['class'] == class_name].sort_values(['score'], ascending=[True])
241 | for img_indx in range(ipc):
242 | img_path = os.path.join(root_dir, df.iloc[img_indx]['img_path'])
243 | file_name= df.iloc[img_indx]['img_path'].split("/")[-1]
244 | new_path = os.path.join(target_dir, f"{img_indx:03}_{file_name}")
245 | shutil.copy(img_path, new_path)
246 |
247 | def store_medium(ranked_df, class_name, root_dir, target_dir, ipc):
248 | target_dir = os.path.join(target_dir, f"{class_name}")
249 | if os.path.exists(target_dir):
250 | shutil.rmtree(target_dir)
251 |
252 | os.makedirs(target_dir)
253 | df = ranked_df[ranked_df['class'] == class_name].sort_values(['score'], ascending=[True])
254 | mid_indx = math.ceil(len(df)/2)
255 | pos_neg = -1
256 | for indx in range(ipc):
257 | img_indx = mid_indx + int(pos_neg*(indx+1)/2)
258 | pos_neg = pos_neg*-1
259 | img_path = os.path.join(root_dir, df.iloc[img_indx]['img_path'])
260 | file_name= df.iloc[img_indx]['img_path'].split("/")[-1]
261 | new_path = os.path.join(target_dir, f"{indx:03}_{file_name}")
262 | shutil.copy(img_path, new_path)
263 |
264 | def store_imgs(ranked_df, class_name, root_dir, target_dir, ipc):
265 | target_dir = os.path.join(target_dir, f"{class_name}")
266 | if os.path.exists(target_dir):
267 | shutil.rmtree(target_dir)
268 |
269 | os.makedirs(target_dir)
270 | df = ranked_df[ranked_df['class'] == class_name].sort_values(['score'], ascending=[False])
271 | for img_indx in range(ipc):
272 | img_path = os.path.join(root_dir, df.iloc[img_indx]['img_path'])
273 | file_name= df.iloc[img_indx]['img_path'].split("/")[-1]
274 | new_path = os.path.join(target_dir, f"{img_indx:03}_{file_name}")
275 | shutil.copy(img_path, new_path)
276 |
277 | if __name__ == "__main__":
278 | args = get_args()
279 | if not args.store_rank_file and args.ranking_file:
280 | ranking_df = pd.read_csv(args.ranking_file)
281 | if "Unnamed: 0" in ranking_df.columns.values.tolist():
282 | del ranking_df['Unnamed: 0']
283 | else:
284 | model = get_model(args)
285 | loader = get_dataloader(args)
286 | if model is None or loader is None:
287 | exit()
288 | ranking_df = rank(args, model, loader)
289 |
290 | if args.selection_criteria =="medium":
291 | store = store_medium
292 | elif args.selection_criteria =="top":
293 | store = store_top
294 | elif args.selection_criteria =="min":
295 | store = store_min
296 | else:
297 | print("Unknown selection crtieria")
298 | exit()
299 |
300 | for class_name in tqdm(ranking_df['class'].unique()):
301 | store(ranking_df, class_name, root_dir=args.data_path,
302 | target_dir=args.output_path, ipc = args.ipc)
--------------------------------------------------------------------------------
/evaluation/argument.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | parser = argparse.ArgumentParser("Evaluation scheme for DELT")
4 | """Architecture and Dataset"""
5 | parser.add_argument(
6 | "--arch-name",
7 | type=str,
8 | default="resnet18",
9 | help="arch name from pretrained torchvision models",
10 | )
11 | parser.add_argument(
12 | "--subset",
13 | type=str,
14 | default="imagenet-1k",
15 | )
16 | parser.add_argument(
17 | "--nclass",
18 | type=int,
19 | default=1000,
20 | help="number of synthesized classes",
21 | )
22 | parser.add_argument(
23 | "--ipc",
24 | type=int,
25 | default=50,
26 | help="number of synthesized images per class",
27 | )
28 | parser.add_argument(
29 | "--input-size",
30 | default=224,
31 | type=int,
32 | metavar="S",
33 | )
34 | """Evaluation arguments"""
35 | parser.add_argument("--re-batch-size", default=0, type=int, metavar="N")
36 | parser.add_argument(
37 | "--re-accum-steps",
38 | type=int,
39 | default=1,
40 | help="gradient accumulation steps for small gpu memory",
41 | )
42 | parser.add_argument(
43 | "--mix-type",
44 | default="cutmix",
45 | type=str,
46 | choices=["mixup", "cutmix", None],
47 | help="mixup or cutmix or None",
48 | )
49 | parser.add_argument(
50 | "--stud-name",
51 | type=str,
52 | default="resnet18",
53 | help="arch name from torchvision models",
54 | )
55 | parser.add_argument(
56 | "--val-ipc",
57 | type=int,
58 | default=30,
59 | )
60 | parser.add_argument(
61 | "--workers",
62 | default=4,
63 | type=int,
64 | metavar="N",
65 | help="number of data loading workers (default: 4)",
66 | )
67 | parser.add_argument(
68 | "--classes",
69 | type=list,
70 | help="number of classes for synthesis",
71 | )
72 | parser.add_argument(
73 | "--temperature",
74 | type=float,
75 | help="temperature for distillation loss",
76 | )
77 | parser.add_argument(
78 | "--val-dir",
79 | type=str,
80 | default="/workspace/data/val/",
81 | help="path to validation dataset",
82 | )
83 | parser.add_argument(
84 | "--min-scale-crops", type=float, default=0.08, help="argument in RandomResizedCrop"
85 | )
86 | parser.add_argument(
87 | "--max-scale-crops", type=float, default=1, help="argument in RandomResizedCrop"
88 | )
89 | parser.add_argument("--re-epochs", default=300, type=int)
90 | parser.add_argument(
91 | "--syn-data-path",
92 | type=str,
93 | default="/workspace/data/DELT_IN_1K/ipc50_cr5_mipc300",
94 | help="path to synthetic data",
95 | )
96 | parser.add_argument(
97 | "--seed", default=42, type=int, help="seed for initializing training. "
98 | )
99 | parser.add_argument(
100 | "--mixup",
101 | type=float,
102 | default=0.8,
103 | help="mixup alpha, mixup enabled if > 0. (default: 0.8)",
104 | )
105 | parser.add_argument(
106 | "--cutmix",
107 | type=float,
108 | default=1.0,
109 | help="cutmix alpha, cutmix enabled if > 0. (default: 1.0)",
110 | )
111 | parser.add_argument("--cos", default=True, help="cosine lr scheduler")
112 | parser.add_argument("--use-rand-augment", default=False, action="store_true", help="use Timm's RandAugment")
113 | parser.add_argument('--rand-augment-config', type=str, default="rand-m6-n2-mstd1.0", help='RandAugment configuration')
114 |
115 | # sgd
116 | parser.add_argument("--sgd", default=False, action="store_true", help="sgd optimizer")
117 | parser.add_argument(
118 | "-lr",
119 | "--learning-rate",
120 | type=float,
121 | default=0.1,
122 | help="sgd init learning rate",
123 | )
124 | parser.add_argument("--momentum", type=float, default=0.9, help="sgd momentum")
125 | parser.add_argument("--weight-decay", type=float, default=1e-4, help="sgd weight decay")
126 |
127 | # adamw
128 | parser.add_argument("--adamw-lr", type=float, default=0, help="adamw learning rate")
129 | parser.add_argument(
130 | "--adamw-weight-decay", type=float, default=0.01, help="adamw weight decay"
131 | )
132 | parser.add_argument(
133 | "--exp-name",
134 | type=str,
135 | help="name of the experiment, subfolder under syn_data_path",
136 | )
137 | parser.add_argument('--wandb-api-key', type=str,
138 | default=None, help='wandb api key')
139 | parser.add_argument('--wandb-project', type=str,
140 | default='imagnet_1k_distillation', help='wandb project name')
141 | parser.add_argument('--gpu-device', default="-1", type=str)
142 | args = parser.parse_args()
143 |
144 | # set up dataset settings
145 | # set smaller val_ipc only for quick validation
146 | if args.subset in [
147 | "imagenet-a",
148 | "imagenet-b",
149 | "imagenet-c",
150 | "imagenet-d",
151 | "imagenet-e",
152 | "imagenet-birds",
153 | "imagenet-fruits",
154 | "imagenet-cats",
155 | "imagenet-10",
156 | ]:
157 | args.nclass = 10
158 | args.classes = range(args.nclass)
159 | args.val_ipc = 50
160 | args.input_size = 224
161 |
162 | elif args.subset == "imagenet-nette":
163 | args.nclass = 10
164 | args.classes = range(args.nclass)
165 | args.val_ipc = 50
166 | args.input_size = 224
167 | if args.arch_name in ["conv5", "conv6"] or args.stud_name in ["conv5", "conv6"]:
168 | args.input_size = 128
169 |
170 | elif args.subset == "imagenet-woof":
171 | args.nclass = 10
172 | args.classes = range(args.nclass)
173 | args.val_ipc = 50
174 | args.input_size = 224
175 | if args.arch_name in ["conv5", "conv6"] or args.stud_name in ["conv5", "conv6"]:
176 | args.input_size = 128
177 |
178 | elif args.subset == "imagenet-100":
179 | args.nclass = 100
180 | args.classes = range(args.nclass)
181 | args.val_ipc = 50
182 | args.input_size = 224
183 | if args.arch_name in ["conv5", "conv6"] or args.stud_name in ["conv5", "conv6"]:
184 | args.input_size = 128
185 |
186 | elif args.subset == "imagenet-1k":
187 | args.nclass = 1000
188 | args.classes = range(args.nclass)
189 | args.val_ipc = 50
190 | args.input_size = 224
191 | if "conv" in args.arch_name or "conv" in args.stud_name:
192 | args.input_size = 64
193 |
194 | elif args.subset == "cifar10":
195 | args.nclass = 10
196 | args.classes = range(args.nclass)
197 | args.val_ipc = 1000
198 | args.input_size = 32
199 |
200 | elif args.subset == "cifar100":
201 | args.nclass = 100
202 | args.classes = range(args.nclass)
203 | args.val_ipc = 100
204 | args.input_size = 32
205 |
206 | elif args.subset == "tinyimagenet":
207 | args.nclass = 200
208 | args.classes = range(args.nclass)
209 | args.val_ipc = 50
210 | args.input_size = 64
211 |
212 | args.nclass = len(args.classes)
213 |
214 | # set up batch size
215 | if args.re_batch_size == 0:
216 | if args.ipc == 100:
217 | args.re_batch_size = 100
218 | args.workers = 8
219 | elif args.ipc == 50:
220 | args.re_batch_size = 100
221 | args.workers = 4
222 | elif args.ipc in [30, 40]:
223 | args.re_batch_size = 100
224 | args.workers = 4
225 | elif args.ipc in [10, 20]:
226 | args.re_batch_size = 50
227 | args.workers = 4
228 | elif args.ipc == 1:
229 | args.re_batch_size = 10
230 | args.workers = 4
231 |
232 | if args.nclass == 10:
233 | args.re_batch_size *= 1
234 | if args.nclass == 100:
235 | args.re_batch_size *= 2
236 | if args.nclass == 1000:
237 | args.re_batch_size *= 2
238 | if args.subset in ["imagenet-1k", "imagenet-100"] and args.stud_name == "resnet101":
239 | args.re_batch_size = args.re_batch_size//2
240 | # ! tinyimagenet
241 | if args.subset == "tinyimagenet":
242 | args.re_batch_size = 100
243 |
244 | # reset batch size below ipc * nclass
245 | if args.re_batch_size > args.ipc * args.nclass:
246 | args.re_batch_size = int(args.ipc * args.nclass)
247 |
248 | # reset batch size with re_accum_steps
249 | if args.re_accum_steps != 1:
250 | args.re_batch_size = int(args.re_batch_size / args.re_accum_steps)
251 |
252 | # temperature
253 | if args.mix_type == "mixup":
254 | args.temperature = 4
255 | elif args.mix_type == "cutmix":
256 | args.temperature = 20
257 |
258 | # adamw learning rate
259 | if args.stud_name == "vgg11":
260 | args.adamw_lr = 0.0005
261 | elif args.stud_name == "conv3":
262 | args.adamw_lr = 0.001
263 | elif args.stud_name == "conv4":
264 | args.adamw_lr = 0.001
265 | elif args.stud_name == "conv5":
266 | args.adamw_lr = 0.001
267 | elif args.stud_name == "conv6":
268 | args.adamw_lr = 0.001
269 | elif args.stud_name == "resnet18":
270 | args.adamw_lr = 0.001
271 | elif args.stud_name == "resnet18_modified":
272 | args.adamw_lr = 0.001
273 | elif args.stud_name == "efficientnet_b0":
274 | args.adamw_lr = 0.002
275 | elif args.stud_name in ["mobilenet_v2", "mobilenet_v2_modified"]:
276 | args.adamw_lr = 0.0025
277 | elif args.stud_name == "alexnet":
278 | args.adamw_lr = 0.0001
279 | elif args.stud_name == "resnet50":
280 | args.adamw_lr = 0.001
281 | elif args.stud_name == "resnet101":
282 | args.adamw_lr = 0.001
283 | elif args.stud_name == "resnet101_modified":
284 | args.adamw_lr = 0.001
285 | elif args.stud_name == "vit_b_16":
286 | args.adamw_lr = 0.0001
287 | elif args.stud_name == "swin_v2_t":
288 | args.adamw_lr = 0.0001
289 |
290 | # special experiment
291 | if (
292 | args.subset == "cifar100"
293 | and args.arch_name == "conv3"
294 | and args.stud_name == "conv3"
295 | ):
296 | args.re_batch_size = 25
297 | args.adamw_lr = 0.002
298 |
299 | print(f"args: {args}")
300 |
--------------------------------------------------------------------------------
/evaluation/main.py:
--------------------------------------------------------------------------------
1 | """This code is adopted from RDED here: https://github.com/LINs-lab/RDED/tree/main"""
2 | from argument import args
3 | from validation.main import (
4 | main as valid_main,
5 | ) # The relabel and validation are combined here for fast experiment
6 | import warnings
7 |
8 | warnings.filterwarnings("ignore")
9 |
10 | if __name__ == "__main__":
11 | valid_main(args)
12 |
--------------------------------------------------------------------------------
/evaluation/model_utils/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | # Conv-3 model
6 | class ConvNet(nn.Module):
7 | def __init__(
8 | self,
9 | num_classes,
10 | net_norm="batch",
11 | net_depth=3,
12 | net_width=128,
13 | channel=3,
14 | net_act="relu",
15 | net_pooling="avgpooling",
16 | im_size=(32, 32),
17 | ):
18 | # print(f"Define Convnet (depth {net_depth}, width {net_width}, norm {net_norm})")
19 | super(ConvNet, self).__init__()
20 | if net_act == "sigmoid":
21 | self.net_act = nn.Sigmoid()
22 | elif net_act == "relu":
23 | self.net_act = nn.ReLU()
24 | elif net_act == "leakyrelu":
25 | self.net_act = nn.LeakyReLU(negative_slope=0.01)
26 | else:
27 | exit("unknown activation function: %s" % net_act)
28 |
29 | if net_pooling == "maxpooling":
30 | self.net_pooling = nn.MaxPool2d(kernel_size=2, stride=2)
31 | elif net_pooling == "avgpooling":
32 | self.net_pooling = nn.AvgPool2d(kernel_size=2, stride=2)
33 | elif net_pooling == "none":
34 | self.net_pooling = None
35 | else:
36 | exit("unknown net_pooling: %s" % net_pooling)
37 |
38 | self.depth = net_depth
39 | self.net_norm = net_norm
40 |
41 | self.layers, shape_feat = self._make_layers(
42 | channel, net_width, net_depth, net_norm, net_pooling, im_size
43 | )
44 | num_feat = shape_feat[0] * shape_feat[1] * shape_feat[2]
45 | self.classifier = nn.Linear(num_feat, num_classes)
46 |
47 | def forward(self, x, return_features=False):
48 | for d in range(self.depth):
49 | x = self.layers["conv"][d](x)
50 | if len(self.layers["norm"]) > 0:
51 | x = self.layers["norm"][d](x)
52 | x = self.layers["act"][d](x)
53 | if len(self.layers["pool"]) > 0:
54 | x = self.layers["pool"][d](x)
55 |
56 | # x = nn.functional.avg_pool2d(x, x.shape[-1])
57 | out = x.view(x.shape[0], -1)
58 | logit = self.classifier(out)
59 |
60 | if return_features:
61 | return logit, out
62 | else:
63 | return logit
64 |
65 | def get_feature(
66 | self, x, idx_from, idx_to=-1, return_prob=False, return_logit=False
67 | ):
68 | if idx_to == -1:
69 | idx_to = idx_from
70 | features = []
71 |
72 | for d in range(self.depth):
73 | x = self.layers["conv"][d](x)
74 | if self.net_norm:
75 | x = self.layers["norm"][d](x)
76 | x = self.layers["act"][d](x)
77 | if self.net_pooling:
78 | x = self.layers["pool"][d](x)
79 | features.append(x)
80 | if idx_to < len(features):
81 | return features[idx_from : idx_to + 1]
82 |
83 | if return_prob:
84 | out = x.view(x.size(0), -1)
85 | logit = self.classifier(out)
86 | prob = torch.softmax(logit, dim=-1)
87 | return features, prob
88 | elif return_logit:
89 | out = x.view(x.size(0), -1)
90 | logit = self.classifier(out)
91 | return features, logit
92 | else:
93 | return features[idx_from : idx_to + 1]
94 |
95 | def _get_normlayer(self, net_norm, shape_feat):
96 | # shape_feat = (c * h * w)
97 | if net_norm == "batch":
98 | norm = nn.BatchNorm2d(shape_feat[0], affine=True)
99 | elif net_norm == "layer":
100 | norm = nn.LayerNorm(shape_feat, elementwise_affine=True)
101 | elif net_norm == "instance":
102 | norm = nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True)
103 | elif net_norm == "group":
104 | norm = nn.GroupNorm(4, shape_feat[0], affine=True)
105 | elif net_norm == "none":
106 | norm = None
107 | else:
108 | norm = None
109 | exit("unknown net_norm: %s" % net_norm)
110 | return norm
111 |
112 | def _make_layers(
113 | self, channel, net_width, net_depth, net_norm, net_pooling, im_size
114 | ):
115 | layers = {"conv": [], "norm": [], "act": [], "pool": []}
116 |
117 | in_channels = channel
118 | if im_size[0] == 28:
119 | im_size = (32, 32)
120 | shape_feat = [in_channels, im_size[0], im_size[1]]
121 |
122 | for d in range(net_depth):
123 | layers["conv"] += [
124 | nn.Conv2d(
125 | in_channels,
126 | net_width,
127 | kernel_size=3,
128 | padding=3 if channel == 1 and d == 0 else 1,
129 | )
130 | ]
131 | shape_feat[0] = net_width
132 | if net_norm != "none":
133 | layers["norm"] += [self._get_normlayer(net_norm, shape_feat)]
134 | layers["act"] += [self.net_act]
135 | in_channels = net_width
136 | if net_pooling != "none":
137 | layers["pool"] += [self.net_pooling]
138 | shape_feat[1] //= 2
139 | shape_feat[2] //= 2
140 |
141 | layers["conv"] = nn.ModuleList(layers["conv"])
142 | layers["norm"] = nn.ModuleList(layers["norm"])
143 | layers["act"] = nn.ModuleList(layers["act"])
144 | layers["pool"] = nn.ModuleList(layers["pool"])
145 | layers = nn.ModuleDict(layers)
146 |
147 | return layers, shape_feat
148 |
149 | """https://github.com/megvii-research/mdistiller/blob/a08d46f10d6102bd6e3f258ca5ac880b020ea259/mdistiller/models/cifar/mv2_tinyimagenet.py"""
150 | class LinearBottleNeck(nn.Module):
151 |
152 | def __init__(self, in_channels, out_channels, stride, t=6, class_num=100):
153 | super().__init__()
154 |
155 | self.residual = nn.Sequential(
156 | nn.Conv2d(in_channels, in_channels * t, 1),
157 | nn.BatchNorm2d(in_channels * t),
158 | nn.ReLU6(inplace=True),
159 |
160 | nn.Conv2d(in_channels * t, in_channels * t, 3, stride=stride, padding=1, groups=in_channels * t),
161 | nn.BatchNorm2d(in_channels * t),
162 | nn.ReLU6(inplace=True),
163 |
164 | nn.Conv2d(in_channels * t, out_channels, 1),
165 | nn.BatchNorm2d(out_channels)
166 | )
167 |
168 | self.stride = stride
169 | self.in_channels = in_channels
170 | self.out_channels = out_channels
171 |
172 | def forward(self, x):
173 |
174 | residual = self.residual(x)
175 |
176 | if self.stride == 1 and self.in_channels == self.out_channels:
177 | residual += x
178 |
179 | return residual
180 |
181 | class MobileNetV2(nn.Module):
182 |
183 | def __init__(self, num_classes=100):
184 | super().__init__()
185 |
186 | self.pre = nn.Sequential(
187 | nn.Conv2d(3, 32, 1, padding=1),
188 | nn.BatchNorm2d(32),
189 | nn.ReLU6(inplace=True)
190 | )
191 |
192 | self.stage1 = LinearBottleNeck(32, 16, 1, 1)
193 | self.stage2 = self._make_stage(2, 16, 24, 2, 6)
194 | self.stage3 = self._make_stage(3, 24, 32, 2, 6)
195 | self.stage4 = self._make_stage(4, 32, 64, 2, 6)
196 | self.stage5 = self._make_stage(3, 64, 96, 1, 6)
197 | self.stage6 = self._make_stage(3, 96, 160, 1, 6)
198 | self.stage7 = LinearBottleNeck(160, 320, 1, 6)
199 |
200 | self.conv1 = nn.Sequential(
201 | nn.Conv2d(320, 1280, 1),
202 | nn.BatchNorm2d(1280),
203 | nn.ReLU6(inplace=True)
204 | )
205 |
206 | self.conv2 = nn.Conv2d(1280, num_classes, 1)
207 |
208 | def forward(self, x):
209 | x = self.pre(x)
210 | x = self.stage1(x)
211 | x = self.stage2(x)
212 | x = self.stage3(x)
213 | x = self.stage4(x)
214 | x = self.stage5(x)
215 | x = self.stage6(x)
216 | x = self.stage7(x)
217 | x = self.conv1(x)
218 | x = F.adaptive_avg_pool2d(x, 1)
219 | x = self.conv2(x)
220 | x = x.view(x.size(0), -1)
221 |
222 | return x
223 |
224 | def _make_stage(self, repeat, in_channels, out_channels, stride, t):
225 |
226 | layers = []
227 | layers.append(LinearBottleNeck(in_channels, out_channels, stride, t))
228 |
229 | while repeat - 1:
230 | layers.append(LinearBottleNeck(out_channels, out_channels, 1, t))
231 | repeat -= 1
232 |
233 | return nn.Sequential(*layers)
234 |
235 | def mobilenetv2_tinyimagenet():
236 | return MobileNetV2(num_classes = 200)
237 |
238 | def mobilenetv2_cifar10():
239 | return MobileNetV2(num_classes = 10)
--------------------------------------------------------------------------------
/evaluation/model_utils/utils.py:
--------------------------------------------------------------------------------
1 | """this code is modified from the RDED repo: https://github.com/LINs-lab/RDED/tree/main"""
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | from argument import args as sys_args
5 | import torch
6 | import numpy as np
7 | import torchvision.transforms as transforms
8 |
9 | import torchvision.models as thmodels
10 | from torchvision.models._api import WeightsEnum
11 | from torch.hub import load_state_dict_from_url
12 | from model_utils.models import ConvNet, mobilenetv2_tinyimagenet, mobilenetv2_cifar10
13 |
14 | import math
15 |
16 | # use 0 to pad "other three picture"
17 | def pad(input_tensor, target_height, target_width=None):
18 | if target_width is None:
19 | target_width = target_height
20 | vertical_padding = target_height - input_tensor.size(2)
21 | horizontal_padding = target_width - input_tensor.size(3)
22 |
23 | top_padding = vertical_padding // 2
24 | bottom_padding = vertical_padding - top_padding
25 | left_padding = horizontal_padding // 2
26 | right_padding = horizontal_padding - left_padding
27 |
28 | padded_tensor = F.pad(
29 | input_tensor, (left_padding, right_padding, top_padding, bottom_padding)
30 | )
31 |
32 | return padded_tensor
33 |
34 |
35 | def batched_forward(model, tensor, batch_size):
36 | total_samples = tensor.size(0)
37 |
38 | all_outputs = []
39 |
40 | model.eval()
41 |
42 | with torch.no_grad():
43 | for i in range(0, total_samples, batch_size):
44 | batch_data = tensor[i : min(i + batch_size, total_samples)]
45 |
46 | output = model(batch_data)
47 |
48 | all_outputs.append(output)
49 |
50 | final_output = torch.cat(all_outputs, dim=0)
51 |
52 | return final_output
53 |
54 |
55 | class MultiRandomCrop(torch.nn.Module):
56 | def __init__(self, num_crop=5, size=224, factor=2):
57 | super().__init__()
58 | self.num_crop = num_crop
59 | self.size = size
60 | self.factor = factor
61 |
62 | def forward(self, image):
63 | cropper = transforms.RandomResizedCrop(
64 | self.size // self.factor,
65 | ratio=(1, 1),
66 | antialias=True,
67 | )
68 | patches = []
69 | for _ in range(self.num_crop):
70 | patches.append(cropper(image))
71 | return torch.stack(patches, 0)
72 |
73 | def __repr__(self) -> str:
74 | detail = f"(num_crop={self.num_crop}, size={self.size})"
75 | return f"{self.__class__.__name__}{detail}"
76 |
77 |
78 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
79 |
80 | denormalize = transforms.Compose(
81 | [
82 | transforms.Normalize(
83 | mean=[0.0, 0.0, 0.0], std=[1 / 0.229, 1 / 0.224, 1 / 0.225]
84 | ),
85 | transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1.0, 1.0, 1.0]),
86 | ]
87 | )
88 |
89 |
90 | def get_state_dict(self, *args, **kwargs):
91 | kwargs.pop("check_hash")
92 | return load_state_dict_from_url(self.url, *args, **kwargs)
93 |
94 |
95 | WeightsEnum.get_state_dict = get_state_dict
96 |
97 |
98 | def cross_entropy(y_pre, y):
99 | y_pre = F.softmax(y_pre, dim=1)
100 | return (-torch.log(y_pre.gather(1, y.view(-1, 1))))[:, 0]
101 |
102 |
103 | def selector(n, model, images, labels, size, m=5):
104 | with torch.no_grad():
105 | # [mipc, m, 3, 224, 224]
106 | images = images.cuda()
107 | s = images.shape
108 |
109 | # [mipc * m, 3, 224, 224]
110 | images = images.permute(1, 0, 2, 3, 4)
111 | images = images.reshape(s[0] * s[1], s[2], s[3], s[4])
112 |
113 | # [mipc * m, 1]
114 | labels = labels.repeat(m).cuda()
115 |
116 | # [mipc * m, n_class]
117 | batch_size = s[0] # Change it for small GPU memory
118 | preds = batched_forward(model, pad(images, size).cuda(), batch_size)
119 |
120 | # [mipc * m]
121 | dist = cross_entropy(preds, labels)
122 |
123 | # [m, mipc]
124 | dist = dist.reshape(m, s[0])
125 |
126 | n_patch_per_image = math.ceil(n/dist.shape[-1])
127 |
128 | # [mipc]
129 | # [n_patch_per_image, mipc]
130 | index = torch.argsort(dist, dim=0)[:n_patch_per_image,:]
131 | # print(f"index_1 shape: {index_1.shape}")
132 | # print(f"index_1 val: {index_1[0,0]}")
133 | # index = torch.argmin(dist, 0)
134 | # print(f"index shape: {index.shape}")
135 | # print(f"index val: {index[0]}")
136 | # dist_1 = dist[index_1, torch.arange(s[0])]
137 | # print(f"dist_1: {dist_1.shape}")
138 | dist = dist[index, torch.arange(s[0])]
139 |
140 | # [mipc*n_patch_per_image, 3, 224, 224]
141 | sa = images.shape
142 | images = images.reshape(m, s[0], sa[1], sa[2], sa[3])
143 | images = images[index, torch.arange(s[0])]
144 | images = images.reshape(n_patch_per_image*s[0], sa[1], sa[2], sa[3])
145 | dist = dist.reshape(n_patch_per_image*s[0])
146 |
147 | indices = torch.argsort(dist, descending=False)[:n]
148 | torch.cuda.empty_cache()
149 | images = images[indices]
150 | # shuffle
151 | indexes = torch.randperm(images.shape[0])
152 | return images[indexes].detach()
153 |
154 |
155 | def mix_images(input_img, out_size, factor, n):
156 | s = out_size // factor
157 | remained = out_size % factor
158 | k = 0
159 | mixed_images = torch.zeros(
160 | (n, 3, out_size, out_size),
161 | requires_grad=False,
162 | dtype=torch.float,
163 | )
164 | h_loc = 0
165 | for i in range(factor):
166 | h_r = s + 1 if i < remained else s
167 | w_loc = 0
168 | for j in range(factor):
169 | w_r = s + 1 if j < remained else s
170 | # print(f"{k * n} : {(k + 1) * n}")
171 |
172 | img_part = F.interpolate(
173 | input_img.data[k * n : (k + 1) * n], size=(h_r, w_r)
174 | )
175 | # print(f"shape: {img_part.shape}")
176 | mixed_images.data[
177 | 0:n,
178 | :,
179 | h_loc : h_loc + h_r,
180 | w_loc : w_loc + w_r,
181 | ] = img_part
182 | w_loc += w_r
183 | k += 1
184 | h_loc += h_r
185 | return mixed_images
186 |
187 |
188 | def load_model(model_name="resnet18", dataset="cifar10", pretrained=True, classes=[]):
189 | def get_model(model_name="resnet18"):
190 | if "conv" in model_name:
191 | if dataset in ["cifar10", "cifar100"]:
192 | size = 32
193 | elif dataset in ["tinyimagenet", "imagenet-1k"]:
194 | size = 64
195 | elif dataset in ["imagenet-nette", "imagenet-woof", "imagenet-100"]:
196 | size = 128
197 | else:
198 | raise Exception("Unrecognized dataset")
199 |
200 | nclass = len(classes)
201 |
202 | model = ConvNet(
203 | num_classes=nclass,
204 | net_norm="batch",
205 | net_act="relu",
206 | net_pooling="avgpooling",
207 | net_depth=int(model_name[-1]),
208 | net_width=128,
209 | channel=3,
210 | im_size=(size, size),
211 | )
212 | elif model_name == "resnet18_modified":
213 | model = thmodels.__dict__["resnet18"](pretrained=False)
214 | model.conv1 = nn.Conv2d(
215 | 3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
216 | )
217 | model.maxpool = nn.Identity()
218 | elif model_name == "resnet101_modified":
219 | model = thmodels.__dict__["resnet101"](pretrained=False)
220 | model.conv1 = nn.Conv2d(
221 | 3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
222 | )
223 | model.maxpool = nn.Identity()
224 | else:
225 | model = thmodels.__dict__[model_name](pretrained=False)
226 |
227 | return model
228 |
229 | def pruning_classifier(model=None, classes=[]):
230 | try:
231 | model_named_parameters = [name for name, x in model.named_parameters()]
232 | for name, x in model.named_parameters():
233 | if (
234 | name == model_named_parameters[-1]
235 | or name == model_named_parameters[-2]
236 | ):
237 | x.data = x[classes]
238 | except:
239 | print("ERROR in changing the number of classes.")
240 |
241 | return model
242 |
243 | if not pretrained and dataset == "tinyimagenet" and model_name == "mobilenet_v2_modified":
244 | return mobilenetv2_tinyimagenet()
245 | elif not pretrained and dataset == "cifar10" and model_name == "mobilenet_v2_modified":
246 | return mobilenetv2_cifar10()
247 | # "imagenet-100" "imagenet-10" "imagenet-first" "imagenet-nette" "imagenet-woof"
248 | model = get_model(model_name)
249 | model = pruning_classifier(model, classes)
250 | if pretrained:
251 | if dataset in [
252 | "imagenet-100",
253 | "imagenet-10",
254 | "imagenet-nette",
255 | "imagenet-woof",
256 | "tinyimagenet",
257 | "cifar10",
258 | "cifar100",
259 | ]:
260 | checkpoint = torch.load(
261 | f"/workspace/save/{dataset}_{model_name}.pth", map_location="cpu"
262 | )
263 | model.load_state_dict(checkpoint["model"])
264 | elif dataset in ["imagenet-1k"]:
265 | if model_name == "efficientNet-b0":
266 | # Specifically, for loading the pre-trained EfficientNet model, the following modifications are made
267 | from torchvision.models._api import WeightsEnum
268 | from torch.hub import load_state_dict_from_url
269 |
270 | def get_state_dict(self, *args, **kwargs):
271 | kwargs.pop("check_hash")
272 | return load_state_dict_from_url(self.url, *args, **kwargs)
273 |
274 | WeightsEnum.get_state_dict = get_state_dict
275 | elif "conv" in model_name:
276 | checkpoint = torch.load(
277 | f"/workspace/save/{dataset}_{model_name}.pth", map_location="cpu"
278 | )
279 | model.load_state_dict(checkpoint["model"])
280 | else:
281 | model = thmodels.__dict__[model_name](pretrained=True)
282 | elif not pretrained and model_name == "mobilenet_v2":
283 | return thmodels.get_model("mobilenet_v2", weights=False, num_classes=len(classes))
284 | return model
285 |
--------------------------------------------------------------------------------
/evaluation/validation/main.py:
--------------------------------------------------------------------------------
1 | """this code is modified from the RDED repo: https://github.com/LINs-lab/RDED/tree/main"""
2 | import os
3 | import random
4 |
5 | import wandb
6 |
7 | import torch
8 | import torch.nn.parallel
9 | import torch.backends.cudnn as cudnn
10 | import torch.optim
11 | import torch.utils.data
12 | import torch.utils.data.distributed
13 | import torchvision.transforms as transforms
14 | from torch.optim.lr_scheduler import LambdaLR
15 | import math
16 | import time
17 | import torch.nn as nn
18 | import torch.nn.functional as F
19 |
20 | from model_utils.utils import load_model
21 | from validation.utils import (
22 | ImageFolder,
23 | mix_aug,
24 | AverageMeter,
25 | accuracy,
26 | get_parameters,
27 | )
28 |
29 | from timm.data.auto_augment import rand_augment_transform
30 |
31 | sharing_strategy = "file_system"
32 | torch.multiprocessing.set_sharing_strategy(sharing_strategy)
33 |
34 |
35 | def set_worker_sharing_strategy(worker_id: int) -> None:
36 | torch.multiprocessing.set_sharing_strategy(sharing_strategy)
37 |
38 |
39 | def main(args):
40 | if args.seed is not None:
41 | random.seed(args.seed)
42 | torch.manual_seed(args.seed)
43 | main_worker(args)
44 | wandb.finish()
45 |
46 |
47 | def main_worker(args):
48 | if args.gpu_device != "-1":
49 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_device
50 | print("set CUDA_VISIBLE_DEVICES to ", args.gpu_device)
51 | wandb.login(key=args.wandb_api_key)
52 | wandb.init(project=args.wandb_project, name=args.exp_name)
53 |
54 | print("=> using pytorch pre-trained teacher model '{}'".format(args.arch_name))
55 | teacher_model = load_model(
56 | model_name=args.arch_name,
57 | dataset=args.subset,
58 | pretrained=True,
59 | classes=args.classes,
60 | )
61 |
62 | student_model = load_model(
63 | model_name=args.stud_name,
64 | dataset=args.subset,
65 | pretrained=False,
66 | classes=args.classes,
67 | )
68 | teacher_model = torch.nn.DataParallel(teacher_model).cuda()
69 | student_model = torch.nn.DataParallel(student_model).cuda()
70 |
71 | teacher_model.eval()
72 | student_model.train()
73 |
74 | # freeze all layers
75 | for param in teacher_model.parameters():
76 | param.requires_grad = False
77 |
78 | cudnn.benchmark = True
79 |
80 | # optimizer
81 | if args.sgd:
82 | optimizer = torch.optim.SGD(
83 | get_parameters(student_model),
84 | lr=args.learning_rate,
85 | momentum=args.momentum,
86 | weight_decay=args.weight_decay,
87 | )
88 | else:
89 | optimizer = torch.optim.AdamW(
90 | get_parameters(student_model),
91 | lr=args.adamw_lr,
92 | betas=[0.9, 0.999],
93 | weight_decay=args.adamw_weight_decay,
94 | )
95 |
96 | # lr scheduler
97 | if args.cos == True:
98 | scheduler = LambdaLR(
99 | optimizer,
100 | lambda step: 0.5 * (1.0 + math.cos(math.pi * step / args.re_epochs / 2))
101 | if step <= args.re_epochs
102 | else 0,
103 | last_epoch=-1,
104 | )
105 | else:
106 | scheduler = LambdaLR(
107 | optimizer,
108 | lambda step: (1.0 - step / args.re_epochs) if step <= args.re_epochs else 0,
109 | last_epoch=-1,
110 | )
111 |
112 | print("process data from {}".format(args.syn_data_path))
113 | normalize = transforms.Normalize(
114 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
115 | )
116 |
117 | augment = []
118 |
119 | if args.use_rand_augment:
120 | tfm = rand_augment_transform(
121 | config_str=args.rand_augment_config,
122 | hparams={'translate_const': 117, 'img_mean': (124, 116, 104)}
123 | )
124 |
125 | augment.append(tfm)
126 |
127 | augment.append(transforms.ToTensor())
128 |
129 | if args.random_erasing_p != 0.0:
130 | augment.append(transforms.RandomErasing(p=args.random_erasing_p))
131 |
132 | augment.append(
133 | transforms.RandomResizedCrop(
134 | size=args.input_size,
135 | scale=(args.min_scale_crops, args.max_scale_crops),
136 | antialias=True,
137 | )
138 | )
139 | augment.append(transforms.RandomHorizontalFlip())
140 | augment.append(normalize)
141 |
142 | train_dataset = ImageFolder(
143 | classes=range(args.nclass),
144 | ipc=args.ipc,
145 | mem=False,
146 | shuffle=True,
147 | root=args.syn_data_path,
148 | transform=transforms.Compose(augment),
149 | )
150 |
151 | train_loader = torch.utils.data.DataLoader(
152 | train_dataset,
153 | batch_size=args.re_batch_size,
154 | shuffle=True,
155 | num_workers=args.workers,
156 | pin_memory=True,
157 | worker_init_fn=set_worker_sharing_strategy,
158 | )
159 |
160 | val_loader = torch.utils.data.DataLoader(
161 | ImageFolder(
162 | classes=args.classes,
163 | ipc=args.val_ipc,
164 | mem=False,
165 | root=args.val_dir,
166 | transform=transforms.Compose(
167 | [
168 | transforms.Resize(args.input_size // 7 * 8, antialias=True),
169 | transforms.CenterCrop(args.input_size),
170 | transforms.ToTensor(),
171 | normalize,
172 | ]
173 | ),
174 | ),
175 | batch_size=args.re_batch_size,
176 | shuffle=False,
177 | num_workers=args.workers,
178 | pin_memory=True,
179 | worker_init_fn=set_worker_sharing_strategy,
180 | )
181 | print("load data successfully")
182 |
183 | best_acc1 = 0
184 | best_epoch = 0
185 | args.optimizer = optimizer
186 | args.scheduler = scheduler
187 | args.train_loader = train_loader
188 | args.val_loader = val_loader
189 |
190 | for epoch in range(args.re_epochs):
191 | global wandb_metrics
192 | wandb_metrics = {}
193 | train(epoch, train_loader, teacher_model, student_model, args)
194 |
195 | if epoch % 10 == 0 or epoch == args.re_epochs - 1:
196 | top1 = validate(student_model, args, epoch)
197 | else:
198 | top1 = 0
199 | wandb.log(wandb_metrics)
200 | scheduler.step()
201 | if top1 > best_acc1:
202 | best_acc1 = max(top1, best_acc1)
203 | best_epoch = epoch
204 |
205 | print(f"Train Finish! Best accuracy is {best_acc1}@{best_epoch}")
206 |
207 |
208 | def train(epoch, train_loader, teacher_model, student_model, args):
209 | """Generate soft labels and train"""
210 | objs = AverageMeter()
211 | top1 = AverageMeter()
212 | top5 = AverageMeter()
213 |
214 | optimizer = args.optimizer
215 | loss_function_kl = nn.KLDivLoss(reduction="batchmean")
216 | teacher_model.eval()
217 | student_model.train()
218 | t1 = time.time()
219 | for batch_idx, (images, labels) in enumerate(train_loader):
220 | with torch.no_grad():
221 | images = images.cuda()
222 | labels = labels.cuda()
223 |
224 | mix_images, _, _, _ = mix_aug(images, args)
225 |
226 | pred_label = student_model(images)
227 |
228 | soft_mix_label = teacher_model(mix_images)
229 | soft_mix_label = F.softmax(soft_mix_label / args.temperature, dim=1)
230 |
231 | if batch_idx % args.re_accum_steps == 0:
232 | optimizer.zero_grad()
233 |
234 | prec1, prec5 = accuracy(pred_label, labels, topk=(1, 5))
235 |
236 | pred_mix_label = student_model(mix_images)
237 |
238 | soft_pred_mix_label = F.log_softmax(pred_mix_label / args.temperature, dim=1)
239 | loss = loss_function_kl(soft_pred_mix_label, soft_mix_label)
240 |
241 | loss = loss / args.re_accum_steps
242 |
243 | loss.backward()
244 | if batch_idx % args.re_accum_steps == (args.re_accum_steps - 1):
245 | optimizer.step()
246 |
247 | n = images.size(0)
248 | objs.update(loss.item(), n)
249 | top1.update(prec1.item(), n)
250 | top5.update(prec5.item(), n)
251 |
252 | scheduler = args.scheduler
253 | metrics = {
254 | "train/Top1": top1.avg,
255 | "train/Top5": top5.avg,
256 | "train/loss": objs.avg,
257 | "train/lr": scheduler.get_last_lr()[0],
258 | "train/epoch": epoch,}
259 | wandb_metrics.update(metrics)
260 |
261 | printInfo = (
262 | "TRAIN Iter {}: loss = {:.6f},\t".format(epoch, objs.avg)
263 | + "Top-1 err = {:.6f},\t".format(100 - top1.avg)
264 | + "Top-5 err = {:.6f},\t".format(100 - top5.avg)
265 | + "train_time = {:.6f}".format((time.time() - t1))
266 | )
267 | print(printInfo)
268 | t1 = time.time()
269 |
270 |
271 | def validate(model, args, epoch=None):
272 | objs = AverageMeter()
273 | top1 = AverageMeter()
274 | top5 = AverageMeter()
275 | loss_function = nn.CrossEntropyLoss()
276 |
277 | model.eval()
278 | t1 = time.time()
279 | with torch.no_grad():
280 | for data, target in args.val_loader:
281 | target = target.type(torch.LongTensor)
282 | data, target = data.cuda(), target.cuda()
283 |
284 | output = model(data)
285 | loss = loss_function(output, target)
286 |
287 | prec1, prec5 = accuracy(output, target, topk=(1, 5))
288 | n = data.size(0)
289 | objs.update(loss.item(), n)
290 | top1.update(prec1.item(), n)
291 | top5.update(prec5.item(), n)
292 |
293 | logInfo = (
294 | "TEST:\nIter {}: loss = {:.6f},\t".format(epoch, objs.avg)
295 | + "Top-1 err = {:.6f},\t".format(100 - top1.avg)
296 | + "Top-5 err = {:.6f},\t".format(100 - top5.avg)
297 | + "val_time = {:.6f}".format(time.time() - t1)
298 | )
299 | print(logInfo)
300 | metrics = {
301 | 'val/top1': top1.avg,
302 | 'val/top5': top5.avg,
303 | 'val/loss': objs.avg,
304 | 'val/epoch': epoch,
305 | }
306 | wandb_metrics.update(metrics)
307 | return top1.avg
308 |
309 |
310 | if __name__ == "__main__":
311 | pass
312 | # main(args)
313 |
--------------------------------------------------------------------------------
/evaluation/validation/utils.py:
--------------------------------------------------------------------------------
1 | """this code is modified from the RDED repo: https://github.com/LINs-lab/RDED/tree/main"""
2 | import torch
3 | import numpy as np
4 | import os
5 | import torch.distributed
6 | import torchvision
7 | from torchvision.transforms import functional as t_F
8 | import torch.nn.functional as F
9 | import random
10 |
11 |
12 | # keep top k largest values, and smooth others
13 | def keep_top_k(p, k, n_classes=1000): # p is the softmax on label output
14 | if k == n_classes:
15 | return p
16 |
17 | values, indices = p.topk(k, dim=1)
18 |
19 | mask_topk = torch.zeros_like(p)
20 | mask_topk.scatter_(-1, indices, 1.0)
21 | top_p = mask_topk * p
22 |
23 | minor_value = (1 - torch.sum(values, dim=1)) / (n_classes - k)
24 | minor_value = minor_value.unsqueeze(1).expand(p.shape)
25 | mask_smooth = torch.ones_like(p)
26 | mask_smooth.scatter_(-1, indices, 0)
27 | smooth_p = mask_smooth * minor_value
28 |
29 | topk_smooth_p = top_p + smooth_p
30 | assert np.isclose(
31 | topk_smooth_p.sum().item(), p.shape[0]
32 | ), f"{topk_smooth_p.sum().item()} not close to {p.shape[0]}"
33 | return topk_smooth_p
34 |
35 |
36 | class AverageMeter(object):
37 | def __init__(self):
38 | self.reset()
39 |
40 | def reset(self):
41 | self.avg = 0
42 | self.sum = 0
43 | self.cnt = 0
44 | self.val = 0
45 |
46 | def update(self, val, n=1):
47 | self.val = val
48 | self.sum += val * n
49 | self.cnt += n
50 | self.avg = self.sum / self.cnt
51 |
52 |
53 | def accuracy(output, target, topk=(1,)):
54 | maxk = max(topk)
55 | batch_size = target.size(0)
56 |
57 | _, pred = output.topk(maxk, 1, True, True)
58 | pred = pred.t()
59 | correct = pred.eq(target.reshape(1, -1).expand_as(pred))
60 |
61 | res = []
62 | for k in topk:
63 | correct_k = correct[:k].reshape(-1).float().sum(0)
64 | res.append(correct_k.mul_(100.0 / batch_size))
65 | return res
66 |
67 |
68 | def get_parameters(model):
69 | group_no_weight_decay = []
70 | group_weight_decay = []
71 | for pname, p in model.named_parameters():
72 | if pname.find("weight") >= 0 and len(p.size()) > 1:
73 | # print('include ', pname, p.size())
74 | group_weight_decay.append(p)
75 | else:
76 | # print('not include ', pname, p.size())
77 | group_no_weight_decay.append(p)
78 | assert len(list(model.parameters())) == len(group_weight_decay) + len(
79 | group_no_weight_decay
80 | )
81 | groups = [
82 | dict(params=group_weight_decay),
83 | dict(params=group_no_weight_decay, weight_decay=0.0),
84 | ]
85 | return groups
86 |
87 |
88 | class ImageFolder(torchvision.datasets.ImageFolder):
89 | def __init__(self, classes, ipc, mem=False, shuffle=False, **kwargs):
90 | super(ImageFolder, self).__init__(**kwargs)
91 | self.mem = mem
92 | self.image_paths = []
93 | self.targets = []
94 | self.samples = []
95 | dirlist = []
96 | for name in os.listdir(self.root):
97 | if os.path.isdir(os.path.join(self.root,name)):
98 | dirlist.append(name)
99 | dirlist.sort()
100 | # print(f"Num of dirs: {len(dirlist)}")
101 | # print(f"Num of classes: {len(classes)}")
102 | for c in range(len(classes)):
103 | # print(self.root)
104 | # print(dirlist)
105 | dir_path = os.path.join(self.root, dirlist[c])
106 | # print("\n\n\n")
107 |
108 | # self.root + "/" + str(classes[c]).zfill(5)
109 | # print(dir_path)
110 | file_ls = os.listdir(dir_path)
111 | # exit()
112 | if shuffle:
113 | random.shuffle(file_ls)
114 | # print(len(file_ls))
115 | # print(f"IPC: {ipc}")
116 | for i in range(ipc):
117 | self.image_paths.append(dir_path + "/" + file_ls[i])
118 | self.targets.append(c)
119 | if self.mem:
120 | self.samples.append(self.loader(dir_path + "/" + file_ls[i]))
121 |
122 | def __getitem__(self, index):
123 | if self.mem:
124 | sample = self.samples[index]
125 | else:
126 | sample = self.loader(self.image_paths[index])
127 | sample = self.transform(sample)
128 | return sample, self.targets[index]
129 |
130 | def __len__(self):
131 | return len(self.targets)
132 |
133 |
134 | def rand_bbox(size, lam):
135 | W = size[2]
136 | H = size[3]
137 | cut_rat = np.sqrt(1.0 - lam)
138 | cut_w = int(W * cut_rat)
139 | cut_h = int(H * cut_rat)
140 |
141 | # uniform
142 | cx = np.random.randint(W)
143 | cy = np.random.randint(H)
144 |
145 | bbx1 = np.clip(cx - cut_w // 2, 0, W)
146 | bby1 = np.clip(cy - cut_h // 2, 0, H)
147 | bbx2 = np.clip(cx + cut_w // 2, 0, W)
148 | bby2 = np.clip(cy + cut_h // 2, 0, H)
149 |
150 | return bbx1, bby1, bbx2, bby2
151 |
152 |
153 | def cutmix(images, args, rand_index=None, lam=None, bbox=None):
154 | rand_index = torch.randperm(images.size()[0]).cuda()
155 | lam = np.random.beta(args.cutmix, args.cutmix)
156 | bbx1, bby1, bbx2, bby2 = rand_bbox(images.size(), lam)
157 |
158 | images[:, :, bbx1:bbx2, bby1:bby2] = images[rand_index, :, bbx1:bbx2, bby1:bby2]
159 | return images, rand_index.cpu(), lam, [bbx1, bby1, bbx2, bby2]
160 |
161 |
162 | def mixup(images, args, rand_index=None, lam=None):
163 | rand_index = torch.randperm(images.size()[0]).cuda()
164 | lam = np.random.beta(args.mixup, args.mixup)
165 |
166 | mixed_images = lam * images + (1 - lam) * images[rand_index]
167 | return mixed_images, rand_index.cpu(), lam, None
168 |
169 |
170 | def mix_aug(images, args, rand_index=None, lam=None, bbox=None):
171 | if args.mix_type == "mixup":
172 | return mixup(images, args, rand_index, lam)
173 | elif args.mix_type == "cutmix":
174 | return cutmix(images, args, rand_index, lam, bbox)
175 | else:
176 | return images, None, None, None
177 |
178 |
179 | class ShufflePatches(torch.nn.Module):
180 | def shuffle_weight(self, img, factor):
181 | h, w = img.shape[1:]
182 | th, tw = h // factor, w // factor
183 | patches = []
184 | for i in range(factor):
185 | i = i * tw
186 | if i != factor - 1:
187 | patches.append(img[..., i : i + tw])
188 | else:
189 | patches.append(img[..., i:])
190 | random.shuffle(patches)
191 | img = torch.cat(patches, -1)
192 | return img
193 |
194 | def __init__(self, factor):
195 | super().__init__()
196 | self.factor = factor
197 |
198 | def forward(self, img):
199 | img = self.shuffle_weight(img, self.factor)
200 | img = img.permute(0, 2, 1)
201 | img = self.shuffle_weight(img, self.factor)
202 | img = img.permute(0, 2, 1)
203 | return img
204 |
--------------------------------------------------------------------------------
/imgs/DELT_framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VILA-Lab/DELT/1a45beb6381856081be6d48997c848dca50a375a/imgs/DELT_framework.png
--------------------------------------------------------------------------------
/imgs/diversity_comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VILA-Lab/DELT/1a45beb6381856081be6d48997c848dca50a375a/imgs/diversity_comparison.png
--------------------------------------------------------------------------------
/imgs/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VILA-Lab/DELT/1a45beb6381856081be6d48997c848dca50a375a/imgs/results.png
--------------------------------------------------------------------------------
/imgs/samples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VILA-Lab/DELT/1a45beb6381856081be6d48997c848dca50a375a/imgs/samples.png
--------------------------------------------------------------------------------
/recover/models.py:
--------------------------------------------------------------------------------
1 | """Adopted from https://github.com/LINs-lab/RDED/blob/main/synthesize/models.py to load the custom conv models"""
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torchvision.models as models
6 |
7 | # Conv-3 model
8 | class ConvNet(nn.Module):
9 | def __init__(
10 | self,
11 | num_classes,
12 | net_norm="batch",
13 | net_depth=3,
14 | net_width=128,
15 | channel=3,
16 | net_act="relu",
17 | net_pooling="avgpooling",
18 | im_size=(32, 32),
19 | ):
20 | # print(f"Define Convnet (depth {net_depth}, width {net_width}, norm {net_norm})")
21 | super(ConvNet, self).__init__()
22 | if net_act == "sigmoid":
23 | self.net_act = nn.Sigmoid()
24 | elif net_act == "relu":
25 | self.net_act = nn.ReLU()
26 | elif net_act == "leakyrelu":
27 | self.net_act = nn.LeakyReLU(negative_slope=0.01)
28 | else:
29 | exit("unknown activation function: %s" % net_act)
30 |
31 | if net_pooling == "maxpooling":
32 | self.net_pooling = nn.MaxPool2d(kernel_size=2, stride=2)
33 | elif net_pooling == "avgpooling":
34 | self.net_pooling = nn.AvgPool2d(kernel_size=2, stride=2)
35 | elif net_pooling == "none":
36 | self.net_pooling = None
37 | else:
38 | exit("unknown net_pooling: %s" % net_pooling)
39 |
40 | self.depth = net_depth
41 | self.net_norm = net_norm
42 |
43 | self.layers, shape_feat = self._make_layers(
44 | channel, net_width, net_depth, net_norm, net_pooling, im_size
45 | )
46 | num_feat = shape_feat[0] * shape_feat[1] * shape_feat[2]
47 | self.classifier = nn.Linear(num_feat, num_classes)
48 |
49 | def forward(self, x, return_features=False):
50 | for d in range(self.depth):
51 | x = self.layers["conv"][d](x)
52 | if len(self.layers["norm"]) > 0:
53 | x = self.layers["norm"][d](x)
54 | x = self.layers["act"][d](x)
55 | if len(self.layers["pool"]) > 0:
56 | x = self.layers["pool"][d](x)
57 |
58 | # x = nn.functional.avg_pool2d(x, x.shape[-1])
59 | out = x.view(x.shape[0], -1)
60 | logit = self.classifier(out)
61 |
62 | if return_features:
63 | return logit, out
64 | else:
65 | return logit
66 |
67 | def get_feature(
68 | self, x, idx_from, idx_to=-1, return_prob=False, return_logit=False
69 | ):
70 | if idx_to == -1:
71 | idx_to = idx_from
72 | features = []
73 |
74 | for d in range(self.depth):
75 | x = self.layers["conv"][d](x)
76 | if self.net_norm:
77 | x = self.layers["norm"][d](x)
78 | x = self.layers["act"][d](x)
79 | if self.net_pooling:
80 | x = self.layers["pool"][d](x)
81 | features.append(x)
82 | if idx_to < len(features):
83 | return features[idx_from : idx_to + 1]
84 |
85 | if return_prob:
86 | out = x.view(x.size(0), -1)
87 | logit = self.classifier(out)
88 | prob = torch.softmax(logit, dim=-1)
89 | return features, prob
90 | elif return_logit:
91 | out = x.view(x.size(0), -1)
92 | logit = self.classifier(out)
93 | return features, logit
94 | else:
95 | return features[idx_from : idx_to + 1]
96 |
97 | def _get_normlayer(self, net_norm, shape_feat):
98 | # shape_feat = (c * h * w)
99 | if net_norm == "batch":
100 | norm = nn.BatchNorm2d(shape_feat[0], affine=True)
101 | elif net_norm == "layer":
102 | norm = nn.LayerNorm(shape_feat, elementwise_affine=True)
103 | elif net_norm == "instance":
104 | norm = nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True)
105 | elif net_norm == "group":
106 | norm = nn.GroupNorm(4, shape_feat[0], affine=True)
107 | elif net_norm == "none":
108 | norm = None
109 | else:
110 | norm = None
111 | exit("unknown net_norm: %s" % net_norm)
112 | return norm
113 |
114 | def _make_layers(
115 | self, channel, net_width, net_depth, net_norm, net_pooling, im_size
116 | ):
117 | layers = {"conv": [], "norm": [], "act": [], "pool": []}
118 |
119 | in_channels = channel
120 | if im_size[0] == 28:
121 | im_size = (32, 32)
122 | shape_feat = [in_channels, im_size[0], im_size[1]]
123 |
124 | for d in range(net_depth):
125 | layers["conv"] += [
126 | nn.Conv2d(
127 | in_channels,
128 | net_width,
129 | kernel_size=3,
130 | padding=3 if channel == 1 and d == 0 else 1,
131 | )
132 | ]
133 | shape_feat[0] = net_width
134 | if net_norm != "none":
135 | layers["norm"] += [self._get_normlayer(net_norm, shape_feat)]
136 | layers["act"] += [self.net_act]
137 | in_channels = net_width
138 | if net_pooling != "none":
139 | layers["pool"] += [self.net_pooling]
140 | shape_feat[1] //= 2
141 | shape_feat[2] //= 2
142 |
143 | layers["conv"] = nn.ModuleList(layers["conv"])
144 | layers["norm"] = nn.ModuleList(layers["norm"])
145 | layers["act"] = nn.ModuleList(layers["act"])
146 | layers["pool"] = nn.ModuleList(layers["pool"])
147 | layers = nn.ModuleDict(layers)
148 |
149 | return layers, shape_feat
150 |
151 | def load_model(model_name="resnet18", dataset="cifar10", pretrained=True, classes=[]):
152 | def get_model(model_name="resnet18"):
153 | if "conv" in model_name:
154 | if dataset in ["cifar10", "cifar100"]:
155 | size = 32
156 | elif dataset in ["tinyimagenet", "imagenet-1k"]:
157 | size = 64
158 | elif dataset in ["imagenet-nette", "imagenet-woof", "imagenet-100"]:
159 | size = 128
160 | else:
161 | raise Exception("Unrecognized dataset")
162 |
163 | nclass = len(classes)
164 |
165 | model = ConvNet(
166 | num_classes=nclass,
167 | net_norm="batch",
168 | net_act="relu",
169 | net_pooling="avgpooling",
170 | net_depth=int(model_name[-1]),
171 | net_width=128,
172 | channel=3,
173 | im_size=(size, size),
174 | )
175 | elif model_name == "resnet18_modified":
176 | model = models.__dict__["resnet18"](pretrained=False)
177 | model.conv1 = nn.Conv2d(
178 | 3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
179 | )
180 | model.maxpool = nn.Identity()
181 | elif model_name == "resnet101_modified":
182 | model = models.__dict__["resnet101"](pretrained=False)
183 | model.conv1 = nn.Conv2d(
184 | 3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
185 | )
186 | model.maxpool = nn.Identity()
187 | else:
188 | model = models.__dict__[model_name](pretrained=False)
189 |
190 | return model
191 |
192 | def pruning_classifier(model=None, classes=[]):
193 | try:
194 | model_named_parameters = [name for name, x in model.named_parameters()]
195 | for name, x in model.named_parameters():
196 | if (
197 | name == model_named_parameters[-1]
198 | or name == model_named_parameters[-2]
199 | ):
200 | x.data = x[classes]
201 | except:
202 | print("ERROR in changing the number of classes.")
203 |
204 | return model
205 |
206 | # "imagenet-100" "imagenet-10" "imagenet-first" "imagenet-nette" "imagenet-woof"
207 | model = get_model(model_name)
208 | model = pruning_classifier(model, classes)
209 | if pretrained:
210 | if dataset in [
211 | "imagenet-100",
212 | "imagenet-10",
213 | "imagenet-nette",
214 | "imagenet-woof",
215 | "tinyimagenet",
216 | "cifar10",
217 | "cifar100",
218 | ]:
219 | checkpoint = torch.load(
220 | f"/workspace/save/{dataset}_{model_name}.pth", map_location="cpu"
221 | )
222 | model.load_state_dict(checkpoint["model"])
223 | elif dataset in ["imagenet-1k"]:
224 | if model_name == "efficientNet-b0":
225 | # Specifically, for loading the pre-trained EfficientNet model, the following modifications are made
226 | from torchvision.models._api import WeightsEnum
227 | from torch.hub import load_state_dict_from_url
228 |
229 | def get_state_dict(self, *args, **kwargs):
230 | kwargs.pop("check_hash")
231 | return load_state_dict_from_url(self.url, *args, **kwargs)
232 |
233 | WeightsEnum.get_state_dict = get_state_dict
234 | elif "conv" in model_name:
235 | checkpoint = torch.load(
236 | f"/workspace/save/{dataset}_{model_name}.pth", map_location="cpu"
237 | )
238 | model.load_state_dict(checkpoint["model"])
239 | else:
240 | model = models.__dict__[model_name](pretrained=True)
241 | return model
--------------------------------------------------------------------------------
/recover/recover.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import collections
3 | import os
4 | import random
5 |
6 | import numpy as np
7 | import torch
8 |
9 | from utils import BNFeatureHook, clip, lr_cosine_policy, ImageFolder
10 | from models import load_model # Load the custom conv models
11 |
12 | from torch.utils.data import DataLoader
13 |
14 | import torch.nn as nn
15 | import torch.optim as optim
16 | import torchvision.models as models
17 | from PIL import Image
18 | from torchvision import transforms
19 |
20 |
21 | from tqdm import tqdm
22 |
23 |
24 | def get_round_imgs_indx(args, class_iterations, targets, img_indxer):
25 | # ===================================================================================
26 | # first, get the maximum iterations for each class
27 | # ===================================================================================
28 | max_itr = class_iterations[targets.cpu().unique()].max().item()
29 | # initialize the lists that include
30 | slicing_list = [[] for i in range(max_itr//args.round_iterations)]
31 | for class_id in targets.unique():
32 | init_img_indx = -1 if class_id.item() not in img_indxer else img_indxer[class_id.item()]
33 | # if we have a single
34 | if (targets==class_id).sum() == 1:
35 | img_index_in_batch = [(targets==class_id).nonzero().squeeze().item()]
36 | else:
37 | img_index_in_batch = (targets==class_id).nonzero().squeeze().tolist()
38 | img_indices = (np.array(img_index_in_batch) - img_index_in_batch[0] \
39 | + init_img_indx + 1).tolist()
40 | # get the class rounds
41 | class_rounds = class_iterations[class_id.item()]//args.round_iterations
42 |
43 | round_imgs_indx = np.ones(class_rounds, dtype=int)*int(args.ipc//class_rounds)
44 | round_imgs_indx[:args.ipc%class_rounds] += 1
45 | round_imgs_indx = np.cumsum(round_imgs_indx)
46 | # Loop over every image in the batch
47 | for img_indx, batch_indx in zip(img_indices, img_index_in_batch):
48 | # get the rounds that include this image
49 | start_round = (img_indx < round_imgs_indx).nonzero()[0][0]
50 | for round_indx in range(start_round, len(round_imgs_indx)):
51 | slicing_list[round_indx].append(batch_indx)
52 | return slicing_list
53 |
54 | def get_images(args, inputs, targets, model_teacher, loss_r_feature_layers,
55 | class_iterations, img_indxer):
56 | save_every = 100
57 |
58 | best_cost = 1e4
59 |
60 | mean, std = get_mean_std(args)
61 | # =======================================================================================
62 | # Prepare the labels and the inputs
63 | targets = targets.to('cuda')
64 | data_type = torch.float
65 | inputs = inputs.type(data_type)
66 | inputs = inputs.to('cuda')
67 |
68 | # ---------------------------------------------------------------------------------------
69 | # store the initial images to be used in the skip connection if needed
70 | # ---------------------------------------------------------------------------------------
71 | total_iterations = args.iteration
72 | skip_images = None
73 | inputs.requires_grad = True
74 | if args.use_early_late or (args.min_iterations != args.iteration):
75 | slicing_list = get_round_imgs_indx(args, class_iterations, targets, img_indxer)
76 | # ===================================================================================
77 | # Finally, check if there is an empty round
78 | # ===================================================================================
79 | round_indx = 0
80 | list_size = len(slicing_list)
81 | while round_indx < list_size:
82 | if len(slicing_list[round_indx]) == 0: # an empty round
83 | del slicing_list[round_indx] # delete the round
84 | round_indx -= 1 # reduce the number of rounds
85 | list_size -= 1
86 | total_iterations -= args.round_iterations # reduce the number of iterations
87 | round_indx += 1
88 |
89 | iterations_per_layer = total_iterations
90 | lim_0, lim_1 = args.jitter, args.jitter
91 |
92 | optimizer = optim.Adam([inputs], lr=args.lr, betas=[0.5, 0.9], eps=1e-8)
93 | lr_scheduler = lr_cosine_policy(args.lr, 0, iterations_per_layer) # 0 - do not use warmup
94 |
95 | criterion = nn.CrossEntropyLoss()
96 | criterion = criterion.cuda()
97 | round_indx= 0
98 | for iteration in range(iterations_per_layer):
99 | # ===================================================================================
100 | # Identify the round index for the Early-Late Training scheme
101 | # ===================================================================================
102 | if args.round_iterations != 0 and (args.use_early_late or \
103 | (args.min_iterations != args.iteration)):
104 | round_indx = iteration//args.round_iterations
105 | # -----------------------------------------------------------------------------------
106 |
107 | # learning rate scheduling: reset the scheduling per round
108 | if args.round_lr_reset and args.round_iterations != 0:
109 | lr_scheduler(optimizer, iteration%args.round_iterations, iteration%args.round_iterations)
110 | else:
111 | lr_scheduler(optimizer, iteration, iteration)
112 | lr_scheduler(optimizer, iteration, iteration)
113 | # ===================================================================================
114 | # strategy: start with whole image with mix crop of 1, then lower to 0.08
115 | # easy to hard
116 | min_crop = 0.08
117 | max_crop = 1.0
118 | if iteration < args.milestone * iterations_per_layer:
119 | if args.easy2hard_mode == "step":
120 | min_crop = 1.0
121 | elif args.easy2hard_mode == "linear":
122 | # min_crop linear decreasing: 1.0 -> 0.08
123 | min_crop = 0.08 + (1.0 - 0.08) * (1 - iteration / (args.milestone * iterations_per_layer))
124 | elif args.easy2hard_mode == "cosine":
125 | # min_crop cosine decreasing: 1.0 -> 0.08
126 | min_crop = 0.08 + (1.0 - 0.08) * (1 + np.cos(np.pi * iteration / (args.milestone * iterations_per_layer))) / 2
127 |
128 | aug_function = transforms.Compose(
129 | [
130 | # transforms.RandomResizedCrop(224, scale=(0.08, 1.0)),
131 | transforms.RandomResizedCrop(args.input_size, scale=(min_crop, max_crop)),
132 | transforms.RandomHorizontalFlip(),
133 | ]
134 | )
135 | if args.round_iterations != 0 and (args.use_early_late or (args.min_iterations != args.iteration)):
136 | inputs_jit = aug_function(inputs[slicing_list[round_indx]])
137 | else:
138 | inputs_jit = aug_function(inputs)
139 | # apply random jitter offsets
140 | off1 = random.randint(0, lim_0)
141 | off2 = random.randint(0, lim_1)
142 | inputs_jit = torch.roll(inputs_jit, shifts=(off1, off2), dims=(2, 3))
143 |
144 | # forward pass
145 | optimizer.zero_grad()
146 | outputs = model_teacher(inputs_jit)
147 |
148 | # R_cross classification loss
149 | if args.round_iterations != 0 and (args.use_early_late or (args.min_iterations != args.iteration)):
150 | loss_ce = criterion(outputs, targets[slicing_list[round_indx]])
151 | else:
152 | loss_ce = criterion(outputs, targets)
153 |
154 | # R_feature loss
155 | rescale = [args.first_bn_multiplier] + [1.0 for _ in range(len(loss_r_feature_layers) - 1)]
156 | loss_r_bn_feature = sum([mod.r_feature * rescale[idx] for (idx, mod) in enumerate(loss_r_feature_layers)])
157 |
158 | # combining losses
159 | loss_aux = args.r_bn * loss_r_bn_feature
160 |
161 | loss = loss_ce + loss_aux
162 |
163 | # ===================================================================================
164 | # do image update
165 | # ===================================================================================
166 | loss.backward()
167 | # ===================================================================================
168 | optimizer.step()
169 |
170 | # clip color outlayers
171 | inputs.data = clip(inputs.data, mean=mean, std= std)
172 |
173 | if best_cost > loss.item() or iteration == 1:
174 | best_inputs = inputs.data.clone()
175 | if args.store_best_images:
176 | best_inputs = inputs.data.clone() # using multicrop, save the last one
177 | # add the denormalizer
178 | denormalize = transforms.Compose(
179 | [
180 | transforms.Normalize(
181 | mean=[0.0, 0.0, 0.0], std= 1/std
182 | ),
183 | transforms.Normalize(mean= -mean, std=[1.0, 1.0, 1.0]),
184 | ]
185 | )
186 | best_inputs = denormalize(best_inputs)
187 | save_images(args, best_inputs, targets, round_indx)
188 |
189 | # to reduce memory consumption by states of the optimizer we deallocate memory
190 | optimizer.state = collections.defaultdict(dict)
191 | torch.cuda.empty_cache()
192 |
193 |
194 | def save_images(args, images, targets, round_indx=0, indx_list=None):
195 | for id in range(images.shape[0]):
196 | if targets.ndimension() == 1:
197 | class_id = targets[id].item()
198 | else:
199 | class_id = targets[id].argmax().item()
200 |
201 | if not os.path.exists(args.syn_data_path):
202 | os.mkdir(args.syn_data_path)
203 |
204 | # save into separate folders
205 | dir_path = "{}/new{:03d}".format(args.syn_data_path, class_id)
206 | place_to_store = dir_path + "/class{:03d}_id{:03d}.jpg".format(class_id, get_img_indx(class_id))
207 |
208 | if not os.path.exists(dir_path):
209 | os.makedirs(dir_path)
210 |
211 | image_np = images[id].data.cpu().numpy().transpose((1, 2, 0))
212 | pil_image = Image.fromarray((image_np * 255).astype(np.uint8))
213 | pil_image.save(place_to_store)
214 |
215 | def get_mean_std(args):
216 | if args.dataset in ["imagenet-1k", "imagenet-100", "imagenet-woof", "imagenet-nette"]:
217 | mean = np.array([0.485, 0.456, 0.406])
218 | std = np.array([0.229, 0.224, 0.225])
219 | elif args.dataset == "tinyimagenet":
220 | mean = np.array([0.4802, 0.4481, 0.3975])
221 | std = np.array([0.2302, 0.2265, 0.2262])
222 | elif args.dataset in ["cifar10", "cifar100"]:
223 | mean = np.array([0.4914, 0.4822, 0.4465])
224 | std = np.array([0.2023, 0.1994, 0.2010])
225 | return mean, std
226 |
227 | def main_syn(args):
228 | if "conv" in args.arch_name:
229 | model_teacher = load_model(model_name=args.arch_name,
230 | dataset=args.dataset,
231 | pretrained=True,
232 | classes=range(args.num_classes))
233 | elif args.arch_path:
234 | model_teacher = models.get_model(args.arch_name, weights=False, num_classes=args.num_classes)
235 | if args.dataset in ["cifar10", "cifar100", "tinyimagenet"]:
236 | model_teacher.conv1 = nn.Conv2d(
237 | 3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
238 | )
239 | model_teacher.maxpool = nn.Identity()
240 | checkpoint = torch.load(args.arch_path, map_location="cpu")
241 | model_teacher.load_state_dict(checkpoint["model"])
242 | else:
243 | model_teacher = models.get_model(args.arch_name, weights="DEFAULT")
244 | model_teacher = nn.DataParallel(model_teacher).cuda()
245 | model_teacher.eval()
246 | for p in model_teacher.parameters():
247 | p.requires_grad = False
248 |
249 | # load the init_set
250 | # =======================================================================================
251 | # do some augmentation to the initialized images
252 | transforms_list = []
253 | mean, std = get_mean_std(args)
254 | transforms_list += [ transforms.Resize(args.input_size // 7 * 8, antialias=True),
255 | transforms.CenterCrop(args.input_size),
256 | transforms.ToTensor(),
257 | transforms.Normalize(mean= mean,
258 | std= std)
259 | ]
260 | init_data = ImageFolder(
261 | ipc = args.ipc,
262 | shuffle = False,
263 | root=args.init_data_path,
264 | transform=transforms.Compose(transforms_list)
265 | )
266 |
267 | data_loader = DataLoader(init_data, batch_size=args.batch_size,
268 | shuffle=False, num_workers=args.workers,
269 | pin_memory=True)
270 | # Loop over the images
271 | global img_indxer
272 | img_indxer = {}
273 |
274 | # =======================================================================================
275 | # define the number of gradient iterations for each class
276 | # =======================================================================================
277 | class_iterations = torch.ones((args.num_classes,),dtype=int)*args.iteration
278 | if args.round_iterations != 0:
279 | min_rounds = args.min_iterations//args.round_iterations
280 | max_rounds = args.iteration//args.round_iterations
281 | # randomly select the number of rounds for each class
282 | class_iterations = torch.randint(min_rounds, max_rounds+1, (args.num_classes,))
283 | # get the number of iterations by multiplying the round number * iteration per round
284 | class_iterations = class_iterations*args.round_iterations
285 | print(f"Class Iterations: {class_iterations}")
286 |
287 | loss_r_feature_layers = []
288 | for module in model_teacher.modules():
289 | if isinstance(module, nn.BatchNorm2d):
290 | loss_r_feature_layers.append(BNFeatureHook(module))
291 |
292 |
293 | for images, labels in tqdm(data_loader):
294 | get_images(args, images, labels, model_teacher, loss_r_feature_layers,
295 | class_iterations, img_indxer)
296 |
297 |
298 | def get_img_indx(class_id):
299 | global img_indxer
300 | if class_id not in img_indxer:
301 | img_indxer[class_id] = 0
302 | else:
303 | img_indxer[class_id] += 1
304 | return img_indxer[class_id]
305 |
306 | def parse_args():
307 | parser = argparse.ArgumentParser("DELT: Early-Late Recovery scheme for different datasets")
308 | """Data save flags"""
309 | parser.add_argument("--exp-name", type=str, default="test", help="name of the experiment, subfolder under syn_data_path")
310 | parser.add_argument("--init-data-path", type=str, default="/workspace/data/RDED_IN_1K/ipc50_cr5_mipc300", help="location of initialization data")
311 | parser.add_argument("--syn-data-path", type=str, default="./syn-data", help="where to store synthetic data")
312 | parser.add_argument("--store-best-images", action="store_true", help="whether to store best images")
313 | parser.add_argument("--ipc", type=int, default=50, help="number of IPC to use")
314 | parser.add_argument("--dataset", type=str, default="imagenet-1k", help="dataset to use")
315 | parser.add_argument("--input-size", type=int, default=224, help="image input size")
316 | parser.add_argument("--num-classes", type=int, default=1000, help="number of classes")
317 |
318 | """Early-Late Training flags"""
319 | parser.add_argument("--use-early-late", action="store_true", default=False, help="use a incremental learn")
320 | parser.add_argument("--round-iterations", type=int, default=0, help="number of iterations in a single round")
321 | parser.add_argument("--min-iterations", type=int, default=-1, help="minimum num of iterations to optimize the synthetic data of a specific class")
322 | parser.add_argument("--round-lr-reset", action="store_true", default=False, help="reset the lr per round")
323 | """Optimization related flags"""
324 | parser.add_argument("--batch-size", type=int, default=100, help="number of images to optimize at the same time")
325 | parser.add_argument('-j', '--workers', default=16, type=int, help='number of data loading workers')
326 | parser.add_argument("--iteration", type=int, default=1000, help="num of iterations to optimize the synthetic data")
327 | parser.add_argument("--lr", type=float, default=0.1, help="learning rate for optimization")
328 | parser.add_argument("--jitter", default=32, type=int, help="random shift on the synthetic data")
329 | parser.add_argument("--r-bn", type=float, default=0.05, help="coefficient for BN feature distribution regularization")
330 | parser.add_argument("--first-bn-multiplier", type=float, default=10.0, help="additional multiplier on first bn layer of R_bn")
331 | """Model related flags"""
332 | parser.add_argument("--arch-name", type=str, default="resnet18", help="arch name from pretrained torchvision models")
333 | parser.add_argument("--arch-path", type=str, default="", help="path to the teacher model")
334 | parser.add_argument("--easy2hard-mode", default="cosine", type=str, choices=["step", "linear", "cosine"])
335 | parser.add_argument("--milestone", default=0, type=float)
336 | parser.add_argument('--gpu-device', default="-1", type=str)
337 | args = parser.parse_args()
338 |
339 | assert args.milestone >= 0 and args.milestone <= 1
340 | # assert args.batch_size%args.ipc == 0 and args.batch_size != 0
341 |
342 | if args.gpu_device != "-1":
343 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_device
344 | print("set CUDA_VISIBLE_DEVICES to ", args.gpu_device)
345 |
346 | # =======================================================================================
347 | # Do some initializations
348 | # =======================================================================================
349 | if args.dataset == "imagenet-1k":
350 | args.num_classes = 1000
351 | elif args.dataset == "tinyimagenet":
352 | args.num_classes = 200
353 | elif args.dataset in ["imagenet-100", "cifar100"]:
354 | args.num_classes = 100
355 | elif args.dataset in ["imagenet-woof", "imagenet-nette", "cifar10"]:
356 | args.num_classes = 10
357 |
358 | if args.dataset in ["imagenet-1k", "imagenet-100", "imagenet-woof", "imagenet-nette"]:
359 | args.input_size = 224
360 | elif args.dataset == "tinyimagenet":
361 | args.input_size = 64
362 | elif args.dataset in ["cifar10", "cifar100"]:
363 | args.input_size = 32
364 |
365 | if "conv" in args.arch_name:
366 | if args.dataset in ["imagenet-100", "imagenet-woof", "imagenet-nette"]:
367 | args.input_size = 128
368 | elif args.dataset in ["tinyimagenet", "imagenet-1k"]:
369 | args.input_size = 64
370 | elif args.dataset in ["cifar10", "cifar100"]:
371 | args.input_size = 32
372 |
373 | if args.min_iterations == -1:
374 | args.min_iterations = args.iteration
375 | else:
376 | args.min_iterations = max(args.min_iterations, args.round_iterations)
377 |
378 | args.syn_data_path = os.path.join(args.syn_data_path, args.exp_name)
379 | print(args)
380 | return args
381 |
382 |
383 | if __name__ == "__main__":
384 | args = parse_args()
385 |
386 | if not os.path.exists(args.syn_data_path):
387 | os.makedirs(args.syn_data_path)
388 |
389 | main_syn(args)
390 | print("Done.")
391 |
--------------------------------------------------------------------------------
/recover/utils.py:
--------------------------------------------------------------------------------
1 | """Modifying the utils from the original CDA here https://github.com/VILA-Lab/SRe2L/blob/main/CDA/utils.py"""
2 | import numpy as np
3 | import torch
4 | import os
5 | import torchvision
6 | import random
7 |
8 | def clip(image_tensor, mean = np.array([0.485, 0.456, 0.406]),
9 | std = np.array([0.229, 0.224, 0.225])):
10 | """
11 | adjust the input based on mean and variance
12 | """
13 | # mean = np.array([0.485, 0.456, 0.406])
14 | # std = np.array([0.229, 0.224, 0.225])
15 | for c in range(3):
16 | m, s = mean[c], std[c]
17 | image_tensor[:, c] = torch.clamp(image_tensor[:, c], -m / s, (1 - m) / s)
18 | return image_tensor
19 |
20 |
21 | def tiny_clip(image_tensor):
22 | """
23 | adjust the input based on mean and variance, using tiny-imagenet normalization
24 | """
25 | mean = np.array([0.4802, 0.4481, 0.3975])
26 | std = np.array([0.2302, 0.2265, 0.2262])
27 |
28 | for c in range(3):
29 | m, s = mean[c], std[c]
30 | image_tensor[:, c] = torch.clamp(image_tensor[:, c], -m / s, (1 - m) / s)
31 | return image_tensor
32 |
33 |
34 | def denormalize(image_tensor, mean = np.array([0.485, 0.456, 0.406]),
35 | std = np.array([0.229, 0.224, 0.225])):
36 | """
37 | convert floats back to input
38 | """
39 | # mean = np.array([0.485, 0.456, 0.406])
40 | # std = np.array([0.229, 0.224, 0.225])
41 |
42 | for c in range(3):
43 | m, s = mean[c], std[c]
44 | image_tensor[:, c] = torch.clamp(image_tensor[:, c] * s + m, 0, 1)
45 |
46 | return image_tensor
47 |
48 |
49 | def tiny_denormalize(image_tensor):
50 | """
51 | convert floats back to input, using tiny-imagenet normalization
52 | """
53 | mean = np.array([0.4802, 0.4481, 0.3975])
54 | std = np.array([0.2302, 0.2265, 0.2262])
55 |
56 | for c in range(3):
57 | m, s = mean[c], std[c]
58 | image_tensor[:, c] = torch.clamp(image_tensor[:, c] * s + m, 0, 1)
59 |
60 | return image_tensor
61 |
62 |
63 | def lr_policy(lr_fn):
64 | def _alr(optimizer, iteration, epoch):
65 | lr = lr_fn(iteration, epoch)
66 | for param_group in optimizer.param_groups:
67 | param_group["lr"] = lr
68 |
69 | return _alr
70 |
71 |
72 | def lr_cosine_policy(base_lr, warmup_length, epochs):
73 | def _lr_fn(iteration, epoch):
74 | if epoch < warmup_length:
75 | lr = base_lr * (epoch + 1) / warmup_length
76 | else:
77 | e = epoch - warmup_length
78 | es = epochs - warmup_length
79 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
80 | return lr
81 |
82 | return lr_policy(_lr_fn)
83 |
84 |
85 | class ViT_BNFeatureHook:
86 | def __init__(self, module):
87 | self.hook = module.register_forward_hook(self.hook_fn)
88 |
89 | def hook_fn(self, module, input, output):
90 | B, N, C = input[0].shape
91 | mean = torch.mean(input[0], dim=[0, 1])
92 | var = torch.var(input[0], dim=[0, 1], unbiased=False)
93 | r_feature = torch.norm(module.running_var.data - var, 2) + torch.norm(module.running_mean.data - mean, 2)
94 | self.r_feature = r_feature
95 |
96 | def close(self):
97 | self.hook.remove()
98 |
99 |
100 | class BNFeatureHook:
101 | def __init__(self, module):
102 | self.hook = module.register_forward_hook(self.hook_fn)
103 |
104 | def hook_fn(self, module, input, output):
105 | nch = input[0].shape[1]
106 | mean = input[0].mean([0, 2, 3])
107 | var = input[0].permute(1, 0, 2, 3).contiguous().reshape([nch, -1]).var(1, unbiased=False)
108 | r_feature = torch.norm(module.running_var.data - var, 2) + torch.norm(module.running_mean.data - mean, 2)
109 | self.r_feature = r_feature
110 |
111 | def close(self):
112 | self.hook.remove()
113 |
114 |
115 | # modified from Alibaba-ImageNet21K/src_files/models/utils/factory.py
116 | def load_model_weights(model, model_path):
117 | state = torch.load(model_path, map_location="cpu")
118 |
119 | Flag = False
120 | if "state_dict" in state:
121 | # resume from a model trained with nn.DataParallel
122 | state = state["state_dict"]
123 | Flag = True
124 |
125 | for key in model.state_dict():
126 | if "num_batches_tracked" in key:
127 | continue
128 | p = model.state_dict()[key]
129 |
130 | if Flag:
131 | key = "module." + key
132 |
133 | if key in state:
134 | ip = state[key]
135 | # if key in state['state_dict']:
136 | # ip = state['state_dict'][key]
137 | if p.shape == ip.shape:
138 | p.data.copy_(ip.data) # Copy the data of parameters
139 | else:
140 | print("could not load layer: {}, mismatch shape {} ,{}".format(key, (p.shape), (ip.shape)))
141 | else:
142 | print("could not load layer: {}, not in checkpoint".format(key))
143 | return model
144 |
145 | class ImageFolder(torchvision.datasets.ImageFolder):
146 | def __init__(self, ipc = 50, mem=False, shuffle=False, **kwargs):
147 | super(ImageFolder, self).__init__(**kwargs)
148 | self.mem = mem
149 | self.image_paths = []
150 | self.targets = []
151 | self.samples = []
152 | dirlist = []
153 | for name in os.listdir(self.root):
154 | if os.path.isdir(os.path.join(self.root,name)):
155 | dirlist.append(name)
156 | dirlist.sort()
157 | for c in range(len(dirlist)):
158 | dir_path = os.path.join(self.root, dirlist[c])
159 | file_ls = os.listdir(dir_path)
160 | file_ls.sort()
161 | if shuffle:
162 | random.shuffle(file_ls)
163 | for i in range(ipc):
164 | self.image_paths.append(dir_path + "/" + file_ls[i])
165 | # print(self.image_paths)
166 | # exit()
167 | self.targets.append(c)
168 | if self.mem:
169 | self.samples.append(self.loader(dir_path + "/" + file_ls[i]))
170 |
171 | def __getitem__(self, index):
172 | if self.mem:
173 | sample = self.samples[index]
174 | else:
175 | sample = self.loader(self.image_paths[index])
176 | sample = self.transform(sample)
177 | return sample, self.targets[index]
178 |
179 | def __len__(self):
180 | return len(self.targets)
--------------------------------------------------------------------------------
/scripts/conv4_imagenet1k_synthesis.sh:
--------------------------------------------------------------------------------
1 | GPU=0
2 |
3 | # =========================================================================================
4 |
5 | SYN_PATH="/path/to/synthesized/conv_imagenet_1k/"
6 | INIT_PATH="/path/to/initialization/medium_prob"
7 | IPC=50
8 | ITR=4000
9 | ROUND_ITR=500
10 | EXP_NAME="IPC$((IPC))_4K_500_medium"
11 |
12 | echo "$EXP_NAME"
13 |
14 | python /path/to/DELT/recover/recover.py \
15 | --init-data-path "$INIT_PATH" \
16 | --syn-data-path "$SYN_PATH" \
17 | --arch-name "conv4" \
18 | --arch-path "/path/to/model/imagenet-1k_conv4.pth" \
19 | --dataset "imagenet-1k" \
20 | --exp-name "$EXP_NAME" \
21 | --use-early-late \
22 | --round-iterations $((ROUND_ITR)) \
23 | --batch-size 100 \
24 | --lr 0.25 \
25 | --r-bn 0.01 \
26 | --gpu-device $((GPU)) \
27 | --iteration $((ITR)) \
28 | --jitter 0\
29 | --easy2hard-mode "cosine" --milestone 1 \
30 | --ipc $((IPC)) --store-best-images
31 |
32 | echo "Synthesis -> DONE"
33 |
34 | # =========================================================================================
35 |
36 | IPC=10
37 | ITR=4000
38 | ROUND_ITR=500
39 | EXP_NAME="IPC$((IPC))_4K_500_medium"
40 |
41 | echo "$EXP_NAME"
42 |
43 | python /path/to/DELT/recover/recover.py \
44 | --init-data-path "$INIT_PATH" \
45 | --syn-data-path "$SYN_PATH" \
46 | --arch-name "conv4" \
47 | --arch-path "/path/to/model/imagenet-1k_conv4.pth" \
48 | --dataset "imagenet-1k" \
49 | --exp-name "$EXP_NAME" \
50 | --use-early-late \
51 | --round-iterations $((ROUND_ITR)) \
52 | --batch-size 100 \
53 | --lr 0.25 \
54 | --r-bn 0.01 \
55 | --gpu-device $((GPU)) \
56 | --iteration $((ITR)) \
57 | --jitter 0\
58 | --easy2hard-mode "cosine" --milestone 1 \
59 | --ipc $((IPC)) --store-best-images
60 |
61 | echo "Synthesis -> DONE"
62 |
63 | # =========================================================================================
64 |
65 | IPC=1
66 | ITR=4000
67 | EXP_NAME="IPC$((IPC))_4K_medium"
68 |
69 | echo "$EXP_NAME"
70 |
71 | python /path/to/DELT/recover/recover.py \
72 | --init-data-path "$INIT_PATH" \
73 | --syn-data-path "$SYN_PATH" \
74 | --arch-name "conv4" \
75 | --arch-path "/path/to/model/imagenet-1k_conv4.pth" \
76 | --dataset "imagenet-1k" \
77 | --exp-name "$EXP_NAME" \
78 | --batch-size 100 \
79 | --lr 0.25 \
80 | --r-bn 0.01 \
81 | --gpu-device $((GPU)) \
82 | --iteration $((ITR)) \
83 | --jitter 0\
84 | --easy2hard-mode "cosine" --milestone 1 \
85 | --ipc $((IPC)) --store-best-images
86 |
87 | echo "Synthesis -> DONE"
88 |
89 | # =========================================================================================
90 |
91 | IPC=1
92 | ITR=3000
93 | EXP_NAME="IPC$((IPC))_3K_medium"
94 |
95 | echo "$EXP_NAME"
96 |
97 | python /path/to/DELT/recover/recover.py \
98 | --init-data-path "$INIT_PATH" \
99 | --syn-data-path "$SYN_PATH" \
100 | --arch-name "conv4" \
101 | --arch-path "/path/to/model/imagenet-1k_conv4.pth" \
102 | --dataset "imagenet-1k" \
103 | --exp-name "$EXP_NAME" \
104 | --batch-size 100 \
105 | --lr 0.25 \
106 | --r-bn 0.01 \
107 | --gpu-device $((GPU)) \
108 | --iteration $((ITR)) \
109 | --jitter 0\
110 | --easy2hard-mode "cosine" --milestone 1 \
111 | --ipc $((IPC)) --store-best-images
112 |
113 | echo "Synthesis -> DONE"
114 |
115 | # =========================================================================================
116 |
117 | IPC=1
118 | ITR=2000
119 | EXP_NAME="IPC$((IPC))_2K_medium"
120 |
121 | echo "$EXP_NAME"
122 |
123 | python /path/to/DELT/recover/recover.py \
124 | --init-data-path "$INIT_PATH" \
125 | --syn-data-path "$SYN_PATH" \
126 | --arch-name "conv4" \
127 | --arch-path "/path/to/model/imagenet-1k_conv4.pth" \
128 | --dataset "imagenet-1k" \
129 | --exp-name "$EXP_NAME" \
130 | --batch-size 100 \
131 | --lr 0.25 \
132 | --r-bn 0.01 \
133 | --gpu-device $((GPU)) \
134 | --iteration $((ITR)) \
135 | --jitter 0\
136 | --easy2hard-mode "cosine" --milestone 1 \
137 | --ipc $((IPC)) --store-best-images
138 |
139 | echo "Synthesis -> DONE"
140 |
141 | # =========================================================================================
142 |
143 | IPC=1
144 | ITR=1000
145 | EXP_NAME="IPC$((IPC))_1K_medium"
146 |
147 | echo "$EXP_NAME"
148 |
149 | python /path/to/DELT/recover/recover.py \
150 | --init-data-path "$INIT_PATH" \
151 | --syn-data-path "$SYN_PATH" \
152 | --arch-name "conv4" \
153 | --arch-path "/path/to/model/imagenet-1k_conv4.pth" \
154 | --dataset "imagenet-1k" \
155 | --exp-name "$EXP_NAME" \
156 | --batch-size 100 \
157 | --lr 0.25 \
158 | --r-bn 0.01 \
159 | --gpu-device $((GPU)) \
160 | --iteration $((ITR)) \
161 | --jitter 0\
162 | --easy2hard-mode "cosine" --milestone 1 \
163 | --ipc $((IPC)) --store-best-images
164 |
165 | echo "Synthesis -> DONE"
166 |
167 | # =========================================================================================
168 |
--------------------------------------------------------------------------------
/scripts/conv4_imagenet1k_validation.sh:
--------------------------------------------------------------------------------
1 | GPU=0
2 |
3 | # =========================================================================================
4 |
5 | SYN_PATH="/path/to/synthesized/imagenet_1k_conv/"
6 | PROJECT_NAME="Imagenet1K-Seeds"
7 | IPC=50
8 | EXP_NAME="IPC$((IPC))_4K_500_medium"
9 | WANDB_API_KEY="write your api here"
10 |
11 | echo "$EXP_NAME"
12 | SEED=3407
13 |
14 | RAND_AUG="rand-m6-n2-mstd1.0"
15 | # VAL_NAME="$RAND_AUG"
16 | VAL_NAME="IPC$((IPC)) 4K_500 Medium Conv4 S$((SEED))"
17 |
18 | wandb enabled
19 | wandb online
20 | python /path/to/DELT/evaluation/main.py \
21 | --wandb-project "$PROJECT_NAME" \
22 | --wandb-api-key "$WANDB_API_KEY" \
23 | --val-dir "/path/to/val" \
24 | --syn-data-path "$SYN_PATH$EXP_NAME" \
25 | --exp-name "$VAL_NAME" \
26 | --subset "imagenet-1k" \
27 | --arch-name "conv4" \
28 | --use-rand-augment \
29 | --rand-augment-config "$RAND_AUG" \
30 | --random-erasing-p 0.0 \
31 | --min-scale-crops 0.25 \
32 | --max-scale-crops 1.0 \
33 | --ipc $((IPC)) \
34 | --val-ipc 50 \
35 | --stud-name "conv4" \
36 | --re-epochs 300 \
37 | --gpu-device $((GPU)) \
38 | --seed $((SEED))
39 |
40 | # =========================================================================================
41 |
42 |
43 | SEED=4663
44 | VAL_NAME="IPC$((IPC)) 4K_500 Medium Conv4 S$((SEED))"
45 |
46 | wandb enabled
47 | wandb online
48 | python /path/to/DELT/evaluation/main.py \
49 | --wandb-project "$PROJECT_NAME" \
50 | --wandb-api-key "$WANDB_API_KEY" \
51 | --val-dir "/path/to/val" \
52 | --syn-data-path "$SYN_PATH$EXP_NAME" \
53 | --exp-name "$VAL_NAME" \
54 | --subset "imagenet-1k" \
55 | --arch-name "conv4" \
56 | --use-rand-augment \
57 | --rand-augment-config "$RAND_AUG" \
58 | --random-erasing-p 0.0 \
59 | --min-scale-crops 0.25 \
60 | --max-scale-crops 1.0 \
61 | --ipc $((IPC)) \
62 | --val-ipc 50 \
63 | --stud-name "conv4" \
64 | --re-epochs 300 \
65 | --gpu-device $((GPU)) \
66 | --seed $((SEED))
67 |
68 | # =========================================================================================
69 |
70 | SEED=2897
71 | VAL_NAME="IPC$((IPC)) 4K_500 Medium Conv4 S$((SEED))"
72 |
73 | wandb enabled
74 | wandb online
75 | python /path/to/DELT/evaluation/main.py \
76 | --wandb-project "$PROJECT_NAME" \
77 | --wandb-api-key "$WANDB_API_KEY" \
78 | --val-dir "/path/to/val" \
79 | --syn-data-path "$SYN_PATH$EXP_NAME" \
80 | --exp-name "$VAL_NAME" \
81 | --subset "imagenet-1k" \
82 | --arch-name "conv4" \
83 | --use-rand-augment \
84 | --rand-augment-config "$RAND_AUG" \
85 | --random-erasing-p 0.0 \
86 | --min-scale-crops 0.25 \
87 | --max-scale-crops 1.0 \
88 | --ipc $((IPC)) \
89 | --val-ipc 50 \
90 | --stud-name "conv4" \
91 | --re-epochs 300 \
92 | --gpu-device $((GPU)) \
93 | --seed $((SEED))
94 |
95 | # =========================================================================================
96 |
97 | IPC=10
98 | EXP_NAME="IPC$((IPC))_4K_500_medium"
99 |
100 | echo "$EXP_NAME"
101 | SEED=3407
102 |
103 |
104 | # VAL_NAME="$RAND_AUG"
105 | VAL_NAME="IPC$((IPC)) 4K_500 Medium Conv4 S$((SEED))"
106 |
107 | wandb enabled
108 | wandb online
109 | python /path/to/DELT/evaluation/main.py \
110 | --wandb-project "$PROJECT_NAME" \
111 | --wandb-api-key "$WANDB_API_KEY" \
112 | --val-dir "/path/to/val" \
113 | --syn-data-path "$SYN_PATH$EXP_NAME" \
114 | --exp-name "$VAL_NAME" \
115 | --subset "imagenet-1k" \
116 | --arch-name "conv4" \
117 | --use-rand-augment \
118 | --rand-augment-config "$RAND_AUG" \
119 | --random-erasing-p 0.0 \
120 | --min-scale-crops 0.25 \
121 | --max-scale-crops 1.0 \
122 | --ipc $((IPC)) \
123 | --val-ipc 50 \
124 | --stud-name "conv4" \
125 | --re-epochs 300 \
126 | --gpu-device $((GPU)) \
127 | --seed $((SEED))
128 |
129 | # =========================================================================================
130 |
131 |
132 | SEED=4663
133 | VAL_NAME="IPC$((IPC)) 4K_500 Medium Conv4 S$((SEED))"
134 |
135 | wandb enabled
136 | wandb online
137 | python /path/to/DELT/evaluation/main.py \
138 | --wandb-project "$PROJECT_NAME" \
139 | --wandb-api-key "$WANDB_API_KEY" \
140 | --val-dir "/path/to/val" \
141 | --syn-data-path "$SYN_PATH$EXP_NAME" \
142 | --exp-name "$VAL_NAME" \
143 | --subset "imagenet-1k" \
144 | --arch-name "conv4" \
145 | --use-rand-augment \
146 | --rand-augment-config "$RAND_AUG" \
147 | --random-erasing-p 0.0 \
148 | --min-scale-crops 0.25 \
149 | --max-scale-crops 1.0 \
150 | --ipc $((IPC)) \
151 | --val-ipc 50 \
152 | --stud-name "conv4" \
153 | --re-epochs 300 \
154 | --gpu-device $((GPU)) \
155 | --seed $((SEED))
156 |
157 | # =========================================================================================
158 |
159 | SEED=2897
160 | VAL_NAME="IPC$((IPC)) 4K_500 Medium Conv4 S$((SEED))"
161 |
162 | wandb enabled
163 | wandb online
164 | python /path/to/DELT/evaluation/main.py \
165 | --wandb-project "$PROJECT_NAME" \
166 | --wandb-api-key "$WANDB_API_KEY" \
167 | --val-dir "/path/to/val" \
168 | --syn-data-path "$SYN_PATH$EXP_NAME" \
169 | --exp-name "$VAL_NAME" \
170 | --subset "imagenet-1k" \
171 | --arch-name "conv4" \
172 | --use-rand-augment \
173 | --rand-augment-config "$RAND_AUG" \
174 | --random-erasing-p 0.0 \
175 | --min-scale-crops 0.25 \
176 | --max-scale-crops 1.0 \
177 | --ipc $((IPC)) \
178 | --val-ipc 50 \
179 | --stud-name "conv4" \
180 | --re-epochs 300 \
181 | --gpu-device $((GPU)) \
182 | --seed $((SEED))
183 |
184 | # =========================================================================================
185 |
186 | IPC=1
187 | EXP_NAME="IPC$((IPC))_4K_medium"
188 |
189 | echo "$EXP_NAME"
190 | SEED=3407
191 |
192 |
193 | # VAL_NAME="$RAND_AUG"
194 | VAL_NAME="IPC$((IPC)) 4K Medium Conv4 S$((SEED))"
195 |
196 | wandb enabled
197 | wandb online
198 | python /path/to/DELT/evaluation/main.py \
199 | --wandb-project "$PROJECT_NAME" \
200 | --wandb-api-key "$WANDB_API_KEY" \
201 | --val-dir "/path/to/val" \
202 | --syn-data-path "$SYN_PATH$EXP_NAME" \
203 | --exp-name "$VAL_NAME" \
204 | --subset "imagenet-1k" \
205 | --arch-name "conv4" \
206 | --use-rand-augment \
207 | --rand-augment-config "$RAND_AUG" \
208 | --random-erasing-p 0.0 \
209 | --min-scale-crops 0.25 \
210 | --max-scale-crops 1.0 \
211 | --ipc $((IPC)) \
212 | --val-ipc 50 \
213 | --stud-name "conv4" \
214 | --re-epochs 300 \
215 | --gpu-device $((GPU)) \
216 | --seed $((SEED))
217 |
218 | # =========================================================================================
219 |
220 |
221 | SEED=4663
222 | VAL_NAME="IPC$((IPC)) 4K Medium Conv4 S$((SEED))"
223 |
224 | wandb enabled
225 | wandb online
226 | python /path/to/DELT/evaluation/main.py \
227 | --wandb-project "$PROJECT_NAME" \
228 | --wandb-api-key "$WANDB_API_KEY" \
229 | --val-dir "/path/to/val" \
230 | --syn-data-path "$SYN_PATH$EXP_NAME" \
231 | --exp-name "$VAL_NAME" \
232 | --subset "imagenet-1k" \
233 | --arch-name "conv4" \
234 | --use-rand-augment \
235 | --rand-augment-config "$RAND_AUG" \
236 | --random-erasing-p 0.0 \
237 | --min-scale-crops 0.25 \
238 | --max-scale-crops 1.0 \
239 | --ipc $((IPC)) \
240 | --val-ipc 50 \
241 | --stud-name "conv4" \
242 | --re-epochs 300 \
243 | --gpu-device $((GPU)) \
244 | --seed $((SEED))
245 |
246 | # =========================================================================================
247 |
248 | SEED=2897
249 | VAL_NAME="IPC$((IPC)) 4K Medium Conv4 S$((SEED))"
250 |
251 | wandb enabled
252 | wandb online
253 | python /path/to/DELT/evaluation/main.py \
254 | --wandb-project "$PROJECT_NAME" \
255 | --wandb-api-key "$WANDB_API_KEY" \
256 | --val-dir "/path/to/val" \
257 | --syn-data-path "$SYN_PATH$EXP_NAME" \
258 | --exp-name "$VAL_NAME" \
259 | --subset "imagenet-1k" \
260 | --arch-name "conv4" \
261 | --use-rand-augment \
262 | --rand-augment-config "$RAND_AUG" \
263 | --random-erasing-p 0.0 \
264 | --min-scale-crops 0.25 \
265 | --max-scale-crops 1.0 \
266 | --ipc $((IPC)) \
267 | --val-ipc 50 \
268 | --stud-name "conv4" \
269 | --re-epochs 300 \
270 | --gpu-device $((GPU)) \
271 | --seed $((SEED))
272 |
273 | # =========================================================================================
274 |
275 | IPC=1
276 | EXP_NAME="IPC$((IPC))_3K_medium"
277 |
278 | echo "$EXP_NAME"
279 | SEED=3407
280 |
281 |
282 | # VAL_NAME="$RAND_AUG"
283 | VAL_NAME="IPC$((IPC)) 3K Medium Conv4 S$((SEED))"
284 |
285 | wandb enabled
286 | wandb online
287 | python /path/to/DELT/evaluation/main.py \
288 | --wandb-project "$PROJECT_NAME" \
289 | --wandb-api-key "$WANDB_API_KEY" \
290 | --val-dir "/path/to/val" \
291 | --syn-data-path "$SYN_PATH$EXP_NAME" \
292 | --exp-name "$VAL_NAME" \
293 | --subset "imagenet-1k" \
294 | --arch-name "conv4" \
295 | --use-rand-augment \
296 | --rand-augment-config "$RAND_AUG" \
297 | --random-erasing-p 0.0 \
298 | --min-scale-crops 0.25 \
299 | --max-scale-crops 1.0 \
300 | --ipc $((IPC)) \
301 | --val-ipc 50 \
302 | --stud-name "conv4" \
303 | --re-epochs 300 \
304 | --gpu-device $((GPU)) \
305 | --seed $((SEED))
306 |
307 | # =========================================================================================
308 |
309 |
310 | SEED=4663
311 | VAL_NAME="IPC$((IPC)) 3K Medium Conv4 S$((SEED))"
312 |
313 | wandb enabled
314 | wandb online
315 | python /path/to/DELT/evaluation/main.py \
316 | --wandb-project "$PROJECT_NAME" \
317 | --wandb-api-key "$WANDB_API_KEY" \
318 | --val-dir "/path/to/val" \
319 | --syn-data-path "$SYN_PATH$EXP_NAME" \
320 | --exp-name "$VAL_NAME" \
321 | --subset "imagenet-1k" \
322 | --arch-name "conv4" \
323 | --use-rand-augment \
324 | --rand-augment-config "$RAND_AUG" \
325 | --random-erasing-p 0.0 \
326 | --min-scale-crops 0.25 \
327 | --max-scale-crops 1.0 \
328 | --ipc $((IPC)) \
329 | --val-ipc 50 \
330 | --stud-name "conv4" \
331 | --re-epochs 300 \
332 | --gpu-device $((GPU)) \
333 | --seed $((SEED))
334 |
335 | # =========================================================================================
336 |
337 | SEED=2897
338 | VAL_NAME="IPC$((IPC)) 3K Medium Conv4 S$((SEED))"
339 |
340 | wandb enabled
341 | wandb online
342 | python /path/to/DELT/evaluation/main.py \
343 | --wandb-project "$PROJECT_NAME" \
344 | --wandb-api-key "$WANDB_API_KEY" \
345 | --val-dir "/path/to/val" \
346 | --syn-data-path "$SYN_PATH$EXP_NAME" \
347 | --exp-name "$VAL_NAME" \
348 | --subset "imagenet-1k" \
349 | --arch-name "conv4" \
350 | --use-rand-augment \
351 | --rand-augment-config "$RAND_AUG" \
352 | --random-erasing-p 0.0 \
353 | --min-scale-crops 0.25 \
354 | --max-scale-crops 1.0 \
355 | --ipc $((IPC)) \
356 | --val-ipc 50 \
357 | --stud-name "conv4" \
358 | --re-epochs 300 \
359 | --gpu-device $((GPU)) \
360 | --seed $((SEED))
361 |
362 | # =========================================================================================
363 |
364 | IPC=1
365 | EXP_NAME="IPC$((IPC))_2K_medium"
366 |
367 | echo "$EXP_NAME"
368 | SEED=3407
369 |
370 |
371 | # VAL_NAME="$RAND_AUG"
372 | VAL_NAME="IPC$((IPC)) 2K Medium Conv4 S$((SEED))"
373 |
374 | wandb enabled
375 | wandb online
376 | python /path/to/DELT/evaluation/main.py \
377 | --wandb-project "$PROJECT_NAME" \
378 | --wandb-api-key "$WANDB_API_KEY" \
379 | --val-dir "/path/to/val" \
380 | --syn-data-path "$SYN_PATH$EXP_NAME" \
381 | --exp-name "$VAL_NAME" \
382 | --subset "imagenet-1k" \
383 | --arch-name "conv4" \
384 | --use-rand-augment \
385 | --rand-augment-config "$RAND_AUG" \
386 | --random-erasing-p 0.0 \
387 | --min-scale-crops 0.25 \
388 | --max-scale-crops 1.0 \
389 | --ipc $((IPC)) \
390 | --val-ipc 50 \
391 | --stud-name "conv4" \
392 | --re-epochs 300 \
393 | --gpu-device $((GPU)) \
394 | --seed $((SEED))
395 |
396 | # =========================================================================================
397 |
398 |
399 | SEED=4663
400 | VAL_NAME="IPC$((IPC)) 2K Medium Conv4 S$((SEED))"
401 |
402 | wandb enabled
403 | wandb online
404 | python /path/to/DELT/evaluation/main.py \
405 | --wandb-project "$PROJECT_NAME" \
406 | --wandb-api-key "$WANDB_API_KEY" \
407 | --val-dir "/path/to/val" \
408 | --syn-data-path "$SYN_PATH$EXP_NAME" \
409 | --exp-name "$VAL_NAME" \
410 | --subset "imagenet-1k" \
411 | --arch-name "conv4" \
412 | --use-rand-augment \
413 | --rand-augment-config "$RAND_AUG" \
414 | --random-erasing-p 0.0 \
415 | --min-scale-crops 0.25 \
416 | --max-scale-crops 1.0 \
417 | --ipc $((IPC)) \
418 | --val-ipc 50 \
419 | --stud-name "conv4" \
420 | --re-epochs 300 \
421 | --gpu-device $((GPU)) \
422 | --seed $((SEED))
423 |
424 | # =========================================================================================
425 |
426 | SEED=2897
427 | VAL_NAME="IPC$((IPC)) 2K Medium Conv4 S$((SEED))"
428 |
429 | wandb enabled
430 | wandb online
431 | python /path/to/DELT/evaluation/main.py \
432 | --wandb-project "$PROJECT_NAME" \
433 | --wandb-api-key "$WANDB_API_KEY" \
434 | --val-dir "/path/to/val" \
435 | --syn-data-path "$SYN_PATH$EXP_NAME" \
436 | --exp-name "$VAL_NAME" \
437 | --subset "imagenet-1k" \
438 | --arch-name "conv4" \
439 | --use-rand-augment \
440 | --rand-augment-config "$RAND_AUG" \
441 | --random-erasing-p 0.0 \
442 | --min-scale-crops 0.25 \
443 | --max-scale-crops 1.0 \
444 | --ipc $((IPC)) \
445 | --val-ipc 50 \
446 | --stud-name "conv4" \
447 | --re-epochs 300 \
448 | --gpu-device $((GPU)) \
449 | --seed $((SEED))
450 |
451 | # =========================================================================================
452 |
453 | IPC=1
454 | EXP_NAME="IPC$((IPC))_1K_medium"
455 |
456 | echo "$EXP_NAME"
457 | SEED=3407
458 |
459 |
460 | # VAL_NAME="$RAND_AUG"
461 | VAL_NAME="IPC$((IPC)) 1K Medium Conv4 S$((SEED))"
462 |
463 | wandb enabled
464 | wandb online
465 | python /path/to/DELT/evaluation/main.py \
466 | --wandb-project "$PROJECT_NAME" \
467 | --wandb-api-key "$WANDB_API_KEY" \
468 | --val-dir "/path/to/val" \
469 | --syn-data-path "$SYN_PATH$EXP_NAME" \
470 | --exp-name "$VAL_NAME" \
471 | --subset "imagenet-1k" \
472 | --arch-name "conv4" \
473 | --use-rand-augment \
474 | --rand-augment-config "$RAND_AUG" \
475 | --random-erasing-p 0.0 \
476 | --min-scale-crops 0.25 \
477 | --max-scale-crops 1.0 \
478 | --ipc $((IPC)) \
479 | --val-ipc 50 \
480 | --stud-name "conv4" \
481 | --re-epochs 300 \
482 | --gpu-device $((GPU)) \
483 | --seed $((SEED))
484 |
485 | # =========================================================================================
486 |
487 |
488 | SEED=4663
489 | VAL_NAME="IPC$((IPC)) 1K Medium Conv4 S$((SEED))"
490 |
491 | wandb enabled
492 | wandb online
493 | python /path/to/DELT/evaluation/main.py \
494 | --wandb-project "$PROJECT_NAME" \
495 | --wandb-api-key "$WANDB_API_KEY" \
496 | --val-dir "/path/to/val" \
497 | --syn-data-path "$SYN_PATH$EXP_NAME" \
498 | --exp-name "$VAL_NAME" \
499 | --subset "imagenet-1k" \
500 | --arch-name "conv4" \
501 | --use-rand-augment \
502 | --rand-augment-config "$RAND_AUG" \
503 | --random-erasing-p 0.0 \
504 | --min-scale-crops 0.25 \
505 | --max-scale-crops 1.0 \
506 | --ipc $((IPC)) \
507 | --val-ipc 50 \
508 | --stud-name "conv4" \
509 | --re-epochs 300 \
510 | --gpu-device $((GPU)) \
511 | --seed $((SEED))
512 |
513 | # =========================================================================================
514 |
515 | SEED=2897
516 | VAL_NAME="IPC$((IPC)) 1K Medium Conv4 S$((SEED))"
517 |
518 | wandb enabled
519 | wandb online
520 | python /path/to/DELT/evaluation/main.py \
521 | --wandb-project "$PROJECT_NAME" \
522 | --wandb-api-key "$WANDB_API_KEY" \
523 | --val-dir "/path/to/val" \
524 | --syn-data-path "$SYN_PATH$EXP_NAME" \
525 | --exp-name "$VAL_NAME" \
526 | --subset "imagenet-1k" \
527 | --arch-name "conv4" \
528 | --use-rand-augment \
529 | --rand-augment-config "$RAND_AUG" \
530 | --random-erasing-p 0.0 \
531 | --min-scale-crops 0.25 \
532 | --max-scale-crops 1.0 \
533 | --ipc $((IPC)) \
534 | --val-ipc 50 \
535 | --stud-name "conv4" \
536 | --re-epochs 300 \
537 | --gpu-device $((GPU)) \
538 | --seed $((SEED))
539 |
540 | # =========================================================================================
541 |
--------------------------------------------------------------------------------
/scripts/resnet18_imagenet1k_synthesis.sh:
--------------------------------------------------------------------------------
1 | GPU=0
2 |
3 | # =========================================================================================
4 |
5 | SYN_PATH="/path/to/synthesized/imagenet_1k/"
6 | INIT_PATH="/path/to/initialization/medium_prob"
7 | IPC=50
8 | ITR=4000
9 | ROUND_ITR=500
10 | EXP_NAME="IPC$((IPC))_4K_500_medium"
11 |
12 | echo "$EXP_NAME"
13 |
14 | python /path/to/DELT/recover/recover.py \
15 | --init-data-path "$INIT_PATH" \
16 | --syn-data-path "$SYN_PATH" \
17 | --arch-name "resnet18" \
18 | --dataset "imagenet-1k" \
19 | --exp-name "$EXP_NAME" \
20 | --use-early-late \
21 | --round-iterations $((ROUND_ITR)) \
22 | --batch-size 100 \
23 | --lr 0.25 \
24 | --r-bn 0.01 \
25 | --gpu-device $((GPU)) \
26 | --iteration $((ITR)) \
27 | --jitter 0\
28 | --easy2hard-mode "cosine" --milestone 1 \
29 | --ipc $((IPC)) --store-best-images
30 |
31 | echo "Synthesis -> DONE"
32 |
33 | # =========================================================================================
34 |
35 | IPC=10
36 | ITR=4000
37 | ROUND_ITR=500
38 | EXP_NAME="IPC$((IPC))_4K_500_medium"
39 |
40 | echo "$EXP_NAME"
41 |
42 | python /path/to/DELT/recover/recover.py \
43 | --init-data-path "$INIT_PATH" \
44 | --syn-data-path "$SYN_PATH" \
45 | --arch-name "resnet18" \
46 | --dataset "imagenet-1k" \
47 | --exp-name "$EXP_NAME" \
48 | --use-early-late \
49 | --round-iterations $((ROUND_ITR)) \
50 | --batch-size 100 \
51 | --lr 0.25 \
52 | --r-bn 0.01 \
53 | --gpu-device $((GPU)) \
54 | --iteration $((ITR)) \
55 | --jitter 0\
56 | --easy2hard-mode "cosine" --milestone 1 \
57 | --ipc $((IPC)) --store-best-images
58 |
59 | echo "Synthesis -> DONE"
60 |
61 | # =========================================================================================
62 |
63 | IPC=1
64 | ITR=4000
65 | EXP_NAME="IPC$((IPC))_4K_medium"
66 |
67 | echo "$EXP_NAME"
68 |
69 | python /path/to/DELT/recover/recover.py \
70 | --init-data-path "$INIT_PATH" \
71 | --syn-data-path "$SYN_PATH" \
72 | --arch-name "resnet18" \
73 | --dataset "imagenet-1k" \
74 | --exp-name "$EXP_NAME" \
75 | --batch-size 100 \
76 | --lr 0.25 \
77 | --r-bn 0.01 \
78 | --gpu-device $((GPU)) \
79 | --iteration $((ITR)) \
80 | --jitter 0\
81 | --easy2hard-mode "cosine" --milestone 1 \
82 | --ipc $((IPC)) --store-best-images
83 |
84 | echo "Synthesis -> DONE"
85 |
86 | # =========================================================================================
87 |
88 | IPC=1
89 | ITR=3000
90 | EXP_NAME="IPC$((IPC))_3K_medium"
91 |
92 | echo "$EXP_NAME"
93 |
94 | python /path/to/DELT/recover/recover.py \
95 | --init-data-path "$INIT_PATH" \
96 | --syn-data-path "$SYN_PATH" \
97 | --arch-name "resnet18" \
98 | --dataset "imagenet-1k" \
99 | --exp-name "$EXP_NAME" \
100 | --batch-size 100 \
101 | --lr 0.25 \
102 | --r-bn 0.01 \
103 | --gpu-device $((GPU)) \
104 | --iteration $((ITR)) \
105 | --jitter 0\
106 | --easy2hard-mode "cosine" --milestone 1 \
107 | --ipc $((IPC)) --store-best-images
108 |
109 | echo "Synthesis -> DONE"
110 |
111 | # =========================================================================================
112 |
113 | IPC=1
114 | ITR=2000
115 | EXP_NAME="IPC$((IPC))_2K_medium"
116 |
117 | echo "$EXP_NAME"
118 |
119 | python /path/to/DELT/recover/recover.py \
120 | --init-data-path "$INIT_PATH" \
121 | --syn-data-path "$SYN_PATH" \
122 | --arch-name "resnet18" \
123 | --dataset "imagenet-1k" \
124 | --exp-name "$EXP_NAME" \
125 | --batch-size 100 \
126 | --lr 0.25 \
127 | --r-bn 0.01 \
128 | --gpu-device $((GPU)) \
129 | --iteration $((ITR)) \
130 | --jitter 0\
131 | --easy2hard-mode "cosine" --milestone 1 \
132 | --ipc $((IPC)) --store-best-images
133 |
134 | echo "Synthesis -> DONE"
135 |
136 | # =========================================================================================
137 |
138 | IPC=1
139 | ITR=1000
140 | EXP_NAME="IPC$((IPC))_1K_medium"
141 |
142 | echo "$EXP_NAME"
143 |
144 | python /path/to/DELT/recover/recover.py \
145 | --init-data-path "$INIT_PATH" \
146 | --syn-data-path "$SYN_PATH" \
147 | --arch-name "resnet18" \
148 | --dataset "imagenet-1k" \
149 | --exp-name "$EXP_NAME" \
150 | --batch-size 100 \
151 | --lr 0.25 \
152 | --r-bn 0.01 \
153 | --gpu-device $((GPU)) \
154 | --iteration $((ITR)) \
155 | --jitter 0\
156 | --easy2hard-mode "cosine" --milestone 1 \
157 | --ipc $((IPC)) --store-best-images
158 |
159 | echo "Synthesis -> DONE"
160 |
161 | # =========================================================================================
162 |
--------------------------------------------------------------------------------
/scripts/resnet18_imagenet1k_validation.sh:
--------------------------------------------------------------------------------
1 | GPU=0
2 |
3 | # =========================================================================================
4 |
5 | SYN_PATH="/path/to/synthesized/imagenet_1k/"
6 | PROJECT_NAME="Imagenet1K-Seeds"
7 | IPC=50
8 | EXP_NAME="IPC$((IPC))_4K_500_medium"
9 | WANDB_API_KEY="write your api here"
10 |
11 | echo "$EXP_NAME"
12 | SEED=3407
13 |
14 | RAND_AUG="rand-m6-n2-mstd1.0"
15 | # VAL_NAME="$RAND_AUG"
16 | VAL_NAME="IPC$((IPC)) 4K_500 Medium R18 S$((SEED))"
17 |
18 | wandb enabled
19 | wandb online
20 | python /path/to/DELT/evaluation/main.py \
21 | --wandb-project "$PROJECT_NAME" \
22 | --wandb-api-key "$WANDB_API_KEY" \
23 | --val-dir "/path/to/val" \
24 | --syn-data-path "$SYN_PATH$EXP_NAME" \
25 | --exp-name "$VAL_NAME" \
26 | --subset "imagenet-1k" \
27 | --arch-name "resnet18" \
28 | --use-rand-augment \
29 | --rand-augment-config "$RAND_AUG" \
30 | --random-erasing-p 0.0 \
31 | --min-scale-crops 0.25 \
32 | --max-scale-crops 1.0 \
33 | --ipc $((IPC)) \
34 | --val-ipc 50 \
35 | --stud-name "resnet18" \
36 | --re-epochs 300 \
37 | --gpu-device $((GPU)) \
38 | --seed $((SEED))
39 |
40 | # =========================================================================================
41 |
42 |
43 | SEED=4663
44 | VAL_NAME="IPC$((IPC)) 4K_500 Medium R18 S$((SEED))"
45 |
46 | wandb enabled
47 | wandb online
48 | python /path/to/DELT/evaluation/main.py \
49 | --wandb-project "$PROJECT_NAME" \
50 | --wandb-api-key "$WANDB_API_KEY" \
51 | --val-dir "/path/to/val" \
52 | --syn-data-path "$SYN_PATH$EXP_NAME" \
53 | --exp-name "$VAL_NAME" \
54 | --subset "imagenet-1k" \
55 | --arch-name "resnet18" \
56 | --use-rand-augment \
57 | --rand-augment-config "$RAND_AUG" \
58 | --random-erasing-p 0.0 \
59 | --min-scale-crops 0.25 \
60 | --max-scale-crops 1.0 \
61 | --ipc $((IPC)) \
62 | --val-ipc 50 \
63 | --stud-name "resnet18" \
64 | --re-epochs 300 \
65 | --gpu-device $((GPU)) \
66 | --seed $((SEED))
67 |
68 | # =========================================================================================
69 |
70 | SEED=2897
71 | VAL_NAME="IPC$((IPC)) 4K_500 Medium R18 S$((SEED))"
72 |
73 | wandb enabled
74 | wandb online
75 | python /path/to/DELT/evaluation/main.py \
76 | --wandb-project "$PROJECT_NAME" \
77 | --wandb-api-key "$WANDB_API_KEY" \
78 | --val-dir "/path/to/val" \
79 | --syn-data-path "$SYN_PATH$EXP_NAME" \
80 | --exp-name "$VAL_NAME" \
81 | --subset "imagenet-1k" \
82 | --arch-name "resnet18" \
83 | --use-rand-augment \
84 | --rand-augment-config "$RAND_AUG" \
85 | --random-erasing-p 0.0 \
86 | --min-scale-crops 0.25 \
87 | --max-scale-crops 1.0 \
88 | --ipc $((IPC)) \
89 | --val-ipc 50 \
90 | --stud-name "resnet18" \
91 | --re-epochs 300 \
92 | --gpu-device $((GPU)) \
93 | --seed $((SEED))
94 |
95 | # =========================================================================================
96 |
97 | IPC=10
98 | EXP_NAME="IPC$((IPC))_4K_500_medium"
99 |
100 | echo "$EXP_NAME"
101 | SEED=3407
102 |
103 |
104 | # VAL_NAME="$RAND_AUG"
105 | VAL_NAME="IPC$((IPC)) 4K_500 Medium R18 S$((SEED))"
106 |
107 | wandb enabled
108 | wandb online
109 | python /path/to/DELT/evaluation/main.py \
110 | --wandb-project "$PROJECT_NAME" \
111 | --wandb-api-key "$WANDB_API_KEY" \
112 | --val-dir "/path/to/val" \
113 | --syn-data-path "$SYN_PATH$EXP_NAME" \
114 | --exp-name "$VAL_NAME" \
115 | --subset "imagenet-1k" \
116 | --arch-name "resnet18" \
117 | --use-rand-augment \
118 | --rand-augment-config "$RAND_AUG" \
119 | --random-erasing-p 0.0 \
120 | --min-scale-crops 0.25 \
121 | --max-scale-crops 1.0 \
122 | --ipc $((IPC)) \
123 | --val-ipc 50 \
124 | --stud-name "resnet18" \
125 | --re-epochs 300 \
126 | --gpu-device $((GPU)) \
127 | --seed $((SEED))
128 |
129 | # =========================================================================================
130 |
131 |
132 | SEED=4663
133 | VAL_NAME="IPC$((IPC)) 4K_500 Medium R18 S$((SEED))"
134 |
135 | wandb enabled
136 | wandb online
137 | python /path/to/DELT/evaluation/main.py \
138 | --wandb-project "$PROJECT_NAME" \
139 | --wandb-api-key "$WANDB_API_KEY" \
140 | --val-dir "/path/to/val" \
141 | --syn-data-path "$SYN_PATH$EXP_NAME" \
142 | --exp-name "$VAL_NAME" \
143 | --subset "imagenet-1k" \
144 | --arch-name "resnet18" \
145 | --use-rand-augment \
146 | --rand-augment-config "$RAND_AUG" \
147 | --random-erasing-p 0.0 \
148 | --min-scale-crops 0.25 \
149 | --max-scale-crops 1.0 \
150 | --ipc $((IPC)) \
151 | --val-ipc 50 \
152 | --stud-name "resnet18" \
153 | --re-epochs 300 \
154 | --gpu-device $((GPU)) \
155 | --seed $((SEED))
156 |
157 | # =========================================================================================
158 |
159 | SEED=2897
160 | VAL_NAME="IPC$((IPC)) 4K_500 Medium R18 S$((SEED))"
161 |
162 | wandb enabled
163 | wandb online
164 | python /path/to/DELT/evaluation/main.py \
165 | --wandb-project "$PROJECT_NAME" \
166 | --wandb-api-key "$WANDB_API_KEY" \
167 | --val-dir "/path/to/val" \
168 | --syn-data-path "$SYN_PATH$EXP_NAME" \
169 | --exp-name "$VAL_NAME" \
170 | --subset "imagenet-1k" \
171 | --arch-name "resnet18" \
172 | --use-rand-augment \
173 | --rand-augment-config "$RAND_AUG" \
174 | --random-erasing-p 0.0 \
175 | --min-scale-crops 0.25 \
176 | --max-scale-crops 1.0 \
177 | --ipc $((IPC)) \
178 | --val-ipc 50 \
179 | --stud-name "resnet18" \
180 | --re-epochs 300 \
181 | --gpu-device $((GPU)) \
182 | --seed $((SEED))
183 |
184 | # =========================================================================================
185 |
186 | IPC=1
187 | EXP_NAME="IPC$((IPC))_4K_medium"
188 |
189 | echo "$EXP_NAME"
190 | SEED=3407
191 |
192 |
193 | # VAL_NAME="$RAND_AUG"
194 | VAL_NAME="IPC$((IPC)) 4K Medium R18 S$((SEED))"
195 |
196 | wandb enabled
197 | wandb online
198 | python /path/to/DELT/evaluation/main.py \
199 | --wandb-project "$PROJECT_NAME" \
200 | --wandb-api-key "$WANDB_API_KEY" \
201 | --val-dir "/path/to/val" \
202 | --syn-data-path "$SYN_PATH$EXP_NAME" \
203 | --exp-name "$VAL_NAME" \
204 | --subset "imagenet-1k" \
205 | --arch-name "resnet18" \
206 | --use-rand-augment \
207 | --rand-augment-config "$RAND_AUG" \
208 | --random-erasing-p 0.0 \
209 | --min-scale-crops 0.25 \
210 | --max-scale-crops 1.0 \
211 | --ipc $((IPC)) \
212 | --val-ipc 50 \
213 | --stud-name "resnet18" \
214 | --re-epochs 300 \
215 | --gpu-device $((GPU)) \
216 | --seed $((SEED))
217 |
218 | # =========================================================================================
219 |
220 |
221 | SEED=4663
222 | VAL_NAME="IPC$((IPC)) 4K Medium R18 S$((SEED))"
223 |
224 | wandb enabled
225 | wandb online
226 | python /path/to/DELT/evaluation/main.py \
227 | --wandb-project "$PROJECT_NAME" \
228 | --wandb-api-key "$WANDB_API_KEY" \
229 | --val-dir "/path/to/val" \
230 | --syn-data-path "$SYN_PATH$EXP_NAME" \
231 | --exp-name "$VAL_NAME" \
232 | --subset "imagenet-1k" \
233 | --arch-name "resnet18" \
234 | --use-rand-augment \
235 | --rand-augment-config "$RAND_AUG" \
236 | --random-erasing-p 0.0 \
237 | --min-scale-crops 0.25 \
238 | --max-scale-crops 1.0 \
239 | --ipc $((IPC)) \
240 | --val-ipc 50 \
241 | --stud-name "resnet18" \
242 | --re-epochs 300 \
243 | --gpu-device $((GPU)) \
244 | --seed $((SEED))
245 |
246 | # =========================================================================================
247 |
248 | SEED=2897
249 | VAL_NAME="IPC$((IPC)) 4K Medium R18 S$((SEED))"
250 |
251 | wandb enabled
252 | wandb online
253 | python /path/to/DELT/evaluation/main.py \
254 | --wandb-project "$PROJECT_NAME" \
255 | --wandb-api-key "$WANDB_API_KEY" \
256 | --val-dir "/path/to/val" \
257 | --syn-data-path "$SYN_PATH$EXP_NAME" \
258 | --exp-name "$VAL_NAME" \
259 | --subset "imagenet-1k" \
260 | --arch-name "resnet18" \
261 | --use-rand-augment \
262 | --rand-augment-config "$RAND_AUG" \
263 | --random-erasing-p 0.0 \
264 | --min-scale-crops 0.25 \
265 | --max-scale-crops 1.0 \
266 | --ipc $((IPC)) \
267 | --val-ipc 50 \
268 | --stud-name "resnet18" \
269 | --re-epochs 300 \
270 | --gpu-device $((GPU)) \
271 | --seed $((SEED))
272 |
273 | # =========================================================================================
274 |
275 | IPC=1
276 | EXP_NAME="IPC$((IPC))_3K_medium"
277 |
278 | echo "$EXP_NAME"
279 | SEED=3407
280 |
281 |
282 | # VAL_NAME="$RAND_AUG"
283 | VAL_NAME="IPC$((IPC)) 3K Medium R18 S$((SEED))"
284 |
285 | wandb enabled
286 | wandb online
287 | python /path/to/DELT/evaluation/main.py \
288 | --wandb-project "$PROJECT_NAME" \
289 | --wandb-api-key "$WANDB_API_KEY" \
290 | --val-dir "/path/to/val" \
291 | --syn-data-path "$SYN_PATH$EXP_NAME" \
292 | --exp-name "$VAL_NAME" \
293 | --subset "imagenet-1k" \
294 | --arch-name "resnet18" \
295 | --use-rand-augment \
296 | --rand-augment-config "$RAND_AUG" \
297 | --random-erasing-p 0.0 \
298 | --min-scale-crops 0.25 \
299 | --max-scale-crops 1.0 \
300 | --ipc $((IPC)) \
301 | --val-ipc 50 \
302 | --stud-name "resnet18" \
303 | --re-epochs 300 \
304 | --gpu-device $((GPU)) \
305 | --seed $((SEED))
306 |
307 | # =========================================================================================
308 |
309 |
310 | SEED=4663
311 | VAL_NAME="IPC$((IPC)) 3K Medium R18 S$((SEED))"
312 |
313 | wandb enabled
314 | wandb online
315 | python /path/to/DELT/evaluation/main.py \
316 | --wandb-project "$PROJECT_NAME" \
317 | --wandb-api-key "$WANDB_API_KEY" \
318 | --val-dir "/path/to/val" \
319 | --syn-data-path "$SYN_PATH$EXP_NAME" \
320 | --exp-name "$VAL_NAME" \
321 | --subset "imagenet-1k" \
322 | --arch-name "resnet18" \
323 | --use-rand-augment \
324 | --rand-augment-config "$RAND_AUG" \
325 | --random-erasing-p 0.0 \
326 | --min-scale-crops 0.25 \
327 | --max-scale-crops 1.0 \
328 | --ipc $((IPC)) \
329 | --val-ipc 50 \
330 | --stud-name "resnet18" \
331 | --re-epochs 300 \
332 | --gpu-device $((GPU)) \
333 | --seed $((SEED))
334 |
335 | # =========================================================================================
336 |
337 | SEED=2897
338 | VAL_NAME="IPC$((IPC)) 3K Medium R18 S$((SEED))"
339 |
340 | wandb enabled
341 | wandb online
342 | python /path/to/DELT/evaluation/main.py \
343 | --wandb-project "$PROJECT_NAME" \
344 | --wandb-api-key "$WANDB_API_KEY" \
345 | --val-dir "/path/to/val" \
346 | --syn-data-path "$SYN_PATH$EXP_NAME" \
347 | --exp-name "$VAL_NAME" \
348 | --subset "imagenet-1k" \
349 | --arch-name "resnet18" \
350 | --use-rand-augment \
351 | --rand-augment-config "$RAND_AUG" \
352 | --random-erasing-p 0.0 \
353 | --min-scale-crops 0.25 \
354 | --max-scale-crops 1.0 \
355 | --ipc $((IPC)) \
356 | --val-ipc 50 \
357 | --stud-name "resnet18" \
358 | --re-epochs 300 \
359 | --gpu-device $((GPU)) \
360 | --seed $((SEED))
361 |
362 | # =========================================================================================
363 |
364 | IPC=1
365 | EXP_NAME="IPC$((IPC))_2K_medium"
366 |
367 | echo "$EXP_NAME"
368 | SEED=3407
369 |
370 |
371 | # VAL_NAME="$RAND_AUG"
372 | VAL_NAME="IPC$((IPC)) 2K Medium R18 S$((SEED))"
373 |
374 | wandb enabled
375 | wandb online
376 | python /path/to/DELT/evaluation/main.py \
377 | --wandb-project "$PROJECT_NAME" \
378 | --wandb-api-key "$WANDB_API_KEY" \
379 | --val-dir "/path/to/val" \
380 | --syn-data-path "$SYN_PATH$EXP_NAME" \
381 | --exp-name "$VAL_NAME" \
382 | --subset "imagenet-1k" \
383 | --arch-name "resnet18" \
384 | --use-rand-augment \
385 | --rand-augment-config "$RAND_AUG" \
386 | --random-erasing-p 0.0 \
387 | --min-scale-crops 0.25 \
388 | --max-scale-crops 1.0 \
389 | --ipc $((IPC)) \
390 | --val-ipc 50 \
391 | --stud-name "resnet18" \
392 | --re-epochs 300 \
393 | --gpu-device $((GPU)) \
394 | --seed $((SEED))
395 |
396 | # =========================================================================================
397 |
398 |
399 | SEED=4663
400 | VAL_NAME="IPC$((IPC)) 2K Medium R18 S$((SEED))"
401 |
402 | wandb enabled
403 | wandb online
404 | python /path/to/DELT/evaluation/main.py \
405 | --wandb-project "$PROJECT_NAME" \
406 | --wandb-api-key "$WANDB_API_KEY" \
407 | --val-dir "/path/to/val" \
408 | --syn-data-path "$SYN_PATH$EXP_NAME" \
409 | --exp-name "$VAL_NAME" \
410 | --subset "imagenet-1k" \
411 | --arch-name "resnet18" \
412 | --use-rand-augment \
413 | --rand-augment-config "$RAND_AUG" \
414 | --random-erasing-p 0.0 \
415 | --min-scale-crops 0.25 \
416 | --max-scale-crops 1.0 \
417 | --ipc $((IPC)) \
418 | --val-ipc 50 \
419 | --stud-name "resnet18" \
420 | --re-epochs 300 \
421 | --gpu-device $((GPU)) \
422 | --seed $((SEED))
423 |
424 | # =========================================================================================
425 |
426 | SEED=2897
427 | VAL_NAME="IPC$((IPC)) 2K Medium R18 S$((SEED))"
428 |
429 | wandb enabled
430 | wandb online
431 | python /path/to/DELT/evaluation/main.py \
432 | --wandb-project "$PROJECT_NAME" \
433 | --wandb-api-key "$WANDB_API_KEY" \
434 | --val-dir "/path/to/val" \
435 | --syn-data-path "$SYN_PATH$EXP_NAME" \
436 | --exp-name "$VAL_NAME" \
437 | --subset "imagenet-1k" \
438 | --arch-name "resnet18" \
439 | --use-rand-augment \
440 | --rand-augment-config "$RAND_AUG" \
441 | --random-erasing-p 0.0 \
442 | --min-scale-crops 0.25 \
443 | --max-scale-crops 1.0 \
444 | --ipc $((IPC)) \
445 | --val-ipc 50 \
446 | --stud-name "resnet18" \
447 | --re-epochs 300 \
448 | --gpu-device $((GPU)) \
449 | --seed $((SEED))
450 |
451 | # =========================================================================================
452 |
453 | IPC=1
454 | EXP_NAME="IPC$((IPC))_1K_medium"
455 |
456 | echo "$EXP_NAME"
457 | SEED=3407
458 |
459 |
460 | # VAL_NAME="$RAND_AUG"
461 | VAL_NAME="IPC$((IPC)) 1K Medium R18 S$((SEED))"
462 |
463 | wandb enabled
464 | wandb online
465 | python /path/to/DELT/evaluation/main.py \
466 | --wandb-project "$PROJECT_NAME" \
467 | --wandb-api-key "$WANDB_API_KEY" \
468 | --val-dir "/path/to/val" \
469 | --syn-data-path "$SYN_PATH$EXP_NAME" \
470 | --exp-name "$VAL_NAME" \
471 | --subset "imagenet-1k" \
472 | --arch-name "resnet18" \
473 | --use-rand-augment \
474 | --rand-augment-config "$RAND_AUG" \
475 | --random-erasing-p 0.0 \
476 | --min-scale-crops 0.25 \
477 | --max-scale-crops 1.0 \
478 | --ipc $((IPC)) \
479 | --val-ipc 50 \
480 | --stud-name "resnet18" \
481 | --re-epochs 300 \
482 | --gpu-device $((GPU)) \
483 | --seed $((SEED))
484 |
485 | # =========================================================================================
486 |
487 |
488 | SEED=4663
489 | VAL_NAME="IPC$((IPC)) 1K Medium R18 S$((SEED))"
490 |
491 | wandb enabled
492 | wandb online
493 | python /path/to/DELT/evaluation/main.py \
494 | --wandb-project "$PROJECT_NAME" \
495 | --wandb-api-key "$WANDB_API_KEY" \
496 | --val-dir "/path/to/val" \
497 | --syn-data-path "$SYN_PATH$EXP_NAME" \
498 | --exp-name "$VAL_NAME" \
499 | --subset "imagenet-1k" \
500 | --arch-name "resnet18" \
501 | --use-rand-augment \
502 | --rand-augment-config "$RAND_AUG" \
503 | --random-erasing-p 0.0 \
504 | --min-scale-crops 0.25 \
505 | --max-scale-crops 1.0 \
506 | --ipc $((IPC)) \
507 | --val-ipc 50 \
508 | --stud-name "resnet18" \
509 | --re-epochs 300 \
510 | --gpu-device $((GPU)) \
511 | --seed $((SEED))
512 |
513 | # =========================================================================================
514 |
515 | SEED=2897
516 | VAL_NAME="IPC$((IPC)) 1K Medium R18 S$((SEED))"
517 |
518 | wandb enabled
519 | wandb online
520 | python /path/to/DELT/evaluation/main.py \
521 | --wandb-project "$PROJECT_NAME" \
522 | --wandb-api-key "$WANDB_API_KEY" \
523 | --val-dir "/path/to/val" \
524 | --syn-data-path "$SYN_PATH$EXP_NAME" \
525 | --exp-name "$VAL_NAME" \
526 | --subset "imagenet-1k" \
527 | --arch-name "resnet18" \
528 | --use-rand-augment \
529 | --rand-augment-config "$RAND_AUG" \
530 | --random-erasing-p 0.0 \
531 | --min-scale-crops 0.25 \
532 | --max-scale-crops 1.0 \
533 | --ipc $((IPC)) \
534 | --val-ipc 50 \
535 | --stud-name "resnet18" \
536 | --re-epochs 300 \
537 | --gpu-device $((GPU)) \
538 | --seed $((SEED))
539 |
540 | # =========================================================================================
541 |
--------------------------------------------------------------------------------
/scripts/select_medium_tiny_rn18_ep50_ipc50.sh:
--------------------------------------------------------------------------------
1 | TRAIN_DIR="/path/to/tiny-imagenet/train"
2 | OUTPUT_DIR="/path/to/ranked/tiny_imagenet_medium"
3 | RANKER_PATH="/path/to/model/tinyimagenet_resnet18_modified.pth"
4 | RANKING_FILE="/path/to/rankings_csv/tiny_imagenet.csv"
5 | # Download the model
6 | curl "https://drive.usercontent.google.com/download?id={1h_Enp0_FlgxCED-oriPuyYbmonYwxIi9}&confirm=xxx" -o "$RANKER_PATH"
7 |
8 | python /path/to/data_selection.py \
9 | --dataset "tiny-imagenet" \
10 | --data-path "$TRAIN_DIR" \
11 | --output-path "$OUTPUT_DIR" \
12 | --ranker-path "$RANKER_PATH" \
13 | --store-rank-file \
14 | --ranker-arch "resnet18" \
15 | --ranking-file "$RANKING_FILE" \
16 | --selection-criteria "medium" \
17 | --ipc 50 \
18 | --gpu-device 0
19 |
--------------------------------------------------------------------------------