├── aiib23 └── images │ ├── 1 │ ├── main.png │ └── organizers.png ├── README.md └── utils ├── airway_metric.py └── tree_parse.py /aiib23/images/1: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /aiib23/images/main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nandayang/FANN-for-airway-segmentation/HEAD/aiib23/images/main.png -------------------------------------------------------------------------------- /aiib23/images/organizers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nandayang/FANN-for-airway-segmentation/HEAD/aiib23/images/organizers.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FANN-for-airway-segmentation 2 | This is a release code of the paper "Fuzzy Attention Neural Network to Tackle Discontinuity in Airway Segmentation" 3 | ## Updates 4 | **04/11/2022**: 5 | Upload docker download link for reproducing the module. 6 | Release the evaluation metrics. 7 | **Coming soon**: 8 | release the FANN code when the paper is accepted. 9 | ## Reproduce the work in the paper 10 | The file format must be *.nii.gz 11 | 1. Download the docker file through (no available at the moment due to AIIB23) 12 | ```https://drive.google.com/file/d/1K3JZsEOVBYX1QCnhNhrW6xOFwcZJrlwu/view?usp=sharing``` 13 | 2. Use the docker file for prediction 14 | ```docker image load < yang.tar.gz``` 15 | ```docker container run --gpus "device=0" --name yang --rm -v "your_test_data_path":/workspace/inputs/ -v "your_test_output_path":/workspace/outputs/ yang:latest /bin/bash -c "sh predict.sh" ``` 16 | ## Evaluation metrics 17 | Please find the evaluation metrics from ```utils/airway_metric.py``` 18 | -------------------------------------------------------------------------------- /utils/airway_metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import nibabel 5 | import skimage.measure as measure 6 | from skimage.morphology import skeletonize_3d 7 | from utils.tree_parse import get_parsing 8 | import math 9 | 10 | EPSILON = 1e-32 11 | 12 | 13 | def compute_binary_iou(y_true, y_pred): 14 | intersection = np.sum(y_true * y_pred) + EPSILON 15 | union = np.sum(y_true) + np.sum(y_pred) - intersection + EPSILON 16 | iou = intersection / union 17 | return iou 18 | 19 | def evaluation_branch_metrics(fid,label, pred,refine=False): 20 | """ 21 | :return: iou,dice, detected length ratio, detected branch ratio, 22 | precision, leakages, false negative ratio (airway missing ratio), 23 | large_cd (largest connected component) 24 | """ 25 | # compute tree sparsing 26 | parsing_gt = get_parsing(label, refine) 27 | # find the largest component to locate the airway prediction 28 | cd, num = measure.label(pred, return_num=True, connectivity=1) 29 | volume = np.zeros([num]) 30 | for k in range(num): 31 | volume[k] = ((cd == (k + 1)).astype(np.uint8)).sum() 32 | volume_sort = np.argsort(volume) 33 | large_cd = (cd == (volume_sort[-1] + 1)).astype(np.uint8) 34 | iou = compute_binary_iou(label, large_cd) 35 | flag=-1 36 | while iou < 0.1: 37 | print(fid," failed cases, require post-processing") 38 | large_cd = (cd == (volume_sort[flag-1] + 1)).astype(np.uint8) 39 | iou = compute_binary_iou(label, large_cd) 40 | skeleton = skeletonize_3d(label) 41 | skeleton = (skeleton > 0) 42 | skeleton = skeleton.astype('uint8') 43 | 44 | DLR = (large_cd * skeleton).sum() / skeleton.sum() 45 | precision = (large_cd * label).sum() / large_cd.sum() 46 | leakages = ((large_cd - label)==1).sum() / label.sum() 47 | 48 | num_branch = parsing_gt.max() 49 | detected_num = 0 50 | for j in range(num_branch): 51 | branch_label = ((parsing_gt == (j + 1)).astype(np.uint8)) * skeleton 52 | if (large_cd * branch_label).sum() / branch_label.sum() >= 0.8: 53 | detected_num += 1 54 | DBR = detected_num / num_branch 55 | return iou, DLR, DBR, precision, leakages 56 | -------------------------------------------------------------------------------- /utils/tree_parse.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from scipy import ndimage 4 | import skimage.measure as measure 5 | import nibabel 6 | from skimage.morphology import skeletonize_3d 7 | 8 | 9 | def large_connected_domain(label): 10 | cd, num = measure.label(label, return_num=True, connectivity=1) 11 | volume = np.zeros([num]) 12 | for k in range(num): 13 | volume[k] = ((cd == (k + 1)).astype(np.uint8)).sum() 14 | volume_sort = np.argsort(volume) 15 | # print(volume_sort) 16 | label = (cd == (volume_sort[-1] + 1)).astype(np.uint8) 17 | label = ndimage.binary_fill_holes(label) 18 | label = label.astype(np.uint8) 19 | return label 20 | 21 | 22 | def skeleton_parsing(skeleton): 23 | # separate the skeleton 24 | neighbor_filter = ndimage.generate_binary_structure(3, 3) 25 | skeleton_filtered = ndimage.convolve(skeleton, neighbor_filter) * skeleton 26 | # distribution = skeleton_filtered[skeleton_filtered>0] 27 | # plt.hist(distribution) 28 | skeleton_parse = skeleton.copy() 29 | skeleton_parse[skeleton_filtered > 3] = 0 30 | con_filter = ndimage.generate_binary_structure(3, 3) 31 | cd, num = ndimage.label(skeleton_parse, structure=con_filter) 32 | # remove small branches 33 | for i in range(num): 34 | a = cd[cd == (i + 1)] 35 | if a.shape[0] < 5: 36 | skeleton_parse[cd == (i + 1)] = 0 37 | cd, num = ndimage.label(skeleton_parse, structure=con_filter) 38 | return skeleton_parse, cd, num 39 | 40 | 41 | def tree_parsing_func(skeleton_parse, label, cd): 42 | # parse the airway tree 43 | edt, inds = ndimage.distance_transform_edt(1 - skeleton_parse, return_indices=True) 44 | tree_parsing = np.zeros(label.shape, dtype=np.uint16) 45 | tree_parsing = cd[inds[0, ...], inds[1, ...], inds[2, ...]] * label 46 | return tree_parsing 47 | 48 | 49 | def loc_trachea(tree_parsing, num): 50 | # find the trachea 51 | volume = np.zeros([num]) 52 | for k in range(num): 53 | volume[k] = ((tree_parsing == (k + 1)).astype(np.uint8)).sum() 54 | volume_sort = np.argsort(volume) 55 | # print(volume_sort) 56 | trachea = (volume_sort[-1] + 1) 57 | return trachea 58 | 59 | 60 | def get_parsing(mask, refine=False): 61 | mask = (mask > 0).astype(np.uint8) 62 | mask = large_connected_domain(mask) 63 | skeleton = skeletonize_3d(mask) 64 | skeleton_parse, cd, num = skeleton_parsing(skeleton) 65 | tree_parsing = tree_parsing_func(skeleton_parse, mask, cd) 66 | return tree_parsing 67 | --------------------------------------------------------------------------------