├── requirements.txt ├── figures └── teaser.png ├── Evaluation ├── fms_224.pkl ├── visualize_DFM.py ├── compute_ADCS.py ├── verify_mask_imgn.py └── test_rank.py ├── LICENSE ├── backbone ├── vgg.py ├── alexnet.py └── resnet.py ├── data └── Synthetic.py ├── blocks └── resnet │ └── Blocks.py ├── README.md └── train.py /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy 2 | matplotlib 3 | pytorch_lightning 4 | torchmetrics 5 | timm -------------------------------------------------------------------------------- /figures/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nis-research/nn-frequency-shortcuts/HEAD/figures/teaser.png -------------------------------------------------------------------------------- /Evaluation/fms_224.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nis-research/nn-frequency-shortcuts/HEAD/Evaluation/fms_224.pkl -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 nis-research 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Evaluation/visualize_DFM.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import argparse 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | def main(args): 7 | dir = './DFMs/' 8 | imagenet_classes = ['Airliner','Wagon','Humming\n Bird','Siamese\n Cat','Ox','Golden\n Retriever','Tailed\n Frog','Zebra','Container\n Ship','Trailer\n Truck'] 9 | m_path = args.DFMs+'.pkl' 10 | with open(dir+m_path, 'rb') as f: 11 | all_mask = pickle.load(f) 12 | fig, axs = plt.subplots(1,10,sharex=True,sharey=True) 13 | fig.set_figheight(15) 14 | fig.set_figwidth(15) 15 | for mask_i in range(len(all_mask)): 16 | map = np.array(all_mask[mask_i]) 17 | axs[mask_i].imshow(map,cmap='gray') 18 | axs[mask_i].set_title(imagenet_classes[mask_i]) 19 | axs[mask_i].set_yticks([]) 20 | axs[mask_i].set_xticks([]) 21 | 22 | plt.subplots_adjust(left=0.2, bottom=0.1, right=0.8, top=0.8, wspace=0.05, hspace=-0.85) 23 | plt.savefig(dir + args.DFMs + '.pdf',bbox_inches='tight') 24 | 25 | 26 | 27 | if __name__ == '__main__': 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('--DFMs', type=str, default='resnet18_DFM_1', 30 | help='File name of DFMs') 31 | 32 | args = parser.parse_args() 33 | 34 | main(args) -------------------------------------------------------------------------------- /backbone/vgg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | cfg = { 4 | 'VGG8': [64, 'M', 128, 'M', 256, 'M', 512, 'M', 512, 'M'], 5 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 6 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 7 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 8 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 9 | } 10 | 11 | 12 | class VGG(nn.Module): 13 | def __init__(self, vgg_name,num_class): 14 | super(VGG, self).__init__() 15 | self.features = self._make_layers(cfg[vgg_name]) 16 | self.classifier = nn.Linear(self.in_planes, num_class) 17 | 18 | def forward(self, x): 19 | out = self.features(x) 20 | enc = F.avg_pool2d(out, out.size(2)) 21 | enc = enc.view(enc.size(0), -1) 22 | prediction = self.classifier(enc) 23 | return prediction 24 | 25 | def _make_layers(self, cfg): 26 | layers = [] 27 | in_channels = 3 28 | for x in cfg: 29 | if x == 'M': 30 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 31 | else: 32 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 33 | nn.BatchNorm2d(x), 34 | nn.ReLU(inplace=True)] 35 | in_channels = x 36 | # layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 37 | self.in_planes = in_channels 38 | 39 | return nn.Sequential(*layers) 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /backbone/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class AlexNet(nn.Module): 5 | def __init__(self, num_classes=11): 6 | super(AlexNet, self).__init__() 7 | self.features = nn.Sequential( 8 | nn.Conv2d(3, 96, kernel_size=3, stride=2, padding=1,bias=False), 9 | 10 | nn.ReLU(inplace=True), 11 | nn.BatchNorm2d(96), 12 | nn.MaxPool2d(kernel_size=2), 13 | 14 | nn.Conv2d(96, 256, kernel_size=3, padding=1,bias=False), 15 | nn.ReLU(inplace=True), 16 | nn.BatchNorm2d(256), 17 | nn.MaxPool2d(kernel_size=2), 18 | 19 | nn.Conv2d(256, 384, kernel_size=3, padding=1,bias=False), 20 | nn.ReLU(inplace=True), 21 | nn.BatchNorm2d(384), 22 | 23 | nn.Conv2d(384, 384, kernel_size=3, padding=1,bias=False), 24 | nn.ReLU(inplace=True), 25 | nn.BatchNorm2d(384), 26 | 27 | nn.Conv2d(384, 256, kernel_size=3, padding=1,bias=False), 28 | nn.ReLU(inplace=True), 29 | nn.BatchNorm2d(256), 30 | nn.MaxPool2d(kernel_size=2), 31 | ) 32 | self.classifier = nn.Sequential( 33 | nn.Dropout(), 34 | nn.Linear(256 * 2 * 2, 4096,bias=False), 35 | nn.ReLU(inplace=True), 36 | nn.Dropout(), 37 | nn.Linear(4096, 4096,bias=False), 38 | nn.ReLU(inplace=True), 39 | nn.Linear(4096, num_classes,bias=False) 40 | ) 41 | 42 | def forward(self, x): 43 | 44 | x = self.features(x.float()) 45 | enc = x.view(x.size(0), 256 * 2 * 2) 46 | prediction = self.classifier(enc) 47 | out = enc 48 | return prediction -------------------------------------------------------------------------------- /data/Synthetic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data.dataset import Dataset 4 | from PIL import Image 5 | import numpy as np 6 | import torch 7 | 8 | class Synthetic(Dataset): 9 | def __init__(self, root_dir,train=True, transform = None, complex = '', band = '',t=True): 10 | super(Synthetic).__init__() 11 | 12 | if train is False: 13 | 14 | self.labels_path = os.path.join(root_dir,'synthetic','test_label'+complex +'.npy') 15 | self.root_dir = os.path.join(root_dir,'synthetic','test_data'+complex +band+'.npy') 16 | else: 17 | self.labels_path = os.path.join(root_dir,'synthetic','train_label'+complex+'.npy') 18 | self.root_dir = os.path.join(root_dir,'synthetic','train_data'+complex+band+'.npy') 19 | 20 | 21 | print(self.root_dir) 22 | self.transform = transform 23 | self.data = np.load(self.root_dir, allow_pickle=True) 24 | self.targets = np.load(self.labels_path, allow_pickle=True) 25 | self.band = band 26 | self.t = t 27 | # self.data = self.data.transpose((0, 3, 1, 2)) 28 | 29 | 30 | def __len__(self): 31 | return len(self.data) 32 | 33 | def __getitem__(self, index): 34 | if self.t: 35 | img = self.data[index].permute(1,2,0).numpy() 36 | else: 37 | img = self.data[index] 38 | 39 | if self.transform is not None: 40 | # print(np.max(img)) 41 | # print(np.min(img)) 42 | img = np.clip(img,0,1) 43 | img = img*255 44 | img = Image.fromarray(img.astype(np.uint8),mode='RGB') 45 | 46 | img = self.transform(img) 47 | 48 | target = self.targets[index] 49 | 50 | return img, torch.tensor(target, dtype=torch.long) -------------------------------------------------------------------------------- /backbone/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | 7 | class ResNet(nn.Module): 8 | def __init__(self, block_en, num_blocks,num_class): 9 | super(ResNet, self).__init__() 10 | 11 | self.in_planes = 64 12 | self.num_class = num_class 13 | 14 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, 15 | stride=1, padding=1, bias=False) 16 | self.bn1 = nn.BatchNorm2d(64) 17 | self.relu = nn.ReLU(inplace=True) 18 | self.layer1 = self._make_layer(block_en, 64, num_blocks[0], stride=1) 19 | self.layer2 = self._make_layer(block_en, 128, num_blocks[1], stride=2) 20 | self.layer3 = self._make_layer(block_en, 256, num_blocks[2], stride=2) 21 | self.layer4 = self._make_layer(block_en, 512, num_blocks[3], stride=2) 22 | 23 | self.features = nn.Sequential(self.conv1, self.bn1,self.relu,self.layer1,self.layer2,self.layer3,self.layer4) 24 | 25 | 26 | # self.sm = nn.Softmax(dim=1) 27 | self.classifier = nn.Linear(512*block_en.expansion,self.num_class) 28 | 29 | def _make_layer(self, block, planes, num_blocks, stride): 30 | strides = [stride] + [1]*(num_blocks-1) 31 | layers = [] 32 | for stride in strides: 33 | layers.append(block(self.in_planes, planes, stride,shortcut=True)) 34 | self.in_planes = planes*block.expansion 35 | return nn.Sequential(*layers) 36 | 37 | 38 | def forward(self, x): 39 | # out = F.relu(self.bn1(self.conv1(x))) 40 | enc = self.features(x) 41 | # print(enc.size()) 42 | # out = self.layer1(out) 43 | # out = self.layer2(out) 44 | # out = self.layer3(out) 45 | # out = self.layer4(out) 46 | enc = F.avg_pool2d(enc, enc.size(2)) 47 | enc = enc.view(enc.size(0), -1) # flatten 48 | 49 | prediction = self.classifier(enc) 50 | # prediction = self.sm(prediction) 51 | return prediction 52 | -------------------------------------------------------------------------------- /Evaluation/compute_ADCS.py: -------------------------------------------------------------------------------- 1 | from scipy import signal 2 | from scipy.ndimage import gaussian_filter 3 | import numpy.fft as fft 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import torchvision 7 | import torch 8 | import torchvision.transforms as transforms 9 | from torchvision.datasets import ImageFolder 10 | 11 | def distance(i, j, imageSize, r1,r2): 12 | dis = np.sqrt((i - imageSize/2) ** 2 + (j - imageSize/2) ** 2) 13 | if dis < r2 and dis >=r1: 14 | return 1.0 15 | else: 16 | return 0 17 | 18 | def mask_radial(img, r1,r2): 19 | rows, cols = img.shape 20 | mask = np.zeros((rows, cols)) 21 | for i in range(rows): 22 | for j in range(cols): 23 | mask[i, j] = distance(i, j, imageSize=rows, r1=r1,r2=r2) 24 | return mask 25 | 26 | 27 | def rgb2gray(rgb): 28 | r, g, b = rgb[:,:,0], rgb[:,:,1], rgb[:,:,2] 29 | gray = 0.2989 * r + 0.5870 * g + 0.1140 * b 30 | return gray 31 | 32 | 33 | 34 | Energy = {} 35 | 36 | mean = [0.479838, 0.470448, 0.429404] 37 | std = [0.258143, 0.252662, 0.272406] 38 | transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(),transforms.Normalize(mean, std)]) 39 | batchsize = 1 40 | data_test = ImageFolder('./datasets/ImageNet/val/',transform=transform) # data path to be changed 41 | test_loader = torch.utils.data.DataLoader(dataset=data_test, batch_size=batchsize, shuffle=False) 42 | img_size =224 43 | for x,y in test_loader: 44 | x1=x[0] 45 | y1 = np.zeros((img_size,img_size,3),dtype=np.complex128) 46 | for j in range(3): 47 | y1[:,:,j] = fft.fftshift(fft.fft2(x1[j,:,:])) 48 | y1[y1==0] = 12e-12 49 | abs_y1 = np.abs(y1) 50 | if y.item() in Energy: 51 | Energy[y.item()] += abs_y1 52 | else: 53 | Energy.update({y.item():abs_y1}) 54 | 55 | 56 | fig, axs = plt.subplots(2,5,sharex=True,sharey=True) 57 | fig.set_figheight(8) 58 | fig.set_figwidth(20) 59 | for j in range(10): 60 | olp = np.zeros((img_size,img_size)) 61 | for i in range(10): 62 | diff = Energy[j] - Energy[i] 63 | diff = rgb2gray(diff) 64 | diff[diff>0] = 1 65 | diff[diff<=0] = -1 66 | olp += diff 67 | if j >=5: 68 | axs[1,j-5].imshow(olp,cmap='jet',vmin=-9,vmax=9) 69 | axs[1,j-5].axis('off') 70 | axs[1,j-5].set_title('Class: %d' %j) 71 | else: 72 | axs[0,j].imshow(olp,cmap='jet',vmin=-9,vmax=9) 73 | axs[0,j].axis('off') 74 | axs[0,j].set_title('Class: %d' %j) 75 | plt.rcParams.update({'font.size': 25}) 76 | plt.savefig('ADCS_imagenet10.pdf') -------------------------------------------------------------------------------- /blocks/resnet/Blocks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | class BasicBlock(nn.Module): 5 | expansion = 1 6 | 7 | def __init__(self, in_planes, planes, stride=1,shortcut=True): 8 | super(BasicBlock, self).__init__() 9 | self.conv1 = nn.Conv2d( 10 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 11 | self.bn1 = nn.BatchNorm2d(planes) 12 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 13 | stride=1, padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(planes) 15 | self.shortcut_flag = shortcut 16 | self.shortcut = nn.Sequential() 17 | 18 | if stride != 1 or in_planes != self.expansion*planes: 19 | self.shortcut = nn.Sequential( 20 | nn.Conv2d(in_planes, self.expansion*planes, 21 | kernel_size=1, stride=stride, bias=False), 22 | nn.BatchNorm2d(self.expansion*planes) 23 | ) 24 | 25 | def forward(self, x): 26 | out = F.relu(self.bn1(self.conv1(x))) 27 | out = self.bn2(self.conv2(out)) 28 | if self.shortcut_flag == True: 29 | out += self.shortcut(x) 30 | 31 | out = F.relu(out) 32 | 33 | return out 34 | 35 | 36 | class Upconvblock(nn.Module): 37 | expansion = 16 38 | def __init__(self,in_planes,output_channels, stride=1): 39 | super(Upconvblock,self).__init__() 40 | if stride == 1: 41 | self.scaleup = nn.Conv2d(in_planes,output_channels,kernel_size=3,padding=1) 42 | elif stride == 2: 43 | self.scaleup = nn.ConvTranspose2d(in_planes,output_channels,kernel_size=2,stride=stride) 44 | # self.scaleup = nn.ConvTranspose2d(in_planes,output_channels,kernel_size=2,stride=2) 45 | # self.scaleup = nn.Upsample(scale_factor=stride,mode="bilinear",align_corners=True) 46 | self.conv1 = nn.Conv2d(output_channels,output_channels,stride = 1,kernel_size=3,padding=1) 47 | self.bn1 = nn.BatchNorm2d(output_channels) 48 | self.conv2 = nn.Conv2d(output_channels,output_channels,stride=1,kernel_size=3,padding=1) 49 | self.bn2 = nn.BatchNorm2d(output_channels) 50 | 51 | def forward(self, x): 52 | x = self.scaleup(x) 53 | x = F.relu(self.bn1(self.conv1(x))) 54 | x = F.relu(self.bn2(self.conv2(x))) 55 | 56 | return x 57 | 58 | 59 | class Bottleneck(nn.Module): 60 | expansion = 4 61 | 62 | def __init__(self, in_planes, planes, stride=1,shortcut=True): 63 | super(Bottleneck, self).__init__() 64 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 65 | self.bn1 = nn.BatchNorm2d(planes) 66 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,stride=stride, padding=1, bias=False) 67 | self.bn2 = nn.BatchNorm2d(planes) 68 | self.conv3 = nn.Conv2d(planes, planes*self.expansion, kernel_size=1, bias=False) 69 | self.bn3 = nn.BatchNorm2d(planes*self.expansion) 70 | self.shortcut_flag = shortcut 71 | self.shortcut = nn.Sequential() 72 | 73 | if stride != 1 or in_planes != self.expansion*planes: 74 | self.shortcut = nn.Sequential( 75 | nn.Conv2d(in_planes, self.expansion*planes, 76 | kernel_size=1, stride=stride, bias=False), 77 | nn.BatchNorm2d(self.expansion*planes) 78 | ) 79 | 80 | def forward(self, x): 81 | out = F.relu(self.bn1(self.conv1(x))) 82 | out = F.relu(self.bn2(self.conv2(out))) 83 | 84 | out = self.bn3(self.conv3(out)) 85 | if self.shortcut_flag == True: 86 | out += self.shortcut(x) 87 | 88 | out = F.relu(out) 89 | return out -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Official website of 'What do neural networks learn in image classification? A frequency shortcut perspective (Paper accepted at ICCV2023)' 2 | #### The paper is available on [arXiv](https://arxiv.org/abs/2307.09829). 3 | 4 | ### Introduction 5 | 6 | Frequency analysis is useful for understanding the mechanisms of representation learning in neural networks (NNs). Most research in this area focuses on the learning dynamics of NNs for regression tasks, while little for classification. This study empirically investigates the latter and expands the understanding of frequency shortcuts. First, we perform experiments on synthetic datasets, designed to have a bias in different frequency bands. Our results demonstrate that NNs tend to find simple solutions for classification, and what they learn first during training depends on the most distinctive frequency characteristics, which can be either low- or high-frequencies. Second, we confirm this phenomenon on natural images. We 7 | propose a metric to measure class-wise frequency characteristics and a method to identify frequency shortcuts. The results show that frequency shortcuts can be texturebased 8 | or shape-based, depending on what best simplifies the objective. Third, we validate the transferability of frequency shortcuts on out-of-distribution (OOD) test 9 | sets. Our results suggest that frequency shortcuts can be transferred across datasets and cannot be fully avoided by larger model capacity and data augmentation. We recommend 10 | that future research should focus on effective training schemes mitigating frequency shortcut learning. 11 | 12 |

