├── src ├── network │ ├── __init__.py │ ├── __pycache__ │ │ ├── base.cpython-38.pyc │ │ ├── vgg.cpython-38.pyc │ │ ├── alexnet.cpython-38.pyc │ │ ├── resnet.cpython-38.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── densenet.cpython-38.pyc │ │ └── inception.cpython-38.pyc │ ├── alexnet.py │ ├── base.py │ ├── vgg.py │ ├── resnet.py │ ├── densenet.py │ └── inception.py └── representation │ ├── __init__.py │ ├── GAvP.py │ ├── INVCOV.py │ ├── COV.py │ └── SICE.py ├── isice.png ├── model_init.py ├── isice.yml ├── functions.py ├── train_iSICE_model.sh ├── LICENCE ├── imagepreprocess.py ├── README.md └── main.py /src/network/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | 3 | 4 | -------------------------------------------------------------------------------- /isice.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csiro-robotics/iSICE/HEAD/isice.png -------------------------------------------------------------------------------- /src/network/__pycache__/base.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csiro-robotics/iSICE/HEAD/src/network/__pycache__/base.cpython-38.pyc -------------------------------------------------------------------------------- /src/network/__pycache__/vgg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csiro-robotics/iSICE/HEAD/src/network/__pycache__/vgg.cpython-38.pyc -------------------------------------------------------------------------------- /src/representation/__init__.py: -------------------------------------------------------------------------------- 1 | from .GAvP import GAvP 2 | from .COV import COV 3 | from .INVCOV import INVCOV 4 | from .SICE import SICE 5 | -------------------------------------------------------------------------------- /src/network/__pycache__/alexnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csiro-robotics/iSICE/HEAD/src/network/__pycache__/alexnet.cpython-38.pyc -------------------------------------------------------------------------------- /src/network/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csiro-robotics/iSICE/HEAD/src/network/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /src/network/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csiro-robotics/iSICE/HEAD/src/network/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/network/__pycache__/densenet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csiro-robotics/iSICE/HEAD/src/network/__pycache__/densenet.cpython-38.pyc -------------------------------------------------------------------------------- /src/network/__pycache__/inception.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csiro-robotics/iSICE/HEAD/src/network/__pycache__/inception.cpython-38.pyc -------------------------------------------------------------------------------- /src/representation/GAvP.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class GAvP(nn.Module): 5 | """Global Average pooling 6 | Widely used in ResNet, Inception, DenseNet, etc. 7 | """ 8 | def __init__(self, input_dim=2048, dimension_reduction=None): 9 | super(GAvP, self).__init__() 10 | self.dr = dimension_reduction 11 | if self.dr is not None: 12 | self.conv_dr_block = nn.Sequential( 13 | nn.Conv2d(input_dim, self.dr, kernel_size=1, stride=1, bias=False), 14 | nn.BatchNorm2d(self.dr), 15 | nn.ReLU(inplace=True) 16 | ) 17 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 18 | self.output_dim = self.dr if self.dr else input_dim 19 | self._init_weight() 20 | 21 | def _init_weight(self): 22 | for m in self.modules(): 23 | if isinstance(m, nn.Conv2d): 24 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 25 | elif isinstance(m, nn.BatchNorm2d): 26 | nn.init.constant_(m.weight, 1) 27 | nn.init.constant_(m.bias, 0) 28 | 29 | def forward(self, x): 30 | if self.dr is not None: 31 | x = self.conv_dr_block(x) 32 | x = self.avgpool(x) 33 | return x 34 | -------------------------------------------------------------------------------- /src/network/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | 4 | 5 | __all__ = ['AlexNet', 'alexnet'] 6 | 7 | 8 | model_urls = { 9 | 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', 10 | } 11 | 12 | 13 | class AlexNet(nn.Module): 14 | 15 | def __init__(self, num_classes=1000): 16 | super(AlexNet, self).__init__() 17 | self.features = nn.Sequential( 18 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), 19 | nn.ReLU(inplace=True), 20 | nn.MaxPool2d(kernel_size=3, stride=2), 21 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 22 | nn.ReLU(inplace=True), 23 | nn.MaxPool2d(kernel_size=3, stride=2), 24 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 25 | nn.ReLU(inplace=True), 26 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 27 | nn.ReLU(inplace=True), 28 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 29 | nn.ReLU(inplace=True), 30 | nn.MaxPool2d(kernel_size=3, stride=2), 31 | ) 32 | self.classifier = nn.Sequential( 33 | nn.Dropout(), 34 | nn.Linear(256 * 6 * 6, 4096), 35 | nn.ReLU(inplace=True), 36 | nn.Dropout(), 37 | nn.Linear(4096, 4096), 38 | nn.ReLU(inplace=True), 39 | nn.Linear(4096, num_classes), 40 | ) 41 | 42 | def forward(self, x): 43 | x = self.features(x) 44 | x = x.view(x.size(0), 256 * 6 * 6) 45 | x = self.classifier(x) 46 | return x 47 | 48 | 49 | def alexnet(pretrained=False, **kwargs): 50 | r"""AlexNet model architecture from the 51 | `"One weird trick..." `_ paper. 52 | 53 | Args: 54 | pretrained (bool): If True, returns a model pre-trained on ImageNet 55 | """ 56 | model = AlexNet(**kwargs) 57 | if pretrained: 58 | model.load_state_dict(model_zoo.load_url(model_urls['alexnet'])) 59 | return model 60 | -------------------------------------------------------------------------------- /model_init.py: -------------------------------------------------------------------------------- 1 | from src.network import * 2 | import torch 3 | import torch.nn as nn 4 | import warnings 5 | __all__ = ['Newmodel', 'get_model'] 6 | 7 | class Newmodel(Basemodel): 8 | """replace the image representation method and classifier 9 | 10 | Args: 11 | modeltype: model archtecture 12 | representation: image representation method 13 | num_classes: the number of classes 14 | freezed_layer: the end of freezed layers in network 15 | pretrained: whether use pretrained weights or not 16 | """ 17 | def __init__(self, modeltype, representation, num_classes, freezed_layer, pretrained=False): 18 | super(Newmodel, self).__init__(modeltype, pretrained) 19 | if representation is not None: 20 | representation_method = representation['function'] 21 | representation.pop('function') 22 | representation_args = representation 23 | representation_args['input_dim'] = self.representation_dim 24 | self.representation = representation_method(**representation_args) 25 | fc_input_dim = self.representation.output_dim 26 | if not pretrained: 27 | if isinstance(self.classifier, nn.Sequential): # for alexnet and vgg* 28 | conv6_index = 0 29 | for m in self.classifier.children(): 30 | if isinstance(m, nn.Linear): 31 | output_dim = m.weight.size(0) # 4096 32 | self.classifier[conv6_index] = nn.Linear(fc_input_dim, output_dim) #conv6 33 | break 34 | conv6_index += 1 35 | self.classifier[-1] = nn.Linear(output_dim, num_classes) 36 | else: 37 | self.classifier = nn.Linear(fc_input_dim, num_classes) 38 | else: 39 | self.classifier = nn.Linear(fc_input_dim, num_classes) 40 | else: 41 | if modeltype.startswith('alexnet') or modeltype.startswith('vgg'): 42 | output_dim = self.classifier[-1].weight.size(1) # 4096 43 | self.classifier[-1] = nn.Linear(output_dim, num_classes) 44 | else: 45 | self.classifier = nn.Linear(self.representation_dim, num_classes) 46 | index_before_freezed_layer = 0 47 | if freezed_layer: 48 | for m in self.features.children(): 49 | if index_before_freezed_layer < freezed_layer: 50 | m = self._freeze(m) 51 | index_before_freezed_layer += 1 52 | 53 | def _freeze(self, modules): 54 | for param in modules.parameters(): 55 | param.requires_grad = False 56 | return modules 57 | 58 | 59 | def get_model(modeltype, representation, num_classes, freezed_layer, pretrained=False): 60 | _model = Newmodel(modeltype, representation, num_classes, freezed_layer, pretrained=pretrained) 61 | return _model 62 | -------------------------------------------------------------------------------- /isice.yml: -------------------------------------------------------------------------------- 1 | name: isice 2 | channels: 3 | - anaconda 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=conda_forge 8 | - _openmp_mutex=4.5=2_gnu 9 | - blas=1.0=mkl 10 | - brotli=1.0.9=h166bdaf_7 11 | - brotli-bin=1.0.9=h166bdaf_7 12 | - ca-certificates=2022.9.24=ha878542_0 13 | - certifi=2022.9.24=pyhd8ed1ab_0 14 | - code-server=3.10.2=ha770c72_0 15 | - cycler=0.11.0=pyhd8ed1ab_0 16 | - dbus=1.13.18=hb2f20db_0 17 | - expat=2.4.8=h27087fc_0 18 | - fontconfig=2.14.0=h8e229c2_0 19 | - fonttools=4.25.0=pyhd3eb1b0_0 20 | - freetype=2.10.4=h0708190_1 21 | - gettext=0.19.8.1=h0b5b191_1005 22 | - glib=2.68.4=h9c3ff4c_0 23 | - glib-tools=2.68.4=h9c3ff4c_0 24 | - gst-plugins-base=1.18.4=h29181c9_0 25 | - gstreamer=1.18.5=h76c114f_0 26 | - icu=68.2=h9c3ff4c_0 27 | - intel-openmp=2021.4.0=h06a4308_3561 28 | - jpeg=9e=h166bdaf_1 29 | - keyutils=1.6.1=h166bdaf_0 30 | - kiwisolver=1.4.2=py38h295c915_0 31 | - krb5=1.19.3=h3790be6_0 32 | - ld_impl_linux-64=2.38=h1181459_1 33 | - libbrotlicommon=1.0.9=h166bdaf_7 34 | - libbrotlidec=1.0.9=h166bdaf_7 35 | - libbrotlienc=1.0.9=h166bdaf_7 36 | - libclang=11.1.0=default_ha53f305_1 37 | - libedit=3.1.20191231=he28a2e2_2 38 | - libevent=2.1.10=h9b69904_4 39 | - libffi=3.3=he6710b0_2 40 | - libgcc-ng=12.1.0=h8d9b700_16 41 | - libgfortran-ng=7.5.0=ha8ba4b0_17 42 | - libgfortran4=7.5.0=ha8ba4b0_17 43 | - libglib=2.68.4=h3e27bee_0 44 | - libgomp=12.1.0=h8d9b700_16 45 | - libiconv=1.16=h516909a_0 46 | - libllvm11=11.1.0=hf817b99_3 47 | - libpng=1.6.37=h21135ba_2 48 | - libpq=13.5=hd57d9b9_1 49 | - libsodium=1.0.18=h36c2ea0_1 50 | - libstdcxx-ng=11.2.0=h1234567_1 51 | - libtiff=4.0.10=hc3755c2_1005 52 | - libuuid=2.32.1=h7f98852_1000 53 | - libuv=1.42.0=h7f98852_0 54 | - libxcb=1.13=h7f98852_1004 55 | - libxkbcommon=1.0.3=he3ba5ed_0 56 | - libxml2=2.9.12=h72842e0_0 57 | - libzlib=1.2.12=h166bdaf_2 58 | - lz4-c=1.9.3=h9c3ff4c_1 59 | - matplotlib=3.5.2=py38h578d9bd_0 60 | - matplotlib-base=3.5.2=py38h826bfd8_0 61 | - mkl=2021.4.0=h06a4308_640 62 | - mkl-service=2.4.0=py38h7f8727e_0 63 | - mkl_fft=1.3.1=py38hd3c417c_0 64 | - mkl_random=1.2.2=py38h51133e4_0 65 | - munkres=1.1.4=pyh9f0ad1d_0 66 | - mysql-common=8.0.25=ha770c72_0 67 | - mysql-libs=8.0.25=h935591d_0 68 | - ncurses=6.3=h5eee18b_3 69 | - nodejs=12.22.6=h8b53aa1_0 70 | - nspr=4.32=h9c3ff4c_1 71 | - nss=3.77=h2350873_0 72 | - numpy-base=1.22.3=py38hf524024_0 73 | - olefile=0.46=pyh9f0ad1d_1 74 | - openssl=1.1.1s=h166bdaf_0 75 | - packaging=21.3=pyhd8ed1ab_0 76 | - pcre=8.45=h9c3ff4c_0 77 | - pip=22.1.2=py38h06a4308_0 78 | - pthread-stubs=0.4=h36c2ea0_1001 79 | - pyparsing=3.0.9=pyhd8ed1ab_0 80 | - pyqt=5.12.3=py38ha8c2ead_4 81 | - python=3.8.13=h12debd9_0 82 | - python-dateutil=2.8.2=pyhd8ed1ab_0 83 | - python_abi=3.8=2_cp38 84 | - qt=5.12.9=hda022c4_4 85 | - readline=8.1.2=h7f8727e_1 86 | - scipy=1.7.3=py38hc147768_0 87 | - setuptools=61.2.0=py38h06a4308_0 88 | - sip=4.19.13=py38h295c915_0 89 | - six=1.16.0=pyh6c4a22f_0 90 | - sqlite=3.38.5=hc218d9a_0 91 | - time=1.8=h516909a_0 92 | - tk=8.6.12=h1ccaba5_0 93 | - tornado=5.1.1=py38h1e0a361_2 94 | - vscode-jupyter=2021.3.1=py38h38764e9_1 95 | - vscode-python=2021.4.765268190=py38h4d516c6_1 96 | - wheel=0.37.1=pyhd3eb1b0_0 97 | - xorg-libxau=1.0.9=h7f98852_0 98 | - xorg-libxdmcp=1.1.3=h7f98852_0 99 | - xz=5.2.5=h7f8727e_1 100 | - zeromq=4.3.4=h9c3ff4c_1 101 | - zlib=1.2.12=h7f8727e_2 102 | - zstd=1.4.9=ha95c52a_0 103 | - pip: 104 | - charset-normalizer==2.1.0 105 | - idna==3.3 106 | - numpy==1.23.1 107 | - pillow==9.2.0 108 | - protobuf==3.20.1 109 | - pyqt5-sip==4.19.18 110 | - pyqtchart==5.12 111 | - pyqtwebengine==5.12.1 112 | - requests==2.28.1 113 | - tensorboardx==2.5.1 114 | - timer==0.2.2 115 | - timm==0.6.5 116 | - torch==1.12.0+cu113 117 | - torchaudio==0.12.0+cu113 118 | - torchvision==0.13.0+cu113 119 | - typing-extensions==4.3.0 120 | - urllib3==1.26.11 121 | prefix: /home/rah025/.conda/envs/isice 122 | -------------------------------------------------------------------------------- /functions.py: -------------------------------------------------------------------------------- 1 | import os 2 | import matplotlib as mpl 3 | if os.environ.get('DISPLAY','') == '': 4 | print('no display found. Using non-interactive Agg backend') 5 | mpl.use('Agg') 6 | import matplotlib.pyplot as plt 7 | import scipy.io as sio 8 | import torch 9 | import torch.nn as nn 10 | import numpy as np 11 | 12 | 13 | class stats: 14 | def __init__(self, path, start_epoch): 15 | if start_epoch is not 0: 16 | stats_ = sio.loadmat(os.path.join(path,'stats.mat')) 17 | data = stats_['data'] 18 | content = data[0,0] 19 | self.trainObj = content['trainObj'][:,:start_epoch].squeeze().tolist() 20 | self.trainTop1 = content['trainTop1'][:,:start_epoch].squeeze().tolist() 21 | self.trainTop5 = content['trainTop5'][:,:start_epoch].squeeze().tolist() 22 | self.valObj = content['valObj'][:,:start_epoch].squeeze().tolist() 23 | self.valTop1 = content['valTop1'][:,:start_epoch].squeeze().tolist() 24 | self.valTop5 = content['valTop5'][:,:start_epoch].squeeze().tolist() 25 | if start_epoch is 1: 26 | self.trainObj = [self.trainObj] 27 | self.trainTop1 = [self.trainTop1] 28 | self.trainTop5 = [self.trainTop5] 29 | self.valObj = [self.valObj] 30 | self.valTop1 = [self.valTop1] 31 | self.valTop5 = [self.valTop5] 32 | else: 33 | self.trainObj = [] 34 | self.trainTop1 = [] 35 | self.trainTop5 = [] 36 | self.valObj = [] 37 | self.valTop1 = [] 38 | self.valTop5 = [] 39 | def _update(self, trainObj, top1, top5, valObj, prec1, prec5): 40 | self.trainObj.append(trainObj) 41 | self.trainTop1.append(top1.cpu().numpy()) 42 | self.trainTop5.append(top5.cpu().numpy()) 43 | self.valObj.append(valObj) 44 | self.valTop1.append(prec1.cpu().numpy()) 45 | self.valTop5.append(prec5.cpu().numpy()) 46 | 47 | 48 | def vizNet(model, path): 49 | model.eval() 50 | x = torch.randn(10,3,224,224) 51 | y = model(x) 52 | g = make_dot(y) 53 | g.render(os.path.join(path,'graph'), view=False) 54 | 55 | def plot_curve(stats, path, iserr): 56 | trainObj = np.array(stats.trainObj) 57 | valObj = np.array(stats.valObj) 58 | if iserr: 59 | trainTop1 = 100 - np.array(stats.trainTop1) 60 | trainTop5 = 100 - np.array(stats.trainTop5) 61 | valTop1 = 100 - np.array(stats.valTop1) 62 | valTop5 = 100 - np.array(stats.valTop5) 63 | titleName = 'error' 64 | else: 65 | trainTop1 = np.array(stats.trainTop1) 66 | trainTop5 = np.array(stats.trainTop5) 67 | valTop1 = np.array(stats.valTop1) 68 | valTop5 = np.array(stats.valTop5) 69 | titleName = 'accuracy' 70 | epoch = len(trainObj) 71 | figure = plt.figure() 72 | obj = plt.subplot(1,3,1) 73 | obj.plot(range(1,epoch+1),trainObj,'o-',label = 'train') 74 | obj.plot(range(1,epoch+1),valObj,'o-',label = 'val') 75 | plt.xlabel('epoch') 76 | plt.title('objective') 77 | handles, labels = obj.get_legend_handles_labels() 78 | obj.legend(handles[::-1], labels[::-1]) 79 | top1 = plt.subplot(1,3,2) 80 | top1.plot(range(1,epoch+1),trainTop1,'o-',label = 'train') 81 | top1.plot(range(1,epoch+1),valTop1,'o-',label = 'val') 82 | plt.title('top1'+titleName) 83 | plt.xlabel('epoch') 84 | handles, labels = top1.get_legend_handles_labels() 85 | top1.legend(handles[::-1], labels[::-1]) 86 | top5 = plt.subplot(1,3,3) 87 | top5.plot(range(1,epoch+1),trainTop5,'o-',label = 'train') 88 | top5.plot(range(1,epoch+1),valTop5,'o-',label = 'val') 89 | plt.title('top5'+titleName) 90 | plt.xlabel('epoch') 91 | handles, labels = top5.get_legend_handles_labels() 92 | top5.legend(handles[::-1], labels[::-1]) 93 | filename = os.path.join(path, 'net-train.pdf') 94 | figure.savefig(filename, bbox_inches='tight') 95 | plt.close() 96 | 97 | def decode_params(input_params): 98 | params = input_params[0] 99 | out_params = [] 100 | _start=0 101 | _end=0 102 | for i in range(len(params)): 103 | if params[i] == ',': 104 | out_params.append(float(params[_start:_end])) 105 | _start=_end+1 106 | _end+=1 107 | out_params.append(float(params[_start:_end])) 108 | return out_params 109 | -------------------------------------------------------------------------------- /train_iSICE_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # load anaconda and activate the environment 4 | eval "$(command conda 'shell.bash' 'hook' 2> /dev/null)" 5 | conda activate your_environment_name 6 | 7 | # please provide the path of source code 8 | cd /path/to/your/files 9 | 10 | set -e 11 | :< transform method of '{}' does not exist!".format(dataset)) 141 | return train_transforms, val_transforms, evaluate_transforms 142 | -------------------------------------------------------------------------------- /src/network/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.models import convnext_tiny, swin_t 4 | import warnings as warn 5 | from .alexnet import * 6 | from .vgg import * 7 | from .resnet import * 8 | from .inception import * 9 | from .densenet import * 10 | 11 | def get_basemodel(modeltype, pretrained=False): 12 | modeltype = globals()[modeltype] 13 | if pretrained == False: 14 | warn.warn('You will use model that randomly initialized!') 15 | return modeltype(pretrained=pretrained) 16 | 17 | class Basemodel(nn.Module): 18 | """Load backbone model and reconstruct it into three part: 19 | 1) feature extractor 20 | 2) global image representaion 21 | 3) classifier 22 | """ 23 | def __init__(self, modeltype, pretrained=False): 24 | super(Basemodel, self).__init__() 25 | # SR : to fit-in convnext 26 | if not modeltype.startswith('convnext') and not modeltype.startswith('swin'): 27 | basemodel = get_basemodel(modeltype, pretrained) 28 | self.pretrained = pretrained 29 | if modeltype.startswith('alexnet'): 30 | basemodel = self._reconstruct_alexnet(basemodel) 31 | if modeltype.startswith('vgg'): 32 | basemodel = self._reconstruct_vgg(basemodel) 33 | if modeltype.startswith('resnet'): 34 | basemodel = self._reconstruct_resnet(basemodel) 35 | if modeltype.startswith('inception'): 36 | basemodel = self._reconstruct_inception(basemodel) 37 | if modeltype.startswith('densenet'): 38 | basemodel = self._reconstruct_densenet(basemodel) 39 | if modeltype.startswith('mpncovresnet'): 40 | basemodel = self._reconstruct_mpncovresnet(basemodel) # 41 | if modeltype.startswith('mpncovvgg'): 42 | basemodel = self._reconstruct_mpncov_vgg(basemodel) 43 | if modeltype.startswith('convnext'): 44 | basemodel = nn.Module() 45 | model = convnext_tiny(pretrained=pretrained) 46 | basemodel.features = model.features 47 | basemodel.representation = model.avgpool 48 | basemodel.classifier = model.classifier 49 | basemodel.representation_dim = model.classifier[2].weight.size(1) 50 | if modeltype.startswith('swin'): 51 | basemodel = nn.Module() 52 | model = swin_t(weights='DEFAULT') 53 | basemodel.features = model.features 54 | basemodel.representation = model.avgpool 55 | basemodel.classifier = model.head 56 | basemodel.representation_dim = model.head.weight.size(1) 57 | self.features = basemodel.features 58 | self.representation = basemodel.representation 59 | self.classifier = basemodel.classifier 60 | self.representation_dim = basemodel.representation_dim 61 | def _reconstruct_alexnet(self, basemodel): 62 | model = nn.Module() 63 | model.features = basemodel.features[:-1] 64 | model.representation = basemodel.features[-1] 65 | if self.pretrained: 66 | model.classifier = basemodel.classifier[-1] 67 | else: 68 | model.classifier = basemodel.classifier 69 | model.representation_dim = 256 70 | return model 71 | def _reconstruct_vgg(self, basemodel): 72 | model = nn.Module() 73 | model.features = basemodel.features[:-1] 74 | model.representation = basemodel.features[-1] 75 | if self.pretrained: 76 | model.classifier = basemodel.classifier[-1] 77 | else: 78 | model.classifier = basemodel.classifier 79 | model.representation_dim = 512 80 | return model 81 | def _reconstruct_resnet(self, basemodel): 82 | model = nn.Module() 83 | model.features = nn.Sequential(*list(basemodel.children())[:-2]) 84 | model.representation = basemodel.avgpool 85 | model.classifier = basemodel.fc 86 | model.representation_dim=basemodel.fc.weight.size(1) 87 | return model 88 | def _reconstruct_inception(self, basemodel): 89 | model = nn.Module() 90 | model.features = nn.Sequential(basemodel.Conv2d_1a_3x3, 91 | basemodel.Conv2d_2a_3x3, 92 | basemodel.Conv2d_2b_3x3, 93 | nn.MaxPool2d(kernel_size=3, stride=2), 94 | basemodel.Conv2d_3b_1x1, 95 | basemodel.Conv2d_4a_3x3, 96 | nn.MaxPool2d(kernel_size=3, stride=2), 97 | basemodel.Mixed_5b, 98 | basemodel.Mixed_5c, 99 | basemodel.Mixed_5d, 100 | basemodel.Mixed_6a, 101 | basemodel.Mixed_6b, 102 | basemodel.Mixed_6c, 103 | basemodel.Mixed_6d, 104 | basemodel.Mixed_6e, 105 | basemodel.Mixed_7a, 106 | basemodel.Mixed_7b, 107 | basemodel.Mixed_7c) 108 | model.representation = nn.AdaptiveAvgPool2d((1, 1)) 109 | model.classifier = basemodel.fc 110 | model.representation_dim=basemodel.fc.weight.size(1) 111 | return model 112 | def _reconstruct_densenet(self, basemodel): 113 | model = nn.Module() 114 | model.features = basemodel.features 115 | model.features.add_module('last_relu', nn.ReLU(inplace=True)) 116 | model.representation = nn.AdaptiveAvgPool2d((1, 1)) 117 | model.classifier = basemodel.classifier 118 | model.representation_dim=basemodel.classifier.weight.size(1) 119 | return model 120 | def _reconstruct_mpncovresnet(self, basemodel): 121 | model = nn.Module() 122 | if self.pretrained: 123 | model.features = nn.Sequential(*list(basemodel.children())[:-1]) 124 | model.representation_dim=basemodel.layer_reduce.weight.size(0) 125 | else: 126 | model.features = nn.Sequential(*list(basemodel.children())[:-4]) 127 | model.representation_dim=basemodel.layer_reduce.weight.size(1) 128 | model.representation = None 129 | model.classifier = basemodel.fc 130 | return model 131 | 132 | def _reconstruct_mpncov_vgg(self, basemodel): 133 | model = nn.Module() 134 | model.features = basemodel.features 135 | model.representation = basemodel.representation 136 | model.classifier = basemodel.classifier 137 | #model.representation_dim = model.representation.output_dim 138 | model.representation_dim = 512 139 | return model 140 | 141 | def forward(self, x): 142 | x = self.features(x) 143 | x = self.representation(x) 144 | x = x.view(x.size(0), -1) 145 | out = self.classifier(x) 146 | return out 147 | -------------------------------------------------------------------------------- /src/network/vgg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | import math 4 | 5 | 6 | __all__ = [ 7 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 8 | 'vgg19_bn', 'vgg19', 9 | ] 10 | 11 | 12 | model_urls = { 13 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 14 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 15 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 16 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 17 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 18 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 19 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 20 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 21 | } 22 | 23 | 24 | class VGG(nn.Module): 25 | 26 | def __init__(self, features, num_classes=1000, init_weights=True): 27 | super(VGG, self).__init__() 28 | self.features = features 29 | self.classifier = nn.Sequential( 30 | nn.Linear(512 * 7 * 7, 4096), 31 | nn.ReLU(True), 32 | nn.Dropout(), 33 | nn.Linear(4096, 4096), 34 | nn.ReLU(True), 35 | nn.Dropout(), 36 | nn.Linear(4096, num_classes), 37 | ) 38 | if init_weights: 39 | self._initialize_weights() 40 | 41 | def forward(self, x): 42 | x = self.features(x) 43 | x = x.view(x.size(0), -1) 44 | x = self.classifier(x) 45 | return x 46 | 47 | def _initialize_weights(self): 48 | for m in self.modules(): 49 | if isinstance(m, nn.Conv2d): 50 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 51 | if m.bias is not None: 52 | nn.init.constant_(m.bias, 0) 53 | elif isinstance(m, nn.BatchNorm2d): 54 | nn.init.constant_(m.weight, 1) 55 | nn.init.constant_(m.bias, 0) 56 | elif isinstance(m, nn.Linear): 57 | nn.init.normal_(m.weight, 0, 0.01) 58 | nn.init.constant_(m.bias, 0) 59 | 60 | 61 | def make_layers(cfg, batch_norm=False): 62 | layers = [] 63 | in_channels = 3 64 | for v in cfg: 65 | if v == 'M': 66 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 67 | else: 68 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 69 | if batch_norm: 70 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 71 | else: 72 | layers += [conv2d, nn.ReLU(inplace=True)] 73 | in_channels = v 74 | return nn.Sequential(*layers) 75 | 76 | 77 | cfg = { 78 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 79 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 80 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 81 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 82 | } 83 | 84 | 85 | def vgg11(pretrained=False, **kwargs): 86 | """VGG 11-layer model (configuration "A") 87 | 88 | Args: 89 | pretrained (bool): If True, returns a model pre-trained on ImageNet 90 | """ 91 | if pretrained: 92 | kwargs['init_weights'] = False 93 | model = VGG(make_layers(cfg['A']), **kwargs) 94 | if pretrained: 95 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11'])) 96 | return model 97 | 98 | 99 | def vgg11_bn(pretrained=False, **kwargs): 100 | """VGG 11-layer model (configuration "A") with batch normalization 101 | 102 | Args: 103 | pretrained (bool): If True, returns a model pre-trained on ImageNet 104 | """ 105 | if pretrained: 106 | kwargs['init_weights'] = False 107 | model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs) 108 | if pretrained: 109 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn'])) 110 | return model 111 | 112 | 113 | def vgg13(pretrained=False, **kwargs): 114 | """VGG 13-layer model (configuration "B") 115 | 116 | Args: 117 | pretrained (bool): If True, returns a model pre-trained on ImageNet 118 | """ 119 | if pretrained: 120 | kwargs['init_weights'] = False 121 | model = VGG(make_layers(cfg['B']), **kwargs) 122 | if pretrained: 123 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13'])) 124 | return model 125 | 126 | 127 | def vgg13_bn(pretrained=False, **kwargs): 128 | """VGG 13-layer model (configuration "B") with batch normalization 129 | 130 | Args: 131 | pretrained (bool): If True, returns a model pre-trained on ImageNet 132 | """ 133 | if pretrained: 134 | kwargs['init_weights'] = False 135 | model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs) 136 | if pretrained: 137 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn'])) 138 | return model 139 | 140 | 141 | def vgg16(pretrained=False, **kwargs): 142 | """VGG 16-layer model (configuration "D") 143 | 144 | Args: 145 | pretrained (bool): If True, returns a model pre-trained on ImageNet 146 | """ 147 | if pretrained: 148 | kwargs['init_weights'] = False 149 | model = VGG(make_layers(cfg['D']), **kwargs) 150 | if pretrained: 151 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16'])) 152 | return model 153 | 154 | 155 | def vgg16_bn(pretrained=False, **kwargs): 156 | """VGG 16-layer model (configuration "D") with batch normalization 157 | 158 | Args: 159 | pretrained (bool): If True, returns a model pre-trained on ImageNet 160 | """ 161 | if pretrained: 162 | kwargs['init_weights'] = False 163 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs) 164 | if pretrained: 165 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn'])) 166 | return model 167 | 168 | 169 | def vgg19(pretrained=False, **kwargs): 170 | """VGG 19-layer model (configuration "E") 171 | 172 | Args: 173 | pretrained (bool): If True, returns a model pre-trained on ImageNet 174 | """ 175 | if pretrained: 176 | kwargs['init_weights'] = False 177 | model = VGG(make_layers(cfg['E']), **kwargs) 178 | if pretrained: 179 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19'])) 180 | return model 181 | 182 | 183 | def vgg19_bn(pretrained=False, **kwargs): 184 | """VGG 19-layer model (configuration 'E') with batch normalization 185 | 186 | Args: 187 | pretrained (bool): If True, returns a model pre-trained on ImageNet 188 | """ 189 | if pretrained: 190 | kwargs['init_weights'] = False 191 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs) 192 | if pretrained: 193 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn'])) 194 | return model 195 | -------------------------------------------------------------------------------- /src/representation/INVCOV.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Function 4 | 5 | class INVCOV(nn.Module): 6 | def __init__(self, iterNum=3, is_sqrt=True, is_vec=True, input_dim=2048, dimension_reduction=None): 7 | 8 | super(INVCOV, self).__init__() 9 | self.iterNum=iterNum 10 | self.is_sqrt = is_sqrt 11 | self.is_vec = is_vec 12 | self.dr = dimension_reduction 13 | if self.dr is not None: 14 | self.conv_dr_block = nn.Sequential( 15 | nn.Conv2d(input_dim, self.dr, kernel_size=1, stride=1, bias=False), 16 | nn.BatchNorm2d(self.dr), 17 | nn.ReLU(inplace=True) 18 | ) 19 | output_dim = self.dr if self.dr else input_dim 20 | if self.is_vec: 21 | self.output_dim = int(output_dim*(output_dim+1)/2) 22 | else: 23 | self.output_dim = int(output_dim*output_dim) 24 | self._init_weight() 25 | 26 | def _init_weight(self): 27 | for m in self.modules(): 28 | if isinstance(m, nn.Conv2d): 29 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 30 | elif isinstance(m, nn.BatchNorm2d): 31 | nn.init.constant_(m.weight, 1) 32 | nn.init.constant_(m.bias, 0) 33 | 34 | def _cov_pool(self, x): 35 | return Covpool.apply(x) 36 | def _sqrtm(self, x): 37 | return Sqrtm.apply(x, self.iterNum) 38 | def _triuvec(self, x): 39 | return Triuvec.apply(x) 40 | 41 | def forward(self, x): 42 | if self.dr is not None: 43 | x = self.conv_dr_block(x) 44 | x = self._cov_pool(x) 45 | if self.is_sqrt: 46 | I3 = 1e-10+1e-9* torch.diag(torch.rand(x.shape[1], device=x.device)).view(1, x.shape[1], x.shape[1]).repeat(x.shape[0], 1, 1).type(x.dtype) 47 | x = self._sqrtm(x+I3) 48 | x = x.bmm(x) #inverse 49 | if self.is_vec: 50 | x = self._triuvec(x) 51 | return x 52 | 53 | 54 | class Covpool(Function): 55 | @staticmethod 56 | def forward(ctx, input): 57 | x = input 58 | batchSize = x.data.shape[0] 59 | dim = x.data.shape[1] 60 | h = x.data.shape[2] 61 | w = x.data.shape[3] 62 | M = h*w 63 | x = x.reshape(batchSize,dim,M) 64 | I_hat = (-1./M/M)*torch.ones(M,M,device = x.device) + (1./M)*torch.eye(M,M,device = x.device) 65 | I_hat = I_hat.view(1,M,M).repeat(batchSize,1,1).type(x.dtype) 66 | y = x.bmm(I_hat).bmm(x.transpose(1,2)) 67 | ctx.save_for_backward(input,I_hat) 68 | return y 69 | @staticmethod 70 | def backward(ctx, grad_output): 71 | input,I_hat = ctx.saved_tensors 72 | x = input 73 | batchSize = x.data.shape[0] 74 | dim = x.data.shape[1] 75 | h = x.data.shape[2] 76 | w = x.data.shape[3] 77 | M = h*w 78 | x = x.reshape(batchSize,dim,M) 79 | grad_input = grad_output + grad_output.transpose(1,2) 80 | grad_input = grad_input.bmm(x).bmm(I_hat) 81 | grad_input = grad_input.reshape(batchSize,dim,h,w) 82 | return grad_input 83 | 84 | class Sqrtm(Function): 85 | @staticmethod 86 | def forward(ctx, input, iterN): 87 | x = input 88 | batchSize = x.data.shape[0] 89 | dim = x.data.shape[1] 90 | dtype = x.dtype 91 | I3 = 3.0*torch.eye(dim,dim,device = x.device).view(1, dim, dim).repeat(batchSize,1,1).type(dtype) 92 | normA = (1.0/3.0)*x.mul(I3).sum(dim=1).sum(dim=1) 93 | A = x.div(normA.view(batchSize,1,1).expand_as(x)) 94 | Y = torch.zeros(batchSize, iterN, dim, dim, requires_grad = False, device = x.device).type(dtype) 95 | Z = torch.eye(dim,dim,device = x.device).view(1,dim,dim).repeat(batchSize,iterN,1,1).type(dtype) 96 | if iterN < 2: 97 | ZY = 0.5*(I3 - A) 98 | YZY = A.bmm(ZY) 99 | else: 100 | ZY = 0.5*(I3 - A) 101 | Y[:,0,:,:] = A.bmm(ZY) 102 | Z[:,0,:,:] = ZY 103 | for i in range(1, iterN-1): 104 | ZY = 0.5*(I3 - Z[:,i-1,:,:].bmm(Y[:,i-1,:,:])) 105 | Y[:,i,:,:] = Y[:,i-1,:,:].bmm(ZY) 106 | Z[:,i,:,:] = ZY.bmm(Z[:,i-1,:,:]) 107 | ZYZ = 0.5 * (I3 - Z[:,iterN-2,:,:].bmm(Y[:,iterN-2,:,:])).bmm(Z[:,iterN-2,:,:]) 108 | y = ZYZ * torch.pow(normA,-0.5).view(batchSize, 1, 1).expand_as(x) 109 | ctx.save_for_backward(input, A, ZYZ, normA, Y, Z) 110 | ctx.iterN = iterN 111 | return y 112 | @staticmethod 113 | def backward(ctx, grad_output): 114 | input, A, ZY, normA, Y, Z = ctx.saved_tensors 115 | iterN = ctx.iterN 116 | x = input 117 | batchSize = x.data.shape[0] 118 | dim = x.data.shape[1] 119 | dtype = x.dtype 120 | der_postCom = grad_output*torch.pow(normA, -0.5).view(batchSize, 1, 1).expand_as(x) 121 | der_postComAux = -0.5*torch.pow(normA, -1.5)*((grad_output*ZY).sum(dim=1).sum(dim=1)) 122 | I3 = 3.0*torch.eye(dim,dim,device = x.device).view(1, dim, dim).repeat(batchSize,1,1).type(dtype) 123 | if iterN < 2: 124 | der_NSiter = 0.5*(der_postCom.bmm(I3 - A) - A.bmm(der_postCom)) 125 | else: 126 | dldZ = 0.5*((I3 - Y[:,iterN-2,:,:].bmm(Z[:,iterN-2,:,:])).bmm(der_postCom) - 127 | der_postCom.bmm(Z[:,iterN-2,:,:]).bmm(Y[:,iterN-2,:,:])) 128 | dldY = -0.5*Z[:,iterN-2,:,:].bmm(der_postCom).bmm(Z[:,iterN-2,:,:]) 129 | for i in range(iterN-3, -1, -1): 130 | YZ = I3 - Y[:,i,:,:].bmm(Z[:,i,:,:]) 131 | ZY = Z[:,i,:,:].bmm(Y[:,i,:,:]) 132 | dldY_ = 0.5*(dldY.bmm(YZ) - 133 | Z[:,i,:,:].bmm(dldZ).bmm(Z[:,i,:,:]) - 134 | ZY.bmm(dldY)) 135 | dldZ_ = 0.5*(YZ.bmm(dldZ) - 136 | Y[:,i,:,:].bmm(dldY).bmm(Y[:,i,:,:]) - 137 | dldZ.bmm(ZY)) 138 | dldY = dldY_ 139 | dldZ = dldZ_ 140 | der_NSiter = 0.5*(dldY.bmm(I3 - A) - dldZ - A.bmm(dldY)) 141 | der_NSiter = der_NSiter.transpose(1, 2) 142 | grad_input = der_NSiter.div(normA.view(batchSize,1,1).expand_as(x)) 143 | grad_aux = der_NSiter.mul(x).sum(dim=1).sum(dim=1) 144 | for i in range(batchSize): 145 | grad_input[i,:,:] += (der_postComAux[i] \ 146 | - grad_aux[i] / (normA[i] * normA[i])) \ 147 | *torch.ones(dim,device = x.device).diag().type(dtype) 148 | return grad_input, None 149 | 150 | class Triuvec(Function): 151 | @staticmethod 152 | def forward(ctx, input): 153 | x = input 154 | batchSize = x.data.shape[0] 155 | dim = x.data.shape[1] 156 | dtype = x.dtype 157 | x = x.reshape(batchSize, dim*dim) 158 | I = torch.ones(dim,dim).triu().reshape(dim*dim) 159 | index = I.nonzero() 160 | y = torch.zeros(batchSize,int(dim*(dim+1)/2),device = x.device).type(dtype) 161 | y = x[:,index] 162 | ctx.save_for_backward(input,index) 163 | return y 164 | @staticmethod 165 | def backward(ctx, grad_output): 166 | input,index = ctx.saved_tensors 167 | x = input 168 | batchSize = x.data.shape[0] 169 | dim = x.data.shape[1] 170 | dtype = x.dtype 171 | grad_input = torch.zeros(batchSize,dim*dim,device = x.device,requires_grad=False).type(dtype) 172 | grad_input[:,index] = grad_output 173 | grad_input = grad_input.reshape(batchSize,dim,dim) 174 | return grad_input 175 | 176 | def CovpoolLayer(var): 177 | return Covpool.apply(var) 178 | 179 | def SqrtmLayer(var, iterN): 180 | return Sqrtm.apply(var, iterN) 181 | 182 | def TriuvecLayer(var): 183 | return Triuvec.apply(var) 184 | -------------------------------------------------------------------------------- /src/network/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152'] 7 | 8 | 9 | model_urls = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | } 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | """3x3 convolution with padding""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, bias=False) 22 | 23 | 24 | def conv1x1(in_planes, out_planes, stride=1): 25 | """1x1 convolution""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 27 | 28 | 29 | class BasicBlock(nn.Module): 30 | expansion = 1 31 | 32 | def __init__(self, inplanes, planes, stride=1, downsample=None): 33 | super(BasicBlock, self).__init__() 34 | self.conv1 = conv3x3(inplanes, planes, stride) 35 | self.bn1 = nn.BatchNorm2d(planes) 36 | self.relu = nn.ReLU(inplace=True) 37 | self.conv2 = conv3x3(planes, planes) 38 | self.bn2 = nn.BatchNorm2d(planes) 39 | self.downsample = downsample 40 | self.stride = stride 41 | 42 | def forward(self, x): 43 | identity = x 44 | 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | if self.downsample is not None: 53 | identity = self.downsample(x) 54 | 55 | out += identity 56 | out = self.relu(out) 57 | 58 | return out 59 | 60 | 61 | class Bottleneck(nn.Module): 62 | expansion = 4 63 | 64 | def __init__(self, inplanes, planes, stride=1, downsample=None): 65 | super(Bottleneck, self).__init__() 66 | self.conv1 = conv1x1(inplanes, planes) 67 | self.bn1 = nn.BatchNorm2d(planes) 68 | self.conv2 = conv3x3(planes, planes, stride) 69 | self.bn2 = nn.BatchNorm2d(planes) 70 | self.conv3 = conv1x1(planes, planes * self.expansion) 71 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 72 | self.relu = nn.ReLU(inplace=True) 73 | self.downsample = downsample 74 | self.stride = stride 75 | 76 | def forward(self, x): 77 | identity = x 78 | 79 | out = self.conv1(x) 80 | out = self.bn1(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv2(out) 84 | out = self.bn2(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv3(out) 88 | out = self.bn3(out) 89 | 90 | if self.downsample is not None: 91 | identity = self.downsample(x) 92 | 93 | out += identity 94 | out = self.relu(out) 95 | 96 | return out 97 | 98 | 99 | class ResNet(nn.Module): 100 | 101 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False): 102 | super(ResNet, self).__init__() 103 | self.inplanes = 64 104 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 105 | bias=False) 106 | self.bn1 = nn.BatchNorm2d(64) 107 | self.relu = nn.ReLU(inplace=True) 108 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 109 | self.layer1 = self._make_layer(block, 64, layers[0]) 110 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 111 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 112 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 113 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 114 | self.fc = nn.Linear(512 * block.expansion, num_classes) 115 | 116 | for m in self.modules(): 117 | if isinstance(m, nn.Conv2d): 118 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 119 | elif isinstance(m, nn.BatchNorm2d): 120 | nn.init.constant_(m.weight, 1) 121 | nn.init.constant_(m.bias, 0) 122 | 123 | # Zero-initialize the last BN in each residual branch, 124 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 125 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 126 | if zero_init_residual: 127 | for m in self.modules(): 128 | if isinstance(m, Bottleneck): 129 | nn.init.constant_(m.bn3.weight, 0) 130 | elif isinstance(m, BasicBlock): 131 | nn.init.constant_(m.bn2.weight, 0) 132 | 133 | def _make_layer(self, block, planes, blocks, stride=1): 134 | downsample = None 135 | if stride != 1 or self.inplanes != planes * block.expansion: 136 | downsample = nn.Sequential( 137 | conv1x1(self.inplanes, planes * block.expansion, stride), 138 | nn.BatchNorm2d(planes * block.expansion), 139 | ) 140 | 141 | layers = [] 142 | layers.append(block(self.inplanes, planes, stride, downsample)) 143 | self.inplanes = planes * block.expansion 144 | for _ in range(1, blocks): 145 | layers.append(block(self.inplanes, planes)) 146 | 147 | return nn.Sequential(*layers) 148 | 149 | def forward(self, x): 150 | x = self.conv1(x) 151 | x = self.bn1(x) 152 | x = self.relu(x) 153 | x = self.maxpool(x) 154 | 155 | x = self.layer1(x) 156 | x = self.layer2(x) 157 | x = self.layer3(x) 158 | x = self.layer4(x) 159 | 160 | x = self.avgpool(x) 161 | x = x.view(x.size(0), -1) 162 | x = self.fc(x) 163 | 164 | return x 165 | 166 | 167 | def resnet18(pretrained=False, **kwargs): 168 | """Constructs a ResNet-18 model. 169 | 170 | Args: 171 | pretrained (bool): If True, returns a model pre-trained on ImageNet 172 | """ 173 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 174 | if pretrained: 175 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 176 | return model 177 | 178 | 179 | def resnet34(pretrained=False, **kwargs): 180 | """Constructs a ResNet-34 model. 181 | 182 | Args: 183 | pretrained (bool): If True, returns a model pre-trained on ImageNet 184 | """ 185 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 186 | if pretrained: 187 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 188 | return model 189 | 190 | 191 | def resnet50(pretrained=False, **kwargs): 192 | """Constructs a ResNet-50 model. 193 | 194 | Args: 195 | pretrained (bool): If True, returns a model pre-trained on ImageNet 196 | """ 197 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 198 | if pretrained: 199 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 200 | return model 201 | 202 | 203 | def resnet101(pretrained=False, **kwargs): 204 | """Constructs a ResNet-101 model. 205 | 206 | Args: 207 | pretrained (bool): If True, returns a model pre-trained on ImageNet 208 | """ 209 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 210 | if pretrained: 211 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 212 | return model 213 | 214 | 215 | def resnet152(pretrained=False, **kwargs): 216 | """Constructs a ResNet-152 model. 217 | 218 | Args: 219 | pretrained (bool): If True, returns a model pre-trained on ImageNet 220 | """ 221 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 222 | if pretrained: 223 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 224 | return model 225 | 226 | -------------------------------------------------------------------------------- /src/representation/COV.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Function 4 | 5 | class COV(nn.Module): 6 | """Matrix power normalized Covariance pooling (MPNCOV) 7 | implementation of fast MPN-COV (i.e.,iSQRT-COV) 8 | https://arxiv.org/abs/1712.01034 9 | 10 | Args: 11 | iterNum: #iteration of Newton-schulz method 12 | is_sqrt: whether perform matrix square root or not 13 | is_vec: whether the output is a vector or not 14 | input_dim: the #channel of input feature 15 | dimension_reduction: if None, it will not use 1x1 conv to 16 | reduce the #channel of feature. 17 | if 256 or others, the #channel of feature 18 | will be reduced to 256 or others. 19 | """ 20 | def __init__(self, iterNum=3, is_sqrt=True, is_vec=True, input_dim=2048, dimension_reduction=None): 21 | 22 | super(COV, self).__init__() 23 | self.iterNum=iterNum 24 | self.is_sqrt = is_sqrt 25 | self.is_vec = is_vec 26 | self.dr = dimension_reduction 27 | if self.dr is not None: 28 | self.conv_dr_block = nn.Sequential( 29 | nn.Conv2d(input_dim, self.dr, kernel_size=1, stride=1, bias=False), 30 | nn.BatchNorm2d(self.dr), 31 | nn.ReLU(inplace=True) 32 | ) 33 | output_dim = self.dr if self.dr else input_dim 34 | if self.is_vec: 35 | self.output_dim = int(output_dim*(output_dim+1)/2) 36 | else: 37 | self.output_dim = int(output_dim*output_dim) 38 | self._init_weight() 39 | 40 | def _init_weight(self): 41 | for m in self.modules(): 42 | if isinstance(m, nn.Conv2d): 43 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 44 | elif isinstance(m, nn.BatchNorm2d): 45 | nn.init.constant_(m.weight, 1) 46 | nn.init.constant_(m.bias, 0) 47 | 48 | def _cov_pool(self, x): 49 | return Covpool.apply(x) 50 | def _sqrtm(self, x): 51 | return Sqrtm.apply(x, self.iterNum) 52 | def _sqrtm_autograd(self, x): 53 | batchSize = x.shape[0] 54 | dim = x.shape[1] 55 | dtype = x.dtype 56 | iterN = self.iterNum 57 | I3 = 3.0 * torch.eye(dim, dim, device=x.device).view(1, dim, dim).repeat(batchSize, 1, 1).type(dtype) 58 | normA = (1.0 / 3.0) * x.mul(I3).sum(dim=1).sum(dim=1) 59 | A = x.div(normA.view(batchSize, 1, 1).expand_as(x)) 60 | ZY = 0.5 * (I3 - A) 61 | if iterN < 2: 62 | ZY = 0.5*(I3 - A) 63 | YZY = A.bmm(ZY) 64 | else: 65 | Y = A.bmm(ZY) 66 | Z = ZY 67 | for _ in range(iterN - 2): 68 | ZY = 0.5 * (I3 - Z.bmm(Y)) 69 | Y = Y.bmm(ZY) 70 | Z = ZY.bmm(Z) 71 | #print(torch.norm(Y[0]*Z[0]-torch.eye(dim, dim, device=x.device).view(1, dim, dim).repeat(batchSize, 1, 1).type(dtype))) 72 | YZY = 0.5 * Y.bmm(I3 - Z.bmm(Y)) #original version 73 | y = YZY * torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x) #original version 74 | return y 75 | def _triuvec(self, x): 76 | return Triuvec.apply(x) 77 | 78 | def forward(self, x): 79 | if self.dr is not None: 80 | x = self.conv_dr_block(x) 81 | x = self._cov_pool(x) 82 | if self.is_sqrt: 83 | x = self._sqrtm(x) 84 | if self.is_vec: 85 | x = self._triuvec(x) 86 | return x 87 | 88 | 89 | class Covpool(Function): 90 | @staticmethod 91 | def forward(ctx, input): 92 | x = input 93 | batchSize = x.data.shape[0] 94 | dim = x.data.shape[1] 95 | h = x.data.shape[2] 96 | w = x.data.shape[3] 97 | M = h*w 98 | x = x.reshape(batchSize,dim,M) 99 | I_hat = (-1./M/M)*torch.ones(M,M,device = x.device) + (1./M)*torch.eye(M,M,device = x.device) 100 | I_hat = I_hat.view(1,M,M).repeat(batchSize,1,1).type(x.dtype) 101 | y = x.bmm(I_hat).bmm(x.transpose(1,2)) 102 | ctx.save_for_backward(input,I_hat) 103 | return y 104 | @staticmethod 105 | def backward(ctx, grad_output): 106 | input,I_hat = ctx.saved_tensors 107 | x = input 108 | batchSize = x.data.shape[0] 109 | dim = x.data.shape[1] 110 | h = x.data.shape[2] 111 | w = x.data.shape[3] 112 | M = h*w 113 | x = x.reshape(batchSize,dim,M) 114 | grad_input = grad_output + grad_output.transpose(1,2) 115 | grad_input = grad_input.bmm(x).bmm(I_hat) 116 | grad_input = grad_input.reshape(batchSize,dim,h,w) 117 | return grad_input 118 | 119 | class Sqrtm(Function): 120 | @staticmethod 121 | def forward(ctx, input, iterN): 122 | x = input 123 | batchSize = x.data.shape[0] 124 | dim = x.data.shape[1] 125 | dtype = x.dtype 126 | I3 = 3.0*torch.eye(dim,dim,device = x.device).view(1, dim, dim).repeat(batchSize,1,1).type(dtype) 127 | normA = (1.0/3.0)*x.mul(I3).sum(dim=1).sum(dim=1) 128 | A = x.div(normA.view(batchSize,1,1).expand_as(x)) 129 | Y = torch.zeros(batchSize, iterN, dim, dim, requires_grad = False, device = x.device).type(dtype) 130 | Z = torch.eye(dim,dim,device = x.device).view(1,dim,dim).repeat(batchSize,iterN,1,1).type(dtype) 131 | if iterN < 2: 132 | ZY = 0.5*(I3 - A) 133 | YZY = A.bmm(ZY) 134 | else: 135 | ZY = 0.5*(I3 - A) 136 | Y[:,0,:,:] = A.bmm(ZY) 137 | Z[:,0,:,:] = ZY 138 | for i in range(1, iterN-1): 139 | ZY = 0.5*(I3 - Z[:,i-1,:,:].bmm(Y[:,i-1,:,:])) 140 | Y[:,i,:,:] = Y[:,i-1,:,:].bmm(ZY) 141 | Z[:,i,:,:] = ZY.bmm(Z[:,i-1,:,:]) 142 | YZY = 0.5*Y[:,iterN-2,:,:].bmm(I3 - Z[:,iterN-2,:,:].bmm(Y[:,iterN-2,:,:])) 143 | y = YZY*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x) 144 | ctx.save_for_backward(input, A, YZY, normA, Y, Z) 145 | ctx.iterN = iterN 146 | return y 147 | @staticmethod 148 | def backward(ctx, grad_output): 149 | input, A, ZY, normA, Y, Z = ctx.saved_tensors 150 | iterN = ctx.iterN 151 | x = input 152 | batchSize = x.data.shape[0] 153 | dim = x.data.shape[1] 154 | dtype = x.dtype 155 | der_postCom = grad_output*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x) 156 | der_postComAux = (grad_output*ZY).sum(dim=1).sum(dim=1).div(2*torch.sqrt(normA)) 157 | I3 = 3.0*torch.eye(dim,dim,device = x.device).view(1, dim, dim).repeat(batchSize,1,1).type(dtype) 158 | if iterN < 2: 159 | der_NSiter = 0.5*(der_postCom.bmm(I3 - A) - A.bmm(der_postCom)) 160 | else: 161 | dldY = 0.5*(der_postCom.bmm(I3 - Y[:,iterN-2,:,:].bmm(Z[:,iterN-2,:,:])) - 162 | Z[:,iterN-2,:,:].bmm(Y[:,iterN-2,:,:]).bmm(der_postCom)) 163 | dldZ = -0.5*Y[:,iterN-2,:,:].bmm(der_postCom).bmm(Y[:,iterN-2,:,:]) 164 | for i in range(iterN-3, -1, -1): 165 | YZ = I3 - Y[:,i,:,:].bmm(Z[:,i,:,:]) 166 | ZY = Z[:,i,:,:].bmm(Y[:,i,:,:]) 167 | dldY_ = 0.5*(dldY.bmm(YZ) - 168 | Z[:,i,:,:].bmm(dldZ).bmm(Z[:,i,:,:]) - 169 | ZY.bmm(dldY)) 170 | dldZ_ = 0.5*(YZ.bmm(dldZ) - 171 | Y[:,i,:,:].bmm(dldY).bmm(Y[:,i,:,:]) - 172 | dldZ.bmm(ZY)) 173 | dldY = dldY_ 174 | dldZ = dldZ_ 175 | der_NSiter = 0.5*(dldY.bmm(I3 - A) - dldZ - A.bmm(dldY)) 176 | der_NSiter = der_NSiter.transpose(1, 2) 177 | grad_input = der_NSiter.div(normA.view(batchSize,1,1).expand_as(x)) 178 | grad_aux = der_NSiter.mul(x).sum(dim=1).sum(dim=1) 179 | for i in range(batchSize): 180 | grad_input[i,:,:] += (der_postComAux[i] \ 181 | - grad_aux[i] / (normA[i] * normA[i])) \ 182 | *torch.ones(dim,device = x.device).diag().type(dtype) 183 | return grad_input, None 184 | 185 | class Triuvec(Function): 186 | @staticmethod 187 | def forward(ctx, input): 188 | x = input 189 | batchSize = x.data.shape[0] 190 | dim = x.data.shape[1] 191 | dtype = x.dtype 192 | x = x.reshape(batchSize, dim*dim) 193 | I = torch.ones(dim,dim).triu().reshape(dim*dim) 194 | index = I.nonzero() 195 | y = torch.zeros(batchSize,int(dim*(dim+1)/2),device = x.device).type(dtype) 196 | y = x[:,index] 197 | ctx.save_for_backward(input,index) 198 | return y 199 | @staticmethod 200 | def backward(ctx, grad_output): 201 | input,index = ctx.saved_tensors 202 | x = input 203 | batchSize = x.data.shape[0] 204 | dim = x.data.shape[1] 205 | dtype = x.dtype 206 | grad_input = torch.zeros(batchSize,dim*dim,device = x.device,requires_grad=False).type(dtype) 207 | grad_input[:,index] = grad_output 208 | grad_input = grad_input.reshape(batchSize,dim,dim) 209 | return grad_input 210 | 211 | def CovpoolLayer(var): 212 | return Covpool.apply(var) 213 | 214 | def SqrtmLayer(var, iterN): 215 | return Sqrtm.apply(var, iterN) 216 | 217 | def TriuvecLayer(var): 218 | return Triuvec.apply(var) 219 | -------------------------------------------------------------------------------- /src/network/densenet.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | from collections import OrderedDict 7 | 8 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] 9 | 10 | 11 | model_urls = { 12 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 13 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 14 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 15 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', 16 | } 17 | 18 | 19 | class _DenseLayer(nn.Sequential): 20 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): 21 | super(_DenseLayer, self).__init__() 22 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 23 | self.add_module('relu1', nn.ReLU(inplace=True)), 24 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 25 | growth_rate, kernel_size=1, stride=1, bias=False)), 26 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 27 | self.add_module('relu2', nn.ReLU(inplace=True)), 28 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 29 | kernel_size=3, stride=1, padding=1, bias=False)), 30 | self.drop_rate = drop_rate 31 | 32 | def forward(self, x): 33 | new_features = super(_DenseLayer, self).forward(x) 34 | if self.drop_rate > 0: 35 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) 36 | return torch.cat([x, new_features], 1) 37 | 38 | 39 | class _DenseBlock(nn.Sequential): 40 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): 41 | super(_DenseBlock, self).__init__() 42 | for i in range(num_layers): 43 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate) 44 | self.add_module('denselayer%d' % (i + 1), layer) 45 | 46 | 47 | class _Transition(nn.Sequential): 48 | def __init__(self, num_input_features, num_output_features): 49 | super(_Transition, self).__init__() 50 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 51 | self.add_module('relu', nn.ReLU(inplace=True)) 52 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 53 | kernel_size=1, stride=1, bias=False)) 54 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 55 | 56 | 57 | class DenseNet(nn.Module): 58 | r"""Densenet-BC model class, based on 59 | `"Densely Connected Convolutional Networks" `_ 60 | 61 | Args: 62 | growth_rate (int) - how many filters to add each layer (`k` in paper) 63 | block_config (list of 4 ints) - how many layers in each pooling block 64 | num_init_features (int) - the number of filters to learn in the first convolution layer 65 | bn_size (int) - multiplicative factor for number of bottle neck layers 66 | (i.e. bn_size * k features in the bottleneck layer) 67 | drop_rate (float) - dropout rate after each dense layer 68 | num_classes (int) - number of classification classes 69 | """ 70 | 71 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 72 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000): 73 | 74 | super(DenseNet, self).__init__() 75 | 76 | # First convolution 77 | self.features = nn.Sequential(OrderedDict([ 78 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), 79 | ('norm0', nn.BatchNorm2d(num_init_features)), 80 | ('relu0', nn.ReLU(inplace=True)), 81 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 82 | ])) 83 | 84 | # Each denseblock 85 | num_features = num_init_features 86 | for i, num_layers in enumerate(block_config): 87 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, 88 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) 89 | self.features.add_module('denseblock%d' % (i + 1), block) 90 | num_features = num_features + num_layers * growth_rate 91 | if i != len(block_config) - 1: 92 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2) 93 | self.features.add_module('transition%d' % (i + 1), trans) 94 | num_features = num_features // 2 95 | 96 | # Final batch norm 97 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 98 | 99 | # Linear layer 100 | self.classifier = nn.Linear(num_features, num_classes) 101 | 102 | # Official init from torch repo. 103 | for m in self.modules(): 104 | if isinstance(m, nn.Conv2d): 105 | nn.init.kaiming_normal_(m.weight) 106 | elif isinstance(m, nn.BatchNorm2d): 107 | nn.init.constant_(m.weight, 1) 108 | nn.init.constant_(m.bias, 0) 109 | elif isinstance(m, nn.Linear): 110 | nn.init.constant_(m.bias, 0) 111 | 112 | def forward(self, x): 113 | features = self.features(x) 114 | out = F.relu(features, inplace=True) 115 | out = F.adaptive_avg_pool2d(out, (1, 1)).view(features.size(0), -1) 116 | out = self.classifier(out) 117 | return out 118 | 119 | 120 | def densenet121(pretrained=False, **kwargs): 121 | r"""Densenet-121 model from 122 | `"Densely Connected Convolutional Networks" `_ 123 | 124 | Args: 125 | pretrained (bool): If True, returns a model pre-trained on ImageNet 126 | """ 127 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), 128 | **kwargs) 129 | if pretrained: 130 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 131 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 132 | # They are also in the checkpoints in model_urls. This pattern is used 133 | # to find such keys. 134 | pattern = re.compile( 135 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 136 | state_dict = model_zoo.load_url(model_urls['densenet121']) 137 | for key in list(state_dict.keys()): 138 | res = pattern.match(key) 139 | if res: 140 | new_key = res.group(1) + res.group(2) 141 | state_dict[new_key] = state_dict[key] 142 | del state_dict[key] 143 | model.load_state_dict(state_dict) 144 | return model 145 | 146 | 147 | def densenet169(pretrained=False, **kwargs): 148 | r"""Densenet-169 model from 149 | `"Densely Connected Convolutional Networks" `_ 150 | 151 | Args: 152 | pretrained (bool): If True, returns a model pre-trained on ImageNet 153 | """ 154 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), 155 | **kwargs) 156 | if pretrained: 157 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 158 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 159 | # They are also in the checkpoints in model_urls. This pattern is used 160 | # to find such keys. 161 | pattern = re.compile( 162 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 163 | state_dict = model_zoo.load_url(model_urls['densenet169']) 164 | for key in list(state_dict.keys()): 165 | res = pattern.match(key) 166 | if res: 167 | new_key = res.group(1) + res.group(2) 168 | state_dict[new_key] = state_dict[key] 169 | del state_dict[key] 170 | model.load_state_dict(state_dict) 171 | return model 172 | 173 | 174 | def densenet201(pretrained=False, **kwargs): 175 | r"""Densenet-201 model from 176 | `"Densely Connected Convolutional Networks" `_ 177 | 178 | Args: 179 | pretrained (bool): If True, returns a model pre-trained on ImageNet 180 | """ 181 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), 182 | **kwargs) 183 | if pretrained: 184 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 185 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 186 | # They are also in the checkpoints in model_urls. This pattern is used 187 | # to find such keys. 188 | pattern = re.compile( 189 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 190 | state_dict = model_zoo.load_url(model_urls['densenet201']) 191 | for key in list(state_dict.keys()): 192 | res = pattern.match(key) 193 | if res: 194 | new_key = res.group(1) + res.group(2) 195 | state_dict[new_key] = state_dict[key] 196 | del state_dict[key] 197 | model.load_state_dict(state_dict) 198 | return model 199 | 200 | 201 | def densenet161(pretrained=False, **kwargs): 202 | r"""Densenet-161 model from 203 | `"Densely Connected Convolutional Networks" `_ 204 | 205 | Args: 206 | pretrained (bool): If True, returns a model pre-trained on ImageNet 207 | """ 208 | model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), 209 | **kwargs) 210 | if pretrained: 211 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 212 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 213 | # They are also in the checkpoints in model_urls. This pattern is used 214 | # to find such keys. 215 | pattern = re.compile( 216 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 217 | state_dict = model_zoo.load_url(model_urls['densenet161']) 218 | for key in list(state_dict.keys()): 219 | res = pattern.match(key) 220 | if res: 221 | new_key = res.group(1) + res.group(2) 222 | state_dict[new_key] = state_dict[key] 223 | del state_dict[key] 224 | model.load_state_dict(state_dict) 225 | return model 226 | -------------------------------------------------------------------------------- /src/representation/SICE.py: -------------------------------------------------------------------------------- 1 | import profile 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Function 5 | from numpy import linalg as LA 6 | import numpy as np 7 | 8 | class SICE(nn.Module): 9 | def __init__(self, iterNum=3, is_sqrt=True, is_vec=True, input_dim=2048, dimension_reduction=None, sparsity_val=0.0, sice_lrate=0.0): 10 | 11 | super(SICE, self).__init__() 12 | self.iterNum=iterNum 13 | self.is_sqrt = is_sqrt 14 | self.is_vec = is_vec 15 | self.dr = dimension_reduction 16 | self.sparsity = sparsity_val 17 | self.learingRate = sice_lrate 18 | if self.dr is not None: 19 | self.conv_dr_block = nn.Sequential( 20 | nn.Conv2d(input_dim, self.dr, kernel_size=1, stride=1, bias=False), 21 | nn.BatchNorm2d(self.dr), 22 | nn.ReLU(inplace=True) 23 | ) 24 | output_dim = self.dr if self.dr else input_dim 25 | if self.is_vec: 26 | self.output_dim = int(output_dim*(output_dim+1)/2) 27 | else: 28 | self.output_dim = int(output_dim*output_dim) 29 | self._init_weight() 30 | 31 | def _init_weight(self): 32 | for m in self.modules(): 33 | if isinstance(m, nn.Conv2d): 34 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 35 | elif isinstance(m, nn.BatchNorm2d): 36 | nn.init.constant_(m.weight, 1) 37 | nn.init.constant_(m.bias, 0) 38 | 39 | def _cov_pool(self, x): 40 | return Covpool.apply(x) 41 | 42 | 43 | def _inv_sqrtm(self, x, iterN): 44 | return Sqrtm.apply(x, iterN) 45 | 46 | def _sqrtm(self, x, iterN): 47 | batchSize = x.shape[0] 48 | dim = x.shape[1] 49 | dtype = x.dtype 50 | I3 = 3.0 * torch.eye(dim, dim, device=x.device).view(1, dim, dim).repeat(batchSize, 1, 1).type(dtype) 51 | normA = (1.0 / 3.0) * x.mul(I3).sum(dim=1).sum(dim=1) 52 | A = x.div(normA.view(batchSize, 1, 1).expand_as(x)) 53 | ZY = 0.5 * (I3 - A) 54 | if iterN < 2: 55 | ZY = 0.5*(I3 - A) 56 | YZY = A.bmm(ZY) 57 | else: 58 | Y = A.bmm(ZY) 59 | Z = ZY 60 | for _ in range(iterN - 2): 61 | ZY = 0.5 * (I3 - Z.bmm(Y)) 62 | Y = Y.bmm(ZY) 63 | Z = ZY.bmm(Z) 64 | YZY = 0.5 * Y.bmm(I3 - Z.bmm(Y)) 65 | y = ZY * torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x) 66 | return y 67 | 68 | def _sice(self, mfX, fLR=5.0, fSparsity=0.07, nSteps=10000): 69 | mfC = self._cov_pool(mfX) 70 | mfC=mfC/torch.diagonal(mfC, dim1=-2, dim2=-1).sum(-1).view(-1,1,1) 71 | I = 1e-10+1e-9*torch.diag(torch.rand(mfC.shape[1],device = mfC.device)).view(1, mfC.shape[1], mfC.shape[2]).repeat(mfC.shape[0],1,1).type(mfC.dtype) 72 | zz=self._inv_sqrtm(mfC+I, 7) 73 | 74 | mfInvC=zz.bmm(zz) 75 | 76 | mfCov=mfC*1.0 77 | mfLLT=mfInvC*1.0 #+1 78 | 79 | mfCov=mfCov 80 | mfLLT=mfLLT 81 | mfLLT_prev=1e10*torch.ones(mfLLT.size(), device=mfC.device) 82 | 83 | nCounter=0 84 | for i in range(nSteps): 85 | mfLLT_plus = torch.relu(mfLLT) 86 | mfLLT_minus = torch.relu(-mfLLT) 87 | 88 | zz = self._inv_sqrtm(mfLLT+I, 7) 89 | mfGradPart1=-zz.bmm(zz) 90 | 91 | mfGradPart2 = 0.5*(mfCov.transpose(1,2) + mfCov) 92 | mfGradPart12 = mfGradPart1+mfGradPart2 93 | 94 | mfGradPart3_plus = mfGradPart12 + fSparsity 95 | mfGradPart3_minus = -mfGradPart12 + fSparsity 96 | 97 | fDec=(1-i/(nSteps-1.0) ) 98 | 99 | mfLLT_plus = mfLLT_plus - fLR*fDec*mfGradPart3_plus 100 | mfLLT_minus = mfLLT_minus - fLR*fDec*mfGradPart3_minus 101 | 102 | mfLLT_plus = torch.relu(mfLLT_plus) 103 | mfLLT_minus = torch.relu(mfLLT_minus) 104 | 105 | mfLLT = mfLLT_plus-mfLLT_minus 106 | mfLLT = 0.5*(mfLLT+mfLLT.transpose(1,2)) 107 | 108 | fSolDiff = (mfLLT-mfLLT_prev).abs().mean() 109 | fSparseCount = ((mfLLT.abs()>2e-8)*1.0).mean() 110 | 111 | mfLLT_prev = mfLLT*1.0 112 | mfLLT_prev = mfLLT_prev 113 | mfOut = mfLLT 114 | mfOut = mfOut/torch.sqrt(torch.diagonal(mfOut, dim1=-2, dim2=-1).sum(-1)).view(-1,1,1) #works better and faster convergence 115 | return mfOut 116 | 117 | def _triuvec(self, x): 118 | return Triuvec.apply(x) 119 | 120 | 121 | def forward(self, x): 122 | if self.dr is not None: 123 | x = self.conv_dr_block(x) 124 | x = self._sice(x, fLR=self.learingRate, fSparsity=self.sparsity, nSteps=self.iterNum) 125 | if self.is_vec: 126 | x = self._triuvec(x) 127 | return x 128 | 129 | 130 | 131 | class Covpool(Function): 132 | @staticmethod 133 | def forward(ctx, input): 134 | x = input 135 | batchSize = x.data.shape[0] 136 | dim = x.data.shape[1] 137 | h = x.data.shape[2] 138 | w = x.data.shape[3] 139 | M = h*w 140 | x = x.reshape(batchSize,dim,M) 141 | I_hat = (-1./M/M)*torch.ones(M,M,device = x.device) + (1./M)*torch.eye(M,M,device = x.device) 142 | I_hat = I_hat.view(1,M,M).repeat(batchSize,1,1).type(x.dtype) 143 | y = x.bmm(I_hat).bmm(x.transpose(1,2)) 144 | ctx.save_for_backward(input,I_hat) 145 | return y 146 | @staticmethod 147 | def backward(ctx, grad_output): 148 | input,I_hat = ctx.saved_tensors 149 | x = input 150 | batchSize = x.data.shape[0] 151 | dim = x.data.shape[1] 152 | h = x.data.shape[2] 153 | w = x.data.shape[3] 154 | M = h*w 155 | x = x.reshape(batchSize,dim,M) 156 | grad_input = grad_output + grad_output.transpose(1,2) 157 | grad_input = grad_input.bmm(x).bmm(I_hat) 158 | grad_input = grad_input.reshape(batchSize,dim,h,w) 159 | return grad_input 160 | 161 | class Sqrtm(Function): 162 | @staticmethod 163 | def forward(ctx, input, iterN): 164 | x = input 165 | batchSize = x.data.shape[0] 166 | dim = x.data.shape[1] 167 | dtype = x.dtype 168 | I3 = 3.0*torch.eye(dim,dim,device = x.device).view(1, dim, dim).repeat(batchSize,1,1).type(dtype) 169 | normA = (1.0/3.0)*x.mul(I3).sum(dim=1).sum(dim=1) 170 | A = x.div(normA.view(batchSize,1,1).expand_as(x)) 171 | Y = torch.zeros(batchSize, iterN, dim, dim, requires_grad = False, device = x.device).type(dtype) 172 | Z = torch.eye(dim,dim,device = x.device).view(1,dim,dim).repeat(batchSize,iterN,1,1).type(dtype) 173 | if iterN < 2: 174 | ZY = 0.5*(I3 - A) 175 | YZY = A.bmm(ZY) 176 | else: 177 | ZY = 0.5*(I3 - A) 178 | Y[:,0,:,:] = A.bmm(ZY) 179 | Z[:,0,:,:] = ZY 180 | for i in range(1, iterN-1): 181 | ZY = 0.5*(I3 - Z[:,i-1,:,:].bmm(Y[:,i-1,:,:])) 182 | Y[:,i,:,:] = Y[:,i-1,:,:].bmm(ZY) 183 | Z[:,i,:,:] = ZY.bmm(Z[:,i-1,:,:]) 184 | ZYZ = 0.5 * (I3 - Z[:,iterN-2,:,:].bmm(Y[:,iterN-2,:,:])).bmm(Z[:,iterN-2,:,:]) 185 | y = ZYZ * torch.pow(normA,-0.5).view(batchSize, 1, 1).expand_as(x) 186 | ctx.save_for_backward(input, A, ZYZ, normA, Y, Z) 187 | ctx.iterN = iterN 188 | return y 189 | @staticmethod 190 | def backward(ctx, grad_output): 191 | input, A, ZY, normA, Y, Z = ctx.saved_tensors 192 | iterN = ctx.iterN 193 | x = input 194 | batchSize = x.data.shape[0] 195 | dim = x.data.shape[1] 196 | dtype = x.dtype 197 | der_postCom = grad_output*torch.pow(normA, -0.5).view(batchSize, 1, 1).expand_as(x) 198 | der_postComAux = -0.5*torch.pow(normA, -1.5)*((grad_output*ZY).sum(dim=1).sum(dim=1)) 199 | I3 = 3.0*torch.eye(dim,dim,device = x.device).view(1, dim, dim).repeat(batchSize,1,1).type(dtype) 200 | if iterN < 2: 201 | der_NSiter = 0.5*(der_postCom.bmm(I3 - A) - A.bmm(der_postCom)) 202 | else: 203 | dldZ = 0.5*((I3 - Y[:,iterN-2,:,:].bmm(Z[:,iterN-2,:,:])).bmm(der_postCom) - 204 | der_postCom.bmm(Z[:,iterN-2,:,:]).bmm(Y[:,iterN-2,:,:])) 205 | dldY = -0.5*Z[:,iterN-2,:,:].bmm(der_postCom).bmm(Z[:,iterN-2,:,:]) 206 | for i in range(iterN-3, -1, -1): 207 | YZ = I3 - Y[:,i,:,:].bmm(Z[:,i,:,:]) 208 | ZY = Z[:,i,:,:].bmm(Y[:,i,:,:]) 209 | dldY_ = 0.5*(dldY.bmm(YZ) - 210 | Z[:,i,:,:].bmm(dldZ).bmm(Z[:,i,:,:]) - 211 | ZY.bmm(dldY)) 212 | dldZ_ = 0.5*(YZ.bmm(dldZ) - 213 | Y[:,i,:,:].bmm(dldY).bmm(Y[:,i,:,:]) - 214 | dldZ.bmm(ZY)) 215 | dldY = dldY_ 216 | dldZ = dldZ_ 217 | der_NSiter = 0.5*(dldY.bmm(I3 - A) - dldZ - A.bmm(dldY)) 218 | der_NSiter = der_NSiter.transpose(1, 2) 219 | grad_input = der_NSiter.div(normA.view(batchSize,1,1).expand_as(x)) 220 | grad_aux = der_NSiter.mul(x).sum(dim=1).sum(dim=1) 221 | for i in range(batchSize): 222 | grad_input[i,:,:] += (der_postComAux[i] \ 223 | - grad_aux[i] / (normA[i] * normA[i])) \ 224 | *torch.ones(dim,device = x.device).diag().type(dtype) 225 | return grad_input, None 226 | 227 | class Triuvec(Function): 228 | @staticmethod 229 | def forward(ctx, input): 230 | x = input 231 | batchSize = x.data.shape[0] 232 | dim = x.data.shape[1] 233 | dtype = x.dtype 234 | x = x.reshape(batchSize, dim*dim) 235 | I = torch.ones(dim,dim).triu().reshape(dim*dim) 236 | index = I.nonzero() 237 | y = torch.zeros(batchSize,int(dim*(dim+1)/2),device = x.device).type(dtype) 238 | y = x[:,index] 239 | ctx.save_for_backward(input,index) 240 | return y 241 | @staticmethod 242 | def backward(ctx, grad_output): 243 | input,index = ctx.saved_tensors 244 | x = input 245 | batchSize = x.data.shape[0] 246 | dim = x.data.shape[1] 247 | dtype = x.dtype 248 | grad_input = torch.zeros(batchSize,dim*dim,device = x.device,requires_grad=False).type(dtype) 249 | grad_input[:,index] = grad_output 250 | grad_input = grad_input.reshape(batchSize,dim,dim) 251 | return grad_input 252 | 253 | def CovpoolLayer(var): 254 | return Covpool.apply(var) 255 | 256 | def SqrtmLayer(var, iterN): 257 | return Sqrtm.apply(var, iterN) 258 | 259 | def InvcovpoolLayer(var): 260 | return InverseCOV.apply(var) 261 | 262 | def TriuvecLayer(var): 263 | return Triuvec.apply(var) 264 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Learning Partial Correlation based Deep Visual Representation for Image Classification
Official implementation with PyTorch 2 | 3 | ### [Paper](https://arxiv.org/abs/2304.11597v2) | [Project Website](https://csiro-robotics.github.io/iSICE) 4 | ![iSICE](isice.png) 5 | This repository contains the model definitions, training/evaluation code and pre-trained model weights for our paper exploring partial correlation based deep SPD visual representation. More information are available on our [project website](https://csiro-robotics.github.io/iSICE). 6 | 7 | > Learning Partial Correlation based Deep Visual Representation for Image Classification
8 | > [Saimunur Rahman](#), [Piotr Koniusz](http://users.cecs.anu.edu.au/~koniusz), [Lei Wang](https://sites.google.com/view/lei-hs-wang), [Luping Zhou](https://www.sydney.edu.au/engineering/about/our-people/academic-staff/luping-zhou.html), [Peyman Moghadam](https://people.csiro.au/m/p/peyman-moghadam), [Changming Sun](https://vision-cdc.csiro.au/changming.sun)
9 | > CSIRO Data61, University of Wollongong, Australian National University, University of Sydney, Queensland University of Technology 10 | 11 | Visual representation based on covariance matrix has demonstrates its efficacy for image classification by characterising the pairwise correlation of different channels in convolutional feature maps. However, pairwise correlation will become misleading once there is another channel correlating with both channels of interest, resulting in the "confounding" effect. For this case, "partial correlation" which removes the confounding effect shall be estimated instead. Nevertheless, reliably estimating partial correlation requires to solve a symmetric positive definite matrix optimisation, known as sparse inverse covariance estimation (SICE). How to incorporate this process into CNN remains an open issue. In this work, we formulate SICE as a novel structured layer of CNN. To ensure the CNN still be end-to-end trainable, we develop an iterative method based on Newton-Schulz iteration to solve the above matrix optimisation during forward and backward propagation steps. Our work not only obtains a partial correlation based deep visual representation but also mitigates the small sample problem frequently encountered by covariance matrix estimation in CNN. Computationally, our model can be effectively trained with GPU and works well with a large number of channels in advanced CNN models. Experimental results confirm the efficacy of the proposed deep visual representation and its superior classification performance to that of its covariance matrix based counterparts. 12 | 13 | This repository contains: 14 | 15 | :heavy_check_mark: A simple implementation of our method with PyTorch
16 | :heavy_check_mark: A script useful for training/evaluating our method on various datasets
17 | :heavy_check_mark: Pre-trained model weights on several datasets 18 | 19 | ## Repository Setup Guide 20 | To run our code on your machine, the first step would be repository download which can be done using the following commands: 21 | ```bash 22 | cd /the/path/where/you/want/to/copy/the/code 23 | git clone https://github.com/csiro-robotics/iSICE.git 24 | cd iSICE 25 | ``` 26 | The second step is to create a conda enovironment with necessary python packages which can be done using the following commands: 27 | 28 | ```bash 29 | conda create -name iSICE 30 | conda install pytorch torchvision cudatoolkit torchaudio scipy matplotlib -c pytorch 31 | ``` 32 | 33 | For easiness of use, we only use common python packages so that users can run our code with less difficulty. If you do not have anaconda installed, you can either install anaconda or its lighter version miniconda, or use python virtual environment. In case of python virtual environment, the packages can be installed with `pip`. Please see [here](https://pip.pypa.io/en/stable/cli/pip_install) for details. We also provided the `isice.yml' file for creating conda environment similar to us. 34 | 35 | Note that we have evaluated our code with PyTorch 1.9.0. However, there should not be problem with other versions released after PyTorch 0.4.0. The above command will provide GPU support via CUDA which supports CPU by default. 36 | 37 | The third step is to activate the above conda enovironment with the following command: 38 | 39 | ```bash 40 | conda activate iSICE 41 | ``` 42 | 43 | The forth step will be downloading the datasets. All datasets should be prepared as follows. 44 | 45 | ```bash 46 | . 47 | ├── train 48 | │   ├── class 1 49 | │   │   ├── image_001.format 50 | │   │   ├── image_002.format 51 | | | └── ... 52 | │   ├── class 2 53 | │   ├── class 3 54 | │   ├── ... 55 | │   ├── ... 56 | │   └── class N 57 | └── val 58 | ├── class 1 59 | │   ├── image_001.format 60 | │   ├── image_002.format 61 | | └── ... 62 | ├── class 2 63 | ├── class 3 64 | ├── ... 65 | ├── ... 66 | └── class N 67 | ``` 68 | 69 | ## Repository Overview 70 | We use a modular design for this repository. From our experience, we find that such design is easy to manage and extend. Our code repository is segmented as follows. 71 | 72 | ```bash 73 | ├── main.py 74 | ├── imagepreprocess.py 75 | ├── functions.py 76 | ├── model_init.py 77 | ├── src 78 | │   ├── network 79 | │   │   ├── __init__.py 80 | │   │   ├── base.py 81 | │   │   ├── inception.py 82 | │   │   ├── alexnet.py 83 | │   │   ├── resnet.py 84 | │   │   └── vgg.py 85 | │   ├── representation 86 | │   │   ├── __init__.py 87 | │   │   ├── SICE.py 88 | │   │   ├── INVCOV.py 89 | │   │   ├── COV.py 90 | │   │   ├── GAvP.py 91 | ├── train_iSICE_model.sh 92 | ``` 93 | 94 | ## How to use our code 95 | Our `main.py` maintains the process of running our code to reproduce the results reported in the paper. Supppose, we want to train a partial correlation representation model based on VGG-16 backbone with CUB-200 dataset (referred as Birds dataset in the paper) and evaluate on the same dataset, the following command can be used: 96 | 97 | ```bash 98 | python main.py /path/to/CUB --benchmark CUB --pretrained -a vgg16_bn --epochs 100 --lr 1.2e-4 --lr-method step --lr-params 15 30 -j 10 -b 65 --num-classes 200 --representation SICE --freezed-layer 0 --classifier-factor 5 --modeldir /path/to/save/the/model/and/meta/information 99 | ``` 100 | 101 | As training progresses, loss, top-1 error and top-5 error information for both training and test evaluation will be automatically saved in the path specified with `--modeldir` parameter above. 102 | 103 | For training on computing clusters such as HPC please use the `train_iSICE_model.sh` script by changing it various fields as per the given instructions on the script (we showed how to train using MIT indoor dataset). Our code is compatible with multiple GPU training. 104 | 105 | ## Pre-trained models 106 | For convanience, we provide our VGG-16 and ResNet-50 based partial correlation models on traned on fine-grained and scene datasets. They can be downloaded here. 107 | 108 | #### Pairwise correlation based models (computed via iSQRT-COV pooling) 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 |
MITAirplaneBirdsCars
Backbonetop1 acc. (%)Modeltop1 acc. (%)Modeltop1 acc. (%)Modeltop1 acc. (%)Model
VGG-1676.1TBA90.0TBA84.5TBA91.2TBA
ResNet-5078.8 TBA90.9TBA84.3TBA92.1TBA
155 | 156 | #### Partial correlation based models (computed via Precision Matrix described in Algorithm 1 of the paper) 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 |
MITAirplaneBirdsCars
Backbonetop1 acc. (%)Modeltop1 acc. (%)Modeltop1 acc. (%)Modeltop1 acc. (%)Model
VGG-1680.2 TBA89.4 TBA83.4 TBA92.0TBA
ResNet-5080.8 TBA91.2 TBA84.7 TBA92.0TBA
203 | 204 | #### Partial correlation based models (computed via iSICE described in Algorithm 2 of the paper) 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 |
MITAirplaneBirdsCars
Backbonetop1 acc. (%)Modeltop1 acc. (%)Modeltop1 acc. (%)Modeltop1 acc. (%)Model
VGG-1678.7 TBA92.2 TBA86.5 TBA94.0TBA
ResNet-5080.5 TBA92.7 TBA85.9 TBA93.5TBA
251 | 252 | Pre-trained models canbe used as a checkpoints for further training/evaluation using the following command: 253 | 254 | ```bash 255 | python main.py /path/to/CUB --benchmark CUB --pretrained -a vgg16_bn --epochs 100 --lr 1.2e-4 --lr-method step --lr-params 15\ 30 -j 10 -b 65 --num-classes 200 --representation SICE --freezed-layer 0 --classifier-factor 5 --resume /path/to/downloaded/model 256 | ``` 257 | 258 | ## How to cite our paper 259 | Please use the following bibtex reference to cite our paper. 260 | ```bibtex 261 | @InProceedings{isice_cvpr, 262 | author = {Rahman, Saimunur and Koniusz, Piotr and Wang, Lei and Zhou, Luping and Moghadam, Peyman and Sun, Changming}, 263 | title = {Learning Partial Correlation based Deep Visual Representation for Image Classification}, 264 | booktitle = {IEEE/CVF Int. Conf. on Computer Vision and Pattern Recognition (CVPR)}, 265 | month = {June}, 266 | year = {2023} 267 | } 268 | ``` 269 | 270 | ## Acknowledgments 271 | 272 | This codebase borrows from [iSQRT-COV repository](https://github.com/jiangtaoxie/fast-MPN-COV), we thank the authors for maintaining the repository. 273 | 274 | ## Contact 275 | If you have any questions or suggestions, please contact `saimun.rahman@data61.csiro.au`. 276 | 277 | -------------------------------------------------------------------------------- /src/network/inception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | 7 | __all__ = ['Inception3', 'inception_v3'] 8 | 9 | 10 | model_urls = { 11 | # Inception v3 ported from TensorFlow 12 | 'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth', 13 | } 14 | 15 | 16 | def inception_v3(pretrained=False, **kwargs): 17 | r"""Inception v3 model architecture from 18 | `"Rethinking the Inception Architecture for Computer Vision" `_. 19 | 20 | Args: 21 | pretrained (bool): If True, returns a model pre-trained on ImageNet 22 | """ 23 | if pretrained: 24 | if 'transform_input' not in kwargs: 25 | kwargs['transform_input'] = True 26 | model = Inception3(**kwargs) 27 | model.load_state_dict(model_zoo.load_url(model_urls['inception_v3_google'])) 28 | return model 29 | 30 | return Inception3(**kwargs) 31 | 32 | 33 | class Inception3(nn.Module): 34 | 35 | def __init__(self, num_classes=1000, aux_logits=True, transform_input=False): 36 | super(Inception3, self).__init__() 37 | self.aux_logits = aux_logits 38 | self.transform_input = transform_input 39 | self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2) 40 | self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3) 41 | self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) 42 | self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) 43 | self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) 44 | self.Mixed_5b = InceptionA(192, pool_features=32) 45 | self.Mixed_5c = InceptionA(256, pool_features=64) 46 | self.Mixed_5d = InceptionA(288, pool_features=64) 47 | self.Mixed_6a = InceptionB(288) 48 | self.Mixed_6b = InceptionC(768, channels_7x7=128) 49 | self.Mixed_6c = InceptionC(768, channels_7x7=160) 50 | self.Mixed_6d = InceptionC(768, channels_7x7=160) 51 | self.Mixed_6e = InceptionC(768, channels_7x7=192) 52 | if aux_logits: 53 | self.AuxLogits = InceptionAux(768, num_classes) 54 | self.Mixed_7a = InceptionD(768) 55 | self.Mixed_7b = InceptionE(1280) 56 | self.Mixed_7c = InceptionE(2048) 57 | self.fc = nn.Linear(2048, num_classes) 58 | 59 | for m in self.modules(): 60 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 61 | import scipy.stats as stats 62 | stddev = m.stddev if hasattr(m, 'stddev') else 0.1 63 | X = stats.truncnorm(-2, 2, scale=stddev) 64 | values = torch.Tensor(X.rvs(m.weight.numel())) 65 | values = values.view(m.weight.size()) 66 | m.weight.data.copy_(values) 67 | elif isinstance(m, nn.BatchNorm2d): 68 | nn.init.constant_(m.weight, 1) 69 | nn.init.constant_(m.bias, 0) 70 | 71 | def forward(self, x): 72 | if self.transform_input: 73 | x = x.clone() 74 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 75 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 76 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 77 | # 299 x 299 x 3 78 | x = self.Conv2d_1a_3x3(x) 79 | # 149 x 149 x 32 80 | x = self.Conv2d_2a_3x3(x) 81 | # 147 x 147 x 32 82 | x = self.Conv2d_2b_3x3(x) 83 | # 147 x 147 x 64 84 | x = F.max_pool2d(x, kernel_size=3, stride=2) 85 | # 73 x 73 x 64 86 | x = self.Conv2d_3b_1x1(x) 87 | # 73 x 73 x 80 88 | x = self.Conv2d_4a_3x3(x) 89 | # 71 x 71 x 192 90 | x = F.max_pool2d(x, kernel_size=3, stride=2) 91 | # 35 x 35 x 192 92 | x = self.Mixed_5b(x) 93 | # 35 x 35 x 256 94 | x = self.Mixed_5c(x) 95 | # 35 x 35 x 288 96 | x = self.Mixed_5d(x) 97 | # 35 x 35 x 288 98 | x = self.Mixed_6a(x) 99 | # 17 x 17 x 768 100 | x = self.Mixed_6b(x) 101 | # 17 x 17 x 768 102 | x = self.Mixed_6c(x) 103 | # 17 x 17 x 768 104 | x = self.Mixed_6d(x) 105 | # 17 x 17 x 768 106 | x = self.Mixed_6e(x) 107 | # 17 x 17 x 768 108 | if self.training and self.aux_logits: 109 | aux = self.AuxLogits(x) 110 | # 17 x 17 x 768 111 | x = self.Mixed_7a(x) 112 | # 8 x 8 x 1280 113 | x = self.Mixed_7b(x) 114 | # 8 x 8 x 2048 115 | x = self.Mixed_7c(x) 116 | # 8 x 8 x 2048 117 | x = F.avg_pool2d(x, kernel_size=8) 118 | # 1 x 1 x 2048 119 | x = F.dropout(x, training=self.training) 120 | # 1 x 1 x 2048 121 | x = x.view(x.size(0), -1) 122 | # 2048 123 | x = self.fc(x) 124 | # 1000 (num_classes) 125 | if self.training and self.aux_logits: 126 | return x, aux 127 | return x 128 | 129 | 130 | class InceptionA(nn.Module): 131 | 132 | def __init__(self, in_channels, pool_features): 133 | super(InceptionA, self).__init__() 134 | self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1) 135 | 136 | self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1) 137 | self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2) 138 | 139 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1) 140 | self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1) 141 | self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, padding=1) 142 | 143 | self.branch_pool = BasicConv2d(in_channels, pool_features, kernel_size=1) 144 | 145 | def forward(self, x): 146 | branch1x1 = self.branch1x1(x) 147 | 148 | branch5x5 = self.branch5x5_1(x) 149 | branch5x5 = self.branch5x5_2(branch5x5) 150 | 151 | branch3x3dbl = self.branch3x3dbl_1(x) 152 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 153 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 154 | 155 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 156 | branch_pool = self.branch_pool(branch_pool) 157 | 158 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 159 | return torch.cat(outputs, 1) 160 | 161 | 162 | class InceptionB(nn.Module): 163 | 164 | def __init__(self, in_channels): 165 | super(InceptionB, self).__init__() 166 | self.branch3x3 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2) 167 | 168 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1) 169 | self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1) 170 | self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, stride=2) 171 | 172 | def forward(self, x): 173 | branch3x3 = self.branch3x3(x) 174 | 175 | branch3x3dbl = self.branch3x3dbl_1(x) 176 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 177 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 178 | 179 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) 180 | 181 | outputs = [branch3x3, branch3x3dbl, branch_pool] 182 | return torch.cat(outputs, 1) 183 | 184 | 185 | class InceptionC(nn.Module): 186 | 187 | def __init__(self, in_channels, channels_7x7): 188 | super(InceptionC, self).__init__() 189 | self.branch1x1 = BasicConv2d(in_channels, 192, kernel_size=1) 190 | 191 | c7 = channels_7x7 192 | self.branch7x7_1 = BasicConv2d(in_channels, c7, kernel_size=1) 193 | self.branch7x7_2 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)) 194 | self.branch7x7_3 = BasicConv2d(c7, 192, kernel_size=(7, 1), padding=(3, 0)) 195 | 196 | self.branch7x7dbl_1 = BasicConv2d(in_channels, c7, kernel_size=1) 197 | self.branch7x7dbl_2 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)) 198 | self.branch7x7dbl_3 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)) 199 | self.branch7x7dbl_4 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)) 200 | self.branch7x7dbl_5 = BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3)) 201 | 202 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1) 203 | 204 | def forward(self, x): 205 | branch1x1 = self.branch1x1(x) 206 | 207 | branch7x7 = self.branch7x7_1(x) 208 | branch7x7 = self.branch7x7_2(branch7x7) 209 | branch7x7 = self.branch7x7_3(branch7x7) 210 | 211 | branch7x7dbl = self.branch7x7dbl_1(x) 212 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 213 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 214 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 215 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 216 | 217 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 218 | branch_pool = self.branch_pool(branch_pool) 219 | 220 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 221 | return torch.cat(outputs, 1) 222 | 223 | 224 | class InceptionD(nn.Module): 225 | 226 | def __init__(self, in_channels): 227 | super(InceptionD, self).__init__() 228 | self.branch3x3_1 = BasicConv2d(in_channels, 192, kernel_size=1) 229 | self.branch3x3_2 = BasicConv2d(192, 320, kernel_size=3, stride=2) 230 | 231 | self.branch7x7x3_1 = BasicConv2d(in_channels, 192, kernel_size=1) 232 | self.branch7x7x3_2 = BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3)) 233 | self.branch7x7x3_3 = BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0)) 234 | self.branch7x7x3_4 = BasicConv2d(192, 192, kernel_size=3, stride=2) 235 | 236 | def forward(self, x): 237 | branch3x3 = self.branch3x3_1(x) 238 | branch3x3 = self.branch3x3_2(branch3x3) 239 | 240 | branch7x7x3 = self.branch7x7x3_1(x) 241 | branch7x7x3 = self.branch7x7x3_2(branch7x7x3) 242 | branch7x7x3 = self.branch7x7x3_3(branch7x7x3) 243 | branch7x7x3 = self.branch7x7x3_4(branch7x7x3) 244 | 245 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) 246 | outputs = [branch3x3, branch7x7x3, branch_pool] 247 | return torch.cat(outputs, 1) 248 | 249 | 250 | class InceptionE(nn.Module): 251 | 252 | def __init__(self, in_channels): 253 | super(InceptionE, self).__init__() 254 | self.branch1x1 = BasicConv2d(in_channels, 320, kernel_size=1) 255 | 256 | self.branch3x3_1 = BasicConv2d(in_channels, 384, kernel_size=1) 257 | self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 258 | self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 259 | 260 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 448, kernel_size=1) 261 | self.branch3x3dbl_2 = BasicConv2d(448, 384, kernel_size=3, padding=1) 262 | self.branch3x3dbl_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 263 | self.branch3x3dbl_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 264 | 265 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1) 266 | 267 | def forward(self, x): 268 | branch1x1 = self.branch1x1(x) 269 | 270 | branch3x3 = self.branch3x3_1(x) 271 | branch3x3 = [ 272 | self.branch3x3_2a(branch3x3), 273 | self.branch3x3_2b(branch3x3), 274 | ] 275 | branch3x3 = torch.cat(branch3x3, 1) 276 | 277 | branch3x3dbl = self.branch3x3dbl_1(x) 278 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 279 | branch3x3dbl = [ 280 | self.branch3x3dbl_3a(branch3x3dbl), 281 | self.branch3x3dbl_3b(branch3x3dbl), 282 | ] 283 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 284 | 285 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 286 | branch_pool = self.branch_pool(branch_pool) 287 | 288 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 289 | return torch.cat(outputs, 1) 290 | 291 | 292 | class InceptionAux(nn.Module): 293 | 294 | def __init__(self, in_channels, num_classes): 295 | super(InceptionAux, self).__init__() 296 | self.conv0 = BasicConv2d(in_channels, 128, kernel_size=1) 297 | self.conv1 = BasicConv2d(128, 768, kernel_size=5) 298 | self.conv1.stddev = 0.01 299 | self.fc = nn.Linear(768, num_classes) 300 | self.fc.stddev = 0.001 301 | 302 | def forward(self, x): 303 | # 17 x 17 x 768 304 | x = F.avg_pool2d(x, kernel_size=5, stride=3) 305 | # 5 x 5 x 768 306 | x = self.conv0(x) 307 | # 5 x 5 x 128 308 | x = self.conv1(x) 309 | # 1 x 1 x 768 310 | x = x.view(x.size(0), -1) 311 | # 768 312 | x = self.fc(x) 313 | # 1000 314 | return x 315 | 316 | 317 | class BasicConv2d(nn.Module): 318 | 319 | def __init__(self, in_channels, out_channels, **kwargs): 320 | super(BasicConv2d, self).__init__() 321 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 322 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 323 | 324 | def forward(self, x): 325 | x = self.conv(x) 326 | x = self.bn(x) 327 | return F.relu(x, inplace=True) 328 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import warnings 7 | 8 | import numpy as np 9 | 10 | 11 | from torchvision import datasets 12 | from functions import * 13 | from imagepreprocess import * 14 | from model_init import * 15 | from src.representation import * 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.parallel 19 | import torch.backends.cudnn as cudnn 20 | import torch.distributed as dist 21 | import torch.optim 22 | import torch.utils.data 23 | import torch.utils.data.distributed 24 | 25 | 26 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 27 | parser.add_argument('data', metavar='DIR', 28 | help='path to dataset') 29 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', 30 | help='model architecture: ') 31 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 32 | help='number of data loading workers (default: 4)') 33 | parser.add_argument('--epochs', default=100, type=int, metavar='N', 34 | help='number of total epochs to run') 35 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 36 | help='manual epoch number (useful on restarts)') 37 | parser.add_argument('-b', '--batch-size', default=256, type=int, 38 | metavar='N', help='mini-batch size (default: 256)') 39 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 40 | metavar='LR', help='initial learning rate') 41 | parser.add_argument('--lr-method', default='step', type=str, 42 | help='method of learning rate') 43 | parser.add_argument('--lr-params', default=[], dest='lr_params',nargs='*',type=float, 44 | action='append', help='params of lr method') 45 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 46 | help='momentum') 47 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 48 | metavar='W', help='weight decay (default: 1e-4)') 49 | parser.add_argument('--print-freq', '-p', default=10, type=int, 50 | metavar='N', help='print frequency (default: 10)') 51 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 52 | help='path to latest checkpoint (default: none)') 53 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 54 | help='evaluate model on validation set') 55 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 56 | help='use pre-trained model') 57 | parser.add_argument('--world-size', default=1, type=int, 58 | help='number of distributed processes') 59 | parser.add_argument('--dist-url', default='tcp://127.0.0.1:50', type=str, 60 | help='url used to set up distributed training') 61 | parser.add_argument('--dist-backend', default='gloo', type=str, 62 | help='distributed backend') 63 | parser.add_argument('--seed', default=None, type=int, 64 | help='seed for initializing training. ') 65 | parser.add_argument('--gpu', default=None, type=int, 66 | help='GPU id to use.') 67 | parser.add_argument('--modeldir', default=None, type=str, 68 | help='director of checkpoint') 69 | parser.add_argument('--representation', default=None, type=str, 70 | help='define the representation method') 71 | parser.add_argument('--num-classes', default=None, type=int, 72 | help='define the number of classes') 73 | parser.add_argument('--freezed-layer', default=None, type=int, 74 | help='define the end of freezed layer') 75 | parser.add_argument('--store-model-everyepoch', dest='store_model_everyepoch', action='store_true', 76 | help='store checkpoint in every epoch') 77 | parser.add_argument('--classifier-factor', default=None, type=int, 78 | help='define the multiply factor of classifier') 79 | parser.add_argument('--benchmark', default=None, type=str, 80 | help='name of dataset') 81 | parser.add_argument('--sparsity', default=0.01, type=float, 82 | help='sparsity value (default: 0.01)') 83 | parser.add_argument('--iterations', default=5, type=int, 84 | help='solver iterations (default: 5)') 85 | parser.add_argument('--sicelr', default=5.0, type=float, 86 | help='solver learning rate (default: 5)') 87 | best_prec1 = 0 88 | 89 | 90 | def main(): 91 | global args, best_prec1 92 | args = parser.parse_args() 93 | print(args) 94 | if args.seed is not None: 95 | random.seed(args.seed) 96 | torch.manual_seed(args.seed) 97 | cudnn.deterministic = True 98 | warnings.warn('You have chosen to seed training. ' 99 | 'This will turn on the CUDNN deterministic setting, ' 100 | 'which can slow down your training considerably! ' 101 | 'You may see unexpected behavior when restarting ' 102 | 'from checkpoints.') 103 | 104 | if args.gpu is not None: 105 | warnings.warn('You have chosen a specific GPU. This will completely ' 106 | 'disable data parallelism.') 107 | 108 | args.distributed = args.world_size > 1 109 | 110 | if args.distributed: 111 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 112 | world_size=args.world_size) 113 | 114 | # create model 115 | if args.representation == 'GAvP': 116 | representation = {'function':GAvP, 117 | 'input_dim':2048, 118 | 'dimension_reduction':256} 119 | elif args.representation == 'COV': 120 | representation = {'function':COV, 121 | 'iterNum':5, 122 | 'is_sqrt':True, 123 | 'is_vec':True, 124 | 'input_dim':2048, 125 | 'dimension_reduction':None} 126 | elif args.representation == 'INVCOV': 127 | representation = {'function':INVCOV, 128 | 'iterNum':7, 129 | 'is_sqrt':True, 130 | 'is_vec':True, 131 | 'input_dim':2048, 132 | 'dimension_reduction':256} 133 | elif args.representation == 'SICE': 134 | representation = {'function':SICE, 135 | 'iterNum':args.iterations, 136 | 'is_sqrt':True, 137 | 'is_vec':True, 138 | 'input_dim':2048, 139 | 'dimension_reduction':256, 140 | 'sparsity_val':args.sparsity, 141 | 'sice_lrate':args.sicelr} 142 | else: 143 | warnings.warn('=> You did not choose a global image representation method!') 144 | representation = None # which for original vgg or alexnet 145 | 146 | model = get_model(args.arch, 147 | representation, 148 | args.num_classes, 149 | args.freezed_layer, 150 | pretrained=args.pretrained) 151 | 152 | # obtain learning rate 153 | LR = Learning_rate_generater(args.lr_method, args.lr_params, args.epochs) 154 | if args.pretrained: 155 | params_list = [{'params': model.features.parameters(), 'lr': args.lr, 156 | 'weight_decay': args.weight_decay},] 157 | params_list.append({'params': model.representation.parameters(), 'lr': args.lr, 158 | 'weight_decay': args.weight_decay}) 159 | params_list.append({'params': model.classifier.parameters(), 160 | 'lr': args.lr*args.classifier_factor, 161 | 'weight_decay': 0. if args.arch.startswith('vgg') else args.weight_decay}) 162 | else: 163 | params_list = [{'params': model.features.parameters(), 'lr': args.lr, 164 | 'weight_decay': args.weight_decay},] 165 | params_list.append({'params': model.representation.parameters(), 'lr': args.lr, 166 | 'weight_decay': args.weight_decay}) 167 | params_list.append({'params': model.classifier.parameters(), 168 | 'lr': args.lr*args.classifier_factor, 169 | 'weight_decay':args.weight_decay}) 170 | 171 | optimizer = torch.optim.AdamW(params_list, lr=args.lr) 172 | 173 | if args.gpu is not None: 174 | model = model.cuda(args.gpu) 175 | elif args.distributed: 176 | model.cuda() 177 | model = torch.nn.parallel.DistributedDataParallel(model) 178 | else: 179 | model.features = torch.nn.DataParallel(model.features) 180 | model.cuda() 181 | 182 | # define loss function (criterion) and optimizer 183 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 184 | 185 | 186 | # optionally resume from a checkpoint 187 | if args.resume: 188 | if os.path.isfile(args.resume): 189 | print("=> loading checkpoint '{}'".format(args.resume)) 190 | checkpoint = torch.load(args.resume) 191 | args.start_epoch = checkpoint['epoch'] 192 | best_prec1 = checkpoint['best_prec1'] 193 | model.load_state_dict(checkpoint['state_dict']) 194 | optimizer.load_state_dict(checkpoint['optimizer']) 195 | print("=> loaded checkpoint '{}' (epoch {})" 196 | .format(args.resume, checkpoint['epoch'])) 197 | else: 198 | print("=> no checkpoint found at '{}'".format(args.resume)) 199 | 200 | cudnn.benchmark = True 201 | 202 | # Data loading code 203 | traindir = os.path.join(args.data, 'train') 204 | valdir = os.path.join(args.data, 'val') 205 | train_transforms, val_transforms, evaluate_transforms = preprocess_strategy(args.benchmark) 206 | 207 | train_dataset = datasets.ImageFolder( 208 | traindir, 209 | train_transforms) 210 | 211 | if args.distributed: 212 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 213 | else: 214 | train_sampler = None 215 | 216 | train_loader = torch.utils.data.DataLoader( 217 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 218 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 219 | 220 | val_loader = torch.utils.data.DataLoader( 221 | datasets.ImageFolder(valdir, val_transforms), 222 | batch_size=args.batch_size, shuffle=False, 223 | num_workers=args.workers, pin_memory=True) 224 | 225 | ## init evaluation data loader 226 | if evaluate_transforms is not None: 227 | evaluate_loader = torch.utils.data.DataLoader( 228 | datasets.ImageFolder(valdir, evaluate_transforms), 229 | batch_size=args.batch_size, shuffle=False, 230 | num_workers=args.workers, pin_memory=True) 231 | 232 | if args.evaluate: 233 | if evaluate_transforms is not None: 234 | validate(evaluate_loader, model, criterion) 235 | validate(val_loader, model, criterion) 236 | return 237 | # make directory for storing checkpoint files 238 | if os.path.exists(args.modeldir) is not True: 239 | os.mkdir(args.modeldir) 240 | stats_ = stats(args.modeldir, args.start_epoch) 241 | for epoch in range(args.start_epoch, args.epochs): 242 | if args.distributed: 243 | train_sampler.set_epoch(epoch) 244 | adjust_learning_rate(optimizer, LR.lr_factor, epoch) 245 | # train for one epoch 246 | trainObj, top1, top5 = train(train_loader, model, criterion, optimizer, epoch) 247 | # evaluate on validation set 248 | valObj, prec1, prec5 = validate(val_loader, model, criterion) 249 | # update stats 250 | stats_._update(trainObj, top1, top5, valObj, prec1, prec5) 251 | # remember best prec@1 and save checkpoint 252 | is_best = prec1 > best_prec1 253 | best_prec1 = max(prec1, best_prec1) 254 | filename = [] 255 | if args.store_model_everyepoch: 256 | filename.append(os.path.join(args.modeldir, 'net-epoch-%s.pth.tar' % (epoch + 1))) 257 | else: 258 | filename.append(os.path.join(args.modeldir, 'checkpoint.pth.tar')) 259 | filename.append(os.path.join(args.modeldir, 'model_best.pth.tar')) 260 | save_checkpoint({ 261 | 'epoch': epoch + 1, 262 | 'arch': args.arch, 263 | 'state_dict': model.state_dict(), 264 | 'best_prec1': best_prec1, 265 | 'optimizer' : optimizer.state_dict(), 266 | }, is_best, filename) 267 | plot_curve(stats_, args.modeldir, True) 268 | data = stats_ 269 | sio.savemat(os.path.join(args.modeldir,'stats.mat'), {'data':data}) 270 | if evaluate_transforms is not None: 271 | model_file = os.path.join(args.modeldir, 'model_best.pth.tar') 272 | print("=> loading best model '{}'".format(model_file)) 273 | print("=> start evaluation") 274 | best_model = torch.load(model_file) 275 | model.load_state_dict(best_model['state_dict']) 276 | validate(evaluate_loader, model, criterion) 277 | 278 | 279 | 280 | 281 | def train(train_loader, model, criterion, optimizer, epoch): 282 | batch_time = AverageMeter() 283 | data_time = AverageMeter() 284 | losses = AverageMeter() 285 | top1 = AverageMeter() 286 | top5 = AverageMeter() 287 | 288 | # switch to train mode 289 | model.train() 290 | 291 | end = time.time() 292 | for i, (input, target) in enumerate(train_loader): 293 | # measure data loading time 294 | data_time.update(time.time() - end) 295 | 296 | if args.gpu is not None: 297 | input = input.cuda(args.gpu, non_blocking=True) 298 | target = target.cuda(args.gpu, non_blocking=True) 299 | 300 | # compute output 301 | output = model(input) 302 | loss = criterion(output, target) 303 | 304 | # measure accuracy and record loss 305 | prec1, prec5 = accuracy(output, target, topk=(1, 5)) 306 | losses.update(loss.item(), input.size(0)) 307 | top1.update(prec1[0], input.size(0)) 308 | top5.update(prec5[0], input.size(0)) 309 | 310 | # compute gradient and do SGD step 311 | optimizer.zero_grad() 312 | loss.backward() 313 | optimizer.step() 314 | 315 | # measure elapsed time 316 | batch_time.update(time.time() - end) 317 | end = time.time() 318 | 319 | if i % args.print_freq == 0: 320 | print('Epoch: [{0}][{1}/{2}]\t' 321 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 322 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 323 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 324 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 325 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 326 | epoch, i, len(train_loader), batch_time=batch_time, 327 | data_time=data_time, loss=losses, top1=top1, top5=top5)) 328 | return losses.avg, top1.avg, top5.avg 329 | 330 | 331 | def validate(val_loader, model, criterion): 332 | batch_time = AverageMeter() 333 | losses = AverageMeter() 334 | top1 = AverageMeter() 335 | top5 = AverageMeter() 336 | 337 | # switch to evaluate mode 338 | model.eval() 339 | 340 | with torch.no_grad(): 341 | end = time.time() 342 | for i, (input, target) in enumerate(val_loader): 343 | if args.gpu is not None: 344 | input = input.cuda(args.gpu, non_blocking=True) 345 | target = target.cuda(args.gpu, non_blocking=True) 346 | 347 | # compute output 348 | if len(input.size()) > 4:# 5-D tensor 349 | bs, crops, ch, h, w = input.size() 350 | output = model(input.view(-1, ch, h, w)) 351 | # fuse scores among all crops 352 | output = output.view(bs, crops, -1).mean(dim=1) 353 | else: 354 | output = model(input) 355 | loss = criterion(output, target) 356 | 357 | # measure accuracy and record loss 358 | prec1, prec5 = accuracy(output, target, topk=(1, 5)) 359 | losses.update(loss.item(), input.size(0)) 360 | top1.update(prec1[0], input.size(0)) 361 | top5.update(prec5[0], input.size(0)) 362 | 363 | # measure elapsed time 364 | batch_time.update(time.time() - end) 365 | end = time.time() 366 | 367 | if i % args.print_freq == 0: 368 | print('Test: [{0}/{1}]\t' 369 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 370 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 371 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 372 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 373 | i, len(val_loader), batch_time=batch_time, loss=losses, 374 | top1=top1, top5=top5)) 375 | 376 | print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' 377 | .format(top1=top1, top5=top5)) 378 | 379 | return losses.avg, top1.avg, top5.avg 380 | 381 | 382 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 383 | torch.save(state, filename[0]) 384 | if is_best: 385 | shutil.copyfile(filename[0], filename[1]) 386 | 387 | 388 | class AverageMeter(object): 389 | """Computes and stores the average and current value""" 390 | def __init__(self): 391 | self.reset() 392 | 393 | def reset(self): 394 | self.val = 0 395 | self.avg = 0 396 | self.sum = 0 397 | self.count = 0 398 | 399 | def update(self, val, n=1): 400 | self.val = val 401 | self.sum += val * n 402 | self.count += n 403 | self.avg = self.sum / self.count 404 | 405 | 406 | class Learning_rate_generater(object): 407 | """Generates a list of learning rate for each training epoch""" 408 | def __init__(self, method, params, total_epoch): 409 | if method == 'step': 410 | lr_factor, lr = self.step(params, total_epoch) 411 | elif method == 'log': 412 | lr_factor, lr = self.log(params, total_epoch) 413 | else: 414 | raise KeyError("=> undefined learning rate method '{}'" .format(method)) 415 | self.lr_factor = lr_factor 416 | self.lr = lr 417 | def step(self, params, total_epoch): 418 | decrease_until = params[0] 419 | decrease_num = len(decrease_until) 420 | base_factor = 0.1 421 | lr_factor = torch.ones(total_epoch, dtype=torch.double) 422 | lr = [args.lr] 423 | for num in range(decrease_num): 424 | if decrease_until[num] < total_epoch: 425 | lr_factor[int(decrease_until[num])] = base_factor 426 | for epoch in range(1,total_epoch): 427 | lr.append(lr[-1]*lr_factor[epoch]) 428 | return lr_factor, lr 429 | def log(self, params, total_epoch): 430 | params = params[0] 431 | left_range = params[0] 432 | right_range = params[1] 433 | np_lr = np.logspace(left_range, right_range, total_epoch) 434 | lr_factor = [1] 435 | lr = [np_lr[0]] 436 | for epoch in range(1, total_epoch): 437 | lr.append(np_lr[epoch]) 438 | lr_factor.append(np_lr[epoch]/np_lr[epoch-1]) 439 | if lr[0] != args.lr: 440 | args.lr = lr[0] 441 | return lr_factor, lr 442 | 443 | 444 | def adjust_learning_rate(optimizer, lr_factor, epoch): 445 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 446 | #lr = args.lr * (0.1 ** (epoch // 30)) 447 | groups = ['features'] 448 | groups.append('representation') 449 | groups.append('classifier') 450 | num_group = 0 451 | for param_group in optimizer.param_groups: 452 | param_group['lr'] *= lr_factor[epoch] 453 | print('the learning rate is set to {0:.8f} in {1:} part'.format(param_group['lr'], groups[num_group])) 454 | num_group += 1 455 | 456 | 457 | def accuracy(output, target, topk=(1,)): 458 | """Computes the precision@k for the specified values of k""" 459 | with torch.no_grad(): 460 | maxk = max(topk) 461 | batch_size = target.size(0) 462 | 463 | _, pred = output.topk(maxk, 1, True, True) 464 | pred = pred.t() 465 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 466 | 467 | res = [] 468 | for k in topk: 469 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 470 | res.append(correct_k.mul_(100.0 / batch_size)) 471 | return res 472 | 473 | 474 | if __name__ == '__main__': 475 | main() 476 | --------------------------------------------------------------------------------