├── .gitignore ├── Definition.def ├── README.md ├── RandLA-Net-pytorch_Visualization.gif ├── data ├── pc_id=636 │ ├── metadata │ │ └── metadata.pickle │ └── pc.pickle └── pc_id=637 │ ├── metadata │ └── metadata.pickle │ └── pc.pickle ├── model ├── __init__.py ├── dataset.py ├── hyperparameters.py ├── model.py ├── sampler.py ├── testing.py ├── training.py └── utils.py ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /Definition.def: -------------------------------------------------------------------------------- 1 | Bootstrap: docker 2 | From: nvidia/cuda:10.1-cudnn7-devel-ubuntu18.04 3 | 4 | %setup 5 | 6 | %files 7 | 8 | %environment 9 | 10 | %post 11 | 12 | export DEBIAN_FRONTEND=noninteractive 13 | apt update && apt -y install git wget nano python3 python3-pip python3-opencv htop curl 14 | 15 | python3 -m pip install tqdm matplotlib mlflow pandas scikit-learn==0.23.1 scikit-image 16 | python3 -m pip install seaborn scipy datetime numpy==1.19.1 17 | python3 -m pip install torch==1.6.0 msgpack k3d 18 | 19 | %environment 20 | export LC_ALL=C 21 | 22 | %runscript 23 | 24 | %startscript 25 | 26 | %test 27 | 28 | %labels 29 | 30 | %help 31 | 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RandLA-Net-pytorch 2 | [![Resuts example viz](https://github.com/idsia-robotics/RandLA-Net-pytorch/blob/main/RandLA-Net-pytorch_Visualization.gif)](https://youtu.be/qE3vvh8aY00) 3 | 4 | Our PyTorch implementation of [RandLA-Net](https://github.com/QingyongHu/RandLA-Net) 5 | 6 | We tried to stay as close as possible to the original Tensorflow model implementation. 7 | However, some changes in the pipeline and input format were made to adapt the model to our own data format. 8 | 9 | [model.py](https://github.com/idsia-robotics/RandLA-Net-pytorch/blob/main/model/model.py), [sampler.py](https://github.com/idsia-robotics/RandLA-Net-pytorch/blob/main/model/sampler.py), and [dataset.py](https://github.com/idsia-robotics/RandLA-Net-pytorch/blob/main/model/dataset.py) contain all the relevant pytorch code to be reused and adapted for different data formats. [hyperparameters.py](https://github.com/idsia-robotics/RandLA-Net-pytorch/blob/main/model/hyperparameters.py) contains... the hyperparameters that can be set for training. 10 | 11 | Instructions are provided to run the complete pipeline on a data sample and to explain how to prepare your own data. 12 | 13 | ## Input Format 14 | Each pointcloud must be stored in its own folder named **pc_id=integer_id**. 15 | This folder must contain a pickle file (**pc.pickle**) for the pointcloud itself and another folder called **metadata**. 16 | Finally, metadata folder must contain a file called **metadata.pickle**, which contains a python dictionary like this: 17 | 18 | { 19 | "pc_id": int id for the pointloud, 20 | "labels": list of floats representing the classes available in the pointcloud, 21 | "name": string name for the pointcloud (can be None) 22 | } 23 | Here is an example of how the dataset folder would look like: 24 | ``` 25 | data/ 26 | |________pc_id=20/ 27 | | |________pc.pickle 28 | | |________metadata/ 29 | | |________metadata.pickle 30 | |________pc_id=30/ 31 | |________pc.pickle 32 | |________metadata/ 33 | |________metadata.pickle 34 | ``` 35 | Each **pc.pickle** file must contain a numpy array of shape *(n, 7)*, where *n* is the number of points and 7 are (in this order): 36 | - x,y,z float coordinates 37 | - r, g, b colors which are integers in [0, 255] (stored, however, as floats) 38 | - ground truth label 39 | In the [data](https://github.com/idsia-robotics/RandLA-Net-pytorch/tree/main/data) folder a sample dataset is available. It contains two low density pointclouds from [Nomoko](https://nomoko.world/) dataset, available [here](https://zenodo.org/record/4390295#.YIEin3UzY5k). 40 | ## Output Format 41 | Training will produce as output pytorch model checkpoints (by default under the data folder). 42 | 43 | Testing will produce the following files (inside the selected model folder): 44 | - **xyz_tile.pickle**: array of shape (n, 3) containing xyz coordinates 45 | - **xyz_probs.pickle**: array of shape (n, n_classes) containing the predicted score for each class 46 | - **xyz_labels.pickle**: array of shape (n,) containing the most probable class for each point (i.e. argmax(xyz_probs)) 47 | - **true_rgb.pickle**: array of shape (n, 3) containing rgb colors 48 | - **gt_labels.pickle**: array of shape (n,) containing the ground truth class for each point 49 | 50 | Also, it produces two visualization of the pointcloud to allow for quick inspection of results: 51 | - **snapshot_predictions.html**: K3D html snapshot to visualize the pointcloud colored according to model predictions (each predicted class can be toggled in the UI) 52 | - **rgb_predictions.html**: K3D html snapshot to visualize the pointcloud colored with real rgb (each predicted class can be toggled in the UI) 53 | 54 | ## How To Run 55 | We provide a [singularity definition](https://github.com/idsia-robotics/RandLA-Net-pytorch/blob/main/Definition.def) to build the needed environment. 56 | Please, note that if you store your dataset in another folder, you will need to change the **DATA_ROOT_PATH** variable in [utils.py](https://github.com/idsia-robotics/RandLA-Net-pytorch/blob/main/model/utils.py). Default points to *data/* within this repo. 57 | 58 | ``` 59 | # Clone repo and cd into it 60 | git clone https://github.com/idsia-robotics/RandLA-Net-pytorch.git 61 | cd RandLA-Net-pytorch 62 | 63 | # Build singularity container 64 | singularity build --fakeroot RandLA-Net-pytorch.sif Definition.def 65 | 66 | # Start the singularity container 67 | singularity instance start --nv RandLA-Net-pytorch.sif randlanet 68 | 69 | # Open a shell inside the container: 70 | singularity shell instance://randlanet 71 | 72 | # From within the shell, run training: 73 | python3 train.py 74 | 75 | # and testing: 76 | python3 test.py 77 | ``` 78 | # Citation 79 | 80 | #### Swiss3DCities: Aerial Photogrammetric 3D Pointcloud Dataset with Semantic Labels 81 | - [paper](https://arxiv.org/abs/2012.12996) 82 | - [dataset](https://zenodo.org/record/4390295) 83 | 84 | ### RandLA-Net: Efficient Semantic Segmentation of Large-Scale Point Clouds 85 | - [paper](http://arxiv.org/abs/1911.11236) 86 | - [repo](https://github.com/QingyongHu/RandLA-Net) with the original Tensorflow 1 implementation 87 | -------------------------------------------------------------------------------- /RandLA-Net-pytorch_Visualization.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idsia-robotics/RandLA-Net-pytorch/322e0f8f9d2c1443ae180bb4c48f3b54446546e6/RandLA-Net-pytorch_Visualization.gif -------------------------------------------------------------------------------- /data/pc_id=636/metadata/metadata.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idsia-robotics/RandLA-Net-pytorch/322e0f8f9d2c1443ae180bb4c48f3b54446546e6/data/pc_id=636/metadata/metadata.pickle -------------------------------------------------------------------------------- /data/pc_id=636/pc.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idsia-robotics/RandLA-Net-pytorch/322e0f8f9d2c1443ae180bb4c48f3b54446546e6/data/pc_id=636/pc.pickle -------------------------------------------------------------------------------- /data/pc_id=637/metadata/metadata.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idsia-robotics/RandLA-Net-pytorch/322e0f8f9d2c1443ae180bb4c48f3b54446546e6/data/pc_id=637/metadata/metadata.pickle -------------------------------------------------------------------------------- /data/pc_id=637/pc.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idsia-robotics/RandLA-Net-pytorch/322e0f8f9d2c1443ae180bb4c48f3b54446546e6/data/pc_id=637/pc.pickle -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idsia-robotics/RandLA-Net-pytorch/322e0f8f9d2c1443ae180bb4c48f3b54446546e6/model/__init__.py -------------------------------------------------------------------------------- /model/dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import random 3 | from collections import defaultdict 4 | import pickle 5 | 6 | import torch 7 | import numpy as np 8 | from scipy.spatial import cKDTree 9 | 10 | from torch.utils import data 11 | 12 | from .utils import read_metadata, rotate 13 | 14 | 15 | class RandlanetDataset(data.Dataset): 16 | 17 | def __init__(self, pc_path_list, **kwargs): 18 | self.cfg = kwargs 19 | self.size = 0 20 | pc_labels = read_metadata(pc_path_list[0])['labels'] 21 | self.test = [-99.] == pc_labels 22 | if self.test: 23 | assert len(pc_path_list) == 1, "Only one pc can be used as test" 24 | pc_labels = [-99.] 25 | else: 26 | print("Using labels from first dataset provided") 27 | assert len(pc_labels) == self.cfg['num_classes'], \ 28 | f"self.cfg['num_classes'] {self.cfg['num_classes']} is different" \ 29 | f"from len(pc_labels) {len(pc_labels)}" 30 | for pc_path in pc_path_list: 31 | o_pc_labels = read_metadata(pc_path)['labels'] 32 | assert set(o_pc_labels).issubset(set(pc_labels)), \ 33 | "Point clouds must be created considering a subset " \ 34 | "of labels from the first pc provided" 35 | 36 | self.mapping = {label: i for i, label in enumerate(sorted(pc_labels))} 37 | self.kdtrees = dict() 38 | self.colors = dict() 39 | self.labels = dict() 40 | self.pc_class_count = dict() 41 | self.total_class_count = defaultdict(int) 42 | self.total_class_weight = dict() 43 | self.n_points = 0 44 | 45 | for pc_path in pc_path_list: 46 | with open(f"{pc_path}pc.pickle", "rb") as f: 47 | pc = pickle.load(f) 48 | metadata = read_metadata(pc_path) 49 | pc_id = metadata["pc_id"] 50 | pc_name = metadata["name"] 51 | kdtree_f = f"{pc_path}/kdtree.pickle" 52 | if os.path.isfile(kdtree_f): 53 | with open(kdtree_f, 'rb') as f: 54 | kdtree = pickle.load(f) 55 | else: 56 | print(f"KDtree for pc {pc_id} {pc_name} not found, creating it") 57 | kdtree = cKDTree(pc[:, :3], leafsize=50) 58 | with open(kdtree_f, "wb") as f: 59 | pickle.dump(kdtree, f) 60 | self.kdtrees[pc_id] = kdtree 61 | self.colors[pc_id] = pc[:, 3:6]/255. 62 | self.labels[pc_id] = pc[:, 6] 63 | self.size += len(self.kdtrees[pc_id].data) 64 | 65 | labels, counters = np.unique(self.labels[pc_id], return_counts=True) 66 | self.pc_class_count[pc_id] = dict() 67 | for label, counter in zip(labels, counters): 68 | self.pc_class_count[pc_id][label] = counter 69 | self.total_class_count[label] += counter 70 | self.n_points += counter 71 | 72 | for label, counter in self.total_class_count.items(): 73 | self.total_class_weight[label] = counter/self.n_points 74 | 75 | def __getitem__(self, _tuple): 76 | pc_id = _tuple[0] 77 | pick_point = _tuple[1] 78 | # center_point = _tuple[1].reshape(1, -1) 79 | # Get all points within the cloud from tree structure 80 | points = np.array(self.kdtrees[pc_id].data, copy=False) 81 | 82 | query_idx = self.kdtrees[pc_id].query(pick_point, 83 | k=self.cfg['num_points'])[1][0] 84 | # shuffle index inplace 85 | random.shuffle(query_idx) 86 | 87 | # Get corresponding points and colors based on the index 88 | queried_pc_xyz = points[query_idx] 89 | 90 | queried_pc_xyz[:, 0:3] = queried_pc_xyz[:, 0:3] - pick_point[:, 0:3] 91 | queried_pc_colors = self.colors[pc_id][query_idx] 92 | queried_pc_labels = self.labels[pc_id][query_idx] 93 | 94 | queried_pc_labels = np.array( 95 | [self.mapping[lbl] for lbl in queried_pc_labels]) 96 | 97 | input_list = self.build_input(queried_pc_xyz, queried_pc_colors, 98 | queried_pc_labels, query_idx, pc_id) 99 | 100 | return input_list 101 | 102 | def __len__(self): 103 | return self.size 104 | 105 | def build_input(self, xyz, rgb, labels, query_idx, pc_id): 106 | features = torch.tensor(self.augment_input(xyz, rgb), dtype=torch.float32) 107 | labels = torch.tensor(labels, dtype=torch.long) 108 | query_idx = torch.tensor(query_idx, dtype=torch.int32) 109 | pc_id = torch.tensor(pc_id, dtype=torch.int32) 110 | input_points = [] 111 | input_neighbors = [] 112 | input_pools = [] 113 | input_up_samples = [] 114 | 115 | for i in range(self.cfg['num_layers']): 116 | _, neigh_idx = cKDTree(xyz).query(xyz, k=self.cfg['k_n']) 117 | sub_sampling_idx = len(xyz)//self.cfg['sub_sampling_ratio'][i] 118 | sub_points = xyz[:sub_sampling_idx] 119 | pool_i = neigh_idx[:sub_sampling_idx] 120 | _, up_i = cKDTree(sub_points).query(xyz, k=1) 121 | input_points.append(torch.tensor(xyz, dtype=torch.float32)) 122 | input_neighbors.append(torch.tensor(neigh_idx, dtype=torch.int32)) 123 | input_pools.append(torch.tensor(pool_i, dtype=torch.int32)) 124 | input_up_samples.append(torch.tensor(up_i, dtype=torch.int32)) 125 | xyz = sub_points 126 | 127 | inputs = input_points + input_neighbors + input_pools + input_up_samples 128 | inputs += [features, labels, query_idx, pc_id] 129 | return inputs 130 | 131 | def augment_input(self, xyz, rgb): 132 | theta = np.random.uniform(0.0, 2 * np.pi) 133 | transformed_xyz = rotate(xyz, [0., 0., theta]) 134 | 135 | # Choose random scales for each example 136 | min_s = self.cfg['augment_scale_min'] 137 | max_s = self.cfg['augment_scale_max'] 138 | if self.cfg['augment_scale_anisotropic']: 139 | scales = np.random.uniform(min_s, max_s, size=(3,)) 140 | else: 141 | scales = np.random.uniform(min_s, max_s) 142 | scales = np.array([scales, scales, scales]) 143 | 144 | symmetries = [] 145 | for i in range(3): 146 | if self.cfg['augment_symmetries'][i]: 147 | symmetries.append(np.round( 148 | np.random.uniform()) * 2 - 1) 149 | else: 150 | symmetries.append(1.) 151 | scales *= np.array(symmetries) 152 | 153 | # Apply scales 154 | transformed_xyz = transformed_xyz * scales 155 | 156 | noise = np.random.normal(scale=self.cfg['augment_noise'], 157 | size=transformed_xyz.shape) 158 | transformed_xyz = transformed_xyz + noise 159 | 160 | stacked_features = np.concatenate([transformed_xyz, rgb], axis=-1) 161 | return stacked_features 162 | -------------------------------------------------------------------------------- /model/hyperparameters.py: -------------------------------------------------------------------------------- 1 | hyp = { 2 | 'k_n': 16, # KNN 3 | 'num_layers': 5, # Number of layers 4 | 'num_points': 40000, # Number of input points 5 | 'num_classes': 5, # Number of valid classes 6 | 'sub_grid_size': 0.001, # preprocess_parameter 7 | 'batch_size': 4, # batch_size during training 8 | 'val_batch_size': 16, # batch_size during validation and test 9 | 'train_steps': 500, # Number of steps per epochs 10 | 'val_steps': 100, # Number of validation steps per epoch 11 | 'sub_sampling_ratio': [4, 4, 4, 4, 2], 12 | # sampling ratio of random sampling at each layer 13 | 'd_out': [16, 64, 128, 256, 512], # feature dimension 14 | 'noise_init': 3.5, # noise initial parameter 15 | 'max_epoch': 100, # maximum epoch during training 16 | 'learning_rate': 1e-3, # initial learning rate 17 | 'lr_decays': {i: 0.95 for i in range(0, 500)}, # decay rate of learning rate 18 | 'augment_scale_anisotropic': True, 19 | 'augment_symmetries': [True, False, False], 20 | 'augment_rotation': 'vertical', 21 | 'augment_scale_min': 0.8, 22 | 'augment_scale_max': 1.2, 23 | 'augment_noise': 0.001, 24 | 'augment_occlusion': 'none', 25 | 'augment_color': 0.8 26 | } 27 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | import numpy as np 5 | 6 | 7 | class RandlaNet(nn.Module): 8 | def __init__(self, d_out, n_layers, n_classes): 9 | super(RandlaNet, self).__init__() 10 | self.n_classes = n_classes 11 | dilate_block_in = 8 12 | self.fc1 = nn.Linear(6, dilate_block_in) 13 | self.bn1 = nn.BatchNorm1d(dilate_block_in, eps=1e-6, momentum=0.01) 14 | self.f_encoders = nn.ModuleList() 15 | decoder_in_list = [d_out[0]*2] 16 | for i in range(n_layers): 17 | self.f_encoders.append(DilatedResidualBlock(dilate_block_in, d_out[i])) 18 | dilate_block_in = d_out[i]*2 19 | decoder_in_list.append(dilate_block_in) 20 | 21 | self.conv2 = nn.Conv2d(dilate_block_in, dilate_block_in, 22 | kernel_size=[1, 1], stride=[1, 1]) 23 | self.bn2 = nn.BatchNorm2d(dilate_block_in, eps=1e-6, momentum=0.01) 24 | 25 | self.f_decoders = nn.ModuleList() 26 | for i in range(n_layers): 27 | self.f_decoders.append(FeatureDecoder(decoder_in_list[-i-1] + 28 | decoder_in_list[-i-2], 29 | decoder_in_list[-i-2])) 30 | self.conv3 = nn.Conv2d(decoder_in_list[0], 64, kernel_size=[1, 1], 31 | stride=[1, 1]) 32 | self.bn3 = nn.BatchNorm2d(64, eps=1e-6, momentum=0.01) 33 | self.conv4 = nn.Conv2d(64, 32, kernel_size=[1, 1], stride=[1, 1]) 34 | self.bn4 = nn.BatchNorm2d(32, eps=1e-6, momentum=0.01) 35 | self.drop4 = nn.Dropout2d(p=0.5) 36 | self.conv5 = nn.Conv2d(32, self.n_classes, kernel_size=[1, 1], 37 | stride=[1, 1]) 38 | 39 | def forward(self, inputs): 40 | x = inputs['features'] 41 | x = self.fc1(x) 42 | x = x.permute(0, 2, 1).contiguous() 43 | x = self.bn1(x) 44 | x = F.leaky_relu(x) 45 | x = x[:, :, :, None] 46 | encoded_list = [] 47 | for i, encoder in enumerate(self.f_encoders): 48 | x = encoder(x, inputs['xyz'][i], inputs['neigh_idx'][i]) 49 | if i == 0: 50 | encoded_list.append(x.clone()) 51 | x = random_sample(x, inputs['sub_idx'][i]) 52 | encoded_list.append(x.clone()) 53 | x = F.leaky_relu(self.bn2(self.conv2(x)), negative_slope=0.2) 54 | for i, decoder in enumerate(self.f_decoders): 55 | x = decoder(x, encoded_list[-i-2], inputs['interp_idx'][-i-1]) 56 | x = F.leaky_relu(self.bn3(self.conv3(x)), negative_slope=0.2) 57 | x = F.leaky_relu(self.bn4(self.conv4(x)), negative_slope=0.2) 58 | x = self.drop4(x) 59 | x = self.conv5(x) 60 | x = x.squeeze(-1).permute(0, 2, 1).reshape([-1, self.n_classes]).contiguous() 61 | return x 62 | 63 | 64 | class FeatureDecoder(nn.Module): 65 | def __init__(self, f_in, f_out): 66 | super(FeatureDecoder, self).__init__() 67 | self.trconv1 = nn.ConvTranspose2d(f_in, f_out, kernel_size=[1, 1], 68 | stride=[1, 1]) 69 | self.bn1 = nn.BatchNorm2d(f_out, eps=1e-6, momentum=0.01) 70 | 71 | def forward(self, feature, encoded_feature, interp_idx): 72 | f_interp_i = nearest_interpolation(feature, interp_idx) 73 | f_decoded = self.trconv1(torch.cat([encoded_feature, f_interp_i], 74 | dim=1)) 75 | f_decoded = self.bn1(f_decoded) 76 | return f_decoded 77 | 78 | 79 | class DilatedResidualBlock(nn.Module): 80 | def __init__(self, f_in, d_out): 81 | super(DilatedResidualBlock, self).__init__() 82 | self.conv1 = nn.Conv2d(f_in, d_out//2, kernel_size=[1, 1], 83 | stride=[1, 1]) 84 | self.bn1 = nn.BatchNorm2d(d_out//2, eps=1e-6, momentum=0.01) 85 | self.bb = BuildingBlock(d_out) 86 | self.conv2 = nn.Conv2d(d_out, d_out*2, kernel_size=[1, 1], 87 | stride=[1, 1]) 88 | self.bn2 = nn.BatchNorm2d(d_out*2, eps=1e-6, momentum=0.01) 89 | self.shortcut = nn.Conv2d(f_in, d_out*2, kernel_size=[1, 1], 90 | stride=[1, 1]) 91 | self.bn_shortcut = nn.BatchNorm2d(d_out*2, eps=1e-6, momentum=0.01) 92 | 93 | def forward(self, feature, xyz, neigh_idx): 94 | f_pc = F.leaky_relu(self.bn1(self.conv1(feature)), negative_slope=0.2) 95 | f_pc = self.bb(xyz, f_pc, neigh_idx) 96 | f_pc = self.bn2(self.conv2(f_pc)) 97 | shortcut = self.bn_shortcut(self.shortcut(feature)) 98 | return F.leaky_relu(f_pc + shortcut) 99 | 100 | 101 | class BuildingBlock(nn.Module): 102 | def __init__(self, d_out): 103 | super(BuildingBlock, self).__init__() 104 | self.conv1 = nn.Conv2d(10, d_out//2, kernel_size=[1, 1], stride=[1, 1]) 105 | self.bn1 = nn.BatchNorm2d(d_out//2, eps=1e-6, momentum=0.01) 106 | self.attpool1 = AttentivePooling(2*(d_out//2), d_out//2) 107 | self.conv2 = nn.Conv2d(d_out//2, d_out//2, kernel_size=[1, 1], 108 | stride=[1, 1]) 109 | self.bn2 = nn.BatchNorm2d(d_out//2, eps=1e-6, momentum=0.01) 110 | self.attpool2 = AttentivePooling(2*(d_out//2), d_out) 111 | 112 | def forward(self, xyz, feature, neigh_idx): 113 | f_xyz = relative_pos_encoding(xyz, neigh_idx) 114 | f_xyz = F.leaky_relu(self.bn1(self.conv1(f_xyz)), negative_slope=0.2) 115 | feature = torch.squeeze(feature, dim=-1).permute(0, 2, 1).contiguous() 116 | f_neighbours = gather_neighbour(feature, neigh_idx) 117 | f_concat = torch.cat([f_neighbours, f_xyz], dim=1) 118 | f_pc_agg = self.attpool1(f_concat) 119 | 120 | f_xyz = F.leaky_relu(self.bn2(self.conv2(f_xyz)), negative_slope=0.2) 121 | f_pc_agg = torch.squeeze(f_pc_agg, dim=-1).permute(0, 2, 1).contiguous() 122 | f_neighbours = gather_neighbour(f_pc_agg, neigh_idx) 123 | f_concat = torch.cat([f_neighbours, f_xyz], dim=1) 124 | f_pc_agg = self.attpool2(f_concat) 125 | return f_pc_agg 126 | 127 | 128 | class AttentivePooling(nn.Module): 129 | def __init__(self, n_feature, d_out): 130 | super(AttentivePooling, self).__init__() 131 | self.n_feature = n_feature 132 | self.fc1 = nn.Linear(n_feature, n_feature, bias=False) 133 | self.conv1 = nn.Conv2d(n_feature, d_out, kernel_size=[1, 1], 134 | stride=[1, 1]) 135 | self.bn1 = nn.BatchNorm2d(d_out, eps=1e-6, momentum=0.01) 136 | 137 | def forward(self, x): 138 | batch_size = x.shape[0] 139 | num_points = x.shape[2] 140 | num_neigh = x.shape[3] 141 | x = x.permute(0, 2, 3, 1).contiguous() 142 | x = torch.reshape(x, [-1, num_neigh, self.n_feature]) 143 | att_activation = self.fc1(x) 144 | att_score = F.softmax(att_activation, dim=1) 145 | x = x * att_score 146 | x = torch.sum(x, dim=1) 147 | x = torch.reshape(x, [batch_size, num_points, self.n_feature])[:, :, :, None].permute(0, 2, 1, 3).contiguous() 148 | x = F.leaky_relu(self.bn1(self.conv1(x)), negative_slope=0.2) 149 | return x 150 | 151 | 152 | def relative_pos_encoding(xyz, neighbor_idx): 153 | neighbor_xyz = gather_neighbour(xyz, neighbor_idx) 154 | xyz = xyz[:, :, None, :].permute(0, 3, 1, 2).contiguous() 155 | repeated_xyz = xyz.repeat(1, 1, 1, 16) 156 | relative_xyz = repeated_xyz - neighbor_xyz 157 | relative_dist = torch.sqrt(torch.sum(relative_xyz**2, dim=1, keepdim=True)) 158 | relative_feature = torch.cat([relative_dist, relative_xyz, repeated_xyz, neighbor_xyz], dim=1) 159 | return relative_feature 160 | 161 | 162 | def gather_neighbour(point_features, neighbor_idx): 163 | batch_size = point_features.shape[0] 164 | n_points = point_features.shape[1] 165 | n_features = point_features.shape[2] 166 | index_input = torch.reshape(neighbor_idx, shape=[batch_size, -1]) 167 | features = batch_gather(point_features, index_input) 168 | features = torch.reshape(features, [batch_size, 169 | n_points, 170 | neighbor_idx.shape[-1], 171 | n_features]) 172 | return features.permute(0, 3, 1, 2).contiguous() 173 | 174 | 175 | def random_sample(feature, pool_idx): 176 | feature = torch.squeeze(feature, dim=3) 177 | num_neigh = pool_idx.shape[-1] 178 | batch_size = pool_idx.shape[0] 179 | d = feature.shape[1] 180 | feature = feature.permute(0, 2, 1).contiguous() 181 | pool_idx = torch.reshape(pool_idx, [batch_size, -1]) 182 | pool_features = batch_gather(feature, pool_idx) 183 | pool_features = torch.reshape(pool_features, [batch_size, -1, num_neigh, d]) 184 | pool_features = torch.max(pool_features, dim=2, keepdim=True)[0] 185 | return pool_features.permute(0, 3, 1, 2).contiguous() 186 | 187 | 188 | def nearest_interpolation(feature, interp_idx): 189 | feature = torch.squeeze(feature, dim=3) 190 | batch_size = interp_idx.shape[0] 191 | up_num_points = interp_idx.shape[1] 192 | interp_idx = torch.reshape(interp_idx, [batch_size, up_num_points]) 193 | feature = feature.permute(0, 2, 1).contiguous() 194 | interp_features = batch_gather(feature, interp_idx) 195 | return interp_features.permute(0, 2, 1)[:, :, :, None].contiguous() 196 | 197 | 198 | def batch_gather(tensor, indices): 199 | shape = list(tensor.shape) 200 | device = tensor.device 201 | flat_first = torch.reshape( 202 | tensor, [shape[0] * shape[1]] + shape[2:]) 203 | offset = torch.reshape( 204 | torch.arange(shape[0], device=device) * shape[1], 205 | [shape[0]] + [1] * (len(indices.shape) - 1)) 206 | output = flat_first[indices.long() + offset] 207 | return output 208 | -------------------------------------------------------------------------------- /model/sampler.py: -------------------------------------------------------------------------------- 1 | from random import random 2 | 3 | import numpy as np 4 | from torch.utils import data 5 | 6 | 7 | class RandlanetWeightedSampler(data.Sampler): 8 | 9 | def __init__(self, dataset, n_steps): 10 | self.weights = dataset.total_class_weight 11 | self.kdtrees = dataset.kdtrees 12 | self.labels = dataset.labels 13 | self.possibility = dict() 14 | self.min_possibility = dict() 15 | self.n_steps = n_steps 16 | self.cfg = dataset.cfg 17 | 18 | def __iter__(self): 19 | np.random.seed() 20 | for pc_id, kdtree in self.kdtrees.items(): 21 | self.possibility[pc_id] = np.random.rand(len(kdtree.data)) * 1e-3 22 | self.min_possibility[pc_id] = float(np.min(self.possibility[pc_id])) 23 | 24 | for _ in range(self.n_steps): 25 | pc_id = min(self.min_possibility, key=self.min_possibility.get) 26 | point_idx = np.argmin(self.possibility[pc_id]) 27 | # Get all points within the cloud from tree structure 28 | points = np.array(self.kdtrees[pc_id].data, copy=False) 29 | 30 | # Center point of input region 31 | center_point = points[point_idx, :].reshape(1, -1) 32 | 33 | # Add noise to the center point 34 | noise = np.random.normal(scale=self.cfg['noise_init'] / 10, 35 | size=center_point.shape) 36 | pick_point = center_point + noise.astype(center_point.dtype) 37 | # takes the indices of num_points neighbours 38 | query_idx = self.kdtrees[pc_id].query(pick_point, 39 | k=self.cfg['num_points'])[1][0] 40 | # Get corresponding points and colors based on the index 41 | queried_pc_labels = self.labels[pc_id][query_idx] 42 | queried_pt_weight = np.array( 43 | [self.weights[lbl] for lbl in queried_pc_labels]) 44 | # Update the possibility of the selected points 45 | dists = np.sum( 46 | np.square((points[query_idx] - pick_point).astype(np.float32)), 47 | axis=1) 48 | delta = np.square(1 - dists / np.max(dists)) * queried_pt_weight 49 | self.possibility[pc_id][query_idx] += delta 50 | self.min_possibility[pc_id] = float(np.min(self.possibility[pc_id])) 51 | yield pc_id, pick_point 52 | 53 | def __len__(self): 54 | return self.n_steps 55 | -------------------------------------------------------------------------------- /model/testing.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import math 3 | import glob 4 | 5 | import numpy as np 6 | import datetime 7 | 8 | import torch.nn.functional as F 9 | from torch import nn 10 | from torch.utils import data 11 | from tqdm import tqdm 12 | import torch 13 | 14 | from .dataset import RandlanetDataset 15 | from .sampler import RandlanetWeightedSampler 16 | from .model import RandlaNet 17 | from .utils import create_metadata, read_metadata, check_create_folder, generate_k3d_plot, color_mapping, name_mapping 18 | from .training import unpack_input 19 | 20 | 21 | def segment(test_loader, model, device, inv_mapping, cfg, max_epoch=150): 22 | """ 23 | Args: 24 | test_loader: todo 25 | model: pytorch loaded model 26 | device: pytorch device 27 | inv_mapping: dict to map net outputs to original labels 28 | """ 29 | test_logger = tqdm(test_loader, 30 | desc="Segmentation", 31 | total=len(test_loader)) 32 | n_points = len(test_loader.dataset) 33 | n_classes = len(inv_mapping) 34 | xyz_probs = np.zeros((n_points, n_classes)) 35 | xyz_probs[:] = np.nan 36 | visited = np.zeros((n_points,), dtype=np.int32) 37 | model.eval() 38 | #test_smooth = 0.98 39 | n_votes = 2 40 | with torch.no_grad(): 41 | for step in range(max_epoch): 42 | test_logger = tqdm(test_loader, 43 | desc="Segmentation", 44 | total=len(test_loader)) 45 | print(f"Round {step}") 46 | for input_list in test_logger: 47 | inputs = unpack_input(input_list, cfg['num_layers'], device) 48 | outputs = model(inputs) 49 | outputs = F.log_softmax(outputs, dim=1) 50 | outputs = torch.reshape(outputs, [cfg['val_batch_size'], -1, cfg['num_classes']]) 51 | 52 | for j in range(outputs.shape[0]): 53 | probs = outputs[j, :, :].cpu().detach().float().numpy() 54 | # probs = np.swapaxes(np.squeeze(probs), 0, 1) 55 | ids = inputs['input_inds'][j, :].cpu().detach().int().numpy() 56 | xyz_probs[ids] = np.nanmean([xyz_probs[ids], np.exp(probs)], axis=0) 57 | # xyz_probs[ids] = test_smooth * xyz_probs[ids] \ 58 | # + (1 - test_smooth) * probs 59 | visited[ids] += 1 60 | least_visited = np.min(np.unique(visited)) 61 | if least_visited >= n_votes: 62 | print(f"Each point was visited at least {n_votes}") 63 | break 64 | else: 65 | print(least_visited) 66 | for pc_id in test_loader.dataset.kdtrees: 67 | xyz_tile = test_loader.dataset.kdtrees[pc_id].data 68 | true_rgb = test_loader.dataset.colors[pc_id]*255. 69 | gt_labels = test_loader.dataset.labels[pc_id] 70 | xyz_labels = np.argmax(xyz_probs, axis=1) 71 | 72 | return xyz_tile, xyz_labels, xyz_probs, true_rgb, gt_labels 73 | 74 | 75 | def store_results(model_path, xyz_tile, xyz_labels, xyz_probs, true_rgb, 76 | gt_labels, pc_path, segmentation_name): 77 | """ 78 | Stores segmentation results that will be used to generate analysis 79 | and to upload data to ISIN db 80 | 81 | Args: 82 | model_path: path to the model used for segmentation 83 | xyz_tile: [x,y,z] array of points 84 | xyz_labels: array of labels for each point 85 | xyz_probs: array of model outputs for each point 86 | true_rgb: [r,g,b] array for each point 87 | gt_labels: ground truth labels for each point 88 | pc_path: path containing segmented pc file 89 | segmented pc 90 | segmentation_name: name for the segmentation folder. If None, timestamp 91 | will be used 92 | 93 | """ 94 | print("Storing segmentation results") 95 | date = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') 96 | if segmentation_name is None: 97 | segmentation_name = date 98 | 99 | results_path = f"{model_path}output/segmentations/{segmentation_name}/" 100 | check_create_folder(results_path) 101 | 102 | with open(f"{results_path}/xyz_tile.pickle", "wb") as pickle_out: 103 | pickle.dump(xyz_tile, pickle_out) 104 | 105 | with open(f"{results_path}/xyz_probs.pickle", "wb") as pickle_out: 106 | pickle.dump(xyz_probs, pickle_out) 107 | 108 | with open(f"{results_path}/xyz_labels.pickle", "wb") as pickle_out: 109 | pickle.dump(xyz_labels, pickle_out) 110 | 111 | with open(f"{results_path}/true_rgb.pickle", "wb") as pickle_out: 112 | pickle.dump(true_rgb, pickle_out) 113 | 114 | with open(f"{results_path}/gt_labels.pickle", "wb") as pickle_out: 115 | pickle.dump(gt_labels, pickle_out) 116 | 117 | metadata = read_metadata(pc_path) 118 | metadata['timestamp'] = date 119 | create_metadata(results_path, **metadata) 120 | print(f"Results stored at: {results_path}") 121 | return results_path 122 | 123 | 124 | def segment_randlanet(model_path, pc_path, cfg, num_workers, segmentation_name=None): 125 | """Classify all the points contained in the provided pc using the best 126 | checkpoint of the selected model. It stores the results inside the 127 | model folder. 128 | 129 | Args: 130 | model_path: path to the folder containing all data generated during 131 | model training for a given model 132 | pc_path: path to the folder containing the pc 133 | segmentation_name: name for the segmentation folder. If None, timestamp 134 | will be used 135 | """ 136 | with open(f"{model_path}output/metadata.pkl", "rb") as f: 137 | metadata = pickle.load(f) 138 | mapping = metadata["label_mapping"] 139 | best_epoch = metadata["best_epoch"] 140 | model_name = glob.glob(f"{model_path}checkpoints/{best_epoch}_*.pth")[0] 141 | print(f"Loading model checkpoint: {model_name}") 142 | inv_mapping = {mapping[l]: l for l in mapping} 143 | print(f"Label inverse mapping: {inv_mapping}") 144 | 145 | n_classes = len(mapping) 146 | print(f"Best epoch was {metadata['best_epoch']}") 147 | print("Setting up pytorch") 148 | use_cuda = torch.cuda.is_available() 149 | print(f"Use cuda: {use_cuda}") 150 | device = torch.device("cuda:0" if use_cuda else "cpu") 151 | 152 | test_params = {"batch_size": cfg['val_batch_size'], 153 | "shuffle": False, 154 | "num_workers": num_workers} 155 | 156 | test_set = RandlanetDataset([pc_path], **cfg) 157 | test_sampler = RandlanetWeightedSampler(test_set, 158 | cfg['val_batch_size'] * cfg[ 159 | 'val_steps']) 160 | 161 | test_loader = data.DataLoader(test_set, sampler=test_sampler, **test_params) 162 | nice_model = model_name 163 | model = RandlaNet(n_layers=cfg['num_layers'], n_classes=cfg['num_classes'], d_out=cfg['d_out']) 164 | 165 | if not use_cuda: 166 | map_location = torch.device("cpu") 167 | model.load_state_dict(torch.load(nice_model, map_location=map_location)) 168 | else: 169 | model.load_state_dict(torch.load(nice_model)) 170 | model = model.to(device) 171 | 172 | xyz_tile, xyz_labels, xyz_probs, true_rgb, gt_labels = \ 173 | segment(test_loader, model, device, inv_mapping, cfg) 174 | 175 | results_path = store_results(model_path, xyz_tile, xyz_labels, xyz_probs, true_rgb, 176 | gt_labels, pc_path, segmentation_name) 177 | 178 | mask_map = {} 179 | for label in mapping.values(): 180 | mask = xyz_labels == label 181 | mask_map[label] = mask 182 | 183 | plot = generate_k3d_plot(xyz_tile, mask_map=mask_map, mask_color=color_mapping, name_map=name_mapping) 184 | snapshot = plot.get_snapshot(9) 185 | snap_path = f"{results_path}snapshot_predictions.html" 186 | with open(snap_path, 'w') as fp: 187 | fp.write(snapshot) 188 | print(f"Labelled snapshot save at {snap_path}") 189 | 190 | plot = generate_k3d_plot(xyz_tile, rgb=true_rgb, mask_map=mask_map, name_map=name_mapping) 191 | snapshot = plot.get_snapshot(9) 192 | snap_path = f"{results_path}snapshot_rgb.html" 193 | with open(snap_path, 'w') as fp: 194 | fp.write(snapshot) 195 | print(f"RGB snapshot save at {snap_path}") 196 | print("Segmentation Done") 197 | -------------------------------------------------------------------------------- /model/training.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from datetime import datetime 3 | 4 | import mlflow 5 | import torch 6 | from torch import nn 7 | from torch.utils import data 8 | from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR 9 | import torch.nn.functional as F 10 | from tqdm import tqdm 11 | import numpy as np 12 | from sklearn.metrics import confusion_matrix 13 | import pandas as pd 14 | 15 | from .dataset import RandlanetDataset 16 | from .model import RandlaNet 17 | from .sampler import RandlanetWeightedSampler 18 | from .utils import MODEL_SAVES_PATH, check_create_folder, separated_multi_auc 19 | 20 | 21 | def train_model(model, max_epochs, train_loader, test_loader, device, 22 | output_path, checkpoint_path, lr, use_mlflow, n_layers, n_classes, ith_kfold=None): 23 | """ 24 | Function used to train a model 25 | 26 | Args: 27 | model: PyTorch model used being trained 28 | max_epochs: maximum number of epochs 29 | train_loader: PyTorch Dataloader for train set 30 | test_loader: PyTorch Dataloader for test/validation set 31 | device: PyTorch computing device (e.g. 'cpu','cuda') 32 | output_path: where to save training information and plots 33 | checkpoint_path: where to save model checkpoint 34 | lr: learning rate 35 | use_mlflow: if True, logs metrics to mlflow 36 | 37 | """ 38 | class_weight = torch.tensor( 39 | list(train_loader.dataset.total_class_count.values())).to(device) 40 | class_weight = class_weight.float()/torch.sum(class_weight).float() 41 | class_weight = 1 / (class_weight+0.02) 42 | loss = nn.CrossEntropyLoss(reduction='none') 43 | optimizer = torch.optim.Adam(model.parameters(), lr=lr, eps=1e-8) 44 | scheduler = ReduceLROnPlateau(optimizer, 45 | mode='max', 46 | verbose=True, 47 | patience=1, 48 | cooldown=2, 49 | factor=0.95 50 | ) 51 | max_patience = max(max_epochs, 1) 52 | epochs_logger = tqdm(range(1, max_epochs + 1), desc="epoch") 53 | num_labels = len(test_loader.dataset.mapping) 54 | print("Saving Metadata") 55 | inv_map = {v: k for k, v in test_loader.dataset.mapping.items()} 56 | print(inv_map) 57 | metadata = dict() 58 | metadata['label_mapping'] = test_loader.dataset.mapping 59 | metadata['inv_map'] = inv_map 60 | metadata['best_epoch'] = -1 61 | index_best = -1 62 | best_checkpoint_path = '' 63 | with open(f'{output_path}metadata.pkl', 'wb') as file: 64 | pickle.dump(metadata, file) 65 | if use_mlflow: 66 | mlflow.set_tracking_uri("http://localhost:9999") 67 | mlflow.set_experiment('randlanet') 68 | mlflow.start_run() 69 | ml_flow_run_id = mlflow.active_run().info.run_id 70 | if ith_kfold is not None: 71 | mlflow.log_param("k-fold iteration", ith_kfold) 72 | mlflow.log_param("device", device) 73 | mlflow.log_param("max_epochs", max_epochs) 74 | mlflow.log_param("max_patience", max_patience) 75 | mean_iou_list = [0] 76 | history = pd.DataFrame() 77 | print("Start Training") 78 | 79 | for epoch in epochs_logger: 80 | # Training 81 | train_loss, train_acc = train_epoch(device, loss, model, 82 | optimizer, train_loader, n_layers, 83 | class_weight) 84 | val_acc, val_iou, val_mean_iou, val_aucs, val_mean_auc = validation(device, model, test_loader, n_layers, 85 | n_classes, scheduler) 86 | 87 | if val_mean_iou > np.max(mean_iou_list): 88 | index_best = len(mean_iou_list) 89 | mean_iou_list.append(val_mean_iou) 90 | checkpoint_name = checkpoint_save(checkpoint_path, epoch, val_acc, val_mean_iou, 91 | model) 92 | 93 | for param_group in optimizer.param_groups: 94 | current_lr = param_group['lr'] 95 | 96 | iou_dict = {f"iou_{inv_map[i]}": val_iou[i] 97 | for i in range(len(val_iou))} 98 | auc_dict = {f"auc_{inv_map[i]}": val_aucs[i] 99 | for i in range(len(val_aucs))} 100 | 101 | history = history.append({"epoch": epoch, 102 | "train_loss": train_loss, 103 | "train_av_acc": train_acc, 104 | "val_av_acc": val_acc, 105 | "val_av_iou": val_mean_iou, 106 | "val_auc": val_mean_auc, 107 | **iou_dict, 108 | **auc_dict}, ignore_index=True) 109 | 110 | 111 | if use_mlflow: 112 | mlflow.log_metric("train_loss", train_loss, epoch) 113 | mlflow.log_metric("train_acc", train_acc, epoch) 114 | mlflow.log_metric("val_acc", val_acc, epoch) 115 | mlflow.log_metric("val_mean_iou", val_mean_iou, epoch) 116 | mlflow.log_metric("val_mean_auc", val_mean_auc, epoch) 117 | 118 | mlflow.log_metrics(iou_dict, epoch) 119 | mlflow.log_metrics(auc_dict, epoch) 120 | mlflow.log_metric("lr", current_lr, epoch) 121 | 122 | if metadata['best_epoch'] != index_best: 123 | metadata['best_epoch'] = index_best 124 | best_checkpoint_path = checkpoint_name 125 | if use_mlflow: 126 | mlflow.log_param("best_epoch", index_best) 127 | mlflow.log_param("best_epoch_checkpoint_path", best_checkpoint_path) 128 | with open(f'{output_path}metadata.pkl', 'wb') as file: 129 | pickle.dump(metadata, file) 130 | if epoch - index_best > max_patience: 131 | print("\n\nearly stopping!") 132 | break 133 | epochs_logger.set_postfix_str(f"t_loss={train_loss:.5f}, " 134 | f"t_acc={train_acc:.5f}, " 135 | f"v_acc={val_acc:.5f}, " 136 | f"v_iou={val_mean_iou:.5f}") 137 | history_save_path = f"{output_path}history.csv" 138 | history.to_csv(history_save_path) 139 | print(f"best epoch:{index_best}") 140 | print("Finished Training") 141 | if use_mlflow: 142 | mlflow.end_run() 143 | return best_checkpoint_path, history_save_path, ml_flow_run_id 144 | else: 145 | return best_checkpoint_path, history_save_path, None 146 | 147 | 148 | 149 | def train_epoch(device, loss_function, model, optimizer, 150 | train_loader, n_layers, class_weight): 151 | """ 152 | Function that train a single epoch 153 | 154 | Args: 155 | device: PyTorch computing device (e.g. 'cpu','cuda') 156 | loss_function: PyTorch loss function 157 | model: PyTorch model used being trained 158 | optimizer: PyTorch optimizer used for training 159 | train_loader: PyTorch Dataloader for train set 160 | 161 | Returns: 162 | updated history dictionary 163 | 164 | """ 165 | train_logger = tqdm(train_loader, 166 | desc="Train", 167 | total=len(train_loader)) 168 | 169 | train_losses = [] 170 | train_accs = [] 171 | model.train() 172 | 173 | for input_list in train_logger: 174 | inputs = unpack_input(input_list, n_layers, device) 175 | 176 | # zero the parameter gradients 177 | optimizer.zero_grad() 178 | 179 | # forward + backward + optimize 180 | outputs = model(inputs) 181 | labels = torch.reshape(inputs['labels'], [-1]) 182 | 183 | one_hot = torch.zeros_like(outputs) 184 | one_hot[range(labels.shape[0]), labels] = 1 185 | one_hot = one_hot * class_weight 186 | one_hot = torch.sum(one_hot, dim=1) 187 | loss = loss_function(outputs, labels) 188 | loss = loss*one_hot 189 | loss = loss.mean() 190 | loss.backward() 191 | preds = F.log_softmax(outputs, dim=-1).argmax(-1) 192 | train_acc = (preds == labels).to(torch.float32).mean() 193 | optimizer.step() 194 | 195 | train_logger.set_postfix_str(f"t_loss={loss.item():.5f}, " 196 | f"t_acc={train_acc.item():.5f}") 197 | 198 | train_losses.append(loss.item()) 199 | train_accs.append(train_acc.item()) 200 | 201 | return np.mean(train_losses), np.mean(train_accs) 202 | 203 | 204 | def validation(device, model, test_loader, n_layers, n_classes, scheduler): 205 | """ 206 | Given a model a dataset and a set of parameters, the function returns 207 | all the metrics relative to the validation set 208 | 209 | Args: 210 | device: PyTorch computing device (e.g. 'cpu','cuda') 211 | model: PyTorch model to be validated 212 | test_loader: PyTorch Dataloader for test/validation set 213 | 214 | Returns: 215 | updated history dictionary and multiple validation metrics 216 | 217 | """ 218 | model.eval() 219 | gt_classes = [0 for _ in range(n_classes)] 220 | positive_classes = [0 for _ in range(n_classes)] 221 | true_positive_classes = [0 for _ in range(n_classes)] 222 | val_total_correct = 0 223 | val_total_seen = 0 224 | all_preds = torch.Tensor().cpu() 225 | all_gts = torch.Tensor().cpu() 226 | test_logger = tqdm(test_loader, 227 | desc="Validation", 228 | total=len(test_loader)) 229 | auc_every = len(test_loader)//5 230 | with torch.no_grad(): 231 | for cnt, input_list in enumerate(test_logger): 232 | inputs = unpack_input(input_list, n_layers, device) 233 | outputs = model(inputs) 234 | logits = F.log_softmax(outputs, dim=-1) 235 | pred = logits.argmax(1).cpu().numpy() 236 | labels = torch.reshape(inputs['labels'], [-1]).cpu() 237 | if cnt % auc_every == 0: 238 | all_gts = torch.cat([all_gts, labels]) 239 | all_preds = torch.cat([all_preds, logits.cpu()]) 240 | labels = labels.numpy() 241 | correct = np.sum(pred == labels) 242 | val_total_correct += correct 243 | val_total_seen += len(labels) 244 | 245 | conf_matrix = confusion_matrix(labels, pred, 246 | np.arange(0, n_classes, 1)) 247 | gt_classes += np.sum(conf_matrix, axis=1) 248 | positive_classes += np.sum(conf_matrix, axis=0) 249 | true_positive_classes += np.diagonal(conf_matrix) 250 | 251 | iou_list = [] 252 | for n in range(0, n_classes, 1): 253 | iou = true_positive_classes[n] / float( 254 | gt_classes[n] + positive_classes[n] - true_positive_classes[n]) 255 | iou_list.append(iou) 256 | mean_iou = sum(iou_list) / float(n_classes) 257 | val_acc = val_total_correct / float(val_total_seen) 258 | val_aucs = separated_multi_auc( 259 | pred=all_preds, label=all_gts, num_labels=n_classes) 260 | mean_val_auc = np.mean(list(val_aucs.values())) 261 | print('eval accuracy: {}'.format(val_acc)) 262 | print('mean IOU:{}'.format(mean_iou)) 263 | print('mean AUC:{}'.format(mean_val_auc)) 264 | # update the lr scheduler step 265 | # scheduler.step() 266 | scheduler.step(mean_iou) 267 | #mean_iou = 100 * mean_iou 268 | print('Mean IoU = {:.1f}%'.format(100*mean_iou)) 269 | s = '{:5.2f} | '.format(100*mean_iou) 270 | for IoU in iou_list: 271 | s += '{:5.2f} '.format(100 * IoU) 272 | print('-' * len(s)) 273 | print(s) 274 | print('-' * len(s) + '\n') 275 | return val_acc, iou_list, mean_iou, val_aucs, mean_val_auc 276 | 277 | 278 | def unpack_input(input_list, n_layers, device): 279 | inputs = dict() 280 | inputs['xyz'] = input_list[:n_layers] 281 | inputs['neigh_idx'] = input_list[n_layers: 2 * n_layers] 282 | inputs['sub_idx'] = input_list[2 * n_layers:3 * n_layers] 283 | inputs['interp_idx'] = input_list[3 * n_layers:4 * n_layers] 284 | for key, val in inputs.items(): 285 | inputs[key] = [x.to(device) for x in val] 286 | inputs['features'] = input_list[4 * n_layers].to(device) 287 | inputs['labels'] = input_list[4 * n_layers + 1].to(device) 288 | inputs['input_inds'] = input_list[4 * n_layers + 2].to(device) 289 | inputs['cloud_inds'] = input_list[4 * n_layers + 3].to(device) 290 | return inputs 291 | 292 | 293 | def checkpoint_save(checkpoint_path, epoch, mean_v_acc, 294 | mean_v_iou, model): 295 | """ 296 | This function saves a model checkpoint using PyTorch formats 297 | 298 | Args: 299 | checkpoint_path: where to save model checkpoint 300 | epoch: epoch number 301 | mean_v_acc: mean validation accuracy (if only one per epoch than that 302 | value is used) 303 | mean_val_loss: mean validation loss (if only one per epoch than that 304 | value is used) 305 | mean_v_iou: mean validation iou (if only one per epoch than that 306 | value is used) 307 | model: PyTorch model to be saved 308 | 309 | """ 310 | checkpoint_name = f"{epoch}" \ 311 | f"_v_acc={mean_v_acc:.3f}" \ 312 | f"_v_iou={mean_v_iou}" \ 313 | f"_state_dict.pth" 314 | checkpoint_filename = f"{checkpoint_path}{checkpoint_name}" 315 | torch.save(model.state_dict(), checkpoint_filename) 316 | return checkpoint_filename 317 | 318 | 319 | def train_randlanet_model(train_set_list, test_set_list, hyperpars, use_mlflow=False, 320 | num_workers=4, model_name=None): 321 | """ 322 | Function for training randlanet using provided point clouds filepath 323 | as train and test sets. Logs metrics to mlflow if use_mlflow==True. 324 | 325 | Args: 326 | train_set_list: list of path to full pc folders to use as train 327 | test_set_list: list of path to full pc folders to use as test 328 | use_mlflow: if True, logs metrics to mlflow 329 | max_epochs: maximum number of epochs 330 | batch_size: batch size 331 | num_workers: number of parallel pytorch workers to load data 332 | learning_rate: learning rate for training 333 | model_name: name of the model folder. If None, timestamp is used 334 | 335 | """ 336 | 337 | # CUDA for PyTorch 338 | use_cuda = torch.cuda.is_available() 339 | available_gpu = "cuda:0" 340 | device = torch.device(available_gpu if use_cuda else "cpu") 341 | 342 | # Parameters for pytorch dataloaders 343 | train_params = {"batch_size": hyperpars['batch_size'], 344 | "shuffle": False, 345 | "num_workers": num_workers, 346 | "pin_memory": False} 347 | test_params = {"batch_size": hyperpars['val_batch_size'], 348 | "shuffle": False, 349 | "num_workers": num_workers, 350 | "pin_memory": False} 351 | 352 | if model_name is None: 353 | model_name = datetime.now().strftime("%Y_%m_%d_%H_%M_%S") 354 | 355 | model_save_folder = (MODEL_SAVES_PATH + f'{model_name}/') 356 | output_path = f"{model_save_folder}output/" 357 | checkpoint_path = f"{model_save_folder}checkpoints/" 358 | check_create_folder(output_path) 359 | check_create_folder(checkpoint_path) 360 | 361 | train_set = RandlanetDataset(train_set_list, **hyperpars) 362 | train_sampler = RandlanetWeightedSampler( 363 | train_set, hyperpars['batch_size'] * hyperpars['train_steps']) 364 | train_loader = data.DataLoader( 365 | train_set, sampler=train_sampler, **train_params) 366 | test_set = RandlanetDataset(test_set_list, **hyperpars) 367 | test_sampler = RandlanetWeightedSampler( 368 | test_set, hyperpars['val_batch_size'] * hyperpars['val_steps']) 369 | test_loader = data.DataLoader( 370 | test_set, sampler=test_sampler, **test_params) 371 | 372 | with open(f"{output_path}datasets_used.txt", "a") as fl: 373 | fl.write(f'Datasets used:\n' 374 | f' train: {train_set_list}\n' 375 | f' test: {test_set_list}') 376 | Warning("Re mapping of labels values from original to 0 to max_num_labels") 377 | 378 | model = RandlaNet( 379 | hyperpars['d_out'], hyperpars['num_layers'], hyperpars['num_classes']) 380 | model = model.to(device) 381 | train_model(model, hyperpars['max_epoch'], train_loader, test_loader, device, 382 | output_path, checkpoint_path, hyperpars['learning_rate'], use_mlflow, 383 | hyperpars['num_layers'], hyperpars['num_classes']) 384 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from scipy.spatial.transform import Rotation as R 4 | import torch.nn as nn 5 | import numpy as np 6 | from sklearn.metrics import roc_auc_score 7 | import k3d 8 | import seaborn as sns 9 | 10 | DATA_ROOT_PATH = 'data/' 11 | MODEL_SAVES_PATH = DATA_ROOT_PATH + "saved_models/" 12 | 13 | def check_create_folder(folder_path): 14 | if not os.path.exists(os.path.dirname(folder_path)): 15 | try: 16 | os.makedirs(os.path.dirname(folder_path)) 17 | except OSError as exc: # Guard against race condition 18 | if exc.errno != errno.EEXIST: 19 | raise 20 | 21 | 22 | def create_metadata(path, **kwargs): 23 | """ 24 | Creates the metadata file at path/metadata/metadata.pkl 25 | 26 | Args: 27 | path: path to metadata folder 28 | kwargs: each kwarg is stored in the dict saved as pickle file 29 | 30 | """ 31 | meta_path = f'{path}metadata/' 32 | check_create_folder(meta_path) 33 | with open(f'{meta_path}metadata.pickle', 'wb') as f: 34 | d = {x[0]: x[1] for x in kwargs.items()} 35 | pickle.dump(d, f) 36 | print(f'Metadata Stored : {d}') 37 | 38 | 39 | def read_metadata(path): 40 | """ 41 | Read the metadata file at path/metadata/metadata.pkl 42 | 43 | Args: 44 | path: path to metadata folder 45 | 46 | Returns: 47 | dict contained in metadata file 48 | 49 | """ 50 | meta_file = f'{path}metadata/metadata.pickle' 51 | with open(meta_file, 'rb') as f: 52 | return pickle.load(f) 53 | 54 | def rotate(points, angles): 55 | """ 56 | Rotates a set of point around 'xyz' for the angles given in radians. 57 | 58 | Args: 59 | points: array of 3d points, with shape (n_points, 3) 60 | angles: angles[0]: rotation around x 61 | angles[1]: rotation around y 62 | angles[2]: rotation around z 63 | 64 | Returns: 65 | rotated points 66 | 67 | """ 68 | r = R.from_euler('xyz', angles, degrees=False) 69 | return r.apply(points) 70 | 71 | def separated_multi_auc(pred, label, num_labels): 72 | """ 73 | Computes the AUC 74 | 75 | Args: 76 | pred: torch tensor of predictions with dimension [n_samples, n_labels] each prediction 77 | has to be a probability 78 | label: torch tensor of dimension `[n_samples]` of ground truth 79 | labels NOT one-hot encoded 80 | 81 | Returns: 82 | AUC of the prediction computed with sklearn roc_auc_score with parameters `multi_class="ovr"` and `averege="macro"` 83 | """ 84 | np_pred = pred.cpu().detach().numpy() 85 | np_label = label.cpu().numpy().astype(np.int64) 86 | # num_labels = len(np.unique(np_label)) 87 | np_label_one_hot = np.zeros((np_label.size, num_labels)) 88 | # print(num_labels) 89 | np_label_one_hot[np.arange(np_label.size), np_label] = 1 90 | ret = {} 91 | for label_ind in range(num_labels): 92 | ret[label_ind] = roc_auc_score(y_score=np_pred[:, label_ind], 93 | y_true=np_label_one_hot[:, label_ind]) 94 | return ret 95 | 96 | 97 | name_mapping = { 98 | 0: "terrain", 99 | 1: "construction", 100 | 2: "urban_asset", 101 | 3: "vegetation", 102 | 4: "vehicle", 103 | } 104 | palette = sns.color_palette("pastel") 105 | # map at least green to vegetation 106 | g = palette[2] 107 | palette[2] = palette[4] 108 | palette[4] = g 109 | # create color mapping in a format which is suitable for k3d 110 | color_mapping = { 111 | l: [int(palette[i][0] * 255), 112 | int(palette[i][1] * 255), 113 | int(palette[i][2] * 255)] 114 | for i, l in enumerate(name_mapping) 115 | } 116 | 117 | def pack(r, g, b): 118 | """ 119 | (r,g,b) tuple to hex. Each r,g,b can be column arrays 120 | 121 | """ 122 | return ( 123 | (np.array(r).astype(np.uint32) << 16) 124 | + (np.array(g).astype(np.uint32) << 8) 125 | + np.array(b).astype(np.uint32) 126 | ) 127 | 128 | 129 | def pack_single(rgb): 130 | """ 131 | [r,g,b] one line array to hex 132 | """ 133 | return ( 134 | (rgb[0] << 16) 135 | + (rgb[1] << 8) 136 | + rgb[2] 137 | ) 138 | 139 | 140 | def generate_k3d_plot(xyz, rgb=None, mask_map=None, mask_color=None, 141 | name_map=None, old_plot=None): 142 | """ 143 | Generates a k3d snapshot of a set of 3d points, mapping them either 144 | to their true rgb color or a colour corresponding to their label. 145 | Labels are also mapped to names, so that they can be easily toggled 146 | inside the visualization tool. 147 | 148 | Args: 149 | xyz: array of [x, y, z] points 150 | rgb: array of [r, g, b] points inside [0, 255] 151 | mask_map: dict mapping each label to a mask over xyz, which allows to 152 | select points from each class 153 | mask_color: dict mapping each label to a single color 154 | name_map: map each numeric label to a descriptive string name 155 | 156 | Returns: 157 | k3d snapshot that can be saved as html for visualization 158 | 159 | """ 160 | kwargs = dict() 161 | if rgb is None: 162 | pass 163 | else: 164 | assert mask_color is None 165 | kwargs["colors"] = pack(rgb[:, 0], rgb[:, 1], rgb[:, 2]) 166 | 167 | if old_plot is None: 168 | plot = k3d.plot() 169 | else: 170 | plot = old_plot 171 | 172 | if mask_map is None: 173 | plt_points = k3d.points(positions=xyz, 174 | point_size=1., 175 | shader="flat", 176 | **kwargs) 177 | plot += plt_points 178 | else: 179 | for label in mask_map: 180 | mask = mask_map[label] 181 | if name_map is None: 182 | legend_label = f"label {label}" 183 | else: 184 | legend_label = f"{name_map[label]}" 185 | if mask_color is None: 186 | colors = kwargs["colors"][mask] 187 | plt_points = k3d.points(positions=xyz[mask], 188 | point_size=1., 189 | shader="flat", 190 | name=legend_label, 191 | colors=colors) 192 | plot += plt_points 193 | else: 194 | color = pack_single(mask_color[label]) 195 | plt_points = k3d.points( 196 | positions=xyz[mask], 197 | point_size=1., 198 | shader="flat", 199 | name=legend_label, 200 | color=color, 201 | ) 202 | plot += plt_points 203 | plot.camera_mode = 'orbit' 204 | plot.grid_auto_fit = False 205 | # plot.grid = np.concatenate((np.min(xyz, axis=0), np.max(xyz, axis=0))) 206 | plot.grid_visible = False 207 | return plot -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from model.testing import segment_randlanet 2 | from model.hyperparameters import hyp 3 | from model.dataset import RandlanetDataset 4 | from model.training import train_randlanet_model 5 | # media/gabri/ext_ssd/nomoko 6 | # train_set = RandlanetDataset(["/media/gabri/ext_ssd/nomoko/datasets/full_pc/pc_id=39/"], **hyp) 7 | # train_randlanet_model(train_set_list =["data/pc_id=636/"], 8 | # test_set_list = ["data/pc_id=637/"], 9 | # hyperpars=hyp, 10 | # use_mlflow=False, 11 | # num_workers=4, 12 | # model_name="repo_example") 13 | 14 | 15 | segment_randlanet(model_path="data/saved_models/repo_example/", 16 | pc_path="data/pc_id=637/", 17 | cfg=hyp, 18 | num_workers=4, 19 | segmentation_name='example') 20 | # import pickle 21 | # import torch 22 | # import glob 23 | # from segmentation.deep_learning.randlanet.model.torch_model import RandlaNet 24 | # from segmentation.deep_learning.randlanet.model.torch_dataset import RandlanetDataset 25 | # from segmentation.deep_learning.randlanet.model.torch_sampler import ( 26 | # RandlanetWeightedSampler, 27 | # ) 28 | # from segmentation.utils.path_utils import MODEL_SAVES_PATH 29 | # from torch.utils import data 30 | # from segmentation.deep_learning.randlanet.model.asegmentation import ( 31 | # segment, 32 | # store_results, 33 | # ) 34 | # from segmentation.utils.k3d_utils import generate_k3d_plot, color_mapping, name_mapping 35 | 36 | 37 | # model_name = "all_pc_lr_sched" 38 | # pc_path = "/data/datasets/full_pc/pc_id=50/" 39 | # pc_id = 50 40 | # model_root_folder = MODEL_SAVES_PATH + f"model_randlanet_" f"{model_name}/" 41 | # with open(f"{model_root_folder}output/metadata.pkl", "rb") as f: 42 | # metadata = pickle.load(f) 43 | # mapping = metadata["label_mapping"] 44 | # best_epoch = metadata["best_epoch"] 45 | 46 | # path_to_model = glob.glob(f"{model_root_folder}checkpoints/{best_epoch}_*.pth")[0] 47 | # print(f"Loading model checkpoint: {model_name}") 48 | # inv_mapping = {mapping[l]: l for l in mapping} 49 | # print(f"Label inverse mapping: {inv_mapping}") 50 | 51 | # n_classes = len(mapping) 52 | # print(f"Best epoch was {metadata['best_epoch']}") 53 | # print("Setting up pytorch") 54 | # use_cuda = torch.cuda.is_available() 55 | # print(f"Use cuda: {use_cuda}") 56 | # device = torch.device("cuda:1" if use_cuda else "cpu") 57 | 58 | 59 | # test_params = { 60 | # "batch_size": hyp["val_batch_size"], 61 | # "shuffle": False, 62 | # "num_workers": 4 63 | # } 64 | 65 | 66 | # test_set = RandlanetDataset([pc_path], **hyp) 67 | # test_sampler = RandlanetWeightedSampler( 68 | # test_set, hyp["val_batch_size"] * hyp["val_steps"] 69 | # ) 70 | 71 | # test_loader = data.DataLoader(test_set, sampler=test_sampler, **test_params) 72 | # model = RandlaNet( 73 | # n_layers=hyp["num_layers"], n_classes=hyp["num_classes"], d_out=hyp["d_out"] 74 | # ) 75 | 76 | # if device=="cpu": 77 | # map_location = torch.device("cpu") 78 | # model.load_state_dict(torch.load(path_to_model, map_location=map_location)) 79 | # else: 80 | # model.load_state_dict(torch.load(path_to_model)) 81 | # model = model.to(device) 82 | 83 | # # segment tiles, tile by tile, classifying samples in batches of chunk_size 84 | # # sampling each single point 85 | # xyz_tile, xyz_labels, xyz_probs, true_rgb, gt_labels = segment( 86 | # test_loader, model, device, inv_mapping, hyp, 150 87 | # ) 88 | # store_results( 89 | # model_root_folder, 90 | # xyz_tile, 91 | # xyz_labels, 92 | # xyz_probs, 93 | # true_rgb, 94 | # gt_labels, 95 | # pc_path, 96 | # f"all_pc_{pc_id}" 97 | # ) 98 | # mask_map = {} 99 | # for label in mapping.values(): 100 | # mask = xyz_labels == label 101 | # mask_map[label] = mask 102 | 103 | # plot = generate_k3d_plot(xyz_tile, mask_map=mask_map, mask_color=color_mapping, name_map=name_mapping) 104 | # snapshot = plot.get_snapshot(9) 105 | # snap_path = f"{model_root_folder}output/segmentations/all_pc_{pc_id}/predictions.html" 106 | # with open(snap_path, 'w') as fp: 107 | # fp.write(snapshot) 108 | # print(f"Snapshot save at {snap_path}") 109 | 110 | # plot = generate_k3d_plot(xyz_tile, rgb=true_rgb, mask_map=mask_map, name_map=name_mapping) 111 | # snapshot = plot.get_snapshot(9) 112 | # snap_path = f"{model_root_folder}output/segmentations/all_pc_{pc_id}/rgb.html" 113 | # with open(snap_path, 'w') as fp: 114 | # fp.write(snapshot) 115 | # print(f"Snapshot save at {snap_path}") 116 | 117 | # print("Segmentation Done") 118 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from model.testing import segment_randlanet 2 | from model.hyperparameters import hyp 3 | from model.dataset import RandlanetDataset 4 | from model.training import train_randlanet_model 5 | # media/gabri/ext_ssd/nomoko 6 | # train_set = RandlanetDataset(["/media/gabri/ext_ssd/nomoko/datasets/full_pc/pc_id=39/"], **hyp) 7 | train_randlanet_model(train_set_list =["data/pc_id=636/"], 8 | test_set_list = ["data/pc_id=637/"], 9 | hyperpars=hyp, 10 | use_mlflow=False, 11 | num_workers=4, 12 | model_name="repo_example") 13 | 14 | 15 | # segment_randlanet("/data/saved_models/model_randlanet_all_pc_lr_sched/", 16 | # f"/data/datasets/full_pc/pc_id=50/", 17 | # hyp, 18 | # 8, segmentation_name=f'all_pc_50') 19 | # import pickle 20 | # import torch 21 | # import glob 22 | # from segmentation.deep_learning.randlanet.model.torch_model import RandlaNet 23 | # from segmentation.deep_learning.randlanet.model.torch_dataset import RandlanetDataset 24 | # from segmentation.deep_learning.randlanet.model.torch_sampler import ( 25 | # RandlanetWeightedSampler, 26 | # ) 27 | # from segmentation.utils.path_utils import MODEL_SAVES_PATH 28 | # from torch.utils import data 29 | # from segmentation.deep_learning.randlanet.model.asegmentation import ( 30 | # segment, 31 | # store_results, 32 | # ) 33 | # from segmentation.utils.k3d_utils import generate_k3d_plot, color_mapping, name_mapping 34 | 35 | 36 | # model_name = "all_pc_lr_sched" 37 | # pc_path = "/data/datasets/full_pc/pc_id=50/" 38 | # pc_id = 50 39 | # model_root_folder = MODEL_SAVES_PATH + f"model_randlanet_" f"{model_name}/" 40 | # with open(f"{model_root_folder}output/metadata.pkl", "rb") as f: 41 | # metadata = pickle.load(f) 42 | # mapping = metadata["label_mapping"] 43 | # best_epoch = metadata["best_epoch"] 44 | 45 | # path_to_model = glob.glob(f"{model_root_folder}checkpoints/{best_epoch}_*.pth")[0] 46 | # print(f"Loading model checkpoint: {model_name}") 47 | # inv_mapping = {mapping[l]: l for l in mapping} 48 | # print(f"Label inverse mapping: {inv_mapping}") 49 | 50 | # n_classes = len(mapping) 51 | # print(f"Best epoch was {metadata['best_epoch']}") 52 | # print("Setting up pytorch") 53 | # use_cuda = torch.cuda.is_available() 54 | # print(f"Use cuda: {use_cuda}") 55 | # device = torch.device("cuda:1" if use_cuda else "cpu") 56 | 57 | 58 | # test_params = { 59 | # "batch_size": hyp["val_batch_size"], 60 | # "shuffle": False, 61 | # "num_workers": 4 62 | # } 63 | 64 | 65 | # test_set = RandlanetDataset([pc_path], **hyp) 66 | # test_sampler = RandlanetWeightedSampler( 67 | # test_set, hyp["val_batch_size"] * hyp["val_steps"] 68 | # ) 69 | 70 | # test_loader = data.DataLoader(test_set, sampler=test_sampler, **test_params) 71 | # model = RandlaNet( 72 | # n_layers=hyp["num_layers"], n_classes=hyp["num_classes"], d_out=hyp["d_out"] 73 | # ) 74 | 75 | # if device=="cpu": 76 | # map_location = torch.device("cpu") 77 | # model.load_state_dict(torch.load(path_to_model, map_location=map_location)) 78 | # else: 79 | # model.load_state_dict(torch.load(path_to_model)) 80 | # model = model.to(device) 81 | 82 | # # segment tiles, tile by tile, classifying samples in batches of chunk_size 83 | # # sampling each single point 84 | # xyz_tile, xyz_labels, xyz_probs, true_rgb, gt_labels = segment( 85 | # test_loader, model, device, inv_mapping, hyp, 150 86 | # ) 87 | # store_results( 88 | # model_root_folder, 89 | # xyz_tile, 90 | # xyz_labels, 91 | # xyz_probs, 92 | # true_rgb, 93 | # gt_labels, 94 | # pc_path, 95 | # f"all_pc_{pc_id}" 96 | # ) 97 | # mask_map = {} 98 | # for label in mapping.values(): 99 | # mask = xyz_labels == label 100 | # mask_map[label] = mask 101 | 102 | # plot = generate_k3d_plot(xyz_tile, mask_map=mask_map, mask_color=color_mapping, name_map=name_mapping) 103 | # snapshot = plot.get_snapshot(9) 104 | # snap_path = f"{model_root_folder}output/segmentations/all_pc_{pc_id}/predictions.html" 105 | # with open(snap_path, 'w') as fp: 106 | # fp.write(snapshot) 107 | # print(f"Snapshot save at {snap_path}") 108 | 109 | # plot = generate_k3d_plot(xyz_tile, rgb=true_rgb, mask_map=mask_map, name_map=name_mapping) 110 | # snapshot = plot.get_snapshot(9) 111 | # snap_path = f"{model_root_folder}output/segmentations/all_pc_{pc_id}/rgb.html" 112 | # with open(snap_path, 'w') as fp: 113 | # fp.write(snapshot) 114 | # print(f"Snapshot save at {snap_path}") 115 | 116 | # print("Segmentation Done") 117 | --------------------------------------------------------------------------------