├── MANIFEST.in ├── requirements.txt ├── data ├── img2d.png ├── 2d_example.png ├── ISIC_546.jpg ├── img3d.nii.gz ├── ranster_scan.png ├── image3d_dis1.nii.gz ├── image3d_dis2.nii.gz ├── image3d_dis3.nii.gz └── image3d_sub.nii.gz ├── conda ├── conda_build_config.yaml └── meta.yaml ├── cpp ├── wrap_py2.cpp ├── wrap_py3.cpp ├── geodesic_distance_2d.h ├── geodesic_distance_3d.h ├── util.h ├── util.cpp ├── geodesic_distance.cpp ├── geodesic_distance_2d.cpp └── geodesic_distance_3d.cpp ├── .travis.yml ├── LICENSE ├── setup.py ├── demo2d.py ├── demo3d.py └── README.md /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include cpp/*h 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.19.1 2 | -------------------------------------------------------------------------------- /data/img2d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taigw/GeodisTK/HEAD/data/img2d.png -------------------------------------------------------------------------------- /conda/conda_build_config.yaml: -------------------------------------------------------------------------------- 1 | python: 2 | - 3.6 3 | - 3.7 4 | - 3.8 5 | -------------------------------------------------------------------------------- /data/2d_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taigw/GeodisTK/HEAD/data/2d_example.png -------------------------------------------------------------------------------- /data/ISIC_546.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taigw/GeodisTK/HEAD/data/ISIC_546.jpg -------------------------------------------------------------------------------- /data/img3d.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taigw/GeodisTK/HEAD/data/img3d.nii.gz -------------------------------------------------------------------------------- /data/ranster_scan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taigw/GeodisTK/HEAD/data/ranster_scan.png -------------------------------------------------------------------------------- /data/image3d_dis1.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taigw/GeodisTK/HEAD/data/image3d_dis1.nii.gz -------------------------------------------------------------------------------- /data/image3d_dis2.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taigw/GeodisTK/HEAD/data/image3d_dis2.nii.gz -------------------------------------------------------------------------------- /data/image3d_dis3.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taigw/GeodisTK/HEAD/data/image3d_dis3.nii.gz -------------------------------------------------------------------------------- /data/image3d_sub.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taigw/GeodisTK/HEAD/data/image3d_sub.nii.gz -------------------------------------------------------------------------------- /cpp/wrap_py2.cpp: -------------------------------------------------------------------------------- 1 | #include "geodesic_distance.cpp" 2 | 3 | PyMODINIT_FUNC 4 | initGeodisTK(void) { 5 | (void) Py_InitModule("GeodisTK", Methods); 6 | import_array(); 7 | } 8 | -------------------------------------------------------------------------------- /cpp/wrap_py3.cpp: -------------------------------------------------------------------------------- 1 | #include "geodesic_distance.cpp" 2 | 3 | 4 | static struct PyModuleDef cGeosDis = 5 | { 6 | PyModuleDef_HEAD_INIT, 7 | "GeodisTK", /* name of module */ 8 | "", /* module documentation, may be NULL */ 9 | -1, /* size of per-interpreter state of the module, or -1 if the module keeps state in global variables. */ 10 | Methods 11 | }; 12 | 13 | 14 | PyMODINIT_FUNC PyInit_GeodisTK(void) { 15 | import_array(); 16 | return PyModule_Create(&cGeosDis); 17 | } 18 | -------------------------------------------------------------------------------- /cpp/geodesic_distance_2d.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | using namespace std; 4 | 5 | struct Point2D 6 | { 7 | float distance; 8 | int w; 9 | int h; 10 | }; 11 | 12 | float get_l2_distance(std::vector p1, std::vector p2); 13 | 14 | void geodesic2d_fast_marching(const float * img, const unsigned char * seeds, float * distance, 15 | int height, int width, int channel); 16 | 17 | void geodesic2d_raster_scan(const float * img, const unsigned char * seeds, float * distance, 18 | int height, int width, int channel, float lambda, int iteration); 19 | 20 | -------------------------------------------------------------------------------- /conda/meta.yaml: -------------------------------------------------------------------------------- 1 | {% set name = 'geodistk' %} 2 | {% set version = '0.1.6' %} 3 | 4 | package: 5 | name: {{ name | lower }} 6 | version: {{ version }} 7 | 8 | source: 9 | git_url: ../ 10 | 11 | requirements: 12 | build: 13 | - python {{ python }} 14 | - numpy 15 | run: 16 | - python {{ python }} 17 | - numpy 18 | 19 | build: 20 | number: 0 21 | script: 22 | - {{ PYTHON }} -m pip install . 23 | 24 | test: 25 | imports: 26 | - GeodisTK 27 | 28 | app: 29 | summary: "Geodesic distance transform of 2d/3d images" 30 | 31 | about: 32 | home: https://github.com/taigw/GeodisTK 33 | license: MIT 34 | license_file: LICENSE 35 | -------------------------------------------------------------------------------- /cpp/geodesic_distance_3d.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | using namespace std; 4 | 5 | struct Point3D 6 | { 7 | float distance; 8 | int d; 9 | int w; 10 | int h; 11 | }; 12 | 13 | void geodesic3d_fast_marching(const float * img, const unsigned char * seeds, float * distance, 14 | int depth, int height, int width, int channel, std::vector spacing); 15 | 16 | void geodesic3d_raster_scan(const float * img, const unsigned char * seeds, float * distance, 17 | int depth, int height, int width, int channel, 18 | std::vector spacing, float lambda, int iteration); 19 | -------------------------------------------------------------------------------- /cpp/util.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | using namespace std; 4 | 5 | // for 2d images 6 | template 7 | T get_pixel(const T * data, int height, int width, int h, int w); 8 | 9 | template 10 | std::vector get_pixel_vector(const T * data, int height, int width, int channel, int h, int w); 11 | 12 | template 13 | void set_pixel(T * data, int height, int width, int h, int w, T value); 14 | 15 | // for 3d images 16 | template 17 | T get_pixel(const T * data, int depth, int height, int width, int d, int h, int w); 18 | 19 | template 20 | std::vector get_pixel_vector(const T * data, int depth, int height, int width, int channel, int d, int h, int w); 21 | 22 | template 23 | void set_pixel(T * data, int depth, int height, int width, int d, int h, int w, T value); -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: c 2 | 3 | # Makes Travis clone the full repo so that the conda package can get the 4 | # correct string from git describe and correctly label the package 5 | git: 6 | depth: false 7 | 8 | dist: bionic 9 | 10 | env: 11 | global: 12 | - MINICONDA_DIR=$HOME/miniconda 13 | 14 | before_script: 15 | # download miniconda 16 | - wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O $HOME/miniconda.sh 17 | - bash $HOME/miniconda.sh -b -p $MINICONDA_DIR 18 | - source $MINICONDA_DIR/bin/activate $MINICONDA_DIR 19 | - conda config --set always_yes yes 20 | - conda config --set anaconda_upload yes 21 | - conda install anaconda conda-build 22 | 23 | 24 | jobs: 25 | include: 26 | - stage: build 27 | name: Build 28 | if: branch != master 29 | script: 30 | - conda build ./conda 31 | 32 | - stage: build and deploy 33 | name: Build and Deploy 34 | if: branch = master 35 | script: 36 | - conda build ./conda --user $UPLOAD_USER --token $ANACONDA_API_TOKEN 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Guotai Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import setuptools 4 | from distutils.core import setup 5 | from distutils.extension import Extension 6 | 7 | package_name = 'GeodisTK' 8 | module_name = 'GeodisTK' 9 | version = sys.version[0] 10 | wrap_source = './cpp/wrap_py{0:}.cpp'.format(version) 11 | module1 = Extension(module_name, 12 | include_dirs = [np.get_include(), './cpp'], 13 | sources = ['./cpp/util.cpp', 14 | './cpp/geodesic_distance_2d.cpp', 15 | './cpp/geodesic_distance_3d.cpp', 16 | './cpp/geodesic_distance.cpp', 17 | wrap_source]) 18 | 19 | # Get the summary 20 | description = 'An open-source toolkit to calculate geodesic distance' + \ 21 | ' for 2D and 3D images' 22 | 23 | # Get the long description 24 | if(sys.version[0] == '2'): 25 | import io 26 | with io.open('README.md', 'r', encoding='utf-8') as f: 27 | long_description = f.read() 28 | else: 29 | with open('README.md', encoding='utf-8') as f: 30 | long_description = f.read() 31 | 32 | setup( 33 | name = package_name, 34 | version = "0.1.7", 35 | author ='Guotai Wang', 36 | author_email = 'wguotai@gmail.com', 37 | description = description, 38 | long_description = long_description, 39 | long_description_content_type = 'text/markdown', 40 | url = 'https://github.com/taigw/GeodisTK', 41 | license = 'MIT', 42 | packages = setuptools.find_packages(), 43 | ext_modules = [module1], 44 | classifiers=[ 45 | 'License :: OSI Approved :: MIT License', 46 | 'Programming Language :: Python', 47 | 'Programming Language :: Python :: 2', 48 | 'Programming Language :: Python :: 3', 49 | ], 50 | python_requires = '>=3.6', 51 | ) 52 | 53 | # to build, run python setup.py build or python setup.py build_ext --inplace 54 | # to install, run python setup.py install 55 | -------------------------------------------------------------------------------- /demo2d.py: -------------------------------------------------------------------------------- 1 | import GeodisTK 2 | import numpy as np 3 | import time 4 | from PIL import Image 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | def geodesic_distance_2d(I, S, lamb, iter): 9 | ''' 10 | get 2d geodesic disntance by raser scanning. 11 | I: input image, can have multiple channels. Type should be np.float32. 12 | S: binary image where non-zero pixels are used as seeds. Type should be np.uint8. 13 | lamb: weighting betwween 0.0 and 1.0 14 | if lamb==0.0, return spatial euclidean distance without considering gradient 15 | if lamb==1.0, the distance is based on gradient only without using spatial distance 16 | iter: number of iteration for raster scanning. 17 | ''' 18 | return GeodisTK.geodesic2d_raster_scan(I, S, lamb, iter) 19 | 20 | def demo_geodesic_distance2d(img, seed_pos): 21 | I = np.asanyarray(img, np.float32) 22 | S = np.zeros((I.shape[0], I.shape[1]), np.uint8) 23 | S[seed_pos[0]][seed_pos[1]] = 1 24 | t0 = time.time() 25 | D1 = GeodisTK.geodesic2d_fast_marching(I,S) 26 | t1 = time.time() 27 | D2 = geodesic_distance_2d(I, S, 1.0, 2) 28 | dt1 = t1 - t0 29 | dt2 = time.time() - t1 30 | D3 = geodesic_distance_2d(I, S, 0.0, 2) 31 | D4 = geodesic_distance_2d(I, S, 0.5, 2) 32 | print("runtime(s) of fast marching {0:}".format(dt1)) 33 | print("runtime(s) of raster scan {0:}".format(dt2)) 34 | 35 | plt.figure(figsize=(15,5)) 36 | plt.subplot(1,5,1); plt.imshow(img) 37 | plt.autoscale(False); plt.plot([seed_pos[0]], [seed_pos[1]], 'ro') 38 | plt.axis('off'); plt.title('(a) input image \n with a seed point') 39 | 40 | plt.subplot(1,5,2); plt.imshow(D1) 41 | plt.axis('off'); plt.title('(b) Geodesic distance \n based on fast marching') 42 | 43 | plt.subplot(1,5,3); plt.imshow(D2) 44 | plt.axis('off'); plt.title('(c) Geodesic distance \n based on ranster scan') 45 | 46 | plt.subplot(1,5,4); plt.imshow(D3) 47 | plt.axis('off'); plt.title('(d) Euclidean distance') 48 | 49 | plt.subplot(1,5,5); plt.imshow(D4) 50 | plt.axis('off'); plt.title('(e) Mexture of Geodesic \n and Euclidean distance') 51 | plt.show() 52 | 53 | def demo_geodesic_distance2d_gray_scale_image(): 54 | img = Image.open('data/img2d.png').convert('L') 55 | seed_position = [100, 100] 56 | demo_geodesic_distance2d(img, seed_position) 57 | 58 | def demo_geodesic_distance2d_RGB_image(): 59 | img = Image.open('data/ISIC_546.jpg') 60 | seed_position = [128, 128] 61 | demo_geodesic_distance2d(img, seed_position) 62 | 63 | if __name__ == '__main__': 64 | print("example list") 65 | print(" 0 -- example for gray scale image") 66 | print(" 1 -- example for RB image") 67 | print("please enter the index of an example:") 68 | method = input() 69 | method = '{0:}'.format(method) 70 | if(method == '0'): 71 | demo_geodesic_distance2d_gray_scale_image() 72 | elif(method == '1'): 73 | demo_geodesic_distance2d_RGB_image() 74 | else: 75 | print("invalid number : {0:}".format(method)) 76 | -------------------------------------------------------------------------------- /demo3d.py: -------------------------------------------------------------------------------- 1 | import GeodisTK 2 | import time 3 | import psutil 4 | import numpy as np 5 | import SimpleITK as sitk 6 | import matplotlib.pyplot as plt 7 | from PIL import Image 8 | 9 | def geodesic_distance_3d(I, S, spacing, lamb, iter): 10 | ''' 11 | Get 3D geodesic disntance by raser scanning. 12 | I: input image array, can have multiple channels, with shape [D, H, W] or [D, H, W, C] 13 | Type should be np.float32. 14 | S: binary image where non-zero pixels are used as seeds, with shape [D, H, W] 15 | Type should be np.uint8. 16 | spacing: a tuple of float numbers for pixel spacing along D, H and W dimensions respectively. 17 | lamb: weighting betwween 0.0 and 1.0 18 | if lamb==0.0, return spatial euclidean distance without considering gradient 19 | if lamb==1.0, the distance is based on gradient only without using spatial distance 20 | iter: number of iteration for raster scanning. 21 | ''' 22 | return GeodisTK.geodesic3d_raster_scan(I, S, spacing, lamb, iter) 23 | 24 | def demo_geodesic_distance3d(): 25 | input_name = "data/img3d.nii.gz" 26 | img = sitk.ReadImage(input_name) 27 | I = sitk.GetArrayFromImage(img) 28 | spacing_raw = img.GetSpacing() 29 | spacing = [spacing_raw[2], spacing_raw[1],spacing_raw[0]] 30 | I = np.asarray(I, np.float32) 31 | I = I[18:38, 63:183, 93:233 ] 32 | S = np.zeros_like(I, np.uint8) 33 | S[10][60][70] = 1 34 | t0 = time.time() 35 | D1 = GeodisTK.geodesic3d_fast_marching(I,S, spacing) 36 | t1 = time.time() 37 | D2 = geodesic_distance_3d(I,S, spacing, 1.0, 4) 38 | dt1 = t1 - t0 39 | dt2 = time.time() - t1 40 | D3 = geodesic_distance_3d(I,S, spacing, 0.0, 4) 41 | print("runtime(s) fast marching {0:}".format(dt1)) 42 | print("runtime(s) raster scan {0:}".format(dt2)) 43 | 44 | img_d1 = sitk.GetImageFromArray(D1) 45 | img_d1.SetSpacing(spacing_raw) 46 | sitk.WriteImage(img_d1, "data/image3d_dis1.nii.gz") 47 | 48 | img_d2 = sitk.GetImageFromArray(D2) 49 | img_d2.SetSpacing(spacing_raw) 50 | sitk.WriteImage(img_d2, "data/image3d_dis2.nii.gz") 51 | 52 | img_d3 = sitk.GetImageFromArray(D3) 53 | img_d3.SetSpacing(spacing_raw) 54 | sitk.WriteImage(img_d3, "data/image3d_dis3.nii.gz") 55 | 56 | I_sub = sitk.GetImageFromArray(I) 57 | I_sub.SetSpacing(spacing_raw) 58 | sitk.WriteImage(I_sub, "data/image3d_sub.nii.gz") 59 | 60 | I = I*255/I.max() 61 | I = np.asarray(I, np.uint8) 62 | 63 | I_slice = I[10] 64 | D1_slice = D1[10] 65 | D2_slice = D2[10] 66 | D3_slice = D3[10] 67 | plt.subplot(1,4,1); plt.imshow(I_slice, cmap='gray') 68 | plt.autoscale(False); plt.plot([70], [60], 'ro') 69 | plt.axis('off'); plt.title('input image') 70 | 71 | plt.subplot(1,4,2); plt.imshow(D1_slice) 72 | plt.axis('off'); plt.title('fast marching') 73 | 74 | plt.subplot(1,4,3); plt.imshow(D2_slice) 75 | plt.axis('off'); plt.title('ranster scan') 76 | 77 | plt.subplot(1,4,4); plt.imshow(D3_slice) 78 | plt.axis('off'); plt.title('Euclidean distance') 79 | plt.show() 80 | 81 | if __name__ == '__main__': 82 | demo_geodesic_distance3d() 83 | -------------------------------------------------------------------------------- /cpp/util.cpp: -------------------------------------------------------------------------------- 1 | #include "util.h" 2 | #include 3 | // for 2d images 4 | template 5 | T get_pixel(const T * data, int height, int width, int h, int w) 6 | { 7 | return data[h * width + w]; 8 | } 9 | 10 | template 11 | std::vector get_pixel_vector(const T * data, int height, int width, int channel, int h, int w) 12 | { 13 | std::vector pixel_vector(channel); 14 | for (int c = 0; c < channel; c++){ 15 | pixel_vector[c]= data[h * width * channel + w * channel + c]; 16 | } 17 | return pixel_vector; 18 | } 19 | 20 | template 21 | void set_pixel(T * data, int height, int width, int h, int w, T value) 22 | { 23 | data[h * width + w] = value; 24 | } 25 | 26 | template 27 | float get_pixel(const float * data, int height, int width, int h, int w); 28 | 29 | template 30 | int get_pixel(const int * data, int height, int width, int h, int w); 31 | 32 | template 33 | std::vector get_pixel_vector(const float * data, int height, int width, int channel, int h, int w); 34 | 35 | template 36 | unsigned char get_pixel(const unsigned char * data, int height, int width, int h, int w); 37 | 38 | 39 | template 40 | void set_pixel(float * data, int height, int width, int h, int w, float value); 41 | 42 | template 43 | void set_pixel(int * data, int height, int width, int h, int w, int value); 44 | 45 | template 46 | void set_pixel(unsigned char * data, int height, int width, int h, int w, unsigned char value); 47 | 48 | // for 3d images 49 | template 50 | T get_pixel(const T * data, int depth, int height, int width, int d, int h, int w) 51 | { 52 | return data[(d*height + h) * width + w]; 53 | } 54 | 55 | template 56 | std::vector get_pixel_vector(const T * data, int depth, int height, int width, int channel, int d, int h, int w) 57 | { 58 | std::vector pixel_vector(channel); 59 | for (int c = 0; c < channel; c++){ 60 | pixel_vector[c]= data[d*height*width*channel + h * width * channel + w * channel + c]; 61 | } 62 | return pixel_vector; 63 | } 64 | 65 | template 66 | void set_pixel(T * data, int depth, int height, int width, int d, int h, int w, T value) 67 | { 68 | data[(d*height + h) * width + w] = value; 69 | } 70 | 71 | template 72 | float get_pixel(const float * data, int depth, int height, int width, int d, int h, int w); 73 | 74 | template 75 | std::vector get_pixel_vector(const float * data, int depth, int height, int width, int channel, int d, int h, int w); 76 | 77 | template 78 | int get_pixel(const int * data, int depth, int height, int width, int d, int h, int w); 79 | 80 | template 81 | unsigned char get_pixel(const unsigned char * data, 82 | int depth, int height, int width, 83 | int d, int h, int w); 84 | 85 | 86 | template 87 | void set_pixel(float * data, int depth, int height, int width, int d, int h, int w, float value); 88 | 89 | template 90 | void set_pixel(int * data, int depth, int height, int width, int d, int h, int w, int value); 91 | 92 | template 93 | void set_pixel(unsigned char * data, 94 | int depth, int height, int width, 95 | int d, int h, int w, unsigned char value); -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GeodisTK: Geodesic Distance Transform Toolkit for 2D and 3D Images 2 | This repository provides source codes for geodesic distance transforms used in the following papers. If you use our code, please cite them. 3 | 4 | ``` 5 | @article{Wang2018Deepigeos, 6 | author={Wang, Guotai and Zuluaga, Maria A. and Li, Wenqi and Pratt, Rosalind and Patel, Premal A. and Aertsen, Michael and Doel, Tom and David, Anna L. and Deprest, Jan and Ourselin, Sébastien and Vercauteren, Tom}, 7 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 8 | title={DeepIGeoS: A Deep Interactive Geodesic Framework for Medical Image Segmentation}, 9 | year={2019}, 10 | volume={41}, 11 | number={7}, 12 | pages={1559-1572}} 13 | 14 | @article{LUO2021102102, 15 | author = {Luo, Xiangde and Wang, Guotai and Song, Tao and Zhang, Jingyang and Aertsen,Michael and Deprest, Jan and Ourselin, Sebastien and Vercauteren, Tom and Zhang, Shaoting}, 16 | title = {MIDeepSeg: Minimally interactive segmentation of unseen objects from medical images using deep learning}, 17 | journal = {Medical Image Analysis}, 18 | volume = {72}, 19 | pages = {102102}, 20 | year = {2021}} 21 | ``` 22 | 23 | ## Introduction 24 | Geodesic transformation of images can be implementated with two approaches: fast marching and raster scan. Fast marching is based on the iterative propagation of a pixel front with velocity F [1]. Raster scan is based on kernel operations that are sequentially applied over the image in multiple passes [2][3]. In GeoS [4], the authors proposed to use a 3x3 kernel for forward and backward passes for efficient geodesic distance transform, which was used for image segmentation. 25 | 26 | ![ranster scan](./data/ranster_scan.png) 27 | *Raster scan for geodesic distance transform. Image from [4]* 28 | 29 | DeepIGeoS [5] proposed to combine geodesic distance transforms with convolutional neural networks for efficient interactive segmentation of 2D and 3D images. 30 | 31 | * [1] Sethian, James A. "Fast marching methods." SIAM review 41, no. 2 (1999): 199-235. 32 | * [2] Borgefors, Gunilla. "Distance transformations in digital images." CVPR, 1986 33 | * [3] Toivanen, Pekka J. "New geodesic distance transforms for gray-scale images." Pattern Recognition Letters 17, no. 5 (1996): 437-450. 34 | * [4] Criminisi, Antonio, Toby Sharp, and Andrew Blake. "Geos: Geodesic image segmentation." ECCV, 2008. 35 | * [5] Wang, Guotai, et al. "[`DeepIGeoS: A deep interactive geodesic framework for medical image segmentation`](https://ieeexplore.ieee.org/document/8370732)." TPAMI, 2019. 36 | 37 | ![2D example](./data/2d_example.png) 38 | *A comparison of fast marching and ranster scan for 2D geodesic distance transform. (d) shows the Euclidean distance and (e) is a mixture of Geodesic and Euclidean distance.* 39 | 40 | This repository provides a cpp implementation of fast marching and raster scan for 2D/3D geodesic and Euclidean distance transforms and a mixture of them, and proivdes a python interface to use it. 41 | 42 | ## How to install 43 | 1. Install this toolkit easily by typing [`pip install GeodisTK`](https://pypi.org/project/GeodisTK/) 44 | 45 | 2. Alternatively, if you want to build from source files, download this package and run the following: 46 | ```bash 47 | python setup.py build 48 | python setup.py install 49 | ``` 50 | 51 | ## How to use 52 | 1. See a 2D example, run `python demo2d.py` 53 | 54 | 2. See a 3D example, run `python demo3d.py` 55 | -------------------------------------------------------------------------------- /cpp/geodesic_distance.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "numpy/arrayobject.h" 4 | #include "geodesic_distance_2d.h" 5 | #include "geodesic_distance_3d.h" 6 | #include 7 | using namespace std; 8 | 9 | // example to use numpy object: http://blog.debao.me/2013/04/my-first-c-extension-to-numpy/ 10 | // write a c extension ot Numpy: http://folk.uio.no/hpl/scripting/doc/python/NumPy/Numeric/numpy-13.html 11 | static PyObject * 12 | geodesic2d_fast_marching_wrapper(PyObject *self, PyObject *args) 13 | { 14 | PyObject *I=NULL, *Seed=NULL; 15 | PyArrayObject *arr_I=NULL, *arr_Seed=NULL; 16 | 17 | if (!PyArg_ParseTuple(args, "OO", &I, &Seed)) return NULL; 18 | 19 | arr_I = (PyArrayObject*)PyArray_FROM_OTF(I, NPY_FLOAT32, NPY_IN_ARRAY); 20 | if (arr_I == NULL) return NULL; 21 | 22 | arr_Seed = (PyArrayObject*)PyArray_FROM_OTF(Seed, NPY_UINT8, NPY_IN_ARRAY); 23 | if (arr_Seed == NULL) return NULL; 24 | 25 | 26 | int nd = PyArray_NDIM(arr_I); //number of dimensions 27 | npy_intp * shape = PyArray_DIMS(arr_I); // npy_intp array of length nd showing length in each dim. 28 | npy_intp * shape_seed = PyArray_DIMS(arr_Seed); 29 | // cout<<"input shape "; 30 | // for(int i=0; idata, (const unsigned char *)arr_Seed->data, 51 | (float *) distance->data, shape[0], shape[1], channel); 52 | 53 | Py_DECREF(arr_I); 54 | Py_DECREF(arr_Seed); 55 | //Py_INCREF(distance); 56 | return PyArray_Return(distance); 57 | } 58 | 59 | static PyObject * 60 | geodesic2d_raster_scan_wrapper(PyObject *self, PyObject *args) 61 | { 62 | PyObject *I=NULL, *Seed=NULL; 63 | float lambda, iteration; 64 | PyArrayObject *arr_I=NULL, *arr_Seed=NULL; 65 | if (!PyArg_ParseTuple(args, "OOff", &I, &Seed, &lambda, &iteration)) return NULL; 66 | 67 | arr_I = (PyArrayObject*)PyArray_FROM_OTF(I, NPY_FLOAT32, NPY_IN_ARRAY); 68 | if (arr_I == NULL) return NULL; 69 | 70 | arr_Seed = (PyArrayObject*)PyArray_FROM_OTF(Seed, NPY_UINT8, NPY_IN_ARRAY); 71 | if (arr_Seed == NULL) return NULL; 72 | 73 | 74 | int nd = PyArray_NDIM(arr_I); //number of dimensions 75 | npy_intp * shape = PyArray_DIMS(arr_I); // npy_intp array of length nd showing length in each dim. 76 | npy_intp * shape_seed = PyArray_DIMS(arr_Seed); 77 | // cout<<"input shape "; 78 | // for(int i=0; idata, (const unsigned char *)arr_Seed->data, 99 | (float *) distance->data, shape[0], shape[1], channel, lambda, (int)iteration); 100 | 101 | Py_DECREF(arr_I); 102 | Py_DECREF(arr_Seed); 103 | //Py_INCREF(distance); 104 | return PyArray_Return(distance); 105 | } 106 | 107 | static PyObject * 108 | geodesic3d_fast_marching_wrapper(PyObject *self, PyObject *args) 109 | { 110 | PyObject *I=NULL, *Seed=NULL, *Spacing=NULL; 111 | PyArrayObject *arr_I=NULL, *arr_Seed=NULL, *arr_Space=NULL; 112 | 113 | if (!PyArg_ParseTuple(args, "OOO", &I, &Seed, &Spacing)) return NULL; 114 | 115 | arr_I = (PyArrayObject*)PyArray_FROM_OTF(I, NPY_FLOAT32, NPY_IN_ARRAY); 116 | if (arr_I == NULL) return NULL; 117 | 118 | arr_Seed = (PyArrayObject*)PyArray_FROM_OTF(Seed, NPY_UINT8, NPY_IN_ARRAY); 119 | if (arr_Seed == NULL) return NULL; 120 | 121 | arr_Space = (PyArrayObject*)PyArray_FROM_OTF(Spacing, NPY_FLOAT32, NPY_IN_ARRAY); 122 | if (arr_Space == NULL) return NULL; 123 | 124 | int nd = PyArray_NDIM(arr_I); //number of dimensions 125 | npy_intp * shape = PyArray_DIMS(arr_I); // npy_intp array of length nd showing length in each dim. 126 | npy_intp * shape_seed = PyArray_DIMS(arr_Seed); 127 | // cout<<"input shape "; 128 | // for(int i=0; idata; 148 | // cout<<"spacing: "< sp_vec(3); 150 | for(int i = 0; i<3; i++){ 151 | sp_vec[i] = sp[i]; 152 | } 153 | 154 | PyArrayObject * distance = (PyArrayObject*) PyArray_SimpleNew(3, output_shape, NPY_FLOAT32); 155 | geodesic3d_fast_marching((const float *)arr_I->data, (const unsigned char *)arr_Seed->data, (float *) distance->data, 156 | shape[0], shape[1], shape[2], channel, sp_vec); 157 | 158 | Py_DECREF(arr_I); 159 | Py_DECREF(arr_Seed); 160 | //Py_INCREF(distance); 161 | return PyArray_Return(distance); 162 | } 163 | 164 | static PyObject * 165 | geodesic3d_raster_scan_wrapper(PyObject *self, PyObject *args) 166 | { 167 | PyObject *I=NULL, *Seed=NULL, *Spacing=NULL; 168 | float lambda, iteration; 169 | 170 | PyArrayObject *arr_I=NULL, *arr_Seed=NULL, *arr_Space=NULL; 171 | 172 | if (!PyArg_ParseTuple(args, "OOOff", &I, &Seed, &Spacing, &lambda, &iteration)) return NULL; 173 | 174 | arr_I = (PyArrayObject*)PyArray_FROM_OTF(I, NPY_FLOAT32, NPY_IN_ARRAY); 175 | if (arr_I == NULL) return NULL; 176 | 177 | arr_Seed = (PyArrayObject*)PyArray_FROM_OTF(Seed, NPY_UINT8, NPY_IN_ARRAY); 178 | if (arr_Seed == NULL) return NULL; 179 | 180 | arr_Space = (PyArrayObject*)PyArray_FROM_OTF(Spacing, NPY_FLOAT32, NPY_IN_ARRAY); 181 | if (arr_Space == NULL) return NULL; 182 | 183 | int nd = PyArray_NDIM(arr_I); //number of dimensions 184 | npy_intp * shape = PyArray_DIMS(arr_I); // npy_intp array of length nd showing length in each dim. 185 | npy_intp * shape_seed = PyArray_DIMS(arr_Seed); 186 | // cout<<"input shape "; 187 | // for(int i=0; idata; 207 | // cout<<"spacing: "< sp_vec(3); 209 | for(int i = 0; i<3; i++){ 210 | sp_vec[i] = sp[i]; 211 | } 212 | PyArrayObject * distance = (PyArrayObject*) PyArray_SimpleNew(3, output_shape, NPY_FLOAT32); 213 | geodesic3d_raster_scan((const float *)arr_I->data, (const unsigned char *)arr_Seed->data, (float *) distance->data, 214 | shape[0], shape[1], shape[2], channel, sp_vec, lambda, (int) iteration); 215 | 216 | Py_DECREF(arr_I); 217 | Py_DECREF(arr_Seed); 218 | //Py_INCREF(distance); 219 | return PyArray_Return(distance); 220 | } 221 | 222 | static PyMethodDef Methods[] = { 223 | {"geodesic2d_fast_marching", geodesic2d_fast_marching_wrapper, METH_VARARGS, "computing 2d geodesic distance"}, 224 | {"geodesic2d_raster_scan", geodesic2d_raster_scan_wrapper, METH_VARARGS, "computing 2d geodesic distance"}, 225 | {"geodesic3d_fast_marching", geodesic3d_fast_marching_wrapper, METH_VARARGS, "computing 3d geodesic distance"}, 226 | {"geodesic3d_raster_scan", geodesic3d_raster_scan_wrapper, METH_VARARGS, "computing 3d geodesic distance"}, 227 | {NULL, NULL, 0, NULL} 228 | }; 229 | -------------------------------------------------------------------------------- /cpp/geodesic_distance_2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "util.h" 6 | #include "geodesic_distance_2d.h" 7 | using namespace std; 8 | 9 | float get_l2_distance(std::vector p1, std::vector p2) 10 | { 11 | float sq_sum = 0.0; 12 | for(size_t d = 0; d < p1.size(); d++) 13 | { 14 | sq_sum = sq_sum + (p1[d] - p2[d]) * (p1[d] - p2[d]); 15 | } 16 | float dis = sqrt(sq_sum); 17 | return dis; 18 | } 19 | 20 | void insert_point_to_list(std::vector * list, int start_position, Point2D p) 21 | { 22 | int insert_idx = list->size(); 23 | for(size_t i = start_position; i < list->size(); i++) 24 | { 25 | if(list->at(i).distance < p.distance) 26 | { 27 | insert_idx = i; 28 | break; 29 | } 30 | } 31 | list->insert(insert_idx + list->begin(), p); 32 | } 33 | 34 | void update_point_in_list(std::vector * list, Point2D p) 35 | { 36 | int remove_idx = -1; 37 | for(size_t i = 0; i < list->size(); i++) 38 | { 39 | if(list->at(i).w == p.w && list->at(i).h == p.h) 40 | { 41 | remove_idx = i; 42 | break; 43 | } 44 | } 45 | list->erase(remove_idx + list->begin()); 46 | insert_point_to_list(list, remove_idx, p); 47 | } 48 | 49 | void add_new_accepted_point(const float * img, float * distance, int * state, 50 | std::vector * list, int height, int width, int channel, Point2D p) 51 | { 52 | int w = p.w; 53 | int h = p.h; 54 | int nh, nw, temp_state; 55 | std::vector p_value, q_value; 56 | float p_dis, space_dis, delta_dis, old_dis, new_dis; 57 | 58 | p_value = get_pixel_vector(img, height, width, channel, h, w); 59 | p_dis = get_pixel(distance, height, width, h, w); 60 | for(int dh = -1; dh <= 1; dh++) 61 | { 62 | for(int dw = -1; dw <= 1; dw++) 63 | { 64 | if(dh == 0 && dw == 0) continue; 65 | nh = dh + h; 66 | nw = dw + w; 67 | if(nh >=0 && nh < height && nw >=0 && nw < width) 68 | { 69 | temp_state = get_pixel(state, height, width, nh, nw); 70 | if(temp_state == 0) continue; 71 | space_dis = sqrt(dw*dw + dh*dh); 72 | q_value = get_pixel_vector(img, height, width, channel, nh, nw); 73 | delta_dis = space_dis*get_l2_distance(p_value, q_value); 74 | old_dis = get_pixel(distance, height, width, nh, nw); 75 | 76 | new_dis = p_dis + delta_dis; 77 | if(new_dis < old_dis) 78 | { 79 | set_pixel(distance, height, width, nh, nw, new_dis); 80 | Point2D new_point; 81 | new_point.distance = new_dis; 82 | new_point.h = nh; 83 | new_point.w = nw; 84 | if(temp_state == 2){ 85 | set_pixel(state, height, width, nh, nw, 1); 86 | insert_point_to_list(list, 0, new_point); 87 | } 88 | else{ 89 | update_point_in_list(list, new_point); 90 | } 91 | } 92 | } 93 | } // end for dw 94 | } // end for dh 95 | } 96 | 97 | void geodesic2d_fast_marching(const float * img, const unsigned char * seeds, float * distance, 98 | int height, int width, int channel) 99 | { 100 | int * state = new int[width * height]; 101 | 102 | // point state: 0--acceptd, 1--temporary, 2--far away 103 | // get initial accepted set and far away set 104 | float init_dis; 105 | int init_state; 106 | unsigned char seed_type; 107 | for(int h = 0; h < height; h++) 108 | { 109 | for (int w = 0; w < width; w++) 110 | { 111 | seed_type = get_pixel(seeds, height, width, h, w); 112 | if(seed_type > 0){ 113 | init_dis = 0.0; 114 | init_state = 0; 115 | } 116 | else{ 117 | init_dis = 1.0e10; 118 | init_state = 2; 119 | } 120 | set_pixel(distance, height, width, h, w, init_dis); 121 | set_pixel(state, height, width, h, w, init_state); 122 | } 123 | } 124 | 125 | // get initial temporary set, and save them in a list 126 | std::vector temporary_list; 127 | temporary_list.reserve(width*height); 128 | int temp_state; 129 | for(int h = 0; h < height; h++) 130 | { 131 | for (int w = 0; w < width; w++) 132 | { 133 | temp_state = get_pixel(state, height, width, h, w); 134 | if(temp_state == 0) 135 | { 136 | Point2D accepted_p; 137 | accepted_p.distance = 0.0; 138 | accepted_p.h = h; 139 | accepted_p.w = w; 140 | add_new_accepted_point(img, distance, state, &temporary_list, height, width, channel, accepted_p); 141 | } 142 | } 143 | } 144 | 145 | // update temporary set until it is empty 146 | while(temporary_list.size() > 0){ 147 | Point2D temp_point = temporary_list[temporary_list.size() - 1]; 148 | temporary_list.pop_back(); 149 | set_pixel(state, height, width, temp_point.h, temp_point.w, 0); 150 | add_new_accepted_point(img, distance, state, &temporary_list, height, width, channel, temp_point); 151 | } 152 | delete [] state; 153 | } 154 | 155 | void geodesic2d_raster_scan(const float * img, const unsigned char * seeds, float * distance, 156 | int height, int width, int channel, float lambda, int iteration) 157 | { 158 | float init_dis; 159 | unsigned char seed_type; 160 | for(int h = 0; h < height; h++) 161 | { 162 | for (int w = 0; w < width; w++) 163 | { 164 | seed_type = get_pixel(seeds, height, width, h, w); 165 | init_dis = seed_type > 0 ? 0.0 : 1.0e10; 166 | set_pixel(distance, height, width, h, w, init_dis); 167 | } 168 | } 169 | 170 | for(int it = 0; it(distance, height, width, h, w); 182 | std::vector p_value = get_pixel_vector(img, height, width, channel, h, w); 183 | for(int i = 0; i < 4; i++) 184 | { 185 | int nh = h + dh_f[i]; 186 | int nw = w + dw_f[i]; 187 | if(nh < 0 || nh >= height || nw < 0 || nw >= width) continue; 188 | float q_dis = get_pixel(distance, height, width, nh, nw); 189 | std::vector q_value = get_pixel_vector(img, height, width, channel, nh, nw);; 190 | float l2dis = get_l2_distance(p_value, q_value); 191 | float speed = (1.0 - lambda) + lambda / (l2dis + 1e-5); 192 | float delta_d = local_dis_f[i] /speed; 193 | float temp_dis = q_dis + delta_d; 194 | if(temp_dis < p_dis) p_dis = temp_dis; 195 | } 196 | set_pixel(distance, height, width, h, w, p_dis); 197 | } 198 | } 199 | 200 | // backward scann 201 | int dh_b[4] = {0, 1, 1, 1}; 202 | int dw_b[4] = {1, -1, 0, 1}; 203 | float local_dis_b[4] = {1.0, sqrtf(2.0), 1.0, sqrtf(2.0)}; 204 | for(int h = height -1; h >= 0; h--) 205 | { 206 | for (int w = width - 1; w >= 0; w--) 207 | { 208 | float p_dis = get_pixel(distance, height, width, h, w); 209 | std::vector p_value = get_pixel_vector(img, height, width, channel, h, w); 210 | for(int i = 0; i < 4; i++) 211 | { 212 | int nh = h + dh_b[i]; 213 | int nw = w + dw_b[i]; 214 | if(nh < 0 || nh >= height || nw < 0 || nw >= width) continue; 215 | float q_dis = get_pixel(distance, height, width, nh, nw); 216 | std::vector q_value = get_pixel_vector(img, height, width, channel, nh, nw); 217 | float l2dis = get_l2_distance(p_value, q_value); 218 | float speed = (1.0 - lambda) + lambda / (l2dis + 1e-5); 219 | float delta_d = local_dis_b[i] / speed; 220 | float temp_dis = q_dis + delta_d; 221 | if(temp_dis < p_dis) p_dis = temp_dis; 222 | } 223 | set_pixel(distance, height, width, h, w, p_dis); 224 | } 225 | } 226 | } 227 | } 228 | -------------------------------------------------------------------------------- /cpp/geodesic_distance_3d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "util.h" 6 | #include "geodesic_distance_2d.h" 7 | #include "geodesic_distance_3d.h" 8 | using namespace std; 9 | 10 | void insert_point_to_list(std::vector * list, int start_position, Point3D p) 11 | { 12 | int insert_idx = list->size(); 13 | for(int i = start_position; i < list->size(); i++) 14 | { 15 | if(list->at(i).distance < p.distance) 16 | { 17 | insert_idx = i; 18 | break; 19 | } 20 | } 21 | list->insert(insert_idx + list->begin(), p); 22 | } 23 | 24 | void update_point_in_list(std::vector * list, Point3D p) 25 | { 26 | int remove_idx = -1; 27 | for(int i = 0; i < list->size(); i++) 28 | { 29 | if(list->at(i).d == p.d && list->at(i).h == p.h && list->at(i).w == p.w) 30 | { 31 | remove_idx = i; 32 | break; 33 | } 34 | } 35 | list->erase(remove_idx + list->begin()); 36 | insert_point_to_list(list, remove_idx, p); 37 | } 38 | 39 | void add_new_accepted_point(const float * img, float * distance, int * state, 40 | std::vector * list, 41 | int depth, int height, int width, int channel, 42 | std::vector spacing, Point3D p) 43 | { 44 | int d = p.d; 45 | int h = p.h; 46 | int w = p.w; 47 | std::vector p_value, q_value; 48 | p_value = get_pixel_vector(img, depth, height, width, channel, d, h, w); 49 | float p_dis = get_pixel(distance, depth, height, width, d, h, w); 50 | for(int dd = -1; dd <= 1; dd++) 51 | { 52 | for(int dh = -1; dh <= 1; dh++) 53 | { 54 | for(int dw = -1; dw <= 1; dw++) 55 | { 56 | if(dd == 0 && dh == 0 && dw == 0) continue; 57 | int nd = dd + d; 58 | int nh = dh + h; 59 | int nw = dw + w; 60 | 61 | if(nd >=0 && nd < depth && nh >=0 && nh < height && nw >=0 && nw < width ) 62 | { 63 | int temp_state = get_pixel(state, depth, height, width, nd, nh, nw); 64 | if(temp_state == 0) continue; 65 | float dd_sp = dd * spacing[0]; 66 | float dh_sp = dh * spacing[1]; 67 | float dw_sp = dw * spacing[2]; 68 | float space_dis = sqrt(dd_sp*dd_sp + dh_sp*dh_sp + dw_sp*dw_sp); 69 | q_value = get_pixel_vector(img, depth, height, width, channel, nd, nh, nw); 70 | float delta_dis = space_dis*get_l2_distance(p_value, q_value); 71 | float old_dis = get_pixel(distance, depth, height, width, nd, nh, nw); 72 | 73 | float new_dis = p_dis + delta_dis; 74 | if(new_dis < old_dis) 75 | { 76 | set_pixel(distance, depth, height, width, nd, nh, nw, new_dis); 77 | Point3D new_point; 78 | new_point.distance = new_dis; 79 | new_point.d = nd; 80 | new_point.h = nh; 81 | new_point.w = nw; 82 | if(temp_state == 2){ 83 | set_pixel(state, depth, height, width, nd, nh, nw, 1); 84 | insert_point_to_list(list, 0, new_point); 85 | } 86 | else{ 87 | update_point_in_list(list, new_point); 88 | } 89 | } 90 | } 91 | } // end for dw 92 | } // end for dh 93 | }// end for dd 94 | } 95 | 96 | void geodesic3d_fast_marching(const float * img, const unsigned char * seeds, float * distance, 97 | int depth, int height, int width, int channel, std::vector spacing) 98 | { 99 | int * state = new int[depth * height * width]; 100 | 101 | // point state: 0--acceptd, 1--temporary, 2--far away 102 | // get initial accepted set and far away set 103 | float init_dis; 104 | int init_state; 105 | for(int d = 0; d < depth; d++) 106 | { 107 | for(int h = 0; h < height; h++) 108 | { 109 | for (int w = 0; w < width; w++) 110 | { 111 | unsigned char seed_type = get_pixel(seeds, depth, height, width, d, h, w); 112 | if(seed_type > 0){ 113 | init_dis = 0.0; 114 | init_state = 0; 115 | } 116 | else{ 117 | init_dis = 1.0e10; 118 | init_state = 2; 119 | } 120 | set_pixel(distance, depth, height, width, d, h, w, init_dis); 121 | set_pixel(state, depth, height, width, d, h, w, init_state); 122 | } 123 | } 124 | } 125 | 126 | // get initial temporary set, and save them in a list 127 | std::vector temporary_list; 128 | temporary_list.reserve(depth * height * width); 129 | for(int d = 0; d < depth; d++) 130 | { 131 | for(int h = 0; h < height; h++) 132 | { 133 | for (int w = 0; w < width; w++) 134 | { 135 | int temp_state = get_pixel(state, depth, height, width, d, h, w); 136 | if(temp_state == 0) 137 | { 138 | Point3D accepted_p; 139 | accepted_p.distance = 0.0; 140 | accepted_p.d = d; 141 | accepted_p.h = h; 142 | accepted_p.w = w; 143 | add_new_accepted_point(img, distance, state, &temporary_list, depth, height, width, channel, spacing, accepted_p); 144 | } 145 | } 146 | } 147 | } 148 | // update temporary set until it is empty 149 | while(temporary_list.size() > 0){ 150 | Point3D temp_point = temporary_list[temporary_list.size() - 1]; 151 | temporary_list.pop_back(); 152 | set_pixel(state, depth, height, width, temp_point.d, temp_point.h, temp_point.w, 0); 153 | add_new_accepted_point(img, distance, state, &temporary_list, depth, height, width, channel, spacing, temp_point); 154 | } 155 | delete [] state; 156 | } 157 | 158 | void geodesic3d_raster_scan(const float * img, const unsigned char * seeds, float * distance, 159 | int depth, int height, int width, int channel, 160 | std::vector spacing, float lambda, int iteration) 161 | { 162 | float init_dis; 163 | float p_value, q_value; 164 | std::vector p_value_v, q_value_v; 165 | float l2dis; 166 | unsigned char seed_type; 167 | 168 | for(int d = 0; d < depth; d++) 169 | { 170 | for(int h = 0; h < height; h++) 171 | { 172 | for (int w = 0; w < width; w++) 173 | { 174 | seed_type = get_pixel(seeds, depth, height, width, d, h, w); 175 | init_dis = seed_type > 0 ? 0.0 : 1.0e10; 176 | set_pixel(distance, depth, height, width, d, h, w, init_dis); 177 | } 178 | } 179 | } 180 | // float sqrt3 = sqrt(3.0); 181 | // float sqrt2 = sqrt(2.0); 182 | // float sqrt1 = 1.0; 183 | // distance for forward pass 184 | int dd_f[13] = {-1, -1, -1, -1, -1, 0, 0, 0, 0, 1, 1, 1, 1}; 185 | int dh_f[13] = {-1, -1, -1, 0, 0, -1, -1, -1, 0, -1, -1, -1, 0}; 186 | int dw_f[13] = {-1, 0, 1, -1, 0, -1, 0, 1, -1, -1, 0, 1, -1}; 187 | float local_dis_f[13]; // = {sqrt3, sqrt2, sqrt3, sqrt2, sqrt1, 188 | // sqrt2, sqrt1, sqrt2, sqrt1, 189 | // sqrt3, sqrt2, sqrt3, sqrt2}; 190 | for(int i = 0; i< 13; i++){ 191 | float distance = 0.0; 192 | if(dd_f[i] !=0) distance += spacing[0] *spacing[0]; 193 | if(dh_f[i] !=0) distance += spacing[1] *spacing[1]; 194 | if(dw_f[i] !=0) distance += spacing[2] *spacing[2]; 195 | distance = sqrt(distance); 196 | local_dis_f[i] = distance; 197 | } 198 | // distance for backward pass 199 | int dd_b[13] = {-1, -1, -1, -1, 0, 0, 0, 0, 1, 1, 1, 1, 1}; 200 | int dh_b[13] = { 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1}; 201 | int dw_b[13] = { 1, -1, 0, 1, 1, -1, 0, 1, 0, 1, -1, 0, 1}; 202 | float local_dis_b[13]; // = {sqrt2, sqrt3, sqrt2, sqrt3, sqrt1, 203 | // sqrt2, sqrt1, sqrt2, sqrt1, 204 | // sqrt2, sqrt3, sqrt2, sqrt3}; 205 | for(int i = 0; i< 13; i++){ 206 | float distance = 0.0; 207 | if(dd_b[i] !=0) distance += spacing[0] *spacing[0]; 208 | if(dh_b[i] !=0) distance += spacing[1] *spacing[1]; 209 | if(dw_b[i] !=0) distance += spacing[2] *spacing[2]; 210 | distance = sqrt(distance); 211 | local_dis_b[i] = distance; 212 | } 213 | 214 | for(int it = 0; it(distance, depth, height, width, d, h, w); 224 | if(channel == 1){ 225 | p_value = get_pixel(img, depth, height, width, d, h, w); 226 | } 227 | else{ 228 | p_value_v = get_pixel_vector(img, depth, height, width, channel, d, h, w); 229 | } 230 | 231 | for(int i = 0; i < 13; i++) 232 | { 233 | int nd = d + dd_f[i]; 234 | int nh = h + dh_f[i]; 235 | int nw = w + dw_f[i]; 236 | if(nd < 0 || nd >= depth || nh < 0 || nh >= height || nw < 0 || nw >= width) continue; 237 | float q_dis = get_pixel(distance, depth, height, width, nd, nh, nw); 238 | if(channel == 1){ 239 | q_value = get_pixel(img, depth, height, width, nd, nh, nw); 240 | l2dis = abs(p_value - q_value); 241 | } 242 | else{ 243 | q_value_v = get_pixel_vector(img, depth, height, width, channel, nd, nh, nw); 244 | l2dis = get_l2_distance(p_value_v, q_value_v); 245 | } 246 | 247 | float speed = (1.0 - lambda) + lambda/(l2dis + 1e-5); 248 | float delta_d = local_dis_f[i] / speed; 249 | float temp_dis = q_dis + delta_d; 250 | if(temp_dis < p_dis) p_dis = temp_dis; 251 | } 252 | set_pixel(distance, depth, height, width, d, h, w, p_dis); 253 | } 254 | } 255 | } 256 | 257 | // backward pass 258 | for(int d = depth - 1; d >= 0; d--) 259 | { 260 | for(int h = height - 1; h >= 0; h--) 261 | { 262 | for (int w = width - 1; w >= 0; w--) 263 | { 264 | float p_dis = get_pixel(distance, depth, height, width, d, h, w); 265 | if(channel == 1){ 266 | p_value = get_pixel(img, depth, height, width, d, h, w); 267 | } 268 | else{ 269 | p_value_v = get_pixel_vector(img, depth, height, width, channel, d, h, w); 270 | } 271 | for(int i = 0; i < 13; i++) 272 | { 273 | int nd = d + dd_b[i]; 274 | int nh = h + dh_b[i]; 275 | int nw = w + dw_b[i]; 276 | if(nd < 0 || nd >= depth || nh < 0 || nh >= height || nw < 0 || nw >= width) continue; 277 | float q_dis = get_pixel(distance, depth, height, width, nd, nh, nw); 278 | if(channel == 1){ 279 | q_value = get_pixel(img, depth, height, width, nd, nh, nw); 280 | l2dis = abs(p_value - q_value); 281 | } 282 | else{ 283 | q_value_v = get_pixel_vector(img, depth, height, width, channel, nd, nh, nw); 284 | l2dis = get_l2_distance(p_value_v, q_value_v); 285 | } 286 | float speed = (1.0 - lambda) + lambda/(l2dis + 1e-5); 287 | float delta_d = local_dis_b[i] / speed; 288 | float temp_dis = q_dis + delta_d; 289 | if(temp_dis < p_dis) p_dis = temp_dis; 290 | } 291 | set_pixel(distance, depth, height, width, d, h, w, p_dis); 292 | } 293 | } 294 | } 295 | } 296 | } 297 | --------------------------------------------------------------------------------