├── LICENSE ├── README.md ├── config ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── config.cpython-36.pyc └── config.py ├── datasets ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-36.pyc │ ├── data_pth.cpython-35.pyc │ ├── data_pth.cpython-36.pyc │ ├── pc_aug.cpython-35.pyc │ └── pc_aug.cpython-36.pyc ├── data_pth.py └── pc_aug.py ├── docs └── pipeline.png ├── models ├── DGCNN.py ├── MVCNN.py ├── PVRNet.py ├── __init__.py ├── __pycache__ │ ├── DGCNN.cpython-36.pyc │ ├── MVCNN.cpython-36.pyc │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-36.pyc │ ├── model_helper.cpython-35.pyc │ ├── model_helper.cpython-36.pyc │ ├── model_utils.cpython-36.pyc │ ├── mvcnn_gnet.cpython-35.pyc │ ├── mvcnn_gnet.cpython-36.pyc │ ├── mvcnn_resnet.cpython-35.pyc │ ├── mvcnn_resnet.cpython-36.pyc │ ├── point_with_attention.cpython-35.pyc │ ├── point_with_attention.cpython-36.pyc │ └── zhengyue_mvcnn.cpython-36.pyc └── model_utils.py ├── train_mvcnn.py ├── train_pc.py ├── train_pvrnet.py ├── utils ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-36.pyc │ ├── config.cpython-35.pyc │ ├── config.cpython-36.pyc │ ├── generate_pc.cpython-35.pyc │ ├── generate_pc.cpython-36.pyc │ ├── split_fun.cpython-35.pyc │ └── split_fun.cpython-36.pyc ├── generate_pc.py └── meter │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-36.pyc │ ├── apmeter.cpython-35.pyc │ ├── apmeter.cpython-36.pyc │ ├── aucmeter.cpython-35.pyc │ ├── aucmeter.cpython-36.pyc │ ├── averagevaluemeter.cpython-35.pyc │ ├── averagevaluemeter.cpython-36.pyc │ ├── classerrormeter.cpython-35.pyc │ ├── classerrormeter.cpython-36.pyc │ ├── confusionmeter.cpython-35.pyc │ ├── confusionmeter.cpython-36.pyc │ ├── mapmeter.cpython-35.pyc │ ├── mapmeter.cpython-36.pyc │ ├── meter.cpython-35.pyc │ ├── meter.cpython-36.pyc │ ├── movingaveragevaluemeter.cpython-35.pyc │ ├── movingaveragevaluemeter.cpython-36.pyc │ ├── msemeter.cpython-35.pyc │ ├── msemeter.cpython-36.pyc │ ├── timemeter.cpython-35.pyc │ └── timemeter.cpython-36.pyc │ ├── apmeter.py │ ├── aucmeter.py │ ├── averagevaluemeter.py │ ├── classerrormeter.py │ ├── confusionmeter.py │ ├── mapmeter.py │ ├── meter.py │ ├── movingaveragevaluemeter.py │ ├── msemeter.py │ ├── retrievalmeter.py │ └── timemeter.py ├── val_mvcnn.py ├── val_pc.py └── val_pvrnet.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 iMoon: Intelligent Media and Cognition Group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PVRNet 2 | PVRNet: Point-View Relation Neural Network for 3D Shape Recognition (AAAI 2019) 3 | 4 | Created by Haoxuan You, Yifan Feng, Xibin Zhao, Changqing Zou, Rongrong Ji, Yue Gao from Tsinghua University. 5 | 6 | ![](https://github.com/iMoonLab/PVRNet/blob/master/docs/pipeline.png) 7 | ### Introduction 8 | This work will appear in AAAI 2019. We propose a point-view relation neural network called PVRNet for 3D shape recognition and retrieval. You can chekc our [paper](https://arxiv.org/abs/1812.00333) for more details. 9 | 10 | In this repository, our code and data are released for training our PVRNet on ModelNet40 dataset. 11 | 12 | ### Citation 13 | If you find our work useful in your research, please cite our paper: 14 | ``` 15 | @article{you2018pvrnet, 16 | title={PVRNet: Point-View Relation Neural Network for 3D Shape Recognition}, 17 | author={You, Haoxuan and Feng, Yifan and Zhao, Xibin and Zou, Changqing and Ji, Rongrong and Gao, Yue}, 18 | journal={AAAI 2019}, 19 | year={2018} 20 | } 21 | ``` 22 | ### Configuration 23 | Code is tested under the environment of Pytorch 0.4.1, Python 3.6 and CUDA 9.0 on Ubuntu 16.04. 24 | 25 | Data: [point cloud data](https://drive.google.com/file/d/1DUh_8PQjh3ds4yO0O8q_vb0HPistOJ4y/view?usp=sharing) and [multi-view(12-view) data](https://drive.google.com/file/d/12JbIPLvcSUsMjxb_CZYXI8xQK2UKosio/view?usp=sharing) from ModelNet40 dataset. 26 | 27 | Pretrained Model: [multi-view part(MVCNN)](https://drive.google.com/file/d/1dZG7XojtPS9Cl5aaH4iWXA_N2PximB6i/view?usp=sharing), [point cloud part(DGCNN)](https://drive.google.com/file/d/1fY9E44xuPwUFxJ_BIeP5NXwrB7DQm1tw/view?usp=sharing) and [PVRNet](https://drive.google.com/file/d/1g3Ef68jRSY2mNf54dOeqNFYZTm4cO13d/view?usp=sharing) 28 | 29 | ### Usage 30 | + Download data and pretrained ckpt from above links. Create dir for data as well as result, and place them under corresponding dirs(./data/ and ./result/ckpt/). 31 | 32 | ```mkdir -p data result/ckpt``` 33 | 34 | + Train PVRNet. This would use pretrained MVCNN model and DGCNN model saved in ./result/ckpt: 35 | 36 | ``` python train_pvrnet.py``` 37 | 38 | + If validate the performance of PVRNet with our pretrained model: 39 | 40 | `python val_pvrnet.py` 41 | 42 | If validate the performance of pretrained MVCNN and DGCNN models: 43 | ``` 44 | python val_mvcnn.py 45 | python val_pc.py 46 | ``` 47 | 48 | + If you want to train new model for MVCNN and DGCNN: 49 | 50 | 51 | ``` 52 | python train_mvcnn.py 53 | python train_pc.py 54 | ``` 55 | 56 | 57 | ### License 58 | Our code is released under MIT License (see LICENSE file for details). 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import * 2 | 3 | # init environment 4 | 5 | def init_env(): 6 | import os 7 | import os.path as osp 8 | os.environ['CUDA_VISIBLE_DEVICES'] = available_gpus 9 | 10 | def check_dir(_dir, create=True): 11 | if not osp.exists(_dir): 12 | if create: 13 | os.makedirs(_dir) 14 | else: 15 | raise FileNotFoundError(f'{_dir} not exist') 16 | 17 | check_dir(result_root) 18 | check_dir(ckpt_folder) 19 | 20 | # check_dir(pv_net.ckpt_record_folder) 21 | # check_dir(view_net.ckpt_record_folder) 22 | # check_dir(pc_net.ckpt_record_folder) 23 | 24 | check_dir(view_net.data_root, create=False) 25 | check_dir(pc_net.data_root, create=False) -------------------------------------------------------------------------------- /config/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/config/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /config/__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/config/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /config/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import models 4 | 5 | # configuration file 6 | description = 'pvrnet' 7 | version_string = '0.1' 8 | 9 | # device can be "cuda" or "gpu" 10 | device = 'cuda' 11 | num_workers = 4 12 | available_gpus = '1,2,3' 13 | # available_gpus = '0,2,3' 14 | print_freq = 15 15 | 16 | work_root = os.getcwd() 17 | result_root = osp.join(work_root, 'result') 18 | # result_root = '/repository/yhx/result/aaai2018_result' 19 | # result_sub_folder = osp.join(result_root, f'{description}_{version_string}_torch') 20 | ckpt_folder = osp.join(result_root, 'ckpt') 21 | 22 | base_model_name = models.ALEXNET 23 | # base_model_name = models.VGG13BN 24 | # base_model_name = models.VGG11BN 25 | # base_model_name = models.RESNET50 26 | # base_model_name = models.INCEPTION_V3 27 | 28 | 29 | class pc_net: 30 | data_root = osp.join(work_root, 'data', 'pc') 31 | # data_root = '/home/youhaoxuan/data/pc' 32 | n_neighbor = 20 33 | num_classes = 40 34 | pre_trained_model = None 35 | # ckpt_file = osp.join(ckpt_folder, 'PCNet-ckpt.pth') 36 | ckpt_file = osp.join(ckpt_folder, 'PCNet-save-ckpt.pth') 37 | ckpt_load_file = osp.join(ckpt_folder, 'PCNet-release-v1-ckpt.pth') 38 | 39 | class train: 40 | batch_sz = 24*4 41 | resume = False 42 | resume_epoch = None 43 | 44 | lr = 0.001 45 | momentum = 0.9 46 | weight_decay = 0 47 | max_epoch = 250 48 | data_aug = True 49 | 50 | class validation: 51 | batch_sz = 32 52 | 53 | class test: 54 | batch_sz = 32 55 | 56 | 57 | class view_net: 58 | num_classes = 40 59 | 60 | # multi-view cnn 61 | data_root = osp.join(work_root, 'data', '12_ModelNet40') 62 | # data_root = '/home/youhaoxuan/data/12_ModelNet40' 63 | 64 | pre_trained_model = None 65 | if base_model_name == (models.ALEXNET or models.RESNET50): 66 | ckpt_file = osp.join(ckpt_folder, f'MVCNN-{base_model_name}-save-ckpt.pth') 67 | ckpt_load_file = osp.join(ckpt_folder, f'MVCNN-{base_model_name}-ckpt.pth') 68 | else: 69 | ckpt_file = osp.join(ckpt_folder, f'{base_model_name}-12VIEWS-MAX_POOLING-save-ckpt.pth') 70 | ckpt_load_file = osp.join(ckpt_folder, f'{base_model_name}-12VIEWS-MAX_POOLING-ckpt.pth') 71 | 72 | class train: 73 | if base_model_name == models.ALEXNET: 74 | batch_sz = 128 # AlexNet 2 gpus 75 | elif base_model_name == models.INCEPTION_V3: 76 | batch_sz = 2 77 | else: 78 | batch_sz = 32 79 | resume = False 80 | resume_epoch = None 81 | 82 | lr = 0.001 83 | momentum = 0.9 84 | weight_decay = 1e-4 85 | max_epoch = 200 86 | data_aug = True 87 | 88 | class validation: 89 | batch_sz = 256 90 | 91 | class test: 92 | batch_sz = 32 93 | 94 | class pv_net: 95 | num_classes = 40 96 | 97 | # pointcloud 98 | pc_root = osp.join(work_root, 'data', 'pc') 99 | n_neighbor = 20 100 | 101 | # multi-view cnn 102 | view_root = osp.join(work_root, 'data', '12_ModelNet40') 103 | 104 | pre_trained_model = False 105 | ckpt_file = osp.join(ckpt_folder, f'PVNet2-{base_model_name}.pth') 106 | ckpt_load_file = osp.join(ckpt_folder, f'PVNet2-{base_model_name}-v94-ckpt.pth') 107 | 108 | 109 | class train: 110 | # optim = 'Adam' 111 | optim = 'SGD' 112 | # batch_sz = 18*2 113 | batch_sz = 20 114 | batch_sz_res = 5*1 115 | resume = False 116 | resume_epoch = None 117 | 118 | iter_train = True 119 | # iter_train = False 120 | 121 | fc_lr = 0.01 122 | all_lr = 0.0009 123 | momentum = 0.9 124 | # weight_decay = 5e-4 125 | weight_decay = 1e-5 126 | max_epoch = 100 127 | data_aug = True 128 | 129 | class validation: 130 | batch_sz = 40 131 | 132 | class test: 133 | batch_sz = 32 134 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .pc_aug import rotate_point_cloud_by_angle, rotation_point_cloud, jitter_point_cloud, pc_aug_funs, normal_pc 2 | 3 | STATUS_TRAIN = "train" 4 | STATUS_TEST = "test" 5 | from .data_pth import * -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/datasets/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/datasets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/data_pth.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/datasets/__pycache__/data_pth.cpython-35.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/data_pth.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/datasets/__pycache__/data_pth.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/pc_aug.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/datasets/__pycache__/pc_aug.cpython-35.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/pc_aug.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/datasets/__pycache__/pc_aug.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/data_pth.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import models 4 | import numpy as np 5 | import os.path as osp 6 | from PIL import Image 7 | from itertools import groupby 8 | from glob import glob 9 | from torchvision import transforms 10 | from torch.utils.data import Dataset 11 | from datasets import normal_pc, STATUS_TEST, STATUS_TRAIN, pc_aug_funs 12 | import pdb 13 | 14 | name_list = ['night_stand', 'range_hood', 'plant', 'chair', 'tent', 'curtain', 'piano', 'dresser', 'desk', 'bed', 15 | 'sink', 'laptop', 'flower_pot', 'car', 'stool', 'vase', 'monitor', 'airplane', 'stairs', 'glass_box', 16 | 'bottle', 'guitar', 'cone', 'toilet', 'bathtub', 'wardrobe', 'radio', 'person', 'xbox', 'bowl', 'cup', 17 | 'door', 'tv_stand', 'mantel', 'sofa', 'keyboard', 'bookshelf', 'bench', 'table', 'lamp'] 18 | 19 | def get_info(shapes_dir, isView=False): 20 | names_dict = {} 21 | if isView: 22 | for shape_dir in shapes_dir: 23 | name = '_'.join(osp.split(shape_dir)[1].split('.')[0].split('_')[:-1]) 24 | if name in names_dict: 25 | names_dict[name].append(shape_dir) 26 | else: 27 | names_dict[name] = [shape_dir] 28 | else: 29 | for shape_dir in shapes_dir: 30 | name = osp.split(shape_dir)[1].split('.')[0] 31 | names_dict[name] = shape_dir 32 | 33 | return names_dict 34 | 35 | 36 | class pc_data(Dataset): 37 | def __init__(self, pc_root, status='train', pc_input_num=1024): 38 | super(pc_data, self).__init__() 39 | 40 | self.status = status 41 | self.pc_list = [] 42 | self.lbl_list = [] 43 | self.pc_input_num = pc_input_num 44 | 45 | if status == STATUS_TRAIN: 46 | npy_list = glob(osp.join(pc_root, '*', 'train', '*.npy')) 47 | else: 48 | npy_list = glob(osp.join(pc_root, '*', 'test', '*.npy')) 49 | names_dict = get_info(npy_list, isView=False) 50 | 51 | for name, _dir in names_dict.items(): 52 | self.pc_list.append(_dir) 53 | self.lbl_list.append(name_list.index('_'.join(name.split('_')[:-1]))) 54 | 55 | print(f'{status} data num: {len(self.pc_list)}') 56 | 57 | def __getitem__(self, idx): 58 | lbl = self.lbl_list[idx] 59 | pc = np.load(self.pc_list[idx])[:self.pc_input_num].astype(np.float32) 60 | pc = normal_pc(pc) 61 | if self.status == STATUS_TRAIN: 62 | pc = pc_aug_funs(pc) 63 | pc = np.expand_dims(pc.transpose(), axis=2) 64 | return torch.from_numpy(pc).float(), lbl 65 | 66 | def __len__(self): 67 | return len(self.pc_list) 68 | 69 | 70 | class view_data(Dataset): 71 | def __init__(self, view_root, base_model_name=models.ALEXNET, status=STATUS_TRAIN): 72 | super(view_data, self).__init__() 73 | 74 | self.status = status 75 | self.view_list = [] 76 | self.lbl_list = [] 77 | 78 | if base_model_name in (models.ALEXNET, models.VGG13, models.VGG13BN, models.VGG11BN, models.RESNET50): 79 | self.img_sz = 224 80 | elif base_model_name in (models.RESNET101): 81 | self.img_sz = 227 82 | elif base_model_name in models.INCEPTION_V3: 83 | self.img_sz = 299 84 | else: 85 | raise NotImplementedError 86 | 87 | self.transform = transforms.Compose([ 88 | transforms.Resize(self.img_sz), 89 | transforms.ToTensor() 90 | ]) 91 | 92 | if status==STATUS_TRAIN: 93 | jpg_list = glob(osp.join(view_root, '*', 'train', '*.jpg')) 94 | else: 95 | jpg_list = glob(osp.join(view_root, '*', 'test', '*.jpg')) 96 | names_dict = get_info(jpg_list, isView=True) 97 | 98 | for name, _dirs in names_dict.items(): 99 | self.view_list.append(_dirs) 100 | self.lbl_list.append(name_list.index('_'.join(name.split('_')[:-1]))) 101 | 102 | self.view_num = len(self.view_list[0]) 103 | 104 | print(f'{status} data num: {len(self.view_list)}') 105 | 106 | def __getitem__(self, idx): 107 | views = [self.transform(Image.open(v)) for v in self.view_list[idx]] 108 | return torch.stack(views).float(), self.lbl_list[idx] 109 | 110 | def __len__(self): 111 | return len(self.view_list) 112 | 113 | 114 | class pc_view_data(Dataset): 115 | def __init__(self, pc_root, view_root, base_model_name=models.ALEXNET, status='train', pc_input_num=1024): 116 | super(pc_view_data, self).__init__() 117 | 118 | self.status = status 119 | self.view_list = [] 120 | self.pc_list = [] 121 | self.lbl_list = [] 122 | self.pc_input_num = pc_input_num 123 | 124 | if base_model_name in (models.ALEXNET, models.VGG13, models.VGG13BN, models.VGG11BN): 125 | self.img_sz = 224 126 | elif base_model_name in (models.RESNET50, models.RESNET101): 127 | self.img_sz = 224 128 | elif base_model_name in models.INCEPTION_V3: 129 | self.img_sz = 299 130 | else: 131 | raise NotImplementedError 132 | 133 | self.transform = transforms.Compose([ 134 | transforms.Resize(self.img_sz), 135 | transforms.ToTensor() 136 | ]) 137 | 138 | if status == STATUS_TRAIN: 139 | jpg_list = glob(osp.join(view_root, '*', 'train', '*.jpg')) 140 | npy_list = glob(osp.join(pc_root, '*', 'train', '*.npy')) 141 | else: 142 | jpg_list = glob(osp.join(view_root, '*', 'test', '*.jpg')) 143 | npy_list = glob(osp.join(pc_root, '*', 'test', '*.npy')) 144 | pc_dict = get_info(npy_list, isView=False) 145 | view_dict = get_info(jpg_list, isView=True) 146 | 147 | for name in pc_dict.keys(): 148 | self.view_list.append(view_dict[name]) 149 | self.pc_list.append(pc_dict[name]) 150 | self.lbl_list.append(name_list.index('_'.join(name.split('_')[:-1]))) 151 | 152 | self.view_num = len(self.view_list[0]) 153 | 154 | print(f'{status} data num: {len(self.view_list)}') 155 | 156 | def __getitem__(self, idx): 157 | names = osp.split(self.pc_list[idx])[1].split('.')[0] 158 | views = [self.transform(Image.open(v)) for v in self.view_list[idx]] 159 | lbl = self.lbl_list[idx] 160 | pc = np.load(self.pc_list[idx])[:self.pc_input_num].astype(np.float32) 161 | pc = normal_pc(pc) 162 | if self.status == STATUS_TRAIN: 163 | pc = pc_aug_funs(pc) 164 | pc = np.expand_dims(pc.transpose(), axis=2) 165 | # return torch.stack(views).float(), torch.from_numpy(pc).float(), lbl, names 166 | return torch.stack(views).float(), torch.from_numpy(pc).float(), lbl 167 | 168 | def __len__(self): 169 | return len(self.pc_list) 170 | 171 | def get_immediate_subdirectories(a_dir): 172 | return [name for name in os.listdir(a_dir) 173 | if os.path.isdir(os.path.join(a_dir, name))] 174 | 175 | class ModelNet(Dataset): 176 | def __init__(self, data_root, status=STATUS_TRAIN, img_size=224): 177 | super(ModelNet, self).__init__() 178 | self.data_root = data_root 179 | self.status = status 180 | self.img_size = img_size 181 | self.views_list = [] 182 | self.label_list = [] 183 | for i, curr_category in enumerate(sorted(get_immediate_subdirectories(self.data_root))): 184 | if status == STATUS_TEST: 185 | working_dir = os.path.join(data_root, curr_category, 'test') 186 | elif status == STATUS_TRAIN: 187 | working_dir = os.path.join(data_root, curr_category, 'train') 188 | else: 189 | raise NotImplementedError 190 | all_img_list = glob(working_dir + "/*.jpg") 191 | append_views_list = [[v for v in g] for _, g in groupby(sorted(all_img_list), lambda x: x.split('_')[-2])] 192 | self.views_list += append_views_list 193 | self.label_list += [i] * len(append_views_list) 194 | assert len(self.views_list) == len(self.label_list) 195 | self.transform = transforms.Compose([ 196 | transforms.Resize(self.img_size), 197 | transforms.ToTensor() 198 | ]) 199 | 200 | def __getitem__(self, index): 201 | views = [self.transform(Image.open(v)) for v in self.views_list[index]] 202 | return torch.stack(views), self.label_list[index] 203 | 204 | def __len__(self): 205 | return len(self.views_list) 206 | 207 | if __name__ == '__main__': 208 | # vd = view_data(cfg, state='test', batch_size=8, shuffle=True) 209 | # batch_len = len(vd) 210 | # imgs, lbls = vd.get_batch(307) 211 | # print(batch_len) 212 | # print(imgs.shape) 213 | # print(lbls) 214 | # Image.fromarray((imgs[0][0]*255).astype(np.uint8)).show() 215 | 216 | 217 | pvd = pc_view_data(status=STATUS_TEST) 218 | batch_len = len(pvd) 219 | # imgs, pcs, lbls = pvd.get_batch(307) 220 | # print(batch_len) 221 | # print(imgs.shape) 222 | # print(lbls) 223 | # Image.fromarray((imgs[0][0]*255).astype(np.uint8)).show() 224 | # utils.generate_pc.draw_pc(pcs[0]) 225 | -------------------------------------------------------------------------------- /datasets/pc_aug.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def normal_pc(pc): 4 | """ 5 | normalize point cloud in range L 6 | :param pc: type list 7 | :return: type list 8 | """ 9 | pc_mean = pc.mean(axis=0) 10 | pc = pc - pc_mean 11 | pc_L_max = np.max(np.sqrt(np.sum(abs(pc ** 2), axis=-1))) 12 | pc = pc/pc_L_max 13 | return pc 14 | 15 | def rotation_point_cloud(pc): 16 | """ 17 | Randomly rotate the point clouds to augment the dataset 18 | rotation is per shape based along up direction 19 | :param pc: B X N X 3 array, original batch of point clouds 20 | :return: BxNx3 array, rotated batch of point clouds 21 | """ 22 | # rotated_data = np.zeros(pc.shape, dtype=np.float32) 23 | 24 | rotation_angle = np.random.uniform() * 2 * np.pi 25 | cosval = np.cos(rotation_angle) 26 | sinval = np.sin(rotation_angle) 27 | # rotation_matrix = np.array([[cosval, 0, sinval], 28 | # [0, 1, 0], 29 | # [-sinval, 0, cosval]]) 30 | rotation_matrix = np.array([[cosval, -sinval, 0], 31 | [sinval, cosval, 0], 32 | [0, 0, 1]]) 33 | rotated_data = np.dot(pc.reshape((-1, 3)), rotation_matrix) 34 | 35 | return rotated_data 36 | 37 | 38 | def rotate_point_cloud_by_angle(pc, rotation_angle): 39 | """ 40 | Randomly rotate the point clouds to augment the dataset 41 | rotation is per shape based along up direction 42 | :param pc: B X N X 3 array, original batch of point clouds 43 | :param rotation_angle: angle of rotation 44 | :return: BxNx3 array, rotated batch of point clouds 45 | """ 46 | # rotated_data = np.zeros(pc.shape, dtype=np.float32) 47 | 48 | # rotation_angle = np.random.uniform() * 2 * np.pi 49 | cosval = np.cos(rotation_angle) 50 | sinval = np.sin(rotation_angle) 51 | rotation_matrix = np.array([[cosval, 0, sinval], 52 | [0, 1, 0], 53 | [-sinval, 0, cosval]]) 54 | rotated_data = np.dot(pc.reshape((-1, 3)), rotation_matrix) 55 | 56 | return rotated_data 57 | 58 | 59 | def jitter_point_cloud(pc, sigma=0.01, clip=0.05): 60 | """ 61 | Randomly jitter points. jittering is per point. 62 | :param pc: B X N X 3 array, original batch of point clouds 63 | :param sigma: 64 | :param clip: 65 | :return: 66 | """ 67 | jittered_data = np.clip(sigma * np.random.randn(*pc.shape), -1 * clip, clip) 68 | jittered_data += pc 69 | return jittered_data 70 | 71 | 72 | def shift_point_cloud(pc, shift_range=0.1): 73 | """ Randomly shift point cloud. Shift is per point cloud. 74 | Input: 75 | BxNx3 array, original batch of point clouds 76 | Return: 77 | BxNx3 array, shifted batch of point clouds 78 | """ 79 | N, C = pc.shape 80 | shifts = np.random.uniform(-shift_range, shift_range, 3) 81 | pc += shifts 82 | return pc 83 | 84 | 85 | def random_scale_point_cloud(pc, scale_low=0.8, scale_high=1.25): 86 | """ Randomly scale the point cloud. Scale is per point cloud. 87 | Input: 88 | BxNx3 array, original batch of point clouds 89 | Return: 90 | BxNx3 array, scaled batch of point clouds 91 | """ 92 | N, C = pc.shape 93 | scales = np.random.uniform(scale_low, scale_high, 1) 94 | pc *= scales 95 | return pc 96 | 97 | 98 | def rotate_perturbation_point_cloud(pc, angle_sigma=0.06, angle_clip=0.18): 99 | """ Randomly perturb the point clouds by small rotations 100 | Input: 101 | BxNx3 array, original batch of point clouds 102 | Return: 103 | BxNx3 array, rotated batch of point clouds 104 | """ 105 | # rotated_data = np.zeros(pc.shape, dtype=np.float32) 106 | angles = np.clip(angle_sigma * np.random.randn(3), -angle_clip, angle_clip) 107 | Rx = np.array([[1, 0, 0], 108 | [0, np.cos(angles[0]), -np.sin(angles[0])], 109 | [0, np.sin(angles[0]), np.cos(angles[0])]]) 110 | Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])], 111 | [0, 1, 0], 112 | [-np.sin(angles[1]), 0, np.cos(angles[1])]]) 113 | Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0], 114 | [np.sin(angles[2]), np.cos(angles[2]), 0], 115 | [0, 0, 1]]) 116 | R = np.dot(Rz, np.dot(Ry, Rx)) 117 | shape_pc = pc 118 | rotated_data = np.dot(shape_pc.reshape((-1, 3)), R) 119 | return rotated_data 120 | 121 | 122 | def pc_aug_funs(pc): 123 | pc = rotation_point_cloud(pc) 124 | pc = jitter_point_cloud(pc) 125 | pc = random_scale_point_cloud(pc) 126 | pc = rotate_perturbation_point_cloud(pc) 127 | pc = shift_point_cloud(pc) 128 | return pc 129 | -------------------------------------------------------------------------------- /docs/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/docs/pipeline.png -------------------------------------------------------------------------------- /models/DGCNN.py: -------------------------------------------------------------------------------- 1 | from models import * 2 | 3 | 4 | class DGCNN(nn.Module): 5 | def __init__(self, n_neighbor=20, num_classes=20): 6 | super(DGCNN, self).__init__() 7 | self.n_neighbor = n_neighbor 8 | self.trans_net = transform_net(6, 3) 9 | self.conv2d1 = conv_2d(6, 64, 1) 10 | self.conv2d2 = conv_2d(128, 64, 1) 11 | self.conv2d3 = conv_2d(128, 64, 1) 12 | self.conv2d4 = conv_2d(128, 128, 1) 13 | self.conv2d5 = conv_2d(320, 1024, 1) 14 | self.mlp1 = nn.Sequential( 15 | fc_layer(1024, 512, True), 16 | nn.Dropout(p=0.5) 17 | ) 18 | self.mlp2 = nn.Sequential( 19 | fc_layer(512, 256, True), 20 | nn.Dropout(p=0.5) 21 | ) 22 | self.mlp3 = nn.Linear(256, num_classes) 23 | 24 | def forward(self, x): 25 | x_edge = get_edge_feature(x, self.n_neighbor) 26 | x_trans = self.trans_net(x_edge) 27 | x = x.squeeze(-1).transpose(2, 1) 28 | x = torch.bmm(x, x_trans) 29 | x = x.transpose(2, 1) 30 | 31 | x1 = get_edge_feature(x, self.n_neighbor) 32 | x1 = self.conv2d1(x1) 33 | x1, _ = torch.max(x1, dim=-1, keepdim=True) 34 | 35 | x2 = get_edge_feature(x1, self.n_neighbor) 36 | x2 = self.conv2d2(x2) 37 | x2, _ = torch.max(x2, dim=-1, keepdim=True) 38 | 39 | x3 = get_edge_feature(x2, self.n_neighbor) 40 | x3 = self.conv2d3(x3) 41 | x3, _ = torch.max(x3, dim=-1, keepdim=True) 42 | 43 | x4 = get_edge_feature(x3, self.n_neighbor) 44 | x4 = self.conv2d4(x4) 45 | x4, _ = torch.max(x4, dim=-1, keepdim=True) 46 | 47 | x5 = torch.cat((x1, x2, x3, x4), dim=1) 48 | x5 = self.conv2d5(x5) 49 | x5, _ = torch.max(x5, dim=-2, keepdim=True) 50 | 51 | net = x5.view(x5.size(0), -1) 52 | net = self.mlp1(net) 53 | net = self.mlp2(net) 54 | net = self.mlp3(net) 55 | 56 | return net 57 | 58 | 59 | -------------------------------------------------------------------------------- /models/MVCNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import config 3 | import torchvision 4 | import torch.nn as nn 5 | import models 6 | 7 | 8 | class BaseFeatureNet(nn.Module): 9 | def __init__(self, base_model_name=models.VGG13, pretrained=True): 10 | super(BaseFeatureNet, self).__init__() 11 | base_model_name = base_model_name.upper() 12 | self.fc_features = None 13 | 14 | if base_model_name == models.VGG13: 15 | base_model = torchvision.models.vgg13(pretrained=pretrained) 16 | self.feature_len = 4096 17 | self.features = base_model.features 18 | self.fc_features = nn.Sequential(*list(base_model.classifier.children())[:-1]) 19 | 20 | elif base_model_name == models.VGG11BN: 21 | base_model = torchvision.models.vgg11_bn(pretrained=pretrained) 22 | self.feature_len = 4096 23 | self.features = base_model.features 24 | self.fc_features = nn.Sequential(*list(base_model.classifier.children())[:-1]) 25 | 26 | elif base_model_name == models.VGG13BN: 27 | base_model = torchvision.models.vgg13_bn(pretrained=pretrained) 28 | self.feature_len = 4096 29 | self.features = base_model.features 30 | self.fc_features = nn.Sequential(*list(base_model.classifier.children())[:-1]) 31 | 32 | elif base_model_name == models.ALEXNET: 33 | # base_model = torchvision.models.alexnet(pretrained=pretrained) 34 | base_model = torchvision.models.alexnet(pretrained=pretrained) 35 | self.feature_len = 4096 36 | self.features = base_model.features 37 | self.fc_features = nn.Sequential(*list(base_model.classifier.children())[:-1]) 38 | 39 | elif base_model_name == models.RESNET50: 40 | base_model = torchvision.models.resnet50(pretrained=pretrained) 41 | self.feature_len = 2048 42 | self.features = nn.Sequential(*list(base_model.children())[:-1]) 43 | elif base_model_name == models.RESNET101: 44 | base_model = torchvision.models.resnet101(pretrained=pretrained) 45 | self.feature_len = 2048 46 | self.features = nn.Sequential(*list(base_model.children())[:-1]) 47 | 48 | elif base_model_name == models.INCEPTION_V3: 49 | base_model = torchvision.models.inception_v3(pretrained=pretrained) 50 | base_model_list = list(base_model.children())[0:13] 51 | base_model_list.extend(list(base_model.children())[14:17]) 52 | self.features = nn.Sequential(*base_model_list) 53 | self.feature_len = 2048 54 | 55 | else: 56 | raise NotImplementedError(f'{base_model_name} is not supported models') 57 | 58 | def forward(self, x): 59 | # x = x[:,0] 60 | # if len(x.size()) == 5: 61 | batch_sz = x.size(0) 62 | view_num = x.size(1) 63 | x = x.view(x.size(0) * x.size(1), x.size(2), x.size(3), x.size(4)) 64 | 65 | with torch.no_grad(): 66 | x = self.features[:1](x) 67 | x = self.features[1:](x) 68 | 69 | x = x.view(x.size(0), -1) 70 | x = self.fc_features(x) if self.fc_features is not None else x 71 | 72 | # max view pooling 73 | x_view = x.view(batch_sz, view_num, -1) 74 | x, _ = torch.max(x_view, 1) 75 | 76 | return x, x_view 77 | 78 | 79 | class BaseClassifierNet(nn.Module): 80 | def __init__(self, base_model_name=models.VGG13, num_classes=40, pretrained=True): 81 | super(BaseClassifierNet, self).__init__() 82 | base_model_name = base_model_name.upper() 83 | if base_model_name in (models.VGG13, models.VGG13BN, models.ALEXNET, models.VGG11BN): 84 | self.feature_len = 4096 85 | elif base_model_name in (models.RESNET50, models.RESNET101, models.INCEPTION_V3): 86 | self.feature_len = 2048 87 | else: 88 | raise NotImplementedError(f'{base_model_name} is not supported models') 89 | 90 | self.classifier = nn.Linear(self.feature_len, num_classes) 91 | 92 | def forward(self, x): 93 | x = self.classifier(x) 94 | return x 95 | 96 | 97 | class MVCNN(nn.Module): 98 | def __init__(self, pretrained=True): 99 | super(MVCNN, self).__init__() 100 | base_model_name = config.base_model_name 101 | num_classes = config.view_net.num_classes 102 | print(f'\ninit {base_model_name} model...\n') 103 | self.features = BaseFeatureNet(base_model_name, pretrained) 104 | self.classifier = BaseClassifierNet(base_model_name, num_classes, pretrained) 105 | 106 | def forward(self, x): 107 | x, _ = self.features(x) 108 | x = self.classifier(x) 109 | return x 110 | 111 | 112 | -------------------------------------------------------------------------------- /models/PVRNet.py: -------------------------------------------------------------------------------- 1 | from models import * 2 | import config 3 | import numpy as np 4 | 5 | 6 | 7 | class PVRNet(nn.Module): 8 | def __init__(self, n_classes=40, init_weights=True): 9 | super(PVRNet, self).__init__() 10 | 11 | self.fea_dim = 1024 12 | self.num_bottleneck = 512 13 | self.n_scale = [2, 3, 4] 14 | self.n_neighbor = config.pv_net.n_neighbor 15 | 16 | self.mvcnn = BaseFeatureNet(base_model_name=config.base_model_name) 17 | 18 | # Point cloud net 19 | self.trans_net = transform_net(6, 3) 20 | self.conv2d1 = conv_2d(6, 64, 1) 21 | self.conv2d2 = conv_2d(128, 64, 1) 22 | self.conv2d3 = conv_2d(128, 64, 1) 23 | self.conv2d4 = conv_2d(128, 128, 1) 24 | self.conv2d5 = conv_2d(320, 1024, 1) 25 | 26 | self.fusion_fc_mv = nn.Sequential( 27 | fc_layer(4096, 1024, True), 28 | ) 29 | 30 | self.fusion_fc = nn.Sequential( 31 | fc_layer(2048, 512, True), 32 | ) 33 | 34 | self.fusion_conv1 = nn.Sequential( 35 | nn.Linear(2048, 1), 36 | ) 37 | 38 | self.fusion_fc_scales = nn.ModuleList() 39 | 40 | for i in range(len(self.n_scale)): 41 | scale = self.n_scale[i] 42 | fc_fusion = nn.Sequential( 43 | fc_layer((scale+1) * self.fea_dim, self.num_bottleneck, True), 44 | ) 45 | self.fusion_fc_scales += [fc_fusion] 46 | 47 | self.sig = nn.Sigmoid() 48 | 49 | self.fusion_mlp2 = nn.Sequential( 50 | fc_layer(1024, 256, True), 51 | nn.Dropout(p=0.5) 52 | ) 53 | self.fusion_mlp3 = nn.Linear(256, n_classes) 54 | if init_weights: 55 | self.init_mvcnn() 56 | self.init_dgcnn() 57 | 58 | def init_mvcnn(self): 59 | print(f'init parameter from mvcnn {config.base_model_name}') 60 | mvcnn_state_dict = torch.load(config.view_net.ckpt_load_file)['model'] 61 | pvrnet_state_dict = self.state_dict() 62 | 63 | mvcnn_state_dict = {k.replace('features', 'mvcnn', 1): v for k, v in mvcnn_state_dict.items()} 64 | mvcnn_state_dict = {k: v for k, v in mvcnn_state_dict.items() if k in pvrnet_state_dict.keys()} 65 | pvrnet_state_dict.update(mvcnn_state_dict) 66 | self.load_state_dict(pvrnet_state_dict) 67 | print(f'load ckpt from {config.view_net.ckpt_load_file}') 68 | 69 | def init_dgcnn(self): 70 | print(f'init parameter from dgcnn') 71 | dgcnn_state_dict = torch.load(config.pc_net.ckpt_load_file)['model'] 72 | pvrnet_state_dict = self.state_dict() 73 | 74 | dgcnn_state_dict = {k: v for k, v in dgcnn_state_dict.items() if k in pvrnet_state_dict.keys()} 75 | pvrnet_state_dict.update(dgcnn_state_dict) 76 | self.load_state_dict(pvrnet_state_dict) 77 | print(f'load ckpt from {config.pc_net.ckpt_load_file}') 78 | 79 | 80 | def forward(self, pc, mv, get_fea=False): 81 | batch_size = pc.size(0) 82 | view_num = mv.size(1) 83 | mv, mv_view = self.mvcnn(mv) 84 | 85 | x_edge = get_edge_feature(pc, self.n_neighbor) 86 | x_trans = self.trans_net(x_edge) 87 | x = pc.squeeze(-1).transpose(2, 1) 88 | x = torch.bmm(x, x_trans) 89 | x = x.transpose(2, 1) 90 | 91 | x1 = get_edge_feature(x, self.n_neighbor) 92 | x1 = self.conv2d1(x1) 93 | x1, _ = torch.max(x1, dim=-1, keepdim=True) 94 | 95 | x2 = get_edge_feature(x1, self.n_neighbor) 96 | x2 = self.conv2d2(x2) 97 | x2, _ = torch.max(x2, dim=-1, keepdim=True) 98 | 99 | x3 = get_edge_feature(x2, self.n_neighbor) 100 | x3 = self.conv2d3(x3) 101 | x3, _ = torch.max(x3, dim=-1, keepdim=True) 102 | 103 | x4 = get_edge_feature(x3, self.n_neighbor) 104 | x4 = self.conv2d4(x4) 105 | x4, _ = torch.max(x4, dim=-1, keepdim=True) 106 | 107 | x5 = torch.cat((x1, x2, x3, x4), dim=1) 108 | x5 = self.conv2d5(x5) 109 | x5, _ = torch.max(x5, dim=-2, keepdim=True) 110 | 111 | mv_view = mv_view.view(batch_size * view_num, -1) 112 | mv_view = self.fusion_fc_mv(mv_view) 113 | mv_view_expand = mv_view.view(batch_size, view_num, -1) 114 | 115 | pc = x5.squeeze() 116 | pc_expand = pc.unsqueeze(1).expand(-1, view_num, -1) 117 | pc_expand = pc_expand.contiguous().view(batch_size*view_num, -1) 118 | 119 | # Get Relation Scores 120 | fusion_mask = torch.cat((pc_expand, mv_view), dim=1) 121 | fusion_mask = self.fusion_conv1(fusion_mask) 122 | fusion_mask = fusion_mask.view(batch_size, view_num, -1) 123 | fusion_mask = self.sig(fusion_mask) 124 | 125 | # Rank Relation Scores 126 | mask_val, mask_idx = torch.sort(fusion_mask, dim=1, descending=True) 127 | mask_idx = mask_idx.expand(-1, -1, mv_view.size(-1)) 128 | 129 | # Enhance View Feature 130 | mv_view_enhance = torch.mul(mv_view_expand, fusion_mask) + mv_view_expand 131 | 132 | # Get Point-Single-view Fusion 133 | fusion_global = self.fusion_fc(torch.cat((pc_expand, mv_view_enhance.view(batch_size*view_num, self.fea_dim)), dim=1)) 134 | fusion_global, _ = torch.max(fusion_global.view(batch_size, view_num, self.num_bottleneck), dim=1) 135 | 136 | # Get Point-Multi-view Fusion 137 | scale_out = [] 138 | for i in range(len(self.n_scale)): 139 | mv_scale_fea = torch.gather(mv_view_enhance, 1, mask_idx[:, :self.n_scale[i], :]).view(batch_size, self.n_scale[i]*self.fea_dim) 140 | mv_pc_scale = torch.cat((pc, mv_scale_fea), dim=1) 141 | mv_pc_scale = self.fusion_fc_scales[i](mv_pc_scale) 142 | scale_out.append(mv_pc_scale.unsqueeze(2)) 143 | scale_out = torch.cat(scale_out, dim=2).mean(2) 144 | final_out = torch.cat((scale_out, fusion_global),1) 145 | 146 | # Final FC Layers 147 | net_fea = self.fusion_mlp2(final_out) 148 | net = self.fusion_mlp3(net_fea) 149 | 150 | if get_fea: 151 | return net, net_fea 152 | else: 153 | return net 154 | 155 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .model_utils import * 2 | from .DGCNN import DGCNN 3 | from .MVCNN import MVCNN, BaseFeatureNet 4 | from .PVRNet import PVRNet 5 | -------------------------------------------------------------------------------- /models/__pycache__/DGCNN.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/models/__pycache__/DGCNN.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/MVCNN.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/models/__pycache__/MVCNN.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/models/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/model_helper.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/models/__pycache__/model_helper.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/model_helper.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/models/__pycache__/model_helper.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/model_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/models/__pycache__/model_utils.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/mvcnn_gnet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/models/__pycache__/mvcnn_gnet.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/mvcnn_gnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/models/__pycache__/mvcnn_gnet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/mvcnn_resnet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/models/__pycache__/mvcnn_resnet.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/mvcnn_resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/models/__pycache__/mvcnn_resnet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/point_with_attention.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/models/__pycache__/point_with_attention.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/point_with_attention.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/models/__pycache__/point_with_attention.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/zhengyue_mvcnn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/models/__pycache__/zhengyue_mvcnn.cpython-36.pyc -------------------------------------------------------------------------------- /models/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import config 3 | import torch.nn as nn 4 | 5 | 6 | ALEXNET = "ALEXNET" 7 | DENSENET121 = "DENSENET121" 8 | VGG13 = "VGG13" 9 | VGG13BN = "VGG13BN" 10 | VGG11BN = 'VGG11BN' 11 | RESNET50 = "RESNET50" 12 | RESNET101 = "RESNET101" 13 | INCEPTION_V3 = 'INVEPTION_V3' 14 | 15 | # MVCNN functions 16 | 17 | 18 | class conv_2d(nn.Module): 19 | def __init__(self, in_ch, out_ch, kernel): 20 | super(conv_2d, self).__init__() 21 | self.conv = nn.Sequential( 22 | nn.Conv2d(in_ch, out_ch, kernel_size=kernel), 23 | nn.BatchNorm2d(out_ch), 24 | nn.ReLU(inplace=True) 25 | ) 26 | 27 | def forward(self, x): 28 | x = self.conv(x) 29 | return x 30 | 31 | 32 | class fc_layer(nn.Module): 33 | def __init__(self, in_ch, out_ch, bn=True): 34 | super(fc_layer, self).__init__() 35 | if bn: 36 | self.fc = nn.Sequential( 37 | nn.Linear(in_ch, out_ch), 38 | nn.BatchNorm1d(out_ch), 39 | nn.ReLU(inplace=True) 40 | ) 41 | else: 42 | self.fc = nn.Sequential( 43 | nn.Linear(in_ch, out_ch), 44 | nn.ReLU(inplace=True) 45 | ) 46 | 47 | def forward(self, x): 48 | x = self.fc(x) 49 | return x 50 | 51 | 52 | 53 | class transform_net(nn.Module): 54 | def __init__(self, in_ch, K=3): 55 | super(transform_net, self).__init__() 56 | self.K = K 57 | self.conv2d1 = conv_2d(in_ch, 64, 1) 58 | self.conv2d2 = conv_2d(64, 128, 1) 59 | self.conv2d3 = conv_2d(128, 1024, 1) 60 | self.maxpool1 = nn.MaxPool2d(kernel_size=(1024, 1)) 61 | self.fc1 = fc_layer(1024, 512, bn=True) 62 | self.fc2 = fc_layer(512, 256, bn=True) 63 | self.fc3 = nn.Linear(256, K*K) 64 | 65 | 66 | def forward(self, x): 67 | x = self.conv2d1(x) 68 | x = self.conv2d2(x) 69 | x, _ = torch.max(x, dim=-1, keepdim=True) 70 | x = self.conv2d3(x) 71 | x = self.maxpool1(x) 72 | x = x.view(x.size(0), -1) 73 | x = self.fc1(x) 74 | x = self.fc2(x) 75 | x = self.fc3(x) 76 | 77 | iden = torch.eye(3).view(1,9).repeat(x.size(0),1) 78 | iden = iden.to(device=config.device) 79 | x = x + iden 80 | x = x.view(x.size(0), self.K, self.K) 81 | return x 82 | 83 | 84 | def pairwise_distance(x): 85 | batch_size = x.size(0) 86 | point_cloud = torch.squeeze(x) 87 | if batch_size == 1: 88 | point_cloud = torch.unsqueeze(point_cloud, 0) 89 | point_cloud_transpose = torch.transpose(point_cloud, dim0=1, dim1=2) 90 | point_cloud_inner = torch.matmul(point_cloud_transpose, point_cloud) 91 | point_cloud_inner = -2 * point_cloud_inner 92 | point_cloud_square = torch.sum(point_cloud ** 2, dim=1, keepdim=True) 93 | point_cloud_square_transpose = torch.transpose(point_cloud_square, dim0=1, dim1=2) 94 | return point_cloud_square + point_cloud_inner + point_cloud_square_transpose 95 | 96 | 97 | def gather_neighbor(x, nn_idx, n_neighbor): 98 | x = torch.squeeze(x, -1) 99 | batch_size = x.size()[0] 100 | num_dim = x.size()[1] 101 | num_point = x.size()[2] 102 | # point_expand = x.unsqueeze(2).expand(batch_size, num_dim, num_point, num_point) 103 | # nn_idx_expand = nn_idx.unsqueeze(1).expand(batch_size, num_dim, num_point, n_neighbor) 104 | # pc_n = torch.gather(point_expand, -1, nn_idx_expand) 105 | x = x.permute(0,2,1) 106 | a = torch.arange(batch_size).view(batch_size, 1, 1).expand(batch_size, num_point, n_neighbor) 107 | pc_n = x[a, nn_idx, ...] 108 | pc_n = pc_n.permute(0,3,1,2) 109 | return pc_n 110 | 111 | def get_neighbor_feature(x, n_point, n_neighbor): 112 | if len(x.size()) == 3: 113 | x = x.unsqueeze() 114 | adj_matrix = pairwise_distance(x) 115 | _, nn_idx = torch.topk(adj_matrix, n_neighbor, dim=2, largest=False) 116 | nn_idx = nn_idx[:, :n_point, :] 117 | batch_size = x.size()[0] 118 | num_dim = x.size()[1] 119 | num_point = x.size()[2] 120 | point_expand = x[:, :, :n_point, :].expand(-1, -1, -1, num_point) 121 | nn_idx_expand = nn_idx.unsqueeze(1).expand(batch_size, num_dim, n_point, n_neighbor) 122 | pc_n = torch.gather(point_expand, -1, nn_idx_expand) 123 | return pc_n 124 | 125 | 126 | def get_edge_feature(x, n_neighbor): 127 | if len(x.size()) == 3: 128 | x = x.unsqueeze(3) 129 | adj_matrix = pairwise_distance(x) 130 | _, nn_idx = torch.topk(adj_matrix, n_neighbor, dim=2, largest=False) 131 | point_cloud_neighbors = gather_neighbor(x, nn_idx, n_neighbor) 132 | point_cloud_center = x.expand(-1, -1, -1, n_neighbor) 133 | edge_feature = torch.cat((point_cloud_center, point_cloud_neighbors-point_cloud_center), dim=1) 134 | return edge_feature 135 | 136 | 137 | -------------------------------------------------------------------------------- /train_mvcnn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import config 3 | from utils import meter 4 | from torch import nn 5 | from torch import optim 6 | from models import MVCNN 7 | from torch.utils.data import DataLoader 8 | from datasets import * 9 | 10 | 11 | def train(train_loader, net, criterion, optimizer, epoch): 12 | """ 13 | train for one epoch on the training set 14 | """ 15 | batch_time = meter.TimeMeter(True) 16 | data_time = meter.TimeMeter(True) 17 | losses = meter.AverageValueMeter() 18 | prec = meter.ClassErrorMeter(topk=[1], accuracy=True) 19 | # training mode 20 | net.train() 21 | 22 | for i, (views, labels) in enumerate(train_loader): 23 | batch_time.reset() 24 | views = views.to(device=config.device) 25 | labels = labels.to(device=config.device) 26 | 27 | preds = net(views) # bz x C x H x W 28 | loss = criterion(preds, labels) 29 | 30 | prec.add(preds.detach(), labels.detach()) 31 | losses.add(loss.item()) # batchsize 32 | 33 | optimizer.zero_grad() 34 | loss.backward() 35 | optimizer.step() 36 | 37 | if i % config.print_freq == 0: 38 | print(f'Epoch: [{epoch}][{i}/{len(train_loader)}]\t' 39 | f'Batch Time {batch_time.value():.3f}\t' 40 | f'Epoch Time {data_time.value():.3f}\t' 41 | f'Loss {losses.value()[0]:.4f} \t' 42 | f'Prec@1 {prec.value(1):.3f}\t') 43 | 44 | print(f'prec at epoch {epoch}: {prec.value(1)} ') 45 | 46 | 47 | def validate(val_loader, net, epoch): 48 | """ 49 | validation for one epoch on the val set 50 | """ 51 | batch_time = meter.TimeMeter(True) 52 | data_time = meter.TimeMeter(True) 53 | prec = meter.ClassErrorMeter(topk=[1], accuracy=True) 54 | 55 | # testing mode 56 | net.eval() 57 | 58 | for i, (views, labels) in enumerate(val_loader): 59 | batch_time.reset() 60 | # bz x 12 x 3 x 224 x 224 61 | views = views.to(device=config.device) 62 | labels = labels.to(device=config.device) 63 | 64 | preds = net(views) # bz x C x H x W 65 | 66 | prec.add(preds.data, labels.data) 67 | 68 | if i % config.print_freq == 0: 69 | print(f'Epoch: [{epoch}][{i}/{len(val_loader)}]\t' 70 | f'Batch Time {batch_time.value():.3f}\t' 71 | f'Epoch Time {data_time.value():.3f}\t' 72 | f'Prec@1 {prec.value(1):.3f}\t') 73 | 74 | print(f'mean class accuracy at epoch {epoch}: {prec.value(1)} ') 75 | return prec.value(1) 76 | 77 | 78 | def save_record(epoch, prec1, net: nn.Module): 79 | state_dict = net.state_dict() 80 | torch.save(state_dict, osp.join(config.view_net.ckpt_record_folder, f'epoch{epoch}_{prec1:.2f}.pth')) 81 | 82 | 83 | def save_ckpt(epoch, best_prec1, net, optimizer, training_conf=config.view_net): 84 | ckpt = dict( 85 | epoch=epoch, 86 | best_prec1=best_prec1, 87 | model=net.module.state_dict(), 88 | optimizer=optimizer.state_dict(), 89 | training_conf=training_conf 90 | ) 91 | torch.save(ckpt, config.view_net.ckpt_file) 92 | 93 | 94 | def main(): 95 | print('Training Process\nInitializing...\n') 96 | config.init_env() 97 | 98 | train_dataset = data_pth.view_data(config.view_net.data_root, 99 | status=STATUS_TRAIN, 100 | base_model_name=config.base_model_name) 101 | val_dataset = data_pth.view_data(config.view_net.data_root, 102 | status=STATUS_TEST, 103 | base_model_name=config.base_model_name) 104 | 105 | train_loader = DataLoader(train_dataset, batch_size=config.view_net.train.batch_sz, 106 | num_workers=config.num_workers,shuffle = True) 107 | val_loader = DataLoader(val_dataset, batch_size=config.view_net.train.batch_sz, 108 | num_workers=config.num_workers,shuffle=True) 109 | 110 | best_prec1 = 0 111 | resume_epoch = 0 112 | # create model 113 | net = MVCNN() 114 | net = net.to(device=config.device) 115 | net = nn.DataParallel(net) 116 | optimizer = optim.SGD(net.parameters(), config.view_net.train.lr, 117 | momentum=config.view_net.train.momentum, 118 | weight_decay=config.view_net.train.weight_decay) 119 | # optimizer = optim.Adam(net.parameters(), config.view_net.train.lr, 120 | # weight_decay=config.view_net.train.weight_decay) 121 | 122 | if config.view_net.train.resume: 123 | print(f'loading pretrained model from {config.view_net.ckpt_file}') 124 | checkpoint = torch.load(config.view_net.ckpt_file) 125 | net.module.load_state_dict({k[7:]: v for k, v in checkpoint['model'].items()}) 126 | # net.load_state_dict(checkpoint['model']) 127 | optimizer.load_state_dict(checkpoint['optimizer']) 128 | best_prec1 = checkpoint['best_prec1'] 129 | if config.view_net.train.resume_epoch is not None: 130 | resume_epoch = config.view_net.train.resume_epoch 131 | else: 132 | resume_epoch = checkpoint['epoch'] + 1 133 | 134 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 5, 0.5) 135 | criterion = nn.CrossEntropyLoss() 136 | criterion = criterion.to(device=config.device) 137 | 138 | # for p in net.module.feature.parameters(): 139 | # p.requires_grad = False 140 | 141 | for epoch in range(resume_epoch, config.view_net.train.max_epoch): 142 | if epoch >= 5: 143 | for p in net.parameters(): 144 | p.requires_grad = True 145 | lr_scheduler.step(epoch=epoch) 146 | 147 | train(train_loader, net, criterion, optimizer, epoch) 148 | 149 | with torch.no_grad(): 150 | prec1 = validate(val_loader, net, epoch) 151 | 152 | # save checkpoints 153 | if best_prec1 < prec1: 154 | best_prec1 = prec1 155 | save_ckpt(epoch, best_prec1, net, optimizer) 156 | 157 | save_record(epoch, prec1, net.module) 158 | print('curr accuracy: ', prec1) 159 | print('best accuracy: ', best_prec1) 160 | 161 | print('Train Finished!') 162 | 163 | 164 | if __name__ == '__main__': 165 | main() 166 | 167 | -------------------------------------------------------------------------------- /train_pc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import config 3 | import torch 4 | import os.path as osp 5 | from utils import meter 6 | from torch import nn 7 | from torch import optim 8 | from models import DGCNN 9 | from torch.utils.data import DataLoader 10 | from datasets import data_pth, STATUS_TRAIN, STATUS_TEST 11 | 12 | 13 | def train(train_loader, net, criterion, optimizer, epoch): 14 | """ 15 | train for one epoch on the training set 16 | """ 17 | batch_time = meter.TimeMeter(True) 18 | data_time = meter.TimeMeter(True) 19 | losses = meter.AverageValueMeter() 20 | prec = meter.ClassErrorMeter(topk=[1], accuracy=True) 21 | # training mode 22 | net.train() 23 | 24 | for i, (pcs, labels) in enumerate(train_loader): 25 | batch_time.reset() 26 | pcs = pcs.to(device=config.device) 27 | labels = labels.to(device=config.device) 28 | 29 | preds = net(pcs) # bz x C x H x W 30 | loss = criterion(preds, labels) 31 | 32 | prec.add(preds.data, labels.data) 33 | losses.add(loss.item()) # batchsize 34 | 35 | optimizer.zero_grad() 36 | loss.backward() 37 | optimizer.step() 38 | 39 | if i % config.print_freq == 0: 40 | print(f'Epoch: [{epoch}][{i}/{len(train_loader)}]\t' 41 | f'Batch Time {batch_time.value():.3f}\t' 42 | f'Epoch Time {data_time.value():.3f}\t' 43 | f'Loss {losses.value()[0]:.4f} \t' 44 | f'Prec@1 {prec.value(1):.3f}\t') 45 | 46 | print(f'prec at epoch {epoch}: {prec.value(1)} ') 47 | 48 | 49 | def validate(val_loader, net, epoch): 50 | """ 51 | validation for one epoch on the val set 52 | """ 53 | batch_time = meter.TimeMeter(True) 54 | data_time = meter.TimeMeter(True) 55 | prec = meter.ClassErrorMeter(topk=[1], accuracy=True) 56 | 57 | # testing mode 58 | net.eval() 59 | 60 | for i, (pcs, labels) in enumerate(val_loader): 61 | batch_time.reset() 62 | # bz x 12 x 3 x 224 x 224 63 | pcs = pcs.to(device=config.device) 64 | labels = labels.to(device=config.device) 65 | 66 | preds = net(pcs) # bz x C x H x W 67 | 68 | prec.add(preds.data, labels.data) 69 | 70 | if i % config.print_freq == 0: 71 | print(f'Epoch: [{epoch}][{i}/{len(val_loader)}]\t' 72 | f'Batch Time {batch_time.value():.3f}\t' 73 | f'Epoch Time {data_time.value():.3f}\t' 74 | f'Prec@1 {prec.value(1):.3f}\t') 75 | 76 | print(f'mean class accuracy at epoch {epoch}: {prec.value(1)} ') 77 | return prec.value(1) 78 | 79 | 80 | def save_record(epoch, prec1, net: nn.Module): 81 | state_dict = net.state_dict() 82 | torch.save(state_dict, osp.join(config.pc_net.ckpt_record_folder, f'epoch{epoch}_{prec1:.2f}.pth')) 83 | 84 | 85 | def save_ckpt(epoch, best_prec1, net, optimizer, training_conf=config.pc_net): 86 | ckpt = dict( 87 | epoch=epoch, 88 | best_prec1=best_prec1, 89 | model=net.module.state_dict(), 90 | optimizer=optimizer.state_dict(), 91 | training_conf=training_conf 92 | ) 93 | torch.save(ckpt, config.pc_net.ckpt_file) 94 | 95 | 96 | def main(): 97 | print('Training Process\nInitializing...\n') 98 | config.init_env() 99 | 100 | train_dataset = data_pth.pc_data(config.pc_net.data_root, status=STATUS_TRAIN) 101 | val_dataset = data_pth.pc_data(config.pc_net.data_root, status=STATUS_TEST) 102 | 103 | train_loader = DataLoader(train_dataset, batch_size=config.pc_net.train.batch_sz, 104 | num_workers=config.num_workers,shuffle = True,drop_last=False) 105 | val_loader = DataLoader(val_dataset, batch_size=config.pc_net.validation.batch_sz, 106 | num_workers=config.num_workers,shuffle=True) 107 | 108 | best_prec1 = 0 109 | resume_epoch = 0 110 | # create model 111 | net = DGCNN(n_neighbor=config.pc_net.n_neighbor,num_classes=config.pc_net.num_classes) 112 | net = torch.nn.DataParallel(net) 113 | net = net.to(device=config.device) 114 | optimizer = optim.Adam(net.parameters(), config.pc_net.train.lr, 115 | weight_decay=config.pc_net.train.weight_decay) 116 | 117 | if config.pc_net.train.resume: 118 | print(f'loading pretrained model from {config.pc_net.ckpt_file}') 119 | checkpoint = torch.load(config.pc_net.ckpt_file) 120 | net.module.load_state_dict({k[7:]: v for k, v in checkpoint['model'].items()}) 121 | optimizer.load_state_dict(checkpoint['optimizer']) 122 | best_prec1 = checkpoint['best_prec1'] 123 | if config.pc_net.train.resume_epoch is not None: 124 | resume_epoch = config.pc_net.train.resume_epoch 125 | else: 126 | resume_epoch = checkpoint['epoch'] + 1 127 | 128 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 20, 0.7) 129 | criterion = nn.CrossEntropyLoss() 130 | criterion = criterion.to(device=config.device) 131 | 132 | for epoch in range(resume_epoch, config.pc_net.train.max_epoch): 133 | 134 | lr_scheduler.step(epoch=epoch) 135 | # train 136 | train(train_loader, net, criterion, optimizer, epoch) 137 | # validation 138 | with torch.no_grad(): 139 | prec1 = validate(val_loader, net, epoch) 140 | 141 | # save checkpoints 142 | if prec1 > best_prec1: 143 | best_prec1 = prec1 144 | save_ckpt(epoch, best_prec1, net, optimizer) 145 | 146 | print('curr accuracy: ', prec1) 147 | print('best accuracy: ', best_prec1) 148 | 149 | print('Train Finished!') 150 | 151 | 152 | if __name__ == '__main__': 153 | main() 154 | 155 | -------------------------------------------------------------------------------- /train_pvrnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import config 3 | from utils import meter 4 | from torch import nn 5 | from torch import optim 6 | from models import PVRNet 7 | from torch.utils.data import DataLoader 8 | from datasets import * 9 | import argparse 10 | 11 | 12 | def train(train_loader, net, criterion, optimizer, epoch): 13 | """ 14 | train for one epoch on the training set 15 | """ 16 | batch_time = meter.TimeMeter(True) 17 | data_time = meter.TimeMeter(True) 18 | losses = meter.AverageValueMeter() 19 | prec = meter.ClassErrorMeter(topk=[1], accuracy=True) 20 | # training mode 21 | net.train() 22 | 23 | for i, (views, pcs, labels) in enumerate(train_loader): 24 | batch_time.reset() 25 | views = views.to(device=config.device) 26 | pcs = pcs.to(device=config.device) 27 | labels = labels.to(device=config.device) 28 | 29 | preds = net(pcs, views) # bz x C x H x W 30 | loss = criterion(preds, labels) 31 | 32 | prec.add(preds.detach(), labels.detach()) 33 | losses.add(loss.item()) # batchsize 34 | 35 | optimizer.zero_grad() 36 | loss.backward() 37 | optimizer.step() 38 | 39 | if i % config.print_freq == 0: 40 | print(f'Epoch: [{epoch}][{i}/{len(train_loader)}]\t' 41 | f'Batch Time {batch_time.value():.3f}\t' 42 | f'Epoch Time {data_time.value():.3f}\t' 43 | f'Loss {losses.value()[0]:.4f} \t' 44 | f'Prec@1 {prec.value(1):.3f}\t') 45 | 46 | print(f'prec at epoch {epoch}: {prec.value(1)} ') 47 | 48 | 49 | def validate(val_loader, net, epoch): 50 | """ 51 | validation for one epoch on the val set 52 | """ 53 | batch_time = meter.TimeMeter(True) 54 | data_time = meter.TimeMeter(True) 55 | prec = meter.ClassErrorMeter(topk=[1], accuracy=True) 56 | retrieval_map = meter.RetrievalMAPMeter() 57 | 58 | # testing mode 59 | net.eval() 60 | 61 | total_seen_class = [0 for _ in range(40)] 62 | total_right_class = [0 for _ in range(40)] 63 | 64 | for i, (views, pcs, labels) in enumerate(val_loader): 65 | batch_time.reset() 66 | 67 | views = views.to(device=config.device) 68 | pcs = pcs.to(device=config.device) 69 | labels = labels.to(device=config.device) 70 | 71 | preds, fts = net(pcs, views, get_fea=True) # bz x C x H x W 72 | # prec.add(preds.data, labels.data) 73 | 74 | prec.add(preds.data, labels.data) 75 | retrieval_map.add(fts.detach()/torch.norm(fts.detach(), 2, 1, True), labels.detach()) 76 | for j in range(views.size(0)): 77 | total_seen_class[labels.data[j]] += 1 78 | total_right_class[labels.data[j]] += (np.argmax(preds.data,1)[j] == labels.cpu()[j]) 79 | 80 | if i % config.print_freq == 0: 81 | print(f'Epoch: [{epoch}][{i}/{len(val_loader)}]\t' 82 | f'Batch Time {batch_time.value():.3f}\t' 83 | f'Epoch Time {data_time.value():.3f}\t' 84 | f'Prec@1 {prec.value(1):.3f}\t') 85 | 86 | mAP = retrieval_map.mAP() 87 | print(f' instance accuracy at epoch {epoch}: {prec.value(1)} ') 88 | print(f' mean class accuracy at epoch {epoch}: {(np.mean(np.array(total_right_class)/np.array(total_seen_class,dtype=np.float)))} ') 89 | print(f' map at epoch {epoch}: {mAP} ') 90 | return prec.value(1), mAP 91 | 92 | 93 | def save_ckpt(epoch, epoch_pc, epoch_all, best_prec1, net, optimizer_pc, optimizer_all, training_conf=config.pv_net): 94 | ckpt = dict( 95 | epoch=epoch, 96 | epoch_pc=epoch_pc, 97 | epoch_all=epoch_all, 98 | best_prec1=best_prec1, 99 | model=net.module.state_dict(), 100 | optimizer_pc=optimizer_pc.state_dict(), 101 | optimizer_all=optimizer_all.state_dict(), 102 | training_conf=training_conf 103 | ) 104 | torch.save(ckpt, config.pv_net.ckpt_file) 105 | 106 | def parse_args(): 107 | parser = argparse.ArgumentParser( 108 | description="Main", 109 | ) 110 | parser.add_argument("-batch_size", '-b', type=int, default=32, help="Batch size") 111 | parser.add_argument('-gpu', '-g', type=str, default=None, help='GPUS used') 112 | parser.add_argument( 113 | "-epochs", '-e', type=int, default=None, help="Number of epochs to train for" 114 | ) 115 | return parser.parse_args() 116 | 117 | def main(): 118 | print('Training Process\nInitializing...\n') 119 | config.init_env() 120 | args = parse_args() 121 | 122 | total_batch_sz = config.pv_net.train.batch_sz * len(config.available_gpus.split(',')) 123 | total_epoch = config.pv_net.train.max_epoch 124 | 125 | if args.gpu is not None: 126 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 127 | total_batch_sz= config.pv_net.train.batch_sz * len(args.gpu.split(',')) 128 | if args.epochs is not None: 129 | total_epoch = args.epochs 130 | 131 | 132 | train_dataset = pc_view_data(config.pv_net.pc_root, 133 | config.pv_net.view_root, 134 | status=STATUS_TRAIN, 135 | base_model_name=config.base_model_name) 136 | val_dataset = pc_view_data(config.pv_net.pc_root, 137 | config.pv_net.view_root, 138 | status=STATUS_TEST, 139 | base_model_name=config.base_model_name) 140 | 141 | train_loader = DataLoader(train_dataset, batch_size=total_batch_sz, 142 | num_workers=config.num_workers,shuffle = True, drop_last=True) 143 | val_loader = DataLoader(val_dataset, batch_size=total_batch_sz, 144 | num_workers=config.num_workers,shuffle=True) 145 | 146 | best_prec1 = 0 147 | best_map = 0 148 | resume_epoch = 0 149 | 150 | epoch_pc_view = 0 151 | epoch_pc = 0 152 | 153 | # create model 154 | net = PVRNet() 155 | net = net.to(device=config.device) 156 | net = nn.DataParallel(net) 157 | 158 | # optimizer 159 | fc_param = [{'params': v} for k, v in net.named_parameters() if 'fusion' in k] 160 | if config.pv_net.train.optim == 'Adam': 161 | optimizer_fc = optim.Adam(fc_param, config.pv_net.train.fc_lr, 162 | weight_decay=config.pv_net.train.weight_decay) 163 | 164 | optimizer_all = optim.Adam(net.parameters(), config.pv_net.train.all_lr, 165 | weight_decay=config.pv_net.train.weight_decay) 166 | elif config.pv_net.train.optim == 'SGD': 167 | optimizer_fc = optim.SGD(fc_param, config.pv_net.train.fc_lr, 168 | momentum=config.pv_net.train.momentum, 169 | weight_decay=config.pv_net.train.weight_decay) 170 | 171 | optimizer_all = optim.SGD(net.parameters(), config.pv_net.train.all_lr, 172 | momentum=config.pv_net.train.momentum, 173 | weight_decay=config.pv_net.train.weight_decay) 174 | else: 175 | raise NotImplementedError 176 | print(f'use {config.pv_net.train.optim} optimizer') 177 | print(f'Sclae:{net.module.n_scale} ') 178 | 179 | 180 | if config.pv_net.train.resume: 181 | print(f'loading pretrained model from {config.pv_net.ckpt_file}') 182 | checkpoint = torch.load(config.pv_net.ckpt_file) 183 | state_dict = checkpoint['model'] 184 | net.module.load_state_dict(checkpoint['model']) 185 | optimizer_fc.load_state_dict(checkpoint['optimizer_pc']) 186 | optimizer_all.load_state_dict(checkpoint['optimizer_all']) 187 | best_prec1 = checkpoint['best_prec1'] 188 | epoch_pc_view = checkpoint['epoch_all'] 189 | epoch_pc = checkpoint['epoch_pc'] 190 | if config.pv_net.train.resume_epoch is not None: 191 | resume_epoch = config.pv_net.train.resume_epoch 192 | else: 193 | resume_epoch = max(checkpoint['epoch_pc'], checkpoint['epoch_all']) 194 | 195 | if config.pv_net.train.iter_train == False: 196 | print ('No iter') 197 | lr_scheduler_fc = torch.optim.lr_scheduler.StepLR(optimizer_fc, 5, 0.3) 198 | lr_scheduler_all = torch.optim.lr_scheduler.StepLR(optimizer_all, 5, 0.3) 199 | else: 200 | print ('iter') 201 | lr_scheduler_fc = torch.optim.lr_scheduler.StepLR(optimizer_fc, 6, 0.3) 202 | lr_scheduler_all = torch.optim.lr_scheduler.StepLR(optimizer_all, 6, 0.3) 203 | 204 | criterion = nn.CrossEntropyLoss() 205 | criterion = criterion.to(device=config.device) 206 | 207 | for epoch in range(resume_epoch, total_epoch): 208 | 209 | 210 | if config.pv_net.train.iter_train == True: 211 | if epoch < 12: 212 | lr_scheduler_fc.step(epoch=epoch_pc) 213 | print(lr_scheduler_fc.get_lr()) 214 | 215 | if (epoch_pc + 1) % 3 == 0: 216 | print ('train score block') 217 | for m in net.module.parameters(): 218 | m.reqires_grad = False 219 | net.module.fusion_conv1.requires_grad = True 220 | else: 221 | print ('train all fc block') 222 | for m in net.module.parameters(): 223 | m.reqires_grad = True 224 | 225 | train(train_loader, net, criterion, optimizer_fc, epoch) 226 | epoch_pc += 1 227 | 228 | else: 229 | lr_scheduler_all.step(epoch=epoch_pc_view) 230 | print(lr_scheduler_all.get_lr()) 231 | 232 | if (epoch_pc_view + 1) % 3 == 0: 233 | print('train score block') 234 | for m in net.module.parameters(): 235 | m.reqires_grad = False 236 | net.module.fusion_conv1.requires_grad = True 237 | else: 238 | print('train all block') 239 | for m in net.module.parameters(): 240 | m.reqires_grad = True 241 | 242 | train(train_loader, net, criterion, optimizer_all, epoch) 243 | epoch_pc_view += 1 244 | 245 | 246 | else: 247 | if epoch < 10: 248 | lr_scheduler_fc.step(epoch=epoch_pc) 249 | print(lr_scheduler_fc.get_lr()) 250 | train(train_loader, net, criterion, optimizer_fc, epoch) 251 | epoch_pc += 1 252 | 253 | else: 254 | lr_scheduler_all.step(epoch=epoch_pc_view) 255 | print(lr_scheduler_all.get_lr()) 256 | train(train_loader, net, criterion, optimizer_all, epoch) 257 | epoch_pc_view += 1 258 | 259 | 260 | with torch.no_grad(): 261 | prec1, retrieval_map = validate(val_loader, net, epoch) 262 | 263 | # save checkpoints 264 | if best_prec1 < prec1: 265 | best_prec1 = prec1 266 | save_ckpt(epoch, epoch_pc, epoch_pc_view, best_prec1, net, optimizer_fc, optimizer_all) 267 | if best_map < retrieval_map: 268 | best_map = retrieval_map 269 | 270 | print('curr accuracy: ', prec1) 271 | print('best accuracy: ', best_prec1) 272 | print('best map: ', best_map) 273 | 274 | print('Train Finished!') 275 | 276 | 277 | if __name__ == '__main__': 278 | main() 279 | 280 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/utils/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/config.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/utils/__pycache__/config.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/utils/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/generate_pc.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/utils/__pycache__/generate_pc.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/generate_pc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/utils/__pycache__/generate_pc.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/split_fun.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/utils/__pycache__/split_fun.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/split_fun.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/utils/__pycache__/split_fun.cpython-36.pyc -------------------------------------------------------------------------------- /utils/generate_pc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import bisect 3 | import config 4 | import numpy as np 5 | import os.path as osp 6 | import matplotlib.pyplot as plt 7 | from tqdm import tqdm 8 | from glob import glob 9 | from math import isnan 10 | from random import shuffle 11 | from mpl_toolkits.mplot3d import Axes3D 12 | 13 | 14 | def des(a, b): 15 | return np.linalg.norm(a - b) 16 | 17 | 18 | def get_info(shape_dir): 19 | splits = shape_dir.split('/') 20 | class_name = splits[-3] 21 | set_name = splits[-2] 22 | file_name = splits[-1].split('.')[0] 23 | return class_name, set_name, file_name 24 | 25 | 26 | def random_point_triangle(a, b, c): 27 | r1 = np.random.random() 28 | r2 = np.random.random() 29 | p = np.sqrt(r1) * (r2 * c + b * (1-r2)) + a * (1-np.sqrt(r1)) 30 | return p 31 | 32 | 33 | def triangle_area(p1, p2, p3): 34 | a = des(p1, p2) 35 | b = des(p1, p3) 36 | c = des(p2, p3) 37 | p = (a+b+c)/2.0 38 | area = np.sqrt(p*(p-a)*(p-b)*(p-c)) 39 | if isnan(area): 40 | # print('find nan') 41 | area = 1e-6 42 | return area 43 | 44 | 45 | def uniform_sampling(points, faces, n_samples): 46 | sampled_points = [] 47 | total_area = 0 48 | cum_sum = [] 49 | for _idx, face in enumerate(faces): 50 | total_area += triangle_area(points[face[0]], points[face[1]], points[face[2]]) 51 | if isnan(total_area): 52 | print('find nan') 53 | cum_sum.append(total_area) 54 | 55 | for _idx in range(n_samples): 56 | tmp = np.random.random()*total_area 57 | face_idx = bisect.bisect_left(cum_sum, tmp) 58 | pc = random_point_triangle(points[faces[face_idx][0]], 59 | points[faces[face_idx][1]], 60 | points[faces[face_idx][2]]) 61 | sampled_points.append(pc) 62 | return np.array(sampled_points) 63 | 64 | 65 | def resize_pc(pc, L): 66 | """ 67 | normalize point cloud in range L 68 | :param pc: type list 69 | :param L: 70 | :return: type list 71 | """ 72 | pc_L_max = np.sqrt(np.sum(pc ** 2, 1)).max() 73 | return pc/pc_L_max*L 74 | 75 | 76 | def normal_pc(pc): 77 | """ 78 | normalize point cloud in range L 79 | :param pc: type list 80 | :return: type list 81 | """ 82 | pc_mean = pc.mean(axis=0) 83 | pc = pc - pc_mean 84 | pc_L_max = np.max(np.sqrt(np.sum(abs(pc ** 2), axis=-1))) 85 | pc = pc/pc_L_max 86 | return pc 87 | 88 | 89 | def get_pc(shape, point_each): 90 | points = [] 91 | faces = [] 92 | with open(shape, 'r') as f: 93 | line = f.readline().strip() 94 | if line == 'OFF': 95 | num_verts, num_faces, num_edge = f.readline().split() 96 | num_verts = int(num_verts) 97 | num_faces = int(num_faces) 98 | else: 99 | num_verts, num_faces, num_edge = line[3:].split() 100 | num_verts = int(num_verts) 101 | num_faces = int(num_faces) 102 | 103 | for idx in range(num_verts): 104 | line = f.readline() 105 | point = [float(v) for v in line.split()] 106 | points.append(point) 107 | 108 | for idx in range(num_faces): 109 | line = f.readline() 110 | face = [int(t_f) for t_f in line.split()] 111 | faces.append(face[1:]) 112 | 113 | points = np.array(points) 114 | pc = resize_pc(points, 10) 115 | pc = uniform_sampling(pc, faces, point_each) 116 | 117 | pc = normal_pc(pc) 118 | 119 | return pc 120 | 121 | 122 | def generate(raw_off_root, vis_pc=False, num_pc_each=2018): 123 | shape_all = glob(osp.join(raw_off_root, '*', '*', '*.off')) 124 | shuffle(shape_all) 125 | cnt = 0 126 | for shape in tqdm(shape_all): 127 | class_name, set_name, file_name = get_info(shape) 128 | new_folder = osp.join(config.pc_net.data_root, class_name, set_name) 129 | new_dir = osp.join(new_folder, file_name) 130 | if osp.exists(new_dir+'.npy'): 131 | if vis_pc and not osp.exists(new_dir+'.jpg'): 132 | pc = np.load(new_dir+'.npy') 133 | draw_pc(pc, show=False, save_dir=new_dir+'.jpg') 134 | else: 135 | pc = get_pc(shape, num_pc_each) 136 | if not osp.exists(new_folder): 137 | os.makedirs(new_folder) 138 | np.save(new_dir+'.npy', pc) 139 | if vis_pc: 140 | if cnt%10==0: 141 | draw_pc(pc, show=False, save_dir=new_dir+'.jpg') 142 | cnt += 1 143 | 144 | 145 | def draw_pc(pc, show=True, save_dir=None): 146 | ax = plt.figure().add_subplot(111, projection='3d') 147 | ax.scatter(pc[:, 0], pc[:, 1], pc[:, 2], marker='.') 148 | ax.grid(False) 149 | # ax.axis('off') 150 | if show: 151 | plt.show() 152 | if save_dir is not None: 153 | plt.savefig(save_dir) 154 | 155 | 156 | if __name__ == '__main__': 157 | generate('/repository/Modelnet40') 158 | # file_name = '/home/fyf/data/pc_ModelNet40/airplane/train/airplane_0165.npy' 159 | # pc = np.load(file_name) 160 | # draw_pc(pc) 161 | 162 | 163 | -------------------------------------------------------------------------------- /utils/meter/__init__.py: -------------------------------------------------------------------------------- 1 | from .averagevaluemeter import AverageValueMeter 2 | from .classerrormeter import ClassErrorMeter 3 | from .confusionmeter import ConfusionMeter 4 | from .timemeter import TimeMeter 5 | from .msemeter import MSEMeter 6 | from .movingaveragevaluemeter import MovingAverageValueMeter 7 | from .aucmeter import AUCMeter 8 | from .apmeter import APMeter 9 | from .mapmeter import mAPMeter 10 | from .retrievalmeter import RetrievalMAPMeter 11 | -------------------------------------------------------------------------------- /utils/meter/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/utils/meter/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /utils/meter/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/utils/meter/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/meter/__pycache__/apmeter.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/utils/meter/__pycache__/apmeter.cpython-35.pyc -------------------------------------------------------------------------------- /utils/meter/__pycache__/apmeter.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/utils/meter/__pycache__/apmeter.cpython-36.pyc -------------------------------------------------------------------------------- /utils/meter/__pycache__/aucmeter.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/utils/meter/__pycache__/aucmeter.cpython-35.pyc -------------------------------------------------------------------------------- /utils/meter/__pycache__/aucmeter.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/utils/meter/__pycache__/aucmeter.cpython-36.pyc -------------------------------------------------------------------------------- /utils/meter/__pycache__/averagevaluemeter.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/utils/meter/__pycache__/averagevaluemeter.cpython-35.pyc -------------------------------------------------------------------------------- /utils/meter/__pycache__/averagevaluemeter.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/utils/meter/__pycache__/averagevaluemeter.cpython-36.pyc -------------------------------------------------------------------------------- /utils/meter/__pycache__/classerrormeter.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/utils/meter/__pycache__/classerrormeter.cpython-35.pyc -------------------------------------------------------------------------------- /utils/meter/__pycache__/classerrormeter.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/utils/meter/__pycache__/classerrormeter.cpython-36.pyc -------------------------------------------------------------------------------- /utils/meter/__pycache__/confusionmeter.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/utils/meter/__pycache__/confusionmeter.cpython-35.pyc -------------------------------------------------------------------------------- /utils/meter/__pycache__/confusionmeter.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/utils/meter/__pycache__/confusionmeter.cpython-36.pyc -------------------------------------------------------------------------------- /utils/meter/__pycache__/mapmeter.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/utils/meter/__pycache__/mapmeter.cpython-35.pyc -------------------------------------------------------------------------------- /utils/meter/__pycache__/mapmeter.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/utils/meter/__pycache__/mapmeter.cpython-36.pyc -------------------------------------------------------------------------------- /utils/meter/__pycache__/meter.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/utils/meter/__pycache__/meter.cpython-35.pyc -------------------------------------------------------------------------------- /utils/meter/__pycache__/meter.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/utils/meter/__pycache__/meter.cpython-36.pyc -------------------------------------------------------------------------------- /utils/meter/__pycache__/movingaveragevaluemeter.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/utils/meter/__pycache__/movingaveragevaluemeter.cpython-35.pyc -------------------------------------------------------------------------------- /utils/meter/__pycache__/movingaveragevaluemeter.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/utils/meter/__pycache__/movingaveragevaluemeter.cpython-36.pyc -------------------------------------------------------------------------------- /utils/meter/__pycache__/msemeter.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/utils/meter/__pycache__/msemeter.cpython-35.pyc -------------------------------------------------------------------------------- /utils/meter/__pycache__/msemeter.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/utils/meter/__pycache__/msemeter.cpython-36.pyc -------------------------------------------------------------------------------- /utils/meter/__pycache__/timemeter.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/utils/meter/__pycache__/timemeter.cpython-35.pyc -------------------------------------------------------------------------------- /utils/meter/__pycache__/timemeter.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/PVRNet/7b07d62788e67b4052c36b9ca37f6163234a4648/utils/meter/__pycache__/timemeter.cpython-36.pyc -------------------------------------------------------------------------------- /utils/meter/apmeter.py: -------------------------------------------------------------------------------- 1 | import math 2 | from . import meter 3 | import numpy as np 4 | import torch 5 | 6 | 7 | class APMeter(meter.Meter): 8 | """ 9 | The APMeter measures the average precision per class. 10 | 11 | The APMeter is designed to operate on `NxK` Tensors `output` and 12 | `target`, and optionally a `Nx1` Tensor weight where (1) the `output` 13 | contains model output scores for `N` examples and `K` classes that ought to 14 | be higher when the model is more convinced that the example should be 15 | positively labeled, and smaller when the model believes the example should 16 | be negatively labeled (for instance, the output of a sigmoid function); (2) 17 | the `target` contains only values 0 (for negative examples) and 1 18 | (for positive examples); and (3) the `weight` ( > 0) represents weight for 19 | each sample. 20 | """ 21 | def __init__(self): 22 | super(APMeter, self).__init__() 23 | self.reset() 24 | 25 | def reset(self): 26 | """Resets the meter with empty member variables""" 27 | self.scores = torch.FloatTensor(torch.FloatStorage()) 28 | self.targets = torch.LongTensor(torch.LongStorage()) 29 | self.weights = torch.FloatTensor(torch.FloatStorage()) 30 | 31 | def add(self, output, target, weight=None): 32 | """ 33 | Args: 34 | output (Tensor): NxK tensor that for each of the N examples 35 | indicates the probability of the example belonging to each of 36 | the K classes, according to the model. The probabilities should 37 | sum to one over all classes 38 | target (Tensor): binary NxK tensort that encodes which of the K 39 | classes are associated with the N-th input 40 | (eg: a row [0, 1, 0, 1] indicates that the example is 41 | associated with classes 2 and 4) 42 | weight (optional, Tensor): Nx1 tensor representing the weight for 43 | each example (each weight > 0) 44 | """ 45 | if not torch.is_tensor(output): 46 | output = torch.from_numpy(output) 47 | if not torch.is_tensor(target): 48 | target = torch.from_numpy(target) 49 | 50 | if weight is not None: 51 | if not torch.is_tensor(weight): 52 | weight = torch.from_numpy(weight) 53 | weight = weight.squeeze() 54 | if output.dim() == 1: 55 | output = output.view(-1, 1) 56 | else: 57 | assert output.dim() == 2, \ 58 | 'wrong output size (should be 1D or 2D with one column \ 59 | per class)' 60 | if target.dim() == 1: 61 | target = target.view(-1, 1) 62 | else: 63 | assert target.dim() == 2, \ 64 | 'wrong target size (should be 1D or 2D with one column \ 65 | per class)' 66 | if weight is not None: 67 | assert weight.dim() == 1, 'Weight dimension should be 1' 68 | assert weight.numel() == target.size(0), \ 69 | 'Weight dimension 1 should be the same as that of target' 70 | assert torch.min(weight) >= 0, 'Weight should be non-negative only' 71 | assert torch.equal(target.float()**2, target.float()), \ 72 | 'targets should be binary (0 or 1)' 73 | if self.scores.numel() > 0: 74 | assert target.size(1) == self.targets.size(1), \ 75 | 'dimensions for output should match previously added examples.' 76 | 77 | # make sure storage is of sufficient size 78 | if self.scores.storage().size() < self.scores.numel() + output.numel(): 79 | new_size = math.ceil(self.scores.storage().size() * 1.5) 80 | new_weight_size = math.ceil(self.weights.storage().size() * 1.5) 81 | self.scores.storage().resize_(int(new_size + output.numel())) 82 | self.targets.storage().resize_(int(new_size + output.numel())) 83 | if weight is not None: 84 | self.weights.storage().resize_(int(new_weight_size 85 | + output.size(0))) 86 | 87 | # store scores and targets 88 | offset = self.scores.size(0) if self.scores.dim() > 0 else 0 89 | self.scores.resize_(offset + output.size(0), output.size(1)) 90 | self.targets.resize_(offset + target.size(0), target.size(1)) 91 | self.scores.narrow(0, offset, output.size(0)).copy_(output) 92 | self.targets.narrow(0, offset, target.size(0)).copy_(target) 93 | 94 | if weight is not None: 95 | self.weights.resize_(offset + weight.size(0)) 96 | self.weights.narrow(0, offset, weight.size(0)).copy_(weight) 97 | 98 | def value(self): 99 | """Returns the model's average precision for each class 100 | 101 | Return: 102 | ap (FloatTensor): 1xK tensor, with avg precision for each class k 103 | """ 104 | 105 | if self.scores.numel() == 0: 106 | return 0 107 | ap = torch.zeros(self.scores.size(1)) 108 | rg = torch.arange(1, self.scores.size(0)+1).float() 109 | if self.weights.numel() > 0: 110 | weight = self.weights.new(self.weights.size()) 111 | weighted_truth = self.weights.new(self.weights.size()) 112 | 113 | # compute average precision for each class 114 | for k in range(self.scores.size(1)): 115 | # sort scores 116 | scores = self.scores[:, k] 117 | targets = self.targets[:, k] 118 | _, sortind = torch.sort(scores, 0, True) 119 | truth = targets[sortind] 120 | if self.weights.numel() > 0: 121 | weight = self.weights[sortind] 122 | weighted_truth = truth.float() * weight 123 | rg = weight.cumsum(0) 124 | 125 | # compute true positive sums 126 | if self.weights.numel() > 0: 127 | tp = weighted_truth.cumsum(0) 128 | else: 129 | tp = truth.float().cumsum(0) 130 | 131 | # compute precision curve 132 | precision = tp.div(rg) 133 | 134 | # compute average precision 135 | ap[k] = precision[truth.byte()].sum() / max(truth.sum(), 1) 136 | return ap 137 | -------------------------------------------------------------------------------- /utils/meter/aucmeter.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numbers 3 | from . import meter 4 | import numpy as np 5 | import torch 6 | 7 | 8 | class AUCMeter(meter.Meter): 9 | """ 10 | The AUCMeter measures the area under the receiver-operating characteristic 11 | (ROC) curve for binary classification problems. The area under the curve (AUC) 12 | can be interpreted as the probability that, given a randomly selected positive 13 | example and a randomly selected negative example, the positive example is 14 | assigned a higher score by the classification model than the negative example. 15 | 16 | The AUCMeter is designed to operate on one-dimensional Tensors `output` 17 | and `target`, where (1) the `output` contains model output scores that ought to 18 | be higher when the model is more convinced that the example should be positively 19 | labeled, and smaller when the model believes the example should be negatively 20 | labeled (for instance, the output of a signoid function); and (2) the `target` 21 | contains only values 0 (for negative examples) and 1 (for positive examples). 22 | """ 23 | def __init__(self): 24 | super(AUCMeter, self).__init__() 25 | self.reset() 26 | 27 | def reset(self): 28 | self.scores = torch.DoubleTensor(torch.DoubleStorage()).numpy() 29 | self.targets = torch.LongTensor(torch.LongStorage()).numpy() 30 | 31 | def add(self, output, target): 32 | if torch.is_tensor(output): 33 | output = output.cpu().squeeze().numpy() 34 | if torch.is_tensor(target): 35 | target = target.cpu().squeeze().numpy() 36 | elif isinstance(target, numbers.Number): 37 | target = np.asarray([target]) 38 | assert np.ndim(output) == 1, \ 39 | 'wrong output size (1D expected)' 40 | assert np.ndim(target) == 1, \ 41 | 'wrong target size (1D expected)' 42 | assert output.shape[0] == target.shape[0], \ 43 | 'number of outputs and targets does not match' 44 | assert np.all(np.add(np.equal(target, 1), np.equal(target, 0))), \ 45 | 'targets should be binary (0, 1)' 46 | 47 | self.scores = np.append(self.scores, output) 48 | self.targets = np.append(self.targets, target) 49 | 50 | def value(self): 51 | # case when number of elements added are 0 52 | if self.scores.shape[0] == 0: 53 | return 0.5 54 | 55 | # sorting the arrays 56 | scores, sortind = torch.sort(torch.from_numpy(self.scores), dim=0, descending=True) 57 | scores = scores.numpy() 58 | sortind = sortind.numpy() 59 | 60 | # creating the roc curve 61 | tpr = np.zeros(shape=(scores.size + 1), dtype=np.float64) 62 | fpr = np.zeros(shape=(scores.size + 1), dtype=np.float64) 63 | 64 | for i in range(1, scores.size + 1): 65 | if self.targets[sortind[i - 1]] == 1: 66 | tpr[i] = tpr[i - 1] + 1 67 | fpr[i] = fpr[i - 1] 68 | else: 69 | tpr[i] = tpr[i - 1] 70 | fpr[i] = fpr[i - 1] + 1 71 | 72 | tpr /= (self.targets.sum() * 1.0) 73 | fpr /= ((self.targets - 1.0).sum() * -1.0) 74 | 75 | # calculating area under curve using trapezoidal rule 76 | n = tpr.shape[0] 77 | h = fpr[1:n] - fpr[0:n - 1] 78 | sum_h = np.zeros(fpr.shape) 79 | sum_h[0:n - 1] = h 80 | sum_h[1:n] += h 81 | area = (sum_h * tpr).sum() / 2.0 82 | 83 | return (area, tpr, fpr) 84 | -------------------------------------------------------------------------------- /utils/meter/averagevaluemeter.py: -------------------------------------------------------------------------------- 1 | import math 2 | from . import meter 3 | import numpy as np 4 | 5 | class AverageValueMeter(meter.Meter): 6 | def __init__(self): 7 | super(AverageValueMeter, self).__init__() 8 | self.reset() 9 | 10 | def add(self, value, n = 1): 11 | self.sum += value 12 | self.var += value * value 13 | self.n += n 14 | 15 | def value(self): 16 | n = self.n 17 | if n == 0: 18 | mean, std = np.nan, np.nan 19 | elif n == 1: 20 | return self.sum, np.inf 21 | else: 22 | mean = self.sum / n 23 | std = math.sqrt( (self.var - n * mean * mean) / (n - 1.0) ) 24 | return mean, std 25 | 26 | def reset(self): 27 | self.sum = 0.0 28 | self.n = 0 29 | self.var = 0.0 30 | 31 | -------------------------------------------------------------------------------- /utils/meter/classerrormeter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import numbers 4 | from . import meter 5 | 6 | 7 | class ClassErrorMeter(meter.Meter): 8 | def __init__(self, topk=[1], accuracy=False): 9 | super(ClassErrorMeter, self).__init__() 10 | self.topk = np.sort(topk) 11 | self.accuracy = accuracy 12 | self.reset() 13 | 14 | def reset(self): 15 | self.sum = {v: 0 for v in self.topk} 16 | self.n = 0 17 | 18 | def add(self, output, target): 19 | if torch.is_tensor(output): 20 | output = output.cpu().squeeze().numpy() 21 | if torch.is_tensor(target): 22 | target = target.cpu().squeeze().numpy() 23 | elif isinstance(target, numbers.Number): 24 | target = np.asarray([target]) 25 | if np.ndim(output) == 1: 26 | output = output[np.newaxis] 27 | else: 28 | assert np.ndim(output) == 2, \ 29 | 'wrong output size (1D or 2D expected)' 30 | assert np.ndim(target) == 1, \ 31 | 'target and output do not match' 32 | assert target.shape[0] == output.shape[0], \ 33 | 'target and output do not match' 34 | topk = self.topk 35 | maxk = int(topk[-1]) # seems like Python3 wants int and not np.int64 36 | no = output.shape[0] 37 | 38 | pred = torch.from_numpy(output).topk(maxk, 1, True, True)[1].numpy() 39 | correct = pred == target[:, np.newaxis].repeat(pred.shape[1], 1) 40 | 41 | for k in topk: 42 | self.sum[k] += no - correct[:, 0:k].sum() 43 | self.n += no 44 | 45 | def value(self, k=-1): 46 | if k != -1: 47 | assert k in self.sum.keys(), \ 48 | 'invalid k (this k was not provided at construction time)' 49 | if self.accuracy: 50 | return (1. - float(self.sum[k]) / self.n) * 100.0 51 | else: 52 | return float(self.sum[k]) / self.n * 100.0 53 | else: 54 | return [self.value(k_) for k_ in self.topk] 55 | -------------------------------------------------------------------------------- /utils/meter/confusionmeter.py: -------------------------------------------------------------------------------- 1 | from . import meter 2 | import numpy as np 3 | 4 | 5 | class ConfusionMeter(meter.Meter): 6 | """ 7 | The ConfusionMeter constructs a confusion matrix for a multi-class 8 | classification problems. It does not support multi-label, multi-class problems: 9 | for such problems, please use MultiLabelConfusionMeter. 10 | """ 11 | 12 | def __init__(self, k, normalized=False): 13 | """ 14 | Args: 15 | k (int): number of classes in the classification problem 16 | normalized (boolean): Determines whether or not the confusion matrix 17 | is normalized or not 18 | """ 19 | super(ConfusionMeter, self).__init__() 20 | self.conf = np.ndarray((k, k), dtype=np.int32) 21 | self.normalized = normalized 22 | self.k = k 23 | self.reset() 24 | 25 | def reset(self): 26 | self.conf.fill(0) 27 | 28 | def add(self, predicted, target): 29 | """ 30 | Computes the confusion matrix of K x K size where K is no of classes 31 | Args: 32 | predicted (tensor): Can be an N x K tensor of predicted scores obtained from 33 | the model for N examples and K classes or an N-tensor of 34 | integer values between 1 and K. 35 | target (tensor): Can be a N-tensor of integer values assumed to be integer 36 | values between 1 and K or N x K tensor, where targets are 37 | assumed to be provided as one-hot vectors 38 | 39 | """ 40 | predicted = predicted.cpu().squeeze().numpy() 41 | target = target.cpu().squeeze().numpy() 42 | 43 | assert predicted.shape[0] == target.shape[0], \ 44 | 'number of targets and predicted outputs do not match' 45 | 46 | if np.ndim(predicted) != 1: 47 | assert predicted.shape[1] == self.k, \ 48 | 'number of predictions does not match size of confusion matrix' 49 | predicted = np.argmax(predicted, 1) 50 | else: 51 | assert (predicted.max() < self.k) and (predicted.min() >= 0), \ 52 | 'predicted values are not between 1 and k' 53 | 54 | onehot_target = np.ndim(target) != 1 55 | if onehot_target: 56 | assert target.shape[1] == self.k, \ 57 | 'Onehot target does not match size of confusion matrix' 58 | assert (target >= 0).all() and (target <= 1).all(), \ 59 | 'in one-hot encoding, target values should be 0 or 1' 60 | assert (target.sum(1) == 1).all(), \ 61 | 'multi-label setting is not supported' 62 | target = np.argmax(target, 1) 63 | else: 64 | assert (predicted.max() < self.k) and (predicted.min() >= 0), \ 65 | 'predicted values are not between 1 and k' 66 | 67 | # hack for bincounting 2 arrays together 68 | x = predicted + self.k * target 69 | bincount_2d = np.bincount(x.astype(np.int32), 70 | minlength=self.k ** 2) 71 | assert bincount_2d.size == self.k ** 2 72 | conf = bincount_2d.reshape((self.k, self.k)) 73 | 74 | self.conf += conf 75 | 76 | def value(self): 77 | """ 78 | Returns: 79 | Confustion matrix of K rows and K columns, where rows corresponds 80 | to ground-truth targets and columns corresponds to predicted 81 | targets. 82 | """ 83 | if self.normalized: 84 | conf = self.conf.astype(np.float32) 85 | return conf / conf.sum(1).clip(min=1e-12)[:, None] 86 | else: 87 | return self.conf 88 | -------------------------------------------------------------------------------- /utils/meter/mapmeter.py: -------------------------------------------------------------------------------- 1 | import math 2 | from . import meter, APMeter 3 | import numpy as np 4 | import torch 5 | 6 | 7 | class mAPMeter(meter.Meter): 8 | """ 9 | The mAPMeter measures the mean average precision over all classes. 10 | 11 | The mAPMeter is designed to operate on `NxK` Tensors `output` and 12 | `target`, and optionally a `Nx1` Tensor weight where (1) the `output` 13 | contains model output scores for `N` examples and `K` classes that ought to 14 | be higher when the model is more convinced that the example should be 15 | positively labeled, and smaller when the model believes the example should 16 | be negatively labeled (for instance, the output of a sigmoid function); (2) 17 | the `target` contains only values 0 (for negative examples) and 1 18 | (for positive examples); and (3) the `weight` ( > 0) represents weight for 19 | each sample. 20 | """ 21 | def __init__(self): 22 | super(mAPMeter, self).__init__() 23 | self.apmeter = APMeter() 24 | 25 | def reset(self): 26 | self.apmeter.reset() 27 | 28 | def add(self, output, target, weight=None): 29 | self.apmeter.add(output, target, weight) 30 | 31 | def value(self): 32 | return self.apmeter.value().mean() 33 | -------------------------------------------------------------------------------- /utils/meter/meter.py: -------------------------------------------------------------------------------- 1 | 2 | class Meter(object): 3 | def reset(self): 4 | pass 5 | 6 | def add(self): 7 | pass 8 | 9 | def value(self): 10 | pass 11 | -------------------------------------------------------------------------------- /utils/meter/movingaveragevaluemeter.py: -------------------------------------------------------------------------------- 1 | import math 2 | from . import meter 3 | import torch 4 | 5 | 6 | class MovingAverageValueMeter(meter.Meter): 7 | def __init__(self, windowsize): 8 | super(MovingAverageValueMeter, self).__init__() 9 | self.windowsize = windowsize 10 | self.valuequeue = torch.Tensor(windowsize) 11 | self.reset() 12 | 13 | def reset(self): 14 | self.sum = 0.0 15 | self.n = 0 16 | self.var = 0.0 17 | self.valuequeue.fill_(0) 18 | 19 | def add(self, value): 20 | queueid = (self.n % self.windowsize) 21 | oldvalue = self.valuequeue[queueid] 22 | self.sum += value - oldvalue 23 | self.var += value * value - oldvalue * oldvalue 24 | self.valuequeue[queueid] = value 25 | self.n += 1 26 | 27 | def value(self): 28 | n = min(self.n, self.windowsize) 29 | mean = self.sum / max(1, n) 30 | std = math.sqrt(max((self.var - n * mean * mean) / max(1, n-1), 0)) 31 | return mean, std 32 | 33 | -------------------------------------------------------------------------------- /utils/meter/msemeter.py: -------------------------------------------------------------------------------- 1 | import math 2 | from . import meter 3 | import torch 4 | 5 | 6 | class MSEMeter(meter.Meter): 7 | def __init__(self, root=False): 8 | super(MSEMeter, self).__init__() 9 | self.reset() 10 | self.root = root 11 | 12 | def reset(self): 13 | self.n = 0 14 | self.sesum = 0.0 15 | 16 | def add(self, output, target): 17 | if not torch.is_tensor(output) and not torch.is_tensor(target): 18 | output = torch.from_numpy(output) 19 | target = torch.from_numpy(target) 20 | self.n += output.numel() 21 | self.sesum += torch.sum((output - target) ** 2) 22 | 23 | def value(self): 24 | mse = self.sesum / max(1, self.n) 25 | return math.sqrt(mse) if self.root else mse 26 | -------------------------------------------------------------------------------- /utils/meter/retrievalmeter.py: -------------------------------------------------------------------------------- 1 | from . import meter 2 | import torch 3 | import numpy as np 4 | 5 | 6 | class RetrievalMAPMeter(meter.Meter): 7 | MAP = 0 8 | PR = 1 9 | 10 | def __init__(self, topk=1000): 11 | self.topk = topk 12 | self.all_features = [] 13 | self.all_lbs = [] 14 | self.dis_mat = None 15 | 16 | pass 17 | 18 | def reset(self): 19 | self.all_lbs.clear() 20 | self.all_features.clear() 21 | 22 | def add(self, features, lbs): 23 | self.all_features.append(features.cpu()) 24 | self.all_lbs.append(lbs.cpu()) 25 | 26 | def value(self, mode=MAP): 27 | if mode == self.MAP: 28 | return self.mAP() 29 | if mode == self.PR: 30 | return self.pr() 31 | raise NotImplementedError 32 | 33 | def mAP(self): 34 | fts = torch.cat(self.all_features).numpy() 35 | lbls = torch.cat(self.all_lbs).numpy() 36 | self.dis_mat = Eu_dis_mat_fast(np.mat(fts)) 37 | num = len(lbls) 38 | mAP = 0 39 | for i in range(num): 40 | scores = self.dis_mat[:, i] 41 | targets = (lbls == lbls[i]).astype(np.uint8) 42 | sortind = np.argsort(scores, 0)[:self.topk] 43 | truth = targets[sortind] 44 | sum = 0 45 | precision = [] 46 | for j in range(self.topk): 47 | if truth[j]: 48 | sum += 1 49 | precision.append(sum * 1.0 / (j + 1)) 50 | if len(precision) == 0: 51 | ap = 0 52 | else: 53 | for ii in range(len(precision)): 54 | precision[ii] = max(precision[ii:]) 55 | ap = np.array(precision).mean() 56 | mAP += ap 57 | # print(f'{i+1}/{num}\tap:{ap:.3f}\t') 58 | mAP = mAP / num 59 | return mAP 60 | 61 | def pr(self): 62 | lbls = torch.cat(self.all_lbs).numpy() 63 | num = len(lbls) 64 | precisions = [] 65 | recalls = [] 66 | ans = [] 67 | for i in range(num): 68 | scores = self.des_mat[:, i] 69 | targets = (lbls == lbls[i]).astype(np.uint8) 70 | sortind = np.argsort(scores, 0)[:self.topk] 71 | truth = targets[sortind] 72 | tmp = 0 73 | sum = truth[:self.topk].sum() 74 | precision = [] 75 | recall = [] 76 | for j in range(self.topk): 77 | if truth[j]: 78 | tmp += 1 79 | # precision.append(sum/(j + 1)) 80 | recall.append(tmp * 1.0 / sum) 81 | precision.append(tmp * 1.0 / (j + 1)) 82 | precisions.append(precision) 83 | for j in range(len(precision)): 84 | precision[j] = max(precision[j:]) 85 | recalls.append(recall) 86 | tmp = [] 87 | for ii in range(11): 88 | min_des = 100 89 | val = 0 90 | for j in range(self.topk): 91 | if abs(recall[j] - ii * 0.1) < min_des: 92 | min_des = abs(recall[j] - ii * 0.1) 93 | val = precision[j] 94 | tmp.append(val) 95 | print('%d/%d' % (i + 1, num)) 96 | ans.append(tmp) 97 | return np.array(ans).mean(0) 98 | 99 | 100 | def Eu_dis_mat_fast(X): 101 | aa = np.sum(np.multiply(X, X), 1) 102 | ab = X * X.T 103 | D = aa + aa.T - 2 * ab 104 | D[D < 0] = 0 105 | D = np.sqrt(D) 106 | D = np.maximum(D, D.T) 107 | return D 108 | -------------------------------------------------------------------------------- /utils/meter/timemeter.py: -------------------------------------------------------------------------------- 1 | import time 2 | from . import meter 3 | 4 | class TimeMeter(meter.Meter): 5 | """ 6 | 7 | #### tnt.TimeMeter(@ARGP) 8 | @ARGT 9 | 10 | The `tnt.TimeMeter` is designed to measure the time between events and can be 11 | used to measure, for instance, the average processing time per batch of data. 12 | It is different from most other meters in terms of the methods it provides: 13 | 14 | The `tnt.TimeMeter` provides the following methods: 15 | 16 | * `reset()` resets the timer, setting the timer and unit counter to zero. 17 | * `value()` returns the time passed since the last `reset()`; divided by the counter value when `unit=true`. 18 | """ 19 | def __init__(self, unit): 20 | super(TimeMeter, self).__init__() 21 | self.unit = unit 22 | self.reset() 23 | 24 | def reset(self): 25 | self.n = 0 26 | self.time = time.time() 27 | 28 | def value(self): 29 | return time.time() - self.time 30 | -------------------------------------------------------------------------------- /val_mvcnn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import config 3 | from utils import meter 4 | from torch import nn 5 | from models import MVCNN 6 | from torch.utils.data import DataLoader 7 | from datasets import * 8 | 9 | 10 | def validate(val_loader, net): 11 | """ 12 | validation for one epoch on the val set 13 | """ 14 | batch_time = meter.TimeMeter(True) 15 | data_time = meter.TimeMeter(True) 16 | prec = meter.ClassErrorMeter(topk=[1], accuracy=True) 17 | 18 | # testing mode 19 | net.eval() 20 | 21 | for i, (views, labels) in enumerate(val_loader): 22 | batch_time.reset() 23 | # bz x 12 x 3 x 224 x 224 24 | views = views.to(device=config.device) 25 | labels = labels.to(device=config.device) 26 | 27 | preds = net(views) # bz x C x H x W 28 | 29 | prec.add(preds.data, labels.data) 30 | 31 | if i % config.print_freq == 0: 32 | print(f'[{i}/{len(val_loader)}]\t' 33 | f'Batch Time {batch_time.value():.3f}\t' 34 | f'Epoch Time {data_time.value():.3f}\t' 35 | f'Prec@1 {prec.value(1):.3f}\t') 36 | 37 | print(f'mean class accuracy: {prec.value(1)} ') 38 | return prec.value(1) 39 | 40 | 41 | def main(): 42 | print('Training Process\nInitializing...\n') 43 | config.init_env() 44 | 45 | val_dataset = data_pth.view_data(config.view_net.data_root, 46 | status=STATUS_TEST, 47 | base_model_name=config.base_model_name) 48 | 49 | val_loader = DataLoader(val_dataset, batch_size=config.view_net.train.batch_sz, 50 | num_workers=config.num_workers,shuffle=False) 51 | 52 | 53 | # create model 54 | net = MVCNN() 55 | net = net.to(device=config.device) 56 | net = nn.DataParallel(net) 57 | 58 | print(f'loading pretrained model from {config.view_net.ckpt_load_file}') 59 | checkpoint = torch.load(config.view_net.ckpt_load_file) 60 | net.module.load_state_dict(checkpoint['model']) 61 | best_prec1 = checkpoint['best_prec1'] 62 | 63 | with torch.no_grad(): 64 | prec1 = validate(val_loader, net) 65 | 66 | print('curr accuracy: ', prec1) 67 | print('best accuracy: ', best_prec1) 68 | 69 | print('Train Finished!') 70 | 71 | 72 | if __name__ == '__main__': 73 | main() 74 | 75 | -------------------------------------------------------------------------------- /val_pc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import config 3 | import torch 4 | import os.path as osp 5 | from utils import meter 6 | from torch import nn 7 | from torch import optim 8 | from models import DGCNN 9 | from torch.utils.data import DataLoader 10 | from datasets import data_pth, STATUS_TRAIN, STATUS_TEST 11 | 12 | def validate(val_loader, net, epoch): 13 | """ 14 | validation for one epoch on the val set 15 | """ 16 | batch_time = meter.TimeMeter(True) 17 | data_time = meter.TimeMeter(True) 18 | prec = meter.ClassErrorMeter(topk=[1], accuracy=True) 19 | 20 | # testing mode 21 | net.eval() 22 | 23 | for i, (pcs, labels) in enumerate(val_loader): 24 | batch_time.reset() 25 | # bz x 12 x 3 x 224 x 224 26 | pcs = pcs.to(device=config.device) 27 | labels = labels.to(device=config.device) 28 | 29 | preds = net(pcs) # bz x C x H x W 30 | 31 | prec.add(preds.data, labels.data) 32 | 33 | if i % config.print_freq == 0: 34 | print(f'Epoch: [{epoch}][{i}/{len(val_loader)}]\t' 35 | f'Batch Time {batch_time.value():.3f}\t' 36 | f'Epoch Time {data_time.value():.3f}\t' 37 | f'Prec@1 {prec.value(1):.3f}\t') 38 | 39 | print(f'mean class accuracy at epoch {epoch}: {prec.value(1)} ') 40 | return prec.value(1) 41 | 42 | 43 | 44 | 45 | def main(): 46 | print('Training Process\nInitializing...\n') 47 | config.init_env() 48 | 49 | val_dataset = data_pth.pc_data(config.pc_net.data_root, status=STATUS_TEST) 50 | 51 | val_loader = DataLoader(val_dataset, batch_size=config.pc_net.validation.batch_sz, 52 | num_workers=config.num_workers,shuffle=True) 53 | 54 | # create model 55 | net = DGCNN(n_neighbor=config.pc_net.n_neighbor,num_classes=config.pc_net.num_classes) 56 | net = torch.nn.DataParallel(net) 57 | net = net.to(device=config.device) 58 | optimizer = optim.Adam(net.parameters(), config.pc_net.train.lr, 59 | weight_decay=config.pc_net.train.weight_decay) 60 | 61 | print(f'loading pretrained model from {config.pc_net.ckpt_load_file}') 62 | checkpoint = torch.load(config.pc_net.ckpt_load_file) 63 | net.module.load_state_dict(checkpoint['model']) 64 | optimizer.load_state_dict(checkpoint['optimizer']) 65 | best_prec1 = checkpoint['best_prec1'] 66 | resume_epoch = checkpoint['epoch'] 67 | 68 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 5, 0.5) 69 | criterion = nn.CrossEntropyLoss() 70 | criterion = criterion.to(device=config.device) 71 | 72 | # for p in net.module.feature.parameters(): 73 | # p.requires_grad = False 74 | 75 | with torch.no_grad(): 76 | prec1 = validate(val_loader, net, resume_epoch) 77 | 78 | print('curr accuracy: ', prec1) 79 | print('best accuracy: ', best_prec1) 80 | 81 | 82 | 83 | if __name__ == '__main__': 84 | main() 85 | 86 | -------------------------------------------------------------------------------- /val_pvrnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import config 3 | import torch 4 | import os.path as osp 5 | from utils import meter 6 | from torch import nn 7 | from torch import optim 8 | from models import PVRNet 9 | from torch.utils.data import DataLoader 10 | from datasets import * 11 | import pdb 12 | 13 | def validate(val_loader, net, epoch): 14 | """ 15 | validation for one epoch on the val set 16 | """ 17 | batch_time = meter.TimeMeter(True) 18 | data_time = meter.TimeMeter(True) 19 | prec = meter.ClassErrorMeter(topk=[1], accuracy=True) 20 | retrieval_map = meter.RetrievalMAPMeter() 21 | 22 | # testing mode 23 | net.eval() 24 | 25 | total_seen_class = [0 for _ in range(40)] 26 | total_right_class = [0 for _ in range(40)] 27 | 28 | for i, (views, pcs, labels) in enumerate(val_loader): 29 | batch_time.reset() 30 | 31 | views = views.to(device=config.device) 32 | pcs = pcs.to(device=config.device) 33 | labels = labels.to(device=config.device) 34 | 35 | preds, fts = net(pcs, views, get_fea=True) # bz x C x H x W 36 | 37 | prec.add(preds.data, labels.data) 38 | retrieval_map.add(fts.detach()/torch.norm(fts.detach(), 2, 1, True), labels.detach()) 39 | for j in range(views.size(0)): 40 | total_seen_class[labels.data[j]] += 1 41 | total_right_class[labels.data[j]] += (np.argmax(preds.data,1)[j] == labels.cpu()[j]) 42 | 43 | 44 | if i % config.print_freq == 0: 45 | print(f'Epoch: [{epoch}][{i}/{len(val_loader)}]\t' 46 | f'Batch Time {batch_time.value():.3f}\t' 47 | f'Epoch Time {data_time.value():.3f}\t' 48 | f'Prec@1 {prec.value(1):.3f}\t' 49 | f'Mean Class accuracy{(np.mean(np.array(total_right_class)/np.array(total_seen_class,dtype=np.float)))}') 50 | 51 | mAP = retrieval_map.mAP() 52 | print(f' instance accuracy at epoch {epoch}: {prec.value(1)} ') 53 | print(f' mean class accuracy at epoch {epoch}: {(np.mean(np.array(total_right_class)/np.array(total_seen_class,dtype=np.float)))} ') 54 | print(f' map at epoch {epoch}: {mAP} ') 55 | return prec.value(1), mAP 56 | 57 | 58 | 59 | 60 | def main(): 61 | print('Training Process\nInitializing...\n') 62 | config.init_env() 63 | 64 | val_dataset = pc_view_data(config.pv_net.pc_root, 65 | config.pv_net.view_root, 66 | status=STATUS_TEST, 67 | base_model_name=config.base_model_name) 68 | val_loader = DataLoader(val_dataset, batch_size=config.pv_net.train.batch_sz, 69 | num_workers=config.num_workers,shuffle=True) 70 | 71 | # create model 72 | net = PVRNet() 73 | net = torch.nn.DataParallel(net) 74 | net = net.to(device=config.device) 75 | optimizer_all = optim.SGD(net.parameters(), config.pv_net.train.all_lr, 76 | momentum=config.pv_net.train.momentum, 77 | weight_decay=config.pv_net.train.weight_decay) 78 | 79 | print(f'loading pretrained model from {config.pv_net.ckpt_load_file}') 80 | checkpoint = torch.load(config.pv_net.ckpt_load_file) 81 | state_dict = checkpoint['model'] 82 | net.module.load_state_dict(state_dict) 83 | optimizer_all.load_state_dict(checkpoint['optimizer_all']) 84 | best_prec1 = checkpoint['best_prec1'] 85 | resume_epoch = checkpoint['epoch'] 86 | 87 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_all, 5, 0.5) 88 | criterion = nn.CrossEntropyLoss() 89 | criterion = criterion.to(device=config.device) 90 | 91 | # for p in net.module.feature.parameters(): 92 | # p.requires_grad = False 93 | 94 | with torch.no_grad(): 95 | prec1, Map = validate(val_loader, net, resume_epoch) 96 | 97 | print('curr accuracy: ', prec1) 98 | print('best accuracy: ', best_prec1) 99 | 100 | 101 | 102 | if __name__ == '__main__': 103 | main() 104 | 105 | --------------------------------------------------------------------------------