├── .ipynb_checkpoints └── train-checkpoint.py ├── LICENSE.md ├── README.md ├── RobustPointSet.png ├── data_utils ├── ModelNetDataLoader.py └── provider.py ├── models ├── pointnet.py ├── pointnet_cls.py └── pointnet_util.py ├── test.py └── train.py /.ipynb_checkpoints/train-checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import argparse 5 | import datetime 6 | import logging 7 | import shutil 8 | import importlib 9 | import numpy as np 10 | from tqdm import tqdm 11 | import multiprocessing 12 | from pathlib import Path 13 | from torch.optim.lr_scheduler import CosineAnnealingLR 14 | 15 | sys.path.append("..") 16 | from data_utils.ModelNetDataLoader import ModelNetDataLoader 17 | from data_utils import provider 18 | 19 | BASE_DIR = os.path.dirname(os.path.abspath('__file__')) 20 | ROOT_DIR = BASE_DIR 21 | sys.path.append(os.path.join(ROOT_DIR, 'models')) 22 | sys.path.append(os.path.join(ROOT_DIR, 'log')) 23 | 24 | 25 | 26 | 27 | def test(model, loader, num_class=40): 28 | mean_correct = [] 29 | class_acc = np.zeros((num_class,3)) 30 | for j, data in tqdm(enumerate(loader), total=len(loader)): 31 | points, target = data 32 | target = target[:, 0] 33 | points = points.transpose(2, 1) 34 | points, target = points.cuda(), target.cuda() 35 | classifier = model.eval() 36 | pred, _ = classifier(points.float()) 37 | pred_choice = pred.data.max(1)[1] 38 | for cat in np.unique(target.cpu()): 39 | cat = int(cat) 40 | classacc = pred_choice[target==cat].eq(target[target==cat].long().data).cpu().sum() 41 | class_acc[cat,0]+= classacc.item()/float(points[target==cat].size()[0]) 42 | class_acc[cat,1]+=1 43 | correct = pred_choice.eq(target.long().data).cpu().sum() 44 | mean_correct.append(correct.item()/float(points.size()[0])) 45 | class_acc[:,2] = class_acc[:,0]/ class_acc[:,1] 46 | class_acc = np.mean(class_acc[:,2]) 47 | instance_acc = np.mean(mean_correct) 48 | return instance_acc, class_acc 49 | 50 | 51 | 52 | def main(args): 53 | def log_string(str): 54 | logger.info(str) 55 | print(str) 56 | 57 | ### Hyper Parameters ### 58 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 59 | 60 | ### Create Dir ### 61 | timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))+'-'+args.model+'-'+args.task 62 | experiment_dir = Path('log/') 63 | experiment_dir.mkdir(exist_ok=True) 64 | experiment_dir = experiment_dir.joinpath('classification') 65 | experiment_dir.mkdir(exist_ok=True) 66 | if args.log_dir is None: 67 | experiment_dir = experiment_dir.joinpath(timestr) 68 | else: 69 | experiment_dir = experiment_dir.joinpath(args.log_dir) 70 | experiment_dir.mkdir(exist_ok=True) 71 | checkpoints_dir = experiment_dir.joinpath('checkpoints/') 72 | checkpoints_dir.mkdir(exist_ok=True) 73 | log_dir = experiment_dir.joinpath('logs/') 74 | log_dir.mkdir(exist_ok=True) 75 | 76 | ### Log ### 77 | logger = logging.getLogger("Model") 78 | logger.setLevel(logging.INFO) 79 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 80 | file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model)) 81 | file_handler.setLevel(logging.INFO) 82 | file_handler.setFormatter(formatter) 83 | logger.addHandler(file_handler) 84 | log_string('PARAMETER ...') 85 | log_string(args) 86 | 87 | 88 | ### Data Loading ### 89 | log_string('Load dataset ...') 90 | TRAIN_DATASET = ModelNetDataLoader(root=args.data_root, 91 | tasks=args.train_tasks, 92 | labels=args.train_labels, 93 | partition='train', 94 | npoint=args.num_point, 95 | normal_channel=args.normal) 96 | TEST_DATASET = ModelNetDataLoader(root=args.data_root, 97 | tasks=args.test_tasks, 98 | labels=args.test_labels, 99 | partition='test', 100 | npoint=args.num_point, 101 | normal_channel=args.normal) 102 | trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batch_size, shuffle=True, num_workers=0) 103 | testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=0) 104 | 105 | 106 | ### Model Loading ### 107 | num_class = 40 108 | MODEL = importlib.import_module(args.model) 109 | shutil.copy('models/%s.py' % args.model, str(experiment_dir)) 110 | # pointnet_util pointnet_util, 111 | shutil.copy('models/%s.py' % args.ults, str(experiment_dir)) 112 | 113 | classifier = MODEL.get_model(num_class,normal_channel=args.normal).cuda() 114 | criterion = MODEL.get_loss().cuda() 115 | 116 | try: 117 | checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth') 118 | start_epoch = checkpoint['epoch'] 119 | classifier.load_state_dict(checkpoint['model_state_dict']) 120 | log_string('Use pretrain model') 121 | except: 122 | log_string('No existing model, starting training from scratch...') 123 | start_epoch = 0 124 | 125 | if args.optimizer == 'Adam': 126 | optimizer = torch.optim.Adam( 127 | classifier.parameters(), 128 | lr=args.learning_rate, 129 | betas=(0.9, 0.999), 130 | eps=1e-08, 131 | weight_decay=args.decay_rate 132 | ) 133 | else: 134 | optimizer = torch.optim.SGD(classifier.parameters(), 135 | lr=0.01, 136 | momentum=0.9) 137 | 138 | 139 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_decay_step, gamma=args.lr_decay_rate) 140 | 141 | global_epoch = 0 142 | global_step = 0 143 | best_instance_acc = 0.0 144 | best_class_acc = 0.0 145 | mean_correct = [] 146 | 147 | 148 | ### Training ### 149 | logger.info('Start training ...') 150 | for epoch in range(start_epoch, args.epoch): 151 | 152 | 153 | log_string('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch)) 154 | 155 | for batch_id, data in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9): 156 | points, target = data 157 | 158 | points = points.data.numpy() 159 | 160 | points = torch.Tensor(points) 161 | target = target[:, 0] 162 | 163 | points = points.transpose(2, 1) 164 | points, target = points.cuda(), target.cuda() 165 | optimizer.zero_grad() 166 | 167 | classifier = classifier.train() 168 | pred, trans_feat = classifier(points) 169 | loss = criterion(pred, target.long(), trans_feat) 170 | pred_choice = pred.data.max(1)[1] 171 | correct = pred_choice.eq(target.long().data).cpu().sum() 172 | mean_correct.append(correct.item() / float(points.size()[0])) 173 | loss.backward() 174 | optimizer.step() 175 | global_step += 1 176 | 177 | train_instance_acc = np.mean(mean_correct) 178 | log_string('Train Instance Accuracy: %f' % train_instance_acc) 179 | 180 | 181 | if (epoch % 2 == 0) or (epoch == args.epoch): 182 | with torch.no_grad(): 183 | instance_acc, class_acc = test(classifier.eval(), testDataLoader) 184 | 185 | if (instance_acc >= best_instance_acc): 186 | best_instance_acc = instance_acc 187 | best_epoch = epoch + 1 188 | 189 | if (class_acc >= best_class_acc): 190 | best_class_acc = class_acc 191 | log_string('Test Instance Accuracy: %f, Class Accuracy: %f'% (instance_acc, class_acc)) 192 | log_string('Best Instance Accuracy: %f, Class Accuracy: %f'% (best_instance_acc, best_class_acc)) 193 | 194 | if (instance_acc >= best_instance_acc): 195 | logger.info('Save model...') 196 | savepath = str(checkpoints_dir) + '/best_model.pth' 197 | log_string('Saving at %s'% savepath) 198 | state = { 199 | 'epoch': best_epoch, 200 | 'instance_acc': instance_acc, 201 | 'class_acc': class_acc, 202 | 'model_state_dict': classifier.state_dict(), 203 | 'optimizer_state_dict': optimizer.state_dict(), 204 | } 205 | torch.save(state, savepath) 206 | global_epoch += 1 207 | 208 | # adjust lr 209 | scheduler.step() 210 | 211 | logger.info('End of training...') 212 | 213 | 214 | 215 | 216 | 217 | def parseArgs(): 218 | """ Argument parser for configuring model training """ 219 | parser = argparse.ArgumentParser(description='RobustPointSet trainer') 220 | parser.add_argument('--batch_size', type=int, default=32) 221 | parser.add_argument('--model', type=str, default='pointnet_cls', help='point cloud model') 222 | parser.add_argument('--ults', type=str, default='pointnet_util', help='help functions for point cloud model') 223 | parser.add_argument('--task', type=str, default='s1', help='Stragety 1 or 2') 224 | parser.add_argument('--epoch', type=int, default=300, help='training epoch') 225 | parser.add_argument('--learning_rate', type=float, default=0.001) 226 | parser.add_argument('--gpu', type=str, default='0', help='gpu device index') 227 | parser.add_argument('--num_point', type=int, default=2048, help='number of points') 228 | parser.add_argument('--optimizer', type=str, default='Adam', help='optimizer') 229 | parser.add_argument('--log_dir', type=str, default=None, help='directory to store training log') 230 | parser.add_argument('--decay_rate', type=float, default=1e-4) 231 | parser.add_argument('--lr_decay_rate', type=float, default=0.5) 232 | parser.add_argument('--lr_decay_step', type=int, default=100) 233 | parser.add_argument('--lr_clip', type=float, default=1e-7) 234 | parser.add_argument('--normal', default=False, help='Whether to use normal information [default: False]') 235 | parser.add_argument('--data_root', type=str, default='data/', help='data directory') 236 | parser.add_argument('--train_tasks', type=str, nargs='+', required=True, help="List of RobustPointSet files to be trained on") 237 | parser.add_argument('--test_tasks', type=str, nargs='+', required=True, help="List of RobustPointSet files to be tested on during training") 238 | return parser.parse_args() 239 | 240 | 241 | 242 | if __name__ == '__main__': 243 | """ 244 | Available data: original, jitter, translate, missing_part, sparse, rotation, occlusion 245 | Example command for strategy 1: 246 | python train.py --train_tasks train_original.npy --test_tasks test_original.npy 247 | python train.py --train_tasks train_original.npy --test_tasks test_sparse.npy 248 | Example command for strategy 2: 249 | python train.py --train_tasks train_original.npy train_jitter.npy train_translate.npy train_missing_part.npy train_sparse.npy train_rotation.npy --test_tasks test_occlusion.npy 250 | python train.py --train_tasks train_original.npy train_translate.npy train_missing_part.npy train_sparse.npy train_rotation.npy train_occlusion.npy --test_tasks test_jitter.npy 251 | """ 252 | args = parseArgs() 253 | args.train_labels = ['train_labels.npy']*len(args.train_tasks) 254 | args.test_labels =['test_labels.npy']*len(args.test_tasks) 255 | main(args) -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | # Robust Point Set 2 | 3 | This dataset is provided for the convenience of academic research only, and is provided without any representations or warranties, including warranties of non-infringement or fitness for a particular purpose. 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RobustPointSet: A Dataset for Benchmarking Robustness of Point Cloud Classifiers, RobustML - ICLR 2021 2 | _Robust and Reliable Machine Learning in the Real World Workshop (RobustML), ICLR 2021_ 3 | 4 | A benchmark dataset to facilitate augmentation-independent robustness analysis of point cloud classifiers. RobustPointSet comes with 6 different transformations: Noise, Translation, Missing part, Sparse, Rotation, and Occlusion. 5 | 6 |
7 | 8 |
9 | 10 | 11 | ### Evaluation Strategies 12 | 13 | We test two different evaluation strategies on more than 10 models: 14 | 15 | #### Strategy 1 (training-domain validation) 16 | For this strategy, we train on `train_original.npy` without applying any data-augmentation, select the best performing model on `test_original.npy` (to be consistent with baselines), and test on each test set (i.e. `test_*.npy` ) separately. 17 | 18 | #### Strategy 2 (leave-one-out validation) 19 | For this strategy, each time we concatenate 6 train sets (i.e. the `train_*.npy` ones), and test on the test set (i.e. `test_*.npy` ) of the taken-out group. We repeat this process for all the groups. For example, we train with concatenation of `{train_original.npy, train_noise.npy, train_missing_part.npy, train_occlusion.npy, train_rotation.npy, train_sparse.npy}` and test on `test_translate.npy`. Similar to strategy 1, we don't apply any data-augmentation here. For both the strategies, the same label files can be used i.e. `labels_train.npy` and `labels_test.npy`. 20 | 21 | ----------------- 22 | 23 | ### Benchmarks 24 | 25 | Table 1: Training-domain validation results on our RobustPointSet test sets. The *Noise* column, for example, shows the result of training on the *Original* train set and testing with the *Noise* test set. RotInv refers to rotation-invariant models. 26 | 27 | 28 | 29 | | Type | Method | Original | Noise | Translation | Missing part | Sparse | Rotation | Occlusion | Average | 30 | |:----:|:--------------------|:---------:|:---------:|:-----------:|:------------:|:---------:|:--------:|:---------:|:---------:| 31 | |General | PointNet | 89.06 | **74.72** | 79.66 | 81.52 | **60.53** | 8.83 | 39.47 | **61.97** | 32 | |General | PointNet++ (MSG) | 91.27 | 5.73 | 91.31 | 53.69 | 6.65 | 13.02 | 64.18 | 46.55 | 33 | |General | PointNet++ (SSG) | 91.47 | 14.90 | 91.07 | 50.24 | 8.85 | 12.70 | 70.23 | 48.49 | 34 | |General | DGCNN | **92.52** | 57.56 | **91.99** | 85.40 | 9.34 | 13.43 | **78.72** | 61.28 | 35 | |General | PointMask | 88.53 | 73.14 | 78.20 | 81.48 | 58.23 | 8.02 | 39.18 | 60.97 | 36 | |General | DensePoint | 90.96 | 53.28 | 90.72 | 84.49 | 15.52 | 12.76 | 67.67 | 59.40 | 37 | |General | PointCNN | 87.66 | 45.55 | 82.85 | 77.60 | 4.01 | 11.50 | 59.50 | 52.67 | 38 | |General | PointConv | 91.15 | 20.71 | 90.99 | 84.09 | 8.65 | 12.38 | 45.83 | 50.54 | 39 | |General | Relation-Shape-CNN | 91.77 | 48.06 | 91.29 | **85.98** | 23.18 | 11.51 | 75.61 | 61.06 | 40 | |RotInv | SPHnet | 79.18 | 7.22 | **79.18** | 4.22 | 1.26 | 79.18 | 34.33 | 40.65 | 41 | |RotInv | PRIN | 73.66 | 30.19 | 41.21 | 44.17 | 4.17 | 68.56 | 31.56 | 41.93 | 42 | 43 | 44 | ---------------- 45 | 46 | Table 2: Leave-one-out validation strategy classification results on our RobustPointSet test sets. For example, the *Noise* column shows the result of training on *{Original, Translation, Missing part, Sparse, Rotation,Occlusion}* train sets and testing with the *Noise* test set. RotInv refers to rotation-invariant models. 47 | 48 | 49 | | Type | Method | Original | Noise | Translation | Missing part | Sparse | Rotation | Occlusion | Average | 50 | |:----:|:--------------------|:---------:|:---------:|:-----------:|:------------:|:---------:|:---------:|:---------:|:---------:| 51 | |General | PointNet | 88.35 | 72.61 | 81.53 | 82.87 | **69.28** | 9.42 | 35.96 | **62.86** | 52 | |General | PointNet++ (MSG) | 91.55 | 50.92 | 91.43 | 77.16 | 16.19 | 12.26 | **70.39** | 58.56 | 53 | |General | PointNet++ (SSG) | 91.76 | 49.33 | 91.10 | 78.36 | 16.72 | 11.27 | 68.33 | 58.12 | 54 | |General | DGCNN | **92.38** | 66.95 | 91.17 | 85.40 | 6.49 | 14.03 | 68.79 | 60.74 | 55 | |General | PointMask | 88.03 | **73.95** | 80.80 | 82.83 | 63.64 | 8.97 | 36.69 | 62.13 | 56 | |General | DensePoint | 91.00 | 42.38 | 90.64 | 85.70 | 20.66 | 8.55 | 47.89 | 55.26 | 57 | |General | PointCNN | 88.91 | 73.10 | 87.46 | 82.06 | 7.18 | 13.95 | 52.66 | 57.90 | 58 | |General | PointConv | 91.07 | 66.19 | **91.51** | 84.01 | 19.63 | 11.62 | 44.07 | 58.30 | 59 | |General | Relation-Shape-CNN | 90.52 | 36.95 | 91.33 | **85.82** | 24.59 | 8.23 | 60.09 | 56.79 | 60 | |RotInv | SPHnet | 79.30 | 8.24 | 76.02 | 17.94 | 6.33 | **78.86** | 35.96 | 43.23 | 61 | |RotInv | PRIN | 76.54 | 55.35 | 56.36 | 59.20 | 4.05 | 73.30 | 36.91 | 51.67 | 62 | 63 | 64 | ----------------- 65 | 66 | ### Sample train and test codes 67 | 68 | The `trian.py` and `test.py` files are sample codes to train and test PointNet. The data loaders etc. can be used to train other models. Codes are adopted from [here](https://github.com/yanx27/Pointnet_Pointnet2_pytorch). 69 | 70 | ----------------- 71 | 72 | ### Publication 73 | 74 | Please cite the paper below if you use RobustPointSet in your research. 75 | 76 | [RobustPointSet: A Dataset for Benchmarking Robustness of Point Cloud Classifiers](https://arxiv.org/abs/2011.11572) 77 | 78 | ``` 79 | @article{taghanaki2020robustpointset, 80 | title={RobustPointSet: A Dataset for Benchmarking Robustness of Point Cloud Classifiers}, 81 | author={{Asgari Taghanaki}, Saeid and Luo, Jieliang and Zhang, Ran and Wang, Ye and Jayaraman, {Pradeep Kumar} and Jatavallabhula, {Krishna Murthy}}, 82 | year={2020}, 83 | journal={arXiv preprint arXiv:2011.11572} 84 | } 85 | ``` 86 | ----------------- 87 | ### Download 88 | The dateset consists of two parts: [Part I](https://github.com/AutodeskAILab/RobustPointSet/releases/download/v1.0/RobustPointSet.z01) and [Part II](https://github.com/AutodeskAILab/RobustPointSet/releases/download/v1.0/RobustPointSet.zip). Please download both parts and unzip Part I, which will automatically extract the two parts into the same folder. 89 | 90 | ----------------- 91 | ### Reference 92 | We use the following implementations with minor modifications for our evaluations. 93 | 94 | * PointNet: [https://github.com/TianzhongSong/PointNet-Keras](https://github.com/TianzhongSong/PointNet-Keras) 95 | * PointNet++: [https://github.com/yanx27/Pointnet_Pointnet2_pytorch](https://github.com/yanx27/Pointnet_Pointnet2_pytorch) 96 | * DGCNN: [https://github.com/WangYueFt/dgcnn](https://github.com/WangYueFt/dgcnn) 97 | * DensePoint: [https://github.com/Yochengliu/DensePoint](https://github.com/Yochengliu/DensePoint) 98 | * PointCNN: [https://github.com/hxdengBerkeley/PointCNN.Pytorch](https://github.com/hxdengBerkeley/PointCNN.Pytorch) 99 | * PointConv: [https://github.com/DylanWusee/pointconv_pytorch](https://github.com/DylanWusee/pointconv_pytorch) 100 | * Relation-Shape-CNN: [https://github.com/Yochengliu/Relation-Shape-CNN](https://github.com/Yochengliu/Relation-Shape-CNN) 101 | * SPHnet: [https://github.com/adrienPoulenard/SPHnet](https://github.com/adrienPoulenard/SPHnet) 102 | * PRIN: [https://github.com/qq456cvb/PRIN](https://github.com/qq456cvb/PRIN) 103 | * PointMask: [https://github.com/asgsaeid/PointMask](https://github.com/asgsaeid/PointMask) 104 | 105 | Note: 106 | For DensePoint, Relation-Shape-CNN, and PRIN, you will need to run the code on an older version (0.3 & 0.4) of PyTorch. Below are the steps to create a Conda envrionment with an older version of PyTorch: 107 | 108 | ``` 109 | - wget -c https://repo.continuum.io/archive/Anaconda3-5.2.0-Linux-x86_64.sh 110 | - chmod +x Anaconda3-5.2.0-Linux-x86_64.sh 111 | - ./Anaconda3-5.2.0-Linux-x86_64.sh 112 | - conda create -n myenv python=3.6 113 | - conda install pytorch=0.4.1 cuda92 -c pytorch -y 114 | ``` 115 | 116 | ----------------- 117 | ### License 118 | 119 | Please refer to the [dataset license](https://github.com/AutodeskAILab/RobustPointSet/blob/main/LICENSE.md). 120 | 121 | 122 | -------------------------------------------------------------------------------- /RobustPointSet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AutodeskAILab/RobustPointSet/f5439a242da5076d2ca632654a3236ec93a2f197/RobustPointSet.png -------------------------------------------------------------------------------- /data_utils/ModelNetDataLoader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | 5 | 6 | def pc_normalize(pc): 7 | centroid = np.mean(pc, axis=0) 8 | pc = pc - centroid 9 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 10 | pc = pc / m 11 | return pc 12 | 13 | def farthest_point_sample(point, npoint): 14 | """ 15 | Input: 16 | xyz: pointcloud data, [N, D] 17 | npoint: number of samples 18 | Return: 19 | centroids: sampled pointcloud index, [npoint, D] 20 | """ 21 | N, D = point.shape 22 | xyz = point[:,:3] 23 | centroids = np.zeros((npoint,)) 24 | distance = np.ones((N,)) * 1e10 25 | farthest = np.random.randint(0, N) 26 | for i in range(npoint): 27 | centroids[i] = farthest 28 | centroid = xyz[farthest, :] 29 | dist = np.sum((xyz - centroid) ** 2, -1) 30 | mask = dist < distance 31 | distance[mask] = dist[mask] 32 | farthest = np.argmax(distance, -1) 33 | point = point[centroids.astype(np.int32)] 34 | return point 35 | 36 | 37 | 38 | class ModelNetDataLoader(Dataset): 39 | def __init__(self, root, tasks, labels, partition='train', npoint=2048, uniform=False, normal_channel=False, cache_size=15000): 40 | self.root = root 41 | self.npoints = npoint 42 | self.uniform = uniform 43 | self.normal_channel = normal_channel 44 | self.data, self.label = load_data(root, tasks, labels) 45 | self.partition = partition 46 | print('The number of ' + partition + ' data: ' + str(self.data.shape[0])) 47 | 48 | def __len__(self): 49 | return self.data.shape[0] 50 | 51 | def _get_item(self, index): 52 | pointcloud = self.data[index][:self.npoints] 53 | label = self.label[index] 54 | if self.partition == 'train': 55 | np.random.shuffle(pointcloud) 56 | 57 | return pointcloud, label 58 | 59 | def __getitem__(self, index): 60 | return self._get_item(index) 61 | 62 | 63 | 64 | def load_data(root, tasks, labels): 65 | all_data = [] 66 | all_label = [] 67 | for i in range(len(tasks)): 68 | data = np.load(os.path.join(root, tasks[i])) 69 | label = np.load(os.path.join(root, labels[i])) 70 | 71 | all_data.append(data) 72 | all_label.append(label) 73 | all_data = np.concatenate(all_data, axis=0) 74 | all_label = np.concatenate(all_label, axis=0) 75 | 76 | 77 | 78 | 79 | return all_data, all_label -------------------------------------------------------------------------------- /data_utils/provider.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def normalize_data(batch_data): 4 | """ Normalize the batch data, use coordinates of the block centered at origin, 5 | Input: 6 | BxNxC array 7 | Output: 8 | BxNxC array 9 | """ 10 | B, N, C = batch_data.shape 11 | normal_data = np.zeros((B, N, C)) 12 | for b in range(B): 13 | pc = batch_data[b] 14 | centroid = np.mean(pc, axis=0) 15 | pc = pc - centroid 16 | m = np.max(np.sqrt(np.sum(pc ** 2, axis=1))) 17 | pc = pc / m 18 | normal_data[b] = pc 19 | return normal_data 20 | 21 | 22 | def shuffle_data(data, labels): 23 | """ Shuffle data and labels. 24 | Input: 25 | data: B,N,... numpy array 26 | label: B,... numpy array 27 | Return: 28 | shuffled data, label and shuffle indices 29 | """ 30 | idx = np.arange(len(labels)) 31 | np.random.shuffle(idx) 32 | return data[idx, ...], labels[idx], idx 33 | 34 | def shuffle_points(batch_data): 35 | """ Shuffle orders of points in each point cloud -- changes FPS behavior. 36 | Use the same shuffling idx for the entire batch. 37 | Input: 38 | BxNxC array 39 | Output: 40 | BxNxC array 41 | """ 42 | idx = np.arange(batch_data.shape[1]) 43 | np.random.shuffle(idx) 44 | return batch_data[:,idx,:] 45 | 46 | def rotate_point_cloud(batch_data): 47 | """ Randomly rotate the point clouds to augument the dataset 48 | rotation is per shape based along up direction 49 | Input: 50 | BxNx3 array, original batch of point clouds 51 | Return: 52 | BxNx3 array, rotated batch of point clouds 53 | """ 54 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 55 | for k in range(batch_data.shape[0]): 56 | rotation_angle = np.random.uniform() * 2 * np.pi 57 | cosval = np.cos(rotation_angle) 58 | sinval = np.sin(rotation_angle) 59 | rotation_matrix = np.array([[cosval, 0, sinval], 60 | [0, 1, 0], 61 | [-sinval, 0, cosval]]) 62 | shape_pc = batch_data[k, ...] 63 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 64 | return rotated_data 65 | 66 | def rotate_point_cloud_z(batch_data): 67 | """ Randomly rotate the point clouds to augument the dataset 68 | rotation is per shape based along up direction 69 | Input: 70 | BxNx3 array, original batch of point clouds 71 | Return: 72 | BxNx3 array, rotated batch of point clouds 73 | """ 74 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 75 | for k in range(batch_data.shape[0]): 76 | rotation_angle = np.random.uniform() * 2 * np.pi 77 | cosval = np.cos(rotation_angle) 78 | sinval = np.sin(rotation_angle) 79 | rotation_matrix = np.array([[cosval, sinval, 0], 80 | [-sinval, cosval, 0], 81 | [0, 0, 1]]) 82 | shape_pc = batch_data[k, ...] 83 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 84 | return rotated_data 85 | 86 | def rotate_point_cloud_with_normal(batch_xyz_normal): 87 | ''' Randomly rotate XYZ, normal point cloud. 88 | Input: 89 | batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal 90 | Output: 91 | B,N,6, rotated XYZ, normal point cloud 92 | ''' 93 | for k in range(batch_xyz_normal.shape[0]): 94 | rotation_angle = np.random.uniform() * 2 * np.pi 95 | cosval = np.cos(rotation_angle) 96 | sinval = np.sin(rotation_angle) 97 | rotation_matrix = np.array([[cosval, 0, sinval], 98 | [0, 1, 0], 99 | [-sinval, 0, cosval]]) 100 | shape_pc = batch_xyz_normal[k,:,0:3] 101 | shape_normal = batch_xyz_normal[k,:,3:6] 102 | batch_xyz_normal[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 103 | batch_xyz_normal[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix) 104 | return batch_xyz_normal 105 | 106 | def rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, angle_clip=0.18): 107 | """ Randomly perturb the point clouds by small rotations 108 | Input: 109 | BxNx6 array, original batch of point clouds and point normals 110 | Return: 111 | BxNx3 array, rotated batch of point clouds 112 | """ 113 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 114 | for k in range(batch_data.shape[0]): 115 | angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) 116 | Rx = np.array([[1,0,0], 117 | [0,np.cos(angles[0]),-np.sin(angles[0])], 118 | [0,np.sin(angles[0]),np.cos(angles[0])]]) 119 | Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], 120 | [0,1,0], 121 | [-np.sin(angles[1]),0,np.cos(angles[1])]]) 122 | Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], 123 | [np.sin(angles[2]),np.cos(angles[2]),0], 124 | [0,0,1]]) 125 | R = np.dot(Rz, np.dot(Ry,Rx)) 126 | shape_pc = batch_data[k,:,0:3] 127 | shape_normal = batch_data[k,:,3:6] 128 | rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), R) 129 | rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), R) 130 | return rotated_data 131 | 132 | 133 | def rotate_point_cloud_by_angle(batch_data, rotation_angle): 134 | """ Rotate the point cloud along up direction with certain angle. 135 | Input: 136 | BxNx3 array, original batch of point clouds 137 | Return: 138 | BxNx3 array, rotated batch of point clouds 139 | """ 140 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 141 | for k in range(batch_data.shape[0]): 142 | #rotation_angle = np.random.uniform() * 2 * np.pi 143 | cosval = np.cos(rotation_angle) 144 | sinval = np.sin(rotation_angle) 145 | rotation_matrix = np.array([[cosval, 0, sinval], 146 | [0, 1, 0], 147 | [-sinval, 0, cosval]]) 148 | shape_pc = batch_data[k,:,0:3] 149 | rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 150 | return rotated_data 151 | 152 | def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle): 153 | """ Rotate the point cloud along up direction with certain angle. 154 | Input: 155 | BxNx6 array, original batch of point clouds with normal 156 | scalar, angle of rotation 157 | Return: 158 | BxNx6 array, rotated batch of point clouds iwth normal 159 | """ 160 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 161 | for k in range(batch_data.shape[0]): 162 | #rotation_angle = np.random.uniform() * 2 * np.pi 163 | cosval = np.cos(rotation_angle) 164 | sinval = np.sin(rotation_angle) 165 | rotation_matrix = np.array([[cosval, 0, sinval], 166 | [0, 1, 0], 167 | [-sinval, 0, cosval]]) 168 | shape_pc = batch_data[k,:,0:3] 169 | shape_normal = batch_data[k,:,3:6] 170 | rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 171 | rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1,3)), rotation_matrix) 172 | return rotated_data 173 | 174 | 175 | 176 | def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18): 177 | """ Randomly perturb the point clouds by small rotations 178 | Input: 179 | BxNx3 array, original batch of point clouds 180 | Return: 181 | BxNx3 array, rotated batch of point clouds 182 | """ 183 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 184 | for k in range(batch_data.shape[0]): 185 | angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) 186 | Rx = np.array([[1,0,0], 187 | [0,np.cos(angles[0]),-np.sin(angles[0])], 188 | [0,np.sin(angles[0]),np.cos(angles[0])]]) 189 | Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], 190 | [0,1,0], 191 | [-np.sin(angles[1]),0,np.cos(angles[1])]]) 192 | Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], 193 | [np.sin(angles[2]),np.cos(angles[2]),0], 194 | [0,0,1]]) 195 | R = np.dot(Rz, np.dot(Ry,Rx)) 196 | shape_pc = batch_data[k, ...] 197 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R) 198 | return rotated_data 199 | 200 | 201 | def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05): 202 | """ Randomly jitter points. jittering is per point. 203 | Input: 204 | BxNx3 array, original batch of point clouds 205 | Return: 206 | BxNx3 array, jittered batch of point clouds 207 | """ 208 | B, N, C = batch_data.shape 209 | assert(clip > 0) 210 | jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip) 211 | jittered_data += batch_data 212 | return jittered_data 213 | 214 | def shift_point_cloud(batch_data, shift_range=0.1): 215 | """ Randomly shift point cloud. Shift is per point cloud. 216 | Input: 217 | BxNx3 array, original batch of point clouds 218 | Return: 219 | BxNx3 array, shifted batch of point clouds 220 | """ 221 | B, N, C = batch_data.shape 222 | shifts = np.random.uniform(-shift_range, shift_range, (B,3)) 223 | for batch_index in range(B): 224 | batch_data[batch_index,:,:] += shifts[batch_index,:] 225 | return batch_data 226 | 227 | 228 | def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25): 229 | """ Randomly scale the point cloud. Scale is per point cloud. 230 | Input: 231 | BxNx3 array, original batch of point clouds 232 | Return: 233 | BxNx3 array, scaled batch of point clouds 234 | """ 235 | B, N, C = batch_data.shape 236 | scales = np.random.uniform(scale_low, scale_high, B) 237 | for batch_index in range(B): 238 | batch_data[batch_index,:,:] *= scales[batch_index] 239 | return batch_data 240 | 241 | def random_point_dropout(batch_pc, max_dropout_ratio=0.875): 242 | ''' batch_pc: BxNx3 ''' 243 | for b in range(batch_pc.shape[0]): 244 | dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875 245 | drop_idx = np.where(np.random.random((batch_pc.shape[1]))<=dropout_ratio)[0] 246 | if len(drop_idx)>0: 247 | batch_pc[b,drop_idx,:] = batch_pc[b,0,:] # set to the first point 248 | return batch_pc -------------------------------------------------------------------------------- /models/pointnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | import torch.utils.data 5 | from torch.autograd import Variable 6 | import numpy as np 7 | import torch.nn.functional as F 8 | 9 | class STN3d(nn.Module): 10 | def __init__(self, channel): 11 | super(STN3d, self).__init__() 12 | self.conv1 = torch.nn.Conv1d(channel, 64, 1) 13 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 14 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 15 | self.fc1 = nn.Linear(1024, 512) 16 | self.fc2 = nn.Linear(512, 256) 17 | self.fc3 = nn.Linear(256, 9) 18 | self.relu = nn.ReLU() 19 | 20 | self.bn1 = nn.BatchNorm1d(64) 21 | self.bn2 = nn.BatchNorm1d(128) 22 | self.bn3 = nn.BatchNorm1d(1024) 23 | self.bn4 = nn.BatchNorm1d(512) 24 | self.bn5 = nn.BatchNorm1d(256) 25 | 26 | def forward(self, x): 27 | batchsize = x.size()[0] 28 | x = F.relu(self.bn1(self.conv1(x))) 29 | x = F.relu(self.bn2(self.conv2(x))) 30 | x = F.relu(self.bn3(self.conv3(x))) 31 | x = torch.max(x, 2, keepdim=True)[0] 32 | x = x.view(-1, 1024) 33 | 34 | x = F.relu(self.bn4(self.fc1(x))) 35 | x = F.relu(self.bn5(self.fc2(x))) 36 | x = self.fc3(x) 37 | 38 | iden = Variable(torch.from_numpy(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32))).view(1, 9).repeat( 39 | batchsize, 1) 40 | if x.is_cuda: 41 | iden = iden.cuda() 42 | x = x + iden 43 | x = x.view(-1, 3, 3) 44 | return x 45 | 46 | 47 | 48 | 49 | class STNkd(nn.Module): 50 | def __init__(self, k=64): 51 | super(STNkd, self).__init__() 52 | self.conv1 = torch.nn.Conv1d(k, 64, 1) 53 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 54 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 55 | self.fc1 = nn.Linear(1024, 512) 56 | self.fc2 = nn.Linear(512, 256) 57 | self.fc3 = nn.Linear(256, k * k) 58 | self.relu = nn.ReLU() 59 | 60 | self.bn1 = nn.BatchNorm1d(64) 61 | self.bn2 = nn.BatchNorm1d(128) 62 | self.bn3 = nn.BatchNorm1d(1024) 63 | self.bn4 = nn.BatchNorm1d(512) 64 | self.bn5 = nn.BatchNorm1d(256) 65 | 66 | self.k = k 67 | 68 | def forward(self, x): 69 | batchsize = x.size()[0] 70 | x = F.relu(self.bn1(self.conv1(x))) 71 | x = F.relu(self.bn2(self.conv2(x))) 72 | x = F.relu(self.bn3(self.conv3(x))) 73 | x = torch.max(x, 2, keepdim=True)[0] 74 | x = x.view(-1, 1024) 75 | 76 | x = F.relu(self.bn4(self.fc1(x))) 77 | x = F.relu(self.bn5(self.fc2(x))) 78 | x = self.fc3(x) 79 | 80 | iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1, self.k * self.k).repeat( 81 | batchsize, 1) 82 | if x.is_cuda: 83 | iden = iden.cuda() 84 | x = x + iden 85 | x = x.view(-1, self.k, self.k) 86 | return x 87 | 88 | 89 | 90 | class PointNetEncoder(nn.Module): 91 | def __init__(self, global_feat=True, feature_transform=False, channel=3): 92 | super(PointNetEncoder, self).__init__() 93 | self.stn = STN3d(channel) 94 | self.conv1 = torch.nn.Conv1d(channel, 64, 1) 95 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 96 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 97 | self.bn1 = nn.BatchNorm1d(64) 98 | self.bn2 = nn.BatchNorm1d(128) 99 | self.bn3 = nn.BatchNorm1d(1024) 100 | self.global_feat = global_feat 101 | self.feature_transform = feature_transform 102 | if self.feature_transform: 103 | self.fstn = STNkd(k=64) 104 | 105 | def forward(self, x): 106 | B, D, N = x.size() 107 | trans = self.stn(x) 108 | x = x.transpose(2, 1) 109 | if D > 3 : 110 | x, feature = x.split(3,dim=2) 111 | x = torch.bmm(x, trans) 112 | if D > 3: 113 | x = torch.cat([x,feature],dim=2) 114 | x = x.transpose(2, 1) 115 | x = F.relu(self.bn1(self.conv1(x))) 116 | 117 | if self.feature_transform: 118 | trans_feat = self.fstn(x) 119 | x = x.transpose(2, 1) 120 | x = torch.bmm(x, trans_feat) 121 | x = x.transpose(2, 1) 122 | else: 123 | trans_feat = None 124 | 125 | pointfeat = x 126 | x = F.relu(self.bn2(self.conv2(x))) 127 | x = F.relu(self.bn3(self.conv3(x))) 128 | x = torch.max(x, 2, keepdim=True)[0] 129 | x = x.view(-1, 1024) 130 | if self.global_feat: 131 | return x, trans, trans_feat 132 | else: 133 | x = x.view(-1, 1024, 1).repeat(1, 1, N) 134 | return torch.cat([x, pointfeat], 1), trans, trans_feat 135 | 136 | 137 | 138 | 139 | 140 | def feature_transform_reguliarzer(trans): 141 | d = trans.size()[1] 142 | I = torch.eye(d)[None, :, :] 143 | if trans.is_cuda: 144 | I = I.cuda() 145 | loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2, 1) - I), dim=(1, 2), p=2)) 146 | return loss -------------------------------------------------------------------------------- /models/pointnet_cls.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.data 3 | import torch.nn.functional as F 4 | from pointnet import PointNetEncoder, feature_transform_reguliarzer 5 | 6 | 7 | class get_model(nn.Module): 8 | def __init__(self, k=40, normal_channel=True): 9 | super(get_model, self).__init__() 10 | if normal_channel: 11 | channel = 6 12 | else: 13 | channel = 3 14 | self.feat = PointNetEncoder(global_feat=True, feature_transform=True, channel=channel) 15 | self.fc1 = nn.Linear(1024, 512) 16 | self.fc2 = nn.Linear(512, 256) 17 | self.fc3 = nn.Linear(256, k) 18 | self.dropout = nn.Dropout(p=0.4) 19 | self.bn1 = nn.BatchNorm1d(512) 20 | self.bn2 = nn.BatchNorm1d(256) 21 | self.relu = nn.ReLU() 22 | 23 | def forward(self, x): 24 | x, trans, trans_feat = self.feat(x) 25 | x = F.relu(self.bn1(self.fc1(x))) 26 | x = F.relu(self.bn2(self.dropout(self.fc2(x)))) 27 | x = self.fc3(x) 28 | x = F.log_softmax(x, dim=1) 29 | return x, trans_feat 30 | 31 | 32 | class get_loss(torch.nn.Module): 33 | def __init__(self, mat_diff_loss_scale=0.001): 34 | super(get_loss, self).__init__() 35 | self.mat_diff_loss_scale = mat_diff_loss_scale 36 | 37 | def forward(self, pred, target, trans_feat): 38 | loss = F.nll_loss(pred, target) 39 | mat_diff_loss = feature_transform_reguliarzer(trans_feat) 40 | 41 | total_loss = loss + mat_diff_loss * self.mat_diff_loss_scale 42 | return total_loss -------------------------------------------------------------------------------- /models/pointnet_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from time import time 5 | import numpy as np 6 | 7 | 8 | def timeit(tag, t): 9 | print("{}: {}s".format(tag, time() - t)) 10 | return time() 11 | 12 | 13 | def pc_normalize(pc): 14 | """ 15 | normalize point clouds into a unit sphere 16 | """ 17 | l = pc.shape[0] 18 | centroid = np.mean(pc, axis=0) 19 | pc = pc - centroid 20 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 21 | pc = pc / m 22 | return pc 23 | 24 | 25 | def square_distance(src, dst): 26 | """ 27 | Calculate Euclid distance between each two points. 28 | 29 | src^T * dst = xn * xm + yn * ym + zn * zm; 30 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 31 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 32 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 33 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 34 | Input: 35 | src: source points, [B, N, C] 36 | dst: target points, [B, M, C] 37 | Output: 38 | dist: per-point square distance, [B, N, M] 39 | """ 40 | B, N, _ = src.shape 41 | _, M, _ = dst.shape 42 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 43 | dist += torch.sum(src ** 2, -1).view(B, N, 1) 44 | dist += torch.sum(dst ** 2, -1).view(B, 1, M) 45 | return dist 46 | 47 | 48 | def index_points(points, idx): 49 | """ 50 | Input: 51 | points: input points data, [B, N, C] 52 | idx: sample index data, [B, S] 53 | Return: 54 | new_points:, indexed points data, [B, S, C] 55 | """ 56 | device = points.device 57 | B = points.shape[0] 58 | view_shape = list(idx.shape) 59 | view_shape[1:] = [1] * (len(view_shape) - 1) 60 | repeat_shape = list(idx.shape) 61 | repeat_shape[0] = 1 62 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) 63 | new_points = points[batch_indices, idx, :] 64 | return new_points 65 | 66 | 67 | def farthest_point_sample(xyz, npoint): 68 | """ 69 | Input: 70 | xyz: pointcloud data, [B, N, 3] 71 | npoint: number of samples 72 | Return: 73 | centroids: sampled pointcloud index, [B, npoint] 74 | """ 75 | device = xyz.device 76 | B, N, C = xyz.shape 77 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 78 | distance = torch.ones(B, N).to(device) * 1e10 79 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 80 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 81 | for i in range(npoint): 82 | centroids[:, i] = farthest 83 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 84 | dist = torch.sum((xyz - centroid) ** 2, -1) 85 | mask = dist < distance 86 | distance[mask] = dist[mask] 87 | farthest = torch.max(distance, -1)[1] 88 | return centroids 89 | 90 | 91 | 92 | def query_ball_point(radius, nsample, xyz, new_xyz): 93 | """ 94 | # Find n neighbors within r for new_xyz in xyz 95 | Input: 96 | radius: local region radius 97 | nsample: max sample number in local region 98 | xyz: all points, [B, N, 3] 99 | new_xyz: query points, [B, S, 3] 100 | Return: 101 | group_idx: grouped points index, [B, S, nsample] 102 | """ 103 | device = xyz.device 104 | B, N, C = xyz.shape 105 | _, S, _ = new_xyz.shape 106 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) 107 | sqrdists = square_distance(new_xyz, xyz) 108 | group_idx[sqrdists > radius ** 2] = N 109 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] 110 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) 111 | mask = group_idx == N 112 | group_idx[mask] = group_first[mask] 113 | return group_idx 114 | 115 | 116 | 117 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False): 118 | """ 119 | Input: 120 | npoint: 121 | radius: 122 | nsample: 123 | xyz: input points position data, [B, N, 3] 124 | points: input points data, [B, N, D] 125 | Return: 126 | new_xyz: sampled points position data, [B, npoint, 3] 127 | new_points: sampled points data, [B, npoint, nsample, 3+D] 128 | """ 129 | B, N, C = xyz.shape 130 | S = npoint 131 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C] 132 | torch.cuda.empty_cache() 133 | new_xyz = index_points(xyz, fps_idx) 134 | torch.cuda.empty_cache() 135 | idx = query_ball_point(radius, nsample, xyz, new_xyz) 136 | torch.cuda.empty_cache() 137 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] 138 | torch.cuda.empty_cache() 139 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) 140 | torch.cuda.empty_cache() 141 | 142 | if points is not None: 143 | grouped_points = index_points(points, idx) 144 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D] 145 | else: 146 | new_points = grouped_xyz_norm 147 | if returnfps: 148 | return new_xyz, new_points, grouped_xyz, fps_idx 149 | else: 150 | return new_xyz, new_points 151 | 152 | 153 | 154 | 155 | def sample_and_group_all(xyz, points): 156 | """ 157 | Input: 158 | xyz: input points position data, [B, N, 3] 159 | points: input points data, [B, N, D] 160 | Return: 161 | new_xyz: sampled points position data, [B, 1, 3] 162 | new_points: sampled points data, [B, 1, N, 3+D] 163 | """ 164 | device = xyz.device 165 | B, N, C = xyz.shape 166 | new_xyz = torch.zeros(B, 1, C).to(device) 167 | grouped_xyz = xyz.view(B, 1, N, C) 168 | if points is not None: 169 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) 170 | else: 171 | new_points = grouped_xyz 172 | return new_xyz, new_points 173 | 174 | 175 | 176 | 177 | class PointNetSetAbstraction(nn.Module): 178 | def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all): 179 | super(PointNetSetAbstraction, self).__init__() 180 | self.npoint = npoint # number of points sampled from farthest point sampling 181 | self.radius = radius 182 | self.nsample = nsample # the number of points in each local region 183 | self.mlp_convs = nn.ModuleList() 184 | self.mlp_bns = nn.ModuleList() 185 | last_channel = in_channel 186 | for out_channel in mlp: 187 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) 188 | self.mlp_bns.append(nn.BatchNorm2d(out_channel)) 189 | last_channel = out_channel 190 | self.group_all = group_all # group all points into one PC if set true 191 | 192 | 193 | def forward(self, xyz, points): 194 | """ 195 | Input: 196 | xyz: input points position data, [B, C, N] 197 | points: input points data, [B, D, N] 198 | Return: 199 | new_xyz: sampled points position data, [B, C, S] 200 | new_points_concat: sample points feature data, [B, D', S] 201 | """ 202 | xyz = xyz.permute(0, 2, 1) 203 | if points is not None: 204 | points = points.permute(0, 2, 1) 205 | 206 | if self.group_all: 207 | new_xyz, new_points = sample_and_group_all(xyz, points) 208 | else: 209 | new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points) 210 | # new_xyz: sampled points position data, [B, npoint, C] 211 | # new_points: sampled points data, [B, npoint, nsample, C+D] 212 | new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint] 213 | for i, conv in enumerate(self.mlp_convs): 214 | bn = self.mlp_bns[i] 215 | new_points = F.relu(bn(conv(new_points))) 216 | 217 | new_points = torch.max(new_points, 2)[0] 218 | new_xyz = new_xyz.permute(0, 2, 1) 219 | return new_xyz, new_points 220 | 221 | 222 | 223 | 224 | class PointNetSetAbstractionMsg(nn.Module): 225 | def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list): 226 | super(PointNetSetAbstractionMsg, self).__init__() 227 | self.npoint = npoint 228 | self.radius_list = radius_list 229 | self.nsample_list = nsample_list 230 | self.conv_blocks = nn.ModuleList() 231 | self.bn_blocks = nn.ModuleList() 232 | for i in range(len(mlp_list)): 233 | convs = nn.ModuleList() 234 | bns = nn.ModuleList() 235 | last_channel = in_channel + 3 236 | for out_channel in mlp_list[i]: 237 | convs.append(nn.Conv2d(last_channel, out_channel, 1)) 238 | bns.append(nn.BatchNorm2d(out_channel)) 239 | last_channel = out_channel 240 | self.conv_blocks.append(convs) 241 | self.bn_blocks.append(bns) 242 | 243 | 244 | def forward(self, xyz, points): 245 | """ 246 | Input: 247 | xyz: input points position data, [B, C, N] 248 | points: input points data, [B, D, N] 249 | Return: 250 | new_xyz: sampled points position data, [B, C, S] 251 | new_points_concat: sample points feature data, [B, D', S] 252 | """ 253 | xyz = xyz.permute(0, 2, 1) 254 | if points is not None: 255 | points = points.permute(0, 2, 1) 256 | 257 | B, N, C = xyz.shape 258 | S = self.npoint 259 | new_xyz = index_points(xyz, farthest_point_sample(xyz, S)) 260 | new_points_list = [] 261 | for i, radius in enumerate(self.radius_list): 262 | K = self.nsample_list[i] 263 | group_idx = query_ball_point(radius, K, xyz, new_xyz) 264 | grouped_xyz = index_points(xyz, group_idx) 265 | grouped_xyz -= new_xyz.view(B, S, 1, C) # normalize neighborhood 266 | if points is not None: 267 | grouped_points = index_points(points, group_idx) 268 | grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) 269 | else: 270 | grouped_points = grouped_xyz 271 | 272 | grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D-channels, K-nsample, S-npoint] 273 | 274 | for j in range(len(self.conv_blocks[i])): 275 | conv = self.conv_blocks[i][j] 276 | bn = self.bn_blocks[i][j] 277 | grouped_points = F.relu(bn(conv(grouped_points))) 278 | new_points = torch.max(grouped_points, 2)[0] # [B, D', S] 279 | new_points_list.append(new_points) 280 | 281 | new_xyz = new_xyz.permute(0, 2, 1) 282 | new_points_concat = torch.cat(new_points_list, dim=1) 283 | return new_xyz, new_points_concat 284 | 285 | 286 | 287 | class PointNetFeaturePropagation(nn.Module): 288 | def __init__(self, in_channel, mlp): 289 | super(PointNetFeaturePropagation, self).__init__() 290 | self.mlp_convs = nn.ModuleList() 291 | self.mlp_bns = nn.ModuleList() 292 | last_channel = in_channel 293 | for out_channel in mlp: 294 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) 295 | self.mlp_bns.append(nn.BatchNorm1d(out_channel)) 296 | last_channel = out_channel 297 | 298 | 299 | def forward(self, xyz1, xyz2, points1, points2): 300 | """ 301 | Input: 302 | xyz1: input points position data, [B, C, N] 303 | xyz2: sampled input points position data, [B, C, S] 304 | points1: input points data, [B, D, N] 305 | points2: input points data, [B, D, S] 306 | Return: 307 | new_points: upsampled points data, [B, D', N] 308 | """ 309 | xyz1 = xyz1.permute(0, 2, 1) 310 | xyz2 = xyz2.permute(0, 2, 1) 311 | 312 | points2 = points2.permute(0, 2, 1) 313 | B, N, C = xyz1.shape 314 | _, S, _ = xyz2.shape 315 | 316 | if S == 1: 317 | interpolated_points = points2.repeat(1, N, 1) 318 | else: 319 | dists = square_distance(xyz1, xyz2) 320 | dists, idx = dists.sort(dim=-1) 321 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] 322 | 323 | dist_recip = 1.0 / (dists + 1e-8) 324 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 325 | weight = dist_recip / norm 326 | interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2) 327 | 328 | if points1 is not None: 329 | points1 = points1.permute(0, 2, 1) 330 | new_points = torch.cat([points1, interpolated_points], dim=-1) 331 | else: 332 | new_points = interpolated_points 333 | 334 | new_points = new_points.permute(0, 2, 1) 335 | for i, conv in enumerate(self.mlp_convs): 336 | bn = self.mlp_bns[i] 337 | new_points = F.relu(bn(conv(new_points))) 338 | return new_points -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import argparse 5 | import importlib 6 | import numpy as np 7 | from tqdm import tqdm 8 | import multiprocessing 9 | from pathlib import Path 10 | from torch.optim.lr_scheduler import CosineAnnealingLR 11 | 12 | sys.path.append("..") 13 | from data_utils.ModelNetDataLoader import ModelNetDataLoader 14 | from data_utils import provider 15 | 16 | BASE_DIR = os.path.dirname(os.path.abspath('__file__')) 17 | ROOT_DIR = BASE_DIR 18 | sys.path.append(os.path.join(ROOT_DIR, 'models')) 19 | sys.path.append(os.path.join(ROOT_DIR, 'log')) 20 | 21 | 22 | 23 | def test(model, loader, num_class=40, vote_num=1): 24 | 25 | mean_correct = [] 26 | class_acc = np.zeros((num_class,3)) 27 | for j, data in tqdm(enumerate(loader), total=len(loader)): 28 | points, target = data 29 | target = target[:, 0] 30 | points = points.transpose(2, 1) 31 | points, target = points.cuda(), target.cuda() 32 | classifier = model.eval() 33 | vote_pool = torch.zeros(target.size()[0],num_class).cuda() 34 | for _ in range(vote_num): 35 | pred, _ = classifier(points.float()) 36 | vote_pool += pred 37 | pred = vote_pool/vote_num 38 | pred_choice = pred.data.max(1)[1] 39 | for cat in np.unique(target.cpu()): 40 | cat = int(cat) 41 | classacc = pred_choice[target==cat].eq(target[target==cat].long().data).cpu().sum() 42 | class_acc[cat,0]+= classacc.item()/float(points[target==cat].size()[0]) 43 | class_acc[cat,1]+=1 44 | correct = pred_choice.eq(target.long().data).cpu().sum() 45 | mean_correct.append(correct.item()/float(points.size()[0])) 46 | class_acc[:,2] = class_acc[:,0]/ class_acc[:,1] 47 | class_acc = np.mean(class_acc[:,2]) 48 | instance_acc = np.mean(mean_correct) 49 | return instance_acc, class_acc 50 | 51 | 52 | def main(args): 53 | def log_string(str): 54 | logger.info(str) 55 | print(str) 56 | 57 | '''HYPER PARAMETER''' 58 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 59 | torch.backends.cudnn.enabled = False 60 | 61 | '''DATA LOADING''' 62 | TEST_DATASET = ModelNetDataLoader(root=args.data_root, 63 | tasks=args.test_tasks, 64 | labels=args.test_labels, 65 | partition='test', 66 | npoint=args.num_point, 67 | normal_channel=args.normal) 68 | testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=2) 69 | 70 | '''MODEL LOADING''' 71 | num_class = 40 72 | files = os.listdir(args.model_dir+'/logs') 73 | for f in files: 74 | if f.endswith('txt'): 75 | model_name = f.split('.')[0] 76 | 77 | MODEL = importlib.import_module(model_name) 78 | 79 | classifier = MODEL.get_model(num_class,normal_channel=args.normal).cuda() 80 | checkpoint = torch.load(str(args.model_dir) + '/checkpoints/best_model.pth') 81 | classifier.load_state_dict(checkpoint['model_state_dict']) 82 | 83 | with torch.no_grad(): 84 | instance_acc, class_acc = test(classifier.eval(), testDataLoader, vote_num=args.num_votes) 85 | print('Test Instance Accuracy: %f, Class Accuracy: %f' % (instance_acc, class_acc)) 86 | 87 | 88 | def parseArgs(): 89 | """ Argument parser for configuring model testing """ 90 | parser = argparse.ArgumentParser(description='RobustPointSet trainer') 91 | parser.add_argument('--batch_size', type=int, default=32) 92 | parser.add_argument('--model', type=str, default='pointnet_cls', help='point cloud model') 93 | parser.add_argument('--ults', type=str, default='pointnet_util', help='help functions for point cloud model') 94 | parser.add_argument('--gpu', type=str, default='0', help='gpu device index') 95 | parser.add_argument('--num_point', type=int, default=2048, help='number of points') 96 | parser.add_argument('--num_votes', type=int, default=1, help='number of time to run testing and doing majority vote') 97 | parser.add_argument('--normal', default=False, help='Whether to use normal information [default: False]') 98 | parser.add_argument('--data_root', type=str, default='data/', help='data directory') 99 | parser.add_argument('--test_tasks', type=str, nargs='+', required=True, help="List of RobustPointSet files to be tested on during training") 100 | parser.add_argument('--model_dir', type=str, required=True, help="model checkpoint") 101 | return parser.parse_args() 102 | 103 | if __name__ == '__main__': 104 | """ 105 | Available data: original, jitter, translate, missing_part, sparse, rotation, occlusion 106 | Example command for strategy 1 & 2: 107 | python test.py --test_tasks test_original.npy --model_dir log/classification/2021-02-17_10-37-pointnet_cls-s1 108 | python test.py --test_tasks test_rotation.npy --model_dir log/classification/2021-02-17_10-37-pointnet_cls-s1 109 | 110 | """ 111 | args = parseArgs() 112 | args.test_labels =['test_labels.npy']*len(args.test_tasks) 113 | main(args) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import argparse 5 | import datetime 6 | import logging 7 | import shutil 8 | import importlib 9 | import numpy as np 10 | from tqdm import tqdm 11 | import multiprocessing 12 | from pathlib import Path 13 | from torch.optim.lr_scheduler import CosineAnnealingLR 14 | 15 | sys.path.append("..") 16 | from data_utils.ModelNetDataLoader import ModelNetDataLoader 17 | from data_utils import provider 18 | 19 | BASE_DIR = os.path.dirname(os.path.abspath('__file__')) 20 | ROOT_DIR = BASE_DIR 21 | sys.path.append(os.path.join(ROOT_DIR, 'models')) 22 | sys.path.append(os.path.join(ROOT_DIR, 'log')) 23 | 24 | 25 | 26 | 27 | def test(model, loader, num_class=40): 28 | mean_correct = [] 29 | class_acc = np.zeros((num_class,3)) 30 | for j, data in tqdm(enumerate(loader), total=len(loader)): 31 | points, target = data 32 | target = target[:, 0] 33 | points = points.transpose(2, 1) 34 | points, target = points.cuda(), target.cuda() 35 | classifier = model.eval() 36 | pred, _ = classifier(points.float()) 37 | pred_choice = pred.data.max(1)[1] 38 | for cat in np.unique(target.cpu()): 39 | cat = int(cat) 40 | classacc = pred_choice[target==cat].eq(target[target==cat].long().data).cpu().sum() 41 | class_acc[cat,0]+= classacc.item()/float(points[target==cat].size()[0]) 42 | class_acc[cat,1]+=1 43 | correct = pred_choice.eq(target.long().data).cpu().sum() 44 | mean_correct.append(correct.item()/float(points.size()[0])) 45 | class_acc[:,2] = class_acc[:,0]/ class_acc[:,1] 46 | class_acc = np.mean(class_acc[:,2]) 47 | instance_acc = np.mean(mean_correct) 48 | return instance_acc, class_acc 49 | 50 | 51 | 52 | def main(args): 53 | def log_string(str): 54 | logger.info(str) 55 | print(str) 56 | 57 | ### Hyper Parameters ### 58 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 59 | 60 | ### Create Dir ### 61 | timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))+'-'+args.model+'-'+args.task 62 | experiment_dir = Path('log/') 63 | experiment_dir.mkdir(exist_ok=True) 64 | experiment_dir = experiment_dir.joinpath('classification') 65 | experiment_dir.mkdir(exist_ok=True) 66 | if args.log_dir is None: 67 | experiment_dir = experiment_dir.joinpath(timestr) 68 | else: 69 | experiment_dir = experiment_dir.joinpath(args.log_dir) 70 | experiment_dir.mkdir(exist_ok=True) 71 | checkpoints_dir = experiment_dir.joinpath('checkpoints/') 72 | checkpoints_dir.mkdir(exist_ok=True) 73 | log_dir = experiment_dir.joinpath('logs/') 74 | log_dir.mkdir(exist_ok=True) 75 | 76 | ### Log ### 77 | logger = logging.getLogger("Model") 78 | logger.setLevel(logging.INFO) 79 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 80 | file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model)) 81 | file_handler.setLevel(logging.INFO) 82 | file_handler.setFormatter(formatter) 83 | logger.addHandler(file_handler) 84 | log_string('PARAMETER ...') 85 | log_string(args) 86 | 87 | 88 | ### Data Loading ### 89 | log_string('Load dataset ...') 90 | TRAIN_DATASET = ModelNetDataLoader(root=args.data_root, 91 | tasks=args.train_tasks, 92 | labels=args.train_labels, 93 | partition='train', 94 | npoint=args.num_point, 95 | normal_channel=args.normal) 96 | TEST_DATASET = ModelNetDataLoader(root=args.data_root, 97 | tasks=args.test_tasks, 98 | labels=args.test_labels, 99 | partition='test', 100 | npoint=args.num_point, 101 | normal_channel=args.normal) 102 | trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batch_size, shuffle=True, num_workers=0) 103 | testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=0) 104 | 105 | 106 | ### Model Loading ### 107 | num_class = 40 108 | MODEL = importlib.import_module(args.model) 109 | shutil.copy('models/%s.py' % args.model, str(experiment_dir)) 110 | # pointnet_util pointnet_util, 111 | shutil.copy('models/%s.py' % args.ults, str(experiment_dir)) 112 | 113 | classifier = MODEL.get_model(num_class,normal_channel=args.normal).cuda() 114 | criterion = MODEL.get_loss().cuda() 115 | 116 | try: 117 | checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth') 118 | start_epoch = checkpoint['epoch'] 119 | classifier.load_state_dict(checkpoint['model_state_dict']) 120 | log_string('Use pretrain model') 121 | except: 122 | log_string('No existing model, starting training from scratch...') 123 | start_epoch = 0 124 | 125 | if args.optimizer == 'Adam': 126 | optimizer = torch.optim.Adam( 127 | classifier.parameters(), 128 | lr=args.learning_rate, 129 | betas=(0.9, 0.999), 130 | eps=1e-08, 131 | weight_decay=args.decay_rate 132 | ) 133 | else: 134 | optimizer = torch.optim.SGD(classifier.parameters(), 135 | lr=0.01, 136 | momentum=0.9) 137 | 138 | 139 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_decay_step, gamma=args.lr_decay_rate) 140 | 141 | global_epoch = 0 142 | global_step = 0 143 | best_instance_acc = 0.0 144 | best_class_acc = 0.0 145 | mean_correct = [] 146 | 147 | 148 | ### Training ### 149 | logger.info('Start training ...') 150 | for epoch in range(start_epoch, args.epoch): 151 | 152 | 153 | log_string('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch)) 154 | 155 | for batch_id, data in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9): 156 | points, target = data 157 | 158 | points = points.data.numpy() 159 | 160 | points = torch.Tensor(points) 161 | target = target[:, 0] 162 | 163 | points = points.transpose(2, 1) 164 | points, target = points.cuda(), target.cuda() 165 | optimizer.zero_grad() 166 | 167 | classifier = classifier.train() 168 | pred, trans_feat = classifier(points) 169 | loss = criterion(pred, target.long(), trans_feat) 170 | pred_choice = pred.data.max(1)[1] 171 | correct = pred_choice.eq(target.long().data).cpu().sum() 172 | mean_correct.append(correct.item() / float(points.size()[0])) 173 | loss.backward() 174 | optimizer.step() 175 | global_step += 1 176 | 177 | train_instance_acc = np.mean(mean_correct) 178 | log_string('Train Instance Accuracy: %f' % train_instance_acc) 179 | 180 | 181 | if (epoch % 2 == 0) or (epoch == args.epoch): 182 | with torch.no_grad(): 183 | instance_acc, class_acc = test(classifier.eval(), testDataLoader) 184 | 185 | if (instance_acc >= best_instance_acc): 186 | best_instance_acc = instance_acc 187 | best_epoch = epoch + 1 188 | 189 | if (class_acc >= best_class_acc): 190 | best_class_acc = class_acc 191 | log_string('Test Instance Accuracy: %f, Class Accuracy: %f'% (instance_acc, class_acc)) 192 | log_string('Best Instance Accuracy: %f, Class Accuracy: %f'% (best_instance_acc, best_class_acc)) 193 | 194 | if (instance_acc >= best_instance_acc): 195 | logger.info('Save model...') 196 | savepath = str(checkpoints_dir) + '/best_model.pth' 197 | log_string('Saving at %s'% savepath) 198 | state = { 199 | 'epoch': best_epoch, 200 | 'instance_acc': instance_acc, 201 | 'class_acc': class_acc, 202 | 'model_state_dict': classifier.state_dict(), 203 | 'optimizer_state_dict': optimizer.state_dict(), 204 | } 205 | torch.save(state, savepath) 206 | global_epoch += 1 207 | 208 | # adjust lr 209 | scheduler.step() 210 | 211 | logger.info('End of training...') 212 | 213 | 214 | 215 | 216 | 217 | def parseArgs(): 218 | """ Argument parser for configuring model training """ 219 | parser = argparse.ArgumentParser(description='RobustPointSet trainer') 220 | parser.add_argument('--batch_size', type=int, default=32) 221 | parser.add_argument('--model', type=str, default='pointnet_cls', help='point cloud model') 222 | parser.add_argument('--ults', type=str, default='pointnet_util', help='help functions for point cloud model') 223 | parser.add_argument('--task', type=str, default='s1', help='Stragety 1 or 2') 224 | parser.add_argument('--epoch', type=int, default=300, help='training epoch') 225 | parser.add_argument('--learning_rate', type=float, default=0.001) 226 | parser.add_argument('--gpu', type=str, default='0', help='gpu device index') 227 | parser.add_argument('--num_point', type=int, default=2048, help='number of points') 228 | parser.add_argument('--optimizer', type=str, default='Adam', help='optimizer') 229 | parser.add_argument('--log_dir', type=str, default=None, help='directory to store training log') 230 | parser.add_argument('--decay_rate', type=float, default=1e-4) 231 | parser.add_argument('--lr_decay_rate', type=float, default=0.5) 232 | parser.add_argument('--lr_decay_step', type=int, default=100) 233 | parser.add_argument('--lr_clip', type=float, default=1e-7) 234 | parser.add_argument('--normal', default=False, help='Whether to use normal information [default: False]') 235 | parser.add_argument('--data_root', type=str, default='data/', help='data directory') 236 | parser.add_argument('--train_tasks', type=str, nargs='+', required=True, help="List of RobustPointSet files to be trained on") 237 | parser.add_argument('--test_tasks', type=str, nargs='+', required=True, help="List of RobustPointSet files to be tested on during training") 238 | return parser.parse_args() 239 | 240 | 241 | 242 | if __name__ == '__main__': 243 | """ 244 | Available data: original, jitter, translate, missing_part, sparse, rotation, occlusion 245 | Example command for strategy 1: 246 | python train.py --train_tasks train_original.npy --test_tasks test_original.npy 247 | python train.py --train_tasks train_original.npy --test_tasks test_sparse.npy 248 | Example command for strategy 2: 249 | python train.py --train_tasks train_original.npy train_jitter.npy train_translate.npy train_missing_part.npy train_sparse.npy train_rotation.npy --test_tasks test_occlusion.npy 250 | python train.py --train_tasks train_original.npy train_translate.npy train_missing_part.npy train_sparse.npy train_rotation.npy train_occlusion.npy --test_tasks test_jitter.npy 251 | """ 252 | args = parseArgs() 253 | args.train_labels = ['train_labels.npy']*len(args.train_tasks) 254 | args.test_labels =['test_labels.npy']*len(args.test_tasks) 255 | main(args) --------------------------------------------------------------------------------