├── README.md ├── crossx.png ├── crossxresnetavg.py ├── crossxresnetmix.py ├── crossxsenetavg.py ├── crossxsenetmix.py ├── demo.py ├── initialization.py ├── modellearning.py ├── prediction.py ├── utils ├── imdb.py ├── modelserial.py ├── mydataloader.py ├── myimagefolder.py └── receptivesize.py └── x-imdb ├── cubbirds-imdb.py ├── nabirds-imdb.py ├── stcars-imdb.py ├── stdogs-imdb.py └── vggaircraft-imdb.py /README.md: -------------------------------------------------------------------------------- 1 | # CrossX 2 | 3 | This is PyTorch implementation of our ICCV 2019 paper ["Cross-X Learning for Fine-Grained Visual Categorization"](https://arxiv.org/abs/1909.04412). We experimented on 5 fine-grained benchmark datasets --- NABirds, CUB-200-2011, Stanford Cars, Stanford Dogs, and VGG-Aircraft. You should first download these datasets from their project homepages before runing CrossX. 4 | 5 | 6 | ## Appoach 7 | 8 | ![alt text](https://github.com/cswluo/CrossX/blob/crossx/crossx.png) 9 | 10 | ## Implementation 11 | 12 | Our implementation is based on PyTorch(>1.0), CUDA 9.0, and Python 3.5. 13 | 14 | A "x-imdb.py" is provided for each dataset to generate Python pickle files, which are then used to prepare train/val/trainval/test data. Run "x-imdb.py" in the folder of your dataset to generate corresponding pickle file (imdb.pkl) should be the very first step. 15 | 16 | - demo.py is used to train your own CrossX model from scratch. 17 | 18 | - prediction.py outputs classification accuracy by employing pretrained CrossX models. 19 | 20 | Due to the random generation of train/val/test data on some datasets, the classification accuracy may have a bit fluctuation but it should be in a reasonable range. 21 | 22 | The pretrained CrossX models can be download from [HERE](https://pan.baidu.com/s/1k6NaffqmbakH9Vng-CLxlg#list/path=%2F). If you plan to train your own CrossX model from scratch by using the SENet backbone, you need to download the pretrained SENet-50 weights from [HERE](https://pan.baidu.com/s/1803G5v0KDU0B_NS62Ril3A#list/path=%2F). 23 | 24 | ## Results 25 | 26 | | | CrossX-SENet-50 | CrossX-ResNet-50 | 27 | |:-------------|:---------------:|:----------------:| 28 | |NABirds |86.4% |86.2% | 29 | |CUB-200-2011 |87.5% |87.7% | 30 | |Stanford Cars |94.5% |94.6% | 31 | |Stanford Dogs |88.2% |88.9% | 32 | |VGG-Aircraft |92.7% |92.6% | 33 | 34 | 35 | ## Citation 36 | 37 | If you use CrossX in your research, please cite the paper: 38 | ``` 39 | @inproceedings{luowei@19iccv, 40 | author = {Wei Luo and Xitong Yang and Xianjie Mo and Yuheng Lu and Larry S. Davis and Ser-Nam Lim}, 41 | title = {Cross-X learning for fine-grained visual categorization}, 42 | booktitle = {ICCV}, 43 | year = {2019}, 44 | } 45 | ``` 46 | -------------------------------------------------------------------------------- /crossx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cswluo/CrossX/0ed68c11e6d4e061bf8439a8ba028ba31e01f497/crossx.png -------------------------------------------------------------------------------- /crossxresnetavg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | import torch 5 | import numpy as np 6 | import torch.nn.functional as F 7 | import pdb 8 | 9 | 10 | eps = np.finfo(float).eps 11 | device = torch.device("cuda:0" if torch.cuda.is_available() > 0 else "cpu") 12 | 13 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 14 | 'resnet152'] 15 | 16 | 17 | model_urls = { 18 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 19 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 20 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 21 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 22 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 23 | } 24 | 25 | 26 | def conv3x3(in_planes, out_planes, stride=1): 27 | """3x3 convolution with padding""" 28 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 29 | padding=1, bias=False) 30 | 31 | 32 | class RegularLoss(nn.Module): 33 | 34 | def __init__(self, gamma=0, part_features=None, nparts=1): 35 | """ 36 | :param bs: batch size 37 | :param ncrops: number of crops used at constructing dataset 38 | """ 39 | super(RegularLoss, self).__init__() 40 | self.register_buffer('part_features', part_features) 41 | self.nparts = nparts 42 | self.gamma = gamma 43 | # self.batchsize = bs 44 | # self.ncrops = ncrops 45 | 46 | def forward(self, x): 47 | assert isinstance( 48 | x, list), "parts features should be presented in a list" 49 | corr_matrix = torch.zeros(self.nparts, self.nparts) 50 | # x = [torch.div(xx, xx.norm(dim=1, keepdim=True)) for xx in x] 51 | for i in range(self.nparts): 52 | x[i] = x[i].squeeze() 53 | # x[i] = x[i].view(self.batchsize, self.ncrops, -1).mean(1) 54 | x[i] = torch.div(x[i], x[i].norm(dim=1, keepdim=True)) 55 | 56 | # # original design 57 | for i in range(self.nparts): 58 | for j in range(self.nparts): 59 | corr_matrix[i, j] = torch.mean(torch.mm(x[i], x[j].t())) 60 | if i == j: 61 | corr_matrix[i, j] = 1.0 - corr_matrix[i, j] 62 | 63 | regloss = torch.mul(torch.sum(torch.triu(corr_matrix)), self.gamma).to(device) 64 | 65 | return regloss 66 | 67 | class SELayer(nn.Module): 68 | def __init__(self, channel, reduction=16): 69 | super(SELayer, self).__init__() 70 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 71 | self.fc = nn.Sequential( 72 | nn.Linear(channel, channel // reduction), 73 | nn.ReLU(inplace=True), 74 | nn.Linear(channel // reduction, channel), 75 | nn.Sigmoid() 76 | ) 77 | 78 | def forward(self, x): 79 | b, c, _, _ = x.size() 80 | y = self.avg_pool(x).view(b, c) 81 | y = self.fc(y).view(b, c, 1, 1) 82 | return x * y 83 | 84 | 85 | class MELayer(nn.Module): 86 | def __init__(self, channel, reduction=16, nparts=1): 87 | super(MELayer, self).__init__() 88 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 89 | self.nparts = nparts 90 | parts = list() 91 | for part in range(self.nparts): 92 | parts.append(nn.Sequential( 93 | nn.Linear(channel, channel // reduction), 94 | nn.ReLU(inplace=True), 95 | nn.Linear(channel // reduction, channel), 96 | nn.Sigmoid() 97 | )) 98 | self.parts = nn.Sequential(*parts) 99 | 100 | def forward(self, x): 101 | b, c, _, _ = x.size() 102 | y = self.avg_pool(x).view(b, c) 103 | 104 | meouts = list() 105 | for i in range(self.nparts): 106 | meouts.append(x * self.parts[i](y).view(b, c, 1, 1)) 107 | 108 | return meouts 109 | 110 | 111 | class Bottleneck(nn.Module): 112 | expansion = 4 113 | 114 | def __init__(self, inplanes, planes, stride=1, downsample=None, meflag=False, nparts=1, reduction=1): 115 | super(Bottleneck, self).__init__() 116 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 117 | self.bn1 = nn.BatchNorm2d(planes) 118 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 119 | self.bn2 = nn.BatchNorm2d(planes) 120 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 121 | self.bn3 = nn.BatchNorm2d(planes * 4) 122 | self.relu = nn.ReLU(inplace=True) 123 | self.meflag = meflag 124 | if self.meflag: 125 | self.me = MELayer(planes * 4, nparts=nparts, reduction=reduction) 126 | 127 | self.downsample = downsample 128 | self.stride = stride 129 | 130 | def forward(self, x): 131 | residual = x 132 | 133 | out = self.conv1(x) 134 | out = self.bn1(out) 135 | out = self.relu(out) 136 | 137 | out = self.conv2(out) 138 | out = self.bn2(out) 139 | out = self.relu(out) 140 | 141 | out = self.conv3(out) 142 | out = self.bn3(out) 143 | 144 | if self.downsample is not None: 145 | residual = self.downsample(x) 146 | 147 | if self.meflag: 148 | 149 | outreach = out.clone() 150 | parts = self.me(outreach) 151 | 152 | out += residual 153 | out = self.relu(out) 154 | 155 | for i in range(len(parts)): 156 | parts[i] = self.relu(parts[i] + residual) 157 | return out, parts 158 | else: 159 | out += residual 160 | out = self.relu(out) 161 | return out 162 | 163 | 164 | class ResNet(nn.Module): 165 | 166 | def __init__(self, block, layers, nparts=1, meflag=False, num_classes=1000): 167 | self.nparts = nparts 168 | self.nclass = num_classes 169 | self.meflag = meflag 170 | self.inplanes = 64 171 | super(ResNet, self).__init__() 172 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 173 | bias=False) 174 | self.bn1 = nn.BatchNorm2d(64) 175 | self.relu = nn.ReLU(inplace=True) 176 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 177 | self.layer1 = self._make_layer(block, 64, layers[0]) 178 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 179 | self.layer3 = self._make_layer(block, 256, layers[2], meflag=meflag, stride=2, nparts=nparts, reduction=256) 180 | self.layer4 = self._make_layer(block, 512, layers[3], meflag=meflag, stride=2, nparts=nparts, reduction=256) 181 | self.adpavgpool = nn.AdaptiveAvgPool2d(1) 182 | self.fc_ulti = nn.Linear(512 * block.expansion * nparts, num_classes) 183 | 184 | # if meflag 185 | if self.nparts > 1: 186 | self.adpmaxpool = nn.AdaptiveMaxPool2d(1) 187 | self.fc_plty = nn.Linear(256 * block.expansion * nparts, num_classes) 188 | self.fc_cmbn = nn.Linear(256 * block.expansion * nparts, num_classes) 189 | 190 | # for the last convolutional layer 191 | self.conv2_1 = nn.Conv2d(512 * block.expansion, 256 * block.expansion, kernel_size=1, bias=False) 192 | self.conv2_2 = nn.Conv2d(512 * block.expansion, 256 * block.expansion, kernel_size=1, bias=False) 193 | 194 | # for the penultimate layer 195 | self.conv3_1 = nn.Conv2d(256 * block.expansion, 256 * block.expansion, kernel_size=3, padding=1, bias=False) 196 | self.conv3_2 = nn.Conv2d(256 * block.expansion, 256 * block.expansion, kernel_size=3, padding=1, bias=False) 197 | self.bn3_1 = nn.BatchNorm2d(256 * block.expansion) 198 | self.bn3_2 = nn.BatchNorm2d(256 * block.expansion) 199 | 200 | if nparts == 3: 201 | self.conv2_3 = nn.Conv2d(512 * block.expansion, 256 * block.expansion, kernel_size=1, bias=False) 202 | self.conv3_3 = nn.Conv2d(256 * block.expansion, 256 * block.expansion, kernel_size=3, padding=1, bias=False) 203 | self.bn3_3 = nn.BatchNorm2d(256 * block.expansion) 204 | 205 | 206 | for m in self.modules(): 207 | if isinstance(m, nn.Conv2d): 208 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 209 | m.weight.data.normal_(0, math.sqrt(2. / n)) 210 | elif isinstance(m, nn.BatchNorm2d): 211 | m.weight.data.fill_(1) 212 | m.bias.data.zero_() 213 | 214 | def _make_layer(self, block, planes, blocks, meflag=False, stride=1, nparts=1, reduction=1): 215 | downsample = None 216 | if stride != 1 or self.inplanes != planes * block.expansion: 217 | downsample = nn.Sequential( 218 | nn.Conv2d(self.inplanes, planes * block.expansion, 219 | kernel_size=1, stride=stride, bias=False), 220 | nn.BatchNorm2d(planes * block.expansion), 221 | ) 222 | 223 | layers = [] 224 | layers.append(block(self.inplanes, planes, stride, downsample)) 225 | self.inplanes = planes * block.expansion 226 | for i in range(1, blocks): 227 | if i == blocks - 1: 228 | layers.append(block(self.inplanes, planes, meflag=meflag, nparts=nparts, reduction=reduction)) 229 | else: 230 | layers.append(block(self.inplanes, planes)) 231 | return nn.Sequential(*layers) 232 | 233 | def forward(self, x): 234 | x = self.conv1(x) 235 | x = self.bn1(x) 236 | x = self.relu(x) 237 | x = self.maxpool(x) 238 | 239 | x = self.layer1(x) 240 | x = self.layer2(x) 241 | if self.meflag: 242 | x, plty_parts = self.layer3(x) 243 | _, ulti_parts = self.layer4(x) 244 | 245 | cmbn_ftres = list() 246 | for i in range(self.nparts): 247 | # pdb.set_trace() 248 | if i == 0: 249 | ulti_parts_iplt = F.interpolate(self.conv2_1(ulti_parts[i]), 28) 250 | cmbn_ftres.append(self.adpavgpool(self.bn3_1(self.conv3_1(torch.add(plty_parts[i], ulti_parts_iplt))))) 251 | elif i == 1: 252 | ulti_parts_iplt = F.interpolate(self.conv2_2(ulti_parts[i]), 28) 253 | cmbn_ftres.append(self.adpavgpool(self.bn3_2(self.conv3_2(torch.add(plty_parts[i], ulti_parts_iplt))))) 254 | elif i == 2: 255 | ulti_parts_iplt = F.interpolate(self.conv2_3(ulti_parts[i]), 28) 256 | cmbn_ftres.append(self.adpavgpool(self.bn3_3(self.conv3_3(torch.add(plty_parts[i], ulti_parts_iplt))))) 257 | 258 | plty_parts[i] = self.adpavgpool(plty_parts[i]) 259 | ulti_parts[i] = self.adpavgpool(ulti_parts[i]) 260 | 261 | # for the penultimate layer 262 | #pdb.set_trace() 263 | xp = torch.cat(plty_parts, 1) 264 | xp = xp.view(xp.size(0), -1) 265 | xp = self.fc_plty(xp) 266 | 267 | # for the final layer 268 | xf = torch.cat(ulti_parts, 1) 269 | xf = xf.view(xf.size(0), -1) 270 | xf = self.fc_ulti(xf) 271 | 272 | # for the combined feature 273 | xc = torch.cat(cmbn_ftres, 1) 274 | xc = xc.view(xc.size(0), -1) 275 | xc = self.fc_cmbn(xc) 276 | 277 | return xf, xp, xc, ulti_parts, plty_parts, cmbn_ftres 278 | 279 | else: 280 | x = self.layer3(x) 281 | x = self.layer4(x) 282 | 283 | x = self.adpavgpool(x) 284 | x = x.view(x.size(0), -1) 285 | x = self.fc_ulti(x) 286 | 287 | return x 288 | 289 | 290 | 291 | def resnet50(pretrained=False, modelpath=None, **kwargs): 292 | """Constructs a ResNet-50 model. 293 | 294 | Args: 295 | pretrained (bool): If True, returns a model pre-trained on ImageNet 296 | """ 297 | #pdb.set_trace() 298 | if kwargs['nparts'] > 1: 299 | # resnet with osme 300 | kwargs.setdefault('meflag', True) 301 | else: 302 | # the normal resnet 303 | kwargs.setdefault('meflag', False) 304 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 305 | if pretrained: 306 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'], modelpath), strict=False) 307 | return model 308 | 309 | -------------------------------------------------------------------------------- /crossxresnetmix.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | import torch 5 | import numpy as np 6 | import torch.nn.functional as F 7 | import pdb 8 | 9 | 10 | eps = np.finfo(float).eps 11 | device = torch.device("cuda:0" if torch.cuda.is_available() > 0 else "cpu") 12 | 13 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 14 | 'resnet152'] 15 | 16 | 17 | model_urls = { 18 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 19 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 20 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 21 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 22 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 23 | } 24 | 25 | 26 | def conv3x3(in_planes, out_planes, stride=1): 27 | """3x3 convolution with padding""" 28 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 29 | padding=1, bias=False) 30 | 31 | 32 | class RegularLoss(nn.Module): 33 | 34 | def __init__(self, gamma=0, part_features=None, nparts=1): 35 | """ 36 | :param bs: batch size 37 | :param ncrops: number of crops used at constructing dataset 38 | """ 39 | super(RegularLoss, self).__init__() 40 | #self.register_buffer('part_features', part_features) 41 | self.nparts = nparts 42 | self.gamma = gamma 43 | 44 | def forward(self, x): 45 | assert isinstance(x, list), "parts features should be presented in a list" 46 | corr_matrix = torch.zeros(self.nparts, self.nparts) 47 | for i in range(self.nparts): 48 | x[i] = x[i].squeeze() 49 | x[i] = torch.div(x[i], x[i].norm(dim=1, keepdim=True)) 50 | 51 | # original design 52 | for i in range(self.nparts): 53 | for j in range(self.nparts): 54 | corr_matrix[i, j] = torch.mean(torch.mm(x[i], x[j].t())) 55 | if i == j: 56 | corr_matrix[i, j] = 1.0 - corr_matrix[i, j] 57 | regloss = torch.mul(torch.sum(torch.triu(corr_matrix)), self.gamma).to(device) 58 | 59 | return regloss 60 | 61 | 62 | class SELayer(nn.Module): 63 | def __init__(self, channel, reduction=16): 64 | super(SELayer, self).__init__() 65 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 66 | self.fc = nn.Sequential( 67 | nn.Linear(channel, channel // reduction), 68 | nn.ReLU(inplace=True), 69 | nn.Linear(channel // reduction, channel), 70 | nn.Sigmoid() 71 | ) 72 | 73 | def forward(self, x): 74 | b, c, _, _ = x.size() 75 | y = self.avg_pool(x).view(b, c) 76 | y = self.fc(y).view(b, c, 1, 1) 77 | return x * y 78 | 79 | 80 | class MELayer(nn.Module): 81 | def __init__(self, channel, reduction=16, nparts=1): 82 | super(MELayer, self).__init__() 83 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 84 | self.nparts = nparts 85 | parts = list() 86 | for part in range(self.nparts): 87 | parts.append(nn.Sequential( 88 | nn.Linear(channel, channel // reduction), 89 | nn.ReLU(inplace=True), 90 | nn.Linear(channel // reduction, channel), 91 | nn.Sigmoid() 92 | )) 93 | self.parts = nn.Sequential(*parts) 94 | 95 | def forward(self, x): 96 | b, c, _, _ = x.size() 97 | y = self.avg_pool(x).view(b, c) 98 | 99 | meouts = list() 100 | for i in range(self.nparts): 101 | meouts.append(x * self.parts[i](y).view(b, c, 1, 1)) 102 | 103 | return meouts 104 | 105 | 106 | class Bottleneck(nn.Module): 107 | expansion = 4 108 | 109 | def __init__(self, inplanes, planes, stride=1, downsample=None, meflag=False, nparts=1, reduction=1): 110 | super(Bottleneck, self).__init__() 111 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 112 | self.bn1 = nn.BatchNorm2d(planes) 113 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 114 | self.bn2 = nn.BatchNorm2d(planes) 115 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 116 | self.bn3 = nn.BatchNorm2d(planes * 4) 117 | self.relu = nn.ReLU(inplace=True) 118 | self.meflag = meflag 119 | if self.meflag: 120 | self.me = MELayer(planes * 4, nparts=nparts, reduction=reduction) 121 | 122 | self.downsample = downsample 123 | self.stride = stride 124 | 125 | def forward(self, x): 126 | residual = x 127 | 128 | out = self.conv1(x) 129 | out = self.bn1(out) 130 | out = self.relu(out) 131 | 132 | out = self.conv2(out) 133 | out = self.bn2(out) 134 | out = self.relu(out) 135 | 136 | out = self.conv3(out) 137 | out = self.bn3(out) 138 | 139 | if self.downsample is not None: 140 | residual = self.downsample(x) 141 | 142 | if self.meflag: 143 | 144 | outreach = out.clone() 145 | parts = self.me(outreach) 146 | 147 | out += residual 148 | out = self.relu(out) 149 | 150 | for i in range(len(parts)): 151 | parts[i] = self.relu(parts[i] + residual) 152 | return out, parts 153 | else: 154 | out += residual 155 | out = self.relu(out) 156 | return out 157 | 158 | 159 | class ResNet(nn.Module): 160 | 161 | def __init__(self, block, layers, nparts=1, meflag=False, num_classes=1000): 162 | self.nparts = nparts 163 | self.nclass = num_classes 164 | self.meflag = meflag 165 | self.inplanes = 64 166 | super(ResNet, self).__init__() 167 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 168 | bias=False) 169 | self.bn1 = nn.BatchNorm2d(64) 170 | self.relu = nn.ReLU(inplace=True) 171 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 172 | self.layer1 = self._make_layer(block, 64, layers[0]) 173 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 174 | self.layer3 = self._make_layer(block, 256, layers[2], meflag=meflag, stride=2, nparts=nparts, reduction=256) 175 | self.layer4 = self._make_layer(block, 512, layers[3], meflag=meflag, stride=2, nparts=nparts, reduction=256) 176 | self.adpavgpool = nn.AdaptiveAvgPool2d(1) 177 | 178 | # if meflag == False, vanilla resnet 179 | self.fc_ulti = nn.Linear(512 * block.expansion * nparts, num_classes) 180 | 181 | # if meflag == True, multiple branch outputs 182 | if self.nparts > 1: 183 | self.adpmaxpool = nn.AdaptiveMaxPool2d(1) 184 | self.fc_plty = nn.Linear(256 * block.expansion * nparts, num_classes) 185 | self.fc_cmbn = nn.Linear(256 * block.expansion * nparts, num_classes) 186 | 187 | # dimension reducing for the last convolutional layer 188 | self.conv2_1 = nn.Conv2d(512 * block.expansion, 256 * block.expansion, kernel_size=1, bias=False) 189 | self.conv2_2 = nn.Conv2d(512 * block.expansion, 256 * block.expansion, kernel_size=1, bias=False) 190 | 191 | # combinign feature maps from the penultimate layer and the dimension-reduced final layer 192 | self.conv3_1 = nn.Conv2d(256 * block.expansion, 256 * block.expansion, kernel_size=3, padding=1, bias=False) 193 | self.conv3_2 = nn.Conv2d(256 * block.expansion, 256 * block.expansion, kernel_size=3, padding=1, bias=False) 194 | self.bn3_1 = nn.BatchNorm2d(256 * block.expansion) 195 | self.bn3_2 = nn.BatchNorm2d(256 * block.expansion) 196 | 197 | if nparts == 3: 198 | self.conv2_3 = nn.Conv2d(512 * block.expansion, 256 * block.expansion, kernel_size=1, bias=False) 199 | self.conv3_3 = nn.Conv2d(256 * block.expansion, 256 * block.expansion, kernel_size=3, padding=1, bias=False) 200 | self.bn3_3 = nn.BatchNorm2d(256 * block.expansion) 201 | 202 | 203 | for m in self.modules(): 204 | if isinstance(m, nn.Conv2d): 205 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 206 | m.weight.data.normal_(0, math.sqrt(2. / n)) 207 | elif isinstance(m, nn.BatchNorm2d): 208 | m.weight.data.fill_(1) 209 | m.bias.data.zero_() 210 | 211 | def _make_layer(self, block, planes, blocks, meflag=False, stride=1, nparts=1, reduction=1): 212 | downsample = None 213 | if stride != 1 or self.inplanes != planes * block.expansion: 214 | downsample = nn.Sequential( 215 | nn.Conv2d(self.inplanes, planes * block.expansion, 216 | kernel_size=1, stride=stride, bias=False), 217 | nn.BatchNorm2d(planes * block.expansion), 218 | ) 219 | 220 | layers = [] 221 | layers.append(block(self.inplanes, planes, stride, downsample)) 222 | self.inplanes = planes * block.expansion 223 | for i in range(1, blocks): 224 | if i == blocks - 1: 225 | layers.append(block(self.inplanes, planes, meflag=meflag, nparts=nparts, reduction=reduction)) 226 | else: 227 | layers.append(block(self.inplanes, planes)) 228 | return nn.Sequential(*layers) 229 | 230 | def forward(self, x): 231 | x = self.conv1(x) 232 | x = self.bn1(x) 233 | x = self.relu(x) 234 | x = self.maxpool(x) 235 | 236 | x = self.layer1(x) 237 | x = self.layer2(x) 238 | if self.meflag: 239 | x, plty_parts = self.layer3(x) 240 | _, ulti_parts = self.layer4(x) 241 | 242 | cmbn_ftres = list() 243 | for i in range(self.nparts): 244 | # pdb.set_trace() 245 | if i == 0: 246 | ulti_parts_iplt = F.interpolate(self.conv2_1(ulti_parts[i]), 28) 247 | cmbn_ftres.append(self.adpavgpool(self.bn3_1(self.conv3_1(torch.add(plty_parts[i], ulti_parts_iplt))))) 248 | elif i == 1: 249 | ulti_parts_iplt = F.interpolate(self.conv2_2(ulti_parts[i]), 28) 250 | cmbn_ftres.append(self.adpavgpool(self.bn3_2(self.conv3_2(torch.add(plty_parts[i], ulti_parts_iplt))))) 251 | elif i == 2: 252 | ulti_parts_iplt = F.interpolate(self.conv2_3(ulti_parts[i]), 28) 253 | cmbn_ftres.append(self.adpavgpool(self.bn3_3(self.conv3_3(torch.add(plty_parts[i], ulti_parts_iplt))))) 254 | 255 | plty_parts[i] = self.adpmaxpool(plty_parts[i]) 256 | ulti_parts[i] = self.adpavgpool(ulti_parts[i]) 257 | 258 | # for the penultimate layer 259 | #pdb.set_trace() 260 | xp = torch.cat(plty_parts, 1) 261 | xp = xp.view(xp.size(0), -1) 262 | xp = self.fc_plty(xp) 263 | 264 | # for the final layer 265 | xf = torch.cat(ulti_parts, 1) 266 | xf = xf.view(xf.size(0), -1) 267 | xf = self.fc_ulti(xf) 268 | 269 | # for the combined feature 270 | xc = torch.cat(cmbn_ftres, 1) 271 | xc = xc.view(xc.size(0), -1) 272 | xc = self.fc_cmbn(xc) 273 | 274 | return xf, xp, xc, ulti_parts, plty_parts, cmbn_ftres 275 | 276 | else: 277 | x = self.layer3(x) 278 | x = self.layer4(x) 279 | 280 | x = self.adpavgpool(x) 281 | x = x.view(x.size(0), -1) 282 | x = self.fc_ulti(x) 283 | 284 | return x 285 | 286 | 287 | 288 | def resnet50(pretrained=False, modelpath=None, **kwargs): 289 | """Constructs a ResNet-50 model. 290 | 291 | Args: 292 | pretrained (bool): If True, initialize crossx using params of resnet50 pre-trained on ImageNet 293 | """ 294 | #pdb.set_trace() 295 | if kwargs['nparts'] > 1: 296 | # resnet with osme 297 | kwargs.setdefault('meflag', True) 298 | else: 299 | # the normal resnet 300 | kwargs.setdefault('meflag', False) 301 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 302 | if pretrained: 303 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'], modelpath), strict=False) 304 | return model 305 | -------------------------------------------------------------------------------- /crossxsenetavg.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torchvision.models import ResNet 5 | import numpy as np 6 | import torch.nn.functional as F 7 | import pdb 8 | 9 | eps = np.finfo(float).eps 10 | device = torch.device("cuda:0" if torch.cuda.is_available() > 0 else "cpu") 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 14 | 15 | 16 | ##################################### Loss functions 17 | class RegularLoss(nn.Module): 18 | 19 | def __init__(self, gamma=0, part_features=None, nparts=1): 20 | """ 21 | :param bs: batch size 22 | :param ncrops: number of crops used at constructing dataset 23 | """ 24 | super(RegularLoss, self).__init__() 25 | self.register_buffer('part_features', part_features) 26 | self.nparts = nparts 27 | self.gamma = gamma 28 | # self.batchsize = bs 29 | # self.ncrops = ncrops 30 | 31 | def forward(self, x): 32 | assert isinstance(x, list), "parts features should be presented in a list" 33 | corr_matrix = torch.zeros(self.nparts, self.nparts) 34 | # x = [torch.div(xx, xx.norm(dim=1, keepdim=True)) for xx in x] 35 | for i in range(self.nparts): 36 | x[i] = x[i].squeeze() 37 | x[i] = torch.div(x[i], x[i].norm(dim=1, keepdim=True)) 38 | 39 | for i in range(self.nparts): 40 | for j in range(self.nparts): 41 | corr_matrix[i, j] = torch.mean(torch.mm(x[i], x[j].t())) 42 | if i == j: 43 | corr_matrix[i, j] = 1.0 - corr_matrix[i, j] 44 | 45 | return torch.mul(torch.sum(torch.triu(corr_matrix)), self.gamma).to(device) 46 | 47 | 48 | ##################################### Squeeze-and-Excitation modules 49 | class SELayer(nn.Module): 50 | def __init__(self, channel, reduction=16): 51 | super(SELayer, self).__init__() 52 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 53 | self.fc = nn.Sequential( 54 | nn.Linear(channel, channel // reduction), 55 | nn.ReLU(inplace=True), 56 | nn.Linear(channel // reduction, channel), 57 | nn.Sigmoid() 58 | ) 59 | 60 | def forward(self, x): 61 | b, c, _, _ = x.size() 62 | y = self.avg_pool(x).view(b, c) 63 | y = self.fc(y).view(b, c, 1, 1) 64 | return x * y 65 | 66 | 67 | class MELayer(nn.Module): 68 | def __init__(self, channel, reduction=16, nparts=1): 69 | super(MELayer, self).__init__() 70 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 71 | self.nparts = nparts 72 | parts = list() 73 | for part in range(self.nparts): 74 | parts.append(nn.Sequential( 75 | nn.Linear(channel, channel // reduction), 76 | nn.ReLU(inplace=True), 77 | nn.Linear(channel // reduction, channel), 78 | nn.Sigmoid() 79 | )) 80 | self.parts = nn.Sequential(*parts) 81 | self.dresponse = nn.Sequential( 82 | nn.Linear(channel, channel // reduction), 83 | nn.ReLU(inplace=True), 84 | nn.Linear(channel // reduction, channel), 85 | nn.Sigmoid() 86 | ) 87 | 88 | def forward(self, x): 89 | b, c, _, _ = x.size() 90 | y = self.avg_pool(x).view(b, c) 91 | 92 | meouts = list() 93 | for i in range(self.nparts): 94 | meouts.append(x * self.parts[i](y).view(b, c, 1, 1)) 95 | 96 | y = self.dresponse(y).view(b, c, 1, 1) 97 | return x * y, meouts 98 | 99 | 100 | ##################################### SEBlocks 101 | class SEBottleneck(nn.Module): 102 | expansion = 4 103 | 104 | def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=16, meflag=False, nparts=1): 105 | super(SEBottleneck, self).__init__() 106 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 107 | self.bn1 = nn.BatchNorm2d(planes) 108 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 109 | self.bn2 = nn.BatchNorm2d(planes) 110 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 111 | self.bn3 = nn.BatchNorm2d(planes * 4) 112 | self.relu = nn.ReLU(inplace=True) 113 | self.meflag = meflag 114 | if self.meflag: 115 | self.se = MELayer(planes * 4, reduction=reduction, nparts=nparts) 116 | else: 117 | self.se = SELayer(planes * 4, reduction=reduction) 118 | self.downsample = downsample 119 | self.stride = stride 120 | 121 | def forward(self, x): 122 | residual = x 123 | 124 | out = self.conv1(x) 125 | out = self.bn1(out) 126 | out = self.relu(out) 127 | 128 | out = self.conv2(out) 129 | out = self.bn2(out) 130 | out = self.relu(out) 131 | 132 | out = self.conv3(out) 133 | out = self.bn3(out) 134 | 135 | if self.downsample is not None: 136 | residual = self.downsample(x) 137 | 138 | if self.meflag: 139 | out, parts = self.se(out) 140 | 141 | out += residual 142 | out = self.relu(out) 143 | 144 | for i in range(len(parts)): 145 | parts[i] = self.relu(parts[i] + residual) 146 | 147 | return out, parts 148 | else: 149 | out = self.se(out) 150 | out += residual 151 | out = self.relu(out) 152 | return out 153 | 154 | ###################################### ResNet framework 155 | class SeNet(nn.Module): 156 | 157 | def __init__(self, block, layers, num_classes=1000, rd=[16, 16, 16, 16], nparts=1, meflag=False): 158 | """ 159 | :param rd: reductions in SENet 160 | :param meflag: Ture for crossx senet, Flase for default senet 161 | 162 | """ 163 | super(SeNet, self).__init__() 164 | 165 | self.inplanes = 64 166 | self.meflag = meflag 167 | self.rd = rd 168 | self.nparts = nparts 169 | if not self.meflag: 170 | assert self.nparts == 1 171 | 172 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 173 | self.bn1 = nn.BatchNorm2d(64) 174 | self.relu = nn.ReLU(inplace=True) 175 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 176 | self.layer1 = self._make_layer(block, 64, layers[0], reduction=self.rd[0]) 177 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, reduction=self.rd[1]) 178 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, reduction=self.rd[2], meflag=meflag, nparts=nparts) 179 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, reduction=self.rd[3], meflag=meflag, nparts=nparts) 180 | self.adpavgpool = nn.AdaptiveAvgPool2d(1) 181 | self.fc_ulti = nn.Linear(512 * block.expansion * nparts, num_classes) 182 | 183 | if self.nparts > 1: 184 | self.adpmaxpool = nn.AdaptiveMaxPool2d(1) 185 | self.fc_plty = nn.Linear(256 * block.expansion * nparts, num_classes) 186 | self.fc_cmbn = nn.Linear(256 * block.expansion * nparts, num_classes) 187 | 188 | # for the last convolutional layer 189 | self.conv2_1 = nn.Conv2d(512 * block.expansion, 256 * block.expansion, kernel_size=1, bias=False) 190 | self.conv2_2 = nn.Conv2d(512 * block.expansion, 256 * block.expansion, kernel_size=1, bias=False) 191 | # self.bn2_1 = nn.BatchNorm2d(256 * block.expansion) 192 | # self.bn2_2 = nn.BatchNorm2d(256 * block.expansion) 193 | 194 | # for the penultimate layer 195 | self.conv3_1 = nn.Conv2d(256 * block.expansion, 256 * block.expansion, kernel_size=3, padding=1, bias=False) 196 | self.conv3_2 = nn.Conv2d(256 * block.expansion, 256 * block.expansion, kernel_size=3, padding=1, bias=False) 197 | self.bn3_1 = nn.BatchNorm2d(256 * block.expansion) 198 | self.bn3_2 = nn.BatchNorm2d(256 * block.expansion) 199 | 200 | if nparts == 3: 201 | self.conv2_3 = nn.Conv2d(512 * block.expansion, 256 * block.expansion, kernel_size=1, bias=False) 202 | self.conv3_3 = nn.Conv2d(256 * block.expansion, 256 * block.expansion, kernel_size=3, padding=1, bias=False) 203 | self.bn3_3 = nn.BatchNorm2d(256 * block.expansion) 204 | 205 | 206 | # initializing params 207 | for m in self.modules(): 208 | if isinstance(m, nn.Conv2d): 209 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 210 | m.weight.data.normal_(0, math.sqrt(2. / n)) 211 | elif isinstance(m, nn.BatchNorm2d): 212 | m.weight.data.fill_(1) 213 | m.bias.data.zero_() 214 | 215 | def _make_layer(self, block, planes, blocks, stride=1, reduction=16, meflag=False, nparts=1): 216 | downsample = None 217 | if stride != 1 or self.inplanes != planes * block.expansion: 218 | downsample = nn.Sequential( 219 | nn.Conv2d(self.inplanes, planes * block.expansion, 220 | kernel_size=1, stride=stride, bias=False), 221 | nn.BatchNorm2d(planes * block.expansion), 222 | ) 223 | 224 | layers = [] 225 | layers.append(block(self.inplanes, planes, stride, downsample, reduction)) 226 | self.inplanes = planes * block.expansion 227 | for i in range(1, blocks): 228 | if i == blocks - 1 and meflag is True: 229 | layers.append(block(self.inplanes, planes, reduction=reduction, meflag=meflag, nparts=nparts)) 230 | else: 231 | layers.append(block(self.inplanes, planes, reduction=reduction)) 232 | 233 | return nn.Sequential(*layers) 234 | 235 | def forward(self, x): 236 | x = self.conv1(x) 237 | x = self.bn1(x) 238 | x = self.relu(x) 239 | x = self.maxpool(x) 240 | x = self.layer1(x) 241 | x = self.layer2(x) 242 | 243 | 244 | if self.meflag: 245 | 246 | x, plty_parts = self.layer3(x) 247 | _, ulti_parts = self.layer4(x) 248 | 249 | cmbn_ftres = list() 250 | for i in range(self.nparts): 251 | if i == 0: 252 | ulti_parts_iplt = F.interpolate(self.conv2_1(ulti_parts[i]), 28) 253 | cmbn_ftres.append(self.adpavgpool(self.bn3_1(self.conv3_1(torch.add(plty_parts[i], ulti_parts_iplt))))) 254 | elif i == 1: 255 | ulti_parts_iplt = F.interpolate(self.conv2_2(ulti_parts[i]), 28) 256 | cmbn_ftres.append(self.adpavgpool(self.bn3_2(self.conv3_2(torch.add(plty_parts[i], ulti_parts_iplt))))) 257 | elif i == 2: 258 | ulti_parts_iplt = F.interpolate(self.conv2_3(ulti_parts[i]), 28) 259 | cmbn_ftres.append(self.adpavgpool(self.bn3_3(self.conv3_3(torch.add(plty_parts[i], ulti_parts_iplt))))) 260 | 261 | plty_parts[i] = self.adpavgpool(plty_parts[i]) 262 | ulti_parts[i] = self.adpavgpool(ulti_parts[i]) 263 | 264 | # for the penultimate layer 265 | xp = torch.cat(plty_parts, 1) 266 | xp = xp.view(xp.size(0), -1) 267 | xp = self.fc_plty(xp) 268 | 269 | # for the final layer 270 | xf = torch.cat(ulti_parts, 1) 271 | xf = xf.view(xf.size(0), -1) 272 | xf = self.fc_ulti(xf) 273 | 274 | # for the combined feature 275 | xc = torch.cat(cmbn_ftres, 1) 276 | xc = xc.view(xc.size(0), -1) 277 | xc = self.fc_cmbn(xc) 278 | 279 | return xf, xp, xc, ulti_parts, plty_parts, cmbn_ftres 280 | else: 281 | x = self.layer3(x) 282 | x = self.layer4(x) 283 | 284 | x = self.adpavgpool(x) 285 | x = x.view(x.size(0), -1) 286 | x = self.fc_ulti(x) 287 | return x 288 | 289 | 290 | 291 | ########################################## Models 292 | def senet50(num_classes=200, nparts=1, **kwargs): 293 | 294 | if nparts > 1: 295 | # resnet with osme 296 | kwargs.setdefault('meflag', True) 297 | else: 298 | # the normal resnet 299 | kwargs.setdefault('meflag', False) 300 | 301 | 302 | rd = [16, 32, 64, 128] 303 | 304 | if kwargs['meflag']: 305 | model = SeNet(SEBottleneck, [3, 4, 6, 3], num_classes=num_classes, rd=rd, nparts=nparts, meflag=True) 306 | else: 307 | # vanilla senet 308 | model = SeNet(SEBottleneck, [3, 4, 6, 3], num_classes=num_classes, rd=rd, nparts=nparts, meflag=False) 309 | 310 | return model 311 | 312 | -------------------------------------------------------------------------------- /crossxsenetmix.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torchvision.models import ResNet 5 | import numpy as np 6 | import torch.nn.functional as F 7 | import pdb 8 | 9 | eps = np.finfo(float).eps 10 | device = torch.device("cuda:0" if torch.cuda.is_available() > 0 else "cpu") 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 14 | 15 | 16 | ##################################### Loss functions 17 | class RegularLoss(nn.Module): 18 | 19 | def __init__(self, gamma=0, part_features=None, nparts=1): 20 | """ 21 | :param bs: batch size 22 | :param ncrops: number of crops used at constructing dataset 23 | """ 24 | super(RegularLoss, self).__init__() 25 | self.register_buffer('part_features', part_features) 26 | self.nparts = nparts 27 | self.gamma = gamma 28 | # self.batchsize = bs 29 | # self.ncrops = ncrops 30 | 31 | def forward(self, x): 32 | assert isinstance(x, list), "parts features should be presented in a list" 33 | corr_matrix = torch.zeros(self.nparts, self.nparts) 34 | # x = [torch.div(xx, xx.norm(dim=1, keepdim=True)) for xx in x] 35 | for i in range(self.nparts): 36 | x[i] = x[i].squeeze() 37 | x[i] = torch.div(x[i], x[i].norm(dim=1, keepdim=True)) 38 | 39 | for i in range(self.nparts): 40 | for j in range(self.nparts): 41 | corr_matrix[i, j] = torch.mean(torch.mm(x[i], x[j].t())) 42 | if i == j: 43 | corr_matrix[i, j] = 1.0 - corr_matrix[i, j] 44 | 45 | return torch.mul(torch.sum(torch.triu(corr_matrix)), self.gamma).to(device) 46 | 47 | 48 | ##################################### Squeeze-and-Excitation modules 49 | class SELayer(nn.Module): 50 | def __init__(self, channel, reduction=16): 51 | super(SELayer, self).__init__() 52 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 53 | self.fc = nn.Sequential( 54 | nn.Linear(channel, channel // reduction), 55 | nn.ReLU(inplace=True), 56 | nn.Linear(channel // reduction, channel), 57 | nn.Sigmoid() 58 | ) 59 | 60 | def forward(self, x): 61 | b, c, _, _ = x.size() 62 | y = self.avg_pool(x).view(b, c) 63 | y = self.fc(y).view(b, c, 1, 1) 64 | return x * y 65 | 66 | 67 | class MELayer(nn.Module): 68 | def __init__(self, channel, reduction=16, nparts=1): 69 | super(MELayer, self).__init__() 70 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 71 | self.nparts = nparts 72 | parts = list() 73 | for part in range(self.nparts): 74 | parts.append(nn.Sequential( 75 | nn.Linear(channel, channel // reduction), 76 | nn.ReLU(inplace=True), 77 | nn.Linear(channel // reduction, channel), 78 | nn.Sigmoid() 79 | )) 80 | self.parts = nn.Sequential(*parts) 81 | self.dresponse = nn.Sequential( 82 | nn.Linear(channel, channel // reduction), 83 | nn.ReLU(inplace=True), 84 | nn.Linear(channel // reduction, channel), 85 | nn.Sigmoid() 86 | ) 87 | 88 | def forward(self, x): 89 | b, c, _, _ = x.size() 90 | y = self.avg_pool(x).view(b, c) 91 | 92 | meouts = list() 93 | for i in range(self.nparts): 94 | meouts.append(x * self.parts[i](y).view(b, c, 1, 1)) 95 | 96 | y = self.dresponse(y).view(b, c, 1, 1) 97 | return x * y, meouts 98 | 99 | 100 | class SEBottleneck(nn.Module): 101 | expansion = 4 102 | 103 | def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=16, meflag=False, nparts=1): 104 | super(SEBottleneck, self).__init__() 105 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 106 | self.bn1 = nn.BatchNorm2d(planes) 107 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 108 | self.bn2 = nn.BatchNorm2d(planes) 109 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 110 | self.bn3 = nn.BatchNorm2d(planes * 4) 111 | self.relu = nn.ReLU(inplace=True) 112 | self.meflag = meflag 113 | if self.meflag: 114 | self.se = MELayer(planes * 4, reduction=reduction, nparts=nparts) 115 | else: 116 | self.se = SELayer(planes * 4, reduction=reduction) 117 | self.downsample = downsample 118 | self.stride = stride 119 | 120 | def forward(self, x): 121 | residual = x 122 | 123 | out = self.conv1(x) 124 | out = self.bn1(out) 125 | out = self.relu(out) 126 | 127 | out = self.conv2(out) 128 | out = self.bn2(out) 129 | out = self.relu(out) 130 | 131 | out = self.conv3(out) 132 | out = self.bn3(out) 133 | 134 | if self.downsample is not None: 135 | residual = self.downsample(x) 136 | 137 | if self.meflag: 138 | out, parts = self.se(out) 139 | 140 | out += residual 141 | out = self.relu(out) 142 | 143 | for i in range(len(parts)): 144 | parts[i] = self.relu(parts[i] + residual) 145 | 146 | return out, parts 147 | else: 148 | out = self.se(out) 149 | out += residual 150 | out = self.relu(out) 151 | return out 152 | 153 | ###################################### ResNet framework 154 | class SeNet(nn.Module): 155 | 156 | def __init__(self, block, layers, num_classes=1000, rd=[16, 16, 16, 16], nparts=1, meflag=False): 157 | """ 158 | :param rd: reductions in SENet 159 | :param meflag: Ture for crossx senet, Flase for default senet 160 | 161 | """ 162 | super(SeNet, self).__init__() 163 | 164 | self.inplanes = 64 165 | self.meflag = meflag 166 | self.rd = rd 167 | self.nparts = nparts 168 | if not self.meflag: 169 | assert self.nparts == 1 170 | 171 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 172 | self.bn1 = nn.BatchNorm2d(64) 173 | self.relu = nn.ReLU(inplace=True) 174 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 175 | self.layer1 = self._make_layer(block, 64, layers[0], reduction=self.rd[0]) 176 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, reduction=self.rd[1]) 177 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, reduction=self.rd[2], meflag=meflag, nparts=nparts) 178 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, reduction=self.rd[3], meflag=meflag, nparts=nparts) 179 | self.adpavgpool = nn.AdaptiveAvgPool2d(1) 180 | self.fc_ulti = nn.Linear(512 * block.expansion * nparts, num_classes) 181 | if self.nparts > 1: 182 | self.adpmaxpool = nn.AdaptiveMaxPool2d(1) 183 | self.fc_plty = nn.Linear(256 * block.expansion * nparts, num_classes) 184 | self.fc_cmbn = nn.Linear(256 * block.expansion * nparts, num_classes) 185 | 186 | # for the last convolutional layer 187 | self.conv2_1 = nn.Conv2d(512 * block.expansion, 256 * block.expansion, kernel_size=1, bias=False) 188 | self.conv2_2 = nn.Conv2d(512 * block.expansion, 256 * block.expansion, kernel_size=1, bias=False) 189 | # self.bn2_1 = nn.BatchNorm2d(256 * block.expansion) 190 | # self.bn2_2 = nn.BatchNorm2d(256 * block.expansion) 191 | 192 | # for the penultimate layer 193 | self.conv3_1 = nn.Conv2d(256 * block.expansion, 256 * block.expansion, kernel_size=3, padding=1, bias=False) 194 | self.conv3_2 = nn.Conv2d(256 * block.expansion, 256 * block.expansion, kernel_size=3, padding=1, bias=False) 195 | self.bn3_1 = nn.BatchNorm2d(256 * block.expansion) 196 | self.bn3_2 = nn.BatchNorm2d(256 * block.expansion) 197 | 198 | if nparts == 3: 199 | self.conv2_3 = nn.Conv2d(512 * block.expansion, 256 * block.expansion, kernel_size=1, bias=False) 200 | self.conv3_3 = nn.Conv2d(256 * block.expansion, 256 * block.expansion, kernel_size=3, padding=1, bias=False) 201 | self.bn3_3 = nn.BatchNorm2d(256 * block.expansion) 202 | 203 | # initializing params 204 | for m in self.modules(): 205 | if isinstance(m, nn.Conv2d): 206 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 207 | m.weight.data.normal_(0, math.sqrt(2. / n)) 208 | elif isinstance(m, nn.BatchNorm2d): 209 | m.weight.data.fill_(1) 210 | m.bias.data.zero_() 211 | 212 | def _make_layer(self, block, planes, blocks, stride=1, reduction=16, meflag=False, nparts=1): 213 | downsample = None 214 | if stride != 1 or self.inplanes != planes * block.expansion: 215 | downsample = nn.Sequential( 216 | nn.Conv2d(self.inplanes, planes * block.expansion, 217 | kernel_size=1, stride=stride, bias=False), 218 | nn.BatchNorm2d(planes * block.expansion), 219 | ) 220 | 221 | layers = [] 222 | layers.append(block(self.inplanes, planes, stride, downsample, reduction)) 223 | self.inplanes = planes * block.expansion 224 | for i in range(1, blocks): 225 | if i == blocks - 1 and meflag is True: 226 | layers.append(block(self.inplanes, planes, reduction=reduction, meflag=meflag, nparts=nparts)) 227 | else: 228 | layers.append(block(self.inplanes, planes, reduction=reduction)) 229 | 230 | return nn.Sequential(*layers) 231 | 232 | def forward(self, x): 233 | x = self.conv1(x) 234 | x = self.bn1(x) 235 | x = self.relu(x) 236 | x = self.maxpool(x) 237 | 238 | x = self.layer1(x) 239 | x = self.layer2(x) 240 | 241 | if self.meflag: 242 | x, plty_parts = self.layer3(x) 243 | _, ulti_parts = self.layer4(x) 244 | 245 | cmbn_ftres = list() 246 | for i in range(self.nparts): 247 | # pdb.set_trace() 248 | if i == 0: 249 | ulti_parts_iplt = F.interpolate(self.conv2_1(ulti_parts[i]), 28) 250 | cmbn_ftres.append(self.adpavgpool(self.bn3_1(self.conv3_1(torch.add(plty_parts[i], ulti_parts_iplt))))) 251 | elif i == 1: 252 | ulti_parts_iplt = F.interpolate(self.conv2_2(ulti_parts[i]), 28) 253 | cmbn_ftres.append(self.adpavgpool(self.bn3_2(self.conv3_2(torch.add(plty_parts[i], ulti_parts_iplt))))) 254 | elif i == 2: 255 | ulti_parts_iplt = F.interpolate(self.conv2_3(ulti_parts[i]), 28) 256 | cmbn_ftres.append(self.adpavgpool(self.bn3_3(self.conv3_3(torch.add(plty_parts[i], ulti_parts_iplt))))) 257 | 258 | plty_parts[i] = self.adpmaxpool(plty_parts[i]) 259 | ulti_parts[i] = self.adpavgpool(ulti_parts[i]) 260 | 261 | # for the penultimate layer 262 | xp = torch.cat(plty_parts, 1) 263 | xp = xp.view(xp.size(0), -1) 264 | xp = self.fc_plty(xp) 265 | 266 | # for the final layer 267 | xf = torch.cat(ulti_parts, 1) 268 | xf = xf.view(xf.size(0), -1) 269 | xf = self.fc_ulti(xf) 270 | 271 | # for the combined feature 272 | xc = torch.cat(cmbn_ftres, 1) 273 | xc = xc.view(xc.size(0), -1) 274 | xc = self.fc_cmbn(xc) 275 | 276 | return xf, xp, xc, ulti_parts, plty_parts, cmbn_ftres 277 | else: 278 | x = self.layer3(x) 279 | x = self.layer4(x) 280 | 281 | x = self.adpavgpool(x) 282 | x = x.view(x.size(0), -1) 283 | x = self.fc_ulti(x) 284 | return x 285 | 286 | 287 | 288 | ########################################## Models 289 | def senet50(num_classes=200, nparts=1, **kwargs): 290 | 291 | if nparts > 1: 292 | # resnet with osme 293 | kwargs.setdefault('meflag', True) 294 | else: 295 | # the normal resnet 296 | kwargs.setdefault('meflag', False) 297 | 298 | rd = [16, 32, 64, 128] 299 | 300 | if kwargs['meflag']: 301 | model = SeNet(SEBottleneck, [3, 4, 6, 3], num_classes=num_classes, rd=rd, nparts=nparts, meflag=True) 302 | else: 303 | # vanilla senet 304 | model = SeNet(SEBottleneck, [3, 4, 6, 3], num_classes=num_classes, rd=rd, nparts=nparts, meflag=False) 305 | 306 | return model 307 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import pickle as pk 3 | import pdb 4 | 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torchvision import transforms, datasets 9 | import torch.optim as opt 10 | from torch.optim import lr_scheduler 11 | import torch.utils.model_zoo as model_zoo 12 | 13 | from utils import imdb #, myimagefolder, mydataloader 14 | progpath = os.path.dirname(os.path.realpath(__file__)) # /home/luowei/Codes/feasc-msc 15 | sys.path.append(progpath) 16 | import modellearning 17 | from initialization import init_crossx_params, data_transform 18 | 19 | 20 | 21 | """ user defined variables """ 22 | backbone = "resnet" # or "senet" 23 | datasetname = "vggaircraft" # we experiment on 5 datasets: "nabirds", "cubbirds", "stcars", "stdogs", and "vggaircraft" 24 | batchsize = 32 25 | 26 | #################### model zoo: it's a folder to place vanilla models, like ResNet-50 27 | modelzoopath = "/home/luowei/Codes/pymodels" 28 | sys.path.append(os.path.dirname(modelzoopath)) 29 | import pymodels 30 | 31 | ##################### Dataset path 32 | datasets_path = os.path.expanduser("/home/luowei/Datasets") 33 | datasetpath = os.path.join(datasets_path, datasetname) 34 | 35 | 36 | device = torch.device("cuda:0" if torch.cuda.is_available() > 0 else "cpu") 37 | 38 | # organizing data 39 | assert imdb.creatDataset(datasetpath, datasetname=datasetname) == True, "Failing to creat train/val/test sets" 40 | data_transform = data_transform(datasetname) 41 | 42 | # using ground truth data 43 | datasplits = {x: datasets.ImageFolder(os.path.join(datasetpath, x), data_transform[x]) 44 | for x in ['trainval', 'test']} 45 | 46 | dataloader = {x: torch.utils.data.DataLoader(datasplits[x], batch_size=batchsize, shuffle=True, num_workers=8) 47 | for x in ['trainval', 'test']} 48 | 49 | datasplit_sizes = {x: len(datasplits[x]) for x in ['trainval', 'test']} 50 | class_names = datasplits['trainval'].classes 51 | num_classes = len(class_names) 52 | 53 | 54 | 55 | 56 | ################################### constructing or loading model 57 | if datasetname is 'stdogs' and backbone is 'senet': 58 | nparts = 3 59 | else: 60 | nparts = 2 # number of parts you want to use for your dataset 61 | 62 | if backbone is 'senet': 63 | if datasetname in ['cubbirds', 'nabirds']: 64 | import crossxsenetmix as crossxmodel 65 | model = crossxmodel.senet50(num_classes=num_classes, nparts=nparts) 66 | else: 67 | import crossxsenetavg as crossxmodel 68 | model = crossxmodel.senet50(num_classes=num_classes, nparts=nparts) 69 | elif backbone is 'resnet': 70 | if datasetname in ['cubbirds', 'nabirds']: 71 | import crossxresnetmix as crossxmodel 72 | model = crossxmodel.resnet50(pretrained=True, modelpath=modelzoopath, num_classes=num_classes, nparts=nparts) 73 | else: 74 | import crossxresnetavg as crossxmodel 75 | model = crossxmodel.resnet50(pretrained=True, modelpath=modelzoopath, num_classes=num_classes, nparts=nparts) 76 | 77 | 78 | if torch.cuda.device_count() > 0: 79 | model = nn.DataParallel(model) 80 | model.to(device) 81 | 82 | 83 | if backbone is 'senet': 84 | # load pretrained senet weights 85 | state_dict_path = "pretrained-weights.pkl" 86 | state_params = torch.load(state_dict_path, map_location=device) 87 | state_params['weight'].pop('module.fc.weight') 88 | state_params['weight'].pop('module.fc.bias') 89 | model.load_state_dict(state_params['weight'], strict=False) 90 | 91 | 92 | # creating loss functions 93 | gamma1, gamma2, gamma3, lr, epochs = init_crossx_params(backbone, datasetname) 94 | cls_loss = nn.CrossEntropyLoss() 95 | reg_loss_ulti = crossxmodel.RegularLoss(gamma=gamma1, nparts=nparts) 96 | reg_loss_plty = crossxmodel.RegularLoss(gamma=gamma2, nparts=nparts) 97 | reg_loss_cmbn = crossxmodel.RegularLoss(gamma=gamma3, nparts=nparts) 98 | kl_loss = nn.KLDivLoss(reduction='sum') 99 | criterion = [cls_loss, reg_loss_ulti, reg_loss_plty, reg_loss_cmbn, kl_loss] 100 | 101 | 102 | # creating optimizer 103 | optmeth = 'sgd' 104 | optimizer = opt.SGD(model.parameters(), lr=lr, momentum=0.9) 105 | 106 | 107 | # creating optimization scheduler 108 | #scheduler = lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1) 109 | scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[15, 25], gamma=0.1) 110 | 111 | 112 | # training the model 113 | isckpt = False # True for restoring model from checking point 114 | # print parameters 115 | print("{}: {}, gamma: {}_{}_{}, nparts: {}, epochs: {}".format(optmeth, lr, gamma1, gamma2, gamma3, nparts, epochs)) 116 | 117 | model, train_rsltparams = modellearning.train(model, dataloader, criterion, optimizer, scheduler, backbone=backbone, datasetname=datasetname, isckpt=isckpt, epochs=epochs) 118 | 119 | 120 | #### save model 121 | modelpath = './models' 122 | if backbone is 'senet': 123 | modelname = r"{}_parts{}-sc{}_{}_{}-{}{}-SeNet50-crossx.model".format(datasetname, nparts, gamma1, gamma2, gamma3, optmeth, lr) 124 | else: 125 | modelname = r"{}_parts{}-sc{}_{}_{}-{}{}-ResNet50-crossx.model".format(datasetname, nparts, gamma1, gamma2, gamma3, optmeth, lr) 126 | torch.save(model.state_dict(), os.path.join(modelpath, modelname)) 127 | 128 | 129 | ########################### evaluation 130 | #testsplit = datasets.ImageFolder(os.path.join(datasetpath, 'test'), data_transform['val']) 131 | #testloader = torch.utils.data.DataLoader(testsplit, batch_size=64, shuffle=False, num_workers=8) 132 | #test_rsltparams = modellearning.eval(model, testloader) 133 | 134 | 135 | ########################### record results 136 | #filename = r"{}-parts{}-sc{}_{}_{}-{}{}.pkl".format(datasetname, nparts, gamma1, gamma2, gamma3, optmeth, lr) 137 | #rsltpath = os.path.join(progpath, 'results', filename) 138 | #with open(rsltpath, 'wb') as f: 139 | # pk.dump({'train': train_rsltparams, 'test': test_rsltparams}, f) 140 | -------------------------------------------------------------------------------- /initialization.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | 3 | def init_crossx_params(backbone, datasetname): 4 | 5 | epochs = 30 6 | gamma1, gamma2, gamma3 = 0.0, 0.0, 0.0 7 | lr = 0.0 8 | 9 | if backbone is 'senet': 10 | if datasetname is 'nabirds': 11 | gamma1 = 0.1 12 | gamma2 = 0.25 13 | gamma3 = 0.5 14 | elif datasetname in ['cubbirds', 'stcars']: 15 | gamma1 = 1 16 | gamma2 = 0.25 17 | gamma3 = 1 18 | elif datasetname is 'stdogs': 19 | gamma1 = 1 20 | gamma2 = 0.5 21 | gamma3 = 1 22 | elif datasetname is 'vggaricraft': 23 | gamma1 = 0.5 24 | gamma2 = 0.1 25 | gamma3 = 0.1 26 | else: 27 | pass 28 | elif backbone is 'resnet': 29 | if datasetname in ['nabirds', 'cubbirds']: 30 | gamma1 = 0.5 31 | gamma2 = 0.25 32 | gamma3 = 0.5 33 | elif datasetname is 'stcars': 34 | gamma1 = 1 35 | gamma2 = 0.25 36 | gamma3 = 1 37 | elif datasetname is 'stdogs': 38 | gamma1 = 0.01 39 | gamma2 = 0.01 40 | gamma3 = 1 41 | elif datasetname is 'vggaricraft': 42 | gamma1 = 0.5 43 | gamma2 = 0.1 44 | gamma3 = 0.5 45 | else: 46 | pass 47 | else: 48 | pass 49 | 50 | if datasetname is 'stdogs': 51 | lr = 0.001 52 | else: 53 | lr = 0.01 54 | 55 | return gamma1, gamma2, gamma3, lr, epochs 56 | 57 | def data_transform(datasetname=None): 58 | if datasetname in ['cubbirds', 'nabirds', 'vggaircraft']: 59 | return { 60 | 'trainval': transforms.Compose([ 61 | transforms.Resize((600, 600)), 62 | transforms.RandomCrop((448, 448)), 63 | transforms.RandomHorizontalFlip(), 64 | transforms.ToTensor(), 65 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 66 | ]), 67 | 'test': transforms.Compose([ 68 | transforms.Resize((600, 600)), 69 | transforms.CenterCrop((448, 448)), 70 | transforms.ToTensor(), 71 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 72 | ])} 73 | else: 74 | return { 75 | 'trainval': transforms.Compose([ 76 | transforms.Resize((448, 448)), 77 | transforms.RandomHorizontalFlip(), 78 | transforms.ToTensor(), 79 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 80 | ]), 81 | 'test': transforms.Compose([ 82 | transforms.Resize((448, 448)), 83 | transforms.ToTensor(), 84 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 85 | ])} 86 | 87 | 88 | if __name__ == "__main__": 89 | pass -------------------------------------------------------------------------------- /modellearning.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import time 4 | import torch.nn.functional as F 5 | import pdb 6 | from utils import modelserial 7 | 8 | device = torch.device("cuda:0" if torch.cuda.is_available() > 0 else "cpu") 9 | 10 | def train(model, dataloader, criterion, optimizer, scheduler, backbone='resnet', datasetname=None, isckpt=False, epochs=30): 11 | 12 | # get the size of train and evaluation data 13 | if isinstance(dataloader, dict): 14 | dataset_sizes = {x: len(dataloader[x].dataset) for x in dataloader.keys()} 15 | print(dataset_sizes) 16 | else: 17 | dataset_size = len(dataloader.dataset) 18 | 19 | if not isinstance(criterion, list): 20 | criterion = [criterion] 21 | 22 | best_model_params = copy.deepcopy(model.state_dict()) 23 | best_acc = 0.0 24 | global_step = 0 25 | global_step_resume = 0 26 | best_epoch = 0 27 | best_step = 0 28 | start_epoch = -1 29 | 30 | if isckpt: 31 | checkpoint = modelserial.loadCheckpoint(datasetname) 32 | start_epoch = checkpoint['epoch'] 33 | best_acc = checkpoint['best_acc'] 34 | model.load_state_dict(checkpoint['state_dict']) 35 | best_model_params = checkpoint['best_state_dict'] 36 | best_epoch = checkpoint['best_epoch'] 37 | 38 | since = time.time() 39 | for epoch in range(start_epoch+1, epochs): 40 | print('Epoch {}/{}'.format(epoch, epochs)) 41 | print('-' * 10) 42 | 43 | for phase in ['trainval', 'test']: 44 | if phase == 'trainval': 45 | scheduler.step() 46 | model.train() # Set model to training mode 47 | global_step = global_step_resume 48 | else: 49 | model.eval() # Set model to evaluate mode 50 | global_step_resume = global_step 51 | 52 | running_cls_loss = 0.0 53 | running_reg_loss = 0.0 54 | running_corrects = 0 55 | 56 | # Iterate over data. 57 | for inputs, labels in dataloader[phase]: 58 | 59 | inputs = inputs.to(device) 60 | labels = labels.to(device) 61 | 62 | # zero the parameter gradients 63 | optimizer.zero_grad() 64 | 65 | # forward 66 | with torch.set_grad_enabled(phase == 'trainval'): 67 | 68 | if model.module.nparts == 1: 69 | outputs = model(inputs) 70 | _, preds = torch.max(outputs, 1) 71 | all_loss = criterion[0](outputs, labels) 72 | else: 73 | if datasetname is 'stdogs' and backbone is 'resnet': 74 | outputs_ulti, outputs_plty, _, ulti_ftrs, plty_ftrs, _ = model(inputs) 75 | _, preds = torch.max(outputs_ulti+outputs_plty, 1) 76 | cls_loss = criterion[0](outputs_ulti+outputs_plty, labels) 77 | else: 78 | outputs_ulti, outputs_plty, outputs_cmbn, ulti_ftrs, plty_ftrs, cmbn_ftrs = model(inputs) 79 | _, preds = torch.max(outputs_ulti+outputs_plty+outputs_cmbn, 1) 80 | cls_loss = criterion[0](outputs_ulti+outputs_plty+outputs_cmbn, labels) 81 | 82 | reg_loss_cmbn = criterion[3](cmbn_ftrs) 83 | outputs_cmbn = F.log_softmax(outputs_cmbn, 1) 84 | 85 | reg_loss_ulti = criterion[1](ulti_ftrs) 86 | reg_loss_plty = criterion[2](plty_ftrs) 87 | 88 | outputs_plty = F.log_softmax(outputs_plty, 1) 89 | outputs_ulti = F.softmax(outputs_ulti, 1) 90 | 91 | if datasetname is 'stdogs' and backbone is 'resnet': 92 | kl_loss = (criterion[4](outputs_plty, outputs_ulti)) / inputs.size(0) 93 | all_loss = reg_loss_ulti + reg_loss_plty + kl_loss + cls_loss 94 | else: 95 | kl_loss = (criterion[4](outputs_plty, outputs_ulti) + criterion[4](outputs_cmbn, outputs_ulti)) / inputs.size(0) 96 | all_loss = reg_loss_ulti + reg_loss_plty + reg_loss_cmbn + kl_loss + cls_loss 97 | 98 | # backward + optimize only if in training phase 99 | if phase == 'trainval': 100 | all_loss.backward() 101 | optimizer.step() 102 | 103 | # statistics 104 | if model.module.nparts == 1: 105 | running_cls_loss += all_loss.item() * inputs.size(0) 106 | else: 107 | running_cls_loss += cls_loss.item() * inputs.size(0) 108 | if datasetname is 'stdogs' and backbone is 'resnet': 109 | running_reg_loss += (reg_loss_ulti.item() + reg_loss_plty.item() + kl_loss.item()) * inputs.size(0) 110 | else: 111 | running_reg_loss += (reg_loss_ulti.item() + reg_loss_plty.item() + reg_loss_cmbn.item() + kl_loss.item()) * inputs.size(0) 112 | 113 | running_corrects += torch.sum(preds == labels.data) 114 | 115 | 116 | if model.module.nparts == 1: 117 | epoch_loss = running_cls_loss / dataset_sizes[phase] 118 | else: 119 | epoch_loss = (running_cls_loss + running_reg_loss) / dataset_sizes[phase] 120 | 121 | epoch_acc = running_corrects.double() / dataset_sizes[phase] 122 | 123 | print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc)) 124 | 125 | # deep copy the model 126 | if phase == 'test' and epoch_acc > best_acc: 127 | best_acc = epoch_acc 128 | best_epoch = epoch 129 | best_step = global_step_resume 130 | best_model_params = copy.deepcopy(model.state_dict()) 131 | 132 | if phase == 'test' and epoch % 2 == 1: 133 | modelserial.saveCheckpoint({'epoch': epoch, 134 | 'best_epoch': best_epoch, 135 | 'state_dict': model.state_dict(), 136 | 'best_state_dict': best_model_params, 137 | 'best_acc': best_acc}, datasetname) 138 | print() 139 | 140 | time_elapsed = time.time() - since 141 | print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 142 | print('Best test Acc: {:4f}'.format(best_acc)) 143 | 144 | rsltparams = dict() 145 | rsltparams['val_acc'] = best_acc.item() 146 | rsltparams['gamma1'] = criterion[1].gamma 147 | rsltparams['gamma2'] = criterion[2].gamma 148 | rsltparams['gamma3'] = criterion[3].gamma 149 | rsltparams['lr'] = optimizer.param_groups[0]['lr'] 150 | rsltparams['best_epoch'] = best_epoch 151 | rsltparams['best_step'] = best_step 152 | 153 | # load best model weights 154 | model.load_state_dict(best_model_params) 155 | return model, rsltparams 156 | 157 | 158 | def eval(model, dataloader=None, datasetname=None): 159 | model.eval() 160 | datasize = len(dataloader.dataset) 161 | running_corrects = 0 162 | num_label_counts = dict() 163 | pred_label_counts = dict() 164 | 165 | for inputs, labels in dataloader: 166 | 167 | for label in labels.data: 168 | num_label_counts.setdefault(label.item(), 0) 169 | num_label_counts[label.item()] += 1 170 | 171 | inputs = inputs.to(device) 172 | labels = labels.to(device) 173 | 174 | with torch.no_grad(): 175 | if model.module.nparts == 1: 176 | outputs = model(inputs) 177 | preds = torch.argmax(outputs, dim=1) 178 | else: 179 | outputs_ulti, outputs_plty, outputs_cmbn, _, _, _ = model(inputs) 180 | if datasetname is 'stdogs': 181 | preds = torch.argmax(outputs_ulti + outputs_plty, dim=1) 182 | else: 183 | preds = torch.argmax(outputs_ulti + outputs_plty + outputs_cmbn, dim=1) 184 | 185 | if datasetname is 'vggaircraft': 186 | for i, label in enumerate(preds.data): 187 | if label == labels[i]: 188 | pred_label_counts.setdefault(label.item(), 0) 189 | pred_label_counts[label.item()] += 1 190 | 191 | running_corrects += torch.sum(preds == labels.data) 192 | 193 | acc = torch.div(running_corrects.double(), datasize).item() 194 | print("{}: Test Accuracy: {}".format(datasetname, acc)) 195 | 196 | if datasetname is 'vggaircraft': 197 | running_corrects_ = 0 198 | for key in pred_label_counts.keys(): 199 | running_corrects_ += pred_label_counts[key] / num_label_counts[key] 200 | avg_acc = running_corrects_ / len(num_label_counts) 201 | print("{}: Class Average Accuracy: - {}".format(datasetname, avg_acc)) 202 | 203 | 204 | rsltparams = dict() 205 | rsltparams['test_acc'] = acc 206 | return rsltparams 207 | 208 | 209 | -------------------------------------------------------------------------------- /prediction.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import pickle as pk 3 | import pdb 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import transforms, datasets 8 | import torch.optim as opt 9 | from torch.optim import lr_scheduler 10 | import torch.utils.model_zoo as model_zoo 11 | 12 | from initialization import data_transform 13 | from utils import imdb 14 | progpath = os.path.dirname(os.path.realpath(__file__)) 15 | sys.path.append(progpath) 16 | import modellearning 17 | 18 | 19 | """ user params """ 20 | datasetname = "vggaircraft" # 'cubbirds', 'nabirds', 'stdogs', 'stcars' 21 | batchsize = 8 22 | backbone = 'resnet' # or 'senet' 23 | device = torch.device("cuda:0" if torch.cuda.is_available() > 0 else "cpu") 24 | 25 | 26 | 27 | #################### model zoo 28 | modelzoopath = "/home/luowei/Codes/pymodels" 29 | sys.path.append(os.path.dirname(modelzoopath)) 30 | # import pymodels 31 | 32 | ##################### Dataset path 33 | datasets_path = os.path.expanduser("/home/luowei/Datasets") 34 | datasetpath = os.path.join(datasets_path, datasetname) 35 | 36 | 37 | ################### organizing data 38 | assert imdb.creatDataset(datasetpath, datasetname=datasetname) == True, "Failing to creat train/val/test sets" 39 | data_transform = data_transform(datasetname) 40 | 41 | 42 | testsplit = datasets.ImageFolder(os.path.join(datasetpath, 'test'), data_transform['test']) 43 | testloader = torch.utils.data.DataLoader(testsplit, batch_size=batchsize, shuffle=False, num_workers=8) 44 | 45 | 46 | datasplit_sizes = len(testsplit) 47 | class_names = testsplit.classes 48 | num_classes = len(class_names) 49 | 50 | ################################### constructing or loading model 51 | if datasetname is 'stdogs' and backbone is 'senet': 52 | nparts = 3 53 | else: 54 | nparts = 2 # number of parts you want to use for your dataset 55 | 56 | if backbone is 'senet': 57 | if datasetname in ['cubbirds', 'nabirds']: 58 | import mysenetmodelsmix as crossxmodel 59 | model = crossxmodel.senet50(num_classes=num_classes, nparts=nparts) 60 | else: 61 | import mysenetmodelsavg as crossxmodel 62 | model = crossxmodel.senet50(num_classes=num_classes, nparts=nparts) 63 | elif backbone is 'resnet': 64 | if datasetname in ['cubbirds', 'nabirds']: 65 | import myresnetmodelsmix as crossxmodel 66 | model = crossxmodel.resnet50(num_classes=num_classes, nparts=nparts) 67 | else: 68 | import myresnetmodelsavg as crossxmodel 69 | model = crossxmodel.resnet50(num_classes=num_classes, nparts=nparts) 70 | 71 | 72 | 73 | if torch.cuda.device_count() > 0: 74 | model = nn.DataParallel(model) 75 | model.to(device) 76 | 77 | if backbone is 'senet': 78 | if datasetname is 'nabirds': 79 | state_dict_path = "/your/local/path/nabirds_CrossX-SENet50.model" 80 | elif datasetname is 'cubbirds': 81 | state_dict_path = "/your/local/path/cubbirds_CrossX-SENet50.model" 82 | elif datasetname is 'stcars': 83 | state_dict_path = "/your/local/path/stcars_CrossX-SENet50.model" 84 | elif datasetname is 'stdogs': 85 | state_dict_path = "/your/local/path/stdogs_CrossX-SENet50.model" 86 | elif datasetname is 'vggaircraft': 87 | state_dict_path = "/your/local/path/vggaircraft_CrossX-SENet50.model" 88 | elif backbone is 'resnet': 89 | if datasetname is 'nabirds': 90 | state_dict_path = "/your/local/path/nabirds_CrossX-ResNet50.model" 91 | elif datasetname is 'cubbirds': 92 | state_dict_path = "/your/local/path/cubbirds_CrossX-ResNet50.model" 93 | elif datasetname is 'stcars': 94 | state_dict_path = "/your/local/path/stcars_CrossX-ResNet50.model" 95 | elif datasetname is 'stdogs': 96 | state_dict_path = "/your/local/path/stdogs_CrossX-ResNet50.model" 97 | elif datasetname is 'vggaircraft': 98 | state_dict_path = "/your/local/path/vggaircraft_CrossX-ResNet50.model" 99 | 100 | 101 | state_params = torch.load(state_dict_path, map_location=device) 102 | model.load_state_dict(state_params, strict=False) 103 | 104 | 105 | # ########################### evaluation 106 | test_rsltparams = lwmodellearning.eval(model, testloader, datasetname) 107 | 108 | 109 | # ########################### record results 110 | # filename = r"{}-CrossX-{}.pkl".format(datasetname, backbone) 111 | # rsltpath = os.path.join(progpath, 'results', filename) 112 | # with open(rsltpath, 'wb') as f: 113 | # pk.dump({'test': test_rsltparams}, f) 114 | -------------------------------------------------------------------------------- /utils/imdb.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | from torch.utils.data import Dataset 3 | from PIL import Image 4 | import pandas as pd 5 | import numpy as np 6 | import pickle as pk 7 | import shutil 8 | import os.path as osp 9 | 10 | 11 | class CubBirds(object): 12 | 13 | def __init__(self, root): 14 | self.file_dict = {'Cid':'classes.txt', 15 | 'imageCid':'image_class_labels.txt', 16 | 'imageId':'images.txt', 17 | 'imageTVT':'train_test_split.txt' 18 | } 19 | self.root = root 20 | 21 | def _className(self): 22 | with open(os.path.join(self.root, self.file_dict['Cid'])) as f: 23 | coxt = f.readlines() 24 | class_names = [x.split()[-1] for x in coxt] 25 | self.class_names = class_names 26 | return self.class_names 27 | 28 | def _imdb(self): 29 | 30 | with open(os.path.join(self.root, self.file_dict['imageId'])) as f: 31 | coxt = f.readlines() 32 | imageId = [int(x.split()[0]) for x in coxt] 33 | imagePath = [x.split()[-1] for x in coxt] 34 | df = {'imageId':imageId, 'imagePath':imagePath} 35 | imdb = pd.DataFrame(data=df) 36 | imageCid = [x.split('/')[0] for x in imdb['imagePath']] 37 | with open(os.path.join(self.root, self.file_dict['imageTVT'])) as f: 38 | coxt = f.readlines() 39 | imageTVT = [int(x.split()[-1]) for x in coxt] 40 | imdb['imageCid'] = imageCid 41 | imdb['imageTVT'] = imageTVT 42 | 43 | return imdb 44 | 45 | 46 | class StCars(object): 47 | 48 | def __init__(self, root): 49 | with open(osp.join(root, 'imdb.pkl'), 'rb') as handle: 50 | annos = pk.load(handle) 51 | self.classnames = annos['classnames'] 52 | self.classdict = annos['classdict'] 53 | self.testdata = annos['annos_test'] 54 | self.traindata = annos['annos_train'] 55 | self.root = root 56 | 57 | def _createTVTFolders(self): 58 | if not osp.exists(osp.join(self.root, 'train', self.classnames[0])): 59 | for classname in self.classnames: 60 | os.makedirs(osp.join(self.root, 'train', classname)) 61 | os.makedirs(osp.join(self.root, 'val', classname)) 62 | os.makedirs(osp.join(self.root, 'trainval', classname)) 63 | os.makedirs(osp.join(self.root, 'test', classname)) 64 | 65 | 66 | class StDogs(object): 67 | 68 | def __init__(self, root): 69 | self.root = root 70 | with open(osp.join(root, 'imdb.pkl'), 'rb') as handle: 71 | annos = pk.load(handle) 72 | 73 | self.classdict = annos['classdict'] 74 | self.classnames = annos['classnames'] 75 | self.traindata = annos['traindata'] 76 | self.trainlabels = annos['trainlabels'] 77 | self.testdata = annos['testdata'] 78 | self.testlabels = annos['testlabels'] 79 | 80 | def _createTVTFolders(self): 81 | if not osp.exists(osp.join(self.root, 'train', self.classnames[0])): 82 | for classname in self.classnames: 83 | os.makedirs(osp.join(self.root, 'train', classname)) 84 | os.makedirs(osp.join(self.root, 'val', classname)) 85 | os.makedirs(osp.join(self.root, 'trainval', classname)) 86 | os.makedirs(osp.join(self.root, 'test', classname)) 87 | 88 | 89 | class VggAircraft(object): 90 | 91 | def __init__(self, root): 92 | with open(os.path.join(root, 'imdb.pkl'), 'rb') as handle: 93 | annos = pk.load(handle) 94 | 95 | self.classdict_variant = annos['classdict_variant'] 96 | self.classdict_family = annos['classdict_variant'] 97 | self.classdict_manufacturer = annos['classdict_variant'] 98 | self.classnames_variant = annos['classnames_variant'] 99 | self.classnames_family = annos['classnames_family'] 100 | self.classnames_manufacturer = annos['classnames_manufacturer'] 101 | self.traindata = annos['traindata'] 102 | self.trainvaldata = annos['trainvaldata'] 103 | self.testdata = annos['testdata'] 104 | self.valdata = annos['valdata'] 105 | self.root = root 106 | 107 | def _createTVTFolders(self): 108 | if not osp.exists(osp.join(self.root, 'train', self.classnames_variant[0])): 109 | for classname in self.classnames_variant: 110 | os.makedirs(osp.join(self.root, 'train', classname)) 111 | os.makedirs(osp.join(self.root, 'val', classname)) 112 | os.makedirs(osp.join(self.root, 'trainval', classname)) 113 | os.makedirs(osp.join(self.root, 'test', classname)) 114 | 115 | 116 | class NaBirds(object): 117 | def __init__(self, root): 118 | with open(os.path.join(root, 'imdb.pkl'), 'rb') as handle: 119 | annos = pk.load(handle) 120 | 121 | self.prntclassid = annos['prntclassid'] 122 | self.classnames = annos['classnames'] 123 | self.subclassid = annos['subclassid'] 124 | self.testdata = annos['testdata'] 125 | self.traindata = annos['traindata'] 126 | self.root = root 127 | 128 | def _createTVTFolders(self): 129 | subclassid = list(set(self.subclassid.values())) 130 | if not osp.exists(osp.join(self.root, 'train', subclassid[0])): 131 | for classname in subclassid: 132 | os.makedirs(osp.join(self.root, 'train', classname)) 133 | os.makedirs(osp.join(self.root, 'val', classname)) 134 | os.makedirs(osp.join(self.root, 'trainval', classname)) 135 | os.makedirs(osp.join(self.root, 'test', classname)) 136 | 137 | 138 | class WdDogs(object): 139 | 140 | def __init__(self, root): 141 | with open(os.path.join(root, 'imdb.pkl'), 'rb') as handle: 142 | annos = pk.load(handle) 143 | 144 | self.classdict = annos['classdict'] 145 | self.classnames = annos['classnames'] 146 | self.traindata = annos['traindata'] 147 | self.trainvaldata = annos['trainvaldata'] 148 | self.testdata = annos['testdata'] 149 | self.valdata = annos['valdata'] 150 | self.root = root 151 | 152 | def _createTVTFolders(self): 153 | if not osp.exists(osp.join(self.root, 'train', self.classnames[0])): 154 | for classname in self.classnames: 155 | os.makedirs(osp.join(self.root, 'train', classname)) 156 | os.makedirs(osp.join(self.root, 'val', classname)) 157 | os.makedirs(osp.join(self.root, 'trainval', classname)) 158 | os.makedirs(osp.join(self.root, 'test', classname)) 159 | 160 | def creatDataset(root, datasetname=None): 161 | 162 | if datasetname is not None: 163 | trainpath = os.path.join(root, 'train') 164 | valpath = os.path.join(root, 'val') 165 | testpath = os.path.join(root, 'test') 166 | trainvalpath = os.path.join(root, 'trainval') 167 | 168 | if not os.path.exists(trainpath): 169 | os.makedirs(trainpath) 170 | os.makedirs(valpath) 171 | os.makedirs(testpath) 172 | os.makedirs(trainvalpath) 173 | 174 | # checking the train/val/test integrity 175 | train_folders = os.listdir(trainpath) 176 | val_folders = os.listdir(valpath) 177 | test_folders = os.listdir(testpath) 178 | trainval_folders = os.listdir(trainvalpath) 179 | assert len(train_folders) == len(val_folders) == len( 180 | test_folders), "The train/val/test datasets are not complete" 181 | num_train_data = sum([len(os.listdir(os.path.join(trainpath, subfolder))) for subfolder in train_folders]) 182 | num_val_data = sum([len(os.listdir(os.path.join(valpath, subfolder))) for subfolder in val_folders]) 183 | num_test_data = sum([len(os.listdir(os.path.join(testpath, subfolder))) for subfolder in test_folders]) 184 | num_trainval_data = sum( 185 | [len(os.listdir(os.path.join(trainvalpath, subfolder))) for subfolder in trainval_folders]) 186 | 187 | if datasetname is "cubbirds": 188 | 189 | if num_test_data+num_train_data+num_val_data == 11788 and num_train_data+num_val_data == num_trainval_data: 190 | print("train/val/test sets are already exist.") 191 | return True 192 | 193 | # if the train/val/test datasets are not exist 194 | birds = CubBirds(root) 195 | class_names = birds._className() 196 | 197 | if os.path.exists(os.path.join(root, 'imdb.pkl')): 198 | with open(os.path.join(root, 'imdb.pkl'),'rb') as f: 199 | pdata = pk.load(f) 200 | train_pd, val_pd,test_pd, = pdata['train'], pdata['val'], pdata['test'] 201 | else: 202 | imdb = birds._imdb() 203 | 204 | test_pd = imdb.loc[imdb['imageTVT']==0] 205 | train_tpd = imdb.loc[imdb['imageTVT'] == 1] 206 | 207 | train_pd = pd.DataFrame() 208 | val_pd = pd.DataFrame() 209 | 210 | for class_name in class_names: 211 | trainval = train_tpd.loc[train_tpd['imageCid'] == class_name] 212 | permuind = np.random.permutation(trainval.index) 213 | # print(permuind) 214 | train_pd = train_pd.append(trainval.loc[permuind[:-3]], ignore_index=True) 215 | val_pd = val_pd.append(trainval.loc[permuind[-3:]], ignore_index=True) 216 | # print(trainval.index) 217 | with open(os.path.join(root, 'imdb.pkl'),'wb') as f: 218 | pk.dump({'train':train_pd, 'val':val_pd, 'test':test_pd},f) 219 | 220 | for class_name in class_names: 221 | 222 | if not os.path.exists(os.path.join(trainvalpath, class_name)): 223 | os.mkdir(os.path.join(trainpath, class_name)) 224 | os.mkdir(os.path.join(valpath, class_name)) 225 | os.mkdir(os.path.join(testpath, class_name)) 226 | os.mkdir(os.path.join(trainvalpath, class_name)) 227 | 228 | 229 | train_dst_path = os.path.join(trainpath, class_name) 230 | val_dst_path = os.path.join(valpath, class_name) 231 | test_dst_path = os.path.join(testpath, class_name) 232 | trainval_dst_path = os.path.join(trainvalpath, class_name) 233 | 234 | newtrainpd = train_pd.loc[train_pd['imageCid'] == class_name] 235 | newtrainpd_index = newtrainpd.index 236 | for i_ in newtrainpd_index: 237 | src_path = os.path.join(root, 'images', newtrainpd.loc[i_,'imagePath']) 238 | shutil.copy(src_path, train_dst_path) 239 | shutil.copy(src_path, trainval_dst_path) 240 | 241 | newvalpd = val_pd.loc[val_pd['imageCid'] == class_name] 242 | newvalpd_index = newvalpd.index 243 | for i_ in newvalpd_index: 244 | src_path = os.path.join(root, 'images', newvalpd.loc[i_,'imagePath']) 245 | shutil.copy(src_path, val_dst_path) 246 | shutil.copy(src_path, trainval_dst_path) 247 | 248 | newtestpd = test_pd.loc[test_pd['imageCid'] == class_name] 249 | newtestpd_index = newtestpd.index 250 | for i_ in newtestpd.index: 251 | print(i_) 252 | src_path = os.path.join(root, 'images', newtestpd.loc[i_,'imagePath']) 253 | shutil.copy(src_path, test_dst_path) 254 | print("Successfully creating train/val/test sets.") 255 | return True 256 | elif datasetname is "stcars": 257 | 258 | if num_test_data+num_train_data+num_val_data == 16185 and num_train_data+num_val_data == num_trainval_data: 259 | print("train/val/test sets are already exist.") 260 | return True 261 | 262 | # if the train/val/test datasets are not exist 263 | cars = StCars(root) 264 | class_names = cars.classnames 265 | class_dict = cars.classdict 266 | traindata = cars.traindata 267 | testdata = cars.testdata 268 | 269 | cars._createTVTFolders() 270 | 271 | for line in traindata: 272 | train_src_path = osp.join(cars.root, line['relative_im_path']) 273 | class_name = class_dict[line['class']] 274 | trainval_dst_path = osp.join(trainvalpath, class_name) 275 | shutil.copy(train_src_path, trainval_dst_path) 276 | for line in testdata: 277 | test_src_path = osp.join(cars.root, line['relative_im_path']) 278 | class_name = class_dict[line['class']] 279 | test_dst_path = osp.join(testpath, class_name) 280 | shutil.copy(test_src_path, test_dst_path) 281 | 282 | # build train and validation sets from the trainval set 283 | subfolders = os.listdir(trainvalpath) 284 | for subfolder in subfolders: 285 | imgs = os.listdir(osp.join(trainvalpath, subfolder)) 286 | num_imgs = len(imgs) 287 | rndidx = np.random.permutation(num_imgs) 288 | num_val = int(np.floor(0.1 * num_imgs)) 289 | num_train = num_imgs - num_val 290 | train_dst_path = osp.join(trainpath, subfolder) 291 | val_dst_path = osp.join(valpath, subfolder) 292 | for idx in rndidx[:num_train]: 293 | shutil.copy(osp.join(trainvalpath, subfolder, imgs[idx]), train_dst_path) 294 | for idx in rndidx[num_train:]: 295 | shutil.copy(osp.join(trainvalpath, subfolder, imgs[idx]), val_dst_path) 296 | 297 | print("Successfully creating train/val/test sets.") 298 | return True 299 | elif datasetname is 'stdogs': 300 | 301 | if num_test_data+num_train_data+num_val_data == 20580 and num_train_data+num_val_data == num_trainval_data: 302 | print("train/val/test sets are already exist.") 303 | return True 304 | 305 | # if the train/val/test datasets are not exist 306 | dogs = StDogs(root) 307 | class_names = dogs.classnames 308 | class_dict = dogs.classdict 309 | train_data = dogs.traindata 310 | train_labels = dogs.trainlabels 311 | test_data = dogs.testdata 312 | test_labels = dogs.testlabels 313 | 314 | dogs._createTVTFolders() 315 | 316 | for imgpath in train_data: 317 | class_name = imgpath.split('/')[0] 318 | train_src_path = osp.join(dogs.root, 'Images', imgpath) 319 | trainval_dst_path = osp.join(trainvalpath, class_name) 320 | shutil.copy(train_src_path, trainval_dst_path) 321 | for imgpath in test_data: 322 | class_name = imgpath.split('/')[0] 323 | test_src_path = osp.join(dogs.root, 'Images', imgpath) 324 | test_dst_path = osp.join(testpath, class_name) 325 | shutil.copy(test_src_path, test_dst_path) 326 | 327 | # build train and validation sets from the trainval set 328 | subfolders = os.listdir(trainvalpath) 329 | for subfolder in subfolders: 330 | imgs = os.listdir(osp.join(trainvalpath, subfolder)) 331 | num_imgs = len(imgs) 332 | rndidx = np.random.permutation(num_imgs) 333 | num_val = int(np.floor(0.1 * num_imgs)) 334 | num_train = num_imgs - num_val 335 | train_dst_path = osp.join(trainpath, subfolder) 336 | val_dst_path = osp.join(valpath, subfolder) 337 | for idx in rndidx[:num_train]: 338 | shutil.copy(osp.join(trainvalpath, subfolder, imgs[idx]), train_dst_path) 339 | for idx in rndidx[num_train:]: 340 | shutil.copy(osp.join(trainvalpath, subfolder, imgs[idx]), val_dst_path) 341 | 342 | print("Successfully creating train/val/test sets.") 343 | return True 344 | elif datasetname is "vggaircraft": 345 | if num_test_data+num_train_data+num_val_data == 10000 and num_train_data+num_val_data == num_trainval_data: 346 | print("train/val/test sets are already exist.") 347 | return True 348 | aircrafts = VggAircraft(root) 349 | traindata = aircrafts.traindata 350 | trainvaldata = aircrafts.trainvaldata 351 | valdata = aircrafts.valdata 352 | testdata = aircrafts.testdata 353 | classnames = aircrafts.classnames_variant 354 | classdict = aircrafts.classdict_variant 355 | aircrafts._createTVTFolders() 356 | 357 | for row in traindata: 358 | img_src_path = osp.join(root, 'images', row[0]+'.jpg') 359 | img_dst_path = osp.join(root, 'train', row[1]) 360 | shutil.copy(img_src_path, img_dst_path) 361 | 362 | for row in trainvaldata: 363 | img_src_path = osp.join(root, 'images', row[0]+'.jpg') 364 | img_dst_path = osp.join(root, 'trainval', row[1]) 365 | shutil.copy(img_src_path, img_dst_path) 366 | 367 | for row in valdata: 368 | img_src_path = osp.join(root, 'images', row[0]+'.jpg') 369 | img_dst_path = osp.join(root, 'val', row[1]) 370 | shutil.copy(img_src_path, img_dst_path) 371 | 372 | for row in testdata: 373 | img_src_path = osp.join(root, 'images', row[0]+'.jpg') 374 | img_dst_path = osp.join(root, 'test', row[1]) 375 | shutil.copy(img_src_path, img_dst_path) 376 | 377 | print("Successfully creating train/val/test sets.") 378 | return True 379 | elif datasetname is "nabirds": 380 | if num_test_data+num_train_data+num_val_data == 48562 and num_train_data+num_val_data == num_trainval_data: 381 | print("train/val/test sets are already exist.") 382 | return True 383 | nabirds = NaBirds(root) 384 | traindata = nabirds.traindata 385 | testdata = nabirds.testdata 386 | classnames = nabirds.classnames 387 | prntclassid = nabirds.prntclassid 388 | subclassid = nabirds.subclassid 389 | nabirds._createTVTFolders() 390 | 391 | # trainval data 392 | for row in traindata: 393 | img_src_path = osp.join(root, row[0]) 394 | img_dst_path = osp.join(root, 'trainval', row[1]) 395 | shutil.copy(img_src_path, img_dst_path) 396 | 397 | # testing data 398 | for row in testdata: 399 | img_src_path = osp.join(root, row[0]) 400 | img_dst_path = osp.join(root, 'test', row[1]) 401 | shutil.copy(img_src_path, img_dst_path) 402 | 403 | # build train and validation sets from the trainval set 404 | subfolders = os.listdir(trainvalpath) 405 | for subfolder in subfolders: 406 | imgs = os.listdir(osp.join(trainvalpath, subfolder)) 407 | num_imgs = len(imgs) 408 | rndidx = np.random.permutation(num_imgs) 409 | num_val = int(np.floor(0.1 * num_imgs)) 410 | num_train = num_imgs - num_val 411 | train_dst_path = osp.join(trainpath, subfolder) 412 | val_dst_path = osp.join(valpath, subfolder) 413 | for idx in rndidx[:num_train]: 414 | shutil.copy(osp.join(trainvalpath, subfolder, imgs[idx]), train_dst_path) 415 | for idx in rndidx[num_train:]: 416 | shutil.copy(osp.join(trainvalpath, subfolder, imgs[idx]), val_dst_path) 417 | 418 | print("Successfully creating train/val/test sets.") 419 | return True 420 | elif datasetname is "wddogs": 421 | if num_test_data+num_train_data+num_val_data == 299458 and num_train_data+num_val_data == num_trainval_data: 422 | print("train/val/test sets are already exist.") 423 | return True 424 | wddogs = WdDogs(root) 425 | traindata = wddogs.traindata 426 | testdata = wddogs.testdata 427 | valdata = wddogs.valdata 428 | trainvaldata = wddogs.trainvaldata 429 | classnames = wddogs.classnames 430 | classdict = wddogs.classdict 431 | wddogs._createTVTFolders() 432 | 433 | for elmt in traindata: 434 | img_src_path = osp.join(root, elmt['imgname']) 435 | img_dst_path = osp.join(trainpath, elmt['imgCname']) 436 | shutil.copy(img_src_path, img_dst_path) 437 | 438 | for elmt in valdata: 439 | img_src_path = osp.join(root, elmt['imgname']) 440 | img_dst_path = osp.join(valpath, elmt['imgCname']) 441 | shutil.copy(img_src_path, img_dst_path) 442 | 443 | for elmt in trainvaldata: 444 | img_src_path = osp.join(root, elmt['imgname']) 445 | img_dst_path = osp.join(trainvalpath, elmt['imgCname']) 446 | shutil.copy(img_src_path, img_dst_path) 447 | 448 | for elmt in testdata: 449 | img_src_path = osp.join(root, elmt['imgname']) 450 | img_dst_path = osp.join(testpath, elmt['imgCname']) 451 | shutil.copy(img_src_path, img_dst_path) 452 | 453 | print("Successfully creating train/val/test sets.") 454 | return True 455 | else: 456 | print("This dataset has not been implemented.") 457 | return False 458 | 459 | else: 460 | print("You should provide the dataset name for proceeding.\n") 461 | return False 462 | 463 | 464 | if __name__ == "__main__": 465 | pass 466 | -------------------------------------------------------------------------------- /utils/modelserial.py: -------------------------------------------------------------------------------- 1 | import torch 2 | device = torch.device("cuda:0" if torch.cuda.is_available() > 0 else "cpu") 3 | 4 | def saveCheckpoint(state, datasetname=None): 5 | """Save checkpoint if a new best is achieved""" 6 | 7 | filename = './ckpt/{}-checkpoint.pth.tar' 8 | 9 | # if is_best: 10 | print("=> Saving a new best") 11 | torch.save(state, filename.format(datasetname)) # save checkpoint 12 | # else: 13 | # print("=> Validation Accuracy did not improve") 14 | 15 | def loadCheckpoint(datasetname): 16 | filename = './ckpt/{}-checkpoint.pth.tar' 17 | checkpoint = torch.load(filename.format(datasetname), map_location=device) 18 | return checkpoint -------------------------------------------------------------------------------- /utils/mydataloader.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.multiprocessing as multiprocessing 4 | from torch._C import _set_worker_signal_handlers, _update_worker_pids, \ 5 | _remove_worker_pids, _error_if_any_worker_fails 6 | import signal 7 | import functools 8 | import collections 9 | import re 10 | import sys 11 | import threading 12 | import traceback 13 | import os 14 | import time 15 | from torch._six import string_classes, int_classes, FileNotFoundError 16 | from torch._six import int_classes as _int_classes 17 | 18 | IS_WINDOWS = sys.platform == "win32" 19 | if IS_WINDOWS: 20 | import ctypes 21 | from ctypes.wintypes import DWORD, BOOL, HANDLE 22 | 23 | if sys.version_info[0] == 2: 24 | import Queue as queue 25 | else: 26 | import queue 27 | 28 | class Sampler(object): 29 | r"""Base class for all Samplers. 30 | 31 | Every Sampler subclass has to provide an __iter__ method, providing a way 32 | to iterate over indices of dataset elements, and a __len__ method that 33 | returns the length of the returned iterators. 34 | """ 35 | 36 | def __init__(self, data_source): 37 | pass 38 | 39 | def __iter__(self): 40 | raise NotImplementedError 41 | 42 | def __len__(self): 43 | raise NotImplementedError 44 | class RandomSampler(Sampler): 45 | r"""Samples elements randomly, without replacement. 46 | 47 | Arguments: 48 | data_source (Dataset): dataset to sample from 49 | """ 50 | 51 | def __init__(self, data_source): 52 | self.data_source = data_source 53 | 54 | def __iter__(self): 55 | return iter(torch.randperm(len(self.data_source)).tolist()) 56 | 57 | def __len__(self): 58 | return len(self.data_source) 59 | class SequentialSampler(Sampler): 60 | r"""Samples elements sequentially, always in the same order. 61 | 62 | Arguments: 63 | data_source (Dataset): dataset to sample from 64 | """ 65 | 66 | def __init__(self, data_source): 67 | self.data_source = data_source 68 | 69 | def __iter__(self): 70 | return iter(range(len(self.data_source))) 71 | 72 | def __len__(self): 73 | return len(self.data_source) 74 | class BatchSampler(Sampler): 75 | r"""Wraps another sampler to yield a mini-batch of indices. 76 | 77 | Args: 78 | sampler (Sampler): Base sampler. 79 | batch_size (int): Size of mini-batch. 80 | drop_last (bool): If ``True``, the sampler will drop the last batch if 81 | its size would be less than ``batch_size`` 82 | 83 | Example: 84 | >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) 85 | [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] 86 | >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) 87 | [[0, 1, 2], [3, 4, 5], [6, 7, 8]] 88 | """ 89 | 90 | def __init__(self, sampler, batch_size, drop_last): 91 | if not isinstance(sampler, Sampler): 92 | raise ValueError("sampler should be an instance of " 93 | "torch.utils.data.Sampler, but got sampler={}" 94 | .format(sampler)) 95 | if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \ 96 | batch_size <= 0: 97 | raise ValueError("batch_size should be a positive integeral value, " 98 | "but got batch_size={}".format(batch_size)) 99 | if not isinstance(drop_last, bool): 100 | raise ValueError("drop_last should be a boolean value, but got " 101 | "drop_last={}".format(drop_last)) 102 | self.sampler = sampler 103 | self.batch_size = batch_size 104 | self.drop_last = drop_last 105 | 106 | def __iter__(self): 107 | batch = [] 108 | for idx in self.sampler: 109 | batch.append(idx) 110 | if len(batch) == self.batch_size: 111 | yield batch 112 | batch = [] 113 | if len(batch) > 0 and not self.drop_last: 114 | yield batch 115 | 116 | def __len__(self): 117 | if self.drop_last: 118 | return len(self.sampler) // self.batch_size 119 | else: 120 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size 121 | 122 | 123 | class ExceptionWrapper(object): 124 | r"""Wraps an exception plus traceback to communicate across threads""" 125 | 126 | def __init__(self, exc_info): 127 | self.exc_type = exc_info[0] 128 | self.exc_msg = "".join(traceback.format_exception(*exc_info)) 129 | 130 | 131 | _use_shared_memory = False 132 | r"""Whether to use shared memory in default_collate""" 133 | 134 | MANAGER_STATUS_CHECK_INTERVAL = 5.0 135 | 136 | if IS_WINDOWS: 137 | # On Windows, the parent ID of the worker process remains unchanged when the manager process 138 | # is gone, and the only way to check it through OS is to let the worker have a process handle 139 | # of the manager and ask if the process status has changed. 140 | class ManagerWatchdog(object): 141 | def __init__(self): 142 | self.manager_pid = os.getppid() 143 | 144 | self.kernel32 = ctypes.WinDLL('kernel32', use_last_error=True) 145 | self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD) 146 | self.kernel32.OpenProcess.restype = HANDLE 147 | self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD) 148 | self.kernel32.WaitForSingleObject.restype = DWORD 149 | 150 | # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx 151 | SYNCHRONIZE = 0x00100000 152 | self.manager_handle = self.kernel32.OpenProcess(SYNCHRONIZE, 0, self.manager_pid) 153 | 154 | if not self.manager_handle: 155 | raise ctypes.WinError(ctypes.get_last_error()) 156 | 157 | def is_alive(self): 158 | # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx 159 | return self.kernel32.WaitForSingleObject(self.manager_handle, 0) != 0 160 | else: 161 | class ManagerWatchdog(object): 162 | def __init__(self): 163 | self.manager_pid = os.getppid() 164 | 165 | def is_alive(self): 166 | return os.getppid() == self.manager_pid 167 | 168 | 169 | def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id): 170 | global _use_shared_memory 171 | _use_shared_memory = True 172 | 173 | # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal 174 | # module's handlers are executed after Python returns from C low-level 175 | # handlers, likely when the same fatal signal happened again already. 176 | # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1 177 | _set_worker_signal_handlers() 178 | 179 | torch.set_num_threads(1) 180 | random.seed(seed) 181 | torch.manual_seed(seed) 182 | 183 | if init_fn is not None: 184 | init_fn(worker_id) 185 | 186 | watchdog = ManagerWatchdog() 187 | 188 | while True: 189 | try: 190 | r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL) 191 | except queue.Empty: 192 | if watchdog.is_alive(): 193 | continue 194 | else: 195 | break 196 | if r is None: 197 | break 198 | idx, batch_indices = r 199 | try: 200 | samples = collate_fn([dataset[i] for i in batch_indices]) 201 | except Exception: 202 | data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) 203 | else: 204 | data_queue.put((idx, samples)) 205 | del samples 206 | 207 | 208 | def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory, device_id): 209 | if pin_memory: 210 | torch.cuda.set_device(device_id) 211 | 212 | while True: 213 | try: 214 | r = in_queue.get() 215 | except Exception: 216 | if done_event.is_set(): 217 | return 218 | raise 219 | if r is None: 220 | break 221 | if isinstance(r[1], ExceptionWrapper): 222 | out_queue.put(r) 223 | continue 224 | idx, batch = r 225 | try: 226 | if pin_memory: 227 | batch = pin_memory_batch(batch) 228 | except Exception: 229 | out_queue.put((idx, ExceptionWrapper(sys.exc_info()))) 230 | else: 231 | out_queue.put((idx, batch)) 232 | 233 | numpy_type_map = { 234 | 'float64': torch.DoubleTensor, 235 | 'float32': torch.FloatTensor, 236 | 'float16': torch.HalfTensor, 237 | 'int64': torch.LongTensor, 238 | 'int32': torch.IntTensor, 239 | 'int16': torch.ShortTensor, 240 | 'int8': torch.CharTensor, 241 | 'uint8': torch.ByteTensor, 242 | } 243 | 244 | 245 | def default_collate(batch): 246 | r"""Puts each data field into a tensor with outer dimension batch size""" 247 | 248 | error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" 249 | elem_type = type(batch[0]) 250 | if isinstance(batch[0], torch.Tensor): 251 | out = None 252 | if _use_shared_memory: 253 | # If we're in a background process, concatenate directly into a 254 | # shared memory tensor to avoid an extra copy 255 | numel = sum([x.numel() for x in batch]) 256 | storage = batch[0].storage()._new_shared(numel) 257 | out = batch[0].new(storage) 258 | return torch.stack(batch, 0, out=out) 259 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 260 | and elem_type.__name__ != 'string_': 261 | elem = batch[0] 262 | if elem_type.__name__ == 'ndarray': 263 | # array of string classes and object 264 | if re.search('[SaUO]', elem.dtype.str) is not None: 265 | raise TypeError(error_msg.format(elem.dtype)) 266 | 267 | return torch.stack([torch.from_numpy(b) for b in batch], 0) 268 | if elem.shape == (): # scalars 269 | py_type = float if elem.dtype.name.startswith('float') else int 270 | return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) 271 | elif isinstance(batch[0], int_classes): 272 | return torch.LongTensor(batch) 273 | elif isinstance(batch[0], float): 274 | return torch.DoubleTensor(batch) 275 | elif isinstance(batch[0], string_classes): 276 | return batch 277 | elif isinstance(batch[0], collections.Mapping): 278 | return {key: default_collate([d[key] for d in batch]) for key in batch[0]} 279 | elif isinstance(batch[0], collections.Sequence): 280 | transposed = zip(*batch) 281 | return [default_collate(samples) for samples in transposed] 282 | 283 | raise TypeError((error_msg.format(type(batch[0])))) 284 | 285 | 286 | def pin_memory_batch(batch): 287 | if isinstance(batch, torch.Tensor): 288 | return batch.pin_memory() 289 | elif isinstance(batch, string_classes): 290 | return batch 291 | elif isinstance(batch, collections.Mapping): 292 | return {k: pin_memory_batch(sample) for k, sample in batch.items()} 293 | elif isinstance(batch, collections.Sequence): 294 | return [pin_memory_batch(sample) for sample in batch] 295 | else: 296 | return batch 297 | 298 | 299 | _SIGCHLD_handler_set = False 300 | r"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one 301 | handler needs to be set for all DataLoaders in a process.""" 302 | 303 | 304 | def _set_SIGCHLD_handler(): 305 | # Windows doesn't support SIGCHLD handler 306 | if sys.platform == 'win32': 307 | return 308 | # can't set signal in child threads 309 | if not isinstance(threading.current_thread(), threading._MainThread): 310 | return 311 | global _SIGCHLD_handler_set 312 | if _SIGCHLD_handler_set: 313 | return 314 | previous_handler = signal.getsignal(signal.SIGCHLD) 315 | if not callable(previous_handler): 316 | previous_handler = None 317 | 318 | def handler(signum, frame): 319 | # This following call uses `waitid` with WNOHANG from C side. Therefore, 320 | # Python can still get and update the process status successfully. 321 | _error_if_any_worker_fails() 322 | if previous_handler is not None: 323 | previous_handler(signum, frame) 324 | 325 | signal.signal(signal.SIGCHLD, handler) 326 | _SIGCHLD_handler_set = True 327 | 328 | 329 | class _DataLoaderIter(object): 330 | r"""Iterates once over the DataLoader's dataset, as specified by the sampler""" 331 | 332 | def __init__(self, loader): 333 | self.dataset = loader.dataset 334 | self.collate_fn = loader.collate_fn 335 | self.batch_sampler = loader.batch_sampler 336 | self.num_workers = loader.num_workers 337 | self.pin_memory = loader.pin_memory and torch.cuda.is_available() 338 | self.timeout = loader.timeout 339 | self.done_event = threading.Event() 340 | 341 | self.sample_iter = iter(self.batch_sampler) 342 | 343 | base_seed = torch.LongTensor(1).random_().item() 344 | 345 | if self.num_workers > 0: 346 | self.worker_init_fn = loader.worker_init_fn 347 | self.index_queues = [multiprocessing.Queue() for _ in range(self.num_workers)] 348 | self.worker_queue_idx = 0 349 | self.worker_result_queue = multiprocessing.SimpleQueue() 350 | self.batches_outstanding = 0 351 | self.worker_pids_set = False 352 | self.shutdown = False 353 | self.send_idx = 0 354 | self.rcvd_idx = 0 355 | self.reorder_dict = {} 356 | 357 | self.workers = [ 358 | multiprocessing.Process( 359 | target=_worker_loop, 360 | args=(self.dataset, self.index_queues[i], 361 | self.worker_result_queue, self.collate_fn, base_seed + i, 362 | self.worker_init_fn, i)) 363 | for i in range(self.num_workers)] 364 | 365 | if self.pin_memory or self.timeout > 0: 366 | self.data_queue = queue.Queue() 367 | if self.pin_memory: 368 | maybe_device_id = torch.cuda.current_device() 369 | else: 370 | # do not initialize cuda context if not necessary 371 | maybe_device_id = None 372 | self.worker_manager_thread = threading.Thread( 373 | target=_worker_manager_loop, 374 | args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory, 375 | maybe_device_id)) 376 | self.worker_manager_thread.daemon = True 377 | self.worker_manager_thread.start() 378 | else: 379 | self.data_queue = self.worker_result_queue 380 | 381 | for w in self.workers: 382 | w.daemon = True # ensure that the worker exits on process exit 383 | w.start() 384 | 385 | _update_worker_pids(id(self), tuple(w.pid for w in self.workers)) 386 | _set_SIGCHLD_handler() 387 | self.worker_pids_set = True 388 | 389 | # prime the prefetch loop 390 | for _ in range(2 * self.num_workers): 391 | self._put_indices() 392 | 393 | def __len__(self): 394 | return len(self.batch_sampler) 395 | 396 | def _get_batch(self): 397 | if self.timeout > 0: 398 | try: 399 | return self.data_queue.get(timeout=self.timeout) 400 | except queue.Empty: 401 | raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout)) 402 | else: 403 | return self.data_queue.get() 404 | 405 | def __next__(self): 406 | if self.num_workers == 0: # same-process loading 407 | indices = next(self.sample_iter) # may raise StopIteration 408 | batch = self.collate_fn([self.dataset[i] for i in indices]) 409 | if self.pin_memory: 410 | batch = pin_memory_batch(batch) 411 | return batch 412 | 413 | # check if the next sample has already been generated 414 | if self.rcvd_idx in self.reorder_dict: 415 | batch = self.reorder_dict.pop(self.rcvd_idx) 416 | return self._process_next_batch(batch) 417 | 418 | if self.batches_outstanding == 0: 419 | self._shutdown_workers() 420 | raise StopIteration 421 | 422 | while True: 423 | assert (not self.shutdown and self.batches_outstanding > 0) 424 | idx, batch = self._get_batch() 425 | self.batches_outstanding -= 1 426 | if idx != self.rcvd_idx: 427 | # store out-of-order samples 428 | self.reorder_dict[idx] = batch 429 | continue 430 | return self._process_next_batch(batch) 431 | 432 | next = __next__ # Python 2 compatibility 433 | 434 | def __iter__(self): 435 | return self 436 | 437 | def _put_indices(self): 438 | assert self.batches_outstanding < 2 * self.num_workers 439 | indices = next(self.sample_iter, None) 440 | if indices is None: 441 | return 442 | self.index_queues[self.worker_queue_idx].put((self.send_idx, indices)) 443 | self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers 444 | self.batches_outstanding += 1 445 | self.send_idx += 1 446 | 447 | def _process_next_batch(self, batch): 448 | self.rcvd_idx += 1 449 | self._put_indices() 450 | if isinstance(batch, ExceptionWrapper): 451 | raise batch.exc_type(batch.exc_msg) 452 | return batch 453 | 454 | def __getstate__(self): 455 | # TODO: add limited pickling support for sharing an iterator 456 | # across multiple threads for HOGWILD. 457 | # Probably the best way to do this is by moving the sample pushing 458 | # to a separate thread and then just sharing the data queue 459 | # but signalling the end is tricky without a non-blocking API 460 | raise NotImplementedError("_DataLoaderIter cannot be pickled") 461 | 462 | def _shutdown_workers(self): 463 | try: 464 | if not self.shutdown: 465 | self.shutdown = True 466 | self.done_event.set() 467 | for q in self.index_queues: 468 | q.put(None) 469 | # if some workers are waiting to put, make place for them 470 | try: 471 | while not self.worker_result_queue.empty(): 472 | self.worker_result_queue.get() 473 | except (FileNotFoundError, ImportError): 474 | # Many weird errors can happen here due to Python 475 | # shutting down. These are more like obscure Python bugs. 476 | # FileNotFoundError can happen when we rebuild the fd 477 | # fetched from the queue but the socket is already closed 478 | # from the worker side. 479 | # ImportError can happen when the unpickler loads the 480 | # resource from `get`. 481 | pass 482 | # done_event should be sufficient to exit worker_manager_thread, 483 | # but be safe here and put another None 484 | self.worker_result_queue.put(None) 485 | finally: 486 | # removes pids no matter what 487 | if self.worker_pids_set: 488 | _remove_worker_pids(id(self)) 489 | self.worker_pids_set = False 490 | 491 | def __del__(self): 492 | if self.num_workers > 0: 493 | self._shutdown_workers() 494 | 495 | 496 | class DataLoader(object): 497 | r""" 498 | Data loader. Combines a dataset and a sampler, and provides 499 | single- or multi-process iterators over the dataset. 500 | 501 | Arguments: 502 | dataset (Dataset): dataset from which to load the data. 503 | batch_size (int, optional): how many samples per batch to load 504 | (default: 1). 505 | shuffle (bool, optional): set to ``True`` to have the data reshuffled 506 | at every epoch (default: False). 507 | sampler (Sampler, optional): defines the strategy to draw samples from 508 | the dataset. If specified, ``shuffle`` must be False. 509 | batch_sampler (Sampler, optional): like sampler, but returns a batch of 510 | indices at a time. Mutually exclusive with batch_size, shuffle, 511 | sampler, and drop_last. 512 | num_workers (int, optional): how many subprocesses to use for data 513 | loading. 0 means that the data will be loaded in the main process. 514 | (default: 0) 515 | collate_fn (callable, optional): merges a list of samples to form a mini-batch. 516 | pin_memory (bool, optional): If ``True``, the data loader will copy tensors 517 | into CUDA pinned memory before returning them. 518 | drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, 519 | if the dataset size is not divisible by the batch size. If ``False`` and 520 | the size of dataset is not divisible by the batch size, then the last batch 521 | will be smaller. (default: False) 522 | timeout (numeric, optional): if positive, the timeout value for collecting a batch 523 | from workers. Should always be non-negative. (default: 0) 524 | worker_init_fn (callable, optional): If not None, this will be called on each 525 | worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as 526 | input, after seeding and before data loading. (default: None) 527 | 528 | .. note:: By default, each worker will have its PyTorch seed set to 529 | ``base_seed + worker_id``, where ``base_seed`` is a long generated 530 | by main process using its RNG. However, seeds for other libraies 531 | may be duplicated upon initializing workers (w.g., NumPy), causing 532 | each worker to return identical random numbers. (See 533 | :ref:`dataloader-workers-random-seed` section in FAQ.) You may 534 | use ``torch.initial_seed()`` to access the PyTorch seed for each 535 | worker in :attr:`worker_init_fn`, and use it to set other seeds 536 | before data loading. 537 | 538 | .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an 539 | unpicklable object, e.g., a lambda function. 540 | """ 541 | 542 | __initialized = False 543 | 544 | def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, 545 | num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, 546 | timeout=0, worker_init_fn=None): 547 | self.dataset = dataset 548 | self.batch_size = batch_size 549 | self.num_workers = num_workers 550 | self.collate_fn = collate_fn 551 | self.pin_memory = pin_memory 552 | self.drop_last = drop_last 553 | self.timeout = timeout 554 | self.worker_init_fn = worker_init_fn 555 | 556 | if timeout < 0: 557 | raise ValueError('timeout option should be non-negative') 558 | 559 | if batch_sampler is not None: 560 | if batch_size > 1 or shuffle or sampler is not None or drop_last: 561 | raise ValueError('batch_sampler option is mutually exclusive ' 562 | 'with batch_size, shuffle, sampler, and ' 563 | 'drop_last') 564 | self.batch_size = None 565 | self.drop_last = None 566 | 567 | if sampler is not None and shuffle: 568 | raise ValueError('sampler option is mutually exclusive with ' 569 | 'shuffle') 570 | 571 | if self.num_workers < 0: 572 | raise ValueError('num_workers option cannot be negative; ' 573 | 'use num_workers=0 to disable multiprocessing.') 574 | 575 | if batch_sampler is None: 576 | if sampler is None: 577 | if shuffle: 578 | sampler = RandomSampler(dataset) 579 | else: 580 | sampler = SequentialSampler(dataset) 581 | batch_sampler = BatchSampler(sampler, batch_size, drop_last) 582 | 583 | self.sampler = sampler 584 | self.batch_sampler = batch_sampler 585 | self.__initialized = True 586 | 587 | def __setattr__(self, attr, val): 588 | if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'): 589 | raise ValueError('{} attribute should not be set after {} is ' 590 | 'initialized'.format(attr, self.__class__.__name__)) 591 | 592 | super(DataLoader, self).__setattr__(attr, val) 593 | 594 | def __iter__(self): 595 | return _DataLoaderIter(self) 596 | 597 | def __len__(self): 598 | return len(self.batch_sampler) 599 | 600 | 601 | class MyDataLoader(DataLoader): 602 | 603 | __initialized = False 604 | 605 | def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, 606 | num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, 607 | timeout=0, worker_init_fn=None): 608 | super(MyDataLoader, self).__init__(dataset, batch_size, shuffle, sampler, batch_sampler, 609 | num_workers, collate_fn, pin_memory, drop_last, timeout, worker_init_fn) 610 | 611 | self.dataset = dataset 612 | self.batch_size = batch_size 613 | self.num_workers = num_workers 614 | self.collate_fn = collate_fn 615 | self.pin_memory = pin_memory 616 | self.drop_last = drop_last 617 | self.timeout = timeout 618 | self.worker_init_fn = worker_init_fn 619 | 620 | if timeout < 0: 621 | raise ValueError('timeout option should be non-negative') 622 | 623 | if batch_sampler is not None: 624 | if batch_size > 1 or shuffle or sampler is not None or drop_last: 625 | raise ValueError('batch_sampler option is mutually exclusive ' 626 | 'with batch_size, shuffle, sampler, and ' 627 | 'drop_last') 628 | self.batch_size = None 629 | self.drop_last = None 630 | 631 | if sampler is not None and shuffle: 632 | raise ValueError('sampler option is mutually exclusive with ' 633 | 'shuffle') 634 | 635 | if self.num_workers < 0: 636 | raise ValueError('num_workers option cannot be negative; ' 637 | 'use num_workers=0 to disable multiprocessing.') 638 | 639 | if batch_sampler is None: 640 | if sampler is None: 641 | if shuffle: 642 | sampler = RandomSampler(dataset) 643 | else: 644 | sampler = SequentialSampler(dataset) 645 | batch_sampler = BatchSampler(sampler, batch_size, drop_last) 646 | 647 | self.sampler = sampler 648 | self.batch_sampler = batch_sampler 649 | self.__initialized = True 650 | 651 | 652 | def __setattr__(self, attr, val): 653 | if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'): 654 | raise ValueError('{} attribute should not be set after {} is ' 655 | 'initialized'.format(attr, self.__class__.__name__)) 656 | 657 | super(DataLoader, self).__setattr__(attr, val) 658 | 659 | 660 | def __iter__(self): 661 | return _DataLoaderIter(self) 662 | 663 | 664 | def __len__(self): 665 | return len(self.batch_sampler) -------------------------------------------------------------------------------- /utils/myimagefolder.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | 5 | import os 6 | import os.path 7 | 8 | 9 | def has_file_allowed_extension(filename, extensions): 10 | """Checks if a file is an allowed extension. 11 | 12 | Args: 13 | filename (string): path to a file 14 | 15 | Returns: 16 | bool: True if the filename ends with a known image extension 17 | """ 18 | filename_lower = filename.lower() 19 | return any(filename_lower.endswith(ext) for ext in extensions) 20 | 21 | 22 | def find_classes(dir): 23 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 24 | classes.sort() 25 | class_to_idx = {classes[i]: i for i in range(len(classes))} 26 | return classes, class_to_idx 27 | 28 | 29 | def make_dataset(dir, class_to_idx, extensions): 30 | images = [] 31 | dir = os.path.expanduser(dir) 32 | for target in sorted(os.listdir(dir)): 33 | d = os.path.join(dir, target) 34 | if not os.path.isdir(d): 35 | continue 36 | 37 | for root, _, fnames in sorted(os.walk(d)): 38 | for fname in sorted(fnames): 39 | if has_file_allowed_extension(fname, extensions): 40 | path = os.path.join(root, fname) 41 | item = (path, class_to_idx[target]) 42 | images.append(item) 43 | 44 | return images 45 | 46 | 47 | class DatasetFolder(data.Dataset): 48 | """A generic data loader where the samples are arranged in this way: :: 49 | 50 | root/class_x/xxx.ext 51 | root/class_x/xxy.ext 52 | root/class_x/xxz.ext 53 | 54 | root/class_y/123.ext 55 | root/class_y/nsdf3.ext 56 | root/class_y/asd932_.ext 57 | 58 | Args: 59 | root (string): Root directory path. 60 | loader (callable): A function to load a sample given its path. 61 | extensions (list[string]): A list of allowed extensions. 62 | transform (callable, optional): A function/transform that takes in 63 | a sample and returns a transformed version. 64 | E.g, ``transforms.RandomCrop`` for images. 65 | target_transform (callable, optional): A function/transform that takes 66 | in the target and transforms it. 67 | 68 | Attributes: 69 | classes (list): List of the class names. 70 | class_to_idx (dict): Dict with items (class_name, class_index). 71 | samples (list): List of (sample path, class_index) tuples 72 | """ 73 | 74 | def __init__(self, root, loader, extensions, transform=None, target_transform=None): 75 | classes, class_to_idx = find_classes(root) 76 | samples = make_dataset(root, class_to_idx, extensions) 77 | if len(samples) == 0: 78 | raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n" 79 | "Supported extensions are: " + ",".join(extensions))) 80 | 81 | self.root = root 82 | self.loader = loader 83 | self.extensions = extensions 84 | 85 | self.classes = classes 86 | self.class_to_idx = class_to_idx 87 | self.samples = samples 88 | 89 | self.transform = transform 90 | self.target_transform = target_transform 91 | 92 | def __getitem__(self, index): 93 | """ 94 | Args: 95 | index (int): Index 96 | 97 | Returns: 98 | tuple: (sample, target) where target is class_index of the target class. 99 | """ 100 | path, target = self.samples[index] 101 | sample = self.loader(path) 102 | if self.transform is not None: 103 | sample = self.transform(sample) 104 | if self.target_transform is not None: 105 | target = self.target_transform(target) 106 | 107 | return sample, target, path 108 | 109 | def __len__(self): 110 | return len(self.samples) 111 | 112 | def __repr__(self): 113 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 114 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 115 | fmt_str += ' Root Location: {}\n'.format(self.root) 116 | tmp = ' Transforms (if any): ' 117 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 118 | tmp = ' Target Transforms (if any): ' 119 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 120 | return fmt_str 121 | 122 | 123 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] 124 | 125 | 126 | def pil_loader(path): 127 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 128 | with open(path, 'rb') as f: 129 | img = Image.open(f) 130 | return img.convert('RGB') 131 | 132 | 133 | # def accimage_loader(path): 134 | # import accimage 135 | # try: 136 | # return accimage.Image(path) 137 | # except IOError: 138 | # # Potentially a decoding problem, fall back to PIL.Image 139 | # return pil_loader(path) 140 | 141 | 142 | def default_loader(path): 143 | from torchvision import get_image_backend 144 | # if get_image_backend() == 'accimage': 145 | # return accimage_loader(path) 146 | # else: 147 | # return pil_loader(path) 148 | return pil_loader(path) 149 | 150 | 151 | class MyImageFolder(DatasetFolder): 152 | """A generic data loader where the images are arranged in this way: :: 153 | 154 | root/dog/xxx.png 155 | root/dog/xxy.png 156 | root/dog/xxz.png 157 | 158 | root/cat/123.png 159 | root/cat/nsdf3.png 160 | root/cat/asd932_.png 161 | 162 | Args: 163 | root (string): Root directory path. 164 | transform (callable, optional): A function/transform that takes in an PIL image 165 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 166 | target_transform (callable, optional): A function/transform that takes in the 167 | target and transforms it. 168 | loader (callable, optional): A function to load an image given its path. 169 | 170 | Attributes: 171 | classes (list): List of the class names. 172 | class_to_idx (dict): Dict with items (class_name, class_index). 173 | imgs (list): List of (image path, class_index) tuples 174 | """ 175 | def __init__(self, root, transform=None, target_transform=None, 176 | loader=default_loader): 177 | super(MyImageFolder, self).__init__(root, loader, IMG_EXTENSIONS, 178 | transform=transform, 179 | target_transform=target_transform) 180 | self.imgs = self.samples 181 | -------------------------------------------------------------------------------- /utils/receptivesize.py: -------------------------------------------------------------------------------- 1 | import os, sys, math 2 | import os.path as osp 3 | import pprint 4 | from collections import OrderedDict 5 | import pickle as pk 6 | progpath = os.path.dirname(os.path.realpath(__file__)) # /home/luowei/Codes/feasc-msc 7 | sys.path.append(progpath) 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | 13 | #################### model zoo 14 | modelzoopath = "/home/luowei/Codes/pymodels" 15 | # modelzoopath = "/vulcan/scratch/cswluo/Codes/pymodels" 16 | sys.path.append(osp.dirname(modelzoopath)) 17 | import pymodels 18 | 19 | 20 | #################### import modules in the current directory 21 | import mymodels 22 | import modellearning 23 | 24 | 25 | #### model params 26 | num_classes = 200 27 | nparts = 2 28 | seflag = True 29 | 30 | model = mymodels.feasc50(num_classes=num_classes, nparts=nparts, seflag=seflag) 31 | 32 | # for name, param in model.named_parameters(): 33 | # print(name, '--->', param.size()) 34 | 35 | # for name, module in model.named_modules(): 36 | # print(name, '--->', module) 37 | 38 | 39 | print("\n==========================================================\n") 40 | 41 | def reseq(module, layer_size, name=None, seqnum=None): 42 | if isinstance(module, nn.Sequential): 43 | reseq(module, layer_size, name, seqnum) 44 | else: 45 | for m in module.named_children(): 46 | if isinstance(m[-1], (nn.Conv2d, nn.MaxPool2d)): 47 | kernel_size = m[-1].kernel_size[0] 48 | stride_size = m[-1].stride[0] 49 | padding_size = m[-1].padding[0] 50 | subname = m[0] 51 | print(name+'_'+str(seqnum)+'_'+subname, kernel_size, stride_size, padding_size) 52 | layer_size[name+'_'+str(seqnum)+'_'+subname] = [kernel_size, stride_size, padding_size] 53 | 54 | layer_size = OrderedDict() 55 | # this will not print the model itself 56 | for name, module in model.named_children(): 57 | # print(name, '--->', module) 58 | if isinstance(module, nn.Sequential): 59 | for i in range(len(module)): 60 | reseq(module[i], layer_size, name, i) 61 | 62 | if isinstance(module, (nn.Conv2d, nn.MaxPool2d)): 63 | kernel_size = module.kernel_size 64 | stride_size = module.stride 65 | padding_size = module.padding 66 | print(name, kernel_size, padding_size, stride_size) 67 | if isinstance(module, nn.Conv2d): 68 | layer_size[name] = [kernel_size[0], stride_size[0], padding_size[0]] 69 | else: 70 | layer_size[name] = [kernel_size, stride_size, padding_size] 71 | 72 | # pprint.pprint(layer_size) 73 | # for key, value in layer_size.items(): 74 | # print(key, value) 75 | 76 | def outFromIn(conv, layerIn): 77 | n_in = layerIn[0] # input feature dimension 78 | j_in = layerIn[1] # jumps 79 | r_in = layerIn[2] # receptive field 80 | start_in = layerIn[3] 81 | k = conv[0] # kernel size 82 | s = conv[1] # strides 83 | p = conv[2] # padding 84 | 85 | n_out = math.floor((n_in - k + 2 * p) / s) + 1 86 | actualP = (n_out - 1) * s - n_in + k # the total actual padding size 87 | pR = math.ceil(actualP / 2) 88 | pL = math.floor(actualP / 2) 89 | 90 | j_out = j_in * s 91 | r_out = r_in + (k - 1) * j_in 92 | start_out = start_in + ((k - 1) / 2 - pL) * j_in 93 | return n_out, j_out, r_out, start_out 94 | 95 | 96 | def printLayer(layer, layer_name): 97 | print(layer_name + ":") 98 | print("\t n features: %s \n \t jump: %s \n \t receptive size: %s \t start: %s " % ( 99 | layer[0], layer[1], layer[2], layer[3])) 100 | 101 | 102 | layerInfos = [] 103 | if __name__ == "__main__": 104 | imgsize = 224 105 | r_in = 1 106 | s_in = 1 107 | start_in = 0 108 | currentLayer = [imgsize, s_in, r_in, start_in] 109 | layerInfos.append(currentLayer) 110 | printLayer(currentLayer, "input image") 111 | 112 | 113 | for key, value in layer_size.items(): 114 | currentLayer = outFromIn(value, currentLayer) 115 | layerInfos.append(currentLayer) 116 | printLayer(currentLayer, key) 117 | print("------------------------") 118 | 119 | pprint.pprint(layerInfos) 120 | 121 | 122 | -------------------------------------------------------------------------------- /x-imdb/cubbirds-imdb.py: -------------------------------------------------------------------------------- 1 | import scipy.io as sio 2 | import pickle as pk 3 | import os, sys 4 | import os.path as osp 5 | import pandas as pd 6 | import numpy as np 7 | 8 | progpath = os.path.dirname(os.path.realpath(__file__)) 9 | sys.path.append(progpath) 10 | 11 | 12 | imdbpath = osp.join(progpath, "imdb.pkl") 13 | 14 | if not os.path.exists(imdbpath): 15 | 16 | file_dict = {'Cid': 'classes.txt', 17 | 'imageCid': 'image_class_labels.txt', 18 | 'imageId': 'images.txt', 19 | 'imageTVT': 'train_test_split.txt' 20 | } 21 | 22 | with open(os.path.join(progpath, file_dict['Cid'])) as f: 23 | coxt = f.readlines() 24 | class_names = [x.split()[-1] for x in coxt] 25 | classnames = class_names 26 | 27 | 28 | with open(os.path.join(progpath, file_dict['imageId'])) as f: 29 | coxt = f.readlines() 30 | imageId = [int(x.split()[0]) for x in coxt] 31 | imagePath = [x.split()[-1] for x in coxt] 32 | df = {'imageId': imageId, 'imagePath': imagePath} 33 | imdb = pd.DataFrame(data=df) 34 | imageCid = [x.split('/')[0] for x in imdb['imagePath']] 35 | with open(os.path.join(progpath, file_dict['imageTVT'])) as f: 36 | coxt = f.readlines() 37 | imageTVT = [int(x.split()[-1]) for x in coxt] 38 | imdb['imageCid'] = imageCid 39 | imdb['imageTVT'] = imageTVT 40 | 41 | ####################################################33 42 | test_pd = imdb.loc[imdb['imageTVT'] == 0] 43 | train_tpd = imdb.loc[imdb['imageTVT'] == 1] 44 | 45 | train_pd = pd.DataFrame() 46 | val_pd = pd.DataFrame() 47 | 48 | for class_name in classnames: 49 | trainval = train_tpd.loc[train_tpd['imageCid'] == class_name] 50 | permuind = np.random.permutation(trainval.index) 51 | # print(permuind) 52 | train_pd = train_pd.append(trainval.loc[permuind[:-3]], ignore_index=True) 53 | val_pd = val_pd.append(trainval.loc[permuind[-3:]], ignore_index=True) 54 | # print(trainval.index) 55 | 56 | ####################### save in pickle files 57 | # classdict: index to class name 58 | # classnames: class name 59 | # annos_test: annotations for testing data 60 | # annos_train: annotations for training data 61 | 62 | with open(os.path.join(progpath, 'imdb.pkl'), 'wb') as f: 63 | pk.dump({'train': train_pd, 64 | 'val': val_pd, 65 | 'test': test_pd, 66 | 'classnames': classnames}, f) 67 | 68 | 69 | else: 70 | with open(os.path.join(root, 'imdb.pkl'), 'rb') as f: 71 | pdata = pk.load(f) 72 | train_pd, val_pd, test_pd, classnames = pdata['train'], pdata['val'], pdata['test'], pdata['classnames'] 73 | 74 | 75 | -------------------------------------------------------------------------------- /x-imdb/nabirds-imdb.py: -------------------------------------------------------------------------------- 1 | import scipy.io as sio 2 | import pickle as pk 3 | import os, sys 4 | import os.path as osp 5 | 6 | progpath = os.path.dirname(os.path.realpath(__file__)) 7 | sys.path.append(progpath) 8 | 9 | 10 | imdbpath = osp.join(progpath, "imdb.pkl") 11 | 12 | if not os.path.exists(imdbpath): 13 | 14 | img_list = "images.txt" 15 | tvt_list = "train_test_split.txt" 16 | cls_list = "classes.txt" 17 | img_cls_list = "image_class_labels.txt" 18 | cls_hir_list = "hierarchy.txt" 19 | 20 | 21 | classnames = {} 22 | with open(os.path.join(progpath, 'classes.txt')) as f: 23 | for line in f: 24 | pieces = line.strip().split() 25 | class_id = pieces[0] 26 | classnames[class_id] = ' '.join(pieces[1:]) 27 | 28 | 29 | classparents = {} 30 | with open(os.path.join(progpath, 'hierarchy.txt')) as f: 31 | for line in f: 32 | pieces = line.strip().split() 33 | child_id, parent_id = pieces 34 | classparents[child_id] = parent_id 35 | 36 | 37 | imgpaths = {} 38 | with open(os.path.join(progpath, img_list)) as f: 39 | for line in f: 40 | pieces = line.strip().split() 41 | image_id = pieces[0] 42 | path = os.path.join('images', pieces[1]) 43 | imgpaths[image_id] = path 44 | 45 | 46 | imglabels = {} 47 | with open(os.path.join(progpath, 'image_class_labels.txt')) as f: 48 | for line in f: 49 | pieces = line.strip().split() 50 | image_id = pieces[0] 51 | class_id = pieces[1] 52 | imglabels[image_id] = class_id 53 | 54 | traindata, testdata = list(), list() 55 | with open(os.path.join(progpath, 'train_test_split.txt')) as f: 56 | for line in f: 57 | pieces = line.strip().split() 58 | image_id = pieces[0] 59 | is_train = int(pieces[1]) 60 | 61 | clsid = imglabels[image_id] 62 | clsname = classnames[clsid] 63 | prntid = classparents[clsid] 64 | prntname = classnames[prntid] 65 | 66 | if is_train: 67 | traindata.append([imgpaths[image_id], clsid, clsname, prntid, prntname]) 68 | else: 69 | testdata.append([imgpaths[image_id], clsid, clsname, prntid, prntname]) 70 | 71 | 72 | num_classes = len(classnames) 73 | num_subclasses = len(set(imglabels.values())) 74 | num_prntclasses = len(set(classparents.values())) 75 | print(num_classes, num_subclasses, num_prntclasses) 76 | ####################### save in pickle files 77 | # prntclassid: classparents[sub_class_id] = parent_class_id 78 | # classnames: classnames[sub_class_id/parent_class_id] = class_name 79 | # subclassid: imglabels[img_id] = sub_class_id 80 | # traindata: [imgpath, subclass_id, subclass_name, parent_class_id, parent_class_name] 81 | # testdata: [imgpath, subclass_id, subclass_name, parent_class_id, parent_class_name] 82 | 83 | with open(os.path.join(progpath, 'imdb.pkl'), 'wb') as handle: 84 | pk.dump({'prntclassid': classparents, 85 | 'classnames': classnames, 86 | 'subclassid': imglabels, 87 | 'testdata': testdata, 88 | 'traindata': traindata}, handle) 89 | else: 90 | with open(os.path.join(progpath, 'imdb.pkl'), 'rb') as handle: 91 | annos = pk.load(handle) 92 | 93 | prntclassid = annos['prntclassid'] 94 | classnames = annos['classnames'] 95 | subclassid = annos['subclassid'] 96 | testdata = annos['testdata'] 97 | traindata = annos['traindata'] 98 | 99 | 100 | -------------------------------------------------------------------------------- /x-imdb/stcars-imdb.py: -------------------------------------------------------------------------------- 1 | import scipy.io as sio 2 | import pickle as pk 3 | import os, sys 4 | 5 | progpath = os.path.dirname(os.path.realpath(__file__)) 6 | sys.path.append(progpath) 7 | 8 | imdbpath = os.path.join(parnpath, "imdb.pkl") 9 | 10 | if not os.path.exists(imdbpath): 11 | annos_name = "cars_annos.mat" 12 | annos = sio.loadmat(os.path.join(parnpath, annos_name), 13 | mat_dtype=False, squeeze_me=True, matlab_compatible=False,struct_as_record=True) 14 | 15 | annos_type=['relative_im_path', 'bbox_x1', 'bbox_y1', 'bbox_x2', 'bbox_y2', 'class', 'test'] 16 | annotations = annos['annotations'] 17 | classnames_tmp = annos['class_names'].tolist() 18 | classnames = list() 19 | for classname in classnames_tmp: 20 | classnames.append(classname.replace('/', '')) 21 | 22 | annos_train = [] 23 | annos_test = [] 24 | test_count = 0 25 | for anno in annotations: 26 | dl = {} 27 | dl['relative_im_path'] = anno[0] 28 | dl['bbox_x1'] = anno[1] 29 | dl['bbox_y1'] = anno[2] 30 | dl['bbox_x2'] = anno[3] 31 | dl['bbox_y2'] = anno[4] 32 | dl['class'] = anno[5] - 1 33 | dl['test'] = anno[-1] 34 | 35 | if dl['test'] == 1: 36 | annos_test.append(dl) 37 | else: 38 | annos_train.append(dl) 39 | 40 | print(len(annos_train), len(annos_test)) 41 | 42 | classdict = {} 43 | for i, classname in enumerate(classnames): 44 | classdict[i] = classname 45 | 46 | 47 | 48 | ####################### save in pickle files 49 | # classdict: index to class name 50 | # classnames: class name 51 | # annos_test: annotations for testing data 52 | # annos_train: annotations for training data 53 | 54 | with open(os.path.join(progpath, 'imdb.pkl'), 'wb') as handle: 55 | pk.dump({'classdict': classdict, 'classnames': classnames, 56 | 'annos_test': annos_test, 'annos_train': annos_train}, handle) 57 | else: 58 | with open(os.path.join(progpath, 'imdb.pkl'), 'rb') as handle: 59 | annos = pk.load(handle) 60 | 61 | classdict = annos['classdict'] 62 | classnames = annos['classnames'] 63 | annos_test = annos['annos_test'] 64 | annos_train =annos['annos_train'] 65 | 66 | 67 | -------------------------------------------------------------------------------- /x-imdb/stdogs-imdb.py: -------------------------------------------------------------------------------- 1 | import scipy.io as sio 2 | import pickle as pk 3 | import os, sys 4 | import os.path as osp 5 | 6 | progpath = os.path.dirname(os.path.realpath(__file__)) 7 | sys.path.append(progpath) 8 | 9 | 10 | imdbpath = osp.join(progpath, "imdb.pkl") 11 | 12 | if not os.path.exists(imdbpath): 13 | file_list = 'file_list.mat' 14 | train_list = "train_list.mat" 15 | test_list = "test_list.mat" 16 | 17 | file_data = sio.loadmat(osp.join(progpath, file_list), 18 | mat_dtype=False, squeeze_me=True, matlab_compatible=False,struct_as_record=True) 19 | train_data = sio.loadmat(osp.join(progpath, train_list), 20 | mat_dtype=False, squeeze_me=True, matlab_compatible=False, struct_as_record=True) 21 | test_data = sio.loadmat(osp.join(progpath, test_list), 22 | mat_dtype=False, squeeze_me=True, matlab_compatible=False, struct_as_record=True) 23 | 24 | annos_type=['annotation_list', 'file_list', 'labels'] 25 | 26 | train_file_list = train_data['file_list'].tolist() 27 | train_file_labels = train_data['labels'].tolist() 28 | test_file_list = test_data['file_list'].tolist() 29 | test_file_labels = test_data['labels'].tolist() 30 | 31 | class_names = list() 32 | for file_list in test_file_list: 33 | prefix = file_list.split('/')[0] 34 | if prefix not in class_names: 35 | class_names.append(prefix) 36 | 37 | num_classes = len(class_names) 38 | 39 | class_dict = dict() 40 | for i in range(num_classes): 41 | class_dict[i] = class_names[i] 42 | 43 | 44 | 45 | ####################### save in pickle files 46 | # classdict: index to class name 47 | # classnames: class name 48 | # annos_test: annotations for testing data 49 | # annos_train: annotations for training data 50 | 51 | with open(os.path.join(progpath, 'imdb.pkl'), 'wb') as handle: 52 | pk.dump({'classdict': class_dict, 53 | 'classnames': class_names, 54 | 'traindata': train_file_list, 55 | 'trainlabels': train_file_labels, 56 | 'testdata': test_file_list, 57 | 'testlabels': test_file_labels}, handle) 58 | else: 59 | with open(os.path.join(progpath, 'imdb.pkl'), 'rb') as handle: 60 | annos = pk.load(handle) 61 | 62 | classdict = annos['classdict'] 63 | classnames = annos['classnames'] 64 | traindata = annos['traindata'] 65 | trainlabels =annos['trainlabels'] 66 | testdata = annos['testdata'] 67 | testlabels = annos['testlabels'] 68 | 69 | 70 | -------------------------------------------------------------------------------- /x-imdb/vggaircraft-imdb.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # There are a total of 30 manufacturers, 70 families and 100 variants (categories) 3 | # manufacturer is the superclass of families. 4 | # family is the superclass of variants. 5 | # we mainly do classification on the scale of variant. 6 | # 7 | # the label relationship is: 8 | # variant < family < manufacturer < car 9 | # the data structure of the traindata, valdata, trainvaldata and testdata: 10 | # [imgname, variant, family, munufacturer] 11 | ###################################### 12 | 13 | 14 | import scipy.io as sio 15 | import pickle as pk 16 | import os, sys 17 | import os.path as osp 18 | 19 | progpath = os.path.dirname(os.path.realpath(__file__)) 20 | sys.path.append(progpath) 21 | 22 | 23 | imdbpath = osp.join(progpath, "imdb.pkl") 24 | 25 | if not os.path.exists(imdbpath): 26 | 27 | # variant 28 | train_variant_list = "annotations/images_variant_train.txt" 29 | val_variant_list = "annotations/images_variant_val.txt" 30 | trainval_variant_list = "annotations/images_variant_trainval.txt" 31 | test_variant_list = "annotations/images_variant_test.txt" 32 | 33 | # family 34 | train_family_list = "annotations/images_family_train.txt" 35 | val_family_list = "annotations/images_family_val.txt" 36 | trainval_family_list = "annotations/images_family_trainval.txt" 37 | test_family_list = "annotations/images_family_test.txt" 38 | 39 | # manufacturer 40 | train_manufacturer_list = "annotations/images_manufacturer_train.txt" 41 | val_manufacturer_list = "annotations/images_manufacturer_val.txt" 42 | trainval_manufacturer_list = "annotations/images_manufacturer_trainval.txt" 43 | test_manufacturer_list = "annotations/images_manufacturer_test.txt" 44 | 45 | variants_list = "annotations/variants.txt" 46 | families_list = "annotations/families.txt" 47 | manufacturers_list = "annotations/manufacturers.txt" 48 | 49 | # variant, family, manufacturer 50 | with open(osp.join(progpath, variants_list), 'r') as f: 51 | classnames_variant = f.readlines() 52 | with open(osp.join(progpath, families_list), 'r') as f: 53 | classnames_family = f.readlines() 54 | with open(osp.join(progpath, manufacturers_list), 'r') as f: 55 | classnames_manufacturer = f.readlines() 56 | classnames_variant = [x.rstrip('\n').replace('/', '') for x in classnames_variant] 57 | classnames_family = [x.rstrip('\n').replace('/', '') for x in classnames_family] 58 | classnames_manufacturer = [x.rstrip('\n').replace('/', '') for x in classnames_manufacturer] 59 | 60 | # dictionary 61 | classdict_variant, classdict_family, classdict_manufacturer = dict(), dict(), dict() 62 | for idx, classname in enumerate(classnames_variant): 63 | classdict_variant[idx] = classname 64 | for idx, classname in enumerate(classnames_family): 65 | classdict_family[idx] = classname 66 | for idx, classname in enumerate(classnames_manufacturer): 67 | classdict_manufacturer[idx] = classname 68 | 69 | # for the training set 70 | traindata = list() 71 | with open(osp.join(progpath, train_variant_list), 'r') as f: 72 | rows = f.readlines() 73 | traindata_variant = [row.rstrip('\n').replace('/', '').split(' ', 1) for row in rows] 74 | with open(osp.join(progpath, train_family_list), 'r') as f: 75 | rows = f.readlines() 76 | traindata_family = [row.rstrip('\n').replace('/', '').split(' ', 1) for row in rows] 77 | with open(osp.join(progpath, train_manufacturer_list), 'r') as f: 78 | rows = f.readlines() 79 | traindata_manufacturer = [row.rstrip('\n').replace('/', '').split(' ', 1) for row in rows] 80 | if len(traindata_variant) == len(traindata_family) == len(traindata_manufacturer): 81 | for variant, family, manufacturer in zip(traindata_variant, traindata_family, traindata_manufacturer): 82 | # print(variant, family, manufacturer) 83 | variant.append(family[-1]) 84 | variant.append(manufacturer[-1]) 85 | traindata.append(variant) 86 | 87 | # for the validation set 88 | valdata = list() 89 | with open(osp.join(progpath, val_variant_list), 'r') as f: 90 | rows = f.readlines() 91 | valdata_variant = [row.rstrip('\n').replace('/', '').split(' ', 1) for row in rows] 92 | with open(osp.join(progpath, val_family_list), 'r') as f: 93 | rows = f.readlines() 94 | valdata_family = [row.rstrip('\n').replace('/', '').split(' ', 1) for row in rows] 95 | with open(osp.join(progpath, val_manufacturer_list), 'r') as f: 96 | rows = f.readlines() 97 | valdata_manufacturer = [row.rstrip('\n').replace('/', '').split(' ', 1) for row in rows] 98 | if len(valdata_variant) == len(valdata_family) == len(valdata_manufacturer): 99 | for variant, family, manufacturer in zip(valdata_variant, valdata_family, valdata_manufacturer): 100 | # print(variant, family, manufacturer) 101 | variant.append(family[-1]) 102 | variant.append(manufacturer[-1]) 103 | valdata.append(variant) 104 | 105 | 106 | # for the trainval dataset 107 | trainvaldata = list() 108 | with open(osp.join(progpath, trainval_variant_list), 'r') as f: 109 | rows = f.readlines() 110 | trainvaldata_variant = [row.rstrip('\n').replace('/', '').split(' ', 1) for row in rows] 111 | with open(osp.join(progpath, trainval_family_list), 'r') as f: 112 | rows = f.readlines() 113 | trainvaldata_family = [row.rstrip('\n').replace('/', '').split(' ', 1) for row in rows] 114 | with open(osp.join(progpath, trainval_manufacturer_list), 'r') as f: 115 | rows = f.readlines() 116 | trainvaldata_manufacturer = [row.rstrip('\n').replace('/', '').split(' ', 1) for row in rows] 117 | if len(trainvaldata_variant) == len(trainvaldata_family) == len(trainvaldata_manufacturer): 118 | for variant, family, manufacturer in zip(trainvaldata_variant, trainvaldata_family, trainvaldata_manufacturer): 119 | # print(variant, family, manufacturer) 120 | variant.append(family[-1]) 121 | variant.append(manufacturer[-1]) 122 | trainvaldata.append(variant) 123 | 124 | # for the testing dataset 125 | testdata = list() 126 | with open(osp.join(progpath, test_variant_list), 'r') as f: 127 | rows = f.readlines() 128 | testdata_variant = [row.rstrip('\n').replace('/', '').split(' ', 1) for row in rows] 129 | with open(osp.join(progpath, test_family_list), 'r') as f: 130 | rows = f.readlines() 131 | testdata_family = [row.rstrip('\n').replace('/', '').split(' ', 1) for row in rows] 132 | with open(osp.join(progpath, test_manufacturer_list), 'r') as f: 133 | rows = f.readlines() 134 | testdata_manufacturer = [row.rstrip('\n').replace('/', '').split(' ', 1) for row in rows] 135 | if len(testdata_variant) == len(testdata_family) == len(testdata_manufacturer): 136 | for variant, family, manufacturer in zip(testdata_variant, testdata_family, testdata_manufacturer): 137 | # print(variant, family, manufacturer) 138 | variant.append(family[-1]) 139 | variant.append(manufacturer[-1]) 140 | testdata.append(variant) 141 | 142 | ####################### save in pickle files 143 | # classdict: index to class name 144 | # classnames: class name 145 | # traindata: list of tuples of (imagename, variant) 146 | # valdata: list of tuples of (imagename, variant) 147 | # trainvaldata: list of tuples of (imagename, variant) 148 | # testdata: list of tuples of (imagename, variant) 149 | 150 | with open(os.path.join(progpath, 'imdb.pkl'), 'wb') as handle: 151 | pk.dump({'classdict_variant': classdict_variant, 152 | 'classdict_family': classdict_family, 153 | 'classdict_manufacturer': classdict_manufacturer, 154 | 'classnames_variant': classnames_variant, 155 | 'classnames_family': classnames_family, 156 | 'classnames_manufacturer': classnames_manufacturer, 157 | 'traindata': traindata, 158 | 'valdata': valdata, 159 | 'testdata': testdata, 160 | 'trainvaldata': trainvaldata}, handle) 161 | else: 162 | with open(os.path.join(progpath, 'imdb.pkl'), 'rb') as handle: 163 | annos = pk.load(handle) 164 | 165 | classdict_variant = annos['classdict_variant'] 166 | classdict_family = annos['classdict_family'] 167 | classdict_manufacturer = annos['classdict_manufacturer'] 168 | classnames_variant = annos['classnames_variant'] 169 | classnames_family = annos['classnames_family'] 170 | classnames_manufacturer = annos['classnames_manufacturer'] 171 | traindata = annos['traindata'] 172 | trainvaldata =annos['trainvaldata'] 173 | testdata = annos['testdata'] 174 | valdata = annos['valdata'] 175 | 176 | print("The vgg aircraft dataset has already been loaded successfully.\n") 177 | 178 | 179 | --------------------------------------------------------------------------------