├── overview_FreD.png
├── requirements.txt
├── ImageNet-abcde
├── networks
│ ├── __init__.py
│ ├── alexnet.py
│ ├── alexnet_cifar.py
│ ├── vgg.py
│ ├── vgg_cifar.py
│ ├── conv_gap.py
│ ├── dnfr.py
│ ├── conv.py
│ ├── vit_cifar.py
│ ├── vit.py
│ ├── resnet_cifar.py
│ └── resnet.py
├── scripts
│ ├── run_buffer.sh
│ ├── run_DC_FreD.sh
│ ├── run_DM_FreD.sh
│ └── run_TM_FreD.sh
├── hyper_params.py
├── shared_args.py
├── frequency_transforms.py
├── reparam_module.py
├── glad_utils.py
├── main_DM_FreD.py
├── main_DC_FreD.py
└── buffer.py
├── TM
├── scripts
│ ├── run_buffer.sh
│ └── run_TM_FreD.sh
├── hyper_params.py
├── frequency_transforms.py
├── buffer.py
└── reparam_module.py
├── corruption-exp
├── scripts
│ └── run.sh
└── main.py
├── DC
├── scripts
│ └── run_DC_FreD.sh
├── frequency_transforms.py
└── main_DC_FreD.py
├── DM
├── scripts
│ └── run_DM_FreD.sh
└── frequency_transforms.py
├── 3D-MNIST
├── scripts
│ └── run_DM_FreD.sh
├── networks.py
├── frequency_transforms.py
├── utils.py
└── main_DM_FreD.py
├── README.md
└── LICENSE
/overview_FreD.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdh0818/FreD/HEAD/overview_FreD.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | torch==1.11.0
3 | torchaudio==0.11.0
4 | torchvision==0.12.0
5 | scipy
6 | tqdm
7 | matplotlib
8 | kornia
9 | ema-pytorch
10 | einops
11 | h5py
--------------------------------------------------------------------------------
/ImageNet-abcde/networks/__init__.py:
--------------------------------------------------------------------------------
1 | from .alexnet import AlexNet
2 | from .conv import ConvNet
3 | from .dnfr import DNFR
4 | from .resnet import ResNet, ResNet18, ResNet18ImageNet
5 | from .vgg import VGG, VGG11, VGG13, VGG16, VGG19, VGG11BN
6 | from .vit import ViT
7 | from .conv_gap import ConvNetGAP
8 | from .alexnet_cifar import AlexNetCIFAR
9 | from .resnet_cifar import ResNet18CIFAR
10 | from .vgg_cifar import VGG11CIFAR
11 | from .vit_cifar import ViTCIFAR
--------------------------------------------------------------------------------
/TM/scripts/run_buffer.sh:
--------------------------------------------------------------------------------
1 | cd ..
2 | cuda_id=3
3 | dst="CIFAR10"
4 | subset="None"
5 | model="ConvNetD3"
6 | data_path="../data"
7 | buffer_path="../buffers"
8 |
9 | train_epochs=50
10 | num_experts=100
11 |
12 | CUDA_VISIBLE_DEVICES=${cuda_id} python3.8 buffer.py \
13 | --dataset=${dst} --subset ${subset} \
14 | --model=${model} \
15 | --data_path ${data_path} --buffer_path ${buffer_path} \
16 | --train_epochs=${train_epochs} \
17 | --num_experts=${num_experts}
--------------------------------------------------------------------------------
/ImageNet-abcde/scripts/run_buffer.sh:
--------------------------------------------------------------------------------
1 | cd ..
2 | cuda_id=3
3 | dst="imagenet-a"
4 | res=128
5 | net="ConvNet"
6 | depth=5
7 | norm_train="instancenorm"
8 | data_path="../../../../../../data/IMAGENET2012"
9 | buffer_path="../buffers"
10 |
11 | train_epochs=50
12 | num_experts=100
13 |
14 | CUDA_VISIBLE_DEVICES=${cuda_id} python3.8 buffer.py \
15 | --dataset=${dst} --res=${res} \
16 | --model=${net} --depth=${depth} --norm_train=${norm_train} \
17 | --data_path=${data_path} --buffer_path=${buffer_path} \
18 | --train_epochs=${train_epochs} --Iteration=${num_experts}
--------------------------------------------------------------------------------
/corruption-exp/scripts/run.sh:
--------------------------------------------------------------------------------
1 | cd ..
2 | cuda_id=2
3 | method="FreD"
4 | source_dataset="CIFAR10"
5 | target_dataset="CIFAR10-C"
6 | subset="none"
7 | level=5
8 | ipc=1
9 | sh_file="run.sh"
10 | data_path="../data"
11 | save_path="./results"
12 | synset_path="./trained_synset"
13 |
14 | num_eval=5
15 | epoch_eval_train=1000
16 | batch_train=256
17 |
18 | FLAG="${target_dataset}_level${level}#${subset}#${method}_ipc${ipc}"
19 | CUDA_VISIBLE_DEVICES=${cuda_id} python3.8 main.py \
20 | --source_dataset ${source_dataset} --target_dataset ${target_dataset} --subset ${subset} \
21 | --level ${level} \
22 | --ipc ${ipc} \
23 | --sh_file ${sh_file} \
24 | --data_path ${data_path} --save_path ${save_path} --synset_path ${synset_path} \
25 | --num_eval ${num_eval} --epoch_eval_train ${epoch_eval_train} --batch_train ${batch_train} \
26 | --FLAG ${FLAG}
27 |
--------------------------------------------------------------------------------
/DC/scripts/run_DC_FreD.sh:
--------------------------------------------------------------------------------
1 | cd ..
2 | cuda_id=0
3 | dst="CIFAR10"
4 | net="ConvNetD3"
5 | ipc=2
6 | sh_file="run_DC_FreD.sh"
7 | eval_mode="S"
8 | data_path="../data"
9 | save_path="./results"
10 |
11 | num_eval=5
12 | Iteration=1000
13 | batch_syn=0 # 0 means no sampling (use entire synthetic dataset)
14 | msz_per_channel=32
15 | lr_freq=1e3
16 | mom_freq=0.5
17 |
18 | TAG=""
19 | FLAG="${dst}_${ipc}ipc_${net}#DC_FreD_${msz_per_channel}_${batch_syn}#${Iteration}_${lr_freq}_${mom_freq}#${TAG}"
20 |
21 | CUDA_VISIBLE_DEVICES=${cuda_id} python3.8 main_DC_FreD.py \
22 | --dataset ${dst} \
23 | --model ${net} \
24 | --ipc ${ipc} \
25 | --sh_file ${sh_file} \
26 | --eval_mode ${eval_mode} \
27 | --data_path ${data_path} --save_path ${save_path} \
28 | --num_eval ${num_eval} \
29 | --Iteration ${Iteration} \
30 | --batch_syn ${batch_syn} \
31 | --msz_per_channel ${msz_per_channel} \
32 | --lr_freq ${lr_freq} --mom_freq ${mom_freq} \
33 | --FLAG ${FLAG}
--------------------------------------------------------------------------------
/DM/scripts/run_DM_FreD.sh:
--------------------------------------------------------------------------------
1 | cd ..
2 | cuda_id=1
3 | dst="CIFAR10"
4 | net="ConvNetD3"
5 | ipc=2
6 | sh_file="run_DM_FreD.sh"
7 | eval_mode="S"
8 | data_path="../data"
9 | save_path="./results"
10 |
11 | num_eval=5
12 | Iteration=20000
13 | batch_syn=0 # 0 means no sampling (use entire synthetic dataset)
14 | msz_per_channel=64
15 | lr_freq=1e6
16 | mom_freq=0.5
17 |
18 | TAG=""
19 | FLAG="${dst}_${ipc}ipc_${net}#DM_FreD_${msz_per_channel}_${batch_syn}#${Iteration}_${lr_freq}_${mom_freq}#${TAG}"
20 |
21 | CUDA_VISIBLE_DEVICES=${cuda_id} python3.8 main_DM_FreD.py \
22 | --dataset ${dst} \
23 | --model ${net} \
24 | --ipc ${ipc} \
25 | --sh_file ${sh_file} \
26 | --eval_mode ${eval_mode} \
27 | --data_path ${data_path} --save_path ${save_path} \
28 | --num_eval ${num_eval} \
29 | --Iteration ${Iteration} \
30 | --batch_syn ${batch_syn} \
31 | --msz_per_channel ${msz_per_channel} \
32 | --lr_freq ${lr_freq} --mom_freq ${mom_freq} \
33 | --FLAG ${FLAG}
--------------------------------------------------------------------------------
/3D-MNIST/scripts/run_DM_FreD.sh:
--------------------------------------------------------------------------------
1 | cd ..
2 | cuda_id=2
3 | dst="3D-MNIST"
4 | net="Conv3DNet"
5 | ipc=1
6 | sh_file="run_DM_FreD.sh"
7 | eval_mode="S"
8 | data_path="../data"
9 | save_path="./results"
10 |
11 | num_eval=5
12 | Iteration=1000
13 | batch_syn=0 # 0 means no sampling (use entire synthetic dataset)
14 | msz_per_channel=512
15 | lr_freq=1e6
16 | mom_freq=0.5
17 |
18 | TAG=""
19 | FLAG="${dst}_${ipc}ipc_${net}#DM_FreD_${msz_per_channel}_${batch_syn}#${Iteration}_${lr_freq}_${mom_freq}#${TAG}"
20 |
21 | CUDA_VISIBLE_DEVICES=${cuda_id} python3.8 main_DM_FreD.py \
22 | --dataset ${dst} \
23 | --model ${net} \
24 | --ipc ${ipc} \
25 | --sh_file ${sh_file} \
26 | --eval_mode ${eval_mode} \
27 | --data_path ${data_path} --save_path ${save_path} \
28 | --num_eval ${num_eval} \
29 | --Iteration ${Iteration} \
30 | --batch_syn ${batch_syn} \
31 | --msz_per_channel ${msz_per_channel} \
32 | --lr_freq ${lr_freq} --mom_freq ${mom_freq} \
33 | --FLAG ${FLAG}
--------------------------------------------------------------------------------
/ImageNet-abcde/scripts/run_DC_FreD.sh:
--------------------------------------------------------------------------------
1 | cd ..
2 | cuda_id=0
3 | dst="imagenet-a"
4 | res=128
5 | net="ConvNet"
6 | depth=5
7 | ipc=1
8 | sh_file="run_DC_FreD.sh"
9 | eval_mode="M"
10 | data_path="../../../../../../data/IMAGENET2012"
11 | save_path="./results"
12 |
13 | num_eval=5
14 | Iteration=1000
15 | batch_syn=0 # 0 means no sampling (use entire synthetic dataset)
16 | msz_per_channel=2048
17 | lr_freq=1e5
18 | mom_freq=0.5
19 |
20 | TAG=""
21 | FLAG="${dst}_${res}_${ipc}ipc_${net}D${depth}#DC_FreD_${msz_per_channel}_${batch_syn}#${Iteration}_${lr_freq}_${mom_freq}#${TAG}"
22 |
23 | CUDA_VISIBLE_DEVICES=${cuda_id} python3.8 main_DC_FreD.py \
24 | --dataset ${dst} --res ${res} \
25 | --model ${net} --depth ${depth} \
26 | --ipc ${ipc} \
27 | --sh_file ${sh_file} \
28 | --eval_mode ${eval_mode} \
29 | --data_path ${data_path} --save_path ${save_path} \
30 | --num_eval ${num_eval} \
31 | --Iteration ${Iteration} \
32 | --batch_syn ${batch_syn} \
33 | --msz_per_channel ${msz_per_channel} \
34 | --lr_freq ${lr_freq} --mom_freq ${mom_freq} \
35 | --FLAG ${FLAG}
--------------------------------------------------------------------------------
/ImageNet-abcde/scripts/run_DM_FreD.sh:
--------------------------------------------------------------------------------
1 | cd ..
2 | cuda_id=1
3 | dst="imagenet-a"
4 | res=128
5 | net="ConvNet"
6 | depth=5
7 | ipc=1
8 | sh_file="run_DM_FreD.sh"
9 | eval_mode="M"
10 | data_path="../../../../../../data/IMAGENET2012"
11 | save_path="./results"
12 |
13 | num_eval=5
14 | Iteration=1000
15 | batch_syn=0 # 0 means no sampling (use entire synthetic dataset)
16 | msz_per_channel=2048
17 | lr_freq=1e6
18 | mom_freq=0.5
19 |
20 | TAG=""
21 | FLAG="${dst}_${res}_${ipc}ipc_${net}D${depth}#DM_FreD_${msz_per_channel}_${batch_syn}#${Iteration}_${lr_freq}_${mom_freq}#${TAG}"
22 |
23 | CUDA_VISIBLE_DEVICES=${cuda_id} python3.8 main_DM_FreD.py \
24 | --dataset ${dst} --res ${res} \
25 | --model ${net} --depth ${depth} \
26 | --ipc ${ipc} \
27 | --sh_file ${sh_file} \
28 | --eval_mode ${eval_mode} \
29 | --data_path ${data_path} --save_path ${save_path} \
30 | --num_eval ${num_eval} \
31 | --Iteration ${Iteration} \
32 | --batch_syn ${batch_syn} \
33 | --msz_per_channel ${msz_per_channel} \
34 | --lr_freq ${lr_freq} --mom_freq ${mom_freq} \
35 | --FLAG ${FLAG}
--------------------------------------------------------------------------------
/TM/scripts/run_TM_FreD.sh:
--------------------------------------------------------------------------------
1 | cd ..
2 | cuda_id=0
3 | dst="CIFAR10"
4 | subset="none"
5 | net="ConvNetD3"
6 | ipc=1
7 | sh_file="run_TM_FreD.sh"
8 | eval_mode="S"
9 | data_path="../data"
10 | save_path="./results"
11 | buffer_path="../buffers"
12 |
13 | num_eval=5
14 | Iteration=15000
15 | batch_syn=0 # 0 means no sampling (use entire synthetic dataset)
16 | msz_per_channel=64
17 | lr_freq=1e8
18 | mom_freq=0.5
19 |
20 | TAG=""
21 | FLAG="${dst}_${subset}_${ipc}ipc_${net}#TM_FreD_${msz_per_channel}_${batch_syn}#${Iteration}_${lr_freq}_${mom_freq}#${TAG}"
22 |
23 | CUDA_VISIBLE_DEVICES=${cuda_id} python3.8 main_TM_FreD.py \
24 | --dataset ${dst} --subset ${subset} \
25 | --model ${net} \
26 | --ipc ${ipc} \
27 | --sh_file ${sh_file} \
28 | --eval_mode ${eval_mode} \
29 | --data_path ${data_path} --save_path ${save_path} --buffer_path ${buffer_path} \
30 | --num_eval ${num_eval} \
31 | --Iteration ${Iteration} \
32 | --batch_syn ${batch_syn} \
33 | --msz_per_channel ${msz_per_channel} \
34 | --lr_freq ${lr_freq} --mom_freq ${mom_freq} \
35 | --FLAG ${FLAG}
--------------------------------------------------------------------------------
/ImageNet-abcde/scripts/run_TM_FreD.sh:
--------------------------------------------------------------------------------
1 | cd ..
2 | cuda_id=0,1,2,3
3 | dst="imagenet-a"
4 | res=128
5 | net="ConvNet"
6 | depth=5
7 | norm_train="instancenorm"
8 | ipc=1
9 | sh_file="run_TM_FreD.sh"
10 | eval_mode="M"
11 | data_path="../../../../../../data/IMAGENET2012"
12 | save_path="./results"
13 | buffer_path="../../../../../../data/tlsehdgur0/buffers"
14 |
15 | num_eval=5
16 | eval_it=500
17 | Iteration=15000
18 | batch_syn=0 # 0 means no sampling (use entire synthetic dataset)
19 | msz_per_channel=2048
20 | lr_freq=1e9
21 | mom_freq=0.5
22 |
23 | TAG=""
24 | FLAG="${dst}_${res}_${ipc}ipc_${net}D${depth}#TM_FreD_${msz_per_channel}_${batch_syn}#${Iteration}_${lr_freq}_${mom_freq}#${TAG}"
25 |
26 | CUDA_VISIBLE_DEVICES=${cuda_id} python3.8 main_TM_FreD.py \
27 | --dataset ${dst} --res ${res} \
28 | --model ${net} --depth ${depth} --norm_train ${norm_train} \
29 | --ipc ${ipc} \
30 | --sh_file ${sh_file} \
31 | --eval_mode ${eval_mode} \
32 | --data_path ${data_path} --save_path ${save_path} --buffer_path ${buffer_path} \
33 | --num_eval ${num_eval} --eval_it ${eval_it} \
34 | --Iteration ${Iteration} \
35 | --batch_syn ${batch_syn} \
36 | --msz_per_channel ${msz_per_channel} \
37 | --lr_freq ${lr_freq} --mom_freq ${mom_freq} \
38 | --FLAG ${FLAG}
--------------------------------------------------------------------------------
/ImageNet-abcde/networks/alexnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch
4 | ''' AlexNet '''
5 | class AlexNet(nn.Module):
6 | def __init__(self, channel, num_classes, im_size, **kwargs):
7 | super(AlexNet, self).__init__()
8 | self.features = nn.Sequential(
9 | nn.Conv2d(channel, 128, kernel_size=5, stride=1, padding=4 if channel==1 else 2),
10 | nn.ReLU(inplace=True),
11 | nn.MaxPool2d(kernel_size=2, stride=2),
12 | nn.Conv2d(128, 192, kernel_size=5, padding=2),
13 | nn.ReLU(inplace=True),
14 | nn.MaxPool2d(kernel_size=2, stride=2),
15 | nn.Conv2d(192, 256, kernel_size=3, padding=1),
16 | nn.ReLU(inplace=True),
17 | nn.Conv2d(256, 192, kernel_size=3, padding=1),
18 | nn.ReLU(inplace=True),
19 | nn.Conv2d(192, 192, kernel_size=3, padding=1),
20 | nn.ReLU(inplace=True),
21 | nn.MaxPool2d(kernel_size=2, stride=2),
22 | )
23 | self.fc = nn.Linear(192 * im_size[0]//8 * im_size[1]//8, num_classes)
24 |
25 | def forward(self, x):
26 | x = self.features(x)
27 | feat_fc = x.view(x.size(0), -1)
28 | x = self.fc(feat_fc)
29 |
30 | return x
31 |
32 |
--------------------------------------------------------------------------------
/ImageNet-abcde/networks/alexnet_cifar.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | class AlexNetCIFAR(nn.Module):
4 | def __init__(self, channel, num_classes):
5 | super(AlexNetCIFAR, self).__init__()
6 | self.features = nn.Sequential(
7 | nn.Conv2d(channel, 128, kernel_size=5, stride=1, padding=4 if channel==1 else 2),
8 | nn.ReLU(inplace=True),
9 | nn.MaxPool2d(kernel_size=2, stride=2),
10 | nn.Conv2d(128, 192, kernel_size=5, padding=2),
11 | nn.ReLU(inplace=True),
12 | nn.MaxPool2d(kernel_size=2, stride=2),
13 | nn.Conv2d(192, 256, kernel_size=3, padding=1),
14 | nn.ReLU(inplace=True),
15 | nn.Conv2d(256, 192, kernel_size=3, padding=1),
16 | nn.ReLU(inplace=True),
17 | nn.Conv2d(192, 192, kernel_size=3, padding=1),
18 | nn.ReLU(inplace=True),
19 | nn.MaxPool2d(kernel_size=2, stride=2),
20 | )
21 | self.fc = nn.Linear(192 * 4 * 4, num_classes)
22 |
23 | def forward(self, x):
24 | x = self.features(x)
25 | x = x.view(x.size(0), -1)
26 | x = self.fc(x)
27 | return x
28 |
29 | def embed(self, x):
30 | x = self.features(x)
31 | x = x.view(x.size(0), -1)
32 | return x
--------------------------------------------------------------------------------
/3D-MNIST/networks.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | class Conv3DNet(nn.Module):
4 | def __init__(self, channel, num_classes, net_width, net_depth, im_size):
5 | super(Conv3DNet, self).__init__()
6 |
7 | self.features, shape_feat = self._make_layers(channel, net_width, net_depth, im_size)
8 | num_feat = shape_feat[0] * shape_feat[1] * shape_feat[2] * shape_feat[3]
9 | self.classifier = nn.Linear(num_feat, num_classes)
10 |
11 | def _get_normlayer(self, shape_feat):
12 | return nn.InstanceNorm3d(num_features=shape_feat[0], affine=True)
13 |
14 | def _get_activation(self):
15 | return nn.ReLU(inplace=True)
16 |
17 | def _get_pooling(self):
18 | return nn.AvgPool3d(kernel_size=(2, 2, 2), stride=2)
19 |
20 | def _make_layers(self, channel, net_width, net_depth, im_size):
21 |
22 | layers = []
23 | in_channels = channel
24 | shape_feat = [in_channels, im_size[0], im_size[1], im_size[2]]
25 |
26 | for d in range(net_depth):
27 | layers += [nn.Conv3d(in_channels, net_width, kernel_size=(3, 3, 3), padding=3 if channel == 1 and d == 0 else 1)]
28 | shape_feat[0] = net_width
29 |
30 | layers += [self._get_normlayer(shape_feat)]
31 | layers += [self._get_activation()]
32 |
33 | in_channels = net_width
34 |
35 | layers += [self._get_pooling()]
36 | shape_feat[1] //= 2
37 | shape_feat[2] //= 2
38 | shape_feat[3] //= 2
39 |
40 | return nn.Sequential(*layers), shape_feat
41 |
42 | def forward(self, x):
43 | out = self.features(x)
44 | out = out.view(out.size(0), -1)
45 | out = self.classifier(out)
46 | return out
47 |
48 | def embed(self, x):
49 | out = self.features(x)
50 | out = out.view(out.size(0), -1)
51 | return out
52 |
53 |
--------------------------------------------------------------------------------
/ImageNet-abcde/networks/vgg.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch
4 |
5 |
6 | ''' VGG '''
7 | cfg_vgg = {
8 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
9 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
10 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
11 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
12 | }
13 | class VGG(nn.Module):
14 | def __init__(self, vgg_name, channel, num_classes, norm='instancenorm'):
15 | super(VGG, self).__init__()
16 | self.channel = channel
17 | self.features = self._make_layers(cfg_vgg[vgg_name], norm)
18 | self.classifier = nn.Linear(512*7*7 if vgg_name != 'VGGS' else 128, num_classes)
19 |
20 | def forward(self, x):
21 | x = self.features(x)
22 | feat_fc = x.view(x.size(0), -1)
23 | x = self.classifier(feat_fc)
24 |
25 | return x
26 |
27 | def _make_layers(self, cfg, norm):
28 | layers = []
29 | in_channels = self.channel
30 | for ic, x in enumerate(cfg):
31 | if x == 'M':
32 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
33 | else:
34 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=3 if self.channel==1 and ic==0 else 1),
35 | nn.GroupNorm(x, x, affine=True) if norm=='instancenorm' else nn.BatchNorm2d(x) if norm == 'batch' else nn.Identity(),
36 | nn.ReLU(inplace=True)]
37 | in_channels = x
38 | layers += [nn.AdaptiveMaxPool2d((7, 7))]
39 | return nn.Sequential(*layers)
40 |
41 |
42 | def VGG11(channel, num_classes, **kwargs):
43 | return VGG('VGG11', channel, num_classes)
44 | def VGG11BN(channel, num_classes):
45 | return VGG('VGG11', channel, num_classes, norm='batchnorm')
46 | def VGG13(channel, num_classes):
47 | return VGG('VGG13', channel, num_classes)
48 | def VGG16(channel, num_classes):
49 | return VGG('VGG16', channel, num_classes)
50 | def VGG19(channel, num_classes):
51 | return VGG('VGG19', channel, num_classes)
--------------------------------------------------------------------------------
/ImageNet-abcde/networks/vgg_cifar.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | ''' VGG '''
4 | cfg_vgg = {
5 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
6 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
7 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
8 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
9 | }
10 | class VGGCIFAR(nn.Module):
11 | def __init__(self, vgg_name, channel, num_classes, norm='instancenorm'):
12 | super(VGGCIFAR, self).__init__()
13 | self.channel = channel
14 | self.features = self._make_layers(cfg_vgg[vgg_name], norm)
15 | self.classifier = nn.Linear(512 if vgg_name != 'VGGS' else 128, num_classes)
16 |
17 | def forward(self, x):
18 | x = self.features(x)
19 | x = x.view(x.size(0), -1)
20 | x = self.classifier(x)
21 | return x
22 |
23 | def embed(self, x):
24 | x = self.features(x)
25 | x = x.view(x.size(0), -1)
26 | return x
27 |
28 | def _make_layers(self, cfg, norm):
29 | layers = []
30 | in_channels = self.channel
31 | for ic, x in enumerate(cfg):
32 | if x == 'M':
33 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
34 | else:
35 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=3 if self.channel==1 and ic==0 else 1),
36 | nn.GroupNorm(x, x, affine=True) if norm=='instancenorm' else nn.BatchNorm2d(x),
37 | nn.ReLU(inplace=True)]
38 | in_channels = x
39 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
40 | return nn.Sequential(*layers)
41 |
42 |
43 | def VGG11CIFAR(channel, num_classes):
44 | return VGGCIFAR('VGG11', channel, num_classes)
45 | def VGG11BNCIFAR(channel, num_classes):
46 | return VGGCIFAR('VGG11', channel, num_classes, norm='batchnorm')
47 | def VGG13CIFAR(channel, num_classes):
48 | return VGGCIFAR('VGG13', channel, num_classes)
49 | def VGG16CIFAR(channel, num_classes):
50 | return VGGCIFAR('VGG16', channel, num_classes)
51 | def VGG19CIFAR(channel, num_classes):
52 | return VGGCIFAR('VGG19', channel, num_classes)
--------------------------------------------------------------------------------
/TM/hyper_params.py:
--------------------------------------------------------------------------------
1 | SYN_STEPS = {"MNIST": {1: 50, 10: 30}, "FashionMNIST": {1: 50, 10: 60}, "SVHN": {1: 50, 10: 30, 50: 40},
2 | "CIFAR10": {1: 50, 2: 50, 10: 40, 11: 40, 50: 30, 51: 30},
3 | "CIFAR100": {1: 50, 10: 20, 50: 80}, "Tiny": {1: 30, 10: 40, 50: 40}, "ImageNet": {1: 20, 2: 20, 10: 20}}
4 |
5 | EXPERT_EPOCHS = {"MNIST": {1: 2, 10: 2}, "FashionMNIST": {1: 2, 10: 2}, "SVHN": {1: 2, 10: 2, 50: 2},
6 | "CIFAR10": {1: 2, 2: 2, 10: 2, 11: 2, 50: 2, 51: 2},
7 | "CIFAR100": {1: 2, 10: 2, 50: 2}, "Tiny": {1: 2, 10: 2, 50: 2}, "ImageNet": {1: 2, 2: 2, 10: 2}}
8 |
9 | MAX_START_EPOCH = {"MNIST": {1: 5, 10: 15}, "FashionMNIST": {1: 5, 10: 15}, "SVHN": {1: 5, 10: 15, 50: 40},
10 | "CIFAR10": {1: 5, 2: 5, 10: 15, 11: 15, 50: 40, 51: 40},
11 | "CIFAR100": {1: 15, 10: 40, 50: 40}, "Tiny": {1: 30, 10: 40, 50: 40}, "ImageNet": {1: 10, 2: 10, 10: 10}}
12 |
13 | LR_LR = {"MNIST": {1: 1e-7, 10: 1e-5}, "FashionMNIST": {1: 1e-7, 10: 1e-5}, "SVHN": {1: 1e-7, 10: 1e-5, 50: 1e-5},
14 | "CIFAR10": {1: 1e-7, 2: 1e-7, 10: 1e-5, 11: 1e-5, 50: 1e-5, 51: 1e-5},
15 | "CIFAR100": {1: 1e-5, 10: 1e-5, 50: 1e-5}, "Tiny": {1: 1e-4, 10: 1e-4, 50: 1e-4}, "ImageNet": {1: 1e-6, 2: 1e-6, 10: 1e-6}}
16 |
17 | LR_TEACHER = {"MNIST": {1: 1e-2, 10: 1e-2}, "FashionMNIST": {1: 1e-2, 10: 1e-2}, "SVHN": {1: 1e-2, 10: 1e-2, 50: 1e-3},
18 | "CIFAR10": {1: 1e-2, 2: 1e-2, 10: 1e-2, 11: 1e-2, 50: 1e-3, 51: 1e-3},
19 | "CIFAR100": {1: 1e-2, 10: 1e-2, 50: 1e-2}, "Tiny": {1: 1e-2, 10: 1e-2, 50: 1e-2}, "ImageNet": {1: 1e-2, 2: 1e-2, 10: 1e-2}}
20 |
21 |
22 | def load_default(args):
23 | if args.zca:
24 | exit("Default FreD does not use ZCA")
25 |
26 | dataset = args.dataset
27 |
28 | if args.ipc in [1, 2, 10, 11, 50, 51]:
29 | ipc = args.ipc
30 | else:
31 | exit("Undefined IPC")
32 |
33 | if args.syn_steps == None:
34 | args.syn_steps = SYN_STEPS[dataset][ipc]
35 |
36 | if args.expert_epochs == None:
37 | args.expert_epochs = EXPERT_EPOCHS[dataset][ipc]
38 |
39 | if args.max_start_epoch == None:
40 | args.max_start_epoch = MAX_START_EPOCH[dataset][ipc]
41 |
42 | if args.lr_lr == None:
43 | args.lr_lr = LR_LR[dataset][ipc]
44 |
45 | if args.lr_teacher == None:
46 | args.lr_teacher = LR_TEACHER[dataset][ipc]
47 | return args
--------------------------------------------------------------------------------
/ImageNet-abcde/hyper_params.py:
--------------------------------------------------------------------------------
1 | SYN_STEPS = {"MNIST": {1: 50, 10: 30}, "FashionMNIST": {1: 50, 10: 60}, "SVHN": {1: 50, 10: 30, 50: 40},
2 | "CIFAR10": {1: 50, 2: 50, 10: 40, 11: 40, 50: 30, 51: 30},
3 | "CIFAR100": {1: 50, 10: 20, 50: 80}, "Tiny": {1: 30, 10: 40, 50: 40}, "ImageNet": {1: 20, 2: 20, 10: 20}}
4 |
5 | EXPERT_EPOCHS = {"MNIST": {1: 2, 10: 2}, "FashionMNIST": {1: 2, 10: 2}, "SVHN": {1: 2, 10: 2, 50: 2},
6 | "CIFAR10": {1: 2, 2: 2, 10: 2, 11: 2, 50: 2, 51: 2},
7 | "CIFAR100": {1: 2, 10: 2, 50: 2}, "Tiny": {1: 2, 10: 2, 50: 2}, "ImageNet": {1: 2, 2: 2, 10: 2}}
8 |
9 | MAX_START_EPOCH = {"MNIST": {1: 5, 10: 15}, "FashionMNIST": {1: 5, 10: 15}, "SVHN": {1: 5, 10: 15, 50: 40},
10 | "CIFAR10": {1: 5, 2: 5, 10: 15, 11: 15, 50: 40, 51: 40},
11 | "CIFAR100": {1: 15, 10: 40, 50: 40}, "Tiny": {1: 30, 10: 40, 50: 40}, "ImageNet": {1: 10, 2: 10, 10: 10}}
12 |
13 | LR_LR = {"MNIST": {1: 1e-7, 10: 1e-5}, "FashionMNIST": {1: 1e-7, 10: 1e-5}, "SVHN": {1: 1e-7, 10: 1e-5, 50: 1e-5},
14 | "CIFAR10": {1: 1e-7, 2: 1e-7, 10: 1e-5, 11: 1e-5, 50: 1e-5, 51: 1e-5},
15 | "CIFAR100": {1: 1e-5, 10: 1e-5, 50: 1e-5}, "Tiny": {1: 1e-4, 10: 1e-4, 50: 1e-4}, "ImageNet": {1: 1e-6, 2: 1e-6, 10: 1e-6}}
16 |
17 | LR_TEACHER = {"MNIST": {1: 1e-2, 10: 1e-2}, "FashionMNIST": {1: 1e-2, 10: 1e-2}, "SVHN": {1: 1e-2, 10: 1e-2, 50: 1e-3},
18 | "CIFAR10": {1: 1e-2, 2: 1e-2, 10: 1e-2, 11: 1e-2, 50: 1e-3, 51: 1e-3},
19 | "CIFAR100": {1: 1e-2, 10: 1e-2, 50: 1e-2}, "Tiny": {1: 1e-2, 10: 1e-2, 50: 1e-2}, "ImageNet": {1: 1e-2, 2: 1e-2, 10: 1e-2}}
20 |
21 |
22 | def load_default(args):
23 | if args.zca:
24 | exit("Default FreD does not use ZCA")
25 |
26 | dataset = args.dataset
27 | if dataset.startswith("imagenet"):
28 | dataset = "ImageNet"
29 |
30 | if args.ipc in [1, 2, 10, 11, 50, 51]:
31 | ipc = args.ipc
32 | else:
33 | exit("Undefined IPC")
34 |
35 | if args.syn_steps == None:
36 | args.syn_steps = SYN_STEPS[dataset][ipc]
37 |
38 | if args.expert_epochs == None:
39 | args.expert_epochs = EXPERT_EPOCHS[dataset][ipc]
40 |
41 | if args.max_start_epoch == None:
42 | args.max_start_epoch = MAX_START_EPOCH[dataset][ipc]
43 |
44 | if args.lr_lr == None:
45 | args.lr_lr = LR_LR[dataset][ipc]
46 |
47 | if args.lr_teacher == None:
48 | args.lr_teacher = LR_TEACHER[dataset][ipc]
49 | return args
--------------------------------------------------------------------------------
/ImageNet-abcde/networks/conv_gap.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | ''' ConvNetGAP '''
5 | class ConvNetGAP(nn.Module):
6 | def __init__(self, channel, num_classes, net_width, net_depth, net_act, net_norm, net_pooling='avgpooling', im_size = (32,32)):
7 | super(ConvNetGAP, self).__init__()
8 |
9 | self.features, shape_feat = self._make_layers(channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size)
10 | num_feat = shape_feat[0]
11 | self.classifier = nn.Linear(num_feat, num_classes)
12 | self.pool = nn.AdaptiveAvgPool2d((1, 1))
13 |
14 | def forward(self, x):
15 | # print("MODEL DATA ON: ", x.get_device(), "MODEL PARAMS ON: ", self.classifier.weight.data.get_device())
16 | out = self.features(x)
17 | out = self.pool(out)
18 | out = out.view(out.size(0), -1)
19 | out = self.classifier(out)
20 | return out
21 |
22 | def _get_activation(self, net_act):
23 | if net_act == 'sigmoid':
24 | return nn.Sigmoid()
25 | elif net_act == 'relu':
26 | return nn.ReLU(inplace=True)
27 | elif net_act == 'leakyrelu':
28 | return nn.LeakyReLU(negative_slope=0.01)
29 | else:
30 | exit('unknown activation function: %s'%net_act)
31 |
32 | def _get_pooling(self, net_pooling):
33 | if net_pooling == 'maxpooling':
34 | return nn.MaxPool2d(kernel_size=2, stride=2)
35 | elif net_pooling == 'avgpooling':
36 | return nn.AvgPool2d(kernel_size=2, stride=2)
37 | elif net_pooling == 'none':
38 | return None
39 | else:
40 | exit('unknown net_pooling: %s'%net_pooling)
41 |
42 | def _get_normlayer(self, net_norm, shape_feat):
43 | # shape_feat = (c*h*w)
44 | if net_norm == 'batchnorm':
45 | return nn.BatchNorm2d(shape_feat[0], affine=True)
46 | elif net_norm == 'layernorm':
47 | return nn.LayerNorm(shape_feat, elementwise_affine=True)
48 | elif net_norm == 'instancenorm':
49 | return nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True)
50 | elif net_norm == 'groupnorm':
51 | return nn.GroupNorm(4, shape_feat[0], affine=True)
52 | elif net_norm == 'none':
53 | return None
54 | else:
55 | exit('unknown net_norm: %s'%net_norm)
56 |
57 | def _make_layers(self, channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size):
58 | layers = []
59 | in_channels = channel
60 | if im_size[0] == 28:
61 | im_size = (32, 32)
62 | shape_feat = [in_channels, im_size[0], im_size[1]]
63 | for d in range(net_depth):
64 | layers += [nn.Conv2d(in_channels, net_width, kernel_size=3, padding=3 if channel == 1 and d == 0 else 1)]
65 | shape_feat[0] = net_width
66 | if net_norm != 'none':
67 | layers += [self._get_normlayer(net_norm, shape_feat)]
68 | layers += [self._get_activation(net_act)]
69 | in_channels = net_width
70 | if net_pooling != 'none':
71 | layers += [self._get_pooling(net_pooling)]
72 | shape_feat[1] //= 2
73 | shape_feat[2] //= 2
74 | net_width *= 2
75 |
76 |
77 | return nn.Sequential(*layers), shape_feat
--------------------------------------------------------------------------------
/ImageNet-abcde/networks/dnfr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | ''' DNFR_Net '''
5 | class DNFR(nn.Module):
6 | def __init__(self, channel, num_classes, net_width, net_depth, net_act, net_norm, net_pooling='avgpooling', im_size=(32,32)):
7 | super(DNFR, self).__init__()
8 |
9 | self.features, shape_feat = self._make_layers(channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size)
10 | num_feat = shape_feat[0]*shape_feat[1]*shape_feat[2]
11 | self.classifier = nn.Linear(num_feat, num_classes)
12 |
13 | def forward(self, x):
14 | # print("MODEL DATA ON: ", x.get_device(), "MODEL PARAMS ON: ", self.classifier.weight.data.get_device())
15 | out = self.features(x)
16 | out = out.view(out.size(0), -1)
17 | out = self.classifier(out)
18 | return out
19 |
20 | def _get_activation(self, net_act):
21 | if net_act == 'sigmoid':
22 | return nn.Sigmoid()
23 | elif net_act == 'relu':
24 | return nn.ReLU(inplace=True)
25 | elif net_act == 'leakyrelu':
26 | return nn.LeakyReLU(negative_slope=0.01)
27 | else:
28 | exit('unknown activation function: %s'%net_act)
29 |
30 | def _get_pooling(self, net_pooling):
31 | if net_pooling == 'maxpooling':
32 | return nn.MaxPool2d(kernel_size=2, stride=2)
33 | elif net_pooling == 'avgpooling':
34 | return nn.AvgPool2d(kernel_size=2, stride=2)
35 | elif net_pooling == 'none':
36 | return None
37 | else:
38 | exit('unknown net_pooling: %s'%net_pooling)
39 |
40 | def _get_normlayer(self, net_norm, shape_feat):
41 | # shape_feat = (c*h*w)
42 | if net_norm == 'batchnorm':
43 | return nn.BatchNorm2d(shape_feat[0], affine=True)
44 | elif net_norm == 'layernorm':
45 | return nn.LayerNorm(shape_feat, elementwise_affine=True)
46 | elif net_norm == 'instancenorm':
47 | return nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True)
48 | elif net_norm == 'groupnorm':
49 | return nn.GroupNorm(4, shape_feat[0], affine=True)
50 | elif net_norm == 'none':
51 | return None
52 | else:
53 | exit('unknown net_norm: %s'%net_norm)
54 |
55 | def _make_layers(self, channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size):
56 | layers = []
57 | in_channels = channel
58 | if im_size[0] == 28:
59 | im_size = (32, 32)
60 | shape_feat = [in_channels, im_size[0], im_size[1]]
61 | for d in range(net_depth):
62 | if net_norm != 'none':
63 | if d == 0 and net_norm == 'groupnorm':
64 | layers += [self._get_normlayer('instancenorm', shape_feat)]
65 | else:
66 | layers += [self._get_normlayer(net_norm, shape_feat)]
67 | layers += [nn.Conv2d(in_channels, net_width, kernel_size=3, padding=3 if channel == 1 and d == 0 else 1)]
68 | shape_feat[0] = net_width
69 | layers += [self._get_activation(net_act)]
70 | in_channels = net_width
71 | if net_pooling != 'none':
72 | layers += [self._get_pooling(net_pooling)]
73 | shape_feat[1] //= 2
74 | shape_feat[2] //= 2
75 |
76 |
77 | return nn.Sequential(*layers), shape_feat
--------------------------------------------------------------------------------
/ImageNet-abcde/shared_args.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def add_shared_args():
4 | parser = argparse.ArgumentParser(description='Parameter Processing')
5 | parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset')
6 | parser.add_argument('--model', type=str, default='ConvNet', help='model')
7 | parser.add_argument('--ipc', type=int, default=1, help='image(s) per class')
8 | parser.add_argument('--eval_mode', type=str, default='M', help='eval_mode') # S: the same to training model, M: multi architectures
9 | parser.add_argument('--num_eval', type=int, default=5, help='the number of evaluating randomly initialized models')
10 | parser.add_argument('--eval_it', type=int, default=100, help='how often to evaluate')
11 | parser.add_argument('--save_it', type=int, default=None, help='how often to evaluate')
12 | parser.add_argument('--epoch_eval_train', type=int, default=1000, help='epochs to train a model with synthetic data')
13 | parser.add_argument('--Iteration', type=int, default=1000, help='training iterations')
14 |
15 | parser.add_argument('--mom_img', type=float, default=0.5, help='momentum for updating synthetic images')
16 |
17 | parser.add_argument('--batch_real', type=int, default=256, help='batch size for real data')
18 | parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks')
19 | parser.add_argument('--batch_test', type=int, default=128, help='batch size for training networks')
20 |
21 | parser.add_argument('--pix_init', type=str, default='noise', choices=["noise", "real"], help='noise/real: initialize synthetic images from random noise or randomly sampled real images.')
22 | parser.add_argument('--dsa', type=str, default='True', choices=['True', 'False'], help='whether to use differentiable Siamese augmentation.')
23 | parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', help='differentiable Siamese augmentation strategy')
24 | parser.add_argument('--data_path', type=str, default='data', help='dataset path')
25 |
26 | parser.add_argument('--save_path', type=str, default='result', help='path to save results')
27 |
28 | parser.add_argument('--space', type=str, default='p', choices=['p', 'wp'])
29 | parser.add_argument('--res', type=int, default=128, choices=[128, 256, 512], help='resolution')
30 | parser.add_argument('--layer', type=int, default=12)
31 | parser.add_argument('--avg_w', action='store_true')
32 |
33 | parser.add_argument('--eval_all', action='store_true')
34 |
35 | parser.add_argument('--min_it', type=bool, default=False)
36 | parser.add_argument('--no_aug', type=bool, default=False)
37 |
38 | parser.add_argument('--force_save', action='store_true')
39 |
40 | parser.add_argument('--sg_batch', type=int, default=10)
41 |
42 | parser.add_argument('--rand_f', action='store_true')
43 |
44 | parser.add_argument('--logdir', type=str, default='./logged_files')
45 |
46 | parser.add_argument('--wait_eval', action='store_true')
47 |
48 | parser.add_argument('--idc_factor', type=int, default=1)
49 |
50 | parser.add_argument('--rand_gan_un', action='store_true')
51 | parser.add_argument('--rand_gan_con', action='store_true')
52 |
53 | parser.add_argument('--learn_g', action='store_true')
54 |
55 | parser.add_argument('--width', type=int, default=128)
56 | parser.add_argument('--depth', type=int, default=5)
57 |
58 | parser.add_argument('--norm_train', type=str, default="batchnorm")
59 |
60 | parser.add_argument('--special_gan', default=None)
61 |
62 | return parser
63 |
--------------------------------------------------------------------------------
/ImageNet-abcde/networks/conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | ''' ConvNet '''
5 | class ConvNet(nn.Module):
6 | def __init__(self, channel, num_classes, net_width, net_depth, net_act, net_norm, net_pooling='avgpooling', im_size = (32,32)):
7 | super(ConvNet, self).__init__()
8 |
9 | self.features, shape_feat = self._make_layers(channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size)
10 | num_feat = shape_feat[0]*shape_feat[1]*shape_feat[2]
11 | self.classifier = nn.Linear(num_feat, num_classes)
12 |
13 | def forward(self, x):
14 | # print("MODEL DATA ON: ", x.get_device(), "MODEL PARAMS ON: ", self.classifier.weight.data.get_device())
15 | out = self.features(x)
16 | out = out.view(out.size(0), -1)
17 | out = self.classifier(out)
18 | return out
19 |
20 |
21 | def forward_cafe(self, x):
22 | f_maps = []
23 | x = self.features[0:3](x)
24 | f_maps.append(x)
25 |
26 | for i in range(start=3, stop=-1, step=4):
27 | x = self.features[i:i+4](x)
28 | f_maps.append(x)
29 |
30 | out = self.features[-1:](x)
31 | f_maps.append(out)
32 | out = out.view(out.size(0), -1)
33 | out_final = self.classifier(out)
34 | return out_final, out, f_maps
35 |
36 | def embed(self, x):
37 | out = self.features(x)
38 | out = out.view(out.size(0), -1)
39 | return out
40 |
41 | def _get_activation(self, net_act):
42 | if net_act == 'sigmoid':
43 | return nn.Sigmoid()
44 | elif net_act == 'relu':
45 | return nn.ReLU(inplace=True)
46 | elif net_act == 'leakyrelu':
47 | return nn.LeakyReLU(negative_slope=0.01)
48 | else:
49 | exit('unknown activation function: %s'%net_act)
50 |
51 | def _get_pooling(self, net_pooling):
52 | if net_pooling == 'maxpooling':
53 | return nn.MaxPool2d(kernel_size=2, stride=2)
54 | elif net_pooling == 'avgpooling':
55 | return nn.AvgPool2d(kernel_size=2, stride=2)
56 | elif net_pooling == 'none':
57 | return None
58 | else:
59 | exit('unknown net_pooling: %s'%net_pooling)
60 |
61 | def _get_normlayer(self, net_norm, shape_feat):
62 | # shape_feat = (c*h*w)
63 | if net_norm == 'batchnorm':
64 | return nn.BatchNorm2d(shape_feat[0], affine=True)
65 | elif net_norm == 'layernorm':
66 | return nn.LayerNorm(shape_feat, elementwise_affine=True)
67 | elif net_norm == 'instancenorm':
68 | return nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True)
69 | elif net_norm == 'groupnorm':
70 | return nn.GroupNorm(4, shape_feat[0], affine=True)
71 | elif net_norm == 'none':
72 | return None
73 | else:
74 | exit('unknown net_norm: %s'%net_norm)
75 |
76 | def _make_layers(self, channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size):
77 | layers = []
78 | in_channels = channel
79 | if im_size[0] == 28:
80 | im_size = (32, 32)
81 | shape_feat = [in_channels, im_size[0], im_size[1]]
82 | for d in range(net_depth):
83 | layers += [nn.Conv2d(in_channels, net_width, kernel_size=3, padding=3 if channel == 1 and d == 0 else 1)]
84 | shape_feat[0] = net_width
85 | if net_norm != 'none':
86 | layers += [self._get_normlayer(net_norm, shape_feat)]
87 | layers += [self._get_activation(net_act)]
88 | in_channels = net_width
89 | if net_pooling != 'none':
90 | layers += [self._get_pooling(net_pooling)]
91 | shape_feat[1] //= 2
92 | shape_feat[2] //= 2
93 |
94 |
95 | return nn.Sequential(*layers), shape_feat
--------------------------------------------------------------------------------
/DC/frequency_transforms.py:
--------------------------------------------------------------------------------
1 | """
2 | Reference: https://github.com/zh217/torch-dct
3 | """
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 |
9 | class DCT():
10 | def __init__(self, resolution, device, norm=None, bias=False):
11 | self.resolution = resolution
12 | self.norm = norm
13 | self.device = device
14 |
15 | I = torch.eye(self.resolution, device=self.device)
16 | self.forward_transform = nn.Linear(resolution, resolution, bias=bias).to(self.device)
17 | self.forward_transform.weight.data = self._dct(I, norm=self.norm).data.t()
18 | self.forward_transform.weight.requires_grad = False
19 |
20 | self.inverse_transform = nn.Linear(resolution, resolution, bias=bias).to(self.device)
21 | self.inverse_transform.weight.data = self._idct(I, norm=self.norm).data.t()
22 | self.inverse_transform.weight.requires_grad = False
23 |
24 | def _dct(self, x, norm=None):
25 | """
26 | Discrete Cosine Transform, Type II (a.k.a. the DCT)
27 | For the meaning of the parameter `norm`, see:
28 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
29 | :param x: the input signal
30 | :param norm: the normalization, None or 'ortho'
31 | :return: the DCT-II of the signal over the last dimension
32 | """
33 | x_shape = x.shape
34 | N = x_shape[-1]
35 | x = x.contiguous().view(-1, N)
36 |
37 | v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)
38 |
39 | Vc = torch.view_as_real(torch.fft.fft(v, dim=1))
40 |
41 | k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
42 | W_r = torch.cos(k)
43 | W_i = torch.sin(k)
44 |
45 | V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i
46 |
47 | if norm == 'ortho':
48 | V[:, 0] /= np.sqrt(N) * 2
49 | V[:, 1:] /= np.sqrt(N / 2) * 2
50 |
51 | V = 2 * V.view(*x_shape)
52 | return V
53 |
54 | def _idct(self, X, norm=None):
55 | """
56 | The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III
57 | Our definition of idct is that idct(dct(x)) == x
58 | For the meaning of the parameter `norm`, see:
59 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
60 | :param X: the input signal
61 | :param norm: the normalization, None or 'ortho'
62 | :return: the inverse DCT-II of the signal over the last dimension
63 | """
64 |
65 | x_shape = X.shape
66 | N = x_shape[-1]
67 |
68 | X_v = X.contiguous().view(-1, x_shape[-1]) / 2
69 |
70 | if norm == 'ortho':
71 | X_v[:, 0] *= np.sqrt(N) * 2
72 | X_v[:, 1:] *= np.sqrt(N / 2) * 2
73 |
74 | k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N)
75 | W_r = torch.cos(k)
76 | W_i = torch.sin(k)
77 |
78 | V_t_r = X_v
79 | V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)
80 |
81 | V_r = V_t_r * W_r - V_t_i * W_i
82 | V_i = V_t_r * W_i + V_t_i * W_r
83 |
84 | V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)
85 |
86 | v = torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1)
87 | x = v.new_zeros(v.shape)
88 | x[:, ::2] += v[:, :N - (N // 2)]
89 | x[:, 1::2] += v.flip([1])[:, :N // 2]
90 | return x.view(*x_shape)
91 |
92 | def forward(self, x):
93 | X1 = self.forward_transform(x)
94 | X2 = self.forward_transform(X1.transpose(-1, -2))
95 | return X2.transpose(-1, -2)
96 |
97 | def inverse(self, x):
98 | X1 = self.inverse_transform(x)
99 | X2 = self.inverse_transform(X1.transpose(-1, -2))
100 | return X2.transpose(-1, -2)
--------------------------------------------------------------------------------
/DM/frequency_transforms.py:
--------------------------------------------------------------------------------
1 | """
2 | Reference: https://github.com/zh217/torch-dct
3 | """
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 |
9 | class DCT():
10 | def __init__(self, resolution, device, norm=None, bias=False):
11 | self.resolution = resolution
12 | self.norm = norm
13 | self.device = device
14 |
15 | I = torch.eye(self.resolution, device=self.device)
16 | self.forward_transform = nn.Linear(resolution, resolution, bias=bias).to(self.device)
17 | self.forward_transform.weight.data = self._dct(I, norm=self.norm).data.t()
18 | self.forward_transform.weight.requires_grad = False
19 |
20 | self.inverse_transform = nn.Linear(resolution, resolution, bias=bias).to(self.device)
21 | self.inverse_transform.weight.data = self._idct(I, norm=self.norm).data.t()
22 | self.inverse_transform.weight.requires_grad = False
23 |
24 | def _dct(self, x, norm=None):
25 | """
26 | Discrete Cosine Transform, Type II (a.k.a. the DCT)
27 | For the meaning of the parameter `norm`, see:
28 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
29 | :param x: the input signal
30 | :param norm: the normalization, None or 'ortho'
31 | :return: the DCT-II of the signal over the last dimension
32 | """
33 | x_shape = x.shape
34 | N = x_shape[-1]
35 | x = x.contiguous().view(-1, N)
36 |
37 | v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)
38 |
39 | Vc = torch.view_as_real(torch.fft.fft(v, dim=1))
40 |
41 | k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
42 | W_r = torch.cos(k)
43 | W_i = torch.sin(k)
44 |
45 | V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i
46 |
47 | if norm == 'ortho':
48 | V[:, 0] /= np.sqrt(N) * 2
49 | V[:, 1:] /= np.sqrt(N / 2) * 2
50 |
51 | V = 2 * V.view(*x_shape)
52 | return V
53 |
54 | def _idct(self, X, norm=None):
55 | """
56 | The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III
57 | Our definition of idct is that idct(dct(x)) == x
58 | For the meaning of the parameter `norm`, see:
59 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
60 | :param X: the input signal
61 | :param norm: the normalization, None or 'ortho'
62 | :return: the inverse DCT-II of the signal over the last dimension
63 | """
64 |
65 | x_shape = X.shape
66 | N = x_shape[-1]
67 |
68 | X_v = X.contiguous().view(-1, x_shape[-1]) / 2
69 |
70 | if norm == 'ortho':
71 | X_v[:, 0] *= np.sqrt(N) * 2
72 | X_v[:, 1:] *= np.sqrt(N / 2) * 2
73 |
74 | k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N)
75 | W_r = torch.cos(k)
76 | W_i = torch.sin(k)
77 |
78 | V_t_r = X_v
79 | V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)
80 |
81 | V_r = V_t_r * W_r - V_t_i * W_i
82 | V_i = V_t_r * W_i + V_t_i * W_r
83 |
84 | V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)
85 |
86 | v = torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1)
87 | x = v.new_zeros(v.shape)
88 | x[:, ::2] += v[:, :N - (N // 2)]
89 | x[:, 1::2] += v.flip([1])[:, :N // 2]
90 | return x.view(*x_shape)
91 |
92 | def forward(self, x):
93 | X1 = self.forward_transform(x)
94 | X2 = self.forward_transform(X1.transpose(-1, -2))
95 | return X2.transpose(-1, -2)
96 |
97 | def inverse(self, x):
98 | X1 = self.inverse_transform(x)
99 | X2 = self.inverse_transform(X1.transpose(-1, -2))
100 | return X2.transpose(-1, -2)
--------------------------------------------------------------------------------
/TM/frequency_transforms.py:
--------------------------------------------------------------------------------
1 | """
2 | Reference: https://github.com/zh217/torch-dct
3 | """
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 |
9 | class DCT():
10 | def __init__(self, resolution, device, norm=None, bias=False):
11 | self.resolution = resolution
12 | self.norm = norm
13 | self.device = device
14 |
15 | I = torch.eye(self.resolution, device=self.device)
16 | self.forward_transform = nn.Linear(resolution, resolution, bias=bias).to(self.device)
17 | self.forward_transform.weight.data = self._dct(I, norm=self.norm).data.t()
18 | self.forward_transform.weight.requires_grad = False
19 |
20 | self.inverse_transform = nn.Linear(resolution, resolution, bias=bias).to(self.device)
21 | self.inverse_transform.weight.data = self._idct(I, norm=self.norm).data.t()
22 | self.inverse_transform.weight.requires_grad = False
23 |
24 | def _dct(self, x, norm=None):
25 | """
26 | Discrete Cosine Transform, Type II (a.k.a. the DCT)
27 | For the meaning of the parameter `norm`, see:
28 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
29 | :param x: the input signal
30 | :param norm: the normalization, None or 'ortho'
31 | :return: the DCT-II of the signal over the last dimension
32 | """
33 | x_shape = x.shape
34 | N = x_shape[-1]
35 | x = x.contiguous().view(-1, N)
36 |
37 | v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)
38 |
39 | Vc = torch.view_as_real(torch.fft.fft(v, dim=1))
40 |
41 | k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
42 | W_r = torch.cos(k)
43 | W_i = torch.sin(k)
44 |
45 | V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i
46 |
47 | if norm == 'ortho':
48 | V[:, 0] /= np.sqrt(N) * 2
49 | V[:, 1:] /= np.sqrt(N / 2) * 2
50 |
51 | V = 2 * V.view(*x_shape)
52 | return V
53 |
54 | def _idct(self, X, norm=None):
55 | """
56 | The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III
57 | Our definition of idct is that idct(dct(x)) == x
58 | For the meaning of the parameter `norm`, see:
59 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
60 | :param X: the input signal
61 | :param norm: the normalization, None or 'ortho'
62 | :return: the inverse DCT-II of the signal over the last dimension
63 | """
64 |
65 | x_shape = X.shape
66 | N = x_shape[-1]
67 |
68 | X_v = X.contiguous().view(-1, x_shape[-1]) / 2
69 |
70 | if norm == 'ortho':
71 | X_v[:, 0] *= np.sqrt(N) * 2
72 | X_v[:, 1:] *= np.sqrt(N / 2) * 2
73 |
74 | k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N)
75 | W_r = torch.cos(k)
76 | W_i = torch.sin(k)
77 |
78 | V_t_r = X_v
79 | V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)
80 |
81 | V_r = V_t_r * W_r - V_t_i * W_i
82 | V_i = V_t_r * W_i + V_t_i * W_r
83 |
84 | V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)
85 |
86 | v = torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1)
87 | x = v.new_zeros(v.shape)
88 | x[:, ::2] += v[:, :N - (N // 2)]
89 | x[:, 1::2] += v.flip([1])[:, :N // 2]
90 | return x.view(*x_shape)
91 |
92 | def forward(self, x):
93 | X1 = self.forward_transform(x)
94 | X2 = self.forward_transform(X1.transpose(-1, -2))
95 | return X2.transpose(-1, -2)
96 |
97 | def inverse(self, x):
98 | X1 = self.inverse_transform(x)
99 | X2 = self.inverse_transform(X1.transpose(-1, -2))
100 | return X2.transpose(-1, -2)
--------------------------------------------------------------------------------
/ImageNet-abcde/frequency_transforms.py:
--------------------------------------------------------------------------------
1 | """
2 | Reference: https://github.com/zh217/torch-dct
3 | """
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 |
9 | class DCT():
10 | def __init__(self, resolution, device, norm=None, bias=False):
11 | self.resolution = resolution
12 | self.norm = norm
13 | self.device = device
14 |
15 | I = torch.eye(self.resolution, device=self.device)
16 | self.forward_transform = nn.Linear(resolution, resolution, bias=bias).to(self.device)
17 | self.forward_transform.weight.data = self._dct(I, norm=self.norm).data.t()
18 | self.forward_transform.weight.requires_grad = False
19 |
20 | self.inverse_transform = nn.Linear(resolution, resolution, bias=bias).to(self.device)
21 | self.inverse_transform.weight.data = self._idct(I, norm=self.norm).data.t()
22 | self.inverse_transform.weight.requires_grad = False
23 |
24 | def _dct(self, x, norm=None):
25 | """
26 | Discrete Cosine Transform, Type II (a.k.a. the DCT)
27 | For the meaning of the parameter `norm`, see:
28 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
29 | :param x: the input signal
30 | :param norm: the normalization, None or 'ortho'
31 | :return: the DCT-II of the signal over the last dimension
32 | """
33 | x_shape = x.shape
34 | N = x_shape[-1]
35 | x = x.contiguous().view(-1, N)
36 |
37 | v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)
38 |
39 | Vc = torch.view_as_real(torch.fft.fft(v, dim=1))
40 |
41 | k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
42 | W_r = torch.cos(k)
43 | W_i = torch.sin(k)
44 |
45 | V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i
46 |
47 | if norm == 'ortho':
48 | V[:, 0] /= np.sqrt(N) * 2
49 | V[:, 1:] /= np.sqrt(N / 2) * 2
50 |
51 | V = 2 * V.view(*x_shape)
52 | return V
53 |
54 | def _idct(self, X, norm=None):
55 | """
56 | The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III
57 | Our definition of idct is that idct(dct(x)) == x
58 | For the meaning of the parameter `norm`, see:
59 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
60 | :param X: the input signal
61 | :param norm: the normalization, None or 'ortho'
62 | :return: the inverse DCT-II of the signal over the last dimension
63 | """
64 |
65 | x_shape = X.shape
66 | N = x_shape[-1]
67 |
68 | X_v = X.contiguous().view(-1, x_shape[-1]) / 2
69 |
70 | if norm == 'ortho':
71 | X_v[:, 0] *= np.sqrt(N) * 2
72 | X_v[:, 1:] *= np.sqrt(N / 2) * 2
73 |
74 | k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N)
75 | W_r = torch.cos(k)
76 | W_i = torch.sin(k)
77 |
78 | V_t_r = X_v
79 | V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)
80 |
81 | V_r = V_t_r * W_r - V_t_i * W_i
82 | V_i = V_t_r * W_i + V_t_i * W_r
83 |
84 | V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)
85 |
86 | v = torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1)
87 | x = v.new_zeros(v.shape)
88 | x[:, ::2] += v[:, :N - (N // 2)]
89 | x[:, 1::2] += v.flip([1])[:, :N // 2]
90 | return x.view(*x_shape)
91 |
92 | def forward(self, x):
93 | X1 = self.forward_transform(x)
94 | X2 = self.forward_transform(X1.transpose(-1, -2))
95 | return X2.transpose(-1, -2)
96 |
97 | def inverse(self, x):
98 | X1 = self.inverse_transform(x)
99 | X2 = self.inverse_transform(X1.transpose(-1, -2))
100 | return X2.transpose(-1, -2)
--------------------------------------------------------------------------------
/3D-MNIST/frequency_transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 | class DCT():
6 | def __init__(self, resolution, device, norm=None, bias=False):
7 | self.resolution = resolution
8 | self.norm = norm
9 | self.device = device
10 |
11 | I = torch.eye(self.resolution, device=self.device)
12 | self.forward_transform = nn.Linear(resolution, resolution, bias=bias).to(self.device)
13 | self.forward_transform.weight.data = self._dct(I, norm=self.norm).data.t()
14 | self.forward_transform.weight.requires_grad = False
15 |
16 | self.inverse_transform = nn.Linear(resolution, resolution, bias=bias).to(self.device)
17 | self.inverse_transform.weight.data = self._idct(I, norm=self.norm).data.t()
18 | self.inverse_transform.weight.requires_grad = False
19 |
20 | def _dct(self, x, norm=None):
21 | """
22 | Discrete Cosine Transform, Type II (a.k.a. the DCT)
23 | For the meaning of the parameter `norm`, see:
24 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
25 | :param x: the input signal
26 | :param norm: the normalization, None or 'ortho'
27 | :return: the DCT-II of the signal over the last dimension
28 | """
29 | x_shape = x.shape
30 | N = x_shape[-1]
31 | x = x.contiguous().view(-1, N)
32 |
33 | v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)
34 |
35 | Vc = torch.view_as_real(torch.fft.fft(v, dim=1))
36 |
37 | k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
38 | W_r = torch.cos(k)
39 | W_i = torch.sin(k)
40 |
41 | V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i
42 |
43 | if norm == 'ortho':
44 | V[:, 0] /= np.sqrt(N) * 2
45 | V[:, 1:] /= np.sqrt(N / 2) * 2
46 |
47 | V = 2 * V.view(*x_shape)
48 | return V
49 |
50 | def _idct(self, X, norm=None):
51 | """
52 | The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III
53 | Our definition of idct is that idct(dct(x)) == x
54 | For the meaning of the parameter `norm`, see:
55 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
56 | :param X: the input signal
57 | :param norm: the normalization, None or 'ortho'
58 | :return: the inverse DCT-II of the signal over the last dimension
59 | """
60 |
61 | x_shape = X.shape
62 | N = x_shape[-1]
63 |
64 | X_v = X.contiguous().view(-1, x_shape[-1]) / 2
65 |
66 | if norm == 'ortho':
67 | X_v[:, 0] *= np.sqrt(N) * 2
68 | X_v[:, 1:] *= np.sqrt(N / 2) * 2
69 |
70 | k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N)
71 | W_r = torch.cos(k)
72 | W_i = torch.sin(k)
73 |
74 | V_t_r = X_v
75 | V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)
76 |
77 | V_r = V_t_r * W_r - V_t_i * W_i
78 | V_i = V_t_r * W_i + V_t_i * W_r
79 |
80 | V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)
81 |
82 | v = torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1)
83 | x = v.new_zeros(v.shape)
84 | x[:, ::2] += v[:, :N - (N // 2)]
85 | x[:, 1::2] += v.flip([1])[:, :N // 2]
86 | return x.view(*x_shape)
87 |
88 | def forward(self, x):
89 | X1 = self.forward_transform(x) # (bsz, C, D, H, W)
90 | X2 = self.forward_transform(X1.transpose(-1, -2)) # (bsz, C, D, W, H)
91 | X3 = self.forward_transform(X2.transpose(-1, -3)) # (bsz, C, H, W, D)
92 | return X3.permute(0, 1, -1, -3, -2)
93 |
94 | def inverse(self, x):
95 | X1 = self.inverse_transform(x) # (bsz, C, D, H, W)
96 | X2 = self.inverse_transform(X1.transpose(-1, -2)) # (bsz, C, D, W, H)
97 | X3 = self.inverse_transform(X2.transpose(-1, -3)) # (bsz, C, H, W, D)
98 | return X3.permute(0, 1, -1, -3, -2)
--------------------------------------------------------------------------------
/ImageNet-abcde/networks/vit_cifar.py:
--------------------------------------------------------------------------------
1 | # https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py
2 |
3 | import torch
4 | from torch import nn
5 |
6 | from einops import rearrange, repeat
7 | from einops.layers.torch import Rearrange
8 |
9 | # helpers
10 |
11 | def pair(t):
12 | return t if isinstance(t, tuple) else (t, t)
13 |
14 | # classes
15 |
16 | class PreNorm(nn.Module):
17 | def __init__(self, dim, fn):
18 | super().__init__()
19 | self.norm = nn.LayerNorm(dim)
20 | self.fn = fn
21 | def forward(self, x, **kwargs):
22 | return self.fn(self.norm(x), **kwargs)
23 |
24 | class FeedForward(nn.Module):
25 | def __init__(self, dim, hidden_dim, dropout = 0.):
26 | super().__init__()
27 | self.net = nn.Sequential(
28 | nn.Linear(dim, hidden_dim),
29 | nn.GELU(),
30 | nn.Dropout(dropout),
31 | nn.Linear(hidden_dim, dim),
32 | nn.Dropout(dropout)
33 | )
34 | def forward(self, x):
35 | return self.net(x)
36 |
37 | class Attention(nn.Module):
38 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
39 | super().__init__()
40 | inner_dim = dim_head * heads
41 | project_out = not (heads == 1 and dim_head == dim)
42 |
43 | self.heads = heads
44 | self.scale = dim_head ** -0.5
45 |
46 | self.attend = nn.Softmax(dim = -1)
47 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
48 |
49 | self.to_out = nn.Sequential(
50 | nn.Linear(inner_dim, dim),
51 | nn.Dropout(dropout)
52 | ) if project_out else nn.Identity()
53 |
54 | def forward(self, x):
55 | qkv = self.to_qkv(x).chunk(3, dim = -1)
56 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
57 |
58 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
59 |
60 | attn = self.attend(dots)
61 |
62 | out = torch.matmul(attn, v)
63 | out = rearrange(out, 'b h n d -> b n (h d)')
64 | return self.to_out(out)
65 |
66 | class Transformer(nn.Module):
67 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
68 | super().__init__()
69 | self.layers = nn.ModuleList([])
70 | for _ in range(depth):
71 | self.layers.append(nn.ModuleList([
72 | PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
73 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
74 | ]))
75 | def forward(self, x):
76 | for attn, ff in self.layers:
77 | x = attn(x) + x
78 | x = ff(x) + x
79 | return x
80 |
81 | class ViTCIFAR(nn.Module):
82 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
83 | super().__init__()
84 | image_height, image_width = pair(image_size)
85 | patch_height, patch_width = pair(patch_size)
86 |
87 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
88 |
89 | num_patches = (image_height // patch_height) * (image_width // patch_width)
90 | patch_dim = channels * patch_height * patch_width
91 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
92 |
93 | self.to_patch_embedding = nn.Sequential(
94 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
95 | nn.Linear(patch_dim, dim),
96 | )
97 |
98 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
99 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
100 | self.dropout = nn.Dropout(emb_dropout)
101 |
102 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
103 |
104 | self.pool = pool
105 | self.to_latent = nn.Identity()
106 |
107 | self.mlp_head = nn.Sequential(
108 | nn.LayerNorm(dim),
109 | nn.Linear(dim, num_classes)
110 | )
111 |
112 | def forward(self, img):
113 | x = self.to_patch_embedding(img)
114 | b, n, _ = x.shape
115 |
116 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
117 | x = torch.cat((cls_tokens, x), dim=1)
118 | x += self.pos_embedding[:, :(n + 1)]
119 | x = self.dropout(x)
120 |
121 | x = self.transformer(x)
122 |
123 | x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
124 |
125 | x = self.to_latent(x)
126 | return self.mlp_head(x)
--------------------------------------------------------------------------------
/ImageNet-abcde/networks/vit.py:
--------------------------------------------------------------------------------
1 | # https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py
2 |
3 | import torch
4 | from torch import nn
5 |
6 | from einops import rearrange, repeat
7 | from einops.layers.torch import Rearrange
8 |
9 | # helpers
10 |
11 | def pair(t):
12 | return t if isinstance(t, tuple) else (t, t)
13 |
14 | # classes
15 |
16 | class PreNorm(nn.Module):
17 | def __init__(self, dim, fn):
18 | super().__init__()
19 | # self.norm = nn.LayerNorm(dim)
20 | self.fn = fn
21 | def forward(self, x, **kwargs):
22 | # return self.fn(self.norm(x), **kwargs)
23 | return self.fn(x, **kwargs)
24 |
25 | class FeedForward(nn.Module):
26 | def __init__(self, dim, hidden_dim, dropout = 0.):
27 | super().__init__()
28 | self.net = nn.Sequential(
29 | nn.Linear(dim, hidden_dim),
30 | nn.GELU(),
31 | nn.Dropout(dropout),
32 | nn.Linear(hidden_dim, dim),
33 | nn.Dropout(dropout)
34 | )
35 | def forward(self, x):
36 | return self.net(x)
37 |
38 | class Attention(nn.Module):
39 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
40 | super().__init__()
41 | inner_dim = dim_head * heads
42 | project_out = not (heads == 1 and dim_head == dim)
43 |
44 | self.heads = heads
45 | self.scale = dim_head ** -0.5
46 |
47 | self.attend = nn.Softmax(dim = -1)
48 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
49 |
50 | self.to_out = nn.Sequential(
51 | nn.Linear(inner_dim, dim),
52 | nn.Dropout(dropout)
53 | ) if project_out else nn.Identity()
54 |
55 | def forward(self, x):
56 | qkv = self.to_qkv(x).chunk(3, dim = -1)
57 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
58 |
59 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
60 |
61 | attn = self.attend(dots)
62 |
63 | out = torch.matmul(attn, v)
64 | out = rearrange(out, 'b h n d -> b n (h d)')
65 | return self.to_out(out)
66 |
67 | class Transformer(nn.Module):
68 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
69 | super().__init__()
70 | self.layers = nn.ModuleList([])
71 | for _ in range(depth):
72 | self.layers.append(nn.ModuleList([
73 | PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
74 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
75 | ]))
76 | def forward(self, x):
77 | for attn, ff in self.layers:
78 | x = attn(x) + x
79 | x = ff(x) + x
80 | return x
81 |
82 | class ViT(nn.Module):
83 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
84 | super().__init__()
85 | image_height, image_width = pair(image_size)
86 | patch_height, patch_width = pair(patch_size)
87 |
88 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
89 |
90 | num_patches = (image_height // patch_height) * (image_width // patch_width)
91 | patch_dim = channels * patch_height * patch_width
92 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
93 |
94 | self.to_patch_embedding = nn.Sequential(
95 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
96 | nn.Linear(patch_dim, dim),
97 | )
98 |
99 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
100 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
101 | self.dropout = nn.Dropout(emb_dropout)
102 |
103 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
104 |
105 | self.pool = pool
106 | self.to_latent = nn.Identity()
107 |
108 | self.mlp_head = nn.Sequential(
109 | # nn.LayerNorm(dim),
110 | nn.Linear(dim, num_classes)
111 | )
112 |
113 | def forward(self, img):
114 | x = self.to_patch_embedding(img)
115 | b, n, _ = x.shape
116 |
117 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
118 | x = torch.cat((cls_tokens, x), dim=1)
119 | x += self.pos_embedding[:, :(n + 1)]
120 | x = self.dropout(x)
121 |
122 | x = self.transformer(x)
123 |
124 | x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
125 |
126 | x = self.to_latent(x)
127 |
128 | feat_fc = x
129 |
130 | x = self.mlp_head(x)
131 |
132 | return x
--------------------------------------------------------------------------------
/ImageNet-abcde/networks/resnet_cifar.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | ''' ResNet '''
5 |
6 | class BasicBlock(nn.Module):
7 | expansion = 1
8 |
9 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'):
10 | super(BasicBlock, self).__init__()
11 | self.norm = norm
12 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
13 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
14 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
15 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
16 |
17 | self.shortcut = nn.Sequential()
18 | if stride != 1 or in_planes != self.expansion*planes:
19 | self.shortcut = nn.Sequential(
20 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
21 | nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes)
22 | )
23 |
24 | def forward(self, x):
25 | out = F.relu(self.bn1(self.conv1(x)))
26 | out = self.bn2(self.conv2(out))
27 | out += self.shortcut(x)
28 | out = F.relu(out)
29 | return out
30 |
31 |
32 | class Bottleneck(nn.Module):
33 | expansion = 4
34 |
35 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'):
36 | super(Bottleneck, self).__init__()
37 | self.norm = norm
38 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
39 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
40 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
41 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
42 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
43 | self.bn3 = nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes)
44 |
45 | self.shortcut = nn.Sequential()
46 | if stride != 1 or in_planes != self.expansion*planes:
47 | self.shortcut = nn.Sequential(
48 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
49 | nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes)
50 | )
51 |
52 | def forward(self, x):
53 | out = F.relu(self.bn1(self.conv1(x)))
54 | out = F.relu(self.bn2(self.conv2(out)))
55 | out = self.bn3(self.conv3(out))
56 | out += self.shortcut(x)
57 | out = F.relu(out)
58 | return out
59 |
60 |
61 | class ResNetCIFAR(nn.Module):
62 | def __init__(self, block, num_blocks, channel=3, num_classes=10, norm='instancenorm'):
63 | super(ResNetCIFAR, self).__init__()
64 | self.in_planes = 64
65 | self.norm = norm
66 |
67 | self.conv1 = nn.Conv2d(channel, 64, kernel_size=3, stride=1, padding=1, bias=False)
68 | self.bn1 = nn.GroupNorm(64, 64, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(64)
69 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
70 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
71 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
72 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
73 | self.classifier = nn.Linear(512*block.expansion, num_classes)
74 |
75 | def _make_layer(self, block, planes, num_blocks, stride):
76 | strides = [stride] + [1]*(num_blocks-1)
77 | layers = []
78 | for stride in strides:
79 | layers.append(block(self.in_planes, planes, stride, self.norm))
80 | self.in_planes = planes * block.expansion
81 | return nn.Sequential(*layers)
82 |
83 | def forward(self, x):
84 | out = F.relu(self.bn1(self.conv1(x)))
85 | out = self.layer1(out)
86 | out = self.layer2(out)
87 | out = self.layer3(out)
88 | out = self.layer4(out)
89 | out = F.avg_pool2d(out, 4)
90 | out = out.view(out.size(0), -1)
91 | out = self.classifier(out)
92 | return out
93 |
94 | def embed(self, x):
95 | out = F.relu(self.bn1(self.conv1(x)))
96 | out = self.layer1(out)
97 | out = self.layer2(out)
98 | out = self.layer3(out)
99 | out = self.layer4(out)
100 | out = F.avg_pool2d(out, 4)
101 | out = out.view(out.size(0), -1)
102 | return out
103 |
104 |
105 | def ResNet18BNCIFAR(channel, num_classes):
106 | return ResNetCIFAR(BasicBlock, [2,2,2,2], channel=channel, num_classes=num_classes, norm='batchnorm')
107 |
108 | def ResNet18CIFAR(channel, num_classes):
109 | return ResNetCIFAR(BasicBlock, [2,2,2,2], channel=channel, num_classes=num_classes)
110 |
111 | def ResNet34CIFAR(channel, num_classes):
112 | return ResNetCIFAR(BasicBlock, [3,4,6,3], channel=channel, num_classes=num_classes)
113 |
114 | def ResNet50CIFAR(channel, num_classes):
115 | return ResNetCIFAR(Bottleneck, [3,4,6,3], channel=channel, num_classes=num_classes)
116 |
117 | def ResNet101CIFAR(channel, num_classes):
118 | return ResNetCIFAR(Bottleneck, [3,4,23,3], channel=channel, num_classes=num_classes)
119 |
120 | def ResNet152CIFAR(channel, num_classes):
121 | return ResNetCIFAR(Bottleneck, [3,8,36,3], channel=channel, num_classes=num_classes)
122 |
--------------------------------------------------------------------------------
/corruption-exp/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
3 |
4 | import argparse
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.utils.data
9 | import time
10 | import shutil
11 | from utils import set_seed, get_source_dataset, get_target_dataset, get_network, ParamDiffAug, get_daparam, save_and_print, TensorDataset, epoch, get_time
12 |
13 | np.set_printoptions(linewidth=250, suppress=True, precision=4)
14 |
15 | common_corruptions = ['gaussian_noise', 'shot_noise', 'impulse_noise',
16 | 'defocus_blur', 'glass_blur', 'motion_blur', 'zoom_blur',
17 | 'snow', 'frost', 'fog', 'brightness',
18 | 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression']
19 |
20 | def main(args):
21 |
22 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
23 | args.dsa_param = ParamDiffAug()
24 |
25 | if args.method == "DC":
26 | args.dsa = False
27 | else:
28 | args.dsa = True
29 |
30 | if args.dsa:
31 | args.dc_aug_param = None
32 | else:
33 | args.dc_aug_param = get_daparam(args.source_dataset, "", "", args.ipc)
34 |
35 | save_and_print(args.log_path, f'Hyper-parameters: {args.__dict__}')
36 |
37 | ''' organize the source dataset '''
38 | dst_train, images_train, labels_train = get_source_dataset(args)
39 |
40 | if args.target_dataset[-1] == "C": # CIFAR-10-C / ImageNet-Subset-C
41 | total_accs = np.zeros((args.num_eval, len(common_corruptions)))
42 | else: # CIFAR-10.1
43 | total_accs = np.zeros(args.num_eval)
44 |
45 | for idx_eval in range(args.num_eval):
46 | ''' Train the network from scratch'''
47 | net = get_network(args).to(args.device)
48 | images_train = images_train.to(args.device)
49 | labels_train = labels_train.to(args.device)
50 | lr = float(args.lr_net)
51 | Epoch = int(args.epoch_eval_train)
52 | lr_schedule = [Epoch // 2 + 1]
53 | optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)
54 |
55 | criterion = nn.CrossEntropyLoss().to(args.device)
56 |
57 | dst_train = TensorDataset(images_train, labels_train)
58 | trainloader = torch.utils.data.DataLoader(dst_train, batch_size=args.batch_train, shuffle=True, num_workers=0)
59 |
60 | start = time.time()
61 |
62 | for ep in range(Epoch + 1):
63 | loss_train, acc_train = epoch('train', trainloader, net, optimizer, criterion, args, aug=True)
64 | if ep in lr_schedule:
65 | lr *= 0.1
66 | optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)
67 |
68 | time_train = time.time() - start
69 | save_and_print(args.log_path, '%s Evaluate_%02d: epoch = %04d train time = %d s train loss = %.6f train acc = %.4f \n' % (get_time(), idx_eval, Epoch, int(time_train), loss_train, acc_train))
70 |
71 | ''' Test on the target dataset '''
72 | optimizer = None
73 | criterion = nn.CrossEntropyLoss().to(args.device)
74 |
75 | if args.target_dataset[-1] == "C": # CIFAR-10-C / ImageNet-Subset-C
76 | for idx_corruption, corruption in enumerate(common_corruptions):
77 | args.corruption = corruption
78 | dst_test, testloader = get_target_dataset(args)
79 | with torch.no_grad():
80 | _, acc_test = epoch('test', testloader, net, optimizer, criterion, args, aug=False)
81 |
82 | total_accs[idx_eval, idx_corruption] = acc_test
83 | else: # CIFAR-10.1
84 | dst_test, testloader = get_target_dataset(args)
85 | with torch.no_grad():
86 | _, acc_test = epoch('test', testloader, net, optimizer, criterion, args, aug=False)
87 |
88 | total_accs[idx_eval] = acc_test
89 |
90 | save_and_print(args.log_path, '%s Evaluate_%02d: %s \n' % (get_time(), idx_eval, total_accs[idx_eval]))
91 | torch.save(net, f"{args.save_path}/net#{idx_eval}.pt")
92 |
93 | ''' Visualize and Save '''
94 | save_and_print(args.log_path , "=" * 100)
95 | save_and_print(args.log_path, f"Average across num_eval: {np.average(total_accs, axis=0)}")
96 | save_and_print(args.log_path, f"Average: {np.average(total_accs)}")
97 | torch.save(total_accs, f"{args.save_path}/total_accs.pt")
98 |
99 | if __name__ == "__main__":
100 | parser = argparse.ArgumentParser()
101 | parser.add_argument('--method', type=str, default="FreD")
102 | parser.add_argument("--source_dataset", default='CIFAR10', type=str)
103 | parser.add_argument("--target_dataset", default='CIFAR10-C', type=str)
104 | parser.add_argument('--subset', type=str, default='imagenette', help='ImageNet subset. This only does anything when --dataset=ImageNet')
105 | parser.add_argument("--level", default=1, type=int)
106 | parser.add_argument('--ipc', type=int, default=1, help='image(s) per class')
107 | parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', help='differentiable Siamese augmentation strategy')
108 | parser.add_argument('--num_eval', type=int, default=5, help='the number of evaluating randomly initialized models')
109 | parser.add_argument('--epoch_eval_train', type=int, default=1000, help='epochs to train a model with synthetic data')
110 | parser.add_argument('--lr_net', type=float, default=0.01, help='learning rate for updating network parameters')
111 | parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks')
112 | parser.add_argument('--data_path', type=str, default='./data', help='dataset path')
113 | parser.add_argument('--save_path', type=str, default='./results', help='path to save results')
114 | parser.add_argument('--synset_path', type=str, default='./trained_synset', help='trained synthetic dataset path')
115 |
116 | parser.add_argument('--seed', type=int, default=0)
117 | parser.add_argument('--sh_file', type=str)
118 | parser.add_argument('--FLAG', type=str, default="TEST")
119 | args = parser.parse_args()
120 | set_seed(args.seed)
121 |
122 | if not os.path.exists(args.save_path):
123 | os.mkdir(args.save_path)
124 | args.save_path = args.save_path + f"/{args.FLAG}"
125 | if not os.path.exists(args.save_path):
126 | os.mkdir(args.save_path)
127 |
128 | shutil.copy(f"./scripts/{args.sh_file}", f"{args.save_path}/{args.sh_file}")
129 | args.log_path = f"{args.save_path}/log.txt"
130 |
131 | main(args)
132 |
--------------------------------------------------------------------------------
/TM/buffer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torch
4 | import torch.nn as nn
5 | from tqdm import tqdm
6 | from utils import get_dataset, get_network, get_daparam, TensorDataset, epoch, ParamDiffAug, set_seed, save_and_print
7 |
8 | import warnings
9 | warnings.filterwarnings("ignore", category=DeprecationWarning)
10 |
11 | def main(args):
12 | args.dsa = True if args.dsa == 'True' else False
13 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
14 | args.dsa_param = ParamDiffAug()
15 |
16 | save_dir = os.path.join(args.buffer_path, args.dataset)
17 | if args.dataset == "ImageNet":
18 | save_dir = os.path.join(save_dir, args.subset)
19 | if not args.zca:
20 | save_dir += "_NO_ZCA"
21 | save_dir = os.path.join(save_dir, args.model)
22 | if not os.path.exists(save_dir):
23 | os.makedirs(save_dir)
24 |
25 | args.log_path = f"{args.buffer_path}/log#{args.dataset}"
26 | if args.dataset == "ImageNet":
27 | args.log_path += f"_{args.subset}"
28 | if not args.zca:
29 | args.log_path += "_NO_ZCA"
30 | args.log_path += f"_{args.model}.txt"
31 |
32 | channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader, loader_train_dict, class_map, class_map_inv = get_dataset(args.dataset, args.data_path, args.batch_real, args.subset, args=args)
33 |
34 | save_and_print(args.log_path, f'Hyper-parameters: {args.__dict__}')
35 |
36 | ''' organize the real dataset '''
37 | images_all = []
38 | labels_all = []
39 | indices_class = [[] for c in range(num_classes)]
40 | save_and_print(args.log_path, "BUILDING DATASET")
41 | for i in tqdm(range(len(dst_train))):
42 | sample = dst_train[i]
43 | images_all.append(torch.unsqueeze(sample[0], dim=0))
44 | labels_all.append(class_map[torch.tensor(sample[1]).item()])
45 |
46 | for i, lab in tqdm(enumerate(labels_all)):
47 | indices_class[lab].append(i)
48 | images_all = torch.cat(images_all, dim=0).to("cpu")
49 | labels_all = torch.tensor(labels_all, dtype=torch.long, device="cpu")
50 |
51 | for c in range(num_classes):
52 | save_and_print(args.log_path, 'class c = %d: %d real images'%(c, len(indices_class[c])))
53 |
54 | for ch in range(channel):
55 | save_and_print(args.log_path, 'real images channel %d, mean = %.4f, std = %.4f'%(ch, torch.mean(images_all[:, ch]), torch.std(images_all[:, ch])))
56 |
57 | criterion = nn.CrossEntropyLoss().to(args.device)
58 |
59 | trajectories = []
60 |
61 | dst_train = TensorDataset(images_all, labels_all)
62 | trainloader = torch.utils.data.DataLoader(dst_train, batch_size=args.batch_train, shuffle=True, num_workers=0)
63 |
64 | ''' set augmentation for whole-dataset training '''
65 | args.dc_aug_param = get_daparam(args.dataset, args.model, args.model, None)
66 | args.dc_aug_param['strategy'] = 'crop_scale_rotate' # for whole-dataset training
67 | save_and_print(args.log_path, f'DC augmentation parameters: {args.dc_aug_param}')
68 |
69 | for it in range(0, args.num_experts):
70 |
71 | ''' Train synthetic data '''
72 | teacher_net = get_network(args.model, channel, num_classes, im_size).to(args.device) # get a random model
73 | teacher_net.train()
74 | lr = args.lr_teacher
75 | teacher_optim = torch.optim.SGD(teacher_net.parameters(), lr=lr, momentum=args.mom, weight_decay=args.l2) # optimizer_img for synthetic data
76 | teacher_optim.zero_grad()
77 |
78 | timestamps = []
79 |
80 | timestamps.append([p.detach().cpu() for p in teacher_net.parameters()])
81 |
82 | lr_schedule = [args.train_epochs // 2 + 1]
83 |
84 | for e in range(args.train_epochs):
85 |
86 | train_loss, train_acc = epoch("train", dataloader=trainloader, net=teacher_net, optimizer=teacher_optim,
87 | criterion=criterion, args=args, aug=True)
88 |
89 | test_loss, test_acc = epoch("test", dataloader=testloader, net=teacher_net, optimizer=None,
90 | criterion=criterion, args=args, aug=False)
91 |
92 | save_and_print(args.log_path, "Itr: {}\tEpoch: {}\tTrain Acc: {}\tTest Acc: {}".format(it, e, train_acc, test_acc))
93 |
94 | timestamps.append([p.detach().cpu() for p in teacher_net.parameters()])
95 |
96 | if e in lr_schedule and args.decay:
97 | lr *= 0.1
98 | teacher_optim = torch.optim.SGD(teacher_net.parameters(), lr=lr, momentum=args.mom, weight_decay=args.l2)
99 | teacher_optim.zero_grad()
100 |
101 | trajectories.append(timestamps)
102 |
103 | if len(trajectories) == args.save_interval:
104 | n = 0
105 | while os.path.exists(os.path.join(save_dir, "replay_buffer_{}.pt".format(n))):
106 | n += 1
107 | save_and_print(args.log_path, "Saving {}".format(os.path.join(save_dir, "replay_buffer_{}.pt".format(n))))
108 | torch.save(trajectories, os.path.join(save_dir, "replay_buffer_{}.pt".format(n)))
109 | trajectories = []
110 |
111 |
112 | if __name__ == '__main__':
113 | parser = argparse.ArgumentParser(description='Parameter Processing')
114 | parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset')
115 | parser.add_argument('--subset', type=str, default='imagenette', help='subset')
116 | parser.add_argument('--model', type=str, default='ConvNet', help='model')
117 | parser.add_argument('--num_experts', type=int, default=100, help='training iterations')
118 | parser.add_argument('--lr_teacher', type=float, default=0.01, help='learning rate for updating network parameters')
119 | parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks')
120 | parser.add_argument('--batch_real', type=int, default=256, help='batch size for real loader')
121 | parser.add_argument('--batch_test', type=int, default=128, help='batch size for real loader')
122 | parser.add_argument('--dsa', type=str, default='True', choices=['True', 'False'], help='whether to use differentiable Siamese augmentation.')
123 | parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', help='differentiable Siamese augmentation strategy')
124 | parser.add_argument('--data_path', type=str, default='../data', help='dataset path')
125 | parser.add_argument('--buffer_path', type=str, default='../buffers', help='buffer path')
126 | parser.add_argument('--train_epochs', type=int, default=50)
127 | parser.add_argument('--zca', action='store_true')
128 | parser.add_argument('--decay', action='store_true')
129 | parser.add_argument('--mom', type=float, default=0, help='momentum')
130 | parser.add_argument('--l2', type=float, default=0, help='l2 regularization')
131 | parser.add_argument('--save_interval', type=int, default=10)
132 |
133 | parser.add_argument('--res', type=str, default='')
134 |
135 | parser.add_argument('--seed', type=int, default=0)
136 | args = parser.parse_args()
137 | set_seed(args.seed)
138 |
139 | main(args)
140 |
141 |
142 |
--------------------------------------------------------------------------------
/ImageNet-abcde/reparam_module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import warnings
4 | import types
5 | from collections import namedtuple
6 | from contextlib import contextmanager
7 |
8 |
9 | class ReparamModule(nn.Module):
10 | def _get_module_from_name(self, mn):
11 | if mn == '':
12 | return self
13 | m = self
14 | for p in mn.split('.'):
15 | m = getattr(m, p)
16 | return m
17 |
18 | def __init__(self, module):
19 | super(ReparamModule, self).__init__()
20 | self.module = module
21 |
22 | param_infos = [] # (module name/path, param name)
23 | shared_param_memo = {}
24 | shared_param_infos = [] # (module name/path, param name, src module name/path, src param_name)
25 | params = []
26 | param_numels = []
27 | param_shapes = []
28 | for mn, m in self.named_modules():
29 | for n, p in m.named_parameters(recurse=False):
30 | if p is not None:
31 | if p in shared_param_memo:
32 | shared_mn, shared_n = shared_param_memo[p]
33 | shared_param_infos.append((mn, n, shared_mn, shared_n))
34 | else:
35 | shared_param_memo[p] = (mn, n)
36 | param_infos.append((mn, n))
37 | params.append(p.detach())
38 | param_numels.append(p.numel())
39 | param_shapes.append(p.size())
40 |
41 | assert len(set(p.dtype for p in params)) <= 1, \
42 | "expects all parameters in module to have same dtype"
43 |
44 | # store the info for unflatten
45 | self._param_infos = tuple(param_infos)
46 | self._shared_param_infos = tuple(shared_param_infos)
47 | self._param_numels = tuple(param_numels)
48 | self._param_shapes = tuple(param_shapes)
49 |
50 | # flatten
51 | flat_param = nn.Parameter(torch.cat([p.reshape(-1) for p in params], 0))
52 | self.register_parameter('flat_param', flat_param)
53 | self.param_numel = flat_param.numel()
54 | del params
55 | del shared_param_memo
56 |
57 | # deregister the names as parameters
58 | for mn, n in self._param_infos:
59 | delattr(self._get_module_from_name(mn), n)
60 | for mn, n, _, _ in self._shared_param_infos:
61 | delattr(self._get_module_from_name(mn), n)
62 |
63 | # register the views as plain attributes
64 | self._unflatten_param(self.flat_param)
65 |
66 | # now buffers
67 | # they are not reparametrized. just store info as (module, name, buffer)
68 | buffer_infos = []
69 | for mn, m in self.named_modules():
70 | for n, b in m.named_buffers(recurse=False):
71 | if b is not None:
72 | buffer_infos.append((mn, n, b))
73 |
74 | self._buffer_infos = tuple(buffer_infos)
75 | self._traced_self = None
76 |
77 | def trace(self, example_input, **trace_kwargs):
78 | assert self._traced_self is None, 'This ReparamModule is already traced'
79 |
80 | if isinstance(example_input, torch.Tensor):
81 | example_input = (example_input,)
82 | example_input = tuple(example_input)
83 | example_param = (self.flat_param.detach().clone(),)
84 | example_buffers = (tuple(b.detach().clone() for _, _, b in self._buffer_infos),)
85 |
86 | self._traced_self = torch.jit.trace_module(
87 | self,
88 | inputs=dict(
89 | _forward_with_param=example_param + example_input,
90 | _forward_with_param_and_buffers=example_param + example_buffers + example_input,
91 | ),
92 | **trace_kwargs,
93 | )
94 |
95 | # replace forwards with traced versions
96 | self._forward_with_param = self._traced_self._forward_with_param
97 | self._forward_with_param_and_buffers = self._traced_self._forward_with_param_and_buffers
98 | return self
99 |
100 | def clear_views(self):
101 | for mn, n in self._param_infos:
102 | setattr(self._get_module_from_name(mn), n, None) # This will set as plain attr
103 |
104 | def _apply(self, *args, **kwargs):
105 | if self._traced_self is not None:
106 | self._traced_self._apply(*args, **kwargs)
107 | return self
108 | return super(ReparamModule, self)._apply(*args, **kwargs)
109 |
110 | def _unflatten_param(self, flat_param):
111 | ps = (t.view(s) for (t, s) in zip(flat_param.split(self._param_numels), self._param_shapes))
112 | for (mn, n), p in zip(self._param_infos, ps):
113 | setattr(self._get_module_from_name(mn), n, p) # This will set as plain attr
114 | for (mn, n, shared_mn, shared_n) in self._shared_param_infos:
115 | setattr(self._get_module_from_name(mn), n, getattr(self._get_module_from_name(shared_mn), shared_n))
116 |
117 | @contextmanager
118 | def unflattened_param(self, flat_param):
119 | saved_views = [getattr(self._get_module_from_name(mn), n) for mn, n in self._param_infos]
120 | self._unflatten_param(flat_param)
121 | yield
122 | # Why not just `self._unflatten_param(self.flat_param)`?
123 | # 1. because of https://github.com/pytorch/pytorch/issues/17583
124 | # 2. slightly faster since it does not require reconstruct the split+view
125 | # graph
126 | for (mn, n), p in zip(self._param_infos, saved_views):
127 | setattr(self._get_module_from_name(mn), n, p)
128 | for (mn, n, shared_mn, shared_n) in self._shared_param_infos:
129 | setattr(self._get_module_from_name(mn), n, getattr(self._get_module_from_name(shared_mn), shared_n))
130 |
131 | @contextmanager
132 | def replaced_buffers(self, buffers):
133 | for (mn, n, _), new_b in zip(self._buffer_infos, buffers):
134 | setattr(self._get_module_from_name(mn), n, new_b)
135 | yield
136 | for mn, n, old_b in self._buffer_infos:
137 | setattr(self._get_module_from_name(mn), n, old_b)
138 |
139 | def _forward_with_param_and_buffers(self, flat_param, buffers, *inputs, **kwinputs):
140 | with self.unflattened_param(flat_param):
141 | with self.replaced_buffers(buffers):
142 | return self.module(*inputs, **kwinputs)
143 |
144 | def _forward_with_param(self, flat_param, *inputs, **kwinputs):
145 | with self.unflattened_param(flat_param):
146 | return self.module(*inputs, **kwinputs)
147 |
148 | def forward(self, *inputs, flat_param=None, buffers=None, **kwinputs):
149 | flat_param = torch.squeeze(flat_param)
150 | # print("PARAMS ON DEVICE: ", flat_param.get_device(), flat_param.shape)
151 | # print("DATA ON DEVICE: ", inputs[0].get_device(), inputs[0].shape)
152 | # flat_param.to("cuda:{}".format(inputs[0].get_device()))
153 | # self.module.to("cuda:{}".format(inputs[0].get_device()))
154 | if flat_param is None:
155 | flat_param = self.flat_param
156 | if buffers is None:
157 | return self._forward_with_param(flat_param, *inputs, **kwinputs)
158 | else:
159 | return self._forward_with_param_and_buffers(flat_param, tuple(buffers), *inputs, **kwinputs)
--------------------------------------------------------------------------------
/TM/reparam_module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import warnings
4 | import types
5 | from collections import namedtuple
6 | from contextlib import contextmanager
7 |
8 |
9 | class ReparamModule(nn.Module):
10 | def _get_module_from_name(self, mn):
11 | if mn == '':
12 | return self
13 | m = self
14 | for p in mn.split('.'):
15 | m = getattr(m, p)
16 | return m
17 |
18 | def __init__(self, module):
19 | super(ReparamModule, self).__init__()
20 | self.module = module
21 |
22 | param_infos = [] # (module name/path, param name)
23 | shared_param_memo = {}
24 | shared_param_infos = [] # (module name/path, param name, src module name/path, src param_name)
25 | params = []
26 | param_numels = []
27 | param_shapes = []
28 | for mn, m in self.named_modules():
29 | for n, p in m.named_parameters(recurse=False):
30 | if p is not None:
31 | if p in shared_param_memo:
32 | shared_mn, shared_n = shared_param_memo[p]
33 | shared_param_infos.append((mn, n, shared_mn, shared_n))
34 | else:
35 | shared_param_memo[p] = (mn, n)
36 | param_infos.append((mn, n))
37 | params.append(p.detach())
38 | param_numels.append(p.numel())
39 | param_shapes.append(p.size())
40 |
41 | assert len(set(p.dtype for p in params)) <= 1, \
42 | "expects all parameters in module to have same dtype"
43 |
44 | # store the info for unflatten
45 | self._param_infos = tuple(param_infos)
46 | self._shared_param_infos = tuple(shared_param_infos)
47 | self._param_numels = tuple(param_numels)
48 | self._param_shapes = tuple(param_shapes)
49 |
50 | # flatten
51 | flat_param = nn.Parameter(torch.cat([p.reshape(-1) for p in params], 0))
52 | self.register_parameter('flat_param', flat_param)
53 | self.param_numel = flat_param.numel()
54 | del params
55 | del shared_param_memo
56 |
57 | # deregister the names as parameters
58 | for mn, n in self._param_infos:
59 | delattr(self._get_module_from_name(mn), n)
60 | for mn, n, _, _ in self._shared_param_infos:
61 | delattr(self._get_module_from_name(mn), n)
62 |
63 | # register the views as plain attributes
64 | self._unflatten_param(self.flat_param)
65 |
66 | # now buffers
67 | # they are not reparametrized. just store info as (module, name, buffer)
68 | buffer_infos = []
69 | for mn, m in self.named_modules():
70 | for n, b in m.named_buffers(recurse=False):
71 | if b is not None:
72 | buffer_infos.append((mn, n, b))
73 |
74 | self._buffer_infos = tuple(buffer_infos)
75 | self._traced_self = None
76 |
77 | def trace(self, example_input, **trace_kwargs):
78 | assert self._traced_self is None, 'This ReparamModule is already traced'
79 |
80 | if isinstance(example_input, torch.Tensor):
81 | example_input = (example_input,)
82 | example_input = tuple(example_input)
83 | example_param = (self.flat_param.detach().clone(),)
84 | example_buffers = (tuple(b.detach().clone() for _, _, b in self._buffer_infos),)
85 |
86 | self._traced_self = torch.jit.trace_module(
87 | self,
88 | inputs=dict(
89 | _forward_with_param=example_param + example_input,
90 | _forward_with_param_and_buffers=example_param + example_buffers + example_input,
91 | ),
92 | **trace_kwargs,
93 | )
94 |
95 | # replace forwards with traced versions
96 | self._forward_with_param = self._traced_self._forward_with_param
97 | self._forward_with_param_and_buffers = self._traced_self._forward_with_param_and_buffers
98 | return self
99 |
100 | def clear_views(self):
101 | for mn, n in self._param_infos:
102 | setattr(self._get_module_from_name(mn), n, None) # This will set as plain attr
103 |
104 | def _apply(self, *args, **kwargs):
105 | if self._traced_self is not None:
106 | self._traced_self._apply(*args, **kwargs)
107 | return self
108 | return super(ReparamModule, self)._apply(*args, **kwargs)
109 |
110 | def _unflatten_param(self, flat_param):
111 | ps = (t.view(s) for (t, s) in zip(flat_param.split(self._param_numels), self._param_shapes))
112 | for (mn, n), p in zip(self._param_infos, ps):
113 | setattr(self._get_module_from_name(mn), n, p) # This will set as plain attr
114 | for (mn, n, shared_mn, shared_n) in self._shared_param_infos:
115 | setattr(self._get_module_from_name(mn), n, getattr(self._get_module_from_name(shared_mn), shared_n))
116 |
117 | @contextmanager
118 | def unflattened_param(self, flat_param):
119 | saved_views = [getattr(self._get_module_from_name(mn), n) for mn, n in self._param_infos]
120 | self._unflatten_param(flat_param)
121 | yield
122 | # Why not just `self._unflatten_param(self.flat_param)`?
123 | # 1. because of https://github.com/pytorch/pytorch/issues/17583
124 | # 2. slightly faster since it does not require reconstruct the split+view
125 | # graph
126 | for (mn, n), p in zip(self._param_infos, saved_views):
127 | setattr(self._get_module_from_name(mn), n, p)
128 | for (mn, n, shared_mn, shared_n) in self._shared_param_infos:
129 | setattr(self._get_module_from_name(mn), n, getattr(self._get_module_from_name(shared_mn), shared_n))
130 |
131 | @contextmanager
132 | def replaced_buffers(self, buffers):
133 | for (mn, n, _), new_b in zip(self._buffer_infos, buffers):
134 | setattr(self._get_module_from_name(mn), n, new_b)
135 | yield
136 | for mn, n, old_b in self._buffer_infos:
137 | setattr(self._get_module_from_name(mn), n, old_b)
138 |
139 | def _forward_with_param_and_buffers(self, flat_param, buffers, *inputs, **kwinputs):
140 | with self.unflattened_param(flat_param):
141 | with self.replaced_buffers(buffers):
142 | return self.module(*inputs, **kwinputs)
143 |
144 | def _forward_with_param(self, flat_param, *inputs, **kwinputs):
145 | with self.unflattened_param(flat_param):
146 | return self.module(*inputs, **kwinputs)
147 |
148 | def forward(self, *inputs, flat_param=None, buffers=None, **kwinputs):
149 | flat_param = torch.squeeze(flat_param)
150 | # print("PARAMS ON DEVICE: ", flat_param.get_device())
151 | # print("DATA ON DEVICE: ", inputs[0].get_device())
152 | # flat_param.to("cuda:{}".format(inputs[0].get_device()))
153 | # self.module.to("cuda:{}".format(inputs[0].get_device()))
154 | if flat_param is None:
155 | flat_param = self.flat_param
156 | if buffers is None:
157 | return self._forward_with_param(flat_param, *inputs, **kwinputs)
158 | else:
159 | return self._forward_with_param_and_buffers(flat_param, tuple(buffers), *inputs, **kwinputs)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Frequency Domain-based Dataset Distillation (FreD) [NeurIPS 2023]
2 |
3 | This repository contains an official PyTorch implementation for the paper "Frequency Domain-based Dataset Distillation" in NeurIPS 2023.
4 |
5 | **[Donghyeok Shin](http://kaal.dsso.kr/bbs/board.php?bo_table=sub2_1&wr_id=8) \*, [Seungjae Shin](https://sites.google.com/view/seungjae-shin) \*, and Il-Chul Moon**
6 | * Equal contribution
7 |
8 | | [paper](https://arxiv.org/abs/2311.08819) | [slides](https://neurips.cc/media/neurips-2023/Slides/71874.pdf) | [pretrained](https://drive.google.com/drive/folders/1r1OMVv9llejGmpHfK5DpW4m57Dz_SZ2n?usp=sharing) |
9 |
10 | ## Updates
11 | - (2024.05.07) We uploaded the distilled synthetic dataset except in a few cases. Please refer to [pretrained](https://drive.google.com/drive/folders/1r1OMVv9llejGmpHfK5DpW4m57Dz_SZ2n?usp=sharing). The rest of the cases will be uploaded as soon as possible.
12 |
13 | ## Overview
14 | 
15 | > **Abstract** *This paper presents FreD, a novel parameterization method for dataset distillation, which utilizes the frequency domain to distill a small-sized synthetic dataset from a large-sized original dataset. Unlike conventional approaches that focus on the spatial domain, FreD employs frequency-based transforms to optimize the frequency representations of each data instance. By leveraging the concentration of spatial domain information on specific frequency components, FreD intelligently selects a subset of frequency dimensions for optimization, leading to a significant reduction in the required budget for synthesizing an instance. Through the selection of frequency dimensions based on the explained variance, FreD demonstrates both theoretical and empirical evidence of its ability to operate efficiently within a limited budget, while better preserving the information of the original dataset compared to conventional parameterization methods. Furthermore, based on the orthogonal compatibility of FreD with existing methods, we confirm that FreD consistently improves the performances of existing distillation methods over the evaluation scenarios with different benchmark datasets.*
16 |
17 | ## Requirements
18 | This code was tested with CUDA 11.4 and Python 3.8.
19 | ```
20 | pip install -r requirements.txt
21 | ```
22 |
23 | ## Usage
24 | The main hyper-parameters of FreD are as follows:
25 | - `msz_per_channel` : Memory size per each channel
26 | - `lr_freq` : Learning rate for synthetic frequency representation
27 | - `mom_freq` : Momentum for synthetic frequency representation
28 |
29 | The detailed values of these hyper-parameters can be found in our paper.
30 | For other hyper-parameters, we follow the default setting of each dataset distillation objectives.
31 | Please refer to the bash file for detailed arguments to run the experiment.
32 |
33 | Below are some example commands to run FreD with each dataset distillation objective.
34 | ### FreD with Gradient Matching (DC)
35 | - Run the following command:
36 | ```
37 | cd DC/scripts
38 | sh run_DC_FreD.sh
39 | ```
40 |
41 | ### FreD with Distribution Matching (DM)
42 | - Run the following command:
43 | ```
44 | cd DM/scripts
45 | sh run_DM_FreD.sh
46 | ```
47 |
48 | ### FreD with Trajectory Matching (TM)
49 | - Since TM need expert trajectories, run `run_buffer.sh` to generate expert trajectories before distillation:
50 | ```
51 | cd TM/scripts
52 | sh run_buffer.sh
53 | sh run_TM_FreD.sh
54 | ```
55 |
56 | ### FreD with Other Dataset Distillation Objective
57 | FreD is a highly compatible parameterization method regardless of the dataset distillation objective.
58 | Herein, we provide simple guidelines for how to use FreD with different dataset distillation objectives.
59 | ```
60 | # Define the frequency domain-based parameterization
61 | synset = SynSet(args)
62 |
63 | # Initialization
64 | synset.init(images_all, labels_all, indices_class)
65 |
66 | # Get partial synthetic dataset
67 | images_syn, labels_syn = synset.get(indices=indices)
68 |
69 | # Get entire dataset (need_copy is optional)
70 | images_syn, labels_syn = synset.get(need_copy=True)
71 | ```
72 | ### ImageNet-[A,B,C,D,E]
73 | - For ImageNet-[A,B,C,D,E] experiments, we built upon the GLaD's official code.
74 | - Run the following command with appropriate objective in `XXX`:
75 | - If you want to run FreD with TM, run `run_buffer.sh` before just like above.
76 | ```
77 | cd ImageNet-abcde/scripts
78 | sh run_XXX_FreD.sh
79 | ```
80 |
81 | ### 3D-MNIST
82 | - Download [3D-MNIST](https://www.kaggle.com/datasets/daavoo/3d-mnist) dataset.
83 | - Run the following command:
84 | ```
85 | cd 3D-MNIST/scripts
86 | sh run_DM_FreD.sh
87 | ```
88 |
89 | ### Robustness of Corruption
90 | - Download the corrupted dataset: [CIFAR-10.1](https://github.com/modestyachts/CIFAR-10.1), [CIFAR-10-C](https://zenodo.org/records/2535967), and [ImageNet-C](https://zenodo.org/records/2235448#.YpCSLxNBxAc).
91 | - Place the trained synthetic dataset at `corruption-exp/trained_synset/FreD/{dataset_name}`.
92 | - Run the following command:
93 | ```
94 | cd corruption-exp/scripts
95 | sh run.sh
96 | ```
97 |
98 | ## Experiment Results
99 | - Test accuracies (%) on low-dimensional datasets (≤ 64×64 resolution) with TM.
100 |
101 | | | MNIST | FashionMNIST | SVHN | CIFAR-10 | CIFAR-100 | Tiny-ImageNet |
102 | | :------: | :-----: | :----: | :-----: | :----: | :----: | :----: |
103 | | 1 img/cls | 95.8 | 84.6 | 82.2 | 60.6 | 34.6 | 19.2 |
104 | | 10 img/cls | 97.6 | 89.1 | 89.5 | 70.3 | 42.7 | 24.2 |
105 | | 50 img/cls | - | - | 90.3 | 75.8 | 47.8 | 26.4 |
106 |
107 | - Test accuracies (%) on Image-[Nette, Woof, Fruit, Yellow, Meow, Squawk] (128 × 128 resolution) with TM.
108 |
109 | | | ImageNette | ImageWoof | ImageFruit | ImageYellow | ImageMeow | ImageSquawk |
110 | | :------: | :-----: | :----: | :-----: | :----: | :----: | :----: |
111 | | 1 img/cls | 66.8 | 38.3 | 43.7 | 63.2 | 43.2 | 57.0 |
112 | | 10 img/cls | 72.0 | 41.3 | 47.0 | 69.2 | 48.6 | 67.3 |
113 |
114 | - Test accuracies (%) on ImageNet-[A, B, C, D, E] (128 × 128 resolution) with TM under 1 img/cls.
115 |
116 | | | ImageNet-A | ImageNet-B | ImageNet-C | ImageNet-D | ImageNet-E |
117 | | :------: | :-----: | :----: | :-----: | :----: | :----: |
118 | | DC w/ FreD | 53.1 | 54.8 | 54.2 | 42.8 | 41.0 |
119 | | DM w/ FreD | 58.0 | 58.6 | 55.6 | 46.3 | 45.0 |
120 | | TM w/ FreD | 67.7 | 69.3 | 63.6 | 54.4 | 55.4 |
121 |
122 | More results can be found in our paper.
123 |
124 | ## Citation
125 | If you find the code useful for your research, please consider citing our paper.
126 | ```bib
127 | @inproceedings{shin2023frequency,
128 | title={Frequency Domain-Based Dataset Distillation},
129 | author={Shin, DongHyeok and Shin, Seungjae and Moon, Il-chul},
130 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
131 | year={2023}
132 | }
133 | ```
134 | This work is heavily built upon the code from
135 | - *Bo Zhao, Konda Reddy Mopuri, and Hakan Bilen. Dataset condensation with gradient matching. arXiv preprint arXiv:2006.05929, 2020.* [Code Link](https://github.com/VICO-UoE/DatasetCondensation)
136 | - *Bo Zhao and Hakan Bilen. Dataset condensation with distribution matching. In Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision, pages 6514–6523, 2023.* [Code Link](https://github.com/VICO-UoE/DatasetCondensation)
137 | - *George Cazenavette, Tongzhou Wang, Antonio Torralba, Alexei A Efros, and Jun-Yan Zhu. Dataset distillation by matching training trajectories. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 4750–4759, 2022.* [Code Link](https://github.com/georgecazenavette/mtt-distillation)
138 | - *George Cazenavette, Tongzhou Wang, Antonio Torralba, Alexei A Efros, and Jun-Yan Zhu. Generalizing dataset distillation via deep generative prior. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 3739–3748, 2023.* [Code Link](https://github.com/GeorgeCazenavette/glad/tree/main)
139 |
--------------------------------------------------------------------------------
/3D-MNIST/utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch.utils.data import Dataset
8 | from torchvision import datasets, transforms
9 | from scipy.ndimage.interpolation import rotate as scipyrotate
10 | from networks import Conv3DNet
11 |
12 | import tqdm
13 |
14 | import random
15 | import h5py
16 |
17 | def set_seed(seed):
18 | torch.manual_seed(seed)
19 | torch.cuda.manual_seed(seed)
20 | torch.cuda.manual_seed_all(seed)
21 | np.random.seed(seed)
22 | random.seed(seed)
23 | torch.backends.cudnn.deterministic = True
24 | torch.backends.cudnn.benchmark = False
25 |
26 | def save_and_print(dirname, msg):
27 | if not os.path.isfile(dirname):
28 | f = open(dirname, "w")
29 | f.write(str(time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime())))
30 | f.write("\n")
31 | f.close()
32 | f = open(dirname, "a")
33 | f.write(str(msg))
34 | f.write("\n")
35 | f.close()
36 | print(msg)
37 |
38 | def get_images(images_all, indices_class, c, n): # get random n images from class c
39 | idx_shuffle = np.random.permutation(indices_class[c])[:n]
40 | return images_all[idx_shuffle]
41 |
42 | def prepare_points(tensor, data_config, threshold = False):
43 | if threshold:
44 | tensor = np.where(
45 | tensor > data_config.threshold,
46 | data_config.lower,
47 | data_config.upper
48 | )
49 | tensor = tensor.reshape((
50 | tensor.shape[0],
51 | data_config.y_shape,
52 | data_config.x_shape,
53 | data_config.z_shape
54 | ))
55 | return tensor
56 |
57 | class data_config:
58 | threshold = 0.2
59 | upper = 1
60 | lower = 0
61 | x_shape = 16
62 | y_shape = 16
63 | z_shape = 16
64 |
65 | def get_dataset(data_path, args=None):
66 | channel = 1
67 | im_size = (16, 16, 16)
68 | num_classes = 10
69 | class_names = [str(c) for c in range(num_classes)]
70 |
71 | full_vectors = os.path.join(data_path, '3D-MNIST', 'full_dataset_vectors.h5')
72 |
73 | with h5py.File(full_vectors, "r") as hf:
74 | X_train = hf['X_train'][:]
75 | X_test = hf['X_test'][:]
76 | y_train = hf['y_train'][:]
77 | y_test = hf['y_test'][:]
78 |
79 | X_train = prepare_points(X_train, data_config)
80 | X_test = prepare_points(X_test, data_config)
81 |
82 | X_train, X_test = torch.tensor(X_train, dtype=torch.float, device=args.device), torch.tensor(X_test, dtype=torch.float, device=args.device)
83 | X_train, X_test = torch.unsqueeze(X_train, dim=1), torch.unsqueeze(X_test, dim=1)
84 | y_train, y_test = torch.tensor(y_train, dtype=torch.long, device=args.device), torch.tensor(y_test, dtype=torch.long, device=args.device)
85 |
86 | dst_train = TensorDataset(X_train, y_train) # no augmentation
87 | dst_test = TensorDataset(X_test, y_test)
88 | testloader = torch.utils.data.DataLoader(dst_test, shuffle=False, batch_size=256, num_workers=0)
89 |
90 | return channel, im_size, num_classes, class_names, dst_train, dst_test, testloader
91 |
92 |
93 | class TensorDataset(Dataset):
94 | def __init__(self, images, labels): # images: n x c x h x w tensor
95 | self.images = images.detach().float()
96 | self.labels = labels.detach()
97 |
98 | def __getitem__(self, index):
99 | return self.images[index], self.labels[index]
100 |
101 | def __len__(self):
102 | return self.images.shape[0]
103 |
104 |
105 | def get_network(model, channel, num_classes, im_size=(32, 32)):
106 | #torch.random.manual_seed(int(time.time() * 1000) % 100000)
107 | net_width, net_depth = 64, 3
108 |
109 | if model == 'Conv3DNet':
110 | net = Conv3DNet(channel, num_classes, net_width, net_depth, im_size)
111 | else:
112 | net = None
113 | exit('unknown model: %s'%model)
114 |
115 | gpu_num = torch.cuda.device_count()
116 | if gpu_num>0:
117 | device = 'cuda'
118 | if gpu_num>1:
119 | net = nn.DataParallel(net)
120 | else:
121 | device = 'cpu'
122 | net = net.to(device)
123 |
124 | return net
125 |
126 | def get_time():
127 | return str(time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime()))
128 |
129 | def get_loops(ipc):
130 | # Get the two hyper-parameters of outer-loop and inner-loop.
131 | # The following values are empirically good.
132 | if ipc == 1:
133 | outer_loop, inner_loop = 1, 1
134 | elif ipc == 10:
135 | outer_loop, inner_loop = 10, 50
136 | elif ipc == 20:
137 | outer_loop, inner_loop = 20, 25
138 | elif ipc == 30:
139 | outer_loop, inner_loop = 30, 20
140 | elif ipc == 40:
141 | outer_loop, inner_loop = 40, 15
142 | elif ipc == 50:
143 | outer_loop, inner_loop = 50, 10
144 |
145 | elif ipc == 2:
146 | outer_loop, inner_loop = 1, 1
147 | elif ipc == 11:
148 | outer_loop, inner_loop = 10, 50
149 | elif ipc == 51:
150 | outer_loop, inner_loop = 50, 10
151 |
152 | else:
153 | outer_loop, inner_loop = 0, 0
154 | exit('loop hyper-parameters are not defined for %d ipc'%ipc)
155 | return outer_loop, inner_loop
156 |
157 |
158 | def epoch(mode, dataloader, net, optimizer, criterion, args, aug):
159 | loss_avg, acc_avg, num_exp = 0, 0, 0
160 | net = net.to(args.device)
161 | criterion = criterion.to(args.device)
162 |
163 | if mode == 'train':
164 | net.train()
165 | else:
166 | net.eval()
167 |
168 | for i_batch, datum in enumerate(dataloader):
169 | img = datum[0].float().to(args.device)
170 | lab = datum[1].long().to(args.device)
171 | n_b = lab.shape[0]
172 |
173 | output = net(img)
174 | loss = criterion(output, lab)
175 | acc = np.sum(np.equal(np.argmax(output.cpu().data.numpy(), axis=-1), lab.cpu().data.numpy()))
176 |
177 | loss_avg += loss.item()*n_b
178 | acc_avg += acc
179 | num_exp += n_b
180 |
181 | if mode == 'train':
182 | optimizer.zero_grad()
183 | loss.backward()
184 | optimizer.step()
185 |
186 | loss_avg /= num_exp
187 | acc_avg /= num_exp
188 |
189 | return loss_avg, acc_avg
190 |
191 |
192 | def evaluate_synset(it_eval, net, images_train, labels_train, testloader, args):
193 | net = net.to(args.device)
194 | images_train = images_train.to(args.device)
195 | labels_train = labels_train.to(args.device)
196 | lr = float(args.lr_net)
197 | Epoch = int(args.epoch_eval_train)
198 | lr_schedule = [Epoch//2+1]
199 | optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)
200 | criterion = nn.CrossEntropyLoss().to(args.device)
201 |
202 | dst_train = TensorDataset(images_train, labels_train)
203 | trainloader = torch.utils.data.DataLoader(dst_train, batch_size=args.batch_train, shuffle=True, num_workers=0)
204 |
205 | start = time.time()
206 | for ep in tqdm.tqdm(range(Epoch+1)):
207 | loss_train, acc_train = epoch('train', trainloader, net, optimizer, criterion, args, aug = False)
208 | if ep in lr_schedule:
209 | lr *= 0.1
210 | optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)
211 |
212 | time_train = time.time() - start
213 | loss_test, acc_test = epoch('test', testloader, net, optimizer, criterion, args, aug = False)
214 | save_and_print(args.log_path, '%s Evaluate_%02d: epoch = %04d train time = %d s train loss = %.6f train acc = %.4f, test acc = %.4f' % (get_time(), it_eval, Epoch, int(time_train), loss_train, acc_train, acc_test))
215 |
216 | return net, acc_train, acc_test
217 |
--------------------------------------------------------------------------------
/ImageNet-abcde/networks/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch
4 |
5 | ''' ResNet '''
6 |
7 | LAYERS = {'resnet18': [2, 2, 2, 2],
8 | 'resnet34': [3, 4, 6, 3],
9 | 'resnet50': [3, 4, 6, 3],
10 | 'resnet101': [3, 4, 23, 3],
11 | 'resnet152': [3, 8, 36, 3],
12 | 'resnet20': [3, 3, 3],
13 | 'resnet32': [5, 5, 5],
14 | 'resnet44': [7, 7, 7],
15 | 'resnet56': [9, 9, 9],
16 | 'resnet110': [18, 18, 18],
17 | }
18 |
19 |
20 | class BasicBlock(nn.Module):
21 | expansion = 1
22 |
23 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'):
24 | super(BasicBlock, self).__init__()
25 | self.norm = norm
26 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
27 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(
28 | planes) if self.norm == 'batch' else nn.Identity()
29 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
30 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(
31 | planes) if self.norm == 'batch' else nn.Identity()
32 |
33 | self.shortcut = nn.Sequential()
34 | if stride != 1 or in_planes != self.expansion * planes:
35 | self.shortcut = nn.Sequential(
36 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
37 | nn.GroupNorm(self.expansion * planes, self.expansion * planes,
38 | affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(
39 | self.expansion * planes) if self.norm == 'batch' else nn.Identity()
40 | )
41 |
42 | def forward(self, x):
43 | out = F.relu(self.bn1(self.conv1(x)))
44 | out = self.bn2(self.conv2(out))
45 | out += self.shortcut(x)
46 | out = F.relu(out)
47 | return out
48 |
49 |
50 | class Bottleneck(nn.Module):
51 | expansion = 4
52 |
53 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'):
54 | super(Bottleneck, self).__init__()
55 | self.norm = norm
56 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
57 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(
58 | planes) if self.norm == 'batch' else nn.Identity()
59 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
60 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(
61 | planes) if self.norm == 'batch' else nn.Identity()
62 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
63 | self.bn3 = nn.GroupNorm(self.expansion * planes, self.expansion * planes,
64 | affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(
65 | self.expansion * planes) if self.norm == 'batch' else nn.Identity()
66 |
67 | self.shortcut = nn.Sequential()
68 | if stride != 1 or in_planes != self.expansion * planes:
69 | self.shortcut = nn.Sequential(
70 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
71 | nn.GroupNorm(self.expansion * planes, self.expansion * planes,
72 | affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(
73 | self.expansion * planes) if self.norm == 'batch' else nn.Identity()
74 | )
75 |
76 | def forward(self, x):
77 | out = F.relu(self.bn1(self.conv1(x)))
78 | out = F.relu(self.bn2(self.conv2(out)))
79 | out = self.bn3(self.conv3(out))
80 | out += self.shortcut(x)
81 | out = F.relu(out)
82 | return out
83 |
84 |
85 | class ResNet(nn.Module):
86 | def __init__(self, block, num_blocks, channel=3, num_classes=10, norm='instancenorm'):
87 | super(ResNet, self).__init__()
88 | self.in_planes = 64
89 | self.norm = norm
90 |
91 | self.conv1 = nn.Conv2d(channel, 64, kernel_size=3, stride=1, padding=1, bias=False)
92 | self.bn1 = nn.GroupNorm(64, 64, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(
93 | 64) if self.norm == 'batch' else nn.Identity()
94 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
95 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
96 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
97 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
98 | self.pool = nn.AdaptiveAvgPool2d((1, 1))
99 | self.classifier = nn.Linear(512 * block.expansion, num_classes)
100 |
101 | def _make_layer(self, block, planes, num_blocks, stride):
102 | strides = [stride] + [1] * (num_blocks - 1)
103 | layers = []
104 | for stride in strides:
105 | layers.append(block(self.in_planes, planes, stride, self.norm))
106 | self.in_planes = planes * block.expansion
107 | return nn.Sequential(*layers)
108 |
109 | def forward(self, x):
110 | out = F.relu(self.bn1(self.conv1(x)))
111 | out = self.layer1(out)
112 | out = self.layer2(out)
113 | out = self.layer3(out)
114 | out = self.layer4(out)
115 | out = self.pool(out)
116 |
117 | out = out.view(out.size(0), -1)
118 | feat_fc = out
119 | out = self.classifier(out)
120 |
121 | return out
122 |
123 | def embed(self, x):
124 | out = F.relu(self.bn1(self.conv1(x)))
125 | out = self.layer1(out)
126 | out = self.layer2(out)
127 | out = self.layer3(out)
128 | out = self.layer4(out)
129 | out = F.avg_pool2d(out, 4)
130 | out = out.view(out.size(0), -1)
131 | return out
132 |
133 |
134 | class ResNetImageNet(nn.Module):
135 | def __init__(self, block, num_blocks, channel=3, num_classes=10, norm='instancenorm'):
136 | super(ResNetImageNet, self).__init__()
137 | self.in_planes = 64
138 | self.norm = norm
139 |
140 | self.conv1 = nn.Conv2d(channel, 64, kernel_size=7, stride=2, padding=3, bias=False)
141 | self.bn1 = nn.GroupNorm(64, 64, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(
142 | 64) if self.norm == 'batch' else nn.Identity()
143 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
144 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
145 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
146 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
147 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
148 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
149 | self.classifier = nn.Linear(512 * block.expansion, num_classes)
150 |
151 | def _make_layer(self, block, planes, num_blocks, stride):
152 | strides = [stride] + [1] * (num_blocks - 1)
153 | layers = []
154 | for stride in strides:
155 | layers.append(block(self.in_planes, planes, stride, self.norm))
156 | self.in_planes = planes * block.expansion
157 | return nn.Sequential(*layers)
158 |
159 | def forward(self, x):
160 | out = F.relu(self.bn1(self.conv1(x)))
161 | out = self.maxpool(out)
162 | out = self.layer1(out)
163 | out = self.layer2(out)
164 | out = self.layer3(out)
165 | out = self.layer4(out)
166 | # out = F.avg_pool2d(out, 4)
167 | # out = out.view(out.size(0), -1)
168 | out = self.avgpool(out)
169 | out = torch.flatten(out, 1)
170 | out = self.classifier(out)
171 | return out
172 |
173 |
174 | def ResNet18BN(channel, num_classes):
175 | return ResNet(BasicBlock, [2, 2, 2, 2], channel=channel, num_classes=num_classes, norm='batchnorm')
176 |
177 |
178 | def ResNet18(channel, num_classes, norm):
179 | return ResNet(BasicBlock, [2, 2, 2, 2], channel=channel, num_classes=num_classes, norm=norm)
180 |
181 |
182 | def ResNet18ImageNet(channel, num_classes):
183 | return ResNetImageNet(BasicBlock, [2, 2, 2, 2], channel=channel, num_classes=num_classes)
184 |
185 |
186 | def ResNet34(channel, num_classes):
187 | return ResNet(BasicBlock, [3, 4, 6, 3], channel=channel, num_classes=num_classes)
188 |
189 |
190 | def ResNet50(channel, num_classes):
191 | return ResNet(Bottleneck, [3, 4, 6, 3], channel=channel, num_classes=num_classes)
192 |
193 |
194 | def ResNet101(channel, num_classes):
195 | return ResNet(Bottleneck, [3, 4, 23, 3], channel=channel, num_classes=num_classes)
196 |
197 |
198 | def ResNet152(channel, num_classes):
199 | return ResNet(Bottleneck, [3, 8, 36, 3], channel=channel, num_classes=num_classes)
200 |
201 |
202 | def ResNet152Imagenet(channel, num_classes):
203 | return ResNetImageNet(Bottleneck, [3, 8, 36, 3], channel=channel, num_classes=num_classes)
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/ImageNet-abcde/glad_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import copy
4 | from tqdm import tqdm
5 |
6 | from utils import get_network, config, evaluate_synset, save_and_print
7 |
8 | def build_dataset(ds, class_map, num_classes):
9 | images_all = []
10 | labels_all = []
11 | indices_class = [[] for c in range(num_classes)]
12 | print("BUILDING DATASET")
13 | for i in tqdm(range(len(ds))):
14 | sample = ds[i]
15 | images_all.append(torch.unsqueeze(sample[0], dim=0))
16 | labels_all.append(class_map[torch.tensor(sample[1]).item()])
17 | for i, lab in tqdm(enumerate(labels_all)):
18 | indices_class[lab].append(i)
19 | images_all = torch.cat(images_all, dim=0).to("cpu")
20 | labels_all = torch.tensor(labels_all, dtype=torch.long, device="cpu")
21 |
22 | return images_all, labels_all, indices_class
23 |
24 |
25 | def prepare_latents(channel=3, num_classes=10, im_size=(32, 32), zdim=512, G=None, class_map_inv={}, get_images=None, args=None):
26 | with torch.no_grad():
27 | ''' initialize the synthetic data '''
28 | label_syn = torch.tensor([i*np.ones(args.ipc, dtype=np.int64) for i in range(num_classes)], dtype=torch.long, requires_grad=False, device=args.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9]
29 |
30 | if args.space == 'p':
31 | latents = torch.randn(size=(num_classes * args.ipc, channel, im_size[0], im_size[1]), dtype=torch.float, requires_grad=False, device=args.device)
32 | f_latents = None
33 |
34 | else:
35 | zs = torch.randn(num_classes * args.ipc, zdim, device=args.device, requires_grad=False)
36 |
37 | if "imagenet" in args.dataset:
38 | one_hot_dim = 1000
39 | elif args.dataset == "CIFAR10":
40 | one_hot_dim = 10
41 | elif args.dataset == "CIFAR100":
42 | one_hot_dim = 100
43 |
44 | if args.avg_w:
45 | G_labels = torch.zeros([label_syn.nelement(), one_hot_dim], device=args.device)
46 | G_labels[torch.arange(0, label_syn.nelement(), dtype=torch.long), [class_map_inv[x.item()] for x in label_syn]] = 1
47 | new_latents = []
48 | for label in G_labels:
49 | zs = torch.randn(1000, zdim).to(args.device)
50 | ws = G.mapping(zs, torch.stack([label] * 1000))
51 | w = torch.mean(ws, dim=0)
52 | new_latents.append(w)
53 | latents = torch.stack(new_latents)
54 | del zs
55 | for _ in new_latents:
56 | del _
57 | del new_latents
58 |
59 | else:
60 | G_labels = torch.zeros([label_syn.nelement(), one_hot_dim], device=args.device)
61 | G_labels[torch.arange(0, label_syn.nelement(), dtype=torch.long), [class_map_inv[x.item()] for x in label_syn]] = 1
62 | if args.distributed and False:
63 | latents = G.mapping(zs.to("cuda:1"), G_labels.to("cuda:1")).to("cuda:0")
64 | else:
65 | latents = G.mapping(zs, G_labels)
66 | del zs
67 |
68 | del G_labels
69 |
70 | ws = latents
71 | if args.layer is not None:
72 | f_latents = torch.cat(
73 | [G.forward(split_ws, f_layer=args.layer, mode="to_f").detach() for split_ws in
74 | torch.split(ws, args.sg_batch)])
75 | f_type = f_latents.dtype
76 | f_latents = f_latents.to(torch.float32).cpu()
77 | f_latents = torch.nan_to_num(f_latents, posinf=5.0, neginf=-5.0)
78 | f_latents = torch.clip(f_latents, min=-10, max=10)
79 | f_latents = f_latents.to(f_type).cuda()
80 |
81 | save_and_print(args.log_path, f"{torch.mean(f_latents)}, {torch.std(f_latents)}")
82 |
83 | if args.rand_f:
84 | f_latents = (torch.randn(f_latents.shape).to(args.device) * torch.std(f_latents, dim=(1,2,3), keepdim=True) + torch.mean(f_latents, dim=(1,2,3), keepdim=True))
85 |
86 | f_latents = f_latents.to(f_type)
87 | save_and_print(args.log_path, f"{torch.mean(f_latents)}, {torch.std(f_latents)}")
88 | f_latents.requires_grad_(True)
89 | else:
90 | f_latents = None
91 |
92 | if args.pix_init == 'real' and args.space == "p":
93 | save_and_print(args.log_path, 'initialize synthetic data from random real images')
94 | for c in range(num_classes):
95 | latents.data[c*args.ipc:(c+1)*args.ipc] = torch.cat([get_images(c, 1).detach().data for s in range(args.ipc)])
96 | else:
97 | save_and_print(args.log_path, 'initialize synthetic data from random noise')
98 |
99 | latents = latents.detach().to(args.device).requires_grad_(True)
100 |
101 | return latents, f_latents, label_syn
102 |
103 |
104 | def get_optimizer_img(latents=None, f_latents=None, G=None, args=None):
105 | if args.space == "wp" and (args.layer is not None and args.layer != -1):
106 | optimizer_img = torch.optim.SGD([latents], lr=args.lr_w, momentum=0.5)
107 | optimizer_img.add_param_group({'params': f_latents, 'lr': args.lr_img, 'momentum': 0.5})
108 | else:
109 | optimizer_img = torch.optim.SGD([latents], lr=args.lr_img, momentum=0.5)
110 |
111 | if args.learn_g:
112 | G.requires_grad_(True)
113 | optimizer_img.add_param_group({'params': G.parameters(), 'lr': args.lr_g, 'momentum': 0.5})
114 |
115 | optimizer_img.zero_grad()
116 |
117 | return optimizer_img
118 |
119 | def get_eval_lrs(args):
120 | eval_pool_dict = {
121 | args.model: 0.001,
122 | "ResNet18": 0.001,
123 | "VGG11": 0.0001,
124 | "AlexNet": 0.001,
125 | "ViT": 0.001,
126 |
127 | "AlexNetCIFAR": 0.001,
128 | "ResNet18CIFAR": 0.001,
129 | "VGG11CIFAR": 0.0001,
130 | "ViTCIFAR": 0.001,
131 | }
132 |
133 | return eval_pool_dict
134 |
135 |
136 | def eval_loop(latents=None, f_latents=None, label_syn=None, G=None, best_acc={}, best_std={}, testloader=None, model_eval_pool=[], it=0, channel=3, num_classes=10, im_size=(32, 32), args=None):
137 | curr_acc_dict = {}
138 | max_acc_dict = {}
139 |
140 | curr_std_dict = {}
141 | max_std_dict = {}
142 |
143 | eval_pool_dict = get_eval_lrs(args)
144 |
145 | save_this_it = False
146 |
147 | for model_eval in model_eval_pool:
148 |
149 | if model_eval != args.model and args.wait_eval and it != args.Iteration:
150 | continue
151 | save_and_print(args.log_path, '-------------------------\nEvaluation\nmodel_train = %s, model_eval = %s, iteration = %d' % (args.model, model_eval, it))
152 |
153 | accs_test = []
154 | accs_train = []
155 |
156 | for it_eval in range(args.num_eval):
157 | net_eval = get_network(model_eval, channel, num_classes, im_size, width=args.width, depth=args.depth, dist=False).to(args.device) # get a random model
158 | eval_lats = latents
159 | eval_labs = label_syn
160 | image_syn = latents
161 | image_syn_eval, label_syn_eval = copy.deepcopy(image_syn.detach()), copy.deepcopy(eval_labs.detach()) # avoid any unaware modification
162 |
163 | if args.space == "wp":
164 | with torch.no_grad():
165 | image_syn_eval = torch.cat(
166 | [latent_to_im(G, (image_syn_eval_split, f_latents_split), args=args).detach() for
167 | image_syn_eval_split, f_latents_split, label_syn_split in zip(torch.split(image_syn_eval, args.sg_batch), torch.split(f_latents, args.sg_batch), torch.split(label_syn, args.sg_batch))])
168 |
169 | args.lr_net = eval_pool_dict[model_eval]
170 | _, acc_train, acc_test = evaluate_synset(it_eval, net_eval, image_syn_eval, label_syn_eval, testloader, args=args, aug=True)
171 | del _
172 | del net_eval
173 | accs_test.append(acc_test)
174 | accs_train.append(acc_train)
175 |
176 | accs_test = np.array(accs_test)
177 | accs_train = np.array(accs_train)
178 | acc_test_mean = np.mean(np.max(accs_test, axis=1))
179 | acc_test_std = np.std(np.max(accs_test, axis=1))
180 | best_dict_str = "{}".format(model_eval)
181 | if acc_test_mean > best_acc[best_dict_str]:
182 | best_acc[best_dict_str] = acc_test_mean
183 | best_std[best_dict_str] = acc_test_std
184 | save_this_it = True
185 |
186 | curr_acc_dict[best_dict_str] = acc_test_mean
187 | curr_std_dict[best_dict_str] = acc_test_std
188 |
189 | max_acc_dict[best_dict_str] = best_acc[best_dict_str]
190 | max_std_dict[best_dict_str] = best_std[best_dict_str]
191 |
192 | save_and_print(args.log_path, 'Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------' % (len(accs_test[:, -1]), model_eval, acc_test_mean, np.std(np.max(accs_test, axis=1))))
193 | save_and_print(args.log_path, f"{args.save_path}")
194 | save_and_print(args.log_path, f"{it:5d} | Accuracy/{model_eval}: {acc_test_mean}")
195 | save_and_print(args.log_path, f"{it:5d} | Max_Accuracy/{model_eval}: {best_acc[best_dict_str]}")
196 | save_and_print(args.log_path, f"{it:5d} | Std/{model_eval}: {acc_test_std}")
197 | save_and_print(args.log_path, f"{it:5d} | Max_Std/{model_eval}: {best_std[best_dict_str]}")
198 |
199 | if len(model_eval_pool) > 1:
200 | save_and_print(args.log_path, "-" * 20)
201 | save_and_print(args.log_path, f"{it:5d} | Accuracy/Avg_All: {np.mean(np.array(list(curr_acc_dict.values())))}")
202 | save_and_print(args.log_path, f"{it:5d} | Std/Avg_All: {np.mean(np.array(list(curr_std_dict.values())))}")
203 | save_and_print(args.log_path, f"{it:5d} | Max_Accuracy/Avg_All: {np.mean(np.array(list(max_acc_dict.values())))}")
204 | save_and_print(args.log_path, f"{it:5d} | Max_Std/Avg_All: {np.mean(np.array(list(max_std_dict.values())))}")
205 |
206 | curr_acc_dict.pop("{}".format(args.model))
207 | curr_std_dict.pop("{}".format(args.model))
208 | max_acc_dict.pop("{}".format(args.model))
209 | max_std_dict.pop("{}".format(args.model))
210 |
211 | save_and_print(args.log_path, "-" * 20)
212 | save_and_print(args.log_path, f"{it:5d} | Accuracy/Avg_Cross: {np.mean(np.array(list(curr_acc_dict.values())))}")
213 | save_and_print(args.log_path, f"{it:5d} | Std/Avg_Cross: {np.mean(np.array(list(curr_std_dict.values())))}")
214 | save_and_print(args.log_path, f"{it:5d} | Max_Accuracy/Avg_Cross: {np.mean(np.array(list(max_acc_dict.values())))}")
215 | save_and_print(args.log_path, f"{it:5d} | Max_Std/Avg_Cross: {np.mean(np.array(list(max_std_dict.values())))}")
216 |
217 | return save_this_it
218 |
219 | def latent_to_im(G, latents, args=None):
220 |
221 | if args.space == "p":
222 | return latents
223 |
224 | mean, std = config.mean, config.std
225 |
226 | if "imagenet" in args.dataset:
227 | class_map = {i: x for i, x in enumerate(config.img_net_classes)}
228 |
229 | if args.space == "p":
230 | im = latents
231 |
232 | elif args.space == "wp":
233 | if args.layer is None or args.layer==-1:
234 | im = G(latents[0], mode="wp")
235 | else:
236 | im = G(latents[0], latents[1], args.layer, mode="from_f")
237 |
238 | im = (im + 1) / 2
239 | im = (im - mean) / std
240 |
241 | elif args.dataset == "CIFAR10" or args.dataset == "CIFAR100":
242 | if args.space == "p":
243 | im = latents
244 | elif args.space == "wp":
245 | if args.layer is None or args.layer == -1:
246 | im = G(latents[0], mode="wp")
247 | else:
248 | im = G(latents[0], latents[1], args.layer, mode="from_f")
249 |
250 | if args.distributed and False:
251 | mean, std = config.mean_1, config.std_1
252 |
253 | im = (im + 1) / 2
254 | im = (im - mean) / std
255 |
256 | return im
--------------------------------------------------------------------------------
/ImageNet-abcde/main_DM_FreD.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
3 |
4 | import time
5 | import copy
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug, set_seed, save_and_print, get_images
10 |
11 | import shutil
12 | import torchvision
13 | import matplotlib.pyplot as plt
14 | from frequency_transforms import DCT
15 |
16 | from glad_utils import build_dataset, get_eval_lrs, eval_loop
17 |
18 | class SynSet():
19 | def __init__(self, args):
20 | ### Basic ###
21 | self.args = args
22 | self.log_path = self.args.log_path
23 | self.channel = self.args.channel
24 | self.num_classes = self.args.num_classes
25 | self.im_size = self.args.im_size
26 | self.device = self.args.device
27 | self.ipc = self.args.ipc
28 |
29 | ### FreD ###
30 | self.dct = DCT(resolution=self.im_size[0], device=self.device)
31 | self.lr_freq = self.args.lr_freq
32 | self.mom_freq = self.args.mom_freq
33 | self.msz_per_channel = self.args.msz_per_channel
34 | self.num_per_class = int((self.ipc * self.im_size[0] * self.im_size[1]) / self.msz_per_channel)
35 |
36 | def init(self, images_real, labels_real, indices_class):
37 | ### Initialize Frequency (F) ###
38 | images = torch.randn(size=(self.num_classes * self.num_per_class, self.channel, self.im_size[0], self.im_size[1]), dtype=torch.float, device=self.device)
39 | for c in range(self.num_classes):
40 | idx_shuffle = np.random.permutation(indices_class[c])[:self.num_per_class]
41 | images.data[c * self.num_per_class:(c + 1) * self.num_per_class] = images_real[idx_shuffle].detach().data
42 | self.freq_syn = self.dct.forward(images)
43 | self.freq_syn.requires_grad = True
44 | del images
45 |
46 | ### Initialize Mask (M) ###
47 | self.mask = torch.zeros(size=(self.num_classes * self.num_per_class, self.channel, self.im_size[0], self.im_size[1]), dtype=torch.float, device=self.device)
48 | self.mask.requires_grad = False
49 |
50 | ### Initialize Label ###
51 | self.label_syn = torch.tensor([np.ones(self.num_per_class) * i for i in range(self.num_classes)], requires_grad=False, device=self.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9]
52 | self.label_syn = self.label_syn.long()
53 |
54 | ### Initialize Optimizer ###
55 | self.optimizers = torch.optim.SGD([self.freq_syn, ], lr=self.lr_freq, momentum=self.mom_freq)
56 |
57 | self.init_mask()
58 | self.optim_zero_grad()
59 | self.show_budget()
60 |
61 | def get(self, indices=None, need_copy=False):
62 | if not hasattr(indices, '__iter__'):
63 | indices = range(len(self.label_syn))
64 |
65 | if need_copy:
66 | freq_syn, label_syn = copy.deepcopy(self.freq_syn[indices].detach()), copy.deepcopy(self.label_syn[indices].detach())
67 | mask = copy.deepcopy(self.mask[indices].detach())
68 | else:
69 | freq_syn, label_syn = self.freq_syn[indices], self.label_syn[indices]
70 | mask = self.mask[indices]
71 |
72 | image_syn = self.dct.inverse(mask * freq_syn)
73 | return image_syn, label_syn
74 |
75 | def init_mask(self):
76 | save_and_print(self.args.log_path, "Initialize Mask")
77 |
78 | for c in range(self.num_classes):
79 | freq_c = copy.deepcopy(self.freq_syn[c * self.num_per_class:(c + 1) * self.num_per_class].detach())
80 | freq_c = torch.mean(freq_c, 1)
81 | freqs_flat = torch.flatten(freq_c, 1)
82 | freqs_flat = freqs_flat - torch.mean(freqs_flat, dim=0)
83 |
84 | try:
85 | cov = torch.cov(freqs_flat.T)
86 | except:
87 | save_and_print(self.args.log_path, f"Can not use torch.cov. Instead use np.cov")
88 | cov = np.cov(freqs_flat.T.cpu().numpy())
89 | cov = torch.tensor(cov, dtype=torch.float, device=self.device)
90 | total_variance = torch.sum(torch.diag(cov))
91 | vr_fl2f = torch.zeros((np.prod(self.im_size), 1), device=self.device)
92 | for idx in range(np.prod(self.im_size)):
93 | pc_low = torch.eye(np.prod(self.im_size), device=self.device)[idx].reshape(-1, 1)
94 | vector_variance = torch.matmul(torch.matmul(pc_low.T, cov), pc_low)
95 | explained_variance_ratio = vector_variance / total_variance
96 | vr_fl2f[idx] = explained_variance_ratio.item()
97 |
98 | v, i = torch.topk(vr_fl2f.flatten(), self.msz_per_channel)
99 | top_indices = np.array(np.unravel_index(i.cpu().numpy(), freq_c.shape)).T[:, 1:]
100 | for h, w in top_indices:
101 | self.mask[c * self.num_per_class:(c + 1) * self.num_per_class, :, h, w] = 1.0
102 | save_and_print(self.args.log_path, f"{get_time()} Class {c:3d} | {torch.sum(self.mask[c * self.num_per_class, 0] > 0.0):5d}")
103 |
104 | ### Visualize and Save ###
105 | indices_save = np.arange(10) * self.num_per_class
106 | grid = torchvision.utils.make_grid(self.mask[indices_save], nrow=10)
107 | plt.imshow(np.transpose(grid.detach().cpu().numpy(), (1, 2, 0)))
108 | plt.savefig(f"{self.args.save_path}/Mask.png", dpi=300)
109 | plt.close()
110 |
111 | mask_save = copy.deepcopy(self.mask.detach())
112 | torch.save(mask_save.cpu(), os.path.join(self.args.save_path, "mask.pt"))
113 | del mask_save
114 |
115 | def optim_zero_grad(self):
116 | self.optimizers.zero_grad()
117 |
118 | def optim_step(self):
119 | self.optimizers.step()
120 |
121 | def show_budget(self):
122 | save_and_print(self.log_path, '=' * 50)
123 | save_and_print(self.log_path, f"Freq: {self.freq_syn.shape} | Mask: {self.mask.shape} , {torch.sum(self.mask[0, 0] > 0.0):5d}")
124 | images, _ = self.get(need_copy=True)
125 | save_and_print(self.log_path, f"Decode condensed data: {images.shape}")
126 | del images
127 | save_and_print(self.log_path, '=' * 50)
128 |
129 | def main(args):
130 |
131 | eval_it_pool = np.arange(0, args.Iteration + 1, args.eval_it).tolist()
132 | channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader, loader_train_dict, class_map, class_map_inv = get_dataset(args.dataset, args.data_path, args.batch_real, args.res, args=args)
133 | args.channel, args.im_size, args.num_classes, args.mean, args.std = channel, im_size, num_classes, mean, std
134 | model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)
135 |
136 | accs_all_exps = dict() # record performances of all experiments
137 | for key in model_eval_pool:
138 | accs_all_exps[key] = []
139 |
140 | data_save = []
141 |
142 | save_and_print(args.log_path, f'\n================== Exp {0} ==================\n ')
143 | save_and_print(args.log_path, f'Hyper-parameters: {args.__dict__}')
144 |
145 | ''' organize the real dataset '''
146 | images_all, labels_all, indices_class = build_dataset(dst_train, class_map, num_classes)
147 | images_all, labels_all = images_all.to(args.device), labels_all.to(args.device)
148 |
149 | ''' initialize the synthetic data'''
150 | synset = SynSet(args)
151 | synset.init(images_all, labels_all, indices_class)
152 |
153 | ''' training '''
154 | criterion = nn.CrossEntropyLoss().to(args.device)
155 | save_and_print(args.log_path, '%s training begins'%get_time())
156 |
157 | best_acc = {"{}".format(m): 0 for m in model_eval_pool}
158 | best_std = {m: 0 for m in model_eval_pool}
159 |
160 | save_this_it = False
161 | for it in range(args.Iteration+1):
162 |
163 | if it in eval_it_pool:
164 | image_syn_eval, label_syn_eval = synset.get(need_copy=True)
165 | save_this_it = eval_loop(latents=image_syn_eval, f_latents=None, label_syn=label_syn_eval, G=None, best_acc=best_acc,
166 | best_std=best_std, testloader=testloader,
167 | model_eval_pool=model_eval_pool, channel=channel, num_classes=num_classes,
168 | im_size=im_size, it=it, args=args)
169 |
170 | ''' Train synthetic data '''
171 | net = get_network(args.model, channel, num_classes, im_size, depth=args.depth, width=args.width).to(args.device) # get a random model
172 | net.train()
173 | for param in list(net.parameters()):
174 | param.requires_grad = False
175 |
176 | embed = net.module.embed if torch.cuda.device_count() > 1 else net.embed # for GPU parallel
177 |
178 | loss_avg = 0
179 |
180 | ''' update synthetic data '''
181 | if 'BN' not in args.model: # for ConvNet
182 | loss = torch.tensor(0.0).to(args.device)
183 | for c in range(num_classes):
184 | img_real = get_images(images_all, indices_class, c, args.batch_real)
185 |
186 | if args.batch_syn > 0:
187 | indices = np.random.permutation(range(c * synset.num_per_class, (c + 1) * synset.num_per_class))[:args.batch_syn]
188 | else:
189 | indices = range(c * synset.num_per_class, (c + 1) * synset.num_per_class)
190 |
191 | img_syn, lab_syn = synset.get(indices=indices)
192 |
193 | if args.dsa:
194 | seed = int(time.time() * 1000) % 100000
195 | img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)
196 | img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param)
197 |
198 | output_real = embed(img_real).detach()
199 | output_syn = embed(img_syn)
200 |
201 | loss += torch.sum((torch.mean(output_real, dim=0) - torch.mean(output_syn, dim=0))**2)
202 |
203 | else: # for ConvNetBN
204 | images_real_all = []
205 | images_syn_all = []
206 | loss = torch.tensor(0.0).to(args.device)
207 | for c in range(num_classes):
208 | img_real = get_images(c, args.batch_real)
209 |
210 | if args.batch_syn > 0:
211 | indices = np.random.permutation(range(c * synset.num_per_class, (c + 1) * synset.num_per_class))[:args.batch_syn]
212 | else:
213 | indices = range(c * synset.num_per_class, (c + 1) * synset.num_per_class)
214 |
215 | img_syn, lab_syn = synset.get(indices=indices)
216 |
217 | if args.dsa:
218 | seed = int(time.time() * 1000) % 100000
219 | img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)
220 | img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param)
221 |
222 | images_real_all.append(img_real)
223 | images_syn_all.append(img_syn)
224 |
225 | images_real_all = torch.cat(images_real_all, dim=0)
226 | images_syn_all = torch.cat(images_syn_all, dim=0)
227 |
228 | output_real = embed(images_real_all).detach()
229 | output_syn = embed(images_syn_all)
230 |
231 | loss += torch.sum((torch.mean(output_real.reshape(num_classes, args.batch_real, -1), dim=1) - torch.mean(output_syn.reshape(num_classes, args.ipc, -1), dim=1))**2)
232 |
233 | synset.optim_zero_grad()
234 | loss.backward()
235 | synset.optim_step()
236 | loss_avg += loss.item()
237 |
238 | loss_avg /= (num_classes)
239 |
240 | if it%10 == 0:
241 | save_and_print(args.log_path, '%s iter = %04d, loss = %.4f' % (get_time(), it, loss_avg))
242 |
243 | if it == args.Iteration: # only record the final results
244 | data_save.append([synset.get(need_copy=True)])
245 | torch.save({'data': data_save, 'accs_all_exps': accs_all_exps, }, os.path.join(args.save_path, 'res_%s_%s_%dipc.pt'%(args.dataset, args.model, args.ipc)))
246 |
247 |
248 | if __name__ == '__main__':
249 | import shared_args
250 |
251 | parser = shared_args.add_shared_args()
252 |
253 | parser.add_argument('--zca', action='store_true', help="do ZCA whitening")
254 |
255 | parser.add_argument('--seed', type=int, default=0)
256 | parser.add_argument('--sh_file', type=str)
257 | parser.add_argument('--FLAG', type=str, default="TEST")
258 |
259 | ### FreD ###
260 | parser.add_argument('--batch_syn', type=int)
261 | parser.add_argument('--msz_per_channel', type=int)
262 | parser.add_argument('--lr_freq', type=float)
263 | parser.add_argument('--mom_freq', type=float)
264 | args = parser.parse_args()
265 | args.space = "p"
266 | args.zca = False
267 |
268 | set_seed(args.seed)
269 |
270 | args.outer_loop, args.inner_loop = get_loops(args.ipc)
271 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
272 | args.dsa_param = ParamDiffAug()
273 | args.dsa = False if args.dsa_strategy in ['none', 'None'] else True
274 |
275 | if not os.path.exists(args.data_path):
276 | os.mkdir(args.data_path)
277 |
278 | if not os.path.exists(args.save_path):
279 | os.mkdir(args.save_path)
280 | args.save_path = args.save_path + f"/{args.FLAG}"
281 | if not os.path.exists(args.save_path):
282 | os.mkdir(args.save_path)
283 |
284 | shutil.copy(f"./scripts/{args.sh_file}", f"{args.save_path}/{args.sh_file}")
285 | args.log_path = f"{args.save_path}/log.txt"
286 |
287 | main(args)
288 |
289 |
290 |
--------------------------------------------------------------------------------
/3D-MNIST/main_DM_FreD.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
3 |
4 | import copy
5 | import argparse
6 | import numpy as np
7 | import torch
8 | from utils import get_loops, get_dataset, get_network, evaluate_synset, get_time, TensorDataset, epoch, set_seed, save_and_print, get_images
9 |
10 | import shutil
11 | from frequency_transforms import DCT
12 |
13 | class SynSet():
14 | def __init__(self, args):
15 | ### Basic ###
16 | self.args = args
17 | self.log_path = self.args.log_path
18 | self.channel = self.args.channel
19 | self.num_classes = self.args.num_classes
20 | self.im_size = self.args.im_size
21 | self.device = self.args.device
22 | self.ipc = self.args.ipc
23 |
24 | ### FreD ###
25 | self.dct = DCT(resolution=self.im_size[0], device=self.device)
26 | self.lr_freq = self.args.lr_freq
27 | self.mom_freq = self.args.mom_freq
28 | self.msz_per_channel = self.args.msz_per_channel
29 | self.num_per_class = int((self.ipc * self.im_size[0] * self.im_size[1] * self.im_size[2]) / self.msz_per_channel)
30 |
31 | def init(self, images_real, labels_real, indices_class):
32 | ### Initialize Frequency (F) ###
33 | images = torch.randn(size=(self.num_classes * self.num_per_class, self.channel, self.im_size[0], self.im_size[1], self.im_size[2]), dtype=torch.float, device=self.device)
34 | for c in range(self.num_classes):
35 | idx_shuffle = np.random.permutation(indices_class[c])[:self.num_per_class]
36 | images.data[c * self.num_per_class:(c + 1) * self.num_per_class] = images_real[idx_shuffle].detach().data
37 | self.freq_syn = self.dct.forward(images)
38 | self.freq_syn.requires_grad = True
39 | del images
40 |
41 | ### Initialize Mask (M) ###
42 | self.mask = torch.zeros(size=(self.num_classes * self.num_per_class, self.channel, self.im_size[0], self.im_size[1], self.im_size[2]), dtype=torch.float, device=self.device)
43 | self.mask.requires_grad = False
44 |
45 | ### Initialize Label ###
46 | self.label_syn = torch.tensor([np.ones(self.num_per_class) * i for i in range(self.num_classes)], requires_grad=False, device=self.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9]
47 | self.label_syn = self.label_syn.long()
48 |
49 | ### Initialize Optimizer ###
50 | self.optimizers = torch.optim.SGD([self.freq_syn, ], lr=self.lr_freq, momentum=self.mom_freq)
51 |
52 | self.init_mask()
53 | self.optim_zero_grad()
54 | self.show_budget()
55 |
56 | def get(self, indices=None, need_copy=False):
57 | if not hasattr(indices, '__iter__'):
58 | indices = range(len(self.label_syn))
59 |
60 | if need_copy:
61 | freq_syn, label_syn = copy.deepcopy(self.freq_syn[indices].detach()), copy.deepcopy(self.label_syn[indices].detach())
62 | mask = copy.deepcopy(self.mask[indices].detach())
63 | else:
64 | freq_syn, label_syn = self.freq_syn[indices], self.label_syn[indices]
65 | mask = self.mask[indices]
66 |
67 | image_syn = self.dct.inverse(mask * freq_syn)
68 | return image_syn, label_syn
69 |
70 | def init_mask(self):
71 | save_and_print(self.args.log_path, "Initialize Mask")
72 |
73 | for c in range(self.num_classes):
74 | freq_c = copy.deepcopy(self.freq_syn[c * self.num_per_class:(c + 1) * self.num_per_class].detach())
75 | freq_c = torch.mean(freq_c, 1)
76 | freqs_flat = torch.flatten(freq_c, 1)
77 | freqs_flat = freqs_flat - torch.mean(freqs_flat, dim=0)
78 |
79 | try:
80 | cov = torch.cov(freqs_flat.T)
81 | except:
82 | save_and_print(self.args.log_path, f"Can not use torch.cov. Instead use np.cov")
83 | cov = np.cov(freqs_flat.T.cpu().numpy())
84 | cov = torch.tensor(cov, dtype=torch.float, device=self.device)
85 | total_variance = torch.sum(torch.diag(cov))
86 | vr_fl2f = torch.zeros((np.prod(self.im_size), 1), device=self.device)
87 | for idx in range(np.prod(self.im_size)):
88 | pc_low = torch.eye(np.prod(self.im_size), device=self.device)[idx].reshape(-1, 1)
89 | vector_variance = torch.matmul(torch.matmul(pc_low.T, cov), pc_low)
90 | explained_variance_ratio = vector_variance / total_variance
91 | vr_fl2f[idx] = explained_variance_ratio.item()
92 |
93 | v, i = torch.topk(vr_fl2f.flatten(), self.msz_per_channel)
94 | top_indices = np.array(np.unravel_index(i.cpu().numpy(), freq_c.shape)).T[:, 1:]
95 | for d, h, w in top_indices:
96 | self.mask[c * self.num_per_class:(c + 1) * self.num_per_class, :, d, h, w] = 1.0
97 | save_and_print(self.args.log_path, f"{get_time()} Class {c:3d} | {torch.sum(self.mask[c * self.num_per_class, 0] > 0.0):5d}")
98 |
99 | def optim_zero_grad(self):
100 | self.optimizers.zero_grad()
101 |
102 | def optim_step(self):
103 | self.optimizers.step()
104 |
105 | def show_budget(self):
106 | save_and_print(self.log_path, '=' * 50)
107 | save_and_print(self.log_path, f"Freq: {self.freq_syn.shape} | Mask: {self.mask.shape} , {torch.sum(self.mask[0, 0] > 0.0):5d}")
108 | images, _ = self.get(need_copy=True)
109 | save_and_print(self.log_path, f"Decode condensed data: {images.shape}")
110 | del images
111 | save_and_print(self.log_path, '=' * 50)
112 |
113 | def main():
114 |
115 | parser = argparse.ArgumentParser(description='Parameter Processing')
116 | parser.add_argument('--method', type=str, default='DM', help='DC/DSA/DM')
117 | parser.add_argument('--dataset', type=str, default='3D-MNIST', help='dataset')
118 | parser.add_argument('--model', type=str, default='Conv3DNet', help='model')
119 | parser.add_argument('--ipc', type=int, default=1, help='image(s) per class')
120 | parser.add_argument('--eval_mode', type=str, default='S', help='eval_mode') # S: the same to training model, M: multi architectures, W: net width, D: net depth, A: activation function, P: pooling layer, N: normalization layer,
121 | parser.add_argument('--num_exp', type=int, default=1, help='the number of experiments')
122 | parser.add_argument('--num_eval', type=int, default=5, help='the number of evaluating randomly initialized models')
123 | parser.add_argument('--eval_it', type=int, default=200)
124 | parser.add_argument('--epoch_eval_train', type=int, default=500, help='epochs to train a model with synthetic data') # it can be small for speeding up with little performance drop
125 | parser.add_argument('--Iteration', type=int, default=20000, help='training iterations')
126 | parser.add_argument('--lr_net', type=float, default=0.01, help='learning rate for updating network parameters')
127 | parser.add_argument('--batch_real', type=int, default=256, help='batch size for real data')
128 | parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks')
129 | # parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', help='differentiable Siamese augmentation strategy')
130 | parser.add_argument('--data_path', type=str, default='../data', help='dataset path')
131 | parser.add_argument('--save_path', type=str, default='./results', help='path to save results')
132 | parser.add_argument('--dis_metric', type=str, default='ours', help='distance metric')
133 |
134 |
135 | parser.add_argument('--seed', type=int, default=0)
136 | parser.add_argument('--sh_file', type=str)
137 | parser.add_argument('--FLAG', type=str, default="TEST")
138 |
139 | ### FreD ###
140 | parser.add_argument('--batch_syn', type=int)
141 | parser.add_argument('--msz_per_channel', type=int)
142 | parser.add_argument('--lr_freq', type=float)
143 | parser.add_argument('--mom_freq', type=float)
144 | args = parser.parse_args()
145 | set_seed(args.seed)
146 |
147 | args.outer_loop, args.inner_loop = get_loops(args.ipc)
148 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
149 |
150 | if not os.path.exists(args.data_path):
151 | os.mkdir(args.data_path)
152 |
153 | if not os.path.exists(args.save_path):
154 | os.mkdir(args.save_path)
155 | args.save_path = args.save_path + f"/{args.FLAG}"
156 | if not os.path.exists(args.save_path):
157 | os.mkdir(args.save_path)
158 |
159 | shutil.copy(f"./scripts/{args.sh_file}", f"{args.save_path}/{args.sh_file}")
160 | args.log_path = f"{args.save_path}/log.txt"
161 |
162 | eval_it_pool = np.arange(0, args.Iteration+1, args.eval_it).tolist()
163 | channel, im_size, num_classes, class_names, dst_train, dst_test, testloader = get_dataset(args.data_path, args=args)
164 | args.channel, args.im_size, args.num_classes = channel, im_size, num_classes
165 | model_eval_pool = ["Conv3DNet"]
166 |
167 | accs_all_exps = dict() # record performances of all experiments
168 | for key in model_eval_pool:
169 | accs_all_exps[key] = []
170 |
171 | data_save = []
172 |
173 | for exp in range(args.num_exp):
174 | save_and_print(args.log_path, f'\n================== Exp {exp} ==================\n ')
175 | save_and_print(args.log_path, f'Hyper-parameters: {args.__dict__}')
176 |
177 | ''' organize the real dataset '''
178 | images_all = []
179 | labels_all = []
180 | indices_class = [[] for c in range(num_classes)]
181 |
182 | images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
183 | labels_all = [dst_train[i][1] for i in range(len(dst_train))]
184 | for i, lab in enumerate(labels_all):
185 | indices_class[lab].append(i)
186 | images_all = torch.cat(images_all, dim=0).to(args.device)
187 | labels_all = torch.tensor(labels_all, dtype=torch.long, device=args.device)
188 |
189 | ''' initialize the synthetic data '''
190 | synset = SynSet(args)
191 | synset.init(images_all, labels_all, indices_class)
192 |
193 | ''' training '''
194 | save_and_print(args.log_path, '%s training begins'%get_time())
195 |
196 | for it in range(args.Iteration+1):
197 |
198 | ''' Evaluate synthetic data '''
199 | if it in eval_it_pool:
200 | for model_eval in model_eval_pool:
201 | save_and_print(args.log_path, '-------------------------\nEvaluation\nmodel_train = %s, model_eval = %s, iteration = %d'%(args.model, model_eval, it))
202 |
203 | accs = []
204 | for it_eval in range(args.num_eval):
205 | net_eval = get_network(model_eval, channel, num_classes, im_size).to(args.device) # get a random model
206 | image_syn_eval, label_syn_eval = synset.get(need_copy=True)
207 | _, acc_train, acc_test = evaluate_synset(it_eval, net_eval, image_syn_eval, label_syn_eval, testloader, args)
208 | accs.append(acc_test)
209 |
210 | del image_syn_eval, label_syn_eval
211 | save_and_print(args.log_path, 'Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------'%(len(accs), model_eval, np.mean(accs), np.std(accs)))
212 |
213 | if it == args.Iteration: # record the final results
214 | accs_all_exps[model_eval] += accs
215 |
216 | ''' Train synthetic data '''
217 | net = get_network(args.model, channel, num_classes, im_size).to(args.device) # get a random model
218 | net.train()
219 | for param in list(net.parameters()):
220 | param.requires_grad = False
221 |
222 | embed = net.module.embed if torch.cuda.device_count() > 1 else net.embed # for GPU parallel
223 |
224 | loss_avg = 0
225 |
226 | ''' update synthetic data '''
227 | loss = torch.tensor(0.0).to(args.device)
228 | for c in range(num_classes):
229 | img_real = get_images(images_all, indices_class, c, args.batch_real)
230 |
231 | if args.batch_syn > 0:
232 | indices = np.random.permutation(range(c * synset.num_per_class, (c + 1) * synset.num_per_class))[:args.batch_syn]
233 | else:
234 | indices = range(c * synset.num_per_class, (c + 1) * synset.num_per_class)
235 |
236 | img_syn, lab_syn = synset.get(indices=indices)
237 |
238 | output_real = embed(img_real).detach()
239 | output_syn = embed(img_syn)
240 |
241 | loss += torch.sum((torch.mean(output_real, dim=0) - torch.mean(output_syn, dim=0))**2)
242 |
243 | synset.optim_zero_grad()
244 | loss.backward()
245 | synset.optim_step()
246 | loss_avg += loss.item()
247 |
248 | loss_avg /= (num_classes)
249 |
250 | if it%10 == 0:
251 | save_and_print(args.log_path, '%s iter = %04d, loss = %.4f' % (get_time(), it, loss_avg))
252 |
253 | if it == args.Iteration: # only record the final results
254 | data_save.append([synset.get(need_copy=True)])
255 | torch.save({'data': data_save, 'accs_all_exps': accs_all_exps, }, os.path.join(args.save_path, 'res_%s_%s_%s_%dipc.pt'%(args.method, args.dataset, args.model, args.ipc)))
256 |
257 |
258 | save_and_print(args.log_path, '\n==================== Final Results ====================\n')
259 | for key in model_eval_pool:
260 | accs = accs_all_exps[key]
261 | save_and_print(args.log_path, 'Run %d experiments, train on %s, evaluate %d random %s, mean = %.2f%% std = %.2f%%'%(args.num_exp, args.model, len(accs), key, np.mean(accs)*100, np.std(accs)*100))
262 |
263 |
264 | if __name__ == '__main__':
265 | main()
266 |
267 |
268 |
--------------------------------------------------------------------------------
/ImageNet-abcde/main_DC_FreD.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
3 |
4 | import time
5 | import copy
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug, set_seed, save_and_print, get_images
10 |
11 | import shutil
12 | import torchvision
13 | import matplotlib.pyplot as plt
14 | from frequency_transforms import DCT
15 |
16 | from glad_utils import build_dataset, get_eval_lrs, eval_loop
17 |
18 | class SynSet():
19 | def __init__(self, args):
20 | ### Basic ###
21 | self.args = args
22 | self.log_path = self.args.log_path
23 | self.channel = self.args.channel
24 | self.num_classes = self.args.num_classes
25 | self.im_size = self.args.im_size
26 | self.device = self.args.device
27 | self.ipc = self.args.ipc
28 |
29 | ### FreD ###
30 | self.dct = DCT(resolution=self.im_size[0], device=self.device)
31 | self.lr_freq = self.args.lr_freq
32 | self.mom_freq = self.args.mom_freq
33 | self.msz_per_channel = self.args.msz_per_channel
34 | self.num_per_class = int((self.ipc * self.im_size[0] * self.im_size[1]) / self.msz_per_channel)
35 |
36 | def init(self, images_real, labels_real, indices_class):
37 | ### Initialize Frequency (F) ###
38 | images = torch.randn(size=(self.num_classes * self.num_per_class, self.channel, self.im_size[0], self.im_size[1]), dtype=torch.float, device=self.device)
39 | for c in range(self.num_classes):
40 | idx_shuffle = np.random.permutation(indices_class[c])[:self.num_per_class]
41 | images.data[c * self.num_per_class:(c + 1) * self.num_per_class] = images_real[idx_shuffle].detach().data
42 | self.freq_syn = self.dct.forward(images)
43 | self.freq_syn.requires_grad = True
44 | del images
45 |
46 | ### Initialize Mask (M) ###
47 | self.mask = torch.zeros(size=(self.num_classes * self.num_per_class, self.channel, self.im_size[0], self.im_size[1]), dtype=torch.float, device=self.device)
48 | self.mask.requires_grad = False
49 |
50 | ### Initialize Label ###
51 | self.label_syn = torch.tensor([np.ones(self.num_per_class) * i for i in range(self.num_classes)], requires_grad=False, device=self.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9]
52 | self.label_syn = self.label_syn.long()
53 |
54 | ### Initialize Optimizer ###
55 | self.optimizers = torch.optim.SGD([self.freq_syn, ], lr=self.lr_freq, momentum=self.mom_freq)
56 |
57 | self.init_mask()
58 | self.optim_zero_grad()
59 | self.show_budget()
60 |
61 | def get(self, indices=None, need_copy=False):
62 | if not hasattr(indices, '__iter__'):
63 | indices = range(len(self.label_syn))
64 |
65 | if need_copy:
66 | freq_syn, label_syn = copy.deepcopy(self.freq_syn[indices].detach()), copy.deepcopy(self.label_syn[indices].detach())
67 | mask = copy.deepcopy(self.mask[indices].detach())
68 | else:
69 | freq_syn, label_syn = self.freq_syn[indices], self.label_syn[indices]
70 | mask = self.mask[indices]
71 |
72 | image_syn = self.dct.inverse(mask * freq_syn)
73 | return image_syn, label_syn
74 |
75 | def init_mask(self):
76 | save_and_print(self.args.log_path, "Initialize Mask")
77 |
78 | for c in range(self.num_classes):
79 | freq_c = copy.deepcopy(self.freq_syn[c * self.num_per_class:(c + 1) * self.num_per_class].detach())
80 | freq_c = torch.mean(freq_c, 1)
81 | freqs_flat = torch.flatten(freq_c, 1)
82 | freqs_flat = freqs_flat - torch.mean(freqs_flat, dim=0)
83 |
84 | try:
85 | cov = torch.cov(freqs_flat.T)
86 | except:
87 | save_and_print(self.args.log_path, f"Can not use torch.cov. Instead use np.cov")
88 | cov = np.cov(freqs_flat.T.cpu().numpy())
89 | cov = torch.tensor(cov, dtype=torch.float, device=self.device)
90 | total_variance = torch.sum(torch.diag(cov))
91 | vr_fl2f = torch.zeros((np.prod(self.im_size), 1), device=self.device)
92 | for idx in range(np.prod(self.im_size)):
93 | pc_low = torch.eye(np.prod(self.im_size), device=self.device)[idx].reshape(-1, 1)
94 | vector_variance = torch.matmul(torch.matmul(pc_low.T, cov), pc_low)
95 | explained_variance_ratio = vector_variance / total_variance
96 | vr_fl2f[idx] = explained_variance_ratio.item()
97 |
98 | v, i = torch.topk(vr_fl2f.flatten(), self.msz_per_channel)
99 | top_indices = np.array(np.unravel_index(i.cpu().numpy(), freq_c.shape)).T[:, 1:]
100 | for h, w in top_indices:
101 | self.mask[c * self.num_per_class:(c + 1) * self.num_per_class, :, h, w] = 1.0
102 | save_and_print(self.args.log_path, f"{get_time()} Class {c:3d} | {torch.sum(self.mask[c * self.num_per_class, 0] > 0.0):5d}")
103 |
104 | ### Visualize and Save ###
105 | indices_save = np.arange(10) * self.num_per_class
106 | grid = torchvision.utils.make_grid(self.mask[indices_save], nrow=10)
107 | plt.imshow(np.transpose(grid.detach().cpu().numpy(), (1, 2, 0)))
108 | plt.savefig(f"{self.args.save_path}/Mask.png", dpi=300)
109 | plt.close()
110 |
111 | mask_save = copy.deepcopy(self.mask.detach())
112 | torch.save(mask_save.cpu(), os.path.join(self.args.save_path, "mask.pt"))
113 | del mask_save
114 |
115 | def optim_zero_grad(self):
116 | self.optimizers.zero_grad()
117 |
118 | def optim_step(self):
119 | self.optimizers.step()
120 |
121 | def show_budget(self):
122 | save_and_print(self.log_path, '=' * 50)
123 | save_and_print(self.log_path, f"Freq: {self.freq_syn.shape} | Mask: {self.mask.shape} , {torch.sum(self.mask[0, 0] > 0.0):5d}")
124 | images, _ = self.get(need_copy=True)
125 | save_and_print(self.log_path, f"Decode condensed data: {images.shape}")
126 | del images
127 | save_and_print(self.log_path, '=' * 50)
128 |
129 | def main(args):
130 |
131 | eval_it_pool = np.arange(0, args.Iteration + 1, args.eval_it).tolist()
132 | channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader, loader_train_dict, class_map, class_map_inv = get_dataset(args.dataset, args.data_path, args.batch_real, args.res, args=args)
133 | args.channel, args.im_size, args.num_classes, args.mean, args.std = channel, im_size, num_classes, mean, std
134 | model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)
135 |
136 | accs_all_exps = dict() # record performances of all experiments
137 | for key in model_eval_pool:
138 | accs_all_exps[key] = []
139 |
140 | data_save = []
141 |
142 | save_and_print(args.log_path, f'\n================== Exp {0} ==================\n ')
143 | save_and_print(args.log_path, f'Hyper-parameters: {args.__dict__}')
144 |
145 | ''' organize the real dataset '''
146 | images_all, labels_all, indices_class = build_dataset(dst_train, class_map, num_classes)
147 | images_all, labels_all = images_all.to(args.device), labels_all.to(args.device)
148 |
149 | ''' initialize the synthetic data'''
150 | synset = SynSet(args)
151 | synset.init(images_all, labels_all, indices_class)
152 |
153 | ''' training '''
154 | criterion = nn.CrossEntropyLoss().to(args.device)
155 | save_and_print(args.log_path, '%s training begins'%get_time())
156 |
157 | best_acc = {"{}".format(m): 0 for m in model_eval_pool}
158 | best_std = {m: 0 for m in model_eval_pool}
159 | eval_pool_dict = get_eval_lrs(args)
160 |
161 | save_this_it = False
162 | for it in range(args.Iteration+1):
163 |
164 | if it in eval_it_pool and it > 0:
165 | image_syn_eval, label_syn_eval = synset.get(need_copy=True)
166 | save_this_it = eval_loop(latents=image_syn_eval, f_latents=None, label_syn=label_syn_eval, G=None, best_acc=best_acc,
167 | best_std=best_std, testloader=testloader,
168 | model_eval_pool=model_eval_pool, channel=channel, num_classes=num_classes,
169 | im_size=im_size, it=it, args=args)
170 |
171 | del image_syn_eval, label_syn_eval
172 |
173 | ''' Train synthetic data '''
174 | net = get_network(args.model, channel, num_classes, im_size, depth=args.depth, width=args.width).to(args.device) # get a random model
175 | net.train()
176 | net_parameters = list(net.parameters())
177 | optimizer_net = torch.optim.SGD(net.parameters(), lr=args.lr_net) # optimizer_img for synthetic data
178 | optimizer_net.zero_grad()
179 | loss_avg = 0
180 | args.dc_aug_param = None # Mute the DC augmentation when learning synthetic data (in inner-loop epoch function) in oder to be consistent with DC paper.
181 |
182 | for ol in range(args.outer_loop):
183 |
184 | ''' freeze the running mu and sigma for BatchNorm layers '''
185 | # Synthetic data batch, e.g. only 1 image/batch, is too small to obtain stable mu and sigma.
186 | # So, we calculate and freeze mu and sigma for BatchNorm layer with real data batch ahead.
187 | # This would make the training with BatchNorm layers easier.
188 |
189 | BN_flag = False
190 | BNSizePC = 16 # for batch normalization
191 | for module in net.modules():
192 | if 'BatchNorm' in module._get_name(): #BatchNorm
193 | BN_flag = True
194 | if BN_flag:
195 | img_real = torch.cat([get_images(images_all, indices_class, c, BNSizePC) for c in range(num_classes)], dim=0)
196 | net.train() # for updating the mu, sigma of BatchNorm
197 | output_real = net(img_real) # get running mu, sigma
198 | for module in net.modules():
199 | if 'BatchNorm' in module._get_name(): #BatchNorm
200 | module.eval() # fix mu and sigma of every BatchNorm layer
201 |
202 | ''' update synthetic data '''
203 | loss = torch.tensor(0.0).to(args.device)
204 | for c in range(num_classes):
205 | img_real = get_images(images_all, indices_class, c, args.batch_real)
206 | lab_real = torch.ones((img_real.shape[0],), device=args.device, dtype=torch.long) * c
207 |
208 | if args.batch_syn > 0:
209 | indices = np.random.permutation(range(c * synset.num_per_class, (c + 1) * synset.num_per_class))[:args.batch_syn]
210 | else:
211 | indices = range(c * synset.num_per_class, (c + 1) * synset.num_per_class)
212 |
213 | img_syn, lab_syn = synset.get(indices=indices)
214 |
215 | if args.dsa:
216 | seed = int(time.time() * 1000) % 100000
217 | img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)
218 | img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param)
219 |
220 | output_real = net(img_real)
221 | loss_real = criterion(output_real, lab_real)
222 | gw_real = torch.autograd.grad(loss_real, net_parameters)
223 | gw_real = list((_.detach().clone() for _ in gw_real))
224 |
225 | output_syn = net(img_syn)
226 | loss_syn = criterion(output_syn, lab_syn)
227 | gw_syn = torch.autograd.grad(loss_syn, net_parameters, create_graph=True)
228 |
229 | loss += match_loss(gw_syn, gw_real, args)
230 |
231 | synset.optim_zero_grad()
232 | loss.backward()
233 | synset.optim_step()
234 | loss_avg += loss.item()
235 |
236 | if ol == args.outer_loop - 1:
237 | break
238 |
239 | ''' update network '''
240 | image_syn_train, label_syn_train = synset.get(need_copy=True)
241 | dst_syn_train = TensorDataset(image_syn_train, label_syn_train)
242 | trainloader = torch.utils.data.DataLoader(dst_syn_train, batch_size=args.batch_train, shuffle=True, num_workers=0)
243 | for il in range(args.inner_loop):
244 | epoch('train', trainloader, net, optimizer_net, criterion, args, aug = True if args.dsa else False)
245 |
246 | loss_avg /= (num_classes*args.outer_loop)
247 |
248 | if it%10 == 0:
249 | save_and_print(args.log_path, '%s iter = %04d, loss = %.4f' % (get_time(), it, loss_avg))
250 |
251 | if it == args.Iteration: # only record the final results
252 | data_save.append([synset.get(need_copy=True)])
253 | torch.save({'data': data_save, 'accs_all_exps': accs_all_exps, }, os.path.join(args.save_path, 'res_%s_%s_%dipc.pt'%(args.dataset, args.model, args.ipc)))
254 |
255 |
256 | if __name__ == '__main__':
257 | import shared_args
258 |
259 | parser = shared_args.add_shared_args()
260 |
261 | parser.add_argument('--lr_net', type=float, default=0.01, help='learning rate for updating network parameters')
262 | parser.add_argument('--dis_metric', type=str, default='ours', help='distance metric')
263 | parser.add_argument('--zca', action='store_true', help="do ZCA whitening")
264 |
265 | parser.add_argument('--seed', type=int, default=0)
266 | parser.add_argument('--sh_file', type=str)
267 | parser.add_argument('--FLAG', type=str, default="TEST")
268 |
269 | ### FreD ###
270 | parser.add_argument('--batch_syn', type=int)
271 | parser.add_argument('--msz_per_channel', type=int)
272 | parser.add_argument('--lr_freq', type=float)
273 | parser.add_argument('--mom_freq', type=float)
274 | args = parser.parse_args()
275 | args.space = "p"
276 | args.zca = False
277 |
278 | set_seed(args.seed)
279 |
280 | args.outer_loop, args.inner_loop = get_loops(args.ipc)
281 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
282 | args.dsa_param = ParamDiffAug()
283 | args.dsa = False if args.dsa_strategy in ['none', 'None'] else True
284 |
285 | if not os.path.exists(args.data_path):
286 | os.mkdir(args.data_path)
287 |
288 | if not os.path.exists(args.save_path):
289 | os.mkdir(args.save_path)
290 | args.save_path = args.save_path + f"/{args.FLAG}"
291 | if not os.path.exists(args.save_path):
292 | os.mkdir(args.save_path)
293 |
294 | shutil.copy(f"./scripts/{args.sh_file}", f"{args.save_path}/{args.sh_file}")
295 | args.log_path = f"{args.save_path}/log.txt"
296 |
297 | main(args)
298 |
299 |
300 |
--------------------------------------------------------------------------------
/ImageNet-abcde/buffer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | from tqdm import tqdm
7 | from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug
8 | import copy
9 |
10 | def main(args):
11 |
12 | outer_loop_default, inner_loop_default = get_loops(args.ipc)
13 |
14 | args.lr_net = args.lr_teacher
15 |
16 | if args.outer_loop is None:
17 | args.outer_loop = outer_loop_default
18 | if args.inner_loop is None:
19 | args.inner_loop = inner_loop_default
20 |
21 | if args.g_train_ipc is None:
22 | args.g_train_ipc = args.ipc
23 | if args.g_eval_ipc is None:
24 | args.g_eval_ipc = args.ipc
25 | if args.g_grad_ipc is None:
26 | args.g_grad_ipc = args.ipc
27 |
28 | args.dsa = True if args.dsa == 'True' else False
29 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
30 | args.dsa_param = ParamDiffAug()
31 | args.dsa_param.blur_perc = args.blur_perc
32 | args.dsa_param.blur_min = args.blur_min
33 | args.dsa_param.blur_max = args.blur_max
34 |
35 | if args.lr_img is None:
36 | if args.space == 'p':
37 | args.lr_img = 0.1
38 | else:
39 | args.lr_img = 0.01
40 |
41 | if args.batch_syn is None:
42 | args.batch_syn = args.ipc
43 |
44 | args.num_batches = args.ipc // args.batch_syn
45 |
46 | run_name = ""
47 |
48 | if args.clip:
49 | run_name += "clip_"
50 |
51 | run_name += "{}_".format(args.model)
52 |
53 | run_name += "{}_".format(args.dataset)
54 |
55 | if args.dataset == "ImageNet":
56 | run_name += "{}_".format(args.subset)
57 | run_name += "{}_".format(args.res)
58 |
59 |
60 | run_name += "space_{}_".format(args.space)
61 | if args.space != "p":
62 | run_name += "tanh_{}_".format(args.tanh)
63 | run_name += "proj_{}_".format(args.proj_ball)
64 | run_name += "trunc_{}_".format(args.trunc)
65 | run_name += "layer_{}_".format(args.layer)
66 |
67 | if args.space == "p":
68 | run_name += "init_{}_".format(args.pix_init)
69 | elif args.space == "z":
70 | run_name += "init_{}_".format(args.gan_init)
71 | run_name += "RandCond_{}_".format(args.rand_cond)
72 | run_name += "RandLat_{}_".format(args.rand_lat)
73 |
74 | elif args.patch:
75 | run_name += "patch_"
76 |
77 | elif args.rand_gen:
78 | run_name += "rand-gen_"
79 |
80 | if args.spec_proj:
81 | run_name += "spec-proj_"
82 |
83 | if args.spec_reg is not None:
84 | run_name += "spec-reg_{}_".format(args.spec_reg)
85 |
86 |
87 | run_name += "aug_{}_".format(args.dsa)
88 |
89 | run_name += "ipc_{}_".format(args.ipc)
90 |
91 | run_name += "batch_{}_".format(args.batch_syn)
92 |
93 | run_name += "ol_{}_il_{}_".format(args.outer_loop, args.inner_loop)
94 |
95 | run_name += "im-opt_{}_".format(args.im_opt)
96 |
97 | run_name += "eval_{}_".format(args.eval_mode)
98 |
99 | args.save_path = os.path.join(args.save_path, run_name)
100 |
101 | if not os.path.exists(args.data_path):
102 | os.mkdir(args.data_path)
103 |
104 | if not os.path.exists(args.save_path):
105 | os.makedirs(args.save_path)
106 |
107 | # eval_it_pool = np.arange(0, args.syn_batches*args.Iteration+1, 100).tolist() if args.eval_mode == 'S' else [args.Iteration] # The list of iterations when we evaluate models and record results.
108 | eval_it_pool = np.arange(0, args.Iteration + 1, args.eval_it).tolist()
109 | channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader, loader_train_dict, class_map, class_map_inv = get_dataset(args.dataset, args.data_path, args.batch_real, args.res, args=args)
110 | model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)
111 |
112 | im_res = im_size[0]
113 |
114 | # print('\n================== Exp %d ==================\n '%exp)
115 | print('Hyper-parameters: \n', args.__dict__)
116 | print('Evaluation model pool: ', model_eval_pool)
117 |
118 | save_dir = os.path.join(args.buffer_path, args.dataset)
119 | if args.pm1:
120 | save_dir = os.path.join(args.buffer_path, "tanh", args.dataset)
121 | else:
122 | save_dir = os.path.join(args.buffer_path, args.dataset)
123 | # if args.dataset == "ImageNet":
124 | # save_dir = os.path.join(save_dir, args.subset, str(args.res))
125 | if args.dataset in ["CIFAR10", "CIFAR100"] and args.zca:
126 | save_dir += "_ZCA"
127 | save_dir = os.path.join(save_dir, args.model)
128 |
129 | save_dir = os.path.join(save_dir, "depth-{}".format(args.depth), "width-{}".format(args.width))
130 |
131 | save_dir = os.path.join(save_dir, args.norm_train)
132 |
133 | if not os.path.exists(save_dir):
134 | os.makedirs(save_dir)
135 |
136 |
137 | if args.dataset != "ImageNet" or True:
138 | ''' organize the real dataset '''
139 | images_all = []
140 | labels_all = []
141 | indices_class = [[] for c in range(num_classes)]
142 | print(len(dst_train))
143 | print("BUILDING DATASET")
144 | for i in tqdm(range(len(dst_train))):
145 | sample = dst_train[i]
146 | images_all.append(torch.unsqueeze(sample[0], dim=0))
147 | labels_all.append(class_map[torch.tensor(sample[1]).item()])
148 | # images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in tqdm(range(len(dst_train)))]
149 | # labels_all = [class_map[dst_train[i][1]] for i in tqdm(range(len(dst_train)))]
150 | for i, lab in tqdm(enumerate(labels_all)):
151 | indices_class[lab].append(i)
152 | images_all = torch.cat(images_all, dim=0).to("cpu")
153 | labels_all = torch.tensor(labels_all, dtype=torch.long, device="cpu")
154 |
155 | for c in range(num_classes):
156 | print('class c = %d: %d real images'%(c, len(indices_class[c])))
157 |
158 | for ch in range(channel):
159 | print('real images channel %d, mean = %.4f, std = %.4f'%(ch, torch.mean(images_all[:, ch]), torch.std(images_all[:, ch])))
160 |
161 | criterion = nn.CrossEntropyLoss().to(args.device)
162 |
163 | trajectories = []
164 |
165 | dst_train = TensorDataset(copy.deepcopy(images_all.detach()), copy.deepcopy(labels_all.detach()))
166 | trainloader = torch.utils.data.DataLoader(dst_train, batch_size=args.batch_train, shuffle=True, num_workers=0)
167 |
168 | ''' set augmentation for whole-dataset training '''
169 | # args.dsa = False
170 | args.dc_aug_param = get_daparam(args.dataset, args.model, args.model, None)
171 | args.dc_aug_param['strategy'] = 'crop_scale_rotate' # for whole-dataset training
172 | print('DC augmentation parameters: \n', args.dc_aug_param)
173 |
174 | for it in range(0, args.Iteration):
175 |
176 | ''' Train synthetic data '''
177 | teacher_net = get_network(args.model, channel, num_classes, im_size, depth=args.depth, width=args.width, norm=args.norm_train).to(args.device) # get a random model
178 | teacher_net.train()
179 | # if torch.cuda.device_count() > 1:
180 | # teacher_net = torch.nn.DataParallel(teacher_net)
181 | lr = args.lr_teacher
182 | # teacher_optim = torch.optim.SGD(teacher_net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005) # optimizer_img for synthetic data
183 | teacher_optim = torch.optim.SGD(teacher_net.parameters(), lr=lr, momentum=args.mom, weight_decay=args.l2) # optimizer_img for synthetic data
184 | teacher_optim.zero_grad()
185 |
186 | timestamps = []
187 |
188 | timestamps.append([p.detach().cpu() for p in teacher_net.parameters()])
189 |
190 | lr_schedule = [args.train_epochs // 2 + 1]
191 |
192 | for e in range(args.train_epochs):
193 |
194 |
195 | train_loss, train_acc = epoch("train", dataloader=trainloader, net=teacher_net, optimizer=teacher_optim,
196 | criterion=criterion, args=args, aug=True)
197 |
198 | test_loss, test_acc = epoch("test", dataloader=testloader, net=teacher_net, optimizer=None,
199 | criterion=criterion, args=args, aug=False)
200 |
201 | print("Itr: {}\tEpoch: {}\tReal Acc: {}\tTest Acc: {}".format(it, e, train_acc, test_acc))
202 |
203 | timestamps.append([p.detach().cpu() for p in teacher_net.parameters()])
204 |
205 | if e in lr_schedule and args.decay:
206 | lr *= 0.1
207 | teacher_optim = torch.optim.SGD(teacher_net.parameters(), lr=lr, momentum=args.mom, weight_decay=args.l2)
208 | teacher_optim.zero_grad()
209 |
210 | trajectories.append(timestamps)
211 |
212 | if len(trajectories) == args.save_interval:
213 | n = 0
214 | while os.path.exists(os.path.join(save_dir, "replay_buffer_{}.pt".format(n))):
215 | n += 1
216 | torch.save(trajectories, os.path.join(save_dir, "replay_buffer_{}.pt".format(n)))
217 | trajectories = []
218 |
219 | print(trajectories)
220 |
221 |
222 | if __name__ == '__main__':
223 | parser = argparse.ArgumentParser(description='Parameter Processing')
224 | parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset')
225 | # parser.add_argument('--subset', type=str, default='imagenette', help='subset')
226 | parser.add_argument('--model', type=str, default='ConvNet', help='model')
227 | parser.add_argument('--ipc', type=int, default=1, help='image(s) per class')
228 | parser.add_argument('--eval_mode', type=str, default='S', help='eval_mode') # S: the same to training model, M: multi architectures, W: net width, D: net depth, A: activation function, P: pooling layer, N: normalization layer,
229 | # parser.add_argument('--num_exp', type=int, default=1, help='the number of experiments')
230 | parser.add_argument('--num_eval', type=int, default=5, help='the number of evaluating randomly initialized models')
231 | parser.add_argument('--eval_it', type=int, default=100, help='how often to evaluate')
232 | parser.add_argument('--epoch_eval_train', type=int, default=300, help='epochs to train a model with synthetic data')
233 | parser.add_argument('--Iteration', type=int, default=1000, help='training iterations')
234 | parser.add_argument('--lr_img', type=float, default=None, help='learning rate for updating synthetic images')
235 | parser.add_argument('--lr_lr', type=float, default=None, help='learning rate learning rate')
236 | parser.add_argument('--mom_img', type=float, default=0.5, help='momentum for updating synthetic images')
237 | parser.add_argument('--lr_teacher', type=float, default=0.01, help='learning rate for updating network parameters')
238 | parser.add_argument('--batch_real', type=int, default=256, help='batch size for real data')
239 | parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks')
240 | parser.add_argument('--batch_syn', type=int, default=None, help='batch size for syn data')
241 | parser.add_argument('--batch_test', type=int, default=128, help='batch size for training networks')
242 | parser.add_argument('--syn_batches', type=int, default=1, help='number of synthetic batches')
243 | parser.add_argument('--pix_init', type=str, default='noise', choices=["noise", "real"], help='noise/real: initialize synthetic images from random noise or randomly sampled real images.')
244 | parser.add_argument('--gan_init', type=str, default='class', choices=["class", "rand"])
245 | parser.add_argument('--dsa', type=str, default='True', choices=['True', 'False'], help='whether to use differentiable Siamese augmentation.')
246 | parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', help='differentiable Siamese augmentation strategy')
247 | parser.add_argument('--data_path', type=str, default='data', help='dataset path')
248 | parser.add_argument('--buffer_path', type=str, default='./buffers', help='buffer path')
249 | parser.add_argument('--save_path', type=str, default='results_buffer', help='path to save results')
250 | parser.add_argument('--dis_metric', type=str, default='ours', help='distance metric')
251 | parser.add_argument('--blur', action='store_true')
252 | parser.add_argument('--patch', action='store_true')
253 | parser.add_argument('--tanh', action='store_true')
254 | parser.add_argument('--proj_ball', action='store_true')
255 | parser.add_argument('--rand_gen', action='store_true')
256 | parser.add_argument('--spec_proj', action='store_true')
257 | parser.add_argument('--spec_reg', type=float, default=None)
258 | parser.add_argument('--clip', action='store_true')
259 | parser.add_argument('--im_opt', type=str, default='sgd', choices=['sgd', 'adam'])
260 | parser.add_argument('--outer_loop', type=int, default=None)
261 | parser.add_argument('--inner_loop', type=int, default=None)
262 | parser.add_argument('--skip_epochs', type=int, default=0)
263 | parser.add_argument('--train_epochs', type=int, default=200)
264 | parser.add_argument('--trunc', type=float, default=1, help='truncation_trick')
265 | # parser.add_argument('--res', type=int, default=128, choices=[128, 256, 512], help='resolution')
266 | parser.add_argument('--res', type=int, default=128, help='resolution')
267 | parser.add_argument('--layer', type=int, default=0)
268 | parser.add_argument('--blur_perc', type=float, default=0.0)
269 | parser.add_argument('--blur_min', type=float, default=0.0)
270 | parser.add_argument('--blur_max', type=float, default=3.0)
271 | parser.add_argument('--rand_cond', action='store_true')
272 | parser.add_argument('--rand_lat', action='store_true')
273 | parser.add_argument('--coarse2fine', action='store_true')
274 | parser.add_argument('--teacher', type=str, default='fake', choices=['real', 'fake'])
275 | group = parser.add_mutually_exclusive_group()
276 | group.add_argument('--texture', action='store_true')
277 | group.add_argument('--tex_avg', action='store_true')
278 | parser.add_argument('--tex_it', type=int, default=500)
279 | parser.add_argument('--tex_batch', type=int, default=10)
280 | parser.add_argument('--lr_decay', type=str, default='none', choices=['none', 'cosine', 'linear', 'step'])
281 | # parser.add_argument('--g_grad_ipc', type=int, default=10)
282 | parser.add_argument('--g_train_ipc', type=int, default=None)
283 | parser.add_argument('--g_eval_ipc', type=int, default=None)
284 | parser.add_argument('--g_grad_ipc', type=int, default=None)
285 | parser.add_argument('--cluster', action='store_true')
286 | parser.add_argument('--zca', action='store_true')
287 | parser.add_argument('--learn_labels', action='store_true')
288 | parser.add_argument('--save_interval', type=int, default=10)
289 | parser.add_argument('--mom', type=float, default=0.0)
290 | parser.add_argument('--l2', type=float, default=0.0)
291 | parser.add_argument('--decay', type=bool, default=False)
292 | parser.add_argument('--cl_subset', default=None)
293 | parser.add_argument('--kip_zca', action='store_true')
294 | parser.add_argument('--pm1', action='store_true')
295 |
296 | parser.add_argument('--canvas_size', type=int, default=None)
297 | parser.add_argument('--canvas_samples', type=int, default=1)
298 | parser.add_argument('--canvas_stride', type=int, default=1)
299 |
300 | parser.add_argument('--space', type=str, default='p', choices=['p', 'z', 'w', 'w+', 'wp', 'g'], help='[ p | z | w | w+ ]')
301 |
302 | parser.add_argument('--width', type=int, default=128)
303 | parser.add_argument('--depth', type=int, default=3)
304 |
305 | parser.add_argument('--norm_train', type=str, default="batchnorm")
306 |
307 | parser.add_argument('--norm_eval', type=str, default="none")
308 |
309 | # For speeding up, we can decrease the Iteration and epoch_eval_train, which will not cause significant performance decrease.
310 |
311 | args = parser.parse_args()
312 | main(args)
313 |
314 |
315 |
--------------------------------------------------------------------------------
/DC/main_DC_FreD.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
3 |
4 | import time
5 | import copy
6 | import argparse
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | from torchvision.utils import save_image
11 | from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug, set_seed, save_and_print, TensorDataset, get_images
12 |
13 | import shutil
14 | import torchvision
15 | import matplotlib.pyplot as plt
16 | from frequency_transforms import DCT
17 |
18 | class SynSet():
19 | def __init__(self, args):
20 | ### Basic ###
21 | self.args = args
22 | self.log_path = self.args.log_path
23 | self.channel = self.args.channel
24 | self.num_classes = self.args.num_classes
25 | self.im_size = self.args.im_size
26 | self.device = self.args.device
27 | self.ipc = self.args.ipc
28 |
29 | ### FreD ###
30 | self.dct = DCT(resolution=self.im_size[0], device=self.device)
31 | self.lr_freq = self.args.lr_freq
32 | self.mom_freq = self.args.mom_freq
33 | self.msz_per_channel = self.args.msz_per_channel
34 | self.num_per_class = int((self.ipc * self.im_size[0] * self.im_size[1]) / self.msz_per_channel)
35 |
36 | def init(self, images_real, labels_real, indices_class):
37 | ### Initialize Frequency (F) ###
38 | images = torch.randn(size=(self.num_classes * self.num_per_class, self.channel, self.im_size[0], self.im_size[1]), dtype=torch.float, device=self.device)
39 | for c in range(self.num_classes):
40 | idx_shuffle = np.random.permutation(indices_class[c])[:self.num_per_class]
41 | images.data[c * self.num_per_class:(c + 1) * self.num_per_class] = images_real[idx_shuffle].detach().data
42 | self.freq_syn = self.dct.forward(images)
43 | self.freq_syn.requires_grad = True
44 | del images
45 |
46 | ### Initialize Mask (M) ###
47 | self.mask = torch.zeros(size=(self.num_classes * self.num_per_class, self.channel, self.im_size[0], self.im_size[1]), dtype=torch.float, device=self.device)
48 | self.mask.requires_grad = False
49 |
50 | ### Initialize Label ###
51 | self.label_syn = torch.tensor([np.ones(self.num_per_class) * i for i in range(self.num_classes)], requires_grad=False, device=self.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9]
52 | self.label_syn = self.label_syn.long()
53 |
54 | ### Initialize Optimizer ###
55 | self.optimizers = torch.optim.SGD([self.freq_syn, ], lr=self.lr_freq, momentum=self.mom_freq)
56 |
57 | self.init_mask()
58 | self.optim_zero_grad()
59 | self.show_budget()
60 |
61 | def get(self, indices=None, need_copy=False):
62 | if not hasattr(indices, '__iter__'):
63 | indices = range(len(self.label_syn))
64 |
65 | if need_copy:
66 | freq_syn, label_syn = copy.deepcopy(self.freq_syn[indices].detach()), copy.deepcopy(self.label_syn[indices].detach())
67 | mask = copy.deepcopy(self.mask[indices].detach())
68 | else:
69 | freq_syn, label_syn = self.freq_syn[indices], self.label_syn[indices]
70 | mask = self.mask[indices]
71 |
72 | image_syn = self.dct.inverse(mask * freq_syn)
73 | return image_syn, label_syn
74 |
75 | def init_mask(self):
76 | save_and_print(self.args.log_path, "Initialize Mask")
77 |
78 | for c in range(self.num_classes):
79 | freq_c = copy.deepcopy(self.freq_syn[c * self.num_per_class:(c + 1) * self.num_per_class].detach())
80 | freq_c = torch.mean(freq_c, 1)
81 | freqs_flat = torch.flatten(freq_c, 1)
82 | freqs_flat = freqs_flat - torch.mean(freqs_flat, dim=0)
83 |
84 | try:
85 | cov = torch.cov(freqs_flat.T)
86 | except:
87 | save_and_print(self.args.log_path, f"Can not use torch.cov. Instead use np.cov")
88 | cov = np.cov(freqs_flat.T.cpu().numpy())
89 | cov = torch.tensor(cov, dtype=torch.float, device=self.device)
90 | total_variance = torch.sum(torch.diag(cov))
91 | vr_fl2f = torch.zeros((np.prod(self.im_size), 1), device=self.device)
92 | for idx in range(np.prod(self.im_size)):
93 | pc_low = torch.eye(np.prod(self.im_size), device=self.device)[idx].reshape(-1, 1)
94 | vector_variance = torch.matmul(torch.matmul(pc_low.T, cov), pc_low)
95 | explained_variance_ratio = vector_variance / total_variance
96 | vr_fl2f[idx] = explained_variance_ratio.item()
97 |
98 | v, i = torch.topk(vr_fl2f.flatten(), self.msz_per_channel)
99 | top_indices = np.array(np.unravel_index(i.cpu().numpy(), freq_c.shape)).T[:, 1:]
100 | for h, w in top_indices:
101 | self.mask[c * self.num_per_class:(c + 1) * self.num_per_class, :, h, w] = 1.0
102 | save_and_print(self.args.log_path, f"{get_time()} Class {c:3d} | {torch.sum(self.mask[c * self.num_per_class, 0] > 0.0):5d}")
103 |
104 | ### Visualize and Save ###
105 | indices_save = np.arange(10) * self.num_per_class
106 | grid = torchvision.utils.make_grid(self.mask[indices_save], nrow=10)
107 | plt.imshow(np.transpose(grid.detach().cpu().numpy(), (1, 2, 0)))
108 | plt.savefig(f"{self.args.save_path}/Mask.png", dpi=300)
109 | plt.close()
110 |
111 | mask_save = copy.deepcopy(self.mask.detach())
112 | torch.save(mask_save.cpu(), os.path.join(self.args.save_path, "mask.pt"))
113 | del mask_save
114 |
115 | def optim_zero_grad(self):
116 | self.optimizers.zero_grad()
117 |
118 | def optim_step(self):
119 | self.optimizers.step()
120 |
121 | def show_budget(self):
122 | save_and_print(self.log_path, '=' * 50)
123 | save_and_print(self.log_path, f"Freq: {self.freq_syn.shape} | Mask: {self.mask.shape} , {torch.sum(self.mask[0, 0] > 0.0):5d}")
124 | images, _ = self.get(need_copy=True)
125 | save_and_print(self.log_path, f"Decode condensed data: {images.shape}")
126 | del images
127 | save_and_print(self.log_path, '=' * 50)
128 |
129 | def main():
130 | parser = argparse.ArgumentParser(description='Parameter Processing')
131 | parser.add_argument('--method', type=str, default='DC', help='DC/DSA')
132 | parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset')
133 | parser.add_argument('--model', type=str, default='ConvNet', help='model')
134 | parser.add_argument('--ipc', type=int, default=1, help='image(s) per class')
135 | parser.add_argument('--eval_mode', type=str, default='S', help='eval_mode') # S: the same to training model, M: multi architectures, W: net width, D: net depth, A: activation function, P: pooling layer, N: normalization layer,
136 | parser.add_argument('--num_exp', type=int, default=1, help='the number of experiments')
137 | parser.add_argument('--num_eval', type=int, default=10, help='the number of evaluating randomly initialized models')
138 | parser.add_argument('--epoch_eval_train', type=int, default=300, help='epochs to train a model with synthetic data')
139 | parser.add_argument('--Iteration', type=int, default=1000, help='training iterations')
140 | parser.add_argument('--lr_net', type=float, default=0.01, help='learning rate for updating network parameters')
141 | parser.add_argument('--batch_real', type=int, default=256, help='batch size for real data')
142 | parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks')
143 | parser.add_argument('--dsa_strategy', type=str, default='None', help='differentiable Siamese augmentation strategy')
144 | parser.add_argument('--data_path', type=str, default='../data', help='dataset path')
145 | parser.add_argument('--save_path', type=str, default='./results', help='path to save results')
146 | parser.add_argument('--dis_metric', type=str, default='ours', help='distance metric')
147 |
148 | parser.add_argument('--seed', type=int, default=0)
149 | parser.add_argument('--sh_file', type=str)
150 | parser.add_argument('--FLAG', type=str, default="TEST")
151 |
152 | ### FreD ###
153 | parser.add_argument('--batch_syn', type=int)
154 | parser.add_argument('--msz_per_channel', type=int)
155 | parser.add_argument('--lr_freq', type=float)
156 | parser.add_argument('--mom_freq', type=float)
157 | args = parser.parse_args()
158 | set_seed(args.seed)
159 |
160 | args.outer_loop, args.inner_loop = get_loops(args.ipc)
161 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
162 | args.dsa_param = ParamDiffAug()
163 | args.dsa = True if args.method == 'DSA' else False
164 |
165 | if not os.path.exists(args.data_path):
166 | os.mkdir(args.data_path)
167 |
168 | if not os.path.exists(args.save_path):
169 | os.mkdir(args.save_path)
170 | args.save_path = args.save_path + f"/{args.FLAG}"
171 | if not os.path.exists(args.save_path):
172 | os.mkdir(args.save_path)
173 |
174 | shutil.copy(f"./scripts/{args.sh_file}", f"{args.save_path}/{args.sh_file}")
175 | print (args)
176 | args.log_path = f"{args.save_path}/log.txt"
177 |
178 | eval_it_pool = np.arange(0, args.Iteration+1, 500).tolist() if args.eval_mode == 'S' or args.eval_mode == 'SS' else [args.Iteration] # The list of iterations when we evaluate models and record results.
179 | channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path, args=args)
180 | args.channel, args.im_size, args.num_classes, args.mean, args.std = channel, im_size, num_classes, mean, std
181 | model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)
182 |
183 |
184 | accs_all_exps = dict() # record performances of all experiments
185 | for key in model_eval_pool:
186 | accs_all_exps[key] = []
187 |
188 | data_save = []
189 |
190 | for exp in range(args.num_exp):
191 | save_and_print(args.log_path, f'\n================== Exp {exp} ==================\n ')
192 | save_and_print(args.log_path, f'Hyper-parameters: {args.__dict__}')
193 |
194 | ''' organize the real dataset '''
195 | images_all = []
196 | labels_all = []
197 | indices_class = [[] for c in range(num_classes)]
198 |
199 | images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
200 | labels_all = [dst_train[i][1] for i in range(len(dst_train))]
201 | for i, lab in enumerate(labels_all):
202 | indices_class[lab].append(i)
203 | images_all = torch.cat(images_all, dim=0).to(args.device)
204 | labels_all = torch.tensor(labels_all, dtype=torch.long, device=args.device)
205 |
206 | ''' initialize the synthetic data '''
207 | synset = SynSet(args)
208 | synset.init(images_all, labels_all, indices_class)
209 |
210 | ''' training '''
211 | criterion = nn.CrossEntropyLoss().to(args.device)
212 | save_and_print(args.log_path, '%s training begins'%get_time())
213 |
214 | for it in range(args.Iteration+1):
215 |
216 | ''' Evaluate synthetic data '''
217 | if it in eval_it_pool:
218 | for model_eval in model_eval_pool:
219 | save_and_print(args.log_path, '-------------------------\nEvaluation\nmodel_train = %s, model_eval = %s, iteration = %d'%(args.model, model_eval, it))
220 | if args.dsa:
221 | args.epoch_eval_train = 1000
222 | args.dc_aug_param = None
223 | save_and_print(args.log_path, f'DSA augmentation strategy: {args.dsa_strategy}')
224 | save_and_print(args.log_path, f'DSA augmentation parameters: {args.dsa_param.__dict__}')
225 | else:
226 | args.dc_aug_param = get_daparam(args.dataset, args.model, model_eval, args.ipc) # This augmentation parameter set is only for DC method. It will be muted when args.dsa is True.
227 | save_and_print(args.log_path, f'DC augmentation parameters: {args.dc_aug_param}')
228 |
229 | if args.dsa or args.dc_aug_param['strategy'] != 'none':
230 | args.epoch_eval_train = 1000 # Training with data augmentation needs more epochs.
231 | else:
232 | args.epoch_eval_train = 300
233 |
234 | accs = []
235 | for it_eval in range(args.num_eval):
236 | net_eval = get_network(model_eval, channel, num_classes, im_size).to(args.device) # get a random model
237 | image_syn_eval, label_syn_eval = synset.get(need_copy=True)
238 | _, acc_train, acc_test = evaluate_synset(it_eval, net_eval, image_syn_eval, label_syn_eval, testloader, args)
239 | accs.append(acc_test)
240 | save_and_print(args.log_path, 'Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------'%(len(accs), model_eval, np.mean(accs), np.std(accs)))
241 | del image_syn_eval, label_syn_eval
242 |
243 | if it == args.Iteration: # record the final results
244 | accs_all_exps[model_eval] += accs
245 |
246 | ''' visualize and save '''
247 | save_name = os.path.join(args.save_path, 'vis_%s_%s_%s_%dipc_exp%d_iter%d.png'%(args.method, args.dataset, args.model, args.ipc, exp, it))
248 | image_syn_vis, _ = synset.get(need_copy=True)
249 | for ch in range(channel):
250 | image_syn_vis[:, ch] = image_syn_vis[:, ch] * std[ch] + mean[ch]
251 | image_syn_vis[image_syn_vis<0] = 0.0
252 | image_syn_vis[image_syn_vis>1] = 1.0
253 | save_image(image_syn_vis, save_name, nrow=synset.num_per_class) # Trying normalize = True/False may get better visual effects.
254 | del image_syn_vis
255 |
256 |
257 | ''' Train synthetic data '''
258 | net = get_network(args.model, channel, num_classes, im_size).to(args.device) # get a random model
259 | net.train()
260 | net_parameters = list(net.parameters())
261 | optimizer_net = torch.optim.SGD(net.parameters(), lr=args.lr_net) # optimizer_img for synthetic data
262 | optimizer_net.zero_grad()
263 | loss_avg = 0
264 | args.dc_aug_param = None # Mute the DC augmentation when learning synthetic data (in inner-loop epoch function) in oder to be consistent with DC paper.
265 |
266 | for ol in range(args.outer_loop):
267 |
268 | ''' freeze the running mu and sigma for BatchNorm layers '''
269 | # Synthetic data batch, e.g. only 1 image/batch, is too small to obtain stable mu and sigma.
270 | # So, we calculate and freeze mu and sigma for BatchNorm layer with real data batch ahead.
271 | # This would make the training with BatchNorm layers easier.
272 | BN_flag = False
273 | BNSizePC = 16 # for batch normalization
274 | for module in net.modules():
275 | if 'BatchNorm' in module._get_name(): #BatchNorm
276 | BN_flag = True
277 | if BN_flag:
278 | img_real = torch.cat([get_images(images_all, indices_class, c, BNSizePC) for c in range(num_classes)], dim=0)
279 | net.train() # for updating the mu, sigma of BatchNorm
280 | output_real = net(img_real) # get running mu, sigma
281 | for module in net.modules():
282 | if 'BatchNorm' in module._get_name(): #BatchNorm
283 | module.eval() # fix mu and sigma of every BatchNorm layer
284 |
285 |
286 | ''' update synthetic data '''
287 | loss = torch.tensor(0.0).to(args.device)
288 | for c in range(num_classes):
289 | img_real = get_images(images_all, indices_class, c, args.batch_real)
290 | lab_real = torch.ones((img_real.shape[0],), device=args.device, dtype=torch.long) * c
291 |
292 | if args.batch_syn > 0:
293 | indices = np.random.permutation(range(c * synset.num_per_class, (c + 1) * synset.num_per_class))[:args.batch_syn]
294 | else:
295 | indices = range(c * synset.num_per_class, (c + 1) * synset.num_per_class)
296 |
297 | img_syn, lab_syn = synset.get(indices=indices)
298 |
299 | if args.dsa:
300 | seed = int(time.time() * 1000) % 100000
301 | img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)
302 | img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param)
303 |
304 | output_real = net(img_real)
305 | loss_real = criterion(output_real, lab_real)
306 | gw_real = torch.autograd.grad(loss_real, net_parameters)
307 | gw_real = list((_.detach().clone() for _ in gw_real))
308 |
309 | output_syn = net(img_syn)
310 | loss_syn = criterion(output_syn, lab_syn)
311 | gw_syn = torch.autograd.grad(loss_syn, net_parameters, create_graph=True)
312 |
313 | loss += match_loss(gw_syn, gw_real, args)
314 |
315 | synset.optim_zero_grad()
316 | loss.backward()
317 | synset.optim_step()
318 | loss_avg += loss.item()
319 |
320 | if ol == args.outer_loop - 1:
321 | break
322 |
323 |
324 | ''' update network '''
325 | image_syn_train, label_syn_train = synset.get(need_copy=True)
326 | dst_syn_train = TensorDataset(image_syn_train, label_syn_train)
327 | trainloader = torch.utils.data.DataLoader(dst_syn_train, batch_size=args.batch_train, shuffle=True, num_workers=0)
328 | for il in range(args.inner_loop):
329 | epoch('train', trainloader, net, optimizer_net, criterion, args, aug = True if args.dsa else False)
330 |
331 |
332 | loss_avg /= (num_classes*args.outer_loop)
333 |
334 | if it%10 == 0:
335 | save_and_print(args.log_path, '%s iter = %04d, loss = %.4f' % (get_time(), it, loss_avg))
336 |
337 | if it == args.Iteration: # only record the final results
338 | data_save.append([synset.get(need_copy=True)])
339 | torch.save({'data': data_save, 'accs_all_exps': accs_all_exps, }, os.path.join(args.save_path, 'res_%s_%s_%s_%dipc.pt'%(args.method, args.dataset, args.model, args.ipc)))
340 |
341 |
342 | save_and_print(args.log_path, '\n==================== Final Results ====================\n')
343 | for key in model_eval_pool:
344 | accs = accs_all_exps[key]
345 | save_and_print(args.log_path, 'Run %d experiments, train on %s, evaluate %d random %s, mean = %.2f%% std = %.2f%%'%(args.num_exp, args.model, len(accs), key, np.mean(accs)*100, np.std(accs)*100))
346 |
347 |
348 |
349 | if __name__ == '__main__':
350 | main()
351 |
352 |
353 |
--------------------------------------------------------------------------------