├── requirements.txt
├── figures
└── teaser.png
├── Evaluation
├── fms_224.pkl
├── visualize_DFM.py
├── compute_ADCS.py
├── verify_mask_imgn.py
└── test_rank.py
├── LICENSE
├── backbone
├── vgg.py
├── alexnet.py
└── resnet.py
├── data
└── Synthetic.py
├── blocks
└── resnet
│ └── Blocks.py
├── README.md
└── train.py
/requirements.txt:
--------------------------------------------------------------------------------
1 | scipy
2 | matplotlib
3 | pytorch_lightning
4 | torchmetrics
5 | timm
--------------------------------------------------------------------------------
/figures/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nis-research/nn-frequency-shortcuts/HEAD/figures/teaser.png
--------------------------------------------------------------------------------
/Evaluation/fms_224.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nis-research/nn-frequency-shortcuts/HEAD/Evaluation/fms_224.pkl
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 nis-research
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Evaluation/visualize_DFM.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import argparse
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 |
6 | def main(args):
7 | dir = './DFMs/'
8 | imagenet_classes = ['Airliner','Wagon','Humming\n Bird','Siamese\n Cat','Ox','Golden\n Retriever','Tailed\n Frog','Zebra','Container\n Ship','Trailer\n Truck']
9 | m_path = args.DFMs+'.pkl'
10 | with open(dir+m_path, 'rb') as f:
11 | all_mask = pickle.load(f)
12 | fig, axs = plt.subplots(1,10,sharex=True,sharey=True)
13 | fig.set_figheight(15)
14 | fig.set_figwidth(15)
15 | for mask_i in range(len(all_mask)):
16 | map = np.array(all_mask[mask_i])
17 | axs[mask_i].imshow(map,cmap='gray')
18 | axs[mask_i].set_title(imagenet_classes[mask_i])
19 | axs[mask_i].set_yticks([])
20 | axs[mask_i].set_xticks([])
21 |
22 | plt.subplots_adjust(left=0.2, bottom=0.1, right=0.8, top=0.8, wspace=0.05, hspace=-0.85)
23 | plt.savefig(dir + args.DFMs + '.pdf',bbox_inches='tight')
24 |
25 |
26 |
27 | if __name__ == '__main__':
28 | parser = argparse.ArgumentParser()
29 | parser.add_argument('--DFMs', type=str, default='resnet18_DFM_1',
30 | help='File name of DFMs')
31 |
32 | args = parser.parse_args()
33 |
34 | main(args)
--------------------------------------------------------------------------------
/backbone/vgg.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | cfg = {
4 | 'VGG8': [64, 'M', 128, 'M', 256, 'M', 512, 'M', 512, 'M'],
5 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
6 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
7 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
8 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
9 | }
10 |
11 |
12 | class VGG(nn.Module):
13 | def __init__(self, vgg_name,num_class):
14 | super(VGG, self).__init__()
15 | self.features = self._make_layers(cfg[vgg_name])
16 | self.classifier = nn.Linear(self.in_planes, num_class)
17 |
18 | def forward(self, x):
19 | out = self.features(x)
20 | enc = F.avg_pool2d(out, out.size(2))
21 | enc = enc.view(enc.size(0), -1)
22 | prediction = self.classifier(enc)
23 | return prediction
24 |
25 | def _make_layers(self, cfg):
26 | layers = []
27 | in_channels = 3
28 | for x in cfg:
29 | if x == 'M':
30 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
31 | else:
32 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
33 | nn.BatchNorm2d(x),
34 | nn.ReLU(inplace=True)]
35 | in_channels = x
36 | # layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
37 | self.in_planes = in_channels
38 |
39 | return nn.Sequential(*layers)
40 |
41 |
42 |
43 |
--------------------------------------------------------------------------------
/backbone/alexnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class AlexNet(nn.Module):
5 | def __init__(self, num_classes=11):
6 | super(AlexNet, self).__init__()
7 | self.features = nn.Sequential(
8 | nn.Conv2d(3, 96, kernel_size=3, stride=2, padding=1,bias=False),
9 |
10 | nn.ReLU(inplace=True),
11 | nn.BatchNorm2d(96),
12 | nn.MaxPool2d(kernel_size=2),
13 |
14 | nn.Conv2d(96, 256, kernel_size=3, padding=1,bias=False),
15 | nn.ReLU(inplace=True),
16 | nn.BatchNorm2d(256),
17 | nn.MaxPool2d(kernel_size=2),
18 |
19 | nn.Conv2d(256, 384, kernel_size=3, padding=1,bias=False),
20 | nn.ReLU(inplace=True),
21 | nn.BatchNorm2d(384),
22 |
23 | nn.Conv2d(384, 384, kernel_size=3, padding=1,bias=False),
24 | nn.ReLU(inplace=True),
25 | nn.BatchNorm2d(384),
26 |
27 | nn.Conv2d(384, 256, kernel_size=3, padding=1,bias=False),
28 | nn.ReLU(inplace=True),
29 | nn.BatchNorm2d(256),
30 | nn.MaxPool2d(kernel_size=2),
31 | )
32 | self.classifier = nn.Sequential(
33 | nn.Dropout(),
34 | nn.Linear(256 * 2 * 2, 4096,bias=False),
35 | nn.ReLU(inplace=True),
36 | nn.Dropout(),
37 | nn.Linear(4096, 4096,bias=False),
38 | nn.ReLU(inplace=True),
39 | nn.Linear(4096, num_classes,bias=False)
40 | )
41 |
42 | def forward(self, x):
43 |
44 | x = self.features(x.float())
45 | enc = x.view(x.size(0), 256 * 2 * 2)
46 | prediction = self.classifier(enc)
47 | out = enc
48 | return prediction
--------------------------------------------------------------------------------
/data/Synthetic.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from torch.utils.data.dataset import Dataset
4 | from PIL import Image
5 | import numpy as np
6 | import torch
7 |
8 | class Synthetic(Dataset):
9 | def __init__(self, root_dir,train=True, transform = None, complex = '', band = '',t=True):
10 | super(Synthetic).__init__()
11 |
12 | if train is False:
13 |
14 | self.labels_path = os.path.join(root_dir,'synthetic','test_label'+complex +'.npy')
15 | self.root_dir = os.path.join(root_dir,'synthetic','test_data'+complex +band+'.npy')
16 | else:
17 | self.labels_path = os.path.join(root_dir,'synthetic','train_label'+complex+'.npy')
18 | self.root_dir = os.path.join(root_dir,'synthetic','train_data'+complex+band+'.npy')
19 |
20 |
21 | print(self.root_dir)
22 | self.transform = transform
23 | self.data = np.load(self.root_dir, allow_pickle=True)
24 | self.targets = np.load(self.labels_path, allow_pickle=True)
25 | self.band = band
26 | self.t = t
27 | # self.data = self.data.transpose((0, 3, 1, 2))
28 |
29 |
30 | def __len__(self):
31 | return len(self.data)
32 |
33 | def __getitem__(self, index):
34 | if self.t:
35 | img = self.data[index].permute(1,2,0).numpy()
36 | else:
37 | img = self.data[index]
38 |
39 | if self.transform is not None:
40 | # print(np.max(img))
41 | # print(np.min(img))
42 | img = np.clip(img,0,1)
43 | img = img*255
44 | img = Image.fromarray(img.astype(np.uint8),mode='RGB')
45 |
46 | img = self.transform(img)
47 |
48 | target = self.targets[index]
49 |
50 | return img, torch.tensor(target, dtype=torch.long)
--------------------------------------------------------------------------------
/backbone/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 |
7 | class ResNet(nn.Module):
8 | def __init__(self, block_en, num_blocks,num_class):
9 | super(ResNet, self).__init__()
10 |
11 | self.in_planes = 64
12 | self.num_class = num_class
13 |
14 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
15 | stride=1, padding=1, bias=False)
16 | self.bn1 = nn.BatchNorm2d(64)
17 | self.relu = nn.ReLU(inplace=True)
18 | self.layer1 = self._make_layer(block_en, 64, num_blocks[0], stride=1)
19 | self.layer2 = self._make_layer(block_en, 128, num_blocks[1], stride=2)
20 | self.layer3 = self._make_layer(block_en, 256, num_blocks[2], stride=2)
21 | self.layer4 = self._make_layer(block_en, 512, num_blocks[3], stride=2)
22 |
23 | self.features = nn.Sequential(self.conv1, self.bn1,self.relu,self.layer1,self.layer2,self.layer3,self.layer4)
24 |
25 |
26 | # self.sm = nn.Softmax(dim=1)
27 | self.classifier = nn.Linear(512*block_en.expansion,self.num_class)
28 |
29 | def _make_layer(self, block, planes, num_blocks, stride):
30 | strides = [stride] + [1]*(num_blocks-1)
31 | layers = []
32 | for stride in strides:
33 | layers.append(block(self.in_planes, planes, stride,shortcut=True))
34 | self.in_planes = planes*block.expansion
35 | return nn.Sequential(*layers)
36 |
37 |
38 | def forward(self, x):
39 | # out = F.relu(self.bn1(self.conv1(x)))
40 | enc = self.features(x)
41 | # print(enc.size())
42 | # out = self.layer1(out)
43 | # out = self.layer2(out)
44 | # out = self.layer3(out)
45 | # out = self.layer4(out)
46 | enc = F.avg_pool2d(enc, enc.size(2))
47 | enc = enc.view(enc.size(0), -1) # flatten
48 |
49 | prediction = self.classifier(enc)
50 | # prediction = self.sm(prediction)
51 | return prediction
52 |
--------------------------------------------------------------------------------
/Evaluation/compute_ADCS.py:
--------------------------------------------------------------------------------
1 | from scipy import signal
2 | from scipy.ndimage import gaussian_filter
3 | import numpy.fft as fft
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 | import torchvision
7 | import torch
8 | import torchvision.transforms as transforms
9 | from torchvision.datasets import ImageFolder
10 |
11 | def distance(i, j, imageSize, r1,r2):
12 | dis = np.sqrt((i - imageSize/2) ** 2 + (j - imageSize/2) ** 2)
13 | if dis < r2 and dis >=r1:
14 | return 1.0
15 | else:
16 | return 0
17 |
18 | def mask_radial(img, r1,r2):
19 | rows, cols = img.shape
20 | mask = np.zeros((rows, cols))
21 | for i in range(rows):
22 | for j in range(cols):
23 | mask[i, j] = distance(i, j, imageSize=rows, r1=r1,r2=r2)
24 | return mask
25 |
26 |
27 | def rgb2gray(rgb):
28 | r, g, b = rgb[:,:,0], rgb[:,:,1], rgb[:,:,2]
29 | gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
30 | return gray
31 |
32 |
33 |
34 | Energy = {}
35 |
36 | mean = [0.479838, 0.470448, 0.429404]
37 | std = [0.258143, 0.252662, 0.272406]
38 | transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(),transforms.Normalize(mean, std)])
39 | batchsize = 1
40 | data_test = ImageFolder('./datasets/ImageNet/val/',transform=transform) # data path to be changed
41 | test_loader = torch.utils.data.DataLoader(dataset=data_test, batch_size=batchsize, shuffle=False)
42 | img_size =224
43 | for x,y in test_loader:
44 | x1=x[0]
45 | y1 = np.zeros((img_size,img_size,3),dtype=np.complex128)
46 | for j in range(3):
47 | y1[:,:,j] = fft.fftshift(fft.fft2(x1[j,:,:]))
48 | y1[y1==0] = 12e-12
49 | abs_y1 = np.abs(y1)
50 | if y.item() in Energy:
51 | Energy[y.item()] += abs_y1
52 | else:
53 | Energy.update({y.item():abs_y1})
54 |
55 |
56 | fig, axs = plt.subplots(2,5,sharex=True,sharey=True)
57 | fig.set_figheight(8)
58 | fig.set_figwidth(20)
59 | for j in range(10):
60 | olp = np.zeros((img_size,img_size))
61 | for i in range(10):
62 | diff = Energy[j] - Energy[i]
63 | diff = rgb2gray(diff)
64 | diff[diff>0] = 1
65 | diff[diff<=0] = -1
66 | olp += diff
67 | if j >=5:
68 | axs[1,j-5].imshow(olp,cmap='jet',vmin=-9,vmax=9)
69 | axs[1,j-5].axis('off')
70 | axs[1,j-5].set_title('Class: %d' %j)
71 | else:
72 | axs[0,j].imshow(olp,cmap='jet',vmin=-9,vmax=9)
73 | axs[0,j].axis('off')
74 | axs[0,j].set_title('Class: %d' %j)
75 | plt.rcParams.update({'font.size': 25})
76 | plt.savefig('ADCS_imagenet10.pdf')
--------------------------------------------------------------------------------
/blocks/resnet/Blocks.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | class BasicBlock(nn.Module):
5 | expansion = 1
6 |
7 | def __init__(self, in_planes, planes, stride=1,shortcut=True):
8 | super(BasicBlock, self).__init__()
9 | self.conv1 = nn.Conv2d(
10 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
11 | self.bn1 = nn.BatchNorm2d(planes)
12 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
13 | stride=1, padding=1, bias=False)
14 | self.bn2 = nn.BatchNorm2d(planes)
15 | self.shortcut_flag = shortcut
16 | self.shortcut = nn.Sequential()
17 |
18 | if stride != 1 or in_planes != self.expansion*planes:
19 | self.shortcut = nn.Sequential(
20 | nn.Conv2d(in_planes, self.expansion*planes,
21 | kernel_size=1, stride=stride, bias=False),
22 | nn.BatchNorm2d(self.expansion*planes)
23 | )
24 |
25 | def forward(self, x):
26 | out = F.relu(self.bn1(self.conv1(x)))
27 | out = self.bn2(self.conv2(out))
28 | if self.shortcut_flag == True:
29 | out += self.shortcut(x)
30 |
31 | out = F.relu(out)
32 |
33 | return out
34 |
35 |
36 | class Upconvblock(nn.Module):
37 | expansion = 16
38 | def __init__(self,in_planes,output_channels, stride=1):
39 | super(Upconvblock,self).__init__()
40 | if stride == 1:
41 | self.scaleup = nn.Conv2d(in_planes,output_channels,kernel_size=3,padding=1)
42 | elif stride == 2:
43 | self.scaleup = nn.ConvTranspose2d(in_planes,output_channels,kernel_size=2,stride=stride)
44 | # self.scaleup = nn.ConvTranspose2d(in_planes,output_channels,kernel_size=2,stride=2)
45 | # self.scaleup = nn.Upsample(scale_factor=stride,mode="bilinear",align_corners=True)
46 | self.conv1 = nn.Conv2d(output_channels,output_channels,stride = 1,kernel_size=3,padding=1)
47 | self.bn1 = nn.BatchNorm2d(output_channels)
48 | self.conv2 = nn.Conv2d(output_channels,output_channels,stride=1,kernel_size=3,padding=1)
49 | self.bn2 = nn.BatchNorm2d(output_channels)
50 |
51 | def forward(self, x):
52 | x = self.scaleup(x)
53 | x = F.relu(self.bn1(self.conv1(x)))
54 | x = F.relu(self.bn2(self.conv2(x)))
55 |
56 | return x
57 |
58 |
59 | class Bottleneck(nn.Module):
60 | expansion = 4
61 |
62 | def __init__(self, in_planes, planes, stride=1,shortcut=True):
63 | super(Bottleneck, self).__init__()
64 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
65 | self.bn1 = nn.BatchNorm2d(planes)
66 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,stride=stride, padding=1, bias=False)
67 | self.bn2 = nn.BatchNorm2d(planes)
68 | self.conv3 = nn.Conv2d(planes, planes*self.expansion, kernel_size=1, bias=False)
69 | self.bn3 = nn.BatchNorm2d(planes*self.expansion)
70 | self.shortcut_flag = shortcut
71 | self.shortcut = nn.Sequential()
72 |
73 | if stride != 1 or in_planes != self.expansion*planes:
74 | self.shortcut = nn.Sequential(
75 | nn.Conv2d(in_planes, self.expansion*planes,
76 | kernel_size=1, stride=stride, bias=False),
77 | nn.BatchNorm2d(self.expansion*planes)
78 | )
79 |
80 | def forward(self, x):
81 | out = F.relu(self.bn1(self.conv1(x)))
82 | out = F.relu(self.bn2(self.conv2(out)))
83 |
84 | out = self.bn3(self.conv3(out))
85 | if self.shortcut_flag == True:
86 | out += self.shortcut(x)
87 |
88 | out = F.relu(out)
89 | return out
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Official website of 'What do neural networks learn in image classification? A frequency shortcut perspective (Paper accepted at ICCV2023)'
2 | #### The paper is available on [arXiv](https://arxiv.org/abs/2307.09829).
3 |
4 | ### Introduction
5 |
6 | Frequency analysis is useful for understanding the mechanisms of representation learning in neural networks (NNs). Most research in this area focuses on the learning dynamics of NNs for regression tasks, while little for classification. This study empirically investigates the latter and expands the understanding of frequency shortcuts. First, we perform experiments on synthetic datasets, designed to have a bias in different frequency bands. Our results demonstrate that NNs tend to find simple solutions for classification, and what they learn first during training depends on the most distinctive frequency characteristics, which can be either low- or high-frequencies. Second, we confirm this phenomenon on natural images. We
7 | propose a metric to measure class-wise frequency characteristics and a method to identify frequency shortcuts. The results show that frequency shortcuts can be texturebased
8 | or shape-based, depending on what best simplifies the objective. Third, we validate the transferability of frequency shortcuts on out-of-distribution (OOD) test
9 | sets. Our results suggest that frequency shortcuts can be transferred across datasets and cannot be fully avoided by larger model capacity and data augmentation. We recommend
10 | that future research should focus on effective training schemes mitigating frequency shortcut learning.
11 |
12 |

