├── .ipynb_checkpoints
└── train-checkpoint.py
├── LICENSE.md
├── README.md
├── RobustPointSet.png
├── data_utils
├── ModelNetDataLoader.py
└── provider.py
├── models
├── pointnet.py
├── pointnet_cls.py
└── pointnet_util.py
├── test.py
└── train.py
/.ipynb_checkpoints/train-checkpoint.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | import argparse
5 | import datetime
6 | import logging
7 | import shutil
8 | import importlib
9 | import numpy as np
10 | from tqdm import tqdm
11 | import multiprocessing
12 | from pathlib import Path
13 | from torch.optim.lr_scheduler import CosineAnnealingLR
14 |
15 | sys.path.append("..")
16 | from data_utils.ModelNetDataLoader import ModelNetDataLoader
17 | from data_utils import provider
18 |
19 | BASE_DIR = os.path.dirname(os.path.abspath('__file__'))
20 | ROOT_DIR = BASE_DIR
21 | sys.path.append(os.path.join(ROOT_DIR, 'models'))
22 | sys.path.append(os.path.join(ROOT_DIR, 'log'))
23 |
24 |
25 |
26 |
27 | def test(model, loader, num_class=40):
28 | mean_correct = []
29 | class_acc = np.zeros((num_class,3))
30 | for j, data in tqdm(enumerate(loader), total=len(loader)):
31 | points, target = data
32 | target = target[:, 0]
33 | points = points.transpose(2, 1)
34 | points, target = points.cuda(), target.cuda()
35 | classifier = model.eval()
36 | pred, _ = classifier(points.float())
37 | pred_choice = pred.data.max(1)[1]
38 | for cat in np.unique(target.cpu()):
39 | cat = int(cat)
40 | classacc = pred_choice[target==cat].eq(target[target==cat].long().data).cpu().sum()
41 | class_acc[cat,0]+= classacc.item()/float(points[target==cat].size()[0])
42 | class_acc[cat,1]+=1
43 | correct = pred_choice.eq(target.long().data).cpu().sum()
44 | mean_correct.append(correct.item()/float(points.size()[0]))
45 | class_acc[:,2] = class_acc[:,0]/ class_acc[:,1]
46 | class_acc = np.mean(class_acc[:,2])
47 | instance_acc = np.mean(mean_correct)
48 | return instance_acc, class_acc
49 |
50 |
51 |
52 | def main(args):
53 | def log_string(str):
54 | logger.info(str)
55 | print(str)
56 |
57 | ### Hyper Parameters ###
58 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
59 |
60 | ### Create Dir ###
61 | timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))+'-'+args.model+'-'+args.task
62 | experiment_dir = Path('log/')
63 | experiment_dir.mkdir(exist_ok=True)
64 | experiment_dir = experiment_dir.joinpath('classification')
65 | experiment_dir.mkdir(exist_ok=True)
66 | if args.log_dir is None:
67 | experiment_dir = experiment_dir.joinpath(timestr)
68 | else:
69 | experiment_dir = experiment_dir.joinpath(args.log_dir)
70 | experiment_dir.mkdir(exist_ok=True)
71 | checkpoints_dir = experiment_dir.joinpath('checkpoints/')
72 | checkpoints_dir.mkdir(exist_ok=True)
73 | log_dir = experiment_dir.joinpath('logs/')
74 | log_dir.mkdir(exist_ok=True)
75 |
76 | ### Log ###
77 | logger = logging.getLogger("Model")
78 | logger.setLevel(logging.INFO)
79 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
80 | file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
81 | file_handler.setLevel(logging.INFO)
82 | file_handler.setFormatter(formatter)
83 | logger.addHandler(file_handler)
84 | log_string('PARAMETER ...')
85 | log_string(args)
86 |
87 |
88 | ### Data Loading ###
89 | log_string('Load dataset ...')
90 | TRAIN_DATASET = ModelNetDataLoader(root=args.data_root,
91 | tasks=args.train_tasks,
92 | labels=args.train_labels,
93 | partition='train',
94 | npoint=args.num_point,
95 | normal_channel=args.normal)
96 | TEST_DATASET = ModelNetDataLoader(root=args.data_root,
97 | tasks=args.test_tasks,
98 | labels=args.test_labels,
99 | partition='test',
100 | npoint=args.num_point,
101 | normal_channel=args.normal)
102 | trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batch_size, shuffle=True, num_workers=0)
103 | testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=0)
104 |
105 |
106 | ### Model Loading ###
107 | num_class = 40
108 | MODEL = importlib.import_module(args.model)
109 | shutil.copy('models/%s.py' % args.model, str(experiment_dir))
110 | # pointnet_util pointnet_util,
111 | shutil.copy('models/%s.py' % args.ults, str(experiment_dir))
112 |
113 | classifier = MODEL.get_model(num_class,normal_channel=args.normal).cuda()
114 | criterion = MODEL.get_loss().cuda()
115 |
116 | try:
117 | checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth')
118 | start_epoch = checkpoint['epoch']
119 | classifier.load_state_dict(checkpoint['model_state_dict'])
120 | log_string('Use pretrain model')
121 | except:
122 | log_string('No existing model, starting training from scratch...')
123 | start_epoch = 0
124 |
125 | if args.optimizer == 'Adam':
126 | optimizer = torch.optim.Adam(
127 | classifier.parameters(),
128 | lr=args.learning_rate,
129 | betas=(0.9, 0.999),
130 | eps=1e-08,
131 | weight_decay=args.decay_rate
132 | )
133 | else:
134 | optimizer = torch.optim.SGD(classifier.parameters(),
135 | lr=0.01,
136 | momentum=0.9)
137 |
138 |
139 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_decay_step, gamma=args.lr_decay_rate)
140 |
141 | global_epoch = 0
142 | global_step = 0
143 | best_instance_acc = 0.0
144 | best_class_acc = 0.0
145 | mean_correct = []
146 |
147 |
148 | ### Training ###
149 | logger.info('Start training ...')
150 | for epoch in range(start_epoch, args.epoch):
151 |
152 |
153 | log_string('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch))
154 |
155 | for batch_id, data in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9):
156 | points, target = data
157 |
158 | points = points.data.numpy()
159 |
160 | points = torch.Tensor(points)
161 | target = target[:, 0]
162 |
163 | points = points.transpose(2, 1)
164 | points, target = points.cuda(), target.cuda()
165 | optimizer.zero_grad()
166 |
167 | classifier = classifier.train()
168 | pred, trans_feat = classifier(points)
169 | loss = criterion(pred, target.long(), trans_feat)
170 | pred_choice = pred.data.max(1)[1]
171 | correct = pred_choice.eq(target.long().data).cpu().sum()
172 | mean_correct.append(correct.item() / float(points.size()[0]))
173 | loss.backward()
174 | optimizer.step()
175 | global_step += 1
176 |
177 | train_instance_acc = np.mean(mean_correct)
178 | log_string('Train Instance Accuracy: %f' % train_instance_acc)
179 |
180 |
181 | if (epoch % 2 == 0) or (epoch == args.epoch):
182 | with torch.no_grad():
183 | instance_acc, class_acc = test(classifier.eval(), testDataLoader)
184 |
185 | if (instance_acc >= best_instance_acc):
186 | best_instance_acc = instance_acc
187 | best_epoch = epoch + 1
188 |
189 | if (class_acc >= best_class_acc):
190 | best_class_acc = class_acc
191 | log_string('Test Instance Accuracy: %f, Class Accuracy: %f'% (instance_acc, class_acc))
192 | log_string('Best Instance Accuracy: %f, Class Accuracy: %f'% (best_instance_acc, best_class_acc))
193 |
194 | if (instance_acc >= best_instance_acc):
195 | logger.info('Save model...')
196 | savepath = str(checkpoints_dir) + '/best_model.pth'
197 | log_string('Saving at %s'% savepath)
198 | state = {
199 | 'epoch': best_epoch,
200 | 'instance_acc': instance_acc,
201 | 'class_acc': class_acc,
202 | 'model_state_dict': classifier.state_dict(),
203 | 'optimizer_state_dict': optimizer.state_dict(),
204 | }
205 | torch.save(state, savepath)
206 | global_epoch += 1
207 |
208 | # adjust lr
209 | scheduler.step()
210 |
211 | logger.info('End of training...')
212 |
213 |
214 |
215 |
216 |
217 | def parseArgs():
218 | """ Argument parser for configuring model training """
219 | parser = argparse.ArgumentParser(description='RobustPointSet trainer')
220 | parser.add_argument('--batch_size', type=int, default=32)
221 | parser.add_argument('--model', type=str, default='pointnet_cls', help='point cloud model')
222 | parser.add_argument('--ults', type=str, default='pointnet_util', help='help functions for point cloud model')
223 | parser.add_argument('--task', type=str, default='s1', help='Stragety 1 or 2')
224 | parser.add_argument('--epoch', type=int, default=300, help='training epoch')
225 | parser.add_argument('--learning_rate', type=float, default=0.001)
226 | parser.add_argument('--gpu', type=str, default='0', help='gpu device index')
227 | parser.add_argument('--num_point', type=int, default=2048, help='number of points')
228 | parser.add_argument('--optimizer', type=str, default='Adam', help='optimizer')
229 | parser.add_argument('--log_dir', type=str, default=None, help='directory to store training log')
230 | parser.add_argument('--decay_rate', type=float, default=1e-4)
231 | parser.add_argument('--lr_decay_rate', type=float, default=0.5)
232 | parser.add_argument('--lr_decay_step', type=int, default=100)
233 | parser.add_argument('--lr_clip', type=float, default=1e-7)
234 | parser.add_argument('--normal', default=False, help='Whether to use normal information [default: False]')
235 | parser.add_argument('--data_root', type=str, default='data/', help='data directory')
236 | parser.add_argument('--train_tasks', type=str, nargs='+', required=True, help="List of RobustPointSet files to be trained on")
237 | parser.add_argument('--test_tasks', type=str, nargs='+', required=True, help="List of RobustPointSet files to be tested on during training")
238 | return parser.parse_args()
239 |
240 |
241 |
242 | if __name__ == '__main__':
243 | """
244 | Available data: original, jitter, translate, missing_part, sparse, rotation, occlusion
245 | Example command for strategy 1:
246 | python train.py --train_tasks train_original.npy --test_tasks test_original.npy
247 | python train.py --train_tasks train_original.npy --test_tasks test_sparse.npy
248 | Example command for strategy 2:
249 | python train.py --train_tasks train_original.npy train_jitter.npy train_translate.npy train_missing_part.npy train_sparse.npy train_rotation.npy --test_tasks test_occlusion.npy
250 | python train.py --train_tasks train_original.npy train_translate.npy train_missing_part.npy train_sparse.npy train_rotation.npy train_occlusion.npy --test_tasks test_jitter.npy
251 | """
252 | args = parseArgs()
253 | args.train_labels = ['train_labels.npy']*len(args.train_tasks)
254 | args.test_labels =['test_labels.npy']*len(args.test_tasks)
255 | main(args)
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | # Robust Point Set
2 |
3 | This dataset is provided for the convenience of academic research only, and is provided without any representations or warranties, including warranties of non-infringement or fitness for a particular purpose.
4 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RobustPointSet: A Dataset for Benchmarking Robustness of Point Cloud Classifiers, RobustML - ICLR 2021
2 | _Robust and Reliable Machine Learning in the Real World Workshop (RobustML), ICLR 2021_
3 |
4 | A benchmark dataset to facilitate augmentation-independent robustness analysis of point cloud classifiers. RobustPointSet comes with 6 different transformations: Noise, Translation, Missing part, Sparse, Rotation, and Occlusion.
5 |
6 |
7 |

