├── .DS_Store ├── LICENSE ├── README.md ├── Similar_Mask_Generate.py ├── SpectralClustering.py ├── celeb.py ├── cub_voc.py ├── densenet_iccnn_multi_train.py ├── densenet_iccnn_train.py ├── densenet_ori_train.py ├── draw_fmap.py ├── hyperparameters.txt ├── load_utils.py ├── newPad2d.py ├── resnet_iccnn_multi_train.py ├── resnet_iccnn_train.py ├── resnet_ori_train.py ├── train_all.py ├── tutorial.pdf ├── utils └── utils.py ├── vgg_iccnn_multi_train.py ├── vgg_iccnn_train.py └── vgg_ori_train.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ada-shen/icCNN/6f6d7bd31a437a3e39c33cb53967ea5e7f1b26f2/.DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 ada-shen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # icCNN 2 | This repository is a pytorch implementation of interpretable compositional convolutional neural networks ([arXiv](https://arxiv.org/abs/2107.04474)), which has been published at IJCAI 2021. 3 | 4 | 5 | 6 | Document Structure 7 | 8 | ### utils 9 | --utils.py [the utility modules used for networks] 10 | 11 | --train_all.py [the top module of networks] 12 | 13 | --celeb.py [the dataset driver for Large-scale CelebFaces Attributes (CelebA) dataset] 14 | 15 | --cub_voc.py [the dataset driver for CUB200-2011 dataset/PASCAL-Part dataset/Helen Facial Feature dataset] 16 | 17 | --vgg_iccnn_multi_train.py [the compositional CNN network used for VGGs of Multi-category classification] 18 | 19 | --vgg_iccnn_train.py [the compositional CNN network used for VGGs of Single-category classification] 20 | 21 | --vgg_ori_train.py [the traditional CNN network used for VGGs] 22 | 23 | --resnet_iccnn_multi_train.py [the compositional CNN network used for Resnets of Multi-category classification] 24 | 25 | --resnet_iccnn_train.py [the compositional CNN network used for Resnets of Single-category classification] 26 | 27 | --resnet_ori_train.py [the traditional CNN network used for Resnets] 28 | 29 | --densenet_iccnn_multi_train.py [the compositional CNN network used for Densenets of Multi-category classification] 30 | 31 | --densenet_iccnn_train.py [the compositional CNN network used for Densenets of Single-category classification] 32 | 33 | --densenet_ori_train.py [the traditional CNN network used for Densenets] 34 | 35 | --load_utils.py [the driver of utility modules] 36 | 37 | --newPad2d.py [the rewrite of deplicated padding] 38 | 39 | --SpectralClustering.py [the module of Spectral Clustering] 40 | 41 | --Similar_Mask_Generate.py [the module to generate similar masks] 42 | 43 | ### Train/Test 44 | 45 | ``` 46 | python3 train_all.py -type [ori/iccnn] -is_multi [single/multi 0/1] -model [vgg/resnet/densenet] 47 | ``` 48 | 49 | ### Models corresponding to results in the paper 50 | 51 | We have uploaded the model files used in the paper to the following link https://b2y4v0n4v2.feishu.cn/drive/folder/fldcnwJeKnPEuQVydwhqNH8op5d 52 | We have uploaded the model files for multi-category classification to the following link https://pan.baidu.com/s/1RlCfjiJCdGxI8fGCpeClKQ?pwd=87ps 53 | 54 | ### Datasets for train and visualization 55 | 56 | We have uploaded datasets for train icCNNs and those images for visualization to the following link https://pan.baidu.com/s/1RlCfjiJCdGxI8fGCpeClKQ?pwd=87ps 57 | 58 | ## Citation 59 | 60 | If you use this project in your research, please cite it. 61 | 62 | ``` 63 | @inproceedings{shen2021interpretable, 64 | title={Interpretable Compositional Convolutional Neural Networks}, 65 | author={Shen, Wen and Wei, Zhihua and Huang, Shikun and Zhang, Binbin and Fan, Jiaqi and Zhao, Ping and Zhang, Quanshi}, 66 | booktitle={Proceedings of the International Joint Conference on Artificial Intelligence}, 67 | year={2021} 68 | } 69 | ``` 70 | -------------------------------------------------------------------------------- /Similar_Mask_Generate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on 03 05 2020 3 | @author: H 4 | """ 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | from SpectralClustering import spectral_clustering 9 | from utils.utils import EMA_FM 10 | 11 | class SMGBlock(nn.Module): 12 | def __init__(self, channel_size=2048, f_map_size=196): 13 | super(SMGBlock, self).__init__() 14 | 15 | self.EMA_FM = EMA_FM(decay=0.95, first_decay=0.0, channel_size=channel_size, f_map_size=f_map_size, is_use=True) 16 | 17 | 18 | def forward(self, x): 19 | ''' 20 | :param x: (b, c, h, w) 21 | :return: 22 | ''' 23 | batch_size, channel, _,_ = x.size() 24 | theta_x = x.view(batch_size,channel,-1).permute(0,2,1).contiguous() 25 | transpose_x = x.view(batch_size,channel,-1).permute(0,2,1).contiguous()# [b,h×w,c] 26 | with torch.no_grad(): 27 | f_mean = self.EMA_FM.update(theta_x) 28 | sz = f_mean.size()[0] 29 | f_mean = f_mean.view(1,channel,sz) 30 | f_mean_transposed = f_mean.permute(0,2,1) 31 | Local = torch.matmul(theta_x.permute(0, 2, 1)-f_mean, theta_x-f_mean_transposed) 32 | diag = torch.eye(channel).view(-1,channel,channel).cuda() 33 | cov = torch.sum(Local*diag,dim=2).view(batch_size,channel,1) 34 | cov_transpose = cov.permute(0,2,1) 35 | norm = torch.sqrt(torch.matmul(cov,cov_transpose)) 36 | correlation = torch.div(Local,norm)+1 ## normlize to [0,2] 37 | 38 | return correlation 39 | 40 | def bn(input,eps=1e-5): 41 | # input b,c,n 42 | inSize = input.size() 43 | mean = input.mean(dim=0)##.view(inSize[0],-1) 44 | std = input.std(dim=0)#.view(inSize[0],-1) 45 | y = torch.div(input-mean,std+eps) 46 | return y 47 | 48 | def fn(input,eps=1e-5): 49 | # input b,c,n 50 | inSize = input.size() 51 | mean = input.view(inSize[0],-1).mean(dim=-1) 52 | std = input.view(inSize[0],-1).std(dim=-1) 53 | y = torch.div(input-mean.view(inSize[0],1,1),std.view(inSize[0],1,1)+eps) 54 | return y 55 | 56 | def single_max_min_norm(input,eps=1e-5): 57 | # input b,c,n 58 | inSize = input.size() 59 | max_ = torch.max(input.view(inSize[0],-1),-1)[0] 60 | min_ = torch.min(input.view(inSize[0],-1),-1)[0] 61 | #print(min_.shape) 62 | y = torch.div(input-min_.view(inSize[0],1,1),max_.view(inSize[0],1,1)-min_.view(inSize[0],1,1)+eps) 63 | return y 64 | 65 | def batch_max_min_norm(input,eps=1e-5): 66 | # input b,c,n 67 | inSize = input.size() 68 | input_p = input.permute(1,0,2).contiguous() 69 | max_ = torch.max(input_p.view(inSize[1],-1),-1)[0] 70 | min_ = torch.min(input_p.view(inSize[1],-1),-1)[0] 71 | #print(min_.shape) 72 | y = torch.div(input-min_.view(1,inSize[1],1),max_.view(1,inSize[1],1)-min_.view(1,inSize[1],1)+eps) 73 | return y 74 | 75 | if __name__ == '__main__': 76 | import torch 77 | 78 | img = torch.zeros(1, 1024, 14, 14) 79 | net = SMGBlock(1024, 196) 80 | out = net(img) 81 | print(out.size()) 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /SpectralClustering.py: -------------------------------------------------------------------------------- 1 | from sklearn.cluster import KMeans, SpectralClustering 2 | import numpy as np 3 | import torch 4 | 5 | def spectral_clustering(similarity_matrix,n_cluster=8): 6 | W = similarity_matrix 7 | 8 | sz = W.shape[0] 9 | sp = SpectralClustering(n_clusters=n_cluster,affinity='precomputed',random_state=21) 10 | y_pred = sp.fit_predict(W) 11 | # for i in range(n_cluster): 12 | # print(np.sum(y_pred==i)) 13 | del W 14 | ground_true_matrix = np.zeros((sz,sz)) 15 | loss_mask_num = [] 16 | loss_mask_den = [] 17 | for i in range(n_cluster): 18 | idx = np.where(y_pred==i)[0] 19 | cur_mask_num = np.zeros((sz,sz)) 20 | cur_mask_den = np.zeros((sz,sz)) 21 | for j in idx: 22 | ground_true_matrix[j][idx] = 1 23 | cur_mask_num[j][idx] = 1 24 | cur_mask_den[j][:] = 1 25 | loss_mask_num.append(np.expand_dims(cur_mask_num,0)) 26 | loss_mask_den.append(np.expand_dims(cur_mask_den,0)) 27 | loss_mask_num = np.concatenate(loss_mask_num,axis=0) 28 | loss_mask_den = np.concatenate(loss_mask_den,axis=0) 29 | return torch.from_numpy(ground_true_matrix).float().cuda(), torch.from_numpy(loss_mask_num).float().cuda(), torch.from_numpy(loss_mask_den).float().cuda() -------------------------------------------------------------------------------- /celeb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from PIL import Image 4 | import numpy as np 5 | import torch 6 | import random 7 | from torch.utils.data import Dataset, DataLoader 8 | from torchvision import transforms, utils 9 | from torchvision.datasets.folder import default_loader 10 | 11 | class Celeb(Dataset): 12 | def __init__(self, data_file, dst_path='cropped_CelebA', training=True, transform=None, train_num=16000): 13 | src_path = data_file + 'CelebA_info' 14 | if train_num == 10240: 15 | category = 'celeb_sample_10240.txt' 16 | else: 17 | category = 'list_attr_celeba.txt' 18 | fn = open(src_path + '/Anno/' + category, 'r') 19 | fh2 = open(src_path + '/Eval/list_eval_partition.txt', 'r') 20 | imgs = [] 21 | lbls = [] 22 | ln = 0 23 | train_bound = 162770 + 2 24 | test_bound = 182638 + 2 25 | regex = re.compile('\s+') 26 | for line in fn: 27 | ln += 1 28 | if ln <= 2: 29 | continue 30 | if ln < test_bound and not training: 31 | continue 32 | if (ln - 2 <= train_num and training and ln <=train_bound) or\ 33 | (ln - test_bound < train_num and not training): 34 | line = line.rstrip('\n') 35 | line_value = regex.split(line) 36 | imgs.append(line_value[0]) 37 | lbls.append(list(int(i) if int(i) > 0 else 0 for i in line_value[1:])) 38 | self.imgs = imgs 39 | self.lbls = lbls 40 | self.is_train = training 41 | self.dst_path = data_file + dst_path 42 | if transform is None: 43 | if training: 44 | self.transform = transforms.Compose([ 45 | transforms.RandomHorizontalFlip(), 46 | transforms.ToTensor(), 47 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 48 | ]) 49 | else: 50 | self.transform = transforms.Compose([ 51 | transforms.ToTensor(), 52 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 53 | ]) 54 | else: 55 | self.transform = transform 56 | 57 | def __getitem__(self, idx): 58 | fn = self.imgs[idx] 59 | lbls = self.lbls[idx] 60 | if self.is_train: 61 | imgs = default_loader(self.dst_path + '/train/' + fn) 62 | else: 63 | imgs = default_loader(self.dst_path + '/test/' + fn) 64 | imgs = self.transform(imgs) 65 | lbls = torch.Tensor(lbls) 66 | return [imgs, lbls] 67 | 68 | def __len__(self): 69 | return len(self.imgs) 70 | 71 | def sample_celeb(data_file, category='list_attr_celeba.txt', training=True, sample_num=10240, train_num=162770): 72 | src_path = data_file + 'CelebA_info' 73 | fn = open(src_path + '/Anno/' + category, 'r') 74 | sample_path = src_path + '/Anno/celeb_sample_'+str(sample_num)+'.txt' 75 | if os.path.exists(sample_path): 76 | os.system('rm '+ sample_path) 77 | sample_fh = open(sample_path, 'w') 78 | ln = 0 79 | train_bound = 162770 + 2 80 | test_bound = 182638 + 2 81 | regex = re.compile('\s+') 82 | content = [] 83 | trainnum_list = np.arange(0, train_bound-2) 84 | sample_num_list = random.sample(trainnum_list.tolist(), sample_num) 85 | for line in fn: 86 | ln += 1 87 | if ln <= 2: 88 | sample_fh.write(line) 89 | if ln < test_bound and not training: 90 | continue 91 | if (ln - 2 <= train_num and training and ln <=train_bound) or\ 92 | (ln - test_bound < train_num and not training): 93 | content.append(line) 94 | 95 | for idx in sample_num_list: 96 | sample_fh.write(content[idx]) 97 | sample_fh.close() 98 | 99 | if __name__ == '__main__': 100 | data_file = '/home/wzh/project/fjq/dataset/CelebA/' 101 | sample_celeb(data_file, sample_num=10240) -------------------------------------------------------------------------------- /cub_voc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | from torchvision.datasets.folder import default_loader 4 | from torchvision.datasets.utils import download_url 5 | from torch.utils.data import Dataset 6 | import numpy as np 7 | # object_categories = ['aeroplane', 'bicycle', 'bird', 'boat', 8 | # 'bottle', 'bus', 'car', 'cat', 'chair', 9 | # 'cow', 'diningtable', 'dog', 'horse', 10 | # 'motorbike', 'person', 'pottedplant', 11 | # 'sheep', 'sofa', 'train', 'tvmonitor'] 12 | object_categories = ['bird', 'cat', 'cow', 'dog', 'horse', 'sheep'] 13 | 14 | class CUB_VOC(Dataset): 15 | 16 | def __init__(self, root, dataname, mytype, train=True, transform=None, loader=default_loader, is_frac=None, sample_num=-1): 17 | self.root = os.path.expanduser(root) 18 | self.dataname = dataname 19 | self.mytype = mytype 20 | self.transform = transform 21 | self.loader = default_loader 22 | self.train = train 23 | self.is_frac = is_frac 24 | self.sample_num = sample_num 25 | if not self._check_integrity(): 26 | raise RuntimeError('Dataset not found or corrupted.' + 27 | ' You can use download=True to download it') 28 | 29 | def _load_metadata(self): 30 | data_txt = None 31 | if self.dataname in object_categories: 32 | data_txt = '%s_info.txt' % self.dataname 33 | elif self.dataname == 'cub': 34 | if self.mytype == 'ori': 35 | data_txt = 'image_info.txt' 36 | else: 37 | data_txt = 'cubsample_info.txt' 38 | elif self.dataname == 'helen': 39 | data_txt = 'helen_info.txt' 40 | elif self.dataname == 'voc_multi': 41 | data_txt = 'animal_info.txt' 42 | 43 | self.data = pd.read_csv(os.path.join(self.root, data_txt), 44 | names=['img_id','file_path','target','is_training_img']) 45 | if self.train: 46 | self.data = self.data[self.data.is_training_img == 1] 47 | else: 48 | self.data = self.data[self.data.is_training_img == 0] 49 | 50 | if self.is_frac is not None: 51 | self.data = self.data[self.data.target == self.is_frac] 52 | 53 | if self.sample_num != -1: 54 | self.data = self.data[0:self.sample_num] 55 | 56 | def _check_integrity(self): 57 | try: 58 | self._load_metadata() 59 | except Exception: 60 | return False 61 | 62 | for index, row in self.data.iterrows(): 63 | filepath = os.path.join(self.root, row.file_path) 64 | if not os.path.isfile(filepath): 65 | print(filepath) 66 | return False 67 | return True 68 | 69 | def __len__(self): 70 | return len(self.data) 71 | 72 | def __getitem__(self, idx): 73 | sample = self.data.iloc[idx] 74 | path = os.path.join(self.root, sample.file_path) 75 | target = sample.target # Targets start at 1 by default, so shift to 0 76 | img = self.loader(path) 77 | 78 | if self.transform is not None: 79 | img = self.transform(img) 80 | 81 | return img, target 82 | -------------------------------------------------------------------------------- /densenet_iccnn_multi_train.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.checkpoint as cp 6 | from collections import OrderedDict 7 | from torch import Tensor 8 | from torch.jit.annotations import List 9 | 10 | #added 11 | import torchvision.transforms as transforms 12 | from torch.utils.data import DataLoader 13 | from load_utils import load_state_dict_from_url 14 | from cub_voc import CUB_VOC 15 | import os 16 | from tqdm import tqdm 17 | import shutil 18 | from utils.utils import Cluster_loss, Multiclass_loss 19 | import numpy as np 20 | from Similar_Mask_Generate import SMGBlock 21 | from SpectralClustering import spectral_clustering 22 | from newPad2d import newPad2d 23 | 24 | MEMORY_EFFICIENT = True 25 | IS_TRAIN = 0 # 0/1 26 | LAYERS = '121' 27 | DATANAME = 'voc_multi' # voc_multi 28 | NUM_CLASSES = 6 29 | cub_file = './dataset/frac_dataset' 30 | voc_file = './dataset/VOCdevkit/VOC2010/voc2010_crop' 31 | log_path = './icCNN/run/densenet/' # for model 32 | save_path = './icCNN/run/basic_fmap/densenet/' # for get_feature 33 | acc_path = './icCNN/run/basic_fmap/densenet/acc/' 34 | dataset = '%s_densenet_%s_iccnn' % (LAYERS, DATANAME) 35 | log_path = log_path + dataset + '/' 36 | pretrain_model = log_path + 'model_1000.pth' 37 | BATCHSIZE = 1 38 | LR = 0.00001 39 | EPOCH = 1000 40 | center_num = 16 41 | lam1 = 0.1 42 | lam2 = 0.1 43 | T = 2 # T = 2 ===> do sc each epoch 44 | F_MAP_SIZE = 49 45 | STOP_CLUSTERING = 200 46 | if LAYERS == '121': 47 | CHANNEL_NUM = 1024 # for SMGBlock1D 48 | elif LAYERS == '161': 49 | CHANNEL_NUM = 2208 # for SMGBlock1D 50 | 51 | 52 | 53 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] 54 | 55 | model_urls = { 56 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 57 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 58 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 59 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', 60 | } 61 | 62 | class _DenseLayer(nn.Module): 63 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=MEMORY_EFFICIENT): 64 | super(_DenseLayer, self).__init__() 65 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 66 | self.add_module('relu1', nn.ReLU(inplace=True)), 67 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 68 | growth_rate, kernel_size=1, stride=1, 69 | bias=False)), 70 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 71 | self.add_module('relu2', nn.ReLU(inplace=True)), 72 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 73 | kernel_size=3, stride=1, padding=0, #new padding 74 | bias=False)), 75 | self.drop_rate = float(drop_rate) 76 | self.memory_efficient = memory_efficient 77 | self.pad2d_1 = newPad2d(1)#new padding 78 | 79 | def bn_function(self, inputs): 80 | # type: (List[Tensor]) -> Tensor 81 | concated_features = torch.cat(inputs, 1) 82 | bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features))) # noqa: T484 83 | return bottleneck_output 84 | 85 | # todo: rewrite when torchscript supports any 86 | def any_requires_grad(self, input): 87 | # type: (List[Tensor]) -> bool 88 | for tensor in input: 89 | if tensor.requires_grad: 90 | return True 91 | return False 92 | 93 | @torch.jit.unused # noqa: T484 94 | def call_checkpoint_bottleneck(self, input): 95 | # type: (List[Tensor]) -> Tensor 96 | def closure(*inputs): 97 | return self.bn_function(inputs) 98 | 99 | return cp.checkpoint(closure, *input) 100 | 101 | @torch.jit._overload_method # noqa: F811 102 | def forward(self, input): 103 | # type: (List[Tensor]) -> (Tensor) 104 | pass 105 | 106 | @torch.jit._overload_method # noqa: F811 107 | def forward(self, input): 108 | # type: (Tensor) -> (Tensor) 109 | pass 110 | 111 | # torchscript does not yet support *args, so we overload method 112 | # allowing it to take either a List[Tensor] or single Tensor 113 | def forward(self, input): # noqa: F811 114 | if isinstance(input, Tensor): 115 | prev_features = [input] 116 | else: 117 | prev_features = input 118 | 119 | if self.memory_efficient and self.any_requires_grad(prev_features): 120 | if torch.jit.is_scripting(): 121 | raise Exception("Memory Efficient not supported in JIT") 122 | 123 | bottleneck_output = self.call_checkpoint_bottleneck(prev_features) 124 | else: 125 | bottleneck_output = self.bn_function(prev_features) 126 | 127 | new_features = self.conv2(self.pad2d_1(self.relu2(self.norm2(bottleneck_output))))#new padding 128 | if self.drop_rate > 0: 129 | new_features = F.dropout(new_features, p=self.drop_rate, 130 | training=self.training) 131 | return new_features 132 | 133 | class _DenseBlock(nn.ModuleDict): 134 | _version = 2 135 | 136 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=MEMORY_EFFICIENT): 137 | super(_DenseBlock, self).__init__() 138 | for i in range(num_layers): 139 | layer = _DenseLayer( 140 | num_input_features + i * growth_rate, 141 | growth_rate=growth_rate, 142 | bn_size=bn_size, 143 | drop_rate=drop_rate, 144 | memory_efficient=memory_efficient, 145 | ) 146 | self.add_module('denselayer%d' % (i + 1), layer) 147 | 148 | def forward(self, init_features): 149 | features = [init_features] 150 | for name, layer in self.items(): 151 | new_features = layer(features) 152 | features.append(new_features) 153 | return torch.cat(features, 1) 154 | 155 | class _Transition(nn.Sequential): 156 | def __init__(self, num_input_features, num_output_features): 157 | super(_Transition, self).__init__() 158 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 159 | self.add_module('relu', nn.ReLU(inplace=True)) 160 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 161 | kernel_size=1, stride=1, bias=False)) 162 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 163 | 164 | class DenseNet(nn.Module): 165 | r"""Densenet-BC model class, based on 166 | `"Densely Connected Convolutional Networks" `_ 167 | 168 | Args: 169 | growth_rate (int) - how many filters to add each layer (`k` in paper) 170 | block_config (list of 4 ints) - how many layers in each pooling block 171 | num_init_features (int) - the number of filters to learn in the first convolution layer 172 | bn_size (int) - multiplicative factor for number of bottle neck layers 173 | (i.e. bn_size * k features in the bottleneck layer) 174 | drop_rate (float) - dropout rate after each dense layer 175 | num_classes (int) - number of classification classes 176 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 177 | but slower. Default: *False*. See `"paper" `_ 178 | """ 179 | 180 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 181 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=NUM_CLASSES, memory_efficient=MEMORY_EFFICIENT): 182 | 183 | super(DenseNet, self).__init__() 184 | 185 | # First convolution 186 | self.features = nn.Sequential(OrderedDict([ 187 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, 188 | padding=0, bias=False)), # new padding 189 | ('norm0', nn.BatchNorm2d(num_init_features)), 190 | ('relu0', nn.ReLU(inplace=True)), 191 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)), # new padding 192 | ])) 193 | 194 | self.pad2d_1 = newPad2d(1)#new padding 195 | self.pad2d_3 = newPad2d(3)#new padding 196 | 197 | self.smg = SMGBlock(channel_size = CHANNEL_NUM ,f_map_size=F_MAP_SIZE) # 14*14 feature map #CHANNEL_SIZE 198 | 199 | 200 | # Each denseblock 201 | num_features = num_init_features 202 | for i, num_layers in enumerate(block_config): 203 | block = _DenseBlock( 204 | num_layers=num_layers, 205 | num_input_features=num_features, 206 | bn_size=bn_size, 207 | growth_rate=growth_rate, 208 | drop_rate=drop_rate, 209 | memory_efficient=memory_efficient 210 | ) 211 | self.features.add_module('denseblock%d' % (i + 1), block) 212 | num_features = num_features + num_layers * growth_rate 213 | if i != len(block_config) - 1: 214 | trans = _Transition(num_input_features=num_features, 215 | num_output_features=num_features // 2) 216 | self.features.add_module('transition%d' % (i + 1), trans) 217 | num_features = num_features // 2 218 | 219 | # Final batch norm 220 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 221 | 222 | # Linear layer 223 | self.classifier = nn.Linear(num_features, num_classes) 224 | 225 | # Official init from torch repo. 226 | for m in self.modules(): 227 | if isinstance(m, nn.Conv2d): 228 | nn.init.kaiming_normal_(m.weight) 229 | elif isinstance(m, nn.BatchNorm2d): 230 | nn.init.constant_(m.weight, 1) 231 | nn.init.constant_(m.bias, 0) 232 | elif isinstance(m, nn.Linear): 233 | nn.init.constant_(m.bias, 0) 234 | 235 | def forward(self, x, eval=False): 236 | for i, layer in enumerate(self.features): 237 | if i == 0: 238 | x = self.pad2d_3(x) # new padding 239 | if i == 3: 240 | x = self.pad2d_1(x) # new padding 241 | x = layer(x) 242 | out = F.relu(x, inplace=True) 243 | if eval: 244 | return out 245 | corre_matrix = self.smg(out) 246 | f_map = out # get_feature 247 | out = F.adaptive_avg_pool2d(out, (1, 1)) 248 | out = torch.flatten(out, 1) 249 | out = self.classifier(out) 250 | return out,f_map,corre_matrix 251 | 252 | def _load_state_dict(model, model_url, progress): 253 | # '.'s are no longer allowed in module names, but previous _DenseLayer 254 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 255 | # They are also in the checkpoints in model_urls. This pattern is used 256 | # to find such keys. 257 | pattern = re.compile( 258 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 259 | 260 | state_dict = load_state_dict_from_url(model_url, progress=progress) 261 | for key in list(state_dict.keys()): 262 | res = pattern.match(key) 263 | if res: 264 | new_key = res.group(1) + res.group(2) 265 | state_dict[new_key] = state_dict[key] 266 | del state_dict[key] 267 | # model.load_state_dict(state_dict) 268 | # for finetune 269 | pretrained_dict = {k: v for k, v in state_dict.items() if 'classifier' not in k}#'fc' not in k and 'layer4.1' not in k and 270 | model_dict = model.state_dict() 271 | model_dict.update(pretrained_dict) 272 | model.load_state_dict(model_dict) 273 | 274 | 275 | def _densenet(arch, growth_rate, block_config, num_init_features, pretrained, progress, 276 | **kwargs): 277 | model = DenseNet(growth_rate, block_config, num_init_features, **kwargs) 278 | if pretrained: 279 | _load_state_dict(model, model_urls[arch], progress) 280 | else: 281 | if pretrain_model is not None: 282 | print("Load pretrained model") 283 | device = torch.device("cuda") 284 | model = nn.DataParallel(model).to(device) 285 | # model.load_state_dict(torch.load(pretrain_model)) 286 | pretrained_dict = torch.load(pretrain_model) 287 | if IS_TRAIN == 0: 288 | pretrained_dict = {k[k.find('.')+1:]: v for k, v in pretrained_dict.items()} # for get_feature 289 | model.load_state_dict(pretrained_dict) 290 | else: 291 | print('Error: pretrain_model == None') 292 | return model 293 | 294 | def densenet121(pretrained=False, progress=True, **kwargs): 295 | r"""Densenet-121 model from 296 | `"Densely Connected Convolutional Networks" `_ 297 | 298 | Args: 299 | pretrained (bool): If True, returns a model pre-trained on ImageNet 300 | progress (bool): If True, displays a progress bar of the download to stderr 301 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 302 | but slower. Default: *False*. See `"paper" `_ 303 | """ 304 | return _densenet('densenet121', 32, (6, 12, 24, 16), 64, pretrained, progress, 305 | **kwargs) 306 | 307 | def densenet161(pretrained=False, progress=True, **kwargs): 308 | r"""Densenet-161 model from 309 | `"Densely Connected Convolutional Networks" `_ 310 | 311 | Args: 312 | pretrained (bool): If True, returns a model pre-trained on ImageNet 313 | progress (bool): If True, displays a progress bar of the download to stderr 314 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 315 | but slower. Default: *False*. See `"paper" `_ 316 | """ 317 | return _densenet('densenet161', 48, (6, 12, 36, 24), 96, pretrained, progress, 318 | **kwargs) 319 | 320 | def densenet169(pretrained=False, progress=True, **kwargs): 321 | r"""Densenet-169 model from 322 | `"Densely Connected Convolutional Networks" `_ 323 | 324 | Args: 325 | pretrained (bool): If True, returns a model pre-trained on ImageNet 326 | progress (bool): If True, displays a progress bar of the download to stderr 327 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 328 | but slower. Default: *False*. See `"paper" `_ 329 | """ 330 | return _densenet('densenet169', 32, (6, 12, 32, 32), 64, pretrained, progress, 331 | **kwargs) 332 | 333 | def densenet201(pretrained=False, progress=True, **kwargs): 334 | r"""Densenet-201 model from 335 | `"Densely Connected Convolutional Networks" `_ 336 | 337 | Args: 338 | pretrained (bool): If True, returns a model pre-trained on ImageNet 339 | progress (bool): If True, displays a progress bar of the download to stderr 340 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 341 | but slower. Default: *False*. See `"paper" `_ 342 | """ 343 | return _densenet('densenet201', 32, (6, 12, 48, 32), 64, pretrained, progress, 344 | **kwargs) 345 | 346 | def get_Data(is_train, dataset_name, batch_size): 347 | val_transform = transforms.Compose([ 348 | transforms.Resize((224,224)), 349 | transforms.ToTensor(), 350 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 351 | std=[0.229, 0.224, 0.225]) 352 | ]) 353 | voc_helen = ['bird', 'cat', 'cow', 'dog', 'horse', 'sheep', 'helen', 'voc_multi'] 354 | ##cub dataset### 355 | label = None if is_train else 0 356 | if not is_train: 357 | batch_size = 1 358 | if dataset_name == 'cub': 359 | trainset = CUB_VOC(cub_file, dataset_name, 'iccnn', train=True, transform=val_transform, is_frac=label) 360 | testset = CUB_VOC(cub_file, dataset_name, 'iccnn', train=False, transform=val_transform, is_frac=label) 361 | ###cropped voc dataset### 362 | elif dataset_name in voc_helen: 363 | trainset = CUB_VOC(voc_file, dataset_name, 'iccnn', train=True, transform=val_transform, is_frac=label) 364 | testset = CUB_VOC(voc_file, dataset_name, 'iccnn', train=False, transform=val_transform, is_frac=label) 365 | ###celeb dataset### 366 | #elif dataset_name == 'celeb': 367 | # trainset = Celeb(training = True, transform=None) 368 | # testset = Celeb(training = False, transform=None) 369 | train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True) 370 | test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False) 371 | return train_loader, test_loader 372 | 373 | def net_train(): 374 | trainset_loader, testset_loader = get_Data(IS_TRAIN, DATANAME, BATCHSIZE) 375 | if os.path.exists(log_path): 376 | shutil.rmtree(log_path);os.makedirs(log_path) 377 | else: 378 | os.makedirs(log_path) 379 | device = torch.device("cuda") 380 | net = None 381 | if LAYERS == '121': 382 | net = densenet121(pretrained=False) 383 | elif LAYERS == '161': 384 | net = densenet161(pretrained=False) 385 | net = nn.DataParallel(net).to(device) 386 | # Loss and Optimizer 387 | criterion = nn.CrossEntropyLoss() 388 | optimizer = torch.optim.Adam(net.module.parameters(), lr=LR) 389 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=125, gamma=0.6) 390 | 391 | # Train the model 392 | save_similatiry_loss = [];save_gt=[];save_class_loss= [];save_total_loss = []; 393 | cs_loss = Cluster_loss() 394 | mc_loss = Multiclass_loss(class_num= NUM_CLASSES) 395 | for epoch in range(EPOCH+1): 396 | if epoch % T==0 and epoch < STOP_CLUSTERING: 397 | with torch.no_grad(): 398 | Ground_true, loss_mask_num, loss_mask_den = offline_spectral_cluster(net, trainset_loader) 399 | save_gt.append(Ground_true.cpu().numpy()) 400 | else: 401 | scheduler.step() 402 | net.train() 403 | total_loss = 0.0;similarity_loss = 0.0;class_loss = 0.0 404 | for batch_step, input_data in tqdm(enumerate(trainset_loader,0),total=len(trainset_loader),smoothing=0.9): 405 | inputs, labels = input_data 406 | inputs, labels = inputs.to(device), labels.to(device) 407 | optimizer.zero_grad() 408 | output, f_map, corre = net(inputs, eval=False) 409 | clr_loss = criterion(output, labels) 410 | loss1 = cs_loss.update(corre, loss_mask_num, loss_mask_den, None) 411 | loss2 = mc_loss.update(f_map, loss_mask_num, labels) 412 | loss = clr_loss + lam1 *loss1 + lam2*loss2 413 | loss.backward() 414 | optimizer.step() 415 | total_loss += loss.item() 416 | similarity_loss += loss1.item() 417 | class_loss += loss2.item() 418 | ### loss save code 419 | total_loss = float(total_loss) / len(trainset_loader) 420 | similarity_loss = float(similarity_loss) / len(trainset_loader) 421 | class_loss = float(class_loss) / len(trainset_loader) 422 | save_total_loss.append(total_loss) 423 | save_similatiry_loss.append(similarity_loss) 424 | save_class_loss.append(class_loss) 425 | acc = 0 426 | if epoch % 5==0: 427 | acc = test(net, testset_loader) 428 | print('Epoch', epoch, 'loss: %.4f' % total_loss, 'sc_loss: %.4f' % similarity_loss, 'class_loss: %.4f' % class_loss, 'test accuracy:%.4f' % acc) 429 | if epoch % 100 == 0: 430 | torch.save(net.state_dict(), log_path+'model_%.3d.pth' % (epoch)) 431 | np.savez(log_path+'loss_%.3d.npz'% (epoch), loss=np.array(save_total_loss), similarity_loss = np.array(save_similatiry_loss), class_loss = np.array(save_class_loss),gt=np.array(save_gt)) 432 | print('Finished Training') 433 | return net 434 | 435 | def offline_spectral_cluster(net, train_data): 436 | net.eval() 437 | f_map = [] 438 | device = torch.device("cuda") 439 | for inputs, labels in train_data: 440 | inputs, labels = inputs.to(device), labels.to(device) 441 | cur_fmap= net(inputs,eval=True).detach().cpu().numpy() 442 | f_map.append(cur_fmap) 443 | f_map = np.concatenate(f_map,axis=0) 444 | sample, channel,_,_ = f_map.shape 445 | f_map = f_map.reshape((sample,channel,-1)) 446 | mean = np.mean(f_map,axis=0) 447 | cov = np.mean(np.matmul(f_map-mean,np.transpose(f_map-mean,(0,2,1))),axis=0) 448 | diag = np.diag(cov).reshape(channel,-1) 449 | correlation = cov/(np.sqrt(np.matmul(diag,np.transpose(diag,(1,0))))+1e-5)+1 450 | ground_true, loss_mask_num, loss_mask_den = spectral_clustering(correlation,n_cluster=center_num) 451 | return ground_true, loss_mask_num, loss_mask_den 452 | 453 | def get_feature(): 454 | print('pretrain_model:', pretrain_model) 455 | _, testset_test = get_Data(True, DATANAME, BATCHSIZE) 456 | _, testset_feature = get_Data(False, DATANAME, BATCHSIZE) 457 | device = torch.device("cuda") 458 | net = None 459 | if LAYERS == '121': 460 | net = densenet121(pretrained=False) 461 | elif LAYERS == '161': 462 | net = densenet161(pretrained=False) 463 | acc = test(net, testset_test) 464 | f = open(acc_path+dataset+'_test.txt', 'w+') 465 | f.write('%s\n' % dataset) 466 | f.write('acc:%f\n' % acc) 467 | all_feature = [] 468 | testset = testset_test if DATANAME == 'voc_multi' else testset_feature 469 | for batch_step, input_data in tqdm(enumerate(testset,0),total=len(testset),smoothing=0.9): 470 | inputs, labels = input_data 471 | inputs, labels = inputs.to(device), labels.to(device) 472 | net.eval() 473 | f_map = net(inputs,eval=True) 474 | all_feature.append(f_map.detach().cpu().numpy()) 475 | all_feature = np.concatenate(all_feature,axis=0) 476 | print(all_feature.shape) 477 | f.write('sample num:%d' % (all_feature.shape[0])) 478 | f.close() 479 | np.savez(save_path+LAYERS+'_densenet_'+DATANAME+'_iccnn.npz', f_map=all_feature[...]) 480 | print('Finished Operation!') 481 | return net 482 | 483 | def test(net, testdata): 484 | correct, total = .0, .0 485 | for inputs, labels in testdata: 486 | inputs, labels = inputs.cuda(), labels.cuda() 487 | net.eval() 488 | outputs, _,_ = net(inputs) 489 | _, predicted = torch.max(outputs, 1) 490 | total += labels.size(0) 491 | correct += (predicted == labels).sum() 492 | print('test acc = ',float(correct) / total) 493 | return float(correct) / total 494 | 495 | def densenet_multi_train(): 496 | if IS_TRAIN == 1: 497 | net = net_train() 498 | elif IS_TRAIN == 0: 499 | net = get_feature() 500 | -------------------------------------------------------------------------------- /densenet_iccnn_train.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.checkpoint as cp 6 | from collections import OrderedDict 7 | from torch import Tensor 8 | from torch.jit.annotations import List 9 | 10 | #added 11 | import torchvision.transforms as transforms 12 | from torch.utils.data import DataLoader 13 | from load_utils import load_state_dict_from_url 14 | from cub_voc import CUB_VOC 15 | import os 16 | from tqdm import tqdm 17 | import shutil 18 | import numpy as np 19 | 20 | # for iccnn train 21 | from Similar_Mask_Generate import SMGBlock 22 | from SpectralClustering import spectral_clustering 23 | from utils.utils import Cluster_loss 24 | from newPad2d import newPad2d 25 | 26 | MEMORY_EFFICIENT = True 27 | IS_TRAIN = 0 # 0/1 28 | LAYERS = '121' 29 | DATANAME = 'bird' # bird/cat/.../cub/helen 30 | NUM_CLASSES = 2 31 | cub_file = '/data/sw/dataset/frac_dataset' 32 | voc_file = '/data/sw/dataset/VOCdevkit/VOC2010/voc2010_crop' 33 | log_path = '/data/fjq/iccnn/densenet/' # for model 34 | save_path = '/data/fjq/iccnn/basic_fmap/densenet/' # for get_feature 35 | acc_path = '/data/fjq/iccnn/basic_fmap/densenet/acc/' 36 | 37 | dataset = '%s_densenet_%s_iccnn' % (LAYERS, DATANAME) 38 | log_path = log_path + dataset + '/' 39 | pretrain_model = log_path + 'model_2000.pth' 40 | BATCHSIZE = 1 41 | LR = 0.00001 42 | EPOCH = 2000 43 | center_num = 5 44 | lam = 1 45 | T = 2 46 | F_MAP_SIZE = 49 47 | STOP_CLUSTERING = 200 48 | if LAYERS == '121': 49 | CHANNEL_NUM = 1024 # for SMGBlock1D 50 | elif LAYERS == '161': 51 | CHANNEL_NUM = 2208 # for SMGBlock1D 52 | 53 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] 54 | 55 | model_urls = { 56 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 57 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 58 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 59 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', 60 | } 61 | 62 | 63 | class _DenseLayer(nn.Module): 64 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=MEMORY_EFFICIENT): 65 | super(_DenseLayer, self).__init__() 66 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 67 | self.add_module('relu1', nn.ReLU(inplace=True)), 68 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 69 | growth_rate, kernel_size=1, stride=1, 70 | bias=False)), 71 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 72 | self.add_module('relu2', nn.ReLU(inplace=True)), 73 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 74 | kernel_size=3, stride=1, padding=0, #new padding 75 | bias=False)), 76 | self.drop_rate = float(drop_rate) 77 | self.memory_efficient = memory_efficient 78 | self.pad2d_1 = newPad2d(1)#new padding 79 | 80 | def bn_function(self, inputs): 81 | # type: (List[Tensor]) -> Tensor 82 | concated_features = torch.cat(inputs, 1) 83 | bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features))) # noqa: T484 84 | return bottleneck_output 85 | 86 | # todo: rewrite when torchscript supports any 87 | def any_requires_grad(self, input): 88 | # type: (List[Tensor]) -> bool 89 | for tensor in input: 90 | if tensor.requires_grad: 91 | return True 92 | return False 93 | 94 | @torch.jit.unused # noqa: T484 95 | def call_checkpoint_bottleneck(self, input): 96 | # type: (List[Tensor]) -> Tensor 97 | def closure(*inputs): 98 | return self.bn_function(inputs) 99 | 100 | return cp.checkpoint(closure, *input) 101 | 102 | @torch.jit._overload_method # noqa: F811 103 | def forward(self, input): 104 | # type: (List[Tensor]) -> (Tensor) 105 | pass 106 | 107 | @torch.jit._overload_method # noqa: F811 108 | def forward(self, input): 109 | # type: (Tensor) -> (Tensor) 110 | pass 111 | 112 | # torchscript does not yet support *args, so we overload method 113 | # allowing it to take either a List[Tensor] or single Tensor 114 | def forward(self, input): # noqa: F811 115 | if isinstance(input, Tensor): 116 | prev_features = [input] 117 | else: 118 | prev_features = input 119 | 120 | if self.memory_efficient and self.any_requires_grad(prev_features): 121 | if torch.jit.is_scripting(): 122 | raise Exception("Memory Efficient not supported in JIT") 123 | 124 | bottleneck_output = self.call_checkpoint_bottleneck(prev_features) 125 | else: 126 | bottleneck_output = self.bn_function(prev_features) 127 | 128 | new_features = self.conv2(self.pad2d_1(self.relu2(self.norm2(bottleneck_output))))#new padding 129 | if self.drop_rate > 0: 130 | new_features = F.dropout(new_features, p=self.drop_rate, 131 | training=self.training) 132 | return new_features 133 | 134 | 135 | class _DenseBlock(nn.ModuleDict): 136 | _version = 2 137 | 138 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=MEMORY_EFFICIENT): 139 | super(_DenseBlock, self).__init__() 140 | for i in range(num_layers): 141 | layer = _DenseLayer( 142 | num_input_features + i * growth_rate, 143 | growth_rate=growth_rate, 144 | bn_size=bn_size, 145 | drop_rate=drop_rate, 146 | memory_efficient=memory_efficient, 147 | ) 148 | self.add_module('denselayer%d' % (i + 1), layer) 149 | 150 | def forward(self, init_features): 151 | features = [init_features] 152 | for name, layer in self.items(): 153 | new_features = layer(features) 154 | features.append(new_features) 155 | return torch.cat(features, 1) 156 | 157 | 158 | class _Transition(nn.Sequential): 159 | def __init__(self, num_input_features, num_output_features): 160 | super(_Transition, self).__init__() 161 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 162 | self.add_module('relu', nn.ReLU(inplace=True)) 163 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 164 | kernel_size=1, stride=1, bias=False)) 165 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 166 | 167 | 168 | class DenseNet(nn.Module): 169 | r"""Densenet-BC model class, based on 170 | `"Densely Connected Convolutional Networks" `_ 171 | 172 | Args: 173 | growth_rate (int) - how many filters to add each layer (`k` in paper) 174 | block_config (list of 4 ints) - how many layers in each pooling block 175 | num_init_features (int) - the number of filters to learn in the first convolution layer 176 | bn_size (int) - multiplicative factor for number of bottle neck layers 177 | (i.e. bn_size * k features in the bottleneck layer) 178 | drop_rate (float) - dropout rate after each dense layer 179 | num_classes (int) - number of classification classes 180 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 181 | but slower. Default: *False*. See `"paper" `_ 182 | """ 183 | 184 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 185 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=2, memory_efficient=MEMORY_EFFICIENT): 186 | 187 | super(DenseNet, self).__init__() 188 | 189 | # First convolution 190 | self.features = nn.Sequential(OrderedDict([ 191 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, 192 | padding=0, bias=False)), # new padding 193 | ('norm0', nn.BatchNorm2d(num_init_features)), 194 | ('relu0', nn.ReLU(inplace=True)), 195 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)), # new padding 196 | ])) 197 | 198 | self.pad2d_1 = newPad2d(1)#new padding 199 | self.pad2d_3 = newPad2d(3)#new padding 200 | 201 | self.smg = SMGBlock(channel_size = CHANNEL_NUM,f_map_size=F_MAP_SIZE) # 14*14 feature map 202 | 203 | 204 | # Each denseblock 205 | num_features = num_init_features 206 | for i, num_layers in enumerate(block_config): 207 | block = _DenseBlock( 208 | num_layers=num_layers, 209 | num_input_features=num_features, 210 | bn_size=bn_size, 211 | growth_rate=growth_rate, 212 | drop_rate=drop_rate, 213 | memory_efficient=memory_efficient 214 | ) 215 | self.features.add_module('denseblock%d' % (i + 1), block) 216 | num_features = num_features + num_layers * growth_rate 217 | if i != len(block_config) - 1: 218 | trans = _Transition(num_input_features=num_features, 219 | num_output_features=num_features // 2) 220 | self.features.add_module('transition%d' % (i + 1), trans) 221 | num_features = num_features // 2 222 | 223 | # Final batch norm 224 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 225 | 226 | # Linear layer 227 | self.classifier = nn.Linear(num_features, num_classes) 228 | 229 | # Official init from torch repo. 230 | for m in self.modules(): 231 | if isinstance(m, nn.Conv2d): 232 | nn.init.kaiming_normal_(m.weight) 233 | elif isinstance(m, nn.BatchNorm2d): 234 | nn.init.constant_(m.weight, 1) 235 | nn.init.constant_(m.bias, 0) 236 | elif isinstance(m, nn.Linear): 237 | nn.init.constant_(m.bias, 0) 238 | 239 | def forward(self, x, eval=False): 240 | for i, layer in enumerate(self.features): 241 | if i == 0: 242 | x = self.pad2d_3(x) # new padding 243 | if i == 3: 244 | x = self.pad2d_1(x) # new padding 245 | x = layer(x) 246 | out = F.relu(x, inplace=True) 247 | if eval: 248 | return out 249 | corre_matrix = self.smg(out) 250 | f_map = out.detach() # get_feature 251 | out = F.adaptive_avg_pool2d(out, (1, 1)) 252 | out = torch.flatten(out, 1) 253 | out = self.classifier(out) 254 | return out,f_map,corre_matrix 255 | 256 | def _load_state_dict(model, model_url, progress): 257 | # '.'s are no longer allowed in module names, but previous _DenseLayer 258 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 259 | # They are also in the checkpoints in model_urls. This pattern is used 260 | # to find such keys. 261 | pattern = re.compile( 262 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 263 | 264 | state_dict = load_state_dict_from_url(model_url, progress=progress) 265 | for key in list(state_dict.keys()): 266 | res = pattern.match(key) 267 | if res: 268 | new_key = res.group(1) + res.group(2) 269 | state_dict[new_key] = state_dict[key] 270 | del state_dict[key] 271 | # model.load_state_dict(state_dict) 272 | # for finetune 273 | pretrained_dict = {k: v for k, v in state_dict.items() if 'classifier' not in k}#'fc' not in k and 'layer4.1' not in k and 274 | model_dict = model.state_dict() 275 | model_dict.update(pretrained_dict) 276 | model.load_state_dict(model_dict) 277 | 278 | 279 | def _densenet(arch, growth_rate, block_config, num_init_features, num_class, pretrained, progress, 280 | **kwargs): 281 | model = DenseNet(growth_rate, block_config, num_init_features, num_classes=num_class, **kwargs) 282 | if pretrained: 283 | _load_state_dict(model, model_urls[arch], progress) 284 | else: 285 | if pretrain_model is not None: 286 | print("Load pretrained model") 287 | device = torch.device("cuda") 288 | # model = nn.DataParallel(model).to(device) 289 | pretrained_dict = torch.load(pretrain_model) 290 | # if IS_TRAIN == 0: 291 | pretrained_dict = {k[k.find('.')+1:]: v for k, v in pretrained_dict.items()} # for get_feature 292 | model.load_state_dict(pretrained_dict) 293 | else: 294 | print('Error: pretrain_model == None') 295 | return model 296 | 297 | def densenet121(num_class, pretrained=False, progress=True, **kwargs): 298 | r"""Densenet-121 model from 299 | `"Densely Connected Convolutional Networks" `_ 300 | 301 | Args: 302 | pretrained (bool): If True, returns a model pre-trained on ImageNet 303 | progress (bool): If True, displays a progress bar of the download to stderr 304 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 305 | but slower. Default: *False*. See `"paper" `_ 306 | """ 307 | return _densenet('densenet121', 32, (6, 12, 24, 16), 64, num_class, pretrained, progress, 308 | **kwargs) 309 | 310 | def densenet161(num_class, pretrained=False, progress=True, **kwargs): 311 | r"""Densenet-161 model from 312 | `"Densely Connected Convolutional Networks" `_ 313 | 314 | Args: 315 | pretrained (bool): If True, returns a model pre-trained on ImageNet 316 | progress (bool): If True, displays a progress bar of the download to stderr 317 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 318 | but slower. Default: *False*. See `"paper" `_ 319 | """ 320 | return _densenet('densenet161', 48, (6, 12, 36, 24), 96, num_class, pretrained, progress, 321 | **kwargs) 322 | 323 | 324 | def densenet169(num_class, pretrained=False, progress=True, **kwargs): 325 | r"""Densenet-169 model from 326 | `"Densely Connected Convolutional Networks" `_ 327 | 328 | Args: 329 | pretrained (bool): If True, returns a model pre-trained on ImageNet 330 | progress (bool): If True, displays a progress bar of the download to stderr 331 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 332 | but slower. Default: *False*. See `"paper" `_ 333 | """ 334 | return _densenet('densenet169', 32, (6, 12, 32, 32), 64, num_class, pretrained, progress, 335 | **kwargs) 336 | 337 | 338 | def densenet201(num_class, pretrained=False, progress=True, **kwargs): 339 | r"""Densenet-201 model from 340 | `"Densely Connected Convolutional Networks" `_ 341 | 342 | Args: 343 | pretrained (bool): If True, returns a model pre-trained on ImageNet 344 | progress (bool): If True, displays a progress bar of the download to stderr 345 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 346 | but slower. Default: *False*. See `"paper" `_ 347 | """ 348 | return _densenet('densenet201', 32, (6, 12, 48, 32), 64, num_class, pretrained, progress, 349 | **kwargs) 350 | 351 | def get_Data(is_train, dataset_name, batch_size): 352 | val_transform = transforms.Compose([ 353 | transforms.Resize((224,224)), 354 | transforms.ToTensor(), 355 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 356 | std=[0.229, 0.224, 0.225]) 357 | ]) 358 | voc_helen = ['bird', 'cat', 'cow', 'dog', 'horse', 'sheep', 'helen', 'voc_multi'] 359 | ##cub dataset### 360 | label = None if is_train else 0 361 | if not is_train: 362 | batch_size = 1 363 | if dataset_name == 'cub': 364 | trainset = CUB_VOC(cub_file, dataset_name, 'iccnn', train=True, transform=val_transform, is_frac=label) 365 | testset = CUB_VOC(cub_file, dataset_name, 'iccnn', train=False, transform=val_transform, is_frac=label) 366 | ###cropped voc dataset### 367 | elif dataset_name in voc_helen: 368 | trainset = CUB_VOC(voc_file, dataset_name, 'iccnn', train=True, transform=val_transform, is_frac=label) 369 | testset = CUB_VOC(voc_file, dataset_name, 'iccnn', train=False, transform=val_transform, is_frac=label) 370 | ###celeb dataset### 371 | #elif dataset_name == 'celeb': 372 | # trainset = Celeb(training = True, transform=None) 373 | # testset = Celeb(training = False, transform=None) 374 | train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True) 375 | test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False) 376 | return train_loader, test_loader 377 | 378 | def net_train(): 379 | trainset_loader, testset_loader = get_Data(IS_TRAIN, DATANAME, BATCHSIZE) 380 | if os.path.exists(log_path): 381 | shutil.rmtree(log_path);os.makedirs(log_path) 382 | else: 383 | os.makedirs(log_path) 384 | device = torch.device("cuda") 385 | net = None 386 | if LAYERS == '121': 387 | net = densenet121(num_class=NUM_CLASSES, pretrained=True) 388 | if LAYERS == '161': 389 | net = densenet161(num_class=NUM_CLASSES, pretrained=True) 390 | net = nn.DataParallel(net).to(device) 391 | # Loss and Optimizer 392 | criterion = nn.CrossEntropyLoss() 393 | optimizer = torch.optim.Adam(net.module.parameters(), lr=LR) 394 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=125, gamma=0.6) 395 | 396 | # Train the model 397 | best_acc = 0.0 398 | save_loss = [];save_similatiry_loss = [];save_gt=[] 399 | cs_loss = Cluster_loss() 400 | for epoch in range(EPOCH+1): 401 | if epoch % T==0 and epoch best_acc: 433 | # best_acc = acc 434 | # torch.save(net.state_dict(), log_path+'model_%.3d_%.4f.pth' % (epoch,best_acc)) 435 | print('Finished Training') 436 | return net 437 | 438 | def offline_spectral_cluster(net, train_data): 439 | net.eval() 440 | f_map = [] 441 | for inputs, labels in train_data: 442 | inputs, labels = inputs.cuda(), labels.cuda() 443 | cur_fmap= net(inputs,eval=True).detach().cpu().numpy() 444 | f_map.append(cur_fmap) 445 | f_map = np.concatenate(f_map,axis=0) 446 | sample, channel,_,_ = f_map.shape 447 | f_map = f_map.reshape((sample,channel,-1)) 448 | mean = np.mean(f_map,axis=0) 449 | cov = np.mean(np.matmul(f_map-mean,np.transpose(f_map-mean,(0,2,1))),axis=0) 450 | diag = np.diag(cov).reshape(channel,-1) 451 | correlation = cov/(np.sqrt(np.matmul(diag,np.transpose(diag,(1,0))))+1e-5)+1 452 | ground_true, loss_mask_num, loss_mask_den = spectral_clustering(correlation,n_cluster=center_num) 453 | 454 | return ground_true, loss_mask_num, loss_mask_den 455 | 456 | def get_feature(): 457 | print('pretrain_model:', pretrain_model) 458 | _,testset_test = get_Data(True, DATANAME, BATCHSIZE) 459 | _,testset_feature = get_Data(False, DATANAME, BATCHSIZE) 460 | device = torch.device("cuda") 461 | net = None 462 | if LAYERS == '121': 463 | net = densenet121(num_class=NUM_CLASSES, pretrained=False) 464 | if LAYERS == '161': 465 | net = densenet161(num_class=NUM_CLASSES, pretrained=False) 466 | acc = test(net, testset_test) 467 | f = open(acc_path+dataset+'_test.txt', 'w+') 468 | f.write('%s\n' % dataset) 469 | f.write('acc:%f\n' % acc) 470 | print(acc) 471 | all_feature = [] 472 | for batch_step, input_data in tqdm(enumerate(testset_feature,0),total=len(testset_feature),smoothing=0.9): 473 | inputs, labels = input_data 474 | inputs, labels = inputs.to(device), labels.to(device) 475 | net.eval() 476 | f_map = net(inputs,eval=True) 477 | all_feature.append(f_map.detach().cpu().numpy()) 478 | all_feature = np.concatenate(all_feature,axis=0) 479 | f.write('sample num:%d' % (all_feature.shape[0])) 480 | f.close() 481 | print(all_feature.shape) 482 | np.savez(save_path+LAYERS+'_densenet_'+DATANAME+'_iccnn.npz', f_map=all_feature[...]) 483 | print('Finished Operation!') 484 | return net 485 | 486 | def test(net, testdata): 487 | correct, total = .0, .0 488 | for batch_step, input_data in tqdm(enumerate(testdata,0),total=len(testdata),smoothing=0.9): 489 | inputs, labels = input_data 490 | inputs, labels = inputs.cuda(), labels.cuda() 491 | net.eval() 492 | outputs, _, _ = net(inputs) 493 | _, predicted = torch.max(outputs, 1) 494 | total += labels.size(0) 495 | correct += (predicted == labels).sum() 496 | return (float(correct) / total) 497 | 498 | def densenet_single_train(): 499 | 500 | if IS_TRAIN == 1: 501 | net = net_train() 502 | elif IS_TRAIN == 0: 503 | net = get_feature() 504 | -------------------------------------------------------------------------------- /densenet_ori_train.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.checkpoint as cp 6 | from collections import OrderedDict 7 | from torch import Tensor 8 | from torch.jit.annotations import List 9 | 10 | #added 11 | import torchvision.transforms as transforms 12 | from torch.utils.data import DataLoader 13 | from load_utils import load_state_dict_from_url 14 | from cub_voc import CUB_VOC 15 | import os 16 | from tqdm import tqdm 17 | import shutil 18 | import numpy as np 19 | from newPad2d import newPad2d 20 | #from torch.autograd import Variable 21 | 22 | 23 | MEMORY_EFFICIENT = True 24 | IS_TRAIN = 0 # 0/1 25 | IS_MULTI = 0 # 0/1 26 | LAYERS = '121' 27 | DATANAME = 'bird' # bird/cat/.../cub/helen/voc_multi 28 | NUM_CLASSES =6 if IS_MULTI else 2 29 | cub_file = '/data/sw/dataset/frac_dataset' 30 | voc_file = '/data/sw/dataset/VOCdevkit/VOC2010/voc2010_crop' 31 | log_path = '/data/fjq/iccnn/densenet/' # for model 32 | save_path = '/data/fjq/iccnn/basic_fmap/densenet/' # for get_feature 33 | acc_path = '/data/fjq/iccnn/basic_fmap/densenet/acc/' 34 | 35 | dataset = '%s_densenet_%s_ori' % (LAYERS, DATANAME) 36 | log_path = log_path + dataset + '/' 37 | pretrain_model = log_path + 'model_2000.pth' 38 | BATCHSIZE = 1 39 | LR = 0.00001 40 | EPOCH = 1000 41 | 42 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] 43 | 44 | model_urls = { 45 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 46 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 47 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 48 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', 49 | } 50 | 51 | 52 | class _DenseLayer(nn.Module): 53 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=MEMORY_EFFICIENT): 54 | super(_DenseLayer, self).__init__() 55 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 56 | self.add_module('relu1', nn.ReLU(inplace=True)), 57 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 58 | growth_rate, kernel_size=1, stride=1, 59 | bias=False)), 60 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 61 | self.add_module('relu2', nn.ReLU(inplace=True)), 62 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 63 | kernel_size=3, stride=1, padding=0, #new padding 64 | bias=False)), 65 | self.drop_rate = float(drop_rate) 66 | self.memory_efficient = memory_efficient 67 | self.pad2d_1 = newPad2d(1) #nn.ReplicationPad2d(1)#new padding 68 | 69 | def bn_function(self, inputs): 70 | # type: (List[Tensor]) -> Tensor 71 | concated_features = torch.cat(inputs, 1) 72 | bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features))) # noqa: T484 73 | return bottleneck_output 74 | 75 | # todo: rewrite when torchscript supports any 76 | def any_requires_grad(self, input): 77 | # type: (List[Tensor]) -> bool 78 | for tensor in input: 79 | if tensor.requires_grad: 80 | return True 81 | return False 82 | 83 | @torch.jit.unused # noqa: T484 84 | def call_checkpoint_bottleneck(self, input): 85 | # type: (List[Tensor]) -> Tensor 86 | def closure(*inputs): 87 | return self.bn_function(inputs) 88 | 89 | return cp.checkpoint(closure, *input) 90 | 91 | @torch.jit._overload_method # noqa: F811 92 | def forward(self, input): 93 | # type: (List[Tensor]) -> (Tensor) 94 | pass 95 | 96 | @torch.jit._overload_method # noqa: F811 97 | def forward(self, input): 98 | # type: (Tensor) -> (Tensor) 99 | pass 100 | 101 | # torchscript does not yet support *args, so we overload method 102 | # allowing it to take either a List[Tensor] or single Tensor 103 | def forward(self, input): # noqa: F811 104 | if isinstance(input, Tensor): 105 | prev_features = [input] 106 | else: 107 | prev_features = input 108 | 109 | if self.memory_efficient and self.any_requires_grad(prev_features): 110 | if torch.jit.is_scripting(): 111 | raise Exception("Memory Efficient not supported in JIT") 112 | 113 | bottleneck_output = self.call_checkpoint_bottleneck(prev_features) 114 | else: 115 | bottleneck_output = self.bn_function(prev_features) 116 | 117 | new_features = self.conv2(self.pad2d_1(self.relu2(self.norm2(bottleneck_output))))#new padding 118 | if self.drop_rate > 0: 119 | new_features = F.dropout(new_features, p=self.drop_rate, 120 | training=self.training) 121 | return new_features 122 | 123 | 124 | class _DenseBlock(nn.ModuleDict): 125 | _version = 2 126 | 127 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=MEMORY_EFFICIENT): 128 | super(_DenseBlock, self).__init__() 129 | for i in range(num_layers): 130 | layer = _DenseLayer( 131 | num_input_features + i * growth_rate, 132 | growth_rate=growth_rate, 133 | bn_size=bn_size, 134 | drop_rate=drop_rate, 135 | memory_efficient=memory_efficient, 136 | ) 137 | self.add_module('denselayer%d' % (i + 1), layer) 138 | 139 | def forward(self, init_features): 140 | features = [init_features] 141 | for name, layer in self.items(): 142 | new_features = layer(features) 143 | features.append(new_features) 144 | return torch.cat(features, 1) 145 | 146 | 147 | class _Transition(nn.Sequential): 148 | def __init__(self, num_input_features, num_output_features): 149 | super(_Transition, self).__init__() 150 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 151 | self.add_module('relu', nn.ReLU(inplace=True)) 152 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 153 | kernel_size=1, stride=1, bias=False)) 154 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 155 | 156 | 157 | class DenseNet(nn.Module): 158 | r"""Densenet-BC model class, based on 159 | `"Densely Connected Convolutional Networks" `_ 160 | 161 | Args: 162 | growth_rate (int) - how many filters to add each layer (`k` in paper) 163 | block_config (list of 4 ints) - how many layers in each pooling block 164 | num_init_features (int) - the number of filters to learn in the first convolution layer 165 | bn_size (int) - multiplicative factor for number of bottle neck layers 166 | (i.e. bn_size * k features in the bottleneck layer) 167 | drop_rate (float) - dropout rate after each dense layer 168 | num_classes (int) - number of classification classes 169 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 170 | but slower. Default: *False*. See `"paper" `_ 171 | """ 172 | 173 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 174 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=2, memory_efficient=MEMORY_EFFICIENT): 175 | 176 | super(DenseNet, self).__init__() 177 | 178 | # First convolution 179 | self.features = nn.Sequential(OrderedDict([ 180 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, 181 | padding=0, bias=False)), # new padding 182 | ('norm0', nn.BatchNorm2d(num_init_features)), 183 | ('relu0', nn.ReLU(inplace=True)), 184 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)), # new padding 185 | ])) 186 | 187 | self.pad2d_1 = newPad2d(1)#nn.ZeroPad2d(1) #new padding 188 | self.pad2d_3 = newPad2d(3)#nn.ZeroPad2d(3) #new padding 189 | 190 | 191 | # Each denseblock 192 | num_features = num_init_features 193 | for i, num_layers in enumerate(block_config): 194 | block = _DenseBlock( 195 | num_layers=num_layers, 196 | num_input_features=num_features, 197 | bn_size=bn_size, 198 | growth_rate=growth_rate, 199 | drop_rate=drop_rate, 200 | memory_efficient=memory_efficient 201 | ) 202 | self.features.add_module('denseblock%d' % (i + 1), block) 203 | num_features = num_features + num_layers * growth_rate 204 | if i != len(block_config) - 1: 205 | trans = _Transition(num_input_features=num_features, 206 | num_output_features=num_features // 2) 207 | self.features.add_module('transition%d' % (i + 1), trans) 208 | num_features = num_features // 2 209 | 210 | # Final batch norm 211 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 212 | 213 | # Linear layer 214 | self.classifier = nn.Linear(num_features, num_classes) 215 | 216 | # Official init from torch repo. 217 | for m in self.modules(): 218 | if isinstance(m, nn.Conv2d): 219 | nn.init.kaiming_normal_(m.weight) 220 | elif isinstance(m, nn.BatchNorm2d): 221 | nn.init.constant_(m.weight, 1) 222 | nn.init.constant_(m.bias, 0) 223 | elif isinstance(m, nn.Linear): 224 | nn.init.constant_(m.bias, 0) 225 | 226 | def forward(self, x): 227 | for i, layer in enumerate(self.features): 228 | if i == 0: 229 | x = self.pad2d_3(x) # new padding 230 | if i == 3: 231 | x = self.pad2d_1(x) # new padding 232 | x = layer(x) 233 | out = F.relu(x, inplace=True) 234 | f_map = out.detach() # get_feature 235 | out = F.adaptive_avg_pool2d(out, (1, 1)) 236 | out = torch.flatten(out, 1) 237 | out = self.classifier(out) 238 | return out, f_map #out 239 | 240 | def _load_state_dict(model, model_url, progress): 241 | # '.'s are no longer allowed in module names, but previous _DenseLayer 242 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 243 | # They are also in the checkpoints in model_urls. This pattern is used 244 | # to find such keys. 245 | pattern = re.compile( 246 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 247 | 248 | state_dict = load_state_dict_from_url(model_url, progress=progress) 249 | for key in list(state_dict.keys()): 250 | # print(key) 251 | res = pattern.match(key) 252 | if res: 253 | new_key = res.group(1) + res.group(2) 254 | state_dict[new_key] = state_dict[key] 255 | del state_dict[key] 256 | pretrained_dict = {k: v for k, v in state_dict.items() if 'classifier' not in k} 257 | model_dict = model.state_dict() 258 | model_dict.update(pretrained_dict) 259 | model.load_state_dict(model_dict, strict=False) 260 | 261 | 262 | def _densenet(arch, growth_rate, block_config, num_init_features, num_class, pretrained, progress, 263 | **kwargs): 264 | model = DenseNet(growth_rate, block_config, num_init_features, num_classes=num_class, **kwargs) 265 | if pretrained: 266 | _load_state_dict(model, model_urls[arch], progress) 267 | else: 268 | if pretrain_model is not None: 269 | device = torch.device("cuda") 270 | model = nn.DataParallel(model).to(device) 271 | model.load_state_dict(torch.load(pretrain_model)) 272 | else: 273 | print('Error: pretrain_model == None') 274 | return model 275 | 276 | 277 | def densenet121(num_class, pretrained=False, progress=True, **kwargs): 278 | r"""Densenet-121 model from 279 | `"Densely Connected Convolutional Networks" `_ 280 | 281 | Args: 282 | pretrained (bool): If True, returns a model pre-trained on ImageNet 283 | progress (bool): If True, displays a progress bar of the download to stderr 284 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 285 | but slower. Default: *False*. See `"paper" `_ 286 | """ 287 | return _densenet('densenet121', 32, (6, 12, 24, 16), 64, num_class, pretrained, progress, 288 | **kwargs) 289 | 290 | 291 | def densenet161(num_class, pretrained=False, progress=True, **kwargs): 292 | r"""Densenet-161 model from 293 | `"Densely Connected Convolutional Networks" `_ 294 | 295 | Args: 296 | pretrained (bool): If True, returns a model pre-trained on ImageNet 297 | progress (bool): If True, displays a progress bar of the download to stderr 298 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 299 | but slower. Default: *False*. See `"paper" `_ 300 | """ 301 | return _densenet('densenet161', 48, (6, 12, 36, 24), 96, num_class, pretrained, progress, 302 | **kwargs) 303 | 304 | 305 | def densenet169(num_class, pretrained=False, progress=True, **kwargs): 306 | r"""Densenet-169 model from 307 | `"Densely Connected Convolutional Networks" `_ 308 | 309 | Args: 310 | pretrained (bool): If True, returns a model pre-trained on ImageNet 311 | progress (bool): If True, displays a progress bar of the download to stderr 312 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 313 | but slower. Default: *False*. See `"paper" `_ 314 | """ 315 | return _densenet('densenet169', 32, (6, 12, 32, 32), 64, num_class, pretrained, progress, 316 | **kwargs) 317 | 318 | 319 | def densenet201(num_class, pretrained=False, progress=True, **kwargs): 320 | r"""Densenet-201 model from 321 | `"Densely Connected Convolutional Networks" `_ 322 | 323 | Args: 324 | pretrained (bool): If True, returns a model pre-trained on ImageNet 325 | progress (bool): If True, displays a progress bar of the download to stderr 326 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 327 | but slower. Default: *False*. See `"paper" `_ 328 | """ 329 | return _densenet('densenet201', 32, (6, 12, 48, 32), 64, num_class, pretrained, progress, 330 | **kwargs) 331 | 332 | def get_Data(is_train, dataset_name, batch_size): 333 | transform = transforms.Compose([ 334 | transforms.RandomResizedCrop((224, 224), scale=(0.5, 1.0)), 335 | transforms.RandomHorizontalFlip(), 336 | transforms.ToTensor(), 337 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 338 | std=[0.229, 0.224, 0.225]) 339 | ]) 340 | val_transform = transforms.Compose([ 341 | transforms.Resize((224,224)), 342 | transforms.ToTensor(), 343 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 344 | std=[0.229, 0.224, 0.225]) 345 | ]) 346 | voc_helen = ['bird', 'cat', 'cow', 'dog', 'horse', 'sheep', 'helen', 'voc_multi'] 347 | ##cub dataset### 348 | label = None if is_train else 0 349 | if not is_train: 350 | batch_size = 1 351 | if dataset_name == 'cub': 352 | trainset = CUB_VOC(cub_file, dataset_name, 'ori', train=True, transform=transform, is_frac=label) 353 | testset = CUB_VOC(cub_file, dataset_name, 'ori', train=False, transform=val_transform, is_frac=label) 354 | ###cropped voc dataset### 355 | elif dataset_name in voc_helen: 356 | trainset = CUB_VOC(voc_file, dataset_name, 'ori', train=True, transform=transform, is_frac=label) 357 | testset = CUB_VOC(voc_file, dataset_name, 'ori', train=False, transform=val_transform, is_frac=label) 358 | ###celeb dataset### 359 | #elif dataset_name == 'celeb': 360 | # trainset = Celeb(training = True, transform=None) 361 | # testset = Celeb(training = False, transform=None) 362 | train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True) 363 | test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False) 364 | return train_loader, test_loader 365 | 366 | def net_train(): 367 | trainset_loader, testset_loader = get_Data(IS_TRAIN, DATANAME, BATCHSIZE) 368 | if os.path.exists(log_path): 369 | shutil.rmtree(log_path);os.makedirs(log_path) 370 | else: 371 | os.makedirs(log_path) 372 | device = torch.device("cuda") 373 | net = None 374 | if LAYERS == '121': 375 | net = densenet121(num_class=NUM_CLASSES, pretrained=True) 376 | if LAYERS == '161': 377 | net = densenet161(num_class=NUM_CLASSES, pretrained=True) 378 | net = nn.DataParallel(net).to(device) 379 | # Loss and Optimizer 380 | criterion = nn.CrossEntropyLoss() 381 | optimizer = torch.optim.Adam(net.module.parameters(), lr=LR) 382 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.6) 383 | 384 | # Train the model 385 | best_acc = 0.0; save_loss = []; test_loss = []; train_acc = []; test_acc = []; 386 | for epoch in range(EPOCH+1): 387 | scheduler.step() 388 | net.train() 389 | total_loss = 0.0; correct = .0; total = .0; 390 | for batch_step, input_data in tqdm(enumerate(trainset_loader,0),total=len(trainset_loader),smoothing=0.9): 391 | inputs, labels = input_data 392 | inputs, labels = inputs.to(device), labels.to(device) 393 | optimizer.zero_grad() 394 | output, _ = net(inputs) 395 | #print(output) 396 | _, predicted = torch.max(output.data, 1) 397 | correct += (predicted == labels).sum() 398 | total += labels.size(0) 399 | loss = criterion(output, labels) 400 | #print(module.features.conv0.weight) 401 | loss.backward() 402 | #if batch_step>0: 403 | # return 404 | #for name, parms in net.named_parameters(): 405 | # print('after* name:', name, 'grad_value:',parms.grad) 406 | optimizer.step() 407 | total_loss += loss.item() 408 | total_loss = float(total_loss) / (batch_step+1) 409 | correct = float(correct) / total 410 | testacc, testloss = test(net, testset_loader) 411 | save_loss.append(total_loss); train_acc.append(correct); 412 | test_loss.append(testloss); test_acc.append(testacc); 413 | np.savez(log_path+'loss.npz', train_loss=np.array(save_loss), test_loss=np.array(test_loss),\ 414 | train_acc=np.array(train_acc), test_acc=np.array(test_acc)) 415 | print('Epoch', epoch, 'train loss: %.4f' % total_loss, 'train accuracy:%.4f' % correct, \ 416 | 'test loss: %.4f' % testloss, 'test accuracy:%.4f' % testacc) 417 | if epoch % 50 == 0: 418 | torch.save(net.state_dict(), log_path+'model_%.3d.pth' % epoch) 419 | if epoch % 1 == 0: 420 | if testacc > best_acc: 421 | best_acc = testacc 422 | torch.save(net.state_dict(), log_path+'model_%.3d_%.4f.pth' % (epoch, best_acc)) 423 | print('Finished Training') 424 | return net 425 | 426 | def get_feature(): 427 | print('pretrain_model:', pretrain_model) 428 | _, testset_test = get_Data(True, DATANAME, BATCHSIZE) 429 | _, testset_feature = get_Data(False, DATANAME, BATCHSIZE) 430 | device = torch.device("cuda") 431 | net = None 432 | if LAYERS == '121': 433 | net = densenet121(num_class=NUM_CLASSES, pretrained=False) 434 | if LAYERS == '161': 435 | net = densenet161(num_class=NUM_CLASSES, pretrained=False) 436 | net = nn.DataParallel(net).to(device) 437 | # Test the model 438 | acc, _ = test(net, testset_test) 439 | f = open(acc_path+dataset+'_test.txt', 'w+') 440 | f.write('%s\n' % dataset) 441 | f.write('acc:%f\n' % acc) 442 | print('test acc:', acc) 443 | all_feature = [] 444 | testset = testset_test if DATANAME == 'voc_multi' else testset_feature 445 | for batch_step, input_data in tqdm(enumerate(testset,0),total=len(testset),smoothing=0.9): 446 | inputs, labels = input_data 447 | inputs, labels = inputs.to(device), labels.to(device) 448 | net.eval() 449 | output, f_map = net(inputs) 450 | all_feature.append(f_map.cpu().numpy()) 451 | all_feature = np.concatenate(all_feature,axis=0) 452 | f.write('sample num:%d' % (all_feature.shape[0])) 453 | f.close() 454 | print(all_feature.shape) 455 | np.savez_compressed(save_path+LAYERS+'_densenet_'+DATANAME+'_ori.npz', f_map=all_feature[...]) 456 | print('Finished Operation!') 457 | return net 458 | 459 | 460 | def test(net, testdata): 461 | criterion = nn.CrossEntropyLoss() 462 | correct, total = .0, .0 463 | total_loss = .0 464 | for batch_step, input_data in tqdm(enumerate(testdata,0),total=len(testdata),smoothing=0.9): 465 | inputs, labels = input_data 466 | inputs, labels = inputs.cuda(), labels.cuda() 467 | net.eval() 468 | outputs, _ = net(inputs) 469 | loss = criterion(outputs, labels) 470 | total_loss += loss.item() 471 | _, predicted = torch.max(outputs, 1) 472 | total += labels.size(0) 473 | correct += (predicted == labels).sum() 474 | total_loss = float(total_loss)/(batch_step+1) 475 | return float(correct)/total, total_loss 476 | 477 | def densenet_ori_train(): 478 | if IS_TRAIN == 1: 479 | net = net_train() 480 | elif IS_TRAIN == 0: 481 | net = get_feature() 482 | -------------------------------------------------------------------------------- /draw_fmap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import cv2 as cv 4 | import os 5 | import shutil 6 | from PIL import Image 7 | import h5py 8 | import argparse 9 | 10 | def channel_max_min_whole(f_map): 11 | T, C, H, W = f_map.shape 12 | max_v = np.max(f_map,axis=(0,2,3),keepdims=True) 13 | min_v = np.min(f_map,axis=(0,2,3),keepdims=True) 14 | print(max_v.shape,min_v.shape) 15 | return (f_map - min_v)/(max_v - min_v + 1e-6) 16 | 17 | def self_max_min(f_map): 18 | if np.max(f_map) - np.mean(f_map) != 0: 19 | return (f_map-np.min(f_map))/(np.max(f_map)-np.mean(f_map))*255.0 20 | else: 21 | return (f_map-np.min(f_map))/(np.max(f_map)-np.mean(f_map)+1e-5)*255.0 22 | 23 | def get_file_path(path): 24 | paths = [] 25 | for root, dirs, files in os.walk(path): 26 | for file in files: 27 | paths.append(os.path.join(root,file)) 28 | return paths 29 | 30 | # every channel 31 | def draw_fmap_from_npz(data, save_dir,SHOW_NUM,save_channel): 32 | N, C, H, W = data.shape 33 | print('data shape:', data.shape) 34 | for i in range(N): 35 | if i in SHOW_NUM: 36 | for j in save_channel:#range(10): 37 | print("-----") 38 | print(i) 39 | print(j) 40 | fig = data[i,j] 41 | fig = cv.resize(fig,(112,112)) 42 | # to visualize more clear, do max min norm 43 | # fig = self_max_min(fig) 44 | print(i,j) 45 | cv.imwrite(save_dir + 'sample'+str(i) + '_channel'+str(j) + '.bmp', fig*255.0) 46 | 47 | # mean 48 | def draw_fmap_from_npz_mean(data, save_dir): 49 | if os.path.exists(save_dir): 50 | shutil.rmtree(save_dir) 51 | os.makedirs(save_dir) 52 | else: 53 | os.makedirs(save_dir) 54 | 55 | T, C, H, W = data.shape 56 | 57 | for i in range(T): 58 | 59 | mean_f_map = data[i].sum(axis=0)/C 60 | mean_f_map = cv.resize(mean_f_map,(112,112)) 61 | # to visualize more clear, do max min norm 62 | mean_f_map = self_max_min(mean_f_map) 63 | cv.imwrite(save_dir + 'voc_'+str(i) + '_mean' + '.bmp', mean_f_map) 64 | 65 | def addTransparency(img, factor = 0.3): 66 | img = img.convert('RGBA') 67 | img_blender = Image.new('RGBA', img.size, (0,0,0,0)) 68 | img = Image.blend(img_blender, img, factor) 69 | return img 70 | 71 | def put_mask(img_path,mask_path,output_fold,Th,factor): 72 | img = Image.open(img_path) 73 | img = addTransparency(img, factor) 74 | mask_img = cv.resize(cv.cvtColor(np.asarray(img),cv.COLOR_RGB2BGR),(224,224)) 75 | print('----') 76 | print(img_path) 77 | print(mask_path) 78 | ori_img = cv.resize(cv.imread(img_path),(224,224)) 79 | 80 | 81 | zeros_mask = cv.resize(cv.imread(mask_path),(224,224)) 82 | mask_for_red = np.zeros((224,224)) 83 | # mask_for_red = pct_max_min(zeros_mask,Th) 84 | for i in range(zeros_mask.shape[0]): 85 | for j in range(zeros_mask.shape[1]): 86 | if np.sum((zeros_mask[i][j]/255.0)>Th): # vgg/cub 0.5 # VOC animal 0.5 87 | mask_for_red[i][j] = 1 88 | mask_img[i][j] = ori_img[i][j] 89 | else: 90 | mask_for_red[i][j] = 0 91 | red = np.zeros((224,224)) 92 | for i in range(mask_for_red.shape[0]): 93 | for j in range(mask_for_red.shape[1]): 94 | if j > 2 and mask_for_red[i][j-1] == 0 and mask_for_red[i][j] == 1: 95 | red[i][j] = 1 96 | red[i][j-1] = 1 97 | red[i][j-2] = 1 98 | red[i][j-3] = 1 99 | if j < (mask_for_red.shape[1]-2): 100 | red[i][j+1] = 1 101 | red[i][j+2] = 1 102 | #red[i][j+3] = 1 103 | if j < (mask_for_red.shape[1]-3) and mask_for_red[i][j] == 1 and mask_for_red[i][j+1] == 0: 104 | red[i][j] = 1 105 | if j > 1: 106 | red[i][j-1] = 1 107 | red[i][j-2] = 1 108 | #red[i][j-3] = 1 109 | red[i][j+1] = 1 110 | red[i][j+2] = 1 111 | red[i][j+3] = 1 112 | if i > 2 and mask_for_red[i-1][j] == 0 and mask_for_red[i][j] == 1: 113 | red[i-1][j] = 1 114 | red[i-2][j] = 1 115 | red[i-3][j] = 1 116 | red[i][j] = 1 117 | if i < (mask_for_red.shape[0]-2): 118 | red[i+1][j] = 1 119 | red[i+2][j] = 1 120 | #red[i+3][j] = 1 121 | if i < (mask_for_red.shape[0]-3) and mask_for_red[i][j] == 1 and mask_for_red[i+1][j] == 0: 122 | if i > 1: 123 | red[i-1][j] = 1 124 | red[i-2][j] = 1 125 | #red[i-3][j] = 1 126 | red[i][j] = 1 127 | red[i+1][j] = 1 128 | red[i+2][j] = 1 129 | red[i+3][j] = 1 130 | 131 | 132 | for i in range(mask_for_red.shape[0]): 133 | for j in range(mask_for_red.shape[1]): 134 | if red[i][j] == 1: 135 | mask_img[i][j] = [0,0,255] 136 | return mask_img 137 | 138 | # image add mask 139 | def image_add_mask(show_num,image_dir,mask_dir,save_dir,save_channel,factor,animal,show_num_per_center): 140 | for i in show_num: 141 | if animal == 'bird': 142 | image_paths = image_dir + 'vocbird_' + str(i) + '.jpg' 143 | else: 144 | image_paths = image_dir + str(i) + '.jpg' 145 | for j,channel in enumerate(save_channel): 146 | mask_path = mask_dir + 'sample'+str(i) + '_channel'+ str(channel) + '.bmp' 147 | mask_img = put_mask(img_path = image_paths,mask_path=mask_path,output_fold=save_dir,Th=Th,factor=factor) 148 | mask_img = cv.resize(mask_img,(112,112)) 149 | cv.imwrite(os.path.join(save_dir+'factor'+str(factor)+'_Th'+str(Th)+'_sample'+str(i)+'_center'+str(j//show_num_per_center)+'_channel'+str(channel)+'.bmp'), mask_img) 150 | 151 | 152 | # randomly shuffle feature maps of N samples 153 | def permute_fmaps_N(data,file_name): 154 | N,_,_,_ = data.shape 155 | permute_idx = np.random.permutation(np.arange(N)) 156 | data = data[permute_idx,...] 157 | print (data.shape,data.dtype) 158 | np.savez(file_name + '_pert', f_map = data) 159 | 160 | def get_cluster(matrix): 161 | cluser = [] 162 | visited = np.zeros(matrix.shape[0]) 163 | for i in range(matrix.shape[0]): 164 | tmp = [] 165 | if(visited[i]==0): 166 | for j in range(matrix.shape[1]): 167 | if(matrix[i][j]==1 ): 168 | tmp.append(j) 169 | visited[j]=1; 170 | cluser.append(tmp) 171 | for i,channels in enumerate(cluser): 172 | print('Group',i,'contains',len(channels),'channels.') 173 | return cluser 174 | 175 | if __name__ == '__main__': 176 | parser = argparse.ArgumentParser() # add positional arguments 177 | 178 | parser.add_argument('-Th', type=int, default=0.2) 179 | parser.add_argument('-factor', type=int, default=0.5) 180 | parser.add_argument('-show_num', type=int, default=10) 181 | parser.add_argument('-model', type=str) 182 | parser.add_argument('-animal', type=str) 183 | parser.add_argument('-fmap_path', type=str) 184 | parser.add_argument('-loss_path', type=str) 185 | parser.add_argument('-folder_name', default=None, type=str) 186 | 187 | args = parser.parse_args() 188 | 189 | 190 | # fixed 191 | Th = args.Th # >Th --> in the red circle 192 | factor = args.factor # the smaller the factor, the darker the area outside the red circle 193 | animals = ['bird','cat','dog','cow','horse','sheep','cub', 'celeba'] 194 | show_num_per_center = args.show_num 195 | file_path = args.fmap_path 196 | # the No. of sample to visualize; the No. starts from 0 197 | # The id of images that you want to visualize 198 | voc = [ 199 | [1],#voc_bird 200 | [1],#voc_cat 201 | [1],#voc_dog 202 | [1],#voc_cow 203 | [1],#voc_horse 204 | [1],#voc_sheep 205 | [1],#cub 206 | [1] #celeba 207 | ] 208 | 209 | # if args.loss_path == None: 210 | # cluster_label = [[],[],[],[],[]] 211 | # cluster_label[0] = np.array(range(100)) 212 | # cluster_label[1] = np.array(range(100,200)) 213 | # cluster_label[2] = np.array(range(200,300)) 214 | # cluster_label[3] = np.array(range(300,400)) 215 | # cluster_label[4] = np.array(range(400,512)) 216 | # else: 217 | loss = np.load(args.loss_path) 218 | gt = loss['gt'][-1] # show channel id of different groups 219 | cluster_label = get_cluster(gt) 220 | print ('groups and channels', cluster_label) 221 | 222 | 223 | 224 | animal_id = animals.index(args.animal)# 0~5; which category you want to draw feature maps 225 | save_channel = [] 226 | for i in range(len(cluster_label)): 227 | for j in range(show_num_per_center): 228 | save_channel.append(cluster_label[i][j]) 229 | 230 | if args.folder_name == None: 231 | model_name = args.model+'_'+args.animal 232 | else: 233 | model_name = args.folder_name 234 | 235 | SHOW_NUM = voc[animal_id] 236 | animal = animals[animal_id] 237 | # load data 238 | data = np.load(file_path)['f_map'] 239 | print('data shape:', data.shape,data.dtype) # verify the data.shape, e.g. bird category has 421 samples 240 | # channel normalization 241 | data = channel_max_min_whole(data) # 242 | 243 | save_dir='./fmap/'+model_name+'/'+animal+'/' 244 | # 245 | if os.path.exists(save_dir): 246 | shutil.rmtree(save_dir) 247 | os.makedirs(save_dir) 248 | else: 249 | os.makedirs(save_dir) 250 | # draw feature map and save feature maps 251 | draw_fmap_from_npz(data,save_dir=save_dir,SHOW_NUM=SHOW_NUM,save_channel=save_channel) #############iccnn 252 | # draw_fmap_from_npz_mean(data, save_dir=save_dir) 253 | 254 | if args.animal == 'cub': 255 | img_dir = './images/hook_cub_test/' 256 | elif args.animal == 'celeba': 257 | img_dir = './images/hook_celeba_test/' 258 | else: 259 | img_dir = './images/voc'+animal+'_test/' 260 | mask_dir = './fmap/'+model_name+'/'+animal+'/' # i.e. the dir of feature maps (same with the 'save_dir' above) 261 | masked_save_dir = './fmap/'+model_name+'/'+animal+'_masked/' # save dir of images with the red circle we want! 262 | if os.path.exists(masked_save_dir): 263 | shutil.rmtree(masked_save_dir) 264 | os.makedirs(masked_save_dir) 265 | else: 266 | os.makedirs(masked_save_dir) 267 | image_add_mask(show_num=SHOW_NUM,image_dir=img_dir,mask_dir=mask_dir,save_dir=masked_save_dir,save_channel=save_channel,factor=factor,animal=animal,show_num_per_center=show_num_per_center) 268 | -------------------------------------------------------------------------------- /hyperparameters.txt: -------------------------------------------------------------------------------- 1 | need to be completed ... -------------------------------------------------------------------------------- /load_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | try: 4 | from torch.hub import load_state_dict_from_url 5 | except ImportError: 6 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 7 | -------------------------------------------------------------------------------- /newPad2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Module 3 | import torch.nn as nn 4 | import copy 5 | 6 | class newPad2d(Module): 7 | def __init__(self,length): 8 | super(newPad2d,self).__init__() 9 | self.length = length 10 | self.zeroPad = nn.ZeroPad2d(self.length) 11 | 12 | def forward(self, input): 13 | b,c,h,w = input.shape 14 | output = self.zeroPad(input) 15 | 16 | #output = torch.FloatTensor(b,c,h+self.length*2,w+self.length*2) 17 | #output[:,:,self.length:self.length+h,self.length:self.length+w] = input 18 | 19 | for i in range(self.length): 20 | # 一层的四个切片 21 | output[:, :, self.length:self.length+h, i] = output[:, :, self.length:self.length+h, self.length] 22 | output[:, :, self.length:self.length + h, w+ self.length+i] = output[:, :, self.length:self.length + h, 23 | self.length-1+w] 24 | output[:, :, i, self.length:self.length+w] = output[:, :, self.length, self.length:self.length+w] 25 | output[:, :, h+self.length+i, self.length:self.length + w] = output[:, :, h + self.length-1, 26 | self.length:self.length + w] 27 | # 对角进行特别处理 28 | for j in range(self.length): 29 | for k in range(self.length): 30 | output[:,:,j,k]=output[:,:,self.length,self.length] 31 | output[:, :, j, w+ self.length+k] = output[:, :, self.length, self.length-1+w] 32 | output[:, :, h+self.length+j, k] = output[:, :, h + self.length-1, self.length] 33 | output[:, :, h+self.length+j, w + self.length + k] = output[:, :, h + self.length-1, self.length - 1 + w] 34 | return output 35 | ''' 36 | class newPad2d(Module): 37 | def __init__(self,length): 38 | super(newPad2d,self).__init__() 39 | self.length = length 40 | self.zeroPad = nn.ZeroPad2d(self.length) 41 | 42 | def forward(self, input): 43 | b,c,h,w = input.shape 44 | output = self.zeroPad(input) 45 | out_cp = torch.zeros_like(output) 46 | #output = torch.FloatTensor(b,c,h+self.length*2,w+self.length*2) 47 | #output[:,:,self.length:self.length+h,self.length:self.length+w] = input 48 | 49 | # 一层的四个切片 50 | out_cp[:, :, self.length:self.length+h, 0:self.length] = output[:, :, self.length:self.length+h, self.length].view(b,c,h,1).repeat(1,1,1,self.length) 51 | out_cp[:, :, self.length:self.length + h, w+self.length: 2*self.length+w] = output[:, :, self.length:self.length + h, self.length-1+w].view(b,c,h,1).repeat(1,1,1,self.length) 52 | out_cp[:, :, 0:self.length, self.length:self.length+w] = output[:, :, self.length, self.length:self.length+w].view(b,c,1,w).repeat(1,1,self.length,1) 53 | out_cp[:, :, h+self.length:h+2*self.length, self.length:self.length+w] = output[:, :, h + self.length-1, self.length:self.length + w].view(b,c,1,w).repeat(1,1,self.length,1) 54 | # 对角进行特别处理 55 | out_cp[:,:, 0:self.length, 0:self.length] = output[:,:,self.length,self.length].view(b,c,1,1).repeat(1,1,self.length,self.length) 56 | out_cp[:, :, 0:self.length, w+self.length: 2*self.length+w] = output[:, :, self.length, self.length-1+w].view(b,c,1,1).repeat(1,1,self.length,self.length) 57 | out_cp[:, :, h+self.length:h+2*self.length, 0:self.length] = output[:, :, h + self.length-1, self.length].view(b,c,1,1).repeat(1,1,self.length,self.length) 58 | out_cp[:, :, h+self.length:h+2*self.length, w+self.length: 2*self.length+w] = output[:, :, h + self.length-1, self.length - 1 + w].view(b,c,1,1).repeat(1,1,self.length,self.length) 59 | out_cp[:, :, self.length:self.length+h, self.length:w+self.length] = output[:, :, self.length:self.length+h, self.length:w+self.length] 60 | return out_cp 61 | ''' 62 | -------------------------------------------------------------------------------- /resnet_iccnn_multi_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | import torchvision.transforms as transforms 6 | import torchvision.models as models 7 | from torch.utils.data import DataLoader 8 | from load_utils import load_state_dict_from_url 9 | from cub_voc import CUB_VOC 10 | import os 11 | from tqdm import tqdm 12 | import shutil 13 | from utils.utils import Cluster_loss, Multiclass_loss 14 | import numpy as np 15 | from Similar_Mask_Generate import SMGBlock 16 | from SpectralClustering import spectral_clustering 17 | from newPad2d import newPad2d 18 | 19 | IS_TRAIN = 0 # 0/1 20 | LAYERS = '18' 21 | DATANAME = 'voc_multi' # voc_multi 22 | NUM_CLASSES = 6 23 | cub_file = '/data/sw/dataset/frac_dataset' 24 | voc_file = '/data/sw/dataset/VOCdevkit/VOC2010/voc2010_crop' 25 | log_path = '/data/fjq/iccnn/resnet/' # for model 26 | save_path = '/data/fjq/iccnn/basic_fmap/resnet/' # for get_feature 27 | acc_path = '/data/fjq/iccnn/basic_fmap/resnet/acc/' 28 | 29 | dataset = '%s_resnet_%s_iccnn' % (LAYERS, DATANAME) 30 | log_path = log_path + dataset + '/' 31 | pretrain_model = log_path + 'model_2000.pth' 32 | BATCHSIZE = 16 33 | LR = 0.00001 34 | EPOCH = 3000 35 | center_num = 16 36 | lam1 = 0.1 37 | lam2 = 0.1 38 | T = 2 # T = 2 ===> do sc each epoch 39 | F_MAP_SIZE = 196 40 | STOP_CLUSTERING = 200 41 | if LAYERS == '18': 42 | CHANNEL_NUM = 256 43 | elif LAYERS == '50': 44 | CHANNEL_NUM = 1024 45 | 46 | _all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 47 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 48 | 'wide_resnet50_2', 'wide_resnet101_2'] 49 | model_urls = { 50 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 51 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 52 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 53 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 54 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 55 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 56 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 57 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 58 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 59 | } 60 | 61 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 62 | """3x3 convolution with padding""" 63 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 64 | padding=0, groups=groups, bias=False, dilation=dilation)#new padding 65 | 66 | def conv1x1(in_planes, out_planes, stride=1): 67 | """1x1 convolution""" 68 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 69 | 70 | class BasicBlock(nn.Module): 71 | expansion = 1 72 | 73 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 74 | base_width=64, dilation=1, norm_layer=None): 75 | super(BasicBlock, self).__init__() 76 | if norm_layer is None: 77 | norm_layer = nn.BatchNorm2d 78 | if groups != 1 or base_width != 64: 79 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 80 | if dilation > 1: 81 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 82 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 83 | self.conv1 = conv3x3(inplanes, planes, stride) 84 | self.bn1 = norm_layer(planes) 85 | self.relu = nn.ReLU(inplace=True) 86 | self.conv2 = conv3x3(planes, planes) 87 | self.bn2 = norm_layer(planes) 88 | self.downsample = downsample 89 | self.stride = stride 90 | self.pad2d = newPad2d(1)#new paddig 91 | 92 | def forward(self, x): 93 | identity = x 94 | out = self.pad2d(x) #new padding 95 | out = self.conv1(out) 96 | out = self.bn1(out) 97 | out = self.relu(out) 98 | 99 | out = self.pad2d(out) #new padding 100 | out = self.conv2(out) 101 | out = self.bn2(out) 102 | 103 | if self.downsample is not None: 104 | identity = self.downsample(x) 105 | 106 | out += identity 107 | out = self.relu(out) 108 | 109 | return out 110 | 111 | class Bottleneck(nn.Module): 112 | expansion = 4 113 | 114 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 115 | base_width=64, dilation=1, norm_layer=None): 116 | super(Bottleneck, self).__init__() 117 | if norm_layer is None: 118 | norm_layer = nn.BatchNorm2d 119 | width = int(planes * (base_width / 64.)) * groups 120 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 121 | self.conv1 = conv1x1(inplanes, width) 122 | self.bn1 = norm_layer(width) 123 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 124 | self.bn2 = norm_layer(width) 125 | self.conv3 = conv1x1(width, planes * self.expansion) 126 | self.bn3 = norm_layer(planes * self.expansion) 127 | self.relu = nn.ReLU(inplace=True) 128 | self.downsample = downsample 129 | self.stride = stride 130 | self.pad2d = newPad2d(1)#new paddig 131 | 132 | def forward(self, x): 133 | identity = x 134 | 135 | out = self.conv1(x) 136 | out = self.bn1(out) 137 | out = self.relu(out) 138 | 139 | out = self.pad2d(out) #new padding 140 | out = self.conv2(out) 141 | out = self.bn2(out) 142 | out = self.relu(out) 143 | 144 | out = self.conv3(out) 145 | out = self.bn3(out) 146 | 147 | if self.downsample is not None: 148 | identity = self.downsample(x) 149 | 150 | out += identity 151 | out = self.relu(out) 152 | 153 | return out 154 | 155 | class ResNet(nn.Module): 156 | def __init__(self, block, layers, num_classes=NUM_CLASSES, zero_init_residual=False, 157 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 158 | norm_layer=None): 159 | super(ResNet, self).__init__() 160 | if norm_layer is None: 161 | norm_layer = nn.BatchNorm2d 162 | self._norm_layer = norm_layer 163 | 164 | self.inplanes = 64 165 | self.dilation = 1 166 | if replace_stride_with_dilation is None: 167 | # each element in the tuple indicates if we should replace 168 | # the 2x2 stride with a dilated convolution instead 169 | replace_stride_with_dilation = [False, False, False] 170 | if len(replace_stride_with_dilation) != 3: 171 | raise ValueError("replace_stride_with_dilation should be None " 172 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 173 | self.groups = groups 174 | self.base_width = width_per_group 175 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=0, 176 | bias=False)#new padding 177 | self.bn1 = norm_layer(self.inplanes) 178 | self.relu = nn.ReLU(inplace=True) 179 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)#new padding 180 | self.layer1 = self._make_layer(block, 64, layers[0]) 181 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 182 | dilate=replace_stride_with_dilation[0]) 183 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 184 | dilate=replace_stride_with_dilation[1]) 185 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 186 | dilate=replace_stride_with_dilation[2]) 187 | self.smg = SMGBlock(channel_size = CHANNEL_NUM,f_map_size=F_MAP_SIZE) 188 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 189 | self.fc = nn.Linear(512 * block.expansion, num_classes) 190 | self.pad2d_1 = newPad2d(1)#new paddig 191 | self.pad2d_3 = newPad2d(3)#new paddig 192 | 193 | for m in self.modules(): 194 | if isinstance(m, nn.Conv2d): 195 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 196 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 197 | nn.init.constant_(m.weight, 1) 198 | nn.init.constant_(m.bias, 0) 199 | 200 | # Zero-initialize the last BN in each residual branch, 201 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 202 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 203 | if zero_init_residual: 204 | for m in self.modules(): 205 | if isinstance(m, Bottleneck): 206 | nn.init.constant_(m.bn3.weight, 0) 207 | elif isinstance(m, BasicBlock): 208 | nn.init.constant_(m.bn2.weight, 0) 209 | 210 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 211 | norm_layer = self._norm_layer 212 | downsample = None 213 | previous_dilation = self.dilation 214 | if dilate: 215 | self.dilation *= stride 216 | stride = 1 217 | if stride != 1 or self.inplanes != planes * block.expansion: 218 | downsample = nn.Sequential( 219 | conv1x1(self.inplanes, planes * block.expansion, stride), 220 | norm_layer(planes * block.expansion), 221 | ) 222 | 223 | layers = [] 224 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 225 | self.base_width, previous_dilation, norm_layer)) 226 | self.inplanes = planes * block.expansion 227 | for _ in range(1, blocks): 228 | layers.append(block(self.inplanes, planes, groups=self.groups, 229 | base_width=self.base_width, dilation=self.dilation, 230 | norm_layer=norm_layer)) 231 | 232 | return nn.Sequential(*layers) 233 | 234 | def forward(self, x, eval=False): 235 | # See note [TorchScript super()] 236 | x = self.pad2d_3(x) #new padding 237 | x = self.conv1(x) 238 | x = self.bn1(x) 239 | x = self.relu(x) 240 | x = self.pad2d_1(x) 241 | x = self.maxpool(x) 242 | 243 | x = self.layer1(x) 244 | x = self.layer2(x) 245 | x = self.layer3(x) 246 | if eval: 247 | return x 248 | corre_matrix = self.smg(x) 249 | f_map = x 250 | x = self.layer4(x) 251 | # if eval: 252 | # return x 253 | # corre_matrix = self.smg(x,ground_true) 254 | # f_map = x 255 | x = self.avgpool(x) 256 | x = torch.flatten(x, 1) 257 | x = self.fc(x) 258 | return x, f_map, corre_matrix 259 | 260 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 261 | model = ResNet(block, layers, **kwargs) 262 | if pretrained: 263 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 264 | pretrained_dict = {k: v for k, v in state_dict.items() if 'fc' not in k}#'fc' not in k and 'layer4.1' not in k and 265 | model_dict = model.state_dict() 266 | model_dict.update(pretrained_dict) 267 | model.load_state_dict(model_dict) 268 | else: 269 | if pretrain_model is not None: 270 | print("Load pretrained model",pretrain_model) 271 | device = torch.device("cuda") 272 | model = nn.DataParallel(model).to(device) 273 | pretrained_dict = torch.load(pretrain_model) 274 | pretrained_dict = {k[k.find('.')+1:]: v for k, v in pretrained_dict.items()} 275 | model.load_state_dict(pretrained_dict) 276 | return model 277 | 278 | def ResNet18(pretrained=False, progress=True, **kwargs): 279 | return _resnet('resnet18', BasicBlock, [2,2,2,2], pretrained, progress, **kwargs) 280 | 281 | def ResNet34(pretrained=False, progress=True, **kwargs): 282 | return _resnet('resnet34', BasicBlock, [3,4,6,3], pretrained, progress, **kwargs) 283 | 284 | def ResNet50(pretrained=False, progress=True, **kwargs): 285 | return _resnet('resnet50', Bottleneck, [3,4,6,3], pretrained, progress, **kwargs) 286 | 287 | def ResNet101(pretrained=False, progress=True, **kwargs): 288 | return _resnet('resnet101', Bottleneck, [3,4,23,3], pretrained, progress, **kwargs) 289 | 290 | def ResNet152(pretrained=False, progress=True, **kwargs): 291 | return _resnet('resnet152', Bottleneck, [3,8,36,3], pretrained, progress, **kwargs) 292 | 293 | def get_Data(is_train, dataset_name, batch_size): 294 | val_transform = transforms.Compose([ 295 | transforms.Resize((224,224)), 296 | transforms.ToTensor(), 297 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 298 | std=[0.229, 0.224, 0.225]) 299 | ]) 300 | voc_helen = ['bird', 'cat', 'cow', 'dog', 'horse', 'sheep', 'helen', 'voc_multi'] 301 | ##cub dataset### 302 | label = None if is_train else 0 303 | if not is_train: 304 | batch_size = 1 305 | if dataset_name == 'cub': 306 | trainset = CUB_VOC(cub_file, dataset_name, 'iccnn', train=True, transform=val_transform, is_frac=label) 307 | testset = CUB_VOC(cub_file, dataset_name, 'iccnn', train=False, transform=val_transform, is_frac=label) 308 | ###cropped voc dataset### 309 | elif dataset_name in voc_helen: 310 | trainset = CUB_VOC(voc_file, dataset_name, 'iccnn', train=True, transform=val_transform, is_frac=label) 311 | testset = CUB_VOC(voc_file, dataset_name, 'iccnn', train=False, transform=val_transform, is_frac=label) 312 | ###celeb dataset### 313 | #elif dataset_name == 'celeb': 314 | # trainset = Celeb(training = True, transform=None) 315 | # testset = Celeb(training = False, transform=None) 316 | train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True) 317 | test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False) 318 | return train_loader, test_loader 319 | 320 | def net_train(): 321 | trainset_loader, testset_loader = get_Data(IS_TRAIN, DATANAME, BATCHSIZE) 322 | if os.path.exists(log_path): 323 | shutil.rmtree(log_path);os.makedirs(log_path) 324 | else: 325 | os.makedirs(log_path) 326 | device = torch.device("cuda") 327 | 328 | if LAYERS == '18': 329 | net = ResNet18(pretrained=False) 330 | elif LAYERS == '50': 331 | net = ResNet50(pretrained=False) 332 | 333 | net = nn.DataParallel(net).to(device) 334 | # Loss and Optimizer 335 | criterion = nn.CrossEntropyLoss() 336 | optimizer = torch.optim.Adam(net.module.parameters(), lr=LR) 337 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=125, gamma=0.6) 338 | 339 | # Train the model 340 | save_loss = [];save_similatiry_loss = [];save_gt=[];save_class_loss= [];save_total_loss = []; 341 | cs_loss = Cluster_loss() 342 | mc_loss = Multiclass_loss(class_num=NUM_CLASSES) 343 | for epoch in range(EPOCH+1): 344 | if epoch % T==0 and epoch < STOP_CLUSTERING: 345 | with torch.no_grad(): 346 | Ground_true, loss_mask_num, loss_mask_den = offline_spectral_cluster(net, trainset_loader) 347 | save_gt.append(Ground_true.cpu().numpy()) 348 | else: 349 | scheduler.step() 350 | net.train() 351 | all_feature = []; total_loss = 0.0;similarity_loss = 0.0;class_loss = 0.0 352 | for batch_step, input_data in tqdm(enumerate(trainset_loader,0),total=len(trainset_loader),smoothing=0.9): 353 | inputs, labels = input_data 354 | inputs, labels = inputs.to(device), labels.to(device) 355 | optimizer.zero_grad() 356 | output, f_map, corre = net(inputs, eval=False) 357 | 358 | clr_loss = criterion(output, labels) 359 | loss1 = cs_loss.update(corre, loss_mask_num, loss_mask_den, None) 360 | loss2 = mc_loss.update(f_map, loss_mask_num, labels) 361 | loss = clr_loss + lam1 *loss1 + lam2*loss2 362 | loss.backward() 363 | optimizer.step() 364 | total_loss += loss.item() 365 | similarity_loss += loss1.item() 366 | class_loss += loss2.item() 367 | 368 | ### loss save code ##### 369 | total_loss = float(total_loss) / len(trainset_loader) 370 | similarity_loss = float(similarity_loss) / len(trainset_loader) 371 | class_loss = float(class_loss) / len(trainset_loader) 372 | save_total_loss.append(total_loss) 373 | save_similatiry_loss.append(similarity_loss) 374 | save_class_loss.append(class_loss) 375 | acc = 0 376 | #if epoch % 5==0: 377 | # acc = test(net, testset_loader) 378 | 379 | print('Epoch', epoch, 'loss: %.4f' % total_loss, 'sc_loss: %.4f' % similarity_loss, 'class_loss: %.4f' % class_loss, 'test accuracy:%.4f' % acc) 380 | if epoch % 100 == 0: 381 | torch.save(net.state_dict(), log_path+'model_%.3d.pth' % (epoch)) 382 | np.savez(log_path+'loss_%.3d.npz'% (epoch), loss=np.array(save_total_loss), similarity_loss = np.array(save_similatiry_loss), class_loss = np.array(save_class_loss),gt=np.array(save_gt)) 383 | print('Finished Training') 384 | 385 | def offline_spectral_cluster(net, train_data): 386 | net.eval() 387 | f_map = [] 388 | for inputs, labels in train_data: 389 | inputs, labels = inputs.cuda(), labels.cuda() 390 | cur_fmap= net(inputs,eval=True).detach().cpu().numpy() 391 | f_map.append(cur_fmap) 392 | f_map = np.concatenate(f_map,axis=0) 393 | sample, channel,_,_ = f_map.shape 394 | f_map = f_map.reshape((sample,channel,-1)) 395 | mean = np.mean(f_map,axis=0) 396 | cov = np.mean(np.matmul(f_map-mean,np.transpose(f_map-mean,(0,2,1))),axis=0) 397 | diag = np.diag(cov).reshape(channel,-1) 398 | correlation = cov/(np.sqrt(np.matmul(diag,np.transpose(diag,(1,0))))+1e-5)+1 399 | ground_true, loss_mask_num, loss_mask_den = spectral_clustering(correlation,n_cluster=center_num) 400 | 401 | return ground_true, loss_mask_num, loss_mask_den 402 | 403 | def get_feature(): 404 | print('pretrain_model:', pretrain_model) 405 | _, testset_test = get_Data(True, DATANAME, BATCHSIZE) 406 | _, testset_feature = get_Data(False, DATANAME, BATCHSIZE) 407 | device = torch.device("cuda") 408 | net = None 409 | if LAYERS == '50': 410 | net = ResNet50(pretrained=False) 411 | elif LAYERS == '18': 412 | net = ResNet18(pretrained=False) 413 | 414 | acc = test(net, testset_test) 415 | f = open(acc_path+dataset+'_test.txt', 'w+') 416 | f.write('%s\n' % dataset) 417 | f.write('acc:%f\n' %acc) 418 | all_feature = [] 419 | for batch_step, input_data in tqdm(enumerate(testset_feature,0),total=len(testset_feature),smoothing=0.9): 420 | inputs, labels = input_data 421 | inputs, labels = inputs.to(device), labels.to(device) 422 | net.eval() 423 | f_map = net(inputs,eval=True) 424 | all_feature.append(f_map.detach().cpu().numpy()) 425 | all_feature = np.concatenate(all_feature,axis=0) 426 | print(all_feature.shape) 427 | f.write('sample num:%d' % (all_feature.shape[0])) 428 | f.close() 429 | np.savez(save_path+LAYERS+'_resnet_'+DATANAME+'iccnn_.npz', f_map=all_feature[...]) 430 | print('Finished Operation!') 431 | 432 | def test(net, testdata): 433 | correct, total = .0, .0 434 | for inputs, labels in testdata: 435 | inputs, labels = inputs.cuda(), labels.cuda() 436 | net.eval() 437 | outputs, _,_ = net(inputs) 438 | _, predicted = torch.max(outputs, 1) 439 | total += labels.size(0) 440 | correct += (predicted == labels).sum() 441 | print('test acc = ',float(correct) / total) 442 | return float(correct) / total 443 | 444 | def resnet_multi_train(): 445 | if IS_TRAIN ==1: 446 | net_train() 447 | elif IS_TRAIN == 0: 448 | get_feature() 449 | -------------------------------------------------------------------------------- /resnet_iccnn_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import torch 3 | import torch.nn as nn 4 | import torchvision.transforms as transforms 5 | import torchvision.models as models 6 | from torch.utils.data import DataLoader 7 | from load_utils import load_state_dict_from_url 8 | from cub_voc import CUB_VOC 9 | import os 10 | from tqdm import tqdm 11 | import shutil 12 | from utils.utils import Cluster_loss 13 | import numpy as np 14 | from celeb import Celeb 15 | from Similar_Mask_Generate import SMGBlock 16 | from SpectralClustering import spectral_clustering 17 | from newPad2d import newPad2d 18 | 19 | IS_TRAIN = 0 # 0/1 20 | LAYERS = '18' 21 | DATANAME = 'bird' 22 | NUM_CLASSES = 80 if DATANAME == 'celeb' else 2 23 | cub_file = '/data/sw/dataset/frac_dataset' 24 | voc_file = '/data/sw/dataset/VOCdevkit/VOC2010/voc2010_crop' 25 | celeb_file = '/home/user05/fjq/dataset/CelebA/' 26 | log_path = '/data/fjq/iccnn/resnet/' # for model 27 | save_path = '/data/fjq/iccnn/basic_fmap/resnet/' # for get_feature 28 | acc_path = '/data/fjq/iccnn/basic_fmap/resnet/acc/' 29 | 30 | dataset = '%s_resnet_%s_iccnn' % (LAYERS, DATANAME) 31 | log_path = log_path + dataset + '/' 32 | pretrain_model = log_path + 'model_2000.pth' 33 | BATCHSIZE = 16 34 | LR = 0.00001 35 | EPOCH = 2500 36 | center_num = 16 37 | lam = 1 38 | T = 2 # T = 2 ===> do sc each epoch 39 | F_MAP_SIZE = 196 40 | STOP_CLUSTERING = 200 41 | if LAYERS == '18': 42 | CHANNEL_NUM = 256 43 | elif LAYERS == '50': 44 | CHANNEL_NUM = 1024 45 | 46 | _all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 47 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 48 | 'wide_resnet50_2', 'wide_resnet101_2'] 49 | model_urls = { 50 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 51 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 52 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 53 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 54 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 55 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 56 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 57 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 58 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 59 | } 60 | 61 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 62 | """3x3 convolution with padding""" 63 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 64 | padding=0, groups=groups, bias=False, dilation=dilation)#new padding 65 | 66 | def conv1x1(in_planes, out_planes, stride=1): 67 | """1x1 convolution""" 68 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 69 | 70 | class BasicBlock(nn.Module): 71 | expansion = 1 72 | 73 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 74 | base_width=64, dilation=1, norm_layer=None): 75 | super(BasicBlock, self).__init__() 76 | if norm_layer is None: 77 | norm_layer = nn.BatchNorm2d 78 | if groups != 1 or base_width != 64: 79 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 80 | if dilation > 1: 81 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 82 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 83 | self.conv1 = conv3x3(inplanes, planes, stride) 84 | self.bn1 = norm_layer(planes) 85 | self.relu = nn.ReLU(inplace=True) 86 | self.conv2 = conv3x3(planes, planes) 87 | self.bn2 = norm_layer(planes) 88 | self.downsample = downsample 89 | self.stride = stride 90 | self.pad2d = newPad2d(1)#new paddig 91 | 92 | def forward(self, x): 93 | identity = x 94 | out = self.pad2d(x) #new padding 95 | out = self.conv1(out) 96 | out = self.bn1(out) 97 | out = self.relu(out) 98 | 99 | out = self.pad2d(out) #new padding 100 | out = self.conv2(out) 101 | out = self.bn2(out) 102 | 103 | if self.downsample is not None: 104 | identity = self.downsample(x) 105 | 106 | out += identity 107 | out = self.relu(out) 108 | 109 | return out 110 | 111 | class Bottleneck(nn.Module): 112 | expansion = 4 113 | 114 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 115 | base_width=64, dilation=1, norm_layer=None): 116 | super(Bottleneck, self).__init__() 117 | if norm_layer is None: 118 | norm_layer = nn.BatchNorm2d 119 | width = int(planes * (base_width / 64.)) * groups 120 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 121 | self.conv1 = conv1x1(inplanes, width) 122 | self.bn1 = norm_layer(width) 123 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 124 | self.bn2 = norm_layer(width) 125 | self.conv3 = conv1x1(width, planes * self.expansion) 126 | self.bn3 = norm_layer(planes * self.expansion) 127 | self.relu = nn.ReLU(inplace=True) 128 | self.downsample = downsample 129 | self.stride = stride 130 | self.pad2d = newPad2d(1)#new paddig 131 | 132 | def forward(self, x): 133 | identity = x 134 | out = self.conv1(x) 135 | 136 | out = self.bn1(out) 137 | out = self.relu(out) 138 | 139 | out = self.pad2d(out) #new padding 140 | out = self.conv2(out) 141 | out = self.bn2(out) 142 | out = self.relu(out) 143 | 144 | out = self.conv3(out) 145 | out = self.bn3(out) 146 | 147 | if self.downsample is not None: 148 | identity = self.downsample(x) 149 | 150 | out += identity 151 | out = self.relu(out) 152 | 153 | return out 154 | 155 | class ResNet(nn.Module): 156 | def __init__(self, block, layers, num_classes, zero_init_residual=False, 157 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 158 | norm_layer=None): 159 | super(ResNet, self).__init__() 160 | if norm_layer is None: 161 | norm_layer = nn.BatchNorm2d 162 | self._norm_layer = norm_layer 163 | 164 | self.inplanes = 64 165 | self.dilation = 1 166 | if replace_stride_with_dilation is None: 167 | # each element in the tuple indicates if we should replace 168 | # the 2x2 stride with a dilated convolution instead 169 | replace_stride_with_dilation = [False, False, False] 170 | if len(replace_stride_with_dilation) != 3: 171 | raise ValueError("replace_stride_with_dilation should be None " 172 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 173 | self.groups = groups 174 | self.base_width = width_per_group 175 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=0, 176 | bias=False)#new padding 177 | self.bn1 = norm_layer(self.inplanes) 178 | self.relu = nn.ReLU(inplace=True) 179 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)#new padding 180 | self.layer1 = self._make_layer(block, 64, layers[0]) 181 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 182 | dilate=replace_stride_with_dilation[0]) 183 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 184 | dilate=replace_stride_with_dilation[1]) 185 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 186 | dilate=replace_stride_with_dilation[2]) 187 | self.smg = SMGBlock(channel_size = CHANNEL_NUM,f_map_size=F_MAP_SIZE) 188 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 189 | self.fc = nn.Linear(512 * block.expansion, num_classes) 190 | self.pad2d_1 = newPad2d(1)#new paddig 191 | self.pad2d_3 = newPad2d(3)#new paddig 192 | 193 | for m in self.modules(): 194 | if isinstance(m, nn.Conv2d): 195 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 196 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 197 | nn.init.constant_(m.weight, 1) 198 | nn.init.constant_(m.bias, 0) 199 | 200 | # Zero-initialize the last BN in each residual branch, 201 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 202 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 203 | if zero_init_residual: 204 | for m in self.modules(): 205 | if isinstance(m, Bottleneck): 206 | nn.init.constant_(m.bn3.weight, 0) 207 | elif isinstance(m, BasicBlock): 208 | nn.init.constant_(m.bn2.weight, 0) 209 | 210 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 211 | norm_layer = self._norm_layer 212 | downsample = None 213 | previous_dilation = self.dilation 214 | if dilate: 215 | self.dilation *= stride 216 | stride = 1 217 | if stride != 1 or self.inplanes != planes * block.expansion: 218 | downsample = nn.Sequential( 219 | conv1x1(self.inplanes, planes * block.expansion, stride), 220 | norm_layer(planes * block.expansion), 221 | ) 222 | 223 | layers = [] 224 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 225 | self.base_width, previous_dilation, norm_layer)) 226 | self.inplanes = planes * block.expansion 227 | for _ in range(1, blocks): 228 | layers.append(block(self.inplanes, planes, groups=self.groups, 229 | base_width=self.base_width, dilation=self.dilation, 230 | norm_layer=norm_layer)) 231 | 232 | return nn.Sequential(*layers) 233 | 234 | def forward(self, x, eval=False): 235 | # See note [TorchScript super()] 236 | x = self.pad2d_3(x) #new padding 237 | x = self.conv1(x) 238 | x = self.bn1(x) 239 | x = self.relu(x) 240 | x = self.pad2d_1(x) 241 | x = self.maxpool(x) 242 | 243 | x = self.layer1(x) 244 | x = self.layer2(x) 245 | x = self.layer3(x) 246 | if eval: 247 | return x 248 | corre_matrix = self.smg(x) 249 | f_map = x.detach() 250 | x = self.layer4(x) 251 | # if eval: 252 | # return x 253 | # corre_matrix = self.smg(x,ground_true) 254 | # f_map = x.detach() 255 | x = self.avgpool(x) 256 | x = torch.flatten(x, 1) 257 | x = self.fc(x) 258 | return x, f_map.detach(), corre_matrix 259 | 260 | def _resnet(arch, block, layers, num_class, pretrained, progress, **kwargs): 261 | model = ResNet(block, layers, num_class, **kwargs) 262 | if pretrained: 263 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 264 | pretrained_dict = {k: v for k, v in state_dict.items() if 'fc' not in k}#'fc' not in k and 'layer4.1' not in k and 265 | model_dict = model.state_dict() 266 | model_dict.update(pretrained_dict) 267 | model.load_state_dict(model_dict) 268 | else: 269 | if pretrain_model is not None: 270 | print("Load pretrained model") 271 | device = torch.device("cuda") 272 | model = nn.DataParallel(model).to(device) 273 | pretrained_dict = torch.load(pretrain_model) 274 | if IS_TRAIN == 0: 275 | pretrained_dict = {k[k.find('.')+1:]: v for k, v in pretrained_dict.items()} 276 | model.load_state_dict(pretrained_dict) 277 | return model 278 | 279 | def ResNet18(num_class, pretrained=False, progress=True, **kwargs): 280 | return _resnet('resnet18', BasicBlock, [2,2,2,2], num_class, pretrained, progress, **kwargs) 281 | 282 | def ResNet34(num_class, pretrained=False, progress=True, **kwargs): 283 | return _resnet('resnet34', BasicBlock, [3,4,6,3], num_class, pretrained, progress, **kwargs) 284 | 285 | def ResNet50(num_class, pretrained=False, progress=True, **kwargs): 286 | return _resnet('resnet50', Bottleneck, [3,4,6,3], num_class, pretrained, progress, **kwargs) 287 | 288 | def ResNet101(num_class, pretrained=False, progress=True, **kwargs): 289 | return _resnet('resnet101', Bottleneck, [3,4,23,3], num_class, pretrained, progress, **kwargs) 290 | 291 | def ResNet152(num_class, pretrained=False, progress=True, **kwargs): 292 | return _resnet('resnet152', Bottleneck, [3,8,36,3], num_class, pretrained, progress, **kwargs) 293 | 294 | def get_Data(is_train, dataset_name, batch_size): 295 | val_transform = transforms.Compose([ 296 | transforms.Resize((224,224)), 297 | transforms.ToTensor(), 298 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 299 | std=[0.229, 0.224, 0.225]) 300 | ]) 301 | celeb_transform = transforms.Compose([ 302 | transforms.ToTensor(), 303 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 304 | ]) 305 | voc_helen = ['bird', 'cat', 'cow', 'dog', 'horse', 'sheep', 'helen', 'voc_multi'] 306 | ##cub dataset### 307 | label = None if is_train else 0 308 | if not is_train: 309 | batch_size = 1 310 | if dataset_name == 'cub': 311 | trainset = CUB_VOC(cub_file, dataset_name, 'iccnn', train=True, transform=val_transform, is_frac=label) 312 | testset = CUB_VOC(cub_file, dataset_name, 'iccnn', train=False, transform=val_transform, is_frac=label) 313 | ###cropped voc dataset### 314 | elif dataset_name in voc_helen: 315 | trainset = CUB_VOC(voc_file, dataset_name, 'iccnn', train=True, transform=val_transform, is_frac=label) 316 | testset = CUB_VOC(voc_file, dataset_name, 'iccnn', train=False, transform=val_transform, is_frac=label) 317 | ###celeb dataset### 318 | elif dataset_name == 'celeb': 319 | trainset = Celeb(celeb_file, training = True, transform=celeb_transform, train_num=10240) 320 | testset = Celeb(celeb_file, training = False, transform=celeb_transform, train_num=19962) 321 | train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True) 322 | test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False) 323 | return train_loader, test_loader 324 | 325 | def net_train(): 326 | trainset_loader, testset_loader = get_Data(IS_TRAIN, DATANAME, BATCHSIZE) 327 | if os.path.exists(log_path): 328 | shutil.rmtree(log_path);os.makedirs(log_path) 329 | else: 330 | os.makedirs(log_path) 331 | device = torch.device("cuda") 332 | net = None 333 | if LAYERS == '18': 334 | net = ResNet18(num_class=NUM_CLASSES, pretrained=False) 335 | elif LAYERS == '50': 336 | net = ResNet50(num_class=NUM_CLASSES, pretrained=False) 337 | 338 | net = nn.DataParallel(net).to(device) 339 | # Loss and Optimizer 340 | criterion = nn.CrossEntropyLoss() 341 | optimizer = torch.optim.Adam(net.module.parameters(), lr=LR) 342 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=125, gamma=0.6) 343 | 344 | # Train the model 345 | #best_acc = 0.0 346 | save_loss = [];save_similatiry_loss = [];save_gt=[] 347 | cs_loss = Cluster_loss() 348 | for epoch in range(EPOCH+1): 349 | if epoch % T==0 and epoch < STOP_CLUSTERING: 350 | with torch.no_grad(): 351 | Ground_true, loss_mask_num, loss_mask_den = offline_spectral_cluster(net, trainset_loader, DATANAME) 352 | save_gt.append(Ground_true.cpu().numpy()) 353 | else: 354 | scheduler.step() 355 | net.train() 356 | all_feature = []; total_loss = 0.0;similarity_loss = 0.0 357 | for batch_step, input_data in tqdm(enumerate(trainset_loader,0),total=len(trainset_loader),smoothing=0.9): 358 | inputs, labels = input_data 359 | inputs, labels = inputs.to(device), labels.long().to(device) 360 | optimizer.zero_grad() 361 | output, f_map, corre = net(inputs, eval=False) 362 | 363 | if DATANAME != 'celeb': 364 | clr_loss = criterion(output, labels) 365 | else: 366 | clr_loss = .0 367 | for attribution in range(NUM_CLASSES//2): 368 | clr_loss += criterion(output[:, 2*attribution:2*attribution+2], labels[:, attribution]) 369 | labels = None 370 | 371 | loss_ = cs_loss.update(corre, loss_mask_num, loss_mask_den, labels) 372 | loss = clr_loss + lam * loss_ 373 | loss.backward() 374 | optimizer.step() 375 | total_loss += loss.item() 376 | similarity_loss += loss_.item() 377 | 378 | ### loss save code ##### 379 | total_loss = float(total_loss) / len(trainset_loader) 380 | similarity_loss = float(similarity_loss) / len(trainset_loader) 381 | save_loss.append(total_loss) 382 | save_similatiry_loss.append(similarity_loss) 383 | # acc = test(net, testset_loader) 384 | acc = 0 385 | print('Epoch', epoch, 'loss: %.4f' % total_loss, 'cs_loss: %.4f' % similarity_loss, 'test accuracy:%.4f' % acc) 386 | if epoch % 100 == 0: 387 | torch.save(net.state_dict(), log_path+'model_%.3d.pth' % (epoch)) 388 | np.savez(log_path+'loss_%.3d.npz'% (epoch), loss=np.array(save_loss), similarity_loss = np.array(save_similatiry_loss),gt=np.array(save_gt)) 389 | #if epoch % 1 == 0: 390 | # if acc > best_acc: 391 | # best_acc = acc 392 | # torch.save(net.state_dict(), log_path+'model_%.3d_%.4f.pth' % (epoch,best_acc)) 393 | print('Finished Training') 394 | return net 395 | 396 | def offline_spectral_cluster(net, train_data, dataname=None): 397 | net.eval() 398 | f_map = [] 399 | for inputs, labels in train_data: 400 | inputs, labels = inputs.cuda(), labels.cuda() 401 | cur_fmap= net(inputs,eval=True).detach().cpu().numpy() 402 | f_map.append(cur_fmap) 403 | if dataname == 'celeb' and len(f_map) >= 1024: 404 | break 405 | f_map = np.concatenate(f_map,axis=0) 406 | sample, channel,_,_ = f_map.shape 407 | f_map = f_map.reshape((sample,channel,-1)) 408 | mean = np.mean(f_map,axis=0) 409 | cov = np.mean(np.matmul(f_map-mean,np.transpose(f_map-mean,(0,2,1))),axis=0) 410 | diag = np.diag(cov).reshape(channel,-1) 411 | correlation = cov/(np.sqrt(np.matmul(diag,np.transpose(diag,(1,0))))+1e-5)+1 412 | ground_true, loss_mask_num, loss_mask_den = spectral_clustering(correlation,n_cluster=center_num) 413 | 414 | return ground_true, loss_mask_num, loss_mask_den 415 | 416 | def test_ori(net, testdata, n_cls): 417 | correct, total = .0, .0 418 | for batch_step, input_data in tqdm(enumerate(testdata,0),total=len(testdata),smoothing=0.9): 419 | inputs, labels = input_data 420 | inputs, labels = inputs.cuda(), labels.cuda() 421 | net.eval() 422 | outputs, _, _ = net(inputs) 423 | _, predicted = torch.max(outputs, 1) 424 | total += labels.size(0) 425 | correct += (predicted == labels).sum() 426 | return float(correct) / total 427 | 428 | def test_celeb(net, testdata, n_cls): 429 | correct, total = .0, .0 430 | ATTRIBUTION_NUM = n_cls//2 431 | running_correct = np.zeros(ATTRIBUTION_NUM) 432 | for inputs, labels in tqdm(testdata): 433 | inputs, labels = inputs.cuda(), labels.cuda().long() 434 | net.eval() 435 | outputs, _, _ = net(inputs) 436 | out = outputs.data 437 | total += labels.size(0) 438 | for attribution in range(ATTRIBUTION_NUM): 439 | _, predicted = torch.max(out[:, 2*attribution:2*attribution+2], 1) 440 | correct = (predicted == labels[:, attribution]).sum().item() 441 | running_correct[attribution] += correct 442 | attr_acc = running_correct / float(total) 443 | return np.mean(attr_acc) 444 | 445 | def get_feature(): 446 | _,testset_test = get_Data(True, DATANAME, BATCHSIZE) 447 | _,testset_feature = get_Data(False, DATANAME, BATCHSIZE) 448 | device = torch.device("cuda") 449 | net = None 450 | if LAYERS == '18': 451 | net = ResNet18(num_class=NUM_CLASSES, pretrained=False) 452 | elif LAYERS == '50': 453 | net = ResNet50(num_class=NUM_CLASSES, pretrained=False) 454 | global pretrain_model 455 | print(pretrain_model) 456 | 457 | test = test_celeb if DATANAME=='celeb' else test_ori 458 | acc = test(net, testset_test, NUM_CLASSES) 459 | f = open(acc_path+dataset+'_test.txt', 'w+') 460 | f.write('%s\n' % dataset) 461 | f.write('acc:%f\n' % acc) 462 | print(acc) 463 | all_feature = [] 464 | for batch_step, input_data in tqdm(enumerate(testset_feature,0),total=len(testset_feature),smoothing=0.9): 465 | inputs, labels = input_data 466 | inputs, labels = inputs.to(device), labels.to(device) 467 | net.eval() 468 | f_map = net(inputs,eval=True) 469 | all_feature.append(f_map.detach().cpu().numpy()) 470 | all_feature = np.concatenate(all_feature,axis=0) 471 | f.write('sample num:%d' % (all_feature.shape[0])) 472 | f.close() 473 | print(all_feature.shape) 474 | np.savez(save_path+LAYERS+'_resnet_'+DATANAME+'_iccnn.npz', f_map=all_feature[...]) 475 | print('Finished Operation!') 476 | 477 | def resnet_single_train(): 478 | if IS_TRAIN == 1: 479 | net_train() 480 | elif IS_TRAIN == 0: 481 | get_feature() 482 | 483 | -------------------------------------------------------------------------------- /resnet_ori_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.transforms as transforms 6 | import torchvision.models as models 7 | from torch.utils.data import DataLoader 8 | from load_utils import load_state_dict_from_url 9 | from cub_voc import CUB_VOC 10 | from celeb import Celeb 11 | import os 12 | from tqdm import tqdm 13 | import shutil 14 | import numpy as np 15 | from newPad2d import newPad2d 16 | 17 | IS_TRAIN = 0 # 0/1 18 | IS_MULTI = 0 19 | LAYERS = '18' 20 | DATANAME = 'bird' # 21 | NUM_CLASSES = 6 if IS_MULTI else 2 22 | if DATANAME == 'celeb': 23 | NUM_CLASSES = 80 24 | cub_file = '/data/sw/dataset/frac_dataset' 25 | voc_file = '/data/sw/dataset/VOCdevkit/VOC2010/voc2010_crop' 26 | celeb_file = '/home/user05/fjq/dataset/CelebA/' 27 | log_path = '/data/fjq/iccnn/resnet/' # for model 28 | save_path = '/data/fjq/iccnn/basic_fmap/resnet/' # for get_feature 29 | acc_path = '/data/fjq/iccnn/basic_fmap/resnet/acc/' 30 | 31 | dataset = '%s_resnet_%s_ori' % (LAYERS, DATANAME) 32 | log_path = log_path + dataset + '/' 33 | pretrain_model = log_path + 'model_2000.pth' 34 | BATCHSIZE = 16 35 | LR = 0.000001 36 | EPOCH = 200 37 | 38 | 39 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 40 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 41 | 'wide_resnet50_2', 'wide_resnet101_2'] 42 | 43 | model_urls = { 44 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 45 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 46 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 47 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 48 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 49 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 50 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 51 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 52 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 53 | } 54 | 55 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 56 | """3x3 convolution with padding""" 57 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 58 | padding=0, groups=groups, bias=False, dilation=dilation)#new padding 59 | 60 | def conv1x1(in_planes, out_planes, stride=1): 61 | """1x1 convolution""" 62 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 63 | 64 | class BasicBlock(nn.Module): 65 | expansion = 1 66 | 67 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 68 | base_width=64, dilation=1, norm_layer=None): 69 | super(BasicBlock, self).__init__() 70 | if norm_layer is None: 71 | norm_layer = nn.BatchNorm2d 72 | if groups != 1 or base_width != 64: 73 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 74 | if dilation > 1: 75 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 76 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 77 | self.conv1 = conv3x3(inplanes, planes, stride) 78 | self.bn1 = norm_layer(planes) 79 | self.relu = nn.ReLU(inplace=True) 80 | self.conv2 = conv3x3(planes, planes) 81 | self.bn2 = norm_layer(planes) 82 | self.downsample = downsample 83 | self.stride = stride 84 | self.pad2d = newPad2d(1) #nn.ReplicationPad2d(1)#new paddig 85 | 86 | def forward(self, x): 87 | identity = x 88 | out = self.pad2d(x) #new padding 89 | out = self.conv1(out) 90 | out = self.bn1(out) 91 | out = self.relu(out) 92 | 93 | out = self.pad2d(out) #new padding 94 | out = self.conv2(out) 95 | out = self.bn2(out) 96 | 97 | if self.downsample is not None: 98 | identity = self.downsample(x) 99 | 100 | out += identity 101 | out = self.relu(out) 102 | 103 | return out 104 | 105 | class Bottleneck(nn.Module): 106 | expansion = 4 107 | 108 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 109 | base_width=64, dilation=1, norm_layer=None): 110 | super(Bottleneck, self).__init__() 111 | if norm_layer is None: 112 | norm_layer = nn.BatchNorm2d 113 | width = int(planes * (base_width / 64.)) * groups 114 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 115 | self.conv1 = conv1x1(inplanes, width) 116 | self.bn1 = norm_layer(width) 117 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 118 | self.bn2 = norm_layer(width) 119 | self.conv3 = conv1x1(width, planes * self.expansion) 120 | self.bn3 = norm_layer(planes * self.expansion) 121 | self.relu = nn.ReLU(inplace=True) 122 | self.downsample = downsample 123 | self.stride = stride 124 | self.pad2d = newPad2d(1) #nn.ReplicationPad2d(1)#new paddig 125 | 126 | def forward(self, x): 127 | identity = x 128 | 129 | out = self.conv1(x) 130 | out = self.bn1(out) 131 | out = self.relu(out) 132 | 133 | out = self.pad2d(out) #new padding 134 | out = self.conv2(out) 135 | out = self.bn2(out) 136 | out = self.relu(out) 137 | 138 | out = self.conv3(out) 139 | out = self.bn3(out) 140 | 141 | if self.downsample is not None: 142 | identity = self.downsample(x) 143 | 144 | out += identity 145 | out = self.relu(out) 146 | 147 | return out 148 | 149 | class ResNet(nn.Module): 150 | def __init__(self, block, layers, num_classes=2, zero_init_residual=False, 151 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 152 | norm_layer=None): 153 | super(ResNet, self).__init__() 154 | if norm_layer is None: 155 | norm_layer = nn.BatchNorm2d 156 | self._norm_layer = norm_layer 157 | 158 | self.inplanes = 64 159 | self.dilation = 1 160 | if replace_stride_with_dilation is None: 161 | # each element in the tuple indicates if we should replace 162 | # the 2x2 stride with a dilated convolution instead 163 | replace_stride_with_dilation = [False, False, False] 164 | if len(replace_stride_with_dilation) != 3: 165 | raise ValueError("replace_stride_with_dilation should be None " 166 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 167 | self.groups = groups 168 | self.base_width = width_per_group 169 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=0, 170 | bias=False)#new padding 171 | self.bn1 = norm_layer(self.inplanes) 172 | self.relu = nn.ReLU(inplace=True) 173 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)#new padding 174 | self.layer1 = self._make_layer(block, 64, layers[0]) 175 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 176 | dilate=replace_stride_with_dilation[0]) 177 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 178 | dilate=replace_stride_with_dilation[1]) 179 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 180 | dilate=replace_stride_with_dilation[2]) 181 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 182 | self.fc = nn.Linear(512 * block.expansion, num_classes) 183 | self.pad2d_1 = newPad2d(1)#new paddig 184 | self.pad2d_3 = newPad2d(3)#new paddig 185 | 186 | for m in self.modules(): 187 | if isinstance(m, nn.Conv2d): 188 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 189 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 190 | nn.init.constant_(m.weight, 1) 191 | nn.init.constant_(m.bias, 0) 192 | 193 | # Zero-initialize the last BN in each residual branch, 194 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 195 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 196 | if zero_init_residual: 197 | for m in self.modules(): 198 | if isinstance(m, Bottleneck): 199 | nn.init.constant_(m.bn3.weight, 0) 200 | elif isinstance(m, BasicBlock): 201 | nn.init.constant_(m.bn2.weight, 0) 202 | 203 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 204 | norm_layer = self._norm_layer 205 | downsample = None 206 | previous_dilation = self.dilation 207 | if dilate: 208 | self.dilation *= stride 209 | stride = 1 210 | if stride != 1 or self.inplanes != planes * block.expansion: 211 | downsample = nn.Sequential( 212 | conv1x1(self.inplanes, planes * block.expansion, stride), 213 | norm_layer(planes * block.expansion), 214 | ) 215 | 216 | layers = [] 217 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 218 | self.base_width, previous_dilation, norm_layer)) 219 | self.inplanes = planes * block.expansion 220 | for _ in range(1, blocks): 221 | layers.append(block(self.inplanes, planes, groups=self.groups, 222 | base_width=self.base_width, dilation=self.dilation, 223 | norm_layer=norm_layer)) 224 | 225 | return nn.Sequential(*layers) 226 | 227 | def forward(self, x): 228 | # See note [TorchScript super()] 229 | x = self.pad2d_3(x) #new padding 230 | x = self.conv1(x) 231 | x = self.bn1(x) 232 | x = self.relu(x) 233 | x = self.pad2d_1(x) 234 | x = self.maxpool(x) 235 | 236 | x = self.layer1(x) 237 | x = self.layer2(x) 238 | x = self.layer3(x) 239 | f_map = x.detach() 240 | x = self.layer4(x) 241 | x = self.avgpool(x) 242 | x = torch.flatten(x, 1) 243 | x = self.fc(x) 244 | return x, f_map 245 | 246 | def _resnet(arch, block, layers, n_cls, pretrained, progress, **kwargs): 247 | model = ResNet(block, layers, n_cls, **kwargs) 248 | if pretrained: 249 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 250 | pretrained_dict = {k: v for k, v in state_dict.items() if 'fc' not in k and 'layer4.1' not in k}#'fc' not in k and 'layer4.1' not in k and 251 | model_dict = model.state_dict() 252 | model_dict.update(pretrained_dict) 253 | model.load_state_dict(model_dict) 254 | else: 255 | device = torch.device("cuda") 256 | model = nn.DataParallel(model).to(device) 257 | model.load_state_dict(torch.load(pretrain_model)) 258 | return model 259 | 260 | def ResNet18(n_cls, pretrained=False, progress=True, **kwargs): 261 | return _resnet('resnet18', BasicBlock, [2,2,2,2], n_cls, pretrained, progress, **kwargs) 262 | 263 | def ResNet34(n_cls, pretrained=False, progress=True, **kwargs): 264 | return _resnet('resnet34', BasicBlock, [3,4,6,3], n_cls, pretrained, progress, **kwargs) 265 | 266 | def ResNet50(n_cls, pretrained=False, progress=True, **kwargs): 267 | return _resnet('resnet50', Bottleneck, [3,4,6,3], n_cls, pretrained, progress, **kwargs) 268 | 269 | def ResNet101(n_cls, pretrained=False, progress=True, **kwargs): 270 | return _resnet('resnet101', Bottleneck, [3,4,23,3], n_cls, pretrained, progress, **kwargs) 271 | 272 | def ResNet152(n_cls, pretrained=False, progress=True, **kwargs): 273 | return _resnet('resnet152', Bottleneck, [3,8,36,3], n_cls, pretrained, progress, **kwargs) 274 | 275 | def get_Data(is_train, dataset_name,batch_size): 276 | transform = transforms.Compose([ 277 | transforms.RandomResizedCrop((224,224),scale=(0.5,1.0)), 278 | transforms.RandomHorizontalFlip(), 279 | transforms.ToTensor(), 280 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 281 | std=[0.229, 0.224, 0.225]) 282 | ]) 283 | val_transform = transforms.Compose([ 284 | transforms.Resize((224,224)), 285 | transforms.ToTensor(), 286 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 287 | std=[0.229, 0.224, 0.225]) 288 | ]) 289 | voc_helen_name = ['bird', 'cat', 'cow', 'dog', 'horse', 'sheep', 'helen', 'voc_multi'] 290 | ##cub dataset### 291 | label = None if is_train else 0 292 | if dataset_name == 'cub': 293 | trainset = CUB_VOC(cub_file, dataset_name, 'ori', train=True, transform=transform, is_frac=label) 294 | testset = CUB_VOC(cub_file, dataset_name, 'ori', train=False, transform=val_transform, is_frac=label) 295 | ###cropped voc dataset### 296 | elif dataset_name in voc_helen_name: 297 | trainset = CUB_VOC(voc_file, dataset_name, 'ori', train=True, transform=transform, is_frac=label) 298 | testset = CUB_VOC(voc_file, dataset_name, 'ori', train=False, transform=val_transform, is_frac=label) 299 | ###celeb dataset### 300 | elif dataset_name == 'celeb': 301 | trainset = Celeb(celeb_file, training = True, transform=None, train_num=162770) 302 | testset = Celeb(celeb_file, training = False, transform=None, train_num=19962) 303 | train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True) 304 | test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, drop_last=False) 305 | return train_loader, test_loader 306 | 307 | def net_train(): 308 | trainset_loader, testset_loader = get_Data(IS_TRAIN, DATANAME,BATCHSIZE) 309 | device = torch.device("cuda") 310 | if LAYERS == '50': 311 | net = ResNet50(NUM_CLASSES, pretrained=True) 312 | else: 313 | net = ResNet18(NUM_CLASSES, pretrained=True) 314 | net = nn.DataParallel(net).to(device) 315 | model_path = os.path.join(log_path, '%s_resnet_%s_ori' % (LAYERS, DATANAME)) 316 | if os.path.exists(model_path): 317 | shutil.rmtree(model_path) 318 | os.makedirs(model_path) 319 | else: 320 | os.makedirs(model_path) 321 | 322 | # Loss and Optimizer 323 | criterion = nn.CrossEntropyLoss() 324 | optimizer = torch.optim.Adam(net.module.parameters(), lr=LR) 325 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=125, gamma=0.7) 326 | test = test_celeb if DATANAME=='celeb' else test_ori 327 | # Train the model 328 | best_acc = 0.0; save_loss = []; 329 | for epoch in range(0, EPOCH+1): 330 | scheduler.step() 331 | net.train() 332 | total_loss = 0.0; 333 | for batch_step, input_data in tqdm(enumerate(trainset_loader,0),total=len(trainset_loader),smoothing=0.9): 334 | inputs, labels = input_data 335 | inputs, labels = inputs.to(device), labels.long().to(device) 336 | optimizer.zero_grad() 337 | output, _ = net(inputs) 338 | if DATANAME != 'celeb': 339 | loss = criterion(output, labels) 340 | else: 341 | loss = .0 342 | for attribution in range(NUM_CLASSES//2): 343 | loss += criterion(output[:, 2*attribution:2*attribution+2], labels[:, attribution]) 344 | loss.backward() 345 | optimizer.step() 346 | total_loss += loss.item() 347 | ### loss save code ##### 348 | total_loss = float(total_loss) / len(trainset_loader) 349 | save_loss.append(total_loss) 350 | np.savez(os.path.join(model_path, 'loss.npz'), loss=np.array(save_loss)) 351 | if epoch % 50 == 0: 352 | train_acc = test(net, trainset_loader, NUM_CLASSES) 353 | print('Epoch', epoch, 'train accuracy:%.4f' % train_acc) 354 | torch.save(net.state_dict(), model_path+'/model_%.3d.pth' % (epoch)) 355 | if epoch % 1 == 0: 356 | acc = test(net, testset_loader, NUM_CLASSES) 357 | print('Epoch', epoch, 'loss: %.4f' % total_loss, 'test accuracy:%.4f' % acc) 358 | if acc > best_acc and epoch >= 10: 359 | best_acc = acc 360 | torch.save(net.state_dict(), model_path+'/model_%.3d_%.4f.pth' % (epoch, best_acc)) 361 | print('Finished Training') 362 | return net 363 | 364 | def get_feature(): 365 | _, testset_test = get_Data(True, DATANAME, BATCHSIZE) 366 | _, testset_feature = get_Data(False, DATANAME, BATCHSIZE) 367 | device = torch.device("cuda") 368 | if not os.path.exists(pretrain_model): 369 | raise Exception("Not such pretrain-model!") 370 | if LAYERS == '50': 371 | net = ResNet50(NUM_CLASSES) 372 | else: 373 | net = ResNet18(NUM_CLASSES) 374 | net = nn.DataParallel(net).to(device) 375 | test = test_celeb if DATANAME=='celeb' else test_ori 376 | acc = test(net, testset_test, NUM_CLASSES) 377 | print('test acc:', acc) 378 | #if not os.path.exists(acc_path): 379 | # os.makedirs(acc_path) 380 | f = open(os.path.join(acc_path, 'res'+str(LAYERS)+'_'+DATANAME+'_test.txt'), 'w+') 381 | f.write('%s%s\n' % ('res', str(LAYERS))) 382 | f.write('%s\n' % DATANAME) 383 | f.write('acc:%f\n' % acc) 384 | #if not os.path.exists(save_path): 385 | # os.makedirs(save_path) 386 | all_feature = [] 387 | for batch_step, input_data in tqdm(enumerate(testset_feature,0),total=len(testset_feature),smoothing=0.9): 388 | inputs, labels = input_data 389 | inputs, labels = inputs.to(device), labels.long().to(device) 390 | net.eval() 391 | output, f_map = net(inputs) 392 | all_feature.append(f_map.cpu().numpy()) 393 | all_feature = np.concatenate(all_feature, axis=0) 394 | f.write('sample num:%d' % (all_feature.shape[0])) 395 | f.close() 396 | np.savez_compressed(save_path+LAYERS+'_resnet_'+DATANAME+'_ori.npz', f_map=all_feature[...]) 397 | print('Finished Getting Feature!') 398 | return net 399 | 400 | def test_ori(net, testdata, n_cls): 401 | correct, total = .0, .0 402 | for inputs, labels in testdata: 403 | inputs, labels = inputs.cuda(), labels.cuda().long() 404 | net.eval() 405 | outputs, _ = net(inputs) 406 | _, predicted = torch.max(outputs, 1) 407 | total += labels.size(0) 408 | correct += (predicted == labels).sum() 409 | return float(correct) / total 410 | 411 | def test_celeb(net, testdata, n_cls): 412 | correct, total = .0, .0 413 | ATTRIBUTION_NUM = n_cls//2 414 | running_correct = np.zeros(ATTRIBUTION_NUM) 415 | for inputs, labels in tqdm(testdata): 416 | inputs, labels = inputs.cuda(), labels.cuda().long() 417 | net.eval() 418 | outputs, _ = net(inputs) 419 | out = outputs.data 420 | total += labels.size(0) 421 | for attribution in range(ATTRIBUTION_NUM): 422 | _, predicted = torch.max(out[:, 2*attribution:2*attribution+2], 1) 423 | correct = (predicted == labels[:, attribution]).sum().item() 424 | running_correct[attribution] += correct 425 | attr_acc = running_correct / float(total) 426 | return np.mean(attr_acc) 427 | 428 | def resnet_ori_train(): 429 | if IS_TRAIN: 430 | net = net_train() 431 | else: 432 | net = get_feature() 433 | 434 | 435 | if __name__ == '__main__': 436 | net = resnet_ori_train() 437 | -------------------------------------------------------------------------------- /train_all.py: -------------------------------------------------------------------------------- 1 | ## train files 2 | from densenet_ori_train import densenet_ori_train 3 | from densenet_iccnn_train import densenet_single_train 4 | from densenet_iccnn_multi_train import densenet_multi_train 5 | #from vgg_train import 6 | #from resnet_train import 7 | #resnet 8 | from resnet_iccnn_multi_train import resnet_multi_train 9 | from resnet_iccnn_train import resnet_single_train 10 | from resnet_ori_train import resnet_ori_train ### 11 | #vgg 12 | from vgg_iccnn_train import vgg_single_train 13 | from vgg_iccnn_multi_train import vgg_multi_train 14 | from vgg_ori_train import vgg_ori_train 15 | ## 16 | import argparse 17 | import random 18 | import os 19 | import numpy as np 20 | import torch 21 | 22 | def set_seed(seed=0): 23 | random.seed(seed) 24 | os.environ['PYTHONHASHSEED'] = str(seed) 25 | np.random.seed(seed) 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed_all(seed) 28 | torch.cuda.manual_seed(seed) 29 | torch.backends.cudnn.deterministic = True 30 | 31 | if __name__ == '__main__': 32 | 33 | torch.set_num_threads(5) 34 | 35 | set_seed(0) 36 | 37 | parser = argparse.ArgumentParser() 38 | 39 | # add positional arguments 40 | parser.add_argument('-type', type=str, help='the type of train model ori/iccnn') 41 | parser.add_argument('-is_multi', type=int, help='single/multi 0/1') 42 | parser.add_argument('-model', type=str, help='vgg/resnet/densenet') 43 | 44 | args = parser.parse_args() 45 | 46 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 47 | 48 | if args.model == 'densenet': 49 | if args.type == 'ori': 50 | densenet_ori_train() 51 | elif args.type == 'iccnn': 52 | if args.is_multi == 0: 53 | densenet_single_train() 54 | elif args.is_multi == 1: 55 | densenet_multi_train() 56 | elif args.model == 'resnet': 57 | if args.type == 'ori': 58 | resnet_ori_train() 59 | elif args.type == 'iccnn': 60 | if args.is_multi == 0: 61 | resnet_single_train() 62 | else: 63 | resnet_multi_train() 64 | elif args.model == 'vgg': 65 | if args.type == 'ori': 66 | vgg_ori_train() 67 | elif args.type == 'iccnn': 68 | if args.is_multi == 0: 69 | vgg_single_train() 70 | else: 71 | vgg_multi_train() 72 | else: 73 | raise Exception("Not Implemented!") 74 | 75 | -------------------------------------------------------------------------------- /tutorial.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ada-shen/icCNN/6f6d7bd31a437a3e39c33cb53967ea5e7f1b26f2/tutorial.pdf -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import numpy as np 4 | import torch 5 | 6 | class EMA_FM(): 7 | def __init__(self, decay=0.9, first_decay=0.0, channel_size=512, f_map_size=196, is_use = False): 8 | self.decay = decay 9 | self.first_decay = first_decay 10 | self.is_use = is_use 11 | self.shadow = {} 12 | self.epsional = 1e-5 13 | if is_use: 14 | self._register(channel_size=channel_size, f_map_size= f_map_size) 15 | 16 | def _register(self, channel_size=512, f_map_size=196): 17 | Init_FM = torch.zeros((f_map_size, channel_size),dtype=torch.float) 18 | self.shadow['FM'] = Init_FM.cuda().clone() 19 | self.is_first = True 20 | 21 | def update(self, input): 22 | B, C, _ = input.size() 23 | if not(self.is_use): 24 | return torch.ones((C,C), dtype=torch.float) 25 | decay = self.first_decay if self.is_first else self.decay 26 | ####### FEATURE SIMILARITY MATRIX EMA ######## 27 | # Mu = torch.mean(input,dim=0) 28 | self.shadow['FM'] = (1.0 - decay) * torch.mean(input,dim=0) + decay * self.shadow['FM'] 29 | self.is_first = False 30 | return self.shadow['FM'] 31 | 32 | class Cluster_loss(): 33 | def __init__(self): 34 | pass 35 | 36 | def update(self, correlation, loss_mask_num, loss_mask_den, labels): 37 | batch, channel, _ = correlation.shape 38 | c, _, _ = loss_mask_num.shape 39 | if labels is not None: 40 | label_mask = (1 - labels).view(batch, 1, 1) 41 | ## smg_loss if only available for positive sample 42 | correlation = correlation * label_mask 43 | correlation = (correlation / batch).view(1, batch, channel, channel).repeat(c, 1, 1, 1) 44 | 45 | new_Num = torch.sum(correlation * loss_mask_num.view(c, 1, channel, channel).repeat(1, batch, 1, 1), 46 | dim=(1, 2, 3)) 47 | new_Den = torch.sum(correlation * (loss_mask_den).view(c, 1, channel, channel).repeat(1, batch, 1, 1), 48 | dim=(1, 2, 3)) 49 | ret_loss = -torch.sum(new_Num / (new_Den + 1e-5)) 50 | return ret_loss 51 | 52 | class Multiclass_loss(): 53 | def __init__(self, class_num=None): 54 | self.class_num = class_num 55 | 56 | def get_label_mask(self, label): 57 | label = label.cpu().numpy() 58 | sz = label.shape[0] 59 | label_mask_num = [] 60 | label_mask_den = [] 61 | for i in range(self.class_num): 62 | idx = np.where(label == i)[0] 63 | cur_mask_num = np.zeros((sz, sz)) 64 | cur_mask_den = np.zeros((sz, sz)) 65 | for j in idx: 66 | cur_mask_num[j][idx] = 1 67 | cur_mask_den[j][:] = 1 68 | label_mask_num.append(np.expand_dims(cur_mask_num, 0)) 69 | label_mask_den.append(np.expand_dims(cur_mask_den, 0)) 70 | label_mask_num = np.concatenate(label_mask_num, axis=0) 71 | label_mask_den = np.concatenate(label_mask_den, axis=0) 72 | return torch.from_numpy(label_mask_num).float().cuda(), torch.from_numpy(label_mask_den).float().cuda() 73 | 74 | def update(self, fmap, loss_mask_num, label): 75 | B, C, _, _ = fmap.shape 76 | center, _, _ = loss_mask_num.shape 77 | fmap = fmap.view(1, B, C, -1).repeat(center, 1, 1, 1) 78 | mean_activate = torch.mean(torch.matmul(loss_mask_num.view(center, 1, C, C).repeat(1, B, 1, 1), fmap), 79 | dim=(2, 3)) 80 | # cosine 81 | mean_activate = torch.div(mean_activate, torch.norm(mean_activate, p=2, dim=0, keepdim=True) + 1e-5) 82 | inner_dot = torch.matmul(mean_activate.permute(1, 0), mean_activate).view(-1, B, B).repeat(self.class_num, 1, 1) 83 | label_mask, label_mask_intra = self.get_label_mask(label) 84 | 85 | new_Num = torch.mean(inner_dot * label_mask, dim=(1, 2)) 86 | new_Den = torch.mean(inner_dot * label_mask_intra, dim=(1, 2)) 87 | ret_loss = -torch.sum(new_Num / (new_Den + 1e-5)) 88 | return ret_loss 89 | 90 | def Cal_Center(fmap, gt): 91 | f_1map = fmap.detach().cpu().numpy() 92 | matrix = gt.detach().cpu().numpy() 93 | B, C, H, W = f_1map.shape 94 | cluster = [] 95 | visited = np.zeros(C) 96 | for i in range(matrix.shape[0]): 97 | tmp = [] 98 | if(visited[i]==0): 99 | for j in range(matrix.shape[1]): 100 | if(matrix[i][j]==1 ): 101 | tmp.append(j) 102 | visited[j]=1; 103 | cluster.append(tmp) 104 | center = [] 105 | for i in range(len(cluster)): 106 | cur_clustet_fmap = f_1map[:,cluster[i],...] 107 | cluster_center = np.mean(cur_clustet_fmap,axis=1) 108 | center.append(cluster_center) 109 | center = np.transpose(np.array(center),[1,0,2,3]) 110 | center = torch.from_numpy(center).float() 111 | return center 112 | -------------------------------------------------------------------------------- /vgg_iccnn_multi_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import torch 4 | import torch.nn as nn 5 | import math 6 | import torchvision.transforms as transforms 7 | import torchvision as tv 8 | import torchvision.models as models 9 | from torch.utils.data import DataLoader 10 | from load_utils import load_state_dict_from_url 11 | from cub_voc import CUB_VOC 12 | import os 13 | from tqdm import tqdm 14 | import shutil 15 | from utils.utils import Cluster_loss, Multiclass_loss 16 | import numpy as np 17 | from Similar_Mask_Generate import SMGBlock 18 | from SpectralClustering import spectral_clustering 19 | from newPad2d import newPad2d 20 | 21 | IS_TRAIN = 0 # 0/1 22 | LAYERS = '19' 23 | DATANAME = 'voc_multi' # voc_multi 24 | NUM_CLASSES = 6 25 | cub_file = './dataset/frac_dataset' 26 | voc_file = './dataset/VOCdevkit/VOC2010/voc2010_crop' 27 | log_path = './icCNN/run/vgg/' # for model 28 | save_path = './icCNN/run/basic_fmap/vgg/' # for get_feature 29 | acc_path = './icCNN/run/basic_fmap/vgg/acc/' 30 | 31 | dataset = '%s_vgg_%s_iccnn' % (LAYERS, DATANAME) 32 | log_path = log_path + dataset + '/' 33 | pretrain_model = log_path + 'model_1000.pth' 34 | BATCHSIZE = 1 35 | LR = 0.00001 36 | EPOCH = 1000 37 | center_num = 16 38 | lam1 = 0.1 39 | lam2 = 0.1 40 | T = 2 # T = 2 ===> do sc each epoch 41 | F_MAP_SIZE = 196 42 | STOP_CLUSTERING = 200 43 | if LAYERS == '13': 44 | CHANNEL_NUM = 512 45 | elif LAYERS == '16': 46 | CHANNEL_NUM = 512 47 | elif LAYERS == '19': 48 | CHANNEL_NUM = 512 49 | 50 | __all__ = ['VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn','vgg19_bn', 'vgg19',] 51 | cfgs = { 52 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 53 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 54 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 55 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],} 56 | model_urls = { 57 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 58 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 59 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 60 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 61 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 62 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 63 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 64 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',} 65 | 66 | class VGG(nn.Module): 67 | def __init__(self, features, num_classes=NUM_CLASSES, init_weights=True, cfg=None): 68 | super(VGG, self).__init__() 69 | self.features = features 70 | # define the layer number of the relu after the top conv layer of VGG 71 | if cfg=='D': # VGG16 72 | self.target_layer = 42 73 | if cfg=='B': # VGG13 74 | self.target_layer = 33 75 | if cfg=='E': # VGG19 76 | self.target_layer = 51 77 | self.layer_num = self.target_layer 78 | self.smg = SMGBlock(channel_size = CHANNEL_NUM, f_map_size=F_MAP_SIZE) 79 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 80 | self.pad2d = newPad2d(1) 81 | self.classifier = nn.Sequential( 82 | #fc6 83 | nn.Linear(512*7*7, 4096),nn.ReLU(True),nn.Dropout(0.5), 84 | #fc7 85 | nn.Linear(4096, 512),nn.ReLU(True),nn.Dropout(0.5), 86 | #fc8 87 | nn.Linear(512, num_classes)) 88 | 89 | if init_weights: 90 | self._initialize_weights() 91 | 92 | def forward(self, x, eval=False): 93 | for layer in self.features[:self.target_layer+1]: 94 | if isinstance(layer,nn.Conv2d): 95 | x = self.pad2d(x) 96 | x = layer(x) 97 | 98 | if eval: 99 | return x 100 | corre_matrix = self.smg(x) 101 | f_map = x.detach() 102 | fd = x 103 | for layer in self.features[self.target_layer+1:]: 104 | if isinstance(layer,nn.Conv2d): 105 | x = self.pad2d(x) 106 | x = layer(x) 107 | x = self.avgpool(x) 108 | x = x.view(x.size(0), -1) 109 | x = self.classifier(x) 110 | return x, f_map, corre_matrix , fd 111 | 112 | def _initialize_weights(self): 113 | for layer, m in enumerate(self.modules()): 114 | #print(layer,m) 115 | if layer > self.layer_num: 116 | if isinstance(m, nn.Conv2d): 117 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 118 | if m.bias is not None: 119 | nn.init.constant_(m.bias, 0) 120 | elif isinstance(m, nn.BatchNorm2d): 121 | nn.init.constant_(m.weight, 1) 122 | nn.init.constant_(m.bias, 0) 123 | if isinstance(m, nn.Linear): 124 | nn.init.normal_(m.weight, 0, 0.01) 125 | nn.init.constant_(m.bias, 0) 126 | 127 | 128 | def make_layers(cfg, batch_norm=False): 129 | layers = [] 130 | in_channels = 3 131 | for v in cfg: 132 | if v == 'M': 133 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 134 | else: 135 | conv2d = nn.Conv2d(in_channels, v, 3, padding=0) 136 | if batch_norm: 137 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 138 | else: 139 | layers += [conv2d, nn.ReLU(inplace=True)] 140 | in_channels = v 141 | return nn.Sequential(*layers) 142 | 143 | def vgg16(arch, cfg, device=None, pretrained=False, progress=True, **kwargs): 144 | if pretrained: 145 | kwargs['init_weights'] = False 146 | kwargs['cfg'] = cfg 147 | model = VGG(make_layers(cfgs[cfg], batch_norm=True), **kwargs) 148 | if pretrained: 149 | if pretrain_model is None: 150 | state_dict = load_state_dict_from_url(model_urls[arch],progress=progress) 151 | pretrained_dict = {k: v for k, v in state_dict.items() if 'classifier' not in k} 152 | model_dict = model.state_dict() 153 | model_dict.update(pretrained_dict) 154 | model.load_state_dict(model_dict) 155 | else: 156 | device = torch.device("cuda") 157 | model = nn.DataParallel(model).to(device) 158 | pretrained_dict = torch.load(pretrain_model) 159 | if IS_TRAIN == 0: 160 | pretrained_dict = {k[k.find('.')+1:]: v for k, v in pretrained_dict.items()} # imagenet need comment 161 | # model_dict = model.state_dict() 162 | # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 163 | # model_dict.update(pretrained_dict) 164 | model.load_state_dict(pretrained_dict) 165 | 166 | if device is not None: 167 | model = nn.DataParallel(model).to(device) 168 | return model 169 | 170 | def get_Data(is_train, dataset_name, batch_size): 171 | val_transform = transforms.Compose([ 172 | transforms.Resize((224,224)), 173 | transforms.ToTensor(), 174 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 175 | std=[0.229, 0.224, 0.225]) 176 | ]) 177 | voc_helen = ['bird', 'cat', 'cow', 'dog', 'horse', 'sheep', 'helen', 'voc_multi'] 178 | ##cub dataset### 179 | label = None if is_train else 0 180 | if not is_train: 181 | batch_size = 1 182 | if dataset_name == 'cub': 183 | trainset = CUB_VOC(cub_file, dataset_name, 'iccnn', train=True, transform=val_transform, is_frac=label) 184 | testset = CUB_VOC(cub_file, dataset_name, 'iccnn', train=False, transform=val_transform, is_frac=label) 185 | ###cropped voc dataset### 186 | elif dataset_name in voc_helen: 187 | trainset = CUB_VOC(voc_file, dataset_name, 'iccnn', train=True, transform=val_transform, is_frac=label) 188 | testset = CUB_VOC(voc_file, dataset_name, 'iccnn', train=False, transform=val_transform, is_frac=label) 189 | ###celeb dataset### 190 | #elif dataset_name == 'celeb': 191 | # trainset = Celeb(training = True, transform=None) 192 | # testset = Celeb(training = False, transform=None) 193 | train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True) 194 | test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False) 195 | return train_loader, test_loader 196 | 197 | def net_train(): 198 | trainset_loader, testset_loader = get_Data(IS_TRAIN, DATANAME, BATCHSIZE) 199 | 200 | if os.path.exists(log_path): 201 | shutil.rmtree(log_path);os.makedirs(log_path) 202 | else: 203 | os.makedirs(log_path) 204 | device = torch.device("cuda") 205 | 206 | net = None 207 | if LAYERS == '13': 208 | net = vgg16(arch='vgg13_bn',cfg='B', device=device, pretrained=True, progress=True) 209 | elif LAYERS == '16': 210 | net = vgg16(arch='vgg16_bn',cfg='D', device=device, pretrained=True, progress=True) 211 | elif LAYERS == '19': 212 | net = vgg16(arch='vgg19_bn',cfg='E', device=device, pretrained=True, progress=True) 213 | # Loss and Optimizer 214 | criterion_ce = nn.CrossEntropyLoss() 215 | optimizer = torch.optim.Adam(net.module.parameters(), lr=LR) 216 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=125, gamma=0.6) 217 | 218 | # Train the model 219 | best_acc = 0.0; save_total_loss = []; save_similatiry_loss = [];save_gt=[] 220 | save_class_loss= [] 221 | save_num = [] 222 | save_label = [] 223 | cs_loss = Cluster_loss() 224 | mc_loss = Multiclass_loss(class_num= NUM_CLASSES) 225 | for epoch in range(EPOCH+1): 226 | if epoch % T==0 and epoch < STOP_CLUSTERING: 227 | with torch.no_grad(): 228 | Ground_true, loss_mask_num, loss_mask_den = offline_spectral_cluster(net, trainset_loader) 229 | save_gt.append(Ground_true.cpu().numpy()) 230 | else: 231 | scheduler.step() 232 | net.train() 233 | all_feature = [] 234 | all_grad = [] 235 | clr_loss_grad = [] 236 | loss1_grad = [] 237 | loss2_grad = [] 238 | all_label = [] 239 | all_f_map = [] 240 | total_loss = 0.0 241 | similarity_loss = 0.0 242 | class_loss = 0.0 243 | for batch_step, input_data in tqdm(enumerate(trainset_loader,0),total=len(trainset_loader),smoothing=0.9): 244 | inputs, labels = input_data 245 | inputs, labels = inputs.to(device), labels.to(device) 246 | inputs.requires_grad=True 247 | all_label.append(labels.detach().cpu().numpy()) 248 | 249 | output, f_map, corre, fd = net(inputs) 250 | 251 | fd.retain_grad() 252 | clr_loss = criterion_ce(output, labels) 253 | loss1 = cs_loss.update(corre, loss_mask_num, loss_mask_den, None) 254 | loss2 = mc_loss.update(fd, loss_mask_num, labels) 255 | loss = clr_loss + lam1 * loss1 + lam2 * loss2 256 | if epoch <= 0: 257 | # optimizer.zero_grad() 258 | # clr_loss.backward(retain_graph=True) 259 | # clr_loss_grad.append(fd.grad.detach().cpu().numpy()) 260 | # print(fd.grad.shape) 261 | # print('-'*10) 262 | # optimizer.zero_grad() 263 | # loss1.backward(retain_graph=True) 264 | # loss1_grad.append(fd.grad.detach().cpu().numpy()) 265 | # print(fd.grad.shape) 266 | # print('-'*10) 267 | optimizer.zero_grad() 268 | loss2.requires_grad_(True) 269 | loss2.backward(retain_graph=True) 270 | print(inputs.grad.shape) 271 | if fd.grad != None: 272 | loss2_grad.append(fd.grad.detach().cpu().numpy()) 273 | else: 274 | loss2_grad.append(fd.grad) 275 | print('-'*10) 276 | optimizer.zero_grad() 277 | loss.backward() 278 | if epoch <= 0: 279 | all_f_map.append(f_map.detach().cpu().numpy()) 280 | all_grad.append(fd.grad.detach().cpu().numpy()) 281 | optimizer.step() 282 | total_loss += loss.item() 283 | similarity_loss += loss1.item() 284 | class_loss += loss2.item() 285 | ### loss save code ##### 286 | total_loss = float(total_loss) / len(trainset_loader) 287 | similarity_loss = float(similarity_loss) / len(trainset_loader) 288 | class_loss = float(class_loss) / len(trainset_loader) 289 | save_total_loss.append(total_loss) 290 | save_similatiry_loss.append(similarity_loss) 291 | save_class_loss.append(class_loss) 292 | acc = test(net, testset_loader) 293 | print('Epoch', epoch, 'loss: %.4f' % total_loss, 'sc_loss: %.4f' % similarity_loss, 'class_loss: %.4f' % class_loss, 'test accuracy:%.4f' % acc) 294 | if epoch <= 0: 295 | np.savez(log_path+'grad_%.3d.npz'% (epoch), grad=np.array(all_grad),clr_loss_grad=np.array(clr_loss_grad),loss1_grad=np.array(loss1_grad),loss2_grad=np.array(loss2_grad),label=np.array(all_label),f_map=np.array(all_f_map)) 296 | if epoch % 100 == 0: 297 | torch.save(net.state_dict(), log_path+'model_%.3d.pth' % (epoch)) 298 | np.savez(log_path+'loss_%.3d.npz'% (epoch), loss=np.array(save_total_loss), similarity_loss = np.array(save_similatiry_loss), class_loss = np.array(save_class_loss),gt=np.array(save_gt)) 299 | 300 | print('Finished Training') 301 | return net 302 | 303 | def offline_spectral_cluster(net, train_data): 304 | net.eval() 305 | f_map = [] 306 | for inputs, labels in train_data: 307 | inputs, labels = inputs.cuda(), labels.cuda() 308 | cur_fmap= net(inputs,eval=True).detach().cpu().numpy() 309 | f_map.append(cur_fmap) 310 | f_map = np.concatenate(f_map,axis=0) 311 | sample, channel,_,_ = f_map.shape 312 | f_map = f_map.reshape((sample,channel,-1)) 313 | mean = np.mean(f_map,axis=0) 314 | cov = np.mean(np.matmul(f_map-mean,np.transpose(f_map-mean,(0,2,1))),axis=0) 315 | diag = np.diag(cov).reshape(channel,-1) 316 | correlation = cov/(np.sqrt(np.matmul(diag,np.transpose(diag,(1,0))))+1e-5)+1 317 | ground_true, loss_mask_num, loss_mask_den = spectral_clustering(correlation,n_cluster=center_num) 318 | 319 | return ground_true, loss_mask_num, loss_mask_den 320 | 321 | def test(net, testdata): 322 | correct, total = .0, .0 323 | for inputs, labels in testdata: 324 | inputs, labels = inputs.cuda(), labels.cuda() 325 | net.eval() 326 | outputs, _,_,_ = net(inputs) 327 | _, predicted = torch.max(outputs, 1) 328 | total += labels.size(0) 329 | correct += (predicted == labels).sum() 330 | return float(correct) / total 331 | 332 | def get_feature(): 333 | print('pretrain_model:', pretrain_model) 334 | _,testset_test = get_Data(True, DATANAME, BATCHSIZE) 335 | device = torch.device("cuda") 336 | net = None 337 | if LAYERS == '13': 338 | net = vgg16(arch='vgg13_bn',cfg='B', device=device, pretrained=True, progress=True) 339 | elif LAYERS == '16': 340 | net = vgg16(arch='vgg16_bn',cfg='D', device=device, pretrained=True, progress=True) 341 | elif LAYERS == '19': 342 | net = vgg16(arch='vgg19_bn',cfg='E', device=device, pretrained=True, progress=True) 343 | net = nn.DataParallel(net).to(device) 344 | acc = test(net, testset_test)## 345 | f = open(acc_path+dataset+'_test.txt', 'w+') 346 | f.write('%s\n' % dataset) 347 | f.write('acc:%f\n' % acc) 348 | print(acc) 349 | all_feature = [] 350 | for batch_step, input_data in tqdm(enumerate(testset_test,0),total=len(testset_test),smoothing=0.9): 351 | inputs, labels = input_data 352 | inputs, labels = inputs.to(device), labels.to(device) 353 | net.eval() 354 | f_map = net(inputs,eval=True) 355 | all_feature.append(f_map.detach().cpu().numpy()) 356 | all_feature = np.concatenate(all_feature,axis=0) 357 | f.write('sample num:%d' % (all_feature.shape[0])) 358 | f.close() 359 | print(all_feature.shape) 360 | np.savez(save_path+LAYERS+'_vgg_'+DATANAME+'_iccnn.npz', f_map=all_feature[...]) 361 | print('Finished Operation!') 362 | return net 363 | 364 | def vgg_multi_train(): 365 | if IS_TRAIN == 1: 366 | net_train() 367 | elif IS_TRAIN == 0: 368 | get_feature() 369 | -------------------------------------------------------------------------------- /vgg_iccnn_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import torch 4 | import torch.nn as nn 5 | import math 6 | import torchvision.transforms as transforms 7 | import torchvision as tv 8 | import torchvision.models as models 9 | from torch.utils.data import DataLoader 10 | from load_utils import load_state_dict_from_url 11 | from cub_voc import CUB_VOC 12 | import os 13 | from tqdm import tqdm 14 | import shutil 15 | import numpy as np 16 | from celeb import Celeb 17 | from Similar_Mask_Generate import SMGBlock 18 | from SpectralClustering import spectral_clustering 19 | from utils.utils import Cluster_loss 20 | from newPad2d import newPad2d 21 | 22 | IS_TRAIN = 0 # 0/1 23 | LAYERS = '13' 24 | DATANAME = 'bird' 25 | NUM_CLASSES = 80 if DATANAME == 'celeb' else 2 26 | cub_file = '/data/sw/dataset/frac_dataset' 27 | voc_file = '/data/sw/dataset/VOCdevkit/VOC2010/voc2010_crop' 28 | celeb_file = '/home/user05/fjq/dataset/CelebA/' 29 | log_path = '/data/fjq/iccnn/vgg/' # for model 30 | save_path = '/data/fjq/iccnn/basic_fmap/vgg/' # for get_feature 31 | acc_path = '/data/fjq/iccnn/basic_fmap/vgg/acc/' 32 | 33 | dataset = '%s_vgg_%s_iccnn' % (LAYERS, DATANAME) 34 | log_path = log_path + dataset + '/' 35 | pretrain_model = log_path + 'model_2000.pth' 36 | BATCHSIZE = 1 37 | LR = 0.00001 38 | EPOCH = 2500 39 | center_num = 5 40 | lam = 0.1 41 | T = 2 # T = 2 ===> do sc each epoch 42 | F_MAP_SIZE = 196 43 | STOP_CLUSTERING = 200 44 | if LAYERS == '13': 45 | CHANNEL_NUM = 512 46 | elif LAYERS == '16': 47 | CHANNEL_NUM = 512 48 | 49 | 50 | __all__ = ['VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn','vgg19_bn', 'vgg19',] 51 | cfgs = { 52 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 53 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 54 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 55 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],} 56 | model_urls = { 57 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 58 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 59 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 60 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 61 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 62 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 63 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 64 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',} 65 | 66 | class VGG(nn.Module): 67 | def __init__(self, features, num_classes, init_weights=True, cfg=None): 68 | super(VGG, self).__init__() 69 | 70 | self.features = features 71 | if cfg=='D': # VGG16 72 | self.target_layer = 42 73 | if cfg=='B': # VGG13 74 | self.target_layer = 33 75 | self.layer_num = self.target_layer 76 | self.pad2d = newPad2d(1) #nn.ReplicationPad2d(1) 77 | self.smg = SMGBlock(channel_size = CHANNEL_NUM, f_map_size=F_MAP_SIZE) 78 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 79 | self.classifier = nn.Sequential( 80 | #fc6 81 | nn.Linear(512*7*7, 4096),nn.ReLU(True),nn.Dropout(0.5), 82 | #fc7 83 | nn.Linear(4096, 512),nn.ReLU(True),nn.Dropout(0.5), 84 | #fc8 85 | nn.Linear(512, num_classes)) 86 | 87 | if init_weights: 88 | self._initialize_weights() 89 | 90 | def forward(self, x, eval=False): 91 | for layer in self.features[:self.target_layer+1]: 92 | if isinstance(layer,nn.Conv2d): 93 | x = self.pad2d(x) 94 | x = layer(x) 95 | if eval: 96 | return x 97 | corre_matrix = self.smg(x) 98 | f_map = x.detach() 99 | for layer in self.features[self.target_layer+1:]: 100 | if isinstance(layer,nn.Conv2d): 101 | x = self.pad2d(x) 102 | x = layer(x) 103 | x = self.avgpool(x) 104 | x = x.view(x.size(0), -1) 105 | x = self.classifier(x) 106 | return x, f_map, corre_matrix 107 | 108 | def _initialize_weights(self): 109 | for layer, m in enumerate(self.modules()): 110 | if layer > self.layer_num: 111 | if isinstance(m, nn.Conv2d): 112 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 113 | if m.bias is not None: 114 | nn.init.constant_(m.bias, 0) 115 | elif isinstance(m, nn.BatchNorm2d): 116 | nn.init.constant_(m.weight, 1) 117 | nn.init.constant_(m.bias, 0) 118 | if isinstance(m, nn.Linear): 119 | nn.init.normal_(m.weight, 0, 0.01) 120 | nn.init.constant_(m.bias, 0) 121 | 122 | 123 | def make_layers(cfg, batch_norm=False): 124 | layers = [] 125 | in_channels = 3 126 | for v in cfg: 127 | if v == 'M': 128 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 129 | else: 130 | conv2d = nn.Conv2d(in_channels, v, 3, padding=0) 131 | if batch_norm: 132 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 133 | else: 134 | layers += [conv2d, nn.ReLU(inplace=True)] 135 | in_channels = v 136 | return nn.Sequential(*layers) 137 | 138 | def vgg16(arch, cfg, num_class, device=None, pretrained=False, progress=True, **kwargs): 139 | if pretrained: 140 | kwargs['init_weights'] = False 141 | kwargs['cfg'] = cfg 142 | model = VGG(make_layers(cfgs[cfg], batch_norm=True), num_class, **kwargs) 143 | if pretrained: 144 | if pretrain_model is None: 145 | state_dict = load_state_dict_from_url(model_urls[arch],progress=progress) 146 | pretrained_dict = {k: v for k, v in state_dict.items() if 'classifier' not in k} 147 | model_dict = model.state_dict() 148 | model_dict.update(pretrained_dict) 149 | model.load_state_dict(model_dict) 150 | else: 151 | device = torch.device("cuda") 152 | model = nn.DataParallel(model).to(device) 153 | pretrained_dict = torch.load(pretrain_model) 154 | if IS_TRAIN == 0: 155 | pretrained_dict = {k[k.find('.')+1:]: v for k, v in pretrained_dict.items()} 156 | model.load_state_dict(pretrained_dict) 157 | if device is not None: 158 | model = nn.DataParallel(model).to(device) 159 | return model 160 | 161 | def get_Data(is_train, dataset_name, batch_size): 162 | val_transform = transforms.Compose([ 163 | transforms.Resize((224,224)), 164 | transforms.ToTensor(), 165 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 166 | std=[0.229, 0.224, 0.225]) 167 | ]) 168 | celeb_transform = transforms.Compose([ 169 | transforms.ToTensor(), 170 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 171 | ]) 172 | voc_helen = ['bird', 'cat', 'cow', 'dog', 'horse', 'sheep', 'helen', 'voc_multi'] 173 | ##cub dataset### 174 | label = None if is_train else 0 175 | if not is_train: 176 | batch_size = 1 177 | if dataset_name == 'cub': 178 | trainset = CUB_VOC(cub_file, dataset_name, 'iccnn', train=True, transform=val_transform, is_frac=label) 179 | testset = CUB_VOC(cub_file, dataset_name, 'iccnn', train=False, transform=val_transform, is_frac=label) 180 | ###cropped voc dataset### 181 | elif dataset_name in voc_helen: 182 | trainset = CUB_VOC(voc_file, dataset_name, 'iccnn', train=True, transform=val_transform, is_frac=label) 183 | testset = CUB_VOC(voc_file, dataset_name, 'iccnn', train=False, transform=val_transform, is_frac=label) 184 | ###celeb dataset### 185 | elif dataset_name == 'celeb': 186 | trainset = Celeb(celeb_file, training = True, transform=celeb_transform, train_num=10240) 187 | testset = Celeb(celeb_file, training = False, transform=celeb_transform, train_num=19962) 188 | train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True) 189 | test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False) 190 | return train_loader, test_loader 191 | 192 | def net_train(): 193 | trainset_loader, testset_loader = get_Data(IS_TRAIN, DATANAME, BATCHSIZE) 194 | if os.path.exists(log_path): 195 | shutil.rmtree(log_path);os.makedirs(log_path) 196 | else: 197 | os.makedirs(log_path) 198 | device = torch.device("cuda") 199 | test = test_celeb if DATANAME=='celeb' else test_ori 200 | 201 | net = None 202 | if LAYERS == '13': 203 | net = vgg16(arch='vgg13_bn',cfg='B', num_class=NUM_CLASSES, device=device, pretrained=True, progress=True) 204 | elif LAYERS == '16': 205 | net = vgg16(arch='vgg16_bn',cfg='D', num_class=NUM_CLASSES, device=device, pretrained=True, progress=True) 206 | 207 | # Loss and Optimizer 208 | criterion = nn.CrossEntropyLoss() 209 | optimizer = torch.optim.Adam(net.module.parameters(), lr=LR) 210 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=125, gamma=0.6) 211 | 212 | # Train the model 213 | best_acc = 0.0 214 | save_total_loss = []; save_similatiry_loss = [];save_gt=[] 215 | cs_loss = Cluster_loss() 216 | for epoch in range(EPOCH+1): 217 | if (epoch) % T==0 and epoch < STOP_CLUSTERING: 218 | with torch.no_grad(): 219 | Ground_true, loss_mask_num, loss_mask_den = offline_spectral_cluster(net, trainset_loader, DATANAME) 220 | save_gt.append(Ground_true.cpu().numpy()) 221 | else: 222 | scheduler.step() 223 | net.train() 224 | total_loss = 0.0;similarity_loss = 0.0 225 | 226 | for batch_step, input_data in tqdm(enumerate(trainset_loader,0),total=len(trainset_loader),smoothing=0.9): 227 | inputs, labels = input_data 228 | inputs, labels = inputs.to(device), labels.long().to(device) 229 | optimizer.zero_grad() 230 | output, f_map, corre = net(inputs, eval=False) 231 | 232 | if DATANAME != 'celeb': 233 | clr_loss = criterion(output, labels) 234 | else: 235 | clr_loss = .0 236 | for attribution in range(NUM_CLASSES//2): 237 | clr_loss += criterion(output[:, 2*attribution:2*attribution+2], labels[:, attribution]) 238 | labels = None 239 | 240 | loss_ = cs_loss.update(corre, loss_mask_num, loss_mask_den, labels) 241 | loss = clr_loss + lam * loss_ 242 | loss.backward() 243 | optimizer.step() 244 | total_loss += loss.item() 245 | similarity_loss += loss_.item() 246 | 247 | ### loss save code ##### 248 | total_loss = float(total_loss) / len(trainset_loader) 249 | similarity_loss = float(similarity_loss) / len(trainset_loader) 250 | save_total_loss.append(total_loss) 251 | save_similatiry_loss.append(similarity_loss) 252 | acc = 0#test(net, testset_loader, n_cls) 253 | print('Epoch', epoch, 'loss: %.4f' % total_loss,'sc_loss: %.4f' % similarity_loss, 'test accuracy:%.4f' % acc) 254 | 255 | if epoch % 100 == 0 : 256 | torch.save(net.state_dict(), log_path+'model_%.3d.pth' % (epoch)) 257 | np.savez(log_path+'loss_%.3d.npz'% (epoch), loss=np.array(save_total_loss), similarity_loss = np.array(save_similatiry_loss),gt=np.array(save_gt)) 258 | if epoch %1 == 0: 259 | if acc > best_acc: 260 | best_acc = acc 261 | torch.save(net.state_dict(), log_path+'model_%.3d_%.4f.pth' % (epoch,best_acc)) 262 | print('Finished Training') 263 | return net 264 | 265 | def offline_spectral_cluster(net, train_data, dataname): 266 | net.eval() 267 | f_map = [] 268 | for inputs, labels in train_data: 269 | inputs, labels = inputs.cuda(), labels.cuda() 270 | cur_fmap= net(inputs,eval=True).detach().cpu().numpy() 271 | f_map.append(cur_fmap) 272 | if dataname == 'celeb' and len(f_map)>=1024: 273 | break 274 | f_map = np.concatenate(f_map,axis=0) 275 | sample, channel,_,_ = f_map.shape 276 | f_map = f_map.reshape((sample,channel,-1)) 277 | mean = np.mean(f_map,axis=0) 278 | cov = np.mean(np.matmul(f_map-mean,np.transpose(f_map-mean,(0,2,1))),axis=0) 279 | diag = np.diag(cov).reshape(channel,-1) 280 | correlation = cov/(np.sqrt(np.matmul(diag,np.transpose(diag,(1,0))))+1e-5)+1 281 | ground_true, loss_mask_num, loss_mask_den = spectral_clustering(correlation,n_cluster=center_num) 282 | return ground_true, loss_mask_num, loss_mask_den 283 | 284 | def test_ori(net, testdata, n_cls): 285 | correct, total = .0, .0 286 | for inputs, labels in tqdm(testdata): 287 | inputs, labels = inputs.cuda(), labels.cuda().long() 288 | net.eval() 289 | outputs, _, _ = net(inputs) 290 | _, predicted = torch.max(outputs, 1) 291 | total += labels.size(0) 292 | correct += (predicted == labels).sum() 293 | return float(correct) / total 294 | 295 | def test_celeb(net, testdata, n_cls): 296 | correct, total = .0, .0 297 | ATTRIBUTION_NUM = n_cls//2 298 | running_correct = np.zeros(ATTRIBUTION_NUM) 299 | for inputs, labels in tqdm(testdata): 300 | inputs, labels = inputs.cuda(), labels.cuda().long() 301 | net.eval() 302 | outputs, _, _ = net(inputs) 303 | out = outputs.data 304 | total += labels.size(0) 305 | for attribution in range(ATTRIBUTION_NUM): 306 | _, predicted = torch.max(out[:, 2*attribution:2*attribution+2], 1) 307 | correct = (predicted == labels[:, attribution]).sum().item() 308 | running_correct[attribution] += correct 309 | attr_acc = running_correct / float(total) 310 | return np.mean(attr_acc) 311 | 312 | 313 | def get_feature(): 314 | print('pretrain_model:', pretrain_model) 315 | _,testset_test = get_Data(True, DATANAME, BATCHSIZE) 316 | _,testset_feature = get_Data(False, DATANAME, BATCHSIZE) 317 | device = torch.device("cuda") 318 | net = None 319 | if LAYERS == '13': 320 | net = vgg16(arch='vgg13_bn',cfg='B', num_class=NUM_CLASSES, device=device, pretrained=True, progress=True) 321 | elif LAYERS == '16': 322 | net = vgg16(arch='vgg16_bn',cfg='D', num_class=NUM_CLASSES, device=device, pretrained=True, progress=True) 323 | 324 | net = nn.DataParallel(net).to(device) 325 | test = test_celeb if DATANAME=='celeb' else test_ori 326 | acc = test(net, testset_test, NUM_CLASSES) 327 | f = open(acc_path+dataset+'_test.txt', 'w+') 328 | f.write('%s\n' % dataset) 329 | f.write('acc:%f\n' % acc) 330 | print(acc) 331 | all_feature = [] 332 | for batch_step, input_data in tqdm(enumerate(testset_feature,0),total=len(testset_feature),smoothing=0.9): 333 | inputs, labels = input_data 334 | inputs, labels = inputs.to(device), labels.to(device) 335 | net.eval() 336 | f_map = net(inputs,eval=True) 337 | all_feature.append(f_map.detach().cpu().numpy()) 338 | all_feature = np.concatenate(all_feature,axis=0) 339 | f.write('sample num:%d' % (all_feature.shape[0])) 340 | f.close() 341 | print(all_feature.shape) 342 | np.savez(save_path+LAYERS+'_vgg_'+DATANAME+'_iccnn.npz', f_map=all_feature[...]) 343 | print('Finished Operation!') 344 | return net 345 | 346 | def vgg_single_train(): 347 | if IS_TRAIN == 1: 348 | net_train() 349 | elif IS_TRAIN == 0: 350 | get_feature() 351 | 352 | -------------------------------------------------------------------------------- /vgg_ori_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.transforms as transforms 6 | import torchvision.models as models 7 | from torch.utils.data import DataLoader 8 | from load_utils import load_state_dict_from_url 9 | from cub_voc import CUB_VOC 10 | from celeb import Celeb 11 | import os 12 | from tqdm import tqdm 13 | import shutil 14 | import numpy as np 15 | from newPad2d import newPad2d 16 | 17 | IS_TRAIN = 0 # 0/1 18 | IS_MULTI = 0 19 | LAYERS = '13' 20 | DATANAME = 'bird' 21 | NUM_CLASSES = 6 if IS_MULTI else 2 22 | if DATANAME == 'celeb': 23 | NUM_CLASSES = 80 24 | cub_file = '/data/sw/dataset/frac_dataset' 25 | voc_file = '/data/sw/dataset/VOCdevkit/VOC2010/voc2010_crop' 26 | celeb_file = '/home/user05/fjq/dataset/CelebA/' 27 | log_path = '/data/fjq/iccnn/vgg/' # for model 28 | save_path = '/data/fjq/iccnn/basic_fmap/vgg/' # for get_feature 29 | acc_path = '/data/fjq/iccnn/basic_fmap/vgg/acc/' 30 | 31 | dataset = '%s_vgg_%s_ori' % (LAYERS, DATANAME) 32 | log_path = log_path + dataset + '/' 33 | pretrain_model = log_path + 'model_2000.pth' 34 | BATCHSIZE = 1 35 | LR = 0.000001 36 | EPOCH = 200 37 | 38 | 39 | __all__ = ['VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn','vgg19_bn', 'vgg19'] 40 | 41 | cfgs = { 42 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 43 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 44 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 45 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],} 46 | 47 | model_urls = { 48 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 49 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 50 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 51 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 52 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 53 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 54 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 55 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',} 56 | 57 | class VGG(nn.Module): 58 | def __init__(self, features, num_classes=2, cfg='D', init_weights=True): 59 | super(VGG, self).__init__() 60 | 61 | self.features = features 62 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 63 | self.pad2d = newPad2d(1)#nn.ReplicationPad2d(1) 64 | self.cfg = cfg 65 | self.classifier = nn.Sequential( #分类器结构 66 | #fc6 67 | nn.Linear(512*7*7, 4096),nn.ReLU(True),nn.Dropout(0.5), 68 | #fc7 69 | nn.Linear(4096, 512),nn.ReLU(True),nn.Dropout(0.5), 70 | #fc8 71 | nn.Linear(512, num_classes)) 72 | 73 | if init_weights: 74 | self._initialize_weights() 75 | 76 | def forward(self, x): 77 | target_layer = 42 if self.cfg=='D' else 33 78 | f_map = None 79 | for i, layer in enumerate(self.features): 80 | if isinstance(layer, nn.Conv2d): 81 | x = self.pad2d(x) 82 | x = layer(x) 83 | if i == target_layer: 84 | f_map = x.detach() 85 | x = self.avgpool(x) 86 | x = x.view(x.size(0), -1) 87 | x = self.classifier(x) 88 | return x, f_map 89 | 90 | def _initialize_weights(self): 91 | for layer, m in enumerate(self.modules()): 92 | if isinstance(m, nn.Conv2d): 93 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 94 | if m.bias is not None: 95 | nn.init.constant_(m.bias, 0) 96 | elif isinstance(m, nn.BatchNorm2d): 97 | nn.init.constant_(m.weight, 1) 98 | nn.init.constant_(m.bias, 0) 99 | if isinstance(m, nn.Linear): 100 | nn.init.normal_(m.weight, 0, 0.01) 101 | nn.init.constant_(m.bias, 0) 102 | 103 | 104 | def make_layers(cfg, batch_norm=False): 105 | layers = [] 106 | in_channels = 3 107 | for v in cfg: 108 | if v == 'M': 109 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 110 | else: 111 | conv2d = nn.Conv2d(in_channels, v, 3, padding=0) 112 | if batch_norm: 113 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] #new padding 114 | else: 115 | layers += [conv2d, nn.ReLU(inplace=True)] #new padding 116 | in_channels = v 117 | return nn.Sequential(*layers) 118 | 119 | def vgg(arch, cfg, num_class, device=None, pretrained=False, progress=True, **kwargs): 120 | model = VGG(make_layers(cfgs[cfg], batch_norm=True), num_class, cfg, **kwargs) 121 | if pretrained: 122 | pretrain_layer = 39 if cfg=='D' else 30 123 | state_dict = load_state_dict_from_url(model_urls[arch],progress=progress) 124 | pretrained_dict = {k: v for k, v in state_dict.items() if 'classifier' not in k and int(k.split('.')[1])<=pretrain_layer} 125 | model_dict = model.state_dict() 126 | model_dict.update(pretrained_dict) 127 | model.load_state_dict(model_dict) 128 | else: 129 | device = torch.device("cuda") 130 | model = nn.DataParallel(model).to(device) 131 | model.load_state_dict(torch.load(pretrain_model)) 132 | return model 133 | 134 | def get_Data(is_train, dataset_name, batch_size): 135 | transform = transforms.Compose([ 136 | transforms.RandomResizedCrop((224,224),scale=(0.5,1.0)), 137 | transforms.RandomHorizontalFlip(), 138 | transforms.ToTensor(), 139 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 140 | std=[0.229, 0.224, 0.225]) 141 | ]) 142 | val_transform = transforms.Compose([ 143 | transforms.Resize((224,224)), 144 | transforms.ToTensor(), 145 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 146 | std=[0.229, 0.224, 0.225]) 147 | ]) 148 | voc_helen_name = ['bird', 'cat', 'cow', 'dog', 'horse', 'sheep', 'helen', 'voc_multi'] 149 | ##cub dataset### 150 | label = None if is_train else 0 151 | if dataset_name == 'cub': 152 | trainset = CUB_VOC(cub_file, dataset_name, 'ori', train=True, transform=transform, is_frac=label) 153 | testset = CUB_VOC(cub_file, dataset_name, 'ori', train=False, transform=val_transform, is_frac=label) 154 | ###cropped voc dataset### 155 | elif dataset_name in voc_helen_name: 156 | trainset = CUB_VOC(voc_file, dataset_name, 'ori', train=True, transform=transform, is_frac=label) 157 | testset = CUB_VOC(voc_file, dataset_name, 'ori', train=False, transform=val_transform, is_frac=label) 158 | ###celeb dataset### 159 | elif dataset_name == 'celeb': 160 | trainset = Celeb(celeb_file, training = True, transform=None, train_num=162770) 161 | testset = Celeb(celeb_file, training = False, transform=None, train_num=19962) 162 | train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, drop_last=False) 163 | test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, drop_last=False) 164 | return train_loader, test_loader 165 | 166 | def net_train(): 167 | trainset_loader, testset_loader = get_Data(IS_TRAIN, DATANAME, BATCHSIZE) 168 | device = torch.device("cuda") 169 | layer_arch = 'vgg13_bn' if LAYERS=='13' else 'vgg16_bn' 170 | layer_cfg = 'B' if LAYERS=='13' else 'D' 171 | #model_path = os.path.join(log_path, '%s_vgg_%s_%s' % (layers, dataset_name, mytype)) 172 | if os.path.exists(log_path): 173 | shutil.rmtree(log_path);os.makedirs(log_path) 174 | else: 175 | os.makedirs(log_path) 176 | 177 | net = vgg(arch=layer_arch,cfg=layer_cfg,num_class=NUM_CLASSES,device=device,pretrained=True,progress=True, ) 178 | net = nn.DataParallel(net).to(device) 179 | 180 | # Loss and Optimizer 181 | criterion = nn.CrossEntropyLoss() 182 | optimizer = torch.optim.Adam(net.module.parameters(), lr=LR) 183 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=125, gamma=0.6) 184 | test = test_celeb if DATANAME=='celeb' else test_ori 185 | # Train the model 186 | best_acc = 0.0; save_loss = []; 187 | for epoch in range(0, EPOCH+1): 188 | scheduler.step() 189 | net.train() 190 | total_loss = 0.0; 191 | for batch_step, input_data in tqdm(enumerate(trainset_loader,0),total=len(trainset_loader),smoothing=0.9): 192 | inputs, labels = input_data 193 | inputs, labels = inputs.to(device), labels.to(device).long() 194 | optimizer.zero_grad() 195 | output, _ = net(inputs) 196 | #print(output) 197 | if DATANAME != 'celeb': 198 | loss = criterion(output, labels) 199 | else: 200 | loss = .0 201 | for attribution in range(NUM_CLASSES//2): 202 | loss += criterion(output[:, 2*attribution:2*attribution+2], labels[:, attribution]) 203 | loss.backward() 204 | optimizer.step() 205 | total_loss += loss.item() 206 | ### loss save code ##### 207 | total_loss = float(total_loss) / len(trainset_loader) 208 | save_loss.append(total_loss) 209 | np.savez(os.path.join(log_path, 'loss.npz'), loss=np.array(save_loss)) 210 | if epoch % 50 == 0: 211 | train_acc = test(net, trainset_loader, NUM_CLASSES) 212 | print('Epoch', epoch, 'train accuracy:%.4f' % train_acc) 213 | torch.save(net.state_dict(), log_path+'/model_%.3d.pth' % (epoch)) 214 | if epoch % 1 == 0: 215 | acc = test(net, testset_loader, NUM_CLASSES) 216 | print('Epoch', epoch, 'loss: %.4f' % total_loss, 'test accuracy:%.4f' % acc) 217 | if acc > best_acc and epoch >= 10: 218 | best_acc = acc 219 | torch.save(net.state_dict(), log_path+'/model_%.3d_%.4f.pth' % (epoch, best_acc)) 220 | print('Finished Training') 221 | return net 222 | 223 | def get_feature(): 224 | _, testset_test = get_Data(True, DATANAME, BATCHSIZE) 225 | _, testset_feature = get_Data(False, DATANAME, BATCHSIZE) 226 | device = torch.device("cuda") 227 | layer_arch = 'vgg13_bn' if LAYERS=='13' else 'vgg16_bn' 228 | layer_cfg = 'B' if LAYERS=='13' else 'D' 229 | 230 | if not os.path.exists(pretrain_model): 231 | raise Exception("Not such pretrain-model!") 232 | net = vgg(arch=layer_arch,cfg=layer_cfg,num_class=NUM_CLASSES,device=device,pretrained=False,progress=True, ) 233 | net = nn.DataParallel(net).to(device) 234 | test = test_celeb if DATANAME=='celeb' else test_ori 235 | acc = test(net, testset_test, NUM_CLASSES) 236 | print('test acc:', acc) 237 | #if not os.path.exists(acc_path): 238 | # os.makedirs(acc_path) 239 | f = open(os.path.join(acc_path, layer_arch+'_'+DATANAME+'_test.txt'), 'w+') 240 | f.write('%s\n' % layer_arch) 241 | f.write('%s\n' % DATANAME) 242 | f.write('acc:%f\n' % acc) 243 | #if not os.path.exists(save_path): 244 | # os.makedirs(save_path) 245 | testset = testset_test if DATANAME == 'voc_multi' else testset_feature 246 | all_feature = [] 247 | for batch_step, input_data in tqdm(enumerate(testset,0),total=len(testset),smoothing=0.9): 248 | inputs, labels = input_data 249 | inputs, labels = inputs.cuda(), labels.cuda() 250 | net.eval() 251 | _, f_map = net(inputs) 252 | all_feature.append(f_map.cpu().numpy()) 253 | all_feature = np.concatenate(all_feature, axis=0) 254 | f.write('sample num:%d' % (all_feature.shape[0])) 255 | f.close() 256 | print(all_feature.shape) 257 | np.savez_compressed(save_path+LAYERS+'_vgg_'+DATANAME+'_ori.npz', f_map=all_feature[...]) 258 | print('Finished Getting Feature!') 259 | return net 260 | 261 | def test_ori(net, testdata, n_cls): 262 | correct, total = .0, .0 263 | for inputs, labels in testdata: 264 | inputs, labels = inputs.cuda(), labels.cuda().long() 265 | net.eval() 266 | outputs, _ = net(inputs) 267 | _, predicted = torch.max(outputs, 1) 268 | total += labels.size(0) 269 | correct += (predicted == labels).sum() 270 | return float(correct) / total 271 | 272 | def test_celeb(net, testdata, n_cls): 273 | correct, total = .0, .0 274 | ATTRIBUTION_NUM = n_cls//2 275 | running_correct = np.zeros(ATTRIBUTION_NUM) 276 | for inputs, labels in tqdm(testdata): 277 | inputs, labels = inputs.cuda(), labels.cuda().long() 278 | net.eval() 279 | outputs, _ = net(inputs) 280 | out = outputs.data 281 | total += labels.size(0) 282 | for attribution in range(ATTRIBUTION_NUM): 283 | _, predicted = torch.max(out[:, 2*attribution:2*attribution+2], 1) 284 | correct = (predicted == labels[:, attribution]).sum().item() 285 | running_correct[attribution] += correct 286 | attr_acc = running_correct / float(total) 287 | return np.mean(attr_acc) 288 | 289 | def vgg_ori_train(): 290 | if IS_TRAIN: 291 | net = net_train() 292 | else: 293 | net = get_feature() 294 | 295 | if __name__ == '__main__': 296 | vgg_ori_train() 297 | --------------------------------------------------------------------------------