├── README.md ├── loss.py ├── config.py ├── data.py ├── GVCNN.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # GVCNN 2 | a pytorch implement of GVCNN 3 | 4 | 5 | the classification result on Modelnet40 is 93.07%(within 50 epochs) 6 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def cal_loss(pred, gold, smoothing=True): 7 | ''' Calculate cross entropy loss, apply label smoothing if needed. ''' 8 | 9 | gold = gold.contiguous().view(-1) 10 | 11 | if smoothing: 12 | eps = 0.2 13 | n_class = pred.size(1) 14 | 15 | one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1) 16 | one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) 17 | # print('pred: {}'.format(pred)) 18 | log_prb = F.log_softmax(pred, dim=1) 19 | 20 | loss = -(one_hot * log_prb).sum(dim=1).mean() 21 | else: 22 | loss = F.cross_entropy(pred, gold, reduction='mean') 23 | 24 | return loss -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | from logging import Formatter 4 | import torch 5 | import os 6 | 7 | 8 | parser = argparse.ArgumentParser("GVCNN") 9 | parser.add_argument("--data_dir", type=str, default="", help="dataset path") 10 | parser.add_argument('--gpu_device', type=str, default="0", help="gpu id") 11 | parser.add_argument('--seed', type=int, default=6, help="random seed") 12 | parser.add_argument('--mv_backbone', type=str, default='GOOGLENET', help=('ALEXNET', 'VGG13', 'VGG13BN', 'VGG11BN', 'RESNET50', 'GOOGLENET' 13 | ,'INCEPTION_V3')) 14 | parser.add_argument('--learning_rate', type=float, default=0.0001) 15 | parser.add_argument('--weight_decay', type=float, default=0.0001) 16 | parser.add_argument('--momentum', type=float, default=0.9) 17 | parser.add_argument('--optimizer', type=str, default='SGD', help='SGD, Adam. [default: SGD]') 18 | parser.add_argument('--num_views', type=int, default=12) 19 | parser.add_argument('--batch_size', type=int, default=8) 20 | parser.add_argument('--test_batch_size', type=int, default=2) 21 | parser.add_argument('--num_epochs', type=int, default=500) 22 | parser.add_argument('--dropout', type=float, default=0.5) 23 | parser.add_argument('--valid_freq', type=int, default=1) 24 | parser.add_argument('--save_freq', type=int, default=10) 25 | parser.add_argument('--pretrain_model_dir', type=str, default='./pretrain') 26 | parser.add_argument('--save_dir', type=str, default=None, help='The saving directory of training process.') 27 | parser.add_argument('--group_num', type=int, default=8) 28 | args = parser.parse_args() 29 | 30 | # Main logger 31 | main_logger = logging.getLogger() 32 | main_logger.setLevel(logging.INFO) 33 | log_console_format = "[%(levelname)s] - %(asctime)s : %(message)s" 34 | console_handler = logging.StreamHandler() 35 | console_handler.setLevel(logging.INFO) 36 | console_handler.setFormatter(Formatter(log_console_format)) 37 | main_logger.addHandler(console_handler) 38 | logger = logging.getLogger() 39 | 40 | # device 41 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_device 42 | torch.cuda.manual_seed_all(args.seed) 43 | device = torch.device("cuda" if torch.cuda.is_available() else 'cpu') 44 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from PIL import Image 4 | import torch 5 | from torchvision import transforms 6 | from torch.utils.data import Dataset 7 | 8 | 9 | 10 | class ModelNet40Views(Dataset): 11 | def __init__(self, data_root, base_model, mode='train'): 12 | super(ModelNet40Views, self).__init__() 13 | 14 | self.data_root = data_root 15 | self.mode = mode 16 | self.image_dirs = [] 17 | self.labels = [] 18 | self.image_dict = {} 19 | 20 | if base_model in ('ALEXNET', 'VGG13', 'VGG13BN', 'VGG11BN', 'RESNET50', 'GOOGLENET'): 21 | self.img_size = 224 22 | elif base_model in ('RESNET101'): 23 | self.img_size = 227 24 | elif base_model in ('INCEPTION_V3'): 25 | self.img_size = 299 26 | else: 27 | raise NotImplementedError 28 | 29 | self.transform = transforms.Compose([ 30 | transforms.Resize(self.img_size), 31 | transforms.ToTensor() 32 | ]) 33 | 34 | class_list = os.listdir(self.data_root) 35 | if self.mode == 'train': 36 | for oneclass in class_list: 37 | self.image_dict[oneclass] = glob(os.path.join(data_root, oneclass, 'train', '*.jpg')) 38 | elif self.mode == 'val': 39 | for oneclass in class_list: 40 | self.image_dict[oneclass] = glob(os.path.join(data_root, oneclass, 'test', '*.jpg')) 41 | else: 42 | raise NotImplementedError 43 | 44 | for class_key in self.image_dict: 45 | name_dict = {} 46 | for image_dir in self.image_dict[class_key]: 47 | image_class = '_'.join(os.path.split(image_dir)[1].split('.')[0].split('_')[:-1]) 48 | if image_class in name_dict: 49 | name_dict[image_class].append(image_dir) 50 | else: 51 | name_dict[image_class] = [image_dir] 52 | 53 | for image_class, dirs in name_dict.items(): 54 | self.image_dirs.append(dirs) 55 | self.labels.append(class_list.index(class_key)) 56 | 57 | self.image_num = len(self.image_dirs) if len(self.image_dirs)==len(self.labels) else print("labels don't match") 58 | 59 | 60 | def __getitem__(self, idx): 61 | images = [self.transform(Image.open(image)) for image in self.image_dirs[idx]] 62 | return torch.stack(images).float(), self.labels[idx] 63 | 64 | def __len__(self): 65 | return self.image_num 66 | 67 | 68 | if __name__ == '__main__': 69 | train = ModelNet40Views(data_root='', base_model='ALEXNET', mode="train") 70 | data, label = train[0] 71 | print('data: {}'.format(data.shape)) 72 | print('label: {}'.format(label)) 73 | -------------------------------------------------------------------------------- /GVCNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.nn as nn 4 | from torchsummary import summary 5 | 6 | 7 | def fc_bn_block(input, output): 8 | return nn.Sequential( 9 | nn.Linear(input, output), 10 | nn.BatchNorm1d(output), 11 | nn.ReLU(inplace=True)) 12 | 13 | 14 | def cal_scores(scores): 15 | n = len(scores) 16 | s = 0 17 | for score in scores: 18 | s += torch.ceil(score*n) 19 | s /= n 20 | return s 21 | 22 | 23 | def group_fusion(view_group, weight_group): 24 | shape_des = map(lambda a, b: a*b, view_group, weight_group) 25 | shape_des = sum(shape_des)/sum(weight_group) 26 | return shape_des 27 | 28 | 29 | def group_pooling(final_views, views_score, group_num): 30 | interval = 1.0 / group_num 31 | 32 | def onebatch_grouping(onebatch_views, onebatch_scores): 33 | viewgroup_onebatch = [[] for i in range(group_num)] 34 | scoregroup_onebatch = [[] for i in range(group_num)] 35 | 36 | for i in range(group_num): 37 | left = i*interval 38 | right = (i+1)*interval 39 | for j, score in enumerate(onebatch_scores): 40 | if left<=score0] 47 | weight_group = [cal_scores(scores) for scores in scoregroup_onebatch if len(scores)>0] 48 | onebatch_shape_des = group_fusion(view_group, weight_group) 49 | return onebatch_shape_des 50 | 51 | shape_descriptors = [] 52 | for (onebatch_views,onebatch_scores) in zip(final_views,views_score): 53 | shape_descriptors.append(onebatch_grouping(onebatch_views,onebatch_scores)) 54 | shape_descriptor = torch.stack(shape_descriptors, 0) 55 | # shape_descriptor: [B, 1024] 56 | return shape_descriptor 57 | 58 | 59 | class GVCNN(nn.Module): 60 | def __init__(self, num_classes=40, group_num=8, model_name='GOOGLENET', pretrained=True): 61 | super(GVCNN, self).__init__() 62 | 63 | self.num_classes = num_classes 64 | self.group_num = group_num 65 | 66 | if model_name=='GOOGLENET': 67 | base_model = torchvision.models.googlenet(pretrained=pretrained) 68 | 69 | self.FCN = nn.Sequential(*list(base_model.children())[:6]) 70 | self.CNN = nn.Sequential(*list(base_model.children())[:-2]) 71 | self.FC = nn.Sequential(fc_bn_block(256*28*28, 256), 72 | fc_bn_block(256,1)) 73 | self.fc_block_1 = fc_bn_block(1024, 512) 74 | self.drop_1 = nn.Dropout(0.5) 75 | self.fc_block_2 = fc_bn_block(512, 256) 76 | self.drop_2 = nn.Dropout(0.5) 77 | self.linear = nn.Linear(256, self.num_classes) 78 | 79 | 80 | def forward(self, views): 81 | ''' 82 | params views: B V C H W (B 12 3 224 224) 83 | return result: B num_classes 84 | ''' 85 | # print(views.size()) 86 | # views = views.cpu() 87 | batch_size, num_views, channel, image_size = views.size(0), views.size(1), views.size(2), views.size(3) 88 | 89 | views = views.view(batch_size*num_views, channel, image_size, image_size) 90 | raw_views = self.FCN(views) 91 | # print(raw_views.size()) 92 | # raw_views: [B*V 256 28 28] 93 | final_views = self.CNN(views) 94 | # final_views: [B*V 1024 1 1] 95 | final_views = final_views.view(batch_size, num_views, 1024) 96 | views_score = self.FC(raw_views.view(batch_size*num_views, -1)) 97 | views_score = torch.sigmoid(torch.tanh(torch.abs(views_score))) 98 | views_score = views_score.view(batch_size, num_views, -1) 99 | # views_score: [B V] 100 | shape_descriptor = group_pooling(final_views, views_score, self.group_num) 101 | # print(shape_descriptor.size()) 102 | 103 | out = self.fc_block_1(shape_descriptor) 104 | out = self.drop_1(out) 105 | out = self.fc_block_2(out) 106 | viewcnn_feature = out 107 | out = self.drop_2(out) 108 | pred = self.linear(out) 109 | 110 | return pred, viewcnn_feature 111 | 112 | 113 | if __name__ == '__main__': 114 | net = GVCNN(num_classes=40, group_num=8) 115 | summary(net, (12, 3, 224, 224)) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from tqdm import tqdm 4 | import shutil 5 | import sklearn.metrics as metrics 6 | import torch 7 | from torch import nn 8 | from torch.backends import cudnn 9 | from torch.autograd import Variable 10 | from torch.utils.data import DataLoader 11 | from torch.optim import Adam, SGD 12 | from torch.optim import lr_scheduler 13 | from torch.optim.lr_scheduler import CosineAnnealingLR 14 | from tensorboardX import SummaryWriter 15 | import time 16 | 17 | from config import args, logger, device 18 | from data import ModelNet40Views 19 | from loss import cal_loss 20 | from GVCNN import GVCNN 21 | 22 | cudnn.benchmark = True 23 | torch.autograd.set_detect_anomaly(True) 24 | 25 | 26 | def train(): 27 | 28 | save_dir = 'exp_gvcnn_{}_{}'.format(args.mv_backbone.lower(), time.strftime("%Y%m%d-%H%M%S")) 29 | 30 | if not os.path.exists(save_dir): 31 | os.makedirs(save_dir) 32 | checkpoint_dir = os.path.join(save_dir, 'checkpoint') 33 | if not os.path.exists(checkpoint_dir): 34 | os.makedirs(checkpoint_dir) 35 | summary_dir = os.path.join(save_dir, 'summary') 36 | if not os.path.exists(summary_dir): 37 | os.makedirs(summary_dir) 38 | 39 | logger.info("Loading dataset...") 40 | logger.info('ModelNet40') 41 | train_loader = DataLoader(ModelNet40Views(args.data_dir, args.mv_backbone, mode='train'), num_workers=8, 42 | batch_size=args.batch_size, shuffle=True, drop_last=True) 43 | test_loader = DataLoader(ModelNet40Views(args.data_dir, args.mv_backbone, mode='val'), num_workers=8, 44 | batch_size=args.test_batch_size, shuffle=False, drop_last=False) 45 | num_classes = 40 46 | logger.info('classes: {}'.format(num_classes)) 47 | 48 | logger.info('Creating model...') 49 | model = GVCNN(num_classes=40, group_num=args.group_num, model_name=args.mv_backbone).to(device) 50 | model = nn.DataParallel(model) 51 | criterion = cal_loss 52 | 53 | if args.optimizer == 'Adam': 54 | optimizer = Adam(model.parameters(), lr=args.learning_rate, 55 | betas=(0.9, 0.999), weight_decay=args.weight_decay) 56 | elif args.optimizer == 'SGD': 57 | optimizer = SGD(model.parameters(), lr=args.learning_rate * 100, 58 | momentum=args.momentum, weight_decay=args.weight_decay) 59 | else: 60 | raise RuntimeError('optimizer type not supported.({})'.format(args.optimizer)) 61 | scheduler= CosineAnnealingLR(optimizer, args.num_epochs, eta_min=args.learning_rate) 62 | 63 | summary_writer = SummaryWriter(log_dir=os.path.join(summary_dir, args.mv_backbone)) 64 | 65 | logger.info('start training.') 66 | best_test_acc = 0 67 | for epoch in range(1, args.num_epochs + 1): 68 | #################### 69 | # Train 70 | #################### 71 | tqdm_batch = tqdm(train_loader, desc='Epoch-{} training'.format(epoch)) 72 | model.train() 73 | train_loss = 0.0 74 | train_pred = [] 75 | count = 0.0 76 | train_true = [] 77 | for data, label in tqdm_batch: 78 | data, label = data.to(device), label.to(device) 79 | batch_size = data.size(0) 80 | optimizer.zero_grad() 81 | 82 | pred, feature = model(data) 83 | loss = criterion(pred, label) 84 | 85 | loss.backward() 86 | optimizer.step() 87 | 88 | preds = pred.max(dim=1)[1] 89 | train_loss += loss.item() * batch_size 90 | train_pred.append(preds.detach().cpu().numpy()) 91 | 92 | count += batch_size 93 | train_true.append(label.cpu().numpy()) 94 | 95 | scheduler.step() 96 | 97 | train_true = np.concatenate(train_true) 98 | train_pred = np.concatenate(train_pred) 99 | outstr = 'Train %d, loss: %.6f, train acc: %.6f, train avg acc: %.6f' % (epoch, train_loss * 1.0 / count, 100 | metrics.accuracy_score(train_true, train_pred), 101 | metrics.balanced_accuracy_score(train_true, train_pred)) 102 | logger.info(outstr) 103 | 104 | summary_writer.add_scalar('train/loss', train_loss * 1.0 / count, epoch) 105 | summary_writer.add_scalar('train/overall_acc', metrics.accuracy_score(train_true, train_pred), epoch) 106 | summary_writer.add_scalar('train/avg_acc', metrics.balanced_accuracy_score(train_true, train_pred), epoch) 107 | 108 | 109 | #################### 110 | # Test 111 | #################### 112 | tqdm_batch = tqdm(test_loader, desc='Epoch-{} testing'.format(epoch)) 113 | model.eval() 114 | test_loss = 0.0 115 | test_pred = [] 116 | count = 0.0 117 | test_true = [] 118 | for data, label in tqdm_batch: 119 | data, label = data.to(device), label.to(device) 120 | batch_size = data.size(0) 121 | 122 | pred, feature = model(data) 123 | 124 | loss = criterion(pred, label) 125 | 126 | preds = pred.max(dim=1)[1] 127 | test_loss += loss.item() * batch_size 128 | test_pred.append(preds.detach().cpu().numpy()) 129 | 130 | count += batch_size 131 | test_true.append(label.cpu().numpy()) 132 | 133 | test_true = np.concatenate(test_true) 134 | test_pred = np.concatenate(test_pred) 135 | test_acc = metrics.accuracy_score(test_true, test_pred) 136 | 137 | outstr = 'Test %d, loss: %.6f, test acc: %.6f, test avg acc: %.6f' % (epoch, test_loss * 1.0 / count, 138 | metrics.accuracy_score(test_true, 139 | test_pred), 140 | metrics.balanced_accuracy_score(test_true, 141 | test_pred)) 142 | logger.info(outstr) 143 | 144 | summary_writer.add_scalar('test/loss', test_loss * 1.0 / count, epoch) 145 | summary_writer.add_scalar('test/overall_acc', metrics.accuracy_score(test_true, test_pred), epoch) 146 | summary_writer.add_scalar('test/avg_acc', metrics.balanced_accuracy_score(test_true, test_pred), epoch) 147 | 148 | if test_acc >= best_test_acc: 149 | best_test_acc = test_acc 150 | torch.save(model.state_dict(), os.path.join(checkpoint_dir, args.mv_backbone+'best_model.pth')) 151 | 152 | if epoch % args.save_freq == 0: 153 | torch.save(model.state_dict(), os.path.join(checkpoint_dir, args.mv_backbone+'model_{}.pth'.format(epoch))) 154 | 155 | logger.info('best_test_acc: {:.6f}'.format(best_test_acc)) 156 | 157 | 158 | 159 | if __name__ == '__main__': 160 | train() 161 | --------------------------------------------------------------------------------