8 |
9 |
10 |
11 | ### Evaluation Strategies
12 |
13 | We test two different evaluation strategies on more than 10 models:
14 |
15 | #### Strategy 1 (training-domain validation)
16 | For this strategy, we train on `train_original.npy` without applying any data-augmentation, select the best performing model on `test_original.npy` (to be consistent with baselines), and test on each test set (i.e. `test_*.npy` ) separately.
17 |
18 | #### Strategy 2 (leave-one-out validation)
19 | For this strategy, each time we concatenate 6 train sets (i.e. the `train_*.npy` ones), and test on the test set (i.e. `test_*.npy` ) of the taken-out group. We repeat this process for all the groups. For example, we train with concatenation of `{train_original.npy, train_noise.npy, train_missing_part.npy, train_occlusion.npy, train_rotation.npy, train_sparse.npy}` and test on `test_translate.npy`. Similar to strategy 1, we don't apply any data-augmentation here. For both the strategies, the same label files can be used i.e. `labels_train.npy` and `labels_test.npy`.
20 |
21 | -----------------
22 |
23 | ### Benchmarks
24 |
25 | Table 1: Training-domain validation results on our RobustPointSet test sets. The *Noise* column, for example, shows the result of training on the *Original* train set and testing with the *Noise* test set. RotInv refers to rotation-invariant models.
26 |
27 |
28 |
29 | | Type | Method | Original | Noise | Translation | Missing part | Sparse | Rotation | Occlusion | Average |
30 | |:----:|:--------------------|:---------:|:---------:|:-----------:|:------------:|:---------:|:--------:|:---------:|:---------:|
31 | |General | PointNet | 89.06 | **74.72** | 79.66 | 81.52 | **60.53** | 8.83 | 39.47 | **61.97** |
32 | |General | PointNet++ (MSG) | 91.27 | 5.73 | 91.31 | 53.69 | 6.65 | 13.02 | 64.18 | 46.55 |
33 | |General | PointNet++ (SSG) | 91.47 | 14.90 | 91.07 | 50.24 | 8.85 | 12.70 | 70.23 | 48.49 |
34 | |General | DGCNN | **92.52** | 57.56 | **91.99** | 85.40 | 9.34 | 13.43 | **78.72** | 61.28 |
35 | |General | PointMask | 88.53 | 73.14 | 78.20 | 81.48 | 58.23 | 8.02 | 39.18 | 60.97 |
36 | |General | DensePoint | 90.96 | 53.28 | 90.72 | 84.49 | 15.52 | 12.76 | 67.67 | 59.40 |
37 | |General | PointCNN | 87.66 | 45.55 | 82.85 | 77.60 | 4.01 | 11.50 | 59.50 | 52.67 |
38 | |General | PointConv | 91.15 | 20.71 | 90.99 | 84.09 | 8.65 | 12.38 | 45.83 | 50.54 |
39 | |General | Relation-Shape-CNN | 91.77 | 48.06 | 91.29 | **85.98** | 23.18 | 11.51 | 75.61 | 61.06 |
40 | |RotInv | SPHnet | 79.18 | 7.22 | **79.18** | 4.22 | 1.26 | 79.18 | 34.33 | 40.65 |
41 | |RotInv | PRIN | 73.66 | 30.19 | 41.21 | 44.17 | 4.17 | 68.56 | 31.56 | 41.93 |
42 |
43 |
44 | ----------------
45 |
46 | Table 2: Leave-one-out validation strategy classification results on our RobustPointSet test sets. For example, the *Noise* column shows the result of training on *{Original, Translation, Missing part, Sparse, Rotation,Occlusion}* train sets and testing with the *Noise* test set. RotInv refers to rotation-invariant models.
47 |
48 |
49 | | Type | Method | Original | Noise | Translation | Missing part | Sparse | Rotation | Occlusion | Average |
50 | |:----:|:--------------------|:---------:|:---------:|:-----------:|:------------:|:---------:|:---------:|:---------:|:---------:|
51 | |General | PointNet | 88.35 | 72.61 | 81.53 | 82.87 | **69.28** | 9.42 | 35.96 | **62.86** |
52 | |General | PointNet++ (MSG) | 91.55 | 50.92 | 91.43 | 77.16 | 16.19 | 12.26 | **70.39** | 58.56 |
53 | |General | PointNet++ (SSG) | 91.76 | 49.33 | 91.10 | 78.36 | 16.72 | 11.27 | 68.33 | 58.12 |
54 | |General | DGCNN | **92.38** | 66.95 | 91.17 | 85.40 | 6.49 | 14.03 | 68.79 | 60.74 |
55 | |General | PointMask | 88.03 | **73.95** | 80.80 | 82.83 | 63.64 | 8.97 | 36.69 | 62.13 |
56 | |General | DensePoint | 91.00 | 42.38 | 90.64 | 85.70 | 20.66 | 8.55 | 47.89 | 55.26 |
57 | |General | PointCNN | 88.91 | 73.10 | 87.46 | 82.06 | 7.18 | 13.95 | 52.66 | 57.90 |
58 | |General | PointConv | 91.07 | 66.19 | **91.51** | 84.01 | 19.63 | 11.62 | 44.07 | 58.30 |
59 | |General | Relation-Shape-CNN | 90.52 | 36.95 | 91.33 | **85.82** | 24.59 | 8.23 | 60.09 | 56.79 |
60 | |RotInv | SPHnet | 79.30 | 8.24 | 76.02 | 17.94 | 6.33 | **78.86** | 35.96 | 43.23 |
61 | |RotInv | PRIN | 76.54 | 55.35 | 56.36 | 59.20 | 4.05 | 73.30 | 36.91 | 51.67 |
62 |
63 |
64 | -----------------
65 |
66 | ### Sample train and test codes
67 |
68 | The `trian.py` and `test.py` files are sample codes to train and test PointNet. The data loaders etc. can be used to train other models. Codes are adopted from [here](https://github.com/yanx27/Pointnet_Pointnet2_pytorch).
69 |
70 | -----------------
71 |
72 | ### Publication
73 |
74 | Please cite the paper below if you use RobustPointSet in your research.
75 |
76 | [RobustPointSet: A Dataset for Benchmarking Robustness of Point Cloud Classifiers](https://arxiv.org/abs/2011.11572)
77 |
78 | ```
79 | @article{taghanaki2020robustpointset,
80 | title={RobustPointSet: A Dataset for Benchmarking Robustness of Point Cloud Classifiers},
81 | author={{Asgari Taghanaki}, Saeid and Luo, Jieliang and Zhang, Ran and Wang, Ye and Jayaraman, {Pradeep Kumar} and Jatavallabhula, {Krishna Murthy}},
82 | year={2020},
83 | journal={arXiv preprint arXiv:2011.11572}
84 | }
85 | ```
86 | -----------------
87 | ### Download
88 | The dateset consists of two parts: [Part I](https://github.com/AutodeskAILab/RobustPointSet/releases/download/v1.0/RobustPointSet.z01) and [Part II](https://github.com/AutodeskAILab/RobustPointSet/releases/download/v1.0/RobustPointSet.zip). Please download both parts and unzip Part I, which will automatically extract the two parts into the same folder.
89 |
90 | -----------------
91 | ### Reference
92 | We use the following implementations with minor modifications for our evaluations.
93 |
94 | * PointNet: [https://github.com/TianzhongSong/PointNet-Keras](https://github.com/TianzhongSong/PointNet-Keras)
95 | * PointNet++: [https://github.com/yanx27/Pointnet_Pointnet2_pytorch](https://github.com/yanx27/Pointnet_Pointnet2_pytorch)
96 | * DGCNN: [https://github.com/WangYueFt/dgcnn](https://github.com/WangYueFt/dgcnn)
97 | * DensePoint: [https://github.com/Yochengliu/DensePoint](https://github.com/Yochengliu/DensePoint)
98 | * PointCNN: [https://github.com/hxdengBerkeley/PointCNN.Pytorch](https://github.com/hxdengBerkeley/PointCNN.Pytorch)
99 | * PointConv: [https://github.com/DylanWusee/pointconv_pytorch](https://github.com/DylanWusee/pointconv_pytorch)
100 | * Relation-Shape-CNN: [https://github.com/Yochengliu/Relation-Shape-CNN](https://github.com/Yochengliu/Relation-Shape-CNN)
101 | * SPHnet: [https://github.com/adrienPoulenard/SPHnet](https://github.com/adrienPoulenard/SPHnet)
102 | * PRIN: [https://github.com/qq456cvb/PRIN](https://github.com/qq456cvb/PRIN)
103 | * PointMask: [https://github.com/asgsaeid/PointMask](https://github.com/asgsaeid/PointMask)
104 |
105 | Note:
106 | For DensePoint, Relation-Shape-CNN, and PRIN, you will need to run the code on an older version (0.3 & 0.4) of PyTorch. Below are the steps to create a Conda envrionment with an older version of PyTorch:
107 |
108 | ```
109 | - wget -c https://repo.continuum.io/archive/Anaconda3-5.2.0-Linux-x86_64.sh
110 | - chmod +x Anaconda3-5.2.0-Linux-x86_64.sh
111 | - ./Anaconda3-5.2.0-Linux-x86_64.sh
112 | - conda create -n myenv python=3.6
113 | - conda install pytorch=0.4.1 cuda92 -c pytorch -y
114 | ```
115 |
116 | -----------------
117 | ### License
118 |
119 | Please refer to the [dataset license](https://github.com/AutodeskAILab/RobustPointSet/blob/main/LICENSE.md).
120 |
121 |
122 |
--------------------------------------------------------------------------------
/RobustPointSet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AutodeskAILab/RobustPointSet/f5439a242da5076d2ca632654a3236ec93a2f197/RobustPointSet.png
--------------------------------------------------------------------------------
/data_utils/ModelNetDataLoader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from torch.utils.data import Dataset
4 |
5 |
6 | def pc_normalize(pc):
7 | centroid = np.mean(pc, axis=0)
8 | pc = pc - centroid
9 | m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
10 | pc = pc / m
11 | return pc
12 |
13 | def farthest_point_sample(point, npoint):
14 | """
15 | Input:
16 | xyz: pointcloud data, [N, D]
17 | npoint: number of samples
18 | Return:
19 | centroids: sampled pointcloud index, [npoint, D]
20 | """
21 | N, D = point.shape
22 | xyz = point[:,:3]
23 | centroids = np.zeros((npoint,))
24 | distance = np.ones((N,)) * 1e10
25 | farthest = np.random.randint(0, N)
26 | for i in range(npoint):
27 | centroids[i] = farthest
28 | centroid = xyz[farthest, :]
29 | dist = np.sum((xyz - centroid) ** 2, -1)
30 | mask = dist < distance
31 | distance[mask] = dist[mask]
32 | farthest = np.argmax(distance, -1)
33 | point = point[centroids.astype(np.int32)]
34 | return point
35 |
36 |
37 |
38 | class ModelNetDataLoader(Dataset):
39 | def __init__(self, root, tasks, labels, partition='train', npoint=2048, uniform=False, normal_channel=False, cache_size=15000):
40 | self.root = root
41 | self.npoints = npoint
42 | self.uniform = uniform
43 | self.normal_channel = normal_channel
44 | self.data, self.label = load_data(root, tasks, labels)
45 | self.partition = partition
46 | print('The number of ' + partition + ' data: ' + str(self.data.shape[0]))
47 |
48 | def __len__(self):
49 | return self.data.shape[0]
50 |
51 | def _get_item(self, index):
52 | pointcloud = self.data[index][:self.npoints]
53 | label = self.label[index]
54 | if self.partition == 'train':
55 | np.random.shuffle(pointcloud)
56 |
57 | return pointcloud, label
58 |
59 | def __getitem__(self, index):
60 | return self._get_item(index)
61 |
62 |
63 |
64 | def load_data(root, tasks, labels):
65 | all_data = []
66 | all_label = []
67 | for i in range(len(tasks)):
68 | data = np.load(os.path.join(root, tasks[i]))
69 | label = np.load(os.path.join(root, labels[i]))
70 |
71 | all_data.append(data)
72 | all_label.append(label)
73 | all_data = np.concatenate(all_data, axis=0)
74 | all_label = np.concatenate(all_label, axis=0)
75 |
76 |
77 |
78 |
79 | return all_data, all_label
--------------------------------------------------------------------------------
/data_utils/provider.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def normalize_data(batch_data):
4 | """ Normalize the batch data, use coordinates of the block centered at origin,
5 | Input:
6 | BxNxC array
7 | Output:
8 | BxNxC array
9 | """
10 | B, N, C = batch_data.shape
11 | normal_data = np.zeros((B, N, C))
12 | for b in range(B):
13 | pc = batch_data[b]
14 | centroid = np.mean(pc, axis=0)
15 | pc = pc - centroid
16 | m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
17 | pc = pc / m
18 | normal_data[b] = pc
19 | return normal_data
20 |
21 |
22 | def shuffle_data(data, labels):
23 | """ Shuffle data and labels.
24 | Input:
25 | data: B,N,... numpy array
26 | label: B,... numpy array
27 | Return:
28 | shuffled data, label and shuffle indices
29 | """
30 | idx = np.arange(len(labels))
31 | np.random.shuffle(idx)
32 | return data[idx, ...], labels[idx], idx
33 |
34 | def shuffle_points(batch_data):
35 | """ Shuffle orders of points in each point cloud -- changes FPS behavior.
36 | Use the same shuffling idx for the entire batch.
37 | Input:
38 | BxNxC array
39 | Output:
40 | BxNxC array
41 | """
42 | idx = np.arange(batch_data.shape[1])
43 | np.random.shuffle(idx)
44 | return batch_data[:,idx,:]
45 |
46 | def rotate_point_cloud(batch_data):
47 | """ Randomly rotate the point clouds to augument the dataset
48 | rotation is per shape based along up direction
49 | Input:
50 | BxNx3 array, original batch of point clouds
51 | Return:
52 | BxNx3 array, rotated batch of point clouds
53 | """
54 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
55 | for k in range(batch_data.shape[0]):
56 | rotation_angle = np.random.uniform() * 2 * np.pi
57 | cosval = np.cos(rotation_angle)
58 | sinval = np.sin(rotation_angle)
59 | rotation_matrix = np.array([[cosval, 0, sinval],
60 | [0, 1, 0],
61 | [-sinval, 0, cosval]])
62 | shape_pc = batch_data[k, ...]
63 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
64 | return rotated_data
65 |
66 | def rotate_point_cloud_z(batch_data):
67 | """ Randomly rotate the point clouds to augument the dataset
68 | rotation is per shape based along up direction
69 | Input:
70 | BxNx3 array, original batch of point clouds
71 | Return:
72 | BxNx3 array, rotated batch of point clouds
73 | """
74 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
75 | for k in range(batch_data.shape[0]):
76 | rotation_angle = np.random.uniform() * 2 * np.pi
77 | cosval = np.cos(rotation_angle)
78 | sinval = np.sin(rotation_angle)
79 | rotation_matrix = np.array([[cosval, sinval, 0],
80 | [-sinval, cosval, 0],
81 | [0, 0, 1]])
82 | shape_pc = batch_data[k, ...]
83 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
84 | return rotated_data
85 |
86 | def rotate_point_cloud_with_normal(batch_xyz_normal):
87 | ''' Randomly rotate XYZ, normal point cloud.
88 | Input:
89 | batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal
90 | Output:
91 | B,N,6, rotated XYZ, normal point cloud
92 | '''
93 | for k in range(batch_xyz_normal.shape[0]):
94 | rotation_angle = np.random.uniform() * 2 * np.pi
95 | cosval = np.cos(rotation_angle)
96 | sinval = np.sin(rotation_angle)
97 | rotation_matrix = np.array([[cosval, 0, sinval],
98 | [0, 1, 0],
99 | [-sinval, 0, cosval]])
100 | shape_pc = batch_xyz_normal[k,:,0:3]
101 | shape_normal = batch_xyz_normal[k,:,3:6]
102 | batch_xyz_normal[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
103 | batch_xyz_normal[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix)
104 | return batch_xyz_normal
105 |
106 | def rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, angle_clip=0.18):
107 | """ Randomly perturb the point clouds by small rotations
108 | Input:
109 | BxNx6 array, original batch of point clouds and point normals
110 | Return:
111 | BxNx3 array, rotated batch of point clouds
112 | """
113 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
114 | for k in range(batch_data.shape[0]):
115 | angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip)
116 | Rx = np.array([[1,0,0],
117 | [0,np.cos(angles[0]),-np.sin(angles[0])],
118 | [0,np.sin(angles[0]),np.cos(angles[0])]])
119 | Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])],
120 | [0,1,0],
121 | [-np.sin(angles[1]),0,np.cos(angles[1])]])
122 | Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0],
123 | [np.sin(angles[2]),np.cos(angles[2]),0],
124 | [0,0,1]])
125 | R = np.dot(Rz, np.dot(Ry,Rx))
126 | shape_pc = batch_data[k,:,0:3]
127 | shape_normal = batch_data[k,:,3:6]
128 | rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), R)
129 | rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), R)
130 | return rotated_data
131 |
132 |
133 | def rotate_point_cloud_by_angle(batch_data, rotation_angle):
134 | """ Rotate the point cloud along up direction with certain angle.
135 | Input:
136 | BxNx3 array, original batch of point clouds
137 | Return:
138 | BxNx3 array, rotated batch of point clouds
139 | """
140 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
141 | for k in range(batch_data.shape[0]):
142 | #rotation_angle = np.random.uniform() * 2 * np.pi
143 | cosval = np.cos(rotation_angle)
144 | sinval = np.sin(rotation_angle)
145 | rotation_matrix = np.array([[cosval, 0, sinval],
146 | [0, 1, 0],
147 | [-sinval, 0, cosval]])
148 | shape_pc = batch_data[k,:,0:3]
149 | rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
150 | return rotated_data
151 |
152 | def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle):
153 | """ Rotate the point cloud along up direction with certain angle.
154 | Input:
155 | BxNx6 array, original batch of point clouds with normal
156 | scalar, angle of rotation
157 | Return:
158 | BxNx6 array, rotated batch of point clouds iwth normal
159 | """
160 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
161 | for k in range(batch_data.shape[0]):
162 | #rotation_angle = np.random.uniform() * 2 * np.pi
163 | cosval = np.cos(rotation_angle)
164 | sinval = np.sin(rotation_angle)
165 | rotation_matrix = np.array([[cosval, 0, sinval],
166 | [0, 1, 0],
167 | [-sinval, 0, cosval]])
168 | shape_pc = batch_data[k,:,0:3]
169 | shape_normal = batch_data[k,:,3:6]
170 | rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
171 | rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1,3)), rotation_matrix)
172 | return rotated_data
173 |
174 |
175 |
176 | def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18):
177 | """ Randomly perturb the point clouds by small rotations
178 | Input:
179 | BxNx3 array, original batch of point clouds
180 | Return:
181 | BxNx3 array, rotated batch of point clouds
182 | """
183 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
184 | for k in range(batch_data.shape[0]):
185 | angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip)
186 | Rx = np.array([[1,0,0],
187 | [0,np.cos(angles[0]),-np.sin(angles[0])],
188 | [0,np.sin(angles[0]),np.cos(angles[0])]])
189 | Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])],
190 | [0,1,0],
191 | [-np.sin(angles[1]),0,np.cos(angles[1])]])
192 | Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0],
193 | [np.sin(angles[2]),np.cos(angles[2]),0],
194 | [0,0,1]])
195 | R = np.dot(Rz, np.dot(Ry,Rx))
196 | shape_pc = batch_data[k, ...]
197 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R)
198 | return rotated_data
199 |
200 |
201 | def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05):
202 | """ Randomly jitter points. jittering is per point.
203 | Input:
204 | BxNx3 array, original batch of point clouds
205 | Return:
206 | BxNx3 array, jittered batch of point clouds
207 | """
208 | B, N, C = batch_data.shape
209 | assert(clip > 0)
210 | jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip)
211 | jittered_data += batch_data
212 | return jittered_data
213 |
214 | def shift_point_cloud(batch_data, shift_range=0.1):
215 | """ Randomly shift point cloud. Shift is per point cloud.
216 | Input:
217 | BxNx3 array, original batch of point clouds
218 | Return:
219 | BxNx3 array, shifted batch of point clouds
220 | """
221 | B, N, C = batch_data.shape
222 | shifts = np.random.uniform(-shift_range, shift_range, (B,3))
223 | for batch_index in range(B):
224 | batch_data[batch_index,:,:] += shifts[batch_index,:]
225 | return batch_data
226 |
227 |
228 | def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25):
229 | """ Randomly scale the point cloud. Scale is per point cloud.
230 | Input:
231 | BxNx3 array, original batch of point clouds
232 | Return:
233 | BxNx3 array, scaled batch of point clouds
234 | """
235 | B, N, C = batch_data.shape
236 | scales = np.random.uniform(scale_low, scale_high, B)
237 | for batch_index in range(B):
238 | batch_data[batch_index,:,:] *= scales[batch_index]
239 | return batch_data
240 |
241 | def random_point_dropout(batch_pc, max_dropout_ratio=0.875):
242 | ''' batch_pc: BxNx3 '''
243 | for b in range(batch_pc.shape[0]):
244 | dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875
245 | drop_idx = np.where(np.random.random((batch_pc.shape[1]))<=dropout_ratio)[0]
246 | if len(drop_idx)>0:
247 | batch_pc[b,drop_idx,:] = batch_pc[b,0,:] # set to the first point
248 | return batch_pc
--------------------------------------------------------------------------------
/models/pointnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.parallel
4 | import torch.utils.data
5 | from torch.autograd import Variable
6 | import numpy as np
7 | import torch.nn.functional as F
8 |
9 | class STN3d(nn.Module):
10 | def __init__(self, channel):
11 | super(STN3d, self).__init__()
12 | self.conv1 = torch.nn.Conv1d(channel, 64, 1)
13 | self.conv2 = torch.nn.Conv1d(64, 128, 1)
14 | self.conv3 = torch.nn.Conv1d(128, 1024, 1)
15 | self.fc1 = nn.Linear(1024, 512)
16 | self.fc2 = nn.Linear(512, 256)
17 | self.fc3 = nn.Linear(256, 9)
18 | self.relu = nn.ReLU()
19 |
20 | self.bn1 = nn.BatchNorm1d(64)
21 | self.bn2 = nn.BatchNorm1d(128)
22 | self.bn3 = nn.BatchNorm1d(1024)
23 | self.bn4 = nn.BatchNorm1d(512)
24 | self.bn5 = nn.BatchNorm1d(256)
25 |
26 | def forward(self, x):
27 | batchsize = x.size()[0]
28 | x = F.relu(self.bn1(self.conv1(x)))
29 | x = F.relu(self.bn2(self.conv2(x)))
30 | x = F.relu(self.bn3(self.conv3(x)))
31 | x = torch.max(x, 2, keepdim=True)[0]
32 | x = x.view(-1, 1024)
33 |
34 | x = F.relu(self.bn4(self.fc1(x)))
35 | x = F.relu(self.bn5(self.fc2(x)))
36 | x = self.fc3(x)
37 |
38 | iden = Variable(torch.from_numpy(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32))).view(1, 9).repeat(
39 | batchsize, 1)
40 | if x.is_cuda:
41 | iden = iden.cuda()
42 | x = x + iden
43 | x = x.view(-1, 3, 3)
44 | return x
45 |
46 |
47 |
48 |
49 | class STNkd(nn.Module):
50 | def __init__(self, k=64):
51 | super(STNkd, self).__init__()
52 | self.conv1 = torch.nn.Conv1d(k, 64, 1)
53 | self.conv2 = torch.nn.Conv1d(64, 128, 1)
54 | self.conv3 = torch.nn.Conv1d(128, 1024, 1)
55 | self.fc1 = nn.Linear(1024, 512)
56 | self.fc2 = nn.Linear(512, 256)
57 | self.fc3 = nn.Linear(256, k * k)
58 | self.relu = nn.ReLU()
59 |
60 | self.bn1 = nn.BatchNorm1d(64)
61 | self.bn2 = nn.BatchNorm1d(128)
62 | self.bn3 = nn.BatchNorm1d(1024)
63 | self.bn4 = nn.BatchNorm1d(512)
64 | self.bn5 = nn.BatchNorm1d(256)
65 |
66 | self.k = k
67 |
68 | def forward(self, x):
69 | batchsize = x.size()[0]
70 | x = F.relu(self.bn1(self.conv1(x)))
71 | x = F.relu(self.bn2(self.conv2(x)))
72 | x = F.relu(self.bn3(self.conv3(x)))
73 | x = torch.max(x, 2, keepdim=True)[0]
74 | x = x.view(-1, 1024)
75 |
76 | x = F.relu(self.bn4(self.fc1(x)))
77 | x = F.relu(self.bn5(self.fc2(x)))
78 | x = self.fc3(x)
79 |
80 | iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1, self.k * self.k).repeat(
81 | batchsize, 1)
82 | if x.is_cuda:
83 | iden = iden.cuda()
84 | x = x + iden
85 | x = x.view(-1, self.k, self.k)
86 | return x
87 |
88 |
89 |
90 | class PointNetEncoder(nn.Module):
91 | def __init__(self, global_feat=True, feature_transform=False, channel=3):
92 | super(PointNetEncoder, self).__init__()
93 | self.stn = STN3d(channel)
94 | self.conv1 = torch.nn.Conv1d(channel, 64, 1)
95 | self.conv2 = torch.nn.Conv1d(64, 128, 1)
96 | self.conv3 = torch.nn.Conv1d(128, 1024, 1)
97 | self.bn1 = nn.BatchNorm1d(64)
98 | self.bn2 = nn.BatchNorm1d(128)
99 | self.bn3 = nn.BatchNorm1d(1024)
100 | self.global_feat = global_feat
101 | self.feature_transform = feature_transform
102 | if self.feature_transform:
103 | self.fstn = STNkd(k=64)
104 |
105 | def forward(self, x):
106 | B, D, N = x.size()
107 | trans = self.stn(x)
108 | x = x.transpose(2, 1)
109 | if D > 3 :
110 | x, feature = x.split(3,dim=2)
111 | x = torch.bmm(x, trans)
112 | if D > 3:
113 | x = torch.cat([x,feature],dim=2)
114 | x = x.transpose(2, 1)
115 | x = F.relu(self.bn1(self.conv1(x)))
116 |
117 | if self.feature_transform:
118 | trans_feat = self.fstn(x)
119 | x = x.transpose(2, 1)
120 | x = torch.bmm(x, trans_feat)
121 | x = x.transpose(2, 1)
122 | else:
123 | trans_feat = None
124 |
125 | pointfeat = x
126 | x = F.relu(self.bn2(self.conv2(x)))
127 | x = F.relu(self.bn3(self.conv3(x)))
128 | x = torch.max(x, 2, keepdim=True)[0]
129 | x = x.view(-1, 1024)
130 | if self.global_feat:
131 | return x, trans, trans_feat
132 | else:
133 | x = x.view(-1, 1024, 1).repeat(1, 1, N)
134 | return torch.cat([x, pointfeat], 1), trans, trans_feat
135 |
136 |
137 |
138 |
139 |
140 | def feature_transform_reguliarzer(trans):
141 | d = trans.size()[1]
142 | I = torch.eye(d)[None, :, :]
143 | if trans.is_cuda:
144 | I = I.cuda()
145 | loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2, 1) - I), dim=(1, 2), p=2))
146 | return loss
--------------------------------------------------------------------------------
/models/pointnet_cls.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.utils.data
3 | import torch.nn.functional as F
4 | from pointnet import PointNetEncoder, feature_transform_reguliarzer
5 |
6 |
7 | class get_model(nn.Module):
8 | def __init__(self, k=40, normal_channel=True):
9 | super(get_model, self).__init__()
10 | if normal_channel:
11 | channel = 6
12 | else:
13 | channel = 3
14 | self.feat = PointNetEncoder(global_feat=True, feature_transform=True, channel=channel)
15 | self.fc1 = nn.Linear(1024, 512)
16 | self.fc2 = nn.Linear(512, 256)
17 | self.fc3 = nn.Linear(256, k)
18 | self.dropout = nn.Dropout(p=0.4)
19 | self.bn1 = nn.BatchNorm1d(512)
20 | self.bn2 = nn.BatchNorm1d(256)
21 | self.relu = nn.ReLU()
22 |
23 | def forward(self, x):
24 | x, trans, trans_feat = self.feat(x)
25 | x = F.relu(self.bn1(self.fc1(x)))
26 | x = F.relu(self.bn2(self.dropout(self.fc2(x))))
27 | x = self.fc3(x)
28 | x = F.log_softmax(x, dim=1)
29 | return x, trans_feat
30 |
31 |
32 | class get_loss(torch.nn.Module):
33 | def __init__(self, mat_diff_loss_scale=0.001):
34 | super(get_loss, self).__init__()
35 | self.mat_diff_loss_scale = mat_diff_loss_scale
36 |
37 | def forward(self, pred, target, trans_feat):
38 | loss = F.nll_loss(pred, target)
39 | mat_diff_loss = feature_transform_reguliarzer(trans_feat)
40 |
41 | total_loss = loss + mat_diff_loss * self.mat_diff_loss_scale
42 | return total_loss
--------------------------------------------------------------------------------
/models/pointnet_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from time import time
5 | import numpy as np
6 |
7 |
8 | def timeit(tag, t):
9 | print("{}: {}s".format(tag, time() - t))
10 | return time()
11 |
12 |
13 | def pc_normalize(pc):
14 | """
15 | normalize point clouds into a unit sphere
16 | """
17 | l = pc.shape[0]
18 | centroid = np.mean(pc, axis=0)
19 | pc = pc - centroid
20 | m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
21 | pc = pc / m
22 | return pc
23 |
24 |
25 | def square_distance(src, dst):
26 | """
27 | Calculate Euclid distance between each two points.
28 |
29 | src^T * dst = xn * xm + yn * ym + zn * zm;
30 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
31 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
32 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
33 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
34 | Input:
35 | src: source points, [B, N, C]
36 | dst: target points, [B, M, C]
37 | Output:
38 | dist: per-point square distance, [B, N, M]
39 | """
40 | B, N, _ = src.shape
41 | _, M, _ = dst.shape
42 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
43 | dist += torch.sum(src ** 2, -1).view(B, N, 1)
44 | dist += torch.sum(dst ** 2, -1).view(B, 1, M)
45 | return dist
46 |
47 |
48 | def index_points(points, idx):
49 | """
50 | Input:
51 | points: input points data, [B, N, C]
52 | idx: sample index data, [B, S]
53 | Return:
54 | new_points:, indexed points data, [B, S, C]
55 | """
56 | device = points.device
57 | B = points.shape[0]
58 | view_shape = list(idx.shape)
59 | view_shape[1:] = [1] * (len(view_shape) - 1)
60 | repeat_shape = list(idx.shape)
61 | repeat_shape[0] = 1
62 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
63 | new_points = points[batch_indices, idx, :]
64 | return new_points
65 |
66 |
67 | def farthest_point_sample(xyz, npoint):
68 | """
69 | Input:
70 | xyz: pointcloud data, [B, N, 3]
71 | npoint: number of samples
72 | Return:
73 | centroids: sampled pointcloud index, [B, npoint]
74 | """
75 | device = xyz.device
76 | B, N, C = xyz.shape
77 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
78 | distance = torch.ones(B, N).to(device) * 1e10
79 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
80 | batch_indices = torch.arange(B, dtype=torch.long).to(device)
81 | for i in range(npoint):
82 | centroids[:, i] = farthest
83 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
84 | dist = torch.sum((xyz - centroid) ** 2, -1)
85 | mask = dist < distance
86 | distance[mask] = dist[mask]
87 | farthest = torch.max(distance, -1)[1]
88 | return centroids
89 |
90 |
91 |
92 | def query_ball_point(radius, nsample, xyz, new_xyz):
93 | """
94 | # Find n neighbors within r for new_xyz in xyz
95 | Input:
96 | radius: local region radius
97 | nsample: max sample number in local region
98 | xyz: all points, [B, N, 3]
99 | new_xyz: query points, [B, S, 3]
100 | Return:
101 | group_idx: grouped points index, [B, S, nsample]
102 | """
103 | device = xyz.device
104 | B, N, C = xyz.shape
105 | _, S, _ = new_xyz.shape
106 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
107 | sqrdists = square_distance(new_xyz, xyz)
108 | group_idx[sqrdists > radius ** 2] = N
109 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
110 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
111 | mask = group_idx == N
112 | group_idx[mask] = group_first[mask]
113 | return group_idx
114 |
115 |
116 |
117 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
118 | """
119 | Input:
120 | npoint:
121 | radius:
122 | nsample:
123 | xyz: input points position data, [B, N, 3]
124 | points: input points data, [B, N, D]
125 | Return:
126 | new_xyz: sampled points position data, [B, npoint, 3]
127 | new_points: sampled points data, [B, npoint, nsample, 3+D]
128 | """
129 | B, N, C = xyz.shape
130 | S = npoint
131 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
132 | torch.cuda.empty_cache()
133 | new_xyz = index_points(xyz, fps_idx)
134 | torch.cuda.empty_cache()
135 | idx = query_ball_point(radius, nsample, xyz, new_xyz)
136 | torch.cuda.empty_cache()
137 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
138 | torch.cuda.empty_cache()
139 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
140 | torch.cuda.empty_cache()
141 |
142 | if points is not None:
143 | grouped_points = index_points(points, idx)
144 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
145 | else:
146 | new_points = grouped_xyz_norm
147 | if returnfps:
148 | return new_xyz, new_points, grouped_xyz, fps_idx
149 | else:
150 | return new_xyz, new_points
151 |
152 |
153 |
154 |
155 | def sample_and_group_all(xyz, points):
156 | """
157 | Input:
158 | xyz: input points position data, [B, N, 3]
159 | points: input points data, [B, N, D]
160 | Return:
161 | new_xyz: sampled points position data, [B, 1, 3]
162 | new_points: sampled points data, [B, 1, N, 3+D]
163 | """
164 | device = xyz.device
165 | B, N, C = xyz.shape
166 | new_xyz = torch.zeros(B, 1, C).to(device)
167 | grouped_xyz = xyz.view(B, 1, N, C)
168 | if points is not None:
169 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
170 | else:
171 | new_points = grouped_xyz
172 | return new_xyz, new_points
173 |
174 |
175 |
176 |
177 | class PointNetSetAbstraction(nn.Module):
178 | def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
179 | super(PointNetSetAbstraction, self).__init__()
180 | self.npoint = npoint # number of points sampled from farthest point sampling
181 | self.radius = radius
182 | self.nsample = nsample # the number of points in each local region
183 | self.mlp_convs = nn.ModuleList()
184 | self.mlp_bns = nn.ModuleList()
185 | last_channel = in_channel
186 | for out_channel in mlp:
187 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
188 | self.mlp_bns.append(nn.BatchNorm2d(out_channel))
189 | last_channel = out_channel
190 | self.group_all = group_all # group all points into one PC if set true
191 |
192 |
193 | def forward(self, xyz, points):
194 | """
195 | Input:
196 | xyz: input points position data, [B, C, N]
197 | points: input points data, [B, D, N]
198 | Return:
199 | new_xyz: sampled points position data, [B, C, S]
200 | new_points_concat: sample points feature data, [B, D', S]
201 | """
202 | xyz = xyz.permute(0, 2, 1)
203 | if points is not None:
204 | points = points.permute(0, 2, 1)
205 |
206 | if self.group_all:
207 | new_xyz, new_points = sample_and_group_all(xyz, points)
208 | else:
209 | new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
210 | # new_xyz: sampled points position data, [B, npoint, C]
211 | # new_points: sampled points data, [B, npoint, nsample, C+D]
212 | new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
213 | for i, conv in enumerate(self.mlp_convs):
214 | bn = self.mlp_bns[i]
215 | new_points = F.relu(bn(conv(new_points)))
216 |
217 | new_points = torch.max(new_points, 2)[0]
218 | new_xyz = new_xyz.permute(0, 2, 1)
219 | return new_xyz, new_points
220 |
221 |
222 |
223 |
224 | class PointNetSetAbstractionMsg(nn.Module):
225 | def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
226 | super(PointNetSetAbstractionMsg, self).__init__()
227 | self.npoint = npoint
228 | self.radius_list = radius_list
229 | self.nsample_list = nsample_list
230 | self.conv_blocks = nn.ModuleList()
231 | self.bn_blocks = nn.ModuleList()
232 | for i in range(len(mlp_list)):
233 | convs = nn.ModuleList()
234 | bns = nn.ModuleList()
235 | last_channel = in_channel + 3
236 | for out_channel in mlp_list[i]:
237 | convs.append(nn.Conv2d(last_channel, out_channel, 1))
238 | bns.append(nn.BatchNorm2d(out_channel))
239 | last_channel = out_channel
240 | self.conv_blocks.append(convs)
241 | self.bn_blocks.append(bns)
242 |
243 |
244 | def forward(self, xyz, points):
245 | """
246 | Input:
247 | xyz: input points position data, [B, C, N]
248 | points: input points data, [B, D, N]
249 | Return:
250 | new_xyz: sampled points position data, [B, C, S]
251 | new_points_concat: sample points feature data, [B, D', S]
252 | """
253 | xyz = xyz.permute(0, 2, 1)
254 | if points is not None:
255 | points = points.permute(0, 2, 1)
256 |
257 | B, N, C = xyz.shape
258 | S = self.npoint
259 | new_xyz = index_points(xyz, farthest_point_sample(xyz, S))
260 | new_points_list = []
261 | for i, radius in enumerate(self.radius_list):
262 | K = self.nsample_list[i]
263 | group_idx = query_ball_point(radius, K, xyz, new_xyz)
264 | grouped_xyz = index_points(xyz, group_idx)
265 | grouped_xyz -= new_xyz.view(B, S, 1, C) # normalize neighborhood
266 | if points is not None:
267 | grouped_points = index_points(points, group_idx)
268 | grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
269 | else:
270 | grouped_points = grouped_xyz
271 |
272 | grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D-channels, K-nsample, S-npoint]
273 |
274 | for j in range(len(self.conv_blocks[i])):
275 | conv = self.conv_blocks[i][j]
276 | bn = self.bn_blocks[i][j]
277 | grouped_points = F.relu(bn(conv(grouped_points)))
278 | new_points = torch.max(grouped_points, 2)[0] # [B, D', S]
279 | new_points_list.append(new_points)
280 |
281 | new_xyz = new_xyz.permute(0, 2, 1)
282 | new_points_concat = torch.cat(new_points_list, dim=1)
283 | return new_xyz, new_points_concat
284 |
285 |
286 |
287 | class PointNetFeaturePropagation(nn.Module):
288 | def __init__(self, in_channel, mlp):
289 | super(PointNetFeaturePropagation, self).__init__()
290 | self.mlp_convs = nn.ModuleList()
291 | self.mlp_bns = nn.ModuleList()
292 | last_channel = in_channel
293 | for out_channel in mlp:
294 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
295 | self.mlp_bns.append(nn.BatchNorm1d(out_channel))
296 | last_channel = out_channel
297 |
298 |
299 | def forward(self, xyz1, xyz2, points1, points2):
300 | """
301 | Input:
302 | xyz1: input points position data, [B, C, N]
303 | xyz2: sampled input points position data, [B, C, S]
304 | points1: input points data, [B, D, N]
305 | points2: input points data, [B, D, S]
306 | Return:
307 | new_points: upsampled points data, [B, D', N]
308 | """
309 | xyz1 = xyz1.permute(0, 2, 1)
310 | xyz2 = xyz2.permute(0, 2, 1)
311 |
312 | points2 = points2.permute(0, 2, 1)
313 | B, N, C = xyz1.shape
314 | _, S, _ = xyz2.shape
315 |
316 | if S == 1:
317 | interpolated_points = points2.repeat(1, N, 1)
318 | else:
319 | dists = square_distance(xyz1, xyz2)
320 | dists, idx = dists.sort(dim=-1)
321 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
322 |
323 | dist_recip = 1.0 / (dists + 1e-8)
324 | norm = torch.sum(dist_recip, dim=2, keepdim=True)
325 | weight = dist_recip / norm
326 | interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
327 |
328 | if points1 is not None:
329 | points1 = points1.permute(0, 2, 1)
330 | new_points = torch.cat([points1, interpolated_points], dim=-1)
331 | else:
332 | new_points = interpolated_points
333 |
334 | new_points = new_points.permute(0, 2, 1)
335 | for i, conv in enumerate(self.mlp_convs):
336 | bn = self.mlp_bns[i]
337 | new_points = F.relu(bn(conv(new_points)))
338 | return new_points
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | import argparse
5 | import importlib
6 | import numpy as np
7 | from tqdm import tqdm
8 | import multiprocessing
9 | from pathlib import Path
10 | from torch.optim.lr_scheduler import CosineAnnealingLR
11 |
12 | sys.path.append("..")
13 | from data_utils.ModelNetDataLoader import ModelNetDataLoader
14 | from data_utils import provider
15 |
16 | BASE_DIR = os.path.dirname(os.path.abspath('__file__'))
17 | ROOT_DIR = BASE_DIR
18 | sys.path.append(os.path.join(ROOT_DIR, 'models'))
19 | sys.path.append(os.path.join(ROOT_DIR, 'log'))
20 |
21 |
22 |
23 | def test(model, loader, num_class=40, vote_num=1):
24 |
25 | mean_correct = []
26 | class_acc = np.zeros((num_class,3))
27 | for j, data in tqdm(enumerate(loader), total=len(loader)):
28 | points, target = data
29 | target = target[:, 0]
30 | points = points.transpose(2, 1)
31 | points, target = points.cuda(), target.cuda()
32 | classifier = model.eval()
33 | vote_pool = torch.zeros(target.size()[0],num_class).cuda()
34 | for _ in range(vote_num):
35 | pred, _ = classifier(points.float())
36 | vote_pool += pred
37 | pred = vote_pool/vote_num
38 | pred_choice = pred.data.max(1)[1]
39 | for cat in np.unique(target.cpu()):
40 | cat = int(cat)
41 | classacc = pred_choice[target==cat].eq(target[target==cat].long().data).cpu().sum()
42 | class_acc[cat,0]+= classacc.item()/float(points[target==cat].size()[0])
43 | class_acc[cat,1]+=1
44 | correct = pred_choice.eq(target.long().data).cpu().sum()
45 | mean_correct.append(correct.item()/float(points.size()[0]))
46 | class_acc[:,2] = class_acc[:,0]/ class_acc[:,1]
47 | class_acc = np.mean(class_acc[:,2])
48 | instance_acc = np.mean(mean_correct)
49 | return instance_acc, class_acc
50 |
51 |
52 | def main(args):
53 | def log_string(str):
54 | logger.info(str)
55 | print(str)
56 |
57 | '''HYPER PARAMETER'''
58 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
59 | torch.backends.cudnn.enabled = False
60 |
61 | '''DATA LOADING'''
62 | TEST_DATASET = ModelNetDataLoader(root=args.data_root,
63 | tasks=args.test_tasks,
64 | labels=args.test_labels,
65 | partition='test',
66 | npoint=args.num_point,
67 | normal_channel=args.normal)
68 | testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=2)
69 |
70 | '''MODEL LOADING'''
71 | num_class = 40
72 | files = os.listdir(args.model_dir+'/logs')
73 | for f in files:
74 | if f.endswith('txt'):
75 | model_name = f.split('.')[0]
76 |
77 | MODEL = importlib.import_module(model_name)
78 |
79 | classifier = MODEL.get_model(num_class,normal_channel=args.normal).cuda()
80 | checkpoint = torch.load(str(args.model_dir) + '/checkpoints/best_model.pth')
81 | classifier.load_state_dict(checkpoint['model_state_dict'])
82 |
83 | with torch.no_grad():
84 | instance_acc, class_acc = test(classifier.eval(), testDataLoader, vote_num=args.num_votes)
85 | print('Test Instance Accuracy: %f, Class Accuracy: %f' % (instance_acc, class_acc))
86 |
87 |
88 | def parseArgs():
89 | """ Argument parser for configuring model testing """
90 | parser = argparse.ArgumentParser(description='RobustPointSet trainer')
91 | parser.add_argument('--batch_size', type=int, default=32)
92 | parser.add_argument('--model', type=str, default='pointnet_cls', help='point cloud model')
93 | parser.add_argument('--ults', type=str, default='pointnet_util', help='help functions for point cloud model')
94 | parser.add_argument('--gpu', type=str, default='0', help='gpu device index')
95 | parser.add_argument('--num_point', type=int, default=2048, help='number of points')
96 | parser.add_argument('--num_votes', type=int, default=1, help='number of time to run testing and doing majority vote')
97 | parser.add_argument('--normal', default=False, help='Whether to use normal information [default: False]')
98 | parser.add_argument('--data_root', type=str, default='data/', help='data directory')
99 | parser.add_argument('--test_tasks', type=str, nargs='+', required=True, help="List of RobustPointSet files to be tested on during training")
100 | parser.add_argument('--model_dir', type=str, required=True, help="model checkpoint")
101 | return parser.parse_args()
102 |
103 | if __name__ == '__main__':
104 | """
105 | Available data: original, jitter, translate, missing_part, sparse, rotation, occlusion
106 | Example command for strategy 1 & 2:
107 | python test.py --test_tasks test_original.npy --model_dir log/classification/2021-02-17_10-37-pointnet_cls-s1
108 | python test.py --test_tasks test_rotation.npy --model_dir log/classification/2021-02-17_10-37-pointnet_cls-s1
109 |
110 | """
111 | args = parseArgs()
112 | args.test_labels =['test_labels.npy']*len(args.test_tasks)
113 | main(args)
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | import argparse
5 | import datetime
6 | import logging
7 | import shutil
8 | import importlib
9 | import numpy as np
10 | from tqdm import tqdm
11 | import multiprocessing
12 | from pathlib import Path
13 | from torch.optim.lr_scheduler import CosineAnnealingLR
14 |
15 | sys.path.append("..")
16 | from data_utils.ModelNetDataLoader import ModelNetDataLoader
17 | from data_utils import provider
18 |
19 | BASE_DIR = os.path.dirname(os.path.abspath('__file__'))
20 | ROOT_DIR = BASE_DIR
21 | sys.path.append(os.path.join(ROOT_DIR, 'models'))
22 | sys.path.append(os.path.join(ROOT_DIR, 'log'))
23 |
24 |
25 |
26 |
27 | def test(model, loader, num_class=40):
28 | mean_correct = []
29 | class_acc = np.zeros((num_class,3))
30 | for j, data in tqdm(enumerate(loader), total=len(loader)):
31 | points, target = data
32 | target = target[:, 0]
33 | points = points.transpose(2, 1)
34 | points, target = points.cuda(), target.cuda()
35 | classifier = model.eval()
36 | pred, _ = classifier(points.float())
37 | pred_choice = pred.data.max(1)[1]
38 | for cat in np.unique(target.cpu()):
39 | cat = int(cat)
40 | classacc = pred_choice[target==cat].eq(target[target==cat].long().data).cpu().sum()
41 | class_acc[cat,0]+= classacc.item()/float(points[target==cat].size()[0])
42 | class_acc[cat,1]+=1
43 | correct = pred_choice.eq(target.long().data).cpu().sum()
44 | mean_correct.append(correct.item()/float(points.size()[0]))
45 | class_acc[:,2] = class_acc[:,0]/ class_acc[:,1]
46 | class_acc = np.mean(class_acc[:,2])
47 | instance_acc = np.mean(mean_correct)
48 | return instance_acc, class_acc
49 |
50 |
51 |
52 | def main(args):
53 | def log_string(str):
54 | logger.info(str)
55 | print(str)
56 |
57 | ### Hyper Parameters ###
58 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
59 |
60 | ### Create Dir ###
61 | timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))+'-'+args.model+'-'+args.task
62 | experiment_dir = Path('log/')
63 | experiment_dir.mkdir(exist_ok=True)
64 | experiment_dir = experiment_dir.joinpath('classification')
65 | experiment_dir.mkdir(exist_ok=True)
66 | if args.log_dir is None:
67 | experiment_dir = experiment_dir.joinpath(timestr)
68 | else:
69 | experiment_dir = experiment_dir.joinpath(args.log_dir)
70 | experiment_dir.mkdir(exist_ok=True)
71 | checkpoints_dir = experiment_dir.joinpath('checkpoints/')
72 | checkpoints_dir.mkdir(exist_ok=True)
73 | log_dir = experiment_dir.joinpath('logs/')
74 | log_dir.mkdir(exist_ok=True)
75 |
76 | ### Log ###
77 | logger = logging.getLogger("Model")
78 | logger.setLevel(logging.INFO)
79 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
80 | file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
81 | file_handler.setLevel(logging.INFO)
82 | file_handler.setFormatter(formatter)
83 | logger.addHandler(file_handler)
84 | log_string('PARAMETER ...')
85 | log_string(args)
86 |
87 |
88 | ### Data Loading ###
89 | log_string('Load dataset ...')
90 | TRAIN_DATASET = ModelNetDataLoader(root=args.data_root,
91 | tasks=args.train_tasks,
92 | labels=args.train_labels,
93 | partition='train',
94 | npoint=args.num_point,
95 | normal_channel=args.normal)
96 | TEST_DATASET = ModelNetDataLoader(root=args.data_root,
97 | tasks=args.test_tasks,
98 | labels=args.test_labels,
99 | partition='test',
100 | npoint=args.num_point,
101 | normal_channel=args.normal)
102 | trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batch_size, shuffle=True, num_workers=0)
103 | testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=0)
104 |
105 |
106 | ### Model Loading ###
107 | num_class = 40
108 | MODEL = importlib.import_module(args.model)
109 | shutil.copy('models/%s.py' % args.model, str(experiment_dir))
110 | # pointnet_util pointnet_util,
111 | shutil.copy('models/%s.py' % args.ults, str(experiment_dir))
112 |
113 | classifier = MODEL.get_model(num_class,normal_channel=args.normal).cuda()
114 | criterion = MODEL.get_loss().cuda()
115 |
116 | try:
117 | checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth')
118 | start_epoch = checkpoint['epoch']
119 | classifier.load_state_dict(checkpoint['model_state_dict'])
120 | log_string('Use pretrain model')
121 | except:
122 | log_string('No existing model, starting training from scratch...')
123 | start_epoch = 0
124 |
125 | if args.optimizer == 'Adam':
126 | optimizer = torch.optim.Adam(
127 | classifier.parameters(),
128 | lr=args.learning_rate,
129 | betas=(0.9, 0.999),
130 | eps=1e-08,
131 | weight_decay=args.decay_rate
132 | )
133 | else:
134 | optimizer = torch.optim.SGD(classifier.parameters(),
135 | lr=0.01,
136 | momentum=0.9)
137 |
138 |
139 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_decay_step, gamma=args.lr_decay_rate)
140 |
141 | global_epoch = 0
142 | global_step = 0
143 | best_instance_acc = 0.0
144 | best_class_acc = 0.0
145 | mean_correct = []
146 |
147 |
148 | ### Training ###
149 | logger.info('Start training ...')
150 | for epoch in range(start_epoch, args.epoch):
151 |
152 |
153 | log_string('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch))
154 |
155 | for batch_id, data in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9):
156 | points, target = data
157 |
158 | points = points.data.numpy()
159 |
160 | points = torch.Tensor(points)
161 | target = target[:, 0]
162 |
163 | points = points.transpose(2, 1)
164 | points, target = points.cuda(), target.cuda()
165 | optimizer.zero_grad()
166 |
167 | classifier = classifier.train()
168 | pred, trans_feat = classifier(points)
169 | loss = criterion(pred, target.long(), trans_feat)
170 | pred_choice = pred.data.max(1)[1]
171 | correct = pred_choice.eq(target.long().data).cpu().sum()
172 | mean_correct.append(correct.item() / float(points.size()[0]))
173 | loss.backward()
174 | optimizer.step()
175 | global_step += 1
176 |
177 | train_instance_acc = np.mean(mean_correct)
178 | log_string('Train Instance Accuracy: %f' % train_instance_acc)
179 |
180 |
181 | if (epoch % 2 == 0) or (epoch == args.epoch):
182 | with torch.no_grad():
183 | instance_acc, class_acc = test(classifier.eval(), testDataLoader)
184 |
185 | if (instance_acc >= best_instance_acc):
186 | best_instance_acc = instance_acc
187 | best_epoch = epoch + 1
188 |
189 | if (class_acc >= best_class_acc):
190 | best_class_acc = class_acc
191 | log_string('Test Instance Accuracy: %f, Class Accuracy: %f'% (instance_acc, class_acc))
192 | log_string('Best Instance Accuracy: %f, Class Accuracy: %f'% (best_instance_acc, best_class_acc))
193 |
194 | if (instance_acc >= best_instance_acc):
195 | logger.info('Save model...')
196 | savepath = str(checkpoints_dir) + '/best_model.pth'
197 | log_string('Saving at %s'% savepath)
198 | state = {
199 | 'epoch': best_epoch,
200 | 'instance_acc': instance_acc,
201 | 'class_acc': class_acc,
202 | 'model_state_dict': classifier.state_dict(),
203 | 'optimizer_state_dict': optimizer.state_dict(),
204 | }
205 | torch.save(state, savepath)
206 | global_epoch += 1
207 |
208 | # adjust lr
209 | scheduler.step()
210 |
211 | logger.info('End of training...')
212 |
213 |
214 |
215 |
216 |
217 | def parseArgs():
218 | """ Argument parser for configuring model training """
219 | parser = argparse.ArgumentParser(description='RobustPointSet trainer')
220 | parser.add_argument('--batch_size', type=int, default=32)
221 | parser.add_argument('--model', type=str, default='pointnet_cls', help='point cloud model')
222 | parser.add_argument('--ults', type=str, default='pointnet_util', help='help functions for point cloud model')
223 | parser.add_argument('--task', type=str, default='s1', help='Stragety 1 or 2')
224 | parser.add_argument('--epoch', type=int, default=300, help='training epoch')
225 | parser.add_argument('--learning_rate', type=float, default=0.001)
226 | parser.add_argument('--gpu', type=str, default='0', help='gpu device index')
227 | parser.add_argument('--num_point', type=int, default=2048, help='number of points')
228 | parser.add_argument('--optimizer', type=str, default='Adam', help='optimizer')
229 | parser.add_argument('--log_dir', type=str, default=None, help='directory to store training log')
230 | parser.add_argument('--decay_rate', type=float, default=1e-4)
231 | parser.add_argument('--lr_decay_rate', type=float, default=0.5)
232 | parser.add_argument('--lr_decay_step', type=int, default=100)
233 | parser.add_argument('--lr_clip', type=float, default=1e-7)
234 | parser.add_argument('--normal', default=False, help='Whether to use normal information [default: False]')
235 | parser.add_argument('--data_root', type=str, default='data/', help='data directory')
236 | parser.add_argument('--train_tasks', type=str, nargs='+', required=True, help="List of RobustPointSet files to be trained on")
237 | parser.add_argument('--test_tasks', type=str, nargs='+', required=True, help="List of RobustPointSet files to be tested on during training")
238 | return parser.parse_args()
239 |
240 |
241 |
242 | if __name__ == '__main__':
243 | """
244 | Available data: original, jitter, translate, missing_part, sparse, rotation, occlusion
245 | Example command for strategy 1:
246 | python train.py --train_tasks train_original.npy --test_tasks test_original.npy
247 | python train.py --train_tasks train_original.npy --test_tasks test_sparse.npy
248 | Example command for strategy 2:
249 | python train.py --train_tasks train_original.npy train_jitter.npy train_translate.npy train_missing_part.npy train_sparse.npy train_rotation.npy --test_tasks test_occlusion.npy
250 | python train.py --train_tasks train_original.npy train_translate.npy train_missing_part.npy train_sparse.npy train_rotation.npy train_occlusion.npy --test_tasks test_jitter.npy
251 | """
252 | args = parseArgs()
253 | args.train_labels = ['train_labels.npy']*len(args.train_tasks)
254 | args.test_labels =['test_labels.npy']*len(args.test_tasks)
255 | main(args)
--------------------------------------------------------------------------------