13 | 14 | 15 | ### Quick start 16 | 17 | * Clone this repository: 18 | ``` 19 | git clone https://github.com/nis-research/nn-frequency-shortcuts.git 20 | cd nn-frequency-shortcuts 21 | ``` 22 | 23 | * Installation 24 | * Python 3.9.12, cuda-11.7, cuda-11.x_cudnn-8.6 25 | * You can create a virtual environment with conda and activate the environment before the next step 26 | ``` 27 | conda create -n virtualenv python=3.9 anaconda 28 | source activate virtualenv 29 | conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia 30 | ``` 31 | * Install other packages 32 | ``` 33 | pip install -r requirements.txt 34 | ``` 35 | * Datasets can be [dowloaded](https://drive.google.com/drive/folders/1Ug4WDwQWlFJpdks1woSsY6gWuSMYzNSB?usp=sharing) here 36 | * Computing ADCS 37 | ``` 38 | python -u Evaluation/compute_ADCS.py 39 | ``` 40 | 41 | * Computing DFM, e.g. 42 | 43 | ``` 44 | python -u Evaluation/test_rank.py --backbone_model resnet18 --model_path /checkpoints/last.ckpt --patch_size 1 45 | ``` 46 | 47 | * Visualizing DFMs, e.g. 48 | ``` 49 | python -u Evaluation/visualize_DFM.py --DFMs resnet18_DFM_1 50 | ``` 51 | * Testing on DFM-filtered datasets, e.g. 52 | ``` 53 | python -u Evaluation/verify_mask_imgn.py --backbone_model resnet18 --m_path ./DFMs/resnet18_DFM_1 --model_path /checkpoints/last.ckpt 54 | 55 | ``` 56 | 57 | 58 | 59 | 60 | * Training models, e.g. 61 | ``` 62 | python -u train.py --backbone_model resnet18 --lr 0.01 --dataset imagenet10 --save_dir results/ --image_size 224 --num_class 10 63 | ``` 64 | * Options for `--dataset`: synthetic, imagenet10 65 | * Options for `--image_size`: 32, 224 66 | * Options for `--num_class`: 4, 10 67 | * There are four synthetic datasets, choosing dataset by adding arguement `--special _complex_special_1_par` for `Syn_1`, `--special _complex_special_2_par` for `Syn_2`, etc. 68 | 69 | ## Citation 70 | 71 | ``` 72 | @InProceedings{wang2023neural, 73 | title={What do neural networks learn in image classification? A frequency shortcut perspective}, 74 | author={Shunxin Wang and Raymond Veldhuis and Christoph Brune and Nicola Strisciuglio}, 75 | booktitle = {International Conference on Computer Vision (ICCV)}, 76 | year = {2023}, 77 | } 78 | ``` 79 | 80 | -------------------------------------------------------------------------------- /Evaluation/verify_mask_imgn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torchvision.transforms import transforms 4 | import torch.fft as fft 5 | import argparse 6 | from torchmetrics import ConfusionMatrix 7 | from torchvision.datasets import ImageFolder 8 | import pickle 9 | 10 | import sys 11 | sys.path.insert(0,'/home/wangs1/nn-frequency-shortcuts/') 12 | from train import Model 13 | 14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | print(device) 16 | 17 | def main(args): 18 | 19 | model_path = args.model_path 20 | 21 | if args.backbone_model == 'resnet18': 22 | from blocks.resnet.Blocks import BasicBlock 23 | elif args.backbone_model == 'resnet50': 24 | from blocks.resnet.Blocks import Bottleneck 25 | 26 | model = Model.load_from_checkpoint(model_path) 27 | model.to(device) 28 | model.eval() 29 | model.freeze() 30 | encoder = model.backbone_model 31 | 32 | confmat = ConfusionMatrix(num_classes=10) 33 | # model performance on original dataset 34 | mean = [0.479838, 0.470448, 0.429404] 35 | std = [0.258143, 0.252662, 0.272406] 36 | transform=transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(),transforms.Normalize(mean, std)]) 37 | data_test = ImageFolder('./data/ImageNet/val/',transform=transform) 38 | 39 | 40 | test_loader = torch.utils.data.DataLoader(data_test, batch_size= 16, shuffle=False,num_workers=2) 41 | total = 0 42 | Matrix2 = torch.zeros((10,10)) 43 | for x, y in test_loader: 44 | x, y = x.to(device), y.to(device) 45 | y_hat = encoder(x) 46 | total += y.size(0) 47 | Matrix2 += confmat(y_hat.cpu(), y.cpu()) 48 | print('Confusion Metrix on testing set:') 49 | print(Matrix2) 50 | 51 | for mask_i in range(10): 52 | print('TP_f/P -- class %d' % mask_i) 53 | delta1 = (Matrix2[mask_i,mask_i])/sum(Matrix2[mask_i,:]) 54 | print(delta1) 55 | 56 | print('FP_f/N -- class %d' % mask_i) 57 | delta2 = (sum(Matrix2[:,mask_i])-Matrix2[mask_i,mask_i])/(sum(sum(Matrix2))-sum(Matrix2[mask_i,:])) 58 | print(delta2) 59 | 60 | # model performance on DFM-filtered datasets 61 | batchsize = 16 62 | testset = ImageFolder('./data/ImageNet/val/',transform=transform) 63 | 64 | test_loader = torch.utils.data.DataLoader(dataset=testset, batch_size=batchsize, shuffle=False) 65 | 66 | with open(args.m_path+'.pkl', 'rb') as f: 67 | all = pickle.load(f) 68 | 69 | for mask_i in all: 70 | print('Using mask %d' %mask_i) 71 | mask = np.array(all[mask_i]) #map 72 | print(len(mask[mask==1])) 73 | mat = torch.zeros((10,10)) 74 | for x,y in test_loader: 75 | size = x.size() 76 | x1=x 77 | y1 = torch.zeros(size,dtype=torch.complex128) 78 | y1 = fft.fftshift(fft.fft2(x1)) 79 | for num_s in range(size[0]): 80 | for channel in range(3): 81 | y1[num_s,channel,:,:] = y1[num_s,channel,:,:] * mask 82 | 83 | x1 = fft.ifft2(fft.ifftshift(y1)) 84 | x1 = torch.real(x1) 85 | x1 = torch.Tensor(x1).to(device) 86 | y_hat = encoder(x1) 87 | mat += confmat(y_hat.cpu(), y.cpu()) 88 | 89 | print(mat) 90 | 91 | print('TP_f/P -- class %d' % mask_i) 92 | delta1 = (mat[mask_i,mask_i])/sum(Matrix2[mask_i,:]) 93 | print(delta1) 94 | 95 | print('FP_f/N -- class %d' % mask_i) 96 | delta2 = (sum(mat[:,mask_i])-mat[mask_i,mask_i])/(sum(sum(Matrix2))-sum(Matrix2[mask_i,:])) 97 | print(delta2) 98 | 99 | if __name__ == '__main__': 100 | parser = argparse.ArgumentParser() 101 | parser.add_argument('--backbone_model', type=str, default='resnet18', 102 | help='model ') 103 | parser.add_argument('--model_path', type=str, default='None', 104 | help='path of the model') 105 | parser.add_argument('--m_path', type=str, default='./', 106 | help='path of the msk') 107 | 108 | 109 | args = parser.parse_args() 110 | 111 | main(args) 112 | -------------------------------------------------------------------------------- /Evaluation/test_rank.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from torchvision.transforms import transforms 5 | from torchvision.datasets import ImageFolder 6 | import torch.fft as fft 7 | import argparse 8 | from torchmetrics import ConfusionMatrix 9 | import pickle 10 | import numpy as np 11 | 12 | import sys 13 | sys.path.insert(0,'/home/wangs1/nn-frequency-shortcuts/') 14 | from train import Model 15 | 16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 17 | print(device) 18 | 19 | def main(args): 20 | model_path = args.model_path 21 | dir = './DFMs/' 22 | if args.backbone_model == 'resnet18': 23 | from blocks.resnet.Blocks import BasicBlock 24 | elif args.backbone_model == 'resnet50': 25 | from blocks.resnet.Blocks import Bottleneck 26 | 27 | 28 | 29 | model = Model.load_from_checkpoint(model_path) 30 | model.to(device) 31 | model.eval() 32 | model.freeze() 33 | encoder = model.backbone_model 34 | 35 | confmat = ConfusionMatrix(num_classes=10) 36 | size = 224 37 | transform=transforms.Compose([transforms.Resize((size,size)), transforms.ToTensor(),transforms.Normalize([0.479838, 0.470448, 0.429404], [0.258143, 0.252662, 0.272406])]) 38 | # Model performance on the original test set 39 | Matrix1 = torch.zeros((10,10)) 40 | data_test = ImageFolder('./data/ImageNet/val/',transform=transform) 41 | test_loader = torch.utils.data.DataLoader(data_test, batch_size= 32, shuffle=False,num_workers=4) 42 | for x, y in test_loader: 43 | x, y = x.to(device), y.to(device) 44 | y_hat = encoder(x) 45 | Matrix1 += confmat(y_hat.cpu(), y.cpu()) 46 | print(Matrix1) 47 | 48 | # Testing importance of each frequency 49 | batchsize = 100 50 | test_loader = torch.utils.data.DataLoader(dataset=data_test, batch_size=batchsize, shuffle=False) 51 | result_prediction = {} 52 | result_loss = {} 53 | criterion1 = nn.CrossEntropyLoss() 54 | for test_class in ([0,1,2,3,4,5,6,7,8,9]): 55 | prection_matrix = torch.zeros(size,size) 56 | loss_matrix = torch.zeros(size,size) 57 | patch_size = args.patch_size 58 | image_size = 224 59 | 60 | for r in range(int(image_size/patch_size)): 61 | 62 | for c in range(int(image_size/patch_size/2)+1): 63 | mask = torch.ones((image_size,image_size)) 64 | #mask[patch_size*0:patch_size*(0+1),int(image_size/2+patch_size):]=0 65 | 66 | mask[patch_size*r:patch_size*(r+1),patch_size*c:patch_size*(c+1)] = 0 67 | if int(image_size/patch_size)-r0: 79 | 80 | y1 = torch.zeros(sizex,dtype=torch.complex128) 81 | y1 = fft.fftshift(fft.fft2(x1)) 82 | for num_s in range(sizex[0]): 83 | for channel in range(3): 84 | y1[num_s,channel,:,:] = y1[num_s,channel,:,:] * mask 85 | 86 | x1 = fft.ifft2(fft.ifftshift(y1)) 87 | x1 = torch.real(x1) 88 | x1 = torch.Tensor(x1).to(device) 89 | 90 | y_hat = encoder(x1) 91 | _, predicted = torch.max(y_hat.data,1) 92 | 93 | correct_predictions = (predicted == y.to(device)) 94 | correct_predictions = correct_predictions.int() 95 | 96 | # selecting images of the corresponding class 97 | tested_classes = (y.to(device) == reference_class.to(device)) 98 | tested_classes = tested_classes.int() 99 | 100 | # correct += (tested_classes*correct_predictions).sum().item() 101 | tc = torch.unsqueeze(tested_classes,1) 102 | test_cla = torch.cat((tc,tc,tc,tc,tc,tc,tc,tc,tc,tc),1).to(device) 103 | 104 | loss += criterion1(test_cla*y_hat,tested_classes*y.to(device)) 105 | 106 | 107 | # prection_matrix[patch_size*r:patch_size*(r+1),patch_size*c:patch_size*(c+1)] = correct/50.0 108 | loss_matrix[patch_size*r:patch_size*(r+1),patch_size*c:patch_size*(c+1)] = loss 109 | if int(image_size/patch_size)-r=t] = 1 137 | 138 | mask_of_rank_th.update({mask_i:map}) 139 | with open(dir+args.backbone_model+'_DFM_'+str(int(th*100))+'.pkl', 'wb') as f: 140 | pickle.dump(mask_of_rank_th, f) 141 | f.close() 142 | 143 | 144 | if __name__ == '__main__': 145 | parser = argparse.ArgumentParser() 146 | parser.add_argument('--backbone_model', type=str, default='resnet18', 147 | help='model ') 148 | parser.add_argument('--model_path', type=str, default='None', 149 | help='path of the model') 150 | parser.add_argument('--patch_size', type=int, default=1, 151 | help='patch_size') 152 | 153 | args = parser.parse_args() 154 | 155 | if not os.path.exists('./DFMs'): 156 | os.makedirs('./DFMs') 157 | 158 | main(args) 159 | 160 | 161 | 162 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import argparse 4 | import torchvision.transforms as transforms 5 | from torchvision.datasets import ImageFolder 6 | import os 7 | from pytorch_lightning.core.lightning import LightningModule 8 | import pytorch_lightning as pl 9 | from pytorch_lightning.callbacks import ModelCheckpoint 10 | import torchmetrics 11 | import timm 12 | from torch.optim.lr_scheduler import ReduceLROnPlateau 13 | from pytorch_lightning.loggers.tensorboard import TensorBoardLogger 14 | import sys 15 | sys.path.insert(0,'/home/wangs1/nn-frequency-shortcuts/') 16 | from data.Synthetic import Synthetic 17 | import backbone.resnet as resnet 18 | import backbone.vgg as vgg 19 | import backbone.alexnet as alexnet 20 | 21 | 22 | class Model(LightningModule): 23 | def __init__(self,backbone_model, lr,num_class,dataset,image_size,special): 24 | super(Model, self).__init__() 25 | self.save_hyperparameters() 26 | self.lr = lr 27 | self.train_acc = torchmetrics.Accuracy() 28 | self.val_acc = torchmetrics.Accuracy() 29 | self.test_acc = torchmetrics.Accuracy() 30 | self.dataset = dataset 31 | self.num_class = num_class 32 | self.image_size = image_size 33 | self.backbone_model = backbone_model 34 | self.special = special 35 | 36 | def forward(self, x): 37 | # enc, prediction = self.backbone_model(x) 38 | prediction = self.backbone_model(x) 39 | 40 | return prediction 41 | 42 | 43 | def configure_optimizers(self): 44 | optimizer = torch.optim.SGD(self.parameters(), self.lr, 45 | momentum=0.9, nesterov=True, 46 | weight_decay=1e-4) 47 | scheduler = ReduceLROnPlateau(optimizer, mode='min',verbose=True, factor=0.1) 48 | return {'optimizer': optimizer, 49 | 'lr_scheduler':scheduler, 50 | 'monitor': 'val_loss'} 51 | 52 | def training_step(self, batch, batch_idx): 53 | x, y = batch 54 | 55 | criterion1 = nn.CrossEntropyLoss() 56 | 57 | # _, y_hat = self(x) 58 | y_hat = self(x) 59 | #print(y_hat) 60 | loss1 = criterion1(y_hat, y) 61 | loss = loss1 62 | 63 | _, predicted = torch.max(y_hat.data,1) 64 | self.log_dict({'train_classification_loss': loss1}, on_epoch=True,on_step=True) 65 | self.log_dict({'train_loss': loss}, on_epoch=True,on_step=True) 66 | return {"loss": loss,'epoch_preds': predicted, 'epoch_targets': y} 67 | 68 | def validation_step(self, batch, batch_idx): 69 | x, y = batch 70 | criterion1 = nn.CrossEntropyLoss() 71 | 72 | # _, y_hat = self(x) 73 | y_hat = self(x) 74 | # print(y_hat) 75 | # print(y_hat.size()) 76 | loss1 = criterion1(y_hat, y) 77 | self.val_loss = loss1 78 | 79 | _, predicted = torch.max(y_hat.data,1) 80 | self.log_dict( {'val_loss': self.val_loss}, on_epoch=True,on_step=True) 81 | 82 | return {'epoch_preds': predicted, 'epoch_targets': y} #self.val_loss 83 | 84 | def test_step(self, batch, batch_idx): 85 | x, y = batch 86 | # _, y_hat = self(x) 87 | y_hat = self(x) 88 | # print(y_hat.size()) 89 | 90 | _, predicted = torch.max(y_hat.data,1) 91 | 92 | return {'batch_preds': predicted, 'batch_targets': y} 93 | 94 | 95 | def test_step_end(self, output_results): 96 | 97 | self.test_acc(output_results['batch_preds'], output_results['batch_targets']) 98 | self.log_dict( {'test_acc': self.test_acc}, on_epoch=True,on_step=False) 99 | 100 | def training_epoch_end(self, output_results): 101 | # print(output_results) 102 | self.train_acc(output_results[0]['epoch_preds'], output_results[0]['epoch_targets']) 103 | self.log_dict({"train_acc": self.train_acc}, on_epoch=True, on_step=False) 104 | 105 | def validation_epoch_end(self, output_results): 106 | # print(output_results) 107 | self.val_acc(output_results[0]['epoch_preds'], output_results[0]['epoch_targets']) 108 | self.log_dict({"valid_acc": self.val_acc}, on_epoch=True, on_step=False) 109 | # print(acc) 110 | # return val_accuracy 111 | 112 | def setup(self, stage): 113 | if self.dataset == 'synthetic': 114 | transform_train = transforms.Compose([ 115 | transforms.Pad(4), 116 | transforms.RandomHorizontalFlip(), 117 | transforms.RandomResizedCrop(self.image_size), 118 | transforms.ToTensor(), 119 | transforms.Normalize([0.498, 0.498, 0.498], [0.172, 0.173042, 0.173]) 120 | # normalize 121 | ]) 122 | transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.498, 0.498, 0.498], [0.172, 0.173042, 0.173])]) 123 | data_train = Synthetic('./data',train=True,complex=self.special, transform=transform_train,band = '') 124 | data_test = Synthetic('./data',train=False,complex=self.special, transform=transform,band = '') 125 | elif self.dataset == 'imagenet10': 126 | transform_train = transforms.Compose([ 127 | transforms.Pad(4), 128 | transforms.RandomHorizontalFlip(), 129 | transforms.RandomResizedCrop(self.image_size), 130 | # transforms.AugMix(),# transforms.AutoAugment(), # change here to add other augmentations 131 | transforms.ToTensor(), 132 | transforms.Normalize([0.479838, 0.470448, 0.429404], [0.258143, 0.252662, 0.272406]) 133 | # normalize 134 | ]) 135 | transform=transforms.Compose([transforms.Resize((self.image_size,self.image_size)), transforms.ToTensor(),transforms.Normalize([0.479838, 0.470448, 0.429404], [0.258143, 0.252662, 0.272406])]) 136 | data_train = ImageFolder('./data/ImageNet/train/',transform=transform_train) 137 | data_test = ImageFolder('./data/ImageNet/val/',transform=transform) 138 | elif self.dataset == 'imagenet10_style': 139 | transform_train = transforms.Compose([ 140 | transforms.Pad(4), 141 | transforms.RandomHorizontalFlip(), 142 | transforms.RandomResizedCrop(self.image_size), 143 | transforms.ToTensor(), 144 | transforms.Normalize([0.479838, 0.470448, 0.429404], [0.258143, 0.252662, 0.272406]) 145 | ]) 146 | transform=transforms.Compose([transforms.Resize((self.image_size,self.image_size)), transforms.ToTensor(),transforms.Normalize([0.479838, 0.470448, 0.429404], [0.258143, 0.252662, 0.272406])]) 147 | data_train = ImageFolder('./data/ImageNet_style/train/',transform=transform_train) 148 | data_test = ImageFolder('./data/ImageNet_style/val/',transform=transform) 149 | 150 | # train/val split 151 | data_train2, data_val = torch.utils.data.random_split(data_train, [int(len(data_train)*0.9), len(data_train)-int(len(data_train)*0.9)]) 152 | 153 | # assign to use in dataloaders 154 | self.train_dataset = data_train2 155 | self.val_dataset = data_val 156 | self.test_dataset = data_test 157 | 158 | 159 | def train_dataloader(self): 160 | return torch.utils.data.DataLoader(self.train_dataset, batch_size=64, shuffle=True)#,num_workers=2) 161 | 162 | def test_dataloader(self): 163 | return torch.utils.data.DataLoader(self.test_dataset, batch_size=64, shuffle=False)#,num_workers=2) 164 | 165 | def val_dataloader(self): 166 | return torch.utils.data.DataLoader(self.val_dataset, batch_size=64)#,num_workers=2) 167 | 168 | 169 | def main(args): 170 | backbone = ['resnet18', 'resnet34', 'resnet50','resnet101', 'alex', 'ViT', 'vgg16'] 171 | print(torch.cuda.device_count()) 172 | if args.backbone_model == 'resnet18': 173 | from blocks.resnet.Blocks import BasicBlock 174 | backbone_model = resnet.ResNet(BasicBlock,[2,2,2,2],args.num_class) 175 | elif args.backbone_model == 'resnet34': 176 | from blocks.resnet.Blocks import BasicBlock 177 | backbone_model = resnet.ResNet(BasicBlock, [3,4,6,3],args.num_class) 178 | elif args.backbone_model == 'resnet50': 179 | from blocks.resnet.Blocks import Bottleneck 180 | backbone_model = resnet.ResNet(Bottleneck,[3,4,6,3],args.num_class) 181 | elif args.backbone_model == 'resnet101': 182 | from blocks.resnet.Blocks import Bottleneck 183 | backbone_model = resnet.ResNet(Bottleneck[3,4,23,3],args.num_class) 184 | elif args.backbone_model == 'alex': 185 | backbone_model = alexnet.AlexNet(args.num_class) 186 | elif args.backbone_model == 'ViT': 187 | backbone_model = timm.create_model('vit_base_patch8_224', pretrained=False) 188 | 189 | 190 | logger = TensorBoardLogger(args.save_dir, name=args.backbone_model) 191 | 192 | model = Model(backbone_model, args.lr,args.num_class,args.dataset,args.image_size, args.special) 193 | maxepoch = 200 194 | checkpoints_callback = ModelCheckpoint(save_last=True,save_top_k=-1) 195 | trainer = pl.Trainer(enable_progress_bar=False,logger=logger, callbacks=[checkpoints_callback], gpus=-1, max_epochs=maxepoch) # accelerator='dp', 196 | trainer.fit(model) 197 | trainer.test() 198 | 199 | 200 | 201 | 202 | 203 | if __name__ == '__main__': 204 | parser = argparse.ArgumentParser(description='Write parameters') 205 | parser.add_argument('--backbone_model', type=str, 206 | help='backbone_model') 207 | parser.add_argument('--image_size', type=int, default= 32, 208 | help='size of images in dataset') 209 | parser.add_argument('--num_class', type=int, default= 10, 210 | help='number of classes in dataset') 211 | parser.add_argument('--dataset', type=str, default='imagenet10', 212 | help='dataset') 213 | parser.add_argument('--lr', type=float, default=0.001, 214 | help='learning rate') 215 | parser.add_argument('--save_dir', type=str, default='results/') 216 | parser.add_argument('--special', required=False, default=None, 217 | help='selecting synthetic dataset') 218 | 219 | args = parser.parse_args() 220 | if not os.path.exists(args.save_dir+'/'+args.backbone_model): 221 | os.makedirs(args.save_dir+'/'+args.backbone_model) 222 | print('make the directory') 223 | 224 | main(args) --------------------------------------------------------------------------------