├── overview_FreD.png ├── requirements.txt ├── ImageNet-abcde ├── networks │ ├── __init__.py │ ├── alexnet.py │ ├── alexnet_cifar.py │ ├── vgg.py │ ├── vgg_cifar.py │ ├── conv_gap.py │ ├── dnfr.py │ ├── conv.py │ ├── vit_cifar.py │ ├── vit.py │ ├── resnet_cifar.py │ └── resnet.py ├── scripts │ ├── run_buffer.sh │ ├── run_DC_FreD.sh │ ├── run_DM_FreD.sh │ └── run_TM_FreD.sh ├── hyper_params.py ├── shared_args.py ├── frequency_transforms.py ├── reparam_module.py ├── glad_utils.py ├── main_DM_FreD.py ├── main_DC_FreD.py └── buffer.py ├── TM ├── scripts │ ├── run_buffer.sh │ └── run_TM_FreD.sh ├── hyper_params.py ├── frequency_transforms.py ├── buffer.py └── reparam_module.py ├── corruption-exp ├── scripts │ └── run.sh └── main.py ├── DC ├── scripts │ └── run_DC_FreD.sh ├── frequency_transforms.py └── main_DC_FreD.py ├── DM ├── scripts │ └── run_DM_FreD.sh └── frequency_transforms.py ├── 3D-MNIST ├── scripts │ └── run_DM_FreD.sh ├── networks.py ├── frequency_transforms.py ├── utils.py └── main_DM_FreD.py ├── README.md └── LICENSE /overview_FreD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdh0818/FreD/HEAD/overview_FreD.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch==1.11.0 3 | torchaudio==0.11.0 4 | torchvision==0.12.0 5 | scipy 6 | tqdm 7 | matplotlib 8 | kornia 9 | ema-pytorch 10 | einops 11 | h5py -------------------------------------------------------------------------------- /ImageNet-abcde/networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .alexnet import AlexNet 2 | from .conv import ConvNet 3 | from .dnfr import DNFR 4 | from .resnet import ResNet, ResNet18, ResNet18ImageNet 5 | from .vgg import VGG, VGG11, VGG13, VGG16, VGG19, VGG11BN 6 | from .vit import ViT 7 | from .conv_gap import ConvNetGAP 8 | from .alexnet_cifar import AlexNetCIFAR 9 | from .resnet_cifar import ResNet18CIFAR 10 | from .vgg_cifar import VGG11CIFAR 11 | from .vit_cifar import ViTCIFAR -------------------------------------------------------------------------------- /TM/scripts/run_buffer.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | cuda_id=3 3 | dst="CIFAR10" 4 | subset="None" 5 | model="ConvNetD3" 6 | data_path="../data" 7 | buffer_path="../buffers" 8 | 9 | train_epochs=50 10 | num_experts=100 11 | 12 | CUDA_VISIBLE_DEVICES=${cuda_id} python3.8 buffer.py \ 13 | --dataset=${dst} --subset ${subset} \ 14 | --model=${model} \ 15 | --data_path ${data_path} --buffer_path ${buffer_path} \ 16 | --train_epochs=${train_epochs} \ 17 | --num_experts=${num_experts} -------------------------------------------------------------------------------- /ImageNet-abcde/scripts/run_buffer.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | cuda_id=3 3 | dst="imagenet-a" 4 | res=128 5 | net="ConvNet" 6 | depth=5 7 | norm_train="instancenorm" 8 | data_path="../../../../../../data/IMAGENET2012" 9 | buffer_path="../buffers" 10 | 11 | train_epochs=50 12 | num_experts=100 13 | 14 | CUDA_VISIBLE_DEVICES=${cuda_id} python3.8 buffer.py \ 15 | --dataset=${dst} --res=${res} \ 16 | --model=${net} --depth=${depth} --norm_train=${norm_train} \ 17 | --data_path=${data_path} --buffer_path=${buffer_path} \ 18 | --train_epochs=${train_epochs} --Iteration=${num_experts} -------------------------------------------------------------------------------- /corruption-exp/scripts/run.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | cuda_id=2 3 | method="FreD" 4 | source_dataset="CIFAR10" 5 | target_dataset="CIFAR10-C" 6 | subset="none" 7 | level=5 8 | ipc=1 9 | sh_file="run.sh" 10 | data_path="../data" 11 | save_path="./results" 12 | synset_path="./trained_synset" 13 | 14 | num_eval=5 15 | epoch_eval_train=1000 16 | batch_train=256 17 | 18 | FLAG="${target_dataset}_level${level}#${subset}#${method}_ipc${ipc}" 19 | CUDA_VISIBLE_DEVICES=${cuda_id} python3.8 main.py \ 20 | --source_dataset ${source_dataset} --target_dataset ${target_dataset} --subset ${subset} \ 21 | --level ${level} \ 22 | --ipc ${ipc} \ 23 | --sh_file ${sh_file} \ 24 | --data_path ${data_path} --save_path ${save_path} --synset_path ${synset_path} \ 25 | --num_eval ${num_eval} --epoch_eval_train ${epoch_eval_train} --batch_train ${batch_train} \ 26 | --FLAG ${FLAG} 27 | -------------------------------------------------------------------------------- /DC/scripts/run_DC_FreD.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | cuda_id=0 3 | dst="CIFAR10" 4 | net="ConvNetD3" 5 | ipc=2 6 | sh_file="run_DC_FreD.sh" 7 | eval_mode="S" 8 | data_path="../data" 9 | save_path="./results" 10 | 11 | num_eval=5 12 | Iteration=1000 13 | batch_syn=0 # 0 means no sampling (use entire synthetic dataset) 14 | msz_per_channel=32 15 | lr_freq=1e3 16 | mom_freq=0.5 17 | 18 | TAG="" 19 | FLAG="${dst}_${ipc}ipc_${net}#DC_FreD_${msz_per_channel}_${batch_syn}#${Iteration}_${lr_freq}_${mom_freq}#${TAG}" 20 | 21 | CUDA_VISIBLE_DEVICES=${cuda_id} python3.8 main_DC_FreD.py \ 22 | --dataset ${dst} \ 23 | --model ${net} \ 24 | --ipc ${ipc} \ 25 | --sh_file ${sh_file} \ 26 | --eval_mode ${eval_mode} \ 27 | --data_path ${data_path} --save_path ${save_path} \ 28 | --num_eval ${num_eval} \ 29 | --Iteration ${Iteration} \ 30 | --batch_syn ${batch_syn} \ 31 | --msz_per_channel ${msz_per_channel} \ 32 | --lr_freq ${lr_freq} --mom_freq ${mom_freq} \ 33 | --FLAG ${FLAG} -------------------------------------------------------------------------------- /DM/scripts/run_DM_FreD.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | cuda_id=1 3 | dst="CIFAR10" 4 | net="ConvNetD3" 5 | ipc=2 6 | sh_file="run_DM_FreD.sh" 7 | eval_mode="S" 8 | data_path="../data" 9 | save_path="./results" 10 | 11 | num_eval=5 12 | Iteration=20000 13 | batch_syn=0 # 0 means no sampling (use entire synthetic dataset) 14 | msz_per_channel=64 15 | lr_freq=1e6 16 | mom_freq=0.5 17 | 18 | TAG="" 19 | FLAG="${dst}_${ipc}ipc_${net}#DM_FreD_${msz_per_channel}_${batch_syn}#${Iteration}_${lr_freq}_${mom_freq}#${TAG}" 20 | 21 | CUDA_VISIBLE_DEVICES=${cuda_id} python3.8 main_DM_FreD.py \ 22 | --dataset ${dst} \ 23 | --model ${net} \ 24 | --ipc ${ipc} \ 25 | --sh_file ${sh_file} \ 26 | --eval_mode ${eval_mode} \ 27 | --data_path ${data_path} --save_path ${save_path} \ 28 | --num_eval ${num_eval} \ 29 | --Iteration ${Iteration} \ 30 | --batch_syn ${batch_syn} \ 31 | --msz_per_channel ${msz_per_channel} \ 32 | --lr_freq ${lr_freq} --mom_freq ${mom_freq} \ 33 | --FLAG ${FLAG} -------------------------------------------------------------------------------- /3D-MNIST/scripts/run_DM_FreD.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | cuda_id=2 3 | dst="3D-MNIST" 4 | net="Conv3DNet" 5 | ipc=1 6 | sh_file="run_DM_FreD.sh" 7 | eval_mode="S" 8 | data_path="../data" 9 | save_path="./results" 10 | 11 | num_eval=5 12 | Iteration=1000 13 | batch_syn=0 # 0 means no sampling (use entire synthetic dataset) 14 | msz_per_channel=512 15 | lr_freq=1e6 16 | mom_freq=0.5 17 | 18 | TAG="" 19 | FLAG="${dst}_${ipc}ipc_${net}#DM_FreD_${msz_per_channel}_${batch_syn}#${Iteration}_${lr_freq}_${mom_freq}#${TAG}" 20 | 21 | CUDA_VISIBLE_DEVICES=${cuda_id} python3.8 main_DM_FreD.py \ 22 | --dataset ${dst} \ 23 | --model ${net} \ 24 | --ipc ${ipc} \ 25 | --sh_file ${sh_file} \ 26 | --eval_mode ${eval_mode} \ 27 | --data_path ${data_path} --save_path ${save_path} \ 28 | --num_eval ${num_eval} \ 29 | --Iteration ${Iteration} \ 30 | --batch_syn ${batch_syn} \ 31 | --msz_per_channel ${msz_per_channel} \ 32 | --lr_freq ${lr_freq} --mom_freq ${mom_freq} \ 33 | --FLAG ${FLAG} -------------------------------------------------------------------------------- /ImageNet-abcde/scripts/run_DC_FreD.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | cuda_id=0 3 | dst="imagenet-a" 4 | res=128 5 | net="ConvNet" 6 | depth=5 7 | ipc=1 8 | sh_file="run_DC_FreD.sh" 9 | eval_mode="M" 10 | data_path="../../../../../../data/IMAGENET2012" 11 | save_path="./results" 12 | 13 | num_eval=5 14 | Iteration=1000 15 | batch_syn=0 # 0 means no sampling (use entire synthetic dataset) 16 | msz_per_channel=2048 17 | lr_freq=1e5 18 | mom_freq=0.5 19 | 20 | TAG="" 21 | FLAG="${dst}_${res}_${ipc}ipc_${net}D${depth}#DC_FreD_${msz_per_channel}_${batch_syn}#${Iteration}_${lr_freq}_${mom_freq}#${TAG}" 22 | 23 | CUDA_VISIBLE_DEVICES=${cuda_id} python3.8 main_DC_FreD.py \ 24 | --dataset ${dst} --res ${res} \ 25 | --model ${net} --depth ${depth} \ 26 | --ipc ${ipc} \ 27 | --sh_file ${sh_file} \ 28 | --eval_mode ${eval_mode} \ 29 | --data_path ${data_path} --save_path ${save_path} \ 30 | --num_eval ${num_eval} \ 31 | --Iteration ${Iteration} \ 32 | --batch_syn ${batch_syn} \ 33 | --msz_per_channel ${msz_per_channel} \ 34 | --lr_freq ${lr_freq} --mom_freq ${mom_freq} \ 35 | --FLAG ${FLAG} -------------------------------------------------------------------------------- /ImageNet-abcde/scripts/run_DM_FreD.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | cuda_id=1 3 | dst="imagenet-a" 4 | res=128 5 | net="ConvNet" 6 | depth=5 7 | ipc=1 8 | sh_file="run_DM_FreD.sh" 9 | eval_mode="M" 10 | data_path="../../../../../../data/IMAGENET2012" 11 | save_path="./results" 12 | 13 | num_eval=5 14 | Iteration=1000 15 | batch_syn=0 # 0 means no sampling (use entire synthetic dataset) 16 | msz_per_channel=2048 17 | lr_freq=1e6 18 | mom_freq=0.5 19 | 20 | TAG="" 21 | FLAG="${dst}_${res}_${ipc}ipc_${net}D${depth}#DM_FreD_${msz_per_channel}_${batch_syn}#${Iteration}_${lr_freq}_${mom_freq}#${TAG}" 22 | 23 | CUDA_VISIBLE_DEVICES=${cuda_id} python3.8 main_DM_FreD.py \ 24 | --dataset ${dst} --res ${res} \ 25 | --model ${net} --depth ${depth} \ 26 | --ipc ${ipc} \ 27 | --sh_file ${sh_file} \ 28 | --eval_mode ${eval_mode} \ 29 | --data_path ${data_path} --save_path ${save_path} \ 30 | --num_eval ${num_eval} \ 31 | --Iteration ${Iteration} \ 32 | --batch_syn ${batch_syn} \ 33 | --msz_per_channel ${msz_per_channel} \ 34 | --lr_freq ${lr_freq} --mom_freq ${mom_freq} \ 35 | --FLAG ${FLAG} -------------------------------------------------------------------------------- /TM/scripts/run_TM_FreD.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | cuda_id=0 3 | dst="CIFAR10" 4 | subset="none" 5 | net="ConvNetD3" 6 | ipc=1 7 | sh_file="run_TM_FreD.sh" 8 | eval_mode="S" 9 | data_path="../data" 10 | save_path="./results" 11 | buffer_path="../buffers" 12 | 13 | num_eval=5 14 | Iteration=15000 15 | batch_syn=0 # 0 means no sampling (use entire synthetic dataset) 16 | msz_per_channel=64 17 | lr_freq=1e8 18 | mom_freq=0.5 19 | 20 | TAG="" 21 | FLAG="${dst}_${subset}_${ipc}ipc_${net}#TM_FreD_${msz_per_channel}_${batch_syn}#${Iteration}_${lr_freq}_${mom_freq}#${TAG}" 22 | 23 | CUDA_VISIBLE_DEVICES=${cuda_id} python3.8 main_TM_FreD.py \ 24 | --dataset ${dst} --subset ${subset} \ 25 | --model ${net} \ 26 | --ipc ${ipc} \ 27 | --sh_file ${sh_file} \ 28 | --eval_mode ${eval_mode} \ 29 | --data_path ${data_path} --save_path ${save_path} --buffer_path ${buffer_path} \ 30 | --num_eval ${num_eval} \ 31 | --Iteration ${Iteration} \ 32 | --batch_syn ${batch_syn} \ 33 | --msz_per_channel ${msz_per_channel} \ 34 | --lr_freq ${lr_freq} --mom_freq ${mom_freq} \ 35 | --FLAG ${FLAG} -------------------------------------------------------------------------------- /ImageNet-abcde/scripts/run_TM_FreD.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | cuda_id=0,1,2,3 3 | dst="imagenet-a" 4 | res=128 5 | net="ConvNet" 6 | depth=5 7 | norm_train="instancenorm" 8 | ipc=1 9 | sh_file="run_TM_FreD.sh" 10 | eval_mode="M" 11 | data_path="../../../../../../data/IMAGENET2012" 12 | save_path="./results" 13 | buffer_path="../../../../../../data/tlsehdgur0/buffers" 14 | 15 | num_eval=5 16 | eval_it=500 17 | Iteration=15000 18 | batch_syn=0 # 0 means no sampling (use entire synthetic dataset) 19 | msz_per_channel=2048 20 | lr_freq=1e9 21 | mom_freq=0.5 22 | 23 | TAG="" 24 | FLAG="${dst}_${res}_${ipc}ipc_${net}D${depth}#TM_FreD_${msz_per_channel}_${batch_syn}#${Iteration}_${lr_freq}_${mom_freq}#${TAG}" 25 | 26 | CUDA_VISIBLE_DEVICES=${cuda_id} python3.8 main_TM_FreD.py \ 27 | --dataset ${dst} --res ${res} \ 28 | --model ${net} --depth ${depth} --norm_train ${norm_train} \ 29 | --ipc ${ipc} \ 30 | --sh_file ${sh_file} \ 31 | --eval_mode ${eval_mode} \ 32 | --data_path ${data_path} --save_path ${save_path} --buffer_path ${buffer_path} \ 33 | --num_eval ${num_eval} --eval_it ${eval_it} \ 34 | --Iteration ${Iteration} \ 35 | --batch_syn ${batch_syn} \ 36 | --msz_per_channel ${msz_per_channel} \ 37 | --lr_freq ${lr_freq} --mom_freq ${mom_freq} \ 38 | --FLAG ${FLAG} -------------------------------------------------------------------------------- /ImageNet-abcde/networks/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | ''' AlexNet ''' 5 | class AlexNet(nn.Module): 6 | def __init__(self, channel, num_classes, im_size, **kwargs): 7 | super(AlexNet, self).__init__() 8 | self.features = nn.Sequential( 9 | nn.Conv2d(channel, 128, kernel_size=5, stride=1, padding=4 if channel==1 else 2), 10 | nn.ReLU(inplace=True), 11 | nn.MaxPool2d(kernel_size=2, stride=2), 12 | nn.Conv2d(128, 192, kernel_size=5, padding=2), 13 | nn.ReLU(inplace=True), 14 | nn.MaxPool2d(kernel_size=2, stride=2), 15 | nn.Conv2d(192, 256, kernel_size=3, padding=1), 16 | nn.ReLU(inplace=True), 17 | nn.Conv2d(256, 192, kernel_size=3, padding=1), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(192, 192, kernel_size=3, padding=1), 20 | nn.ReLU(inplace=True), 21 | nn.MaxPool2d(kernel_size=2, stride=2), 22 | ) 23 | self.fc = nn.Linear(192 * im_size[0]//8 * im_size[1]//8, num_classes) 24 | 25 | def forward(self, x): 26 | x = self.features(x) 27 | feat_fc = x.view(x.size(0), -1) 28 | x = self.fc(feat_fc) 29 | 30 | return x 31 | 32 | -------------------------------------------------------------------------------- /ImageNet-abcde/networks/alexnet_cifar.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class AlexNetCIFAR(nn.Module): 4 | def __init__(self, channel, num_classes): 5 | super(AlexNetCIFAR, self).__init__() 6 | self.features = nn.Sequential( 7 | nn.Conv2d(channel, 128, kernel_size=5, stride=1, padding=4 if channel==1 else 2), 8 | nn.ReLU(inplace=True), 9 | nn.MaxPool2d(kernel_size=2, stride=2), 10 | nn.Conv2d(128, 192, kernel_size=5, padding=2), 11 | nn.ReLU(inplace=True), 12 | nn.MaxPool2d(kernel_size=2, stride=2), 13 | nn.Conv2d(192, 256, kernel_size=3, padding=1), 14 | nn.ReLU(inplace=True), 15 | nn.Conv2d(256, 192, kernel_size=3, padding=1), 16 | nn.ReLU(inplace=True), 17 | nn.Conv2d(192, 192, kernel_size=3, padding=1), 18 | nn.ReLU(inplace=True), 19 | nn.MaxPool2d(kernel_size=2, stride=2), 20 | ) 21 | self.fc = nn.Linear(192 * 4 * 4, num_classes) 22 | 23 | def forward(self, x): 24 | x = self.features(x) 25 | x = x.view(x.size(0), -1) 26 | x = self.fc(x) 27 | return x 28 | 29 | def embed(self, x): 30 | x = self.features(x) 31 | x = x.view(x.size(0), -1) 32 | return x -------------------------------------------------------------------------------- /3D-MNIST/networks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class Conv3DNet(nn.Module): 4 | def __init__(self, channel, num_classes, net_width, net_depth, im_size): 5 | super(Conv3DNet, self).__init__() 6 | 7 | self.features, shape_feat = self._make_layers(channel, net_width, net_depth, im_size) 8 | num_feat = shape_feat[0] * shape_feat[1] * shape_feat[2] * shape_feat[3] 9 | self.classifier = nn.Linear(num_feat, num_classes) 10 | 11 | def _get_normlayer(self, shape_feat): 12 | return nn.InstanceNorm3d(num_features=shape_feat[0], affine=True) 13 | 14 | def _get_activation(self): 15 | return nn.ReLU(inplace=True) 16 | 17 | def _get_pooling(self): 18 | return nn.AvgPool3d(kernel_size=(2, 2, 2), stride=2) 19 | 20 | def _make_layers(self, channel, net_width, net_depth, im_size): 21 | 22 | layers = [] 23 | in_channels = channel 24 | shape_feat = [in_channels, im_size[0], im_size[1], im_size[2]] 25 | 26 | for d in range(net_depth): 27 | layers += [nn.Conv3d(in_channels, net_width, kernel_size=(3, 3, 3), padding=3 if channel == 1 and d == 0 else 1)] 28 | shape_feat[0] = net_width 29 | 30 | layers += [self._get_normlayer(shape_feat)] 31 | layers += [self._get_activation()] 32 | 33 | in_channels = net_width 34 | 35 | layers += [self._get_pooling()] 36 | shape_feat[1] //= 2 37 | shape_feat[2] //= 2 38 | shape_feat[3] //= 2 39 | 40 | return nn.Sequential(*layers), shape_feat 41 | 42 | def forward(self, x): 43 | out = self.features(x) 44 | out = out.view(out.size(0), -1) 45 | out = self.classifier(out) 46 | return out 47 | 48 | def embed(self, x): 49 | out = self.features(x) 50 | out = out.view(out.size(0), -1) 51 | return out 52 | 53 | -------------------------------------------------------------------------------- /ImageNet-abcde/networks/vgg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | 6 | ''' VGG ''' 7 | cfg_vgg = { 8 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 9 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 10 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 11 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 12 | } 13 | class VGG(nn.Module): 14 | def __init__(self, vgg_name, channel, num_classes, norm='instancenorm'): 15 | super(VGG, self).__init__() 16 | self.channel = channel 17 | self.features = self._make_layers(cfg_vgg[vgg_name], norm) 18 | self.classifier = nn.Linear(512*7*7 if vgg_name != 'VGGS' else 128, num_classes) 19 | 20 | def forward(self, x): 21 | x = self.features(x) 22 | feat_fc = x.view(x.size(0), -1) 23 | x = self.classifier(feat_fc) 24 | 25 | return x 26 | 27 | def _make_layers(self, cfg, norm): 28 | layers = [] 29 | in_channels = self.channel 30 | for ic, x in enumerate(cfg): 31 | if x == 'M': 32 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 33 | else: 34 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=3 if self.channel==1 and ic==0 else 1), 35 | nn.GroupNorm(x, x, affine=True) if norm=='instancenorm' else nn.BatchNorm2d(x) if norm == 'batch' else nn.Identity(), 36 | nn.ReLU(inplace=True)] 37 | in_channels = x 38 | layers += [nn.AdaptiveMaxPool2d((7, 7))] 39 | return nn.Sequential(*layers) 40 | 41 | 42 | def VGG11(channel, num_classes, **kwargs): 43 | return VGG('VGG11', channel, num_classes) 44 | def VGG11BN(channel, num_classes): 45 | return VGG('VGG11', channel, num_classes, norm='batchnorm') 46 | def VGG13(channel, num_classes): 47 | return VGG('VGG13', channel, num_classes) 48 | def VGG16(channel, num_classes): 49 | return VGG('VGG16', channel, num_classes) 50 | def VGG19(channel, num_classes): 51 | return VGG('VGG19', channel, num_classes) -------------------------------------------------------------------------------- /ImageNet-abcde/networks/vgg_cifar.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | ''' VGG ''' 4 | cfg_vgg = { 5 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 6 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 7 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 8 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 9 | } 10 | class VGGCIFAR(nn.Module): 11 | def __init__(self, vgg_name, channel, num_classes, norm='instancenorm'): 12 | super(VGGCIFAR, self).__init__() 13 | self.channel = channel 14 | self.features = self._make_layers(cfg_vgg[vgg_name], norm) 15 | self.classifier = nn.Linear(512 if vgg_name != 'VGGS' else 128, num_classes) 16 | 17 | def forward(self, x): 18 | x = self.features(x) 19 | x = x.view(x.size(0), -1) 20 | x = self.classifier(x) 21 | return x 22 | 23 | def embed(self, x): 24 | x = self.features(x) 25 | x = x.view(x.size(0), -1) 26 | return x 27 | 28 | def _make_layers(self, cfg, norm): 29 | layers = [] 30 | in_channels = self.channel 31 | for ic, x in enumerate(cfg): 32 | if x == 'M': 33 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 34 | else: 35 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=3 if self.channel==1 and ic==0 else 1), 36 | nn.GroupNorm(x, x, affine=True) if norm=='instancenorm' else nn.BatchNorm2d(x), 37 | nn.ReLU(inplace=True)] 38 | in_channels = x 39 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 40 | return nn.Sequential(*layers) 41 | 42 | 43 | def VGG11CIFAR(channel, num_classes): 44 | return VGGCIFAR('VGG11', channel, num_classes) 45 | def VGG11BNCIFAR(channel, num_classes): 46 | return VGGCIFAR('VGG11', channel, num_classes, norm='batchnorm') 47 | def VGG13CIFAR(channel, num_classes): 48 | return VGGCIFAR('VGG13', channel, num_classes) 49 | def VGG16CIFAR(channel, num_classes): 50 | return VGGCIFAR('VGG16', channel, num_classes) 51 | def VGG19CIFAR(channel, num_classes): 52 | return VGGCIFAR('VGG19', channel, num_classes) -------------------------------------------------------------------------------- /TM/hyper_params.py: -------------------------------------------------------------------------------- 1 | SYN_STEPS = {"MNIST": {1: 50, 10: 30}, "FashionMNIST": {1: 50, 10: 60}, "SVHN": {1: 50, 10: 30, 50: 40}, 2 | "CIFAR10": {1: 50, 2: 50, 10: 40, 11: 40, 50: 30, 51: 30}, 3 | "CIFAR100": {1: 50, 10: 20, 50: 80}, "Tiny": {1: 30, 10: 40, 50: 40}, "ImageNet": {1: 20, 2: 20, 10: 20}} 4 | 5 | EXPERT_EPOCHS = {"MNIST": {1: 2, 10: 2}, "FashionMNIST": {1: 2, 10: 2}, "SVHN": {1: 2, 10: 2, 50: 2}, 6 | "CIFAR10": {1: 2, 2: 2, 10: 2, 11: 2, 50: 2, 51: 2}, 7 | "CIFAR100": {1: 2, 10: 2, 50: 2}, "Tiny": {1: 2, 10: 2, 50: 2}, "ImageNet": {1: 2, 2: 2, 10: 2}} 8 | 9 | MAX_START_EPOCH = {"MNIST": {1: 5, 10: 15}, "FashionMNIST": {1: 5, 10: 15}, "SVHN": {1: 5, 10: 15, 50: 40}, 10 | "CIFAR10": {1: 5, 2: 5, 10: 15, 11: 15, 50: 40, 51: 40}, 11 | "CIFAR100": {1: 15, 10: 40, 50: 40}, "Tiny": {1: 30, 10: 40, 50: 40}, "ImageNet": {1: 10, 2: 10, 10: 10}} 12 | 13 | LR_LR = {"MNIST": {1: 1e-7, 10: 1e-5}, "FashionMNIST": {1: 1e-7, 10: 1e-5}, "SVHN": {1: 1e-7, 10: 1e-5, 50: 1e-5}, 14 | "CIFAR10": {1: 1e-7, 2: 1e-7, 10: 1e-5, 11: 1e-5, 50: 1e-5, 51: 1e-5}, 15 | "CIFAR100": {1: 1e-5, 10: 1e-5, 50: 1e-5}, "Tiny": {1: 1e-4, 10: 1e-4, 50: 1e-4}, "ImageNet": {1: 1e-6, 2: 1e-6, 10: 1e-6}} 16 | 17 | LR_TEACHER = {"MNIST": {1: 1e-2, 10: 1e-2}, "FashionMNIST": {1: 1e-2, 10: 1e-2}, "SVHN": {1: 1e-2, 10: 1e-2, 50: 1e-3}, 18 | "CIFAR10": {1: 1e-2, 2: 1e-2, 10: 1e-2, 11: 1e-2, 50: 1e-3, 51: 1e-3}, 19 | "CIFAR100": {1: 1e-2, 10: 1e-2, 50: 1e-2}, "Tiny": {1: 1e-2, 10: 1e-2, 50: 1e-2}, "ImageNet": {1: 1e-2, 2: 1e-2, 10: 1e-2}} 20 | 21 | 22 | def load_default(args): 23 | if args.zca: 24 | exit("Default FreD does not use ZCA") 25 | 26 | dataset = args.dataset 27 | 28 | if args.ipc in [1, 2, 10, 11, 50, 51]: 29 | ipc = args.ipc 30 | else: 31 | exit("Undefined IPC") 32 | 33 | if args.syn_steps == None: 34 | args.syn_steps = SYN_STEPS[dataset][ipc] 35 | 36 | if args.expert_epochs == None: 37 | args.expert_epochs = EXPERT_EPOCHS[dataset][ipc] 38 | 39 | if args.max_start_epoch == None: 40 | args.max_start_epoch = MAX_START_EPOCH[dataset][ipc] 41 | 42 | if args.lr_lr == None: 43 | args.lr_lr = LR_LR[dataset][ipc] 44 | 45 | if args.lr_teacher == None: 46 | args.lr_teacher = LR_TEACHER[dataset][ipc] 47 | return args -------------------------------------------------------------------------------- /ImageNet-abcde/hyper_params.py: -------------------------------------------------------------------------------- 1 | SYN_STEPS = {"MNIST": {1: 50, 10: 30}, "FashionMNIST": {1: 50, 10: 60}, "SVHN": {1: 50, 10: 30, 50: 40}, 2 | "CIFAR10": {1: 50, 2: 50, 10: 40, 11: 40, 50: 30, 51: 30}, 3 | "CIFAR100": {1: 50, 10: 20, 50: 80}, "Tiny": {1: 30, 10: 40, 50: 40}, "ImageNet": {1: 20, 2: 20, 10: 20}} 4 | 5 | EXPERT_EPOCHS = {"MNIST": {1: 2, 10: 2}, "FashionMNIST": {1: 2, 10: 2}, "SVHN": {1: 2, 10: 2, 50: 2}, 6 | "CIFAR10": {1: 2, 2: 2, 10: 2, 11: 2, 50: 2, 51: 2}, 7 | "CIFAR100": {1: 2, 10: 2, 50: 2}, "Tiny": {1: 2, 10: 2, 50: 2}, "ImageNet": {1: 2, 2: 2, 10: 2}} 8 | 9 | MAX_START_EPOCH = {"MNIST": {1: 5, 10: 15}, "FashionMNIST": {1: 5, 10: 15}, "SVHN": {1: 5, 10: 15, 50: 40}, 10 | "CIFAR10": {1: 5, 2: 5, 10: 15, 11: 15, 50: 40, 51: 40}, 11 | "CIFAR100": {1: 15, 10: 40, 50: 40}, "Tiny": {1: 30, 10: 40, 50: 40}, "ImageNet": {1: 10, 2: 10, 10: 10}} 12 | 13 | LR_LR = {"MNIST": {1: 1e-7, 10: 1e-5}, "FashionMNIST": {1: 1e-7, 10: 1e-5}, "SVHN": {1: 1e-7, 10: 1e-5, 50: 1e-5}, 14 | "CIFAR10": {1: 1e-7, 2: 1e-7, 10: 1e-5, 11: 1e-5, 50: 1e-5, 51: 1e-5}, 15 | "CIFAR100": {1: 1e-5, 10: 1e-5, 50: 1e-5}, "Tiny": {1: 1e-4, 10: 1e-4, 50: 1e-4}, "ImageNet": {1: 1e-6, 2: 1e-6, 10: 1e-6}} 16 | 17 | LR_TEACHER = {"MNIST": {1: 1e-2, 10: 1e-2}, "FashionMNIST": {1: 1e-2, 10: 1e-2}, "SVHN": {1: 1e-2, 10: 1e-2, 50: 1e-3}, 18 | "CIFAR10": {1: 1e-2, 2: 1e-2, 10: 1e-2, 11: 1e-2, 50: 1e-3, 51: 1e-3}, 19 | "CIFAR100": {1: 1e-2, 10: 1e-2, 50: 1e-2}, "Tiny": {1: 1e-2, 10: 1e-2, 50: 1e-2}, "ImageNet": {1: 1e-2, 2: 1e-2, 10: 1e-2}} 20 | 21 | 22 | def load_default(args): 23 | if args.zca: 24 | exit("Default FreD does not use ZCA") 25 | 26 | dataset = args.dataset 27 | if dataset.startswith("imagenet"): 28 | dataset = "ImageNet" 29 | 30 | if args.ipc in [1, 2, 10, 11, 50, 51]: 31 | ipc = args.ipc 32 | else: 33 | exit("Undefined IPC") 34 | 35 | if args.syn_steps == None: 36 | args.syn_steps = SYN_STEPS[dataset][ipc] 37 | 38 | if args.expert_epochs == None: 39 | args.expert_epochs = EXPERT_EPOCHS[dataset][ipc] 40 | 41 | if args.max_start_epoch == None: 42 | args.max_start_epoch = MAX_START_EPOCH[dataset][ipc] 43 | 44 | if args.lr_lr == None: 45 | args.lr_lr = LR_LR[dataset][ipc] 46 | 47 | if args.lr_teacher == None: 48 | args.lr_teacher = LR_TEACHER[dataset][ipc] 49 | return args -------------------------------------------------------------------------------- /ImageNet-abcde/networks/conv_gap.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | ''' ConvNetGAP ''' 5 | class ConvNetGAP(nn.Module): 6 | def __init__(self, channel, num_classes, net_width, net_depth, net_act, net_norm, net_pooling='avgpooling', im_size = (32,32)): 7 | super(ConvNetGAP, self).__init__() 8 | 9 | self.features, shape_feat = self._make_layers(channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size) 10 | num_feat = shape_feat[0] 11 | self.classifier = nn.Linear(num_feat, num_classes) 12 | self.pool = nn.AdaptiveAvgPool2d((1, 1)) 13 | 14 | def forward(self, x): 15 | # print("MODEL DATA ON: ", x.get_device(), "MODEL PARAMS ON: ", self.classifier.weight.data.get_device()) 16 | out = self.features(x) 17 | out = self.pool(out) 18 | out = out.view(out.size(0), -1) 19 | out = self.classifier(out) 20 | return out 21 | 22 | def _get_activation(self, net_act): 23 | if net_act == 'sigmoid': 24 | return nn.Sigmoid() 25 | elif net_act == 'relu': 26 | return nn.ReLU(inplace=True) 27 | elif net_act == 'leakyrelu': 28 | return nn.LeakyReLU(negative_slope=0.01) 29 | else: 30 | exit('unknown activation function: %s'%net_act) 31 | 32 | def _get_pooling(self, net_pooling): 33 | if net_pooling == 'maxpooling': 34 | return nn.MaxPool2d(kernel_size=2, stride=2) 35 | elif net_pooling == 'avgpooling': 36 | return nn.AvgPool2d(kernel_size=2, stride=2) 37 | elif net_pooling == 'none': 38 | return None 39 | else: 40 | exit('unknown net_pooling: %s'%net_pooling) 41 | 42 | def _get_normlayer(self, net_norm, shape_feat): 43 | # shape_feat = (c*h*w) 44 | if net_norm == 'batchnorm': 45 | return nn.BatchNorm2d(shape_feat[0], affine=True) 46 | elif net_norm == 'layernorm': 47 | return nn.LayerNorm(shape_feat, elementwise_affine=True) 48 | elif net_norm == 'instancenorm': 49 | return nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True) 50 | elif net_norm == 'groupnorm': 51 | return nn.GroupNorm(4, shape_feat[0], affine=True) 52 | elif net_norm == 'none': 53 | return None 54 | else: 55 | exit('unknown net_norm: %s'%net_norm) 56 | 57 | def _make_layers(self, channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size): 58 | layers = [] 59 | in_channels = channel 60 | if im_size[0] == 28: 61 | im_size = (32, 32) 62 | shape_feat = [in_channels, im_size[0], im_size[1]] 63 | for d in range(net_depth): 64 | layers += [nn.Conv2d(in_channels, net_width, kernel_size=3, padding=3 if channel == 1 and d == 0 else 1)] 65 | shape_feat[0] = net_width 66 | if net_norm != 'none': 67 | layers += [self._get_normlayer(net_norm, shape_feat)] 68 | layers += [self._get_activation(net_act)] 69 | in_channels = net_width 70 | if net_pooling != 'none': 71 | layers += [self._get_pooling(net_pooling)] 72 | shape_feat[1] //= 2 73 | shape_feat[2] //= 2 74 | net_width *= 2 75 | 76 | 77 | return nn.Sequential(*layers), shape_feat -------------------------------------------------------------------------------- /ImageNet-abcde/networks/dnfr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | ''' DNFR_Net ''' 5 | class DNFR(nn.Module): 6 | def __init__(self, channel, num_classes, net_width, net_depth, net_act, net_norm, net_pooling='avgpooling', im_size=(32,32)): 7 | super(DNFR, self).__init__() 8 | 9 | self.features, shape_feat = self._make_layers(channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size) 10 | num_feat = shape_feat[0]*shape_feat[1]*shape_feat[2] 11 | self.classifier = nn.Linear(num_feat, num_classes) 12 | 13 | def forward(self, x): 14 | # print("MODEL DATA ON: ", x.get_device(), "MODEL PARAMS ON: ", self.classifier.weight.data.get_device()) 15 | out = self.features(x) 16 | out = out.view(out.size(0), -1) 17 | out = self.classifier(out) 18 | return out 19 | 20 | def _get_activation(self, net_act): 21 | if net_act == 'sigmoid': 22 | return nn.Sigmoid() 23 | elif net_act == 'relu': 24 | return nn.ReLU(inplace=True) 25 | elif net_act == 'leakyrelu': 26 | return nn.LeakyReLU(negative_slope=0.01) 27 | else: 28 | exit('unknown activation function: %s'%net_act) 29 | 30 | def _get_pooling(self, net_pooling): 31 | if net_pooling == 'maxpooling': 32 | return nn.MaxPool2d(kernel_size=2, stride=2) 33 | elif net_pooling == 'avgpooling': 34 | return nn.AvgPool2d(kernel_size=2, stride=2) 35 | elif net_pooling == 'none': 36 | return None 37 | else: 38 | exit('unknown net_pooling: %s'%net_pooling) 39 | 40 | def _get_normlayer(self, net_norm, shape_feat): 41 | # shape_feat = (c*h*w) 42 | if net_norm == 'batchnorm': 43 | return nn.BatchNorm2d(shape_feat[0], affine=True) 44 | elif net_norm == 'layernorm': 45 | return nn.LayerNorm(shape_feat, elementwise_affine=True) 46 | elif net_norm == 'instancenorm': 47 | return nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True) 48 | elif net_norm == 'groupnorm': 49 | return nn.GroupNorm(4, shape_feat[0], affine=True) 50 | elif net_norm == 'none': 51 | return None 52 | else: 53 | exit('unknown net_norm: %s'%net_norm) 54 | 55 | def _make_layers(self, channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size): 56 | layers = [] 57 | in_channels = channel 58 | if im_size[0] == 28: 59 | im_size = (32, 32) 60 | shape_feat = [in_channels, im_size[0], im_size[1]] 61 | for d in range(net_depth): 62 | if net_norm != 'none': 63 | if d == 0 and net_norm == 'groupnorm': 64 | layers += [self._get_normlayer('instancenorm', shape_feat)] 65 | else: 66 | layers += [self._get_normlayer(net_norm, shape_feat)] 67 | layers += [nn.Conv2d(in_channels, net_width, kernel_size=3, padding=3 if channel == 1 and d == 0 else 1)] 68 | shape_feat[0] = net_width 69 | layers += [self._get_activation(net_act)] 70 | in_channels = net_width 71 | if net_pooling != 'none': 72 | layers += [self._get_pooling(net_pooling)] 73 | shape_feat[1] //= 2 74 | shape_feat[2] //= 2 75 | 76 | 77 | return nn.Sequential(*layers), shape_feat -------------------------------------------------------------------------------- /ImageNet-abcde/shared_args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def add_shared_args(): 4 | parser = argparse.ArgumentParser(description='Parameter Processing') 5 | parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset') 6 | parser.add_argument('--model', type=str, default='ConvNet', help='model') 7 | parser.add_argument('--ipc', type=int, default=1, help='image(s) per class') 8 | parser.add_argument('--eval_mode', type=str, default='M', help='eval_mode') # S: the same to training model, M: multi architectures 9 | parser.add_argument('--num_eval', type=int, default=5, help='the number of evaluating randomly initialized models') 10 | parser.add_argument('--eval_it', type=int, default=100, help='how often to evaluate') 11 | parser.add_argument('--save_it', type=int, default=None, help='how often to evaluate') 12 | parser.add_argument('--epoch_eval_train', type=int, default=1000, help='epochs to train a model with synthetic data') 13 | parser.add_argument('--Iteration', type=int, default=1000, help='training iterations') 14 | 15 | parser.add_argument('--mom_img', type=float, default=0.5, help='momentum for updating synthetic images') 16 | 17 | parser.add_argument('--batch_real', type=int, default=256, help='batch size for real data') 18 | parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks') 19 | parser.add_argument('--batch_test', type=int, default=128, help='batch size for training networks') 20 | 21 | parser.add_argument('--pix_init', type=str, default='noise', choices=["noise", "real"], help='noise/real: initialize synthetic images from random noise or randomly sampled real images.') 22 | parser.add_argument('--dsa', type=str, default='True', choices=['True', 'False'], help='whether to use differentiable Siamese augmentation.') 23 | parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', help='differentiable Siamese augmentation strategy') 24 | parser.add_argument('--data_path', type=str, default='data', help='dataset path') 25 | 26 | parser.add_argument('--save_path', type=str, default='result', help='path to save results') 27 | 28 | parser.add_argument('--space', type=str, default='p', choices=['p', 'wp']) 29 | parser.add_argument('--res', type=int, default=128, choices=[128, 256, 512], help='resolution') 30 | parser.add_argument('--layer', type=int, default=12) 31 | parser.add_argument('--avg_w', action='store_true') 32 | 33 | parser.add_argument('--eval_all', action='store_true') 34 | 35 | parser.add_argument('--min_it', type=bool, default=False) 36 | parser.add_argument('--no_aug', type=bool, default=False) 37 | 38 | parser.add_argument('--force_save', action='store_true') 39 | 40 | parser.add_argument('--sg_batch', type=int, default=10) 41 | 42 | parser.add_argument('--rand_f', action='store_true') 43 | 44 | parser.add_argument('--logdir', type=str, default='./logged_files') 45 | 46 | parser.add_argument('--wait_eval', action='store_true') 47 | 48 | parser.add_argument('--idc_factor', type=int, default=1) 49 | 50 | parser.add_argument('--rand_gan_un', action='store_true') 51 | parser.add_argument('--rand_gan_con', action='store_true') 52 | 53 | parser.add_argument('--learn_g', action='store_true') 54 | 55 | parser.add_argument('--width', type=int, default=128) 56 | parser.add_argument('--depth', type=int, default=5) 57 | 58 | parser.add_argument('--norm_train', type=str, default="batchnorm") 59 | 60 | parser.add_argument('--special_gan', default=None) 61 | 62 | return parser 63 | -------------------------------------------------------------------------------- /ImageNet-abcde/networks/conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | ''' ConvNet ''' 5 | class ConvNet(nn.Module): 6 | def __init__(self, channel, num_classes, net_width, net_depth, net_act, net_norm, net_pooling='avgpooling', im_size = (32,32)): 7 | super(ConvNet, self).__init__() 8 | 9 | self.features, shape_feat = self._make_layers(channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size) 10 | num_feat = shape_feat[0]*shape_feat[1]*shape_feat[2] 11 | self.classifier = nn.Linear(num_feat, num_classes) 12 | 13 | def forward(self, x): 14 | # print("MODEL DATA ON: ", x.get_device(), "MODEL PARAMS ON: ", self.classifier.weight.data.get_device()) 15 | out = self.features(x) 16 | out = out.view(out.size(0), -1) 17 | out = self.classifier(out) 18 | return out 19 | 20 | 21 | def forward_cafe(self, x): 22 | f_maps = [] 23 | x = self.features[0:3](x) 24 | f_maps.append(x) 25 | 26 | for i in range(start=3, stop=-1, step=4): 27 | x = self.features[i:i+4](x) 28 | f_maps.append(x) 29 | 30 | out = self.features[-1:](x) 31 | f_maps.append(out) 32 | out = out.view(out.size(0), -1) 33 | out_final = self.classifier(out) 34 | return out_final, out, f_maps 35 | 36 | def embed(self, x): 37 | out = self.features(x) 38 | out = out.view(out.size(0), -1) 39 | return out 40 | 41 | def _get_activation(self, net_act): 42 | if net_act == 'sigmoid': 43 | return nn.Sigmoid() 44 | elif net_act == 'relu': 45 | return nn.ReLU(inplace=True) 46 | elif net_act == 'leakyrelu': 47 | return nn.LeakyReLU(negative_slope=0.01) 48 | else: 49 | exit('unknown activation function: %s'%net_act) 50 | 51 | def _get_pooling(self, net_pooling): 52 | if net_pooling == 'maxpooling': 53 | return nn.MaxPool2d(kernel_size=2, stride=2) 54 | elif net_pooling == 'avgpooling': 55 | return nn.AvgPool2d(kernel_size=2, stride=2) 56 | elif net_pooling == 'none': 57 | return None 58 | else: 59 | exit('unknown net_pooling: %s'%net_pooling) 60 | 61 | def _get_normlayer(self, net_norm, shape_feat): 62 | # shape_feat = (c*h*w) 63 | if net_norm == 'batchnorm': 64 | return nn.BatchNorm2d(shape_feat[0], affine=True) 65 | elif net_norm == 'layernorm': 66 | return nn.LayerNorm(shape_feat, elementwise_affine=True) 67 | elif net_norm == 'instancenorm': 68 | return nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True) 69 | elif net_norm == 'groupnorm': 70 | return nn.GroupNorm(4, shape_feat[0], affine=True) 71 | elif net_norm == 'none': 72 | return None 73 | else: 74 | exit('unknown net_norm: %s'%net_norm) 75 | 76 | def _make_layers(self, channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size): 77 | layers = [] 78 | in_channels = channel 79 | if im_size[0] == 28: 80 | im_size = (32, 32) 81 | shape_feat = [in_channels, im_size[0], im_size[1]] 82 | for d in range(net_depth): 83 | layers += [nn.Conv2d(in_channels, net_width, kernel_size=3, padding=3 if channel == 1 and d == 0 else 1)] 84 | shape_feat[0] = net_width 85 | if net_norm != 'none': 86 | layers += [self._get_normlayer(net_norm, shape_feat)] 87 | layers += [self._get_activation(net_act)] 88 | in_channels = net_width 89 | if net_pooling != 'none': 90 | layers += [self._get_pooling(net_pooling)] 91 | shape_feat[1] //= 2 92 | shape_feat[2] //= 2 93 | 94 | 95 | return nn.Sequential(*layers), shape_feat -------------------------------------------------------------------------------- /DC/frequency_transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference: https://github.com/zh217/torch-dct 3 | """ 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | 9 | class DCT(): 10 | def __init__(self, resolution, device, norm=None, bias=False): 11 | self.resolution = resolution 12 | self.norm = norm 13 | self.device = device 14 | 15 | I = torch.eye(self.resolution, device=self.device) 16 | self.forward_transform = nn.Linear(resolution, resolution, bias=bias).to(self.device) 17 | self.forward_transform.weight.data = self._dct(I, norm=self.norm).data.t() 18 | self.forward_transform.weight.requires_grad = False 19 | 20 | self.inverse_transform = nn.Linear(resolution, resolution, bias=bias).to(self.device) 21 | self.inverse_transform.weight.data = self._idct(I, norm=self.norm).data.t() 22 | self.inverse_transform.weight.requires_grad = False 23 | 24 | def _dct(self, x, norm=None): 25 | """ 26 | Discrete Cosine Transform, Type II (a.k.a. the DCT) 27 | For the meaning of the parameter `norm`, see: 28 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 29 | :param x: the input signal 30 | :param norm: the normalization, None or 'ortho' 31 | :return: the DCT-II of the signal over the last dimension 32 | """ 33 | x_shape = x.shape 34 | N = x_shape[-1] 35 | x = x.contiguous().view(-1, N) 36 | 37 | v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1) 38 | 39 | Vc = torch.view_as_real(torch.fft.fft(v, dim=1)) 40 | 41 | k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N) 42 | W_r = torch.cos(k) 43 | W_i = torch.sin(k) 44 | 45 | V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i 46 | 47 | if norm == 'ortho': 48 | V[:, 0] /= np.sqrt(N) * 2 49 | V[:, 1:] /= np.sqrt(N / 2) * 2 50 | 51 | V = 2 * V.view(*x_shape) 52 | return V 53 | 54 | def _idct(self, X, norm=None): 55 | """ 56 | The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III 57 | Our definition of idct is that idct(dct(x)) == x 58 | For the meaning of the parameter `norm`, see: 59 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 60 | :param X: the input signal 61 | :param norm: the normalization, None or 'ortho' 62 | :return: the inverse DCT-II of the signal over the last dimension 63 | """ 64 | 65 | x_shape = X.shape 66 | N = x_shape[-1] 67 | 68 | X_v = X.contiguous().view(-1, x_shape[-1]) / 2 69 | 70 | if norm == 'ortho': 71 | X_v[:, 0] *= np.sqrt(N) * 2 72 | X_v[:, 1:] *= np.sqrt(N / 2) * 2 73 | 74 | k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N) 75 | W_r = torch.cos(k) 76 | W_i = torch.sin(k) 77 | 78 | V_t_r = X_v 79 | V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1) 80 | 81 | V_r = V_t_r * W_r - V_t_i * W_i 82 | V_i = V_t_r * W_i + V_t_i * W_r 83 | 84 | V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2) 85 | 86 | v = torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1) 87 | x = v.new_zeros(v.shape) 88 | x[:, ::2] += v[:, :N - (N // 2)] 89 | x[:, 1::2] += v.flip([1])[:, :N // 2] 90 | return x.view(*x_shape) 91 | 92 | def forward(self, x): 93 | X1 = self.forward_transform(x) 94 | X2 = self.forward_transform(X1.transpose(-1, -2)) 95 | return X2.transpose(-1, -2) 96 | 97 | def inverse(self, x): 98 | X1 = self.inverse_transform(x) 99 | X2 = self.inverse_transform(X1.transpose(-1, -2)) 100 | return X2.transpose(-1, -2) -------------------------------------------------------------------------------- /DM/frequency_transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference: https://github.com/zh217/torch-dct 3 | """ 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | 9 | class DCT(): 10 | def __init__(self, resolution, device, norm=None, bias=False): 11 | self.resolution = resolution 12 | self.norm = norm 13 | self.device = device 14 | 15 | I = torch.eye(self.resolution, device=self.device) 16 | self.forward_transform = nn.Linear(resolution, resolution, bias=bias).to(self.device) 17 | self.forward_transform.weight.data = self._dct(I, norm=self.norm).data.t() 18 | self.forward_transform.weight.requires_grad = False 19 | 20 | self.inverse_transform = nn.Linear(resolution, resolution, bias=bias).to(self.device) 21 | self.inverse_transform.weight.data = self._idct(I, norm=self.norm).data.t() 22 | self.inverse_transform.weight.requires_grad = False 23 | 24 | def _dct(self, x, norm=None): 25 | """ 26 | Discrete Cosine Transform, Type II (a.k.a. the DCT) 27 | For the meaning of the parameter `norm`, see: 28 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 29 | :param x: the input signal 30 | :param norm: the normalization, None or 'ortho' 31 | :return: the DCT-II of the signal over the last dimension 32 | """ 33 | x_shape = x.shape 34 | N = x_shape[-1] 35 | x = x.contiguous().view(-1, N) 36 | 37 | v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1) 38 | 39 | Vc = torch.view_as_real(torch.fft.fft(v, dim=1)) 40 | 41 | k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N) 42 | W_r = torch.cos(k) 43 | W_i = torch.sin(k) 44 | 45 | V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i 46 | 47 | if norm == 'ortho': 48 | V[:, 0] /= np.sqrt(N) * 2 49 | V[:, 1:] /= np.sqrt(N / 2) * 2 50 | 51 | V = 2 * V.view(*x_shape) 52 | return V 53 | 54 | def _idct(self, X, norm=None): 55 | """ 56 | The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III 57 | Our definition of idct is that idct(dct(x)) == x 58 | For the meaning of the parameter `norm`, see: 59 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 60 | :param X: the input signal 61 | :param norm: the normalization, None or 'ortho' 62 | :return: the inverse DCT-II of the signal over the last dimension 63 | """ 64 | 65 | x_shape = X.shape 66 | N = x_shape[-1] 67 | 68 | X_v = X.contiguous().view(-1, x_shape[-1]) / 2 69 | 70 | if norm == 'ortho': 71 | X_v[:, 0] *= np.sqrt(N) * 2 72 | X_v[:, 1:] *= np.sqrt(N / 2) * 2 73 | 74 | k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N) 75 | W_r = torch.cos(k) 76 | W_i = torch.sin(k) 77 | 78 | V_t_r = X_v 79 | V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1) 80 | 81 | V_r = V_t_r * W_r - V_t_i * W_i 82 | V_i = V_t_r * W_i + V_t_i * W_r 83 | 84 | V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2) 85 | 86 | v = torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1) 87 | x = v.new_zeros(v.shape) 88 | x[:, ::2] += v[:, :N - (N // 2)] 89 | x[:, 1::2] += v.flip([1])[:, :N // 2] 90 | return x.view(*x_shape) 91 | 92 | def forward(self, x): 93 | X1 = self.forward_transform(x) 94 | X2 = self.forward_transform(X1.transpose(-1, -2)) 95 | return X2.transpose(-1, -2) 96 | 97 | def inverse(self, x): 98 | X1 = self.inverse_transform(x) 99 | X2 = self.inverse_transform(X1.transpose(-1, -2)) 100 | return X2.transpose(-1, -2) -------------------------------------------------------------------------------- /TM/frequency_transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference: https://github.com/zh217/torch-dct 3 | """ 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | 9 | class DCT(): 10 | def __init__(self, resolution, device, norm=None, bias=False): 11 | self.resolution = resolution 12 | self.norm = norm 13 | self.device = device 14 | 15 | I = torch.eye(self.resolution, device=self.device) 16 | self.forward_transform = nn.Linear(resolution, resolution, bias=bias).to(self.device) 17 | self.forward_transform.weight.data = self._dct(I, norm=self.norm).data.t() 18 | self.forward_transform.weight.requires_grad = False 19 | 20 | self.inverse_transform = nn.Linear(resolution, resolution, bias=bias).to(self.device) 21 | self.inverse_transform.weight.data = self._idct(I, norm=self.norm).data.t() 22 | self.inverse_transform.weight.requires_grad = False 23 | 24 | def _dct(self, x, norm=None): 25 | """ 26 | Discrete Cosine Transform, Type II (a.k.a. the DCT) 27 | For the meaning of the parameter `norm`, see: 28 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 29 | :param x: the input signal 30 | :param norm: the normalization, None or 'ortho' 31 | :return: the DCT-II of the signal over the last dimension 32 | """ 33 | x_shape = x.shape 34 | N = x_shape[-1] 35 | x = x.contiguous().view(-1, N) 36 | 37 | v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1) 38 | 39 | Vc = torch.view_as_real(torch.fft.fft(v, dim=1)) 40 | 41 | k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N) 42 | W_r = torch.cos(k) 43 | W_i = torch.sin(k) 44 | 45 | V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i 46 | 47 | if norm == 'ortho': 48 | V[:, 0] /= np.sqrt(N) * 2 49 | V[:, 1:] /= np.sqrt(N / 2) * 2 50 | 51 | V = 2 * V.view(*x_shape) 52 | return V 53 | 54 | def _idct(self, X, norm=None): 55 | """ 56 | The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III 57 | Our definition of idct is that idct(dct(x)) == x 58 | For the meaning of the parameter `norm`, see: 59 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 60 | :param X: the input signal 61 | :param norm: the normalization, None or 'ortho' 62 | :return: the inverse DCT-II of the signal over the last dimension 63 | """ 64 | 65 | x_shape = X.shape 66 | N = x_shape[-1] 67 | 68 | X_v = X.contiguous().view(-1, x_shape[-1]) / 2 69 | 70 | if norm == 'ortho': 71 | X_v[:, 0] *= np.sqrt(N) * 2 72 | X_v[:, 1:] *= np.sqrt(N / 2) * 2 73 | 74 | k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N) 75 | W_r = torch.cos(k) 76 | W_i = torch.sin(k) 77 | 78 | V_t_r = X_v 79 | V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1) 80 | 81 | V_r = V_t_r * W_r - V_t_i * W_i 82 | V_i = V_t_r * W_i + V_t_i * W_r 83 | 84 | V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2) 85 | 86 | v = torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1) 87 | x = v.new_zeros(v.shape) 88 | x[:, ::2] += v[:, :N - (N // 2)] 89 | x[:, 1::2] += v.flip([1])[:, :N // 2] 90 | return x.view(*x_shape) 91 | 92 | def forward(self, x): 93 | X1 = self.forward_transform(x) 94 | X2 = self.forward_transform(X1.transpose(-1, -2)) 95 | return X2.transpose(-1, -2) 96 | 97 | def inverse(self, x): 98 | X1 = self.inverse_transform(x) 99 | X2 = self.inverse_transform(X1.transpose(-1, -2)) 100 | return X2.transpose(-1, -2) -------------------------------------------------------------------------------- /ImageNet-abcde/frequency_transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference: https://github.com/zh217/torch-dct 3 | """ 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | 9 | class DCT(): 10 | def __init__(self, resolution, device, norm=None, bias=False): 11 | self.resolution = resolution 12 | self.norm = norm 13 | self.device = device 14 | 15 | I = torch.eye(self.resolution, device=self.device) 16 | self.forward_transform = nn.Linear(resolution, resolution, bias=bias).to(self.device) 17 | self.forward_transform.weight.data = self._dct(I, norm=self.norm).data.t() 18 | self.forward_transform.weight.requires_grad = False 19 | 20 | self.inverse_transform = nn.Linear(resolution, resolution, bias=bias).to(self.device) 21 | self.inverse_transform.weight.data = self._idct(I, norm=self.norm).data.t() 22 | self.inverse_transform.weight.requires_grad = False 23 | 24 | def _dct(self, x, norm=None): 25 | """ 26 | Discrete Cosine Transform, Type II (a.k.a. the DCT) 27 | For the meaning of the parameter `norm`, see: 28 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 29 | :param x: the input signal 30 | :param norm: the normalization, None or 'ortho' 31 | :return: the DCT-II of the signal over the last dimension 32 | """ 33 | x_shape = x.shape 34 | N = x_shape[-1] 35 | x = x.contiguous().view(-1, N) 36 | 37 | v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1) 38 | 39 | Vc = torch.view_as_real(torch.fft.fft(v, dim=1)) 40 | 41 | k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N) 42 | W_r = torch.cos(k) 43 | W_i = torch.sin(k) 44 | 45 | V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i 46 | 47 | if norm == 'ortho': 48 | V[:, 0] /= np.sqrt(N) * 2 49 | V[:, 1:] /= np.sqrt(N / 2) * 2 50 | 51 | V = 2 * V.view(*x_shape) 52 | return V 53 | 54 | def _idct(self, X, norm=None): 55 | """ 56 | The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III 57 | Our definition of idct is that idct(dct(x)) == x 58 | For the meaning of the parameter `norm`, see: 59 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 60 | :param X: the input signal 61 | :param norm: the normalization, None or 'ortho' 62 | :return: the inverse DCT-II of the signal over the last dimension 63 | """ 64 | 65 | x_shape = X.shape 66 | N = x_shape[-1] 67 | 68 | X_v = X.contiguous().view(-1, x_shape[-1]) / 2 69 | 70 | if norm == 'ortho': 71 | X_v[:, 0] *= np.sqrt(N) * 2 72 | X_v[:, 1:] *= np.sqrt(N / 2) * 2 73 | 74 | k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N) 75 | W_r = torch.cos(k) 76 | W_i = torch.sin(k) 77 | 78 | V_t_r = X_v 79 | V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1) 80 | 81 | V_r = V_t_r * W_r - V_t_i * W_i 82 | V_i = V_t_r * W_i + V_t_i * W_r 83 | 84 | V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2) 85 | 86 | v = torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1) 87 | x = v.new_zeros(v.shape) 88 | x[:, ::2] += v[:, :N - (N // 2)] 89 | x[:, 1::2] += v.flip([1])[:, :N // 2] 90 | return x.view(*x_shape) 91 | 92 | def forward(self, x): 93 | X1 = self.forward_transform(x) 94 | X2 = self.forward_transform(X1.transpose(-1, -2)) 95 | return X2.transpose(-1, -2) 96 | 97 | def inverse(self, x): 98 | X1 = self.inverse_transform(x) 99 | X2 = self.inverse_transform(X1.transpose(-1, -2)) 100 | return X2.transpose(-1, -2) -------------------------------------------------------------------------------- /3D-MNIST/frequency_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | class DCT(): 6 | def __init__(self, resolution, device, norm=None, bias=False): 7 | self.resolution = resolution 8 | self.norm = norm 9 | self.device = device 10 | 11 | I = torch.eye(self.resolution, device=self.device) 12 | self.forward_transform = nn.Linear(resolution, resolution, bias=bias).to(self.device) 13 | self.forward_transform.weight.data = self._dct(I, norm=self.norm).data.t() 14 | self.forward_transform.weight.requires_grad = False 15 | 16 | self.inverse_transform = nn.Linear(resolution, resolution, bias=bias).to(self.device) 17 | self.inverse_transform.weight.data = self._idct(I, norm=self.norm).data.t() 18 | self.inverse_transform.weight.requires_grad = False 19 | 20 | def _dct(self, x, norm=None): 21 | """ 22 | Discrete Cosine Transform, Type II (a.k.a. the DCT) 23 | For the meaning of the parameter `norm`, see: 24 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 25 | :param x: the input signal 26 | :param norm: the normalization, None or 'ortho' 27 | :return: the DCT-II of the signal over the last dimension 28 | """ 29 | x_shape = x.shape 30 | N = x_shape[-1] 31 | x = x.contiguous().view(-1, N) 32 | 33 | v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1) 34 | 35 | Vc = torch.view_as_real(torch.fft.fft(v, dim=1)) 36 | 37 | k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N) 38 | W_r = torch.cos(k) 39 | W_i = torch.sin(k) 40 | 41 | V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i 42 | 43 | if norm == 'ortho': 44 | V[:, 0] /= np.sqrt(N) * 2 45 | V[:, 1:] /= np.sqrt(N / 2) * 2 46 | 47 | V = 2 * V.view(*x_shape) 48 | return V 49 | 50 | def _idct(self, X, norm=None): 51 | """ 52 | The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III 53 | Our definition of idct is that idct(dct(x)) == x 54 | For the meaning of the parameter `norm`, see: 55 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 56 | :param X: the input signal 57 | :param norm: the normalization, None or 'ortho' 58 | :return: the inverse DCT-II of the signal over the last dimension 59 | """ 60 | 61 | x_shape = X.shape 62 | N = x_shape[-1] 63 | 64 | X_v = X.contiguous().view(-1, x_shape[-1]) / 2 65 | 66 | if norm == 'ortho': 67 | X_v[:, 0] *= np.sqrt(N) * 2 68 | X_v[:, 1:] *= np.sqrt(N / 2) * 2 69 | 70 | k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N) 71 | W_r = torch.cos(k) 72 | W_i = torch.sin(k) 73 | 74 | V_t_r = X_v 75 | V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1) 76 | 77 | V_r = V_t_r * W_r - V_t_i * W_i 78 | V_i = V_t_r * W_i + V_t_i * W_r 79 | 80 | V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2) 81 | 82 | v = torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1) 83 | x = v.new_zeros(v.shape) 84 | x[:, ::2] += v[:, :N - (N // 2)] 85 | x[:, 1::2] += v.flip([1])[:, :N // 2] 86 | return x.view(*x_shape) 87 | 88 | def forward(self, x): 89 | X1 = self.forward_transform(x) # (bsz, C, D, H, W) 90 | X2 = self.forward_transform(X1.transpose(-1, -2)) # (bsz, C, D, W, H) 91 | X3 = self.forward_transform(X2.transpose(-1, -3)) # (bsz, C, H, W, D) 92 | return X3.permute(0, 1, -1, -3, -2) 93 | 94 | def inverse(self, x): 95 | X1 = self.inverse_transform(x) # (bsz, C, D, H, W) 96 | X2 = self.inverse_transform(X1.transpose(-1, -2)) # (bsz, C, D, W, H) 97 | X3 = self.inverse_transform(X2.transpose(-1, -3)) # (bsz, C, H, W, D) 98 | return X3.permute(0, 1, -1, -3, -2) -------------------------------------------------------------------------------- /ImageNet-abcde/networks/vit_cifar.py: -------------------------------------------------------------------------------- 1 | # https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from einops import rearrange, repeat 7 | from einops.layers.torch import Rearrange 8 | 9 | # helpers 10 | 11 | def pair(t): 12 | return t if isinstance(t, tuple) else (t, t) 13 | 14 | # classes 15 | 16 | class PreNorm(nn.Module): 17 | def __init__(self, dim, fn): 18 | super().__init__() 19 | self.norm = nn.LayerNorm(dim) 20 | self.fn = fn 21 | def forward(self, x, **kwargs): 22 | return self.fn(self.norm(x), **kwargs) 23 | 24 | class FeedForward(nn.Module): 25 | def __init__(self, dim, hidden_dim, dropout = 0.): 26 | super().__init__() 27 | self.net = nn.Sequential( 28 | nn.Linear(dim, hidden_dim), 29 | nn.GELU(), 30 | nn.Dropout(dropout), 31 | nn.Linear(hidden_dim, dim), 32 | nn.Dropout(dropout) 33 | ) 34 | def forward(self, x): 35 | return self.net(x) 36 | 37 | class Attention(nn.Module): 38 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 39 | super().__init__() 40 | inner_dim = dim_head * heads 41 | project_out = not (heads == 1 and dim_head == dim) 42 | 43 | self.heads = heads 44 | self.scale = dim_head ** -0.5 45 | 46 | self.attend = nn.Softmax(dim = -1) 47 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 48 | 49 | self.to_out = nn.Sequential( 50 | nn.Linear(inner_dim, dim), 51 | nn.Dropout(dropout) 52 | ) if project_out else nn.Identity() 53 | 54 | def forward(self, x): 55 | qkv = self.to_qkv(x).chunk(3, dim = -1) 56 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 57 | 58 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 59 | 60 | attn = self.attend(dots) 61 | 62 | out = torch.matmul(attn, v) 63 | out = rearrange(out, 'b h n d -> b n (h d)') 64 | return self.to_out(out) 65 | 66 | class Transformer(nn.Module): 67 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 68 | super().__init__() 69 | self.layers = nn.ModuleList([]) 70 | for _ in range(depth): 71 | self.layers.append(nn.ModuleList([ 72 | PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), 73 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) 74 | ])) 75 | def forward(self, x): 76 | for attn, ff in self.layers: 77 | x = attn(x) + x 78 | x = ff(x) + x 79 | return x 80 | 81 | class ViTCIFAR(nn.Module): 82 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): 83 | super().__init__() 84 | image_height, image_width = pair(image_size) 85 | patch_height, patch_width = pair(patch_size) 86 | 87 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 88 | 89 | num_patches = (image_height // patch_height) * (image_width // patch_width) 90 | patch_dim = channels * patch_height * patch_width 91 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 92 | 93 | self.to_patch_embedding = nn.Sequential( 94 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), 95 | nn.Linear(patch_dim, dim), 96 | ) 97 | 98 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) 99 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 100 | self.dropout = nn.Dropout(emb_dropout) 101 | 102 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 103 | 104 | self.pool = pool 105 | self.to_latent = nn.Identity() 106 | 107 | self.mlp_head = nn.Sequential( 108 | nn.LayerNorm(dim), 109 | nn.Linear(dim, num_classes) 110 | ) 111 | 112 | def forward(self, img): 113 | x = self.to_patch_embedding(img) 114 | b, n, _ = x.shape 115 | 116 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) 117 | x = torch.cat((cls_tokens, x), dim=1) 118 | x += self.pos_embedding[:, :(n + 1)] 119 | x = self.dropout(x) 120 | 121 | x = self.transformer(x) 122 | 123 | x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] 124 | 125 | x = self.to_latent(x) 126 | return self.mlp_head(x) -------------------------------------------------------------------------------- /ImageNet-abcde/networks/vit.py: -------------------------------------------------------------------------------- 1 | # https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from einops import rearrange, repeat 7 | from einops.layers.torch import Rearrange 8 | 9 | # helpers 10 | 11 | def pair(t): 12 | return t if isinstance(t, tuple) else (t, t) 13 | 14 | # classes 15 | 16 | class PreNorm(nn.Module): 17 | def __init__(self, dim, fn): 18 | super().__init__() 19 | # self.norm = nn.LayerNorm(dim) 20 | self.fn = fn 21 | def forward(self, x, **kwargs): 22 | # return self.fn(self.norm(x), **kwargs) 23 | return self.fn(x, **kwargs) 24 | 25 | class FeedForward(nn.Module): 26 | def __init__(self, dim, hidden_dim, dropout = 0.): 27 | super().__init__() 28 | self.net = nn.Sequential( 29 | nn.Linear(dim, hidden_dim), 30 | nn.GELU(), 31 | nn.Dropout(dropout), 32 | nn.Linear(hidden_dim, dim), 33 | nn.Dropout(dropout) 34 | ) 35 | def forward(self, x): 36 | return self.net(x) 37 | 38 | class Attention(nn.Module): 39 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 40 | super().__init__() 41 | inner_dim = dim_head * heads 42 | project_out = not (heads == 1 and dim_head == dim) 43 | 44 | self.heads = heads 45 | self.scale = dim_head ** -0.5 46 | 47 | self.attend = nn.Softmax(dim = -1) 48 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 49 | 50 | self.to_out = nn.Sequential( 51 | nn.Linear(inner_dim, dim), 52 | nn.Dropout(dropout) 53 | ) if project_out else nn.Identity() 54 | 55 | def forward(self, x): 56 | qkv = self.to_qkv(x).chunk(3, dim = -1) 57 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 58 | 59 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 60 | 61 | attn = self.attend(dots) 62 | 63 | out = torch.matmul(attn, v) 64 | out = rearrange(out, 'b h n d -> b n (h d)') 65 | return self.to_out(out) 66 | 67 | class Transformer(nn.Module): 68 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 69 | super().__init__() 70 | self.layers = nn.ModuleList([]) 71 | for _ in range(depth): 72 | self.layers.append(nn.ModuleList([ 73 | PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), 74 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) 75 | ])) 76 | def forward(self, x): 77 | for attn, ff in self.layers: 78 | x = attn(x) + x 79 | x = ff(x) + x 80 | return x 81 | 82 | class ViT(nn.Module): 83 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): 84 | super().__init__() 85 | image_height, image_width = pair(image_size) 86 | patch_height, patch_width = pair(patch_size) 87 | 88 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 89 | 90 | num_patches = (image_height // patch_height) * (image_width // patch_width) 91 | patch_dim = channels * patch_height * patch_width 92 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 93 | 94 | self.to_patch_embedding = nn.Sequential( 95 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), 96 | nn.Linear(patch_dim, dim), 97 | ) 98 | 99 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) 100 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 101 | self.dropout = nn.Dropout(emb_dropout) 102 | 103 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 104 | 105 | self.pool = pool 106 | self.to_latent = nn.Identity() 107 | 108 | self.mlp_head = nn.Sequential( 109 | # nn.LayerNorm(dim), 110 | nn.Linear(dim, num_classes) 111 | ) 112 | 113 | def forward(self, img): 114 | x = self.to_patch_embedding(img) 115 | b, n, _ = x.shape 116 | 117 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) 118 | x = torch.cat((cls_tokens, x), dim=1) 119 | x += self.pos_embedding[:, :(n + 1)] 120 | x = self.dropout(x) 121 | 122 | x = self.transformer(x) 123 | 124 | x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] 125 | 126 | x = self.to_latent(x) 127 | 128 | feat_fc = x 129 | 130 | x = self.mlp_head(x) 131 | 132 | return x -------------------------------------------------------------------------------- /ImageNet-abcde/networks/resnet_cifar.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | ''' ResNet ''' 5 | 6 | class BasicBlock(nn.Module): 7 | expansion = 1 8 | 9 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'): 10 | super(BasicBlock, self).__init__() 11 | self.norm = norm 12 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 13 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 14 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 15 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 16 | 17 | self.shortcut = nn.Sequential() 18 | if stride != 1 or in_planes != self.expansion*planes: 19 | self.shortcut = nn.Sequential( 20 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 21 | nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes) 22 | ) 23 | 24 | def forward(self, x): 25 | out = F.relu(self.bn1(self.conv1(x))) 26 | out = self.bn2(self.conv2(out)) 27 | out += self.shortcut(x) 28 | out = F.relu(out) 29 | return out 30 | 31 | 32 | class Bottleneck(nn.Module): 33 | expansion = 4 34 | 35 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'): 36 | super(Bottleneck, self).__init__() 37 | self.norm = norm 38 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 39 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 40 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 41 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 42 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 43 | self.bn3 = nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes) 44 | 45 | self.shortcut = nn.Sequential() 46 | if stride != 1 or in_planes != self.expansion*planes: 47 | self.shortcut = nn.Sequential( 48 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 49 | nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes) 50 | ) 51 | 52 | def forward(self, x): 53 | out = F.relu(self.bn1(self.conv1(x))) 54 | out = F.relu(self.bn2(self.conv2(out))) 55 | out = self.bn3(self.conv3(out)) 56 | out += self.shortcut(x) 57 | out = F.relu(out) 58 | return out 59 | 60 | 61 | class ResNetCIFAR(nn.Module): 62 | def __init__(self, block, num_blocks, channel=3, num_classes=10, norm='instancenorm'): 63 | super(ResNetCIFAR, self).__init__() 64 | self.in_planes = 64 65 | self.norm = norm 66 | 67 | self.conv1 = nn.Conv2d(channel, 64, kernel_size=3, stride=1, padding=1, bias=False) 68 | self.bn1 = nn.GroupNorm(64, 64, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(64) 69 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 70 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 71 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 72 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 73 | self.classifier = nn.Linear(512*block.expansion, num_classes) 74 | 75 | def _make_layer(self, block, planes, num_blocks, stride): 76 | strides = [stride] + [1]*(num_blocks-1) 77 | layers = [] 78 | for stride in strides: 79 | layers.append(block(self.in_planes, planes, stride, self.norm)) 80 | self.in_planes = planes * block.expansion 81 | return nn.Sequential(*layers) 82 | 83 | def forward(self, x): 84 | out = F.relu(self.bn1(self.conv1(x))) 85 | out = self.layer1(out) 86 | out = self.layer2(out) 87 | out = self.layer3(out) 88 | out = self.layer4(out) 89 | out = F.avg_pool2d(out, 4) 90 | out = out.view(out.size(0), -1) 91 | out = self.classifier(out) 92 | return out 93 | 94 | def embed(self, x): 95 | out = F.relu(self.bn1(self.conv1(x))) 96 | out = self.layer1(out) 97 | out = self.layer2(out) 98 | out = self.layer3(out) 99 | out = self.layer4(out) 100 | out = F.avg_pool2d(out, 4) 101 | out = out.view(out.size(0), -1) 102 | return out 103 | 104 | 105 | def ResNet18BNCIFAR(channel, num_classes): 106 | return ResNetCIFAR(BasicBlock, [2,2,2,2], channel=channel, num_classes=num_classes, norm='batchnorm') 107 | 108 | def ResNet18CIFAR(channel, num_classes): 109 | return ResNetCIFAR(BasicBlock, [2,2,2,2], channel=channel, num_classes=num_classes) 110 | 111 | def ResNet34CIFAR(channel, num_classes): 112 | return ResNetCIFAR(BasicBlock, [3,4,6,3], channel=channel, num_classes=num_classes) 113 | 114 | def ResNet50CIFAR(channel, num_classes): 115 | return ResNetCIFAR(Bottleneck, [3,4,6,3], channel=channel, num_classes=num_classes) 116 | 117 | def ResNet101CIFAR(channel, num_classes): 118 | return ResNetCIFAR(Bottleneck, [3,4,23,3], channel=channel, num_classes=num_classes) 119 | 120 | def ResNet152CIFAR(channel, num_classes): 121 | return ResNetCIFAR(Bottleneck, [3,8,36,3], channel=channel, num_classes=num_classes) 122 | -------------------------------------------------------------------------------- /corruption-exp/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 3 | 4 | import argparse 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.utils.data 9 | import time 10 | import shutil 11 | from utils import set_seed, get_source_dataset, get_target_dataset, get_network, ParamDiffAug, get_daparam, save_and_print, TensorDataset, epoch, get_time 12 | 13 | np.set_printoptions(linewidth=250, suppress=True, precision=4) 14 | 15 | common_corruptions = ['gaussian_noise', 'shot_noise', 'impulse_noise', 16 | 'defocus_blur', 'glass_blur', 'motion_blur', 'zoom_blur', 17 | 'snow', 'frost', 'fog', 'brightness', 18 | 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression'] 19 | 20 | def main(args): 21 | 22 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 23 | args.dsa_param = ParamDiffAug() 24 | 25 | if args.method == "DC": 26 | args.dsa = False 27 | else: 28 | args.dsa = True 29 | 30 | if args.dsa: 31 | args.dc_aug_param = None 32 | else: 33 | args.dc_aug_param = get_daparam(args.source_dataset, "", "", args.ipc) 34 | 35 | save_and_print(args.log_path, f'Hyper-parameters: {args.__dict__}') 36 | 37 | ''' organize the source dataset ''' 38 | dst_train, images_train, labels_train = get_source_dataset(args) 39 | 40 | if args.target_dataset[-1] == "C": # CIFAR-10-C / ImageNet-Subset-C 41 | total_accs = np.zeros((args.num_eval, len(common_corruptions))) 42 | else: # CIFAR-10.1 43 | total_accs = np.zeros(args.num_eval) 44 | 45 | for idx_eval in range(args.num_eval): 46 | ''' Train the network from scratch''' 47 | net = get_network(args).to(args.device) 48 | images_train = images_train.to(args.device) 49 | labels_train = labels_train.to(args.device) 50 | lr = float(args.lr_net) 51 | Epoch = int(args.epoch_eval_train) 52 | lr_schedule = [Epoch // 2 + 1] 53 | optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005) 54 | 55 | criterion = nn.CrossEntropyLoss().to(args.device) 56 | 57 | dst_train = TensorDataset(images_train, labels_train) 58 | trainloader = torch.utils.data.DataLoader(dst_train, batch_size=args.batch_train, shuffle=True, num_workers=0) 59 | 60 | start = time.time() 61 | 62 | for ep in range(Epoch + 1): 63 | loss_train, acc_train = epoch('train', trainloader, net, optimizer, criterion, args, aug=True) 64 | if ep in lr_schedule: 65 | lr *= 0.1 66 | optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005) 67 | 68 | time_train = time.time() - start 69 | save_and_print(args.log_path, '%s Evaluate_%02d: epoch = %04d train time = %d s train loss = %.6f train acc = %.4f \n' % (get_time(), idx_eval, Epoch, int(time_train), loss_train, acc_train)) 70 | 71 | ''' Test on the target dataset ''' 72 | optimizer = None 73 | criterion = nn.CrossEntropyLoss().to(args.device) 74 | 75 | if args.target_dataset[-1] == "C": # CIFAR-10-C / ImageNet-Subset-C 76 | for idx_corruption, corruption in enumerate(common_corruptions): 77 | args.corruption = corruption 78 | dst_test, testloader = get_target_dataset(args) 79 | with torch.no_grad(): 80 | _, acc_test = epoch('test', testloader, net, optimizer, criterion, args, aug=False) 81 | 82 | total_accs[idx_eval, idx_corruption] = acc_test 83 | else: # CIFAR-10.1 84 | dst_test, testloader = get_target_dataset(args) 85 | with torch.no_grad(): 86 | _, acc_test = epoch('test', testloader, net, optimizer, criterion, args, aug=False) 87 | 88 | total_accs[idx_eval] = acc_test 89 | 90 | save_and_print(args.log_path, '%s Evaluate_%02d: %s \n' % (get_time(), idx_eval, total_accs[idx_eval])) 91 | torch.save(net, f"{args.save_path}/net#{idx_eval}.pt") 92 | 93 | ''' Visualize and Save ''' 94 | save_and_print(args.log_path , "=" * 100) 95 | save_and_print(args.log_path, f"Average across num_eval: {np.average(total_accs, axis=0)}") 96 | save_and_print(args.log_path, f"Average: {np.average(total_accs)}") 97 | torch.save(total_accs, f"{args.save_path}/total_accs.pt") 98 | 99 | if __name__ == "__main__": 100 | parser = argparse.ArgumentParser() 101 | parser.add_argument('--method', type=str, default="FreD") 102 | parser.add_argument("--source_dataset", default='CIFAR10', type=str) 103 | parser.add_argument("--target_dataset", default='CIFAR10-C', type=str) 104 | parser.add_argument('--subset', type=str, default='imagenette', help='ImageNet subset. This only does anything when --dataset=ImageNet') 105 | parser.add_argument("--level", default=1, type=int) 106 | parser.add_argument('--ipc', type=int, default=1, help='image(s) per class') 107 | parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', help='differentiable Siamese augmentation strategy') 108 | parser.add_argument('--num_eval', type=int, default=5, help='the number of evaluating randomly initialized models') 109 | parser.add_argument('--epoch_eval_train', type=int, default=1000, help='epochs to train a model with synthetic data') 110 | parser.add_argument('--lr_net', type=float, default=0.01, help='learning rate for updating network parameters') 111 | parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks') 112 | parser.add_argument('--data_path', type=str, default='./data', help='dataset path') 113 | parser.add_argument('--save_path', type=str, default='./results', help='path to save results') 114 | parser.add_argument('--synset_path', type=str, default='./trained_synset', help='trained synthetic dataset path') 115 | 116 | parser.add_argument('--seed', type=int, default=0) 117 | parser.add_argument('--sh_file', type=str) 118 | parser.add_argument('--FLAG', type=str, default="TEST") 119 | args = parser.parse_args() 120 | set_seed(args.seed) 121 | 122 | if not os.path.exists(args.save_path): 123 | os.mkdir(args.save_path) 124 | args.save_path = args.save_path + f"/{args.FLAG}" 125 | if not os.path.exists(args.save_path): 126 | os.mkdir(args.save_path) 127 | 128 | shutil.copy(f"./scripts/{args.sh_file}", f"{args.save_path}/{args.sh_file}") 129 | args.log_path = f"{args.save_path}/log.txt" 130 | 131 | main(args) 132 | -------------------------------------------------------------------------------- /TM/buffer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | from tqdm import tqdm 6 | from utils import get_dataset, get_network, get_daparam, TensorDataset, epoch, ParamDiffAug, set_seed, save_and_print 7 | 8 | import warnings 9 | warnings.filterwarnings("ignore", category=DeprecationWarning) 10 | 11 | def main(args): 12 | args.dsa = True if args.dsa == 'True' else False 13 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 14 | args.dsa_param = ParamDiffAug() 15 | 16 | save_dir = os.path.join(args.buffer_path, args.dataset) 17 | if args.dataset == "ImageNet": 18 | save_dir = os.path.join(save_dir, args.subset) 19 | if not args.zca: 20 | save_dir += "_NO_ZCA" 21 | save_dir = os.path.join(save_dir, args.model) 22 | if not os.path.exists(save_dir): 23 | os.makedirs(save_dir) 24 | 25 | args.log_path = f"{args.buffer_path}/log#{args.dataset}" 26 | if args.dataset == "ImageNet": 27 | args.log_path += f"_{args.subset}" 28 | if not args.zca: 29 | args.log_path += "_NO_ZCA" 30 | args.log_path += f"_{args.model}.txt" 31 | 32 | channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader, loader_train_dict, class_map, class_map_inv = get_dataset(args.dataset, args.data_path, args.batch_real, args.subset, args=args) 33 | 34 | save_and_print(args.log_path, f'Hyper-parameters: {args.__dict__}') 35 | 36 | ''' organize the real dataset ''' 37 | images_all = [] 38 | labels_all = [] 39 | indices_class = [[] for c in range(num_classes)] 40 | save_and_print(args.log_path, "BUILDING DATASET") 41 | for i in tqdm(range(len(dst_train))): 42 | sample = dst_train[i] 43 | images_all.append(torch.unsqueeze(sample[0], dim=0)) 44 | labels_all.append(class_map[torch.tensor(sample[1]).item()]) 45 | 46 | for i, lab in tqdm(enumerate(labels_all)): 47 | indices_class[lab].append(i) 48 | images_all = torch.cat(images_all, dim=0).to("cpu") 49 | labels_all = torch.tensor(labels_all, dtype=torch.long, device="cpu") 50 | 51 | for c in range(num_classes): 52 | save_and_print(args.log_path, 'class c = %d: %d real images'%(c, len(indices_class[c]))) 53 | 54 | for ch in range(channel): 55 | save_and_print(args.log_path, 'real images channel %d, mean = %.4f, std = %.4f'%(ch, torch.mean(images_all[:, ch]), torch.std(images_all[:, ch]))) 56 | 57 | criterion = nn.CrossEntropyLoss().to(args.device) 58 | 59 | trajectories = [] 60 | 61 | dst_train = TensorDataset(images_all, labels_all) 62 | trainloader = torch.utils.data.DataLoader(dst_train, batch_size=args.batch_train, shuffle=True, num_workers=0) 63 | 64 | ''' set augmentation for whole-dataset training ''' 65 | args.dc_aug_param = get_daparam(args.dataset, args.model, args.model, None) 66 | args.dc_aug_param['strategy'] = 'crop_scale_rotate' # for whole-dataset training 67 | save_and_print(args.log_path, f'DC augmentation parameters: {args.dc_aug_param}') 68 | 69 | for it in range(0, args.num_experts): 70 | 71 | ''' Train synthetic data ''' 72 | teacher_net = get_network(args.model, channel, num_classes, im_size).to(args.device) # get a random model 73 | teacher_net.train() 74 | lr = args.lr_teacher 75 | teacher_optim = torch.optim.SGD(teacher_net.parameters(), lr=lr, momentum=args.mom, weight_decay=args.l2) # optimizer_img for synthetic data 76 | teacher_optim.zero_grad() 77 | 78 | timestamps = [] 79 | 80 | timestamps.append([p.detach().cpu() for p in teacher_net.parameters()]) 81 | 82 | lr_schedule = [args.train_epochs // 2 + 1] 83 | 84 | for e in range(args.train_epochs): 85 | 86 | train_loss, train_acc = epoch("train", dataloader=trainloader, net=teacher_net, optimizer=teacher_optim, 87 | criterion=criterion, args=args, aug=True) 88 | 89 | test_loss, test_acc = epoch("test", dataloader=testloader, net=teacher_net, optimizer=None, 90 | criterion=criterion, args=args, aug=False) 91 | 92 | save_and_print(args.log_path, "Itr: {}\tEpoch: {}\tTrain Acc: {}\tTest Acc: {}".format(it, e, train_acc, test_acc)) 93 | 94 | timestamps.append([p.detach().cpu() for p in teacher_net.parameters()]) 95 | 96 | if e in lr_schedule and args.decay: 97 | lr *= 0.1 98 | teacher_optim = torch.optim.SGD(teacher_net.parameters(), lr=lr, momentum=args.mom, weight_decay=args.l2) 99 | teacher_optim.zero_grad() 100 | 101 | trajectories.append(timestamps) 102 | 103 | if len(trajectories) == args.save_interval: 104 | n = 0 105 | while os.path.exists(os.path.join(save_dir, "replay_buffer_{}.pt".format(n))): 106 | n += 1 107 | save_and_print(args.log_path, "Saving {}".format(os.path.join(save_dir, "replay_buffer_{}.pt".format(n)))) 108 | torch.save(trajectories, os.path.join(save_dir, "replay_buffer_{}.pt".format(n))) 109 | trajectories = [] 110 | 111 | 112 | if __name__ == '__main__': 113 | parser = argparse.ArgumentParser(description='Parameter Processing') 114 | parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset') 115 | parser.add_argument('--subset', type=str, default='imagenette', help='subset') 116 | parser.add_argument('--model', type=str, default='ConvNet', help='model') 117 | parser.add_argument('--num_experts', type=int, default=100, help='training iterations') 118 | parser.add_argument('--lr_teacher', type=float, default=0.01, help='learning rate for updating network parameters') 119 | parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks') 120 | parser.add_argument('--batch_real', type=int, default=256, help='batch size for real loader') 121 | parser.add_argument('--batch_test', type=int, default=128, help='batch size for real loader') 122 | parser.add_argument('--dsa', type=str, default='True', choices=['True', 'False'], help='whether to use differentiable Siamese augmentation.') 123 | parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', help='differentiable Siamese augmentation strategy') 124 | parser.add_argument('--data_path', type=str, default='../data', help='dataset path') 125 | parser.add_argument('--buffer_path', type=str, default='../buffers', help='buffer path') 126 | parser.add_argument('--train_epochs', type=int, default=50) 127 | parser.add_argument('--zca', action='store_true') 128 | parser.add_argument('--decay', action='store_true') 129 | parser.add_argument('--mom', type=float, default=0, help='momentum') 130 | parser.add_argument('--l2', type=float, default=0, help='l2 regularization') 131 | parser.add_argument('--save_interval', type=int, default=10) 132 | 133 | parser.add_argument('--res', type=str, default='') 134 | 135 | parser.add_argument('--seed', type=int, default=0) 136 | args = parser.parse_args() 137 | set_seed(args.seed) 138 | 139 | main(args) 140 | 141 | 142 | -------------------------------------------------------------------------------- /ImageNet-abcde/reparam_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import warnings 4 | import types 5 | from collections import namedtuple 6 | from contextlib import contextmanager 7 | 8 | 9 | class ReparamModule(nn.Module): 10 | def _get_module_from_name(self, mn): 11 | if mn == '': 12 | return self 13 | m = self 14 | for p in mn.split('.'): 15 | m = getattr(m, p) 16 | return m 17 | 18 | def __init__(self, module): 19 | super(ReparamModule, self).__init__() 20 | self.module = module 21 | 22 | param_infos = [] # (module name/path, param name) 23 | shared_param_memo = {} 24 | shared_param_infos = [] # (module name/path, param name, src module name/path, src param_name) 25 | params = [] 26 | param_numels = [] 27 | param_shapes = [] 28 | for mn, m in self.named_modules(): 29 | for n, p in m.named_parameters(recurse=False): 30 | if p is not None: 31 | if p in shared_param_memo: 32 | shared_mn, shared_n = shared_param_memo[p] 33 | shared_param_infos.append((mn, n, shared_mn, shared_n)) 34 | else: 35 | shared_param_memo[p] = (mn, n) 36 | param_infos.append((mn, n)) 37 | params.append(p.detach()) 38 | param_numels.append(p.numel()) 39 | param_shapes.append(p.size()) 40 | 41 | assert len(set(p.dtype for p in params)) <= 1, \ 42 | "expects all parameters in module to have same dtype" 43 | 44 | # store the info for unflatten 45 | self._param_infos = tuple(param_infos) 46 | self._shared_param_infos = tuple(shared_param_infos) 47 | self._param_numels = tuple(param_numels) 48 | self._param_shapes = tuple(param_shapes) 49 | 50 | # flatten 51 | flat_param = nn.Parameter(torch.cat([p.reshape(-1) for p in params], 0)) 52 | self.register_parameter('flat_param', flat_param) 53 | self.param_numel = flat_param.numel() 54 | del params 55 | del shared_param_memo 56 | 57 | # deregister the names as parameters 58 | for mn, n in self._param_infos: 59 | delattr(self._get_module_from_name(mn), n) 60 | for mn, n, _, _ in self._shared_param_infos: 61 | delattr(self._get_module_from_name(mn), n) 62 | 63 | # register the views as plain attributes 64 | self._unflatten_param(self.flat_param) 65 | 66 | # now buffers 67 | # they are not reparametrized. just store info as (module, name, buffer) 68 | buffer_infos = [] 69 | for mn, m in self.named_modules(): 70 | for n, b in m.named_buffers(recurse=False): 71 | if b is not None: 72 | buffer_infos.append((mn, n, b)) 73 | 74 | self._buffer_infos = tuple(buffer_infos) 75 | self._traced_self = None 76 | 77 | def trace(self, example_input, **trace_kwargs): 78 | assert self._traced_self is None, 'This ReparamModule is already traced' 79 | 80 | if isinstance(example_input, torch.Tensor): 81 | example_input = (example_input,) 82 | example_input = tuple(example_input) 83 | example_param = (self.flat_param.detach().clone(),) 84 | example_buffers = (tuple(b.detach().clone() for _, _, b in self._buffer_infos),) 85 | 86 | self._traced_self = torch.jit.trace_module( 87 | self, 88 | inputs=dict( 89 | _forward_with_param=example_param + example_input, 90 | _forward_with_param_and_buffers=example_param + example_buffers + example_input, 91 | ), 92 | **trace_kwargs, 93 | ) 94 | 95 | # replace forwards with traced versions 96 | self._forward_with_param = self._traced_self._forward_with_param 97 | self._forward_with_param_and_buffers = self._traced_self._forward_with_param_and_buffers 98 | return self 99 | 100 | def clear_views(self): 101 | for mn, n in self._param_infos: 102 | setattr(self._get_module_from_name(mn), n, None) # This will set as plain attr 103 | 104 | def _apply(self, *args, **kwargs): 105 | if self._traced_self is not None: 106 | self._traced_self._apply(*args, **kwargs) 107 | return self 108 | return super(ReparamModule, self)._apply(*args, **kwargs) 109 | 110 | def _unflatten_param(self, flat_param): 111 | ps = (t.view(s) for (t, s) in zip(flat_param.split(self._param_numels), self._param_shapes)) 112 | for (mn, n), p in zip(self._param_infos, ps): 113 | setattr(self._get_module_from_name(mn), n, p) # This will set as plain attr 114 | for (mn, n, shared_mn, shared_n) in self._shared_param_infos: 115 | setattr(self._get_module_from_name(mn), n, getattr(self._get_module_from_name(shared_mn), shared_n)) 116 | 117 | @contextmanager 118 | def unflattened_param(self, flat_param): 119 | saved_views = [getattr(self._get_module_from_name(mn), n) for mn, n in self._param_infos] 120 | self._unflatten_param(flat_param) 121 | yield 122 | # Why not just `self._unflatten_param(self.flat_param)`? 123 | # 1. because of https://github.com/pytorch/pytorch/issues/17583 124 | # 2. slightly faster since it does not require reconstruct the split+view 125 | # graph 126 | for (mn, n), p in zip(self._param_infos, saved_views): 127 | setattr(self._get_module_from_name(mn), n, p) 128 | for (mn, n, shared_mn, shared_n) in self._shared_param_infos: 129 | setattr(self._get_module_from_name(mn), n, getattr(self._get_module_from_name(shared_mn), shared_n)) 130 | 131 | @contextmanager 132 | def replaced_buffers(self, buffers): 133 | for (mn, n, _), new_b in zip(self._buffer_infos, buffers): 134 | setattr(self._get_module_from_name(mn), n, new_b) 135 | yield 136 | for mn, n, old_b in self._buffer_infos: 137 | setattr(self._get_module_from_name(mn), n, old_b) 138 | 139 | def _forward_with_param_and_buffers(self, flat_param, buffers, *inputs, **kwinputs): 140 | with self.unflattened_param(flat_param): 141 | with self.replaced_buffers(buffers): 142 | return self.module(*inputs, **kwinputs) 143 | 144 | def _forward_with_param(self, flat_param, *inputs, **kwinputs): 145 | with self.unflattened_param(flat_param): 146 | return self.module(*inputs, **kwinputs) 147 | 148 | def forward(self, *inputs, flat_param=None, buffers=None, **kwinputs): 149 | flat_param = torch.squeeze(flat_param) 150 | # print("PARAMS ON DEVICE: ", flat_param.get_device(), flat_param.shape) 151 | # print("DATA ON DEVICE: ", inputs[0].get_device(), inputs[0].shape) 152 | # flat_param.to("cuda:{}".format(inputs[0].get_device())) 153 | # self.module.to("cuda:{}".format(inputs[0].get_device())) 154 | if flat_param is None: 155 | flat_param = self.flat_param 156 | if buffers is None: 157 | return self._forward_with_param(flat_param, *inputs, **kwinputs) 158 | else: 159 | return self._forward_with_param_and_buffers(flat_param, tuple(buffers), *inputs, **kwinputs) -------------------------------------------------------------------------------- /TM/reparam_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import warnings 4 | import types 5 | from collections import namedtuple 6 | from contextlib import contextmanager 7 | 8 | 9 | class ReparamModule(nn.Module): 10 | def _get_module_from_name(self, mn): 11 | if mn == '': 12 | return self 13 | m = self 14 | for p in mn.split('.'): 15 | m = getattr(m, p) 16 | return m 17 | 18 | def __init__(self, module): 19 | super(ReparamModule, self).__init__() 20 | self.module = module 21 | 22 | param_infos = [] # (module name/path, param name) 23 | shared_param_memo = {} 24 | shared_param_infos = [] # (module name/path, param name, src module name/path, src param_name) 25 | params = [] 26 | param_numels = [] 27 | param_shapes = [] 28 | for mn, m in self.named_modules(): 29 | for n, p in m.named_parameters(recurse=False): 30 | if p is not None: 31 | if p in shared_param_memo: 32 | shared_mn, shared_n = shared_param_memo[p] 33 | shared_param_infos.append((mn, n, shared_mn, shared_n)) 34 | else: 35 | shared_param_memo[p] = (mn, n) 36 | param_infos.append((mn, n)) 37 | params.append(p.detach()) 38 | param_numels.append(p.numel()) 39 | param_shapes.append(p.size()) 40 | 41 | assert len(set(p.dtype for p in params)) <= 1, \ 42 | "expects all parameters in module to have same dtype" 43 | 44 | # store the info for unflatten 45 | self._param_infos = tuple(param_infos) 46 | self._shared_param_infos = tuple(shared_param_infos) 47 | self._param_numels = tuple(param_numels) 48 | self._param_shapes = tuple(param_shapes) 49 | 50 | # flatten 51 | flat_param = nn.Parameter(torch.cat([p.reshape(-1) for p in params], 0)) 52 | self.register_parameter('flat_param', flat_param) 53 | self.param_numel = flat_param.numel() 54 | del params 55 | del shared_param_memo 56 | 57 | # deregister the names as parameters 58 | for mn, n in self._param_infos: 59 | delattr(self._get_module_from_name(mn), n) 60 | for mn, n, _, _ in self._shared_param_infos: 61 | delattr(self._get_module_from_name(mn), n) 62 | 63 | # register the views as plain attributes 64 | self._unflatten_param(self.flat_param) 65 | 66 | # now buffers 67 | # they are not reparametrized. just store info as (module, name, buffer) 68 | buffer_infos = [] 69 | for mn, m in self.named_modules(): 70 | for n, b in m.named_buffers(recurse=False): 71 | if b is not None: 72 | buffer_infos.append((mn, n, b)) 73 | 74 | self._buffer_infos = tuple(buffer_infos) 75 | self._traced_self = None 76 | 77 | def trace(self, example_input, **trace_kwargs): 78 | assert self._traced_self is None, 'This ReparamModule is already traced' 79 | 80 | if isinstance(example_input, torch.Tensor): 81 | example_input = (example_input,) 82 | example_input = tuple(example_input) 83 | example_param = (self.flat_param.detach().clone(),) 84 | example_buffers = (tuple(b.detach().clone() for _, _, b in self._buffer_infos),) 85 | 86 | self._traced_self = torch.jit.trace_module( 87 | self, 88 | inputs=dict( 89 | _forward_with_param=example_param + example_input, 90 | _forward_with_param_and_buffers=example_param + example_buffers + example_input, 91 | ), 92 | **trace_kwargs, 93 | ) 94 | 95 | # replace forwards with traced versions 96 | self._forward_with_param = self._traced_self._forward_with_param 97 | self._forward_with_param_and_buffers = self._traced_self._forward_with_param_and_buffers 98 | return self 99 | 100 | def clear_views(self): 101 | for mn, n in self._param_infos: 102 | setattr(self._get_module_from_name(mn), n, None) # This will set as plain attr 103 | 104 | def _apply(self, *args, **kwargs): 105 | if self._traced_self is not None: 106 | self._traced_self._apply(*args, **kwargs) 107 | return self 108 | return super(ReparamModule, self)._apply(*args, **kwargs) 109 | 110 | def _unflatten_param(self, flat_param): 111 | ps = (t.view(s) for (t, s) in zip(flat_param.split(self._param_numels), self._param_shapes)) 112 | for (mn, n), p in zip(self._param_infos, ps): 113 | setattr(self._get_module_from_name(mn), n, p) # This will set as plain attr 114 | for (mn, n, shared_mn, shared_n) in self._shared_param_infos: 115 | setattr(self._get_module_from_name(mn), n, getattr(self._get_module_from_name(shared_mn), shared_n)) 116 | 117 | @contextmanager 118 | def unflattened_param(self, flat_param): 119 | saved_views = [getattr(self._get_module_from_name(mn), n) for mn, n in self._param_infos] 120 | self._unflatten_param(flat_param) 121 | yield 122 | # Why not just `self._unflatten_param(self.flat_param)`? 123 | # 1. because of https://github.com/pytorch/pytorch/issues/17583 124 | # 2. slightly faster since it does not require reconstruct the split+view 125 | # graph 126 | for (mn, n), p in zip(self._param_infos, saved_views): 127 | setattr(self._get_module_from_name(mn), n, p) 128 | for (mn, n, shared_mn, shared_n) in self._shared_param_infos: 129 | setattr(self._get_module_from_name(mn), n, getattr(self._get_module_from_name(shared_mn), shared_n)) 130 | 131 | @contextmanager 132 | def replaced_buffers(self, buffers): 133 | for (mn, n, _), new_b in zip(self._buffer_infos, buffers): 134 | setattr(self._get_module_from_name(mn), n, new_b) 135 | yield 136 | for mn, n, old_b in self._buffer_infos: 137 | setattr(self._get_module_from_name(mn), n, old_b) 138 | 139 | def _forward_with_param_and_buffers(self, flat_param, buffers, *inputs, **kwinputs): 140 | with self.unflattened_param(flat_param): 141 | with self.replaced_buffers(buffers): 142 | return self.module(*inputs, **kwinputs) 143 | 144 | def _forward_with_param(self, flat_param, *inputs, **kwinputs): 145 | with self.unflattened_param(flat_param): 146 | return self.module(*inputs, **kwinputs) 147 | 148 | def forward(self, *inputs, flat_param=None, buffers=None, **kwinputs): 149 | flat_param = torch.squeeze(flat_param) 150 | # print("PARAMS ON DEVICE: ", flat_param.get_device()) 151 | # print("DATA ON DEVICE: ", inputs[0].get_device()) 152 | # flat_param.to("cuda:{}".format(inputs[0].get_device())) 153 | # self.module.to("cuda:{}".format(inputs[0].get_device())) 154 | if flat_param is None: 155 | flat_param = self.flat_param 156 | if buffers is None: 157 | return self._forward_with_param(flat_param, *inputs, **kwinputs) 158 | else: 159 | return self._forward_with_param_and_buffers(flat_param, tuple(buffers), *inputs, **kwinputs) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Frequency Domain-based Dataset Distillation (FreD) [NeurIPS 2023] 2 | 3 | This repository contains an official PyTorch implementation for the paper "Frequency Domain-based Dataset Distillation" in NeurIPS 2023. 4 | 5 | **[Donghyeok Shin](http://kaal.dsso.kr/bbs/board.php?bo_table=sub2_1&wr_id=8) \*, [Seungjae Shin](https://sites.google.com/view/seungjae-shin) \*, and Il-Chul Moon** 6 | * Equal contribution
7 | 8 | | [paper](https://arxiv.org/abs/2311.08819) | [slides](https://neurips.cc/media/neurips-2023/Slides/71874.pdf) | [pretrained](https://drive.google.com/drive/folders/1r1OMVv9llejGmpHfK5DpW4m57Dz_SZ2n?usp=sharing) | 9 | 10 | ## Updates 11 | - (2024.05.07) We uploaded the distilled synthetic dataset except in a few cases. Please refer to [pretrained](https://drive.google.com/drive/folders/1r1OMVv9llejGmpHfK5DpW4m57Dz_SZ2n?usp=sharing). The rest of the cases will be uploaded as soon as possible. 12 | 13 | ## Overview 14 | ![Teaser image](overview_FreD.png) 15 | > **Abstract** *This paper presents FreD, a novel parameterization method for dataset distillation, which utilizes the frequency domain to distill a small-sized synthetic dataset from a large-sized original dataset. Unlike conventional approaches that focus on the spatial domain, FreD employs frequency-based transforms to optimize the frequency representations of each data instance. By leveraging the concentration of spatial domain information on specific frequency components, FreD intelligently selects a subset of frequency dimensions for optimization, leading to a significant reduction in the required budget for synthesizing an instance. Through the selection of frequency dimensions based on the explained variance, FreD demonstrates both theoretical and empirical evidence of its ability to operate efficiently within a limited budget, while better preserving the information of the original dataset compared to conventional parameterization methods. Furthermore, based on the orthogonal compatibility of FreD with existing methods, we confirm that FreD consistently improves the performances of existing distillation methods over the evaluation scenarios with different benchmark datasets.* 16 | 17 | ## Requirements 18 | This code was tested with CUDA 11.4 and Python 3.8. 19 | ``` 20 | pip install -r requirements.txt 21 | ``` 22 | 23 | ## Usage 24 | The main hyper-parameters of FreD are as follows: 25 | - `msz_per_channel` : Memory size per each channel 26 | - `lr_freq` : Learning rate for synthetic frequency representation 27 | - `mom_freq` : Momentum for synthetic frequency representation 28 | 29 | The detailed values of these hyper-parameters can be found in our paper. 30 | For other hyper-parameters, we follow the default setting of each dataset distillation objectives. 31 | Please refer to the bash file for detailed arguments to run the experiment. 32 | 33 | Below are some example commands to run FreD with each dataset distillation objective. 34 | ### FreD with Gradient Matching (DC) 35 | - Run the following command: 36 | ``` 37 | cd DC/scripts 38 | sh run_DC_FreD.sh 39 | ``` 40 | 41 | ### FreD with Distribution Matching (DM) 42 | - Run the following command: 43 | ``` 44 | cd DM/scripts 45 | sh run_DM_FreD.sh 46 | ``` 47 | 48 | ### FreD with Trajectory Matching (TM) 49 | - Since TM need expert trajectories, run `run_buffer.sh` to generate expert trajectories before distillation: 50 | ``` 51 | cd TM/scripts 52 | sh run_buffer.sh 53 | sh run_TM_FreD.sh 54 | ``` 55 | 56 | ### FreD with Other Dataset Distillation Objective 57 | FreD is a highly compatible parameterization method regardless of the dataset distillation objective. 58 | Herein, we provide simple guidelines for how to use FreD with different dataset distillation objectives. 59 | ``` 60 | # Define the frequency domain-based parameterization 61 | synset = SynSet(args) 62 | 63 | # Initialization 64 | synset.init(images_all, labels_all, indices_class) 65 | 66 | # Get partial synthetic dataset 67 | images_syn, labels_syn = synset.get(indices=indices) 68 | 69 | # Get entire dataset (need_copy is optional) 70 | images_syn, labels_syn = synset.get(need_copy=True) 71 | ``` 72 | ### ImageNet-[A,B,C,D,E] 73 | - For ImageNet-[A,B,C,D,E] experiments, we built upon the GLaD's official code. 74 | - Run the following command with appropriate objective in `XXX`: 75 | - If you want to run FreD with TM, run `run_buffer.sh` before just like above. 76 | ``` 77 | cd ImageNet-abcde/scripts 78 | sh run_XXX_FreD.sh 79 | ``` 80 | 81 | ### 3D-MNIST 82 | - Download [3D-MNIST](https://www.kaggle.com/datasets/daavoo/3d-mnist) dataset. 83 | - Run the following command: 84 | ``` 85 | cd 3D-MNIST/scripts 86 | sh run_DM_FreD.sh 87 | ``` 88 | 89 | ### Robustness of Corruption 90 | - Download the corrupted dataset: [CIFAR-10.1](https://github.com/modestyachts/CIFAR-10.1), [CIFAR-10-C](https://zenodo.org/records/2535967), and [ImageNet-C](https://zenodo.org/records/2235448#.YpCSLxNBxAc). 91 | - Place the trained synthetic dataset at `corruption-exp/trained_synset/FreD/{dataset_name}`. 92 | - Run the following command: 93 | ``` 94 | cd corruption-exp/scripts 95 | sh run.sh 96 | ``` 97 | 98 | ## Experiment Results 99 | - Test accuracies (%) on low-dimensional datasets (≤ 64×64 resolution) with TM. 100 | 101 | | | MNIST | FashionMNIST | SVHN | CIFAR-10 | CIFAR-100 | Tiny-ImageNet | 102 | | :------: | :-----: | :----: | :-----: | :----: | :----: | :----: | 103 | | 1 img/cls | 95.8 | 84.6 | 82.2 | 60.6 | 34.6 | 19.2 | 104 | | 10 img/cls | 97.6 | 89.1 | 89.5 | 70.3 | 42.7 | 24.2 | 105 | | 50 img/cls | - | - | 90.3 | 75.8 | 47.8 | 26.4 | 106 | 107 | - Test accuracies (%) on Image-[Nette, Woof, Fruit, Yellow, Meow, Squawk] (128 × 128 resolution) with TM. 108 | 109 | | | ImageNette | ImageWoof | ImageFruit | ImageYellow | ImageMeow | ImageSquawk | 110 | | :------: | :-----: | :----: | :-----: | :----: | :----: | :----: | 111 | | 1 img/cls | 66.8 | 38.3 | 43.7 | 63.2 | 43.2 | 57.0 | 112 | | 10 img/cls | 72.0 | 41.3 | 47.0 | 69.2 | 48.6 | 67.3 | 113 | 114 | - Test accuracies (%) on ImageNet-[A, B, C, D, E] (128 × 128 resolution) with TM under 1 img/cls. 115 | 116 | | | ImageNet-A | ImageNet-B | ImageNet-C | ImageNet-D | ImageNet-E | 117 | | :------: | :-----: | :----: | :-----: | :----: | :----: | 118 | | DC w/ FreD | 53.1 | 54.8 | 54.2 | 42.8 | 41.0 | 119 | | DM w/ FreD | 58.0 | 58.6 | 55.6 | 46.3 | 45.0 | 120 | | TM w/ FreD | 67.7 | 69.3 | 63.6 | 54.4 | 55.4 | 121 | 122 | More results can be found in our paper. 123 | 124 | ## Citation 125 | If you find the code useful for your research, please consider citing our paper. 126 | ```bib 127 | @inproceedings{shin2023frequency, 128 | title={Frequency Domain-Based Dataset Distillation}, 129 | author={Shin, DongHyeok and Shin, Seungjae and Moon, Il-chul}, 130 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems}, 131 | year={2023} 132 | } 133 | ``` 134 | This work is heavily built upon the code from 135 | - *Bo Zhao, Konda Reddy Mopuri, and Hakan Bilen. Dataset condensation with gradient matching. arXiv preprint arXiv:2006.05929, 2020.* [Code Link](https://github.com/VICO-UoE/DatasetCondensation) 136 | - *Bo Zhao and Hakan Bilen. Dataset condensation with distribution matching. In Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision, pages 6514–6523, 2023.* [Code Link](https://github.com/VICO-UoE/DatasetCondensation) 137 | - *George Cazenavette, Tongzhou Wang, Antonio Torralba, Alexei A Efros, and Jun-Yan Zhu. Dataset distillation by matching training trajectories. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 4750–4759, 2022.* [Code Link](https://github.com/georgecazenavette/mtt-distillation) 138 | - *George Cazenavette, Tongzhou Wang, Antonio Torralba, Alexei A Efros, and Jun-Yan Zhu. Generalizing dataset distillation via deep generative prior. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 3739–3748, 2023.* [Code Link](https://github.com/GeorgeCazenavette/glad/tree/main) 139 | -------------------------------------------------------------------------------- /3D-MNIST/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.utils.data import Dataset 8 | from torchvision import datasets, transforms 9 | from scipy.ndimage.interpolation import rotate as scipyrotate 10 | from networks import Conv3DNet 11 | 12 | import tqdm 13 | 14 | import random 15 | import h5py 16 | 17 | def set_seed(seed): 18 | torch.manual_seed(seed) 19 | torch.cuda.manual_seed(seed) 20 | torch.cuda.manual_seed_all(seed) 21 | np.random.seed(seed) 22 | random.seed(seed) 23 | torch.backends.cudnn.deterministic = True 24 | torch.backends.cudnn.benchmark = False 25 | 26 | def save_and_print(dirname, msg): 27 | if not os.path.isfile(dirname): 28 | f = open(dirname, "w") 29 | f.write(str(time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime()))) 30 | f.write("\n") 31 | f.close() 32 | f = open(dirname, "a") 33 | f.write(str(msg)) 34 | f.write("\n") 35 | f.close() 36 | print(msg) 37 | 38 | def get_images(images_all, indices_class, c, n): # get random n images from class c 39 | idx_shuffle = np.random.permutation(indices_class[c])[:n] 40 | return images_all[idx_shuffle] 41 | 42 | def prepare_points(tensor, data_config, threshold = False): 43 | if threshold: 44 | tensor = np.where( 45 | tensor > data_config.threshold, 46 | data_config.lower, 47 | data_config.upper 48 | ) 49 | tensor = tensor.reshape(( 50 | tensor.shape[0], 51 | data_config.y_shape, 52 | data_config.x_shape, 53 | data_config.z_shape 54 | )) 55 | return tensor 56 | 57 | class data_config: 58 | threshold = 0.2 59 | upper = 1 60 | lower = 0 61 | x_shape = 16 62 | y_shape = 16 63 | z_shape = 16 64 | 65 | def get_dataset(data_path, args=None): 66 | channel = 1 67 | im_size = (16, 16, 16) 68 | num_classes = 10 69 | class_names = [str(c) for c in range(num_classes)] 70 | 71 | full_vectors = os.path.join(data_path, '3D-MNIST', 'full_dataset_vectors.h5') 72 | 73 | with h5py.File(full_vectors, "r") as hf: 74 | X_train = hf['X_train'][:] 75 | X_test = hf['X_test'][:] 76 | y_train = hf['y_train'][:] 77 | y_test = hf['y_test'][:] 78 | 79 | X_train = prepare_points(X_train, data_config) 80 | X_test = prepare_points(X_test, data_config) 81 | 82 | X_train, X_test = torch.tensor(X_train, dtype=torch.float, device=args.device), torch.tensor(X_test, dtype=torch.float, device=args.device) 83 | X_train, X_test = torch.unsqueeze(X_train, dim=1), torch.unsqueeze(X_test, dim=1) 84 | y_train, y_test = torch.tensor(y_train, dtype=torch.long, device=args.device), torch.tensor(y_test, dtype=torch.long, device=args.device) 85 | 86 | dst_train = TensorDataset(X_train, y_train) # no augmentation 87 | dst_test = TensorDataset(X_test, y_test) 88 | testloader = torch.utils.data.DataLoader(dst_test, shuffle=False, batch_size=256, num_workers=0) 89 | 90 | return channel, im_size, num_classes, class_names, dst_train, dst_test, testloader 91 | 92 | 93 | class TensorDataset(Dataset): 94 | def __init__(self, images, labels): # images: n x c x h x w tensor 95 | self.images = images.detach().float() 96 | self.labels = labels.detach() 97 | 98 | def __getitem__(self, index): 99 | return self.images[index], self.labels[index] 100 | 101 | def __len__(self): 102 | return self.images.shape[0] 103 | 104 | 105 | def get_network(model, channel, num_classes, im_size=(32, 32)): 106 | #torch.random.manual_seed(int(time.time() * 1000) % 100000) 107 | net_width, net_depth = 64, 3 108 | 109 | if model == 'Conv3DNet': 110 | net = Conv3DNet(channel, num_classes, net_width, net_depth, im_size) 111 | else: 112 | net = None 113 | exit('unknown model: %s'%model) 114 | 115 | gpu_num = torch.cuda.device_count() 116 | if gpu_num>0: 117 | device = 'cuda' 118 | if gpu_num>1: 119 | net = nn.DataParallel(net) 120 | else: 121 | device = 'cpu' 122 | net = net.to(device) 123 | 124 | return net 125 | 126 | def get_time(): 127 | return str(time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime())) 128 | 129 | def get_loops(ipc): 130 | # Get the two hyper-parameters of outer-loop and inner-loop. 131 | # The following values are empirically good. 132 | if ipc == 1: 133 | outer_loop, inner_loop = 1, 1 134 | elif ipc == 10: 135 | outer_loop, inner_loop = 10, 50 136 | elif ipc == 20: 137 | outer_loop, inner_loop = 20, 25 138 | elif ipc == 30: 139 | outer_loop, inner_loop = 30, 20 140 | elif ipc == 40: 141 | outer_loop, inner_loop = 40, 15 142 | elif ipc == 50: 143 | outer_loop, inner_loop = 50, 10 144 | 145 | elif ipc == 2: 146 | outer_loop, inner_loop = 1, 1 147 | elif ipc == 11: 148 | outer_loop, inner_loop = 10, 50 149 | elif ipc == 51: 150 | outer_loop, inner_loop = 50, 10 151 | 152 | else: 153 | outer_loop, inner_loop = 0, 0 154 | exit('loop hyper-parameters are not defined for %d ipc'%ipc) 155 | return outer_loop, inner_loop 156 | 157 | 158 | def epoch(mode, dataloader, net, optimizer, criterion, args, aug): 159 | loss_avg, acc_avg, num_exp = 0, 0, 0 160 | net = net.to(args.device) 161 | criterion = criterion.to(args.device) 162 | 163 | if mode == 'train': 164 | net.train() 165 | else: 166 | net.eval() 167 | 168 | for i_batch, datum in enumerate(dataloader): 169 | img = datum[0].float().to(args.device) 170 | lab = datum[1].long().to(args.device) 171 | n_b = lab.shape[0] 172 | 173 | output = net(img) 174 | loss = criterion(output, lab) 175 | acc = np.sum(np.equal(np.argmax(output.cpu().data.numpy(), axis=-1), lab.cpu().data.numpy())) 176 | 177 | loss_avg += loss.item()*n_b 178 | acc_avg += acc 179 | num_exp += n_b 180 | 181 | if mode == 'train': 182 | optimizer.zero_grad() 183 | loss.backward() 184 | optimizer.step() 185 | 186 | loss_avg /= num_exp 187 | acc_avg /= num_exp 188 | 189 | return loss_avg, acc_avg 190 | 191 | 192 | def evaluate_synset(it_eval, net, images_train, labels_train, testloader, args): 193 | net = net.to(args.device) 194 | images_train = images_train.to(args.device) 195 | labels_train = labels_train.to(args.device) 196 | lr = float(args.lr_net) 197 | Epoch = int(args.epoch_eval_train) 198 | lr_schedule = [Epoch//2+1] 199 | optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005) 200 | criterion = nn.CrossEntropyLoss().to(args.device) 201 | 202 | dst_train = TensorDataset(images_train, labels_train) 203 | trainloader = torch.utils.data.DataLoader(dst_train, batch_size=args.batch_train, shuffle=True, num_workers=0) 204 | 205 | start = time.time() 206 | for ep in tqdm.tqdm(range(Epoch+1)): 207 | loss_train, acc_train = epoch('train', trainloader, net, optimizer, criterion, args, aug = False) 208 | if ep in lr_schedule: 209 | lr *= 0.1 210 | optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005) 211 | 212 | time_train = time.time() - start 213 | loss_test, acc_test = epoch('test', testloader, net, optimizer, criterion, args, aug = False) 214 | save_and_print(args.log_path, '%s Evaluate_%02d: epoch = %04d train time = %d s train loss = %.6f train acc = %.4f, test acc = %.4f' % (get_time(), it_eval, Epoch, int(time_train), loss_train, acc_train, acc_test)) 215 | 216 | return net, acc_train, acc_test 217 | -------------------------------------------------------------------------------- /ImageNet-abcde/networks/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | ''' ResNet ''' 6 | 7 | LAYERS = {'resnet18': [2, 2, 2, 2], 8 | 'resnet34': [3, 4, 6, 3], 9 | 'resnet50': [3, 4, 6, 3], 10 | 'resnet101': [3, 4, 23, 3], 11 | 'resnet152': [3, 8, 36, 3], 12 | 'resnet20': [3, 3, 3], 13 | 'resnet32': [5, 5, 5], 14 | 'resnet44': [7, 7, 7], 15 | 'resnet56': [9, 9, 9], 16 | 'resnet110': [18, 18, 18], 17 | } 18 | 19 | 20 | class BasicBlock(nn.Module): 21 | expansion = 1 22 | 23 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'): 24 | super(BasicBlock, self).__init__() 25 | self.norm = norm 26 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 27 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d( 28 | planes) if self.norm == 'batch' else nn.Identity() 29 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 30 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d( 31 | planes) if self.norm == 'batch' else nn.Identity() 32 | 33 | self.shortcut = nn.Sequential() 34 | if stride != 1 or in_planes != self.expansion * planes: 35 | self.shortcut = nn.Sequential( 36 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 37 | nn.GroupNorm(self.expansion * planes, self.expansion * planes, 38 | affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d( 39 | self.expansion * planes) if self.norm == 'batch' else nn.Identity() 40 | ) 41 | 42 | def forward(self, x): 43 | out = F.relu(self.bn1(self.conv1(x))) 44 | out = self.bn2(self.conv2(out)) 45 | out += self.shortcut(x) 46 | out = F.relu(out) 47 | return out 48 | 49 | 50 | class Bottleneck(nn.Module): 51 | expansion = 4 52 | 53 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'): 54 | super(Bottleneck, self).__init__() 55 | self.norm = norm 56 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 57 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d( 58 | planes) if self.norm == 'batch' else nn.Identity() 59 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 60 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d( 61 | planes) if self.norm == 'batch' else nn.Identity() 62 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 63 | self.bn3 = nn.GroupNorm(self.expansion * planes, self.expansion * planes, 64 | affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d( 65 | self.expansion * planes) if self.norm == 'batch' else nn.Identity() 66 | 67 | self.shortcut = nn.Sequential() 68 | if stride != 1 or in_planes != self.expansion * planes: 69 | self.shortcut = nn.Sequential( 70 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 71 | nn.GroupNorm(self.expansion * planes, self.expansion * planes, 72 | affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d( 73 | self.expansion * planes) if self.norm == 'batch' else nn.Identity() 74 | ) 75 | 76 | def forward(self, x): 77 | out = F.relu(self.bn1(self.conv1(x))) 78 | out = F.relu(self.bn2(self.conv2(out))) 79 | out = self.bn3(self.conv3(out)) 80 | out += self.shortcut(x) 81 | out = F.relu(out) 82 | return out 83 | 84 | 85 | class ResNet(nn.Module): 86 | def __init__(self, block, num_blocks, channel=3, num_classes=10, norm='instancenorm'): 87 | super(ResNet, self).__init__() 88 | self.in_planes = 64 89 | self.norm = norm 90 | 91 | self.conv1 = nn.Conv2d(channel, 64, kernel_size=3, stride=1, padding=1, bias=False) 92 | self.bn1 = nn.GroupNorm(64, 64, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d( 93 | 64) if self.norm == 'batch' else nn.Identity() 94 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 95 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 96 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 97 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 98 | self.pool = nn.AdaptiveAvgPool2d((1, 1)) 99 | self.classifier = nn.Linear(512 * block.expansion, num_classes) 100 | 101 | def _make_layer(self, block, planes, num_blocks, stride): 102 | strides = [stride] + [1] * (num_blocks - 1) 103 | layers = [] 104 | for stride in strides: 105 | layers.append(block(self.in_planes, planes, stride, self.norm)) 106 | self.in_planes = planes * block.expansion 107 | return nn.Sequential(*layers) 108 | 109 | def forward(self, x): 110 | out = F.relu(self.bn1(self.conv1(x))) 111 | out = self.layer1(out) 112 | out = self.layer2(out) 113 | out = self.layer3(out) 114 | out = self.layer4(out) 115 | out = self.pool(out) 116 | 117 | out = out.view(out.size(0), -1) 118 | feat_fc = out 119 | out = self.classifier(out) 120 | 121 | return out 122 | 123 | def embed(self, x): 124 | out = F.relu(self.bn1(self.conv1(x))) 125 | out = self.layer1(out) 126 | out = self.layer2(out) 127 | out = self.layer3(out) 128 | out = self.layer4(out) 129 | out = F.avg_pool2d(out, 4) 130 | out = out.view(out.size(0), -1) 131 | return out 132 | 133 | 134 | class ResNetImageNet(nn.Module): 135 | def __init__(self, block, num_blocks, channel=3, num_classes=10, norm='instancenorm'): 136 | super(ResNetImageNet, self).__init__() 137 | self.in_planes = 64 138 | self.norm = norm 139 | 140 | self.conv1 = nn.Conv2d(channel, 64, kernel_size=7, stride=2, padding=3, bias=False) 141 | self.bn1 = nn.GroupNorm(64, 64, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d( 142 | 64) if self.norm == 'batch' else nn.Identity() 143 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 144 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 145 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 146 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 147 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 148 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 149 | self.classifier = nn.Linear(512 * block.expansion, num_classes) 150 | 151 | def _make_layer(self, block, planes, num_blocks, stride): 152 | strides = [stride] + [1] * (num_blocks - 1) 153 | layers = [] 154 | for stride in strides: 155 | layers.append(block(self.in_planes, planes, stride, self.norm)) 156 | self.in_planes = planes * block.expansion 157 | return nn.Sequential(*layers) 158 | 159 | def forward(self, x): 160 | out = F.relu(self.bn1(self.conv1(x))) 161 | out = self.maxpool(out) 162 | out = self.layer1(out) 163 | out = self.layer2(out) 164 | out = self.layer3(out) 165 | out = self.layer4(out) 166 | # out = F.avg_pool2d(out, 4) 167 | # out = out.view(out.size(0), -1) 168 | out = self.avgpool(out) 169 | out = torch.flatten(out, 1) 170 | out = self.classifier(out) 171 | return out 172 | 173 | 174 | def ResNet18BN(channel, num_classes): 175 | return ResNet(BasicBlock, [2, 2, 2, 2], channel=channel, num_classes=num_classes, norm='batchnorm') 176 | 177 | 178 | def ResNet18(channel, num_classes, norm): 179 | return ResNet(BasicBlock, [2, 2, 2, 2], channel=channel, num_classes=num_classes, norm=norm) 180 | 181 | 182 | def ResNet18ImageNet(channel, num_classes): 183 | return ResNetImageNet(BasicBlock, [2, 2, 2, 2], channel=channel, num_classes=num_classes) 184 | 185 | 186 | def ResNet34(channel, num_classes): 187 | return ResNet(BasicBlock, [3, 4, 6, 3], channel=channel, num_classes=num_classes) 188 | 189 | 190 | def ResNet50(channel, num_classes): 191 | return ResNet(Bottleneck, [3, 4, 6, 3], channel=channel, num_classes=num_classes) 192 | 193 | 194 | def ResNet101(channel, num_classes): 195 | return ResNet(Bottleneck, [3, 4, 23, 3], channel=channel, num_classes=num_classes) 196 | 197 | 198 | def ResNet152(channel, num_classes): 199 | return ResNet(Bottleneck, [3, 8, 36, 3], channel=channel, num_classes=num_classes) 200 | 201 | 202 | def ResNet152Imagenet(channel, num_classes): 203 | return ResNetImageNet(Bottleneck, [3, 8, 36, 3], channel=channel, num_classes=num_classes) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /ImageNet-abcde/glad_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import copy 4 | from tqdm import tqdm 5 | 6 | from utils import get_network, config, evaluate_synset, save_and_print 7 | 8 | def build_dataset(ds, class_map, num_classes): 9 | images_all = [] 10 | labels_all = [] 11 | indices_class = [[] for c in range(num_classes)] 12 | print("BUILDING DATASET") 13 | for i in tqdm(range(len(ds))): 14 | sample = ds[i] 15 | images_all.append(torch.unsqueeze(sample[0], dim=0)) 16 | labels_all.append(class_map[torch.tensor(sample[1]).item()]) 17 | for i, lab in tqdm(enumerate(labels_all)): 18 | indices_class[lab].append(i) 19 | images_all = torch.cat(images_all, dim=0).to("cpu") 20 | labels_all = torch.tensor(labels_all, dtype=torch.long, device="cpu") 21 | 22 | return images_all, labels_all, indices_class 23 | 24 | 25 | def prepare_latents(channel=3, num_classes=10, im_size=(32, 32), zdim=512, G=None, class_map_inv={}, get_images=None, args=None): 26 | with torch.no_grad(): 27 | ''' initialize the synthetic data ''' 28 | label_syn = torch.tensor([i*np.ones(args.ipc, dtype=np.int64) for i in range(num_classes)], dtype=torch.long, requires_grad=False, device=args.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9] 29 | 30 | if args.space == 'p': 31 | latents = torch.randn(size=(num_classes * args.ipc, channel, im_size[0], im_size[1]), dtype=torch.float, requires_grad=False, device=args.device) 32 | f_latents = None 33 | 34 | else: 35 | zs = torch.randn(num_classes * args.ipc, zdim, device=args.device, requires_grad=False) 36 | 37 | if "imagenet" in args.dataset: 38 | one_hot_dim = 1000 39 | elif args.dataset == "CIFAR10": 40 | one_hot_dim = 10 41 | elif args.dataset == "CIFAR100": 42 | one_hot_dim = 100 43 | 44 | if args.avg_w: 45 | G_labels = torch.zeros([label_syn.nelement(), one_hot_dim], device=args.device) 46 | G_labels[torch.arange(0, label_syn.nelement(), dtype=torch.long), [class_map_inv[x.item()] for x in label_syn]] = 1 47 | new_latents = [] 48 | for label in G_labels: 49 | zs = torch.randn(1000, zdim).to(args.device) 50 | ws = G.mapping(zs, torch.stack([label] * 1000)) 51 | w = torch.mean(ws, dim=0) 52 | new_latents.append(w) 53 | latents = torch.stack(new_latents) 54 | del zs 55 | for _ in new_latents: 56 | del _ 57 | del new_latents 58 | 59 | else: 60 | G_labels = torch.zeros([label_syn.nelement(), one_hot_dim], device=args.device) 61 | G_labels[torch.arange(0, label_syn.nelement(), dtype=torch.long), [class_map_inv[x.item()] for x in label_syn]] = 1 62 | if args.distributed and False: 63 | latents = G.mapping(zs.to("cuda:1"), G_labels.to("cuda:1")).to("cuda:0") 64 | else: 65 | latents = G.mapping(zs, G_labels) 66 | del zs 67 | 68 | del G_labels 69 | 70 | ws = latents 71 | if args.layer is not None: 72 | f_latents = torch.cat( 73 | [G.forward(split_ws, f_layer=args.layer, mode="to_f").detach() for split_ws in 74 | torch.split(ws, args.sg_batch)]) 75 | f_type = f_latents.dtype 76 | f_latents = f_latents.to(torch.float32).cpu() 77 | f_latents = torch.nan_to_num(f_latents, posinf=5.0, neginf=-5.0) 78 | f_latents = torch.clip(f_latents, min=-10, max=10) 79 | f_latents = f_latents.to(f_type).cuda() 80 | 81 | save_and_print(args.log_path, f"{torch.mean(f_latents)}, {torch.std(f_latents)}") 82 | 83 | if args.rand_f: 84 | f_latents = (torch.randn(f_latents.shape).to(args.device) * torch.std(f_latents, dim=(1,2,3), keepdim=True) + torch.mean(f_latents, dim=(1,2,3), keepdim=True)) 85 | 86 | f_latents = f_latents.to(f_type) 87 | save_and_print(args.log_path, f"{torch.mean(f_latents)}, {torch.std(f_latents)}") 88 | f_latents.requires_grad_(True) 89 | else: 90 | f_latents = None 91 | 92 | if args.pix_init == 'real' and args.space == "p": 93 | save_and_print(args.log_path, 'initialize synthetic data from random real images') 94 | for c in range(num_classes): 95 | latents.data[c*args.ipc:(c+1)*args.ipc] = torch.cat([get_images(c, 1).detach().data for s in range(args.ipc)]) 96 | else: 97 | save_and_print(args.log_path, 'initialize synthetic data from random noise') 98 | 99 | latents = latents.detach().to(args.device).requires_grad_(True) 100 | 101 | return latents, f_latents, label_syn 102 | 103 | 104 | def get_optimizer_img(latents=None, f_latents=None, G=None, args=None): 105 | if args.space == "wp" and (args.layer is not None and args.layer != -1): 106 | optimizer_img = torch.optim.SGD([latents], lr=args.lr_w, momentum=0.5) 107 | optimizer_img.add_param_group({'params': f_latents, 'lr': args.lr_img, 'momentum': 0.5}) 108 | else: 109 | optimizer_img = torch.optim.SGD([latents], lr=args.lr_img, momentum=0.5) 110 | 111 | if args.learn_g: 112 | G.requires_grad_(True) 113 | optimizer_img.add_param_group({'params': G.parameters(), 'lr': args.lr_g, 'momentum': 0.5}) 114 | 115 | optimizer_img.zero_grad() 116 | 117 | return optimizer_img 118 | 119 | def get_eval_lrs(args): 120 | eval_pool_dict = { 121 | args.model: 0.001, 122 | "ResNet18": 0.001, 123 | "VGG11": 0.0001, 124 | "AlexNet": 0.001, 125 | "ViT": 0.001, 126 | 127 | "AlexNetCIFAR": 0.001, 128 | "ResNet18CIFAR": 0.001, 129 | "VGG11CIFAR": 0.0001, 130 | "ViTCIFAR": 0.001, 131 | } 132 | 133 | return eval_pool_dict 134 | 135 | 136 | def eval_loop(latents=None, f_latents=None, label_syn=None, G=None, best_acc={}, best_std={}, testloader=None, model_eval_pool=[], it=0, channel=3, num_classes=10, im_size=(32, 32), args=None): 137 | curr_acc_dict = {} 138 | max_acc_dict = {} 139 | 140 | curr_std_dict = {} 141 | max_std_dict = {} 142 | 143 | eval_pool_dict = get_eval_lrs(args) 144 | 145 | save_this_it = False 146 | 147 | for model_eval in model_eval_pool: 148 | 149 | if model_eval != args.model and args.wait_eval and it != args.Iteration: 150 | continue 151 | save_and_print(args.log_path, '-------------------------\nEvaluation\nmodel_train = %s, model_eval = %s, iteration = %d' % (args.model, model_eval, it)) 152 | 153 | accs_test = [] 154 | accs_train = [] 155 | 156 | for it_eval in range(args.num_eval): 157 | net_eval = get_network(model_eval, channel, num_classes, im_size, width=args.width, depth=args.depth, dist=False).to(args.device) # get a random model 158 | eval_lats = latents 159 | eval_labs = label_syn 160 | image_syn = latents 161 | image_syn_eval, label_syn_eval = copy.deepcopy(image_syn.detach()), copy.deepcopy(eval_labs.detach()) # avoid any unaware modification 162 | 163 | if args.space == "wp": 164 | with torch.no_grad(): 165 | image_syn_eval = torch.cat( 166 | [latent_to_im(G, (image_syn_eval_split, f_latents_split), args=args).detach() for 167 | image_syn_eval_split, f_latents_split, label_syn_split in zip(torch.split(image_syn_eval, args.sg_batch), torch.split(f_latents, args.sg_batch), torch.split(label_syn, args.sg_batch))]) 168 | 169 | args.lr_net = eval_pool_dict[model_eval] 170 | _, acc_train, acc_test = evaluate_synset(it_eval, net_eval, image_syn_eval, label_syn_eval, testloader, args=args, aug=True) 171 | del _ 172 | del net_eval 173 | accs_test.append(acc_test) 174 | accs_train.append(acc_train) 175 | 176 | accs_test = np.array(accs_test) 177 | accs_train = np.array(accs_train) 178 | acc_test_mean = np.mean(np.max(accs_test, axis=1)) 179 | acc_test_std = np.std(np.max(accs_test, axis=1)) 180 | best_dict_str = "{}".format(model_eval) 181 | if acc_test_mean > best_acc[best_dict_str]: 182 | best_acc[best_dict_str] = acc_test_mean 183 | best_std[best_dict_str] = acc_test_std 184 | save_this_it = True 185 | 186 | curr_acc_dict[best_dict_str] = acc_test_mean 187 | curr_std_dict[best_dict_str] = acc_test_std 188 | 189 | max_acc_dict[best_dict_str] = best_acc[best_dict_str] 190 | max_std_dict[best_dict_str] = best_std[best_dict_str] 191 | 192 | save_and_print(args.log_path, 'Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------' % (len(accs_test[:, -1]), model_eval, acc_test_mean, np.std(np.max(accs_test, axis=1)))) 193 | save_and_print(args.log_path, f"{args.save_path}") 194 | save_and_print(args.log_path, f"{it:5d} | Accuracy/{model_eval}: {acc_test_mean}") 195 | save_and_print(args.log_path, f"{it:5d} | Max_Accuracy/{model_eval}: {best_acc[best_dict_str]}") 196 | save_and_print(args.log_path, f"{it:5d} | Std/{model_eval}: {acc_test_std}") 197 | save_and_print(args.log_path, f"{it:5d} | Max_Std/{model_eval}: {best_std[best_dict_str]}") 198 | 199 | if len(model_eval_pool) > 1: 200 | save_and_print(args.log_path, "-" * 20) 201 | save_and_print(args.log_path, f"{it:5d} | Accuracy/Avg_All: {np.mean(np.array(list(curr_acc_dict.values())))}") 202 | save_and_print(args.log_path, f"{it:5d} | Std/Avg_All: {np.mean(np.array(list(curr_std_dict.values())))}") 203 | save_and_print(args.log_path, f"{it:5d} | Max_Accuracy/Avg_All: {np.mean(np.array(list(max_acc_dict.values())))}") 204 | save_and_print(args.log_path, f"{it:5d} | Max_Std/Avg_All: {np.mean(np.array(list(max_std_dict.values())))}") 205 | 206 | curr_acc_dict.pop("{}".format(args.model)) 207 | curr_std_dict.pop("{}".format(args.model)) 208 | max_acc_dict.pop("{}".format(args.model)) 209 | max_std_dict.pop("{}".format(args.model)) 210 | 211 | save_and_print(args.log_path, "-" * 20) 212 | save_and_print(args.log_path, f"{it:5d} | Accuracy/Avg_Cross: {np.mean(np.array(list(curr_acc_dict.values())))}") 213 | save_and_print(args.log_path, f"{it:5d} | Std/Avg_Cross: {np.mean(np.array(list(curr_std_dict.values())))}") 214 | save_and_print(args.log_path, f"{it:5d} | Max_Accuracy/Avg_Cross: {np.mean(np.array(list(max_acc_dict.values())))}") 215 | save_and_print(args.log_path, f"{it:5d} | Max_Std/Avg_Cross: {np.mean(np.array(list(max_std_dict.values())))}") 216 | 217 | return save_this_it 218 | 219 | def latent_to_im(G, latents, args=None): 220 | 221 | if args.space == "p": 222 | return latents 223 | 224 | mean, std = config.mean, config.std 225 | 226 | if "imagenet" in args.dataset: 227 | class_map = {i: x for i, x in enumerate(config.img_net_classes)} 228 | 229 | if args.space == "p": 230 | im = latents 231 | 232 | elif args.space == "wp": 233 | if args.layer is None or args.layer==-1: 234 | im = G(latents[0], mode="wp") 235 | else: 236 | im = G(latents[0], latents[1], args.layer, mode="from_f") 237 | 238 | im = (im + 1) / 2 239 | im = (im - mean) / std 240 | 241 | elif args.dataset == "CIFAR10" or args.dataset == "CIFAR100": 242 | if args.space == "p": 243 | im = latents 244 | elif args.space == "wp": 245 | if args.layer is None or args.layer == -1: 246 | im = G(latents[0], mode="wp") 247 | else: 248 | im = G(latents[0], latents[1], args.layer, mode="from_f") 249 | 250 | if args.distributed and False: 251 | mean, std = config.mean_1, config.std_1 252 | 253 | im = (im + 1) / 2 254 | im = (im - mean) / std 255 | 256 | return im -------------------------------------------------------------------------------- /ImageNet-abcde/main_DM_FreD.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 3 | 4 | import time 5 | import copy 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug, set_seed, save_and_print, get_images 10 | 11 | import shutil 12 | import torchvision 13 | import matplotlib.pyplot as plt 14 | from frequency_transforms import DCT 15 | 16 | from glad_utils import build_dataset, get_eval_lrs, eval_loop 17 | 18 | class SynSet(): 19 | def __init__(self, args): 20 | ### Basic ### 21 | self.args = args 22 | self.log_path = self.args.log_path 23 | self.channel = self.args.channel 24 | self.num_classes = self.args.num_classes 25 | self.im_size = self.args.im_size 26 | self.device = self.args.device 27 | self.ipc = self.args.ipc 28 | 29 | ### FreD ### 30 | self.dct = DCT(resolution=self.im_size[0], device=self.device) 31 | self.lr_freq = self.args.lr_freq 32 | self.mom_freq = self.args.mom_freq 33 | self.msz_per_channel = self.args.msz_per_channel 34 | self.num_per_class = int((self.ipc * self.im_size[0] * self.im_size[1]) / self.msz_per_channel) 35 | 36 | def init(self, images_real, labels_real, indices_class): 37 | ### Initialize Frequency (F) ### 38 | images = torch.randn(size=(self.num_classes * self.num_per_class, self.channel, self.im_size[0], self.im_size[1]), dtype=torch.float, device=self.device) 39 | for c in range(self.num_classes): 40 | idx_shuffle = np.random.permutation(indices_class[c])[:self.num_per_class] 41 | images.data[c * self.num_per_class:(c + 1) * self.num_per_class] = images_real[idx_shuffle].detach().data 42 | self.freq_syn = self.dct.forward(images) 43 | self.freq_syn.requires_grad = True 44 | del images 45 | 46 | ### Initialize Mask (M) ### 47 | self.mask = torch.zeros(size=(self.num_classes * self.num_per_class, self.channel, self.im_size[0], self.im_size[1]), dtype=torch.float, device=self.device) 48 | self.mask.requires_grad = False 49 | 50 | ### Initialize Label ### 51 | self.label_syn = torch.tensor([np.ones(self.num_per_class) * i for i in range(self.num_classes)], requires_grad=False, device=self.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9] 52 | self.label_syn = self.label_syn.long() 53 | 54 | ### Initialize Optimizer ### 55 | self.optimizers = torch.optim.SGD([self.freq_syn, ], lr=self.lr_freq, momentum=self.mom_freq) 56 | 57 | self.init_mask() 58 | self.optim_zero_grad() 59 | self.show_budget() 60 | 61 | def get(self, indices=None, need_copy=False): 62 | if not hasattr(indices, '__iter__'): 63 | indices = range(len(self.label_syn)) 64 | 65 | if need_copy: 66 | freq_syn, label_syn = copy.deepcopy(self.freq_syn[indices].detach()), copy.deepcopy(self.label_syn[indices].detach()) 67 | mask = copy.deepcopy(self.mask[indices].detach()) 68 | else: 69 | freq_syn, label_syn = self.freq_syn[indices], self.label_syn[indices] 70 | mask = self.mask[indices] 71 | 72 | image_syn = self.dct.inverse(mask * freq_syn) 73 | return image_syn, label_syn 74 | 75 | def init_mask(self): 76 | save_and_print(self.args.log_path, "Initialize Mask") 77 | 78 | for c in range(self.num_classes): 79 | freq_c = copy.deepcopy(self.freq_syn[c * self.num_per_class:(c + 1) * self.num_per_class].detach()) 80 | freq_c = torch.mean(freq_c, 1) 81 | freqs_flat = torch.flatten(freq_c, 1) 82 | freqs_flat = freqs_flat - torch.mean(freqs_flat, dim=0) 83 | 84 | try: 85 | cov = torch.cov(freqs_flat.T) 86 | except: 87 | save_and_print(self.args.log_path, f"Can not use torch.cov. Instead use np.cov") 88 | cov = np.cov(freqs_flat.T.cpu().numpy()) 89 | cov = torch.tensor(cov, dtype=torch.float, device=self.device) 90 | total_variance = torch.sum(torch.diag(cov)) 91 | vr_fl2f = torch.zeros((np.prod(self.im_size), 1), device=self.device) 92 | for idx in range(np.prod(self.im_size)): 93 | pc_low = torch.eye(np.prod(self.im_size), device=self.device)[idx].reshape(-1, 1) 94 | vector_variance = torch.matmul(torch.matmul(pc_low.T, cov), pc_low) 95 | explained_variance_ratio = vector_variance / total_variance 96 | vr_fl2f[idx] = explained_variance_ratio.item() 97 | 98 | v, i = torch.topk(vr_fl2f.flatten(), self.msz_per_channel) 99 | top_indices = np.array(np.unravel_index(i.cpu().numpy(), freq_c.shape)).T[:, 1:] 100 | for h, w in top_indices: 101 | self.mask[c * self.num_per_class:(c + 1) * self.num_per_class, :, h, w] = 1.0 102 | save_and_print(self.args.log_path, f"{get_time()} Class {c:3d} | {torch.sum(self.mask[c * self.num_per_class, 0] > 0.0):5d}") 103 | 104 | ### Visualize and Save ### 105 | indices_save = np.arange(10) * self.num_per_class 106 | grid = torchvision.utils.make_grid(self.mask[indices_save], nrow=10) 107 | plt.imshow(np.transpose(grid.detach().cpu().numpy(), (1, 2, 0))) 108 | plt.savefig(f"{self.args.save_path}/Mask.png", dpi=300) 109 | plt.close() 110 | 111 | mask_save = copy.deepcopy(self.mask.detach()) 112 | torch.save(mask_save.cpu(), os.path.join(self.args.save_path, "mask.pt")) 113 | del mask_save 114 | 115 | def optim_zero_grad(self): 116 | self.optimizers.zero_grad() 117 | 118 | def optim_step(self): 119 | self.optimizers.step() 120 | 121 | def show_budget(self): 122 | save_and_print(self.log_path, '=' * 50) 123 | save_and_print(self.log_path, f"Freq: {self.freq_syn.shape} | Mask: {self.mask.shape} , {torch.sum(self.mask[0, 0] > 0.0):5d}") 124 | images, _ = self.get(need_copy=True) 125 | save_and_print(self.log_path, f"Decode condensed data: {images.shape}") 126 | del images 127 | save_and_print(self.log_path, '=' * 50) 128 | 129 | def main(args): 130 | 131 | eval_it_pool = np.arange(0, args.Iteration + 1, args.eval_it).tolist() 132 | channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader, loader_train_dict, class_map, class_map_inv = get_dataset(args.dataset, args.data_path, args.batch_real, args.res, args=args) 133 | args.channel, args.im_size, args.num_classes, args.mean, args.std = channel, im_size, num_classes, mean, std 134 | model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model) 135 | 136 | accs_all_exps = dict() # record performances of all experiments 137 | for key in model_eval_pool: 138 | accs_all_exps[key] = [] 139 | 140 | data_save = [] 141 | 142 | save_and_print(args.log_path, f'\n================== Exp {0} ==================\n ') 143 | save_and_print(args.log_path, f'Hyper-parameters: {args.__dict__}') 144 | 145 | ''' organize the real dataset ''' 146 | images_all, labels_all, indices_class = build_dataset(dst_train, class_map, num_classes) 147 | images_all, labels_all = images_all.to(args.device), labels_all.to(args.device) 148 | 149 | ''' initialize the synthetic data''' 150 | synset = SynSet(args) 151 | synset.init(images_all, labels_all, indices_class) 152 | 153 | ''' training ''' 154 | criterion = nn.CrossEntropyLoss().to(args.device) 155 | save_and_print(args.log_path, '%s training begins'%get_time()) 156 | 157 | best_acc = {"{}".format(m): 0 for m in model_eval_pool} 158 | best_std = {m: 0 for m in model_eval_pool} 159 | 160 | save_this_it = False 161 | for it in range(args.Iteration+1): 162 | 163 | if it in eval_it_pool: 164 | image_syn_eval, label_syn_eval = synset.get(need_copy=True) 165 | save_this_it = eval_loop(latents=image_syn_eval, f_latents=None, label_syn=label_syn_eval, G=None, best_acc=best_acc, 166 | best_std=best_std, testloader=testloader, 167 | model_eval_pool=model_eval_pool, channel=channel, num_classes=num_classes, 168 | im_size=im_size, it=it, args=args) 169 | 170 | ''' Train synthetic data ''' 171 | net = get_network(args.model, channel, num_classes, im_size, depth=args.depth, width=args.width).to(args.device) # get a random model 172 | net.train() 173 | for param in list(net.parameters()): 174 | param.requires_grad = False 175 | 176 | embed = net.module.embed if torch.cuda.device_count() > 1 else net.embed # for GPU parallel 177 | 178 | loss_avg = 0 179 | 180 | ''' update synthetic data ''' 181 | if 'BN' not in args.model: # for ConvNet 182 | loss = torch.tensor(0.0).to(args.device) 183 | for c in range(num_classes): 184 | img_real = get_images(images_all, indices_class, c, args.batch_real) 185 | 186 | if args.batch_syn > 0: 187 | indices = np.random.permutation(range(c * synset.num_per_class, (c + 1) * synset.num_per_class))[:args.batch_syn] 188 | else: 189 | indices = range(c * synset.num_per_class, (c + 1) * synset.num_per_class) 190 | 191 | img_syn, lab_syn = synset.get(indices=indices) 192 | 193 | if args.dsa: 194 | seed = int(time.time() * 1000) % 100000 195 | img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param) 196 | img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param) 197 | 198 | output_real = embed(img_real).detach() 199 | output_syn = embed(img_syn) 200 | 201 | loss += torch.sum((torch.mean(output_real, dim=0) - torch.mean(output_syn, dim=0))**2) 202 | 203 | else: # for ConvNetBN 204 | images_real_all = [] 205 | images_syn_all = [] 206 | loss = torch.tensor(0.0).to(args.device) 207 | for c in range(num_classes): 208 | img_real = get_images(c, args.batch_real) 209 | 210 | if args.batch_syn > 0: 211 | indices = np.random.permutation(range(c * synset.num_per_class, (c + 1) * synset.num_per_class))[:args.batch_syn] 212 | else: 213 | indices = range(c * synset.num_per_class, (c + 1) * synset.num_per_class) 214 | 215 | img_syn, lab_syn = synset.get(indices=indices) 216 | 217 | if args.dsa: 218 | seed = int(time.time() * 1000) % 100000 219 | img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param) 220 | img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param) 221 | 222 | images_real_all.append(img_real) 223 | images_syn_all.append(img_syn) 224 | 225 | images_real_all = torch.cat(images_real_all, dim=0) 226 | images_syn_all = torch.cat(images_syn_all, dim=0) 227 | 228 | output_real = embed(images_real_all).detach() 229 | output_syn = embed(images_syn_all) 230 | 231 | loss += torch.sum((torch.mean(output_real.reshape(num_classes, args.batch_real, -1), dim=1) - torch.mean(output_syn.reshape(num_classes, args.ipc, -1), dim=1))**2) 232 | 233 | synset.optim_zero_grad() 234 | loss.backward() 235 | synset.optim_step() 236 | loss_avg += loss.item() 237 | 238 | loss_avg /= (num_classes) 239 | 240 | if it%10 == 0: 241 | save_and_print(args.log_path, '%s iter = %04d, loss = %.4f' % (get_time(), it, loss_avg)) 242 | 243 | if it == args.Iteration: # only record the final results 244 | data_save.append([synset.get(need_copy=True)]) 245 | torch.save({'data': data_save, 'accs_all_exps': accs_all_exps, }, os.path.join(args.save_path, 'res_%s_%s_%dipc.pt'%(args.dataset, args.model, args.ipc))) 246 | 247 | 248 | if __name__ == '__main__': 249 | import shared_args 250 | 251 | parser = shared_args.add_shared_args() 252 | 253 | parser.add_argument('--zca', action='store_true', help="do ZCA whitening") 254 | 255 | parser.add_argument('--seed', type=int, default=0) 256 | parser.add_argument('--sh_file', type=str) 257 | parser.add_argument('--FLAG', type=str, default="TEST") 258 | 259 | ### FreD ### 260 | parser.add_argument('--batch_syn', type=int) 261 | parser.add_argument('--msz_per_channel', type=int) 262 | parser.add_argument('--lr_freq', type=float) 263 | parser.add_argument('--mom_freq', type=float) 264 | args = parser.parse_args() 265 | args.space = "p" 266 | args.zca = False 267 | 268 | set_seed(args.seed) 269 | 270 | args.outer_loop, args.inner_loop = get_loops(args.ipc) 271 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 272 | args.dsa_param = ParamDiffAug() 273 | args.dsa = False if args.dsa_strategy in ['none', 'None'] else True 274 | 275 | if not os.path.exists(args.data_path): 276 | os.mkdir(args.data_path) 277 | 278 | if not os.path.exists(args.save_path): 279 | os.mkdir(args.save_path) 280 | args.save_path = args.save_path + f"/{args.FLAG}" 281 | if not os.path.exists(args.save_path): 282 | os.mkdir(args.save_path) 283 | 284 | shutil.copy(f"./scripts/{args.sh_file}", f"{args.save_path}/{args.sh_file}") 285 | args.log_path = f"{args.save_path}/log.txt" 286 | 287 | main(args) 288 | 289 | 290 | -------------------------------------------------------------------------------- /3D-MNIST/main_DM_FreD.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 3 | 4 | import copy 5 | import argparse 6 | import numpy as np 7 | import torch 8 | from utils import get_loops, get_dataset, get_network, evaluate_synset, get_time, TensorDataset, epoch, set_seed, save_and_print, get_images 9 | 10 | import shutil 11 | from frequency_transforms import DCT 12 | 13 | class SynSet(): 14 | def __init__(self, args): 15 | ### Basic ### 16 | self.args = args 17 | self.log_path = self.args.log_path 18 | self.channel = self.args.channel 19 | self.num_classes = self.args.num_classes 20 | self.im_size = self.args.im_size 21 | self.device = self.args.device 22 | self.ipc = self.args.ipc 23 | 24 | ### FreD ### 25 | self.dct = DCT(resolution=self.im_size[0], device=self.device) 26 | self.lr_freq = self.args.lr_freq 27 | self.mom_freq = self.args.mom_freq 28 | self.msz_per_channel = self.args.msz_per_channel 29 | self.num_per_class = int((self.ipc * self.im_size[0] * self.im_size[1] * self.im_size[2]) / self.msz_per_channel) 30 | 31 | def init(self, images_real, labels_real, indices_class): 32 | ### Initialize Frequency (F) ### 33 | images = torch.randn(size=(self.num_classes * self.num_per_class, self.channel, self.im_size[0], self.im_size[1], self.im_size[2]), dtype=torch.float, device=self.device) 34 | for c in range(self.num_classes): 35 | idx_shuffle = np.random.permutation(indices_class[c])[:self.num_per_class] 36 | images.data[c * self.num_per_class:(c + 1) * self.num_per_class] = images_real[idx_shuffle].detach().data 37 | self.freq_syn = self.dct.forward(images) 38 | self.freq_syn.requires_grad = True 39 | del images 40 | 41 | ### Initialize Mask (M) ### 42 | self.mask = torch.zeros(size=(self.num_classes * self.num_per_class, self.channel, self.im_size[0], self.im_size[1], self.im_size[2]), dtype=torch.float, device=self.device) 43 | self.mask.requires_grad = False 44 | 45 | ### Initialize Label ### 46 | self.label_syn = torch.tensor([np.ones(self.num_per_class) * i for i in range(self.num_classes)], requires_grad=False, device=self.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9] 47 | self.label_syn = self.label_syn.long() 48 | 49 | ### Initialize Optimizer ### 50 | self.optimizers = torch.optim.SGD([self.freq_syn, ], lr=self.lr_freq, momentum=self.mom_freq) 51 | 52 | self.init_mask() 53 | self.optim_zero_grad() 54 | self.show_budget() 55 | 56 | def get(self, indices=None, need_copy=False): 57 | if not hasattr(indices, '__iter__'): 58 | indices = range(len(self.label_syn)) 59 | 60 | if need_copy: 61 | freq_syn, label_syn = copy.deepcopy(self.freq_syn[indices].detach()), copy.deepcopy(self.label_syn[indices].detach()) 62 | mask = copy.deepcopy(self.mask[indices].detach()) 63 | else: 64 | freq_syn, label_syn = self.freq_syn[indices], self.label_syn[indices] 65 | mask = self.mask[indices] 66 | 67 | image_syn = self.dct.inverse(mask * freq_syn) 68 | return image_syn, label_syn 69 | 70 | def init_mask(self): 71 | save_and_print(self.args.log_path, "Initialize Mask") 72 | 73 | for c in range(self.num_classes): 74 | freq_c = copy.deepcopy(self.freq_syn[c * self.num_per_class:(c + 1) * self.num_per_class].detach()) 75 | freq_c = torch.mean(freq_c, 1) 76 | freqs_flat = torch.flatten(freq_c, 1) 77 | freqs_flat = freqs_flat - torch.mean(freqs_flat, dim=0) 78 | 79 | try: 80 | cov = torch.cov(freqs_flat.T) 81 | except: 82 | save_and_print(self.args.log_path, f"Can not use torch.cov. Instead use np.cov") 83 | cov = np.cov(freqs_flat.T.cpu().numpy()) 84 | cov = torch.tensor(cov, dtype=torch.float, device=self.device) 85 | total_variance = torch.sum(torch.diag(cov)) 86 | vr_fl2f = torch.zeros((np.prod(self.im_size), 1), device=self.device) 87 | for idx in range(np.prod(self.im_size)): 88 | pc_low = torch.eye(np.prod(self.im_size), device=self.device)[idx].reshape(-1, 1) 89 | vector_variance = torch.matmul(torch.matmul(pc_low.T, cov), pc_low) 90 | explained_variance_ratio = vector_variance / total_variance 91 | vr_fl2f[idx] = explained_variance_ratio.item() 92 | 93 | v, i = torch.topk(vr_fl2f.flatten(), self.msz_per_channel) 94 | top_indices = np.array(np.unravel_index(i.cpu().numpy(), freq_c.shape)).T[:, 1:] 95 | for d, h, w in top_indices: 96 | self.mask[c * self.num_per_class:(c + 1) * self.num_per_class, :, d, h, w] = 1.0 97 | save_and_print(self.args.log_path, f"{get_time()} Class {c:3d} | {torch.sum(self.mask[c * self.num_per_class, 0] > 0.0):5d}") 98 | 99 | def optim_zero_grad(self): 100 | self.optimizers.zero_grad() 101 | 102 | def optim_step(self): 103 | self.optimizers.step() 104 | 105 | def show_budget(self): 106 | save_and_print(self.log_path, '=' * 50) 107 | save_and_print(self.log_path, f"Freq: {self.freq_syn.shape} | Mask: {self.mask.shape} , {torch.sum(self.mask[0, 0] > 0.0):5d}") 108 | images, _ = self.get(need_copy=True) 109 | save_and_print(self.log_path, f"Decode condensed data: {images.shape}") 110 | del images 111 | save_and_print(self.log_path, '=' * 50) 112 | 113 | def main(): 114 | 115 | parser = argparse.ArgumentParser(description='Parameter Processing') 116 | parser.add_argument('--method', type=str, default='DM', help='DC/DSA/DM') 117 | parser.add_argument('--dataset', type=str, default='3D-MNIST', help='dataset') 118 | parser.add_argument('--model', type=str, default='Conv3DNet', help='model') 119 | parser.add_argument('--ipc', type=int, default=1, help='image(s) per class') 120 | parser.add_argument('--eval_mode', type=str, default='S', help='eval_mode') # S: the same to training model, M: multi architectures, W: net width, D: net depth, A: activation function, P: pooling layer, N: normalization layer, 121 | parser.add_argument('--num_exp', type=int, default=1, help='the number of experiments') 122 | parser.add_argument('--num_eval', type=int, default=5, help='the number of evaluating randomly initialized models') 123 | parser.add_argument('--eval_it', type=int, default=200) 124 | parser.add_argument('--epoch_eval_train', type=int, default=500, help='epochs to train a model with synthetic data') # it can be small for speeding up with little performance drop 125 | parser.add_argument('--Iteration', type=int, default=20000, help='training iterations') 126 | parser.add_argument('--lr_net', type=float, default=0.01, help='learning rate for updating network parameters') 127 | parser.add_argument('--batch_real', type=int, default=256, help='batch size for real data') 128 | parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks') 129 | # parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', help='differentiable Siamese augmentation strategy') 130 | parser.add_argument('--data_path', type=str, default='../data', help='dataset path') 131 | parser.add_argument('--save_path', type=str, default='./results', help='path to save results') 132 | parser.add_argument('--dis_metric', type=str, default='ours', help='distance metric') 133 | 134 | 135 | parser.add_argument('--seed', type=int, default=0) 136 | parser.add_argument('--sh_file', type=str) 137 | parser.add_argument('--FLAG', type=str, default="TEST") 138 | 139 | ### FreD ### 140 | parser.add_argument('--batch_syn', type=int) 141 | parser.add_argument('--msz_per_channel', type=int) 142 | parser.add_argument('--lr_freq', type=float) 143 | parser.add_argument('--mom_freq', type=float) 144 | args = parser.parse_args() 145 | set_seed(args.seed) 146 | 147 | args.outer_loop, args.inner_loop = get_loops(args.ipc) 148 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 149 | 150 | if not os.path.exists(args.data_path): 151 | os.mkdir(args.data_path) 152 | 153 | if not os.path.exists(args.save_path): 154 | os.mkdir(args.save_path) 155 | args.save_path = args.save_path + f"/{args.FLAG}" 156 | if not os.path.exists(args.save_path): 157 | os.mkdir(args.save_path) 158 | 159 | shutil.copy(f"./scripts/{args.sh_file}", f"{args.save_path}/{args.sh_file}") 160 | args.log_path = f"{args.save_path}/log.txt" 161 | 162 | eval_it_pool = np.arange(0, args.Iteration+1, args.eval_it).tolist() 163 | channel, im_size, num_classes, class_names, dst_train, dst_test, testloader = get_dataset(args.data_path, args=args) 164 | args.channel, args.im_size, args.num_classes = channel, im_size, num_classes 165 | model_eval_pool = ["Conv3DNet"] 166 | 167 | accs_all_exps = dict() # record performances of all experiments 168 | for key in model_eval_pool: 169 | accs_all_exps[key] = [] 170 | 171 | data_save = [] 172 | 173 | for exp in range(args.num_exp): 174 | save_and_print(args.log_path, f'\n================== Exp {exp} ==================\n ') 175 | save_and_print(args.log_path, f'Hyper-parameters: {args.__dict__}') 176 | 177 | ''' organize the real dataset ''' 178 | images_all = [] 179 | labels_all = [] 180 | indices_class = [[] for c in range(num_classes)] 181 | 182 | images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))] 183 | labels_all = [dst_train[i][1] for i in range(len(dst_train))] 184 | for i, lab in enumerate(labels_all): 185 | indices_class[lab].append(i) 186 | images_all = torch.cat(images_all, dim=0).to(args.device) 187 | labels_all = torch.tensor(labels_all, dtype=torch.long, device=args.device) 188 | 189 | ''' initialize the synthetic data ''' 190 | synset = SynSet(args) 191 | synset.init(images_all, labels_all, indices_class) 192 | 193 | ''' training ''' 194 | save_and_print(args.log_path, '%s training begins'%get_time()) 195 | 196 | for it in range(args.Iteration+1): 197 | 198 | ''' Evaluate synthetic data ''' 199 | if it in eval_it_pool: 200 | for model_eval in model_eval_pool: 201 | save_and_print(args.log_path, '-------------------------\nEvaluation\nmodel_train = %s, model_eval = %s, iteration = %d'%(args.model, model_eval, it)) 202 | 203 | accs = [] 204 | for it_eval in range(args.num_eval): 205 | net_eval = get_network(model_eval, channel, num_classes, im_size).to(args.device) # get a random model 206 | image_syn_eval, label_syn_eval = synset.get(need_copy=True) 207 | _, acc_train, acc_test = evaluate_synset(it_eval, net_eval, image_syn_eval, label_syn_eval, testloader, args) 208 | accs.append(acc_test) 209 | 210 | del image_syn_eval, label_syn_eval 211 | save_and_print(args.log_path, 'Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------'%(len(accs), model_eval, np.mean(accs), np.std(accs))) 212 | 213 | if it == args.Iteration: # record the final results 214 | accs_all_exps[model_eval] += accs 215 | 216 | ''' Train synthetic data ''' 217 | net = get_network(args.model, channel, num_classes, im_size).to(args.device) # get a random model 218 | net.train() 219 | for param in list(net.parameters()): 220 | param.requires_grad = False 221 | 222 | embed = net.module.embed if torch.cuda.device_count() > 1 else net.embed # for GPU parallel 223 | 224 | loss_avg = 0 225 | 226 | ''' update synthetic data ''' 227 | loss = torch.tensor(0.0).to(args.device) 228 | for c in range(num_classes): 229 | img_real = get_images(images_all, indices_class, c, args.batch_real) 230 | 231 | if args.batch_syn > 0: 232 | indices = np.random.permutation(range(c * synset.num_per_class, (c + 1) * synset.num_per_class))[:args.batch_syn] 233 | else: 234 | indices = range(c * synset.num_per_class, (c + 1) * synset.num_per_class) 235 | 236 | img_syn, lab_syn = synset.get(indices=indices) 237 | 238 | output_real = embed(img_real).detach() 239 | output_syn = embed(img_syn) 240 | 241 | loss += torch.sum((torch.mean(output_real, dim=0) - torch.mean(output_syn, dim=0))**2) 242 | 243 | synset.optim_zero_grad() 244 | loss.backward() 245 | synset.optim_step() 246 | loss_avg += loss.item() 247 | 248 | loss_avg /= (num_classes) 249 | 250 | if it%10 == 0: 251 | save_and_print(args.log_path, '%s iter = %04d, loss = %.4f' % (get_time(), it, loss_avg)) 252 | 253 | if it == args.Iteration: # only record the final results 254 | data_save.append([synset.get(need_copy=True)]) 255 | torch.save({'data': data_save, 'accs_all_exps': accs_all_exps, }, os.path.join(args.save_path, 'res_%s_%s_%s_%dipc.pt'%(args.method, args.dataset, args.model, args.ipc))) 256 | 257 | 258 | save_and_print(args.log_path, '\n==================== Final Results ====================\n') 259 | for key in model_eval_pool: 260 | accs = accs_all_exps[key] 261 | save_and_print(args.log_path, 'Run %d experiments, train on %s, evaluate %d random %s, mean = %.2f%% std = %.2f%%'%(args.num_exp, args.model, len(accs), key, np.mean(accs)*100, np.std(accs)*100)) 262 | 263 | 264 | if __name__ == '__main__': 265 | main() 266 | 267 | 268 | -------------------------------------------------------------------------------- /ImageNet-abcde/main_DC_FreD.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 3 | 4 | import time 5 | import copy 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug, set_seed, save_and_print, get_images 10 | 11 | import shutil 12 | import torchvision 13 | import matplotlib.pyplot as plt 14 | from frequency_transforms import DCT 15 | 16 | from glad_utils import build_dataset, get_eval_lrs, eval_loop 17 | 18 | class SynSet(): 19 | def __init__(self, args): 20 | ### Basic ### 21 | self.args = args 22 | self.log_path = self.args.log_path 23 | self.channel = self.args.channel 24 | self.num_classes = self.args.num_classes 25 | self.im_size = self.args.im_size 26 | self.device = self.args.device 27 | self.ipc = self.args.ipc 28 | 29 | ### FreD ### 30 | self.dct = DCT(resolution=self.im_size[0], device=self.device) 31 | self.lr_freq = self.args.lr_freq 32 | self.mom_freq = self.args.mom_freq 33 | self.msz_per_channel = self.args.msz_per_channel 34 | self.num_per_class = int((self.ipc * self.im_size[0] * self.im_size[1]) / self.msz_per_channel) 35 | 36 | def init(self, images_real, labels_real, indices_class): 37 | ### Initialize Frequency (F) ### 38 | images = torch.randn(size=(self.num_classes * self.num_per_class, self.channel, self.im_size[0], self.im_size[1]), dtype=torch.float, device=self.device) 39 | for c in range(self.num_classes): 40 | idx_shuffle = np.random.permutation(indices_class[c])[:self.num_per_class] 41 | images.data[c * self.num_per_class:(c + 1) * self.num_per_class] = images_real[idx_shuffle].detach().data 42 | self.freq_syn = self.dct.forward(images) 43 | self.freq_syn.requires_grad = True 44 | del images 45 | 46 | ### Initialize Mask (M) ### 47 | self.mask = torch.zeros(size=(self.num_classes * self.num_per_class, self.channel, self.im_size[0], self.im_size[1]), dtype=torch.float, device=self.device) 48 | self.mask.requires_grad = False 49 | 50 | ### Initialize Label ### 51 | self.label_syn = torch.tensor([np.ones(self.num_per_class) * i for i in range(self.num_classes)], requires_grad=False, device=self.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9] 52 | self.label_syn = self.label_syn.long() 53 | 54 | ### Initialize Optimizer ### 55 | self.optimizers = torch.optim.SGD([self.freq_syn, ], lr=self.lr_freq, momentum=self.mom_freq) 56 | 57 | self.init_mask() 58 | self.optim_zero_grad() 59 | self.show_budget() 60 | 61 | def get(self, indices=None, need_copy=False): 62 | if not hasattr(indices, '__iter__'): 63 | indices = range(len(self.label_syn)) 64 | 65 | if need_copy: 66 | freq_syn, label_syn = copy.deepcopy(self.freq_syn[indices].detach()), copy.deepcopy(self.label_syn[indices].detach()) 67 | mask = copy.deepcopy(self.mask[indices].detach()) 68 | else: 69 | freq_syn, label_syn = self.freq_syn[indices], self.label_syn[indices] 70 | mask = self.mask[indices] 71 | 72 | image_syn = self.dct.inverse(mask * freq_syn) 73 | return image_syn, label_syn 74 | 75 | def init_mask(self): 76 | save_and_print(self.args.log_path, "Initialize Mask") 77 | 78 | for c in range(self.num_classes): 79 | freq_c = copy.deepcopy(self.freq_syn[c * self.num_per_class:(c + 1) * self.num_per_class].detach()) 80 | freq_c = torch.mean(freq_c, 1) 81 | freqs_flat = torch.flatten(freq_c, 1) 82 | freqs_flat = freqs_flat - torch.mean(freqs_flat, dim=0) 83 | 84 | try: 85 | cov = torch.cov(freqs_flat.T) 86 | except: 87 | save_and_print(self.args.log_path, f"Can not use torch.cov. Instead use np.cov") 88 | cov = np.cov(freqs_flat.T.cpu().numpy()) 89 | cov = torch.tensor(cov, dtype=torch.float, device=self.device) 90 | total_variance = torch.sum(torch.diag(cov)) 91 | vr_fl2f = torch.zeros((np.prod(self.im_size), 1), device=self.device) 92 | for idx in range(np.prod(self.im_size)): 93 | pc_low = torch.eye(np.prod(self.im_size), device=self.device)[idx].reshape(-1, 1) 94 | vector_variance = torch.matmul(torch.matmul(pc_low.T, cov), pc_low) 95 | explained_variance_ratio = vector_variance / total_variance 96 | vr_fl2f[idx] = explained_variance_ratio.item() 97 | 98 | v, i = torch.topk(vr_fl2f.flatten(), self.msz_per_channel) 99 | top_indices = np.array(np.unravel_index(i.cpu().numpy(), freq_c.shape)).T[:, 1:] 100 | for h, w in top_indices: 101 | self.mask[c * self.num_per_class:(c + 1) * self.num_per_class, :, h, w] = 1.0 102 | save_and_print(self.args.log_path, f"{get_time()} Class {c:3d} | {torch.sum(self.mask[c * self.num_per_class, 0] > 0.0):5d}") 103 | 104 | ### Visualize and Save ### 105 | indices_save = np.arange(10) * self.num_per_class 106 | grid = torchvision.utils.make_grid(self.mask[indices_save], nrow=10) 107 | plt.imshow(np.transpose(grid.detach().cpu().numpy(), (1, 2, 0))) 108 | plt.savefig(f"{self.args.save_path}/Mask.png", dpi=300) 109 | plt.close() 110 | 111 | mask_save = copy.deepcopy(self.mask.detach()) 112 | torch.save(mask_save.cpu(), os.path.join(self.args.save_path, "mask.pt")) 113 | del mask_save 114 | 115 | def optim_zero_grad(self): 116 | self.optimizers.zero_grad() 117 | 118 | def optim_step(self): 119 | self.optimizers.step() 120 | 121 | def show_budget(self): 122 | save_and_print(self.log_path, '=' * 50) 123 | save_and_print(self.log_path, f"Freq: {self.freq_syn.shape} | Mask: {self.mask.shape} , {torch.sum(self.mask[0, 0] > 0.0):5d}") 124 | images, _ = self.get(need_copy=True) 125 | save_and_print(self.log_path, f"Decode condensed data: {images.shape}") 126 | del images 127 | save_and_print(self.log_path, '=' * 50) 128 | 129 | def main(args): 130 | 131 | eval_it_pool = np.arange(0, args.Iteration + 1, args.eval_it).tolist() 132 | channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader, loader_train_dict, class_map, class_map_inv = get_dataset(args.dataset, args.data_path, args.batch_real, args.res, args=args) 133 | args.channel, args.im_size, args.num_classes, args.mean, args.std = channel, im_size, num_classes, mean, std 134 | model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model) 135 | 136 | accs_all_exps = dict() # record performances of all experiments 137 | for key in model_eval_pool: 138 | accs_all_exps[key] = [] 139 | 140 | data_save = [] 141 | 142 | save_and_print(args.log_path, f'\n================== Exp {0} ==================\n ') 143 | save_and_print(args.log_path, f'Hyper-parameters: {args.__dict__}') 144 | 145 | ''' organize the real dataset ''' 146 | images_all, labels_all, indices_class = build_dataset(dst_train, class_map, num_classes) 147 | images_all, labels_all = images_all.to(args.device), labels_all.to(args.device) 148 | 149 | ''' initialize the synthetic data''' 150 | synset = SynSet(args) 151 | synset.init(images_all, labels_all, indices_class) 152 | 153 | ''' training ''' 154 | criterion = nn.CrossEntropyLoss().to(args.device) 155 | save_and_print(args.log_path, '%s training begins'%get_time()) 156 | 157 | best_acc = {"{}".format(m): 0 for m in model_eval_pool} 158 | best_std = {m: 0 for m in model_eval_pool} 159 | eval_pool_dict = get_eval_lrs(args) 160 | 161 | save_this_it = False 162 | for it in range(args.Iteration+1): 163 | 164 | if it in eval_it_pool and it > 0: 165 | image_syn_eval, label_syn_eval = synset.get(need_copy=True) 166 | save_this_it = eval_loop(latents=image_syn_eval, f_latents=None, label_syn=label_syn_eval, G=None, best_acc=best_acc, 167 | best_std=best_std, testloader=testloader, 168 | model_eval_pool=model_eval_pool, channel=channel, num_classes=num_classes, 169 | im_size=im_size, it=it, args=args) 170 | 171 | del image_syn_eval, label_syn_eval 172 | 173 | ''' Train synthetic data ''' 174 | net = get_network(args.model, channel, num_classes, im_size, depth=args.depth, width=args.width).to(args.device) # get a random model 175 | net.train() 176 | net_parameters = list(net.parameters()) 177 | optimizer_net = torch.optim.SGD(net.parameters(), lr=args.lr_net) # optimizer_img for synthetic data 178 | optimizer_net.zero_grad() 179 | loss_avg = 0 180 | args.dc_aug_param = None # Mute the DC augmentation when learning synthetic data (in inner-loop epoch function) in oder to be consistent with DC paper. 181 | 182 | for ol in range(args.outer_loop): 183 | 184 | ''' freeze the running mu and sigma for BatchNorm layers ''' 185 | # Synthetic data batch, e.g. only 1 image/batch, is too small to obtain stable mu and sigma. 186 | # So, we calculate and freeze mu and sigma for BatchNorm layer with real data batch ahead. 187 | # This would make the training with BatchNorm layers easier. 188 | 189 | BN_flag = False 190 | BNSizePC = 16 # for batch normalization 191 | for module in net.modules(): 192 | if 'BatchNorm' in module._get_name(): #BatchNorm 193 | BN_flag = True 194 | if BN_flag: 195 | img_real = torch.cat([get_images(images_all, indices_class, c, BNSizePC) for c in range(num_classes)], dim=0) 196 | net.train() # for updating the mu, sigma of BatchNorm 197 | output_real = net(img_real) # get running mu, sigma 198 | for module in net.modules(): 199 | if 'BatchNorm' in module._get_name(): #BatchNorm 200 | module.eval() # fix mu and sigma of every BatchNorm layer 201 | 202 | ''' update synthetic data ''' 203 | loss = torch.tensor(0.0).to(args.device) 204 | for c in range(num_classes): 205 | img_real = get_images(images_all, indices_class, c, args.batch_real) 206 | lab_real = torch.ones((img_real.shape[0],), device=args.device, dtype=torch.long) * c 207 | 208 | if args.batch_syn > 0: 209 | indices = np.random.permutation(range(c * synset.num_per_class, (c + 1) * synset.num_per_class))[:args.batch_syn] 210 | else: 211 | indices = range(c * synset.num_per_class, (c + 1) * synset.num_per_class) 212 | 213 | img_syn, lab_syn = synset.get(indices=indices) 214 | 215 | if args.dsa: 216 | seed = int(time.time() * 1000) % 100000 217 | img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param) 218 | img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param) 219 | 220 | output_real = net(img_real) 221 | loss_real = criterion(output_real, lab_real) 222 | gw_real = torch.autograd.grad(loss_real, net_parameters) 223 | gw_real = list((_.detach().clone() for _ in gw_real)) 224 | 225 | output_syn = net(img_syn) 226 | loss_syn = criterion(output_syn, lab_syn) 227 | gw_syn = torch.autograd.grad(loss_syn, net_parameters, create_graph=True) 228 | 229 | loss += match_loss(gw_syn, gw_real, args) 230 | 231 | synset.optim_zero_grad() 232 | loss.backward() 233 | synset.optim_step() 234 | loss_avg += loss.item() 235 | 236 | if ol == args.outer_loop - 1: 237 | break 238 | 239 | ''' update network ''' 240 | image_syn_train, label_syn_train = synset.get(need_copy=True) 241 | dst_syn_train = TensorDataset(image_syn_train, label_syn_train) 242 | trainloader = torch.utils.data.DataLoader(dst_syn_train, batch_size=args.batch_train, shuffle=True, num_workers=0) 243 | for il in range(args.inner_loop): 244 | epoch('train', trainloader, net, optimizer_net, criterion, args, aug = True if args.dsa else False) 245 | 246 | loss_avg /= (num_classes*args.outer_loop) 247 | 248 | if it%10 == 0: 249 | save_and_print(args.log_path, '%s iter = %04d, loss = %.4f' % (get_time(), it, loss_avg)) 250 | 251 | if it == args.Iteration: # only record the final results 252 | data_save.append([synset.get(need_copy=True)]) 253 | torch.save({'data': data_save, 'accs_all_exps': accs_all_exps, }, os.path.join(args.save_path, 'res_%s_%s_%dipc.pt'%(args.dataset, args.model, args.ipc))) 254 | 255 | 256 | if __name__ == '__main__': 257 | import shared_args 258 | 259 | parser = shared_args.add_shared_args() 260 | 261 | parser.add_argument('--lr_net', type=float, default=0.01, help='learning rate for updating network parameters') 262 | parser.add_argument('--dis_metric', type=str, default='ours', help='distance metric') 263 | parser.add_argument('--zca', action='store_true', help="do ZCA whitening") 264 | 265 | parser.add_argument('--seed', type=int, default=0) 266 | parser.add_argument('--sh_file', type=str) 267 | parser.add_argument('--FLAG', type=str, default="TEST") 268 | 269 | ### FreD ### 270 | parser.add_argument('--batch_syn', type=int) 271 | parser.add_argument('--msz_per_channel', type=int) 272 | parser.add_argument('--lr_freq', type=float) 273 | parser.add_argument('--mom_freq', type=float) 274 | args = parser.parse_args() 275 | args.space = "p" 276 | args.zca = False 277 | 278 | set_seed(args.seed) 279 | 280 | args.outer_loop, args.inner_loop = get_loops(args.ipc) 281 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 282 | args.dsa_param = ParamDiffAug() 283 | args.dsa = False if args.dsa_strategy in ['none', 'None'] else True 284 | 285 | if not os.path.exists(args.data_path): 286 | os.mkdir(args.data_path) 287 | 288 | if not os.path.exists(args.save_path): 289 | os.mkdir(args.save_path) 290 | args.save_path = args.save_path + f"/{args.FLAG}" 291 | if not os.path.exists(args.save_path): 292 | os.mkdir(args.save_path) 293 | 294 | shutil.copy(f"./scripts/{args.sh_file}", f"{args.save_path}/{args.sh_file}") 295 | args.log_path = f"{args.save_path}/log.txt" 296 | 297 | main(args) 298 | 299 | 300 | -------------------------------------------------------------------------------- /ImageNet-abcde/buffer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from tqdm import tqdm 7 | from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug 8 | import copy 9 | 10 | def main(args): 11 | 12 | outer_loop_default, inner_loop_default = get_loops(args.ipc) 13 | 14 | args.lr_net = args.lr_teacher 15 | 16 | if args.outer_loop is None: 17 | args.outer_loop = outer_loop_default 18 | if args.inner_loop is None: 19 | args.inner_loop = inner_loop_default 20 | 21 | if args.g_train_ipc is None: 22 | args.g_train_ipc = args.ipc 23 | if args.g_eval_ipc is None: 24 | args.g_eval_ipc = args.ipc 25 | if args.g_grad_ipc is None: 26 | args.g_grad_ipc = args.ipc 27 | 28 | args.dsa = True if args.dsa == 'True' else False 29 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 30 | args.dsa_param = ParamDiffAug() 31 | args.dsa_param.blur_perc = args.blur_perc 32 | args.dsa_param.blur_min = args.blur_min 33 | args.dsa_param.blur_max = args.blur_max 34 | 35 | if args.lr_img is None: 36 | if args.space == 'p': 37 | args.lr_img = 0.1 38 | else: 39 | args.lr_img = 0.01 40 | 41 | if args.batch_syn is None: 42 | args.batch_syn = args.ipc 43 | 44 | args.num_batches = args.ipc // args.batch_syn 45 | 46 | run_name = "" 47 | 48 | if args.clip: 49 | run_name += "clip_" 50 | 51 | run_name += "{}_".format(args.model) 52 | 53 | run_name += "{}_".format(args.dataset) 54 | 55 | if args.dataset == "ImageNet": 56 | run_name += "{}_".format(args.subset) 57 | run_name += "{}_".format(args.res) 58 | 59 | 60 | run_name += "space_{}_".format(args.space) 61 | if args.space != "p": 62 | run_name += "tanh_{}_".format(args.tanh) 63 | run_name += "proj_{}_".format(args.proj_ball) 64 | run_name += "trunc_{}_".format(args.trunc) 65 | run_name += "layer_{}_".format(args.layer) 66 | 67 | if args.space == "p": 68 | run_name += "init_{}_".format(args.pix_init) 69 | elif args.space == "z": 70 | run_name += "init_{}_".format(args.gan_init) 71 | run_name += "RandCond_{}_".format(args.rand_cond) 72 | run_name += "RandLat_{}_".format(args.rand_lat) 73 | 74 | elif args.patch: 75 | run_name += "patch_" 76 | 77 | elif args.rand_gen: 78 | run_name += "rand-gen_" 79 | 80 | if args.spec_proj: 81 | run_name += "spec-proj_" 82 | 83 | if args.spec_reg is not None: 84 | run_name += "spec-reg_{}_".format(args.spec_reg) 85 | 86 | 87 | run_name += "aug_{}_".format(args.dsa) 88 | 89 | run_name += "ipc_{}_".format(args.ipc) 90 | 91 | run_name += "batch_{}_".format(args.batch_syn) 92 | 93 | run_name += "ol_{}_il_{}_".format(args.outer_loop, args.inner_loop) 94 | 95 | run_name += "im-opt_{}_".format(args.im_opt) 96 | 97 | run_name += "eval_{}_".format(args.eval_mode) 98 | 99 | args.save_path = os.path.join(args.save_path, run_name) 100 | 101 | if not os.path.exists(args.data_path): 102 | os.mkdir(args.data_path) 103 | 104 | if not os.path.exists(args.save_path): 105 | os.makedirs(args.save_path) 106 | 107 | # eval_it_pool = np.arange(0, args.syn_batches*args.Iteration+1, 100).tolist() if args.eval_mode == 'S' else [args.Iteration] # The list of iterations when we evaluate models and record results. 108 | eval_it_pool = np.arange(0, args.Iteration + 1, args.eval_it).tolist() 109 | channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader, loader_train_dict, class_map, class_map_inv = get_dataset(args.dataset, args.data_path, args.batch_real, args.res, args=args) 110 | model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model) 111 | 112 | im_res = im_size[0] 113 | 114 | # print('\n================== Exp %d ==================\n '%exp) 115 | print('Hyper-parameters: \n', args.__dict__) 116 | print('Evaluation model pool: ', model_eval_pool) 117 | 118 | save_dir = os.path.join(args.buffer_path, args.dataset) 119 | if args.pm1: 120 | save_dir = os.path.join(args.buffer_path, "tanh", args.dataset) 121 | else: 122 | save_dir = os.path.join(args.buffer_path, args.dataset) 123 | # if args.dataset == "ImageNet": 124 | # save_dir = os.path.join(save_dir, args.subset, str(args.res)) 125 | if args.dataset in ["CIFAR10", "CIFAR100"] and args.zca: 126 | save_dir += "_ZCA" 127 | save_dir = os.path.join(save_dir, args.model) 128 | 129 | save_dir = os.path.join(save_dir, "depth-{}".format(args.depth), "width-{}".format(args.width)) 130 | 131 | save_dir = os.path.join(save_dir, args.norm_train) 132 | 133 | if not os.path.exists(save_dir): 134 | os.makedirs(save_dir) 135 | 136 | 137 | if args.dataset != "ImageNet" or True: 138 | ''' organize the real dataset ''' 139 | images_all = [] 140 | labels_all = [] 141 | indices_class = [[] for c in range(num_classes)] 142 | print(len(dst_train)) 143 | print("BUILDING DATASET") 144 | for i in tqdm(range(len(dst_train))): 145 | sample = dst_train[i] 146 | images_all.append(torch.unsqueeze(sample[0], dim=0)) 147 | labels_all.append(class_map[torch.tensor(sample[1]).item()]) 148 | # images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in tqdm(range(len(dst_train)))] 149 | # labels_all = [class_map[dst_train[i][1]] for i in tqdm(range(len(dst_train)))] 150 | for i, lab in tqdm(enumerate(labels_all)): 151 | indices_class[lab].append(i) 152 | images_all = torch.cat(images_all, dim=0).to("cpu") 153 | labels_all = torch.tensor(labels_all, dtype=torch.long, device="cpu") 154 | 155 | for c in range(num_classes): 156 | print('class c = %d: %d real images'%(c, len(indices_class[c]))) 157 | 158 | for ch in range(channel): 159 | print('real images channel %d, mean = %.4f, std = %.4f'%(ch, torch.mean(images_all[:, ch]), torch.std(images_all[:, ch]))) 160 | 161 | criterion = nn.CrossEntropyLoss().to(args.device) 162 | 163 | trajectories = [] 164 | 165 | dst_train = TensorDataset(copy.deepcopy(images_all.detach()), copy.deepcopy(labels_all.detach())) 166 | trainloader = torch.utils.data.DataLoader(dst_train, batch_size=args.batch_train, shuffle=True, num_workers=0) 167 | 168 | ''' set augmentation for whole-dataset training ''' 169 | # args.dsa = False 170 | args.dc_aug_param = get_daparam(args.dataset, args.model, args.model, None) 171 | args.dc_aug_param['strategy'] = 'crop_scale_rotate' # for whole-dataset training 172 | print('DC augmentation parameters: \n', args.dc_aug_param) 173 | 174 | for it in range(0, args.Iteration): 175 | 176 | ''' Train synthetic data ''' 177 | teacher_net = get_network(args.model, channel, num_classes, im_size, depth=args.depth, width=args.width, norm=args.norm_train).to(args.device) # get a random model 178 | teacher_net.train() 179 | # if torch.cuda.device_count() > 1: 180 | # teacher_net = torch.nn.DataParallel(teacher_net) 181 | lr = args.lr_teacher 182 | # teacher_optim = torch.optim.SGD(teacher_net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005) # optimizer_img for synthetic data 183 | teacher_optim = torch.optim.SGD(teacher_net.parameters(), lr=lr, momentum=args.mom, weight_decay=args.l2) # optimizer_img for synthetic data 184 | teacher_optim.zero_grad() 185 | 186 | timestamps = [] 187 | 188 | timestamps.append([p.detach().cpu() for p in teacher_net.parameters()]) 189 | 190 | lr_schedule = [args.train_epochs // 2 + 1] 191 | 192 | for e in range(args.train_epochs): 193 | 194 | 195 | train_loss, train_acc = epoch("train", dataloader=trainloader, net=teacher_net, optimizer=teacher_optim, 196 | criterion=criterion, args=args, aug=True) 197 | 198 | test_loss, test_acc = epoch("test", dataloader=testloader, net=teacher_net, optimizer=None, 199 | criterion=criterion, args=args, aug=False) 200 | 201 | print("Itr: {}\tEpoch: {}\tReal Acc: {}\tTest Acc: {}".format(it, e, train_acc, test_acc)) 202 | 203 | timestamps.append([p.detach().cpu() for p in teacher_net.parameters()]) 204 | 205 | if e in lr_schedule and args.decay: 206 | lr *= 0.1 207 | teacher_optim = torch.optim.SGD(teacher_net.parameters(), lr=lr, momentum=args.mom, weight_decay=args.l2) 208 | teacher_optim.zero_grad() 209 | 210 | trajectories.append(timestamps) 211 | 212 | if len(trajectories) == args.save_interval: 213 | n = 0 214 | while os.path.exists(os.path.join(save_dir, "replay_buffer_{}.pt".format(n))): 215 | n += 1 216 | torch.save(trajectories, os.path.join(save_dir, "replay_buffer_{}.pt".format(n))) 217 | trajectories = [] 218 | 219 | print(trajectories) 220 | 221 | 222 | if __name__ == '__main__': 223 | parser = argparse.ArgumentParser(description='Parameter Processing') 224 | parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset') 225 | # parser.add_argument('--subset', type=str, default='imagenette', help='subset') 226 | parser.add_argument('--model', type=str, default='ConvNet', help='model') 227 | parser.add_argument('--ipc', type=int, default=1, help='image(s) per class') 228 | parser.add_argument('--eval_mode', type=str, default='S', help='eval_mode') # S: the same to training model, M: multi architectures, W: net width, D: net depth, A: activation function, P: pooling layer, N: normalization layer, 229 | # parser.add_argument('--num_exp', type=int, default=1, help='the number of experiments') 230 | parser.add_argument('--num_eval', type=int, default=5, help='the number of evaluating randomly initialized models') 231 | parser.add_argument('--eval_it', type=int, default=100, help='how often to evaluate') 232 | parser.add_argument('--epoch_eval_train', type=int, default=300, help='epochs to train a model with synthetic data') 233 | parser.add_argument('--Iteration', type=int, default=1000, help='training iterations') 234 | parser.add_argument('--lr_img', type=float, default=None, help='learning rate for updating synthetic images') 235 | parser.add_argument('--lr_lr', type=float, default=None, help='learning rate learning rate') 236 | parser.add_argument('--mom_img', type=float, default=0.5, help='momentum for updating synthetic images') 237 | parser.add_argument('--lr_teacher', type=float, default=0.01, help='learning rate for updating network parameters') 238 | parser.add_argument('--batch_real', type=int, default=256, help='batch size for real data') 239 | parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks') 240 | parser.add_argument('--batch_syn', type=int, default=None, help='batch size for syn data') 241 | parser.add_argument('--batch_test', type=int, default=128, help='batch size for training networks') 242 | parser.add_argument('--syn_batches', type=int, default=1, help='number of synthetic batches') 243 | parser.add_argument('--pix_init', type=str, default='noise', choices=["noise", "real"], help='noise/real: initialize synthetic images from random noise or randomly sampled real images.') 244 | parser.add_argument('--gan_init', type=str, default='class', choices=["class", "rand"]) 245 | parser.add_argument('--dsa', type=str, default='True', choices=['True', 'False'], help='whether to use differentiable Siamese augmentation.') 246 | parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', help='differentiable Siamese augmentation strategy') 247 | parser.add_argument('--data_path', type=str, default='data', help='dataset path') 248 | parser.add_argument('--buffer_path', type=str, default='./buffers', help='buffer path') 249 | parser.add_argument('--save_path', type=str, default='results_buffer', help='path to save results') 250 | parser.add_argument('--dis_metric', type=str, default='ours', help='distance metric') 251 | parser.add_argument('--blur', action='store_true') 252 | parser.add_argument('--patch', action='store_true') 253 | parser.add_argument('--tanh', action='store_true') 254 | parser.add_argument('--proj_ball', action='store_true') 255 | parser.add_argument('--rand_gen', action='store_true') 256 | parser.add_argument('--spec_proj', action='store_true') 257 | parser.add_argument('--spec_reg', type=float, default=None) 258 | parser.add_argument('--clip', action='store_true') 259 | parser.add_argument('--im_opt', type=str, default='sgd', choices=['sgd', 'adam']) 260 | parser.add_argument('--outer_loop', type=int, default=None) 261 | parser.add_argument('--inner_loop', type=int, default=None) 262 | parser.add_argument('--skip_epochs', type=int, default=0) 263 | parser.add_argument('--train_epochs', type=int, default=200) 264 | parser.add_argument('--trunc', type=float, default=1, help='truncation_trick') 265 | # parser.add_argument('--res', type=int, default=128, choices=[128, 256, 512], help='resolution') 266 | parser.add_argument('--res', type=int, default=128, help='resolution') 267 | parser.add_argument('--layer', type=int, default=0) 268 | parser.add_argument('--blur_perc', type=float, default=0.0) 269 | parser.add_argument('--blur_min', type=float, default=0.0) 270 | parser.add_argument('--blur_max', type=float, default=3.0) 271 | parser.add_argument('--rand_cond', action='store_true') 272 | parser.add_argument('--rand_lat', action='store_true') 273 | parser.add_argument('--coarse2fine', action='store_true') 274 | parser.add_argument('--teacher', type=str, default='fake', choices=['real', 'fake']) 275 | group = parser.add_mutually_exclusive_group() 276 | group.add_argument('--texture', action='store_true') 277 | group.add_argument('--tex_avg', action='store_true') 278 | parser.add_argument('--tex_it', type=int, default=500) 279 | parser.add_argument('--tex_batch', type=int, default=10) 280 | parser.add_argument('--lr_decay', type=str, default='none', choices=['none', 'cosine', 'linear', 'step']) 281 | # parser.add_argument('--g_grad_ipc', type=int, default=10) 282 | parser.add_argument('--g_train_ipc', type=int, default=None) 283 | parser.add_argument('--g_eval_ipc', type=int, default=None) 284 | parser.add_argument('--g_grad_ipc', type=int, default=None) 285 | parser.add_argument('--cluster', action='store_true') 286 | parser.add_argument('--zca', action='store_true') 287 | parser.add_argument('--learn_labels', action='store_true') 288 | parser.add_argument('--save_interval', type=int, default=10) 289 | parser.add_argument('--mom', type=float, default=0.0) 290 | parser.add_argument('--l2', type=float, default=0.0) 291 | parser.add_argument('--decay', type=bool, default=False) 292 | parser.add_argument('--cl_subset', default=None) 293 | parser.add_argument('--kip_zca', action='store_true') 294 | parser.add_argument('--pm1', action='store_true') 295 | 296 | parser.add_argument('--canvas_size', type=int, default=None) 297 | parser.add_argument('--canvas_samples', type=int, default=1) 298 | parser.add_argument('--canvas_stride', type=int, default=1) 299 | 300 | parser.add_argument('--space', type=str, default='p', choices=['p', 'z', 'w', 'w+', 'wp', 'g'], help='[ p | z | w | w+ ]') 301 | 302 | parser.add_argument('--width', type=int, default=128) 303 | parser.add_argument('--depth', type=int, default=3) 304 | 305 | parser.add_argument('--norm_train', type=str, default="batchnorm") 306 | 307 | parser.add_argument('--norm_eval', type=str, default="none") 308 | 309 | # For speeding up, we can decrease the Iteration and epoch_eval_train, which will not cause significant performance decrease. 310 | 311 | args = parser.parse_args() 312 | main(args) 313 | 314 | 315 | -------------------------------------------------------------------------------- /DC/main_DC_FreD.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 3 | 4 | import time 5 | import copy 6 | import argparse 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | from torchvision.utils import save_image 11 | from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug, set_seed, save_and_print, TensorDataset, get_images 12 | 13 | import shutil 14 | import torchvision 15 | import matplotlib.pyplot as plt 16 | from frequency_transforms import DCT 17 | 18 | class SynSet(): 19 | def __init__(self, args): 20 | ### Basic ### 21 | self.args = args 22 | self.log_path = self.args.log_path 23 | self.channel = self.args.channel 24 | self.num_classes = self.args.num_classes 25 | self.im_size = self.args.im_size 26 | self.device = self.args.device 27 | self.ipc = self.args.ipc 28 | 29 | ### FreD ### 30 | self.dct = DCT(resolution=self.im_size[0], device=self.device) 31 | self.lr_freq = self.args.lr_freq 32 | self.mom_freq = self.args.mom_freq 33 | self.msz_per_channel = self.args.msz_per_channel 34 | self.num_per_class = int((self.ipc * self.im_size[0] * self.im_size[1]) / self.msz_per_channel) 35 | 36 | def init(self, images_real, labels_real, indices_class): 37 | ### Initialize Frequency (F) ### 38 | images = torch.randn(size=(self.num_classes * self.num_per_class, self.channel, self.im_size[0], self.im_size[1]), dtype=torch.float, device=self.device) 39 | for c in range(self.num_classes): 40 | idx_shuffle = np.random.permutation(indices_class[c])[:self.num_per_class] 41 | images.data[c * self.num_per_class:(c + 1) * self.num_per_class] = images_real[idx_shuffle].detach().data 42 | self.freq_syn = self.dct.forward(images) 43 | self.freq_syn.requires_grad = True 44 | del images 45 | 46 | ### Initialize Mask (M) ### 47 | self.mask = torch.zeros(size=(self.num_classes * self.num_per_class, self.channel, self.im_size[0], self.im_size[1]), dtype=torch.float, device=self.device) 48 | self.mask.requires_grad = False 49 | 50 | ### Initialize Label ### 51 | self.label_syn = torch.tensor([np.ones(self.num_per_class) * i for i in range(self.num_classes)], requires_grad=False, device=self.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9] 52 | self.label_syn = self.label_syn.long() 53 | 54 | ### Initialize Optimizer ### 55 | self.optimizers = torch.optim.SGD([self.freq_syn, ], lr=self.lr_freq, momentum=self.mom_freq) 56 | 57 | self.init_mask() 58 | self.optim_zero_grad() 59 | self.show_budget() 60 | 61 | def get(self, indices=None, need_copy=False): 62 | if not hasattr(indices, '__iter__'): 63 | indices = range(len(self.label_syn)) 64 | 65 | if need_copy: 66 | freq_syn, label_syn = copy.deepcopy(self.freq_syn[indices].detach()), copy.deepcopy(self.label_syn[indices].detach()) 67 | mask = copy.deepcopy(self.mask[indices].detach()) 68 | else: 69 | freq_syn, label_syn = self.freq_syn[indices], self.label_syn[indices] 70 | mask = self.mask[indices] 71 | 72 | image_syn = self.dct.inverse(mask * freq_syn) 73 | return image_syn, label_syn 74 | 75 | def init_mask(self): 76 | save_and_print(self.args.log_path, "Initialize Mask") 77 | 78 | for c in range(self.num_classes): 79 | freq_c = copy.deepcopy(self.freq_syn[c * self.num_per_class:(c + 1) * self.num_per_class].detach()) 80 | freq_c = torch.mean(freq_c, 1) 81 | freqs_flat = torch.flatten(freq_c, 1) 82 | freqs_flat = freqs_flat - torch.mean(freqs_flat, dim=0) 83 | 84 | try: 85 | cov = torch.cov(freqs_flat.T) 86 | except: 87 | save_and_print(self.args.log_path, f"Can not use torch.cov. Instead use np.cov") 88 | cov = np.cov(freqs_flat.T.cpu().numpy()) 89 | cov = torch.tensor(cov, dtype=torch.float, device=self.device) 90 | total_variance = torch.sum(torch.diag(cov)) 91 | vr_fl2f = torch.zeros((np.prod(self.im_size), 1), device=self.device) 92 | for idx in range(np.prod(self.im_size)): 93 | pc_low = torch.eye(np.prod(self.im_size), device=self.device)[idx].reshape(-1, 1) 94 | vector_variance = torch.matmul(torch.matmul(pc_low.T, cov), pc_low) 95 | explained_variance_ratio = vector_variance / total_variance 96 | vr_fl2f[idx] = explained_variance_ratio.item() 97 | 98 | v, i = torch.topk(vr_fl2f.flatten(), self.msz_per_channel) 99 | top_indices = np.array(np.unravel_index(i.cpu().numpy(), freq_c.shape)).T[:, 1:] 100 | for h, w in top_indices: 101 | self.mask[c * self.num_per_class:(c + 1) * self.num_per_class, :, h, w] = 1.0 102 | save_and_print(self.args.log_path, f"{get_time()} Class {c:3d} | {torch.sum(self.mask[c * self.num_per_class, 0] > 0.0):5d}") 103 | 104 | ### Visualize and Save ### 105 | indices_save = np.arange(10) * self.num_per_class 106 | grid = torchvision.utils.make_grid(self.mask[indices_save], nrow=10) 107 | plt.imshow(np.transpose(grid.detach().cpu().numpy(), (1, 2, 0))) 108 | plt.savefig(f"{self.args.save_path}/Mask.png", dpi=300) 109 | plt.close() 110 | 111 | mask_save = copy.deepcopy(self.mask.detach()) 112 | torch.save(mask_save.cpu(), os.path.join(self.args.save_path, "mask.pt")) 113 | del mask_save 114 | 115 | def optim_zero_grad(self): 116 | self.optimizers.zero_grad() 117 | 118 | def optim_step(self): 119 | self.optimizers.step() 120 | 121 | def show_budget(self): 122 | save_and_print(self.log_path, '=' * 50) 123 | save_and_print(self.log_path, f"Freq: {self.freq_syn.shape} | Mask: {self.mask.shape} , {torch.sum(self.mask[0, 0] > 0.0):5d}") 124 | images, _ = self.get(need_copy=True) 125 | save_and_print(self.log_path, f"Decode condensed data: {images.shape}") 126 | del images 127 | save_and_print(self.log_path, '=' * 50) 128 | 129 | def main(): 130 | parser = argparse.ArgumentParser(description='Parameter Processing') 131 | parser.add_argument('--method', type=str, default='DC', help='DC/DSA') 132 | parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset') 133 | parser.add_argument('--model', type=str, default='ConvNet', help='model') 134 | parser.add_argument('--ipc', type=int, default=1, help='image(s) per class') 135 | parser.add_argument('--eval_mode', type=str, default='S', help='eval_mode') # S: the same to training model, M: multi architectures, W: net width, D: net depth, A: activation function, P: pooling layer, N: normalization layer, 136 | parser.add_argument('--num_exp', type=int, default=1, help='the number of experiments') 137 | parser.add_argument('--num_eval', type=int, default=10, help='the number of evaluating randomly initialized models') 138 | parser.add_argument('--epoch_eval_train', type=int, default=300, help='epochs to train a model with synthetic data') 139 | parser.add_argument('--Iteration', type=int, default=1000, help='training iterations') 140 | parser.add_argument('--lr_net', type=float, default=0.01, help='learning rate for updating network parameters') 141 | parser.add_argument('--batch_real', type=int, default=256, help='batch size for real data') 142 | parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks') 143 | parser.add_argument('--dsa_strategy', type=str, default='None', help='differentiable Siamese augmentation strategy') 144 | parser.add_argument('--data_path', type=str, default='../data', help='dataset path') 145 | parser.add_argument('--save_path', type=str, default='./results', help='path to save results') 146 | parser.add_argument('--dis_metric', type=str, default='ours', help='distance metric') 147 | 148 | parser.add_argument('--seed', type=int, default=0) 149 | parser.add_argument('--sh_file', type=str) 150 | parser.add_argument('--FLAG', type=str, default="TEST") 151 | 152 | ### FreD ### 153 | parser.add_argument('--batch_syn', type=int) 154 | parser.add_argument('--msz_per_channel', type=int) 155 | parser.add_argument('--lr_freq', type=float) 156 | parser.add_argument('--mom_freq', type=float) 157 | args = parser.parse_args() 158 | set_seed(args.seed) 159 | 160 | args.outer_loop, args.inner_loop = get_loops(args.ipc) 161 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 162 | args.dsa_param = ParamDiffAug() 163 | args.dsa = True if args.method == 'DSA' else False 164 | 165 | if not os.path.exists(args.data_path): 166 | os.mkdir(args.data_path) 167 | 168 | if not os.path.exists(args.save_path): 169 | os.mkdir(args.save_path) 170 | args.save_path = args.save_path + f"/{args.FLAG}" 171 | if not os.path.exists(args.save_path): 172 | os.mkdir(args.save_path) 173 | 174 | shutil.copy(f"./scripts/{args.sh_file}", f"{args.save_path}/{args.sh_file}") 175 | print (args) 176 | args.log_path = f"{args.save_path}/log.txt" 177 | 178 | eval_it_pool = np.arange(0, args.Iteration+1, 500).tolist() if args.eval_mode == 'S' or args.eval_mode == 'SS' else [args.Iteration] # The list of iterations when we evaluate models and record results. 179 | channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path, args=args) 180 | args.channel, args.im_size, args.num_classes, args.mean, args.std = channel, im_size, num_classes, mean, std 181 | model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model) 182 | 183 | 184 | accs_all_exps = dict() # record performances of all experiments 185 | for key in model_eval_pool: 186 | accs_all_exps[key] = [] 187 | 188 | data_save = [] 189 | 190 | for exp in range(args.num_exp): 191 | save_and_print(args.log_path, f'\n================== Exp {exp} ==================\n ') 192 | save_and_print(args.log_path, f'Hyper-parameters: {args.__dict__}') 193 | 194 | ''' organize the real dataset ''' 195 | images_all = [] 196 | labels_all = [] 197 | indices_class = [[] for c in range(num_classes)] 198 | 199 | images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))] 200 | labels_all = [dst_train[i][1] for i in range(len(dst_train))] 201 | for i, lab in enumerate(labels_all): 202 | indices_class[lab].append(i) 203 | images_all = torch.cat(images_all, dim=0).to(args.device) 204 | labels_all = torch.tensor(labels_all, dtype=torch.long, device=args.device) 205 | 206 | ''' initialize the synthetic data ''' 207 | synset = SynSet(args) 208 | synset.init(images_all, labels_all, indices_class) 209 | 210 | ''' training ''' 211 | criterion = nn.CrossEntropyLoss().to(args.device) 212 | save_and_print(args.log_path, '%s training begins'%get_time()) 213 | 214 | for it in range(args.Iteration+1): 215 | 216 | ''' Evaluate synthetic data ''' 217 | if it in eval_it_pool: 218 | for model_eval in model_eval_pool: 219 | save_and_print(args.log_path, '-------------------------\nEvaluation\nmodel_train = %s, model_eval = %s, iteration = %d'%(args.model, model_eval, it)) 220 | if args.dsa: 221 | args.epoch_eval_train = 1000 222 | args.dc_aug_param = None 223 | save_and_print(args.log_path, f'DSA augmentation strategy: {args.dsa_strategy}') 224 | save_and_print(args.log_path, f'DSA augmentation parameters: {args.dsa_param.__dict__}') 225 | else: 226 | args.dc_aug_param = get_daparam(args.dataset, args.model, model_eval, args.ipc) # This augmentation parameter set is only for DC method. It will be muted when args.dsa is True. 227 | save_and_print(args.log_path, f'DC augmentation parameters: {args.dc_aug_param}') 228 | 229 | if args.dsa or args.dc_aug_param['strategy'] != 'none': 230 | args.epoch_eval_train = 1000 # Training with data augmentation needs more epochs. 231 | else: 232 | args.epoch_eval_train = 300 233 | 234 | accs = [] 235 | for it_eval in range(args.num_eval): 236 | net_eval = get_network(model_eval, channel, num_classes, im_size).to(args.device) # get a random model 237 | image_syn_eval, label_syn_eval = synset.get(need_copy=True) 238 | _, acc_train, acc_test = evaluate_synset(it_eval, net_eval, image_syn_eval, label_syn_eval, testloader, args) 239 | accs.append(acc_test) 240 | save_and_print(args.log_path, 'Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------'%(len(accs), model_eval, np.mean(accs), np.std(accs))) 241 | del image_syn_eval, label_syn_eval 242 | 243 | if it == args.Iteration: # record the final results 244 | accs_all_exps[model_eval] += accs 245 | 246 | ''' visualize and save ''' 247 | save_name = os.path.join(args.save_path, 'vis_%s_%s_%s_%dipc_exp%d_iter%d.png'%(args.method, args.dataset, args.model, args.ipc, exp, it)) 248 | image_syn_vis, _ = synset.get(need_copy=True) 249 | for ch in range(channel): 250 | image_syn_vis[:, ch] = image_syn_vis[:, ch] * std[ch] + mean[ch] 251 | image_syn_vis[image_syn_vis<0] = 0.0 252 | image_syn_vis[image_syn_vis>1] = 1.0 253 | save_image(image_syn_vis, save_name, nrow=synset.num_per_class) # Trying normalize = True/False may get better visual effects. 254 | del image_syn_vis 255 | 256 | 257 | ''' Train synthetic data ''' 258 | net = get_network(args.model, channel, num_classes, im_size).to(args.device) # get a random model 259 | net.train() 260 | net_parameters = list(net.parameters()) 261 | optimizer_net = torch.optim.SGD(net.parameters(), lr=args.lr_net) # optimizer_img for synthetic data 262 | optimizer_net.zero_grad() 263 | loss_avg = 0 264 | args.dc_aug_param = None # Mute the DC augmentation when learning synthetic data (in inner-loop epoch function) in oder to be consistent with DC paper. 265 | 266 | for ol in range(args.outer_loop): 267 | 268 | ''' freeze the running mu and sigma for BatchNorm layers ''' 269 | # Synthetic data batch, e.g. only 1 image/batch, is too small to obtain stable mu and sigma. 270 | # So, we calculate and freeze mu and sigma for BatchNorm layer with real data batch ahead. 271 | # This would make the training with BatchNorm layers easier. 272 | BN_flag = False 273 | BNSizePC = 16 # for batch normalization 274 | for module in net.modules(): 275 | if 'BatchNorm' in module._get_name(): #BatchNorm 276 | BN_flag = True 277 | if BN_flag: 278 | img_real = torch.cat([get_images(images_all, indices_class, c, BNSizePC) for c in range(num_classes)], dim=0) 279 | net.train() # for updating the mu, sigma of BatchNorm 280 | output_real = net(img_real) # get running mu, sigma 281 | for module in net.modules(): 282 | if 'BatchNorm' in module._get_name(): #BatchNorm 283 | module.eval() # fix mu and sigma of every BatchNorm layer 284 | 285 | 286 | ''' update synthetic data ''' 287 | loss = torch.tensor(0.0).to(args.device) 288 | for c in range(num_classes): 289 | img_real = get_images(images_all, indices_class, c, args.batch_real) 290 | lab_real = torch.ones((img_real.shape[0],), device=args.device, dtype=torch.long) * c 291 | 292 | if args.batch_syn > 0: 293 | indices = np.random.permutation(range(c * synset.num_per_class, (c + 1) * synset.num_per_class))[:args.batch_syn] 294 | else: 295 | indices = range(c * synset.num_per_class, (c + 1) * synset.num_per_class) 296 | 297 | img_syn, lab_syn = synset.get(indices=indices) 298 | 299 | if args.dsa: 300 | seed = int(time.time() * 1000) % 100000 301 | img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param) 302 | img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param) 303 | 304 | output_real = net(img_real) 305 | loss_real = criterion(output_real, lab_real) 306 | gw_real = torch.autograd.grad(loss_real, net_parameters) 307 | gw_real = list((_.detach().clone() for _ in gw_real)) 308 | 309 | output_syn = net(img_syn) 310 | loss_syn = criterion(output_syn, lab_syn) 311 | gw_syn = torch.autograd.grad(loss_syn, net_parameters, create_graph=True) 312 | 313 | loss += match_loss(gw_syn, gw_real, args) 314 | 315 | synset.optim_zero_grad() 316 | loss.backward() 317 | synset.optim_step() 318 | loss_avg += loss.item() 319 | 320 | if ol == args.outer_loop - 1: 321 | break 322 | 323 | 324 | ''' update network ''' 325 | image_syn_train, label_syn_train = synset.get(need_copy=True) 326 | dst_syn_train = TensorDataset(image_syn_train, label_syn_train) 327 | trainloader = torch.utils.data.DataLoader(dst_syn_train, batch_size=args.batch_train, shuffle=True, num_workers=0) 328 | for il in range(args.inner_loop): 329 | epoch('train', trainloader, net, optimizer_net, criterion, args, aug = True if args.dsa else False) 330 | 331 | 332 | loss_avg /= (num_classes*args.outer_loop) 333 | 334 | if it%10 == 0: 335 | save_and_print(args.log_path, '%s iter = %04d, loss = %.4f' % (get_time(), it, loss_avg)) 336 | 337 | if it == args.Iteration: # only record the final results 338 | data_save.append([synset.get(need_copy=True)]) 339 | torch.save({'data': data_save, 'accs_all_exps': accs_all_exps, }, os.path.join(args.save_path, 'res_%s_%s_%s_%dipc.pt'%(args.method, args.dataset, args.model, args.ipc))) 340 | 341 | 342 | save_and_print(args.log_path, '\n==================== Final Results ====================\n') 343 | for key in model_eval_pool: 344 | accs = accs_all_exps[key] 345 | save_and_print(args.log_path, 'Run %d experiments, train on %s, evaluate %d random %s, mean = %.2f%% std = %.2f%%'%(args.num_exp, args.model, len(accs), key, np.mean(accs)*100, np.std(accs)*100)) 346 | 347 | 348 | 349 | if __name__ == '__main__': 350 | main() 351 | 352 | 353 | --------------------------------------------------------------------------------