├── LICENSE ├── README.md ├── __init__.py ├── models ├── MVCNN.py ├── Model.py └── __init__.py ├── tools ├── ImgDataset.py ├── Trainer.py └── __init__.py └── train_mvcnn.py /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021 Jong-Chyi, Matheus Gadelha, Rui Wang, and Subhransu Maji 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch code for MVCNN 2 | Code is tested on Python 3.6 and PyTorch 0.4.1 3 | 4 | First, download images and put it under ```modelnet40_images_new_12x```: 5 | [Shaded Images (1.6GB)](http://supermoe.cs.umass.edu/shape_recog/shaded_images.tar.gz) 6 | 7 | Command for training: 8 | ```python train_mvcnn.py -name mvcnn -num_models 1000 -weight_decay 0.001 -num_views 12 -cnn_name vgg11``` 9 | 10 | 11 | 12 | 13 | [Project webpage](https://people.cs.umass.edu/~jcsu/papers/shape_recog/) 14 | [Depth Images (1.6GB)](http://supermoe.cs.umass.edu/shape_recog/depth_images.tar.gz) 15 | 16 | [Blender script for rendering shaded images](http://people.cs.umass.edu/~jcsu/papers/shape_recog/render_shaded_black_bg.blend) 17 | [Blender script for rendering depth images](http://people.cs.umass.edu/~jcsu/papers/shape_recog/render_depth.blend) 18 | 19 | ## Reference 20 | **A Deeper Look at 3D Shape Classifiers** 21 | Jong-Chyi Su, Matheus Gadelha, Rui Wang, and Subhransu Maji 22 | *Second Workshop on 3D Reconstruction Meets Semantics, ECCV, 2018* 23 | 24 | **Multi-view Convolutional Neural Networks for 3D Shape Recognition** 25 | Hang Su, Subhransu Maji, Evangelos Kalogerakis, and Erik Learned-Miller, 26 | *International Conference on Computer Vision, ICCV, 2015* 27 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jongchyisu/mvcnn_pytorch/09a3b5134d92a35da31e4247b20c3c814b41f753/__init__.py -------------------------------------------------------------------------------- /models/MVCNN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | import torchvision.models as models 8 | from .Model import Model 9 | 10 | mean = Variable(torch.FloatTensor([0.485, 0.456, 0.406]), requires_grad=False).cuda() 11 | std = Variable(torch.FloatTensor([0.229, 0.224, 0.225]), requires_grad=False).cuda() 12 | 13 | def flip(x, dim): 14 | xsize = x.size() 15 | dim = x.dim() + dim if dim < 0 else dim 16 | x = x.view(-1, *xsize[dim:]) 17 | x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1)-1, 18 | -1, -1), ('cpu','cuda')[x.is_cuda])().long(), :] 19 | return x.view(xsize) 20 | 21 | 22 | class SVCNN(Model): 23 | 24 | def __init__(self, name, nclasses=40, pretraining=True, cnn_name='vgg11'): 25 | super(SVCNN, self).__init__(name) 26 | 27 | self.classnames=['airplane','bathtub','bed','bench','bookshelf','bottle','bowl','car','chair', 28 | 'cone','cup','curtain','desk','door','dresser','flower_pot','glass_box', 29 | 'guitar','keyboard','lamp','laptop','mantel','monitor','night_stand', 30 | 'person','piano','plant','radio','range_hood','sink','sofa','stairs', 31 | 'stool','table','tent','toilet','tv_stand','vase','wardrobe','xbox'] 32 | 33 | self.nclasses = nclasses 34 | self.pretraining = pretraining 35 | self.cnn_name = cnn_name 36 | self.use_resnet = cnn_name.startswith('resnet') 37 | self.mean = Variable(torch.FloatTensor([0.485, 0.456, 0.406]), requires_grad=False).cuda() 38 | self.std = Variable(torch.FloatTensor([0.229, 0.224, 0.225]), requires_grad=False).cuda() 39 | 40 | if self.use_resnet: 41 | if self.cnn_name == 'resnet18': 42 | self.net = models.resnet18(pretrained=self.pretraining) 43 | self.net.fc = nn.Linear(512,40) 44 | elif self.cnn_name == 'resnet34': 45 | self.net = models.resnet34(pretrained=self.pretraining) 46 | self.net.fc = nn.Linear(512,40) 47 | elif self.cnn_name == 'resnet50': 48 | self.net = models.resnet50(pretrained=self.pretraining) 49 | self.net.fc = nn.Linear(2048,40) 50 | else: 51 | if self.cnn_name == 'alexnet': 52 | self.net_1 = models.alexnet(pretrained=self.pretraining).features 53 | self.net_2 = models.alexnet(pretrained=self.pretraining).classifier 54 | elif self.cnn_name == 'vgg11': 55 | self.net_1 = models.vgg11(pretrained=self.pretraining).features 56 | self.net_2 = models.vgg11(pretrained=self.pretraining).classifier 57 | elif self.cnn_name == 'vgg16': 58 | self.net_1 = models.vgg16(pretrained=self.pretraining).features 59 | self.net_2 = models.vgg16(pretrained=self.pretraining).classifier 60 | 61 | self.net_2._modules['6'] = nn.Linear(4096,40) 62 | 63 | def forward(self, x): 64 | if self.use_resnet: 65 | return self.net(x) 66 | else: 67 | y = self.net_1(x) 68 | return self.net_2(y.view(y.shape[0],-1)) 69 | 70 | 71 | class MVCNN(Model): 72 | 73 | def __init__(self, name, model, nclasses=40, cnn_name='vgg11', num_views=12): 74 | super(MVCNN, self).__init__(name) 75 | 76 | self.classnames=['airplane','bathtub','bed','bench','bookshelf','bottle','bowl','car','chair', 77 | 'cone','cup','curtain','desk','door','dresser','flower_pot','glass_box', 78 | 'guitar','keyboard','lamp','laptop','mantel','monitor','night_stand', 79 | 'person','piano','plant','radio','range_hood','sink','sofa','stairs', 80 | 'stool','table','tent','toilet','tv_stand','vase','wardrobe','xbox'] 81 | 82 | self.nclasses = nclasses 83 | self.num_views = num_views 84 | self.mean = Variable(torch.FloatTensor([0.485, 0.456, 0.406]), requires_grad=False).cuda() 85 | self.std = Variable(torch.FloatTensor([0.229, 0.224, 0.225]), requires_grad=False).cuda() 86 | 87 | self.use_resnet = cnn_name.startswith('resnet') 88 | 89 | if self.use_resnet: 90 | self.net_1 = nn.Sequential(*list(model.net.children())[:-1]) 91 | self.net_2 = model.net.fc 92 | else: 93 | self.net_1 = model.net_1 94 | self.net_2 = model.net_2 95 | 96 | def forward(self, x): 97 | y = self.net_1(x) 98 | y = y.view((int(x.shape[0]/self.num_views),self.num_views,y.shape[-3],y.shape[-2],y.shape[-1]))#(8,12,512,7,7) 99 | return self.net_2(torch.max(y,1)[0].view(y.shape[0],-1)) 100 | 101 | -------------------------------------------------------------------------------- /models/Model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | import glob 5 | 6 | 7 | class Model(nn.Module): 8 | 9 | def __init__(self, name): 10 | super(Model, self).__init__() 11 | self.name = name 12 | 13 | 14 | def save(self, path, epoch=0): 15 | complete_path = os.path.join(path, self.name) 16 | if not os.path.exists(complete_path): 17 | os.makedirs(complete_path) 18 | torch.save(self.state_dict(), 19 | os.path.join(complete_path, 20 | "model-{}.pth".format(str(epoch).zfill(5)))) 21 | 22 | 23 | def save_results(self, path, data): 24 | raise NotImplementedError("Model subclass must implement this method.") 25 | 26 | 27 | def load(self, path, modelfile=None): 28 | complete_path = os.path.join(path, self.name) 29 | if not os.path.exists(complete_path): 30 | raise IOError("{} directory does not exist in {}".format(self.name, path)) 31 | 32 | if modelfile is None: 33 | model_files = glob.glob(complete_path+"/*") 34 | mf = max(model_files) 35 | else: 36 | mf = os.path.join(complete_path, modelfile) 37 | 38 | self.load_state_dict(torch.load(mf)) 39 | 40 | 41 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jongchyisu/mvcnn_pytorch/09a3b5134d92a35da31e4247b20c3c814b41f753/models/__init__.py -------------------------------------------------------------------------------- /tools/ImgDataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import glob 3 | import torch.utils.data 4 | import os 5 | import math 6 | from skimage import io, transform 7 | from PIL import Image 8 | import torch 9 | import torchvision as vision 10 | from torchvision import transforms, datasets 11 | import random 12 | 13 | class MultiviewImgDataset(torch.utils.data.Dataset): 14 | 15 | def __init__(self, root_dir, scale_aug=False, rot_aug=False, test_mode=False, \ 16 | num_models=0, num_views=12, shuffle=True): 17 | self.classnames=['airplane','bathtub','bed','bench','bookshelf','bottle','bowl','car','chair', 18 | 'cone','cup','curtain','desk','door','dresser','flower_pot','glass_box', 19 | 'guitar','keyboard','lamp','laptop','mantel','monitor','night_stand', 20 | 'person','piano','plant','radio','range_hood','sink','sofa','stairs', 21 | 'stool','table','tent','toilet','tv_stand','vase','wardrobe','xbox'] 22 | self.root_dir = root_dir 23 | self.scale_aug = scale_aug 24 | self.rot_aug = rot_aug 25 | self.test_mode = test_mode 26 | self.num_views = num_views 27 | 28 | set_ = root_dir.split('/')[-1] 29 | parent_dir = root_dir.rsplit('/',2)[0] 30 | self.filepaths = [] 31 | for i in range(len(self.classnames)): 32 | all_files = sorted(glob.glob(parent_dir+'/'+self.classnames[i]+'/'+set_+'/*.png')) 33 | ## Select subset for different number of views 34 | stride = int(12/self.num_views) # 12 6 4 3 2 1 35 | all_files = all_files[::stride] 36 | 37 | if num_models == 0: 38 | # Use the whole dataset 39 | self.filepaths.extend(all_files) 40 | else: 41 | self.filepaths.extend(all_files[:min(num_models,len(all_files))]) 42 | 43 | if shuffle==True: 44 | # permute 45 | rand_idx = np.random.permutation(int(len(self.filepaths)/num_views)) 46 | filepaths_new = [] 47 | for i in range(len(rand_idx)): 48 | filepaths_new.extend(self.filepaths[rand_idx[i]*num_views:(rand_idx[i]+1)*num_views]) 49 | self.filepaths = filepaths_new 50 | 51 | 52 | if self.test_mode: 53 | self.transform = transforms.Compose([ 54 | transforms.ToTensor(), 55 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 56 | std=[0.229, 0.224, 0.225]) 57 | ]) 58 | else: 59 | self.transform = transforms.Compose([ 60 | transforms.RandomHorizontalFlip(), 61 | transforms.ToTensor(), 62 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 63 | std=[0.229, 0.224, 0.225]) 64 | ]) 65 | 66 | 67 | def __len__(self): 68 | return int(len(self.filepaths)/self.num_views) 69 | 70 | 71 | def __getitem__(self, idx): 72 | path = self.filepaths[idx*self.num_views] 73 | class_name = path.split('/')[-3] 74 | class_id = self.classnames.index(class_name) 75 | # Use PIL instead 76 | imgs = [] 77 | for i in range(self.num_views): 78 | im = Image.open(self.filepaths[idx*self.num_views+i]).convert('RGB') 79 | if self.transform: 80 | im = self.transform(im) 81 | imgs.append(im) 82 | 83 | return (class_id, torch.stack(imgs), self.filepaths[idx*self.num_views:(idx+1)*self.num_views]) 84 | 85 | 86 | 87 | class SingleImgDataset(torch.utils.data.Dataset): 88 | 89 | def __init__(self, root_dir, scale_aug=False, rot_aug=False, test_mode=False, \ 90 | num_models=0, num_views=12): 91 | self.classnames=['airplane','bathtub','bed','bench','bookshelf','bottle','bowl','car','chair', 92 | 'cone','cup','curtain','desk','door','dresser','flower_pot','glass_box', 93 | 'guitar','keyboard','lamp','laptop','mantel','monitor','night_stand', 94 | 'person','piano','plant','radio','range_hood','sink','sofa','stairs', 95 | 'stool','table','tent','toilet','tv_stand','vase','wardrobe','xbox'] 96 | self.root_dir = root_dir 97 | self.scale_aug = scale_aug 98 | self.rot_aug = rot_aug 99 | self.test_mode = test_mode 100 | 101 | set_ = root_dir.split('/')[-1] 102 | parent_dir = root_dir.rsplit('/',2)[0] 103 | self.filepaths = [] 104 | for i in range(len(self.classnames)): 105 | all_files = sorted(glob.glob(parent_dir+'/'+self.classnames[i]+'/'+set_+'/*shaded*.png')) 106 | if num_models == 0: 107 | # Use the whole dataset 108 | self.filepaths.extend(all_files) 109 | else: 110 | self.filepaths.extend(all_files[:min(num_models,len(all_files))]) 111 | 112 | self.transform = transforms.Compose([ 113 | transforms.RandomHorizontalFlip(), 114 | transforms.ToTensor(), 115 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 116 | std=[0.229, 0.224, 0.225]) 117 | ]) 118 | 119 | 120 | def __len__(self): 121 | return len(self.filepaths) 122 | 123 | 124 | def __getitem__(self, idx): 125 | path = self.filepaths[idx] 126 | class_name = path.split('/')[-3] 127 | class_id = self.classnames.index(class_name) 128 | 129 | # Use PIL instead 130 | im = Image.open(self.filepaths[idx]).convert('RGB') 131 | if self.transform: 132 | im = self.transform(im) 133 | 134 | return (class_id, im, path) 135 | 136 | -------------------------------------------------------------------------------- /tools/Trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | import pickle 7 | import os 8 | from tensorboardX import SummaryWriter 9 | import time 10 | 11 | class ModelNetTrainer(object): 12 | 13 | def __init__(self, model, train_loader, val_loader, optimizer, loss_fn, \ 14 | model_name, log_dir, num_views=12): 15 | 16 | self.optimizer = optimizer 17 | self.model = model 18 | self.train_loader = train_loader 19 | self.val_loader = val_loader 20 | self.loss_fn = loss_fn 21 | self.model_name = model_name 22 | self.log_dir = log_dir 23 | self.num_views = num_views 24 | 25 | self.model.cuda() 26 | if self.log_dir is not None: 27 | self.writer = SummaryWriter(log_dir) 28 | 29 | 30 | def train(self, n_epochs): 31 | 32 | best_acc = 0 33 | i_acc = 0 34 | self.model.train() 35 | for epoch in range(n_epochs): 36 | # permute data for mvcnn 37 | rand_idx = np.random.permutation(int(len(self.train_loader.dataset.filepaths)/self.num_views)) 38 | filepaths_new = [] 39 | for i in range(len(rand_idx)): 40 | filepaths_new.extend(self.train_loader.dataset.filepaths[rand_idx[i]*self.num_views:(rand_idx[i]+1)*self.num_views]) 41 | self.train_loader.dataset.filepaths = filepaths_new 42 | 43 | # plot learning rate 44 | lr = self.optimizer.state_dict()['param_groups'][0]['lr'] 45 | self.writer.add_scalar('params/lr', lr, epoch) 46 | 47 | # train one epoch 48 | out_data = None 49 | in_data = None 50 | for i, data in enumerate(self.train_loader): 51 | 52 | if self.model_name == 'mvcnn': 53 | N,V,C,H,W = data[1].size() 54 | in_data = Variable(data[1]).view(-1,C,H,W).cuda() 55 | else: 56 | in_data = Variable(data[1].cuda()) 57 | target = Variable(data[0]).cuda().long() 58 | 59 | self.optimizer.zero_grad() 60 | 61 | out_data = self.model(in_data) 62 | 63 | loss = self.loss_fn(out_data, target) 64 | 65 | self.writer.add_scalar('train/train_loss', loss, i_acc+i+1) 66 | 67 | pred = torch.max(out_data, 1)[1] 68 | results = pred == target 69 | correct_points = torch.sum(results.long()) 70 | 71 | acc = correct_points.float()/results.size()[0] 72 | self.writer.add_scalar('train/train_overall_acc', acc, i_acc+i+1) 73 | 74 | loss.backward() 75 | self.optimizer.step() 76 | 77 | log_str = 'epoch %d, step %d: train_loss %.3f; train_acc %.3f' % (epoch+1, i+1, loss, acc) 78 | if (i+1)%1==0: 79 | print(log_str) 80 | i_acc += i 81 | 82 | # evaluation 83 | if (epoch+1)%1==0: 84 | with torch.no_grad(): 85 | loss, val_overall_acc, val_mean_class_acc = self.update_validation_accuracy(epoch) 86 | self.writer.add_scalar('val/val_mean_class_acc', val_mean_class_acc, epoch+1) 87 | self.writer.add_scalar('val/val_overall_acc', val_overall_acc, epoch+1) 88 | self.writer.add_scalar('val/val_loss', loss, epoch+1) 89 | 90 | # save best model 91 | if val_overall_acc > best_acc: 92 | best_acc = val_overall_acc 93 | self.model.save(self.log_dir, epoch) 94 | 95 | # adjust learning rate manually 96 | if epoch > 0 and (epoch+1) % 10 == 0: 97 | for param_group in self.optimizer.param_groups: 98 | param_group['lr'] = param_group['lr']*0.5 99 | 100 | # export scalar data to JSON for external processing 101 | self.writer.export_scalars_to_json(self.log_dir+"/all_scalars.json") 102 | self.writer.close() 103 | 104 | def update_validation_accuracy(self, epoch): 105 | all_correct_points = 0 106 | all_points = 0 107 | 108 | # in_data = None 109 | # out_data = None 110 | # target = None 111 | 112 | wrong_class = np.zeros(40) 113 | samples_class = np.zeros(40) 114 | all_loss = 0 115 | 116 | self.model.eval() 117 | 118 | avgpool = nn.AvgPool1d(1, 1) 119 | 120 | total_time = 0.0 121 | total_print_time = 0.0 122 | all_target = [] 123 | all_pred = [] 124 | 125 | for _, data in enumerate(self.val_loader, 0): 126 | 127 | if self.model_name == 'mvcnn': 128 | N,V,C,H,W = data[1].size() 129 | in_data = Variable(data[1]).view(-1,C,H,W).cuda() 130 | else:#'svcnn' 131 | in_data = Variable(data[1]).cuda() 132 | target = Variable(data[0]).cuda() 133 | 134 | out_data = self.model(in_data) 135 | pred = torch.max(out_data, 1)[1] 136 | all_loss += self.loss_fn(out_data, target).cpu().data.numpy() 137 | results = pred == target 138 | 139 | for i in range(results.size()[0]): 140 | if not bool(results[i].cpu().data.numpy()): 141 | wrong_class[target.cpu().data.numpy().astype('int')[i]] += 1 142 | samples_class[target.cpu().data.numpy().astype('int')[i]] += 1 143 | correct_points = torch.sum(results.long()) 144 | 145 | all_correct_points += correct_points 146 | all_points += results.size()[0] 147 | 148 | print ('Total # of test models: ', all_points) 149 | val_mean_class_acc = np.mean((samples_class-wrong_class)/samples_class) 150 | acc = all_correct_points.float() / all_points 151 | val_overall_acc = acc.cpu().data.numpy() 152 | loss = all_loss / len(self.val_loader) 153 | 154 | print ('val mean class acc. : ', val_mean_class_acc) 155 | print ('val overall acc. : ', val_overall_acc) 156 | print ('val loss : ', loss) 157 | 158 | self.model.train() 159 | 160 | return loss, val_overall_acc, val_mean_class_acc 161 | 162 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jongchyisu/mvcnn_pytorch/09a3b5134d92a35da31e4247b20c3c814b41f753/tools/__init__.py -------------------------------------------------------------------------------- /train_mvcnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.optim as optim 4 | import torch.nn as nn 5 | import os,shutil,json 6 | import argparse 7 | 8 | from tools.Trainer import ModelNetTrainer 9 | from tools.ImgDataset import MultiviewImgDataset, SingleImgDataset 10 | from models.MVCNN import MVCNN, SVCNN 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("-name", "--name", type=str, help="Name of the experiment", default="MVCNN") 14 | parser.add_argument("-bs", "--batchSize", type=int, help="Batch size for the second stage", default=8)# it will be *12 images in each batch for mvcnn 15 | parser.add_argument("-num_models", type=int, help="number of models per class", default=1000) 16 | parser.add_argument("-lr", type=float, help="learning rate", default=5e-5) 17 | parser.add_argument("-weight_decay", type=float, help="weight decay", default=0.0) 18 | parser.add_argument("-no_pretraining", dest='no_pretraining', action='store_true') 19 | parser.add_argument("-cnn_name", "--cnn_name", type=str, help="cnn model name", default="vgg11") 20 | parser.add_argument("-num_views", type=int, help="number of views", default=12) 21 | parser.add_argument("-train_path", type=str, default="modelnet40_images_new_12x/*/train") 22 | parser.add_argument("-val_path", type=str, default="modelnet40_images_new_12x/*/test") 23 | parser.set_defaults(train=False) 24 | 25 | def create_folder(log_dir): 26 | # make summary folder 27 | if not os.path.exists(log_dir): 28 | os.mkdir(log_dir) 29 | else: 30 | print('WARNING: summary folder already exists!! It will be overwritten!!') 31 | shutil.rmtree(log_dir) 32 | os.mkdir(log_dir) 33 | 34 | if __name__ == '__main__': 35 | args = parser.parse_args() 36 | 37 | pretraining = not args.no_pretraining 38 | log_dir = args.name 39 | create_folder(args.name) 40 | config_f = open(os.path.join(log_dir, 'config.json'), 'w') 41 | json.dump(vars(args), config_f) 42 | config_f.close() 43 | 44 | # STAGE 1 45 | log_dir = args.name+'_stage_1' 46 | create_folder(log_dir) 47 | cnet = SVCNN(args.name, nclasses=40, pretraining=pretraining, cnn_name=args.cnn_name) 48 | 49 | optimizer = optim.Adam(cnet.parameters(), lr=args.lr, weight_decay=args.weight_decay) 50 | 51 | n_models_train = args.num_models*args.num_views 52 | 53 | train_dataset = SingleImgDataset(args.train_path, scale_aug=False, rot_aug=False, num_models=n_models_train, num_views=args.num_views) 54 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0) 55 | 56 | val_dataset = SingleImgDataset(args.val_path, scale_aug=False, rot_aug=False, test_mode=True) 57 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=0) 58 | print('num_train_files: '+str(len(train_dataset.filepaths))) 59 | print('num_val_files: '+str(len(val_dataset.filepaths))) 60 | trainer = ModelNetTrainer(cnet, train_loader, val_loader, optimizer, nn.CrossEntropyLoss(), 'svcnn', log_dir, num_views=1) 61 | trainer.train(30) 62 | 63 | # STAGE 2 64 | log_dir = args.name+'_stage_2' 65 | create_folder(log_dir) 66 | cnet_2 = MVCNN(args.name, cnet, nclasses=40, cnn_name=args.cnn_name, num_views=args.num_views) 67 | del cnet 68 | 69 | optimizer = optim.Adam(cnet_2.parameters(), lr=args.lr, weight_decay=args.weight_decay, betas=(0.9, 0.999)) 70 | 71 | train_dataset = MultiviewImgDataset(args.train_path, scale_aug=False, rot_aug=False, num_models=n_models_train, num_views=args.num_views) 72 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batchSize, shuffle=False, num_workers=0)# shuffle needs to be false! it's done within the trainer 73 | 74 | val_dataset = MultiviewImgDataset(args.val_path, scale_aug=False, rot_aug=False, num_views=args.num_views) 75 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batchSize, shuffle=False, num_workers=0) 76 | print('num_train_files: '+str(len(train_dataset.filepaths))) 77 | print('num_val_files: '+str(len(val_dataset.filepaths))) 78 | trainer = ModelNetTrainer(cnet_2, train_loader, val_loader, optimizer, nn.CrossEntropyLoss(), 'mvcnn', log_dir, num_views=args.num_views) 79 | trainer.train(30) 80 | 81 | 82 | --------------------------------------------------------------------------------