├── LICENSE ├── README.md ├── generate_case2_rotation_candidates.m ├── get_modelnet_png.sh ├── link_images.sh ├── link_images_MIRO.sh ├── train_rotationnet.py ├── vcand_case1.npy ├── vcand_case2.npy └── vcand_case3.npy /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2018, Asako Kanezaki 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch RotationNet 2 | 3 | This is a pytorch implementation of RotationNet. 4 | 5 | Asako Kanezaki, Yasuyuki Matsushita and Yoshifumi Nishida. 6 | **RotationNet: Joint Object Categorization and Pose Estimation Using Multiviews from Unsupervised Viewpoints.** 7 | *CVPR*, accepted, 2018. 8 | ([pdf](https://arxiv.org/abs/1603.06208)) 9 | ([project](https://kanezaki.github.io/rotationnet/)) 10 | 11 | We used caffe for the CVPR submission. 12 | Please see [rotationnet](https://github.com/kanezaki/rotationnet) repository for more details including how to reproduce the results in our paper. 13 | 14 | ## Training/testing ModelNet dataset 15 | 16 | ### 1. Download multi-view images 17 | #### 1-1. Download multi-view images generated in [Su et al. 2015] 18 | $ bash get_modelnet_png.sh 19 | [Su et al. 2015] H. Su, S. Maji, E. Kalogerakis, E. Learned-Miller. Multi-view Convolutional Neural Networks for 3D Shape Recognition. ICCV2015. 20 | This is a subset of ModelNet40. 21 | #### 1-2. Download our multi-view images 22 | $ wget https://data.airc.aist.go.jp/kanezaki.asako/data/modelnet40v2png_ori4.tar; tar xvf modelnet40v2png_ori4.tar 23 | Our BEST results are reported on this dataset. 24 | 25 | ### 2. Prepare dataset directories for training 26 | $ bash link_images.sh ./modelnet40v1png ./ModelNet40v1 1 27 | $ bash link_images.sh ./modelnet40v2png ./ModelNet40_20 2 28 | Or 29 | 30 | $ bash link_images.sh ./modelnet40v2png_ori4 ./ModelNet40_20 31 | 32 | ### 3. Train your own RotationNet models 33 | #### 3-1. Case (2): Train the model w/o upright orientation (RECOMMENDED) 34 | $ python train_rotationnet.py --pretrained -a alexnet -b 400 --lr 0.01 --epochs 1500 ./ModelNet40_20 | tee log_ModelNet40_20_rotationnet.txt 35 | #### 3-2. Case (1): Train the model with upright orientation 36 | $ python train_rotationnet.py --case 1 --pretrained -a alexnet -b 240 --lr 0.01 --epochs 1500 ./ModelNet40v1 | tee log_ModelNet40v1_rotationnet.txt 37 | 38 | ## Training/testing MIRO dataset 39 | 40 | ### 1. Download MIRO dataset (414MB) 41 | $ wget https://data.airc.aist.go.jp/kanezaki.asako/data/MIRO.zip 42 | $ unzip MIRO.zip 43 | 44 | ### 2. Prepare dataset directories for training 45 | $ bash link_images_MIRO.sh ./MIRO ./data_MIRO 46 | 47 | ### 3. Train your own RotationNet models 48 | #### 3-1. Case (3): Train the model w/ upright orientation 49 | $ python train_rotationnet.py --case 3 --pretrained -a alexnet -b 480 --lr 0.01 --epochs 1500 ./data_MIRO | tee log_MIRO_160_rotationnet.txt 50 | -------------------------------------------------------------------------------- /generate_case2_rotation_candidates.m: -------------------------------------------------------------------------------- 1 | function generate_case2_rotation_candidates 2 | 3 | phi = (1+sqrt(5))/2; 4 | 5 | vertices = [ 6 | 1, 1, 1; 7 | 1, 1, -1; 8 | 1, -1, 1; 9 | 1, -1, -1; 10 | -1, 1, 1; 11 | -1, 1, -1; 12 | -1, -1, 1; 13 | -1, -1, -1; 14 | 15 | 0, 1/phi, phi; 16 | 0, 1/phi, -phi; 17 | 0, -1/phi, phi; 18 | 0, -1/phi, -phi; 19 | 20 | phi, 0, 1/phi; 21 | phi, 0, -1/phi; 22 | -phi, 0, 1/phi; 23 | -phi, 0, -1/phi; 24 | 25 | 1/phi, phi, 0; 26 | -1/phi, phi, 0; 27 | 1/phi, -phi, 0; 28 | -1/phi, -phi, 0;]; 29 | 30 | %% edges 31 | edges = zeros(15, 2); 32 | len = norm(vertices(1,:)-vertices(9,:)); 33 | idx = 1; 34 | for i = 1:size(vertices,1)-1 35 | for j = i+1:size(vertices,1) 36 | if abs( norm(vertices(i,:) - vertices(j,:)) - len ) < 0.0001 37 | break_flg = false; 38 | for k = 1:idx-1 39 | if norm( vertices(i,:) + vertices(j,:) + vertices(edges(k,1),:) + vertices(edges(k,2),:)) < 0.0001 40 | break_flg = true; 41 | break 42 | end 43 | end 44 | if break_flg; 45 | break 46 | end 47 | edges(idx,1) = i; 48 | edges(idx,2) = j; 49 | idx = idx + 1; 50 | end 51 | end 52 | end 53 | 54 | % %% edges_all 55 | % edges_all = zeros(30, 2); 56 | % idx = 1; 57 | % for i = 1:size(vertices,1)-1 58 | % for j = i+1:size(vertices,1) 59 | % if abs( norm(vertices(i,:) - vertices(j,:)) - len ) < 0.0001 60 | % edges_all(idx,1) = i; 61 | % edges_all(idx,2) = j; 62 | % idx = idx + 1; 63 | % end 64 | % end 65 | % end 66 | 67 | %% faces 68 | faces(1,:) = [1 13 3 11 9]; 69 | faces(2,:) = [1 17 2 14 13]; 70 | faces(3,:) = [1 9 5 18 17]; 71 | faces(4,:) = [10 6 18 17 2]; 72 | faces(5,:) = [13 14 4 19 3]; 73 | faces(6,:) = [9 11 7 15 5]; 74 | 75 | %% 0. original 76 | inds = 1:20; 77 | 78 | %% 1. axis: vertex, angle: 2/3 * pi 79 | idx = 2; 80 | for x=1:size(vertices,1) 81 | vert_b = vertices(x,:); 82 | vert_b = vert_b ./ norm(vert_b); 83 | vertices_new = my_rotate( vert_b', 2 * pi / 3 ) * vertices'; 84 | vertices_new = vertices_new'; 85 | 86 | for i=1:size(vertices,1) 87 | for j=1:size(vertices,1) 88 | if sum(abs(vertices_new(i,:) - vertices(j,:))) < 0.0001 89 | inds(idx,i) = j; 90 | end 91 | end 92 | end 93 | idx = idx + 1; 94 | end 95 | 96 | %% 2. axis: middle point of edge, angle: pi 97 | for x=1:size(edges,1) 98 | vert_b = vertices(edges(x,1),:) + vertices(edges(x,2),:); 99 | vert_b = vert_b ./ norm(vert_b); 100 | vertices_new = my_rotate( vert_b', pi ) * vertices'; 101 | vertices_new = vertices_new'; 102 | 103 | for i=1:size(vertices,1) 104 | for j=1:size(vertices,1) 105 | if sum(abs(vertices_new(i,:) - vertices(j,:))) < 0.0001 106 | inds(idx,i) = j; 107 | end 108 | end 109 | end 110 | idx = idx + 1; 111 | end 112 | 113 | %% 3. axis: center point of face, angle: (1, 2, 3, 4) * 2/5 pi 114 | for x=1:size(faces,1) 115 | vert_b = vertices(faces(x,1),:) + vertices(faces(x,2),:) + vertices(faces(x,3),:) + vertices(faces(x,4),:) + vertices(faces(x,5),:) ; 116 | vert_b = vert_b ./ norm(vert_b); 117 | for y=1:4 118 | vertices_new = my_rotate( vert_b', y * 2 * pi / 5 ) * vertices'; 119 | vertices_new = vertices_new'; 120 | 121 | for i=1:size(vertices,1) 122 | for j=1:size(vertices,1) 123 | if sum(abs(vertices_new(i,:) - vertices(j,:))) < 0.0001 124 | inds(idx,i) = j; 125 | end 126 | end 127 | end 128 | idx = idx + 1; 129 | end 130 | end 131 | 132 | %% show rotation candidates 133 | inds 134 | 135 | end 136 | 137 | 138 | 139 | function R = my_rotate(k,fi) 140 | x = k(1); 141 | y = k(2); 142 | z = k(3); 143 | 144 | R = zeros(3,3); 145 | 146 | R(1,1) = cos(fi)+x^2*(1-cos(fi)); 147 | R(1,2) = x*y*(1-cos(fi))-z*sin(fi); 148 | R(1,3) = x*z*(1-cos(fi))+y*sin(fi); 149 | 150 | R(2,1) = y*x*(1-cos(fi))+z*sin(fi); 151 | R(2,2) = cos(fi)+y^2*(1-cos(fi)); 152 | R(2,3) = y*z*(1-cos(fi))-x*sin(fi); 153 | 154 | R(3,1) = z*x*(1-cos(fi))-y*sin(fi); 155 | R(3,2) = z*y*(1-cos(fi))+x*sin(fi); 156 | R(3,3) = cos(fi)+z^2*(1-cos(fi)); 157 | end -------------------------------------------------------------------------------- /get_modelnet_png.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # script for downloading multi-view images generated in [Su et al. 2015] 4 | # [Su et al. 2015] H. Su, S. Maji, E. Kalogerakis, E. Learned-Miller. Multi-view Convolutional Neural Networks for 3D Shape Recognition. ICCV2015. 5 | 6 | # download tar files 7 | wget http://maxwell.cs.umass.edu/mvcnn-data/modelnet40v1png.tar 8 | wget http://maxwell.cs.umass.edu/mvcnn-data/modelnet40v2png.tar 9 | 10 | # extract tar files 11 | tar xvf modelnet40v1png.tar 12 | tar xvf modelnet40v2png.tar 13 | -------------------------------------------------------------------------------- /link_images.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | dataset=$1 4 | output=$2 5 | 6 | mkdir -p $output/{train,test} 7 | 8 | for cls in `ls $dataset` 9 | do 10 | for subset in train test 11 | do 12 | mkdir -p $output/$subset/$cls 13 | cd $output/$subset/$cls 14 | 15 | for f in `ls ../../../$dataset/$cls/$subset/` 16 | do 17 | ln -s ../../../$dataset/$cls/$subset/$f . 18 | done 19 | cd ../../.. 20 | done 21 | done 22 | 23 | cd $output/ 24 | ln -s test val 25 | cd .. 26 | 27 | rm -f $output/{train,test}/*/*.off 28 | if [ $3 == 2 ] 29 | then 30 | for ((i=0;i<20;i++)) 31 | do 32 | for ((j=2;j<5;j++)) 33 | do 34 | n=`expr $i \* 4 + $j` 35 | fn=`printf "%03d" $n` 36 | rm -f $output/{train,test}/*/*_$fn.png 37 | done 38 | done 39 | fi 40 | -------------------------------------------------------------------------------- /link_images_MIRO.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | dataset=$1 4 | output=$2 5 | 6 | mkdir -p $output/{train,val} 7 | 8 | for cls in `ls $dataset` 9 | do 10 | echo $cls 11 | 12 | mkdir -p $output/val/$cls 13 | cd $output/val/$cls 14 | for f in `ls ../../../$dataset/$cls/${cls}_1_*.png` 15 | do 16 | ln -s $f . 17 | done 18 | cd ../../.. 19 | 20 | mkdir -p $output/train/$cls 21 | cd $output/train/$cls 22 | for f in `ls ../../../$dataset/$cls/*.png | grep -v "${cls}_1_"` 23 | do 24 | ln -s $f . 25 | done 26 | cd ../../.. 27 | done 28 | -------------------------------------------------------------------------------- /train_rotationnet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.parallel 9 | import torch.nn.functional 10 | import torch.backends.cudnn as cudnn 11 | import torch.distributed as dist 12 | import torch.optim 13 | import torch.utils.data 14 | import torch.utils.data.distributed 15 | import torchvision.transforms as transforms 16 | import torchvision.datasets as datasets 17 | import torchvision.models as models 18 | import numpy as np 19 | 20 | model_names = sorted(name for name in models.__dict__ 21 | if name.islower() and not name.startswith("__") 22 | and callable(models.__dict__[name])) 23 | 24 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 25 | parser.add_argument('data', metavar='DIR', 26 | help='path to dataset') 27 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', 28 | choices=model_names, 29 | help='model architecture: ' + 30 | ' | '.join(model_names) + 31 | ' (default: resnet18)') 32 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 33 | help='number of data loading workers (default: 4)') 34 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 35 | help='number of total epochs to run') 36 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 37 | help='manual epoch number (useful on restarts)') 38 | parser.add_argument('-b', '--batch-size', default=256, type=int, 39 | metavar='N', help='mini-batch size (default: 256)') 40 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 41 | metavar='LR', help='initial learning rate') 42 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 43 | help='momentum') 44 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 45 | metavar='W', help='weight decay (default: 1e-4)') 46 | parser.add_argument('--print-freq', '-p', default=10, type=int, 47 | metavar='N', help='print frequency (default: 10)') 48 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 49 | help='path to latest checkpoint (default: none)') 50 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 51 | help='evaluate model on validation set') 52 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 53 | help='use pre-trained model') 54 | parser.add_argument('--world-size', default=1, type=int, 55 | help='number of distributed processes') 56 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 57 | help='url used to set up distributed training') 58 | parser.add_argument('--dist-backend', default='gloo', type=str, 59 | help='distributed backend') 60 | parser.add_argument('--case', default='2', type=str, 61 | help='viewpoint setup case (1 or 2)') 62 | 63 | best_prec1 = 0 64 | vcand = np.load('vcand_case2.npy') 65 | nview = 20 66 | 67 | class FineTuneModel(nn.Module): 68 | def __init__(self, original_model, arch, num_classes): 69 | super(FineTuneModel, self).__init__() 70 | 71 | if arch.startswith('alexnet') : 72 | self.features = original_model.features 73 | self.classifier = nn.Sequential( 74 | nn.Dropout(), 75 | nn.Linear(256 * 6 * 6, 4096), 76 | nn.ReLU(inplace=True), 77 | nn.Dropout(), 78 | nn.Linear(4096, 4096), 79 | nn.ReLU(inplace=True), 80 | nn.Linear(4096, num_classes), 81 | ) 82 | self.modelName = 'alexnet' 83 | elif arch.startswith('resnet') : 84 | # Everything except the last linear layer 85 | self.features = nn.Sequential(*list(original_model.children())[:-1]) 86 | self.classifier = nn.Sequential( 87 | nn.Linear(512, num_classes) 88 | ) 89 | self.modelName = 'resnet' 90 | elif arch.startswith('vgg16'): 91 | self.features = original_model.features 92 | self.classifier = nn.Sequential( 93 | nn.Dropout(), 94 | nn.Linear(25088, 4096), 95 | nn.ReLU(inplace=True), 96 | nn.Dropout(), 97 | nn.Linear(4096, 4096), 98 | nn.ReLU(inplace=True), 99 | nn.Linear(4096, num_classes), 100 | ) 101 | self.modelName = 'vgg16' 102 | else : 103 | raise("Finetuning not supported on this architecture yet") 104 | 105 | # # Freeze those weights 106 | # for p in self.features.parameters(): 107 | # p.requires_grad = False 108 | 109 | 110 | def forward(self, x): 111 | f = self.features(x) 112 | if self.modelName == 'alexnet' : 113 | f = f.view(f.size(0), 256 * 6 * 6) 114 | elif self.modelName == 'vgg16': 115 | f = f.view(f.size(0), -1) 116 | elif self.modelName == 'resnet' : 117 | f = f.view(f.size(0), -1) 118 | y = self.classifier(f) 119 | return y 120 | 121 | 122 | def main(): 123 | global args, best_prec1, nview, vcand 124 | args = parser.parse_args() 125 | 126 | args.distributed = args.world_size > 1 127 | 128 | if args.case == '1': 129 | vcand = np.load('vcand_case1.npy') 130 | nview = 12 131 | elif args.case == '3': 132 | vcand = np.load('vcand_case3.npy') 133 | nview = 160 134 | 135 | if args.batch_size % nview != 0: 136 | print ('Error: batch size should be multiplication of the number of views,', nview) 137 | exit() 138 | 139 | if args.distributed: 140 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 141 | world_size=args.world_size) 142 | 143 | traindir = os.path.join(args.data, 'train') 144 | valdir = os.path.join(args.data, 'val') 145 | # Get number of classes from train directory 146 | num_classes = len([name for name in os.listdir(traindir)]) 147 | print("num_classes = '{}'".format(num_classes)) 148 | 149 | # create model 150 | if args.pretrained: 151 | print("=> using pre-trained model '{}'".format(args.arch)) 152 | model = models.__dict__[args.arch](pretrained=True) 153 | else: 154 | print("=> creating model '{}'".format(args.arch)) 155 | model = models.__dict__[args.arch]() 156 | 157 | model = FineTuneModel(model, args.arch, (num_classes+1) * nview ) 158 | 159 | if not args.distributed: 160 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 161 | model.features = torch.nn.DataParallel(model.features) 162 | model.cuda() 163 | else: 164 | model = torch.nn.DataParallel(model).cuda() 165 | else: 166 | model.cuda() 167 | model = torch.nn.parallel.DistributedDataParallel(model) 168 | 169 | # define loss function (criterion) and optimizer 170 | criterion = nn.CrossEntropyLoss().cuda() 171 | 172 | ##optimizer = torch.optim.SGD(model.parameters(), args.lr, 173 | optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), # Only finetunable params 174 | args.lr, 175 | momentum=args.momentum, 176 | weight_decay=args.weight_decay) 177 | 178 | # optionally resume from a checkpoint 179 | if args.resume: 180 | if os.path.isfile(args.resume): 181 | print("=> loading checkpoint '{}'".format(args.resume)) 182 | checkpoint = torch.load(args.resume) 183 | args.start_epoch = checkpoint['epoch'] 184 | best_prec1 = checkpoint['best_prec1'] 185 | model.load_state_dict(checkpoint['state_dict']) 186 | optimizer.load_state_dict(checkpoint['optimizer']) 187 | print("=> loaded checkpoint '{}' (epoch {})" 188 | .format(args.resume, checkpoint['epoch'])) 189 | else: 190 | print("=> no checkpoint found at '{}'".format(args.resume)) 191 | 192 | cudnn.benchmark = True 193 | 194 | # Data loading code 195 | traindir = os.path.join(args.data, 'train') 196 | valdir = os.path.join(args.data, 'val') 197 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 198 | std=[0.229, 0.224, 0.225]) 199 | 200 | train_dataset = datasets.ImageFolder( 201 | traindir, 202 | transforms.Compose([ 203 | # transforms.CenterCrop(224), 204 | transforms.ToTensor(), 205 | normalize, 206 | ])) 207 | 208 | if args.distributed: 209 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 210 | else: 211 | train_sampler = None 212 | 213 | train_loader = torch.utils.data.DataLoader( 214 | train_dataset, batch_size=args.batch_size, shuffle=False, 215 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 216 | sorted_imgs = sorted(train_loader.dataset.imgs) 217 | train_nsamp = int( len(sorted_imgs) / nview ) 218 | 219 | val_loader = torch.utils.data.DataLoader( 220 | datasets.ImageFolder(valdir, transforms.Compose([ 221 | # transforms.Scale(256), 222 | # transforms.CenterCrop(224), 223 | transforms.ToTensor(), 224 | normalize, 225 | ])), 226 | batch_size=args.batch_size, shuffle=False, 227 | num_workers=args.workers, pin_memory=True) 228 | val_loader.dataset.imgs = sorted(val_loader.dataset.imgs) 229 | 230 | if args.evaluate: 231 | validate(val_loader, model, criterion) 232 | return 233 | 234 | for epoch in range(args.start_epoch, args.epochs): 235 | if args.distributed: 236 | train_sampler.set_epoch(epoch) 237 | 238 | adjust_learning_rate(optimizer, epoch) 239 | 240 | # random permutation 241 | inds = np.zeros( ( nview, train_nsamp ) ).astype('int') 242 | inds[ 0 ] = np.random.permutation(range(train_nsamp)) * nview 243 | for i in range(1,nview): 244 | inds[ i ] = inds[ 0 ] + i 245 | inds = inds.T.reshape( nview * train_nsamp ) 246 | train_loader.dataset.imgs = [sorted_imgs[ i ] for i in inds] 247 | train_loader.dataset.samples = train_loader.dataset.imgs 248 | 249 | # train for one epoch 250 | train(train_loader, model, criterion, optimizer, epoch) 251 | 252 | # evaluate on validation set 253 | prec1 = validate(val_loader, model, criterion) 254 | 255 | # remember best prec@1 and save checkpoint 256 | is_best = prec1 > best_prec1 257 | best_prec1 = max(prec1, best_prec1) 258 | fname='rotationnet_checkpoint.pth.tar' 259 | fname2='rotationnet_model_best.pth.tar' 260 | if nview == 12: 261 | fname='rotationnet_checkpoint_case1.pth.tar' 262 | fname2='rotationnet_model_best_case1.pth.tar' 263 | save_checkpoint({ 264 | 'epoch': epoch + 1, 265 | 'arch': args.arch, 266 | 'state_dict': model.state_dict(), 267 | 'best_prec1': best_prec1, 268 | 'optimizer' : optimizer.state_dict(), 269 | }, is_best,fname,fname2) 270 | 271 | 272 | 273 | def train(train_loader, model, criterion, optimizer, epoch): 274 | batch_time = AverageMeter() 275 | data_time = AverageMeter() 276 | losses = AverageMeter() 277 | top1 = AverageMeter() 278 | top5 = AverageMeter() 279 | 280 | # switch to train mode 281 | model.train() 282 | 283 | end = time.time() 284 | for i, (input, target) in enumerate(train_loader): 285 | nsamp = int( input.size(0) / nview ) 286 | 287 | # measure data loading time 288 | data_time.update(time.time() - end) 289 | 290 | input_var = torch.autograd.Variable(input) 291 | target_ = torch.LongTensor( target.size(0) * nview ) 292 | 293 | # compute output 294 | output = model(input_var) 295 | num_classes = int( output.size( 1 ) / nview ) - 1 296 | output = output.view( -1, num_classes + 1 ) 297 | 298 | ########################################### 299 | # compute scores and decide target labels # 300 | ########################################### 301 | output_ = torch.nn.functional.log_softmax( output ) 302 | # divide object scores by the scores for "incorrect view label" (see Eq.(5)) 303 | output_ = output_[ :, :-1 ] - torch.t( output_[ :, -1 ].repeat( 1, output_.size(1)-1 ).view( output_.size(1)-1, -1 ) ) 304 | # reshape output matrix 305 | output_ = output_.view( -1, nview * nview, num_classes ) 306 | output_ = output_.data.cpu().numpy() 307 | output_ = output_.transpose( 1, 2, 0 ) 308 | # initialize target labels with "incorrect view label" 309 | for j in range(target_.size(0)): 310 | target_[ j ] = num_classes 311 | # compute scores for all the candidate poses (see Eq.(5)) 312 | scores = np.zeros( ( vcand.shape[ 0 ], num_classes, nsamp ) ) 313 | for j in range(vcand.shape[0]): 314 | for k in range(vcand.shape[1]): 315 | scores[ j ] = scores[ j ] + output_[ vcand[ j ][ k ] * nview + k ] 316 | # for each sample #n, determine the best pose that maximizes the score for the target class (see Eq.(2)) 317 | for n in range( nsamp ): 318 | j_max = np.argmax( scores[ :, target[ n * nview ], n ] ) 319 | # assign target labels 320 | for k in range(vcand.shape[1]): 321 | target_[ n * nview * nview + vcand[ j_max ][ k ] * nview + k ] = target[ n * nview ] 322 | ########################################### 323 | 324 | target_ = target_.cuda() 325 | target_var = torch.autograd.Variable(target_) 326 | 327 | # compute loss 328 | loss = criterion(output, target_var) 329 | losses.update(loss.item(), input.size(0)) 330 | 331 | # compute gradient and do SGD step 332 | optimizer.zero_grad() 333 | loss.backward() 334 | optimizer.step() 335 | 336 | # measure elapsed time 337 | batch_time.update(time.time() - end) 338 | end = time.time() 339 | 340 | if i % args.print_freq == 0: 341 | print('Epoch: [{0}][{1}/{2}]\t' 342 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 343 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 344 | 'Loss {loss.val:.4f} ({loss.avg:.4f})'.format( 345 | epoch, i, len(train_loader), batch_time=batch_time, 346 | data_time=data_time, loss=losses)) 347 | 348 | 349 | def validate(val_loader, model, criterion): 350 | batch_time = AverageMeter() 351 | losses = AverageMeter() 352 | top1 = AverageMeter() 353 | top5 = AverageMeter() 354 | 355 | # switch to evaluate mode 356 | model.eval() 357 | 358 | end = time.time() 359 | for i, (input, target) in enumerate(val_loader): 360 | target = target.cuda() 361 | input_var = torch.autograd.Variable(input, volatile=True) 362 | target_var = torch.autograd.Variable(target, volatile=True) 363 | 364 | # compute output 365 | output = model(input_var) 366 | loss = criterion(output, target_var) 367 | 368 | # log_softmax and reshape output 369 | num_classes = int( output.size( 1 ) / nview ) - 1 370 | output = output.view( -1, num_classes + 1 ) 371 | output = torch.nn.functional.log_softmax( output ) 372 | output = output[ :, :-1 ] - torch.t( output[ :, -1 ].repeat( 1, output.size(1)-1 ).view( output.size(1)-1, -1 ) ) 373 | output = output.view( -1, nview * nview, num_classes ) 374 | 375 | # measure accuracy and record loss 376 | prec1, prec5 = my_accuracy(output.data, target, topk=(1, 5)) 377 | losses.update(loss.item(), input.size(0)) 378 | top1.update(prec1.item(), input.size(0)/nview) 379 | top5.update(prec5.item(), input.size(0)/nview) 380 | 381 | # measure elapsed time 382 | batch_time.update(time.time() - end) 383 | end = time.time() 384 | 385 | if i % args.print_freq == 0: 386 | print('Test: [{0}/{1}]\t' 387 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 388 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 389 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 390 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 391 | i, len(val_loader), batch_time=batch_time, loss=losses, 392 | top1=top1, top5=top5)) 393 | 394 | print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' 395 | .format(top1=top1, top5=top5)) 396 | 397 | return top1.avg 398 | 399 | 400 | def save_checkpoint(state, is_best, filename='rotationnet_checkpoint.pth.tar', filename2='rotationnet_model_best.pth.tar'): 401 | torch.save(state, filename) 402 | if is_best: 403 | shutil.copyfile(filename, filename2) 404 | 405 | 406 | class AverageMeter(object): 407 | """Computes and stores the average and current value""" 408 | def __init__(self): 409 | self.reset() 410 | 411 | def reset(self): 412 | self.val = 0 413 | self.avg = 0 414 | self.sum = 0 415 | self.count = 0 416 | 417 | def update(self, val, n=1): 418 | self.val = val 419 | self.sum += val * n 420 | self.count += n 421 | self.avg = self.sum / self.count 422 | 423 | 424 | def adjust_learning_rate(optimizer, epoch): 425 | """Sets the learning rate to the initial LR decayed by 10 every 200 epochs""" 426 | lr = args.lr * (0.1 ** (epoch // 200)) 427 | for param_group in optimizer.param_groups: 428 | param_group['lr'] = lr 429 | print ('Learning Rate: {lr:.6f}'.format(lr=param_group['lr'])) 430 | 431 | 432 | def accuracy(output, target, topk=(1,)): 433 | """Computes the precision@k for the specified values of k""" 434 | maxk = max(topk) 435 | batch_size = target.size(0) 436 | 437 | _, pred = output.topk(maxk, 1, True, True) 438 | pred = pred.t() 439 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 440 | 441 | res = [] 442 | for k in topk: 443 | #correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 444 | correct_k = correct[:k].view(-1).float().sum(0) 445 | res.append(correct_k.mul_(100.0 / batch_size)) 446 | return res 447 | 448 | def my_accuracy(output_, target, topk=(1,)): 449 | """Computes the precision@k for the specified values of k""" 450 | maxk = max(topk) 451 | target = target[0:-1:nview] 452 | batch_size = target.size(0) 453 | 454 | num_classes = output_.size(2) 455 | output_ = output_.cpu().numpy() 456 | output_ = output_.transpose( 1, 2, 0 ) 457 | scores = np.zeros( ( vcand.shape[ 0 ], num_classes, batch_size ) ) 458 | output = torch.zeros( ( batch_size, num_classes ) ) 459 | # compute scores for all the candidate poses (see Eq.(6)) 460 | for j in range(vcand.shape[0]): 461 | for k in range(vcand.shape[1]): 462 | scores[ j ] = scores[ j ] + output_[ vcand[ j ][ k ] * nview + k ] 463 | # for each sample #n, determine the best pose that maximizes the score (for the top class) 464 | for n in range( batch_size ): 465 | j_max = int( np.argmax( scores[ :, :, n ] ) / scores.shape[ 1 ] ) 466 | output[ n ] = torch.FloatTensor( scores[ j_max, :, n ] ) 467 | output = output.cuda() 468 | 469 | _, pred = output.topk(maxk, 1, True, True) 470 | pred = pred.t() 471 | correct = pred.eq(target.contiguous().view(1, -1).expand_as(pred)) 472 | 473 | res = [] 474 | for k in topk: 475 | correct_k = correct[:k].reshape(-1).float().sum(0) 476 | res.append(correct_k.mul_(100.0 / batch_size)) 477 | return res 478 | 479 | 480 | if __name__ == '__main__': 481 | main() 482 | -------------------------------------------------------------------------------- /vcand_case1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanezaki/pytorch-rotationnet/418ff43a05cfb9a96b3981d9e8beab2ea73211c8/vcand_case1.npy -------------------------------------------------------------------------------- /vcand_case2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanezaki/pytorch-rotationnet/418ff43a05cfb9a96b3981d9e8beab2ea73211c8/vcand_case2.npy -------------------------------------------------------------------------------- /vcand_case3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanezaki/pytorch-rotationnet/418ff43a05cfb9a96b3981d9e8beab2ea73211c8/vcand_case3.npy --------------------------------------------------------------------------------