├── data ├── __init__.py ├── save_img.py ├── dataset_statistics.py ├── augment.py ├── transform.py └── dataset.py ├── utils ├── __init__.py ├── mix_cut_up.py ├── init_script.py ├── ddp.py ├── experiment_tracker.py └── train_val.py ├── .DS_Store ├── asset ├── paper.pdf └── figure1.png ├── docs ├── images │ ├── real.png │ ├── str.png │ ├── minmax.png │ ├── weak2s.png │ ├── weak3s.png │ ├── minmax_1.png │ ├── pipeline.png │ └── synthetic.png └── static │ ├── .DS_Store │ ├── js │ ├── index.js │ └── bulma-slider.min.js │ └── css │ ├── bulma-carousel.min.css │ ├── index.css │ └── bulma-slider.min.css ├── .gitignore ├── requirements.txt ├── imagenet_subset ├── classimagemeow.txt ├── classimagewoof.txt ├── classimagefruit.txt ├── classimagenette.txt ├── classimagesquawk.txt ├── classimageyellow.txt └── class100.txt ├── condenser ├── subsample.py ├── condense_transfom.py ├── compute_loss.py ├── decode.py └── evaluate.py ├── config ├── ipc1 │ ├── cifar10.yaml │ ├── cifar100.yaml │ ├── imagewoof.yaml │ ├── imageyellow.yaml │ ├── imagemeow.yaml │ ├── imagenette.yaml │ ├── tinyimagenet.yaml │ ├── imagesquawk.yaml │ └── imagefruit.yaml ├── ipc10 │ ├── cifar10.yaml │ ├── cifar100.yaml │ ├── imagewoof.yaml │ ├── imageyellow.yaml │ ├── imagemeow.yaml │ ├── tinyimagenet.yaml │ ├── imagenette.yaml │ ├── imagesquawk.yaml │ └── imagefruit.yaml └── ipc50 │ ├── cifar10.yaml │ ├── cifar100.yaml │ └── tinyimagenet.yaml ├── CL └── CL_script.py ├── NCFM ├── SampleNet.py └── NCFM.py ├── argsprocessor └── args.py ├── README.md ├── condense └── condense_script.py ├── evaluation └── evaluation_script.py ├── models ├── densenet_cifar.py ├── convnet.py ├── resnet.py └── resnet_ap.py └── pretrain ├── pretrain_script.py └── pretrained_script_for_softlabel.py /data/ __init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gszfwsb/NCFM/HEAD/.DS_Store -------------------------------------------------------------------------------- /asset/paper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gszfwsb/NCFM/HEAD/asset/paper.pdf -------------------------------------------------------------------------------- /asset/figure1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gszfwsb/NCFM/HEAD/asset/figure1.png -------------------------------------------------------------------------------- /docs/images/real.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gszfwsb/NCFM/HEAD/docs/images/real.png -------------------------------------------------------------------------------- /docs/images/str.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gszfwsb/NCFM/HEAD/docs/images/str.png -------------------------------------------------------------------------------- /docs/images/minmax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gszfwsb/NCFM/HEAD/docs/images/minmax.png -------------------------------------------------------------------------------- /docs/images/weak2s.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gszfwsb/NCFM/HEAD/docs/images/weak2s.png -------------------------------------------------------------------------------- /docs/images/weak3s.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gszfwsb/NCFM/HEAD/docs/images/weak3s.png -------------------------------------------------------------------------------- /docs/static/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gszfwsb/NCFM/HEAD/docs/static/.DS_Store -------------------------------------------------------------------------------- /docs/images/minmax_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gszfwsb/NCFM/HEAD/docs/images/minmax_1.png -------------------------------------------------------------------------------- /docs/images/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gszfwsb/NCFM/HEAD/docs/images/pipeline.png -------------------------------------------------------------------------------- /docs/images/synthetic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gszfwsb/NCFM/HEAD/docs/images/synthetic.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | pretrained_models/ 2 | results/ 3 | dataset/ 4 | *.pyc 5 | pretrained_models 6 | dataset -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | efficientnet_pytorch==0.7.1 2 | matplotlib==3.8.4 3 | numpy==2.2.1 4 | PyYAML==6.0.2 5 | torch==2.5.0 6 | torchvision==0.20.0 7 | tqdm==4.66.4 8 | -------------------------------------------------------------------------------- /imagenet_subset/classimagemeow.txt: -------------------------------------------------------------------------------- 1 | n02123045 2 | n02123159 3 | n02123394 4 | n02123597 5 | n02124075 6 | n02129165 7 | n02129604 8 | n02128925 9 | n02128757 10 | n02127052 -------------------------------------------------------------------------------- /imagenet_subset/classimagewoof.txt: -------------------------------------------------------------------------------- 1 | n02096294 2 | n02093754 3 | n02111889 4 | n02088364 5 | n02086240 6 | n02089973 7 | n02087394 8 | n02115641 9 | n02099601 10 | n02105641 -------------------------------------------------------------------------------- /imagenet_subset/classimagefruit.txt: -------------------------------------------------------------------------------- 1 | n07753275 2 | n07753592 3 | n07745940 4 | n07747607 5 | n07749582 6 | n07768694 7 | n07753113 8 | n07720875 9 | n07718472 10 | n07760859 -------------------------------------------------------------------------------- /imagenet_subset/classimagenette.txt: -------------------------------------------------------------------------------- 1 | n01440764 2 | n02102040 3 | n02979186 4 | n03000684 5 | n03028079 6 | n03394916 7 | n03417042 8 | n03425413 9 | n03445777 10 | n03888257 -------------------------------------------------------------------------------- /imagenet_subset/classimagesquawk.txt: -------------------------------------------------------------------------------- 1 | n01806143 2 | n02007558 3 | n01818515 4 | n02051845 5 | n02056570 6 | n01614925 7 | n01843383 8 | n01518878 9 | n01860187 10 | n01819313 -------------------------------------------------------------------------------- /imagenet_subset/classimageyellow.txt: -------------------------------------------------------------------------------- 1 | n02206856 2 | n12057211 3 | n07753592 4 | n07749582 5 | n12144580 6 | n04146614 7 | n03530642 8 | n02129165 9 | n01773797 10 | n01531178 -------------------------------------------------------------------------------- /condenser/subsample.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def subsample(data, target, max_size=-1): 5 | if (data.shape[0] > max_size) and (max_size > 0): 6 | indices = np.random.permutation(data.shape[0]) 7 | data = data[indices[:max_size]] 8 | target = target[indices[:max_size]] 9 | 10 | return data, target 11 | -------------------------------------------------------------------------------- /data/save_img.py: -------------------------------------------------------------------------------- 1 | from .dataset_statistics import STDS,MEANS 2 | import torch 3 | from torchvision.utils import save_image 4 | import torch.nn.functional as F 5 | 6 | def img_denormlaize(img, dataname='imagenet'): 7 | """Scaling and shift a batch of images (NCHW) 8 | """ 9 | mean = MEANS[dataname] 10 | std = STDS[dataname] 11 | nch = img.shape[1] 12 | 13 | mean = torch.tensor(mean, device=img.device).reshape(1, nch, 1, 1) 14 | std = torch.tensor(std, device=img.device).reshape(1, nch, 1, 1) 15 | 16 | return img * std + mean 17 | 18 | 19 | 20 | 21 | 22 | 23 | def save_img(save_dir, img, unnormalize=True, max_num=200, size=64, nrow=10, dataname='imagenet'): 24 | img = img[:max_num].detach() 25 | if unnormalize: 26 | img = img_denormlaize(img, dataname=dataname) 27 | img = torch.clamp(img, min=0., max=1.) 28 | 29 | if img.shape[-1] > size: 30 | img = F.interpolate(img, size) 31 | save_image(img.cpu(), save_dir, nrow=nrow) -------------------------------------------------------------------------------- /utils/mix_cut_up.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | 5 | 6 | def random_indices(y, nclass=10, intraclass=False, device="cuda"): 7 | n = len(y) 8 | if intraclass: 9 | index = torch.arange(n).to(device) 10 | for c in range(nclass): 11 | index_c = index[y == c] 12 | if len(index_c) > 0: 13 | randidx = torch.randperm(len(index_c)) 14 | index[y == c] = index_c[randidx] 15 | else: 16 | index = torch.randperm(n).to(device) 17 | return index 18 | 19 | 20 | def rand_bbox(size, lam): 21 | W = size[2] 22 | H = size[3] 23 | cut_rat = np.sqrt(1.0 - lam) 24 | cut_w = int(W * cut_rat) 25 | cut_h = int(H * cut_rat) 26 | 27 | # uniform 28 | cx = np.random.randint(W) 29 | cy = np.random.randint(H) 30 | 31 | bbx1 = np.clip(cx - cut_w // 2, 0, W) 32 | bby1 = np.clip(cy - cut_h // 2, 0, H) 33 | bbx2 = np.clip(cx + cut_w // 2, 0, W) 34 | bby2 = np.clip(cy + cut_h // 2, 0, H) 35 | 36 | return bbx1, bby1, bbx2, bby2 37 | -------------------------------------------------------------------------------- /data/dataset_statistics.py: -------------------------------------------------------------------------------- 1 | 2 | # Values borrowed from https://github.com/VICO-UoE/DatasetCondensation/blob/master/utils.py 3 | 4 | IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') 5 | MEANS = {'cifar': [0.4914, 0.4822, 0.4465], 'imagenet': [0.485, 0.456, 0.406]} 6 | STDS = {'cifar': [0.2023, 0.1994, 0.2010], 'imagenet': [0.229, 0.224, 0.225]} 7 | MEANS['cifar10'] = MEANS['cifar'] 8 | STDS['cifar10'] = STDS['cifar'] 9 | MEANS['cifar100'] = MEANS['cifar'] 10 | STDS['cifar100'] = STDS['cifar'] 11 | MEANS['svhn'] = [0.4377, 0.4438, 0.4728] 12 | STDS['svhn'] = [0.1980, 0.2010, 0.1970] 13 | MEANS['mnist'] = [0.1307] 14 | STDS['mnist'] = [0.3081] 15 | MEANS['fashion'] = [0.2861] 16 | STDS['fashion'] = [0.3530] 17 | MEANS['tinyimagenet'] = [0.485, 0.456, 0.406] 18 | STDS['tinyimagenet'] = [0.229, 0.224, 0.225] 19 | 20 | 21 | # ['imagenette', 'imagewoof', 'imagemeow', 'imagesquawk', 'imagefruit', 'imageyellow'] 22 | MEANS['imagenette'] = [0.485, 0.456, 0.406] 23 | STDS['imagenette'] = [0.229, 0.224, 0.225] 24 | MEANS['imagewoof'] = [0.485, 0.456, 0.406] 25 | STDS['imagewoof'] = [0.229, 0.224, 0.225] 26 | MEANS['imagemeow'] = [0.485, 0.456, 0.406] 27 | STDS['imagemeow'] = [0.229, 0.224, 0.225] 28 | MEANS['imagesquawk'] = [0.485, 0.456, 0.406] 29 | STDS['imagesquawk'] = [0.229, 0.224, 0.225] 30 | MEANS['imagefruit'] = [0.485, 0.456, 0.406] 31 | STDS['imagefruit'] = [0.229, 0.224, 0.225] 32 | MEANS['imageyellow'] = [0.485, 0.456, 0.406] 33 | STDS['imageyellow'] = [0.229, 0.224, 0.225] 34 | -------------------------------------------------------------------------------- /config/ipc1/cifar10.yaml: -------------------------------------------------------------------------------- 1 | distibution_train: 2 | backend: 'nccl' # choices=['nccl', 'gloo', 'mpi', 'torch'] 3 | init_method: 'env://' 4 | workers: 8 5 | 6 | dataset: 7 | dataset: cifar10 8 | nclass: 10 9 | size: 32 10 | data_dir: '../dataset' 11 | load_memory: true 12 | batch_real: 1024 13 | nch: 3 14 | 15 | network: 16 | net_type: convnet 17 | norm_type: instance 18 | depth: 3 19 | width: 1.0 20 | 21 | train: 22 | evaluation_epochs: 2000 23 | epoch_print_freq: 10 24 | epoch_eval_interval: 100 25 | pertrain_epochs: 60 26 | batch_size: 128 27 | lr: 0.01 28 | adamw_lr: 0.001 29 | eval_optimizer: adamw 30 | momentum: 0.9 31 | weight_decay: 5e-4 32 | seed: 0 33 | model_num: 20 34 | 35 | 36 | augmentation: 37 | mixup: cut 38 | beta: 1.0 39 | mix_p: 0.5 40 | rrc: True 41 | dsa: true 42 | dsa_strategy: "color_crop_cutout_flip_scale_rotate" 43 | aug_type: 'color_crop_cutout' 44 | 45 | optimization: 46 | optimizer: adamw 47 | lr_scale_adam: 0.1 48 | lr_img: 0.01 49 | mom_img: 0.5 50 | lr_sampling_net: 1e-3 51 | 52 | save_path: 53 | save_dir: "../results/condense" 54 | pretrain_dir: '../pretrained_models' 55 | 56 | condense: 57 | ipc: 1 58 | num_premodel: 20 59 | niter: 20000 60 | iter_calib: 1 61 | calib_weight: 1 62 | sampling_net: False 63 | num_freqs: 4096 64 | dis_metrics: "NCFM" 65 | factor: 2 66 | alpha_for_loss: 0.5 67 | beta_for_loss: 0.5 68 | decode_type: 'single' # choices=['single', 'multi', 'bound'] 69 | teacher_model_epoch: 20 -------------------------------------------------------------------------------- /config/ipc10/cifar10.yaml: -------------------------------------------------------------------------------- 1 | distibution_train: 2 | backend: 'nccl' # choices=['nccl', 'gloo', 'mpi', 'torch'] 3 | init_method: 'env://' 4 | workers: 8 5 | 6 | dataset: 7 | dataset: cifar10 8 | nclass: 10 9 | size: 32 10 | data_dir: '../dataset' 11 | load_memory: true 12 | batch_real: 1024 13 | nch: 3 14 | 15 | network: 16 | net_type: convnet 17 | norm_type: instance 18 | depth: 3 19 | width: 1.0 20 | 21 | train: 22 | evaluation_epochs: 2000 23 | epoch_print_freq: 10 24 | epoch_eval_interval: 100 25 | pertrain_epochs: 60 26 | batch_size: 128 27 | lr: 0.01 28 | adamw_lr: 0.001 29 | eval_optimizer: adamw 30 | momentum: 0.9 31 | weight_decay: 5e-4 32 | seed: 0 33 | model_num: 20 34 | 35 | 36 | augmentation: 37 | mixup: cut 38 | beta: 1.0 39 | mix_p: 0.5 40 | rrc: True 41 | dsa: true 42 | dsa_strategy: "color_crop_cutout_flip_scale_rotate" 43 | aug_type: 'color_crop_cutout' 44 | 45 | optimization: 46 | optimizer: adamw 47 | lr_scale_adam: 0.1 48 | lr_img: 0.01 49 | mom_img: 0.5 50 | lr_sampling_net: 1e-3 51 | 52 | save_path: 53 | save_dir: "../results/condense" 54 | pretrain_dir: '../pretrained_models' 55 | 56 | condense: 57 | ipc: 10 58 | num_premodel: 20 59 | niter: 20000 60 | iter_calib: 1 61 | calib_weight: 1 62 | sampling_net: False 63 | num_freqs: 4096 64 | dis_metrics: "NCFM" 65 | factor: 2 66 | alpha_for_loss: 0.5 67 | beta_for_loss: 0.5 68 | decode_type: 'single' # choices=['single', 'multi', 'bound'] 69 | teacher_model_epoch: 20 -------------------------------------------------------------------------------- /config/ipc50/cifar10.yaml: -------------------------------------------------------------------------------- 1 | distibution_train: 2 | backend: 'nccl' # choices=['nccl', 'gloo', 'mpi', 'torch'] 3 | init_method: 'env://' 4 | workers: 8 5 | 6 | dataset: 7 | dataset: cifar10 8 | nclass: 10 9 | size: 32 10 | data_dir: '../dataset' 11 | load_memory: true 12 | batch_real: 1024 13 | nch: 3 14 | 15 | network: 16 | net_type: convnet 17 | norm_type: instance 18 | depth: 3 19 | width: 1.0 20 | 21 | train: 22 | evaluation_epochs: 2000 23 | epoch_print_freq: 10 24 | epoch_eval_interval: 100 25 | pertrain_epochs: 60 26 | batch_size: 128 27 | lr: 0.01 28 | adamw_lr: 0.001 29 | eval_optimizer: adamw 30 | momentum: 0.9 31 | weight_decay: 5e-4 32 | seed: 0 33 | model_num: 20 34 | 35 | 36 | augmentation: 37 | mixup: cut 38 | beta: 1.0 39 | mix_p: 0.5 40 | rrc: True 41 | dsa: true 42 | dsa_strategy: "color_crop_cutout_flip_scale_rotate" 43 | aug_type: 'color_crop_cutout' 44 | 45 | optimization: 46 | optimizer: adamw 47 | lr_scale_adam: 0.1 48 | lr_img: 0.01 49 | mom_img: 0.5 50 | lr_sampling_net: 1e-3 51 | 52 | save_path: 53 | save_dir: "../results/condense" 54 | pretrain_dir: '../pretrained_models' 55 | 56 | condense: 57 | ipc: 50 58 | num_premodel: 20 59 | niter: 20000 60 | iter_calib: 1 61 | calib_weight: 1 62 | sampling_net: False 63 | num_freqs: 4096 64 | dis_metrics: "NCFM" 65 | factor: 2 66 | alpha_for_loss: 0.5 67 | beta_for_loss: 0.5 68 | decode_type: 'single' # choices=['single', 'multi', 'bound'] 69 | teacher_model_epoch: 20 -------------------------------------------------------------------------------- /config/ipc1/cifar100.yaml: -------------------------------------------------------------------------------- 1 | distibution_train: 2 | backend: 'nccl' # choices=['nccl', 'gloo', 'mpi', 'torch'] 3 | init_method: 'env://' 4 | workers: 8 5 | 6 | dataset: 7 | dataset: cifar100 8 | nclass: 100 9 | size: 32 10 | data_dir: '../dataset' 11 | load_memory: true 12 | batch_real: 1024 13 | nch: 3 14 | 15 | network: 16 | net_type: convnet 17 | norm_type: instance 18 | depth: 3 19 | width: 1.0 20 | 21 | train: 22 | evaluation_epochs: 2000 23 | epoch_print_freq: 10 24 | epoch_eval_interval: 100 25 | pertrain_epochs: 60 26 | batch_size: 128 27 | lr: 0.01 28 | adamw_lr: 0.001 29 | eval_optimizer: adamw 30 | momentum: 0.9 31 | weight_decay: 5e-4 32 | seed: 0 33 | model_num: 20 34 | 35 | 36 | augmentation: 37 | mixup: cut 38 | beta: 1.0 39 | mix_p: 0.5 40 | rrc: True 41 | dsa: true 42 | dsa_strategy: "color_crop_cutout_flip_scale_rotate" 43 | aug_type: 'color_crop_cutout' 44 | 45 | optimization: 46 | optimizer: adamw 47 | lr_scale_adam: 0.1 48 | lr_img: 0.01 49 | mom_img: 0.5 50 | lr_sampling_net: 1e-3 51 | 52 | save_path: 53 | save_dir: "../results/condense" 54 | pretrain_dir: '../pretrained_models' 55 | 56 | condense: 57 | ipc: 1 58 | num_premodel: 20 59 | niter: 20000 60 | iter_calib: 1 61 | calib_weight: 1 62 | sampling_net: False 63 | num_freqs: 4096 64 | dis_metrics: "NCFM" 65 | factor: 2 66 | alpha_for_loss: 0.5 67 | beta_for_loss: 0.5 68 | decode_type: 'single' # choices=['single', 'multi', 'bound'] 69 | teacher_model_epoch: 20 -------------------------------------------------------------------------------- /config/ipc1/imagewoof.yaml: -------------------------------------------------------------------------------- 1 | distibution_train: 2 | backend: 'nccl' # choices=['nccl', 'gloo', 'mpi', 'torch'] 3 | init_method: 'env://' 4 | workers: 8 5 | dataset: 6 | dataset: imagewoof 7 | nclass: 10 8 | size: 128 9 | data_dir: '/root/autodl-tmp/imagenet/' 10 | load_memory: true 11 | batch_real: 1024 12 | nch: 3 13 | network: 14 | net_type: convnet 15 | norm_type: instance 16 | depth: 5 17 | width: 1.0 18 | 19 | train: 20 | evaluation_epochs: 2000 21 | epoch_print_freq: 10 22 | epoch_eval_interval: 100 23 | pertrain_epochs: 80 24 | batch_size: 128 25 | lr: 0.01 26 | adamw_lr: 0.001 27 | eval_optimizer: adamw 28 | momentum: 0.9 29 | weight_decay: 5e-4 30 | seed: 0 31 | model_num: 20 32 | 33 | 34 | augmentation: 35 | mixup: cut 36 | beta: 1.0 37 | mix_p: 0.5 38 | rrc: true 39 | dsa: true 40 | dsa_strategy: "color_crop_cutout_flip_scale_rotate" 41 | aug_type: 'color_crop_cutout' 42 | 43 | optimization: 44 | optimizer: adamw 45 | lr_scale_adam: 0.1 46 | lr_img: 0.01 47 | mom_img: 0.5 48 | lr_sampling_net: 1e-3 49 | 50 | save_path: 51 | save_dir: "../results/condense" 52 | pretrain_dir: '../pretrained_models' 53 | 54 | condense: 55 | ipc: 1 56 | num_premodel: 20 57 | niter: 20000 58 | iter_calib: 1 59 | calib_weight: 1 60 | sampling_net: False 61 | num_freqs: 4096 62 | dis_metrics: "NCFM" 63 | factor: 2 64 | alpha_for_loss: 0.5 65 | beta_for_loss: 0.5 66 | decode_type: 'single' # choices=['single', 'multi', 'bound'] 67 | teacher_model_epoch: 20 -------------------------------------------------------------------------------- /config/ipc10/cifar100.yaml: -------------------------------------------------------------------------------- 1 | distibution_train: 2 | backend: 'nccl' # choices=['nccl', 'gloo', 'mpi', 'torch'] 3 | init_method: 'env://' 4 | workers: 8 5 | 6 | dataset: 7 | dataset: cifar100 8 | nclass: 100 9 | size: 32 10 | data_dir: '../dataset' 11 | load_memory: true 12 | batch_real: 1024 13 | nch: 3 14 | 15 | network: 16 | net_type: convnet 17 | norm_type: instance 18 | depth: 3 19 | width: 1.0 20 | 21 | train: 22 | evaluation_epochs: 2000 23 | epoch_print_freq: 10 24 | epoch_eval_interval: 100 25 | pertrain_epochs: 60 26 | batch_size: 128 27 | lr: 0.01 28 | adamw_lr: 0.001 29 | eval_optimizer: adamw 30 | momentum: 0.9 31 | weight_decay: 5e-4 32 | seed: 0 33 | model_num: 20 34 | 35 | 36 | augmentation: 37 | mixup: cut 38 | beta: 1.0 39 | mix_p: 0.5 40 | rrc: True 41 | dsa: true 42 | dsa_strategy: "color_crop_cutout_flip_scale_rotate" 43 | aug_type: 'color_crop_cutout' 44 | 45 | optimization: 46 | optimizer: sgd 47 | lr_scale_adam: 0.1 48 | lr_img: 0.01 49 | mom_img: 0.5 50 | lr_sampling_net: 1e-3 51 | 52 | save_path: 53 | save_dir: "../results/condense" 54 | pretrain_dir: '../pretrained_models' 55 | 56 | condense: 57 | ipc: 10 58 | num_premodel: 20 59 | niter: 20000 60 | iter_calib: 1 61 | calib_weight: 1 62 | sampling_net: False 63 | num_freqs: 4096 64 | dis_metrics: "NCFM" 65 | factor: 2 66 | alpha_for_loss: 0.5 67 | beta_for_loss: 0.5 68 | decode_type: 'single' # choices=['single', 'multi', 'bound'] 69 | teacher_model_epoch: 20 -------------------------------------------------------------------------------- /config/ipc50/cifar100.yaml: -------------------------------------------------------------------------------- 1 | distibution_train: 2 | backend: 'nccl' # choices=['nccl', 'gloo', 'mpi', 'torch'] 3 | init_method: 'env://' 4 | workers: 8 5 | 6 | dataset: 7 | dataset: cifar100 8 | nclass: 100 9 | size: 32 10 | data_dir: '../dataset' 11 | load_memory: true 12 | batch_real: 1024 13 | nch: 3 14 | 15 | network: 16 | net_type: convnet 17 | norm_type: instance 18 | depth: 3 19 | width: 1.0 20 | 21 | train: 22 | evaluation_epochs: 2000 23 | epoch_print_freq: 10 24 | epoch_eval_interval: 100 25 | pertrain_epochs: 60 26 | batch_size: 128 27 | lr: 0.01 28 | adamw_lr: 0.001 29 | eval_optimizer: adamw 30 | momentum: 0.9 31 | weight_decay: 5e-4 32 | seed: 0 33 | model_num: 20 34 | 35 | 36 | augmentation: 37 | mixup: cut 38 | beta: 1.0 39 | mix_p: 0.5 40 | rrc: True 41 | dsa: true 42 | dsa_strategy: "color_crop_cutout_flip_scale_rotate" 43 | aug_type: 'color_crop_cutout' 44 | 45 | optimization: 46 | optimizer: adamw 47 | lr_scale_adam: 0.1 48 | lr_img: 0.01 49 | mom_img: 0.5 50 | lr_sampling_net: 1e-3 51 | 52 | save_path: 53 | save_dir: "../results/condense" 54 | pretrain_dir: '../pretrained_models' 55 | 56 | condense: 57 | ipc: 50 58 | num_premodel: 20 59 | niter: 20000 60 | iter_calib: 1 61 | calib_weight: 1 62 | sampling_net: False 63 | num_freqs: 4096 64 | dis_metrics: "NCFM" 65 | factor: 2 66 | alpha_for_loss: 0.5 67 | beta_for_loss: 0.5 68 | decode_type: 'single' # choices=['single', 'multi', 'bound'] 69 | teacher_model_epoch: 20 -------------------------------------------------------------------------------- /config/ipc1/imageyellow.yaml: -------------------------------------------------------------------------------- 1 | distibution_train: 2 | backend: 'nccl' # choices=['nccl', 'gloo', 'mpi', 'torch'] 3 | init_method: 'env://' 4 | workers: 8 5 | dataset: 6 | dataset: imageyellow 7 | nclass: 10 8 | size: 128 9 | data_dir: '/root/autodl-tmp/imagenet/' 10 | load_memory: true 11 | batch_real: 1024 12 | nch: 3 13 | network: 14 | net_type: convnet 15 | norm_type: instance 16 | depth: 5 17 | width: 1.0 18 | 19 | train: 20 | evaluation_epochs: 2000 21 | epoch_print_freq: 10 22 | epoch_eval_interval: 100 23 | pertrain_epochs: 80 24 | batch_size: 128 25 | lr: 0.01 26 | adamw_lr: 0.001 27 | eval_optimizer: adamw 28 | momentum: 0.9 29 | weight_decay: 5e-4 30 | seed: 0 31 | model_num: 20 32 | 33 | 34 | augmentation: 35 | mixup: cut 36 | beta: 1.0 37 | mix_p: 0.5 38 | rrc: true 39 | dsa: true 40 | dsa_strategy: "color_crop_cutout_flip_scale_rotate" 41 | aug_type: 'color_crop_cutout' 42 | 43 | optimization: 44 | optimizer: adamw 45 | lr_scale_adam: 0.1 46 | lr_img: 0.01 47 | mom_img: 0.5 48 | lr_sampling_net: 1e-3 49 | 50 | save_path: 51 | save_dir: "../results/condense" 52 | pretrain_dir: '../pretrained_models' 53 | 54 | condense: 55 | ipc: 1 56 | num_premodel: 20 57 | niter: 20000 58 | iter_calib: 1 59 | calib_weight: 1 60 | sampling_net: False 61 | num_freqs: 4096 62 | dis_metrics: "NCFM" 63 | factor: 2 64 | alpha_for_loss: 0.5 65 | beta_for_loss: 0.5 66 | decode_type: 'single' # choices=['single', 'multi', 'bound'] 67 | teacher_model_epoch: 20 -------------------------------------------------------------------------------- /config/ipc10/imagewoof.yaml: -------------------------------------------------------------------------------- 1 | distibution_train: 2 | backend: 'nccl' # choices=['nccl', 'gloo', 'mpi', 'torch'] 3 | init_method: 'env://' 4 | workers: 8 5 | dataset: 6 | dataset: imagewoof 7 | nclass: 10 8 | size: 128 9 | data_dir: '/root/autodl-tmp/imagenet/' 10 | load_memory: true 11 | batch_real: 1024 12 | nch: 3 13 | network: 14 | net_type: convnet 15 | norm_type: instance 16 | depth: 5 17 | width: 1.0 18 | 19 | train: 20 | evaluation_epochs: 2000 21 | epoch_print_freq: 10 22 | epoch_eval_interval: 100 23 | pertrain_epochs: 80 24 | batch_size: 128 25 | lr: 0.01 26 | adamw_lr: 0.001 27 | eval_optimizer: adamw 28 | momentum: 0.9 29 | weight_decay: 5e-4 30 | seed: 0 31 | model_num: 20 32 | 33 | 34 | augmentation: 35 | mixup: cut 36 | beta: 1.0 37 | mix_p: 0.5 38 | rrc: true 39 | dsa: true 40 | dsa_strategy: "color_crop_cutout_flip_scale_rotate" 41 | aug_type: 'color_crop_cutout' 42 | 43 | optimization: 44 | optimizer: adamw 45 | lr_scale_adam: 0.1 46 | lr_img: 0.01 47 | mom_img: 0.5 48 | lr_sampling_net: 1e-3 49 | 50 | save_path: 51 | save_dir: "../results/condense" 52 | pretrain_dir: '../pretrained_models' 53 | 54 | condense: 55 | ipc: 10 56 | num_premodel: 20 57 | niter: 20000 58 | iter_calib: 1 59 | calib_weight: 1 60 | sampling_net: False 61 | num_freqs: 4096 62 | dis_metrics: "NCFM" 63 | factor: 2 64 | alpha_for_loss: 0.5 65 | beta_for_loss: 0.5 66 | decode_type: 'single' # choices=['single', 'multi', 'bound'] 67 | teacher_model_epoch: 20 -------------------------------------------------------------------------------- /config/ipc10/imageyellow.yaml: -------------------------------------------------------------------------------- 1 | distibution_train: 2 | backend: 'nccl' # choices=['nccl', 'gloo', 'mpi', 'torch'] 3 | init_method: 'env://' 4 | workers: 8 5 | dataset: 6 | dataset: imageyellow 7 | nclass: 10 8 | size: 128 9 | data_dir: '/root/autodl-tmp/imagenet/' 10 | load_memory: true 11 | batch_real: 1024 12 | nch: 3 13 | network: 14 | net_type: convnet 15 | norm_type: instance 16 | depth: 5 17 | width: 1.0 18 | 19 | train: 20 | evaluation_epochs: 2000 21 | epoch_print_freq: 10 22 | epoch_eval_interval: 100 23 | pertrain_epochs: 80 24 | batch_size: 128 25 | lr: 0.01 26 | adamw_lr: 0.001 27 | eval_optimizer: adamw 28 | momentum: 0.9 29 | weight_decay: 5e-4 30 | seed: 0 31 | model_num: 20 32 | 33 | 34 | augmentation: 35 | mixup: cut 36 | beta: 1.0 37 | mix_p: 0.5 38 | rrc: true 39 | dsa: true 40 | dsa_strategy: "color_crop_cutout_flip_scale_rotate" 41 | aug_type: 'color_crop_cutout' 42 | 43 | optimization: 44 | optimizer: adamw 45 | lr_scale_adam: 0.1 46 | lr_img: 0.01 47 | mom_img: 0.5 48 | lr_sampling_net: 1e-3 49 | 50 | save_path: 51 | save_dir: "../results/condense" 52 | pretrain_dir: '../pretrained_models' 53 | 54 | condense: 55 | ipc: 10 56 | num_premodel: 20 57 | niter: 20000 58 | iter_calib: 1 59 | calib_weight: 1 60 | sampling_net: False 61 | num_freqs: 4096 62 | dis_metrics: "NCFM" 63 | factor: 2 64 | alpha_for_loss: 0.5 65 | beta_for_loss: 0.5 66 | decode_type: 'single' # choices=['single', 'multi', 'bound'] 67 | teacher_model_epoch: 20 -------------------------------------------------------------------------------- /config/ipc50/tinyimagenet.yaml: -------------------------------------------------------------------------------- 1 | distibution_train: 2 | backend: 'nccl' # choices=['nccl', 'gloo', 'mpi', 'torch'] 3 | init_method: 'env://' 4 | workers: 8 5 | dataset: 6 | dataset: tinyimagenet 7 | nclass: 200 8 | size: 64 9 | data_dir: '../dataset' 10 | load_memory: true 11 | batch_real: 1024 12 | nch: 3 13 | 14 | network: 15 | net_type: convnet 16 | norm_type: instance 17 | depth: 4 18 | width: 1.0 19 | 20 | train: 21 | evaluation_epochs: 40 22 | epoch_print_freq: 10 23 | epoch_eval_interval: 5 24 | pertrain_epochs: 60 25 | batch_size: 128 26 | lr: 0.01 27 | adamw_lr: 0.001 28 | eval_optimizer: adamw 29 | momentum: 0.9 30 | weight_decay: 5e-4 31 | seed: 0 32 | model_num: 20 33 | 34 | 35 | augmentation: 36 | mixup: cut 37 | beta: 1.0 38 | mix_p: 0.5 39 | rrc: True 40 | dsa: true 41 | dsa_strategy: "color_crop_cutout_flip_scale_rotate" 42 | aug_type: 'color_crop_cutout' 43 | 44 | optimization: 45 | optimizer: adamw 46 | lr_scale_adam: 0.1 47 | lr_img: 0.01 48 | mom_img: 0.5 49 | lr_sampling_net: 1e-3 50 | 51 | save_path: 52 | save_dir: "../results/condense" 53 | pretrain_dir: '../pretrained_models' 54 | 55 | condense: 56 | ipc: 50 57 | num_premodel: 20 58 | niter: 20000 59 | iter_calib: 1 60 | calib_weight: 1 61 | sampling_net: False 62 | num_freqs: 4096 63 | dis_metrics: "NCFM" 64 | factor: 2 65 | alpha_for_loss: 0.5 66 | beta_for_loss: 0.5 67 | decode_type: 'single' # choices=['single', 'multi', 'bound'] 68 | teacher_model_epoch: 20 69 | -------------------------------------------------------------------------------- /config/ipc1/imagemeow.yaml: -------------------------------------------------------------------------------- 1 | distibution_train: 2 | backend: 'nccl' # choices=['nccl', 'gloo', 'mpi', 'torch'] 3 | init_method: 'env://' 4 | workers: 8 5 | 6 | dataset: 7 | dataset: imagemeow 8 | nclass: 10 9 | size: 128 10 | data_dir: '/root/autodl-tmp/imagenet/' 11 | load_memory: true 12 | batch_real: 1024 13 | nch: 3 14 | network: 15 | net_type: convnet 16 | norm_type: instance 17 | depth: 5 18 | width: 1.0 19 | 20 | train: 21 | evaluation_epochs: 2000 22 | epoch_print_freq: 10 23 | epoch_eval_interval: 100 24 | pertrain_epochs: 80 25 | batch_size: 128 26 | lr: 0.01 27 | adamw_lr: 0.001 28 | eval_optimizer: adamw 29 | momentum: 0.9 30 | weight_decay: 5e-4 31 | seed: 0 32 | model_num: 20 33 | 34 | 35 | augmentation: 36 | mixup: cut 37 | beta: 1.0 38 | mix_p: 0.5 39 | rrc: true 40 | dsa: true 41 | dsa_strategy: "color_crop_cutout_flip_scale_rotate" 42 | aug_type: 'color_crop_cutout' 43 | 44 | optimization: 45 | optimizer: adamw 46 | lr_scale_adam: 0.1 47 | lr_img: 0.01 48 | mom_img: 0.5 49 | lr_sampling_net: 1e-3 50 | 51 | save_path: 52 | save_dir: "../results/condense" 53 | pretrain_dir: '../pretrained_models' 54 | 55 | condense: 56 | ipc: 1 57 | num_premodel: 20 58 | niter: 20000 59 | iter_calib: 1 60 | calib_weight: 1 61 | sampling_net: False 62 | num_freqs: 4096 63 | dis_metrics: "NCFM" 64 | factor: 2 65 | alpha_for_loss: 0.5 66 | beta_for_loss: 0.5 67 | decode_type: 'single' # choices=['single', 'multi', 'bound'] 68 | teacher_model_epoch: 20 -------------------------------------------------------------------------------- /config/ipc1/imagenette.yaml: -------------------------------------------------------------------------------- 1 | distibution_train: 2 | backend: 'nccl' # choices=['nccl', 'gloo', 'mpi', 'torch'] 3 | init_method: 'env://' 4 | workers: 8 5 | 6 | dataset: 7 | dataset: imagenette 8 | nclass: 10 9 | size: 128 10 | data_dir: '/root/autodl-tmp/imagenet/' 11 | load_memory: true 12 | batch_real: 1024 13 | nch: 3 14 | network: 15 | net_type: convnet 16 | norm_type: instance 17 | depth: 5 18 | width: 1.0 19 | 20 | train: 21 | evaluation_epochs: 2000 22 | epoch_print_freq: 10 23 | epoch_eval_interval: 100 24 | pertrain_epochs: 80 25 | batch_size: 128 26 | lr: 0.01 27 | adamw_lr: 0.001 28 | eval_optimizer: adamw 29 | momentum: 0.9 30 | weight_decay: 5e-4 31 | seed: 0 32 | model_num: 20 33 | 34 | 35 | augmentation: 36 | mixup: cut 37 | beta: 1.0 38 | mix_p: 0.5 39 | rrc: true 40 | dsa: true 41 | dsa_strategy: "color_crop_cutout_flip_scale_rotate" 42 | aug_type: 'color_crop_cutout' 43 | 44 | optimization: 45 | optimizer: adamw 46 | lr_scale_adam: 0.1 47 | lr_img: 0.01 48 | mom_img: 0.5 49 | lr_sampling_net: 1e-3 50 | 51 | save_path: 52 | save_dir: "../results/condense" 53 | pretrain_dir: '../pretrained_models' 54 | 55 | condense: 56 | ipc: 1 57 | num_premodel: 20 58 | niter: 20000 59 | iter_calib: 1 60 | calib_weight: 1 61 | sampling_net: False 62 | num_freqs: 4096 63 | dis_metrics: "NCFM" 64 | factor: 2 65 | alpha_for_loss: 0.5 66 | beta_for_loss: 0.5 67 | decode_type: 'single' # choices=['single', 'multi', 'bound'] 68 | teacher_model_epoch: 20 -------------------------------------------------------------------------------- /config/ipc1/tinyimagenet.yaml: -------------------------------------------------------------------------------- 1 | distibution_train: 2 | backend: 'nccl' # choices=['nccl', 'gloo', 'mpi', 'torch'] 3 | init_method: 'env://' 4 | workers: 8 5 | dataset: 6 | dataset: tinyimagenet 7 | nclass: 200 8 | size: 64 9 | data_dir: '../dataset' 10 | load_memory: true 11 | batch_real: 1024 12 | nch: 3 13 | 14 | network: 15 | net_type: convnet 16 | norm_type: instance 17 | depth: 4 18 | width: 1.0 19 | 20 | train: 21 | evaluation_epochs: 2000 22 | epoch_print_freq: 10 23 | epoch_eval_interval: 100 24 | pertrain_epochs: 60 25 | batch_size: 128 26 | lr: 0.01 27 | adamw_lr: 0.001 28 | eval_optimizer: adamw 29 | momentum: 0.9 30 | weight_decay: 5e-4 31 | seed: 0 32 | model_num: 20 33 | 34 | 35 | augmentation: 36 | mixup: cut 37 | beta: 1.0 38 | mix_p: 0.5 39 | rrc: True 40 | dsa: true 41 | dsa_strategy: "color_crop_cutout_flip_scale_rotate" 42 | aug_type: 'color_crop_cutout' 43 | 44 | optimization: 45 | optimizer: adamw 46 | lr_scale_adam: 0.1 47 | lr_img: 0.01 48 | mom_img: 0.5 49 | lr_sampling_net: 1e-3 50 | 51 | save_path: 52 | save_dir: "../results/condense" 53 | pretrain_dir: '../pretrained_models' 54 | 55 | condense: 56 | ipc: 1 57 | num_premodel: 20 58 | niter: 20000 59 | iter_calib: 1 60 | calib_weight: 1 61 | sampling_net: False 62 | num_freqs: 4096 63 | dis_metrics: "NCFM" 64 | factor: 2 65 | alpha_for_loss: 0.5 66 | beta_for_loss: 0.5 67 | decode_type: 'single' # choices=['single', 'multi', 'bound'] 68 | teacher_model_epoch: 20 69 | 70 | -------------------------------------------------------------------------------- /config/ipc10/imagemeow.yaml: -------------------------------------------------------------------------------- 1 | distibution_train: 2 | backend: 'nccl' # choices=['nccl', 'gloo', 'mpi', 'torch'] 3 | init_method: 'env://' 4 | workers: 8 5 | 6 | dataset: 7 | dataset: imagemeow 8 | nclass: 10 9 | size: 128 10 | data_dir: '/root/autodl-tmp/imagenet/' 11 | load_memory: true 12 | batch_real: 1024 13 | nch: 3 14 | network: 15 | net_type: convnet 16 | norm_type: instance 17 | depth: 5 18 | width: 1.0 19 | 20 | train: 21 | evaluation_epochs: 2000 22 | epoch_print_freq: 10 23 | epoch_eval_interval: 100 24 | pertrain_epochs: 80 25 | batch_size: 128 26 | lr: 0.01 27 | adamw_lr: 0.001 28 | eval_optimizer: adamw 29 | momentum: 0.9 30 | weight_decay: 5e-4 31 | seed: 0 32 | model_num: 20 33 | 34 | 35 | augmentation: 36 | mixup: cut 37 | beta: 1.0 38 | mix_p: 0.5 39 | rrc: true 40 | dsa: true 41 | dsa_strategy: "color_crop_cutout_flip_scale_rotate" 42 | aug_type: 'color_crop_cutout' 43 | 44 | optimization: 45 | optimizer: adamw 46 | lr_scale_adam: 0.1 47 | lr_img: 0.01 48 | mom_img: 0.5 49 | lr_sampling_net: 1e-3 50 | 51 | save_path: 52 | save_dir: "../results/condense" 53 | pretrain_dir: '../pretrained_models' 54 | 55 | condense: 56 | ipc: 10 57 | num_premodel: 20 58 | niter: 20000 59 | iter_calib: 1 60 | calib_weight: 1 61 | sampling_net: False 62 | num_freqs: 4096 63 | dis_metrics: "NCFM" 64 | factor: 2 65 | alpha_for_loss: 0.5 66 | beta_for_loss: 0.5 67 | decode_type: 'single' # choices=['single', 'multi', 'bound'] 68 | teacher_model_epoch: 20 -------------------------------------------------------------------------------- /config/ipc10/tinyimagenet.yaml: -------------------------------------------------------------------------------- 1 | distibution_train: 2 | backend: 'nccl' # choices=['nccl', 'gloo', 'mpi', 'torch'] 3 | init_method: 'env://' 4 | workers: 8 5 | dataset: 6 | dataset: tinyimagenet 7 | nclass: 200 8 | size: 64 9 | data_dir: '../dataset' 10 | load_memory: true 11 | batch_real: 1024 12 | nch: 3 13 | 14 | network: 15 | net_type: convnet 16 | norm_type: instance 17 | depth: 4 18 | width: 1.0 19 | 20 | train: 21 | evaluation_epochs: 300 22 | epoch_print_freq: 10 23 | epoch_eval_interval: 50 24 | pertrain_epochs: 60 25 | batch_size: 128 26 | lr: 0.01 27 | adamw_lr: 0.001 28 | eval_optimizer: adamw 29 | momentum: 0.9 30 | weight_decay: 5e-4 31 | seed: 0 32 | model_num: 20 33 | 34 | 35 | augmentation: 36 | mixup: cut 37 | beta: 1.0 38 | mix_p: 0.5 39 | rrc: True 40 | dsa: true 41 | dsa_strategy: "color_crop_cutout_flip_scale_rotate" 42 | aug_type: 'color_crop_cutout' 43 | 44 | optimization: 45 | optimizer: adamw 46 | lr_scale_adam: 0.1 47 | lr_img: 0.01 48 | mom_img: 0.5 49 | lr_sampling_net: 1e-3 50 | 51 | save_path: 52 | save_dir: "../results/condense" 53 | pretrain_dir: '../pretrained_models' 54 | 55 | condense: 56 | ipc: 10 57 | num_premodel: 20 58 | niter: 20000 59 | iter_calib: 1 60 | calib_weight: 1 61 | sampling_net: False 62 | num_freqs: 4096 63 | dis_metrics: "NCFM" 64 | factor: 2 65 | alpha_for_loss: 0.5 66 | beta_for_loss: 0.5 67 | decode_type: 'single' # choices=['single', 'multi', 'bound'] 68 | teacher_model_epoch: 20 69 | 70 | -------------------------------------------------------------------------------- /config/ipc1/imagesquawk.yaml: -------------------------------------------------------------------------------- 1 | distibution_train: 2 | backend: 'nccl' # choices=['nccl', 'gloo', 'mpi', 'torch'] 3 | init_method: 'env://' 4 | workers: 8 5 | 6 | dataset: 7 | dataset: imagesquawk 8 | nclass: 10 9 | size: 128 10 | data_dir: '/root/autodl-tmp/imagenet/' 11 | load_memory: true 12 | batch_real: 1024 13 | nch: 3 14 | network: 15 | net_type: convnet 16 | norm_type: instance 17 | depth: 5 18 | width: 1.0 19 | 20 | train: 21 | evaluation_epochs: 2000 22 | epoch_print_freq: 10 23 | epoch_eval_interval: 100 24 | pertrain_epochs: 80 25 | batch_size: 128 26 | lr: 0.01 27 | adamw_lr: 0.001 28 | eval_optimizer: adamw 29 | momentum: 0.9 30 | weight_decay: 5e-4 31 | seed: 0 32 | model_num: 20 33 | 34 | 35 | augmentation: 36 | mixup: cut 37 | beta: 1.0 38 | mix_p: 0.5 39 | rrc: true 40 | dsa: true 41 | dsa_strategy: "color_crop_cutout_flip_scale_rotate" 42 | aug_type: 'color_crop_cutout' 43 | 44 | optimization: 45 | optimizer: adamw 46 | lr_scale_adam: 0.1 47 | lr_img: 0.01 48 | mom_img: 0.5 49 | lr_sampling_net: 1e-3 50 | 51 | save_path: 52 | save_dir: "../results/condense" 53 | pretrain_dir: '../pretrained_models' 54 | 55 | condense: 56 | ipc: 1 57 | num_premodel: 20 58 | niter: 20000 59 | iter_calib: 1 60 | calib_weight: 1 61 | sampling_net: False 62 | num_freqs: 4096 63 | dis_metrics: "NCFM" 64 | factor: 2 65 | alpha_for_loss: 0.5 66 | beta_for_loss: 0.5 67 | decode_type: 'single' # choices=['single', 'multi', 'bound'] 68 | teacher_model_epoch: 20 -------------------------------------------------------------------------------- /config/ipc10/imagenette.yaml: -------------------------------------------------------------------------------- 1 | distibution_train: 2 | backend: 'nccl' # choices=['nccl', 'gloo', 'mpi', 'torch'] 3 | init_method: 'env://' 4 | workers: 8 5 | 6 | dataset: 7 | dataset: imagenette 8 | nclass: 10 9 | size: 128 10 | data_dir: '/root/autodl-tmp/imagenet/' 11 | load_memory: true 12 | batch_real: 1024 13 | nch: 3 14 | network: 15 | net_type: convnet 16 | norm_type: instance 17 | depth: 5 18 | width: 1.0 19 | 20 | train: 21 | evaluation_epochs: 2000 22 | epoch_print_freq: 10 23 | epoch_eval_interval: 100 24 | pertrain_epochs: 80 25 | batch_size: 128 26 | lr: 0.01 27 | adamw_lr: 0.001 28 | eval_optimizer: adamw 29 | momentum: 0.9 30 | weight_decay: 5e-4 31 | seed: 0 32 | model_num: 20 33 | 34 | 35 | augmentation: 36 | mixup: cut 37 | beta: 1.0 38 | mix_p: 0.5 39 | rrc: true 40 | dsa: true 41 | dsa_strategy: "color_crop_cutout_flip_scale_rotate" 42 | aug_type: 'color_crop_cutout' 43 | 44 | optimization: 45 | optimizer: adamw 46 | lr_scale_adam: 0.1 47 | lr_img: 0.01 48 | mom_img: 0.5 49 | lr_sampling_net: 1e-3 50 | 51 | save_path: 52 | save_dir: "../results/condense" 53 | pretrain_dir: '../pretrained_models' 54 | 55 | condense: 56 | ipc: 10 57 | num_premodel: 20 58 | niter: 20000 59 | iter_calib: 1 60 | calib_weight: 1 61 | sampling_net: False 62 | num_freqs: 4096 63 | dis_metrics: "NCFM" 64 | factor: 2 65 | alpha_for_loss: 0.5 66 | beta_for_loss: 0.5 67 | decode_type: 'single' # choices=['single', 'multi', 'bound'] 68 | teacher_model_epoch: 20 -------------------------------------------------------------------------------- /config/ipc10/imagesquawk.yaml: -------------------------------------------------------------------------------- 1 | distibution_train: 2 | backend: 'nccl' # choices=['nccl', 'gloo', 'mpi', 'torch'] 3 | init_method: 'env://' 4 | workers: 8 5 | 6 | dataset: 7 | dataset: imagesquawk 8 | nclass: 10 9 | size: 128 10 | data_dir: '/root/autodl-tmp/imagenet/' 11 | load_memory: true 12 | batch_real: 1024 13 | nch: 3 14 | network: 15 | net_type: convnet 16 | norm_type: instance 17 | depth: 5 18 | width: 1.0 19 | 20 | train: 21 | evaluation_epochs: 2000 22 | epoch_print_freq: 10 23 | epoch_eval_interval: 100 24 | pertrain_epochs: 80 25 | batch_size: 128 26 | lr: 0.01 27 | adamw_lr: 0.001 28 | eval_optimizer: adamw 29 | momentum: 0.9 30 | weight_decay: 5e-4 31 | seed: 0 32 | model_num: 20 33 | 34 | 35 | augmentation: 36 | mixup: cut 37 | beta: 1.0 38 | mix_p: 0.5 39 | rrc: true 40 | dsa: true 41 | dsa_strategy: "color_crop_cutout_flip_scale_rotate" 42 | aug_type: 'color_crop_cutout' 43 | 44 | optimization: 45 | optimizer: adamw 46 | lr_scale_adam: 0.1 47 | lr_img: 0.01 48 | mom_img: 0.5 49 | lr_sampling_net: 1e-3 50 | 51 | save_path: 52 | save_dir: "../results/condense" 53 | pretrain_dir: '../pretrained_models' 54 | 55 | condense: 56 | ipc: 10 57 | num_premodel: 20 58 | niter: 20000 59 | iter_calib: 1 60 | calib_weight: 1 61 | sampling_net: False 62 | num_freqs: 4096 63 | dis_metrics: "NCFM" 64 | factor: 2 65 | alpha_for_loss: 0.5 66 | beta_for_loss: 0.5 67 | decode_type: 'single' # choices=['single', 'multi', 'bound'] 68 | teacher_model_epoch: 20 -------------------------------------------------------------------------------- /config/ipc1/imagefruit.yaml: -------------------------------------------------------------------------------- 1 | distibution_train: 2 | backend: 'nccl' # choices=['nccl', 'gloo', 'mpi', 'torch'] 3 | init_method: 'env://' 4 | workers: 8 5 | 6 | dataset: 7 | dataset: imagefruit 8 | nclass: 10 9 | size: 128 10 | data_dir: '/root/autodl-tmp/imagenet/' 11 | load_memory: true 12 | batch_real: 1024 13 | nch: 3 14 | 15 | network: 16 | net_type: convnet 17 | norm_type: instance 18 | depth: 5 19 | width: 1.0 20 | 21 | train: 22 | evaluation_epochs: 2000 23 | epoch_print_freq: 10 24 | epoch_eval_interval: 100 25 | pertrain_epochs: 80 26 | batch_size: 128 27 | lr: 0.01 28 | adamw_lr: 0.001 29 | eval_optimizer: adamw 30 | momentum: 0.9 31 | weight_decay: 5e-4 32 | seed: 0 33 | model_num: 20 34 | 35 | 36 | augmentation: 37 | mixup: cut 38 | beta: 1.0 39 | mix_p: 0.5 40 | rrc: true 41 | dsa: true 42 | dsa_strategy: "color_crop_cutout_flip_scale_rotate" 43 | aug_type: 'color_crop_cutout' 44 | 45 | optimization: 46 | optimizer: adamw 47 | lr_scale_adam: 0.1 48 | lr_img: 0.01 49 | mom_img: 0.5 50 | lr_sampling_net: 1e-3 51 | 52 | save_path: 53 | save_dir: "../results/condense" 54 | pretrain_dir: '../pretrained_models' 55 | 56 | condense: 57 | ipc: 1 58 | num_premodel: 20 59 | niter: 20000 60 | iter_calib: 1 61 | calib_weight: 1 62 | sampling_net: False 63 | num_freqs: 4096 64 | dis_metrics: "NCFM" 65 | factor: 2 66 | alpha_for_loss: 0.5 67 | beta_for_loss: 0.5 68 | decode_type: 'single' # choices=['single', 'multi', 'bound'] 69 | teacher_model_epoch: 20 -------------------------------------------------------------------------------- /config/ipc10/imagefruit.yaml: -------------------------------------------------------------------------------- 1 | distibution_train: 2 | backend: 'nccl' # choices=['nccl', 'gloo', 'mpi', 'torch'] 3 | init_method: 'env://' 4 | workers: 8 5 | 6 | dataset: 7 | dataset: imagefruit 8 | nclass: 10 9 | size: 128 10 | data_dir: '/root/autodl-tmp/imagenet/' 11 | load_memory: true 12 | batch_real: 1024 13 | nch: 3 14 | 15 | network: 16 | net_type: convnet 17 | norm_type: instance 18 | depth: 5 19 | width: 1.0 20 | 21 | train: 22 | evaluation_epochs: 2000 23 | epoch_print_freq: 10 24 | epoch_eval_interval: 100 25 | pertrain_epochs: 80 26 | batch_size: 128 27 | lr: 0.01 28 | adamw_lr: 0.001 29 | eval_optimizer: adamw 30 | momentum: 0.9 31 | weight_decay: 5e-4 32 | seed: 0 33 | model_num: 20 34 | 35 | 36 | augmentation: 37 | mixup: cut 38 | beta: 1.0 39 | mix_p: 0.5 40 | rrc: true 41 | dsa: true 42 | dsa_strategy: "color_crop_cutout_flip_scale_rotate" 43 | aug_type: 'color_crop_cutout' 44 | 45 | optimization: 46 | optimizer: adamw 47 | lr_scale_adam: 0.1 48 | lr_img: 0.01 49 | mom_img: 0.5 50 | lr_sampling_net: 1e-3 51 | 52 | save_path: 53 | save_dir: "../results/condense" 54 | pretrain_dir: '../pretrained_models' 55 | 56 | condense: 57 | ipc: 10 58 | num_premodel: 20 59 | niter: 20000 60 | iter_calib: 1 61 | calib_weight: 1 62 | sampling_net: False 63 | num_freqs: 4096 64 | dis_metrics: "NCFM" 65 | factor: 2 66 | alpha_for_loss: 0.5 67 | beta_for_loss: 0.5 68 | decode_type: 'single' # choices=['single', 'multi', 'bound'] 69 | teacher_model_epoch: 20 -------------------------------------------------------------------------------- /imagenet_subset/class100.txt: -------------------------------------------------------------------------------- 1 | n02869837 2 | n01749939 3 | n02488291 4 | n02107142 5 | n13037406 6 | n02091831 7 | n04517823 8 | n04589890 9 | n03062245 10 | n01773797 11 | n01735189 12 | n07831146 13 | n07753275 14 | n03085013 15 | n04485082 16 | n02105505 17 | n01983481 18 | n02788148 19 | n03530642 20 | n04435653 21 | n02086910 22 | n02859443 23 | n13040303 24 | n03594734 25 | n02085620 26 | n02099849 27 | n01558993 28 | n04493381 29 | n02109047 30 | n04111531 31 | n02877765 32 | n04429376 33 | n02009229 34 | n01978455 35 | n02106550 36 | n01820546 37 | n01692333 38 | n07714571 39 | n02974003 40 | n02114855 41 | n03785016 42 | n03764736 43 | n03775546 44 | n02087046 45 | n07836838 46 | n04099969 47 | n04592741 48 | n03891251 49 | n02701002 50 | n03379051 51 | n02259212 52 | n07715103 53 | n03947888 54 | n04026417 55 | n02326432 56 | n03637318 57 | n01980166 58 | n02113799 59 | n02086240 60 | n03903868 61 | n02483362 62 | n04127249 63 | n02089973 64 | n03017168 65 | n02093428 66 | n02804414 67 | n02396427 68 | n04418357 69 | n02172182 70 | n01729322 71 | n02113978 72 | n03787032 73 | n02089867 74 | n02119022 75 | n03777754 76 | n04238763 77 | n02231487 78 | n03032252 79 | n02138441 80 | n02104029 81 | n03837869 82 | n03494278 83 | n04136333 84 | n03794056 85 | n03492542 86 | n02018207 87 | n04067472 88 | n03930630 89 | n03584829 90 | n02123045 91 | n04229816 92 | n02100583 93 | n03642806 94 | n04336792 95 | n03259280 96 | n02116738 97 | n02108089 98 | n03424325 99 | n01855672 100 | n02090622 -------------------------------------------------------------------------------- /condenser/condense_transfom.py: -------------------------------------------------------------------------------- 1 | from data.transform import ( 2 | transform_imagenet, 3 | transform_cifar, 4 | transform_svhn, 5 | transform_mnist, 6 | transform_fashion, 7 | transform_tiny, 8 | ) 9 | 10 | 11 | def get_train_transform( 12 | dataset, 13 | augment=True, 14 | from_tensor=True, 15 | size=0, 16 | rrc=False, 17 | rrc_size=None, 18 | device="cpu", 19 | ): 20 | if dataset in [ 21 | "imagenette", 22 | "imagewoof", 23 | "imagemeow", 24 | "imagesquawk", 25 | "imagefruit", 26 | "imageyellow", 27 | "imagenet", 28 | ]: 29 | train_transform, _ = transform_imagenet( 30 | augment=augment, 31 | from_tensor=from_tensor, 32 | size=size, 33 | rrc=rrc, 34 | rrc_size=rrc_size, 35 | device=device, 36 | ) 37 | elif dataset[:5] == "cifar": 38 | train_transform, _ = transform_cifar(augment=augment, from_tensor=from_tensor) 39 | elif dataset == "svhn": 40 | train_transform, _ = transform_svhn(augment=augment, from_tensor=from_tensor) 41 | elif dataset == "mnist": 42 | train_transform, _ = transform_mnist(augment=augment, from_tensor=from_tensor) 43 | elif dataset == "fashion": 44 | train_transform, _ = transform_fashion(augment=augment, from_tensor=from_tensor) 45 | elif dataset == "tinyimagenet": 46 | train_transform, _ = transform_tiny(augment=augment, from_tensor=from_tensor) 47 | else: 48 | raise ValueError(f"Unsupported dataset: {dataset}") 49 | 50 | return train_transform, _ 51 | -------------------------------------------------------------------------------- /docs/static/js/index.js: -------------------------------------------------------------------------------- 1 | window.HELP_IMPROVE_VIDEOJS = false; 2 | 3 | var INTERP_BASE = "./static/interpolation/stacked"; 4 | var NUM_INTERP_FRAMES = 240; 5 | 6 | var interp_images = []; 7 | function preloadInterpolationImages() { 8 | for (var i = 0; i < NUM_INTERP_FRAMES; i++) { 9 | var path = INTERP_BASE + '/' + String(i).padStart(6, '0') + '.jpg'; 10 | interp_images[i] = new Image(); 11 | interp_images[i].src = path; 12 | } 13 | } 14 | 15 | function setInterpolationImage(i) { 16 | var image = interp_images[i]; 17 | image.ondragstart = function() { return false; }; 18 | image.oncontextmenu = function() { return false; }; 19 | $('#interpolation-image-wrapper').empty().append(image); 20 | } 21 | 22 | 23 | $(document).ready(function() { 24 | // Check for click events on the navbar burger icon 25 | $(".navbar-burger").click(function() { 26 | // Toggle the "is-active" class on both the "navbar-burger" and the "navbar-menu" 27 | $(".navbar-burger").toggleClass("is-active"); 28 | $(".navbar-menu").toggleClass("is-active"); 29 | 30 | }); 31 | 32 | var options = { 33 | slidesToScroll: 1, 34 | slidesToShow: 3, 35 | loop: true, 36 | infinite: true, 37 | autoplay: false, 38 | autoplaySpeed: 3000, 39 | } 40 | 41 | // Initialize all div with carousel class 42 | var carousels = bulmaCarousel.attach('.carousel', options); 43 | 44 | // Loop on each carousel initialized 45 | for(var i = 0; i < carousels.length; i++) { 46 | // Add listener to event 47 | carousels[i].on('before:show', state => { 48 | console.log(state); 49 | }); 50 | } 51 | 52 | // Access to bulmaCarousel instance of an element 53 | var element = document.querySelector('#my-element'); 54 | if (element && element.bulmaCarousel) { 55 | // bulmaCarousel instance is available as element.bulmaCarousel 56 | element.bulmaCarousel.on('before-show', function(state) { 57 | console.log(state); 58 | }); 59 | } 60 | 61 | /*var player = document.getElementById('interpolation-video'); 62 | player.addEventListener('loadedmetadata', function() { 63 | $('#interpolation-slider').on('input', function(event) { 64 | console.log(this.value, player.duration); 65 | player.currentTime = player.duration / 100 * this.value; 66 | }) 67 | }, false);*/ 68 | preloadInterpolationImages(); 69 | 70 | $('#interpolation-slider').on('input', function(event) { 71 | setInterpolationImage(this.value); 72 | }); 73 | setInterpolationImage(0); 74 | $('#interpolation-slider').prop('max', NUM_INTERP_FRAMES - 1); 75 | 76 | bulmaSlider.attach(); 77 | 78 | }) 79 | -------------------------------------------------------------------------------- /CL/CL_script.py: -------------------------------------------------------------------------------- 1 | def main_work(args): 2 | 3 | _,val_loader = get_loader(args) 4 | 5 | 6 | synset = Condenser(args, nclass_list=list(range(0,args.nclass)), nchannel=args.nch, hs=args.size, ws=args.size, device='cuda') 7 | 8 | for rank in range (args.world_size): 9 | if rank==args.rank: 10 | synset.load_condensed_data(loader=None, init_type="load",load_path=args.load_path) 11 | dist.barrier() 12 | 13 | syndataloader = synset.get_syndataLoader(args, args.augment) 14 | 15 | synset.continue_learning(args, syndataloader, val_loader) 16 | dist.destroy_process_group() 17 | 18 | 19 | 20 | if __name__ == '__main__': 21 | import os 22 | import sys 23 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 24 | from utils.utils import get_loader 25 | from utils.init_script import init_script 26 | import argparse 27 | from argsprocessor.args import ArgsProcessor 28 | from condenser.Condenser import Condenser 29 | import torch.distributed as dist 30 | 31 | parser = argparse.ArgumentParser(description='Configuration parser') 32 | parser.add_argument('--debug',dest='debug',action='store_true',help='When dataset is very large , you should get it') 33 | parser.add_argument('--config_path', type=str, required=True, help='Path to the YAML configuration file') 34 | parser.add_argument('--run_mode',type=str,choices=['Condense', 'Evaluation',"Pretrain"],default='Evaluation',help='Condense or Evaluation') 35 | parser.add_argument('--init',type=str,default='load',choices=['random', 'noise', 'mix', 'load'],help='condensed data initialization type') 36 | parser.add_argument('--load_path',type=str,required=True,help="Path to load the synset") 37 | parser.add_argument('--val_repeat',type=int,default=10,help='The times of validation on syn_dataset Imagenet only 3 times') 38 | parser.add_argument('--gpu', type=str, default = "0",required=True, help='GPUs to use, e.g., "0,1,2,3"') 39 | parser.add_argument('-i', '--ipc', type=int, default=1, help='number of condensed data per class') 40 | parser.add_argument('--tf32', action='store_true',default=True,help='Enable TF32') 41 | parser.add_argument('--softlabel',dest='softlabel',action='store_true',help='Use the softlabel to evaluate the dataset') 42 | parser.add_argument('--step', type=int, default=5,required=True, help='number of condensed data per class') 43 | args = parser.parse_args() 44 | args_processor = ArgsProcessor(args.config_path) 45 | 46 | args = args_processor.add_args_from_yaml(args) 47 | 48 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 49 | 50 | 51 | init_script(args) 52 | 53 | 54 | main_work(args) 55 | 56 | 57 | -------------------------------------------------------------------------------- /condenser/compute_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def compute_match_loss( 5 | args, 6 | loader_real, 7 | sample_fn, 8 | aug_fn, 9 | inner_loss_fn, 10 | optim_img, 11 | class_list, 12 | timing_tracker, 13 | model_interval, 14 | data_grad, 15 | optim_sampling_net = None, 16 | sampling_net =None 17 | ): 18 | 19 | loss_total = 0 20 | match_grad_mean = 0 21 | 22 | for c in class_list: 23 | timing_tracker.start_step() 24 | 25 | img, _ = loader_real.class_sample(c) 26 | timing_tracker.record("data") 27 | img_syn, _ = sample_fn(c) 28 | 29 | img_aug = aug_fn(torch.cat([img, img_syn])) 30 | timing_tracker.record("aug") 31 | n = img.shape[0] 32 | 33 | loss = inner_loss_fn(img_aug[:n], img_aug[n:], model_interval,sampling_net,args) 34 | loss_total += loss.item() 35 | timing_tracker.record("loss") 36 | 37 | optim_img.zero_grad() 38 | if optim_sampling_net is not None: 39 | optim_sampling_net.zero_grad() 40 | loss.backward(retain_graph=True) 41 | optim_img.step() 42 | optim_img.zero_grad() 43 | (-loss).backward() 44 | optim_sampling_net.step() 45 | optim_sampling_net.zero_grad() 46 | else: 47 | loss.backward() 48 | optim_img.step() 49 | if data_grad is not None: 50 | match_grad_mean += torch.norm(data_grad).item() 51 | timing_tracker.record("backward") 52 | 53 | return loss_total, match_grad_mean 54 | 55 | 56 | def compute_calib_loss( 57 | sample_fn, 58 | aug_fn, 59 | inter_loss_fn, 60 | optim_img, 61 | iter_calib, 62 | class_list, 63 | timing_tracker, 64 | model_final, 65 | calib_weight, 66 | data_grad, 67 | ): 68 | 69 | calib_loss_total = 0 70 | calib_grad_norm = 0 71 | for i in range(0, iter_calib): 72 | for c in class_list: 73 | timing_tracker.start_step() 74 | 75 | img_syn, label_syn = sample_fn(c) 76 | timing_tracker.record("data") 77 | 78 | img_aug = aug_fn(torch.cat([img_syn])) 79 | timing_tracker.record("aug") 80 | 81 | loss = calib_weight * inter_loss_fn(img_aug, label_syn, model_final) 82 | calib_loss_total += loss.item() 83 | timing_tracker.record("loss") 84 | 85 | optim_img.zero_grad() 86 | loss.backward() 87 | if data_grad is not None: 88 | calib_grad_norm = torch.norm(data_grad).item() 89 | optim_img.step() 90 | timing_tracker.record("backward") 91 | 92 | return calib_loss_total, calib_grad_norm 93 | -------------------------------------------------------------------------------- /NCFM/SampleNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SampleNet(nn.Module): 6 | """ 7 | TNet module for adversarial networks with fixed activation layers and predefined parameters. 8 | """ 9 | 10 | def __init__(self, feature_dim=64, t_batchsize=64, t_var=1): 11 | super(SampleNet, self).__init__() 12 | self.feature_dim = feature_dim # Feature dimension 13 | self.t_sigma_num = t_batchsize // 16 # Number of sigmas for t_net 14 | self._input_adv_t_net_dim = feature_dim # Input noise dimension 15 | self._input_t_dim = feature_dim # t_net input dimension 16 | self._input_t_batchsize = t_batchsize # Batch size 17 | self._input_t_var = t_var # Variance of input noise 18 | 19 | # Fixed activation layers 20 | self.activation_1 = nn.LeakyReLU(negative_slope=0.2) 21 | self.activation_2 = nn.Tanh() 22 | 23 | # Create a simple 3-layer fully connected network using fixed activation layers 24 | self.t_layers_list = nn.ModuleList() 25 | ch_in = self.feature_dim 26 | num_layer = 3 27 | for i in range(num_layer): 28 | self.t_layers_list.append(nn.Linear(ch_in, ch_in)) 29 | self.t_layers_list.append(nn.BatchNorm1d(ch_in)) 30 | # Use activation_1 for the first two layers, and activation_2 for the last layer 31 | self.t_layers_list.append( 32 | self.activation_1 if i < (num_layer - 1) else self.activation_2 33 | ) 34 | 35 | def forward(self, device): 36 | # Generate white noise 37 | if self.t_sigma_num > 0: 38 | # Initialize the white noise input 39 | self._t_net_input = torch.randn( 40 | self.t_sigma_num, self._input_adv_t_net_dim 41 | ) * (self._input_t_var**0.5) 42 | self._t_net_input = self._t_net_input.to(device).detach() 43 | 44 | # Forward pass 45 | a = self._t_net_input 46 | for layer in self.t_layers_list: 47 | a = layer(a) 48 | 49 | a = a.repeat(int(self._input_t_batchsize / self.t_sigma_num), 1) 50 | 51 | # Generate the final t value 52 | # self._t = torch.randn(self._input_t_batchsize, self._input_t_dim) * ((self._input_t_var / self._input_t_dim) ** 0.5) 53 | # self._t = self._t.to(device).detach() 54 | self._t = a 55 | else: 56 | # When t_sigma_num = 0, generate standard Gaussian noise as t 57 | self._t = torch.randn(self._input_t_batchsize, self._input_t_dim) * ( 58 | (self._input_t_var / self._input_t_dim) ** 0.5 59 | ) 60 | self._t = self._t.to(device).detach() 61 | return self._t -------------------------------------------------------------------------------- /argsprocessor/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Dict, Any, Optional 3 | import yaml 4 | 5 | 6 | class ArgsProcessor: 7 | def __init__(self, config_path: str) -> None: 8 | """ 9 | Initialize ArgsProcessor with a configuration file path. 10 | 11 | Args: 12 | config_path (str): Path to the YAML configuration file 13 | 14 | Returns: 15 | None 16 | """ 17 | self.config_path: str = config_path 18 | 19 | def flatten_dict(self, d: Dict[str, Any], parent_key: str = '', sep: str = '_') -> Dict[str, Any]: 20 | """ 21 | Recursively flattens a nested dictionary, but does not add the parent key. 22 | 23 | Args: 24 | d (Dict[str, Any]): Input dictionary to flatten 25 | parent_key (str, optional): Parent key (unused in this implementation). Defaults to '' 26 | sep (str, optional): Separator for nested keys. Defaults to '_' 27 | 28 | Returns: 29 | Dict[str, Any]: Flattened dictionary 30 | """ 31 | items: list = [] 32 | for k, v in d.items(): 33 | new_key: str = k # Use the current key directly, without adding the parent key 34 | if isinstance(v, dict): 35 | items.extend(self.flatten_dict(v, new_key, sep=sep).items()) 36 | else: 37 | items.append((new_key, v)) 38 | return dict(items) 39 | 40 | def add_args_from_yaml(self, args: argparse.Namespace) -> argparse.Namespace: 41 | """ 42 | Add contents of YAML configuration file to args object. 43 | 44 | Args: 45 | args (argparse.Namespace): Argument namespace to update 46 | 47 | Returns: 48 | argparse.Namespace: Updated argument namespace 49 | """ 50 | # Read the YAML configuration file 51 | with open(self.config_path, 'r') as f: 52 | config: Dict[str, Any] = yaml.safe_load(f) 53 | 54 | # Flatten the configuration dictionary 55 | flat_config: Dict[str, Any] = self.flatten_dict(config) 56 | 57 | # Convert value types (handle floating point numbers and booleans) 58 | for key, value in flat_config.items(): 59 | # Convert to float if possible 60 | if isinstance(value, str): 61 | if value.lower() in ['true', 'false']: 62 | flat_config[key] = value.lower() == 'true' 63 | elif 'e' in value or '.' in value: 64 | try: 65 | flat_config[key] = float(value) 66 | except ValueError: 67 | pass 68 | 69 | # Add the flattened configuration items to args 70 | for key, value in flat_config.items(): 71 | setattr(args, key, value) 72 | 73 | return args -------------------------------------------------------------------------------- /condenser/decode.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import numpy as np 3 | import torch 4 | from math import ceil 5 | import torch .nn as nn 6 | 7 | def decode(decode_type,size, data, target,factor, bound=128): 8 | 9 | if factor > 1: 10 | if decode_type == 'multi': 11 | data, target = decode_zoom_multi(size,data, target, factor) 12 | elif decode_type == 'bound': 13 | data, target = decode_zoom_bound(size,data, target, factor, bound=bound) 14 | else: 15 | data, target = decode_zoom(size,data, target, factor) 16 | 17 | return data, target 18 | 19 | 20 | def subsample(data, target, max_size=-1): 21 | if (data.shape[0] > max_size) and (max_size > 0): 22 | indices = np.random.permutation(data.shape[0]) 23 | data = data[indices[:max_size]] 24 | target = target[indices[:max_size]] 25 | 26 | return data, target 27 | 28 | def decode_zoom(size, img, target, factor): 29 | resizor = nn.Upsample(size=size, mode='bilinear') 30 | h = img.shape[-1] 31 | remained = h % factor 32 | if remained > 0: 33 | img = F.pad(img, pad=(0, factor - remained, 0, factor - remained), value=0.5) 34 | s_crop = ceil(h / factor) 35 | n_crop = factor**2 36 | 37 | cropped = [] 38 | for i in range(factor): 39 | for j in range(factor): 40 | h_loc = i * s_crop 41 | w_loc = j * s_crop 42 | cropped.append(img[:, :, h_loc:h_loc + s_crop, w_loc:w_loc + s_crop]) 43 | cropped = torch.cat(cropped) 44 | data_dec = resizor(cropped) 45 | target_dec = torch.cat([target for _ in range(n_crop)]) 46 | 47 | return data_dec, target_dec 48 | 49 | def decode_zoom_multi(size, img, target, factor_max): 50 | """Multi-scale multi-formation 51 | """ 52 | data_multi = [] 53 | target_multi = [] 54 | for factor in range(1, factor_max + 1): 55 | decoded = decode_zoom(size,img, target, factor) 56 | data_multi.append(decoded[0]) 57 | target_multi.append(decoded[1]) 58 | 59 | return torch.cat(data_multi), torch.cat(target_multi) 60 | 61 | def decode_zoom_bound(size, img, target, factor_max, bound=128): 62 | bound_cur = bound - len(img) 63 | budget = len(img) 64 | 65 | data_multi = [] 66 | target_multi = [] 67 | 68 | idx = 0 69 | decoded_total = 0 70 | for factor in range(factor_max, 0, -1): 71 | decode_size = factor**2 72 | if factor > 1: 73 | n = min(bound_cur // decode_size, budget) 74 | else: 75 | n = budget 76 | 77 | decoded = decode_zoom(size,img[idx:idx + n], target[idx:idx + n], factor) 78 | data_multi.append(decoded[0]) 79 | target_multi.append(decoded[1]) 80 | idx += n 81 | budget -= n 82 | decoded_total += n * decode_size 83 | bound_cur = bound - decoded_total - budget 84 | 85 | if budget == 0: 86 | break 87 | 88 | data_multi = torch.cat(data_multi) 89 | target_multi = torch.cat(target_multi) 90 | return data_multi, target_multi 91 | -------------------------------------------------------------------------------- /docs/static/css/bulma-carousel.min.css: -------------------------------------------------------------------------------- 1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}.slider{position:relative;width:100%}.slider-container{display:flex;flex-wrap:nowrap;flex-direction:row;overflow:hidden;-webkit-transform:translate3d(0,0,0);transform:translate3d(0,0,0);min-height:100%}.slider-container.is-vertical{flex-direction:column}.slider-container .slider-item{flex:none}.slider-container .slider-item .image.is-covered img{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.slider-container .slider-item .video-container{height:0;padding-bottom:0;padding-top:56.25%;margin:0;position:relative}.slider-container .slider-item .video-container.is-1by1,.slider-container .slider-item .video-container.is-square{padding-top:100%}.slider-container .slider-item .video-container.is-4by3{padding-top:75%}.slider-container .slider-item .video-container.is-21by9{padding-top:42.857143%}.slider-container .slider-item .video-container embed,.slider-container .slider-item .video-container iframe,.slider-container .slider-item .video-container object{position:absolute;top:0;left:0;width:100%!important;height:100%!important}.slider-navigation-next,.slider-navigation-previous{display:flex;justify-content:center;align-items:center;position:absolute;width:42px;height:42px;background:#fff center center no-repeat;background-size:20px 20px;border:1px solid #fff;border-radius:25091983px;box-shadow:0 2px 5px #3232321a;top:50%;margin-top:-20px;left:0;cursor:pointer;transition:opacity .3s,-webkit-transform .3s;transition:transform .3s,opacity .3s;transition:transform .3s,opacity .3s,-webkit-transform .3s}.slider-navigation-next:hover,.slider-navigation-previous:hover{-webkit-transform:scale(1.2);transform:scale(1.2)}.slider-navigation-next.is-hidden,.slider-navigation-previous.is-hidden{display:none;opacity:0}.slider-navigation-next svg,.slider-navigation-previous svg{width:25%}.slider-navigation-next{left:auto;right:0;background:#fff center center no-repeat;background-size:20px 20px}.slider-pagination{display:none;justify-content:center;align-items:center;position:absolute;bottom:0;left:0;right:0;padding:.5rem 1rem;text-align:center}.slider-pagination .slider-page{background:#fff;width:10px;height:10px;border-radius:25091983px;display:inline-block;margin:0 3px;box-shadow:0 2px 5px #3232321a;transition:-webkit-transform .3s;transition:transform .3s;transition:transform .3s,-webkit-transform .3s;cursor:pointer}.slider-pagination .slider-page.is-active,.slider-pagination .slider-page:hover{-webkit-transform:scale(1.4);transform:scale(1.4)}@media screen and (min-width:800px){.slider-pagination{display:flex}}.hero.has-carousel{position:relative}.hero.has-carousel+.hero-body,.hero.has-carousel+.hero-footer,.hero.has-carousel+.hero-head{z-index:10;overflow:hidden}.hero.has-carousel .hero-carousel{position:absolute;top:0;left:0;bottom:0;right:0;height:auto;border:none;margin:auto;padding:0;z-index:0}.hero.has-carousel .hero-carousel .slider{width:100%;max-width:100%;overflow:hidden;height:100%!important;max-height:100%;z-index:0}.hero.has-carousel .hero-carousel .slider .has-background{max-height:100%}.hero.has-carousel .hero-carousel .slider .has-background .is-background{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.hero.has-carousel .hero-body{margin:0 3rem;z-index:10} -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [CVPR2025] Dataset Distillation with Neural Characteristic Function: A Minmax Perspective 2 | 3 | Official PyTorch implementation of the paper ["Dataset Distillation with Neural Characteristic Function"](https://arxiv.org/abs/2502.20653) (NCFM) in CVPR 2025. 4 | 5 | 6 | ## :fire: News 7 | 8 | - [2025/03/02] The code of our paper has been released. 9 | - [2025/02/27] Our NCFM paper has been accepted to CVPR 2025 (Rating: 555). Thanks! 10 | 11 | 12 | ## :rocket: Pipeline 13 | 14 | Here's an overview of the process behind our **Neural Characteristic Function Matching (NCFM)** method: 15 | 16 | ![Figure 1](./asset/figure1.png?raw=true) 17 | 18 | 19 | 20 | 21 | ## 🛠️ Getting Started 22 | 23 | To get started with NCFM, follow the installation instructions below. 24 | 25 | 1. Clone the repo 26 | 27 | ```sh 28 | git clone https://github.com/gszfwsb/NCFM.git 29 | ``` 30 | 31 | 2. Install dependencies 32 | 33 | ```sh 34 | pip install -r requirements.txt 35 | ``` 36 | 3. Pretrain the models yourself, or download the **pretrained_models** from [huggingface](https://huggingface.co/maomaocun/NCFM). 37 | ```sh 38 | cd pretrain 39 | torchrun --nproc_per_node={n_gpus} --nnodes=1 pretrain_script.py --gpu={gpu_ids} --config_path=../config/{ipc}/{dataset}.yaml 40 | 41 | ``` 42 | 43 | 4. Condense 44 | ```sh 45 | cd condense 46 | torchrun --nproc_per_node={n_gpus} --nnodes=1 condense_script.py --gpu={gpu_ids} --ipc={ipc} --config_path=../config/{ipc}/{dataset}.yaml 47 | 48 | ``` 49 | 5. Evaluation or or download the **condensed dataset** from [huggingface](https://huggingface.co/maomaocun/NCFM) 50 | ```sh 51 | cd evaluation 52 | torchrun --nproc_per_node={n_gpus} --nnodes=1 evaluation_script.py --gpu={gpu_ids} --ipc={ipc} --config_path=../config/{ipc}/{dataset}.yaml --load_path={distilled_dataset.pt} 53 | ``` 54 | 55 | ### :blue_book: Example Usage 56 | 57 | 1. CIFAR-10 58 | 59 | ```sh 60 | #ipc50 61 | cd condense 62 | torchrun --nproc_per_node=8 --nnodes=1 --master_port=34153 condense_script.py --gpu="0,1,2,3,4,5,6,7" --ipc=50 --config_path=../config/ipc50/cifar10.yaml 63 | ``` 64 | 65 | 2. CIFAR-100 66 | 67 | ```sh 68 | #ipc10 69 | cd condense 70 | torchrun --nproc_per_node=8 --nnodes=1 --master_port=34153 condense_script.py --gpu="0,1,2,3,4,5,6,7" --ipc=10 --config_path=../config/ipc10/cifar100.yaml 71 | ``` 72 | 73 | 74 | 75 | ## :postbox: Contact 76 | If you have any questions, please contact [Shaobo Wang](https://gszfwsb.github.io/)(`shaobowang1009@sjtu.edu.cn`). 77 | 78 | ## :pushpin: Citation 79 | If you find NCFM useful for your research and applications, please cite using this BibTeX: 80 | 81 | ```bibtex 82 | @inproceedings{wang2025NCFM, 83 | title={Dataset Distillation with Neural Characteristic Function: A Minmax Perspective}, 84 | author={Shaobo Wang and Yicun Yang and Zhiyuan Liu and Chenghao Sun and Xuming Hu and Conghui He and Linfeng Zhang}, 85 | booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition}, 86 | year={2025} 87 | } 88 | ``` 89 | 90 | ## Acknowledgement 91 | We sincerely thank the developers of the following projects for their valuable contributions and inspiration: [MTT](https://github.com/GeorgeCazenavette/mtt-distillation), [DATM](https://github.com/NUS-HPC-AI-Lab/DATM), [DC/DM](https://github.com/VICO-UoE/DatasetCondensation), [IDC](https://github.com/snu-mllab/Efficient-Dataset-Condensation), [SRe2L](https://github.com/VILA-Lab/SRe2L), [RDED](https://github.com/LINs-lab/RDED), [DANCE](https://github.com/Hansong-Zhang/DANCE). We draw inspiration from these fantastic projects! 92 | 93 | -------------------------------------------------------------------------------- /condense/condense_script.py: -------------------------------------------------------------------------------- 1 | def main_worker(args): 2 | 3 | args.class_list = distribute_class(args.nclass,args.debug) 4 | 5 | plotter = get_plotter(args) 6 | 7 | loader_real,_ = get_loader(args) 8 | 9 | 10 | aug, _ = diffaug(args) 11 | 12 | condenser = Condenser(args, nclass_list=args.class_list, nchannel=args.nch, hs=args.size, ws=args.size, device='cuda') 13 | for local_rank in range(args.local_world_size): 14 | if args.local_rank == local_rank: 15 | condenser.load_condensed_data(loader_real, init_type=args.init,load_path=args.load_path) 16 | print(f"============RANK:{dist.get_rank()}====LOCAL_RANK {local_rank} Loaded Condensed Data==========================") 17 | dist.barrier() 18 | 19 | optim_img = get_optimizer(optimizer=args.optimizer, parameters=condenser.parameters(),lr=args.lr_img, mom_img=args.mom_img,weight_decay=args.weight_decay,logger=args.logger) 20 | if args.sampling_net: 21 | sampling_net = SampleNet(feature_dim=2048).to(args.device) 22 | optim_sampling_net = get_optimizer(optimizer= "sgd", parameters=sampling_net.parameters(),lr=args.lr_sampling_net, mom_img=args.mom_img,weight_decay=args.weight_decay,logger=args.logger) 23 | else: 24 | sampling_net = None 25 | optim_sampling_net = None 26 | model_init,model_interval,model_final = get_feature_extractor(args) 27 | condenser.condense(args,plotter,loader_real,aug,optim_img,model_init,model_interval,model_final,sampling_net,optim_sampling_net) 28 | 29 | dist.destroy_process_group() 30 | 31 | 32 | 33 | if __name__ == '__main__': 34 | import sys 35 | import os 36 | import torch 37 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 38 | from utils.diffaug import diffaug 39 | import torch.distributed as dist 40 | from utils.ddp import distribute_class 41 | from utils.utils import get_plotter,get_optimizer,get_loader,get_feature_extractor 42 | from utils.init_script import init_script 43 | import argparse 44 | from argsprocessor.args import ArgsProcessor 45 | from condenser.Condenser import Condenser 46 | from NCFM.SampleNet import SampleNet 47 | 48 | parser = argparse.ArgumentParser(description='Configuration parser') 49 | parser.add_argument('--debug',dest='debug',action='store_true',help='When dataset is very large , you should get it') 50 | parser.add_argument('--config_path', type=str, required=True, help='Path to the YAML configuration file') 51 | parser.add_argument('--run_mode',type=str,choices=['Condense', 'Evaluation',"Pretrain"],default='Condense',help='Condense or Evaluation') 52 | parser.add_argument('-a','--aug_type',type=str,default='color_crop_cutout',help='augmentation strategy for condensation matching objective') 53 | parser.add_argument('--init',type=str,default='mix',choices=['random', 'noise', 'mix', 'load'],help='condensed data initialization type') 54 | parser.add_argument('--load_path',type=str,default=None,help="Path to load the synset") 55 | parser.add_argument('--gpu', type=str, default = "0",required=True, help='GPUs to use, e.g., "0,1,2,3"') 56 | parser.add_argument('-i', '--ipc', type=int, default=1,required=True, help='number of condensed data per class') 57 | parser.add_argument('--tf32', action='store_true',default=True,help='Enable TF32') 58 | args = parser.parse_args() 59 | args_processor = ArgsProcessor(args.config_path) 60 | 61 | args = args_processor.add_args_from_yaml(args) 62 | 63 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 64 | 65 | init_script(args) 66 | 67 | main_worker(args) -------------------------------------------------------------------------------- /evaluation/evaluation_script.py: -------------------------------------------------------------------------------- 1 | def main_work(args): 2 | 3 | _, val_loader = get_loader(args) 4 | 5 | synset = Condenser( 6 | args, 7 | nclass_list=list(range(0, args.nclass)), 8 | nchannel=args.nch, 9 | hs=args.size, 10 | ws=args.size, 11 | device="cuda", 12 | ) 13 | 14 | for rank in range(args.world_size): 15 | if rank == args.rank: 16 | synset.load_condensed_data( 17 | loader=None, init_type="load", load_path=args.load_path 18 | ) 19 | dist.barrier() 20 | 21 | syndataloader = synset.get_syndataLoader(args, args.augment) 22 | 23 | synset.evaluate(args, syndataloader, val_loader) 24 | dist.destroy_process_group() 25 | 26 | 27 | if __name__ == "__main__": 28 | import os 29 | import sys 30 | 31 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 32 | from utils.utils import get_loader 33 | from utils.init_script import init_script 34 | import argparse 35 | from argsprocessor.args import ArgsProcessor 36 | from condenser.Condenser import Condenser 37 | import torch.distributed as dist 38 | 39 | parser = argparse.ArgumentParser(description="Configuration parser") 40 | parser.add_argument( 41 | "--debug", 42 | dest="debug", 43 | action="store_true", 44 | help="When dataset is very large , you should get it", 45 | ) 46 | parser.add_argument( 47 | "--config_path", 48 | type=str, 49 | required=True, 50 | help="Path to the YAML configuration file", 51 | ) 52 | parser.add_argument( 53 | "--run_mode", 54 | type=str, 55 | choices=["Condense", "Evaluation", "Pretrain"], 56 | default="Evaluation", 57 | help="Condense or Evaluation", 58 | ) 59 | parser.add_argument( 60 | "--init", 61 | type=str, 62 | default="load", 63 | choices=["random", "noise", "mix", "load"], 64 | help="condensed data initialization type", 65 | ) 66 | parser.add_argument( 67 | "--load_path", type=str, required=True, help="Path to load the synset" 68 | ) 69 | parser.add_argument( 70 | "--val_repeat", 71 | type=int, 72 | default=10, 73 | help="The times of validation on syn_dataset Imagenet only 3 times", 74 | ) 75 | parser.add_argument( 76 | "--gpu", 77 | type=str, 78 | default="0", 79 | required=True, 80 | help='GPUs to use, e.g., "0,1,2,3"', 81 | ) 82 | parser.add_argument( 83 | "-i", "--ipc", type=int, default=1, help="number of condensed data per class" 84 | ) 85 | parser.add_argument("--tf32", action="store_true", default=True, help="Enable TF32") 86 | parser.add_argument( 87 | "--softlabel", 88 | dest="softlabel", 89 | action="store_true", 90 | help="Use the softlabel to evaluate the dataset", 91 | ) 92 | parser.add_argument( 93 | "--kldiv", 94 | dest="kldiv", 95 | action="store_true", 96 | help="Use the kldiv loss to evaluate the dataset", 97 | ) 98 | parser.add_argument( 99 | "--temperature", type=float, default=1.0, help="The temperature for KLdiv" 100 | ) 101 | args = parser.parse_args() 102 | args_processor = ArgsProcessor(args.config_path) 103 | 104 | args = args_processor.add_args_from_yaml(args) 105 | 106 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 107 | 108 | init_script(args) 109 | 110 | main_work(args) 111 | -------------------------------------------------------------------------------- /docs/static/css/index.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: 'Noto Sans', sans-serif; 3 | } 4 | 5 | 6 | .footer .icon-link { 7 | font-size: 25px; 8 | color: #000; 9 | } 10 | 11 | .link-block a { 12 | margin-top: 5px; 13 | margin-bottom: 5px; 14 | } 15 | 16 | .dnerf { 17 | font-variant: small-caps; 18 | } 19 | 20 | 21 | .teaser .hero-body { 22 | padding-top: 0; 23 | padding-bottom: 3rem; 24 | } 25 | 26 | .teaser { 27 | font-family: 'Google Sans', sans-serif; 28 | } 29 | 30 | 31 | .publication-title { 32 | } 33 | 34 | .publication-banner { 35 | max-height: parent; 36 | 37 | } 38 | 39 | .publication-banner video { 40 | position: relative; 41 | left: auto; 42 | top: auto; 43 | transform: none; 44 | object-fit: fit; 45 | } 46 | 47 | .publication-header .hero-body { 48 | } 49 | 50 | .publication-title { 51 | font-family: 'Google Sans', sans-serif; 52 | } 53 | 54 | .publication-authors { 55 | font-family: 'Google Sans', sans-serif; 56 | } 57 | 58 | .publication-venue { 59 | color: #555; 60 | width: fit-content; 61 | font-weight: bold; 62 | } 63 | 64 | .publication-awards { 65 | color: #ff3860; 66 | width: fit-content; 67 | font-weight: bolder; 68 | } 69 | 70 | .publication-authors { 71 | } 72 | 73 | .publication-authors a { 74 | color: hsl(204, 86%, 53%) !important; 75 | } 76 | 77 | .publication-authors a:hover { 78 | text-decoration: underline; 79 | } 80 | 81 | .author-block { 82 | display: inline-block; 83 | } 84 | 85 | .publication-banner img { 86 | } 87 | 88 | .publication-authors { 89 | /*color: #4286f4;*/ 90 | } 91 | 92 | .publication-video { 93 | position: relative; 94 | width: 100%; 95 | height: 0; 96 | padding-bottom: 56.25%; 97 | 98 | overflow: hidden; 99 | border-radius: 10px !important; 100 | } 101 | 102 | .publication-video iframe { 103 | position: absolute; 104 | top: 0; 105 | left: 0; 106 | width: 100%; 107 | height: 100%; 108 | } 109 | 110 | .publication-body img { 111 | } 112 | 113 | .results-carousel { 114 | overflow: hidden; 115 | } 116 | 117 | .results-carousel .item { 118 | margin: 5px; 119 | overflow: hidden; 120 | border: 1px solid #bbb; 121 | border-radius: 10px; 122 | padding: 0; 123 | font-size: 0; 124 | } 125 | 126 | .results-carousel video { 127 | margin: 0; 128 | } 129 | 130 | 131 | .interpolation-panel { 132 | background: #f5f5f5; 133 | border-radius: 10px; 134 | } 135 | 136 | .interpolation-panel .interpolation-image { 137 | width: 100%; 138 | border-radius: 5px; 139 | } 140 | 141 | .interpolation-video-column { 142 | } 143 | 144 | .interpolation-panel .slider { 145 | margin: 0 !important; 146 | } 147 | 148 | .interpolation-panel .slider { 149 | margin: 0 !important; 150 | } 151 | 152 | #interpolation-image-wrapper { 153 | width: 100%; 154 | } 155 | #interpolation-image-wrapper img { 156 | border-radius: 5px; 157 | } 158 | 159 | .image-container { 160 | display: flex; 161 | justify-content: space - between; 162 | align-items: stretch; 163 | } 164 | 165 | .image-container img { 166 | max-width: 49%; 167 | object-fit: cover; 168 | } 169 | 170 | .publication-authors { 171 | font-family: 'Times New Roman', serif; 172 | line-height: 1.6; 173 | } 174 | .author-list { 175 | display: block; 176 | margin-bottom: 0.5em; 177 | } 178 | .author-block { 179 | display: inline-block; 180 | margin-right: 0.3em; 181 | vertical-align: top; 182 | } 183 | .institutions { 184 | display: block; 185 | font-size: 0.9em; 186 | margin: 0.5em 0; 187 | color: #666; 188 | } 189 | .contribution-notes { 190 | font-size: 0.85em; 191 | color: #333; 192 | margin-top: 0.3em; 193 | } 194 | .affiliation { 195 | color: #333; 196 | font-size: 0.9em; 197 | } 198 | .contribution { 199 | font-size: 0.8em; 200 | vertical-align: super; 201 | } -------------------------------------------------------------------------------- /models/densenet_cifar.py: -------------------------------------------------------------------------------- 1 | # Codes are borrowed from https://github.com/kuangliu/pytorch-cifar/blob/master/models/densenet.py 2 | """DenseNet in PyTorch.""" 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | def __init__(self, in_planes, growth_rate): 12 | super(Bottleneck, self).__init__() 13 | self.bn1 = nn.BatchNorm2d(in_planes) 14 | self.conv1 = nn.Conv2d(in_planes, 4 * growth_rate, kernel_size=1, bias=False) 15 | self.bn2 = nn.BatchNorm2d(4 * growth_rate) 16 | self.conv2 = nn.Conv2d( 17 | 4 * growth_rate, growth_rate, kernel_size=3, padding=1, bias=False 18 | ) 19 | 20 | def forward(self, x): 21 | out = self.conv1(F.relu(self.bn1(x))) 22 | out = self.conv2(F.relu(self.bn2(out))) 23 | out = torch.cat([out, x], 1) 24 | return out 25 | 26 | 27 | class Transition(nn.Module): 28 | def __init__(self, in_planes, out_planes): 29 | super(Transition, self).__init__() 30 | self.bn = nn.BatchNorm2d(in_planes) 31 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False) 32 | 33 | def forward(self, x): 34 | out = self.conv(F.relu(self.bn(x))) 35 | out = F.avg_pool2d(out, 2) 36 | return out 37 | 38 | 39 | class DenseNet(nn.Module): 40 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10): 41 | super(DenseNet, self).__init__() 42 | self.growth_rate = growth_rate 43 | 44 | num_planes = 2 * growth_rate 45 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False) 46 | 47 | self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0]) 48 | num_planes += nblocks[0] * growth_rate 49 | out_planes = int(math.floor(num_planes * reduction)) 50 | self.trans1 = Transition(num_planes, out_planes) 51 | num_planes = out_planes 52 | 53 | self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1]) 54 | num_planes += nblocks[1] * growth_rate 55 | out_planes = int(math.floor(num_planes * reduction)) 56 | self.trans2 = Transition(num_planes, out_planes) 57 | num_planes = out_planes 58 | 59 | self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2]) 60 | num_planes += nblocks[2] * growth_rate 61 | out_planes = int(math.floor(num_planes * reduction)) 62 | self.trans3 = Transition(num_planes, out_planes) 63 | num_planes = out_planes 64 | 65 | self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3]) 66 | num_planes += nblocks[3] * growth_rate 67 | 68 | self.bn = nn.BatchNorm2d(num_planes) 69 | self.linear = nn.Linear(num_planes, num_classes) 70 | 71 | def _make_dense_layers(self, block, in_planes, nblock): 72 | layers = [] 73 | for i in range(nblock): 74 | layers.append(block(in_planes, self.growth_rate)) 75 | in_planes += self.growth_rate 76 | return nn.Sequential(*layers) 77 | 78 | def forward(self, x): 79 | out = self.conv1(x) 80 | out = self.trans1(self.dense1(out)) 81 | out = self.trans2(self.dense2(out)) 82 | out = self.trans3(self.dense3(out)) 83 | out = self.dense4(out) 84 | out = F.avg_pool2d(F.relu(self.bn(out)), 4) 85 | out = out.view(out.size(0), -1) 86 | out = self.linear(out) 87 | return out 88 | 89 | 90 | def DenseNet121(nclass): 91 | return DenseNet(Bottleneck, [6, 12, 24, 16], growth_rate=32, num_classes=nclass) 92 | 93 | 94 | def DenseNet169(nclass): 95 | return DenseNet(Bottleneck, [6, 12, 32, 32], growth_rate=32, num_classes=nclass) 96 | 97 | 98 | def DenseNet201(nclass): 99 | return DenseNet(Bottleneck, [6, 12, 48, 32], growth_rate=32, num_classes=nclass) 100 | 101 | 102 | def DenseNet161(nclass): 103 | return DenseNet(Bottleneck, [6, 12, 36, 24], growth_rate=48, num_classes=nclass) 104 | 105 | 106 | def densenet_cifar(nclass): 107 | return DenseNet(Bottleneck, [6, 12, 24, 16], growth_rate=12, num_classes=nclass) 108 | 109 | -------------------------------------------------------------------------------- /NCFM/NCFM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def calculate_norm(x_r, x_i): 7 | return torch.sqrt(torch.mul(x_r, x_r) + torch.mul(x_i, x_i)) 8 | 9 | 10 | def calculate_imag(x): 11 | return torch.mean(torch.sin(x), dim=1) 12 | 13 | 14 | def calculate_real(x): 15 | return torch.mean(torch.cos(x), dim=1) 16 | 17 | 18 | class CFLossFunc(nn.Module): 19 | """ 20 | CF loss function in terms of phase and amplitude difference. 21 | Args: 22 | alpha_for_loss: the weight for amplitude in CF loss, from 0-1 23 | beta_for_loss: the weight for phase in CF loss, from 0-1 24 | """ 25 | 26 | def __init__(self, alpha_for_loss=0.5, beta_for_loss=0.5): 27 | super(CFLossFunc, self).__init__() 28 | self.alpha = alpha_for_loss 29 | self.beta = beta_for_loss 30 | 31 | def forward(self, feat_tg, feat, t=None, args=None): 32 | """ 33 | Calculate CF loss between target and synthetic features. 34 | Args: 35 | feat_tg: target features from real data [B1 x D] 36 | feat: synthetic features [B2 x D] 37 | args: additional arguments containing num_freqs 38 | """ 39 | # Generate random frequencies 40 | if t is None: 41 | t = torch.randn((args.num_freqs, feat.size(1)), device=feat.device) 42 | t_x_real = calculate_real(torch.matmul(t, feat.t())) 43 | t_x_imag = calculate_imag(torch.matmul(t, feat.t())) 44 | t_x_norm = calculate_norm(t_x_real, t_x_imag) 45 | 46 | t_target_real = calculate_real(torch.matmul(t, feat_tg.t())) 47 | t_target_imag = calculate_imag(torch.matmul(t, feat_tg.t())) 48 | t_target_norm = calculate_norm(t_target_real, t_target_imag) 49 | 50 | # Calculate amplitude difference and phase difference 51 | amp_diff = t_target_norm - t_x_norm 52 | loss_amp = torch.mul(amp_diff, amp_diff) 53 | 54 | loss_pha = 2 * ( 55 | torch.mul(t_target_norm, t_x_norm) 56 | - torch.mul(t_x_real, t_target_real) 57 | - torch.mul(t_x_imag, t_target_imag) 58 | ) 59 | 60 | loss_pha = loss_pha.clamp(min=1e-12) # Ensure numerical stability 61 | 62 | # Combine losses 63 | loss = torch.mean(torch.sqrt(self.alpha * loss_amp + self.beta * loss_pha)) 64 | return loss 65 | 66 | 67 | def match_loss(img_real, img_syn, model,sampling_net, args=None): 68 | """Matching losses (feature or gradient)""" 69 | with torch.no_grad(): 70 | _, feat_tg = model(img_real, return_features=True) 71 | _, feat = model(img_syn, return_features=True) 72 | feat = F.normalize(feat, dim=1) 73 | feat_tg = F.normalize(feat_tg, dim=1) 74 | if sampling_net is not None: 75 | t = sampling_net(args.device) 76 | else: 77 | t = None 78 | loss = 300 * args.cf_loss_func(feat_tg, feat, t, args) 79 | return loss 80 | 81 | 82 | def mutil_layer_match_loss(img_real, img_syn, model,sampling_net, args=None): 83 | 84 | # Ensure layer_index is a list 85 | assert isinstance( 86 | args.layer_index, list 87 | ), "args.layer_index must be a list of layer indices" 88 | 89 | # Initialize loss as a tensor on the correct device 90 | loss = torch.tensor(0.0).to(img_real.device) 91 | 92 | # Extract features for both real and synthetic images 93 | with torch.no_grad(): 94 | feat_tg_list = model.get_feature_mutil(img_real) # Real image features 95 | feat_list = model.get_feature_mutil(img_syn) # Synthetic image features 96 | 97 | for layer_index in args.layer_index: 98 | assert ( 99 | 0 <= layer_index <= 6 100 | ), f"layer_index {layer_index} must be between 0 and 6" 101 | if args.dis_metrics == "MMD": 102 | # If the metric is MMD, calculate the MMD loss for the selected layer 103 | feat = feat_list[layer_index] 104 | feat_tg = feat_tg_list[layer_index] 105 | loss += torch.sum((feat.mean(0) - feat_tg.mean(0)) ** 2) 106 | else: 107 | # Otherwise, calculate the feature matching loss for the selected layer 108 | feat = feat_list[layer_index] 109 | feat_tg = feat_tg_list[layer_index] 110 | feat = F.normalize(feat, dim=1) # Normalize the feature 111 | feat_tg = F.normalize(feat_tg, dim=1) # Normalize the target feature 112 | t = None # Adjust this based on your CFLossFunc usage 113 | loss += 300 * args.cf_loss_func(feat_tg, feat, t, args) 114 | 115 | return loss 116 | 117 | 118 | def cailb_loss(img_syn, label_syn, trained_model): 119 | logits = trained_model(img_syn, return_features=False) 120 | loss = F.cross_entropy(logits, label_syn) 121 | return loss -------------------------------------------------------------------------------- /data/augment.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | 4 | 5 | class Compose(object): 6 | def __init__(self, transforms): 7 | self.transforms = transforms 8 | 9 | def __call__(self, img): 10 | for t in self.transforms: 11 | img = t(img) 12 | return img 13 | 14 | def __repr__(self): 15 | format_string = self.__class__.__name__ + "(" 16 | for t in self.transforms: 17 | format_string += "\n" 18 | format_string += " {0}".format(t) 19 | format_string += "\n)" 20 | return format_string 21 | 22 | 23 | class Lighting(object): 24 | """Lighting noise(AlexNet - style PCA - based noise)""" 25 | 26 | def __init__(self, alphastd, eigval, eigvec, device="cpu"): 27 | self.alphastd = alphastd 28 | self.eigval = torch.tensor(eigval, device=device) 29 | self.eigvec = torch.tensor(eigvec, device=device) 30 | 31 | def __call__(self, img): 32 | if self.alphastd == 0: 33 | return img 34 | 35 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 36 | rgb = ( 37 | self.eigvec.type_as(img) 38 | .clone() 39 | .mul(alpha.view(1, 3).expand(3, 3)) 40 | .mul(self.eigval.view(1, 3).expand(3, 3)) 41 | .sum(1) 42 | .squeeze() 43 | ) 44 | 45 | # make differentiable 46 | if len(img.shape) == 4: 47 | return img + rgb.view(1, 3, 1, 1).expand_as(img) 48 | else: 49 | return img + rgb.view(3, 1, 1).expand_as(img) 50 | 51 | 52 | class Grayscale(object): 53 | def __call__(self, img): 54 | gs = img.clone() 55 | gs[0].mul_(0.299).add_(0.587, gs[1]).add_(0.114, gs[2]) 56 | gs[1].copy_(gs[0]) 57 | gs[2].copy_(gs[0]) 58 | return gs 59 | 60 | 61 | class Saturation(object): 62 | def __init__(self, var): 63 | self.var = var 64 | 65 | def __call__(self, img): 66 | gs = Grayscale()(img) 67 | alpha = random.uniform(-self.var, self.var) 68 | return img.lerp(gs, alpha) 69 | 70 | 71 | class Brightness(object): 72 | def __init__(self, var): 73 | self.var = var 74 | 75 | def __call__(self, img): 76 | gs = img.new().resize_as_(img).zero_() 77 | alpha = random.uniform(-self.var, self.var) 78 | return img.lerp(gs, alpha) 79 | 80 | 81 | class Contrast(object): 82 | def __init__(self, var): 83 | self.var = var 84 | 85 | def __call__(self, img): 86 | gs = Grayscale()(img) 87 | gs.fill_(gs.mean()) 88 | alpha = random.uniform(-self.var, self.var) 89 | return img.lerp(gs, alpha) 90 | 91 | 92 | class ColorJitter(object): 93 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4): 94 | self.brightness = brightness 95 | self.contrast = contrast 96 | self.saturation = saturation 97 | 98 | def __call__(self, img): 99 | self.transforms = [] 100 | if self.brightness != 0: 101 | self.transforms.append(Brightness(self.brightness)) 102 | if self.contrast != 0: 103 | self.transforms.append(Contrast(self.contrast)) 104 | if self.saturation != 0: 105 | self.transforms.append(Saturation(self.saturation)) 106 | 107 | random.shuffle(self.transforms) 108 | transform = Compose(self.transforms) 109 | # print(transform) 110 | return transform(img) 111 | 112 | 113 | class CutOut: 114 | def __init__(self, ratio, device="cpu"): 115 | self.ratio = ratio 116 | self.device = device 117 | 118 | def __call__(self, x): 119 | n, _, h, w = x.shape 120 | cutout_size = [int(h * self.ratio + 0.5), int(w * self.ratio + 0.5)] 121 | offset_x = torch.randint( 122 | h + (1 - cutout_size[0] % 2), size=[1], device=self.device 123 | )[0] 124 | offset_y = torch.randint( 125 | w + (1 - cutout_size[1] % 2), size=[1], device=self.device 126 | )[0] 127 | 128 | grid_batch, grid_x, grid_y = torch.meshgrid( 129 | torch.arange(n, dtype=torch.long, device=self.device), 130 | torch.arange(cutout_size[0], dtype=torch.long, device=self.device), 131 | torch.arange(cutout_size[1], dtype=torch.long, device=self.device), 132 | ) 133 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=h - 1) 134 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=w - 1) 135 | mask = torch.ones(n, h, w, dtype=x.dtype, device=self.device) 136 | mask[grid_batch, grid_x, grid_y] = 0 137 | 138 | x = x * mask.unsqueeze(1) 139 | return x 140 | 141 | 142 | class Normalize: 143 | def __init__(self, mean, std, device="cpu"): 144 | self.mean = torch.tensor(mean, device=device).reshape(1, len(mean), 1, 1) 145 | self.std = torch.tensor(std, device=device).reshape(1, len(mean), 1, 1) 146 | 147 | def __call__(self, x, seed=-1): 148 | return (x - self.mean) / self.std 149 | -------------------------------------------------------------------------------- /utils/init_script.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import torch 5 | import datetime 6 | from .experiment_tracker import Logger 7 | from .diffaug import remove_aug 8 | import torch.distributed as dist 9 | import datetime 10 | from datetime import timedelta 11 | from torch.backends import cudnn 12 | from .ddp import initialize_distribution_training 13 | 14 | 15 | def init_script(args): 16 | cudnn.benchmark = True 17 | torch.backends.cuda.matmul.allow_tf32 = args.tf32 18 | torch.backends.cudnn.allow_tf32 = args.tf32 19 | 20 | rank, world_size, local_rank, local_world_size, device = ( 21 | initialize_distribution_training(args.backend, args.init_method) 22 | ) 23 | 24 | args.it_save, args.it_log = set_iteration_parameters(args.niter, args.debug) 25 | 26 | args.pretrain_dir = set_Pretrain_Directory( 27 | args.pretrain_dir, args.dataset, args.depth 28 | ) 29 | 30 | args.exp_name, args.save_dir, args.lr_img = set_experiment_name_and_save_Dir( 31 | args.run_mode, 32 | args.dataset, 33 | args.pretrain_dir, 34 | args.save_dir, 35 | args.lr_img, 36 | args.lr_scale_adam, 37 | args.ipc, 38 | args.optimizer, 39 | args.load_path, 40 | args.factor, 41 | args.lr, 42 | args.num_freqs, 43 | ) 44 | 45 | set_random_seeds(args.seed) 46 | 47 | args.mixup, args.dsa_strategy, args.dsa, args.augment = ( 48 | adjust_augmentation_strategy(args.mixup, args.dsa_strategy, args.dsa) 49 | ) 50 | 51 | args.logger = setup_logging_and_directories(args, args.run_mode, args.save_dir) 52 | args.rank, args.world_size, args.local_rank, args.local_world_size, args.device = ( 53 | rank, 54 | world_size, 55 | local_rank, 56 | local_world_size, 57 | device, 58 | ) 59 | if args.rank == 0: 60 | args.logger("TF32 is enabled") if args.tf32 else print("TF32 is disabled") 61 | args.logger( 62 | f"=> creating model {args.net_type}-{args.depth}, norm: {args.norm_type}" 63 | ) 64 | 65 | 66 | def set_iteration_parameters(niter, debug): 67 | 68 | it_save = np.arange(0, niter + 1, 1000).tolist() 69 | it_log = 1 if debug else 20 70 | return it_save, it_log 71 | 72 | 73 | def set_Pretrain_Directory(pretrain_dir, dataset, depth): 74 | 75 | if dataset.lower() == "imagenet": 76 | pretrain_dir = f"./{pretrain_dir}/{dataset}/ResNet-{depth}" 77 | else: 78 | pretrain_dir = f"./{pretrain_dir}/{dataset}" 79 | return pretrain_dir 80 | 81 | 82 | def set_experiment_name_and_save_Dir( 83 | run_mode, 84 | dataset, 85 | pretrain_dir, 86 | save_dir, 87 | lr_img, 88 | lr_scale_adam, 89 | ipc, 90 | optimizer, 91 | load_path, 92 | factor, 93 | lr, 94 | num_freqs, 95 | ): 96 | timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M") 97 | # Set the base save directory path according to the run_mode 98 | if run_mode == "Condense": 99 | assert ipc > 0, "IPC must be greater than 0" 100 | if optimizer.lower() == "sgd": 101 | lr_img = lr_img 102 | else: 103 | lr_img = lr_img * lr_scale_adam 104 | 105 | # Generate experiment name 106 | exp_name = f"./condense/{dataset}/ipc{ipc}/{optimizer}_lr_img_{lr_img:.4f}_numr_reqs{num_freqs}_factor{factor}_{timestamp}" 107 | if load_path: 108 | exp_name += f"Reload_SynData_Path_{load_path}" 109 | save_dir = os.path.join(save_dir, exp_name) 110 | 111 | elif run_mode == "Evaluation": 112 | assert ipc > 0, "IPC must be greater than 0" 113 | exp_name = ( 114 | f"./evaluate/{dataset}/ipc{ipc}/_lr{lr:.4f}__factor{factor}_{timestamp}" 115 | ) 116 | save_dir = os.path.join(save_dir, exp_name) 117 | elif run_mode == "Pretrain": 118 | save_dir = pretrain_dir 119 | exp_name = pretrain_dir 120 | else: 121 | raise ValueError( 122 | "Invalid run_mode. Choose 'Condense', 'Evaluation' or 'Pretrain'." 123 | ) 124 | 125 | # Create save directory if the rank is 0 126 | if dist.get_rank() == 0: 127 | os.makedirs(save_dir, exist_ok=True) 128 | 129 | return exp_name, save_dir, lr_img 130 | 131 | 132 | def set_random_seeds(seed): 133 | 134 | if seed > 0: 135 | np.random.seed(seed) 136 | torch.manual_seed(seed) 137 | torch.cuda.manual_seed(seed) 138 | if dist.get_rank() == 0: 139 | print(f"Set Random Seed as {seed}") 140 | 141 | 142 | def setup_logging_and_directories(args, run_mode, save_dir): 143 | if dist.get_rank() == 0: 144 | if run_mode == "Condense": 145 | subdirs = ["images", "distilled_data"] 146 | for subdir in subdirs: 147 | os.makedirs(os.path.join(save_dir, subdir), exist_ok=True) 148 | args_log_path = os.path.join(save_dir, "args.log") 149 | with open(args_log_path, "w") as f: 150 | json.dump(vars(args), f, indent=3) 151 | dist.barrier() 152 | logger = Logger(args.save_dir) 153 | dist.barrier() 154 | if dist.get_rank() == 0: 155 | logger(f"Save dir: {args.save_dir}") 156 | 157 | return logger 158 | 159 | 160 | def adjust_augmentation_strategy(mixup, dsa_strategy, dsa): 161 | 162 | if mixup == "cut": 163 | dsa_strategy = remove_aug(dsa_strategy, "cutout") 164 | 165 | if dsa: 166 | augment = False 167 | if dist.get_rank() == 0: 168 | print( 169 | "DSA strategy: ", 170 | dsa_strategy, 171 | ) 172 | else: 173 | augment = True 174 | return mixup, dsa_strategy, dsa, augment 175 | -------------------------------------------------------------------------------- /utils/ddp.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | import os 3 | import torch 4 | from datetime import timedelta 5 | from data.save_img import save_img 6 | from collections import OrderedDict 7 | 8 | 9 | def initialize_distribution_training(backend="nccl", init_method="env://"): 10 | dist.init_process_group( 11 | backend=backend, init_method=init_method, timeout=timedelta(seconds=3000) 12 | ) 13 | rank = dist.get_rank() 14 | world_size = dist.get_world_size() 15 | # Get local rank from environment variable 16 | local_rank = int(os.environ["LOCAL_RANK"]) 17 | local_world_size = int(os.environ["WORLD_SIZE"]) 18 | # Set the current GPU for this process 19 | torch.cuda.set_device(local_rank) 20 | device = torch.device(f"cuda:{local_rank}") 21 | return rank, world_size, local_rank, local_world_size, device 22 | 23 | 24 | def distribute_class(nclass, debug=False): 25 | if debug: 26 | nclass = max(nclass // 100, 10) # Reduce the number of classes for debugging 27 | classes_per_process = nclass // dist.get_world_size() # Distribute classes evenly 28 | remainder = ( 29 | nclass % dist.get_world_size() 30 | ) # Handle remainder for unequal distribution 31 | start_class = ( 32 | dist.get_rank() * classes_per_process 33 | ) # Start class index for this rank 34 | end_class = start_class + classes_per_process # End class index for this rank 35 | if dist.get_rank() == dist.get_world_size() - 1: 36 | end_class += remainder # Add remainder to the last rank's class range 37 | class_list = list(range(start_class, end_class)) # List of classes for this rank 38 | for rank in range(dist.get_world_size()): 39 | if dist.get_rank() == rank: 40 | print( 41 | f"==========================Rank {dist.get_rank()} has classes {class_list}==========================" 42 | ) 43 | else: 44 | dist.barrier() 45 | return class_list 46 | 47 | 48 | def load_state_dict(state_dict_path, model): 49 | state_dict = torch.load(state_dict_path, map_location="cpu") 50 | # Remove `module.` prefix from keys if it exists 51 | new_state_dict = OrderedDict() 52 | for key, value in state_dict.items(): 53 | new_key = key.replace("module.", "") # Remove 'module.' prefix 54 | new_state_dict[new_key] = value 55 | model.load_state_dict(new_state_dict) 56 | 57 | 58 | def gather_save_visualize(synset, args, iteration=None): 59 | temp_save_dir = os.path.join( 60 | args.save_dir, "temp_rank_data" 61 | ) # Temporary directory to save rank data 62 | os.makedirs(temp_save_dir, exist_ok=True) 63 | save_iteration = ( 64 | (iteration + 1) if iteration is not None else "init" 65 | ) # Set iteration name 66 | temp_file_path = os.path.join( 67 | temp_save_dir, f"temp_rank_{args.rank}_{save_iteration}.pt" 68 | ) 69 | torch.save( 70 | [synset.data.detach().cpu(), synset.targets.cpu()], temp_file_path 71 | ) # Save data and targets for this rank 72 | dist.barrier() # Synchronize all processes 73 | if args.rank == 0: 74 | all_data = [] 75 | all_targets = [] 76 | for r in range(args.world_size): 77 | temp_file_path = os.path.join( 78 | temp_save_dir, f"temp_rank_{r}_{save_iteration}.pt" 79 | ) 80 | data, targets = torch.load(temp_file_path) # Load data from all ranks 81 | all_data.append(data) 82 | all_targets.append(targets) 83 | all_data = torch.cat(all_data, dim=0) # Concatenate data from all ranks 84 | all_targets = torch.cat( 85 | all_targets, dim=0 86 | ) # Concatenate targets from all ranks 87 | args.logger(f"the shape of saved data {all_data.shape}") 88 | args.logger(f"the shape of saved target {all_targets.shape}") 89 | os.makedirs(args.save_dir, exist_ok=True) 90 | save_img( 91 | os.path.join(args.save_dir, "images", f"img_{save_iteration}.png"), 92 | all_data, 93 | unnormalize=False, 94 | dataname=args.dataset, 95 | ) # Save images 96 | data_save_path = os.path.join( 97 | args.save_dir, "distilled_data", f"data_{save_iteration}.pt" 98 | ) 99 | torch.save( 100 | [all_data, all_targets], data_save_path 101 | ) # Save concatenated data and targets 102 | args.logger(f"All data saved at iteration {save_iteration}.") 103 | # Clean up temporary directory 104 | for r in range(args.world_size): 105 | temp_file_path = os.path.join( 106 | temp_save_dir, f"temp_rank_{r}_{save_iteration}.pt" 107 | ) 108 | os.remove(temp_file_path) # Remove temporary files 109 | os.rmdir(temp_save_dir) # Remove the temporary directory 110 | else: 111 | pass 112 | 113 | 114 | def sync_distributed_metric(metric): 115 | device = torch.device( 116 | f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu" 117 | ) 118 | if isinstance(metric, list): 119 | # Convert metric to tensor if it isn't already 120 | metric_tensors = [ 121 | torch.tensor(m, device=device) if not isinstance(m, torch.Tensor) else m 122 | for m in metric 123 | ] 124 | # Use all_reduce to synchronize each tensor across ranks 125 | for m in metric_tensors: 126 | dist.all_reduce(m, op=dist.ReduceOp.SUM) 127 | # Return average for each metric 128 | return [m.item() / dist.get_world_size() for m in metric_tensors] 129 | else: 130 | # Single metric 131 | if not isinstance(metric, torch.Tensor): 132 | metric = torch.tensor(metric, device=device) 133 | # Use all_reduce to synchronize the metric 134 | dist.all_reduce(metric, op=dist.ReduceOp.SUM) 135 | # Return the average value 136 | return metric.item() / dist.get_world_size() 137 | -------------------------------------------------------------------------------- /data/transform.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as transforms 2 | from data.dataset_statistics import MEANS, STDS 3 | from data.augment import ColorJitter, Lighting 4 | 5 | 6 | def transform_cifar(augment=False, from_tensor=False, normalize=True): 7 | if not augment: 8 | aug = [] 9 | else: 10 | aug = [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip()] 11 | print("Dataset with basic Cifar augmentation") 12 | 13 | if from_tensor: 14 | cast = [] 15 | else: 16 | cast = [transforms.ToTensor()] 17 | 18 | if normalize: 19 | normal_fn = [transforms.Normalize(mean=MEANS["cifar"], std=STDS["cifar"])] 20 | else: 21 | normal_fn = [] 22 | train_transform = transforms.Compose(cast + aug + normal_fn) 23 | test_transform = transforms.Compose(cast + normal_fn) 24 | 25 | return train_transform, test_transform 26 | 27 | 28 | def transform_svhn(augment=False, from_tensor=False, normalize=True): 29 | if not augment: 30 | aug = [] 31 | else: 32 | aug = [transforms.RandomCrop(32, padding=4)] 33 | print("Dataset with basic SVHN augmentation") 34 | 35 | if from_tensor: 36 | cast = [] 37 | else: 38 | cast = [transforms.ToTensor()] 39 | 40 | if normalize: 41 | normal_fn = [transforms.Normalize(mean=MEANS["svhn"], std=STDS["svhn"])] 42 | else: 43 | normal_fn = [] 44 | 45 | train_transform = transforms.Compose(cast + aug + normal_fn) 46 | test_transform = transforms.Compose(cast + normal_fn) 47 | 48 | return train_transform, test_transform 49 | 50 | 51 | def transform_mnist(augment=False, from_tensor=False, normalize=True): 52 | if not augment: 53 | aug = [] 54 | else: 55 | aug = [transforms.RandomCrop(28, padding=4)] 56 | print("Dataset with basic MNIST augmentation") 57 | 58 | if from_tensor: 59 | cast = [] 60 | else: 61 | cast = [transforms.ToTensor()] 62 | 63 | if normalize: 64 | normal_fn = [transforms.Normalize(mean=MEANS["mnist"], std=STDS["mnist"])] 65 | else: 66 | normal_fn = [] 67 | 68 | train_transform = transforms.Compose(cast + aug + normal_fn) 69 | test_transform = transforms.Compose(cast + normal_fn) 70 | 71 | return train_transform, test_transform 72 | 73 | 74 | def transform_fashion(augment=False, from_tensor=False, normalize=True): 75 | if not augment: 76 | aug = [] 77 | else: 78 | aug = [transforms.RandomCrop(28, padding=4)] 79 | print("Dataset with basic FashionMNIST augmentation") 80 | 81 | if from_tensor: 82 | cast = [] 83 | else: 84 | cast = [transforms.ToTensor()] 85 | 86 | if normalize: 87 | normal_fn = [transforms.Normalize(mean=MEANS["fashion"], std=STDS["fashion"])] 88 | else: 89 | normal_fn = [] 90 | 91 | train_transform = transforms.Compose(cast + aug + normal_fn) 92 | test_transform = transforms.Compose(cast + normal_fn) 93 | 94 | return train_transform, test_transform 95 | 96 | 97 | def transform_tiny(augment=False, from_tensor=False, normalize=True): 98 | if not augment: 99 | aug = [] 100 | else: 101 | aug = [transforms.RandomCrop(64, padding=4), transforms.RandomHorizontalFlip()] 102 | print("Dataset with basic Cifar augmentation") 103 | 104 | if from_tensor: 105 | cast = [] 106 | else: 107 | cast = [transforms.ToTensor()] 108 | 109 | if normalize: 110 | normal_fn = [ 111 | transforms.Normalize(mean=MEANS["tinyimagenet"], std=STDS["tinyimagenet"]) 112 | ] 113 | else: 114 | normal_fn = [] 115 | 116 | train_transform = transforms.Compose(cast + aug + normal_fn) 117 | test_transform = transforms.Compose(cast + normal_fn) 118 | 119 | return train_transform, test_transform 120 | 121 | 122 | def transform_imagenet( 123 | size=-1, 124 | augment=False, 125 | from_tensor=False, 126 | normalize=True, 127 | rrc=True, 128 | rrc_size=-1, 129 | device="cpu", 130 | ): 131 | if size > 0: 132 | resize_train = [transforms.Resize(size), transforms.CenterCrop(size)] 133 | resize_test = [transforms.Resize(size), transforms.CenterCrop(size)] 134 | # print(f"Resize and crop training images to {size}") 135 | elif size == 0: 136 | resize_train = [] 137 | resize_test = [] 138 | assert rrc_size > 0, "Set RRC size!" 139 | else: 140 | resize_train = [transforms.RandomResizedCrop(224)] 141 | resize_test = [transforms.Resize(256), transforms.CenterCrop(224)] 142 | 143 | if not augment: 144 | aug = [] 145 | # print("Loader with DSA augmentation") 146 | else: 147 | jittering = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4) 148 | lighting = Lighting( 149 | alphastd=0.1, 150 | eigval=[0.2175, 0.0188, 0.0045], 151 | eigvec=[ 152 | [-0.5675, 0.7192, 0.4009], 153 | [-0.5808, -0.0045, -0.8140], 154 | [-0.5836, -0.6948, 0.4203], 155 | ], 156 | device=device, 157 | ) 158 | aug = [transforms.RandomHorizontalFlip(), jittering, lighting] 159 | 160 | if rrc and size >= 0: 161 | if rrc_size == -1: 162 | rrc_size = size 163 | rrc_fn = transforms.RandomResizedCrop(rrc_size, scale=(0.5, 1.0)) 164 | aug = [rrc_fn] + aug 165 | print("Dataset with basic imagenet augmentation and RRC") 166 | else: 167 | print("Dataset with basic imagenet augmentation") 168 | 169 | if from_tensor: 170 | cast = [] 171 | else: 172 | cast = [transforms.ToTensor()] 173 | 174 | if normalize: 175 | normal_fn = [transforms.Normalize(mean=MEANS["imagenet"], std=STDS["imagenet"])] 176 | else: 177 | normal_fn = [] 178 | 179 | train_transform = transforms.Compose(resize_train + cast + aug + normal_fn) 180 | test_transform = transforms.Compose(resize_test + cast + normal_fn) 181 | 182 | return train_transform, test_transform 183 | -------------------------------------------------------------------------------- /condenser/evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from utils.experiment_tracker import get_time 5 | from utils.diffaug import DiffAug 6 | from utils.utils import define_model 7 | from utils.ddp import load_state_dict 8 | import warnings 9 | from utils.train_val import train_epoch, validate, train_epoch_softlabel 10 | 11 | warnings.filterwarnings("ignore") 12 | import torch.nn.functional as F 13 | import torch.distributed as dist 14 | from tqdm import tqdm 15 | import os 16 | 17 | 18 | def SoftCrossEntropy(inputs, target, temperature=1.0, reduction="average"): 19 | input_log_likelihood = -F.log_softmax(inputs / temperature, dim=1) 20 | target_log_likelihood = F.softmax(target / temperature, dim=1) 21 | batch = inputs.shape[0] 22 | loss = torch.sum(torch.mul(input_log_likelihood, target_log_likelihood)) / batch 23 | return loss 24 | 25 | 26 | # loss_function_kl = nn.KLDivLoss(reduction="batchmean") 27 | def evaluate_syn_data(args, model, train_loader, val_loader, logger=None): 28 | if args.softlabel: 29 | teacher_model = define_model( 30 | args.dataset, 31 | args.norm_type, 32 | args.net_type, 33 | args.nch, 34 | args.depth, 35 | args.width, 36 | args.nclass, 37 | args.logger, 38 | args.size, 39 | ).to(args.device) 40 | teacher_path = os.path.join(args.pretrain_dir, f"premodel0_trained.pth.tar") 41 | load_state_dict(teacher_path, teacher_model) 42 | train_criterion_sl = SoftCrossEntropy 43 | train_criterion = nn.CrossEntropyLoss().cuda() 44 | val_criterion = nn.CrossEntropyLoss().cuda() 45 | if args.eval_optimizer.lower() == "adamw": 46 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.adamw_lr) 47 | if logger and dist.get_rank() == 0: 48 | logger(f"Using AdamW optimizer with learning rate: {args.adamw_lr}") 49 | elif args.eval_optimizer.lower() == "sgd": 50 | optimizer = torch.optim.SGD( 51 | model.parameters(), lr=args.lr, momentum=args.momentum 52 | ) 53 | if logger and dist.get_rank() == 0: 54 | logger(f"Using SGD optimizer with learning rate: {args.lr}") 55 | scheduler = optim.lr_scheduler.MultiStepLR( 56 | optimizer, 57 | milestones=[ 58 | args.evaluation_epochs // 5, 59 | 2 * args.evaluation_epochs // 5, 60 | 3 * args.evaluation_epochs // 5, 61 | 4 * args.evaluation_epochs // 5, 62 | ], 63 | gamma=0.5, 64 | ) 65 | # scheduler = optim.lr_scheduler.MultiStepLR( 66 | # optimizer, milestones=[args.evaluation_epochs//2], gamma=0.1) 67 | 68 | best_acc1, best_acc5 = 0, 0 69 | acc1, acc5 = 0, 0 70 | model = model.cuda() 71 | model = torch.nn.parallel.DistributedDataParallel( 72 | model, device_ids=[args.rank], output_device=args.rank 73 | ) 74 | 75 | if args.dsa: 76 | aug = DiffAug(strategy=args.dsa_strategy, batch=False) 77 | if args.rank == 0: 78 | logger(f"Start training with DSA and {args.mixup} mixup") 79 | else: 80 | aug = None 81 | if args.rank == 0: 82 | logger(f"Start training with base augmentation and {args.mixup} mixup") 83 | pbar = tqdm(range(1, args.evaluation_epochs + 1)) 84 | for epoch in range(1, args.evaluation_epochs + 1): 85 | train_loader.sampler.set_epoch(epoch) 86 | if args.softlabel and epoch < ( 87 | args.evaluation_epochs - args.epoch_eval_interval 88 | ): 89 | acc1_tr, acc5_tr, loss_tr = train_epoch_softlabel( 90 | args, 91 | train_loader, 92 | model, 93 | teacher_model, 94 | train_criterion_sl, 95 | optimizer, 96 | epoch, 97 | aug, 98 | mixup=args.mixup, 99 | ) 100 | else: 101 | acc1_tr, acc5_tr, loss_tr = train_epoch( 102 | args, 103 | train_loader, 104 | model, 105 | train_criterion, 106 | optimizer, 107 | epoch, 108 | aug, 109 | mixup=args.mixup, 110 | ) 111 | if args.rank == 0: 112 | pbar.set_description( 113 | f"[Epoch {epoch}/{args.evaluation_epochs}] (Train) Top1 {acc1_tr:.1f} Top5 {acc5_tr:.1f} Lr {optimizer.param_groups[0]['lr']} Loss {loss_tr:.3f}" 114 | ) 115 | pbar.update(1) 116 | if (epoch % args.epoch_print_freq == 0) and (logger is not None) == 0: 117 | logger( 118 | "(Train) [Epoch {0}/{1}] {2} Top1 {top1:.1f} Top5 {top5:.1f} Loss {loss:.3f}".format( 119 | epoch, 120 | args.evaluation_epochs, 121 | get_time(), 122 | top1=acc1_tr, 123 | top5=acc5_tr, 124 | loss=loss_tr, 125 | ) 126 | ) 127 | 128 | if ( 129 | epoch % args.epoch_eval_interval == 0 130 | or epoch == args.evaluation_epochs 131 | or (epoch % (args.epoch_eval_interval / 50) == 0 and args.ipc > 50) 132 | ): 133 | acc1, acc5, loss_val = validate(val_loader, model, val_criterion) 134 | is_best = acc1 > best_acc1 135 | if is_best: 136 | best_acc1 = acc1 137 | best_acc5 = acc5 138 | if logger is not None and args.rank == 0: 139 | logger( 140 | "-------Eval Training Epoch [{} / {}] INFO--------".format( 141 | epoch, args.evaluation_epochs 142 | ) 143 | ) 144 | logger( 145 | f"Current accuracy (top-1 and 5): {acc1:.1f} {acc5:.1f}, loss: {loss_val:.3f}" 146 | ) 147 | logger( 148 | f"Best accuracy (top-1 and 5): {best_acc1:.1f} {best_acc5:.1f}" 149 | ) 150 | 151 | scheduler.step() 152 | 153 | return best_acc1, acc1 154 | -------------------------------------------------------------------------------- /pretrain/pretrain_script.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import torch.distributed as dist 5 | from torch.nn.parallel import DistributedDataParallel as DDP 6 | import sys 7 | 8 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 9 | import torch.optim as optim 10 | from utils.utils import define_model 11 | from utils.utils import get_loader 12 | from utils.train_val import train_epoch, validate 13 | from utils.diffaug import diffaug 14 | 15 | 16 | def get_available_model_id(pretrain_dir, model_id): 17 | while True: 18 | init_path = os.path.join(pretrain_dir, f"premodel{model_id}_init.pth.tar") 19 | trained_path = os.path.join(pretrain_dir, f"premodel{model_id}_trained.pth.tar") 20 | # Check if both files do not exist, if both are missing, return the current model_id 21 | if not os.path.exists(init_path) and not os.path.exists(trained_path): 22 | return model_id # Return the first available model_id 23 | model_id += 1 # If files exist, try the next model_id 24 | 25 | 26 | def count_existing_models(pretrain_dir): 27 | """ 28 | Count the number of initial model files (premodel{model_id}_init.pth.tar) 29 | that exist in pretrain_dir. 30 | """ 31 | model_count = 0 32 | for filename in os.listdir(pretrain_dir): 33 | if filename.startswith("premodel") and filename.endswith("_init.pth.tar"): 34 | model_count += 1 # Increment count if the file matches the criteria 35 | 36 | return model_count # Return the count of matching files 37 | 38 | 39 | def main_worker(args): 40 | train_loader, val_loader, train_sampler = get_loader(args) 41 | 42 | for model_id in range(args.model_num): 43 | if count_existing_models(args.pretrain_dir) >= args.model_num: 44 | break 45 | model_id = get_available_model_id(args.pretrain_dir, model_id) 46 | if args.rank == 0: 47 | print(f"Training model {model_id + 1}/{args.model_num}") 48 | model = define_model( 49 | args.dataset, 50 | args.norm_type, 51 | args.net_type, 52 | args.nch, 53 | args.depth, 54 | args.width, 55 | args.nclass, 56 | args.logger, 57 | args.size, 58 | ).to(args.device) 59 | model = model.to(args.device) 60 | model = DDP(model, device_ids=[args.rank]) 61 | 62 | # Save initial model state 63 | init_path = os.path.join(args.pretrain_dir, f"premodel{model_id}_init.pth.tar") 64 | if args.rank == 0 and not os.path.exists(init_path): 65 | torch.save(model.state_dict(), init_path) 66 | print(f"Model {model_id} initial state saved at {init_path}") 67 | 68 | # Define loss function, optimizer, and scheduler 69 | criterion = torch.nn.CrossEntropyLoss().to(args.device) 70 | optimizer = optim.SGD( 71 | model.parameters(), 72 | lr=args.lr, 73 | momentum=args.momentum, 74 | weight_decay=args.weight_decay, 75 | ) 76 | scheduler = optim.lr_scheduler.MultiStepLR( 77 | optimizer, 78 | milestones=[2 * args.pertrain_epochs // 3, 5 * args.pertrain_epochs // 6], 79 | gamma=0.2, 80 | ) 81 | _, aug_rand = diffaug(args) 82 | for epoch in range(0, args.pertrain_epochs): 83 | start_time = time.time() 84 | train_sampler.set_epoch(epoch) 85 | train_acc1, train_acc5, train_loss = train_epoch( 86 | args, 87 | train_loader, 88 | model, 89 | criterion, 90 | optimizer, 91 | epoch, 92 | aug_rand, 93 | mixup=args.mixup, 94 | ) 95 | val_acc1, val_acc5, val_loss = validate(val_loader, model, criterion) 96 | epoch_time = time.time() - start_time 97 | if args.rank == 0: 98 | args.logger( 99 | "...[Epoch {:2d}] Train acc: {:.1f} (loss: {:.3f}), Val acc: {:.1f}, Time: {:.2f} seconds".format( 100 | model_id, epoch, train_acc1, train_loss, val_acc1, epoch_time 101 | ) 102 | ) 103 | scheduler.step() 104 | 105 | # Save trained model state 106 | trained_path = os.path.join( 107 | args.pretrain_dir, f"premodel{model_id}_trained.pth.tar" 108 | ) 109 | if args.rank == 0: 110 | torch.save(model.state_dict(), trained_path) 111 | print(f"Model {model_id} trained state saved at {trained_path}") 112 | 113 | dist.destroy_process_group() 114 | 115 | 116 | def main(): 117 | import os 118 | from utils.init_script import init_script 119 | import argparse 120 | from argsprocessor.args import ArgsProcessor 121 | 122 | parser = argparse.ArgumentParser(description="Configuration parser") 123 | parser.add_argument( 124 | "--debug", 125 | dest="debug", 126 | action="store_true", 127 | help="When dataset is very large , you should get it", 128 | ) 129 | parser.add_argument( 130 | "--config_path", 131 | type=str, 132 | required=True, 133 | help="Path to the YAML configuration file", 134 | ) 135 | parser.add_argument( 136 | "--run_mode", 137 | type=str, 138 | choices=["Condense", "Evaluation", "Pretrain"], 139 | default="Pretrain", 140 | help="Condense or Evaluation", 141 | ) 142 | parser.add_argument( 143 | "--gpu", 144 | type=str, 145 | default="0", 146 | required=True, 147 | help='GPUs to use, e.g., "0,1,2,3"', 148 | ) 149 | parser.add_argument( 150 | "-i", "--ipc", type=int, default=1, help="number of condensed data per class" 151 | ) 152 | parser.add_argument("--load_path", type=str, help="Path to load the synset") 153 | parser.add_argument("--tf32", action="store_true", default=True, help="Enable TF32") 154 | args = parser.parse_args() 155 | 156 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 157 | 158 | args_processor = ArgsProcessor(args.config_path) 159 | 160 | args = args_processor.add_args_from_yaml(args) 161 | 162 | init_script(args) 163 | 164 | main_worker(args) 165 | 166 | 167 | if __name__ == "__main__": 168 | main() 169 | -------------------------------------------------------------------------------- /docs/static/js/bulma-slider.min.js: -------------------------------------------------------------------------------- 1 | !function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define([],e):"object"==typeof exports?exports.bulmaSlider=e():t.bulmaSlider=e()}("undefined"!=typeof self?self:this,function(){return function(n){var r={};function i(t){if(r[t])return r[t].exports;var e=r[t]={i:t,l:!1,exports:{}};return n[t].call(e.exports,e,e.exports,i),e.l=!0,e.exports}return i.m=n,i.c=r,i.d=function(t,e,n){i.o(t,e)||Object.defineProperty(t,e,{configurable:!1,enumerable:!0,get:n})},i.n=function(t){var e=t&&t.__esModule?function(){return t.default}:function(){return t};return i.d(e,"a",e),e},i.o=function(t,e){return Object.prototype.hasOwnProperty.call(t,e)},i.p="",i(i.s=0)}([function(t,e,n){"use strict";Object.defineProperty(e,"__esModule",{value:!0}),n.d(e,"isString",function(){return l});var r=n(1),i=Object.assign||function(t){for(var e=1;e=l.length&&(s=!0)):s=!0),s&&(t.once&&(u[e]=null),t.callback(r))});-1!==u.indexOf(null);)u.splice(u.indexOf(null),1)}}]),e}();e.a=i}]).default}); -------------------------------------------------------------------------------- /models/convnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ConvNet(nn.Module): 6 | def __init__( 7 | self, 8 | num_classes, 9 | net_norm="instance", 10 | net_depth=3, 11 | net_width=128, 12 | channel=3, 13 | net_act="relu", 14 | net_pooling="avgpooling", 15 | im_size=(32, 32), 16 | ): 17 | # print(f"Define Convnet (depth {net_depth}, width {net_width}, norm {net_norm})") 18 | super(ConvNet, self).__init__() 19 | if net_act == "sigmoid": 20 | self.net_act = nn.Sigmoid() 21 | elif net_act == "relu": 22 | self.net_act = nn.ReLU() 23 | elif net_act == "leakyrelu": 24 | self.net_act = nn.LeakyReLU(negative_slope=0.01) 25 | else: 26 | exit("unknown activation function: %s" % net_act) 27 | 28 | if net_pooling == "maxpooling": 29 | self.net_pooling = nn.MaxPool2d(kernel_size=2, stride=2) 30 | elif net_pooling == "avgpooling": 31 | self.net_pooling = nn.AvgPool2d(kernel_size=2, stride=2) 32 | elif net_pooling == "none": 33 | self.net_pooling = None 34 | else: 35 | exit("unknown net_pooling: %s" % net_pooling) 36 | 37 | self.depth = net_depth 38 | self.net_norm = net_norm 39 | 40 | self.layers, shape_feat = self._make_layers( 41 | channel, net_width, net_depth, net_norm, net_pooling, im_size 42 | ) 43 | num_feat = shape_feat[0] * shape_feat[1] * shape_feat[2] 44 | self.classifier = nn.Linear(num_feat, num_classes) 45 | 46 | def forward(self, x, return_features=False): 47 | for d in range(self.depth): 48 | x = self.layers["conv"][d](x) 49 | if len(self.layers["norm"]) > 0: 50 | x = self.layers["norm"][d](x) 51 | x = self.layers["act"][d](x) 52 | if len(self.layers["pool"]) > 0: 53 | x = self.layers["pool"][d](x) 54 | 55 | # x = nn.functional.avg_pool2d(x, x.shape[-1]) 56 | out = x.view(x.shape[0], -1) 57 | logit = self.classifier(out) 58 | 59 | if return_features: 60 | return logit, out 61 | else: 62 | return logit 63 | 64 | def get_feature_from_layer(self, x, return_features=False): 65 | features = [] # Used to store the features from each layer 66 | for d in range(self.depth): 67 | x = self.layers["conv"][d](x) 68 | if len(self.layers["norm"]) > 0: 69 | x = self.layers["norm"][d](x) 70 | x = self.layers["act"][d](x) 71 | if len(self.layers["pool"]) > 0: 72 | x = self.layers["pool"][d](x) 73 | 74 | # If features are required to be returned, add the current layer's output to the list 75 | if return_features: 76 | features.append(x.clone()) 77 | 78 | # x = nn.functional.avg_pool2d(x, x.shape[-1]) 79 | out = x.view(x.shape[0], -1) 80 | logit = self.classifier(out) 81 | 82 | if return_features: 83 | return ( 84 | logit, 85 | features, 86 | ) # Return the classification result and the list of features 87 | else: 88 | return logit 89 | 90 | def get_feature( 91 | self, x, idx_from, idx_to=-1, return_prob=False, return_logit=False 92 | ): 93 | if idx_to == -1: 94 | idx_to = idx_from 95 | features = [] 96 | 97 | for d in range(self.depth): 98 | x = self.layers["conv"][d](x) 99 | if self.net_norm: 100 | x = self.layers["norm"][d](x) 101 | x = self.layers["act"][d](x) 102 | if self.net_pooling: 103 | x = self.layers["pool"][d](x) 104 | features.append(x) 105 | if idx_to < len(features): 106 | return features[idx_from : idx_to + 1] 107 | 108 | if return_prob: 109 | out = x.view(x.size(0), -1) 110 | logit = self.classifier(out) 111 | prob = torch.softmax(logit, dim=-1) 112 | return features, prob 113 | elif return_logit: 114 | out = x.view(x.size(0), -1) 115 | logit = self.classifier(out) 116 | return features, logit 117 | else: 118 | return features[idx_from : idx_to + 1] 119 | 120 | def _get_normlayer(self, net_norm, shape_feat): 121 | # shape_feat = (c * h * w) 122 | if net_norm == "batch": 123 | norm = nn.BatchNorm2d(shape_feat[0], affine=True) 124 | elif net_norm == "layer": 125 | norm = nn.LayerNorm(shape_feat, elementwise_affine=True) 126 | elif net_norm == "instance": 127 | norm = nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True) 128 | elif net_norm == "group": 129 | norm = nn.GroupNorm(4, shape_feat[0], affine=True) 130 | elif net_norm == "none": 131 | norm = None 132 | else: 133 | norm = None 134 | exit("unknown net_norm: %s" % net_norm) 135 | return norm 136 | 137 | def _make_layers( 138 | self, channel, net_width, net_depth, net_norm, net_pooling, im_size 139 | ): 140 | layers = {"conv": [], "norm": [], "act": [], "pool": []} 141 | 142 | in_channels = channel 143 | if im_size[0] == 28: 144 | im_size = (32, 32) 145 | shape_feat = [in_channels, im_size[0], im_size[1]] 146 | 147 | for d in range(net_depth): 148 | layers["conv"] += [ 149 | nn.Conv2d( 150 | in_channels, 151 | net_width, 152 | kernel_size=3, 153 | padding=3 if channel == 1 and d == 0 else 1, 154 | ) 155 | ] 156 | shape_feat[0] = net_width 157 | if net_norm != "none": 158 | layers["norm"] += [self._get_normlayer(net_norm, shape_feat)] 159 | layers["act"] += [self.net_act] 160 | in_channels = net_width 161 | if net_pooling != "none": 162 | layers["pool"] += [self.net_pooling] 163 | shape_feat[1] //= 2 164 | shape_feat[2] //= 2 165 | 166 | layers["conv"] = nn.ModuleList(layers["conv"]) 167 | layers["norm"] = nn.ModuleList(layers["norm"]) 168 | layers["act"] = nn.ModuleList(layers["act"]) 169 | layers["pool"] = nn.ModuleList(layers["pool"]) 170 | layers = nn.ModuleDict(layers) 171 | 172 | return layers, shape_feat 173 | -------------------------------------------------------------------------------- /pretrain/pretrained_script_for_softlabel.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import torch.distributed as dist 5 | from torch.nn.parallel import DistributedDataParallel as DDP 6 | import sys 7 | 8 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 9 | import torch.optim as optim 10 | from utils.utils import define_model 11 | from utils.utils import get_loader 12 | from utils.train_val import train_epoch, validate 13 | from utils.diffaug import diffaug 14 | 15 | 16 | def get_available_model_id(pretrain_dir, model_id): 17 | while True: 18 | init_path = os.path.join(pretrain_dir, f"premodel{model_id}_init.pth.tar") 19 | trained_path = os.path.join(pretrain_dir, f"premodel{model_id}_trained.pth.tar") 20 | # Check if both files do not exist, if both are missing, return the current model_id 21 | if not os.path.exists(init_path) and not os.path.exists(trained_path): 22 | return model_id # Return the first available model_id 23 | model_id += 1 # If the files exist, try the next model_id 24 | 25 | 26 | def count_existing_models(pretrain_dir): 27 | """ 28 | Count the number of initial model files (premodel{model_id}_init.pth.tar) 29 | that exist in pretrain_dir. 30 | """ 31 | model_count = 0 32 | for filename in os.listdir(pretrain_dir): 33 | if filename.startswith("premodel") and filename.endswith("_init.pth.tar"): 34 | model_count += 1 # Increment the count if the file matches the criteria 35 | 36 | return model_count # Return the count of matching files 37 | 38 | 39 | def main_worker(args): 40 | args.pretrain_dir = os.path.join(args.pretrain_dir, f"softlabel") 41 | os.makedirs( 42 | args.pretrain_dir, exist_ok=True 43 | ) # Create the directory if it doesn't exist 44 | train_loader, val_loader, train_sampler = get_loader(args) 45 | 46 | for model_id in range(args.model_num): 47 | if count_existing_models(args.pretrain_dir) >= args.model_num: 48 | break 49 | model_id = get_available_model_id(args.pretrain_dir, model_id) 50 | if args.rank == 0: 51 | print(f"Training model {model_id + 1}/{args.model_num}") 52 | model = define_model( 53 | args.dataset, 54 | args.norm_type, 55 | args.net_type, 56 | args.nch, 57 | args.depth, 58 | args.width, 59 | args.nclass, 60 | args.logger, 61 | args.size, 62 | ).to(args.device) 63 | model = model.to(args.device) 64 | model = DDP(model, device_ids=[args.rank]) 65 | 66 | # Save initial model state 67 | init_path = os.path.join(args.pretrain_dir, f"premodel{model_id}_init.pth.tar") 68 | if args.rank == 0 and not os.path.exists(init_path): 69 | torch.save(model.state_dict(), init_path) 70 | print(f"Model {model_id} initial state saved at {init_path}") 71 | 72 | # Define loss function, optimizer, and scheduler 73 | criterion = torch.nn.CrossEntropyLoss().to(args.device) 74 | optimizer = optim.SGD( 75 | model.parameters(), 76 | lr=args.lr, 77 | momentum=args.momentum, 78 | weight_decay=args.weight_decay, 79 | ) 80 | scheduler = optim.lr_scheduler.MultiStepLR( 81 | optimizer, 82 | milestones=[2 * args.pertrain_epochs // 3, 5 * args.pertrain_epochs // 6], 83 | gamma=0.2, 84 | ) 85 | _, aug_rand = diffaug(args) 86 | for epoch in range(0, args.pertrain_epochs): 87 | start_time = time.time() 88 | train_sampler.set_epoch(epoch) 89 | train_acc1, train_acc5, train_loss = train_epoch( 90 | args, 91 | train_loader, 92 | model, 93 | criterion, 94 | optimizer, 95 | epoch, 96 | aug_rand, 97 | mixup=args.mixup, 98 | ) 99 | val_acc1, val_acc5, val_loss = validate(val_loader, model, criterion) 100 | epoch_time = time.time() - start_time 101 | if args.rank == 0: 102 | args.logger( 103 | "...[Epoch {:2d}] Train acc: {:.1f} (loss: {:.3f}), Val acc: {:.1f}, Time: {:.2f} seconds".format( 104 | model_id, epoch, train_acc1, train_loss, val_acc1, epoch_time 105 | ) 106 | ) 107 | scheduler.step() 108 | 109 | # Save model state after each epoch 110 | training_path = os.path.join( 111 | args.pretrain_dir, f"premodel{model_id}_epoch_{epoch}.pth.tar" 112 | ) 113 | if args.rank == 0: 114 | torch.save(model.state_dict(), training_path) 115 | print( 116 | f"Model {model_id} in Epoch {epoch} trained state saved at {training_path}" 117 | ) 118 | 119 | # Save trained model state 120 | trained_path = os.path.join( 121 | args.pretrain_dir, f"premodel{model_id}_trained.pth.tar" 122 | ) 123 | if args.rank == 0: 124 | torch.save(model.state_dict(), trained_path) 125 | print(f"Model {model_id} trained state saved at {trained_path}") 126 | 127 | dist.destroy_process_group() 128 | 129 | 130 | def main(): 131 | import os 132 | from utils.init_script import init_script 133 | import argparse 134 | from argsprocessor.args import ArgsProcessor 135 | 136 | parser = argparse.ArgumentParser(description="Configuration parser") 137 | parser.add_argument( 138 | "--debug", 139 | dest="debug", 140 | action="store_true", 141 | help="When dataset is very large , you should get it", 142 | ) 143 | parser.add_argument( 144 | "--config_path", 145 | type=str, 146 | required=True, 147 | help="Path to the YAML configuration file", 148 | ) 149 | parser.add_argument( 150 | "--run_mode", 151 | type=str, 152 | choices=["Condense", "Evaluation", "Pretrain"], 153 | default="Pretrain", 154 | help="Condense or Evaluation", 155 | ) 156 | parser.add_argument( 157 | "--gpu", 158 | type=str, 159 | default="0", 160 | required=True, 161 | help='GPUs to use, e.g., "0,1,2,3"', 162 | ) 163 | parser.add_argument( 164 | "-i", "--ipc", type=int, default=1, help="number of condensed data per class" 165 | ) 166 | parser.add_argument("--load_path", type=str, help="Path to load the synset") 167 | args = parser.parse_args() 168 | 169 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 170 | 171 | args_processor = ArgsProcessor(args.config_path) 172 | 173 | args = args_processor.add_args_from_yaml(args) 174 | 175 | init_script(args) 176 | 177 | main_worker(args) 178 | 179 | 180 | if __name__ == "__main__": 181 | main() 182 | -------------------------------------------------------------------------------- /utils/experiment_tracker.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import time 5 | import matplotlib 6 | 7 | matplotlib.use("Agg") 8 | __all__ = ["Compose", "Lighting", "ColorJitter"] 9 | import matplotlib.pyplot as plt 10 | 11 | 12 | class TimingTracker: 13 | def __init__(self, logger): 14 | self.print = logger 15 | self.timing_stats = {"data": 0, "aug": 0, "loss": 0, "backward": 0} 16 | 17 | def start_step(self): 18 | self.step_start_time = time.time() 19 | 20 | def record(self, phase): 21 | current_time = time.time() 22 | self.timing_stats[phase] += current_time - self.step_start_time 23 | self.step_start_time = current_time 24 | 25 | def report(self, reset=True): 26 | total_time = sum(self.timing_stats.values()) 27 | summary = ", ".join( 28 | f"{key}:{value:.2f}s({value / total_time * 100:.1f}%)" 29 | for key, value in self.timing_stats.items() 30 | ) 31 | if reset: 32 | self.reset_stats() 33 | return summary 34 | 35 | def reset_stats(self): 36 | self.timing_stats = {key: 0 for key in self.timing_stats} 37 | 38 | 39 | def accuracy(output, target, topk=(1,)): 40 | """Computes the precision@k for the specified values of k""" 41 | maxk = max(topk) 42 | batch_size = target.size(0) 43 | 44 | _, pred = output.topk(maxk, 1, True, True) 45 | pred = pred.t() 46 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 47 | 48 | res = [] 49 | for k in topk: 50 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 51 | res.append(correct_k.mul_(100.0 / batch_size)) 52 | 53 | return res 54 | 55 | 56 | class AverageMeter(object): 57 | """Computes and stores the average and current value""" 58 | 59 | def __init__(self): 60 | self.reset() 61 | 62 | def reset(self): 63 | self.val = 0 64 | self.avg = 0 65 | self.sum = 0 66 | self.count = 0 67 | 68 | def update(self, val, n=1): 69 | self.val = val 70 | self.sum += val * n 71 | self.count += n 72 | self.avg = self.sum / self.count 73 | 74 | 75 | class Logger: 76 | def __init__(self, path): 77 | self.logger = open(os.path.abspath(os.path.join(path, "print.log")), "w") 78 | 79 | def __call__(self, string, end="\n", print_=True): 80 | if print_: 81 | print("{}".format(string), end=end) 82 | if end == "\n": 83 | self.logger.write("{}\n".format(string)) 84 | else: 85 | self.logger.write("{} ".format(string)) 86 | self.logger.flush() 87 | 88 | 89 | def get_time(): 90 | return str(time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime())) 91 | 92 | 93 | class LossPlotter: 94 | def __init__( 95 | self, 96 | save_path, 97 | filename_pattern, 98 | dataset, 99 | ipc, 100 | dis_metrics, 101 | optimizer_info, 102 | ncfd_distribution="gussian", 103 | ): 104 | """ 105 | Initializes the LossPlotter with paths, dataset details, and optimizer settings. 106 | """ 107 | self.save_path = save_path 108 | self.filename_pattern = filename_pattern 109 | self.dataset = dataset 110 | self.ipc = ipc 111 | self.dis_metrics = dis_metrics 112 | self.ncfd_distribution = ncfd_distribution 113 | self.optimizer_info = optimizer_info 114 | 115 | # Initialize tracking lists for sigma values and loss/accuracy data 116 | self.sigma_history = [] 117 | self.loss_match_data = [] 118 | self.loss_calib_data = [] 119 | self.acc_data = {} 120 | 121 | # Create the save directory if it doesn't exist 122 | if not os.path.exists(self.save_path): 123 | os.makedirs(self.save_path) 124 | 125 | def _get_optimizer_str(self): 126 | """Generates a string representing the optimizer information.""" 127 | opt_type = self.optimizer_info["type"].upper() 128 | lr = self.optimizer_info["lr"] 129 | if opt_type in ["ADAM", "ADAMW"]: 130 | return f"{opt_type}(lr={lr:.4f}, wd={self.optimizer_info['weight_decay']})" 131 | return f"{opt_type}(lr={lr:.4f})" 132 | 133 | def update_sigma(self, sigma): 134 | """ 135 | Updates the sigma history with the new sigma value. 136 | 137 | Parameters: 138 | sigma : np.ndarray or torch.Tensor 139 | The sigma value for the current iteration. 140 | """ 141 | self.sigma_history.append(sigma) 142 | 143 | def update_match_loss(self, loss): 144 | """ 145 | Updates the match loss data. 146 | 147 | Parameters: 148 | loss : torch.Tensor 149 | The loss value for the current iteration. 150 | """ 151 | self.loss_match_data.append(loss) 152 | 153 | def update_calib_loss(self, loss): 154 | """ 155 | Updates the calibration loss data. 156 | 157 | Parameters: 158 | loss : torch.Tensor 159 | The calibration loss value for the current iteration. 160 | """ 161 | self.loss_calib_data.append(loss) 162 | 163 | def plot_and_save_loss_curve(self): 164 | """ 165 | Plots and saves the loss and accuracy trends. 166 | """ 167 | # Check if there is any data to plot 168 | has_loss_data = len(self.loss_match_data) > 0 169 | has_calib_data = len(self.loss_calib_data) > 0 170 | has_acc_data = len(self.acc_data) > 0 171 | 172 | if not has_loss_data and not has_acc_data and not has_calib_data: 173 | print("No loss or accuracy data to plot.") 174 | return 175 | 176 | # Create a figure and axis for plotting 177 | fig, ax1 = plt.subplots(figsize=(8, 5)) 178 | 179 | # Plot the match loss if available 180 | if has_loss_data: 181 | color = "tab:red" 182 | ax1.set_xlabel("Iteration") 183 | ax1.set_ylabel("Loss (Match)", color=color) 184 | ax1.plot( 185 | range(len(self.loss_match_data)), 186 | self.loss_match_data, 187 | linestyle="-", 188 | color=color, 189 | ) 190 | ax1.tick_params(axis="y", labelcolor=color) 191 | 192 | # Plot the calibration loss if available 193 | if has_calib_data: 194 | color = "tab:green" 195 | if has_loss_data: 196 | # If match loss is plotted, use a second y-axis 197 | ax2 = ax1.twinx() 198 | ax2.set_ylabel("Loss (Calib)", color=color) 199 | ax2.plot( 200 | range(len(self.loss_calib_data)), 201 | self.loss_calib_data, 202 | linestyle="-", 203 | color=color, 204 | ) 205 | ax2.tick_params(axis="y", labelcolor=color) 206 | else: 207 | # If no match loss, plot calibration loss on the first axis 208 | ax1.set_ylabel("Loss (Calib)", color=color) 209 | ax1.plot( 210 | range(len(self.loss_calib_data)), 211 | self.loss_calib_data, 212 | linestyle="-", 213 | color=color, 214 | ) 215 | ax1.tick_params(axis="y", labelcolor=color) 216 | 217 | # Plot the accuracy if available 218 | if has_acc_data: 219 | iters = sorted(self.acc_data.keys()) 220 | acc_values = [self.acc_data[it] for it in iters] 221 | 222 | if has_loss_data or has_calib_data: 223 | # Create a second y-axis for accuracy if loss is also plotted 224 | ax2 = ax1.twinx() 225 | color = "tab:blue" 226 | ax2.set_ylabel("Validation Mean Accuracy", color=color) 227 | ax2.plot(iters, acc_values, linestyle="--", color=color) 228 | ax2.tick_params(axis="y", labelcolor=color) 229 | else: 230 | # If no loss data, plot accuracy on the first axis 231 | color = "tab:blue" 232 | ax1.set_ylabel("Validation Mean Accuracy", color=color) 233 | ax1.plot(iters, acc_values, linestyle="--", color=color) 234 | ax1.tick_params(axis="y", labelcolor=color) 235 | 236 | # Set the title of the plot with dataset and optimizer information 237 | plt.title( 238 | f"{self.dataset} - IPC {self.ipc} - {self.dis_metrics}\n" 239 | f"{self.ncfd_distribution.capitalize()} - {self._get_optimizer_str()}" 240 | ) 241 | 242 | fig.tight_layout() 243 | 244 | # Save the plot as a PNG file 245 | file_name = os.path.join( 246 | self.save_path, f"{self.filename_pattern}_loss_acc.png" 247 | ) 248 | plt.savefig(file_name) 249 | plt.close() 250 | -------------------------------------------------------------------------------- /utils/train_val.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import numpy as np 4 | from utils.experiment_tracker import AverageMeter, accuracy 5 | from utils.mix_cut_up import random_indices, rand_bbox 6 | from utils.ddp import sync_distributed_metric 7 | import torch.nn.functional as F 8 | 9 | 10 | def train_epoch( 11 | args, train_loader, model, criterion, optimizer, epoch, aug=None, mixup="cut" 12 | ): 13 | batch_time = AverageMeter() 14 | data_time = AverageMeter() 15 | losses = AverageMeter() 16 | top1 = AverageMeter() 17 | top5 = AverageMeter() 18 | model.train() 19 | end = time.time() 20 | for i, (input, target) in enumerate(train_loader): 21 | input = input.cuda(non_blocking=True) 22 | target = target.cuda(non_blocking=True) 23 | 24 | data_time.update(time.time() - end) 25 | 26 | if aug is not None: 27 | with torch.no_grad(): 28 | input = aug(input) 29 | r = np.random.rand(1) 30 | if r < args.mix_p and mixup == "cut": 31 | lam = np.random.beta(args.beta, args.beta) 32 | rand_index = random_indices(target, nclass=args.nclass) 33 | target_b = target[rand_index] 34 | bbx1, bby1, bbx2, bby2 = rand_bbox(input.size(), lam) 35 | input[:, :, bbx1:bbx2, bby1:bby2] = input[ 36 | rand_index, :, bbx1:bbx2, bby1:bby2 37 | ] 38 | ratio = 1 - ( 39 | (bbx2 - bbx1) * (bby2 - bby1) / (input.size()[-1] * input.size()[-2]) 40 | ) 41 | output = model(input) 42 | loss = criterion(output, target) * ratio + criterion(output, target_b) * ( 43 | 1.0 - ratio 44 | ) 45 | else: 46 | output = model(input) 47 | loss = criterion(output, target) 48 | acc1, acc5 = accuracy(output.data, target, topk=(1, 5)) 49 | 50 | losses.update(loss.item(), input.size(0)) 51 | top1.update(acc1.item(), input.size(0)) 52 | top5.update(acc5.item(), input.size(0)) 53 | 54 | optimizer.zero_grad() 55 | loss.backward() 56 | optimizer.step() 57 | 58 | batch_time.update(time.time() - end) 59 | end = time.time() 60 | 61 | return sync_distributed_metric([top1.avg, top5.avg, losses.avg]) 62 | 63 | 64 | def get_softlabel(img, teacher_model, target=None): 65 | # Get the soft labels 66 | softlabel = teacher_model(img).detach() # [n, class] 67 | 68 | # If target is None, directly return the soft labels 69 | if target is None: 70 | return softlabel 71 | 72 | # Get the predicted class for each sample in the soft labels 73 | predicted = torch.argmax(softlabel, dim=1) # [n] 74 | 75 | # Find the indices of misclassified samples 76 | incorrect_indices = predicted != target # [n], True indicates misclassified samples 77 | 78 | # Replace the misclassified parts with the correct labels 79 | # Initialize the soft labels to all zeros 80 | corrected_softlabel = softlabel.clone() 81 | corrected_softlabel[incorrect_indices] = ( 82 | 0 # Set all class probabilities to 0 for misclassified samples 83 | ) 84 | corrected_softlabel[incorrect_indices, target[incorrect_indices]] = ( 85 | 1 # Set the correct class probability to 1 86 | ) 87 | 88 | return corrected_softlabel 89 | 90 | 91 | def train_epoch_softlabel( 92 | args, 93 | train_loader, 94 | model, 95 | teacher_model, 96 | criterion, 97 | optimizer, 98 | epoch, 99 | aug=None, 100 | mixup="cut", 101 | ): 102 | batch_time = AverageMeter() 103 | data_time = AverageMeter() 104 | losses = AverageMeter() 105 | top1 = AverageMeter() 106 | top5 = AverageMeter() 107 | model.train() 108 | end = time.time() 109 | teacher_model.eval() 110 | model.train() 111 | for i, (input, target) in enumerate(train_loader): 112 | input = input.cuda(non_blocking=True) 113 | target = target.cuda(non_blocking=True) 114 | with torch.no_grad(): 115 | # soft_label = get_softlabel(input,teacher_model,target).detach() 116 | soft_label = teacher_model(input).detach() 117 | soft_label = F.softmax(soft_label / args.temperature, dim=1) 118 | data_time.update(time.time() - end) 119 | 120 | if aug is not None: 121 | with torch.no_grad(): 122 | input = aug(input) 123 | r = np.random.rand(1) 124 | if r < args.mix_p and mixup == "cut": 125 | lam = np.random.beta(args.beta, args.beta) 126 | rand_index = random_indices(target, nclass=args.nclass) 127 | target_b = target[rand_index] 128 | bbx1, bby1, bbx2, bby2 = rand_bbox(input.size(), lam) 129 | input[:, :, bbx1:bbx2, bby1:bby2] = input[ 130 | rand_index, :, bbx1:bbx2, bby1:bby2 131 | ] 132 | ratio = 1 - ( 133 | (bbx2 - bbx1) * (bby2 - bby1) / (input.size()[-1] * input.size()[-2]) 134 | ) 135 | output = model(input) 136 | output = F.log_softmax(output / args.temperature, dim=1) 137 | loss = criterion(output, soft_label, args.temperature) * ratio + criterion( 138 | output, soft_label[rand_index, :], args.temperature 139 | ) * (1.0 - ratio) 140 | else: 141 | output = model(input) 142 | loss = criterion(output, soft_label, args.temperature) 143 | acc1, acc5 = accuracy(output.data, target, topk=(1, 5)) 144 | 145 | losses.update(loss.item(), input.size(0)) 146 | top1.update(acc1.item(), input.size(0)) 147 | top5.update(acc5.item(), input.size(0)) 148 | 149 | optimizer.zero_grad() 150 | loss.backward() 151 | optimizer.step() 152 | 153 | batch_time.update(time.time() - end) 154 | end = time.time() 155 | 156 | return sync_distributed_metric([top1.avg, top5.avg, losses.avg]) 157 | 158 | 159 | def train_epoch_softlabel( 160 | args, 161 | train_loader, 162 | model, 163 | teacher_model, 164 | criterion, 165 | optimizer, 166 | epoch, 167 | aug=None, 168 | mixup="cut", 169 | ): 170 | batch_time = AverageMeter() 171 | data_time = AverageMeter() 172 | losses = AverageMeter() 173 | top1 = AverageMeter() 174 | top5 = AverageMeter() 175 | model.train() 176 | end = time.time() 177 | teacher_model.eval() 178 | model.train() 179 | for i, (input, target) in enumerate(train_loader): 180 | input = input.cuda(non_blocking=True) 181 | target = target.cuda(non_blocking=True) 182 | with torch.no_grad(): 183 | soft_label = get_softlabel(input, teacher_model, target).detach() 184 | data_time.update(time.time() - end) 185 | 186 | if aug is not None: 187 | with torch.no_grad(): 188 | input = aug(input) 189 | r = np.random.rand(1) 190 | if r < args.mix_p and mixup == "cut": 191 | lam = np.random.beta(args.beta, args.beta) 192 | rand_index = random_indices(target, nclass=args.nclass) 193 | target_b = target[rand_index] 194 | bbx1, bby1, bbx2, bby2 = rand_bbox(input.size(), lam) 195 | input[:, :, bbx1:bbx2, bby1:bby2] = input[ 196 | rand_index, :, bbx1:bbx2, bby1:bby2 197 | ] 198 | ratio = 1 - ( 199 | (bbx2 - bbx1) * (bby2 - bby1) / (input.size()[-1] * input.size()[-2]) 200 | ) 201 | output = model(input) 202 | loss = criterion(output, soft_label) * ratio + criterion( 203 | output, soft_label[rand_index, :] 204 | ) * (1.0 - ratio) 205 | else: 206 | output = model(input) 207 | loss = criterion(output, soft_label) 208 | acc1, acc5 = accuracy(output.data, target, topk=(1, 5)) 209 | losses.update(loss.item(), input.size(0)) 210 | top1.update(acc1.item(), input.size(0)) 211 | top5.update(acc5.item(), input.size(0)) 212 | 213 | optimizer.zero_grad() 214 | loss.backward() 215 | optimizer.step() 216 | 217 | batch_time.update(time.time() - end) 218 | end = time.time() 219 | 220 | return sync_distributed_metric([top1.avg, top5.avg, losses.avg]) 221 | 222 | 223 | def validate(val_loader, model, criterion): 224 | batch_time = AverageMeter() 225 | losses = AverageMeter() 226 | top1 = AverageMeter() 227 | top5 = AverageMeter() 228 | model.eval() 229 | end = time.time() 230 | for i, (input, target) in enumerate(val_loader): 231 | input = input.cuda() 232 | target = target.cuda() 233 | output = model(input) 234 | loss = criterion(output, target) 235 | acc1, acc5 = accuracy(output.data, target, topk=(1, 5)) 236 | 237 | losses.update(loss.item(), input.size(0)) 238 | 239 | top1.update(acc1.item(), input.size(0)) 240 | top5.update(acc5.item(), input.size(0)) 241 | 242 | # measure elapsed time 243 | batch_time.update(time.time() - end) 244 | end = time.time() 245 | 246 | return sync_distributed_metric([top1.avg, top5.avg, losses.avg]) 247 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.datasets as datasets 3 | from data.dataset_statistics import IMG_EXTENSIONS, STDS, MEANS 4 | 5 | 6 | class Data: 7 | def __init__(self, X_train, Y_train): 8 | self.X_train = X_train 9 | self.Y_train = Y_train 10 | 11 | self.n_pool = len(X_train) 12 | 13 | def get_class_data(self, c): 14 | idxs = torch.arange(self.n_pool) 15 | idxs_c = torch.where(self.Y_train[idxs] == c) 16 | idxs = idxs[idxs_c[0]] 17 | dst_train = Dataset(self.X_train[idxs], self.Y_train[idxs]) 18 | trainloader = torch.utils.data.DataLoader( 19 | dst_train, batch_size=256, shuffle=False, num_workers=0 20 | ) 21 | return idxs, trainloader 22 | 23 | 24 | class Dataset(torch.utils.data.Dataset): 25 | def __init__(self, images, labels): 26 | # images: NxCxHxW tensor 27 | self.images = images.float() 28 | self.targets = labels 29 | 30 | def __getitem__(self, index): 31 | sample = self.images[index] 32 | target = self.targets[index] 33 | return sample, target 34 | 35 | def __len__(self): 36 | return self.images.shape[0] 37 | 38 | 39 | class TensorDataset(torch.utils.data.Dataset): 40 | def __init__(self, images, labels, transform=None): 41 | # images: NxCxHxW tensor 42 | self.images = images.detach().float() 43 | self.targets = labels.detach() 44 | self.transform = transform 45 | 46 | def __getitem__(self, index): 47 | sample = self.images[index] 48 | if self.transform != None: 49 | sample = self.transform(sample) 50 | 51 | target = self.targets[index] 52 | return sample, target 53 | 54 | def __len__(self): 55 | return self.images.shape[0] 56 | 57 | 58 | class ImageFolder_mtt(datasets.DatasetFolder): 59 | def __init__( 60 | self, 61 | root, 62 | transform=None, 63 | target_transform=None, 64 | loader=datasets.folder.default_loader, 65 | is_valid_file=None, 66 | load_memory=False, 67 | load_transform=None, 68 | type="none", 69 | slct_type="random", 70 | ipc=-1, 71 | ): 72 | self.extensions = IMG_EXTENSIONS if is_valid_file is None else None 73 | super(ImageFolder_mtt, self).__init__( 74 | root, 75 | loader, 76 | self.extensions, 77 | transform=transform, 78 | target_transform=target_transform, 79 | is_valid_file=is_valid_file, 80 | ) 81 | 82 | # Override 83 | self.nclass = 10 84 | self.classes, self.class_to_idx = self.find_subclasses(type=type) 85 | 86 | self.samples = datasets.folder.make_dataset( 87 | self.root, self.class_to_idx, self.extensions, is_valid_file 88 | ) 89 | 90 | if ipc > 0: 91 | self.samples = self._subset(slct_type=slct_type, ipc=ipc) 92 | 93 | self.targets = [s[1] for s in self.samples] 94 | self.load_memory = load_memory 95 | self.load_transform = load_transform 96 | if self.load_memory: 97 | self.imgs = self._load_images(load_transform) 98 | else: 99 | self.imgs = self.samples 100 | 101 | def find_subclasses(self, type="none"): 102 | """Finds the class folders in a dataset.""" 103 | classes = [] 104 | # ['imagenette', 'imagewoof', 'imagemeow', 'imagesquawk', 'imagefruit', 'imageyellow'] 105 | if type != "none": 106 | with open("./imagenet_subset/class{}.txt".format(type), "r") as f: 107 | class_name = f.readlines() 108 | for c in class_name: 109 | c = c.split("\n")[0] 110 | classes.append(c) 111 | 112 | class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} 113 | assert len(classes) == self.nclass 114 | 115 | return classes, class_to_idx 116 | 117 | def _subset(self, slct_type="random", ipc=10): 118 | n = len(self.samples) 119 | idx_class = [[] for _ in range(self.nclass)] 120 | for i in range(n): 121 | label = self.samples[i][1] 122 | idx_class[label].append(i) 123 | 124 | min_class = np.array([len(idx_class[c]) for c in range(self.nclass)]).min() 125 | print("# examples in the smallest class: ", min_class) 126 | assert ipc < min_class 127 | 128 | if slct_type == "random": 129 | indices = np.arange(n) 130 | else: 131 | raise AssertionError(f"selection type does not exist!") 132 | 133 | samples_subset = [] 134 | idx_class_slct = [[] for _ in range(self.nclass)] 135 | for i in indices: 136 | label = self.samples[i][1] 137 | if len(idx_class_slct[label]) < ipc: 138 | idx_class_slct[label].append(i) 139 | samples_subset.append(self.samples[i]) 140 | 141 | if len(samples_subset) == ipc * self.nclass: 142 | break 143 | 144 | return samples_subset 145 | 146 | def _load_images(self, transform=None): 147 | """Load images on memory""" 148 | imgs = [] 149 | for i, (path, _) in enumerate(self.samples): 150 | sample = self.loader(path) 151 | if transform != None: 152 | sample = transform(sample) 153 | imgs.append(sample) 154 | if i % 100 == 0: 155 | print(f"Image loading.. {i}/{len(self.samples)}", end="\r") 156 | 157 | print(" " * 50, end="\r") 158 | return imgs 159 | 160 | def __getitem__(self, index): 161 | if not self.load_memory: 162 | path = self.samples[index][0] 163 | sample = self.loader(path) 164 | else: 165 | sample = self.imgs[index] 166 | 167 | target = self.targets[index] 168 | if self.transform is not None: 169 | sample = self.transform(sample) 170 | if self.target_transform is not None: 171 | target = self.target_transform(target) 172 | 173 | return sample, target 174 | 175 | 176 | class ImageFolder(datasets.DatasetFolder): 177 | def __init__( 178 | self, 179 | root, 180 | transform=None, 181 | target_transform=None, 182 | loader=datasets.folder.default_loader, 183 | is_valid_file=None, 184 | load_memory=False, 185 | load_transform=None, 186 | nclass=100, 187 | phase=0, 188 | slct_type="random", 189 | ipc=-1, 190 | seed=-1, 191 | ): 192 | self.extensions = IMG_EXTENSIONS if is_valid_file is None else None 193 | super(ImageFolder, self).__init__( 194 | root, 195 | loader, 196 | self.extensions, 197 | transform=transform, 198 | target_transform=target_transform, 199 | is_valid_file=is_valid_file, 200 | ) 201 | if nclass < 1000: 202 | self.classes, self.class_to_idx = self.find_subclasses( 203 | nclass=nclass, phase=phase, seed=seed 204 | ) 205 | else: 206 | self.classes, self.class_to_idx = self.find_classes(self.root) 207 | self.nclass = nclass 208 | self.samples = datasets.folder.make_dataset( 209 | self.root, self.class_to_idx, self.extensions, is_valid_file 210 | ) 211 | if ipc > 0: 212 | self.samples = self._subset(slct_type=slct_type, ipc=ipc) 213 | self.targets = [s[1] for s in self.samples] 214 | self.load_memory = load_memory 215 | self.load_transform = load_transform 216 | if self.load_memory: 217 | self.imgs = self._load_images(load_transform) 218 | else: 219 | self.imgs = self.samples 220 | 221 | def find_subclasses(self, nclass=100, phase=0, seed=0): 222 | classes = [] 223 | phase = max(0, phase) 224 | cls_from = nclass * phase 225 | cls_to = nclass * (phase + 1) 226 | if seed == 0: 227 | with open("./imagenet_subset/class100.txt", "r") as f: 228 | class_name = f.readlines() 229 | for c in class_name: 230 | c = c.split("\n")[0] 231 | classes.append(c) 232 | classes = classes[cls_from:cls_to] 233 | else: 234 | np.random.seed(seed) 235 | class_indices = np.random.permutation(len(self.classes))[cls_from:cls_to] 236 | for i in class_indices: 237 | classes.append(self.classes[i]) 238 | 239 | class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} 240 | assert len(classes) == nclass 241 | return classes, class_to_idx 242 | 243 | def _subset(self, slct_type="random", ipc=10): 244 | n = len(self.samples) 245 | idx_class = [[] for _ in range(self.nclass)] 246 | for i in range(n): 247 | label = self.samples[i][1] 248 | idx_class[label].append(i) 249 | min_class = np.array([len(idx_class[c]) for c in range(self.nclass)]).min() 250 | print("# examples in the smallest class: ", min_class) 251 | assert ipc < min_class 252 | if slct_type == "random": 253 | indices = np.arange(n) 254 | else: 255 | raise AssertionError(f"selection type does not exist!") 256 | samples_subset = [] 257 | idx_class_slct = [[] for _ in range(self.nclass)] 258 | for i in indices: 259 | label = self.samples[i][1] 260 | if len(idx_class_slct[label]) < ipc: 261 | idx_class_slct[label].append(i) 262 | samples_subset.append(self.samples[i]) 263 | 264 | if len(samples_subset) == ipc * self.nclass: 265 | break 266 | return samples_subset 267 | 268 | def _load_images(self, transform=None): 269 | imgs = [] 270 | for i, (path, _) in enumerate(self.samples): 271 | sample = self.loader(path) 272 | if transform != None: 273 | sample = transform(sample) 274 | imgs.append(sample) 275 | if i % 100 == 0: 276 | print(f"Image loading.. {i}/{len(self.samples)}", end="\r") 277 | print(" " * 50, end="\r") 278 | return imgs 279 | 280 | def __getitem__(self, index): 281 | if not self.load_memory: 282 | path = self.samples[index][0] 283 | sample = self.loader(path) 284 | else: 285 | sample = self.imgs[index] 286 | 287 | target = self.targets[index] 288 | if self.transform is not None: 289 | sample = self.transform(sample) 290 | if self.target_transform is not None: 291 | target = self.target_transform(target) 292 | 293 | return sample, target 294 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | # Original code: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 2 | 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | import math 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | "3x3 convolution with padding" 10 | return nn.Conv2d( 11 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False 12 | ) 13 | 14 | 15 | def normalization(inplanes, norm_type): 16 | if norm_type == "batch": 17 | bn = nn.BatchNorm2d(inplanes) 18 | elif norm_type == "instance": 19 | bn = nn.GroupNorm(inplanes, inplanes) 20 | else: 21 | raise AssertionError(f"Check normalization type! {norm_type}") 22 | return bn 23 | 24 | 25 | class IntroBlock(nn.Module): 26 | def __init__(self, size, planes, norm_type, nch=3): 27 | super(IntroBlock, self).__init__() 28 | self.size = size 29 | if size == "large": 30 | self.conv1 = nn.Conv2d( 31 | nch, planes, kernel_size=7, stride=2, padding=3, bias=False 32 | ) 33 | self.bn1 = normalization(planes, norm_type) 34 | self.relu = nn.ReLU(inplace=True) 35 | self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 36 | elif size == "mid": 37 | self.conv1 = nn.Conv2d( 38 | nch, planes, kernel_size=3, stride=1, padding=1, bias=False 39 | ) 40 | self.bn1 = normalization(planes, norm_type) 41 | self.relu = nn.ReLU(inplace=True) 42 | self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 43 | elif size == "small": 44 | self.conv1 = nn.Conv2d( 45 | nch, planes, kernel_size=3, stride=1, padding=1, bias=False 46 | ) 47 | self.bn1 = normalization(planes, norm_type) 48 | self.relu = nn.ReLU(inplace=True) 49 | else: 50 | raise AssertionError("Check network size type!") 51 | 52 | def forward(self, x): 53 | x = self.conv1(x) 54 | x = self.bn1(x) 55 | x = self.relu(x) 56 | if self.size != "small": 57 | x = self.pool(x) 58 | 59 | return x 60 | 61 | 62 | class BasicBlock(nn.Module): 63 | expansion = 1 64 | 65 | def __init__(self, inplanes, planes, norm_type="batch", stride=1, downsample=None): 66 | super(BasicBlock, self).__init__() 67 | self.conv1 = conv3x3(inplanes, planes, stride) 68 | self.bn1 = normalization(planes, norm_type) 69 | self.conv2 = conv3x3(planes, planes) 70 | self.bn2 = normalization(planes, norm_type) 71 | self.relu = nn.ReLU(inplace=True) 72 | 73 | self.downsample = downsample 74 | self.stride = stride 75 | 76 | def forward(self, x): 77 | residual = x 78 | 79 | out = self.conv1(x) 80 | out = self.bn1(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv2(out) 84 | out = self.bn2(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class Bottleneck(nn.Module): 96 | expansion = 4 97 | 98 | def __init__(self, inplanes, planes, norm_type="batch", stride=1, downsample=None): 99 | super(Bottleneck, self).__init__() 100 | 101 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 102 | self.bn1 = normalization(planes, norm_type) 103 | self.conv2 = nn.Conv2d( 104 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 105 | ) 106 | self.bn2 = normalization(planes, norm_type) 107 | self.conv3 = nn.Conv2d( 108 | planes, planes * Bottleneck.expansion, kernel_size=1, bias=False 109 | ) 110 | self.bn3 = normalization(planes * Bottleneck.expansion, norm_type) 111 | self.relu = nn.ReLU(inplace=True) 112 | 113 | self.downsample = downsample 114 | self.stride = stride 115 | 116 | def forward(self, x): 117 | residual = x 118 | 119 | out = self.conv1(x) 120 | out = self.bn1(out) 121 | out = self.relu(out) 122 | 123 | out = self.conv2(out) 124 | out = self.bn2(out) 125 | out = self.relu(out) 126 | 127 | out = self.conv3(out) 128 | out = self.bn3(out) 129 | 130 | if self.downsample is not None: 131 | residual = self.downsample(x) 132 | 133 | out += residual 134 | out = self.relu(out) 135 | 136 | return out 137 | 138 | 139 | class ResNet(nn.Module): 140 | def __init__(self, dataset, depth, num_classes, norm_type="batch", size=-1, nch=3): 141 | super(ResNet, self).__init__() 142 | self.dataset = dataset 143 | self.norm_type = norm_type 144 | 145 | if self.dataset.startswith("cifar") or (0 < size and size <= 64): 146 | self.net_size = "small" 147 | elif 64 < size and size <= 128: 148 | self.net_size = "mid" 149 | else: 150 | self.net_size = "large" 151 | 152 | # print(f"ResNet-{depth}-{self.net_size} norm: {self.norm_type}") 153 | if self.dataset.startswith("cifar"): 154 | self.inplanes = 32 155 | n = int((depth - 2) / 6) 156 | block = BasicBlock 157 | 158 | self.layer0 = IntroBlock(self.net_size, self.inplanes, norm_type, nch=nch) 159 | self.layer1 = self._make_layer(block, 32, n, stride=1) 160 | self.layer2 = self._make_layer(block, 64, n, stride=2) 161 | self.layer3 = self._make_layer(block, 128, n, stride=2) 162 | self.layer4 = self._make_layer(block, 256, n, stride=2) 163 | self.avgpool = nn.AvgPool2d(4) 164 | self.fc = nn.Linear(256 * block.expansion, num_classes) 165 | 166 | else: 167 | blocks = { 168 | 10: BasicBlock, 169 | 18: BasicBlock, 170 | 34: BasicBlock, 171 | 50: Bottleneck, 172 | 101: Bottleneck, 173 | 152: Bottleneck, 174 | 200: Bottleneck, 175 | } 176 | layers = { 177 | 10: [1, 1, 1, 1], 178 | 18: [2, 2, 2, 2], 179 | 34: [3, 4, 6, 3], 180 | 50: [3, 4, 6, 3], 181 | 101: [3, 4, 23, 3], 182 | 152: [3, 8, 36, 3], 183 | 200: [3, 24, 36, 3], 184 | } 185 | assert layers[ 186 | depth 187 | ], "invalid detph for ResNet (depth should be one of 18, 34, 50, 101, 152, and 200)" 188 | 189 | self.inplanes = 64 190 | 191 | self.layer0 = IntroBlock(self.net_size, self.inplanes, norm_type, nch=nch) 192 | self.layer1 = self._make_layer(blocks[depth], 64, layers[depth][0]) 193 | self.layer2 = self._make_layer( 194 | blocks[depth], 128, layers[depth][1], stride=2 195 | ) 196 | self.layer3 = self._make_layer( 197 | blocks[depth], 256, layers[depth][2], stride=2 198 | ) 199 | self.layer4 = self._make_layer( 200 | blocks[depth], 512, layers[depth][3], stride=2 201 | ) 202 | self.avgpool = nn.AvgPool2d(7) 203 | self.fc = nn.Linear(512 * blocks[depth].expansion, num_classes) 204 | 205 | for m in self.modules(): 206 | if isinstance(m, nn.Conv2d): 207 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 208 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 209 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.GroupNorm): 210 | m.weight.data.fill_(1) 211 | m.bias.data.zero_() 212 | 213 | def _make_layer(self, block, planes, blocks, stride=1): 214 | downsample = None 215 | if stride != 1 or self.inplanes != planes * block.expansion: 216 | downsample = nn.Sequential( 217 | nn.Conv2d( 218 | self.inplanes, 219 | planes * block.expansion, 220 | kernel_size=1, 221 | stride=stride, 222 | bias=False, 223 | ), 224 | normalization(planes * block.expansion, self.norm_type), 225 | ) 226 | 227 | layers = [] 228 | layers.append( 229 | block( 230 | self.inplanes, 231 | planes, 232 | norm_type=self.norm_type, 233 | stride=stride, 234 | downsample=downsample, 235 | ) 236 | ) 237 | self.inplanes = planes * block.expansion 238 | for i in range(1, blocks): 239 | layers.append(block(self.inplanes, planes, norm_type=self.norm_type)) 240 | 241 | return nn.Sequential(*layers) 242 | 243 | def forward(self, x): 244 | x = self.layer0(x) 245 | x = self.layer1(x) 246 | x = self.layer2(x) 247 | x = self.layer3(x) 248 | x = self.layer4(x) 249 | 250 | x = F.avg_pool2d(x, x.shape[-1]) 251 | x = x.view(x.size(0), -1) 252 | x = self.fc(x) 253 | 254 | return x 255 | 256 | def get_feature(self, x, idx_from, idx_to=-1): 257 | if idx_to == -1: 258 | idx_to = idx_from 259 | 260 | features = [] 261 | x = self.layer0(x) 262 | features.append(x) # starts from 0 263 | if idx_to < len(features): 264 | return features[idx_from : idx_to + 1] 265 | 266 | x = self.layer1(x) 267 | features.append(x) 268 | if idx_to < len(features): 269 | return features[idx_from : idx_to + 1] 270 | 271 | x = self.layer2(x) 272 | features.append(x) 273 | if idx_to < len(features): 274 | return features[idx_from : idx_to + 1] 275 | 276 | x = self.layer3(x) 277 | features.append(x) 278 | if idx_to < len(features): 279 | return features[idx_from : idx_to + 1] 280 | 281 | x = self.layer4(x) 282 | features.append(x) 283 | if idx_to < len(features): 284 | return features[idx_from : idx_to + 1] 285 | 286 | x = F.avg_pool2d(x, x.shape[-1]) 287 | x = x.view(x.size(0), -1) 288 | features.append(x) 289 | if idx_to < len(features): 290 | return features[idx_from : idx_to + 1] 291 | 292 | x = self.fc(x) 293 | features.append(x) # logit is 6 294 | return features[idx_from : idx_to + 1] 295 | 296 | 297 | if __name__ == "__main__": 298 | import torch 299 | 300 | dataset = "imagenet" 301 | depth = 10 302 | num_classes = 10 303 | size = 56 304 | norm_type = "instance" 305 | 306 | model = ResNet(dataset, depth, num_classes, size=size, norm_type=norm_type).cuda() 307 | print(model) 308 | 309 | data = torch.ones([128, 3, size, size]).to("cuda") 310 | output = model(data) 311 | print(output.shape) 312 | -------------------------------------------------------------------------------- /docs/static/css/bulma-slider.min.css: -------------------------------------------------------------------------------- 1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}input[type=range].slider{-webkit-appearance:none;-moz-appearance:none;appearance:none;margin:1rem 0;background:0 0;touch-action:none}input[type=range].slider.is-fullwidth{display:block;width:100%}input[type=range].slider:focus{outline:0}input[type=range].slider:not([orient=vertical])::-webkit-slider-runnable-track{width:100%}input[type=range].slider:not([orient=vertical])::-moz-range-track{width:100%}input[type=range].slider:not([orient=vertical])::-ms-track{width:100%}input[type=range].slider:not([orient=vertical]).has-output+output,input[type=range].slider:not([orient=vertical]).has-output-tooltip+output{width:3rem;background:#4a4a4a;border-radius:4px;padding:.4rem .8rem;font-size:.75rem;line-height:.75rem;text-align:center;text-overflow:ellipsis;white-space:nowrap;color:#fff;overflow:hidden;pointer-events:none;z-index:200}input[type=range].slider:not([orient=vertical]).has-output-tooltip:disabled+output,input[type=range].slider:not([orient=vertical]).has-output:disabled+output{opacity:.5}input[type=range].slider:not([orient=vertical]).has-output{display:inline-block;vertical-align:middle;width:calc(100% - (4.2rem))}input[type=range].slider:not([orient=vertical]).has-output+output{display:inline-block;margin-left:.75rem;vertical-align:middle}input[type=range].slider:not([orient=vertical]).has-output-tooltip{display:block}input[type=range].slider:not([orient=vertical]).has-output-tooltip+output{position:absolute;left:0;top:-.1rem}input[type=range].slider[orient=vertical]{-webkit-appearance:slider-vertical;-moz-appearance:slider-vertical;appearance:slider-vertical;-webkit-writing-mode:bt-lr;-ms-writing-mode:bt-lr;writing-mode:bt-lr}input[type=range].slider[orient=vertical]::-webkit-slider-runnable-track{height:100%}input[type=range].slider[orient=vertical]::-moz-range-track{height:100%}input[type=range].slider[orient=vertical]::-ms-track{height:100%}input[type=range].slider::-webkit-slider-runnable-track{cursor:pointer;animate:.2s;box-shadow:0 0 0 #7a7a7a;background:#dbdbdb;border-radius:4px;border:0 solid #7a7a7a}input[type=range].slider::-moz-range-track{cursor:pointer;animate:.2s;box-shadow:0 0 0 #7a7a7a;background:#dbdbdb;border-radius:4px;border:0 solid #7a7a7a}input[type=range].slider::-ms-track{cursor:pointer;animate:.2s;box-shadow:0 0 0 #7a7a7a;background:#dbdbdb;border-radius:4px;border:0 solid #7a7a7a}input[type=range].slider::-ms-fill-lower{background:#dbdbdb;border-radius:4px}input[type=range].slider::-ms-fill-upper{background:#dbdbdb;border-radius:4px}input[type=range].slider::-webkit-slider-thumb{box-shadow:none;border:1px solid #b5b5b5;border-radius:4px;background:#fff;cursor:pointer}input[type=range].slider::-moz-range-thumb{box-shadow:none;border:1px solid #b5b5b5;border-radius:4px;background:#fff;cursor:pointer}input[type=range].slider::-ms-thumb{box-shadow:none;border:1px solid #b5b5b5;border-radius:4px;background:#fff;cursor:pointer}input[type=range].slider::-webkit-slider-thumb{-webkit-appearance:none;appearance:none}input[type=range].slider.is-circle::-webkit-slider-thumb{border-radius:290486px}input[type=range].slider.is-circle::-moz-range-thumb{border-radius:290486px}input[type=range].slider.is-circle::-ms-thumb{border-radius:290486px}input[type=range].slider:active::-webkit-slider-thumb{-webkit-transform:scale(1.25);transform:scale(1.25)}input[type=range].slider:active::-moz-range-thumb{transform:scale(1.25)}input[type=range].slider:active::-ms-thumb{transform:scale(1.25)}input[type=range].slider:disabled{opacity:.5;cursor:not-allowed}input[type=range].slider:disabled::-webkit-slider-thumb{cursor:not-allowed;-webkit-transform:scale(1);transform:scale(1)}input[type=range].slider:disabled::-moz-range-thumb{cursor:not-allowed;transform:scale(1)}input[type=range].slider:disabled::-ms-thumb{cursor:not-allowed;transform:scale(1)}input[type=range].slider:not([orient=vertical]){min-height:calc((1rem + 2px) * 1.25)}input[type=range].slider:not([orient=vertical])::-webkit-slider-runnable-track{height:.5rem}input[type=range].slider:not([orient=vertical])::-moz-range-track{height:.5rem}input[type=range].slider:not([orient=vertical])::-ms-track{height:.5rem}input[type=range].slider[orient=vertical]::-webkit-slider-runnable-track{width:.5rem}input[type=range].slider[orient=vertical]::-moz-range-track{width:.5rem}input[type=range].slider[orient=vertical]::-ms-track{width:.5rem}input[type=range].slider::-webkit-slider-thumb{height:1rem;width:1rem}input[type=range].slider::-moz-range-thumb{height:1rem;width:1rem}input[type=range].slider::-ms-thumb{height:1rem;width:1rem}input[type=range].slider::-ms-thumb{margin-top:0}input[type=range].slider::-webkit-slider-thumb{margin-top:-.25rem}input[type=range].slider[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.25rem}input[type=range].slider.is-small:not([orient=vertical]){min-height:calc((.75rem + 2px) * 1.25)}input[type=range].slider.is-small:not([orient=vertical])::-webkit-slider-runnable-track{height:.375rem}input[type=range].slider.is-small:not([orient=vertical])::-moz-range-track{height:.375rem}input[type=range].slider.is-small:not([orient=vertical])::-ms-track{height:.375rem}input[type=range].slider.is-small[orient=vertical]::-webkit-slider-runnable-track{width:.375rem}input[type=range].slider.is-small[orient=vertical]::-moz-range-track{width:.375rem}input[type=range].slider.is-small[orient=vertical]::-ms-track{width:.375rem}input[type=range].slider.is-small::-webkit-slider-thumb{height:.75rem;width:.75rem}input[type=range].slider.is-small::-moz-range-thumb{height:.75rem;width:.75rem}input[type=range].slider.is-small::-ms-thumb{height:.75rem;width:.75rem}input[type=range].slider.is-small::-ms-thumb{margin-top:0}input[type=range].slider.is-small::-webkit-slider-thumb{margin-top:-.1875rem}input[type=range].slider.is-small[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.1875rem}input[type=range].slider.is-medium:not([orient=vertical]){min-height:calc((1.25rem + 2px) * 1.25)}input[type=range].slider.is-medium:not([orient=vertical])::-webkit-slider-runnable-track{height:.625rem}input[type=range].slider.is-medium:not([orient=vertical])::-moz-range-track{height:.625rem}input[type=range].slider.is-medium:not([orient=vertical])::-ms-track{height:.625rem}input[type=range].slider.is-medium[orient=vertical]::-webkit-slider-runnable-track{width:.625rem}input[type=range].slider.is-medium[orient=vertical]::-moz-range-track{width:.625rem}input[type=range].slider.is-medium[orient=vertical]::-ms-track{width:.625rem}input[type=range].slider.is-medium::-webkit-slider-thumb{height:1.25rem;width:1.25rem}input[type=range].slider.is-medium::-moz-range-thumb{height:1.25rem;width:1.25rem}input[type=range].slider.is-medium::-ms-thumb{height:1.25rem;width:1.25rem}input[type=range].slider.is-medium::-ms-thumb{margin-top:0}input[type=range].slider.is-medium::-webkit-slider-thumb{margin-top:-.3125rem}input[type=range].slider.is-medium[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.3125rem}input[type=range].slider.is-large:not([orient=vertical]){min-height:calc((1.5rem + 2px) * 1.25)}input[type=range].slider.is-large:not([orient=vertical])::-webkit-slider-runnable-track{height:.75rem}input[type=range].slider.is-large:not([orient=vertical])::-moz-range-track{height:.75rem}input[type=range].slider.is-large:not([orient=vertical])::-ms-track{height:.75rem}input[type=range].slider.is-large[orient=vertical]::-webkit-slider-runnable-track{width:.75rem}input[type=range].slider.is-large[orient=vertical]::-moz-range-track{width:.75rem}input[type=range].slider.is-large[orient=vertical]::-ms-track{width:.75rem}input[type=range].slider.is-large::-webkit-slider-thumb{height:1.5rem;width:1.5rem}input[type=range].slider.is-large::-moz-range-thumb{height:1.5rem;width:1.5rem}input[type=range].slider.is-large::-ms-thumb{height:1.5rem;width:1.5rem}input[type=range].slider.is-large::-ms-thumb{margin-top:0}input[type=range].slider.is-large::-webkit-slider-thumb{margin-top:-.375rem}input[type=range].slider.is-large[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.375rem}input[type=range].slider.is-white::-moz-range-track{background:#fff!important}input[type=range].slider.is-white::-webkit-slider-runnable-track{background:#fff!important}input[type=range].slider.is-white::-ms-track{background:#fff!important}input[type=range].slider.is-white::-ms-fill-lower{background:#fff}input[type=range].slider.is-white::-ms-fill-upper{background:#fff}input[type=range].slider.is-white .has-output-tooltip+output,input[type=range].slider.is-white.has-output+output{background-color:#fff;color:#0a0a0a}input[type=range].slider.is-black::-moz-range-track{background:#0a0a0a!important}input[type=range].slider.is-black::-webkit-slider-runnable-track{background:#0a0a0a!important}input[type=range].slider.is-black::-ms-track{background:#0a0a0a!important}input[type=range].slider.is-black::-ms-fill-lower{background:#0a0a0a}input[type=range].slider.is-black::-ms-fill-upper{background:#0a0a0a}input[type=range].slider.is-black .has-output-tooltip+output,input[type=range].slider.is-black.has-output+output{background-color:#0a0a0a;color:#fff}input[type=range].slider.is-light::-moz-range-track{background:#f5f5f5!important}input[type=range].slider.is-light::-webkit-slider-runnable-track{background:#f5f5f5!important}input[type=range].slider.is-light::-ms-track{background:#f5f5f5!important}input[type=range].slider.is-light::-ms-fill-lower{background:#f5f5f5}input[type=range].slider.is-light::-ms-fill-upper{background:#f5f5f5}input[type=range].slider.is-light .has-output-tooltip+output,input[type=range].slider.is-light.has-output+output{background-color:#f5f5f5;color:#363636}input[type=range].slider.is-dark::-moz-range-track{background:#363636!important}input[type=range].slider.is-dark::-webkit-slider-runnable-track{background:#363636!important}input[type=range].slider.is-dark::-ms-track{background:#363636!important}input[type=range].slider.is-dark::-ms-fill-lower{background:#363636}input[type=range].slider.is-dark::-ms-fill-upper{background:#363636}input[type=range].slider.is-dark .has-output-tooltip+output,input[type=range].slider.is-dark.has-output+output{background-color:#363636;color:#f5f5f5}input[type=range].slider.is-primary::-moz-range-track{background:#00d1b2!important}input[type=range].slider.is-primary::-webkit-slider-runnable-track{background:#00d1b2!important}input[type=range].slider.is-primary::-ms-track{background:#00d1b2!important}input[type=range].slider.is-primary::-ms-fill-lower{background:#00d1b2}input[type=range].slider.is-primary::-ms-fill-upper{background:#00d1b2}input[type=range].slider.is-primary .has-output-tooltip+output,input[type=range].slider.is-primary.has-output+output{background-color:#00d1b2;color:#fff}input[type=range].slider.is-link::-moz-range-track{background:#3273dc!important}input[type=range].slider.is-link::-webkit-slider-runnable-track{background:#3273dc!important}input[type=range].slider.is-link::-ms-track{background:#3273dc!important}input[type=range].slider.is-link::-ms-fill-lower{background:#3273dc}input[type=range].slider.is-link::-ms-fill-upper{background:#3273dc}input[type=range].slider.is-link .has-output-tooltip+output,input[type=range].slider.is-link.has-output+output{background-color:#3273dc;color:#fff}input[type=range].slider.is-info::-moz-range-track{background:#209cee!important}input[type=range].slider.is-info::-webkit-slider-runnable-track{background:#209cee!important}input[type=range].slider.is-info::-ms-track{background:#209cee!important}input[type=range].slider.is-info::-ms-fill-lower{background:#209cee}input[type=range].slider.is-info::-ms-fill-upper{background:#209cee}input[type=range].slider.is-info .has-output-tooltip+output,input[type=range].slider.is-info.has-output+output{background-color:#209cee;color:#fff}input[type=range].slider.is-success::-moz-range-track{background:#23d160!important}input[type=range].slider.is-success::-webkit-slider-runnable-track{background:#23d160!important}input[type=range].slider.is-success::-ms-track{background:#23d160!important}input[type=range].slider.is-success::-ms-fill-lower{background:#23d160}input[type=range].slider.is-success::-ms-fill-upper{background:#23d160}input[type=range].slider.is-success .has-output-tooltip+output,input[type=range].slider.is-success.has-output+output{background-color:#23d160;color:#fff}input[type=range].slider.is-warning::-moz-range-track{background:#ffdd57!important}input[type=range].slider.is-warning::-webkit-slider-runnable-track{background:#ffdd57!important}input[type=range].slider.is-warning::-ms-track{background:#ffdd57!important}input[type=range].slider.is-warning::-ms-fill-lower{background:#ffdd57}input[type=range].slider.is-warning::-ms-fill-upper{background:#ffdd57}input[type=range].slider.is-warning .has-output-tooltip+output,input[type=range].slider.is-warning.has-output+output{background-color:#ffdd57;color:rgba(0,0,0,.7)}input[type=range].slider.is-danger::-moz-range-track{background:#ff3860!important}input[type=range].slider.is-danger::-webkit-slider-runnable-track{background:#ff3860!important}input[type=range].slider.is-danger::-ms-track{background:#ff3860!important}input[type=range].slider.is-danger::-ms-fill-lower{background:#ff3860}input[type=range].slider.is-danger::-ms-fill-upper{background:#ff3860}input[type=range].slider.is-danger .has-output-tooltip+output,input[type=range].slider.is-danger.has-output+output{background-color:#ff3860;color:#fff} -------------------------------------------------------------------------------- /models/resnet_ap.py: -------------------------------------------------------------------------------- 1 | # Original code: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 2 | 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | import math 6 | from torch.nn.utils import spectral_norm 7 | 8 | 9 | def conv_stride1(in_planes, out_planes, kernel_size=3, norm_type="instance"): 10 | "3x3 convolution with padding" 11 | if norm_type in ["sn", "none"]: 12 | bias = True 13 | else: 14 | bias = False 15 | 16 | layer = nn.Conv2d( 17 | in_planes, 18 | out_planes, 19 | kernel_size=kernel_size, 20 | stride=1, 21 | padding=kernel_size // 2, 22 | bias=bias, 23 | ) 24 | 25 | if norm_type == "sn": 26 | return spectral_norm(layer) 27 | else: 28 | return layer 29 | 30 | 31 | class Null(nn.Module): 32 | def __init__(self): 33 | super(Null, self).__init__() 34 | 35 | def forward(self, x): 36 | return x 37 | 38 | 39 | def normalization(inplanes, norm_type): 40 | if norm_type == "batch": 41 | bn = nn.BatchNorm2d(inplanes) 42 | elif norm_type == "instance": 43 | bn = nn.GroupNorm(inplanes, inplanes) 44 | elif norm_type in ["sn", "none"]: 45 | bn = Null() 46 | else: 47 | raise AssertionError(f"Check normalization type! {norm_type}") 48 | return bn 49 | 50 | 51 | class IntroBlock(nn.Module): 52 | def __init__(self, size, planes, norm_type, nch=3): 53 | super(IntroBlock, self).__init__() 54 | self.size = size 55 | if size == "large": 56 | self.conv1 = conv_stride1(nch, planes, kernel_size=7, norm_type=norm_type) 57 | self.bn1 = normalization(planes, norm_type) 58 | self.relu = nn.ReLU(inplace=True) 59 | self.pool = nn.AvgPool2d(kernel_size=4, stride=4) 60 | elif size == "mid": 61 | self.conv1 = conv_stride1(nch, planes, kernel_size=3, norm_type=norm_type) 62 | self.bn1 = normalization(planes, norm_type) 63 | self.relu = nn.ReLU(inplace=True) 64 | self.pool = nn.AvgPool2d(kernel_size=2, stride=2) 65 | elif size == "small": 66 | self.conv1 = conv_stride1(nch, planes, kernel_size=3, norm_type=norm_type) 67 | self.bn1 = normalization(planes, norm_type) 68 | self.relu = nn.ReLU(inplace=True) 69 | else: 70 | raise AssertionError("Check network size type!") 71 | 72 | def forward(self, x): 73 | x = self.conv1(x) 74 | x = self.bn1(x) 75 | x = self.relu(x) 76 | if self.size != "small": 77 | x = self.pool(x) 78 | 79 | return x 80 | 81 | 82 | class BasicBlock(nn.Module): 83 | expansion = 1 84 | 85 | def __init__(self, inplanes, planes, norm_type="batch", stride=1, downsample=None): 86 | super(BasicBlock, self).__init__() 87 | self.conv1 = conv_stride1( 88 | inplanes, planes, kernel_size=3, norm_type=norm_type 89 | ) # Modification 90 | self.bn1 = normalization(planes, norm_type) 91 | self.conv2 = conv_stride1(planes, planes, kernel_size=3, norm_type=norm_type) 92 | self.bn2 = normalization(planes, norm_type) 93 | self.relu = nn.ReLU(inplace=True) 94 | 95 | self.downsample = downsample 96 | self.stride = stride 97 | 98 | def forward(self, x): 99 | residual = x 100 | 101 | out = self.conv1(x) 102 | out = self.bn1(out) 103 | out = self.relu(out) 104 | if self.stride != 1: # modification 105 | out = F.avg_pool2d(out, kernel_size=self.stride, stride=self.stride) 106 | 107 | out = self.conv2(out) 108 | out = self.bn2(out) 109 | 110 | if self.downsample is not None: 111 | residual = self.downsample(x) 112 | 113 | out += residual 114 | out = self.relu(out) 115 | 116 | return out 117 | 118 | 119 | class Bottleneck(nn.Module): 120 | expansion = 4 121 | 122 | def __init__(self, inplanes, planes, norm_type="batch", stride=1, downsample=None): 123 | super(Bottleneck, self).__init__() 124 | 125 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 126 | self.bn1 = normalization(planes, norm_type) 127 | self.conv2 = nn.Conv2d( 128 | planes, planes, kernel_size=3, padding=1, bias=False 129 | ) # modification 130 | self.bn2 = normalization(planes, norm_type) 131 | self.conv3 = nn.Conv2d( 132 | planes, planes * Bottleneck.expansion, kernel_size=1, bias=False 133 | ) 134 | self.bn3 = normalization(planes * Bottleneck.expansion, norm_type) 135 | self.relu = nn.ReLU(inplace=True) 136 | 137 | self.downsample = downsample 138 | self.stride = stride 139 | 140 | def forward(self, x): 141 | residual = x 142 | 143 | out = self.conv1(x) 144 | out = self.bn1(out) 145 | out = self.relu(out) 146 | 147 | out = self.conv2(out) 148 | out = self.bn2(out) 149 | out = self.relu(out) 150 | if self.stride != 1: # modification 151 | out = F.avg_pool2d(out, kernel_size=self.stride, stride=self.stride) 152 | 153 | out = self.conv3(out) 154 | out = self.bn3(out) 155 | 156 | if self.downsample is not None: 157 | residual = self.downsample(x) 158 | 159 | out += residual 160 | out = self.relu(out) 161 | 162 | return out 163 | 164 | 165 | class ResNetAP(nn.Module): 166 | def __init__( 167 | self, dataset, depth, num_classes, width=1.0, norm_type="batch", size=-1, nch=3 168 | ): 169 | super(ResNetAP, self).__init__() 170 | self.dataset = dataset 171 | self.norm_type = norm_type 172 | self.nch = nch 173 | 174 | if self.dataset.startswith("cifar") or (0 < size and size <= 64): 175 | self.net_size = "small" 176 | elif 64 < size and size <= 128: 177 | self.net_size = "mid" 178 | else: 179 | self.net_size = "large" 180 | 181 | # print(f"ResNetAP-{depth}-{self.net_size} norm: {self.norm_type}, width: {width}") 182 | if self.dataset.startswith("cifar"): 183 | self.inplanes = 32 184 | n = int((depth - 2) / 6) 185 | block = BasicBlock 186 | 187 | self.layer0 = IntroBlock(self.net_size, self.inplanes, norm_type, nch=nch) 188 | self.layer1 = self._make_layer(block, 32, n, stride=1) 189 | self.layer2 = self._make_layer(block, 64, n, stride=2) 190 | self.layer3 = self._make_layer(block, 128, n, stride=2) 191 | self.layer4 = self._make_layer(block, 256, n, stride=2) 192 | self.avgpool = nn.AvgPool2d(4) 193 | self.fc = nn.Linear(256 * block.expansion, num_classes) 194 | 195 | else: 196 | blocks = { 197 | 10: BasicBlock, 198 | 18: BasicBlock, 199 | 34: BasicBlock, 200 | 50: Bottleneck, 201 | 101: Bottleneck, 202 | 152: Bottleneck, 203 | 200: Bottleneck, 204 | } 205 | layers = { 206 | 10: [1, 1, 1, 1], 207 | 18: [2, 2, 2, 2], 208 | 34: [3, 4, 6, 3], 209 | 50: [3, 4, 6, 3], 210 | 101: [3, 4, 23, 3], 211 | 152: [3, 8, 36, 3], 212 | 200: [3, 24, 36, 3], 213 | } 214 | assert layers[depth], "invalid detph for ResNet" 215 | 216 | self.inplanes = int(64 * width) 217 | self.layer0 = IntroBlock(self.net_size, self.inplanes, norm_type, nch=nch) 218 | nc = self.inplanes 219 | self.layer1 = self._make_layer(blocks[depth], nc, layers[depth][0]) 220 | self.layer2 = self._make_layer( 221 | blocks[depth], nc * 2, layers[depth][1], stride=2 222 | ) 223 | self.layer3 = self._make_layer( 224 | blocks[depth], nc * 4, layers[depth][2], stride=2 225 | ) 226 | self.layer4 = self._make_layer( 227 | blocks[depth], nc * 8, layers[depth][3], stride=2 228 | ) 229 | self.avgpool = nn.AvgPool2d(7) 230 | self.fc = nn.Linear(self.inplanes, num_classes) 231 | 232 | for m in self.modules(): 233 | if isinstance(m, nn.Conv2d): 234 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 235 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 236 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.GroupNorm): 237 | m.weight.data.fill_(1) 238 | m.bias.data.zero_() 239 | 240 | def _make_layer(self, block, planes, blocks, stride=1): 241 | downsample = None 242 | if stride != 1 or self.inplanes != planes * block.expansion: 243 | downsample = nn.Sequential( 244 | conv_stride1( 245 | self.inplanes, 246 | planes * block.expansion, 247 | kernel_size=1, 248 | norm_type=self.norm_type, 249 | ), 250 | nn.AvgPool2d(kernel_size=stride, stride=stride), 251 | normalization(planes * block.expansion, self.norm_type), 252 | ) 253 | 254 | layers = [] 255 | layers.append( 256 | block( 257 | self.inplanes, 258 | planes, 259 | norm_type=self.norm_type, 260 | stride=stride, 261 | downsample=downsample, 262 | ) 263 | ) 264 | self.inplanes = planes * block.expansion 265 | 266 | for i in range(1, blocks): 267 | layers.append(block(self.inplanes, planes, norm_type=self.norm_type)) 268 | 269 | return nn.Sequential(*layers) 270 | 271 | def forward(self, x, return_features=False): 272 | x = self.layer0(x) 273 | x = self.layer1(x) 274 | x = self.layer2(x) 275 | x = self.layer3(x) 276 | x = self.layer4(x) 277 | 278 | x = F.avg_pool2d(x, x.shape[-1]) 279 | x = x.view(x.size(0), -1) 280 | 281 | out = self.fc(x) 282 | if return_features: 283 | return out, x 284 | else: 285 | return out 286 | 287 | def get_feature(self, x, idx_from, idx_to=-1): 288 | if idx_to == -1: 289 | idx_to = idx_from 290 | 291 | features = [] 292 | x = self.layer0(x) 293 | features.append(x) # starts from 0 294 | if idx_to < len(features): 295 | return features[idx_from : idx_to + 1] 296 | 297 | x = self.layer1(x) 298 | features.append(x) 299 | if idx_to < len(features): 300 | return features[idx_from : idx_to + 1] 301 | 302 | x = self.layer2(x) 303 | features.append(x) 304 | if idx_to < len(features): 305 | return features[idx_from : idx_to + 1] 306 | 307 | x = self.layer3(x) 308 | features.append(x) 309 | if idx_to < len(features): 310 | return features[idx_from : idx_to + 1] 311 | 312 | x = self.layer4(x) 313 | features.append(x) 314 | if idx_to < len(features): 315 | return features[idx_from : idx_to + 1] 316 | 317 | x = F.avg_pool2d(x, x.shape[-1]) 318 | x = x.view(x.size(0), -1) 319 | features.append(x) 320 | if idx_to < len(features): 321 | return features[idx_from : idx_to + 1] 322 | 323 | x = self.fc(x) 324 | features.append(x) # logit is 6 325 | return features[idx_from : idx_to + 1] 326 | 327 | def get_feature_mutil(self, x, layer_num=7): 328 | features = [] 329 | x = self.layer0(x) 330 | features.append(x.view(x.size(0), -1)) 331 | if layer_num == 1: 332 | return features 333 | 334 | x = self.layer1(x) 335 | features.append(x.view(x.size(0), -1)) 336 | if layer_num == 2: 337 | return features 338 | 339 | x = self.layer2(x) 340 | features.append(x.view(x.size(0), -1)) 341 | if layer_num == 3: 342 | return features 343 | 344 | x = self.layer3(x) 345 | features.append(x.view(x.size(0), -1)) 346 | if layer_num == 4: 347 | return features 348 | 349 | x = self.layer4(x) 350 | features.append(x.view(x.size(0), -1)) 351 | if layer_num == 5: 352 | return features 353 | 354 | x = F.avg_pool2d(x, x.shape[-1]) 355 | x = x.view(x.size(0), -1) 356 | features.append(x) 357 | if layer_num == 6: 358 | return features 359 | 360 | x = self.fc(x) 361 | features.append(x) 362 | if layer_num == 7: 363 | return features 364 | # features len == 7 ,and index 0-5 is the feature,index 6 logit is 365 | return features 366 | 367 | 368 | if __name__ == "__main__": 369 | import torch 370 | 371 | dataset = "imagenet" 372 | num_classes = 10 373 | size = int(224 * 0.5) 374 | depth = 10 375 | width = 1.0 376 | norm_type = "instance" 377 | nch = 1 378 | 379 | model = ResNetAP( 380 | dataset, 381 | depth, 382 | num_classes, 383 | size=size, 384 | width=width, 385 | norm_type=norm_type, 386 | nch=nch, 387 | ).cuda() 388 | # print(model) 389 | print( 390 | "# model parameters: {:.1f}M".format( 391 | sum([p.data.nelement() for p in model.parameters()]) / 10**6 392 | ) 393 | ) 394 | 395 | model.train() 396 | for name, param in model.named_parameters(): 397 | if len(param.shape) > 2: 398 | print(name, param.shape) 399 | # print(model) 400 | 401 | data = torch.ones([128, nch, size, size]).to("cuda") 402 | output = model(data) 403 | print(output.shape) 404 | --------------------------------------------------------------------------------