├── LICENSE ├── README.md ├── core ├── data.py ├── main_cls.py ├── main_normal.py ├── main_partseg.py ├── models │ ├── curvenet_cls.py │ ├── curvenet_normal.py │ ├── curvenet_seg.py │ ├── curvenet_util.py │ └── walk.py ├── start_cls.sh ├── start_normal.sh ├── start_part.sh ├── test_cls.sh ├── test_normal.sh ├── test_part.sh ├── util.py └── visualize_curves.py ├── poster3.png └── teaser.png /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Tiange Xiang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CurveNet 2 | Official implementation of "Walk in the Cloud: Learning Curves for Point Clouds Shape Analysis", ICCV 2021 3 | 4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/walk-in-the-cloud-learning-curves-for-point/3d-point-cloud-classification-on-modelnet40)](https://paperswithcode.com/sota/3d-point-cloud-classification-on-modelnet40?p=walk-in-the-cloud-learning-curves-for-point) 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/walk-in-the-cloud-learning-curves-for-point/3d-part-segmentation-on-shapenet-part)](https://paperswithcode.com/sota/3d-part-segmentation-on-shapenet-part?p=walk-in-the-cloud-learning-curves-for-point) 6 | 7 | Paper: https://arxiv.org/abs/2105.01288 8 | 9 | ![CurveNet](./poster3.png) 10 | 11 | ## Requirements 12 | - Python>=3.7 13 | - PyTorch>=1.2 14 | - Packages: glob, h5py, sklearn 15 | 16 | ## Contents 17 | - [Point Cloud Classification](#point-cloud-classification) 18 | - [Point Cloud Part Segmentation](#point-cloud-part-segmentation) 19 | - [Point Cloud Normal Estimation](#point-cloud-normal-estimation) 20 | - [Point Cloud Classification Under Corruptions](#point-cloud-classification-under-corruptions) 21 | 22 | **NOTE:** Please change your current directory to ```core/``` first before excuting the following commands. 23 | 24 | ## Point Cloud Classification 25 | ### Data 26 | 27 | The ModelNet40 dataset is primarily used for the classification experiments. At your first run, the program will automatically download the data if it is not in ```data/```. Or, you can manually download the [offical data](https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip) and unzip to ```data/```. 28 | 29 | Alternatively, you can place your downloaded data anywhere you like, and link the path to ```DATA_DIR``` in ```core/data.py```. Otherwise, the download will still be automatically triggered. 30 | 31 | ### Train 32 | 33 | Train with our default settings (same as in the paper): 34 | 35 | ``` 36 | python3 main_cls.py --exp_name=curvenet_cls_1 37 | ``` 38 | 39 | Train with customized settings with the flags: ```--lr```, ```--scheduler```, ```--batch_size```. 40 | 41 | Alternatively, you can directly modify ```core/start_cls.sh``` and simply run: 42 | 43 | ``` 44 | ./start_cls.sh 45 | ``` 46 | 47 | **NOTE:** Our reported model achieves **93.8%/94.2%** accuracy (see sections below). However, due to randomness, the best result might require repeated training processes. Hence, we also provide another benchmark result here (where we repeated 5 runs with different random seeds, and report their average), which is **93.65%** accuracy. 48 | 49 | 50 | 51 | ### Evaluation 52 | 53 | 54 | Evaluate without voting: 55 | ``` 56 | python3 main_cls.py --exp_name=curvenet_cls_1 --eval=True --model_path=PATH_TO_YOUR_MODEL 57 | ``` 58 | 59 | Alternatively, you can directly modify ```core/test_cls.sh``` and simply run: 60 | ``` 61 | ./test_cls.sh 62 | ``` 63 | 64 | For voting, we used the ```voting_evaluate_cls.py```script provided in [RSCNN](https://github.com/Yochengliu/Relation-Shape-CNN). Please refer to their license for usage. 65 | 66 | ### Evaluation with our pretrained model: 67 | 68 | Please download our pretrained model ```cls/``` at [google drive](https://drive.google.com/drive/folders/1kX-zIipyzB0iMaopcijzdTRuHeTzfTSz?usp=sharing). 69 | 70 | And then run: 71 | 72 | ``` 73 | python3 main_cls.py --exp_name=curvenet_cls_pretrained --eval --model_path=PATH_TO_PRETRAINED/cls/models/model.t7 74 | ``` 75 | 76 |   77 | ## Point Cloud Part Segmentation 78 | ### Data 79 | 80 | The ShapeNet Part dataset is primarily used for the part segmentation experiments. At your first run, the program will automatically download the data if it is not in ```data/```. Or, you can manually download the [offical data](https://shapenet.cs.stanford.edu/media/shapenet_part_seg_hdf5_data.zip) and unzip to ```data/```. 81 | 82 | Alternatively, you can place your downloaded data anywhere you like, and link the path to ```DATA_DIR``` in ```core/data.py```. Otherwise, the download will still be automatically triggered. 83 | 84 | ### Train 85 | 86 | Train with our default settings (same as in the paper): 87 | 88 | ``` 89 | python3 main_partseg.py --exp_name=curvenet_seg_1 90 | ``` 91 | 92 | Train with customized settings with the flags: ```--lr```, ```--scheduler```, ```--batch_size```. 93 | 94 | Alternatively, you can directly modify ```core/start_part.sh``` and simply run: 95 | 96 | ``` 97 | ./start_part.sh 98 | ``` 99 | 100 | **NOTE:** Our reported model achieves **86.6%/86.8%** mIoU (see sections below). However, due to randomness, the best result might require repeated training processes. Hence, we also provide another benchmark result here (where we repeated 5 runs with different random seeds, and report their average), which is **86.46** mIoU. 101 | 102 | 103 | 104 | ### Evaluation 105 | 106 | Evaluate without voting: 107 | ``` 108 | python3 main_partseg.py --exp_name=curvenet_seg_1 --eval=True --model_path=PATH_TO_YOUR_MODEL 109 | ``` 110 | 111 | Alternatively, you can directly modify ```core/test_part.sh``` and simply run: 112 | ``` 113 | ./test_part.sh 114 | ``` 115 | 116 | For voting, we used the ```voting_evaluate_partseg.py```script provided in [RSCNN](https://github.com/Yochengliu/Relation-Shape-CNN). Please refer to their license for usage. 117 | 118 | ### Evaluation with our pretrained model: 119 | 120 | Please download our pretrained model ```partseg/``` at [google drive](https://drive.google.com/drive/folders/1kX-zIipyzB0iMaopcijzdTRuHeTzfTSz?usp=sharing). 121 | 122 | And then run: 123 | 124 | ``` 125 | python3 main_partseg.py --exp_name=curvenet_seg_pretrained --eval=True --model_path=PATH_TO_PRETRAINED/partseg/models/model.t7 126 | ``` 127 | 128 |   129 | ## Point Cloud Normal Estimation 130 | 131 | ### Data 132 | 133 | The ModelNet40 dataset is used for the normal estimation experiments. We have preprocessed the raw ModelNet40 dataset into ```.h5``` files. Each point cloud instance contains 2048 randomly sampled points and point-to-point normal ground truths. 134 | 135 | Please download our processed data [here](https://drive.google.com/file/d/1j6lB3ZOF0_x_l9bqdchAxIYBi7Devie8/view?usp=sharing) and place it to ```data/```, or you need to specify the data root path in ```core/data.py```. 136 | 137 | ### Train 138 | 139 | Train with our default settings (same as in the paper): 140 | 141 | ``` 142 | python3 main_normal.py --exp_name=curvenet_normal_1 143 | ``` 144 | 145 | Train with customized settings with the flags: ```--multiplier```, ```--lr```, ```--scheduler```, ```--batch_size```. 146 | 147 | Alternatively, you can directly modify ```core/start_normal.sh``` and simply run: 148 | 149 | ``` 150 | ./start_normal.sh 151 | ``` 152 | 153 | ### Evaluation 154 | 155 | Evaluate without voting: 156 | ``` 157 | python3 main_normal.py --exp_name=curvenet_normal_1 --eval=True --model_path=PATH_TO_YOUR_MODEL 158 | ``` 159 | 160 | Alternatively, you can directly modify ```core/test_normal.sh``` and simply run: 161 | ``` 162 | ./test_normal.sh 163 | ``` 164 | 165 | ### Evaluation with our pretrained model: 166 | 167 | Please download our pretrained model ```normal/``` at [google drive](https://drive.google.com/drive/folders/1kX-zIipyzB0iMaopcijzdTRuHeTzfTSz?usp=sharing). 168 | 169 | And then run: 170 | 171 | ``` 172 | python3 main_normal.py --exp_name=curvenet_normal_pretrained --eval=True --model_path=PATH_TO_PRETRAINED/normal/models/model.t7 173 | ``` 174 | 175 |   176 | ## Point Cloud Classification Under Corruptions 177 | In [a recent work](https://arxiv.org/abs/2201.12296), Sun et al. studied robustness of state-of-the-art point cloud processing architectures under common corruptions. **CurveNet was verifed by them to be the best architecture to function on common corruptions.** 178 | Please refer to [their official repo](https://github.com/jiachens/ModelNet40-C) for details. 179 | 180 | ## Citation 181 | 182 | If you find this repo useful in your work or research, please cite: 183 | 184 | ``` 185 | @InProceedings{Xiang_2021_ICCV, 186 | author = {Xiang, Tiange and Zhang, Chaoyi and Song, Yang and Yu, Jianhui and Cai, Weidong}, 187 | title = {Walk in the Cloud: Learning Curves for Point Clouds Shape Analysis}, 188 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 189 | month = {October}, 190 | year = {2021}, 191 | pages = {915-924} 192 | } 193 | ``` 194 | 195 | ## Acknowledgement 196 | 197 | Our code borrows a lot from: 198 | - [DGCNN](https://github.com/WangYueFt/dgcnn) 199 | - [DGCNN.pytorch](https://github.com/AnTao97/dgcnn.pytorch) 200 | - [CloserLook3D](https://github.com/zeliu98/CloserLook3D) 201 | -------------------------------------------------------------------------------- /core/data.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: Yue Wang 3 | @Contact: yuewangx@mit.edu 4 | @File: data.py 5 | @Time: 2018/10/13 6:21 PM 6 | 7 | Modified by 8 | @Author: Tiange Xiang 9 | @Contact: txia7609@uni.sydney.edu.au 10 | @Time: 2021/1/21 3:10 PM 11 | """ 12 | 13 | 14 | import os 15 | import sys 16 | import glob 17 | import h5py 18 | import numpy as np 19 | import torch 20 | from torch.utils.data import Dataset 21 | 22 | 23 | # change this to your data root 24 | DATA_DIR = '../data/' 25 | 26 | def download_modelnet40(): 27 | if not os.path.exists(DATA_DIR): 28 | os.mkdir(DATA_DIR) 29 | if not os.path.exists(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048')): 30 | os.mkdir(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048')) 31 | www = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip' 32 | zipfile = os.path.basename(www) 33 | os.system('wget %s --no-check-certificate; unzip %s' % (www, zipfile)) 34 | os.system('mv %s %s' % (zipfile[:-4], DATA_DIR)) 35 | os.system('rm %s' % (zipfile)) 36 | 37 | 38 | def download_shapenetpart(): 39 | if not os.path.exists(DATA_DIR): 40 | os.mkdir(DATA_DIR) 41 | if not os.path.exists(os.path.join(DATA_DIR, 'shapenet_part_seg_hdf5_data')): 42 | os.mkdir(os.path.join(DATA_DIR, 'shapenet_part_seg_hdf5_data')) 43 | www = 'https://shapenet.cs.stanford.edu/media/shapenet_part_seg_hdf5_data.zip' 44 | zipfile = os.path.basename(www) 45 | os.system('wget %s --no-check-certificate; unzip %s' % (www, zipfile)) 46 | os.system('mv %s %s' % (zipfile[:-4], os.path.join(DATA_DIR, 'shapenet_part_seg_hdf5_data'))) 47 | os.system('rm %s' % (zipfile)) 48 | 49 | 50 | def load_data_normal(partition): 51 | f = h5py.File(os.path.join(DATA_DIR, 'modelnet40_normal', 'normal_%s.h5'%partition), 'r+') 52 | data = f['xyz'][:].astype('float32') 53 | label = f['normal'][:].astype('float32') 54 | f.close() 55 | return data, label 56 | 57 | 58 | def load_data_cls(partition): 59 | download_modelnet40() 60 | all_data = [] 61 | all_label = [] 62 | for h5_name in glob.glob(os.path.join(DATA_DIR, 'modelnet40*hdf5_2048', '*%s*.h5'%partition)): 63 | f = h5py.File(h5_name, 'r+') 64 | data = f['data'][:].astype('float32') 65 | label = f['label'][:].astype('int64') 66 | f.close() 67 | all_data.append(data) 68 | all_label.append(label) 69 | all_data = np.concatenate(all_data, axis=0) 70 | all_label = np.concatenate(all_label, axis=0) 71 | return all_data, all_label 72 | 73 | 74 | def load_data_partseg(partition): 75 | download_shapenetpart() 76 | all_data = [] 77 | all_label = [] 78 | all_seg = [] 79 | if partition == 'trainval': 80 | file = glob.glob(os.path.join(DATA_DIR, 'shapenet_part_seg_hdf5_data', 'hdf5_data', '*train*.h5')) \ 81 | + glob.glob(os.path.join(DATA_DIR, 'shapenet_part_seg_hdf5_data', 'hdf5_data', '*val*.h5')) 82 | else: 83 | file = glob.glob(os.path.join(DATA_DIR, 'shapenet_part_seg_hdf5_data', 'hdf5_data', '*%s*.h5'%partition)) 84 | for h5_name in file: 85 | f = h5py.File(h5_name, 'r+') 86 | data = f['data'][:].astype('float32') 87 | label = f['label'][:].astype('int64') 88 | seg = f['pid'][:].astype('int64') 89 | f.close() 90 | all_data.append(data) 91 | all_label.append(label) 92 | all_seg.append(seg) 93 | all_data = np.concatenate(all_data, axis=0) 94 | all_label = np.concatenate(all_label, axis=0) 95 | all_seg = np.concatenate(all_seg, axis=0) 96 | return all_data, all_label, all_seg 97 | 98 | 99 | def translate_pointcloud(pointcloud): 100 | xyz1 = np.random.uniform(low=2./3., high=3./2., size=[3]) 101 | xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3]) 102 | 103 | translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32') 104 | return translated_pointcloud 105 | 106 | 107 | def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02): 108 | N, C = pointcloud.shape 109 | pointcloud += np.clip(sigma * np.random.randn(N, C), -1*clip, clip) 110 | return pointcloud 111 | 112 | 113 | def rotate_pointcloud(pointcloud): 114 | theta = np.pi*2 * np.random.uniform() 115 | rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)],[np.sin(theta), np.cos(theta)]]) 116 | pointcloud[:,[0,2]] = pointcloud[:,[0,2]].dot(rotation_matrix) # random rotation (x,z) 117 | return pointcloud 118 | 119 | 120 | class ModelNet40(Dataset): 121 | def __init__(self, num_points, partition='train'): 122 | self.data, self.label = load_data_cls(partition) 123 | self.num_points = num_points 124 | self.partition = partition 125 | 126 | def __getitem__(self, item): 127 | pointcloud = self.data[item][:self.num_points] 128 | label = self.label[item] 129 | if self.partition == 'train': 130 | pointcloud = translate_pointcloud(pointcloud) 131 | #pointcloud = rotate_pointcloud(pointcloud) 132 | np.random.shuffle(pointcloud) 133 | return pointcloud, label 134 | 135 | def __len__(self): 136 | return self.data.shape[0] 137 | 138 | class ModelNetNormal(Dataset): 139 | def __init__(self, num_points, partition='train'): 140 | self.data, self.label = load_data_normal(partition) 141 | self.num_points = num_points 142 | self.partition = partition 143 | 144 | def __getitem__(self, item): 145 | pointcloud = self.data[item][:self.num_points] 146 | label = self.label[item][:self.num_points] 147 | if self.partition == 'train': 148 | #pointcloud = translate_pointcloud(pointcloud) 149 | idx = np.arange(0, pointcloud.shape[0], dtype=np.int64) 150 | np.random.shuffle(idx) 151 | pointcloud = self.data[item][idx] 152 | label = self.label[item][idx] 153 | return pointcloud, label 154 | 155 | def __len__(self): 156 | return self.data.shape[0] 157 | 158 | class ShapeNetPart(Dataset): 159 | def __init__(self, num_points=2048, partition='train', class_choice=None): 160 | self.data, self.label, self.seg = load_data_partseg(partition) 161 | self.cat2id = {'airplane': 0, 'bag': 1, 'cap': 2, 'car': 3, 'chair': 4, 162 | 'earphone': 5, 'guitar': 6, 'knife': 7, 'lamp': 8, 'laptop': 9, 163 | 'motor': 10, 'mug': 11, 'pistol': 12, 'rocket': 13, 'skateboard': 14, 'table': 15} 164 | self.seg_num = [4, 2, 2, 4, 4, 3, 3, 2, 4, 2, 6, 2, 3, 3, 3, 3] 165 | self.index_start = [0, 4, 6, 8, 12, 16, 19, 22, 24, 28, 30, 36, 38, 41, 44, 47] 166 | self.num_points = num_points 167 | self.partition = partition 168 | self.class_choice = class_choice 169 | 170 | if self.class_choice != None: 171 | id_choice = self.cat2id[self.class_choice] 172 | indices = (self.label == id_choice).squeeze() 173 | self.data = self.data[indices] 174 | self.label = self.label[indices] 175 | self.seg = self.seg[indices] 176 | self.seg_num_all = self.seg_num[id_choice] 177 | self.seg_start_index = self.index_start[id_choice] 178 | else: 179 | self.seg_num_all = 50 180 | self.seg_start_index = 0 181 | 182 | def __getitem__(self, item): 183 | pointcloud = self.data[item][:self.num_points] 184 | label = self.label[item] 185 | seg = self.seg[item][:self.num_points] 186 | if self.partition == 'trainval': 187 | pointcloud = translate_pointcloud(pointcloud) 188 | indices = list(range(pointcloud.shape[0])) 189 | np.random.shuffle(indices) 190 | pointcloud = pointcloud[indices] 191 | seg = seg[indices] 192 | return pointcloud, label, seg 193 | 194 | def __len__(self): 195 | return self.data.shape[0] 196 | -------------------------------------------------------------------------------- /core/main_cls.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: Yue Wang 3 | @Contact: yuewangx@mit.edu 4 | @File: main_cls.py 5 | @Time: 2018/10/13 10:39 PM 6 | 7 | Modified by 8 | @Author: Tiange Xiang 9 | @Contact: txia7609@uni.sydney.edu.au 10 | @Time: 2021/01/21 3:10 PM 11 | """ 12 | 13 | from __future__ import print_function 14 | import os 15 | import argparse 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | import torch.optim as optim 20 | from torch.optim.lr_scheduler import CosineAnnealingLR, MultiStepLR 21 | from data import ModelNet40 22 | from models.curvenet_cls import CurveNet 23 | import numpy as np 24 | from torch.utils.data import DataLoader 25 | from util import cal_loss, IOStream 26 | import sklearn.metrics as metrics 27 | 28 | 29 | def _init_(): 30 | # fix random seed 31 | torch.manual_seed(seed) 32 | np.random.seed(seed) 33 | torch.cuda.manual_seed_all(seed) 34 | torch.cuda.manual_seed(seed) 35 | torch.set_printoptions(10) 36 | torch.backends.cudnn.benchmark = False 37 | torch.backends.cudnn.deterministic = True 38 | os.environ['PYTHONHASHSEED'] = str(seed) 39 | 40 | # prepare file structures 41 | if not os.path.exists('../checkpoints'): 42 | os.makedirs('../checkpoints') 43 | if not os.path.exists('../checkpoints/'+args.exp_name): 44 | os.makedirs('../checkpoints/'+args.exp_name) 45 | if not os.path.exists('../checkpoints/'+args.exp_name+'/'+'models'): 46 | os.makedirs('../checkpoints/'+args.exp_name+'/'+'models') 47 | os.system('cp main_cls.py ../checkpoints/'+args.exp_name+'/main_cls.py.backup') 48 | os.system('cp models/curvenet_cls.py ../checkpoints/'+args.exp_name+'/curvenet_cls.py.backup') 49 | 50 | def train(args, io): 51 | train_loader = DataLoader(ModelNet40(partition='train', num_points=args.num_points), num_workers=8, 52 | batch_size=args.batch_size, shuffle=True, drop_last=True) 53 | test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), num_workers=8, 54 | batch_size=args.test_batch_size, shuffle=False, drop_last=False) 55 | 56 | device = torch.device("cuda" if args.cuda else "cpu") 57 | io.cprint("Let's use" + str(torch.cuda.device_count()) + "GPUs!") 58 | 59 | # create model 60 | model = CurveNet().to(device) 61 | model = nn.DataParallel(model) 62 | 63 | if args.use_sgd: 64 | io.cprint("Use SGD") 65 | opt = optim.SGD(model.parameters(), lr=args.lr*100, momentum=args.momentum, weight_decay=1e-4) 66 | else: 67 | io.cprint("Use Adam") 68 | opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4) 69 | 70 | if args.scheduler == 'cos': 71 | scheduler = CosineAnnealingLR(opt, args.epochs, eta_min=1e-3) 72 | elif args.scheduler == 'step': 73 | scheduler = MultiStepLR(opt, [120, 160], gamma=0.1) 74 | 75 | criterion = cal_loss 76 | 77 | best_test_acc = 0 78 | for epoch in range(args.epochs): 79 | #################### 80 | # Train 81 | #################### 82 | train_loss = 0.0 83 | count = 0.0 84 | model.train() 85 | train_pred = [] 86 | train_true = [] 87 | for data, label in train_loader: 88 | data, label = data.to(device), label.to(device).squeeze() 89 | data = data.permute(0, 2, 1) 90 | batch_size = data.size()[0] 91 | opt.zero_grad() 92 | logits = model(data) 93 | loss = criterion(logits, label) 94 | loss.backward() 95 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1) 96 | opt.step() 97 | preds = logits.max(dim=1)[1] 98 | count += batch_size 99 | train_loss += loss.item() * batch_size 100 | train_true.append(label.cpu().numpy()) 101 | train_pred.append(preds.detach().cpu().numpy()) 102 | if args.scheduler == 'cos': 103 | scheduler.step() 104 | elif args.scheduler == 'step': 105 | if opt.param_groups[0]['lr'] > 1e-5: 106 | scheduler.step() 107 | if opt.param_groups[0]['lr'] < 1e-5: 108 | for param_group in opt.param_groups: 109 | param_group['lr'] = 1e-5 110 | 111 | train_true = np.concatenate(train_true) 112 | train_pred = np.concatenate(train_pred) 113 | outstr = 'Train %d, loss: %.6f, train acc: %.6f' % (epoch, train_loss*1.0/count, 114 | metrics.accuracy_score( 115 | train_true, train_pred)) 116 | io.cprint(outstr) 117 | 118 | #################### 119 | # Test 120 | #################### 121 | test_loss = 0.0 122 | count = 0.0 123 | model.eval() 124 | test_pred = [] 125 | test_true = [] 126 | for data, label in test_loader: 127 | data, label = data.to(device), label.to(device).squeeze() 128 | data = data.permute(0, 2, 1) 129 | batch_size = data.size()[0] 130 | logits = model(data) 131 | loss = criterion(logits, label) 132 | preds = logits.max(dim=1)[1] 133 | count += batch_size 134 | test_loss += loss.item() * batch_size 135 | test_true.append(label.cpu().numpy()) 136 | test_pred.append(preds.detach().cpu().numpy()) 137 | test_true = np.concatenate(test_true) 138 | test_pred = np.concatenate(test_pred) 139 | test_acc = metrics.accuracy_score(test_true, test_pred) 140 | outstr = 'Test %d, loss: %.6f, test acc: %.6f' % (epoch, test_loss*1.0/count, test_acc) 141 | io.cprint(outstr) 142 | if test_acc >= best_test_acc: 143 | best_test_acc = test_acc 144 | torch.save(model.state_dict(), '../checkpoints/%s/models/model.t7' % args.exp_name) 145 | io.cprint('best: %.3f' % best_test_acc) 146 | 147 | def test(args, io): 148 | test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), 149 | batch_size=args.test_batch_size, shuffle=False, drop_last=False) 150 | 151 | device = torch.device("cuda" if args.cuda else "cpu") 152 | 153 | #Try to load models 154 | model = CurveNet().to(device) 155 | model = nn.DataParallel(model) 156 | model.load_state_dict(torch.load(args.model_path)) 157 | 158 | model = model.eval() 159 | test_acc = 0.0 160 | count = 0.0 161 | test_true = [] 162 | test_pred = [] 163 | for data, label in test_loader: 164 | 165 | data, label = data.to(device), label.to(device).squeeze() 166 | data = data.permute(0, 2, 1) 167 | batch_size = data.size()[0] 168 | logits = model(data) 169 | preds = logits.max(dim=1)[1] 170 | test_true.append(label.cpu().numpy()) 171 | test_pred.append(preds.detach().cpu().numpy()) 172 | test_true = np.concatenate(test_true) 173 | test_pred = np.concatenate(test_pred) 174 | test_acc = metrics.accuracy_score(test_true, test_pred) 175 | outstr = 'Test :: test acc: %.6f'%(test_acc) 176 | io.cprint(outstr) 177 | 178 | 179 | if __name__ == "__main__": 180 | # Training settings 181 | parser = argparse.ArgumentParser(description='Point Cloud Recognition') 182 | parser.add_argument('--exp_name', type=str, default='exp', metavar='N', 183 | help='Name of the experiment') 184 | parser.add_argument('--dataset', type=str, default='modelnet40', metavar='N', 185 | choices=['modelnet40']) 186 | parser.add_argument('--batch_size', type=int, default=32, metavar='batch_size', 187 | help='Size of batch)') 188 | parser.add_argument('--test_batch_size', type=int, default=16, metavar='batch_size', 189 | help='Size of batch)') 190 | parser.add_argument('--epochs', type=int, default=200, metavar='N', 191 | help='number of episode to train ') 192 | parser.add_argument('--use_sgd', type=bool, default=True, 193 | help='Use SGD') 194 | parser.add_argument('--lr', type=float, default=0.001, metavar='LR', 195 | help='learning rate (default: 0.001, 0.1 if using sgd)') 196 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 197 | help='SGD momentum (default: 0.9)') 198 | parser.add_argument('--scheduler', type=str, default='cos', metavar='N', 199 | choices=['cos', 'step'], 200 | help='Scheduler to use, [cos, step]') 201 | parser.add_argument('--no_cuda', type=bool, default=False, 202 | help='enables CUDA training') 203 | parser.add_argument('--eval', type=bool, default=False, 204 | help='evaluate the model') 205 | parser.add_argument('--num_points', type=int, default=1024, 206 | help='num of points to use') 207 | parser.add_argument('--model_path', type=str, default='', metavar='N', 208 | help='Pretrained model path') 209 | args = parser.parse_args() 210 | 211 | seed = np.random.randint(1, 10000) 212 | 213 | _init_() 214 | 215 | if args.eval: 216 | io = IOStream('../checkpoints/' + args.exp_name + '/eval.log') 217 | else: 218 | io = IOStream('../checkpoints/' + args.exp_name + '/run.log') 219 | io.cprint(str(args)) 220 | io.cprint('random seed is: ' + str(seed)) 221 | 222 | args.cuda = not args.no_cuda and torch.cuda.is_available() 223 | 224 | if args.cuda: 225 | io.cprint( 226 | 'Using GPU : ' + str(torch.cuda.current_device()) + ' from ' + str(torch.cuda.device_count()) + ' devices') 227 | else: 228 | io.cprint('Using CPU') 229 | 230 | if not args.eval: 231 | train(args, io) 232 | else: 233 | with torch.no_grad(): 234 | test(args, io) 235 | -------------------------------------------------------------------------------- /core/main_normal.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: Tiange Xiang 3 | @Contact: txia7609@uni.sydney.edu.au 4 | @File: main_normal.py 5 | @Time: 2021/01/21 3:10 PM 6 | """ 7 | 8 | 9 | from __future__ import print_function 10 | import os 11 | import argparse 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torch.optim as optim 16 | from torch.optim.lr_scheduler import CosineAnnealingLR, MultiStepLR 17 | from data import ModelNetNormal 18 | from models.curvenet_normal import CurveNet 19 | import numpy as np 20 | from torch.utils.data import DataLoader 21 | from util import IOStream 22 | 23 | 24 | def _init_(): 25 | # fix random seed 26 | torch.manual_seed(seed) 27 | np.random.seed(seed) 28 | torch.cuda.manual_seed_all(seed) 29 | torch.cuda.manual_seed(seed) 30 | torch.set_printoptions(10) 31 | torch.backends.cudnn.benchmark = False 32 | torch.backends.cudnn.deterministic = True 33 | os.environ['PYTHONHASHSEED'] = str(seed) 34 | 35 | # prepare file structures 36 | if not os.path.exists('../checkpoints'): 37 | os.makedirs('../checkpoints') 38 | if not os.path.exists('../checkpoints/'+args.exp_name): 39 | os.makedirs('../checkpoints/'+args.exp_name) 40 | if not os.path.exists('../checkpoints/'+args.exp_name+'/'+'models'): 41 | os.makedirs('../checkpoints/'+args.exp_name+'/'+'models') 42 | os.system('cp main_normal.py ../checkpoints/'+args.exp_name+'/main_normal.py.backup') 43 | os.system('cp models/curvenet_normal.py ../checkpoints/'+args.exp_name+'/curvenet_normal.py.backup') 44 | 45 | def train(args, io): 46 | train_loader = DataLoader(ModelNetNormal(args.num_points, partition='train'), 47 | num_workers=8, batch_size=args.batch_size, shuffle=True, drop_last=True) 48 | test_loader = DataLoader(ModelNetNormal(args.num_points, partition='test'), 49 | num_workers=8, batch_size=args.test_batch_size, shuffle=False, drop_last=False) 50 | 51 | device = torch.device("cuda" if args.cuda else "cpu") 52 | 53 | # create model 54 | model = CurveNet(args.multiplier).to(device) 55 | model = nn.DataParallel(model) 56 | io.cprint("Let's use" + str(torch.cuda.device_count()) + "GPUs!") 57 | 58 | if args.use_sgd: 59 | io.cprint("Use SGD") 60 | opt = optim.SGD(model.parameters(), lr=args.lr*100, momentum=args.momentum, weight_decay=1e-4) 61 | else: 62 | io.cprint("Use Adam") 63 | opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4) 64 | 65 | if args.scheduler == 'cos': 66 | scheduler = CosineAnnealingLR(opt, args.epochs, eta_min=1e-3) 67 | elif args.scheduler == 'step': 68 | scheduler = MultiStepLR(opt, [140, 180], gamma=0.1) 69 | 70 | criterion = torch.nn.CosineEmbeddingLoss() 71 | 72 | best_test_loss = 99 73 | for epoch in range(args.epochs): 74 | #################### 75 | # Train 76 | #################### 77 | train_loss = 0.0 78 | count = 0.0 79 | model.train() 80 | for data, seg in train_loader: 81 | data, seg = data.to(device), seg.to(device) 82 | data = data.permute(0, 2, 1) 83 | batch_size = data.size()[0] 84 | opt.zero_grad() 85 | seg_pred = model(data) 86 | seg_pred = seg_pred.permute(0, 2, 1).contiguous() 87 | #print(seg_pred.shape, seg.shape) 88 | loss = criterion(seg_pred.view(-1, 3), seg.view(-1,3).squeeze(), torch.tensor(1).cuda()) 89 | loss.backward() 90 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1) 91 | opt.step() 92 | count += batch_size 93 | train_loss += loss.item() * batch_size 94 | 95 | if args.scheduler == 'cos': 96 | scheduler.step() 97 | elif args.scheduler == 'step': 98 | if opt.param_groups[0]['lr'] > 1e-5: 99 | scheduler.step() 100 | if opt.param_groups[0]['lr'] < 1e-5: 101 | for param_group in opt.param_groups: 102 | param_group['lr'] = 1e-5 103 | 104 | outstr = 'Train %d, loss: %.6f' % (epoch, train_loss/count) 105 | io.cprint(outstr) 106 | 107 | #################### 108 | # Test 109 | #################### 110 | test_loss = 0.0 111 | count = 0.0 112 | model.eval() 113 | for data, seg in test_loader: 114 | data, seg = data.to(device), seg.to(device) 115 | data = data.permute(0, 2, 1) 116 | batch_size = data.size()[0] 117 | seg_pred = model(data) 118 | seg_pred = seg_pred.permute(0, 2, 1).contiguous() 119 | 120 | loss = criterion(seg_pred.view(-1, 3), seg.view(-1,3).squeeze(), torch.tensor(1).cuda()) 121 | count += batch_size 122 | test_loss += loss.item() * batch_size 123 | 124 | if test_loss*1.0/count <= best_test_loss: 125 | best_test_loss = test_loss*1.0/count 126 | torch.save(model.state_dict(), '../checkpoints/%s/models/model.t7' % args.exp_name) 127 | outstr = 'Test %d, loss: %.6f, best loss %.6f' % (epoch, test_loss/count, best_test_loss) 128 | io.cprint(outstr) 129 | 130 | def test(args, io): 131 | test_loader = DataLoader(ModelNetNormal(args.num_points, partition='test'), 132 | batch_size=args.test_batch_size, shuffle=False, drop_last=False) 133 | 134 | device = torch.device("cuda" if args.cuda else "cpu") 135 | 136 | #Try to load models 137 | model = CurveNet(args.multiplier).to(device) 138 | model = nn.DataParallel(model) 139 | model.load_state_dict(torch.load(args.model_path)) 140 | 141 | criterion = torch.nn.CosineEmbeddingLoss() 142 | 143 | model = model.eval() 144 | test_loss = 0.0 145 | count = 0 146 | for data, seg in test_loader: 147 | data, seg = data.to(device), seg.to(device) 148 | #print(data.shape, seg.shape) 149 | data = data.permute(0, 2, 1) 150 | batch_size = data.size()[0] 151 | seg_pred = model(data) 152 | seg_pred = seg_pred.permute(0, 2, 1).contiguous() 153 | loss = criterion(seg_pred.view(-1, 3), seg.view(-1,3).squeeze(), torch.tensor(1).cuda()) 154 | count += batch_size 155 | test_loss += loss.item() * batch_size 156 | outstr = 'Test :: test loss: %.6f' % (test_loss*1.0/count) 157 | io.cprint(outstr) 158 | 159 | 160 | if __name__ == "__main__": 161 | # Training settings 162 | parser = argparse.ArgumentParser(description='Point Cloud Part Segmentation') 163 | parser.add_argument('--exp_name', type=str, default='exp', metavar='N', 164 | help='Name of the experiment') 165 | parser.add_argument('--batch_size', type=int, default=32, metavar='batch_size', 166 | help='Size of batch)') 167 | parser.add_argument('--test_batch_size', type=int, default=16, metavar='batch_size', 168 | help='Size of batch)') 169 | parser.add_argument('--epochs', type=int, default=200, metavar='N', 170 | help='number of episode to train ') 171 | parser.add_argument('--use_sgd', type=bool, default=True, 172 | help='Use SGD') 173 | parser.add_argument('--lr', type=float, default=0.0005, metavar='LR', 174 | help='learning rate') 175 | parser.add_argument('--multiplier', type=float, default=2.0, metavar='MP', 176 | help='network expansion multiplier') 177 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 178 | help='SGD momentum (default: 0.9)') 179 | parser.add_argument('--scheduler', type=str, default='cos', metavar='N', 180 | choices=['cos', 'step'], 181 | help='Scheduler to use, [cos, step]') 182 | parser.add_argument('--no_cuda', type=bool, default=False, 183 | help='enables CUDA training') 184 | parser.add_argument('--eval', type=bool, default=False, 185 | help='evaluate the model') 186 | parser.add_argument('--num_points', type=int, default=1024, 187 | help='num of points to use') 188 | parser.add_argument('--model_path', type=str, default='', metavar='N', 189 | help='Pretrained model path') 190 | args = parser.parse_args() 191 | 192 | seed = np.random.randint(1, 10000) 193 | 194 | _init_() 195 | 196 | io = IOStream('../checkpoints/' + args.exp_name + '/run.log') 197 | io.cprint(str(args)) 198 | io.cprint('random seed is: ' + str(seed)) 199 | 200 | args.cuda = not args.no_cuda and torch.cuda.is_available() 201 | if args.cuda: 202 | io.cprint( 203 | 'Using GPU : ' + str(torch.cuda.current_device()) + ' from ' + str(torch.cuda.device_count()) + ' devices') 204 | else: 205 | io.cprint('Using CPU') 206 | 207 | if not args.eval: 208 | train(args, io) 209 | else: 210 | with torch.no_grad(): 211 | test(args, io) 212 | -------------------------------------------------------------------------------- /core/main_partseg.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: An Tao 3 | @Contact: ta19@mails.tsinghua.edu.cn 4 | @File: main_partseg.py 5 | @Time: 2019/12/31 11:17 AM 6 | 7 | Modified by 8 | @Author: Tiange Xiang 9 | @Contact: txia7609@uni.sydney.edu.au 10 | @Time: 2021/01/21 3:10 PM 11 | """ 12 | 13 | 14 | from __future__ import print_function 15 | import os 16 | import argparse 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | import torch.optim as optim 21 | from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR, MultiStepLR 22 | from data import ShapeNetPart 23 | from models.curvenet_seg import CurveNet 24 | import numpy as np 25 | from torch.utils.data import DataLoader 26 | from util import cal_loss, IOStream 27 | import sklearn.metrics as metrics 28 | 29 | seg_num = [4, 2, 2, 4, 4, 3, 3, 2, 4, 2, 6, 2, 3, 3, 3, 3] 30 | index_start = [0, 4, 6, 8, 12, 16, 19, 22, 24, 28, 30, 36, 38, 41, 44, 47] 31 | 32 | def _init_(): 33 | # fix random seed 34 | torch.manual_seed(seed) 35 | np.random.seed(seed) 36 | torch.cuda.manual_seed_all(seed) 37 | torch.cuda.manual_seed(seed) 38 | torch.set_printoptions(10) 39 | torch.backends.cudnn.benchmark = False 40 | torch.backends.cudnn.deterministic = True 41 | os.environ['PYTHONHASHSEED'] = str(seed) 42 | 43 | # prepare file structures 44 | if not os.path.exists('../checkpoints'): 45 | os.makedirs('../checkpoints') 46 | if not os.path.exists('../checkpoints/'+args.exp_name): 47 | os.makedirs('../checkpoints/'+args.exp_name) 48 | if not os.path.exists('../checkpoints/'+args.exp_name+'/'+'models'): 49 | os.makedirs('../checkpoints/'+args.exp_name+'/'+'models') 50 | os.system('cp main_partseg.py ../checkpoints/'+args.exp_name+'/main_partseg.py.backup') 51 | os.system('cp models/curvenet_seg.py ../checkpoints/'+args.exp_name+'/curvenet_seg.py.backup') 52 | 53 | def calculate_shape_IoU(pred_np, seg_np, label, class_choice, eva=False): 54 | label = label.squeeze() 55 | shape_ious = [] 56 | category = {} 57 | for shape_idx in range(seg_np.shape[0]): 58 | if not class_choice: 59 | start_index = index_start[label[shape_idx]] 60 | num = seg_num[label[shape_idx]] 61 | parts = range(start_index, start_index + num) 62 | else: 63 | parts = range(seg_num[label[0]]) 64 | part_ious = [] 65 | for part in parts: 66 | I = np.sum(np.logical_and(pred_np[shape_idx] == part, seg_np[shape_idx] == part)) 67 | U = np.sum(np.logical_or(pred_np[shape_idx] == part, seg_np[shape_idx] == part)) 68 | if U == 0: 69 | iou = 1 # If the union of groundtruth and prediction points is empty, then count part IoU as 1 70 | else: 71 | iou = I / float(U) 72 | part_ious.append(iou) 73 | shape_ious.append(np.mean(part_ious)) 74 | if label[shape_idx] not in category: 75 | category[label[shape_idx]] = [shape_ious[-1]] 76 | else: 77 | category[label[shape_idx]].append(shape_ious[-1]) 78 | 79 | if eva: 80 | return shape_ious, category 81 | else: 82 | return shape_ious 83 | 84 | def train(args, io): 85 | train_dataset = ShapeNetPart(partition='trainval', num_points=args.num_points, class_choice=args.class_choice) 86 | if (len(train_dataset) < 100): 87 | drop_last = False 88 | else: 89 | drop_last = True 90 | train_loader = DataLoader(train_dataset, num_workers=8, batch_size=args.batch_size, shuffle=True, drop_last=drop_last) 91 | test_loader = DataLoader(ShapeNetPart(partition='test', num_points=args.num_points, class_choice=args.class_choice), 92 | num_workers=8, batch_size=args.test_batch_size, shuffle=False, drop_last=False) 93 | 94 | device = torch.device("cuda" if args.cuda else "cpu") 95 | io.cprint("Let's use" + str(torch.cuda.device_count()) + "GPUs!") 96 | 97 | seg_num_all = train_loader.dataset.seg_num_all 98 | seg_start_index = train_loader.dataset.seg_start_index 99 | 100 | # create model 101 | model = CurveNet().to(device) 102 | model = nn.DataParallel(model) 103 | 104 | if args.use_sgd: 105 | print("Use SGD") 106 | opt = optim.SGD(model.parameters(), lr=args.lr*100, momentum=args.momentum, weight_decay=1e-4) 107 | else: 108 | print("Use Adam") 109 | opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4) 110 | 111 | if args.scheduler == 'cos': 112 | scheduler = CosineAnnealingLR(opt, args.epochs, eta_min=1e-3) 113 | elif args.scheduler == 'step': 114 | scheduler = MultiStepLR(opt, [140, 180], gamma=0.1) 115 | criterion = cal_loss 116 | 117 | best_test_iou = 0 118 | for epoch in range(args.epochs): 119 | #################### 120 | # Train 121 | #################### 122 | train_loss = 0.0 123 | count = 0.0 124 | model.train() 125 | train_true_cls = [] 126 | train_pred_cls = [] 127 | train_true_seg = [] 128 | train_pred_seg = [] 129 | train_label_seg = [] 130 | for data, label, seg in train_loader: 131 | seg = seg - seg_start_index 132 | label_one_hot = np.zeros((label.shape[0], 16)) 133 | for idx in range(label.shape[0]): 134 | label_one_hot[idx, label[idx]] = 1 135 | label_one_hot = torch.from_numpy(label_one_hot.astype(np.float32)) 136 | data, label_one_hot, seg = data.to(device), label_one_hot.to(device), seg.to(device) 137 | data = data.permute(0, 2, 1) 138 | batch_size = data.size()[0] 139 | opt.zero_grad() 140 | seg_pred = model(data, label_one_hot) 141 | seg_pred = seg_pred.permute(0, 2, 1).contiguous() 142 | loss = criterion(seg_pred.view(-1, seg_num_all), seg.view(-1,1).squeeze()) 143 | loss.backward() 144 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1) 145 | opt.step() 146 | pred = seg_pred.max(dim=2)[1] # (batch_size, num_points) 147 | count += batch_size 148 | train_loss += loss.item() * batch_size 149 | seg_np = seg.cpu().numpy() # (batch_size, num_points) 150 | pred_np = pred.detach().cpu().numpy() # (batch_size, num_points) 151 | train_true_cls.append(seg_np.reshape(-1)) # (batch_size * num_points) 152 | train_pred_cls.append(pred_np.reshape(-1)) # (batch_size * num_points) 153 | train_true_seg.append(seg_np) 154 | train_pred_seg.append(pred_np) 155 | train_label_seg.append(label.reshape(-1)) 156 | if args.scheduler == 'cos': 157 | scheduler.step() 158 | elif args.scheduler == 'step': 159 | if opt.param_groups[0]['lr'] > 1e-5: 160 | scheduler.step() 161 | if opt.param_groups[0]['lr'] < 1e-5: 162 | for param_group in opt.param_groups: 163 | param_group['lr'] = 1e-5 164 | train_true_cls = np.concatenate(train_true_cls) 165 | train_pred_cls = np.concatenate(train_pred_cls) 166 | train_acc = metrics.accuracy_score(train_true_cls, train_pred_cls) 167 | avg_per_class_acc = metrics.balanced_accuracy_score(train_true_cls, train_pred_cls) 168 | train_true_seg = np.concatenate(train_true_seg, axis=0) 169 | train_pred_seg = np.concatenate(train_pred_seg, axis=0) 170 | train_label_seg = np.concatenate(train_label_seg) 171 | train_ious = calculate_shape_IoU(train_pred_seg, train_true_seg, train_label_seg, args.class_choice) 172 | outstr = 'Train %d, loss: %.6f, train acc: %.6f, train avg acc: %.6f, train iou: %.6f' % (epoch, 173 | train_loss*1.0/count, 174 | train_acc, 175 | avg_per_class_acc, 176 | np.mean(train_ious)) 177 | io.cprint(outstr) 178 | 179 | #################### 180 | # Test 181 | #################### 182 | test_loss = 0.0 183 | count = 0.0 184 | model.eval() 185 | test_true_cls = [] 186 | test_pred_cls = [] 187 | test_true_seg = [] 188 | test_pred_seg = [] 189 | test_label_seg = [] 190 | for data, label, seg in test_loader: 191 | seg = seg - seg_start_index 192 | label_one_hot = np.zeros((label.shape[0], 16)) 193 | for idx in range(label.shape[0]): 194 | label_one_hot[idx, label[idx]] = 1 195 | label_one_hot = torch.from_numpy(label_one_hot.astype(np.float32)) 196 | data, label_one_hot, seg = data.to(device), label_one_hot.to(device), seg.to(device) 197 | data = data.permute(0, 2, 1) 198 | batch_size = data.size()[0] 199 | seg_pred = model(data, label_one_hot) 200 | seg_pred = seg_pred.permute(0, 2, 1).contiguous() 201 | loss = criterion(seg_pred.view(-1, seg_num_all), seg.view(-1,1).squeeze()) 202 | pred = seg_pred.max(dim=2)[1] 203 | count += batch_size 204 | test_loss += loss.item() * batch_size 205 | seg_np = seg.cpu().numpy() 206 | pred_np = pred.detach().cpu().numpy() 207 | test_true_cls.append(seg_np.reshape(-1)) 208 | test_pred_cls.append(pred_np.reshape(-1)) 209 | test_true_seg.append(seg_np) 210 | test_pred_seg.append(pred_np) 211 | test_label_seg.append(label.reshape(-1)) 212 | test_true_cls = np.concatenate(test_true_cls) 213 | test_pred_cls = np.concatenate(test_pred_cls) 214 | test_acc = metrics.accuracy_score(test_true_cls, test_pred_cls) 215 | avg_per_class_acc = metrics.balanced_accuracy_score(test_true_cls, test_pred_cls) 216 | test_true_seg = np.concatenate(test_true_seg, axis=0) 217 | test_pred_seg = np.concatenate(test_pred_seg, axis=0) 218 | test_label_seg = np.concatenate(test_label_seg) 219 | test_ious = calculate_shape_IoU(test_pred_seg, test_true_seg, test_label_seg, args.class_choice) 220 | outstr = 'Test %d, loss: %.6f, test acc: %.6f, test avg acc: %.6f, test iou: %.6f, best iou %.6f' % (epoch, 221 | test_loss*1.0/count, 222 | test_acc, 223 | avg_per_class_acc, 224 | np.mean(test_ious), best_test_iou) 225 | io.cprint(outstr) 226 | if np.mean(test_ious) >= best_test_iou: 227 | best_test_iou = np.mean(test_ious) 228 | torch.save(model.state_dict(), '../checkpoints/%s/models/model.t7' % args.exp_name) 229 | 230 | 231 | def test(args, io): 232 | test_loader = DataLoader(ShapeNetPart(partition='test', num_points=args.num_points, class_choice=args.class_choice), 233 | batch_size=args.test_batch_size, shuffle=True, drop_last=False) 234 | 235 | device = torch.device("cuda" if args.cuda else "cpu") 236 | 237 | #Try to load models 238 | seg_start_index = test_loader.dataset.seg_start_index 239 | model = CurveNet().to(device) 240 | model = nn.DataParallel(model) 241 | model.load_state_dict(torch.load(args.model_path)) 242 | 243 | model = model.eval() 244 | test_acc = 0.0 245 | test_true_cls = [] 246 | test_pred_cls = [] 247 | test_true_seg = [] 248 | test_pred_seg = [] 249 | test_label_seg = [] 250 | category = {} 251 | for data, label, seg in test_loader: 252 | seg = seg - seg_start_index 253 | label_one_hot = np.zeros((label.shape[0], 16)) 254 | for idx in range(label.shape[0]): 255 | label_one_hot[idx, label[idx]] = 1 256 | label_one_hot = torch.from_numpy(label_one_hot.astype(np.float32)) 257 | data, label_one_hot, seg = data.to(device), label_one_hot.to(device), seg.to(device) 258 | data = data.permute(0, 2, 1) 259 | seg_pred = model(data, label_one_hot) 260 | seg_pred = seg_pred.permute(0, 2, 1).contiguous() 261 | pred = seg_pred.max(dim=2)[1] 262 | seg_np = seg.cpu().numpy() 263 | pred_np = pred.detach().cpu().numpy() 264 | test_true_cls.append(seg_np.reshape(-1)) 265 | test_pred_cls.append(pred_np.reshape(-1)) 266 | test_true_seg.append(seg_np) 267 | test_pred_seg.append(pred_np) 268 | test_label_seg.append(label.reshape(-1)) 269 | 270 | test_true_cls = np.concatenate(test_true_cls) 271 | test_pred_cls = np.concatenate(test_pred_cls) 272 | test_acc = metrics.accuracy_score(test_true_cls, test_pred_cls) 273 | avg_per_class_acc = metrics.balanced_accuracy_score(test_true_cls, test_pred_cls) 274 | test_true_seg = np.concatenate(test_true_seg, axis=0) 275 | test_pred_seg = np.concatenate(test_pred_seg, axis=0) 276 | test_label_seg = np.concatenate(test_label_seg) 277 | test_ious,category = calculate_shape_IoU(test_pred_seg, test_true_seg, test_label_seg, args.class_choice, eva=True) 278 | outstr = 'Test :: test acc: %.6f, test avg acc: %.6f, test iou: %.6f' % (test_acc, 279 | avg_per_class_acc, 280 | np.mean(test_ious)) 281 | io.cprint(outstr) 282 | results = [] 283 | for key in category.keys(): 284 | results.append((int(key), np.mean(category[key]), len(category[key]))) 285 | results.sort(key=lambda x:x[0]) 286 | for re in results: 287 | io.cprint('idx: %d mIoU: %.3f num: %d' % (re[0], re[1], re[2])) 288 | 289 | 290 | if __name__ == "__main__": 291 | # Training settings 292 | parser = argparse.ArgumentParser(description='Point Cloud Part Segmentation') 293 | parser.add_argument('--exp_name', type=str, default='exp', metavar='N', 294 | help='Name of the experiment') 295 | parser.add_argument('--dataset', type=str, default='shapenetpart', metavar='N', 296 | choices=['shapenetpart']) 297 | parser.add_argument('--class_choice', type=str, default=None, metavar='N', 298 | choices=['airplane', 'bag', 'cap', 'car', 'chair', 299 | 'earphone', 'guitar', 'knife', 'lamp', 'laptop', 300 | 'motor', 'mug', 'pistol', 'rocket', 'skateboard', 'table']) 301 | parser.add_argument('--batch_size', type=int, default=32, metavar='batch_size', 302 | help='Size of batch)') 303 | parser.add_argument('--test_batch_size', type=int, default=16, metavar='batch_size', 304 | help='Size of batch)') 305 | parser.add_argument('--epochs', type=int, default=200, metavar='N', 306 | help='number of episode to train ') 307 | parser.add_argument('--use_sgd', type=bool, default=True, 308 | help='Use SGD') 309 | parser.add_argument('--lr', type=float, default=0.0005, metavar='LR', 310 | help='learning rate (default: 0.001, 0.1 if using sgd)') 311 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 312 | help='SGD momentum (default: 0.9)') 313 | parser.add_argument('--scheduler', type=str, default='step', metavar='N', 314 | choices=['cos', 'step'], 315 | help='Scheduler to use, [cos, step]') 316 | parser.add_argument('--no_cuda', type=bool, default=False, 317 | help='enables CUDA training') 318 | parser.add_argument('--eval', type=bool, default=False, 319 | help='evaluate the model') 320 | parser.add_argument('--num_points', type=int, default=2048, 321 | help='num of points to use') 322 | parser.add_argument('--model_path', type=str, default='', metavar='N', 323 | help='Pretrained model path') 324 | args = parser.parse_args() 325 | 326 | seed = np.random.randint(1, 10000) 327 | 328 | _init_() 329 | 330 | if args.eval: 331 | io = IOStream('../checkpoints/' + args.exp_name + '/eval.log') 332 | else: 333 | io = IOStream('../checkpoints/' + args.exp_name + '/run.log') 334 | io.cprint(str(args)) 335 | io.cprint('random seed is: ' + str(seed)) 336 | 337 | args.cuda = not args.no_cuda and torch.cuda.is_available() 338 | 339 | if args.cuda: 340 | io.cprint( 341 | 'Using GPU : ' + str(torch.cuda.current_device()) + ' from ' + str(torch.cuda.device_count()) + ' devices') 342 | else: 343 | io.cprint('Using CPU') 344 | 345 | if not args.eval: 346 | train(args, io) 347 | else: 348 | with torch.no_grad(): 349 | test(args, io) 350 | -------------------------------------------------------------------------------- /core/models/curvenet_cls.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: Tiange Xiang 3 | @Contact: txia7609@uni.sydney.edu.au 4 | @File: curvenet_cls.py 5 | @Time: 2021/01/21 3:10 PM 6 | """ 7 | 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from .curvenet_util import * 11 | 12 | 13 | curve_config = { 14 | 'default': [[100, 5], [100, 5], None, None], 15 | 'long': [[10, 30], None, None, None] 16 | } 17 | 18 | class CurveNet(nn.Module): 19 | def __init__(self, num_classes=40, k=20, setting='default'): 20 | super(CurveNet, self).__init__() 21 | 22 | assert setting in curve_config 23 | 24 | additional_channel = 32 25 | self.lpfa = LPFA(9, additional_channel, k=k, mlp_num=1, initial=True) 26 | 27 | # encoder 28 | self.cic11 = CIC(npoint=1024, radius=0.05, k=k, in_channels=additional_channel, output_channels=64, bottleneck_ratio=2, mlp_num=1, curve_config=curve_config[setting][0]) 29 | self.cic12 = CIC(npoint=1024, radius=0.05, k=k, in_channels=64, output_channels=64, bottleneck_ratio=4, mlp_num=1, curve_config=curve_config[setting][0]) 30 | 31 | self.cic21 = CIC(npoint=1024, radius=0.05, k=k, in_channels=64, output_channels=128, bottleneck_ratio=2, mlp_num=1, curve_config=curve_config[setting][1]) 32 | self.cic22 = CIC(npoint=1024, radius=0.1, k=k, in_channels=128, output_channels=128, bottleneck_ratio=4, mlp_num=1, curve_config=curve_config[setting][1]) 33 | 34 | self.cic31 = CIC(npoint=256, radius=0.1, k=k, in_channels=128, output_channels=256, bottleneck_ratio=2, mlp_num=1, curve_config=curve_config[setting][2]) 35 | self.cic32 = CIC(npoint=256, radius=0.2, k=k, in_channels=256, output_channels=256, bottleneck_ratio=4, mlp_num=1, curve_config=curve_config[setting][2]) 36 | 37 | self.cic41 = CIC(npoint=64, radius=0.2, k=k, in_channels=256, output_channels=512, bottleneck_ratio=2, mlp_num=1, curve_config=curve_config[setting][3]) 38 | self.cic42 = CIC(npoint=64, radius=0.4, k=k, in_channels=512, output_channels=512, bottleneck_ratio=4, mlp_num=1, curve_config=curve_config[setting][3]) 39 | 40 | self.conv0 = nn.Sequential( 41 | nn.Conv1d(512, 1024, kernel_size=1, bias=False), 42 | nn.BatchNorm1d(1024), 43 | nn.ReLU(inplace=True)) 44 | self.conv1 = nn.Linear(1024 * 2, 512, bias=False) 45 | self.conv2 = nn.Linear(512, num_classes) 46 | self.bn1 = nn.BatchNorm1d(512) 47 | self.dp1 = nn.Dropout(p=0.5) 48 | 49 | def forward(self, xyz, get_flatten_curve_idxs=False): 50 | flatten_curve_idxs = {} 51 | l0_points = self.lpfa(xyz, xyz) 52 | 53 | l1_xyz, l1_points, flatten_curve_idxs_11 = self.cic11(xyz, l0_points) 54 | flatten_curve_idxs['flatten_curve_idxs_11'] = flatten_curve_idxs_11 55 | l1_xyz, l1_points, flatten_curve_idxs_12 = self.cic12(l1_xyz, l1_points) 56 | flatten_curve_idxs['flatten_curve_idxs_12'] = flatten_curve_idxs_12 57 | 58 | l2_xyz, l2_points, flatten_curve_idxs_21 = self.cic21(l1_xyz, l1_points) 59 | flatten_curve_idxs['flatten_curve_idxs_21'] = flatten_curve_idxs_21 60 | l2_xyz, l2_points, flatten_curve_idxs_22 = self.cic22(l2_xyz, l2_points) 61 | flatten_curve_idxs['flatten_curve_idxs_22'] = flatten_curve_idxs_22 62 | 63 | l3_xyz, l3_points, flatten_curve_idxs_31 = self.cic31(l2_xyz, l2_points) 64 | flatten_curve_idxs['flatten_curve_idxs_31'] = flatten_curve_idxs_31 65 | l3_xyz, l3_points, flatten_curve_idxs_32 = self.cic32(l3_xyz, l3_points) 66 | flatten_curve_idxs['flatten_curve_idxs_32'] = flatten_curve_idxs_32 67 | 68 | l4_xyz, l4_points, flatten_curve_idxs_41 = self.cic41(l3_xyz, l3_points) 69 | flatten_curve_idxs['flatten_curve_idxs_41'] = flatten_curve_idxs_41 70 | l4_xyz, l4_points, flatten_curve_idxs_42 = self.cic42(l4_xyz, l4_points) 71 | flatten_curve_idxs['flatten_curve_idxs_42'] = flatten_curve_idxs_42 72 | 73 | x = self.conv0(l4_points) 74 | x_max = F.adaptive_max_pool1d(x, 1) 75 | x_avg = F.adaptive_avg_pool1d(x, 1) 76 | 77 | x = torch.cat((x_max, x_avg), dim=1).squeeze(-1) 78 | x = F.relu(self.bn1(self.conv1(x).unsqueeze(-1)), inplace=True).squeeze(-1) 79 | x = self.dp1(x) 80 | x = self.conv2(x) 81 | if get_flatten_curve_idxs: 82 | return x, flatten_curve_idxs 83 | else: 84 | return x 85 | -------------------------------------------------------------------------------- /core/models/curvenet_normal.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: Tiange Xiang 3 | @Contact: txia7609@uni.sydney.edu.au 4 | @File: curvenet_normal.py 5 | @Time: 2021/01/21 3:10 PM 6 | """ 7 | 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from .curvenet_util import * 11 | 12 | 13 | curve_config = { 14 | 'default': [[100, 5], [100, 5], None, None] 15 | } 16 | 17 | class CurveNet(nn.Module): 18 | def __init__(self, num_classes=3, k=20, multiplier=1.0, setting='default'): 19 | super(CurveNet, self).__init__() 20 | 21 | assert setting in curve_config 22 | 23 | additional_channel = 64 24 | channels = [128, 256, 512, 1024] 25 | channels = [int(c * multiplier) for c in channels] 26 | 27 | self.lpfa = LPFA(9, additional_channel, k=k, mlp_num=1, initial=True) 28 | 29 | # encoder 30 | self.cic11 = CIC(npoint=1024, radius=0.1, k=k, in_channels=additional_channel, output_channels=channels[0], bottleneck_ratio=2, curve_config=curve_config[setting][0]) 31 | self.cic12 = CIC(npoint=1024, radius=0.1, k=k, in_channels=channels[0], output_channels=channels[0], bottleneck_ratio=4, curve_config=curve_config[setting][0]) 32 | 33 | self.cic21 = CIC(npoint=256, radius=0.2, k=k, in_channels=channels[0], output_channels=channels[1], bottleneck_ratio=2, curve_config=curve_config[setting][1]) 34 | self.cic22 = CIC(npoint=256, radius=0.2, k=k, in_channels=channels[1], output_channels=channels[1], bottleneck_ratio=4, curve_config=curve_config[setting][1]) 35 | 36 | self.cic31 = CIC(npoint=64, radius=0.4, k=k, in_channels=channels[1], output_channels=channels[2], bottleneck_ratio=2, curve_config=curve_config[setting][2]) 37 | self.cic32 = CIC(npoint=64, radius=0.4, k=k, in_channels=channels[2], output_channels=channels[2], bottleneck_ratio=4, curve_config=curve_config[setting][2]) 38 | 39 | self.cic41 = CIC(npoint=16, radius=0.8, k=15, in_channels=channels[2], output_channels=channels[3], bottleneck_ratio=2, curve_config=curve_config[setting][3]) 40 | self.cic42 = CIC(npoint=16, radius=0.8, k=15, in_channels=channels[3], output_channels=channels[3], bottleneck_ratio=4, curve_config=curve_config[setting][3]) 41 | #self.cic43 = CIC(npoint=16, radius=0.8, k=15, in_channels=2048, output_channels=2048, bottleneck_ratio=4, curve_config=curve_config[setting][3]) 42 | # decoder 43 | self.fp3 = PointNetFeaturePropagation(in_channel=channels[3] + channels[2], mlp=[channels[2], channels[2]], att=[channels[3], channels[3]//2, channels[3]//8]) 44 | self.up_cic4 = CIC(npoint=64, radius=0.8, k=k, in_channels=channels[2], output_channels=channels[2], bottleneck_ratio=4) 45 | 46 | self.fp2 = PointNetFeaturePropagation(in_channel=channels[2] + channels[1], mlp=[channels[1], channels[1]], att=[channels[2], channels[2]//2, channels[2]//8]) 47 | self.up_cic3 = CIC(npoint=256, radius=0.4, k=k, in_channels=channels[1], output_channels=channels[1], bottleneck_ratio=4) 48 | 49 | self.fp1 = PointNetFeaturePropagation(in_channel=channels[1] + channels[0], mlp=[channels[0], channels[0]], att=[channels[1], channels[1]//2, channels[1]//8]) 50 | self.up_cic2 = CIC(npoint=1024, radius=0.1, k=k, in_channels=channels[0]+3, output_channels=channels[0], bottleneck_ratio=4) 51 | self.up_cic1 = CIC(npoint=1024, radius=0.1, k=k, in_channels=channels[0], output_channels=channels[0], bottleneck_ratio=4) 52 | 53 | self.point_conv = nn.Sequential( 54 | nn.Conv2d(9, additional_channel, kernel_size=1, bias=False), 55 | nn.BatchNorm2d(additional_channel), 56 | nn.LeakyReLU(negative_slope=0.2, inplace=True)) 57 | 58 | self.conv1 = nn.Conv1d(channels[0], num_classes, 1) 59 | 60 | def forward(self, xyz): 61 | l0_points = self.lpfa(xyz, xyz) 62 | 63 | l1_xyz, l1_points = self.cic11(xyz, l0_points) 64 | l1_xyz, l1_points = self.cic12(l1_xyz, l1_points) 65 | 66 | l2_xyz, l2_points = self.cic21(l1_xyz, l1_points) 67 | l2_xyz, l2_points = self.cic22(l2_xyz, l2_points) 68 | 69 | l3_xyz, l3_points = self.cic31(l2_xyz, l2_points) 70 | l3_xyz, l3_points = self.cic32(l3_xyz, l3_points) 71 | 72 | l4_xyz, l4_points = self.cic41(l3_xyz, l3_points) 73 | l4_xyz, l4_points = self.cic42(l4_xyz, l4_points) 74 | #l4_xyz, l4_points = self.cic43(l4_xyz, l4_points) 75 | 76 | l3_points = self.fp3(l3_xyz, l4_xyz, l3_points, l4_points) 77 | l3_xyz, l3_points = self.up_cic4(l3_xyz, l3_points) 78 | l2_points = self.fp2(l2_xyz, l3_xyz, l2_points, l3_points) 79 | l2_xyz, l2_points = self.up_cic3(l2_xyz, l2_points) 80 | l1_points = self.fp1(l1_xyz, l2_xyz, l1_points, l2_points) 81 | 82 | x = torch.cat((l1_xyz, l1_points), dim=1) 83 | 84 | xyz, x = self.up_cic2(l1_xyz, x) 85 | xyz, x = self.up_cic1(xyz, x) 86 | 87 | x = self.conv1(x) 88 | return x 89 | -------------------------------------------------------------------------------- /core/models/curvenet_seg.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: Tiange Xiang 3 | @Contact: txia7609@uni.sydney.edu.au 4 | @File: curvenet_seg.py 5 | @Time: 2021/01/21 3:10 PM 6 | """ 7 | 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from .curvenet_util import * 11 | 12 | 13 | curve_config = { 14 | 'default': [[100, 5], [100, 5], None, None, None] 15 | } 16 | 17 | class CurveNet(nn.Module): 18 | def __init__(self, num_classes=50, category=16, k=32, setting='default'): 19 | super(CurveNet, self).__init__() 20 | 21 | assert setting in curve_config 22 | 23 | additional_channel = 32 24 | self.lpfa = LPFA(9, additional_channel, k=k, mlp_num=1, initial=True) 25 | 26 | # encoder 27 | self.cic11 = CIC(npoint=2048, radius=0.2, k=k, in_channels=additional_channel, output_channels=64, bottleneck_ratio=2, curve_config=curve_config[setting][0]) 28 | self.cic12 = CIC(npoint=2048, radius=0.2, k=k, in_channels=64, output_channels=64, bottleneck_ratio=4, curve_config=curve_config[setting][0]) 29 | 30 | self.cic21 = CIC(npoint=512, radius=0.4, k=k, in_channels=64, output_channels=128, bottleneck_ratio=2, curve_config=curve_config[setting][1]) 31 | self.cic22 = CIC(npoint=512, radius=0.4, k=k, in_channels=128, output_channels=128, bottleneck_ratio=4, curve_config=curve_config[setting][1]) 32 | 33 | self.cic31 = CIC(npoint=128, radius=0.8, k=k, in_channels=128, output_channels=256, bottleneck_ratio=2, curve_config=curve_config[setting][2]) 34 | self.cic32 = CIC(npoint=128, radius=0.8, k=k, in_channels=256, output_channels=256, bottleneck_ratio=4, curve_config=curve_config[setting][2]) 35 | 36 | self.cic41 = CIC(npoint=32, radius=1.2, k=31, in_channels=256, output_channels=512, bottleneck_ratio=2, curve_config=curve_config[setting][3]) 37 | self.cic42 = CIC(npoint=32, radius=1.2, k=31, in_channels=512, output_channels=512, bottleneck_ratio=4, curve_config=curve_config[setting][3]) 38 | 39 | self.cic51 = CIC(npoint=8, radius=2.0, k=7, in_channels=512, output_channels=1024, bottleneck_ratio=2, curve_config=curve_config[setting][4]) 40 | self.cic52 = CIC(npoint=8, radius=2.0, k=7, in_channels=1024, output_channels=1024, bottleneck_ratio=4, curve_config=curve_config[setting][4]) 41 | self.cic53 = CIC(npoint=8, radius=2.0, k=7, in_channels=1024, output_channels=1024, bottleneck_ratio=4, curve_config=curve_config[setting][4]) 42 | 43 | # decoder 44 | self.fp4 = PointNetFeaturePropagation(in_channel=1024 + 512, mlp=[512, 512], att=[1024, 512, 256]) 45 | self.up_cic5 = CIC(npoint=32, radius=1.2, k=31, in_channels=512, output_channels=512, bottleneck_ratio=4) 46 | 47 | self.fp3 = PointNetFeaturePropagation(in_channel=512 + 256, mlp=[256, 256], att=[512, 256, 128]) 48 | self.up_cic4 = CIC(npoint=128, radius=0.8, k=k, in_channels=256, output_channels=256, bottleneck_ratio=4) 49 | 50 | self.fp2 = PointNetFeaturePropagation(in_channel=256 + 128, mlp=[128, 128], att=[256, 128, 64]) 51 | self.up_cic3 = CIC(npoint=512, radius=0.4, k=k, in_channels=128, output_channels=128, bottleneck_ratio=4) 52 | 53 | self.fp1 = PointNetFeaturePropagation(in_channel=128 + 64, mlp=[64, 64], att=[128, 64, 32]) 54 | self.up_cic2 = CIC(npoint=2048, radius=0.2, k=k, in_channels=128+64+64+category+3, output_channels=256, bottleneck_ratio=4) 55 | self.up_cic1 = CIC(npoint=2048, radius=0.2, k=k, in_channels=256, output_channels=256, bottleneck_ratio=4) 56 | 57 | 58 | self.global_conv2 = nn.Sequential( 59 | nn.Conv1d(1024, 128, kernel_size=1, bias=False), 60 | nn.BatchNorm1d(128), 61 | nn.LeakyReLU(negative_slope=0.2)) 62 | self.global_conv1 = nn.Sequential( 63 | nn.Conv1d(512, 64, kernel_size=1, bias=False), 64 | nn.BatchNorm1d(64), 65 | nn.LeakyReLU(negative_slope=0.2)) 66 | 67 | self.conv1 = nn.Conv1d(256, 256, 1, bias=False) 68 | self.bn1 = nn.BatchNorm1d(256) 69 | self.drop1 = nn.Dropout(0.5) 70 | self.conv2 = nn.Conv1d(256, num_classes, 1) 71 | self.se = nn.Sequential(nn.AdaptiveAvgPool1d(1), 72 | nn.Conv1d(256, 256//8, 1, bias=False), 73 | nn.BatchNorm1d(256//8), 74 | nn.LeakyReLU(negative_slope=0.2), 75 | nn.Conv1d(256//8, 256, 1, bias=False), 76 | nn.Sigmoid()) 77 | 78 | def forward(self, xyz, l=None): 79 | batch_size = xyz.size(0) 80 | 81 | l0_points = self.lpfa(xyz, xyz) 82 | 83 | l1_xyz, l1_points = self.cic11(xyz, l0_points) 84 | l1_xyz, l1_points = self.cic12(l1_xyz, l1_points) 85 | 86 | l2_xyz, l2_points = self.cic21(l1_xyz, l1_points) 87 | l2_xyz, l2_points = self.cic22(l2_xyz, l2_points) 88 | 89 | l3_xyz, l3_points = self.cic31(l2_xyz, l2_points) 90 | l3_xyz, l3_points = self.cic32(l3_xyz, l3_points) 91 | 92 | l4_xyz, l4_points = self.cic41(l3_xyz, l3_points) 93 | l4_xyz, l4_points = self.cic42(l4_xyz, l4_points) 94 | 95 | l5_xyz, l5_points = self.cic51(l4_xyz, l4_points) 96 | l5_xyz, l5_points = self.cic52(l5_xyz, l5_points) 97 | l5_xyz, l5_points = self.cic53(l5_xyz, l5_points) 98 | 99 | # global features 100 | emb1 = self.global_conv1(l4_points) 101 | emb1 = emb1.max(dim=-1, keepdim=True)[0] # bs, 64, 1 102 | emb2 = self.global_conv2(l5_points) 103 | emb2 = emb2.max(dim=-1, keepdim=True)[0] # bs, 128, 1 104 | 105 | # Feature Propagation layers 106 | l4_points = self.fp4(l4_xyz, l5_xyz, l4_points, l5_points) 107 | l4_xyz, l4_points = self.up_cic5(l4_xyz, l4_points) 108 | 109 | l3_points = self.fp3(l3_xyz, l4_xyz, l3_points, l4_points) 110 | l3_xyz, l3_points = self.up_cic4(l3_xyz, l3_points) 111 | 112 | l2_points = self.fp2(l2_xyz, l3_xyz, l2_points, l3_points) 113 | l2_xyz, l2_points = self.up_cic3(l2_xyz, l2_points) 114 | 115 | l1_points = self.fp1(l1_xyz, l2_xyz, l1_points, l2_points) 116 | 117 | if l is not None: 118 | l = l.view(batch_size, -1, 1) 119 | emb = torch.cat((emb1, emb2, l), dim=1) # bs, 128 + 16, 1 120 | l = emb.expand(-1,-1, xyz.size(-1)) 121 | x = torch.cat((l1_xyz, l1_points, l), dim=1) 122 | 123 | xyz, x = self.up_cic2(l1_xyz, x) 124 | xyz, x = self.up_cic1(xyz, x) 125 | 126 | x = F.leaky_relu(self.bn1(self.conv1(x)), 0.2, inplace=True) 127 | se = self.se(x) 128 | x = x * se 129 | x = self.drop1(x) 130 | x = self.conv2(x) 131 | return x 132 | -------------------------------------------------------------------------------- /core/models/curvenet_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: Yue Wang 3 | @Contact: yuewangx@mit.edu 4 | @File: pointnet_util.py 5 | @Time: 2018/10/13 10:39 PM 6 | 7 | Modified by 8 | @Author: Tiange Xiang 9 | @Contact: txia7609@uni.sydney.edu.au 10 | @Time: 2021/01/21 3:10 PM 11 | """ 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from time import time 17 | import numpy as np 18 | 19 | from .walk import Walk 20 | 21 | 22 | def knn(x, k): 23 | k = k + 1 24 | inner = -2 * torch.matmul(x.transpose(2, 1), x) 25 | xx = torch.sum(x**2, dim=1, keepdim=True) 26 | pairwise_distance = -xx - inner - xx.transpose(2, 1) 27 | 28 | idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k) 29 | return idx 30 | 31 | def normal_knn(x, k): 32 | inner = -2 * torch.matmul(x.transpose(2, 1), x) 33 | xx = torch.sum(x**2, dim=1, keepdim=True) 34 | pairwise_distance = -xx - inner - xx.transpose(2, 1) 35 | 36 | idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k) 37 | return idx 38 | 39 | def pc_normalize(pc): 40 | l = pc.shape[0] 41 | centroid = np.mean(pc, axis=0) 42 | pc = pc - centroid 43 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 44 | pc = pc / m 45 | return pc 46 | 47 | def square_distance(src, dst): 48 | """ 49 | Calculate Euclid distance between each two points. 50 | """ 51 | B, N, _ = src.shape 52 | _, M, _ = dst.shape 53 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 54 | dist += torch.sum(src ** 2, -1).view(B, N, 1) 55 | dist += torch.sum(dst ** 2, -1).view(B, 1, M) 56 | return dist 57 | 58 | def index_points(points, idx): 59 | """ 60 | 61 | Input: 62 | points: input points data, [B, N, C] 63 | idx: sample index data, [B, S] 64 | Return: 65 | new_points:, indexed points data, [B, S, C] 66 | """ 67 | device = points.device 68 | B = points.shape[0] 69 | view_shape = list(idx.shape) 70 | view_shape[1:] = [1] * (len(view_shape) - 1) 71 | repeat_shape = list(idx.shape) 72 | repeat_shape[0] = 1 73 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) 74 | new_points = points[batch_indices, idx, :] 75 | return new_points 76 | 77 | 78 | def farthest_point_sample(xyz, npoint): 79 | """ 80 | Input: 81 | xyz: pointcloud data, [B, N, 3] 82 | npoint: number of samples 83 | Return: 84 | centroids: sampled pointcloud index, [B, npoint] 85 | """ 86 | device = xyz.device 87 | B, N, C = xyz.shape 88 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 89 | distance = torch.ones(B, N).to(device) * 1e10 90 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) * 0 91 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 92 | for i in range(npoint): 93 | centroids[:, i] = farthest 94 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 95 | dist = torch.sum((xyz - centroid) ** 2, -1) 96 | mask = dist < distance 97 | distance[mask] = dist[mask] 98 | farthest = torch.max(distance, -1)[1] 99 | return centroids 100 | 101 | def query_ball_point(radius, nsample, xyz, new_xyz): 102 | """ 103 | Input: 104 | radius: local region radius 105 | nsample: max sample number in local region 106 | xyz: all points, [B, N, 3] 107 | new_xyz: query points, [B, S, 3] 108 | Return: 109 | group_idx: grouped points index, [B, S, nsample] 110 | """ 111 | device = xyz.device 112 | B, N, C = xyz.shape 113 | _, S, _ = new_xyz.shape 114 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) 115 | sqrdists = square_distance(new_xyz, xyz) 116 | group_idx[sqrdists > radius ** 2] = N 117 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] 118 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) 119 | mask = group_idx == N 120 | group_idx[mask] = group_first[mask] 121 | return group_idx 122 | 123 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False): 124 | """ 125 | Input: 126 | npoint: 127 | radius: 128 | nsample: 129 | xyz: input points position data, [B, N, 3] 130 | points: input points data, [B, N, D] 131 | Return: 132 | new_xyz: sampled points position data, [B, npoint, nsample, 3] 133 | new_points: sampled points data, [B, npoint, nsample, 3+D] 134 | """ 135 | new_xyz = index_points(xyz, farthest_point_sample(xyz, npoint)) 136 | torch.cuda.empty_cache() 137 | 138 | idx = query_ball_point(radius, nsample, xyz, new_xyz) 139 | torch.cuda.empty_cache() 140 | 141 | new_points = index_points(points, idx) 142 | torch.cuda.empty_cache() 143 | 144 | if returnfps: 145 | return new_xyz, new_points, idx 146 | else: 147 | return new_xyz, new_points 148 | 149 | class Attention_block(nn.Module): 150 | ''' 151 | Used in attention U-Net. 152 | ''' 153 | def __init__(self,F_g,F_l,F_int): 154 | super(Attention_block,self).__init__() 155 | self.W_g = nn.Sequential( 156 | nn.Conv1d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True), 157 | nn.BatchNorm1d(F_int) 158 | ) 159 | 160 | self.W_x = nn.Sequential( 161 | nn.Conv1d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True), 162 | nn.BatchNorm1d(F_int) 163 | ) 164 | 165 | self.psi = nn.Sequential( 166 | nn.Conv1d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True), 167 | nn.BatchNorm1d(1), 168 | nn.Sigmoid() 169 | ) 170 | 171 | def forward(self,g,x): 172 | g1 = self.W_g(g) 173 | x1 = self.W_x(x) 174 | psi = F.leaky_relu(g1+x1, negative_slope=0.2) 175 | psi = self.psi(psi) 176 | 177 | return psi, 1. - psi 178 | 179 | 180 | class LPFA(nn.Module): 181 | def __init__(self, in_channel, out_channel, k, mlp_num=2, initial=False): 182 | super(LPFA, self).__init__() 183 | self.k = k 184 | self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 185 | self.initial = initial 186 | 187 | if not initial: 188 | self.xyz2feature = nn.Sequential( 189 | nn.Conv2d(9, in_channel, kernel_size=1, bias=False), 190 | nn.BatchNorm2d(in_channel)) 191 | 192 | self.mlp = [] 193 | for _ in range(mlp_num): 194 | self.mlp.append(nn.Sequential(nn.Conv2d(in_channel, out_channel, 1, bias=False), 195 | nn.BatchNorm2d(out_channel), 196 | nn.LeakyReLU(0.2))) 197 | in_channel = out_channel 198 | self.mlp = nn.Sequential(*self.mlp) 199 | 200 | def forward(self, x, xyz, idx=None): 201 | x = self.group_feature(x, xyz, idx) 202 | x = self.mlp(x) 203 | 204 | if self.initial: 205 | x = x.max(dim=-1, keepdim=False)[0] 206 | else: 207 | x = x.mean(dim=-1, keepdim=False) 208 | 209 | return x 210 | 211 | def group_feature(self, x, xyz, idx): 212 | batch_size, num_dims, num_points = x.size() 213 | 214 | if idx is None: 215 | idx = knn(xyz, k=self.k)[:,:,:self.k] # (batch_size, num_points, k) 216 | 217 | idx_base = torch.arange(0, batch_size, device=self.device).view(-1, 1, 1) * num_points 218 | idx = idx + idx_base 219 | idx = idx.view(-1) 220 | 221 | xyz = xyz.transpose(2, 1).contiguous() # bs, n, 3 222 | point_feature = xyz.view(batch_size * num_points, -1)[idx, :] 223 | point_feature = point_feature.view(batch_size, num_points, self.k, -1) # bs, n, k, 3 224 | points = xyz.view(batch_size, num_points, 1, 3).expand(-1, -1, self.k, -1) # bs, n, k, 3 225 | 226 | point_feature = torch.cat((points, point_feature, point_feature - points), 227 | dim=3).permute(0, 3, 1, 2).contiguous() 228 | 229 | if self.initial: 230 | return point_feature 231 | 232 | x = x.transpose(2, 1).contiguous() # bs, n, c 233 | feature = x.view(batch_size * num_points, -1)[idx, :] 234 | feature = feature.view(batch_size, num_points, self.k, num_dims) #bs, n, k, c 235 | x = x.view(batch_size, num_points, 1, num_dims) 236 | feature = feature - x 237 | 238 | feature = feature.permute(0, 3, 1, 2).contiguous() 239 | point_feature = self.xyz2feature(point_feature) #bs, c, n, k 240 | feature = F.leaky_relu(feature + point_feature, 0.2) 241 | return feature #bs, c, n, k 242 | 243 | 244 | class PointNetFeaturePropagation(nn.Module): 245 | def __init__(self, in_channel, mlp, att=None): 246 | super(PointNetFeaturePropagation, self).__init__() 247 | self.mlp_convs = nn.ModuleList() 248 | self.mlp_bns = nn.ModuleList() 249 | last_channel = in_channel 250 | self.att = None 251 | if att is not None: 252 | self.att = Attention_block(F_g=att[0],F_l=att[1],F_int=att[2]) 253 | 254 | for out_channel in mlp: 255 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) 256 | self.mlp_bns.append(nn.BatchNorm1d(out_channel)) 257 | last_channel = out_channel 258 | 259 | def forward(self, xyz1, xyz2, points1, points2): 260 | """ 261 | Input: 262 | xyz1: input points position data, [B, C, N] 263 | xyz2: sampled input points position data, [B, C, S], skipped xyz 264 | points1: input points data, [B, D, N] 265 | points2: input points data, [B, D, S], skipped features 266 | Return: 267 | new_points: upsampled points data, [B, D', N] 268 | """ 269 | xyz1 = xyz1.permute(0, 2, 1) 270 | xyz2 = xyz2.permute(0, 2, 1) 271 | 272 | points2 = points2.permute(0, 2, 1) 273 | B, N, C = xyz1.shape 274 | _, S, _ = xyz2.shape 275 | 276 | if S == 1: 277 | interpolated_points = points2.repeat(1, N, 1) 278 | else: 279 | dists = square_distance(xyz1, xyz2) 280 | dists, idx = dists.sort(dim=-1) 281 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] 282 | 283 | dist_recip = 1.0 / (dists + 1e-8) 284 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 285 | weight = dist_recip / norm 286 | interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2) 287 | 288 | # skip attention 289 | if self.att is not None: 290 | psix, psig = self.att(interpolated_points.permute(0, 2, 1), points1) 291 | points1 = points1 * psix 292 | 293 | if points1 is not None: 294 | points1 = points1.permute(0, 2, 1) 295 | new_points = torch.cat([points1, interpolated_points], dim=-1) 296 | else: 297 | new_points = interpolated_points 298 | 299 | new_points = new_points.permute(0, 2, 1) 300 | 301 | for i, conv in enumerate(self.mlp_convs): 302 | bn = self.mlp_bns[i] 303 | new_points = F.leaky_relu(bn(conv(new_points)), 0.2) 304 | 305 | return new_points 306 | 307 | 308 | class CIC(nn.Module): 309 | def __init__(self, npoint, radius, k, in_channels, output_channels, bottleneck_ratio=2, mlp_num=2, curve_config=None): 310 | super(CIC, self).__init__() 311 | self.in_channels = in_channels 312 | self.output_channels = output_channels 313 | self.bottleneck_ratio = bottleneck_ratio 314 | self.radius = radius 315 | self.k = k 316 | self.npoint = npoint 317 | 318 | planes = in_channels // bottleneck_ratio 319 | 320 | self.use_curve = curve_config is not None 321 | if self.use_curve: 322 | self.curveaggregation = CurveAggregation(planes) 323 | self.curvegrouping = CurveGrouping(planes, k, curve_config[0], curve_config[1]) 324 | 325 | self.conv1 = nn.Sequential( 326 | nn.Conv1d(in_channels, 327 | planes, 328 | kernel_size=1, 329 | bias=False), 330 | nn.BatchNorm1d(in_channels // bottleneck_ratio), 331 | nn.LeakyReLU(negative_slope=0.2, inplace=True)) 332 | 333 | self.conv2 = nn.Sequential( 334 | nn.Conv1d(planes, output_channels, kernel_size=1, bias=False), 335 | nn.BatchNorm1d(output_channels)) 336 | 337 | if in_channels != output_channels: 338 | self.shortcut = nn.Sequential( 339 | nn.Conv1d(in_channels, 340 | output_channels, 341 | kernel_size=1, 342 | bias=False), 343 | nn.BatchNorm1d(output_channels)) 344 | 345 | self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 346 | 347 | self.maxpool = MaskedMaxPool(npoint, radius, k) 348 | 349 | self.lpfa = LPFA(planes, planes, k, mlp_num=mlp_num, initial=False) 350 | 351 | def forward(self, xyz, x): 352 | 353 | # max pool 354 | if xyz.size(-1) != self.npoint: 355 | xyz, x = self.maxpool( 356 | xyz.transpose(1, 2).contiguous(), x) 357 | xyz = xyz.transpose(1, 2) 358 | 359 | shortcut = x 360 | x = self.conv1(x) # bs, c', n 361 | 362 | idx = knn(xyz, self.k) 363 | 364 | if self.use_curve: 365 | # curve grouping 366 | curves, flatten_curve_idxs = self.curvegrouping(x, xyz, idx[:,:,1:]) # avoid self-loop 367 | 368 | # curve aggregation 369 | x = self.curveaggregation(x, curves) 370 | else: 371 | flatten_curve_idxs = None 372 | 373 | x = self.lpfa(x, xyz, idx=idx[:,:,:self.k]) #bs, c', n, k 374 | 375 | x = self.conv2(x) # bs, c, n 376 | 377 | if self.in_channels != self.output_channels: 378 | shortcut = self.shortcut(shortcut) 379 | 380 | x = self.relu(x + shortcut) 381 | 382 | return xyz, x, flatten_curve_idxs 383 | 384 | 385 | class CurveAggregation(nn.Module): 386 | def __init__(self, in_channel): 387 | super(CurveAggregation, self).__init__() 388 | self.in_channel = in_channel 389 | mid_feature = in_channel // 2 390 | self.conva = nn.Conv1d(in_channel, 391 | mid_feature, 392 | kernel_size=1, 393 | bias=False) 394 | self.convb = nn.Conv1d(in_channel, 395 | mid_feature, 396 | kernel_size=1, 397 | bias=False) 398 | self.convc = nn.Conv1d(in_channel, 399 | mid_feature, 400 | kernel_size=1, 401 | bias=False) 402 | self.convn = nn.Conv1d(mid_feature, 403 | mid_feature, 404 | kernel_size=1, 405 | bias=False) 406 | self.convl = nn.Conv1d(mid_feature, 407 | mid_feature, 408 | kernel_size=1, 409 | bias=False) 410 | self.convd = nn.Sequential( 411 | nn.Conv1d(mid_feature * 2, 412 | in_channel, 413 | kernel_size=1, 414 | bias=False), 415 | nn.BatchNorm1d(in_channel)) 416 | self.line_conv_att = nn.Conv2d(in_channel, 417 | 1, 418 | kernel_size=1, 419 | bias=False) 420 | 421 | def forward(self, x, curves): 422 | curves_att = self.line_conv_att(curves) # bs, 1, c_n, c_l 423 | 424 | curver_inter = torch.sum(curves * F.softmax(curves_att, dim=-1), dim=-1) #bs, c, c_n 425 | curves_intra = torch.sum(curves * F.softmax(curves_att, dim=-2), dim=-2) #bs, c, c_l 426 | 427 | curver_inter = self.conva(curver_inter) # bs, mid, n 428 | curves_intra = self.convb(curves_intra) # bs, mid ,n 429 | 430 | x_logits = self.convc(x).transpose(1, 2).contiguous() 431 | x_inter = F.softmax(torch.bmm(x_logits, curver_inter), dim=-1) # bs, n, c_n 432 | x_intra = F.softmax(torch.bmm(x_logits, curves_intra), dim=-1) # bs, l, c_l 433 | 434 | 435 | curver_inter = self.convn(curver_inter).transpose(1, 2).contiguous() 436 | curves_intra = self.convl(curves_intra).transpose(1, 2).contiguous() 437 | 438 | x_inter = torch.bmm(x_inter, curver_inter) 439 | x_intra = torch.bmm(x_intra, curves_intra) 440 | 441 | curve_features = torch.cat((x_inter, x_intra),dim=-1).transpose(1, 2).contiguous() 442 | x = x + self.convd(curve_features) 443 | 444 | return F.leaky_relu(x, negative_slope=0.2) 445 | 446 | 447 | class CurveGrouping(nn.Module): 448 | def __init__(self, in_channel, k, curve_num, curve_length): 449 | super(CurveGrouping, self).__init__() 450 | self.curve_num = curve_num 451 | self.curve_length = curve_length 452 | self.in_channel = in_channel 453 | self.k = k 454 | 455 | self.att = nn.Conv1d(in_channel, 1, kernel_size=1, bias=False) 456 | 457 | self.walk = Walk(in_channel, k, curve_num, curve_length) 458 | 459 | def forward(self, x, xyz, idx): 460 | # starting point selection in self attention style 461 | x_att = torch.sigmoid(self.att(x)) 462 | x = x * x_att 463 | 464 | _, start_index = torch.topk(x_att, 465 | self.curve_num, 466 | dim=2, 467 | sorted=False) 468 | start_index = start_index.squeeze(1).unsqueeze(2) 469 | 470 | curves, flatten_curve_idxs = self.walk(xyz, x, idx, start_index) #bs, c, c_n, c_l 471 | 472 | return curves, flatten_curve_idxs 473 | 474 | 475 | class MaskedMaxPool(nn.Module): 476 | def __init__(self, npoint, radius, k): 477 | super(MaskedMaxPool, self).__init__() 478 | self.npoint = npoint 479 | self.radius = radius 480 | self.k = k 481 | 482 | def forward(self, xyz, features): 483 | sub_xyz, neighborhood_features = sample_and_group(self.npoint, self.radius, self.k, xyz, features.transpose(1,2)) 484 | 485 | neighborhood_features = neighborhood_features.permute(0, 3, 1, 2).contiguous() 486 | sub_features = F.max_pool2d( 487 | neighborhood_features, kernel_size=[1, neighborhood_features.shape[3]] 488 | ) # bs, c, n, 1 489 | sub_features = torch.squeeze(sub_features, -1) # bs, c, n 490 | return sub_xyz, sub_features 491 | -------------------------------------------------------------------------------- /core/models/walk.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: Tiange Xiang 3 | @Contact: txia7609@uni.sydney.edu.au 4 | @File: walk.py 5 | @Time: 2021/01/21 3:10 PM 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | def batched_index_select(input, dim, index): 15 | views = [input.shape[0]] + \ 16 | [1 if i != dim else -1 for i in range(1, len(input.shape))] 17 | expanse = list(input.shape) 18 | expanse[0] = -1 19 | expanse[dim] = -1 20 | index = index.view(views).expand(expanse) 21 | return torch.gather(input, dim, index) 22 | 23 | def gumbel_softmax(logits, dim, temperature=1): 24 | """ 25 | ST-gumple-softmax w/o random gumbel samplings 26 | input: [*, n_class] 27 | return: flatten --> [*, n_class] an one-hot vector 28 | """ 29 | y = F.softmax(logits / temperature, dim=dim) 30 | 31 | shape = y.size() 32 | _, ind = y.max(dim=-1) 33 | y_hard = torch.zeros_like(y).view(-1, shape[-1]) 34 | y_hard.scatter_(1, ind.view(-1, 1), 1) 35 | y_hard = y_hard.view(*shape) 36 | 37 | y_hard = (y_hard - y).detach() + y 38 | return y_hard 39 | 40 | class Walk(nn.Module): 41 | ''' 42 | Walk in the cloud 43 | ''' 44 | def __init__(self, in_channel, k, curve_num, curve_length): 45 | super(Walk, self).__init__() 46 | self.curve_num = curve_num 47 | self.curve_length = curve_length 48 | self.k = k 49 | 50 | self.agent_mlp = nn.Sequential( 51 | nn.Conv2d(in_channel * 2, 52 | 1, 53 | kernel_size=1, 54 | bias=False), nn.BatchNorm2d(1)) 55 | self.momentum_mlp = nn.Sequential( 56 | nn.Conv1d(in_channel * 2, 57 | 2, 58 | kernel_size=1, 59 | bias=False), nn.BatchNorm1d(2)) 60 | 61 | def crossover_suppression(self, cur, neighbor, bn, n, k): 62 | # cur: bs*n, 3 63 | # neighbor: bs*n, 3, k 64 | neighbor = neighbor.detach() 65 | cur = cur.unsqueeze(-1).detach() 66 | dot = torch.bmm(cur.transpose(1,2), neighbor) # bs*n, 1, k 67 | norm1 = torch.norm(cur, dim=1, keepdim=True) 68 | norm2 = torch.norm(neighbor, dim=1, keepdim=True) 69 | divider = torch.clamp(norm1 * norm2, min=1e-8) 70 | ans = torch.div(dot, divider).squeeze() # bs*n, k 71 | 72 | # normalize to [0, 1] 73 | ans = 1. + ans 74 | ans = torch.clamp(ans, 0., 1.0) 75 | 76 | return ans.detach() 77 | 78 | def forward(self, xyz, x, adj, cur): 79 | bn, c, tot_points = x.size() 80 | 81 | # raw point coordinates 82 | xyz = xyz.transpose(1,2).contiguous # bs, n, 3 83 | 84 | # point features 85 | x = x.transpose(1,2).contiguous() # bs, n, c 86 | 87 | flatten_x = x.view(bn * tot_points, -1) 88 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 89 | batch_offset = torch.arange(0, bn, device=device).detach() * tot_points 90 | 91 | # indices of neighbors for the starting points 92 | tmp_adj = (adj + batch_offset.view(-1,1,1)).view(adj.size(0)*adj.size(1),-1) #bs, n, k 93 | 94 | # batch flattened indices for teh starting points 95 | flatten_cur = (cur + batch_offset.view(-1,1,1)).view(-1) 96 | 97 | curves = [] 98 | flatten_curve_idxs = [flatten_cur.unsqueeze(1)] 99 | 100 | # one step at a time 101 | for step in range(self.curve_length): 102 | 103 | if step == 0: 104 | # get starting point features using flattend indices 105 | starting_points = flatten_x[flatten_cur, :].contiguous() 106 | pre_feature = starting_points.view(bn, self.curve_num, -1, 1).transpose(1,2) # bs * n, c 107 | else: 108 | # dynamic momentum 109 | cat_feature = torch.cat((cur_feature.squeeze(-1), pre_feature.squeeze(-1)),dim=1) 110 | att_feature = F.softmax(self.momentum_mlp(cat_feature),dim=1).view(bn, 1, self.curve_num, 2) # bs, 1, n, 2 111 | cat_feature = torch.cat((cur_feature, pre_feature),dim=-1) # bs, c, n, 2 112 | 113 | # update curve descriptor 114 | pre_feature = torch.sum(cat_feature * att_feature, dim=-1, keepdim=True) # bs, c, n 115 | pre_feature_cos = pre_feature.transpose(1,2).contiguous().view(bn * self.curve_num, -1) 116 | 117 | pick_idx = tmp_adj[flatten_cur] # bs*n, k 118 | 119 | # get the neighbors of current points 120 | pick_values = flatten_x[pick_idx.view(-1),:] 121 | 122 | # reshape to fit crossover suppresion below 123 | pick_values_cos = pick_values.view(bn * self.curve_num, self.k, c) 124 | pick_values = pick_values_cos.view(bn, self.curve_num, self.k, c) 125 | pick_values_cos = pick_values_cos.transpose(1,2).contiguous() 126 | 127 | pick_values = pick_values.permute(0,3,1,2) # bs, c, n, k 128 | 129 | pre_feature_expand = pre_feature.expand_as(pick_values) 130 | 131 | # concat current point features with curve descriptors 132 | pre_feature_expand = torch.cat((pick_values, pre_feature_expand),dim=1) 133 | 134 | # which node to pick next? 135 | pre_feature_expand = self.agent_mlp(pre_feature_expand) # bs, 1, n, k 136 | 137 | if step !=0: 138 | # cross over supression 139 | d = self.crossover_suppression(cur_feature_cos - pre_feature_cos, 140 | pick_values_cos - cur_feature_cos.unsqueeze(-1), 141 | bn, self.curve_num, self.k) 142 | d = d.view(bn, self.curve_num, self.k).unsqueeze(1) # bs, 1, n, k 143 | pre_feature_expand = torch.mul(pre_feature_expand, d) 144 | 145 | pre_feature_expand = gumbel_softmax(pre_feature_expand, -1) #bs, 1, n, k 146 | 147 | cur_feature = torch.sum(pick_values * pre_feature_expand, dim=-1, keepdim=True) # bs, c, n, 1 148 | 149 | cur_feature_cos = cur_feature.transpose(1,2).contiguous().view(bn * self.curve_num, c) 150 | 151 | cur = torch.argmax(pre_feature_expand, dim=-1).view(-1, 1) # bs * n, 1 152 | 153 | flatten_cur = batched_index_select(pick_idx, 1, cur).squeeze() # bs * n 154 | 155 | # collect curve progress 156 | curves.append(cur_feature) 157 | flatten_curve_idxs.append(flatten_cur.unsqueeze(1)) 158 | return torch.cat(curves,dim=-1), torch.cat(flatten_curve_idxs, dim=1) -------------------------------------------------------------------------------- /core/start_cls.sh: -------------------------------------------------------------------------------- 1 | python3 main_cls.py --exp_name=curvenet_cls_1 2 | -------------------------------------------------------------------------------- /core/start_normal.sh: -------------------------------------------------------------------------------- 1 | python3 main_normal.py --exp_name=curvenet_normal_1 2 | -------------------------------------------------------------------------------- /core/start_part.sh: -------------------------------------------------------------------------------- 1 | python3 main_partseg.py --exp_name=curveunet_seg_1 2 | -------------------------------------------------------------------------------- /core/test_cls.sh: -------------------------------------------------------------------------------- 1 | python3 main_cls.py --eval=True --model_path=../pretrained/cls/models/model.t7 2 | -------------------------------------------------------------------------------- /core/test_normal.sh: -------------------------------------------------------------------------------- 1 | python3 main_normal.py --eval=True --model_path=../pretrained/normal/models/model.t7 2 | -------------------------------------------------------------------------------- /core/test_part.sh: -------------------------------------------------------------------------------- 1 | python3 main_partseg.py --eval=True --model_path=../pretrained/seg/models/model.t7 2 | -------------------------------------------------------------------------------- /core/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: Yue Wang 3 | @Contact: yuewangx@mit.edu 4 | @File: util 5 | @Time: 4/5/19 3:47 PM 6 | """ 7 | 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn.functional as F 12 | 13 | 14 | def cal_loss(pred, gold, smoothing=True): 15 | ''' Calculate cross entropy loss, apply label smoothing if needed. ''' 16 | 17 | gold = gold.contiguous().view(-1) 18 | 19 | if smoothing: 20 | eps = 0.2 21 | n_class = pred.size(1) 22 | 23 | one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1) 24 | one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) 25 | log_prb = F.log_softmax(pred, dim=1) 26 | 27 | loss = -(one_hot * log_prb).sum(dim=1).mean() 28 | else: 29 | loss = F.cross_entropy(pred, gold, reduction='mean') 30 | 31 | return loss 32 | 33 | 34 | class IOStream(): 35 | def __init__(self, path): 36 | self.f = open(path, 'a') 37 | 38 | def cprint(self, text): 39 | print(text) 40 | self.f.write(text+'\n') 41 | self.f.flush() 42 | 43 | def close(self): 44 | self.f.close() 45 | -------------------------------------------------------------------------------- /core/visualize_curves.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: Vinit Sarode 3 | @Contact: vinitsarode5@gmail.com 4 | @File: visualize_curves.py 5 | @Time: 2025/03/03 11:17 AM 6 | """ 7 | 8 | import argparse 9 | import torch 10 | import torch.nn as nn 11 | from data import ModelNet40 12 | import plotly.graph_objects as go 13 | from torch.utils.data import DataLoader 14 | from models.curvenet_cls import CurveNet 15 | 16 | def visualize_point_cloud(pcd, curves, axis=False, title=""): 17 | x, y, z= pcd[..., 0], pcd[..., 1], pcd[..., 2] 18 | fig = go.Figure( 19 | layout=dict( 20 | scene=dict( 21 | xaxis=dict(visible=axis), 22 | yaxis=dict(visible=axis), 23 | zaxis=dict(visible=axis) 24 | ), 25 | title=title, 26 | title_x=0.5 27 | ) 28 | ) 29 | fig.add_trace(go.Scatter3d( 30 | x=x, y=y, z=z, 31 | mode='markers', 32 | marker=dict(size=1) 33 | )) 34 | 35 | for curve in curves: 36 | x, y, z= curve[..., 0], curve[..., 1], curve[..., 2] 37 | fig.add_trace(go.Scatter3d( 38 | x=x, y=y, z=z, 39 | # mode='markers', 40 | marker=dict(size=1), 41 | line=dict( 42 | color='darkred', 43 | width=2 44 | ) 45 | )) 46 | fig.show() 47 | 48 | def visualize(args): 49 | test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), 50 | batch_size=1, shuffle=False, drop_last=False) 51 | 52 | device = torch.device("cuda" if args.cuda else "cpu") 53 | 54 | #Try to load models 55 | model = CurveNet().to(device) 56 | weights = torch.load(args.model_path, map_location='cpu') 57 | weights = {k[7:]: v for k, v in weights.items()} 58 | model.load_state_dict(weights) 59 | 60 | model = model.eval() 61 | for idx, (data, label) in enumerate(test_loader): 62 | if idx >= args.no_of_samples: 63 | break 64 | 65 | data, label = data.to(device), label.to(device).squeeze() 66 | data = data.permute(0, 2, 1) 67 | logits, flatten_cur = model(data, get_flatten_curve_idxs=True) 68 | data = data.permute(0, 2, 1).detach().cpu().numpy()[0] 69 | 70 | curves_dict = {} 71 | for key, val in flatten_cur.items(): 72 | if val is not None: 73 | curves = [] 74 | val_np = val.cpu().detach().numpy() 75 | for idx in range(val_np.shape[0]): 76 | curves.append(data[val_np[idx]]) 77 | curves_dict[key] = curves 78 | 79 | visualize_point_cloud(data, curves_dict[args.visualize_curve], title=args.visualize_curve) 80 | 81 | 82 | if __name__ == "__main__": 83 | # Training settings 84 | parser = argparse.ArgumentParser(description='Point Cloud Recognition') 85 | parser.add_argument('--dataset', type=str, default='modelnet40', metavar='N', 86 | choices=['modelnet40']) 87 | parser.add_argument('--no_cuda', type=bool, default=False, 88 | help='enables CUDA training') 89 | parser.add_argument('--num_points', type=int, default=1024, 90 | help='num of points to use') 91 | parser.add_argument('--model_path', type=str, default='', metavar='N', 92 | help='Pretrained model path') 93 | parser.add_argument('--visualize_curve', type=str, default='flatten_curve_idxs_11', 94 | help='Choose which curve to visualize based on model architecture', 95 | choices=['flatten_curve_idxs_11', 'flatten_curve_idxs_12', 96 | 'flatten_curve_idxs_21', 'flatten_curve_idxs_22']) 97 | parser.add_argument('--no_of_samples', type=int, default=3, 98 | help='No of point clouds to visualize with curves') 99 | args = parser.parse_args() 100 | 101 | args.cuda = not args.no_cuda and torch.cuda.is_available() 102 | 103 | visualize(args=args) -------------------------------------------------------------------------------- /poster3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiangexiang/CurveNet/0d2f62562e906bd78ec82e5ee56f5cec6f5c2b8c/poster3.png -------------------------------------------------------------------------------- /teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiangexiang/CurveNet/0d2f62562e906bd78ec82e5ee56f5cec6f5c2b8c/teaser.png --------------------------------------------------------------------------------