13 |
14 |
15 | ### Quick start
16 |
17 | * Clone this repository:
18 | ```
19 | git clone https://github.com/nis-research/nn-frequency-shortcuts.git
20 | cd nn-frequency-shortcuts
21 | ```
22 |
23 | * Installation
24 | * Python 3.9.12, cuda-11.7, cuda-11.x_cudnn-8.6
25 | * You can create a virtual environment with conda and activate the environment before the next step
26 | ```
27 | conda create -n virtualenv python=3.9 anaconda
28 | source activate virtualenv
29 | conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
30 | ```
31 | * Install other packages
32 | ```
33 | pip install -r requirements.txt
34 | ```
35 | * Datasets can be [dowloaded](https://drive.google.com/drive/folders/1Ug4WDwQWlFJpdks1woSsY6gWuSMYzNSB?usp=sharing) here
36 | * Computing ADCS
37 | ```
38 | python -u Evaluation/compute_ADCS.py
39 | ```
40 |
41 | * Computing DFM, e.g.
42 |
43 | ```
44 | python -u Evaluation/test_rank.py --backbone_model resnet18 --model_path /checkpoints/last.ckpt --patch_size 1
45 | ```
46 |
47 | * Visualizing DFMs, e.g.
48 | ```
49 | python -u Evaluation/visualize_DFM.py --DFMs resnet18_DFM_1
50 | ```
51 | * Testing on DFM-filtered datasets, e.g.
52 | ```
53 | python -u Evaluation/verify_mask_imgn.py --backbone_model resnet18 --m_path ./DFMs/resnet18_DFM_1 --model_path /checkpoints/last.ckpt
54 |
55 | ```
56 |
57 |
58 |
59 |
60 | * Training models, e.g.
61 | ```
62 | python -u train.py --backbone_model resnet18 --lr 0.01 --dataset imagenet10 --save_dir results/ --image_size 224 --num_class 10
63 | ```
64 | * Options for `--dataset`: synthetic, imagenet10
65 | * Options for `--image_size`: 32, 224
66 | * Options for `--num_class`: 4, 10
67 | * There are four synthetic datasets, choosing dataset by adding arguement `--special _complex_special_1_par` for `Syn_1`, `--special _complex_special_2_par` for `Syn_2`, etc.
68 |
69 | ## Citation
70 |
71 | ```
72 | @InProceedings{wang2023neural,
73 | title={What do neural networks learn in image classification? A frequency shortcut perspective},
74 | author={Shunxin Wang and Raymond Veldhuis and Christoph Brune and Nicola Strisciuglio},
75 | booktitle = {International Conference on Computer Vision (ICCV)},
76 | year = {2023},
77 | }
78 | ```
79 |
80 |
--------------------------------------------------------------------------------
/Evaluation/verify_mask_imgn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torchvision.transforms import transforms
4 | import torch.fft as fft
5 | import argparse
6 | from torchmetrics import ConfusionMatrix
7 | from torchvision.datasets import ImageFolder
8 | import pickle
9 |
10 | import sys
11 | sys.path.insert(0,'/home/wangs1/nn-frequency-shortcuts/')
12 | from train import Model
13 |
14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15 | print(device)
16 |
17 | def main(args):
18 |
19 | model_path = args.model_path
20 |
21 | if args.backbone_model == 'resnet18':
22 | from blocks.resnet.Blocks import BasicBlock
23 | elif args.backbone_model == 'resnet50':
24 | from blocks.resnet.Blocks import Bottleneck
25 |
26 | model = Model.load_from_checkpoint(model_path)
27 | model.to(device)
28 | model.eval()
29 | model.freeze()
30 | encoder = model.backbone_model
31 |
32 | confmat = ConfusionMatrix(num_classes=10)
33 | # model performance on original dataset
34 | mean = [0.479838, 0.470448, 0.429404]
35 | std = [0.258143, 0.252662, 0.272406]
36 | transform=transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(),transforms.Normalize(mean, std)])
37 | data_test = ImageFolder('./data/ImageNet/val/',transform=transform)
38 |
39 |
40 | test_loader = torch.utils.data.DataLoader(data_test, batch_size= 16, shuffle=False,num_workers=2)
41 | total = 0
42 | Matrix2 = torch.zeros((10,10))
43 | for x, y in test_loader:
44 | x, y = x.to(device), y.to(device)
45 | y_hat = encoder(x)
46 | total += y.size(0)
47 | Matrix2 += confmat(y_hat.cpu(), y.cpu())
48 | print('Confusion Metrix on testing set:')
49 | print(Matrix2)
50 |
51 | for mask_i in range(10):
52 | print('TP_f/P -- class %d' % mask_i)
53 | delta1 = (Matrix2[mask_i,mask_i])/sum(Matrix2[mask_i,:])
54 | print(delta1)
55 |
56 | print('FP_f/N -- class %d' % mask_i)
57 | delta2 = (sum(Matrix2[:,mask_i])-Matrix2[mask_i,mask_i])/(sum(sum(Matrix2))-sum(Matrix2[mask_i,:]))
58 | print(delta2)
59 |
60 | # model performance on DFM-filtered datasets
61 | batchsize = 16
62 | testset = ImageFolder('./data/ImageNet/val/',transform=transform)
63 |
64 | test_loader = torch.utils.data.DataLoader(dataset=testset, batch_size=batchsize, shuffle=False)
65 |
66 | with open(args.m_path+'.pkl', 'rb') as f:
67 | all = pickle.load(f)
68 |
69 | for mask_i in all:
70 | print('Using mask %d' %mask_i)
71 | mask = np.array(all[mask_i]) #map
72 | print(len(mask[mask==1]))
73 | mat = torch.zeros((10,10))
74 | for x,y in test_loader:
75 | size = x.size()
76 | x1=x
77 | y1 = torch.zeros(size,dtype=torch.complex128)
78 | y1 = fft.fftshift(fft.fft2(x1))
79 | for num_s in range(size[0]):
80 | for channel in range(3):
81 | y1[num_s,channel,:,:] = y1[num_s,channel,:,:] * mask
82 |
83 | x1 = fft.ifft2(fft.ifftshift(y1))
84 | x1 = torch.real(x1)
85 | x1 = torch.Tensor(x1).to(device)
86 | y_hat = encoder(x1)
87 | mat += confmat(y_hat.cpu(), y.cpu())
88 |
89 | print(mat)
90 |
91 | print('TP_f/P -- class %d' % mask_i)
92 | delta1 = (mat[mask_i,mask_i])/sum(Matrix2[mask_i,:])
93 | print(delta1)
94 |
95 | print('FP_f/N -- class %d' % mask_i)
96 | delta2 = (sum(mat[:,mask_i])-mat[mask_i,mask_i])/(sum(sum(Matrix2))-sum(Matrix2[mask_i,:]))
97 | print(delta2)
98 |
99 | if __name__ == '__main__':
100 | parser = argparse.ArgumentParser()
101 | parser.add_argument('--backbone_model', type=str, default='resnet18',
102 | help='model ')
103 | parser.add_argument('--model_path', type=str, default='None',
104 | help='path of the model')
105 | parser.add_argument('--m_path', type=str, default='./',
106 | help='path of the msk')
107 |
108 |
109 | args = parser.parse_args()
110 |
111 | main(args)
112 |
--------------------------------------------------------------------------------
/Evaluation/test_rank.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | from torchvision.transforms import transforms
5 | from torchvision.datasets import ImageFolder
6 | import torch.fft as fft
7 | import argparse
8 | from torchmetrics import ConfusionMatrix
9 | import pickle
10 | import numpy as np
11 |
12 | import sys
13 | sys.path.insert(0,'/home/wangs1/nn-frequency-shortcuts/')
14 | from train import Model
15 |
16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17 | print(device)
18 |
19 | def main(args):
20 | model_path = args.model_path
21 | dir = './DFMs/'
22 | if args.backbone_model == 'resnet18':
23 | from blocks.resnet.Blocks import BasicBlock
24 | elif args.backbone_model == 'resnet50':
25 | from blocks.resnet.Blocks import Bottleneck
26 |
27 |
28 |
29 | model = Model.load_from_checkpoint(model_path)
30 | model.to(device)
31 | model.eval()
32 | model.freeze()
33 | encoder = model.backbone_model
34 |
35 | confmat = ConfusionMatrix(num_classes=10)
36 | size = 224
37 | transform=transforms.Compose([transforms.Resize((size,size)), transforms.ToTensor(),transforms.Normalize([0.479838, 0.470448, 0.429404], [0.258143, 0.252662, 0.272406])])
38 | # Model performance on the original test set
39 | Matrix1 = torch.zeros((10,10))
40 | data_test = ImageFolder('./data/ImageNet/val/',transform=transform)
41 | test_loader = torch.utils.data.DataLoader(data_test, batch_size= 32, shuffle=False,num_workers=4)
42 | for x, y in test_loader:
43 | x, y = x.to(device), y.to(device)
44 | y_hat = encoder(x)
45 | Matrix1 += confmat(y_hat.cpu(), y.cpu())
46 | print(Matrix1)
47 |
48 | # Testing importance of each frequency
49 | batchsize = 100
50 | test_loader = torch.utils.data.DataLoader(dataset=data_test, batch_size=batchsize, shuffle=False)
51 | result_prediction = {}
52 | result_loss = {}
53 | criterion1 = nn.CrossEntropyLoss()
54 | for test_class in ([0,1,2,3,4,5,6,7,8,9]):
55 | prection_matrix = torch.zeros(size,size)
56 | loss_matrix = torch.zeros(size,size)
57 | patch_size = args.patch_size
58 | image_size = 224
59 |
60 | for r in range(int(image_size/patch_size)):
61 |
62 | for c in range(int(image_size/patch_size/2)+1):
63 | mask = torch.ones((image_size,image_size))
64 | #mask[patch_size*0:patch_size*(0+1),int(image_size/2+patch_size):]=0
65 |
66 | mask[patch_size*r:patch_size*(r+1),patch_size*c:patch_size*(c+1)] = 0
67 | if int(image_size/patch_size)-r0:
79 |
80 | y1 = torch.zeros(sizex,dtype=torch.complex128)
81 | y1 = fft.fftshift(fft.fft2(x1))
82 | for num_s in range(sizex[0]):
83 | for channel in range(3):
84 | y1[num_s,channel,:,:] = y1[num_s,channel,:,:] * mask
85 |
86 | x1 = fft.ifft2(fft.ifftshift(y1))
87 | x1 = torch.real(x1)
88 | x1 = torch.Tensor(x1).to(device)
89 |
90 | y_hat = encoder(x1)
91 | _, predicted = torch.max(y_hat.data,1)
92 |
93 | correct_predictions = (predicted == y.to(device))
94 | correct_predictions = correct_predictions.int()
95 |
96 | # selecting images of the corresponding class
97 | tested_classes = (y.to(device) == reference_class.to(device))
98 | tested_classes = tested_classes.int()
99 |
100 | # correct += (tested_classes*correct_predictions).sum().item()
101 | tc = torch.unsqueeze(tested_classes,1)
102 | test_cla = torch.cat((tc,tc,tc,tc,tc,tc,tc,tc,tc,tc),1).to(device)
103 |
104 | loss += criterion1(test_cla*y_hat,tested_classes*y.to(device))
105 |
106 |
107 | # prection_matrix[patch_size*r:patch_size*(r+1),patch_size*c:patch_size*(c+1)] = correct/50.0
108 | loss_matrix[patch_size*r:patch_size*(r+1),patch_size*c:patch_size*(c+1)] = loss
109 | if int(image_size/patch_size)-r=t] = 1
137 |
138 | mask_of_rank_th.update({mask_i:map})
139 | with open(dir+args.backbone_model+'_DFM_'+str(int(th*100))+'.pkl', 'wb') as f:
140 | pickle.dump(mask_of_rank_th, f)
141 | f.close()
142 |
143 |
144 | if __name__ == '__main__':
145 | parser = argparse.ArgumentParser()
146 | parser.add_argument('--backbone_model', type=str, default='resnet18',
147 | help='model ')
148 | parser.add_argument('--model_path', type=str, default='None',
149 | help='path of the model')
150 | parser.add_argument('--patch_size', type=int, default=1,
151 | help='patch_size')
152 |
153 | args = parser.parse_args()
154 |
155 | if not os.path.exists('./DFMs'):
156 | os.makedirs('./DFMs')
157 |
158 | main(args)
159 |
160 |
161 |
162 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import argparse
4 | import torchvision.transforms as transforms
5 | from torchvision.datasets import ImageFolder
6 | import os
7 | from pytorch_lightning.core.lightning import LightningModule
8 | import pytorch_lightning as pl
9 | from pytorch_lightning.callbacks import ModelCheckpoint
10 | import torchmetrics
11 | import timm
12 | from torch.optim.lr_scheduler import ReduceLROnPlateau
13 | from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
14 | import sys
15 | sys.path.insert(0,'/home/wangs1/nn-frequency-shortcuts/')
16 | from data.Synthetic import Synthetic
17 | import backbone.resnet as resnet
18 | import backbone.vgg as vgg
19 | import backbone.alexnet as alexnet
20 |
21 |
22 | class Model(LightningModule):
23 | def __init__(self,backbone_model, lr,num_class,dataset,image_size,special):
24 | super(Model, self).__init__()
25 | self.save_hyperparameters()
26 | self.lr = lr
27 | self.train_acc = torchmetrics.Accuracy()
28 | self.val_acc = torchmetrics.Accuracy()
29 | self.test_acc = torchmetrics.Accuracy()
30 | self.dataset = dataset
31 | self.num_class = num_class
32 | self.image_size = image_size
33 | self.backbone_model = backbone_model
34 | self.special = special
35 |
36 | def forward(self, x):
37 | # enc, prediction = self.backbone_model(x)
38 | prediction = self.backbone_model(x)
39 |
40 | return prediction
41 |
42 |
43 | def configure_optimizers(self):
44 | optimizer = torch.optim.SGD(self.parameters(), self.lr,
45 | momentum=0.9, nesterov=True,
46 | weight_decay=1e-4)
47 | scheduler = ReduceLROnPlateau(optimizer, mode='min',verbose=True, factor=0.1)
48 | return {'optimizer': optimizer,
49 | 'lr_scheduler':scheduler,
50 | 'monitor': 'val_loss'}
51 |
52 | def training_step(self, batch, batch_idx):
53 | x, y = batch
54 |
55 | criterion1 = nn.CrossEntropyLoss()
56 |
57 | # _, y_hat = self(x)
58 | y_hat = self(x)
59 | #print(y_hat)
60 | loss1 = criterion1(y_hat, y)
61 | loss = loss1
62 |
63 | _, predicted = torch.max(y_hat.data,1)
64 | self.log_dict({'train_classification_loss': loss1}, on_epoch=True,on_step=True)
65 | self.log_dict({'train_loss': loss}, on_epoch=True,on_step=True)
66 | return {"loss": loss,'epoch_preds': predicted, 'epoch_targets': y}
67 |
68 | def validation_step(self, batch, batch_idx):
69 | x, y = batch
70 | criterion1 = nn.CrossEntropyLoss()
71 |
72 | # _, y_hat = self(x)
73 | y_hat = self(x)
74 | # print(y_hat)
75 | # print(y_hat.size())
76 | loss1 = criterion1(y_hat, y)
77 | self.val_loss = loss1
78 |
79 | _, predicted = torch.max(y_hat.data,1)
80 | self.log_dict( {'val_loss': self.val_loss}, on_epoch=True,on_step=True)
81 |
82 | return {'epoch_preds': predicted, 'epoch_targets': y} #self.val_loss
83 |
84 | def test_step(self, batch, batch_idx):
85 | x, y = batch
86 | # _, y_hat = self(x)
87 | y_hat = self(x)
88 | # print(y_hat.size())
89 |
90 | _, predicted = torch.max(y_hat.data,1)
91 |
92 | return {'batch_preds': predicted, 'batch_targets': y}
93 |
94 |
95 | def test_step_end(self, output_results):
96 |
97 | self.test_acc(output_results['batch_preds'], output_results['batch_targets'])
98 | self.log_dict( {'test_acc': self.test_acc}, on_epoch=True,on_step=False)
99 |
100 | def training_epoch_end(self, output_results):
101 | # print(output_results)
102 | self.train_acc(output_results[0]['epoch_preds'], output_results[0]['epoch_targets'])
103 | self.log_dict({"train_acc": self.train_acc}, on_epoch=True, on_step=False)
104 |
105 | def validation_epoch_end(self, output_results):
106 | # print(output_results)
107 | self.val_acc(output_results[0]['epoch_preds'], output_results[0]['epoch_targets'])
108 | self.log_dict({"valid_acc": self.val_acc}, on_epoch=True, on_step=False)
109 | # print(acc)
110 | # return val_accuracy
111 |
112 | def setup(self, stage):
113 | if self.dataset == 'synthetic':
114 | transform_train = transforms.Compose([
115 | transforms.Pad(4),
116 | transforms.RandomHorizontalFlip(),
117 | transforms.RandomResizedCrop(self.image_size),
118 | transforms.ToTensor(),
119 | transforms.Normalize([0.498, 0.498, 0.498], [0.172, 0.173042, 0.173])
120 | # normalize
121 | ])
122 | transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.498, 0.498, 0.498], [0.172, 0.173042, 0.173])])
123 | data_train = Synthetic('./data',train=True,complex=self.special, transform=transform_train,band = '')
124 | data_test = Synthetic('./data',train=False,complex=self.special, transform=transform,band = '')
125 | elif self.dataset == 'imagenet10':
126 | transform_train = transforms.Compose([
127 | transforms.Pad(4),
128 | transforms.RandomHorizontalFlip(),
129 | transforms.RandomResizedCrop(self.image_size),
130 | # transforms.AugMix(),# transforms.AutoAugment(), # change here to add other augmentations
131 | transforms.ToTensor(),
132 | transforms.Normalize([0.479838, 0.470448, 0.429404], [0.258143, 0.252662, 0.272406])
133 | # normalize
134 | ])
135 | transform=transforms.Compose([transforms.Resize((self.image_size,self.image_size)), transforms.ToTensor(),transforms.Normalize([0.479838, 0.470448, 0.429404], [0.258143, 0.252662, 0.272406])])
136 | data_train = ImageFolder('./data/ImageNet/train/',transform=transform_train)
137 | data_test = ImageFolder('./data/ImageNet/val/',transform=transform)
138 | elif self.dataset == 'imagenet10_style':
139 | transform_train = transforms.Compose([
140 | transforms.Pad(4),
141 | transforms.RandomHorizontalFlip(),
142 | transforms.RandomResizedCrop(self.image_size),
143 | transforms.ToTensor(),
144 | transforms.Normalize([0.479838, 0.470448, 0.429404], [0.258143, 0.252662, 0.272406])
145 | ])
146 | transform=transforms.Compose([transforms.Resize((self.image_size,self.image_size)), transforms.ToTensor(),transforms.Normalize([0.479838, 0.470448, 0.429404], [0.258143, 0.252662, 0.272406])])
147 | data_train = ImageFolder('./data/ImageNet_style/train/',transform=transform_train)
148 | data_test = ImageFolder('./data/ImageNet_style/val/',transform=transform)
149 |
150 | # train/val split
151 | data_train2, data_val = torch.utils.data.random_split(data_train, [int(len(data_train)*0.9), len(data_train)-int(len(data_train)*0.9)])
152 |
153 | # assign to use in dataloaders
154 | self.train_dataset = data_train2
155 | self.val_dataset = data_val
156 | self.test_dataset = data_test
157 |
158 |
159 | def train_dataloader(self):
160 | return torch.utils.data.DataLoader(self.train_dataset, batch_size=64, shuffle=True)#,num_workers=2)
161 |
162 | def test_dataloader(self):
163 | return torch.utils.data.DataLoader(self.test_dataset, batch_size=64, shuffle=False)#,num_workers=2)
164 |
165 | def val_dataloader(self):
166 | return torch.utils.data.DataLoader(self.val_dataset, batch_size=64)#,num_workers=2)
167 |
168 |
169 | def main(args):
170 | backbone = ['resnet18', 'resnet34', 'resnet50','resnet101', 'alex', 'ViT', 'vgg16']
171 | print(torch.cuda.device_count())
172 | if args.backbone_model == 'resnet18':
173 | from blocks.resnet.Blocks import BasicBlock
174 | backbone_model = resnet.ResNet(BasicBlock,[2,2,2,2],args.num_class)
175 | elif args.backbone_model == 'resnet34':
176 | from blocks.resnet.Blocks import BasicBlock
177 | backbone_model = resnet.ResNet(BasicBlock, [3,4,6,3],args.num_class)
178 | elif args.backbone_model == 'resnet50':
179 | from blocks.resnet.Blocks import Bottleneck
180 | backbone_model = resnet.ResNet(Bottleneck,[3,4,6,3],args.num_class)
181 | elif args.backbone_model == 'resnet101':
182 | from blocks.resnet.Blocks import Bottleneck
183 | backbone_model = resnet.ResNet(Bottleneck[3,4,23,3],args.num_class)
184 | elif args.backbone_model == 'alex':
185 | backbone_model = alexnet.AlexNet(args.num_class)
186 | elif args.backbone_model == 'ViT':
187 | backbone_model = timm.create_model('vit_base_patch8_224', pretrained=False)
188 |
189 |
190 | logger = TensorBoardLogger(args.save_dir, name=args.backbone_model)
191 |
192 | model = Model(backbone_model, args.lr,args.num_class,args.dataset,args.image_size, args.special)
193 | maxepoch = 200
194 | checkpoints_callback = ModelCheckpoint(save_last=True,save_top_k=-1)
195 | trainer = pl.Trainer(enable_progress_bar=False,logger=logger, callbacks=[checkpoints_callback], gpus=-1, max_epochs=maxepoch) # accelerator='dp',
196 | trainer.fit(model)
197 | trainer.test()
198 |
199 |
200 |
201 |
202 |
203 | if __name__ == '__main__':
204 | parser = argparse.ArgumentParser(description='Write parameters')
205 | parser.add_argument('--backbone_model', type=str,
206 | help='backbone_model')
207 | parser.add_argument('--image_size', type=int, default= 32,
208 | help='size of images in dataset')
209 | parser.add_argument('--num_class', type=int, default= 10,
210 | help='number of classes in dataset')
211 | parser.add_argument('--dataset', type=str, default='imagenet10',
212 | help='dataset')
213 | parser.add_argument('--lr', type=float, default=0.001,
214 | help='learning rate')
215 | parser.add_argument('--save_dir', type=str, default='results/')
216 | parser.add_argument('--special', required=False, default=None,
217 | help='selecting synthetic dataset')
218 |
219 | args = parser.parse_args()
220 | if not os.path.exists(args.save_dir+'/'+args.backbone_model):
221 | os.makedirs(args.save_dir+'/'+args.backbone_model)
222 | print('make the directory')
223 |
224 | main(args)
--------------------------------------------------------------------------------