├── .gitignore ├── LICENSE ├── LICENSE-3RD-PARTY ├── README.md ├── environment.yaml └── treegraph ├── IO ├── __init__.py └── io.py ├── __init__.py ├── attribute_centres.py ├── build_graph.py ├── build_skeleton.py ├── calculate_voxel_length.py ├── common.py ├── cyl2ply.py ├── cylinder_fitting.py ├── distance_from_base.py ├── distance_from_tip.py ├── downsample.py ├── estimate_radius.py ├── fit_cylinders.py ├── generate_cylinder_model.py ├── graph_process.py ├── main.py ├── plots.py ├── scripts ├── batch_tree2qsm.py ├── generate_inputs.py ├── print_results.py └── tree2qsm.py ├── split_furcation.py ├── taper.py └── third_party ├── available_cpu_count.py ├── closestDistanceBetweenLines.py ├── cyl2ply.py ├── cylinder_fitting.py ├── ply_io.py ├── point2line.py ├── ransac_cyl_fit.py └── shortpath.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # vim 107 | *.swp 108 | *.swo 109 | *~ 110 | -------------------------------------------------------------------------------- /LICENSE-3RD-PARTY: -------------------------------------------------------------------------------- 1 | ----------------------------------------------------------------------------- 2 | The 3-Clause BSD License 3 | applies to: 4 | - third_party/cylinder_fitting.py, Copyright (c) Xingjie Pan, 2017 5 | 6 | ----------------------------------------------------------------------------- 7 | 8 | BSD 3-Clause License 9 | 10 | Redistribution and use in source and binary forms, with or without 11 | modification, are permitted provided that the following conditions are met: 12 | 13 | * Redistributions of source code must retain the above copyright notice, this 14 | list of conditions and the following disclaimer. 15 | 16 | * Redistributions in binary form must reproduce the above copyright notice, 17 | this list of conditions and the following disclaimer in the documentation 18 | and/or other materials provided with the distribution. 19 | 20 | * Neither the name of the copyright holder nor the names of its 21 | contributors may be used to endorse or promote products derived from 22 | this software without specific prior written permission. 23 | 24 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 25 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 26 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 27 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 28 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 29 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 30 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 31 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 32 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 33 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TreeGraph 2 | 3 | Treegraph is a Python library for extracting structural parameters from terrestrial LiDAR point clouds of individual trees. 4 | 5 | ## Usage 6 | 7 | ### File structure 8 | 9 | It is assumed the following folders have been created: 10 | 11 | ``` 12 | treegraph_tutorial/ 13 | ├── clouds/ 14 | │ └── 15 | ├── inputs/ 16 | │ └── 17 | └── results/ 18 | └── 19 | 20 | ``` 21 | 22 | ### Generate input files 23 | 24 | Required input args are the path for input data and outputs. More optional input args see `treegraph/scripts/generate_inputs.py`. 25 | 26 | ``` 27 | conda activate treegraph 28 | cd treegraph_tutorial/inputs/ 29 | python /PATH/TO/treegraph/scripts/generate_inputs.py -d '/PATH/TO/clouds/*.ply' -o '/PATH/TO/results/' 30 | 31 | ``` 32 | 33 | ### Run Treegraph 34 | 35 | #### Option 1: Run on a single tree 36 | 37 | `python treegraph/script/tree2qsm.py -i 'inputs/XXX.yml'` 38 | 39 | #### Option 2: Run all the trees one after another: 40 | 41 | `python treegraph/script/batch_tree2qsm.py -i 'inputs/*.yml'` 42 | 43 | #### Option 3: Batch process on HPC 44 | 45 | Example job_script.sh for SLURM system 46 | 47 | #!/bin/bash 48 | # scheduling queue 49 | #SBATCH --partition=high-mem 50 | # max runtime limit 51 | #SBATCH --time=10:00:00 52 | # job name 53 | #SBATCH --job-name=treegraph 54 | # job output and error output 55 | #SBATCH --output %j.out 56 | #SBATCH --error %j.err 57 | # required memory, unit MB 58 | #SBATCH --mem=102400 59 | # working dir 60 | #SBATCH -D /PATH/TO/OUTPUTS/ 61 | # Number of CPU cores 62 | #SBTACH -n 1 63 | 64 | # executable 65 | conda activate treegraph 66 | echo "python ~/miniconda3/envs/treegraph/lib/python3.7/treegraph/scripts/tree2qsm.py -i '$tree'" 67 | python ~/miniconda3/envs/treegraph/lib/python3.7/treegraph/scripts/tree2qsm.py -i "${tree}" 68 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: treegraph 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - python=3.7 7 | - numpy 8 | - scikit-learn 9 | - networkx 10 | - pandas=1.2.5 11 | - pandarallel 12 | - yaml 13 | - tqdm 14 | - matplotlib 15 | - plotly 16 | - seaborn -------------------------------------------------------------------------------- /treegraph/IO/__init__.py: -------------------------------------------------------------------------------- 1 | from treegraph.IO.io import * 2 | from treegraph.third_party.ply_io import * -------------------------------------------------------------------------------- /treegraph/IO/io.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import json 3 | import datetime 4 | from treegraph.third_party import ply_io 5 | from treegraph.common import * 6 | from treegraph import estimate_radius 7 | from treegraph.third_party.cyl2ply import pandas2ply 8 | 9 | def save_centres(centres, path, verbose=False): 10 | 11 | drop = [c for c, d in zip(centres.columns, centres.dtypes) if d in ['object']] 12 | ply_io.write_ply(path, centres.drop(columns=drop).rename(columns={'cx':'x', 'cy':'y', 'cz':'z'})) 13 | if verbose: print('skeleton points saved to:', path) 14 | 15 | def save_pc(pc, path, downsample=False, verbose=False): 16 | 17 | drop = [c for c, d in zip(pc.columns, pc.dtypes) if d in ['object']] 18 | ply_io.write_ply(path, pc.drop(columns=drop).loc[pc.downsample if downsample else pc.index]) 19 | if verbose: print('point cloud saved to:', path) 20 | 21 | 22 | def to_ply(cyls, path, attribute='nbranch', verbose=False): 23 | 24 | cols = ['length', 'radius', 'sx', 'sy', 'sz', 'ax', 'ay', 'az', attribute] 25 | pandas2ply(cyls[cols], attribute, path) 26 | if verbose: print('cylinders saved to:', path) 27 | 28 | 29 | def qsm2json(self, path, name=None, graph=False): 30 | 31 | ### internode data 32 | self.cyls.ncyl = self.cyls.ncyl.astype(int) 33 | # self.cyls.loc[:, 'surface_area'] = 2 * np.pi * self.cyls.radius * self.cyls.length #+ 2 * np.pi * self.cyls.radius**2 34 | 35 | internodes = pd.DataFrame(data=self.cyls.groupby('ninternode').length.sum(), 36 | columns=['length', 'volume', 'ncyl', 'mean_radius', 'is_tip', 37 | 'distal_radius', 'proximal_radius', 'surface_area']) 38 | 39 | internodes.loc[:, 'ncyl'] = self.cyls.groupby('ninternode').vol.count() 40 | internodes.loc[:, 'volume'] = self.cyls.groupby('ninternode').vol.sum() 41 | internodes.loc[:, 'surface_area'] = self.cyls.groupby('ninternode').surface_area.sum() 42 | internodes.loc[:, 'mean_radius'] = self.cyls.groupby('ninternode').radius.mean() 43 | internodes.loc[:, 'parent'] = self.centres.groupby('ninternode').pinternode.min() 44 | internodes.loc[:, 'is_tip'] = self.cyls.groupby('ninternode').is_tip.max().astype(bool) 45 | 46 | first_and_last = self.cyls.groupby('ninternode').ncyl.agg([min, max]).reset_index().rename(columns={'min':'First', 'max':'Last'}) 47 | 48 | # distal radius (ends) 49 | distal_radius_f = lambda row: self.cyls.loc[(self.cyls.ninternode == row.ninternode) & 50 | (self.cyls.ncyl.isin([row.First, row.Last]))].radius.mean() 51 | internodes.loc[:, 'distal_radius'] = first_and_last.apply(distal_radius_f, axis=1) 52 | 53 | 54 | # proximal radius (centre) 55 | centre_cyl = first_and_last[['First', 'Last']].mean(axis=1).astype(int).reset_index().rename(columns={'index':'ninternode', 0:'ncyl'}) 56 | proximal_radius_f = lambda row: self.cyls.loc[(self.cyls.ninternode == row.ninternode) & 57 | (self.cyls.ncyl == row.ncyl)].radius.mean() 58 | internodes.loc[:, 'proximal_radius'] = centre_cyl.apply(proximal_radius_f, axis=1) 59 | 60 | # radius before furcation ("parent" if measured by hand) 61 | b4fur_radius = lambda row: self.centres.loc[(self.centres.ninternode == row.ninternode) & 62 | (self.centres.ncyl == row.Last)].m_radius.item() 63 | internodes.loc[:, 'b4fur_radius'] = first_and_last.apply(b4fur_radius, axis=1) 64 | internodes.loc[internodes.is_tip, 'b4fur_radius'] = np.nan 65 | 66 | # radius after furcation ("child" if measured by hand) 67 | after_fur_radius = lambda row: self.centres.loc[(self.centres.ninternode == row.ninternode) & 68 | (self.centres.ncyl == row.First)].m_radius.item() 69 | internodes.loc[:, 'after_fur_radius'] = first_and_last.apply(after_fur_radius, axis=1) 70 | 71 | ### node data 72 | nodes = self.centres[(self.centres.nbranch != 0) & 73 | (self.centres.ncyl == 0)].set_index('ninternode')[['node_id', 'parent', 'parent_node']] 74 | nodes.rename(columns={'node_id':'child_node', 'parent':'nbranch', 'parent_node':'node_id'}, inplace=True) 75 | nodes.reset_index(inplace=True) 76 | 77 | for ix, row in nodes.iterrows(): 78 | if len(self.centres[self.centres.nbranch == row.nbranch]) != 0: 79 | tip_id = self.centres.loc[(self.centres.nbranch == row.nbranch) & 80 | (self.centres.is_tip)].node_id.values[0] 81 | branch_path = np.array(self.path_ids[int(tip_id)], dtype=int) 82 | idx = np.where(branch_path == int(row.node_id))[0][0] 83 | next_node = branch_path[idx + 1] 84 | row = row.append(pd.Series(index=['next_node'], data=next_node)) 85 | angle = node_angle_f(self.centres[self.centres.node_id == row.child_node][['cx', 'cy', 'cz']].values, 86 | self.centres[self.centres.node_id == row.node_id][['cx', 'cy', 'cz']].values, 87 | self.centres[self.centres.node_id == row.next_node][['cx', 'cy', 'cz']].values)[0][0] 88 | 89 | 90 | nodes.loc[ix, 'surface_area_b'] = self.cyls[self.cyls.p1.isin(branch_path[idx:])].surface_area.sum() 91 | nodes.loc[ix, 'length_b'] = self.cyls[self.cyls.p1.isin(branch_path[idx:])].length.sum() 92 | nodes.loc[ix, 'volums_b'] = self.cyls[self.cyls.p1.isin(branch_path[idx:])].vol.sum() 93 | nodes.loc[ix, 'child_branch'] = self.centres[self.centres.node_id == row.child_node].nbranch.unique() 94 | 95 | # for test 96 | nodes.loc[ix, 'angle'] = angle * 180 / np.pi 97 | 98 | ### input arguments 99 | args = {'data_path': self.data_path, 'output_path': self.output_path, 100 | 'base_idx': self.base_idx, 'min_pts': self.min_pts, 'cluster_size': self.cluster_size, 101 | 'tip_width': self.tip_width, 'verbose': self.verbose, 'base_corr': self.base_corr, 102 | 'dbh_height': self.dbh_height, 'txt_file': self.txt_file, 'save_graph': self.save_graph} 103 | 104 | ### processing time 105 | run_time = {'run_time': self.time} 106 | 107 | # final skeleton graph nodes and edges 108 | G_skel = dict(nodes=[[int(n), self.G_skel_sf.nodes[n]] for n in self.G_skel_sf.nodes()], \ 109 | edges=[[int(u), int(v), self.G_skel_sf.edges[u,v]] for u,v in self.G_skel_sf.edges()]) 110 | 111 | JSON = {'name':name, 112 | 'created':datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 113 | 'args':args, 114 | 'run_time':run_time, 115 | 'tree':self.tree.to_json(), 116 | 'internode':internodes.to_json(), 117 | 'node':nodes.to_json(), 118 | 'cyls':self.cyls.to_json(), 119 | 'centres':self.centres.to_json(), 120 | 'pc':self.pc.to_json(), 121 | 'path_ids':self.path_ids, 122 | 'G_skel':G_skel} 123 | 124 | ### if save initial graph information, json file size would be doubled 125 | if graph: 126 | # initial graph nodes and edges 127 | G_init = dict(nodes=[[int(n), self.G.nodes[n]] for n in self.G.nodes()], \ 128 | edges=[[int(u), int(v), self.G.edges[u,v]] for u,v in self.G.edges()]) 129 | 130 | 131 | JSON = {'name':name, 132 | 'created':datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 133 | 'args':args, 134 | 'run_time':run_time, 135 | 'tree':self.tree.to_json(), 136 | 'internode':internodes.to_json(), 137 | 'node':nodes.to_json(), 138 | 'cyls':self.cyls.to_json(), 139 | 'centres':self.centres.to_json(), 140 | 'pc':self.pc.to_json(), 141 | 'path_ids':self.path_ids, 142 | 'G_skel':G_skel, 143 | 'G_init':G_init} 144 | 145 | with open(path, 'w') as fh: fh.write(json.dumps(JSON)) 146 | 147 | class read_json: 148 | 149 | def __init__ (self, 150 | path, 151 | pretty_printing=False, 152 | attributes=['tree', 'internode', 'node', 'cyls', 'centres', 'pc'], 153 | graph=False): 154 | 155 | JSON = json.load(open(path)) 156 | setattr(self, 'name', JSON['name']) 157 | setattr(self, 'args', JSON['args']) 158 | run_time = JSON['run_time']['run_time'] 159 | setattr(self, 'run_time', run_time) 160 | setattr(self, 'path_ids', JSON['path_ids']) 161 | setattr(self, 'G_skel', JSON['G_skel']) 162 | 163 | 164 | if pretty_printing: 165 | 166 | tree = pd.read_json(JSON['tree']) 167 | 168 | print(f"name:\t\t{JSON['name']}") 169 | print(f"date:\t\t{JSON['created']}") 170 | print(f"H from clouds:\t{tree.loc[0]['H_from_clouds']:.2f} m") 171 | print(f"H from qsm:\t{tree.loc[0]['H_from_qsm']:.2f} m") 172 | print(f"DBH from clouds: {tree.loc[0]['DBH_from_clouds']:.3f} m") 173 | print(f"DBH from qsm:\t{tree.loc[0]['DBH_from_qsm']:.3f} m") 174 | print(f"Tot. branch len: {tree.loc[0]['length']:.2f} m") 175 | print(f"Tot. volume:\t{tree.loc[0]['vol']:.4f} m³ = {tree.loc[0]['vol']*1e3:.1f} L") 176 | print(f"Tot. surface area: {tree.loc[0]['surface_area']:.4f} m2") 177 | print(f"Trunk len:\t{tree.loc[0]['trunk_length']:.2f} m") 178 | print(f"Trunk volume:\t{tree.loc[0]['trunk_vol']:.4f} m³ = {tree.loc[0]['trunk_vol']*1e3:.1f} L") 179 | # print(f"Stem len:\t{tree.loc[0]['stem_length']:.2f} m") 180 | # print(f"Stem volume:\t{tree.loc[0]['stem_vol']:.4f} m³ = {tree.loc[0]['stem_vol']*1e3:.1f} L") 181 | print(f"N tips:\t\t{tree.loc[0]['N_tip']:.0f}") 182 | print(f"Avg tip width:\t{tree.loc[0]['tip_rad_mean']*2:.3f} ± {tree.loc[0]['tip_rad_std']*2:.3f} m") 183 | print(f"Avg distance between tips: {tree.loc[0]['dist_between_tips']:.3f} m") 184 | m, s = divmod(run_time, 60) 185 | h, m = divmod(m, 60) 186 | print(f"Programme running time:\t{h:.0f}h:{m:02.0f}m:{s:02.0f}s") 187 | 188 | for att in attributes: 189 | try: 190 | setattr(self, att, pd.read_json(JSON[att])) 191 | except: 192 | raise Exception('Field "{}" not in {}'.format(att, path)) 193 | 194 | if graph: 195 | setattr(self, 'G_init', JSON['G_init']) 196 | -------------------------------------------------------------------------------- /treegraph/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.simplefilter(action='ignore', category=Warning) 3 | 4 | from treegraph.main import initialise 5 | -------------------------------------------------------------------------------- /treegraph/attribute_centres.py: -------------------------------------------------------------------------------- 1 | import time 2 | import pandas as pd 3 | import numpy as np 4 | import treegraph.distance_from_base 5 | from tqdm.autonotebook import trange 6 | from pandas.api.types import CategoricalDtype 7 | 8 | 9 | def run(centres, path_ids, verbose=False, branch_hierarchy=False): 10 | 11 | with trange(6 if branch_hierarchy else 5, 12 | disable=False if verbose else True, 13 | desc='steps') as pbar: 14 | 15 | # remove nodes that are not graphed - prob outlying clusters 16 | centres = centres.loc[centres.node_id.isin(path_ids.keys())] 17 | 18 | # identify previous node in the graph 19 | previous_node = lambda nid: np.nan if len(path_ids[nid]) == 1 else path_ids[nid][-2] 20 | centres.loc[:, 'pnode'] = centres.node_id.apply(previous_node) 21 | 22 | # if node is a tip 23 | centres.loc[:, 'is_tip'] = False 24 | unique_nodes = np.unique([v for p in path_ids.values() for v in p], return_counts=True) 25 | centres.loc[centres.node_id.isin(unique_nodes[0][unique_nodes[1] == 1]), 'is_tip'] = True 26 | 27 | pbar.set_description("identified tips", refresh=True) 28 | pbar.update(1) # update progress bar 29 | 30 | # calculate branch lengths and numbers 31 | tip_paths = pd.DataFrame(index=centres[centres.is_tip].node_id.values, 32 | columns=['tip2base', 'length', 'nbranch']) 33 | 34 | for k, v in path_ids.items(): 35 | 36 | v = v[::-1] 37 | if v[0] in centres[centres.is_tip].node_id.values: 38 | c1 = centres.set_index('node_id').loc[v[:-1]][['cx', 'cy', 'cz']].values 39 | c2 = centres.set_index('node_id').loc[v[1:]][['cx', 'cy', 'cz']].values 40 | tip_paths.loc[tip_paths.index == v[0], 'tip2base'] = np.linalg.norm(c1 - c2, axis=1).sum() 41 | 42 | pbar.set_description("calculated tip to base lengths", refresh=True) 43 | pbar.update(1) 44 | 45 | centres.sort_values(['slice_id', 'distance_from_base'], inplace=True) 46 | centres.loc[:, 'nbranch'] = -1 47 | centres.loc[:, 'ncyl'] = -1 48 | 49 | for i, row in enumerate(tip_paths.sort_values('tip2base', ascending=False).itertuples()): 50 | 51 | tip_paths.loc[row.Index, 'nbranch'] = i 52 | cyls = path_ids[row.Index] 53 | # sort branch node_id by path list to avoid branch become its own parent 54 | sorter = CategoricalDtype(cyls, ordered=True) 55 | bnodes = centres[centres.node_id.isin(cyls)] 56 | bnodes['node_id'] = bnodes['node_id'].astype(sorter) 57 | bnodes = bnodes.sort_values('node_id') 58 | bnodes.loc[bnodes.nbranch == -1, 'nbranch'] = i 59 | bnodes.loc[bnodes.nbranch == i, 'ncyl'] = np.arange(len(bnodes[bnodes.nbranch == i])) 60 | centres.loc[centres.node_id.isin(bnodes.node_id), 'nbranch'] = bnodes.nbranch 61 | centres.loc[centres.node_id.isin(bnodes.node_id), 'ncyl'] = bnodes.ncyl 62 | 63 | v = centres.loc[centres.nbranch == i].sort_values('ncyl').node_id 64 | c1 = centres.set_index('node_id').loc[v[:-1]][['cx', 'cy', 'cz']].values 65 | c2 = centres.set_index('node_id').loc[v[1:]][['cx', 'cy', 'cz']].values 66 | tip_paths.loc[row.Index, 'length'] = np.linalg.norm(c1 - c2, axis=1).sum() 67 | 68 | # reattribute branch numbers starting with the longest 69 | new_branch_nums = {bn:i for i, bn in enumerate(tip_paths.sort_values('length', ascending=False).nbranch)} 70 | tip_paths.loc[:, 'nbranch'] = tip_paths.nbranch.map(new_branch_nums) 71 | centres.loc[:, 'nbranch'] = centres.nbranch.map(new_branch_nums) 72 | 73 | pbar.set_description("idnetified individual branches", refresh=True) 74 | pbar.update(1) 75 | 76 | centres.loc[:, 'n_furcation'] = 0 77 | centres.loc[:, 'parent'] = -1 78 | centres.loc[:, 'parent_node'] = np.nan 79 | 80 | # loop over branch base and identify parent 81 | for nbranch in centres.nbranch.unique(): 82 | 83 | if nbranch == 0: continue # main branch does not furcate 84 | furcation_node = -1 85 | branch_base_idx = centres.loc[centres.nbranch == nbranch].ncyl.idxmin() 86 | branch_base_idx = centres.loc[branch_base_idx].node_id 87 | 88 | for path in path_ids.values(): 89 | if path[-1] == branch_base_idx: 90 | if len(path) > 1: 91 | furcation_node = path[-2] 92 | else: 93 | furcation_node = path[-1] 94 | centres.loc[centres.node_id == furcation_node, 'n_furcation'] += 1 95 | break 96 | 97 | if furcation_node != -1: 98 | parent = centres.loc[centres.node_id == furcation_node].nbranch.item() 99 | centres.loc[(centres.nbranch == nbranch), 'parent'] = parent 100 | centres.loc[(centres.nbranch == nbranch), 'parent_node'] = furcation_node 101 | 102 | pbar.set_description('attributed nodes and identified parents', refresh=True) 103 | pbar.update(1) 104 | 105 | # loop over branches and attribute internode 106 | # centres.sort_values(['nbranch', 'slice_id', 'distance_from_base'], inplace=True) 107 | centres.sort_values(['nbranch', 'ncyl'], inplace=True) 108 | centres.loc[:, 'ninternode'] = -1 109 | internode_n = 0 110 | 111 | for ix, row in centres.iterrows(): 112 | centres.loc[centres.node_id == row.node_id, 'ninternode'] = internode_n 113 | if row.n_furcation > 0 or row.is_tip: internode_n += 1 114 | 115 | for internode in centres.ninternode.unique(): 116 | if internode == 0: continue # first internode so ignore 117 | # current nodes belong to this segment 118 | cnode = centres[centres.ninternode == internode] 119 | # pnode is the internode of the previous node of this segment 120 | # the first node of a segment is ncyl=0 121 | pnode = cnode[cnode.ncyl == cnode.ncyl.min()].pnode.values[0] 122 | centres.loc[centres.ninternode == internode, 'pinternode'] = centres.loc[centres.node_id == pnode].ninternode.item() 123 | 124 | ## define branch order (wx adds) 125 | centres.loc[:, 'norder'] = -1 126 | # stem (branch order = 0) 127 | centres.loc[(centres.nbranch == 0) & (centres.ninternode == 0), 'norder'] = 0 128 | node_list = [0] 129 | # branch order +1 after a new furcation 130 | i = 1 131 | while -1 in centres.norder.unique(): 132 | centres.loc[centres.pinternode.isin(node_list), 'norder'] = i 133 | node_list = centres[centres.pinternode.isin(node_list)].ninternode.unique() 134 | i += 1 135 | 136 | pbar.set_description('attributed internodes', refresh=True) 137 | pbar.update(1) 138 | 139 | centres = centres.reset_index(drop=True) 140 | 141 | if branch_hierarchy: 142 | 143 | branch_hierarchy = {0:{'parent_branch':np.array([0]), 'above':centres.nbranch.unique()[1:]}} 144 | # loop over each branch and store its parent branch id into dict 145 | for b in np.sort(centres.nbranch.unique()): 146 | if b == 0: continue 147 | parent = centres.loc[(centres.nbranch == b) & (centres.ncyl == 0)].parent.item() 148 | branch_hierarchy[b] = {} 149 | if parent in branch_hierarchy.keys(): 150 | branch_hierarchy[b]['parent_branch'] = np.hstack([[b], branch_hierarchy[parent]['parent_branch']]) 151 | else: 152 | branch_hierarchy[parent] = {} 153 | branch_hierarchy[parent]['parent_branch'] = [parent] 154 | branch_hierarchy[b]['parent_branch'] = np.hstack([[b], branch_hierarchy[parent]['parent_branch']]) 155 | 156 | for b in centres.nbranch.unique(): 157 | if b == 0: continue 158 | ba = set() 159 | for k, v in branch_hierarchy.items(): 160 | if b not in list(v['parent_branch']): continue 161 | ba.update(set(v['parent_branch'][v['parent_branch'] > b])) 162 | if len(ba) > 0: 163 | branch_hierarchy[b]['above'] = list(ba) 164 | else: 165 | branch_hierarchy[b]['above'] = [] 166 | 167 | pbar.set_description('created branch hierarchy', refresh=True) 168 | pbar.update(1) 169 | 170 | return centres, branch_hierarchy 171 | 172 | else: 173 | return centres -------------------------------------------------------------------------------- /treegraph/build_graph.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import pandas as pd 3 | import numpy as np 4 | import json 5 | import matplotlib.pyplot as plt 6 | from scipy.spatial import ConvexHull 7 | from sklearn.neighbors import NearestNeighbors 8 | from tqdm.autonotebook import tqdm 9 | from pandarallel import pandarallel 10 | 11 | 12 | def run(pc, centres, n_neighbours=100, verbose=False): 13 | # find convex hull points of each cluster 14 | group_pc = pc.groupby('node_id') 15 | pandarallel.initialize(nb_workers=min(24, len(group_pc)+1), progress_bar=verbose) 16 | try: 17 | chull = group_pc.parallel_apply(convexHull) 18 | except OverflowError: 19 | if verbose: 20 | print('!pandarallel could not initiate progress bars, running without') 21 | pandarallel.initialize(progress_bar=False) 22 | chull = group_pc.parallel_apply(convexHull) 23 | 24 | # find shortest path from each cluster to the base 25 | # and build skeleton graph 26 | path_dist, path_list, G_skel = generate_path(chull, centres, n_neighbours=n_neighbours) 27 | 28 | return G_skel, path_dist, path_list 29 | 30 | 31 | def convexHull(pc): 32 | if len(pc) > 5: 33 | try: 34 | vertices = ConvexHull(pc[['x', 'y', 'z']]).vertices 35 | idx = np.random.choice(vertices, size=len(vertices), replace=False) 36 | return pc.loc[pc.index[idx]] 37 | except: 38 | return pc 39 | else: 40 | return pc 41 | 42 | 43 | def generate_path(samples, centres, n_neighbours=200, max_length=np.inf, not_base=-1): 44 | # compute nearest neighbours for each vertex in cluster convex hull 45 | nn = NearestNeighbors(n_neighbors=n_neighbours).fit(samples[['x', 'y', 'z']]) 46 | distances, indices = nn.kneighbors() 47 | from_to_all = pd.DataFrame(np.vstack([np.repeat(samples.node_id.values, n_neighbours), 48 | samples.iloc[indices.ravel()].node_id.values, 49 | distances.ravel(), 50 | np.repeat(samples.slice_id.values, n_neighbours), 51 | samples.iloc[indices.ravel()].slice_id.values]).T, 52 | columns=['source', 'target', 'length', 's_sliceid', 't_sliceid']) 53 | 54 | # remove X-X connections 55 | from_to_all = from_to_all.loc[from_to_all.target != from_to_all.source] 56 | 57 | # build edge list based on min distance and 58 | # number of chull pts in nearest neighbour clusters 59 | groups = from_to_all.groupby(['source', 'target']) 60 | edges = groups.length.apply(lambda x: x.min() / (np.log10(x.count())+0.001)).reset_index() 61 | 62 | # remove edges that are likely leaps between trees 63 | edges = edges.loc[edges.length <= max_length] 64 | 65 | # removes isolated origin points i.e. > edge.length 66 | for nid in np.sort(samples.node_id.unique()): 67 | if nid in edges.source.values: 68 | origin = [nid] 69 | break 70 | # origins = [s for s in origins if s in edges.source.values] ## old method 71 | 72 | # compute graph that connect all clusters 73 | G = nx.from_pandas_edgelist(edges, edge_attr=['length']) 74 | # retrieve shortest path list (sp) of each cluster 75 | # to the base node and its corresponding distance 76 | distance, sp = nx.multi_source_dijkstra(G, 77 | sources=origin, 78 | weight='length') 79 | # build skeleton graph 80 | G_skeleton = nx.Graph() 81 | for i, nid in enumerate(G.nodes()): 82 | if nid in sp.keys(): 83 | if len(sp[nid]) > 1: 84 | x1 = float(centres[centres.node_id == nid].cx) 85 | y1 = float(centres[centres.node_id == nid].cy) 86 | z1 = float(centres[centres.node_id == nid].cz) 87 | node1_coor = np.array([x1,y1,z1]) 88 | sid1 = int(centres[centres.node_id == nid].slice_id) 89 | G_skeleton.add_node(nid, pos=[x1,y1,z1], node_id=int(nid), slice_id=sid1) 90 | 91 | x2 = float(centres[centres.node_id == sp[nid][-2]].cx) 92 | y2 = float(centres[centres.node_id == sp[nid][-2]].cy) 93 | z2 = float(centres[centres.node_id == sp[nid][-2]].cz) 94 | node2_coor = np.array([x2,y2,z2]) 95 | sid2 = int(centres[centres.node_id == sp[nid][-2]].slice_id) 96 | G_skeleton.add_node(sp[nid][-2], pos=[x2,y2,z2], 97 | node_id=int(sp[nid][-2]), slice_id=sid2) 98 | 99 | d = np.linalg.norm(node1_coor - node2_coor) 100 | G_skeleton.add_weighted_edges_from([(int(sp[nid][-2]), int(nid), float(d))]) 101 | 102 | paths = pd.DataFrame(index=distance.keys(), data=distance.values(), columns=['distance']) 103 | paths.loc[:, 'base'] = not_base 104 | for p in paths.index: paths.loc[p, 'base'] = sp[p][0] 105 | paths.reset_index(inplace=True) 106 | paths.columns = ['node_id', 'distance', 'base_node_id'] # t_node_id is the base node 107 | 108 | # identify nodes that are branch tips 109 | node_occurance = {} 110 | for v in sp.values(): 111 | for n in v: 112 | if n in node_occurance.keys(): node_occurance[n] += 1 113 | else: node_occurance[n] = 1 114 | 115 | tips = [k for k, v in node_occurance.items() if v == 1] 116 | 117 | paths.loc[:, 'is_tip'] = False 118 | paths.loc[paths.node_id.isin(tips), 'is_tip'] = True 119 | 120 | return paths, sp, G_skeleton 121 | -------------------------------------------------------------------------------- /treegraph/build_skeleton.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import random 4 | import struct 5 | from sklearn.cluster import DBSCAN 6 | from sklearn.metrics import silhouette_score 7 | from sklearn.metrics import calinski_harabasz_score 8 | from tqdm.autonotebook import tqdm 9 | from treegraph.third_party import shortpath as p2g 10 | from treegraph.downsample import * 11 | from treegraph import common 12 | from pandarallel import pandarallel 13 | from sklearn.decomposition import PCA 14 | 15 | 16 | def run(self, verbose=False): 17 | 18 | columns = self.pc.columns.to_list() + ['node_id'] 19 | 20 | # run pandarallel on points grouped by slice_id 21 | groupby = self.pc.groupby('slice_id') 22 | pandarallel.initialize(nb_workers=min(24, len(groupby)), progress_bar=verbose) 23 | try: 24 | sent_back = groupby.parallel_apply(find_centre, self).values 25 | except OverflowError: 26 | if verbose: print('!pandarallel could not initiate progress bars, running without') 27 | pandarallel.initialize(progress_bar=False) 28 | sent_back = groupby.parallel_apply(find_centre, self).values 29 | 30 | # create and append clusters and filtered pc 31 | centres = pd.DataFrame() 32 | self.pc = pd.DataFrame() 33 | for x in sent_back: 34 | if len(x) == 0: continue 35 | centres = centres.append(x[0]) 36 | self.pc = self.pc.append(x[1]) 37 | 38 | # reset index as appended df have common values 39 | centres.reset_index(inplace=True, drop=True) 40 | self.pc.reset_index(inplace=True, drop=True) 41 | 42 | if 'node_id' in self.pc.columns: self.pc = self.pc.drop(columns=['node_id']) 43 | 44 | # convert binary cluster reference to int 45 | MAP = {v:i for i, v in enumerate(centres.idx.unique())} 46 | if 'level_0' in self.pc.columns: self.pc = self.pc.drop(columns='level_0') 47 | if 'index' in self.pc.columns: self.pc = self.pc.drop(columns='index') 48 | self.pc.loc[:, 'node_id'] = self.pc.idx.map(MAP) 49 | centres.loc[:, 'node_id'] = centres.idx.map(MAP) 50 | 51 | return centres 52 | 53 | 54 | def find_centre(dslice, self): 55 | if len(dslice) < 2: 56 | return [] 57 | 58 | centres = pd.DataFrame() 59 | s = dslice.slice_id.unique()[0] 60 | X = dslice[['x', 'y', 'z']] 61 | 62 | group_slice = self.pc[self.pc.slice_id != 0].groupby('slice_id') 63 | max_pts_sid = group_slice.apply(lambda x: len(x)).idxmax() 64 | nn = 10 if s <= max_pts_sid else 5 65 | 66 | results = common.nn_dist(dslice, n_neighbours=nn) 67 | if type(results) == float: 68 | return [] 69 | 70 | dnn, indices = results 71 | dists = np.sort(dnn, axis=0)[:,-1] 72 | idx = np.argsort(np.diff(dists))[::-1] 73 | knee = dists[idx[1]] if idx[0] == 0 else dists[idx[0]] 74 | 75 | mdnn = group_slice.apply(common.mean_dNN, n_neighbours=nn) 76 | conf_85 = np.nanmean(mdnn) + 1.44 * np.nanstd(mdnn) 77 | 78 | dnn_per_point = np.mean(dnn, axis=1) 79 | conf95 = np.nanmean(dnn_per_point) + 2 * np.nanstd(dnn_per_point) 80 | 81 | def dbscan_cluster(eps): 82 | dbscan = DBSCAN(eps=eps, min_samples=nn, 83 | algorithm='kd_tree', metric='euclidean', 84 | n_jobs=-1).fit(X) 85 | labels = np.unique(dbscan.labels_) 86 | return labels, len(labels[labels >= 0]), dbscan 87 | 88 | labels_knee, c_num_knee, dbscan_knee = dbscan_cluster(knee) 89 | labels_fix, c_num_fix, dbscan_fix = dbscan_cluster(conf_85) 90 | labels_conf95, c_num_conf95, dbscan_conf95 = dbscan_cluster(conf95) 91 | 92 | if c_num_knee > 1: 93 | # Calculate internal evaluation metrics for the clusters 94 | clusters = dbscan_knee.fit_predict(X) 95 | # silhouette score 96 | knee_s1 = silhouette_score(X, clusters) 97 | # Calinski-Harabasz index 98 | knee_s2 = calinski_harabasz_score(X, clusters) 99 | 100 | if c_num_fix > 1: 101 | # Calculate internal evaluation metrics for the clusters 102 | clusters = dbscan_fix.fit_predict(X) 103 | # silhouette score 104 | fix_s1 = silhouette_score(X, clusters) 105 | # Calinski-Harabasz index 106 | fix_s2 = calinski_harabasz_score(X, clusters) 107 | 108 | 109 | eps_candidate = [knee, conf_85, conf95] 110 | cnum = np.array([c_num_knee, c_num_fix, c_num_conf95]) 111 | 112 | if c_num_knee < 2: 113 | eps_ = conf_85 if c_num_fix >= 10 * c_num_knee else knee 114 | elif c_num_fix < 2: 115 | eps_ = knee if c_num_knee >= 10 * c_num_fix else conf_85 116 | elif c_num_knee <= 10 and s <= max_pts_sid: 117 | s1, s2 = max(knee_s1, fix_s1), max(knee_s2, fix_s2) 118 | 119 | if knee_s1 == fix_s1: 120 | eps_ = knee if knee_s2 == s2 else conf_85 121 | elif knee_s1 == s1: 122 | eps_ = knee 123 | else: 124 | eps_ = conf_85 125 | else: 126 | eps_ = eps_candidate[np.argmax(cnum)] 127 | 128 | dbscan = DBSCAN(eps=eps_, 129 | min_samples=nn, 130 | algorithm='kd_tree', 131 | metric='euclidean', 132 | n_jobs=-1).fit(X) 133 | dslice.loc[:, 'centre_id'] = dbscan.labels_ 134 | 135 | for c in np.unique(dbscan.labels_): 136 | # working on each cluster 137 | nvoxel = dslice.loc[dslice.centre_id == c] 138 | if c == -1: 139 | dslice = dslice.loc[~dslice.index.isin(nvoxel.index)] 140 | else: 141 | if len(nvoxel.index) < self.min_pts: 142 | dslice = dslice.loc[~dslice.index.isin(nvoxel.index)] 143 | continue # required so centre is added after points are deleted 144 | 145 | pca = PCA(n_components=3) 146 | pca.fit(nvoxel[['x', 'y', 'z']].to_numpy()) 147 | ratios = pca.explained_variance_ratio_ 148 | cyl_metric = (ratios[0] / ratios[1] + ratios[0] / ratios[2]) / 2 149 | 150 | if (cyl_metric <= 5) & (len(nvoxel) > 100): 151 | centre_coords = common.CPC(nvoxel[['x', 'y', 'z']]).x 152 | centre_coords = pd.Series(centre_coords, index=['x', 'y', 'z']) 153 | else: 154 | centre_coords = nvoxel[['x', 'y', 'z']].median() 155 | 156 | centres = centres.append(pd.Series({'slice_id':int(s), 157 | 'centre_id':int(c), # id within a cluster 158 | 'cx':centre_coords.x, 159 | 'cy':centre_coords.y, 160 | 'cz':centre_coords.z, 161 | 'distance_from_base':nvoxel.distance_from_base.mean(), 162 | 'n_points':len(nvoxel), 163 | 'idx':struct.pack('ii', int(s), int(c))}), 164 | ignore_index=True) 165 | 166 | dslice.loc[(dslice.slice_id == s) & 167 | (dslice.centre_id == c), 'idx'] = struct.pack('ii', int(s), int(c)) 168 | 169 | if (len(centres) != 0) & (isinstance(centres, pd.DataFrame)): 170 | return [centres, dslice] 171 | else: 172 | return [] -------------------------------------------------------------------------------- /treegraph/calculate_voxel_length.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def run(pc, exponent=2, minbin=.005, maxbin=.02): 5 | 6 | """ 7 | qunatises `distance_from_base` dependent on an exponential function 8 | 9 | TODO: allow for any function to be used 10 | 11 | """ 12 | 13 | # normalise the distance 14 | pc.loc[:, 'normalised_distance'] = pc.distance_from_base / pc.distance_from_base.max() 15 | 16 | # exponential function to map smaller bin with increased distance from base 17 | bins, n = np.array([]), 50 18 | while not pc.distance_from_base.max() <= bins.sum() < pc.distance_from_base.max() * 1.05: 19 | bins = -np.exp(exponent * np.linspace(0, 1, n)) 20 | bins = (bins - bins.min()) / bins.ptp() # normalise to multiply by bin width 21 | bins = (((maxbin - minbin) * bins) + minbin) 22 | if bins.sum() < pc.distance_from_base.max(): 23 | n += 1 24 | else: n -= 1 25 | 26 | # merge the first two bin widths to enlarge the range of distance_from_base in 1st slice (the base) 27 | bins_w = bins[3:] 28 | bins_w = np.insert(bins_w, 0, [bins[0]+bins[1]+bins[2]]) 29 | 30 | # generate unique id "slice_id" for bins 31 | pc.loc[:, 'slice_id'] = np.digitize(pc.distance_from_base, bins_w.cumsum()) 32 | 33 | bins = {i: f for i, f in enumerate(bins_w)} 34 | 35 | pc = pc.drop(columns=['normalised_distance']) 36 | 37 | return pc, bins 38 | -------------------------------------------------------------------------------- /treegraph/common.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from sklearn.neighbors import NearestNeighbors 4 | from scipy.spatial.distance import cdist 5 | from scipy import optimize 6 | 7 | def node_angle_f(a, b, c): 8 | 9 | # normalise distance between coordinate pairs where b is the central coordinate 10 | ba = a - b 11 | bc = c - b 12 | 13 | # calculate angle between and length of each vector pair 14 | angle_pair = lambda ba, bc: np.arccos(np.dot(bc, ba) / (np.linalg.norm(ba) * np.linalg.norm(bc))) 15 | 16 | return angle_pair(bc.T, ba)#[0][0] 17 | 18 | 19 | def nn(arr, N): 20 | 21 | nbrs = NearestNeighbors(n_neighbors=N+1, algorithm='kd_tree').fit(arr) 22 | distances, indices = nbrs.kneighbors(arr) 23 | 24 | return distances[:, 1] 25 | 26 | 27 | def update_slice_id(centres, branch_hierarchy, node_id, X): 28 | 29 | node = centres.loc[centres.node_id == node_id] 30 | nbranch = node.nbranch.values[0] 31 | ncyl = node.ncyl.values[0] 32 | 33 | # update slices of same branch above ncyls 34 | centres.loc[(centres.nbranch == nbranch) & (centres.ncyl >= ncyl), 'slice_id'] += X 35 | 36 | # update branches above nbranch 37 | centres.loc[centres.nbranch.isin(branch_hierarchy[nbranch]['above']), 'slice_id'] += X 38 | 39 | return centres, branch_hierarchy 40 | 41 | 42 | def CPC(pc): 43 | ''' 44 | Input: 45 | pc: pd.DataFrame, point clouds of a cluster (with same node_id) 46 | Output: 47 | opt: OptimizeResult object, opt.x is the optimal centre coordinates array 48 | ''' 49 | pc_coor = pc[['x','y','z']] 50 | 51 | # cost function 52 | def costf(x): 53 | d = cdist([x], pc_coor) 54 | mu = d.mean() 55 | sigma = np.power((d-mu),2).sum() / d.shape[1] 56 | penalty = d.shape[1]**2 * mu 57 | # penalty = 1e4 58 | # print(penalty) 59 | cost = d.sum() + penalty * sigma 60 | return cost 61 | 62 | # initial guess of the cluster centre coords 63 | centroid = pc_coor.median() 64 | 65 | # minimise cost fun to get optimal para est 66 | opt = optimize.minimize(costf, centroid, method='BFGS') 67 | 68 | return opt 69 | 70 | 71 | def nn_dist(pc, n_neighbours=10): 72 | ''' 73 | Calculate distance of the K-neighbours of each point. 74 | 75 | Inputs: 76 | pc: pd.DataFrame, 77 | input point coordinates 78 | n_neighbours: int, 79 | number of neighbours to use by default for kneighbours queries 80 | 81 | Outputs: 82 | dists: ndarray of shape (len(pc), n_neighbours) 83 | distances to the neighbours of each point 84 | indices: ndarray of shape (len(pc), n_neighbours) 85 | indices of the nearest points in the population matrix 86 | ''' 87 | if len(pc) <= 2: 88 | return np.nan 89 | elif len(pc) <= n_neighbours: 90 | nn = NearestNeighbors(n_neighbors=2).fit(pc[['x','y','z']]) 91 | dists, indices = nn.kneighbors() 92 | else: 93 | nn = NearestNeighbors(n_neighbors=n_neighbours).fit(pc[['x','y','z']]) 94 | dists, indices = nn.kneighbors() 95 | return dists, indices 96 | 97 | 98 | def mean_dNN(pc, n_neighbours=10): 99 | ''' 100 | Calculate the average distance between each point to its K-neighbours, 101 | and take the mean of these average distances for each slice of points. 102 | ''' 103 | if len(pc) <= 2: 104 | mean_dnn_per_slice = np.nan 105 | elif len(pc) <= n_neighbours: 106 | nn = NearestNeighbors(n_neighbors=2).fit(pc[['x','y','z']]) 107 | dists, indices = nn.kneighbors() 108 | mean_dnn_per_point = np.mean(dists, axis=1) 109 | mean_dnn_per_slice = np.mean(mean_dnn_per_point) 110 | else: 111 | nn = NearestNeighbors(n_neighbors=n_neighbours).fit(pc[['x','y','z']]) 112 | dists, indices = nn.kneighbors() 113 | mean_dnn_per_point = np.mean(dists, axis=1) 114 | mean_dnn_per_slice = np.mean(mean_dnn_per_point) 115 | 116 | return mean_dnn_per_slice 117 | 118 | 119 | # filter out large jump at the end of a branch 120 | def filt_large_jump(centres, bin_dict=None): 121 | ''' 122 | Input: 123 | centres: pd.DataFrame 124 | centres attributes of a specific branch 125 | bin_dict: dict 126 | segment bin width (value) of each slice (key) 127 | Output: 128 | centres_filt: pd.DataFrame 129 | centres attributes after filtering connections with large jump 130 | ''' 131 | # print(f'branch {np.unique(centres.nbranch)[0]}') 132 | if len(centres) < 2: 133 | return [] 134 | else: 135 | # calculate the difference of distance from base 136 | dfb = centres.distance_from_base.values 137 | dfb_diff = np.diff(dfb) 138 | centres.loc[centres.index.values[1]:, 'dfb_diff'] = dfb_diff 139 | 140 | # segment bin width of correpsonding slice 141 | centres.loc[:, 'bin_width'] = centres.slice_id.apply(lambda x: bin_dict[x]) 142 | 143 | # ratio of increased distance to slice width 144 | centres.loc[centres.index.values[1]:, 'ratio'] = centres.dfb_diff / centres.bin_width 145 | # large jump nodes, excluding the trunk base 146 | if centres.nbranch.unique()[0] == 0: 147 | # cut = centres[(centres.ratio >= 1.5) & (centres.ncyl != 1)].ncyl.values 148 | cut = centres[(centres.ratio >= 5) & (centres.ncyl != 1)].ncyl.values 149 | else: 150 | # cut = centres[centres.ratio >= 1.5].ncyl.values 151 | cut = centres[centres.ratio >= 5].ncyl.values 152 | if len(cut) > 0: 153 | cut = cut[0] 154 | centres.loc[(centres.ncyl >= cut), 'nbranch'] = -1 155 | centres = centres[centres.nbranch != -1] 156 | centres.drop(columns=['dfb_diff', 'bin_width', 'ratio'], inplace=True) 157 | 158 | # delete isolated branch whose parent branch has been filtered out 159 | for ix, row in centres.iterrows(): 160 | if row.nbranch != 0: 161 | if len(centres[centres.node_id == row.pnode]) == 0: 162 | centres.loc[centres.nbranch == row.nbranch, 'nbranch'] = -1 163 | centres = centres[centres.nbranch != -1] 164 | 165 | return [centres] 166 | 167 | 168 | ## least squares circle fitting 169 | def distance(centre, xp, yp): 170 | """ calculate the distance of each 2D points from the center (xc, yc) """ 171 | xc, yc = centre 172 | dist = np.sqrt((xp-xc)**2 + (yp-yc)**2) 173 | return dist 174 | 175 | def func(centre, xp, yp): 176 | """ 177 | calculate the algebraic distance between the 2D points 178 | and the mean circle centered at c=(xc, yc) 179 | """ 180 | Ri = distance(centre, xp, yp) 181 | return Ri - Ri.mean() 182 | 183 | def least_squares_circle(points): 184 | # Extract x and y coordinates of the points 185 | if type(points) == pd.DataFrame: 186 | xp = points['x'].values 187 | yp = points['y'].values 188 | cen_est = points.x.mean(), points.y.mean() 189 | else: 190 | xp, yp = [], [] 191 | for i in range(len(points)): 192 | xp.append(points[i][0]) 193 | yp.append(points[i][1]) 194 | cen_est = np.mean(xp), np.mean(yp) 195 | 196 | centre, ier = optimize.leastsq(func, cen_est, args=(xp,yp)) 197 | Ri = distance(centre, xp, yp) 198 | R = Ri.mean() 199 | residual = np.mean(Ri - R) 200 | # residual = np.sqrt(np.mean((Ri - R)**2)) 201 | 202 | return centre, R, residual 203 | 204 | 205 | # function to estimate DBH at a given height, default between 1.27-1.33m 206 | def dbh_est(self, h=1.3, verbose=False, plot=False): 207 | trunk_nids = self.centres[self.centres.nbranch == 0].node_id.values 208 | zmin = self.pc.z.min() 209 | zmax = self.pc.z.max() 210 | 211 | zstart = zmin + h - .03 212 | zstop = zmin + h + .03 213 | 214 | pc_slice = self.pc[(self.pc.z.between(zstart, zstop)) & 215 | (self.pc.node_id.isin(trunk_nids))] 216 | 217 | # if pc_slice contains too few points, increase the slice height 218 | if len(self.pc)*.01 < 5: 219 | minpts = self.min_pts 220 | else: 221 | minpts = len(self.pc)*.01 222 | while (len(pc_slice) < min(50, minpts)) & (zmin <= zmax): 223 | zstop += .01 224 | pc_slice = self.pc[(self.pc.z.between(zstart, zstop)) & 225 | (self.pc.node_id.isin(trunk_nids))] 226 | 227 | centre, radius, residual = least_squares_circle(pc_slice) 228 | if verbose: 229 | print(f'measure height = {zstart-zmin:.3f} ~ {zstop-zmin:.3f} m') 230 | print(f'xc = {centre[0]:.3f} m, yc = {centre[1]:.3f} m') 231 | print(f'radius = {radius:.3f} m, residual = {residual:.3f} m') 232 | 233 | 234 | # DBH est from point clouds 235 | dbh_clouds = round(2*radius, 3) 236 | 237 | # DBH est from QSM 238 | sids = pc_slice.slice_id.unique() 239 | nids = self.centres[(self.centres['slice_id'].isin(sids)) 240 | & (self.centres.nbranch == 0)].node_id.unique() 241 | cyl_r = self.cyls[(self.cyls['p2'].isin(nids)) 242 | & (self.cyls['p1'].isin(trunk_nids))].radius 243 | dbh_qsm = round(np.nanmean(2*cyl_r), 3) 244 | 245 | if verbose: 246 | print(f'DBH_from_clouds = {dbh_clouds} m') 247 | print(f'DBH_from_qsm_cyls = {dbh_qsm:.3f} m') 248 | 249 | if plot: 250 | # plot extracted trunk slice and the fitted circle 251 | ax1 = pc_slice.plot.scatter(x='x',y='z') 252 | ax2 = pc_slice.plot.scatter(x='x',y='y') 253 | 254 | theta_fit = np.linspace(-np.pi, np.pi, 180) 255 | x_fit = centre[0] + radius * np.cos(theta_fit) 256 | y_fit = centre[1] + radius * np.sin(theta_fit) 257 | ax2.scatter(centre[0], centre[1], s=10, c='r') 258 | ax2.plot(x_fit, y_fit, 'r--', lw=2) 259 | ax2.axis('equal') 260 | 261 | return dbh_clouds, dbh_qsm 262 | 263 | 264 | ## function to estimate DAH (diameter above-butress height) 265 | def dah_est(self, verbose=False, plot=False): 266 | trunk_nids = self.centres[self.centres.nbranch == 0].node_id.values 267 | zmin = self.pc.z.min() 268 | zmax = self.pc.z.max() 269 | residual, radius = 1, 1 270 | 271 | while (residual/radius > 0.001) & (zmin <= zmax): 272 | pc_slice = self.pc[(self.pc.z.between(zmin+1.27, zmin+1.33)) & 273 | (self.pc.node_id.isin(trunk_nids))] 274 | zmin += .006 275 | # if len(pc_slice) < 50: 276 | if len(self.pc)*.01 < 5: 277 | minpts = self.min_pts 278 | else: 279 | minpts = len(self.pc)*.01 280 | if len(pc_slice) < min(50, minpts): 281 | continue 282 | centre, radius, residual = least_squares_circle(pc_slice) 283 | 284 | if verbose: 285 | print(f'measure height = {zmin+1.27-self.pc.z.min():.3f} ~ {zmin+1.33-self.pc.z.min():.3f} m') 286 | print(f'xc = {centre[0]:.3f} m, yc = {centre[1]:.3f} m') 287 | print(f'radius = {radius:.3f} m, residual = {residual:.3f} m') 288 | print(f'ratio = {residual/radius:.4f}') 289 | 290 | # DAH est from point clouds 291 | dah_clouds = round(2*radius, 3) 292 | # DAH est from QSM 293 | sids = pc_slice.slice_id.unique() 294 | nids = self.centres[(self.centres['slice_id'].isin(sids)) 295 | & (self.centres.nbranch == 0)].node_id.unique() 296 | cyl_r = self.cyls[(self.cyls['p2'].isin(nids)) 297 | & (self.cyls['p1'].isin(trunk_nids))].radius 298 | dah_qsm = np.nanmean(2*cyl_r) 299 | if verbose: 300 | print(f'DAH_from_clouds = {dah_clouds} m') 301 | print(f'DAH_from_qsm_cyls = {dah_qsm:.3f} m') 302 | 303 | if plot: 304 | # plot extracted trunk slice and the fitted circle 305 | ax1 = pc_slice.plot.scatter(x='x',y='z') 306 | ax2 = pc_slice.plot.scatter(x='x',y='y') 307 | 308 | theta_fit = np.linspace(-np.pi, np.pi, 180) 309 | x_fit = centre[0] + radius * np.cos(theta_fit) 310 | y_fit = centre[1] + radius * np.sin(theta_fit) 311 | ax2.scatter(centre[0], centre[1], s=10, c='r') 312 | ax2.plot(x_fit, y_fit, 'r--', lw=2) 313 | ax2.axis('equal') 314 | 315 | return dah_clouds, dah_qsm 316 | 317 | 318 | class treegraph: 319 | 320 | def __init__(self, pc, slice_interval=.05, min_pts=10, base_location=None, verbose=False): 321 | 322 | self.pc = pc.copy() 323 | self.slice_interval=slice_interval 324 | self.min_pts = min_pts 325 | 326 | if base_location == None: 327 | self.base_location = self.pc.z.idxmin() 328 | else: 329 | self.base_location = base_location 330 | 331 | self.verbose = verbose 332 | 333 | -------------------------------------------------------------------------------- /treegraph/cyl2ply.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import numpy as np 4 | import sys 5 | import argparse 6 | import pandas as pd 7 | 8 | # header needed in ply-file 9 | header = ["ply", 10 | "format ascii 1.0", 11 | "comment Author: Cornelis", 12 | "obj_info Generated using Python", 13 | "element vertex 50", 14 | "property float x", 15 | "property float y", 16 | "property float z", 17 | "property float 0", 18 | "element face 96", 19 | "property list uchar int vertex_indices", 20 | "end_header"] 21 | 22 | # faces as needed in ply-file face is expressed by the 4 vertice IDs of the face 23 | faces = [[3, 0, 3, 2], 24 | [3, 0, 4, 3], 25 | [3, 0, 5, 4], 26 | [3, 0, 6, 5], 27 | [3, 0, 7, 6], 28 | [3, 0, 8, 7], 29 | [3, 0, 9, 8], 30 | [3, 0, 10, 9], 31 | [3, 0, 11, 10], 32 | [3, 0, 12, 11], 33 | [3, 0, 13, 12], 34 | [3, 0, 14, 13], 35 | [3, 0, 15, 14], 36 | [3, 0, 16, 15], 37 | [3, 0, 17, 16], 38 | [3, 0, 18, 17], 39 | [3, 0, 19, 18], 40 | [3, 0, 20, 19], 41 | [3, 0, 21, 20], 42 | [3, 0, 22, 21], 43 | [3, 0, 23, 22], 44 | [3, 0, 24, 23], 45 | [3, 0, 25, 24], 46 | [3, 0, 2, 25], 47 | [3, 1, 26, 27], 48 | [3, 1, 27, 28], 49 | [3, 1, 28, 29], 50 | [3, 1, 29, 30], 51 | [3, 1, 30, 31], 52 | [3, 1, 31, 32], 53 | [3, 1, 32, 33], 54 | [3, 1, 33, 34], 55 | [3, 1, 34, 35], 56 | [3, 1, 35, 36], 57 | [3, 1, 36, 37], 58 | [3, 1, 37, 38], 59 | [3, 1, 38, 39], 60 | [3, 1, 39, 40], 61 | [3, 1, 40, 41], 62 | [3, 1, 41, 42], 63 | [3, 1, 42, 43], 64 | [3, 1, 43, 44], 65 | [3, 1, 44, 45], 66 | [3, 1, 45, 46], 67 | [3, 1, 46, 47], 68 | [3, 1, 47, 48], 69 | [3, 1, 48, 49], 70 | [3, 1, 49, 26], 71 | [3, 2, 3, 26], 72 | [3, 26, 3, 27], 73 | [3, 3, 4, 27], 74 | [3, 27, 4, 28], 75 | [3, 4, 5, 28], 76 | [3, 28, 5, 29], 77 | [3, 5, 6, 29], 78 | [3, 29, 6, 30], 79 | [3, 6, 7, 30], 80 | [3, 30, 7, 31], 81 | [3, 7, 8, 31], 82 | [3, 31, 8, 32], 83 | [3, 8, 9, 32], 84 | [3, 32, 9, 33], 85 | [3, 9, 10, 33], 86 | [3, 33, 10, 34], 87 | [3, 10, 11, 34], 88 | [3, 34, 11, 35], 89 | [3, 11, 12, 35], 90 | [3, 35, 12, 36], 91 | [3, 12, 13, 36], 92 | [3, 36, 13, 37], 93 | [3, 13, 14, 37], 94 | [3, 37, 14, 38], 95 | [3, 14, 15, 38], 96 | [3, 38, 15, 39], 97 | [3, 15, 16, 39], 98 | [3, 39, 16, 40], 99 | [3, 16, 17, 40], 100 | [3, 40, 17, 41], 101 | [3, 17, 18, 41], 102 | [3, 41, 18, 42], 103 | [3, 18, 19, 42], 104 | [3, 42, 19, 43], 105 | [3, 19, 20, 43], 106 | [3, 43, 20, 44], 107 | [3, 20, 21, 44], 108 | [3, 44, 21, 45], 109 | [3, 21, 22, 45], 110 | [3, 45, 22, 46], 111 | [3, 22, 23, 46], 112 | [3, 46, 23, 47], 113 | [3, 23, 24, 47], 114 | [3, 47, 24, 48], 115 | [3, 24, 25, 48], 116 | [3, 48, 25, 49], 117 | [3, 25, 2, 49], 118 | [3, 49, 2, 26]] 119 | 120 | def dot(v1,v2): 121 | '''returns dot-product of two vectors''' 122 | return sum(p*q for p,q in zip(v1,v2)) 123 | 124 | def rotation_matrix(A,angle): 125 | '''returns the rotation matrix''' 126 | c = math.cos(angle) 127 | s = math.sin(angle) 128 | R = [[A[0]**2+(1-A[0]**2)*c, A[0]*A[1]*(1-c)-A[2]*s, A[0]*A[2]*(1-c)+A[1]*s], 129 | [A[0]*A[1]*(1-c)+A[2]*s, A[1]**2+(1-A[1]**2)*c, A[1]*A[2]*(1-c)-A[0]*s], 130 | [A[0]*A[2]*(1-c)-A[1]*s, A[1]*A[2]*(1-c)+A[0]*s, A[2]**2+(1-A[2]**2)*c]] 131 | return R 132 | 133 | def load_cyls(cylfile, args): 134 | 135 | cyls = pd.read_csv(cylfile, 136 | sep='\t', 137 | names=['radius', 'length', 'sx', 'sy', 'sz', 'ax', 'ay', 'az', 'parent', 'extension', 138 | 'branch', 'BranchOrder', 'PositionInBranch', 'added', 'UnmodRadius']) 139 | 140 | if not args.no_branch: 141 | branch = pd.read_csv(cylfile.replace('cyl', 'branch'), 142 | sep='\t', 143 | names=['BOrd', 'BPar', 'BVol', 'BLen', 'BAng', 'BHei', 'BAzi', 'BDia']) 144 | 145 | branch.set_index(branch.index + 1, inplace=True) # otherwise branches are lablelled from 0 146 | branch_ids = branch[(branch.BLen >= args.min_length) & (branch.BDia >= args.min_radius * 2)].index 147 | 148 | if args.random: 149 | 150 | values = cyls[args.field].unique() 151 | MAP = {V:i for i, V in enumerate(np.random.choice(values, size=len(values), replace=False))} 152 | cyls.loc[:, 'COL'] = cyls[args.field].map(MAP) 153 | args.field = 'COL' 154 | 155 | if args.verbose: print(cyls.head()) 156 | 157 | pandas2ply(cyls, args.field, cylfile[:-4] + '.ply') 158 | 159 | def pandas2ply(cyls, field, out): 160 | 161 | n = len(cyls) 162 | n_vertices = 50 * n 163 | n_faces = 96 * n 164 | 165 | tempvertices = [] 166 | tempfaces = [] 167 | 168 | add = 0 169 | for i, (ix, cyl) in enumerate(cyls.iterrows()): 170 | 171 | nvertex = 48 # number of vertices, do not change! 172 | rad = cyl.radius # cylinder radius 173 | l = cyl.length # cylinder length 174 | startp = [cyl.sx, cyl.sy, cyl.sz] # startpoint 175 | axis = [cyl.ax, cyl.ay, cyl.az] # axis relative to startpoint 176 | 177 | # first the cylinder is created without rotation 178 | # starting with center of bottom and top circle 179 | 180 | p1 = [0.0, 0.0, 0.0] 181 | p2 = [0.0, 0.0, l] 182 | 183 | degs = np.deg2rad(np.arange(0, 360, 15)) 184 | ps = [p1,p2] 185 | 186 | # add vertices on the bottom and top circle 187 | for p0 in [p1, p2]: 188 | for deg in degs: 189 | x0 = rad*math.cos(deg)+p0[0] 190 | y0 = rad*math.sin(deg)+p0[1] 191 | z0 = p0[2] 192 | 193 | ps += [[x0,y0,z0]] 194 | 195 | # the following part is adjusted from script in Matlab that does rotation 196 | u = [0,0,1] 197 | raxis = [u[1]*axis[2]-axis[1]*u[2], 198 | u[2]*axis[0]-axis[2]*u[0], 199 | u[0]*axis[1]-axis[0]*u[1]] 200 | 201 | eucl = (axis[0]**2+axis[1]**2+axis[2]**2)**0.5 202 | euclr = (raxis[0]**2+raxis[1]**2+raxis[2]**2)**0.5 203 | 204 | for i in range(3): 205 | raxis[i] /= euclr 206 | 207 | angle = math.acos(dot(u,axis)/eucl) 208 | 209 | M = rotation_matrix(raxis,angle) 210 | 211 | for i in range(len(ps)): 212 | p = ps[i] 213 | x = p[0]*M[0][0]+p[1]*M[0][1]+p[2]*M[0][2] 214 | y = p[0]*M[1][0]+p[1]*M[1][1]+p[2]*M[1][2] 215 | z = p[0]*M[2][0]+p[1]*M[2][1]+p[2]*M[2][2] 216 | 217 | # add start position 218 | x += startp[0] 219 | y += startp[1] 220 | z += startp[2] 221 | ps[i] = [x,y,z, cyl[field]] 222 | 223 | tempvertices += ps 224 | for row in faces: 225 | tempfaces += [[row[0]]+[row[i]+add for i in [1,2,3]]] 226 | 227 | add += 50 228 | 229 | header[4] = "element vertex " + str(n_vertices) 230 | header[8] = "property float {}".format(field) 231 | header[9] = "element face " + str(n_faces) 232 | 233 | with open(out, 'w') as theFile: 234 | for i in header: 235 | theFile.write(i+'\n') 236 | for p in tempvertices: 237 | theFile.write(str(p[0])+' '+str(p[1])+' '+str(p[2])+' '+str(p[3])+'\n') 238 | for f in tempfaces: 239 | #print f 240 | theFile.write(str(f[0])+' '+str(f[1])+' '+str(f[2])+' '+str(f[3])+'\n') 241 | 242 | if __name__ == '__main__': 243 | 244 | parser = argparse.ArgumentParser() 245 | parser.add_argument('-c','--cyl', nargs='*', help='list of *cyl.txt files') 246 | parser.add_argument('-f', '--field', default='branch', help='field with which to colour cylinders by') 247 | parser.add_argument('-rc', '--random', default=False, action='store_true', help='randomise colours') 248 | parser.add_argument('-r', '--min_radius', default=0, type=float, help='filter branhces by minimum radius') 249 | parser.add_argument('-l', '--min_length', default=0, type=float, help='filter branches by minimum length') 250 | parser.add_argument('--no_branch', action='store_true', help='use if no corresponding branch file is available') 251 | parser.add_argument('--verbose', action='store_true', help='print some stuff to screen') 252 | args = parser.parse_args() 253 | 254 | for x,line in enumerate(args.cyl): # loops through treelistfile 255 | name = line.split()[0] 256 | load_cyls(name, args) 257 | -------------------------------------------------------------------------------- /treegraph/cylinder_fitting.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2017, Xingjie Pan 3 | All rights reserved. 4 | 5 | DISCLAIMER: 6 | This is a slightly modified version of the code from cylinder_fitting.py 7 | more specifically fitting.py and geometry.py) authored by Xingjie Pan (2017) 8 | and available at: https://github.com/xingjiepan/cylinder_fitting 9 | """ 10 | 11 | 12 | import numpy as np 13 | from scipy.optimize import minimize 14 | 15 | 16 | def direction(theta, phi): 17 | '''Return the direction vector of a cylinder defined 18 | by the spherical coordinates theta and phi. 19 | ''' 20 | return np.array([np.cos(phi) * np.sin(theta), np.sin(phi) * np.sin(theta), 21 | np.cos(theta)]) 22 | 23 | def projection_matrix(w): 24 | '''Return the projection matrix of a direction w.''' 25 | return np.identity(3) - np.dot(np.reshape(w, (3,1)), np.reshape(w, (1, 3))) 26 | 27 | def skew_matrix(w): 28 | '''Return the skew matrix of a direction w.''' 29 | return np.array([[0, -w[2], w[1]], 30 | [w[2], 0, -w[0]], 31 | [-w[1], w[0], 0]]) 32 | 33 | def calc_A(Ys): 34 | '''Return the matrix A from a list of Y vectors.''' 35 | return sum(np.dot(np.reshape(Y, (3,1)), np.reshape(Y, (1, 3))) 36 | for Y in Ys) 37 | 38 | def calc_A_hat(A, S): 39 | '''Return the A_hat matrix of A given the skew matrix S''' 40 | return np.dot(S, np.dot(A, np.transpose(S))) 41 | 42 | def preprocess_data(Xs_raw): 43 | '''Translate the center of mass (COM) of the data to the origin. 44 | Return the prossed data and the shift of the COM''' 45 | n = len(Xs_raw) 46 | Xs_raw_mean = sum(X for X in Xs_raw) / n 47 | 48 | return [X - Xs_raw_mean for X in Xs_raw], Xs_raw_mean 49 | 50 | def G(w, Xs): 51 | '''Calculate the G function given a cylinder direction w and a 52 | list of data points Xs to be fitted.''' 53 | n = len(Xs) 54 | P = projection_matrix(w) 55 | Ys = [np.dot(P, X) for X in Xs] 56 | A = calc_A(Ys) 57 | A_hat = calc_A_hat(A, skew_matrix(w)) 58 | 59 | 60 | u = sum(np.dot(Y, Y) for Y in Ys) / n 61 | v = np.dot(A_hat, sum(np.dot(Y, Y) * Y for Y in Ys)) / np.trace(np.dot(A_hat, A)) 62 | 63 | return sum((np.dot(Y, Y) - u - 2 * np.dot(Y, v)) ** 2 for Y in Ys) 64 | 65 | def C(w, Xs): 66 | '''Calculate the cylinder center given the cylinder direction and 67 | a list of data points. 68 | ''' 69 | 70 | P = projection_matrix(w) 71 | Ys = [np.dot(P, X) for X in Xs] 72 | A = calc_A(Ys) 73 | A_hat = calc_A_hat(A, skew_matrix(w)) 74 | 75 | return (np.dot(A_hat, sum(np.dot(Y, Y) * Y for Y in Ys)) / 76 | np.trace(np.dot(A_hat, A))) 77 | 78 | def r(w, Xs): 79 | '''Calculate the radius given the cylinder direction and a list 80 | of data points. 81 | ''' 82 | n = len(Xs) 83 | P = projection_matrix(w) 84 | c = C(w, Xs) 85 | 86 | return np.sqrt(sum(np.dot(c - X, np.dot(P, c - X)) for X in Xs) / n) 87 | 88 | def fit(data, guess_angles=None): 89 | '''Fit a list of data points to a cylinder surface. The algorithm 90 | implemented here is from David Eberly's paper "Fitting 3D Data with a 91 | Cylinder" from 92 | https://www.geometrictools.com/Documentation/CylinderFitting.pdf 93 | Arguments: 94 | data - A list of 3D data points to be fitted. 95 | guess_angles[0] - Guess of the theta angle of the axis direction 96 | guess_angles[1] - Guess of the phi angle of the axis direction 97 | 98 | Return: 99 | Direction of the cylinder axis 100 | A point on the cylinder axis 101 | Radius of the cylinder 102 | Fitting error (G function) 103 | ''' 104 | Xs, t = preprocess_data(data) 105 | 106 | # Set the start points 107 | 108 | start_points = [(0, 0)] #, (np.pi / 2, 0), (np.pi / 2, np.pi / 2)] 109 | if guess_angles: 110 | start_points = guess_angles 111 | 112 | # Fit the cylinder from different start points 113 | 114 | best_fit = None 115 | best_score = float('inf') 116 | 117 | # for sp in start_points: 118 | method, tol = 'Powell', 1e-6 119 | # print(method, tol) 120 | fitted = minimize(lambda x : G(direction(x[0], x[1]), Xs), 121 | (0,0), method=method, tol=tol) 122 | 123 | # if fitted.fun < best_score: 124 | best_score = fitted.fun 125 | best_fit = fitted 126 | 127 | w = direction(best_fit.x[0], best_fit.x[1]) 128 | 129 | return w, C(w, Xs) + t, r(w, Xs), best_fit.fun 130 | 131 | 132 | def normalize(v): 133 | '''Normalize a vector based on its 2 norm.''' 134 | if 0 == np.linalg.norm(v): 135 | return v 136 | return v / np.linalg.norm(v) 137 | 138 | def rotation_matrix_from_axis_and_angle(u, theta): 139 | '''Calculate a rotation matrix from an axis and an angle.''' 140 | 141 | x = u[0] 142 | y = u[1] 143 | z = u[2] 144 | s = np.sin(theta) 145 | c = np.cos(theta) 146 | 147 | return np.array([[c + x**2 * (1 - c), x * y * (1 - c) - z * s, x * z * (1 - c) + y * s], 148 | [y * x * (1 - c) + z * s, c + y**2 * (1 - c), y * z * (1 - c) - x * s ], 149 | [z * x * (1 - c) - y * s, z * y * (1 - c) + x * s, c + z**2 * (1 - c) ]]) 150 | 151 | def point_line_distance(p, l_p, l_v): 152 | '''Calculate the distance between a point and a line defined 153 | by a point and a direction vector. 154 | ''' 155 | l_v = normalize(l_v) 156 | u = p - l_p 157 | return np.linalg.norm(u - np.dot(u, l_v) * l_v) 158 | -------------------------------------------------------------------------------- /treegraph/distance_from_base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import os 4 | import math 5 | import matplotlib.pyplot as plt 6 | from treegraph import common 7 | from treegraph.third_party import shortpath as p2g 8 | from treegraph import downsample 9 | from treegraph.third_party import cylinder_fitting as cyl_fit 10 | from sklearn.cluster import DBSCAN 11 | from sklearn.neighbors import NearestNeighbors 12 | 13 | def run(pc, base_location=None, cluster_size=False, knn=100, verbose=False,\ 14 | base_correction=True, plot=False): 15 | 16 | """ 17 | Purpose: Attributes each point with distance_from_base 18 | 19 | Inputs: 20 | pc: pd.DataFrame 21 | Input point clouds 22 | base_location: None or idx (default None) 23 | Index of the base point i.e. where point distance are measured to. 24 | cluster_size: False or float (default False) 25 | Downsample points with vlen=cluster_size to speed up graph generation. 26 | The redundent points will be remained for later process. 27 | knn: int (default 100) 28 | Number of neighbors to search around each point in the neighborhood phase. 29 | The higher the better (careful, it's memory intensive). 30 | base_correction: boolean (default True) 31 | Generate a new base node located at the centre of tree base cross-section. 32 | Update initial graph by connecting points in the base slice to new base node. 33 | """ 34 | 35 | columns = pc.columns.to_list() + ['distance_from_base'] 36 | 37 | if cluster_size: 38 | pc, base_location = downsample.run(pc, cluster_size, 39 | base_location=base_location, 40 | remove_noise=True, 41 | keep_columns=['VX']) 42 | 43 | c = ['x', 'y', 'z', 'pid'] 44 | # generate initial graph 45 | G = p2g.array_to_graph(pc.loc[pc.downsample][c] if 'downsample' in pc.columns else pc[c], 46 | base_id=pc.loc[pc.pid == base_location].index[0], 47 | kpairs=3, 48 | knn=knn, 49 | nbrs_threshold=.2, 50 | nbrs_threshold_step=.1, 51 | # graph_threshold=.05 52 | ) 53 | 54 | if base_correction: 55 | # identify and extract lower part of the stem 56 | stem, fit_cyl, new_base = identify_stem(pc, plot=plot) 57 | # fit a cylinder to the lower stem 58 | fit_C, axis_dir, fit_r, fit_err = fit_cyl 59 | 60 | # select points in the lowest slice 61 | low_slice_len = 0.2 # unit in metre 62 | if 'downsample' in pc.columns: 63 | low_slice = pc[(pc.z <= (min(stem.z)+low_slice_len)) &\ 64 | (pc.downsample == True)] 65 | else: 66 | low_slice = pc[pc.z <= (min(stem.z)+low_slice_len)] 67 | 68 | 69 | # calculate distance between new_base_node and each point in the lowest slice 70 | index = [] 71 | distance = [] 72 | for i, row in low_slice.iterrows(): 73 | index.append(i) 74 | coor = row[['x','y','z']].values 75 | distance.append(np.linalg.norm(new_base - coor)) 76 | 77 | # add the new base node to the graph 78 | new_base_id = np.max(pc.index) + 1 79 | G.add_node(int(new_base_id), 80 | pos=[float(new_base[0]), float(new_base[1]), float(new_base[2])], 81 | pid=int(new_base_id)) 82 | 83 | # add edges (weighted by distance) between new base node and the low_slice nodes in graph 84 | p2g.add_nodes(G, new_base_id, index, distance, np.inf) 85 | # print(f'add new edges: {len(index)}') 86 | 87 | # add new_base_node attributes to pc 88 | if 'downsample' in pc.columns: 89 | base_coords = pd.Series({'x':new_base[0], \ 90 | 'y':new_base[1], \ 91 | 'z':new_base[2], \ 92 | 'pid':new_base_id,\ 93 | 'downsample':True},\ 94 | name=new_base_id) 95 | else: 96 | base_coords = pd.Series({'x':new_base[0], \ 97 | 'y':new_base[1], \ 98 | 'z':new_base[2], \ 99 | 'pid':new_base_id},\ 100 | name=new_base_id) 101 | pc = pc.append(base_coords) 102 | 103 | # extracts shortest path information from the updated initial graph 104 | node_ids, distance, path_dict = p2g.extract_path_info(G, new_base_id) 105 | 106 | else: # do not generate new base node nor update the initial graph 107 | # extracts shortest path information from the initial graph 108 | node_ids, distance, path_dict = p2g.extract_path_info(G, pc.loc[pc.pid == base_location].index[0]) 109 | 110 | if 'distance_from_base' in pc.columns: 111 | del pc['distance_from_base'] 112 | 113 | # if pc is downsampled to generate graph then reindex downsampled pc 114 | # and join distances... 115 | if cluster_size: 116 | dpc = pc.loc[pc.downsample] 117 | # dpc.reset_index(inplace=True) 118 | dpc.loc[node_ids, 'distance_from_base'] = np.array(list(distance)) 119 | pc = pd.merge(pc, dpc[['VX', 'distance_from_base']], on='VX', how='left') 120 | # ...or else just join distances to pc 121 | else: 122 | pc.loc[node_ids, 'distance_from_base'] = np.array(list(distance)) 123 | 124 | if base_correction: 125 | return pc[columns], G, new_base, fit_r 126 | else: 127 | return pc[columns], G 128 | 129 | 130 | 131 | def identify_stem(pc, plot=False): 132 | # create empty df for later stem point collection 133 | stem = pd.DataFrame(columns=['x','y','z']) 134 | 135 | # tree height from point clouds 136 | h = pc.z.max() - pc.z.min() 137 | print(f'tree height (from point clouds) = {h:.2f} m') 138 | 139 | if h > 20: # a tall tree 140 | stop = h / 10. 141 | step = h / 30. 142 | else: # a small tree 143 | stop = 2. 144 | step = .5 145 | 146 | # loop over slices of point clouds with vertical interval of 'step' 147 | # loop stop at the 1/10 of the tree height 148 | for i in np.arange(0, stop, step): 149 | zmin = pc.z.min() 150 | if i < (2*step): 151 | pc_slice = pc[pc.z < (zmin+i+step)] 152 | else: 153 | pc_slice = pc[((zmin+i) <= pc.z) & (pc.z < (zmin+i+step))] 154 | pc_coor = pc_slice.reset_index()[['x','y','z']] 155 | if len(pc_coor) <= 10: 156 | continue 157 | 158 | # filter out outliers 159 | nn = NearestNeighbors(n_neighbors=10).fit(pc_coor) 160 | dnn, indices = nn.kneighbors() 161 | mean_dnn = np.mean(dnn, axis=1) 162 | noise = np.where(mean_dnn > 0.05)[0] 163 | pc_coor = pc_coor.drop(noise) 164 | 165 | # cluster the sliced points 166 | dbscan = DBSCAN(eps=.1, min_samples=50).fit(pc_coor) 167 | pc_coor.loc[:, 'clstr'] = dbscan.labels_ 168 | 169 | # calculate normal vector of each cluster and find stem points 170 | if len(np.unique(dbscan.labels_)) > 1: 171 | # for the lowest slice, stem part is the cluster with most points 172 | if len(stem) == 0: 173 | group = pc_coor.groupby(by='clstr').count() 174 | max_clstr = group[group.x == group.x.max()].index[0] 175 | stem = pc_coor[pc_coor.clstr == max_clstr][['x','y','z']] 176 | # for upper slices, ignore slices with multiple clusters 177 | # to avoid mixed branch points as well as furcation 178 | else: 179 | break 180 | else: 181 | for c in np.unique(dbscan.labels_): 182 | xyz = pc_coor[pc_coor.clstr == c][['x','y','z']] 183 | nv, d = normal_vector(xyz) 184 | # if the difference between xyz normal and ground normal is small, 185 | # then regard xyz as stem points 186 | if d < 0.5: 187 | stem = pd.concat([stem, xyz]) 188 | stem = stem.drop_duplicates() 189 | 190 | # if no stem point is found 191 | if len(stem) == 0: 192 | stem = pc_coor[['x','y','z']] 193 | 194 | 195 | # fit a cylinder to the extracted stem points 196 | pts = stem.to_numpy() 197 | axis_dir, fit_C, r_fit, fit_err = cyl_fit.fit(pts) 198 | 199 | # find the min Z coordinate of the point cloud 200 | lowest_z = np.array(stem.z[stem.z == min(stem.z)])[0] 201 | 202 | # define new base node as the point located on the cyl axis line 203 | # with Z coord the same as the lowest point in pc 204 | # calculate its X,Y coords based on line equation defined by axis direction 205 | base_x = (lowest_z - fit_C[2]) / axis_dir[2] * axis_dir[0] + fit_C[0] 206 | base_y = (lowest_z - fit_C[2]) / axis_dir[2] * axis_dir[1] + fit_C[1] 207 | base_node = [base_x, base_y, lowest_z] 208 | 209 | cpc_cen = common.CPC(stem).x 210 | if stem.x.min() <= base_node[0] <= stem.x.max(): 211 | if stem.y.min() <= base_node[1] <= stem.y.max(): 212 | new_base = base_node 213 | else: 214 | new_base = cpc_cen 215 | else: 216 | new_base = cpc_cen 217 | 218 | # plot extracted stem points and fitted cylinder 219 | if plot == True: 220 | fig, axs = plt.subplots(1,3,figsize=(12,4)) 221 | ax = axs.flatten() 222 | # top view 223 | stem.plot.scatter(x='x',y='y',s=1,ax=ax[0], c='grey') 224 | ax[0].scatter(fit_C[0], fit_C[1], s=50, c='blue', label='fitted cyl centre') 225 | ax[0].scatter(new_base[0], new_base[1], s=50, c='red', label='new base node') 226 | 227 | # front view 228 | stem.plot.scatter(x='x',y='z',s=1,ax=ax[1], c='grey') 229 | # fitted cyl centre 230 | ax[1].scatter(fit_C[0], fit_C[2], s=50, c='blue', label='fitted cyl centre') 231 | # fitted cyl axis 232 | z = np.arange(stem.z.min(), stem.z.max(), 0.01) 233 | x = (axis_dir[0]/axis_dir[2]) * (z - fit_C[2]) + fit_C[0] 234 | ax[1].plot(x, z, linestyle='dashed', label='fitted cyl axis') 235 | # new base node 236 | ax[1].scatter(new_base[0], new_base[2], s=50, c='red', label='new base node') 237 | 238 | # side view 239 | stem.plot.scatter(x='y',y='z',s=1,ax=ax[2], c='grey') 240 | # fitted cyl centre 241 | ax[2].scatter(fit_C[1], fit_C[2], s=50, c='blue', label='fitted cyl centre') 242 | # fitted cyl axis 243 | z = np.arange(stem.z.min(), stem.z.max(), 0.01) 244 | y = (axis_dir[1]/axis_dir[2]) * (z - fit_C[2]) + fit_C[1] 245 | ax[2].plot(y, z, linestyle='dashed', label='fitted cyl axis') 246 | # new base node 247 | ax[2].scatter(new_base[1], new_base[2], s=50, c='red', label='new base node') 248 | ax[2].legend(bbox_to_anchor=(1.05, 1)) 249 | 250 | # fig.suptitle(f'{treeid}') 251 | fig.tight_layout() 252 | 253 | return stem, [fit_C, axis_dir, r_fit, fit_err], new_base 254 | 255 | 256 | def normal_vector(pc): 257 | ''' 258 | Calculate the normal vector of input point cloud, 259 | and the difference between this normal and ground normal. 260 | 261 | Input: 262 | - pc: n×3 array, X,Y,Z coordinates of points 263 | 264 | Output: 265 | - nv: normal vector of input point cloud 266 | - d: a metric quantifying the difference between pc normal and ground normal 267 | ''' 268 | # calculating centroid coordinates of points in 'arr'. 269 | centroid = np.average(pc, axis=0) 270 | 271 | # run SVD on centered points from 'arr'. 272 | _, evals, evecs = np.linalg.svd(pc - centroid, full_matrices=False) 273 | 274 | # normal vector of input pc 275 | # is the eigenvector associated with the smallest eigenvalue 276 | nv = evecs[np.argmin(evals)] 277 | 278 | # normal vector of the ground 279 | n0 = np.array([0,0,1]) 280 | 281 | # difference between pc normal and ground normal 282 | d = np.abs(np.dot(nv, n0)) 283 | 284 | return nv, d 285 | -------------------------------------------------------------------------------- /treegraph/distance_from_tip.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | 4 | from tqdm.autonotebook import tqdm 5 | 6 | def run(pc, centres, bins, vlength=.005, verbose=False, min_pts=0): 7 | 8 | pc.loc[:, 'modified_distance'] = pc.distance_from_base 9 | PC_nodes = pd.DataFrame(columns=['new_parent']) 10 | PC_nodes.loc[:, 'parent_node'] = centres.loc[centres.n_furcation != 0].node_id 11 | PC_nodes = pd.merge(centres.loc[~np.isnan(centres.parent_node)][['node_id', 'parent_node', 'nbranch']], 12 | PC_nodes, on='parent_node', how='left') 13 | 14 | new_pc = pd.DataFrame() 15 | 16 | if verbose: print('reattributing branches...') 17 | for nbranch in tqdm(np.sort(centres.nbranch.unique()), 18 | total=len(centres.nbranch.unique()), 19 | disable=False if verbose else True): 20 | 21 | # nodes to identify points 22 | branch_nodes = centres.loc[centres.nbranch == nbranch].node_id.values 23 | parent_node = list(centres.loc[centres.nbranch == nbranch].parent_node.unique())[0] 24 | parent_branch = centres.loc[centres.nbranch == nbranch].parent.unique()[0] 25 | idx = list(pc.loc[pc.node_id.isin(branch_nodes)].index) # index of nodes 26 | branch_pc = pc.loc[idx] 27 | 28 | # correct for some errors in distance_from_base 29 | if len(branch_pc) > 1000: 30 | dfb_min = branch_pc['distance_from_base'].min() 31 | try: 32 | branch_pc, branch_G = distance_from_base.run(branch_pc, 33 | base_location=branch_pc.distance_from_base.idxmin(), 34 | cluster_size=self.cluster_size, 35 | base_correction=False) 36 | except: pass 37 | branch_pc.distance_from_base += dfb_min 38 | 39 | if nbranch == 0: 40 | branch_pc.loc[:, 'modified_distance'] = branch_pc.distance_from_base 41 | else: 42 | # normalising distance so tip is equal to maximum distance 43 | tip_diff = pc.distance_from_base.max() - branch_pc.distance_from_base.max() 44 | branch_pc.loc[:, 'modified_distance'] = branch_pc.distance_from_base + tip_diff 45 | 46 | # regenerating slice_ids 47 | branch_pc.loc[:, 'slice_id'] = np.digitize(branch_pc.modified_distance, np.array(list(bins.values())).cumsum()) 48 | 49 | # check new clusters are not smaller than min_pts, if they 50 | # are cluster them with the next one 51 | N = branch_pc.groupby('slice_id').x.count() 52 | slice_plus = {n:0 if N[n] > min_pts else -1 if n == N.max() else 1 for n in N.index} 53 | branch_pc.slice_id += branch_pc.slice_id.map(slice_plus) 54 | 55 | # normalise slice_id to 0 56 | branch_pc.slice_id = branch_pc.slice_id - branch_pc.slice_id.min() 57 | 58 | # reattribute centres centres 59 | new_centres = branch_pc.groupby('slice_id')[['x', 'y', 'z']].median().rename(columns={'x':'cx', 'y':'cy', 'z':'cz'}) 60 | centre_path_dist = branch_pc.groupby('slice_id').distance_from_base.mean() 61 | npoints = branch_pc.groupby('slice_id').x.count() 62 | npoints.name = 'n_points' 63 | new_centres = new_centres.join(centre_path_dist).join(npoints).reset_index() 64 | 65 | # update pc node_id and slice_id 66 | new_centres.loc[:, 'node_id'] = np.arange(len(new_centres)) + centres.node_id.max() + 1 67 | branch_pc = branch_pc[branch_pc.columns.drop('node_id')].join(new_centres[['slice_id', 'node_id']], 68 | on='slice_id', 69 | how='left', 70 | rsuffix='x') 71 | 72 | if nbranch != 0: # main branch does not have a parent 73 | parent_slice_id = PC_nodes.loc[(PC_nodes.parent_node == parent_node) & 74 | (PC_nodes.nbranch == nbranch)].slice_id.values[0] 75 | new_centres.slice_id += parent_slice_id + 1 76 | branch_pc.slice_id += parent_slice_id + 1 77 | 78 | # if branch furcates identify new node_id and slice_id 79 | for _, row in centres.loc[(centres.nbranch == nbranch) & (centres.n_furcation > 0)].iterrows(): 80 | 81 | new_centres.loc[:, 'dist2fur'] = np.linalg.norm(row[['cx', 'cy', 'cz']].astype(float) - 82 | new_centres[['cx', 'cy', 'cz']], 83 | axis=1) 84 | PC_nodes.loc[PC_nodes.parent_node == row.node_id, 'new_parent'] = new_centres.loc[new_centres.dist2fur.idxmin()].node_id 85 | PC_nodes.loc[PC_nodes.parent_node == row.node_id, 'slice_id'] = new_centres.loc[new_centres.dist2fur.idxmin()].slice_id 86 | 87 | centres = centres.loc[~centres.node_id.isin(branch_nodes)] 88 | centres = centres.append(new_centres.loc[new_centres.n_points > min_pts]) 89 | 90 | # update dict that is used to identify new nodes in parent branch 91 | # node_ids[nbranch] = new_centres.node_id.values 92 | 93 | new_pc = new_pc.append(branch_pc) 94 | 95 | new_pc.reset_index(inplace=True, drop=True) 96 | centres.reset_index(inplace=True, drop=True) 97 | 98 | return centres, new_pc 99 | -------------------------------------------------------------------------------- /treegraph/downsample.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import string 4 | import struct 5 | 6 | # from scipy.spatial.distance import cdist 7 | from sklearn.neighbors import NearestNeighbors 8 | 9 | def voxelise(tmp, length, method='random'): 10 | 11 | tmp.loc[:, 'xx'] = tmp.x // length * length 12 | tmp.loc[:, 'yy'] = tmp.y // length * length 13 | tmp.loc[:, 'zz'] = tmp.z // length * length 14 | 15 | if method == 'random': 16 | 17 | code = lambda: ''.join(np.random.choice([x for x in string.ascii_letters], size=8)) 18 | 19 | xD = {x:code() for x in tmp.xx.unique()} 20 | yD = {y:code() for y in tmp.yy.unique()} 21 | zD = {z:code() for z in tmp.zz.unique()} 22 | 23 | tmp.loc[:, 'VX'] = tmp.xx.map(xD) + tmp.yy.map(yD) + tmp.zz.map(zD) 24 | 25 | elif method == 'bytes': 26 | 27 | code = lambda row: np.array([row.xx, row.yy, row.zz]).tobytes() 28 | tmp.loc[:, 'VX'] = self.pc.apply(code, axis=1) 29 | 30 | else: 31 | raise Exception('method {} not recognised: choose "random" or "bytes"') 32 | 33 | return tmp 34 | 35 | 36 | def run(pc, vlength, 37 | base_location=None, 38 | remove_noise=False, 39 | min_pts=1, 40 | delete=False, 41 | keep_columns=[], 42 | voxel_method='random', 43 | verbose=False): 44 | 45 | """ 46 | Downsamples a point cloud so that there is one point per voxel. 47 | Points are selected as the point closest to the median xyz value 48 | 49 | Parameters 50 | ---------- 51 | 52 | pc: pd.DataFrame with x, y, z columns 53 | vlength: float 54 | 55 | 56 | Returns 57 | ------- 58 | 59 | pd.DataFrame with boolean downsample column 60 | 61 | """ 62 | 63 | pc_length = len(pc) 64 | pc = pc.drop(columns=[c for c in ['downsample', 'VX'] if c in pc.columns]) 65 | 66 | if base_location is None: 67 | base_location = pc.loc[pc.z.idxmin()].pid.values[0] 68 | 69 | columns = pc.columns.to_list() + keep_columns # required for tidy up later 70 | pc = voxelise(pc, vlength, method=voxel_method) 71 | 72 | if remove_noise: 73 | # dissolve voxels with too few points in to neighbouring voxels 74 | # compute N points per voxel 75 | # rename to count 76 | # join with df of voxel median xyz 77 | # reset index 78 | VX = pd.DataFrame(pc.groupby('VX').x.count()) \ 79 | .rename(columns={'x':'cnt'}) \ 80 | .join(pc.groupby('VX')[['x', 'y', 'z']].median()) \ 81 | .reset_index() 82 | 83 | nbrs = NearestNeighbors(n_neighbors=10, leaf_size=15, n_jobs=-1).fit(VX[['x', 'y', 'z']]) 84 | distances, indices = nbrs.kneighbors(VX[['x', 'y', 'z']]) 85 | idx = np.argmax(np.isin(indices, VX.loc[VX.cnt > min_pts].index.to_numpy()), axis=1) 86 | idx = [indices[i, ix] for i, ix in zip(range(len(idx)), idx)] 87 | VX_map = {vx:vxn for vx, vxn in zip(VX.VX.values, VX.loc[idx].VX.values)} 88 | pc.VX = pc.VX.map(VX_map) 89 | 90 | # groubpy to find central (closest to median) point 91 | groupby = pc.groupby('VX') 92 | pc.loc[:, 'mx'] = groupby.x.transform(np.median) 93 | pc.loc[:, 'my'] = groupby.y.transform(np.median) 94 | pc.loc[:, 'mz'] = groupby.z.transform(np.median) 95 | pc.loc[:, 'dist'] = np.linalg.norm(pc[['x', 'y', 'z']].to_numpy(dtype=np.float32) - 96 | pc[['mx', 'my', 'mz']].to_numpy(dtype=np.float32), axis=1) 97 | 98 | # need to keep all points for cylinder fitting so when downsampling 99 | # just adding a column to select by 100 | pc.loc[:, 'downsample'] = False 101 | pc.loc[~pc.sort_values(['VX', 'dist']).duplicated('VX'), 'downsample'] = True 102 | pc.sort_values('downsample', ascending=False, inplace=True) # sorting to base_location index is correct 103 | 104 | # upadate base_id 105 | if base_location not in pc.loc[pc.downsample].pid.values: 106 | pc.loc[pc.downsample, 'nndist'] = np.linalg.norm(pc.loc[pc.pid == base_location][['x', 'y', 'z']].values - 107 | pc.loc[pc.downsample][['x', 'y', 'z']], axis=1) 108 | base_location = pc.loc[pc.nndist == pc.nndist.min()].pid.values[0] 109 | 110 | pc.reset_index(inplace=True, drop=True) 111 | 112 | if delete: 113 | pc = pc.loc[pc.downsample][columns].reset_index(drop=True) 114 | else: 115 | pc = pc[columns + ['downsample']] 116 | 117 | if verbose: print('downsampled point cloud from {} to {} with edge lengthi {}. Points deleted: {}'.format(pc_length, len(pc), vlength, delete)) 118 | 119 | return pc, base_location 120 | -------------------------------------------------------------------------------- /treegraph/fit_cylinders.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import matplotlib.pyplot as plt 4 | from matplotlib.patches import Circle 5 | from sklearn.decomposition import PCA 6 | from scipy import optimize 7 | from scipy.spatial.transform import Rotation 8 | from scipy.stats import variation 9 | from tqdm.autonotebook import tqdm 10 | from treegraph.third_party.available_cpu_count import available_cpu_count 11 | from pandarallel import pandarallel 12 | 13 | def run(pc, centres, 14 | min_pts=10, 15 | ransac_iterations=50, 16 | sample=100, 17 | nb_workers=available_cpu_count(), 18 | verbose=False): 19 | 20 | for c in centres.columns: 21 | if 'sf' in c: del centres[c] 22 | 23 | node_id = centres[centres.n_points > min_pts].sort_values('n_points').node_id.values 24 | 25 | groupby_ = pc.loc[pc.node_id.isin(node_id)].groupby('node_id') 26 | pandarallel.initialize(progress_bar=verbose, 27 | use_memory_fs=True, 28 | nb_workers=min(len(centres), nb_workers)) 29 | 30 | # cyl = groupby_.parallel_apply(RANSAC_helper, ransac_iterations) 31 | cyl = groupby_.parallel_apply(RANSAC_helper_2, ransac_iterations, pc, centres) 32 | 33 | cyl = cyl.reset_index() 34 | cyl.columns=['node_id', 'result'] 35 | cyl.loc[:, 'sf_radius'] = cyl.result.apply(lambda c: c[0]) 36 | cyl.loc[:, 'sf_cx'] = cyl.result.apply(lambda c: c[1][0]) 37 | cyl.loc[:, 'sf_cy'] = cyl.result.apply(lambda c: c[1][1]) 38 | cyl.loc[:, 'sf_cz'] = cyl.result.apply(lambda c: c[1][2]) 39 | 40 | centres = pd.merge(centres, 41 | cyl[['node_id', 'sf_radius', 'sf_cx', 'sf_cy', 'sf_cz']], 42 | on='node_id', 43 | how='left') 44 | 45 | return centres 46 | 47 | def other_cylinder_fit2(xyz, xm=0, ym=0, xr=0, yr=0, r=1): 48 | 49 | from scipy.optimize import leastsq 50 | 51 | """ 52 | https://stackoverflow.com/a/44164662/1414831 53 | 54 | This is a fitting for a vertical cylinder fitting 55 | Reference: 56 | http://www.int-arch-photogramm-remote-sens-spatial-inf-sci.net/XXXIX-B5/169/2012/isprsarchives-XXXIX-B5-169-2012.pdf 57 | xyz is a matrix contain at least 5 rows, and each row stores x y z of a cylindrical surface 58 | p is initial values of the parameter; 59 | p[0] = Xc, x coordinate of the cylinder centre 60 | P[1] = Yc, y coordinate of the cylinder centre 61 | P[2] = alpha, rotation angle (radian) about the x-axis 62 | P[3] = beta, rotation angle (radian) about the y-axis 63 | P[4] = r, radius of the cylinder 64 | th, threshold for the convergence of the least squares 65 | """ 66 | 67 | x = xyz.x 68 | y = xyz.y 69 | z = xyz.z 70 | 71 | p = np.array([xm, ym, xr, yr, r]) 72 | 73 | fitfunc = lambda p, x, y, z: (- np.cos(p[3])*(p[0] - x) - z*np.cos(p[2])*np.sin(p[3]) - np.sin(p[2])*np.sin(p[3])*(p[1] - y))**2 + (z*np.sin(p[2]) - np.cos(p[2])*(p[1] - y))**2 #fit function 74 | errfunc = lambda p, x, y, z: fitfunc(p, x, y, z) - p[4]**2 #error function 75 | 76 | est_p, success = leastsq(errfunc, p, args=(x, y, z), maxfev=1000) 77 | 78 | return est_p 79 | 80 | def RANSACcylinderFitting4(xyz_, iterations=50, plot=False): 81 | 82 | if plot: 83 | ax = plt.subplot(111) 84 | 85 | bestFit, bestErr = None, np.inf 86 | xyz_mean = xyz_.mean(axis=0) 87 | xyz_ -= xyz_mean 88 | 89 | for i in range(iterations): 90 | 91 | xyz = xyz_.copy() 92 | 93 | # prepare sample 94 | sample = xyz.sample(n=20) 95 | # sample = xyz.sample(n=max(20, int(len(xyz)*.2))) 96 | xyz = xyz.loc[~xyz.index.isin(sample.index)] 97 | 98 | x, y, a, b, radius = other_cylinder_fit2(sample, 0, 0, 0, 0, 0) 99 | centre = (x, y) 100 | if not np.all(np.isclose(centre, 0, atol=radius*1.05)): continue 101 | 102 | MX = Rotation.from_euler('xy', [a, b]).inv() 103 | xyz[['x', 'y', 'z']] = MX.apply(xyz) 104 | xyz.loc[:, 'error'] = np.linalg.norm(xyz[['x', 'y']] - centre, axis=1) / radius 105 | idx = xyz.loc[xyz.error.between(.8, 1.2)].index # 40% of radius is prob quite large 106 | 107 | # select points which best fit model from original dataset 108 | alsoInliers = xyz_.loc[idx].copy() 109 | if len(alsoInliers) < len(xyz_) * .2: continue # skip if no enough points chosen 110 | 111 | # refit model using new params 112 | x, y, a, b, radius = other_cylinder_fit2(alsoInliers, x, y, a, b, radius) 113 | centre = [x, y] 114 | if not np.all(np.isclose(centre, 0, atol=radius*1.05)): continue 115 | 116 | MX = Rotation.from_euler('xy', [a, b]).inv() 117 | alsoInliers[['x', 'y', 'z']] = MX.apply(alsoInliers[['x', 'y', 'z']]) 118 | # calculate error for "best" subset 119 | alsoInliers.loc[:, 'error'] = np.linalg.norm(alsoInliers[['x', 'y']] - centre, axis=1) / radius 120 | 121 | if variation(alsoInliers.error) < bestErr: 122 | 123 | # for testing uncomment 124 | c = Circle(centre, radius=radius, facecolor='none', edgecolor='g') 125 | 126 | bestFit = [radius, centre, c, alsoInliers, MX] 127 | bestErr = variation(alsoInliers.error) 128 | 129 | if bestFit == None: 130 | # usually caused by low number of ransac iterations 131 | return np.nan, xyz[['x', 'y', 'z']].mean(axis=0).values, np.inf, len(xyz_) 132 | 133 | radius, centre, c, alsoInliers, MX = bestFit 134 | centre[0] += xyz_mean.x 135 | centre[1] += xyz_mean.y 136 | centre = centre + [xyz_mean.z] 137 | 138 | # for testing uncomment 139 | if plot: 140 | 141 | radius, Centre, c, alsoInliers, MX = bestFit 142 | 143 | xyz_[['x', 'y', 'z']] = MX.apply(xyz_) 144 | xyz_ += xyz_mean 145 | ax.scatter(xyz_.x, xyz_.y, s=1, c='grey') 146 | 147 | alsoInliers[['x', 'y', 'z']] += xyz_mean 148 | cbar = ax.scatter(alsoInliers.x, alsoInliers.y, s=10, c=alsoInliers.error) 149 | plt.colorbar(cbar) 150 | 151 | ax.scatter(Centre[0], Centre[1], marker='+', s=100, c='r') 152 | ax.add_patch(c) 153 | ax.axis('equal') 154 | 155 | 156 | return [radius, centre, bestErr, len(xyz_)] 157 | 158 | def NotRANSAC(xyz): 159 | 160 | try: 161 | xyz = xyz[['x', 'y', 'z']] 162 | pca = PCA(n_components=3, svd_solver='auto').fit(xyz) 163 | xyz[['x', 'y', 'z']] = pca.transform(xyz) 164 | radius, centre = other_cylinder_fit2(xyz) 165 | 166 | if xyz.z.min() - radius < centre[0] < xyz.z.max() + radius or \ 167 | xyz.y.min() - radius < centre[1] < xyz.y.max() + radius: 168 | centre = np.hstack([xyz.x.mean(), centre]) 169 | else: 170 | centre = xyz.mean().values 171 | 172 | centre = pca.inverse_transform(centre) 173 | except: 174 | radius, centre = np.nan, xyz[['x', 'y', 'z']].mean(axis=0).values 175 | 176 | return [radius, centre, np.inf, len(xyz)] 177 | 178 | 179 | def RANSAC_helper_2(dcluster, ransac_iterations, pc, centres, plot=False): 180 | # node_id of current cluster 181 | nid = np.unique(dcluster.node_id)[0] 182 | # branch_id of current centre node 183 | nbranch = centres[centres.node_id == nid].nbranch.values[0] 184 | # number of centres (segments) of this branch 185 | nseg = len(centres[centres.nbranch == nbranch]) 186 | # the sequence id of current centre node in its branch 187 | ncyl = centres[centres.node_id == nid].ncyl.values[0] 188 | # print(f'node_id = {nid}, nbranch = {nbranch}, nseg = {nseg}, ncyl = {ncyl}') 189 | 190 | # sample points for cyl fitting 191 | if nseg == 1: 192 | samples = dcluster 193 | if ncyl == 0: # the first segment of this branch 194 | node_list = centres[(centres.nbranch == nbranch) & (centres.ncyl.isin([0,1]))].node_id.values 195 | samples = pc[pc.node_id.isin(node_list)] 196 | if ncyl == (nseg-1): # the last segment of this branch 197 | node_list = centres[(centres.nbranch == nbranch) & (centres.ncyl.isin([nseg-2,nseg-1]))].node_id.values 198 | samples = pc[pc.node_id.isin(node_list)] 199 | else: 200 | node_list = centres[(centres.nbranch == nbranch) & (centres.ncyl.isin([ncyl-1,ncyl+1]))].node_id.values 201 | samples = pc[pc.node_id.isin(node_list)] 202 | 203 | # fit cyl to samples using RANSAC 204 | if len(samples) == 0: # don't think this is required but.... 205 | cylinder = [np.nan, np.array([np.inf, np.inf, np.inf]), np.inf, len(samples)] 206 | elif len(samples) <= 10: 207 | cylinder = [np.nan, samples[['x', 'y', 'z']].mean(axis=0).values, np.inf, len(samples)] 208 | elif len(samples) <= 20: 209 | cylinder = NotRANSAC(samples) 210 | else: 211 | cylinder = RANSACcylinderFitting4(samples[['x', 'y', 'z']], iterations=ransac_iterations, plot=plot) 212 | 213 | return cylinder 214 | 215 | 216 | def RANSAC_helper(xyz, ransac_iterations, plot=False): 217 | # try: 218 | if len(xyz) == 0: # don't think this is required but.... 219 | cylinder = [np.nan, np.array([np.inf, np.inf, np.inf]), np.inf, len(xyz)] 220 | elif len(xyz) <= 10: 221 | cylinder = [np.nan, xyz[['x', 'y', 'z']].mean(axis=0).values, np.inf, len(xyz)] 222 | elif len(xyz) <= 20: 223 | cylinder = NotRANSAC(xyz) 224 | else: 225 | cylinder = RANSACcylinderFitting4(xyz[['x', 'y', 'z']], iterations=ransac_iterations, plot=plot) 226 | # if cylinder == None: # again not sure if this is necessary... 227 | # cylinder = [np.nan, xyz[['x', 'y', 'z']].mean(axis=0)] 228 | 229 | # except: 230 | # cylinder = [np.nan, xyz[['x', 'y', 'z']].mean(axis=0), np.inf, np.inf] 231 | 232 | return cylinder 233 | -------------------------------------------------------------------------------- /treegraph/generate_cylinder_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from treegraph.common import node_angle_f 4 | from tqdm.autonotebook import tqdm 5 | 6 | # updated version 7 | def run(self, radius_value='m_radius'): 8 | self.cyls = pd.DataFrame(columns=['p1', 'p2', 9 | 'sx', 'sy', 'sz', 10 | 'ax', 'ay', 'az', 11 | 'radius', 'length', 'vol', 'surface_area', 'point_density', 12 | 'nbranch', 'ninternode', 'ncyl', 'is_tip', 'branch_order', 'branch_order2']) 13 | 14 | for ix, row in tqdm(self.centres.sort_values(['nbranch', 'ncyl']).iterrows(), 15 | total=len(self.centres)): 16 | if row.node_id not in self.path_ids.keys(): continue 17 | # path from current node to the base node 18 | k_path = self.path_ids[row.node_id][::-1] 19 | k1 = k_path[0] 20 | 21 | if len(k_path) > 1: 22 | k2 = k_path[1] 23 | # current node coords 24 | c1 = np.array([row.cx, row.cy, row.cz]) 25 | 26 | if len(self.centres[self.centres.node_id == k2]) == 0: continue 27 | # previous node coords 28 | c2 = np.array([self.centres.loc[self.centres.node_id == k2].cx.values[0], 29 | self.centres.loc[self.centres.node_id == k2].cy.values[0], 30 | self.centres.loc[self.centres.node_id == k2].cz.values[0]]) 31 | 32 | correction = 1 33 | length = np.linalg.norm(c1 - c2) 34 | L = length 35 | length *= correction 36 | 37 | if length < 0: continue 38 | 39 | if isinstance(radius_value, str): 40 | # rad = self.centres[self.centres.node_id.isin([k1, k2])][radius_value].mean() 41 | 42 | if self.centres[self.centres.node_id == k2].n_furcation.values[0] > 0: 43 | # if prev node is a furcation node 44 | rad = self.centres[self.centres.node_id == k1][radius_value].values[0] 45 | else: 46 | rad = self.centres[self.centres.node_id == k2][radius_value].values[0] 47 | 48 | # mask NaN radius 49 | is_null = np.isnan(rad) 50 | if np.all(is_null): 51 | continue 52 | 53 | elif isinstance(radius_value, int) or isinstance(radius_value, float): 54 | rad = radius_value 55 | else: 56 | rad = .05 57 | 58 | volume = np.pi * (rad ** 2) * length 59 | surface_area = 2 * np.pi * rad * length #+ 2 * np.pi * rad**2 60 | 61 | if np.isnan(rad): print(k1, k2) 62 | 63 | direction = direction_vector(c1, c2) 64 | 65 | point_density = ((row.n_points + self.centres.loc[self.centres.node_id == k2].n_points.values) / 2) / volume 66 | row = row.append(pd.Series(index=['point_density'], data=point_density)) 67 | 68 | # branch section order: +1 whenever after a furcation node 69 | branch_order = row.norder 70 | # branch order of complete branch (ending at a tip node) = number of its parent branch 71 | branch_order2 = len(self.branch_hierarchy[row.nbranch]['parent_branch']) 72 | 73 | 74 | self.cyls.loc[ix] = [k1, k2, 75 | c1[0], c1[1], c1[2], 76 | direction[0], direction[1], direction[2], 77 | rad, length, volume, surface_area, row.point_density, 78 | row.nbranch, row.ninternode, int(row.ncyl), row.is_tip, branch_order, branch_order2] 79 | 80 | 81 | def end_of_branch(l, axis, start): 82 | 83 | # the following part is adjusted from script in Matlab that does rotation 84 | u = [0,0,1] 85 | raxis = [u[1]*axis[2]-axis[1]*u[2], 86 | u[2]*axis[0]-axis[2]*u[0], 87 | u[0]*axis[1]-axis[0]*u[1]] 88 | 89 | eucl = (axis[0]**2+axis[1]**2+axis[2]**2)**0.5 90 | euclr = (raxis[0]**2+raxis[1]**2+raxis[2]**2)**0.5 91 | 92 | for i in range(3): 93 | raxis[i] /= euclr 94 | 95 | angle = math.acos(np.dot(u, axis) / eucl) 96 | 97 | M = rotation_matrix(raxis, angle) 98 | p = [0.0, 0.0, l] 99 | x = (p[0]*M[0][0]+p[1]*M[0][1]+p[2]*M[0][2]) + start[0] 100 | y = (p[0]*M[1][0]+p[1]*M[1][1]+p[2]*M[1][2]) + start[1] 101 | z = (p[0]*M[2][0]+p[1]*M[2][1]+p[2]*M[2][2]) + start[2] 102 | 103 | return pd.Series([x, y, z]) 104 | 105 | def rotation_matrix(A, angle): 106 | '''returns the rotation matrix''' 107 | c = math.cos(angle) 108 | s = math.sin(angle) 109 | R = [[A[0]**2+(1-A[0]**2)*c, A[0]*A[1]*(1-c)-A[2]*s, A[0]*A[2]*(1-c)+A[1]*s], 110 | [A[0]*A[1]*(1-c)+A[2]*s, A[1]**2+(1-A[1]**2)*c, A[1]*A[2]*(1-c)-A[0]*s], 111 | [A[0]*A[2]*(1-c)-A[1]*s, A[1]*A[2]*(1-c)+A[0]*s, A[2]**2+(1-A[2]**2)*c]] 112 | return R 113 | 114 | def direction_vector(p1, p2): 115 | return (p2 - p1) / np.linalg.norm(p2 - p1) 116 | -------------------------------------------------------------------------------- /treegraph/graph_process.py: -------------------------------------------------------------------------------- 1 | import json 2 | import networkx as nx 3 | import numpy as np 4 | import pandas as pd 5 | import matplotlib.pyplot as plt 6 | import matplotlib.image as mpimg 7 | from datetime import * 8 | 9 | def save_graph(G, fname): 10 | ''' 11 | Generate a json file to save the node and edge information of a networkx graph. 12 | 13 | Parameters 14 | ---------- 15 | G : networkx graph 16 | Graph needs to be saved. 17 | 18 | fname : string 19 | Output path for the saving json file. 20 | 21 | ''' 22 | 23 | dt = datetime.now().strftime('%Y-%m-%d_%H-%M') 24 | fname = fname + '_' + dt + '.json' 25 | 26 | json.dump(dict(nodes=[[int(n), G.nodes[n]] for n in G.nodes()], \ 27 | edges=[[int(u), int(v)] for u,v in G.edges()]),\ 28 | open(fname, 'w'), indent=2) 29 | 30 | print(f'Graph has been successfully saved in \n{fname}') 31 | 32 | 33 | 34 | def load_graph(fname): 35 | ''' 36 | Load a networkx graph with nodes and edgse information from a json file. 37 | 38 | Parameters 39 | ---------- 40 | fname : string 41 | Path of the file. 42 | 43 | 44 | Returns 45 | ---------- 46 | G : networkx graph 47 | Graph with nodes and edges information. 48 | ''' 49 | 50 | G = nx.DiGraph() 51 | d = json.load(open(fname)) 52 | G.add_nodes_from(d['nodes']) 53 | G.add_edges_from(d['edges']) 54 | 55 | return G 56 | 57 | 58 | 59 | def save_centres_for_graph(centres, fname): 60 | ''' 61 | Generate a csv file to save centres coordinates and other info. 62 | 63 | Parameters 64 | ----------- 65 | centres: pandas dataframe 66 | includes x,y,z coords and node_id of each node 67 | 68 | fname: string 69 | output path of the file 70 | ''' 71 | dt = datetime.now().strftime('%Y-%m-%d_%H-%M') 72 | fname = fname + '_' + dt + '.csv' 73 | 74 | centres.to_csv(fname, index=False) 75 | 76 | print(f'\ncentres has been successfully saved in \n{fname}') 77 | 78 | 79 | def load_centres_for_graph(fname): 80 | ''' 81 | Load dataframe of centres from a csv file. 82 | 83 | Parameter: 84 | ---------- 85 | fname: string 86 | csv file which stores x,y,z coords and node_id of each node 87 | 88 | Ouput: 89 | ---------- 90 | centres: pandas dataframe 91 | ''' 92 | centres = pd.read_csv(fname) 93 | 94 | return centres 95 | -------------------------------------------------------------------------------- /treegraph/main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | import treegraph.IO 5 | from treegraph.third_party.ply_io import * 6 | 7 | class initialise: 8 | 9 | def __init__(self, 10 | data_path='/path/to/pointclouds.ply', 11 | output_path='/path/to/outputs/', 12 | base_idx=None, 13 | min_pts=5, 14 | downsample=.01, 15 | cluster_size=.04, 16 | tip_width=None, 17 | verbose=False, 18 | base_corr=True, 19 | dbh_height=1.3, 20 | txt_file=True, 21 | save_graph=False 22 | ): 23 | 24 | """ 25 | data_path: pandas dataframe or path to point cloud in .ply or .txt format 26 | If pandas dataframe, columns ['x', 'y', 'z'] must be present. 27 | downsample: None or float. 28 | If value is a float the point cloud will be downsampled before 29 | running treegraph. 30 | columns: list default ['x', 'y', 'z'] 31 | If pc is a path to a text file then columns names can also 32 | be passed. 33 | 34 | """ 35 | 36 | self.verbose = verbose 37 | if self.verbose: import time 38 | 39 | self.downsample = downsample 40 | 41 | # read in data 42 | pc = data_path 43 | if isinstance(pc, pd.DataFrame): 44 | if np.all([c in pc.columns for c in ['x', 'y', 'z']]): 45 | self.pc = pc 46 | else: 47 | raise Exception('pc columns need to be x, y, z, columns found {}'.format(pc.columns)) 48 | elif isinstance(pc, str) and pc.endswith('.ply'): 49 | self.pc = read_ply(pc) 50 | elif isinstance(pc, str) and pc.endswith('.txt'): 51 | sep = ',' if ',' in open(pc, 'r').readline() else ' ' 52 | self.pc = pd.read_csv(pc, sep=sep) 53 | if len(columns) != len(self.pc.columns): 54 | raise Exception('pc read from {} has columns {}, expecting {}'.format(pc, self.pc.columns, columns)) 55 | else: 56 | self.pc.columns = columns 57 | else: 58 | raise Exception('pc is not a pandas dataframe nor a path to point cloud') 59 | 60 | self.data_path = data_path 61 | self.output_path = output_path 62 | self.min_pts = min_pts 63 | self.cluster_size = cluster_size 64 | self.tip_width = tip_width 65 | self.base_corr = base_corr 66 | self.dbh_height = dbh_height 67 | self.txt_file = txt_file 68 | self.save_graph = save_graph 69 | 70 | # add unique point id 71 | self.pc.loc[:, 'pid'] = np.arange(len(self.pc)) 72 | self.base_idx = self.pc.loc[self.pc.z.idxmin() if base_idx == None else base_idx].pid -------------------------------------------------------------------------------- /treegraph/scripts/batch_tree2qsm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | from glob import glob 4 | from treegraph.scripts import tree2qsm 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--inputs', '-i', type=str, required=True, help='path to input files') 8 | a = parser.parse_args() 9 | 10 | # run tree2qsm.py on all inputs combination one after the other 11 | inputs_f = glob(a.inputs) 12 | for m in range(len(inputs_f)): 13 | with open (inputs_f[m]) as fr: 14 | args = yaml.safe_load(fr) 15 | for key, item in args.items(): 16 | print(f'{key}: {item}') 17 | 18 | tree2qsm.run(data_path=args['data_path'], 19 | output_path=args['output_path'], 20 | base_idx=args['base_idx'], 21 | min_pts=args['min_pts'], 22 | cluster_size=args['cluster_size'], 23 | tip_width=args['tip_width'], 24 | verbose=args['verbose'], 25 | base_corr=args['base_corr'], 26 | filtering=args['filtering'], 27 | txt_file=args['txt_file'], 28 | save_graph=args['save_graph']) -------------------------------------------------------------------------------- /treegraph/scripts/generate_inputs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | from glob import glob 4 | 5 | parser = argparse.ArgumentParser() 6 | # required input arguments 7 | parser.add_argument('--data', '-d', type=str, required=True, 8 | help='path to point clouds') 9 | parser.add_argument('--outputs', '-o', type=str, required=True, 10 | help='path to save output files') 11 | # optional input arguments 12 | parser.add_argument('--min_pts', type=int, default=5, required=False, 13 | help='min number of points to pass the filtering') 14 | parser.add_argument('--dbh_height', type=float, default=1.3, required=False, 15 | help='height of DBH estimate, unit in metre, default 1.3m') 16 | parser.add_argument('--cluster_size', type=float, default=.04, required=False, 17 | help='voxel length for downsampling points when generating initial graph') 18 | parser.add_argument('--tip_width', type=float, default=None, required=False, 19 | help='average branch tip diameter (if known), float, unit in metre') 20 | parser.add_argument('--no-base_corr', dest='base_corr', action='store_false', required=False, 21 | help='perform a base fitting correction, default True') 22 | parser.add_argument('--base_idx', type=int, default=None, required=False, 23 | help='index of base point, used if base is not the lowest point, if use this, base_corr should be False') 24 | parser.add_argument('--no-txt_file', dest='txt_file', action='store_false', required=False, 25 | help='produce a text file report, default True') 26 | parser.add_argument('--save_graph', action='store_true', required=False, 27 | help='save the initial distance graph for development purpose, user can keep it False') 28 | parser.add_argument('--verbose', action='store_true', required=False, 29 | help='print something') 30 | 31 | args = parser.parse_args() 32 | 33 | ''' 34 | Purpose: generate input yaml files for all point clouds in the given path with the given parameters. 35 | ''' 36 | 37 | for fp in glob(args.data): 38 | # fp: path to input tree point cloud, str 39 | print(f'clouds path: {fp}') 40 | 41 | inputs = {'data_path':fp, 42 | 'output_path':args.outputs, 43 | 'base_idx':args.base_idx, 44 | 'min_pts':args.min_pts, 45 | 'cluster_size':args.cluster_size, 46 | 'tip_width':args.tip_width, 47 | 'verbose':args.verbose, 48 | 'base_corr':args.base_corr, 49 | 'dbh_height':args.dbh_height, 50 | 'txt_file':args.txt_file, 51 | 'save_graph':args.save_graph} 52 | 53 | treeid = fp.split('/')[-1].split('.')[0] 54 | ofn = f'{treeid}-inputs-cs{args.cluster_size}-tip{args.tip_width}.yml' 55 | 56 | with open(ofn, 'w') as f: 57 | f.write(yaml.safe_dump(inputs)) 58 | print(f'generate input file: {ofn}\n') 59 | -------------------------------------------------------------------------------- /treegraph/scripts/print_results.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from treegraph.IO import * 4 | 5 | read_json(sys.argv[1], pretty_printing=True, attributes=[]) 6 | -------------------------------------------------------------------------------- /treegraph/scripts/tree2qsm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import argparse 4 | import numpy as np 5 | import pandas as pd 6 | import treegraph 7 | from treegraph import downsample 8 | from treegraph import distance_from_base 9 | from treegraph import calculate_voxel_length 10 | from treegraph import build_skeleton 11 | from treegraph import build_graph 12 | from treegraph import attribute_centres 13 | from treegraph import distance_from_tip 14 | from treegraph import split_furcation 15 | from treegraph import estimate_radius 16 | from treegraph import taper 17 | from treegraph import fit_cylinders 18 | from treegraph import generate_cylinder_model 19 | from treegraph import IO 20 | from treegraph import common 21 | from treegraph.third_party import ransac_cyl_fit as fc 22 | from datetime import * 23 | 24 | 25 | def run(data_path='/path/to/pointclouds.ply', output_path='../results/TreeID/', 26 | base_idx=None, min_pts=5, cluster_size=.04, tip_width=None, verbose=False, 27 | base_corr=True, dbh_height=1.3, txt_file=True, save_graph=False): 28 | 29 | self = treegraph.initialise(data_path=data_path, 30 | output_path=output_path, 31 | base_idx=base_idx, 32 | min_pts=min_pts, 33 | downsample=.01, 34 | cluster_size=cluster_size, 35 | tip_width=tip_width, 36 | verbose=verbose, 37 | base_corr=base_corr, 38 | dbh_height=dbh_height, 39 | txt_file=txt_file, 40 | save_graph=save_graph) 41 | 42 | ### open a file to store result summary ### 43 | treeid = os.path.splitext(data_path)[0].split('/')[-1] 44 | dt = datetime.now() 45 | print(f'treeid: {treeid}\nProgramme starts at: {dt}') 46 | cs = f'cs{cluster_size}-' 47 | tip = f'tip{tip_width}' 48 | o_f = output_path + treeid + '-' + cs + tip 49 | 50 | if txt_file: 51 | print(f'Outputs are written in a txt file:\n{o_f}.txt') 52 | inputs = f"data_path = {data_path}\noutput_path = {output_path}\n\ 53 | base_idx = {base_idx}\nmin_pts = {min_pts}\ncluster_size = {cluster_size}\n\ 54 | tip_width = {tip_width}\nverbose = {verbose}\nbase_corr = {base_corr}\n\ 55 | dbh_height = {dbh_height}\ntxt_file = {txt_file}\nsave_graph = {save_graph}" 56 | 57 | with open(o_f+'.txt', 'w') as f: 58 | f.write('='*20 + 'Inputs' + '='*20 + f'\n{inputs}\n\n') 59 | f.write('='*20 + 'Processing' + '='*20 +'\n') 60 | f.write(f'treeid: {treeid}\nProgramme starts at: {dt}') 61 | 62 | 63 | ### downsample ### 64 | if self.downsample: 65 | self.pc, self.base_idx = downsample.run(self.pc, 66 | self.downsample, 67 | base_location=self.base_idx, 68 | delete=True, 69 | verbose=self.verbose) 70 | else: 71 | self.pc = downsample.voxelise(self.pc) 72 | if txt_file: 73 | with open(o_f+'.txt', 'a') as f: 74 | f.write("\n\n----Downsample----") 75 | f.write(f"\nPoints after downsampling (vlength = {self.downsample} m): {len(np.unique(self.pc.index))}") 76 | 77 | self.pc = self.pc[['x','y','z','pid']] 78 | 79 | ### build distance graph and calculate shortest path distance ### 80 | if base_corr: 81 | self.pc, self.G, new_base, basal_r = distance_from_base.run(self.pc, 82 | base_location=self.base_idx, 83 | cluster_size=self.cluster_size, 84 | knn=100, 85 | verbose=False, 86 | base_correction=base_corr) 87 | maxbin = basal_r * 1.6 # determine maxbin by basal radius 88 | else: 89 | self.pc, self.G = distance_from_base.run(self.pc, 90 | base_location=self.base_idx, 91 | cluster_size=self.cluster_size, 92 | knn=100, verbose=False, 93 | base_correction=base_corr) 94 | maxbin = 0.35 95 | if txt_file: 96 | with open(o_f+'.txt', 'a') as f: 97 | f.write('\n\n----Build initial graph----') 98 | f.write(f'\nInitial graph has {len(self.G.nodes)} nodes and {len(self.G.edges)} edges.') 99 | 100 | 101 | ### identify skeleton nodes ### 102 | # slice input point clouds 103 | minbin = 0.02 # unit in meter 104 | self.pc, self.bins = calculate_voxel_length.run(self.pc, exponent=1, maxbin=maxbin, minbin=minbin) 105 | self.pc = self.pc[~self.pc.distance_from_base.isna()] 106 | 107 | 108 | # identify skeleton nodes by DBSCAN clustering 109 | self.centres = build_skeleton.run(self, verbose=False) 110 | 111 | if txt_file: 112 | with open(o_f+'.txt', 'a') as f: 113 | f.write('\n\n----Identify skeleton nodes----') 114 | f.write(f"\nTotal bin numbers: {len(self.bins)}") 115 | f.write(f"\nTotal valid slice segments: {len(np.unique(self.pc.slice_id))}") 116 | f.write('\n\n----Refine skeleton----') 117 | 118 | 119 | ### refine skeleton nodes ### 120 | self.pc, self.centres = split_furcation.run(self.pc.copy(), self.centres.copy()) 121 | 122 | ### build skeleton graph ### 123 | self.G_skel_sf, self.path_dist, self.path_ids = build_graph.run(self.pc, self.centres, verbose=self.verbose) 124 | 125 | ### attribute skeleton ### 126 | self.centres, self.branch_hierarchy = attribute_centres.run(self.centres, self.path_ids, 127 | branch_hierarchy=True, verbose=False) 128 | 129 | if base_corr: 130 | # adjust the coords of the 1st slice centre to the coords of new_base_node 131 | if self.centres.slice_id.min() == 0: 132 | idx = self.centres[self.centres.slice_id == 0].index.values[0] 133 | self.centres.loc[idx, ('cx','cy','cz','distance_from_base')] = [new_base[0], new_base[1], new_base[2], 0] 134 | else: 135 | self.centres.loc[:, 'distance_from_base'] = self.centres.distance_from_base - self.centres.distance_from_base.min() 136 | self.pc.loc[:, 'distance_from_base'] = self.pc.distance_from_base - self.pc.distance_from_base.min() 137 | 138 | 139 | if txt_file: 140 | with open(o_f+'.txt', 'a') as f: 141 | f.write('\n\n----Rebuild furcation nodes----') 142 | f.write('\nAttribute of rebuilt skeleton...') 143 | f.write(f"\nSlices segments: {len(np.unique(self.centres.slice_id))}") 144 | f.write(f"\nSkeleton points: {len(np.unique(self.centres.node_id))}") 145 | 146 | 147 | ### estimate branch radius ### 148 | # determine the z-interval for radius estimation (unit: metre) 149 | trunk = self.centres[self.centres.nbranch == 0] 150 | dfb_max = trunk.distance_from_base.max() - trunk.distance_from_base.min() 151 | if dfb_max <= 5: 152 | dz2 = dfb_max / 5. 153 | if dfb_max <= 1.5: 154 | dz2 = .3 155 | else: 156 | dz2 = 1. 157 | 158 | # estimate radius for individual branches 159 | self.centres = estimate_radius.run(self.pc, self.centres, self.path_ids, 160 | dz1=.3, dz2=dz2, branch_list=None, plot=False) 161 | # apply constrains to smoothed radius to avoid significant overestimated radius 162 | self.centres = taper.run(self.centres, self.path_ids, tip_radius=None, est_rad='sm_radius', 163 | branch_list=None, plot=False, verbose=False) 164 | self.centres.loc[:,'sm_radius'] = self.centres.m_radius 165 | 166 | # identify outliers in raw estimates based on smoothed radius 167 | self.centres.loc[:,'zscore'] = (self.centres.sf_radius - self.centres.sm_radius) / self.centres.sm_radius 168 | threshold = np.nanpercentile(self.centres.zscore, 95) 169 | outlier_id = self.centres[np.abs(self.centres.zscore) >= threshold].node_id.values 170 | 171 | # adjust outliers 172 | self.centres.loc[self.centres.node_id.isin(outlier_id), 'sf_radius'] = self.centres[self.centres.node_id.isin(outlier_id)].sm_radius 173 | 174 | # apply constrains to filtered raw radius to avoid significant overestimated radius 175 | self.centres = taper.run(self.centres, self.path_ids, tip_radius=tip_width, est_rad='sf_radius', 176 | branch_list=None, plot=False, verbose=False) 177 | 178 | # try point-to-node dist to adjust overesitimated twig radius 179 | # find node_id of last 4 cylinders of each branch 180 | twig_nid = [] 181 | for nb in self.centres.nbranch.unique(): 182 | branch = self.centres[self.centres.nbranch == nb] 183 | ncyls = np.sort(branch.ncyl) 184 | if len(ncyls) <= 4: 185 | nids = branch[branch.ncyl.isin(ncyls)].node_id.unique() 186 | else: 187 | nids = branch[branch.ncyl.isin(ncyls[-4:])].node_id.unique() 188 | twig_nid.extend(nids) 189 | # calculate point-to-node distance for all twig nodes 190 | for n in twig_nid: 191 | node_pts = np.array(self.pc[self.pc.node_id == n][['x','y','z']]) 192 | node_cen = np.array(self.centres[self.centres.node_id == n][['cx','cy','cz']]) 193 | dist = np.nanmedian(np.linalg.norm((node_pts - node_cen), axis=1)) 194 | self.centres.loc[(self.centres.node_id == n), 'p2c_dist'] = dist 195 | self.centres.loc[(self.centres.m_radius > self.centres.p2c_dist), 'm_radius'] = self.centres[self.centres.m_radius > self.centres.p2c_dist].p2c_dist 196 | 197 | # delete branch whose sf_rad are all NAN 198 | del_nid = np.array([]) 199 | for n in self.centres.nbranch.unique(): 200 | br = self.centres[self.centres.nbranch == n] 201 | if br.sf_radius.isnull().values.all(): 202 | del_nid = np.append(del_nid, br.node_id.values) 203 | self.centres = self.centres.loc[~self.centres.node_id.isin(del_nid)] 204 | 205 | if txt_file: 206 | with open(o_f+'.txt', 'a') as f: 207 | f.write('\n\n----Estimate branch radius----') 208 | f.write(f"\nEstimated radius:\n{self.centres.m_radius.describe()}") 209 | 210 | self.pc = self.pc[self.pc.node_id.isin(self.centres.node_id.values)] 211 | # re-build skeleton graph 212 | self.G_skel_sf, self.path_dist, self.path_ids = build_graph.run(self.pc, self.centres, verbose=self.verbose) 213 | # re-attribute skeleton 214 | self.centres, self.branch_hierarchy = attribute_centres.run(self.centres, self.path_ids, 215 | branch_hierarchy=True, verbose=self.verbose) 216 | 217 | 218 | ### generate cylinder model ### 219 | generate_cylinder_model.run(self, radius_value='m_radius') 220 | 221 | if txt_file: 222 | with open(o_f+'.txt', 'a') as f: 223 | f.write('\n\n----Generate cylinder model----') 224 | f.write('\nModelling complete.\n') 225 | 226 | 227 | ### Result Summary ### 228 | ## tree-level statistics 229 | tree = self.cyls[['length', 'vol', 'surface_area']].sum().to_dict() 230 | tree['H_from_clouds'] = round((self.pc.z.max() - self.pc.z.min()), 2) 231 | tree['H_from_qsm'] = round((self.cyls.sz.max() - self.cyls.sz.min()), 2) 232 | tree['N_tip'] = len(self.cyls[self.cyls.is_tip]) 233 | tree['tip_rad_mean'] = self.cyls[self.cyls.is_tip].radius.mean() 234 | tree['tip_rad_std'] = self.cyls[self.cyls.is_tip].radius.std() 235 | 236 | if len(self.centres.loc[self.centres.is_tip]) > 1: 237 | tree['dist_between_tips'] = common.nn(self.centres.loc[self.centres.is_tip][['cx', 'cy', 'cz']].values, N=1).mean() 238 | else: tree['dist_between_tips'] = np.nan 239 | 240 | ## DBH estimation 241 | dbh_clouds, dbh_qsm = common.dbh_est(self, h=dbh_height, verbose=False, plot=False) 242 | tree['DBH_from_clouds'] = dbh_clouds 243 | tree['DBH_from_qsm'] = dbh_qsm 244 | 245 | ## trunk info 246 | trunk_nid = self.centres[self.centres.nbranch == 0].node_id.values 247 | for i in range(len(trunk_nid)-1): 248 | if i == 0: 249 | trunk = self.cyls[(self.cyls.p1 == trunk_nid[i+1]) & (self.cyls.p2 == trunk_nid[i])] 250 | else: 251 | trunk = trunk.append(self.cyls[(self.cyls.p1 == trunk_nid[i+1]) & (self.cyls.p2 == trunk_nid[i])]) 252 | tree['trunk_vol'] = trunk.vol.sum() 253 | tree['trunk_length'] = trunk.length.sum() 254 | 255 | self.tree = pd.DataFrame(data=tree, index=[0]) 256 | 257 | ## programme running time 258 | e_dt = datetime.now() 259 | self.time = (e_dt - dt).total_seconds() 260 | 261 | if txt_file: 262 | with open(o_f+'.txt', 'a') as f: 263 | f.write('\n\n' + '='*20 + 'Statistical summary' + '='*20 ) 264 | f.write(f"\nH from clouds: {tree['H_from_clouds']:.2f} m") 265 | f.write(f"\nH from qsm: {tree['H_from_qsm']:.2f} m") 266 | f.write(f"\nDBH from clouds: {tree['DBH_from_clouds']:.3f} m") 267 | f.write(f"\nDBH from qsm: {tree['DBH_from_qsm']:.3f} m") 268 | f.write(f"\nTot. branch len: {tree['length']:.2f} m") 269 | f.write(f"\nTot. volume: {tree['vol']:.4f} m³ = {tree['vol']*1e3:.1f} L") 270 | f.write(f"\nTot. surface area: {tree['surface_area']:.4f} m2") 271 | f.write(f"\nTrunk len: {tree['trunk_length']:.2f} m") 272 | f.write(f"\nTrunk volume: {tree['trunk_vol']:.4f} m³ = {tree['trunk_vol']*1e3:.1f} L") 273 | f.write(f"\nN tips: {tree['N_tip']:.0f}") 274 | f.write(f"\nAvg tip width: {tree['tip_rad_mean']*2:.3f} ± {tree['tip_rad_std']*2:.3f} m") 275 | f.write(f"\nAvg distance between tips: {tree['dist_between_tips']:.3f} m") 276 | f.write(f"\nTotal internodes (furcation nodes + tip nodes): {len(np.unique(self.centres.ninternode))}") 277 | f.write(f"\n2-children furcation nodes: {len(self.centres.loc[self.centres['n_furcation'] == 1])}" ) 278 | f.write(f"\n3-children furcation nodes: {len(self.centres.loc[self.centres['n_furcation'] == 2])}" ) 279 | f.write(f"\n4-children + furcation nodes: {len(self.centres.loc[self.centres['n_furcation'] >= 3])}") 280 | f.write('\n' + '='*40 + '\n') 281 | f.write(f'\nProgramme successfully completed.') 282 | m, s = divmod(self.time, 60) 283 | h, m = divmod(m, 60) 284 | f.write(f'\nTotal running time: {self.time:.0f}s = {h:.0f}h:{m:02.0f}m:{s:02.0f}s\n') 285 | 286 | ### save results ### 287 | # save cyl model into a .ply file 288 | fn_cyls = o_f + '.mesh.ply' 289 | IO.to_ply(self.cyls, fn_cyls) 290 | if txt_file: 291 | with open(o_f+'.txt', 'a') as f: 292 | f.write('\n\n----Save results----') 293 | f.write(f'\nMesh (cylinder) model has been saved in:\n{fn_cyls}\n') 294 | 295 | # save skeleton nodes into a .ply file 296 | fn_centres = o_f + '.centres.ply' 297 | IO.save_centres(self.centres, fn_centres) 298 | if txt_file: 299 | with open(o_f+'.txt', 'a') as f: 300 | f.write(f'\nSkeleton points have been saved in:\n{fn_centres}\n') 301 | 302 | # save all results into a json file 303 | for col in ['idx', 'scalar_intensity', 'pid', 'centre_id', 'zscore']: 304 | if col in self.centres.columns: 305 | self.centres = self.centres.drop(columns=[col]) 306 | if col in self.pc.columns: 307 | self.pc = self.pc.drop(columns=[col]) 308 | self.path_ids = {float(key): [float(i) for i in value] for key, value in self.path_ids.items()} 309 | 310 | fn_json = o_f + '.json' 311 | IO.qsm2json(self, fn_json, name=treeid, graph=save_graph) 312 | if txt_file: 313 | with open(o_f+'.txt', 'a') as f: 314 | f.write(f'\nJson file:\n{fn_json}\n') 315 | 316 | 317 | if __name__ == "__main__": 318 | 319 | parser = argparse.ArgumentParser() 320 | parser.add_argument('--inputs', '-i', type=str, required=True, help='path to inputs file') 321 | a = parser.parse_args() 322 | 323 | with open(a.inputs) as f: 324 | args = yaml.safe_load(f) 325 | for key, item in args.items(): 326 | print(f'{key}: {item}') 327 | 328 | run(data_path=args['data_path'], 329 | output_path=args['output_path'], 330 | base_idx=args['base_idx'], 331 | min_pts=args['min_pts'], 332 | cluster_size=args['cluster_size'], 333 | tip_width=args['tip_width'], 334 | verbose=args['verbose'], 335 | base_corr=args['base_corr'], 336 | dbh_height=args['dbh_height'], 337 | txt_file=args['txt_file'], 338 | save_graph=args['save_graph']) 339 | -------------------------------------------------------------------------------- /treegraph/split_furcation.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import struct 4 | from tqdm import tqdm 5 | from sklearn.cluster import KMeans 6 | from scipy.spatial.distance import pdist, squareform 7 | from scipy.sparse import csgraph 8 | from scipy.sparse.linalg import eigsh 9 | from scipy.stats import variation 10 | from numpy import linalg as LA 11 | from sklearn.cluster import SpectralClustering 12 | from pandarallel import pandarallel 13 | from treegraph.fit_cylinders import * 14 | from treegraph.build_skeleton import * 15 | from treegraph.build_graph import * 16 | from treegraph.attribute_centres import * 17 | from treegraph.common import * 18 | from treegraph.third_party.point2line import * 19 | from treegraph.third_party.closestDistanceBetweenLines import * 20 | from treegraph.third_party import ransac_cyl_fit 21 | 22 | 23 | def run(pc, centres): 24 | ''' 25 | Refine skeleton nodes using self-tunning spectral clustering method. 26 | 27 | Inputs: 28 | pc: pd.DataFrame 29 | original points and their attributes 30 | centres: pd.DataFrame 31 | skeleton nodes and their attributes 32 | 33 | Outputs: 34 | pc: pd.DataFrame 35 | new pc after updating clustering 36 | centres: pd.DataFrame 37 | new skeleton nodes after updateing clustering 38 | ''' 39 | for i, sid in enumerate(np.sort(centres.slice_id.unique())): 40 | nid = centres[centres.slice_id == sid].node_id.values 41 | if len(nid) > 1: 42 | break 43 | if i == 0: 44 | sid_start = int(np.sort(centres.slice_id.unique())[i]) 45 | else: 46 | sid_start = int(np.sort(centres.slice_id.unique())[i-1]) 47 | 48 | # run spectral clustering on all nodes except for lower trunk before furcation 49 | nodes = centres[centres.slice_id >= sid_start].node_id.unique() 50 | samples = pc[pc.node_id.isin(nodes)] 51 | 52 | # run pandarallel on groups of points 53 | groupby = samples.groupby('node_id') 54 | sent_back = groupby.apply(stsc_recursive, centres, sid_start=sid_start).values 55 | 56 | 57 | # update pc and centres 58 | new_pc = pd.DataFrame() 59 | new_centres = pd.DataFrame() 60 | nid_max = pc.node_id.max() + 1 61 | 62 | for x in sent_back: 63 | if len(x) == 0: 64 | continue 65 | new_pc = new_pc.append(x[0]) 66 | new_centres = new_centres.append(x[1]) 67 | # remove splitted nodes 68 | centres = centres.loc[centres.node_id != x[2]] 69 | 70 | new_centres.reset_index(inplace=True) 71 | if len(new_centres) < 1: 72 | return pc, centres 73 | 74 | # re-arrange node_id 75 | MAP = {v : i+nid_max for i, v in enumerate(new_centres.idx.unique())} 76 | new_pc.loc[:, 'node_id'] = new_pc.idx.map(MAP) 77 | new_centres.loc[:, 'node_id'] = new_centres.idx.map(MAP) 78 | 79 | # update centres df 80 | centres = centres.append(new_centres).sort_values(by=['slice_id','node_id']) 81 | centres.reset_index(inplace=True, drop=True) 82 | if 'index' in centres.columns: 83 | centres = centres.drop(columns=['index']) 84 | # update centre_id 85 | for s in centres.slice_id.unique(): 86 | sn = len(centres[centres.slice_id == s]) 87 | centres.loc[centres.slice_id == s, 'centre_id'] = np.arange(0, sn) 88 | # update pc 89 | pc.loc[pc.index.isin(new_pc.index), 'idx'] = new_pc.idx 90 | pc.loc[pc.index.isin(new_pc.index), 'node_id'] = new_pc.node_id 91 | 92 | def update_cid(x): 93 | x.loc[:, 'centre_id'] = centres[centres.node_id == x.node_id.values[0]].centre_id.values[0] 94 | return x 95 | pc = pc[pc.node_id.isin(centres.node_id.unique())] 96 | pc = pc.groupby('node_id').apply(update_cid).reset_index(drop=True) 97 | 98 | return pc, centres 99 | 100 | 101 | 102 | def stsc(clus_pc, centres, nn=None, sid_start=None, plot=False, point_size=8): 103 | 104 | sid = clus_pc.slice_id.unique()[0] # slice_id 105 | nid = clus_pc.node_id.unique()[0] # node_id 106 | cid = clus_pc.centre_id.unique()[0] # centre_id 107 | pts = len(clus_pc) 108 | 109 | if sid_start is None: 110 | for i, s in enumerate(np.sort(centres.slice_id.unique())): 111 | nids = centres[centres.slice_id == s].node_id.values 112 | if len(nids) > 1: 113 | break 114 | if i == 0: 115 | sid_start = int(np.sort(centres.slice_id.unique())[i]) 116 | else: 117 | sid_start = int(np.sort(centres.slice_id.unique())[i-1]) 118 | 119 | if nn is None: 120 | # determine n_neighbours for spectral clustering 121 | if pts <= 5: return [] 122 | elif pts <= 20: nn = 5 123 | elif pts <= 50: nn = 10 124 | elif pts <= 200: nn = 20 125 | elif pts <= 500: nn = int(pts*0.1) 126 | else: nn = int(pts*0.2) 127 | 128 | # auto determine cluster number from eigen gaps 129 | W = getAffinityMatrix(clus_pc[['x','y','z']], k = nn) 130 | k, eigenvalues, eigenvectors = eigenDecomposition(W, plot=False, topK=3) 131 | 132 | # spectral clustering 133 | spectral = SpectralClustering(n_clusters=k[0], 134 | random_state=0, 135 | affinity='nearest_neighbors', 136 | n_neighbors=nn).fit(clus_pc[['x', 'y', 'z']]) 137 | clus_pc.loc[:, 'klabels'] = spectral.labels_ 138 | c_num = len(np.unique(spectral.labels_)) 139 | 140 | # stop if current cluster cannot be segmented into more sub-clusters 141 | if c_num == 1: 142 | return [] 143 | 144 | # otherwise, try to merge over-segmented sub-clusters that belong to a single branch 145 | n = np.unique(clus_pc.node_id)[0] 146 | orig_cen = centres[centres.node_id == n][['cx','cy','cz']].values[0] 147 | dist_mean = cdist([orig_cen], clus_pc[['x','y','z']]).mean() 148 | dist_std = cdist([orig_cen], clus_pc[['x','y','z']]).std() 149 | cv = dist_std / dist_mean 150 | 151 | tmp = pd.DataFrame() 152 | 153 | if cv > .4: 154 | clus_pc.loc[:, 'adj_k'] = clus_pc.klabels 155 | else: 156 | count = 1 157 | for klabel in np.unique(spectral.labels_): 158 | mind = centres[centres.slice_id == sid_start].distance_from_base.values[0] 159 | maxd = centres.distance_from_base.max() 160 | currd = centres[centres.node_id == nid].distance_from_base.values[0] 161 | ratio = 1- ((maxd - currd) / (maxd - mind)) 162 | p = -4 ** (0.5 * ratio) + 2 163 | threshold = dist_mean * p 164 | 165 | subclus_pc = clus_pc[clus_pc.klabels == klabel] 166 | if len(subclus_pc) <= 100: 167 | new_cen = subclus_pc[['x','y','z']].median().to_numpy() 168 | else: 169 | new_cen = CPC(subclus_pc).x 170 | 171 | d_cc = np.linalg.norm(new_cen - orig_cen) 172 | if d_cc <= threshold: 173 | tmp = tmp.append(clus_pc[clus_pc.klabels == klabel]) 174 | else: 175 | clus_pc.loc[clus_pc.klabels == klabel, 'adj_k'] = int(count) 176 | count += 1 177 | 178 | clus_pc.loc[clus_pc.index.isin(tmp.index), 'adj_k'] = int(count) 179 | 180 | c_num_merge = len(clus_pc[clus_pc.adj_k >= 0].adj_k.unique()) 181 | 182 | if plot: 183 | pc_stsc = clus_pc 184 | fig, axs = plt.subplots(1,4,figsize=(15,4.8)) 185 | ax = axs.flatten() 186 | # spectral clustering results 187 | ax[0].scatter(pc_stsc.y, pc_stsc.z, c=pc_stsc.klabels, cmap='Pastel1', s=point_size) 188 | ax[1].scatter(pc_stsc.x, pc_stsc.y, c=pc_stsc.klabels, cmap='Pastel1', s=point_size) 189 | ax[0].set_xlabel('Y coords (m)') 190 | ax[0].set_ylabel('Z coords (m)') 191 | ax[1].set_xlabel('X coords (m)') 192 | ax[1].set_ylabel('Y coords (m)') 193 | ax[0].set_title('Before merge') 194 | ax[1].set_title('Before merge') 195 | 196 | # after adjustment based on threshold 197 | ax[2].scatter(pc_stsc.y, pc_stsc.z, c=pc_stsc.adj_k, cmap='Pastel1', s=point_size) 198 | ax[3].scatter(pc_stsc.x, pc_stsc.y, c=pc_stsc.adj_k, cmap='Pastel1', s=point_size) 199 | ax[2].set_xlabel('Y coords (m)') 200 | ax[2].set_ylabel('Z coords (m)') 201 | ax[3].set_xlabel('X coords (m)') 202 | ax[3].set_ylabel('Y coords (m)') 203 | ax[2].set_title('After merge') 204 | ax[3].set_title('After merge') 205 | 206 | fig.suptitle(f'sid={sid}, nid={nid}, pts={pts}, nn={nn}, c_num={c_num}, c_num_m={c_num_merge}, dist_p2c={dist_mean:.2f}m ± {dist_std:.2f}m, cv={cv:.2f}', 207 | fontsize=15) 208 | fig.tight_layout() 209 | 210 | return clus_pc 211 | 212 | 213 | def stsc_recursive(pc, centres, nn=None, sid_start=None, plot=False): 214 | sid = pc.slice_id.unique()[0] # slice_id 215 | nid = pc.node_id.unique()[0] # node_id 216 | cid = pc.centre_id.unique()[0] # centre_id 217 | 218 | if sid_start is None: 219 | for i, s in enumerate(np.sort(centres.slice_id.unique())): 220 | nids = centres[centres.slice_id == s].node_id.values 221 | if len(nids) > 1: 222 | break 223 | if i == 0: 224 | sid_start = int(np.sort(centres.slice_id.unique())[i]) 225 | else: 226 | sid_start = int(np.sort(centres.slice_id.unique())[i-1]) 227 | 228 | subclus_1 = stsc(pc, centres, nn=nn, sid_start=sid_start, plot=False, point_size=5) 229 | 230 | if len(subclus_1) < 1: 231 | return [] 232 | 233 | kmax = subclus_1.adj_k.max() 234 | 235 | if len(np.sort(subclus_1.adj_k.unique())) < 2: 236 | pc.loc[:, 'adj_k'] = subclus_1.adj_k 237 | else: 238 | for k in np.sort(subclus_1.adj_k.unique()): 239 | subpc = subclus_1[subclus_1.adj_k == k] 240 | if len(subpc) <= 20: 241 | continue 242 | if len(subpc) <= 100: 243 | new_cen = subpc[['x','y','z']].median().to_numpy() 244 | else: 245 | new_cen = CPC(subpc).x 246 | 247 | dist_mean = cdist([new_cen], subpc[['x','y','z']]).mean() 248 | dist_std = cdist([new_cen], subpc[['x','y','z']]).std() 249 | cv = dist_std / dist_mean 250 | 251 | if (cv > .3) or (dist_mean >= .5): 252 | subclus_2 = stsc(subpc, centres, nn=nn, sid_start=sid_start, plot=False, point_size=10) 253 | if len(subclus_2) < 1: 254 | continue 255 | subclus_2.loc[:, 'adj_k'] = subclus_2.adj_k + kmax 256 | kmax += len(np.unique(subclus_2.adj_k)) 257 | pc.loc[pc.index.isin(subclus_2.index), 'adj_k'] = subclus_2.adj_k 258 | 259 | for kk in np.sort(subclus_2.adj_k.unique()): 260 | subpc_ = subclus_2[subclus_2.adj_k == kk] 261 | if len(subpc_) <= 20: 262 | continue 263 | else: 264 | new_cen_ = subpc_[['x','y','z']].median().to_numpy() 265 | dist_mean_ = cdist([new_cen_], subpc_[['x','y','z']]).mean() 266 | 267 | if dist_mean >= .5: 268 | subclus_3 = stsc(subpc_, centres, nn=nn, sid_start=sid_start, plot=False, point_size=10) 269 | if len(subclus_3) < 1: 270 | continue 271 | subclus_3.loc[:, 'adj_k'] = subclus_3.adj_k + kmax 272 | kmax += len(np.unique(subclus_3.adj_k)) 273 | pc.loc[pc.index.isin(subclus_3.index), 'adj_k'] = subclus_3.adj_k 274 | 275 | pc.loc[:, 'adj_k'] = pc.adj_k.apply(lambda x: np.where(np.array(np.unique(pc.adj_k)) == x)[0][0]) 276 | 277 | 278 | if plot: 279 | pc_stsc = pc 280 | point_size = 10 281 | pts = len(pc_stsc) 282 | c_num = len(pc_stsc.adj_k.unique()) 283 | 284 | fig, axs = plt.subplots(1,4,figsize=(15,4.8)) 285 | ax = axs.flatten() 286 | # spectral clustering results 287 | ax[0].scatter(pc_stsc.y, pc_stsc.z, c=pc_stsc.klabels, cmap='Pastel1', s=point_size) 288 | ax[1].scatter(pc_stsc.x, pc_stsc.y, c=pc_stsc.klabels, cmap='Pastel1', s=point_size) 289 | ax[0].set_xlabel('Y coords (m)') 290 | ax[0].set_ylabel('Z coords (m)') 291 | ax[1].set_xlabel('X coords (m)') 292 | ax[1].set_ylabel('Y coords (m)') 293 | ax[0].set_title('Before recursion') 294 | ax[1].set_title('Before recursion') 295 | 296 | # after adjustment based on threshold 297 | ax[2].scatter(pc_stsc.y, pc_stsc.z, c=pc_stsc.adj_k, cmap='Pastel1', s=point_size) 298 | ax[3].scatter(pc_stsc.x, pc_stsc.y, c=pc_stsc.adj_k, cmap='Pastel1', s=point_size) 299 | ax[2].set_xlabel('Y coords (m)') 300 | ax[2].set_ylabel('Z coords (m)') 301 | ax[3].set_xlabel('X coords (m)') 302 | ax[3].set_ylabel('Y coords (m)') 303 | ax[2].set_title('After recursion') 304 | ax[3].set_title('After recursion') 305 | 306 | fig.suptitle(f'sid={sid}, nid={nid}, pts={pts}, c_num={c_num}', 307 | fontsize=15) 308 | fig.tight_layout() 309 | 310 | 311 | new_centres = pd.DataFrame() 312 | # loop over adjusted new clusters 313 | for kn in np.unique(pc.adj_k): 314 | dcluster = pc[pc.adj_k == kn] 315 | if len(dcluster) <= 100: 316 | new_cen_coords = dcluster[['x','y','z']].median().to_numpy() 317 | else: 318 | new_cen_coords = CPC(dcluster).x 319 | new_idx = struct.pack('iii', int(sid), int(cid), int(kn)) 320 | 321 | pc.loc[pc.index.isin(dcluster.index), 'centre_id'] = int(kn) 322 | pc.loc[pc.index.isin(dcluster.index), 'idx'] = new_idx 323 | 324 | new_centres = new_centres.append(pd.Series({'slice_id': int(sid), 325 | 'centre_id':int(kn), 326 | 'cx':new_cen_coords[0], 327 | 'cy':new_cen_coords[1], 328 | 'cz':new_cen_coords[2], 329 | 'distance_from_base':dcluster.distance_from_base.mean(), 330 | 'n_points':len(dcluster), 331 | 'idx':new_idx}), ignore_index=True) 332 | pc = pc.drop(columns=['klabels','adj_k']) 333 | if (len(new_centres) != 0) & (isinstance(new_centres, pd.DataFrame)): 334 | return [pc, new_centres, nid] 335 | else: 336 | return [] 337 | 338 | 339 | # https://github.com/ciortanmadalina/high_noise_clustering/blob/master/spectral_clustering.ipynb 340 | def getAffinityMatrix(coordinates, k = 5): 341 | """ 342 | Calculate affinity matrix based on input coordinates matrix and the numeber 343 | of nearest neighbours. 344 | 345 | Apply local scaling based on the k nearest neighbour 346 | References: 347 | https://papers.nips.cc/paper/2619-self-tuning-spectral-clustering.pdf 348 | """ 349 | # calculate euclidian distance matrix 350 | dists = squareform(pdist(coordinates)) 351 | 352 | # for each row, sort the distances ascendingly and take the index of the 353 | #k-th position (nearest neighbour) 354 | knn_distances = np.sort(dists, axis=0)[k] 355 | knn_distances = knn_distances[np.newaxis].T 356 | 357 | # calculate sigma_i * sigma_j 358 | local_scale = knn_distances.dot(knn_distances.T) 359 | 360 | affinity_matrix = dists * dists 361 | affinity_matrix = -affinity_matrix / local_scale 362 | # divide square distance matrix by local scale 363 | affinity_matrix[np.where(np.isnan(affinity_matrix))] = 0.0 364 | # apply exponential 365 | affinity_matrix = np.exp(affinity_matrix) 366 | np.fill_diagonal(affinity_matrix, 0) 367 | return affinity_matrix 368 | 369 | 370 | # https://github.com/ciortanmadalina/high_noise_clustering/blob/master/spectral_clustering.ipynb 371 | def eigenDecomposition(A, plot = True, topK=10): 372 | """ 373 | :param A: Affinity matrix 374 | :param plot: plots the sorted eigen values for visual inspection 375 | :return A tuple containing: 376 | - the optimal number of clusters by eigengap heuristic 377 | - all eigen values 378 | - all eigen vectors 379 | 380 | This method performs the eigen decomposition on a given affinity matrix, 381 | following the steps recommended in the paper: 382 | 1. Construct the normalized affinity matrix: L = D−1/2ADˆ −1/2. 383 | 2. Find the eigenvalues and their associated eigen vectors 384 | 3. Identify the maximum gap which corresponds to the number of clusters 385 | by eigengap heuristic 386 | 387 | References: 388 | https://papers.nips.cc/paper/2619-self-tuning-spectral-clustering.pdf 389 | http://www.kyb.mpg.de/fileadmin/user_upload/files/publications/attachments/Luxburg07_tutorial_4488%5b0%5d.pdf 390 | """ 391 | L = csgraph.laplacian(A, normed=True) 392 | n_components = A.shape[0] 393 | 394 | # LM parameter : Eigenvalues with largest magnitude (eigs, eigsh), that is, largest eigenvalues in 395 | # the euclidean norm of complex numbers. 396 | # eigenvalues, eigenvectors = eigsh(L, k=n_components, which="LM", sigma=1.0, maxiter=5000) 397 | eigenvalues, eigenvectors = LA.eig(L) 398 | 399 | if plot: 400 | plt.title('Largest eigen values of input matrix') 401 | plt.scatter(np.arange(len(eigenvalues)), eigenvalues, s=2) 402 | plt.grid() 403 | 404 | # Identify the optimal number of clusters as the index corresponding 405 | # to the larger gap between eigen values 406 | index_largest_gap = np.argsort(np.diff(eigenvalues))[::-1][:topK] 407 | nb_clusters = index_largest_gap + 1 408 | 409 | return nb_clusters, eigenvalues, eigenvectors 410 | 411 | 412 | 413 | def intersection(A0, A1, B0, B1, clampA0=True, clampA=True): 414 | 415 | pA, pB, D = closestDistanceBetweenLines(A0, A1, B0, B1, clampA0=clampA, clampA1=clampA) 416 | if np.isnan(D): D = np.inf 417 | return pA, pB, D 418 | -------------------------------------------------------------------------------- /treegraph/taper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import matplotlib.pyplot as plt 4 | from scipy import optimize 5 | from pandarallel import pandarallel 6 | from tqdm.autonotebook import tqdm 7 | 8 | 9 | def run(centres, path_ids, tip_radius=None, est_rad='sf_radius', 10 | branch_list=None, verbose=False, plot=False): 11 | ''' 12 | Inputs: 13 | tip_radius: None or float, unit in mm 14 | ''' 15 | if 'm_radius' in centres.columns: 16 | centres = centres.drop(columns=['m_radius']) 17 | 18 | if branch_list is None: 19 | samples = centres.copy() 20 | else: 21 | samples = centres[centres.nbranch.isin(branch_list)].copy() 22 | 23 | samples.loc[:, 'm_radius'] = samples[est_rad].copy() * 1e3 # unit in mm 24 | 25 | # estimate mean tip radius from sf_radius if not given 26 | if tip_radius is None: 27 | tip_radius = centres[centres.is_tip].sf_radius.mean(skipna=True) * 1e3 # unit in mm 28 | 29 | # run pandarallel on groups of points 30 | groupby = samples.groupby('nbranch') 31 | pandarallel.initialize(nb_workers=min(24, len(groupby)), progress_bar=verbose) 32 | try: 33 | sent_back = groupby.parallel_apply(radius_correct, centres, path_ids, 34 | tip_radius, est_rad).values 35 | except OverflowError: 36 | print('!pandarallel could not initiate progress bars, running without') 37 | pandarallel.initialize(progress_bar=False) 38 | sent_back = groupby.parallel_apply(radius_correct, centres, path_ids, 39 | tip_radius, est_rad).values 40 | 41 | samples_new = pd.DataFrame() 42 | for x in sent_back: 43 | if len(x[0]) == 0: continue 44 | samples_new = samples_new.append(x[0]) 45 | 46 | centres.loc[centres.node_id.isin(samples_new.node_id.values), 47 | 'm_radius'] = samples_new.m_radius / 1e3 # unit in meter 48 | 49 | return centres 50 | 51 | 52 | def radius_correct(samples, centres, path_ids, tip_radius, est_rad, 53 | plot=False, xlim=None, ylim=None): 54 | branch = samples[['nbranch', 'node_id', 'distance_from_base', 'sf_radius', 'm_radius', 'cv']] 55 | nbranch = np.unique(branch.nbranch)[0] 56 | if nbranch != 0: 57 | # ensure child branch radius doesn't exceed twice that of its parent 58 | parent_node = centres[centres.nbranch == nbranch].parent_node.values[0] 59 | if len(centres[centres.node_id == parent_node]) != 0: 60 | max_radius = centres[centres.node_id == parent_node][est_rad].values[0] * 1e2 # unit in cm 61 | branch.loc[branch.m_radius > 2*max_radius, 'm_radius'] = 2*max_radius 62 | 63 | # find stem furcation node 64 | ncyl = centres[centres.ninternode == 0].ncyl.max() 65 | stem_fur_node = centres[(centres.nbranch == 0) & (centres.ncyl == ncyl)].node_id.values[0] 66 | # segments from stem furcation node to branch tip 67 | tip = samples[(samples.nbranch == nbranch)].sort_values('ncyl').node_id.values[-1] 68 | path = path_ids[tip] 69 | fur_id = path.index(stem_fur_node) 70 | path = path[fur_id+1:] 71 | 72 | if est_rad == 'sf_radius': 73 | path = centres.loc[centres.node_id.isin(path)][['node_id', 'distance_from_base', 74 | 'sf_radius', 'cv']] 75 | path = path.loc[~np.isnan(path.sf_radius)] 76 | elif est_rad == 'sm_radius': 77 | path = centres.loc[centres.node_id.isin(path)][['node_id', 'distance_from_base', 78 | 'sm_radius', 'cv']] 79 | path = path.loc[~np.isnan(path.sm_radius)] 80 | if len(path) < 4: 81 | samples = samples.loc[~(samples.nbranch == nbranch)] 82 | centres = centres.loc[~(centres.nbranch == nbranch)] 83 | return [samples] 84 | path.loc[:, 'm_radius'] = path[est_rad] * 1e2 # unit in cm 85 | 86 | # segment path into sections and calculate initial upper bound points 87 | X = np.linspace(path.distance_from_base.min(), path.distance_from_base.max(), 20) 88 | cut = pd.cut(path.distance_from_base, X) 89 | bounds = path.groupby(cut).mean().drop(columns=['node_id', est_rad]) 90 | bounds.distance_from_base = path.groupby(cut).distance_from_base.max() 91 | bounds.set_index(np.arange(len(bounds)), inplace=True) 92 | bounds.loc[:, 'upp'] = bounds.m_radius * 1.2 93 | bounds.loc[bounds.cv>0, 'weight'] = 1. / bounds[bounds.cv>0].cv 94 | bounds.loc[~(bounds.cv>0), 'weight'] = 0. 95 | idx = bounds.index.max() 96 | bounds.loc[idx, 'upp'] = tip_radius 97 | bounds.loc[idx, 'low'] = tip_radius 98 | bounds = bounds[~np.isnan(bounds.m_radius)] 99 | if len(bounds) < 4: 100 | return [samples] 101 | 102 | # fit an upper bound curve 103 | L, C = 'upp', 'g' 104 | f_power = lambda x, a, b, c: a * np.power(x,b) + c # power 105 | f_exp = lambda x, a, b, c, d: a * np.exp(-b * x + c) + d # exponential 106 | f_para = lambda x, a, b, c: a + b*x + c*np.power(x,2) # parabola 107 | functions = [f_power, f_exp, f_para] 108 | 109 | best_func = None 110 | best_para = None 111 | best_err = np.inf 112 | 113 | for func in functions: 114 | try: 115 | popt, pcov = optimize.curve_fit(func, bounds.distance_from_base, 116 | bounds[L], sigma=bounds.weight, maxfev=1000) 117 | y_pred = func(bounds.distance_from_base, *popt) 118 | rmse = np.sqrt(np.mean(y_pred - bounds[L])**2) 119 | if rmse < best_err: 120 | best_func, best_para, best_err = func, popt, rmse 121 | except: 122 | pass 123 | 124 | branch.loc[:, L] = best_func(branch.distance_from_base, *best_para) 125 | branch.loc[branch[L] <= 0, 'upp'] = .0015 126 | 127 | # adjust radii that are NAN or fall beyond upper bound 128 | branch.loc[np.isnan(branch.m_radius), 'm_radius'] = branch.loc[np.isnan(branch.m_radius)].upp 129 | branch.loc[branch.m_radius > branch.upp, 'm_radius'] = branch.loc[branch.m_radius > branch.upp].upp 130 | 131 | # update centres 132 | if nbranch == 0: 133 | samples.loc[samples.node_id.isin(path.node_id.values), 'm_radius'] = branch.m_radius 134 | else: 135 | samples.loc[samples.node_id.isin(branch.node_id.values), 'm_radius'] = branch.m_radius 136 | 137 | if plot: 138 | fig, axs = plt.subplots(1,1,figsize=(8,4)) 139 | ax = [axs] 140 | ax[0].plot(bounds['distance_from_base'], bounds['upp'], 'go', 141 | markerfacecolor='none', markersize=2, 142 | label='upper bound candidate pts') 143 | X = np.linspace(bounds.distance_from_base.min(), bounds.distance_from_base.max(), 20) 144 | ax[0].plot(X, best_func(X, *best_para), 'g--', linewidth=1, label='fitted upper bound') 145 | 146 | # original radius estimates 147 | ax[0].plot(branch['distance_from_base'], branch['sf_radius']*1e2, 'r-', 148 | linewidth=1, alpha=0.5, label='Oringal estimates') 149 | ax[0].plot(branch['distance_from_base'], branch['sf_radius']*1e2, 'ro', 150 | markerfacecolor='none', markersize=1) 151 | # corrected radius estimates 152 | ax[0].plot(branch['distance_from_base'], branch['m_radius'], 'b-', 153 | linewidth=1, alpha=0.5, label='Corrected estimates') 154 | ax[0].plot(branch['distance_from_base'], branch['m_radius'], 'bo', 155 | markerfacecolor='none', markersize=1) 156 | 157 | ax[0].set_xlabel('Distance from base (m)') 158 | ax[0].set_ylabel('Estimated radius (cm)') 159 | if xlim is not None: 160 | ax[0].set_xlim(xlim[0], xlim[1]) 161 | if ylim is not None: 162 | ax[0].set_ylim(ylim[0], ylim[1]) 163 | ax[0].set_title(f'Branch {nbranch}') 164 | ax[0].legend(loc='upper right') 165 | 166 | fig.tight_layout() 167 | 168 | return [samples] 169 | -------------------------------------------------------------------------------- /treegraph/third_party/available_cpu_count.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import subprocess 4 | 5 | """ 6 | https://stackoverflow.com/a/1006301/1414831 7 | """ 8 | 9 | def available_cpu_count(): 10 | """ Number of available virtual or physical CPUs on this system, i.e. 11 | user/real as output by time(1) when called with an optimally scaling 12 | userspace-only program""" 13 | 14 | # cpuset 15 | # cpuset may restrict the number of *available* processors 16 | try: 17 | m = re.search(r'(?m)^Cpus_allowed:\s*(.*)$', 18 | open('/proc/self/status').read()) 19 | if m: 20 | res = bin(int(m.group(1).replace(',', ''), 16)).count('1') 21 | if res > 0: 22 | return res 23 | except IOError: 24 | pass 25 | 26 | # Python 2.6+ 27 | try: 28 | import multiprocessing 29 | return multiprocessing.cpu_count() 30 | except (ImportError, NotImplementedError): 31 | pass 32 | 33 | # https://github.com/giampaolo/psutil 34 | try: 35 | import psutil 36 | return psutil.cpu_count() # psutil.NUM_CPUS on old versions 37 | except (ImportError, AttributeError): 38 | pass 39 | 40 | # POSIX 41 | try: 42 | res = int(os.sysconf('SC_NPROCESSORS_ONLN')) 43 | 44 | if res > 0: 45 | return res 46 | except (AttributeError, ValueError): 47 | pass 48 | 49 | # Windows 50 | try: 51 | res = int(os.environ['NUMBER_OF_PROCESSORS']) 52 | 53 | if res > 0: 54 | return res 55 | except (KeyError, ValueError): 56 | pass 57 | 58 | # jython 59 | try: 60 | from java.lang import Runtime 61 | runtime = Runtime.getRuntime() 62 | res = runtime.availableProcessors() 63 | if res > 0: 64 | return res 65 | except ImportError: 66 | pass 67 | 68 | # BSD 69 | try: 70 | sysctl = subprocess.Popen(['sysctl', '-n', 'hw.ncpu'], 71 | stdout=subprocess.PIPE) 72 | scStdout = sysctl.communicate()[0] 73 | res = int(scStdout) 74 | 75 | if res > 0: 76 | return res 77 | except (OSError, ValueError): 78 | pass 79 | 80 | # Linux 81 | try: 82 | res = open('/proc/cpuinfo').read().count('processor\t:') 83 | 84 | if res > 0: 85 | return res 86 | except IOError: 87 | pass 88 | 89 | # Solaris 90 | try: 91 | pseudoDevices = os.listdir('/devices/pseudo/') 92 | res = 0 93 | for pd in pseudoDevices: 94 | if re.match(r'^cpuid@[0-9]+$', pd): 95 | res += 1 96 | 97 | if res > 0: 98 | return res 99 | except OSError: 100 | pass 101 | 102 | # Other UNIXes (heuristic) 103 | try: 104 | try: 105 | dmesg = open('/var/run/dmesg.boot').read() 106 | except IOError: 107 | dmesgProcess = subprocess.Popen(['dmesg'], stdout=subprocess.PIPE) 108 | dmesg = dmesgProcess.communicate()[0] 109 | 110 | res = 0 111 | while '\ncpu' + str(res) + ':' in dmesg: 112 | res += 1 113 | 114 | if res > 0: 115 | return res 116 | except OSError: 117 | pass 118 | 119 | raise Exception('Can not determine number of CPUs on this system') 120 | -------------------------------------------------------------------------------- /treegraph/third_party/closestDistanceBetweenLines.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # copied from https://stackoverflow.com/a/18994296/1414831, thank you to Fnord 4 | 5 | def closestDistanceBetweenLines(a0,a1,b0,b1, 6 | clampAll=False,clampA0=False,clampA1=False,clampB0=False,clampB1=False): 7 | 8 | ''' Given two lines defined by numpy.array pairs (a0,a1,b0,b1) 9 | Return the closest points on each segment and their distance 10 | ''' 11 | 12 | # If clampAll=True, set all clamps to True 13 | if clampAll: 14 | clampA0=True 15 | clampA1=True 16 | clampB0=True 17 | clampB1=True 18 | 19 | 20 | # Calculate denomitator 21 | A = a1 - a0 22 | B = b1 - b0 23 | magA = np.linalg.norm(A) 24 | magB = np.linalg.norm(B) 25 | 26 | _A = A / magA 27 | _B = B / magB 28 | 29 | cross = np.cross(_A, _B); 30 | denom = np.linalg.norm(cross)**2 31 | 32 | 33 | # If lines are parallel (denom=0) test if lines overlap. 34 | # If they don't overlap then there is a closest point solution. 35 | # If they do overlap, there are infinite closest positions, but there is a closest distance 36 | if not denom: 37 | d0 = np.dot(_A,(b0-a0)) 38 | 39 | # Overlap only possible with clamping 40 | if clampA0 or clampA1 or clampB0 or clampB1: 41 | d1 = np.dot(_A,(b1-a0)) 42 | 43 | # Is segment B before A? 44 | if d0 <= 0 >= d1: 45 | if clampA0 and clampB1: 46 | if np.absolute(d0) < np.absolute(d1): 47 | return a0,b0,np.linalg.norm(a0-b0) 48 | return a0,b1,np.linalg.norm(a0-b1) 49 | 50 | 51 | # Is segment B after A? 52 | elif d0 >= magA <= d1: 53 | if clampA1 and clampB0: 54 | if np.absolute(d0) < np.absolute(d1): 55 | return a1,b0,np.linalg.norm(a1-b0) 56 | return a1,b1,np.linalg.norm(a1-b1) 57 | 58 | 59 | # Segments overlap, return distance between parallel segments 60 | return None,None,np.linalg.norm(((d0*_A)+a0)-b0) 61 | 62 | 63 | 64 | # Lines criss-cross: Calculate the projected closest points 65 | t = (b0 - a0); 66 | detA = np.linalg.det([t, _B, cross]) 67 | detB = np.linalg.det([t, _A, cross]) 68 | 69 | t0 = detA/denom; 70 | t1 = detB/denom; 71 | 72 | pA = a0 + (_A * t0) # Projected closest point on segment A 73 | pB = b0 + (_B * t1) # Projected closest point on segment B 74 | 75 | 76 | # Clamp projections 77 | if clampA0 or clampA1 or clampB0 or clampB1: 78 | if clampA0 and t0 < 0: 79 | pA = a0 80 | elif clampA1 and t0 > magA: 81 | pA = a1 82 | 83 | if clampB0 and t1 < 0: 84 | pB = b0 85 | elif clampB1 and t1 > magB: 86 | pB = b1 87 | 88 | # Clamp projection A 89 | if (clampA0 and t0 < 0) or (clampA1 and t0 > magA): 90 | dot = np.dot(_B,(pA-b0)) 91 | if clampB0 and dot < 0: 92 | dot = 0 93 | elif clampB1 and dot > magB: 94 | dot = magB 95 | pB = b0 + (_B * dot) 96 | 97 | # Clamp projection B 98 | if (clampB0 and t1 < 0) or (clampB1 and t1 > magB): 99 | dot = np.dot(_A,(pB-a0)) 100 | if clampA0 and dot < 0: 101 | dot = 0 102 | elif clampA1 and dot > magA: 103 | dot = magA 104 | pA = a0 + (_A * dot) 105 | 106 | 107 | return pA,pB,np.linalg.norm(pA-pB) -------------------------------------------------------------------------------- /treegraph/third_party/cyl2ply.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import numpy as np 4 | import sys 5 | import argparse 6 | import pandas as pd 7 | 8 | from tqdm.autonotebook import tqdm 9 | 10 | # header needed in ply-file 11 | header = ["ply", 12 | "format ascii 1.0", 13 | "comment Author: Cornelis", 14 | "obj_info Generated using Python", 15 | "element vertex 50", 16 | "property float x", 17 | "property float y", 18 | "property float z", 19 | "property float 0", 20 | "element face 96", 21 | "property list uchar int vertex_indices", 22 | "end_header"] 23 | 24 | # faces as needed in ply-file face is expressed by the 4 vertice IDs of the face 25 | faces = [[3, 0, 3, 2], 26 | [3, 0, 4, 3], 27 | [3, 0, 5, 4], 28 | [3, 0, 6, 5], 29 | [3, 0, 7, 6], 30 | [3, 0, 8, 7], 31 | [3, 0, 9, 8], 32 | [3, 0, 10, 9], 33 | [3, 0, 11, 10], 34 | [3, 0, 12, 11], 35 | [3, 0, 13, 12], 36 | [3, 0, 14, 13], 37 | [3, 0, 15, 14], 38 | [3, 0, 16, 15], 39 | [3, 0, 17, 16], 40 | [3, 0, 18, 17], 41 | [3, 0, 19, 18], 42 | [3, 0, 20, 19], 43 | [3, 0, 21, 20], 44 | [3, 0, 22, 21], 45 | [3, 0, 23, 22], 46 | [3, 0, 24, 23], 47 | [3, 0, 25, 24], 48 | [3, 0, 2, 25], 49 | [3, 1, 26, 27], 50 | [3, 1, 27, 28], 51 | [3, 1, 28, 29], 52 | [3, 1, 29, 30], 53 | [3, 1, 30, 31], 54 | [3, 1, 31, 32], 55 | [3, 1, 32, 33], 56 | [3, 1, 33, 34], 57 | [3, 1, 34, 35], 58 | [3, 1, 35, 36], 59 | [3, 1, 36, 37], 60 | [3, 1, 37, 38], 61 | [3, 1, 38, 39], 62 | [3, 1, 39, 40], 63 | [3, 1, 40, 41], 64 | [3, 1, 41, 42], 65 | [3, 1, 42, 43], 66 | [3, 1, 43, 44], 67 | [3, 1, 44, 45], 68 | [3, 1, 45, 46], 69 | [3, 1, 46, 47], 70 | [3, 1, 47, 48], 71 | [3, 1, 48, 49], 72 | [3, 1, 49, 26], 73 | [3, 2, 3, 26], 74 | [3, 26, 3, 27], 75 | [3, 3, 4, 27], 76 | [3, 27, 4, 28], 77 | [3, 4, 5, 28], 78 | [3, 28, 5, 29], 79 | [3, 5, 6, 29], 80 | [3, 29, 6, 30], 81 | [3, 6, 7, 30], 82 | [3, 30, 7, 31], 83 | [3, 7, 8, 31], 84 | [3, 31, 8, 32], 85 | [3, 8, 9, 32], 86 | [3, 32, 9, 33], 87 | [3, 9, 10, 33], 88 | [3, 33, 10, 34], 89 | [3, 10, 11, 34], 90 | [3, 34, 11, 35], 91 | [3, 11, 12, 35], 92 | [3, 35, 12, 36], 93 | [3, 12, 13, 36], 94 | [3, 36, 13, 37], 95 | [3, 13, 14, 37], 96 | [3, 37, 14, 38], 97 | [3, 14, 15, 38], 98 | [3, 38, 15, 39], 99 | [3, 15, 16, 39], 100 | [3, 39, 16, 40], 101 | [3, 16, 17, 40], 102 | [3, 40, 17, 41], 103 | [3, 17, 18, 41], 104 | [3, 41, 18, 42], 105 | [3, 18, 19, 42], 106 | [3, 42, 19, 43], 107 | [3, 19, 20, 43], 108 | [3, 43, 20, 44], 109 | [3, 20, 21, 44], 110 | [3, 44, 21, 45], 111 | [3, 21, 22, 45], 112 | [3, 45, 22, 46], 113 | [3, 22, 23, 46], 114 | [3, 46, 23, 47], 115 | [3, 23, 24, 47], 116 | [3, 47, 24, 48], 117 | [3, 24, 25, 48], 118 | [3, 48, 25, 49], 119 | [3, 25, 2, 49], 120 | [3, 49, 2, 26]] 121 | 122 | def dot(v1,v2): 123 | '''returns dot-product of two vectors''' 124 | return sum(p*q for p,q in zip(v1,v2)) 125 | 126 | def rotation_matrix(A,angle): 127 | '''returns the rotation matrix''' 128 | c = math.cos(angle) 129 | s = math.sin(angle) 130 | R = [[A[0]**2+(1-A[0]**2)*c, A[0]*A[1]*(1-c)-A[2]*s, A[0]*A[2]*(1-c)+A[1]*s], 131 | [A[0]*A[1]*(1-c)+A[2]*s, A[1]**2+(1-A[1]**2)*c, A[1]*A[2]*(1-c)-A[0]*s], 132 | [A[0]*A[2]*(1-c)-A[1]*s, A[1]*A[2]*(1-c)+A[0]*s, A[2]**2+(1-A[2]**2)*c]] 133 | return R 134 | 135 | def load_cyls(cylfile, args): 136 | 137 | cyls = pd.read_csv(cylfile, 138 | sep='\t', 139 | names=['radius', 'length', 'sx', 'sy', 'sz', 'ax', 'ay', 'az', 'parent', 'extension', 140 | 'branch', 'BranchOrder', 'PositionInBranch', 'added', 'UnmodRadius']) 141 | 142 | if not args.no_branch: 143 | branch = pd.read_csv(cylfile.replace('cyl', 'branch'), 144 | sep='\t', 145 | names=['BOrd', 'BPar', 'BVol', 'BLen', 'BAng', 'BHei', 'BAzi', 'BDia']) 146 | 147 | branch.set_index(branch.index + 1, inplace=True) # otherwise branches are lablelled from 0 148 | branch_ids = branch[(branch.BLen >= args.min_length) & (branch.BDia >= args.min_radius * 2)].index 149 | 150 | if args.random: 151 | 152 | values = cyls[args.field].unique() 153 | MAP = {V:i for i, V in enumerate(np.random.choice(values, size=len(values), replace=False))} 154 | cyls.loc[:, 'COL'] = cyls[args.field].map(MAP) 155 | args.field = 'COL' 156 | 157 | if args.verbose: print(cyls.head()) 158 | 159 | pandas2ply(cyls, args.field, cylfile[:-4] + '.ply') 160 | 161 | def pandas2ply(cyls, field, out): 162 | 163 | n = len(cyls) 164 | n_vertices = 50 * n 165 | n_faces = 96 * n 166 | 167 | tempvertices = [] 168 | tempfaces = [] 169 | 170 | add = 0 171 | for i, (ix, cyl) in tqdm(enumerate(cyls.iterrows()), total=len(cyls)): 172 | 173 | nvertex = 48 # number of vertices, do not change! 174 | rad = cyl.radius # cylinder radius 175 | l = cyl.length # cylinder length 176 | startp = [cyl.sx, cyl.sy, cyl.sz] # startpoint 177 | axis = [cyl.ax, cyl.ay, cyl.az] # axis relative to startpoint 178 | 179 | # first the cylinder is created without rotation 180 | # starting with center of bottom and top circle 181 | 182 | p1 = [0.0, 0.0, 0.0] 183 | p2 = [0.0, 0.0, l] 184 | 185 | degs = np.deg2rad(np.arange(0, 360, 15)) 186 | ps = [p1,p2] 187 | 188 | # add vertices on the bottom and top circle 189 | for p0 in [p1, p2]: 190 | for deg in degs: 191 | x0 = rad*math.cos(deg)+p0[0] 192 | y0 = rad*math.sin(deg)+p0[1] 193 | z0 = p0[2] 194 | 195 | ps += [[x0,y0,z0]] 196 | 197 | # the following part is adjusted from script in Matlab that does rotation 198 | u = [0,0,1] 199 | raxis = [u[1]*axis[2]-axis[1]*u[2], 200 | u[2]*axis[0]-axis[2]*u[0], 201 | u[0]*axis[1]-axis[0]*u[1]] 202 | 203 | eucl = (axis[0]**2+axis[1]**2+axis[2]**2)**0.5 204 | euclr = (raxis[0]**2+raxis[1]**2+raxis[2]**2)**0.5 205 | 206 | if euclr == 0: euclr = np.nan# not sure why this happens 207 | for i in range(3): 208 | raxis[i] /= euclr 209 | 210 | angle = math.acos(dot(u,axis)/eucl) 211 | 212 | M = rotation_matrix(raxis, angle) 213 | 214 | for i in range(len(ps)): 215 | p = ps[i] 216 | x = p[0]*M[0][0]+p[1]*M[0][1]+p[2]*M[0][2] 217 | y = p[0]*M[1][0]+p[1]*M[1][1]+p[2]*M[1][2] 218 | z = p[0]*M[2][0]+p[1]*M[2][1]+p[2]*M[2][2] 219 | 220 | # add start position 221 | x += startp[0] 222 | y += startp[1] 223 | z += startp[2] 224 | ps[i] = [x,y,z, cyl[field]] 225 | #if np.any(np.isnan([x, y, z])): print(cyl) 226 | 227 | tempvertices += ps 228 | for row in faces: 229 | tempfaces += [[row[0]]+[row[i]+add for i in [1,2,3]]] 230 | 231 | add += 50 232 | 233 | header[4] = "element vertex " + str(n_vertices) 234 | header[8] = "property float {}".format(field) 235 | header[9] = "element face " + str(n_faces) 236 | 237 | with open(out, 'w') as theFile: 238 | for i in header: 239 | theFile.write(i+'\n') 240 | for p in tempvertices: 241 | theFile.write(str(p[0])+' '+str(p[1])+' '+str(p[2])+' '+str(p[3])+'\n') 242 | for f in tempfaces: 243 | #print f 244 | 245 | 246 | theFile.write(str(f[0])+' '+str(f[1])+' '+str(f[2])+' '+str(f[3])+'\n') 247 | 248 | if __name__ == '__main__': 249 | 250 | parser = argparse.ArgumentParser() 251 | parser.add_argument('-c','--cyl', nargs='*', help='list of *cyl.txt files') 252 | parser.add_argument('-f', '--field', default='branch', help='field with which to colour cylinders by') 253 | parser.add_argument('-rc', '--random', default=False, action='store_true', help='randomise colours') 254 | parser.add_argument('-r', '--min_radius', default=0, type=float, help='filter branhces by minimum radius') 255 | parser.add_argument('-l', '--min_length', default=0, type=float, help='filter branches by minimum length') 256 | parser.add_argument('--no_branch', action='store_true', help='use if no corresponding branch file is available') 257 | parser.add_argument('--verbose', action='store_true', help='print some stuff to screen') 258 | args = parser.parse_args() 259 | 260 | for x,line in enumerate(args.cyl): # loops through treelistfile 261 | name = line.split()[0] 262 | load_cyls(name, args) 263 | -------------------------------------------------------------------------------- /treegraph/third_party/cylinder_fitting.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2017, Xingjie Pan 3 | All rights reserved. 4 | 5 | DISCLAIMER: 6 | This is a slightly modified version of the code from cylinder_fitting.py 7 | more specifically fitting.py and geometry.py) authored by Xingjie Pan (2017) 8 | and available at: https://github.com/xingjiepan/cylinder_fitting 9 | """ 10 | 11 | 12 | import numpy as np 13 | from scipy.optimize import minimize 14 | 15 | 16 | def direction(theta, phi): 17 | '''Return the direction vector of a cylinder defined 18 | by the spherical coordinates theta and phi. 19 | ''' 20 | return np.array([np.cos(phi) * np.sin(theta), np.sin(phi) * np.sin(theta), 21 | np.cos(theta)]) 22 | 23 | def projection_matrix(w): 24 | '''Return the projection matrix of a direction w.''' 25 | return np.identity(3) - np.dot(np.reshape(w, (3,1)), np.reshape(w, (1, 3))) 26 | 27 | def skew_matrix(w): 28 | '''Return the skew matrix of a direction w.''' 29 | return np.array([[0, -w[2], w[1]], 30 | [w[2], 0, -w[0]], 31 | [-w[1], w[0], 0]]) 32 | 33 | def calc_A(Ys): 34 | '''Return the matrix A from a list of Y vectors.''' 35 | return sum(np.dot(np.reshape(Y, (3,1)), np.reshape(Y, (1, 3))) 36 | for Y in Ys) 37 | 38 | def calc_A_hat(A, S): 39 | '''Return the A_hat matrix of A given the skew matrix S''' 40 | return np.dot(S, np.dot(A, np.transpose(S))) 41 | 42 | def preprocess_data(Xs_raw): 43 | '''Translate the center of mass (COM) of the data to the origin. 44 | Return the prossed data and the shift of the COM''' 45 | n = len(Xs_raw) 46 | Xs_raw_mean = sum(X for X in Xs_raw) / n 47 | 48 | return [X - Xs_raw_mean for X in Xs_raw], Xs_raw_mean 49 | 50 | def G(w, Xs): 51 | '''Calculate the G function given a cylinder direction w and a 52 | list of data points Xs to be fitted.''' 53 | n = len(Xs) 54 | P = projection_matrix(w) 55 | Ys = [np.dot(P, X) for X in Xs] 56 | A = calc_A(Ys) 57 | A_hat = calc_A_hat(A, skew_matrix(w)) 58 | 59 | 60 | u = sum(np.dot(Y, Y) for Y in Ys) / n 61 | v = np.dot(A_hat, sum(np.dot(Y, Y) * Y for Y in Ys)) / np.trace(np.dot(A_hat, A)) 62 | 63 | return sum((np.dot(Y, Y) - u - 2 * np.dot(Y, v)) ** 2 for Y in Ys) 64 | 65 | def C(w, Xs): 66 | '''Calculate the cylinder center given the cylinder direction and 67 | a list of data points. 68 | ''' 69 | 70 | P = projection_matrix(w) 71 | Ys = [np.dot(P, X) for X in Xs] 72 | A = calc_A(Ys) 73 | A_hat = calc_A_hat(A, skew_matrix(w)) 74 | 75 | return (np.dot(A_hat, sum(np.dot(Y, Y) * Y for Y in Ys)) / 76 | np.trace(np.dot(A_hat, A))) 77 | 78 | def r(w, Xs): 79 | '''Calculate the radius given the cylinder direction and a list 80 | of data points. 81 | ''' 82 | n = len(Xs) 83 | P = projection_matrix(w) 84 | c = C(w, Xs) 85 | 86 | return np.sqrt(sum(np.dot(c - X, np.dot(P, c - X)) for X in Xs) / n) 87 | 88 | def fit(data, guess_angles=None): 89 | '''Fit a list of data points to a cylinder surface. The algorithm 90 | implemented here is from David Eberly's paper "Fitting 3D Data with a 91 | Cylinder" from 92 | https://www.geometrictools.com/Documentation/CylinderFitting.pdf 93 | Arguments: 94 | data - A list of 3D data points to be fitted. 95 | guess_angles[0] - Guess of the theta angle of the axis direction 96 | guess_angles[1] - Guess of the phi angle of the axis direction 97 | 98 | Return: 99 | Direction of the cylinder axis 100 | A point on the cylinder axis 101 | Radius of the cylinder 102 | Fitting error (G function) 103 | ''' 104 | Xs, t = preprocess_data(data) 105 | 106 | # Set the start points 107 | 108 | start_points = [(0, 0)] #, (np.pi / 2, 0), (np.pi / 2, np.pi / 2)] 109 | if guess_angles: 110 | start_points = guess_angles 111 | 112 | # Fit the cylinder from different start points 113 | 114 | best_fit = None 115 | best_score = float('inf') 116 | 117 | # for sp in start_points: 118 | method, tol = 'Powell', 1e-6 119 | # print(method, tol) 120 | fitted = minimize(lambda x : G(direction(x[0], x[1]), Xs), 121 | (0,0), method=method, tol=tol) 122 | 123 | # if fitted.fun < best_score: 124 | best_score = fitted.fun 125 | best_fit = fitted 126 | 127 | w = direction(best_fit.x[0], best_fit.x[1]) 128 | 129 | return w, C(w, Xs) + t, r(w, Xs), best_fit.fun 130 | 131 | 132 | def normalize(v): 133 | '''Normalize a vector based on its 2 norm.''' 134 | if 0 == np.linalg.norm(v): 135 | return v 136 | return v / np.linalg.norm(v) 137 | 138 | def rotation_matrix_from_axis_and_angle(u, theta): 139 | '''Calculate a rotation matrix from an axis and an angle.''' 140 | 141 | x = u[0] 142 | y = u[1] 143 | z = u[2] 144 | s = np.sin(theta) 145 | c = np.cos(theta) 146 | 147 | return np.array([[c + x**2 * (1 - c), x * y * (1 - c) - z * s, x * z * (1 - c) + y * s], 148 | [y * x * (1 - c) + z * s, c + y**2 * (1 - c), y * z * (1 - c) - x * s ], 149 | [z * x * (1 - c) - y * s, z * y * (1 - c) + x * s, c + z**2 * (1 - c) ]]) 150 | 151 | def point_line_distance(p, l_p, l_v): 152 | '''Calculate the distance between a point and a line defined 153 | by a point and a direction vector. 154 | ''' 155 | l_v = normalize(l_v) 156 | u = p - l_p 157 | return np.linalg.norm(u - np.dot(u, l_v) * l_v) 158 | -------------------------------------------------------------------------------- /treegraph/third_party/ply_io.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import sys 4 | 5 | def read_ply(fp, newline=None): 6 | 7 | line = open(fp, encoding='ISO-8859-1').readline() 8 | newline = '\n' if line == 'ply\n' else None 9 | 10 | return read_ply_(fp, newline) 11 | 12 | def read_ply_(fp, newline): 13 | 14 | open_file = open(fp, 15 | encoding='ISO-8859-1', 16 | newline=newline) 17 | 18 | with open_file as ply: 19 | 20 | length = 0 21 | prop = [] 22 | dtype_map = {'uint16':'uint16', 'uint8':'uint8', 'double':'d', 'float64':'f8', 23 | 'float32':'f4', 'float': 'f4', 'uchar': 'B', 'int':'i'} 24 | dtype = [] 25 | fmt = 'binary' 26 | 27 | for i, line in enumerate(ply.readlines()): 28 | length += len(line) 29 | if i == 1: 30 | if 'ascii' in line: 31 | fmt = 'ascii' 32 | if 'element vertex' in line: N = int(line.split()[2]) 33 | if 'property' in line: 34 | dtype.append(dtype_map[line.split()[1]]) 35 | prop.append(line.split()[2]) 36 | if 'element face' in line: 37 | raise Exception('.ply appears to be a mesh') 38 | if 'end_header' in line: break 39 | 40 | ply.seek(length) 41 | 42 | if fmt == 'binary': 43 | arr = np.fromfile(ply, dtype=','.join(dtype)) 44 | else: 45 | arr = np.loadtxt(ply) 46 | df = pd.DataFrame(data=arr) 47 | df.columns = prop 48 | 49 | return df 50 | 51 | def write_ply(output_name, pc, comments=[]): 52 | 53 | cols = ['x', 'y', 'z'] 54 | pc[['x', 'y', 'z']] = pc[['x', 'y', 'z']].astype('f8') 55 | 56 | with open(output_name, 'w') as ply: 57 | 58 | ply.write("ply\n") 59 | ply.write('format binary_little_endian 1.0\n') 60 | ply.write("comment Author: Phil Wilkes\n") 61 | for comment in comments: 62 | ply.write("comment {}\n".format(comment)) 63 | ply.write("obj_info generated with pcd2ply.py\n") 64 | ply.write("element vertex {}\n".format(len(pc))) 65 | ply.write("property float64 x\n") 66 | ply.write("property float64 y\n") 67 | ply.write("property float64 z\n") 68 | if 'red' in pc.columns: 69 | cols += ['red', 'green', 'blue'] 70 | pc[['red', 'green', 'blue']] = pc[['red', 'green', 'blue']].astype('i') 71 | ply.write("property int red\n") 72 | ply.write("property int green\n") 73 | ply.write("property int blue\n") 74 | for col in pc.columns: 75 | if col in cols: continue 76 | try: 77 | pc[col] = pc[col].astype('f8') 78 | ply.write("property float64 {}\n".format(col)) 79 | cols += [col] 80 | except: 81 | pass 82 | ply.write("end_header\n") 83 | 84 | with open(output_name, 'ab') as ply: 85 | ply.write(pc[cols].to_records(index=False).tobytes()) 86 | 87 | if __name__ == '__main__': 88 | 89 | import sys 90 | print(read_ply(sys.argv[1]).head()) -------------------------------------------------------------------------------- /treegraph/third_party/point2line.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # copied from https://stackoverflow.com/a/50728570/1414831, thank you to Hans Musgrave 4 | 5 | def t(p, q, r): 6 | x = p-q 7 | return np.dot(r-q, x)/np.dot(x, x) 8 | 9 | def d(p, q, rs): 10 | x = p - q 11 | return np.linalg.norm(np.outer(np.dot(rs-q, x)/np.dot(x, x), x)+q-rs, axis=1) -------------------------------------------------------------------------------- /treegraph/third_party/ransac_cyl_fit.py: -------------------------------------------------------------------------------- 1 | # This is from https://github.com/philwilkes/FSCT/blob/main/fsct/fit_cylinders.py 2 | 3 | from sklearn.decomposition import PCA 4 | from scipy import optimize 5 | from scipy.spatial.transform import Rotation 6 | from scipy.stats import variation 7 | import numpy as np 8 | 9 | from matplotlib.patches import Circle 10 | import matplotlib.pyplot as plt 11 | 12 | from tqdm.auto import tqdm 13 | 14 | 15 | def other_cylinder_fit2(xyz, xm=0, ym=0, xr=0, yr=0, r=1): 16 | 17 | from scipy.optimize import leastsq 18 | 19 | """ 20 | https://stackoverflow.com/a/44164662/1414831 21 | 22 | This is a fitting for a vertical cylinder fitting 23 | Reference: 24 | http://www.int-arch-photogramm-remote-sens-spatial-inf-sci.net/XXXIX-B5/169/2012/isprsarchives-XXXIX-B5-169-2012.pdf 25 | xyz is a matrix contain at least 5 rows, and each row stores x y z of a cylindrical surface 26 | p is initial values of the parameter; 27 | p[0] = Xc, x coordinate of the cylinder centre 28 | P[1] = Yc, y coordinate of the cylinder centre 29 | P[2] = alpha, rotation angle (radian) about the x-axis 30 | P[3] = beta, rotation angle (radian) about the y-axis 31 | P[4] = r, radius of the cylinder 32 | th, threshold for the convergence of the least squares 33 | """ 34 | 35 | x = xyz.x 36 | y = xyz.y 37 | z = xyz.z 38 | 39 | p = np.array([xm, ym, xr, yr, r]) 40 | 41 | fitfunc = lambda p, x, y, z: (- np.cos(p[3])*(p[0] - x) - z*np.cos(p[2])*np.sin(p[3]) - np.sin(p[2])*np.sin(p[3])*(p[1] - y))**2 + (z*np.sin(p[2]) - np.cos(p[2])*(p[1] - y))**2 #fit function 42 | errfunc = lambda p, x, y, z: fitfunc(p, x, y, z) - p[4]**2 #error function 43 | 44 | est_p, success = leastsq(errfunc, p, args=(x, y, z), maxfev=1000) 45 | 46 | return est_p 47 | 48 | def RANSACcylinderFitting4(xyz_, iterations=50, plot=False): 49 | 50 | if plot: 51 | ax = plt.subplot(111) 52 | 53 | bestFit, bestErr = None, np.inf 54 | xyz_mean = xyz_.mean(axis=0) 55 | xyz_ -= xyz_mean 56 | 57 | # for i in tqdm(range(iterations), total=iterations, display=plot): 58 | for i in range(iterations): 59 | 60 | xyz = xyz_.copy() 61 | 62 | # prepare sample 63 | sample = xyz.sample(n=20) 64 | # sample = xyz.sample(n=max(10, int(len(xyz)*.2))) 65 | xyz = xyz.loc[~xyz.index.isin(sample.index)] 66 | 67 | x, y, a, b, radius = other_cylinder_fit2(sample, 0, 0, 0, 0, 0) 68 | centre = (x, y) 69 | if not np.all(np.isclose(centre, 0, atol=radius*1.05)): continue 70 | 71 | MX = Rotation.from_euler('xy', [a, b]).inv() 72 | xyz[['x', 'y', 'z']] = MX.apply(xyz) 73 | xyz.loc[:, 'error'] = np.linalg.norm(xyz[['x', 'y']] - centre, axis=1) / radius 74 | idx = xyz.loc[xyz.error.between(.8, 1.2)].index # 40% of radius is prob quite large 75 | 76 | # select points which best fit model from original dataset 77 | alsoInliers = xyz_.loc[idx].copy() 78 | if len(alsoInliers) < len(xyz_) * .2: continue # skip if no enough points chosen 79 | 80 | # refit model using new params 81 | x, y, a, b, radius = other_cylinder_fit2(alsoInliers, x, y, a, b, radius) 82 | centre = [x, y] 83 | if not np.all(np.isclose(centre, 0, atol=radius*1.05)): continue 84 | 85 | MX = Rotation.from_euler('xy', [a, b]).inv() 86 | alsoInliers[['x', 'y', 'z']] = MX.apply(alsoInliers[['x', 'y', 'z']]) 87 | # calculate error for "best" subset 88 | alsoInliers.loc[:, 'error'] = np.linalg.norm(alsoInliers[['x', 'y']] - centre, axis=1) / radius 89 | 90 | if variation(alsoInliers.error) < bestErr: 91 | 92 | # for testing uncomment 93 | c = Circle(centre, radius=radius, facecolor='none', edgecolor='g') 94 | 95 | bestFit = [radius, centre, c, alsoInliers, MX] 96 | bestErr = variation(alsoInliers.error) 97 | 98 | if bestFit == None: 99 | # usually caused by low number of ransac iterations 100 | return np.nan, xyz[['x', 'y', 'z']].mean(axis=0).values, np.inf, len(xyz_) 101 | 102 | radius, centre, c, alsoInliers, MX = bestFit 103 | centre[0] += xyz_mean.x 104 | centre[1] += xyz_mean.y 105 | centre = centre + [xyz_mean.z] 106 | 107 | # for testing uncomment 108 | if plot: 109 | 110 | radius, Centre, c, alsoInliers, MX = bestFit 111 | 112 | xyz_[['x', 'y', 'z']] = MX.apply(xyz_) 113 | xyz_ += xyz_mean 114 | ax.scatter(xyz_.x, xyz_.y, s=1, c='grey') 115 | 116 | alsoInliers[['x', 'y', 'z']] += xyz_mean 117 | cbar = ax.scatter(alsoInliers.x, alsoInliers.y, s=10, c=alsoInliers.error) 118 | plt.colorbar(cbar) 119 | 120 | ax.scatter(Centre[0], Centre[1], marker='+', s=100, c='r') 121 | ax.add_patch(c) 122 | 123 | return [radius, centre, bestErr, len(xyz_)] 124 | 125 | def NotRANSAC(xyz): 126 | 127 | try: 128 | xyz = xyz[['x', 'y', 'z']] 129 | pca = PCA(n_components=3, svd_solver='auto').fit(xyz) 130 | xyz[['x', 'y', 'z']] = pca.transform(xyz) 131 | radius, centre = other_cylinder_fit2(xyz) 132 | 133 | if xyz.z.min() - radius < centre[0] < xyz.z.max() + radius or \ 134 | xyz.y.min() - radius < centre[1] < xyz.y.max() + radius: 135 | centre = np.hstack([xyz.x.mean(), centre]) 136 | else: 137 | centre = xyz.mean().values 138 | 139 | centre = pca.inverse_transform(centre) 140 | except: 141 | radius, centre = np.nan, xyz[['x', 'y', 'z']].mean(axis=0).values 142 | 143 | return [radius, centre, np.inf, len(xyz)] 144 | 145 | def RANSAC_helper(xyz, ransac_iterations, plot=False): 146 | 147 | # try: 148 | if len(xyz) == 0: # don't think this is required but.... 149 | cylinder = [np.nan, np.array([np.inf, np.inf, np.inf]), np.inf, len(xyz)] 150 | elif len(xyz) < 10: 151 | cylinder = [np.nan, xyz[['x', 'y', 'z']].mean(axis=0).values, np.inf, len(xyz)] 152 | elif len(xyz) < 50: 153 | cylinder = NotRANSAC(xyz) 154 | else: 155 | cylinder = RANSACcylinderFitting4(xyz[['x', 'y', 'z']], iterations=ransac_iterations, plot=plot) 156 | # if cylinder == None: # again not sure if this is necessary... 157 | # cylinder = [np.nan, xyz[['x', 'y', 'z']].mean(axis=0)] 158 | 159 | # except: 160 | # cylinder = [np.nan, xyz[['x', 'y', 'z']].mean(axis=0), np.inf, np.inf] 161 | 162 | return cylinder -------------------------------------------------------------------------------- /treegraph/third_party/shortpath.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2019, Matheus Boni Vicari, pc2graph 2 | # All rights reserved. 3 | # 4 | # 5 | # This program is free software: you can redistribute it and/or modify 6 | # it under the terms of the GNU General Public License as published by 7 | # the Free Software Foundation, either version 3 of the License, or 8 | # (at your option) any later version. 9 | # 10 | # This program is distributed in the hope that it will be useful, 11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | # GNU General Public License for more details. 14 | # 15 | # You should have received a copy of the GNU General Public License 16 | # along with this program. If not, see . 17 | 18 | __author__ = "Matheus Boni Vicari" 19 | __copyright__ = "Copyright 2018-2019" 20 | __credits__ = ["Matheus Boni Vicari"] 21 | __license__ = "GPL3" 22 | __version__ = "1.0.1" 23 | __maintainer__ = "Matheus Boni Vicari" 24 | __email__ = "matheus.boni.vicari@gmail.com" 25 | __status__ = "Development" 26 | 27 | import networkx as nx 28 | import numpy as np 29 | from sklearn.neighbors import NearestNeighbors 30 | 31 | 32 | def array_to_graph(arr, base_id, kpairs, knn, nbrs_threshold, 33 | nbrs_threshold_step, graph_threshold=np.inf, 34 | return_step=False): 35 | 36 | """ 37 | Converts a numpy.array of points coordinates into a Weighted BiDirectional 38 | NetworkX Graph. 39 | This funcions uses a NearestNeighbor search to determine points adajency. 40 | The NNsearch results are used to select pairs of points (or nodes) that 41 | have a common edge. 42 | Parameters 43 | ---------- 44 | arr : array 45 | n-dimensional array of points. 46 | base_id : 47 | index of base id (root) in the graph. 48 | kpairs : int 49 | number of points around each point in arr to select in order to build 50 | edges. 51 | knn : int 52 | Number of neighbors to search around each point in the neighborhood 53 | phase. The higher the better (careful, it's memory intensive). 54 | nbrs_threshold : float 55 | Maximum valid distance between neighbors points. 56 | nbrs_threshold_step : float 57 | Distance increment used in the final phase of edges generation. It's 58 | used to make sure that in the end, every point in arr will be 59 | translated to nodes in the graph. 60 | graph_threshold : float 61 | Maximum distance between pairs of nodes (edge distance) accepted in 62 | the graph generation. 63 | return_step : bool 64 | Option to select if function should output the step register, which 65 | can be used to debug the creationg of graph 'G'. 66 | Returns 67 | ------- 68 | G : networkx graph 69 | Graph containing all points in 'arr' as nodes. 70 | step_register : array 71 | 1D array with the same number of entries as 'arr'. Stores the step 72 | number of which each point in 'arr' was added to 'G'. 73 | 74 | """ 75 | 76 | # Initializing graph. 77 | G = nx.Graph() 78 | 79 | # Generating array of all indices from 'arr' and all indices to process 80 | # 'idx'. 81 | idx_base = np.arange(arr.shape[0], dtype=int) 82 | idx = np.arange(arr.shape[0], dtype=int) 83 | 84 | # Initializing NearestNeighbors search and searching for all 'knn' 85 | # neighboring points arround each point in 'arr'. 86 | nbrs = NearestNeighbors(n_neighbors=knn, metric='euclidean', 87 | leaf_size=15, n_jobs=-1).fit(arr[['x','y','z']]) 88 | distances, indices = nbrs.kneighbors(arr[['x','y','z']]) 89 | indices = indices.astype(int) 90 | 91 | # Initializing variables for current ids being processed (current_idx) 92 | # and all ids already processed (processed_idx). 93 | current_idx = [base_id] 94 | processed_idx = [base_id] 95 | 96 | # Setting up the register of at which step each point was added to the 97 | # graph. 98 | step_register = np.full(arr.shape[0], np.nan) 99 | current_step = 0 100 | step_register[base_id] = current_step 101 | 102 | # Looping while there are still indices (idx) left to process. 103 | # while idx.shape[0] > 0: 104 | # wx updated to avoid infinite loop 105 | while (idx.shape[0] > 0) & (nbrs_threshold < .5): 106 | 107 | # Increasing a single step count. 108 | current_step += 1 109 | 110 | # If current_idx is a list containing several indices. 111 | if len(current_idx) > 0: 112 | 113 | # Selecting NearestNeighbors indices and distances for current 114 | # indices being processed. 115 | nn = indices[current_idx] 116 | dd = distances[current_idx] 117 | 118 | # Masking out indices already contained in processed_idx. 119 | mask1 = np.in1d(nn, processed_idx, invert=True).reshape(nn.shape) 120 | 121 | # Initializing temporary list of nearest neighbors. This list 122 | # is latter used to accumulate points that will be added to 123 | # processed points list. 124 | nntemp = [] 125 | 126 | # Looping over current indices's set of nn points and selecting 127 | # knn points that hasn't been added/processed yet (mask1). 128 | for i, (n, d, g) in enumerate(zip(nn, dd, current_idx)): 129 | nn_idx = n[mask1[i]][0:kpairs+1] 130 | dd_idx = d[mask1[i]][0:kpairs+1] 131 | nntemp.append(nn_idx) 132 | 133 | # wx adds: add attributes to the node 134 | G.add_node(g, pos=[float(arr.x[g]), float(arr.y[g]), float(arr.z[g])]) # coordinates 135 | if 'pid' in arr.columns: 136 | G.add_node(g, pid=(int(arr.pid[g]))) # index in original point cloud 137 | 138 | # Adding current knn selected points as nodes to graph G. 139 | add_nodes(G, g, nn_idx, dd_idx, graph_threshold) 140 | 141 | # Obtaining an unique array of points currently being processed. 142 | current_idx = np.unique([t2 for t1 in nntemp for t2 in t1]) 143 | 144 | # If current_idx is an empty list. 145 | elif len(current_idx) == 0: 146 | 147 | # Getting NearestNeighbors indices and distance for all indices 148 | # that remain to be processed. 149 | idx2 = indices[idx] 150 | dist2 = distances[idx] 151 | 152 | # Masking indices in idx2 that have already been processed. The 153 | # idea is to connect remaining points to existing graph nodes. 154 | mask1 = np.in1d(idx2, processed_idx).reshape(idx2.shape) 155 | # Masking neighboring points that are withing threshold distance. 156 | mask2 = dist2 < nbrs_threshold 157 | # mask1 AND mask2. This will mask only indices that are part of 158 | # the graph and within threshold distance. 159 | mask = np.logical_and(mask1, mask2) 160 | 161 | # Getting unique array of indices that match the criteria from 162 | # mask1 and mask2. 163 | temp_idx = np.unique(np.where(mask)[0]) 164 | # Assigns remaining indices (idx) matched in temp_idx to 165 | # current_idx. 166 | current_idx = idx[temp_idx] 167 | 168 | # Selecting NearestNeighbors indices and distances for current 169 | # indices being processed. 170 | nn = indices[current_idx] 171 | dd = distances[current_idx] 172 | 173 | # Masking points in nn that have already been processed. 174 | # This is the oposite approach as above, where points that are 175 | # still not in the graph are desired. Now, to make sure the 176 | # continuity of the graph is kept, join current remaining indices 177 | # to indices already in G. 178 | mask = np.in1d(nn, processed_idx, invert=True).reshape(nn.shape) 179 | 180 | # Initializing temporary list of nearest neighbors. This list 181 | # is latter used to accumulate points that will be added to 182 | # processed points list. 183 | nntemp = [] 184 | 185 | # Looping over current indices's set of nn points and selecting 186 | # knn points that have alreay been added/processed (mask). 187 | # Also, to ensure continuity over next iteration, select another 188 | # kpairs points from indices that haven't been processed (~mask). 189 | for i, (n, d, g) in enumerate(zip(nn, dd, current_idx)): 190 | nn_idx = n[mask[i]][0:kpairs+1] 191 | dd_idx = d[mask[i]][0:kpairs+1] 192 | 193 | # wx adds: add attributes to the node 194 | G.add_node(g, pos=[float(arr.x[g]), float(arr.y[g]), float(arr.z[g])]) # coordinates 195 | if 'pid' in arr.columns: 196 | G.add_node(g, pid=(int(arr.pid[g]))) # index in original point cloud 197 | 198 | # Adding current knn selected points as nodes to graph G. 199 | add_nodes(G, g, nn_idx, dd_idx, graph_threshold) 200 | 201 | nn_idx = n[~mask[i]][0:kpairs+1] 202 | dd_idx = d[~mask[i]][0:kpairs+1] 203 | 204 | # Adding current knn selected points as nodes to graph G. 205 | add_nodes(G, g, nn_idx, dd_idx, graph_threshold) 206 | 207 | # Check if current_idx is still empty. If so, increase the 208 | # nbrs_threshold to try to include more points in the next 209 | # iteration. 210 | if len(current_idx) == 0: 211 | nbrs_threshold += nbrs_threshold_step 212 | 213 | # Appending current_idx to processed_idx. 214 | processed_idx = np.append(processed_idx, current_idx) 215 | processed_idx = np.unique(processed_idx).astype(int) 216 | 217 | # Generating list of remaining proints to process. 218 | idx = idx_base[np.in1d(idx_base, processed_idx, invert=True)] 219 | 220 | # Adding new nodes to the step register. 221 | current_idx = np.array(current_idx).astype(int) 222 | step_register[current_idx] = current_step 223 | 224 | if return_step is True: 225 | return G, step_register 226 | else: 227 | return G 228 | 229 | 230 | def extract_path_info(G, base_id, return_path=True): 231 | 232 | """ 233 | Extracts shortest path information from a NetworkX graph. 234 | Parameters 235 | ---------- 236 | G : networkx graph 237 | NetworkX graph object from which to extract the information. 238 | base_id : int 239 | Base (root) node id to calculate the shortest path for all other 240 | nodes. 241 | return_path : boolean 242 | Option to select if function should output path list for every node 243 | in G to base_id. 244 | Returns 245 | ------- 246 | nodes_ids : list 247 | Indices of all nodes in graph G. 248 | distance : list 249 | Shortest path distance (accumulated) from all nodes in G to base_id 250 | node. 251 | path_list : dict 252 | Dictionary of nodes that comprises the path of every node in G to 253 | base_id node. 254 | """ 255 | 256 | # Calculating the shortest path 257 | shortpath = nx.single_source_dijkstra_path_length(G, base_id) 258 | 259 | # Obtaining the node coordinates and their respective distance from 260 | # the base point. 261 | nodes_ids = shortpath.keys() 262 | distance = shortpath.values() 263 | 264 | # Checking if the function should also return the paths of each node and 265 | # if so, generating the path list and returning it. 266 | if return_path is True: 267 | path_list = nx.single_source_dijkstra_path(G, base_id) 268 | return nodes_ids, distance, path_list 269 | 270 | elif return_path is False: 271 | return nodes_ids, distance 272 | 273 | 274 | def add_nodes(G, base_node, indices, distance, threshold): 275 | 276 | """ 277 | Adds a set of nodes and weighted edges based on pairs of indices 278 | between base_node and all entries in indices. Each node pair shares an 279 | edge with weight equal to the distance between both nodes. 280 | Parameters 281 | ---------- 282 | G : networkx graph 283 | NetworkX graph object to which all nodes/edges will be added. 284 | base_node : int 285 | Base node's id to be added. All other nodes will be paired with 286 | base_node to form different edges. 287 | indices : list or array 288 | Set of nodes indices to be paired with base_node. 289 | distance : list or array 290 | Set of distances between all nodes in 'indices' and base_node. 291 | threshold : float 292 | Edge distance threshold. All edges with distance larger than 293 | 'threshold' will not be added to G. 294 | """ 295 | 296 | for c in np.arange(len(indices)): 297 | if distance[c] <= threshold: 298 | # If the distance between vertices is less than a given 299 | # threshold, add edge (i[0], i[c]) to Graph. 300 | G.add_weighted_edges_from([(base_node, indices[c], 301 | float(distance[c]))]) --------------------------------------------------------------------------------