├── requirements.txt
├── gdown_folder.py
├── README.md
├── get_boxes.py
├── LICENSE
├── CVPR25_text_eval.py
├── SurfaceDice.py
└── CVPR25_iter_eval.py
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | connected-components-3d==3.12.4
3 | pandas==2.2.1
4 | numpy==1.26.3
5 | scipy==1.12.0
6 | cupy-cuda12x
7 | cucim==23.10.0
8 | tqdm
9 | scikit-image
10 |
--------------------------------------------------------------------------------
/gdown_folder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import sys
4 | import time
5 | import requests
6 | import contextlib
7 | from concurrent.futures import ThreadPoolExecutor, as_completed
8 | from tqdm import tqdm # Progress bar for downloads
9 |
10 | def recursive_gdown(folder_id, current_path='', max_workers=4, quiet_gdown=False):
11 | url = f"https://drive.google.com/embeddedfolderview?id={folder_id}"
12 | response = requests.get(url)
13 | if response.status_code != 200:
14 | print(f"Error {response.status_code}: {response.content}")
15 | return
16 |
17 | data = response.text
18 |
19 | # Extracting .npz filenames from the HTML
20 | npz_pattern = r'
([^<]*?\.npz)
'
21 | npz_files = re.findall(npz_pattern, data)
22 |
23 | folder_title_match = re.search(r"(.*?)", data)
24 | folder_title = folder_title_match.group(1) if folder_title_match else "Unknown"
25 |
26 | # Optimized regex patterns to find file links and subfolders
27 | file_pattern = r"https://drive\.google\.com/file/d/([-\w]{25,})/view"
28 | folder_pattern = r"https://drive.google.com/drive/folders/([-\w]{25,})"
29 |
30 | files = re.findall(file_pattern, data)
31 | folders = re.findall(folder_pattern, data)
32 | if len(files) > 0:
33 | print(f"Found {len(files)} files and {len(folders)} folders in '{folder_title}'")
34 | print(f"Found {len(npz_files)} .npz files")
35 |
36 | # Create directory for current folder
37 | path = os.path.join(current_path, folder_title)
38 | os.makedirs(path, exist_ok=True)
39 |
40 | # Multi-threaded file downloads for .npz files
41 | def download_file(npz_filename):
42 | file_path = os.path.join(path, npz_filename)
43 |
44 | # Check if the file already exists
45 | if os.path.exists(file_path):
46 | print(f"File '{npz_filename}' already exists, skipping...")
47 | return npz_filename # Skip downloading
48 |
49 | # Construct the gdown download command
50 | file_url = f"https://drive.google.com/uc?id={files[npz_files.index(npz_filename)]}"
51 |
52 | output_redirect = " > nul 2>&1" if os.name == "nt" else " > /dev/null 2>&1" if quiet_gdown else ""
53 | command = f"gdown {file_url} -O \"{file_path}\"{output_redirect}"
54 |
55 | exit_code = os.system(command)
56 |
57 | num_tries = 0
58 | while exit_code != 0:
59 | if num_tries > 3:
60 | print(f'Tried downloading {npz_filename} already {num_tries} times unsuccessfully. Please re-download your cookies.txt and put them in ~/.cache/gdown/')
61 | exit(1)
62 | #print(f"Retrying {npz_filename} in 30 seconds...")
63 | time.sleep(30)
64 | exit_code = os.system(command) # Retry downloading
65 | num_tries += 1
66 |
67 | return npz_filename # Return filename to update progress bar
68 |
69 | # Progress bar setup
70 | progress_bar = tqdm(total=len(npz_files), desc="Downloading .npz Files", unit="file")
71 |
72 | with ThreadPoolExecutor(max_workers=max_workers) as executor:
73 | future_to_file = {executor.submit(download_file, npz_filename): npz_filename for npz_filename in npz_files}
74 | for future in as_completed(future_to_file):
75 | future.result() # Wait for each file download to complete
76 | progress_bar.update(1)
77 |
78 | progress_bar.close()
79 |
80 | # Recursively process each sub-folder
81 | for folder_id in folders:
82 | recursive_gdown(folder_id, path, max_workers, quiet_gdown)
83 |
84 | if __name__ == "__main__":
85 | if len(sys.argv) < 2:
86 | print(f"Usage: python {sys.argv[0]} [max_workers] [--quiet-gdown]")
87 | exit(1)
88 |
89 | folder_id = sys.argv[1]
90 | if not re.match(r"^[-\w]{25,}$", folder_id):
91 | print(f"Invalid ID: {folder_id}")
92 | exit(1)
93 |
94 | # Get max_workers from user input or default to 4
95 | max_workers = 4
96 | quiet_gdown = False
97 |
98 | if len(sys.argv) > 2:
99 | for arg in sys.argv[2:]:
100 | if arg.isdigit():
101 | max_workers = int(arg)
102 | elif arg == "--quiet-gdown":
103 | quiet_gdown = True
104 |
105 | if max_workers == 0:# Use all cores
106 | import multiprocessing # Detect CPU cores
107 | max_workers = multiprocessing.cpu_count()
108 |
109 | recursive_gdown(folder_id, "./SegFM3D", max_workers, quiet_gdown)
110 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CVPR-MedSegFMCompetition
2 | Foundation Models for Biomedical Image Segmentation
3 |
4 | ## Evaluation
5 | The evaluation script `CVPR25_iter_eval.py` evaluates Docker submissions for the **CVPR25: Foundation Models for Interactive 3D Biomedical Image Segmentation Challenge** using an iterative refinement approach.
6 |
7 | ### Installation
8 | Installation of packages for the evaluation script:
9 | ```
10 | conda create -n cvpr_segfm_eval python=3.11 -y
11 | conda activate cvpr_segfm_eval
12 | pip install -r requirements.txt
13 | ```
14 |
15 | Run the script as follows:
16 |
17 | ```bash
18 | python CVPR25_iter_eval.py --docker_folder path/to/docker_submissions --test_img_path path/to/test_images --save_path path/to/output --verbose
19 | ```
20 |
21 | ### Arguments
22 | - `--docker_folder` : Path to the directory containing submitted Docker containers (`.tar.gz`).
23 | - `--test_img_path` : Path to the directory containing `.npz` test images.
24 | - `--save_path` : Directory to save segmentation outputs and evaluation metrics.
25 | - `--verbose` *(optional)* : Enables detailed output, including generated click coordinates.
26 | - `--validation_gts_path` Path to validation / test set GT files. This is needed to prevent label leakage (val/test) during the challenge.
27 |
28 | ### Evaluation Process
29 | 1. **Loads Docker submissions** and processes test images one by one.
30 | 2. **Initial Prediction:** Uses a bounding box prompt to generate the first segmentation.
31 | 3. **Iterative Refinement:** Simulates up to 5 refinement clicks based on segmentation errors.
32 | 4. **Performance Metrics:** Computes **Dice Similarity Coefficient AUC (DSC_AUC), Normalized Surface Dice AUC (NSD_AUC), Final DSC, Final NSD, and Inference Time**.
33 | 5. **Outputs results** as `.npz` files and a CSV summary.
34 |
35 | ### Output
36 | - Segmentation results are saved in the specified output directory.
37 | - Final prediction in the `segs` key
38 | - All the 6 intermediate predictions in the `all_segs` key
39 | - Metrics for each test case are compiled into a CSV file.
40 |
41 | For more details, refer to the challenge page: https://www.codabench.org/competitions/5263/
42 |
43 |
44 | ### Clicks Accumulation in Image Input
45 |
46 | During the prediction process, clicks are accumulated in the `clicks` key within the input `.npz` file.
47 |
48 | An example of a list stored in the `clicks` key for an image with 4 targets and after all 5 clicks:
49 |
50 | ```json
51 | [
52 | {"fg": [[46, 336, 343], [28, 233, 365], [28, 233, 365], [28, 233, 365]], "bg": [[28, 233, 366]]},
53 | {"fg": [[38, 210, 148]], "bg": [[6, 230, 284], [6, 230, 284], [6, 230, 284], [6, 230, 284]]},
54 | {"fg": [[12, 287, 262], [12, 287, 262], [12, 287, 262], [12, 287, 262], [12, 287, 262]], "bg": []},
55 | {"fg": [[28, 199, 180], [28, 199, 180], [28, 199, 180], [28, 199, 180], [28, 199, 180]], "bg": []},
56 | ]
57 | ```
58 | ### Clicks Order
59 | We also provide the order in which the clicks were generated in a ancilliary key `clicks_order` that is a simple list with values `fg` and `bg`, e.g., `['fg', 'fg', 'bg']`, indicating that the first two clicks were foreground clicks and the last a background click.
60 | ### Previous Prediction in Image Input
61 |
62 | The input image also contains the `prev_pred` key which stores the prediction from the previous iteration. This is used only to help with submissions that are using the previous prediction as an additional input.
63 |
64 | ### No Bounding Box key
65 | We also omit the `boxes` key in some of the validation and test samples as it is a bad prompt for some structures, such as vessels. In this case we simply skip the first inital prediction and only evaluate the models with 5 clicks using the same evaluation metrics.
66 |
67 |
68 | ### Upper Time Bound During Testing
69 | We set a limit of 90 seconds per class during inference (whole docker run). If the inference time exceeds this bound, the corresponding DSC and NSD scores will be set as 0. When participants evaluate their models using the `CVPR25_iter_eval.py` script they will receive a warning if their models exceed this limit.
70 |
71 | There are two motivations for this setting
72 | - The main focus of this competition is to prompt the interactive segmentation algorithm designs. Inference time should not be a huge concern/constraint for participants.
73 | - It is very hard to evaluate the real inference time within docker since implementations also affect the docker overhead.
74 |
75 | ### Final Script Output
76 | The `CVPR25_iter_eval.py` script will produce the following outputs in the `--save_path` argument:
77 | - `{teamname}_metrics.csv` that contains the following columns
78 | - `CaseName`: Test / Validation image filename
79 | - `TotalRunningTime`: Inference time taken for the image (all interactions)
80 | - `RunningTime_{i}`: Inference time for interactions [1-6], 1: bbox, 2-6: clicks
81 | - `DSC_AUC`: Area under DSC-to-Click curve metric
82 | - `NSD_AUC`: Area under NSD-to-Click curve metric
83 | - `DSC_Final`: DSC after final click
84 | - `NSD_Final`: NSD after final click
85 | - `CASE_{i}.npz` - model output with keys:
86 | - `segs`: Final prediction for all classes
87 | - `all_segs`: All intermediate predictions of the model for interactions [1-6]
88 |
89 |
--------------------------------------------------------------------------------
/get_boxes.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import cv2
4 | np.random.seed(2025)
5 | import cc3d
6 | from skimage import segmentation
7 | import copy
8 | import multiprocessing as mp
9 | import glob
10 |
11 | def show_box_cv2(image, box, color=(255, 0, 0), thickness=2):
12 | """
13 | Draws a rectangle on an image using OpenCV.
14 | Args:
15 | image: The input image (numpy array).
16 | box: A bounding box, either 2D ([x_min, y_min, x_max, y_max]) or 3D ([x_min, y_min, z_min, x_max, y_max, z_max]).
17 | color: Color of the rectangle in BGR (default is blue).
18 | thickness: Thickness of the rectangle border (default is 2).
19 | Returns:
20 | The image with the rectangle drawn.
21 | """
22 | color = tuple(map(int, color))
23 | if len(box) == 4: # 2D bounding box
24 | x_min, y_min, x_max, y_max = box
25 | cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color, thickness)
26 | else: # 3D bounding box
27 | x_min, y_min, z_min, x_max, y_max, z_max = box
28 | cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color, thickness)
29 | return image
30 |
31 | def show_mask_cv2(mask, image, color=None, alpha=0.5):
32 | assert mask.sum()>0
33 | if color is None:
34 | color = np.random.randint(0, 255, 3)
35 | h, w = mask.shape[-2:]
36 | overlay = np.zeros_like(image)
37 | for i in range(3):
38 | overlay[:, :, i] = color[i]
39 | overlay = cv2.bitwise_and(overlay, overlay, mask=mask)
40 | combined = cv2.addWeighted(overlay, alpha, image, 1-alpha , 0)
41 |
42 | return combined
43 |
44 | def mask2D_to_bbox(gt2D, file):
45 | try:
46 | y_indices, x_indices = np.where(gt2D > 0)
47 | x_min, x_max = np.min(x_indices), np.max(x_indices)
48 | y_min, y_max = np.min(y_indices), np.max(y_indices)
49 | # add perturbation to bounding box coordinates
50 | H, W = gt2D.shape
51 | bbox_shift = np.random.randint(0, 6, 1)[0]
52 | scale_y, scale_x = gt2D.shape
53 | bbox_shift_x = int(bbox_shift * scale_x/256)
54 | bbox_shift_y = int(bbox_shift * scale_y/256)
55 | #print(f'{bbox_shift_x=} {bbox_shift_y=} with orig {bbox_shift=}')
56 | x_min = max(0, x_min - bbox_shift_x)
57 | x_max = min(W-1, x_max + bbox_shift_x)
58 | y_min = max(0, y_min - bbox_shift_y)
59 | y_max = min(H-1, y_max + bbox_shift_y)
60 | boxes = np.array([x_min, y_min, x_max, y_max])
61 | return boxes
62 | except Exception as e:
63 | raise Exception(f'error {e} with file {file}')
64 |
65 |
66 | def mask3D_to_bbox(gt3D, file):
67 | b_dict = {}
68 | z_indices, y_indices, x_indices = np.where(gt3D > 0)
69 | z_min, z_max = np.min(z_indices), np.max(z_indices)
70 | z_indices = np.unique(z_indices)
71 | # middle of z_indices
72 | z_middle = z_indices[len(z_indices)//2]
73 |
74 | D, H, W = gt3D.shape
75 | b_dict['z_min'] = z_min
76 | b_dict['z_max'] = z_max
77 | b_dict['z_mid'] = z_middle
78 |
79 | gt_mid = gt3D[z_middle]
80 |
81 | box_2d = mask2D_to_bbox(gt_mid, file)
82 | x_min, y_min, x_max, y_max = box_2d
83 | b_dict['z_mid_x_min'] = x_min
84 | b_dict['z_mid_y_min'] = y_min
85 | b_dict['z_mid_x_max'] = x_max
86 | b_dict['z_mid_y_max'] = y_max
87 |
88 | assert z_min == max(0, z_min)
89 | assert z_max == min(D-1, z_max)
90 | return b_dict
91 |
92 | path = 'path-to-npz-files'
93 | path_dest = 'destination-path'
94 | os.makedirs(path_dest, exist_ok=True)
95 | sanity_dir = os.path.join(path_dest, 'sanity')
96 | os.makedirs(sanity_dir, exist_ok=True)
97 | files = glob.glob(os.path.join(path, '*/*/*.npz'))
98 | files = [x for x in files if 'Microscopy' not in x]
99 | files = sorted(files)
100 |
101 | print(f'number of files {len(files)}')
102 |
103 | def process(file):
104 | print(f'processing file {file}')
105 |
106 | npz = np.load(file, allow_pickle=True)
107 | imgs = npz['imgs']
108 |
109 | gts = npz['gts']
110 | gts, _, _ = segmentation.relabel_sequential(gts)
111 | spacing = npz['spacing']
112 | unique_labs = np.unique(gts)[1:]
113 |
114 | boxes_list = []
115 | for lab in unique_labs:
116 | gt = gts==lab
117 | box_dict = mask3D_to_bbox(gt, file)
118 | boxes_list.append(box_dict)
119 |
120 | for j, box_dict in enumerate(boxes_list):
121 | color = np.random.randint(0, 255, 3)
122 | img_mid = imgs[box_dict['z_mid']].copy()
123 | img_mid = np.expand_dims(img_mid, axis=-1).repeat(3, axis=-1)
124 | box2D = [box_dict['z_mid_x_min'], box_dict['z_mid_y_min'], box_dict['z_mid_x_max'], box_dict['z_mid_y_max']]
125 | img_mid = show_box_cv2(img_mid, box2D, color=color, thickness=2)
126 | img_mid = show_mask_cv2((gts[box_dict['z_mid']]==unique_labs[j]).astype(np.uint8), img_mid.astype(np.uint8), color=color, alpha=0.5)
127 | cv2.imwrite(os.path.join(sanity_dir, os.path.basename(file).replace('.npz', f'_boxIdx{j}.png')), img_mid)
128 |
129 | assert gt.sum() > 0
130 | np.savez_compressed(os.path.join(path_dest, os.path.basename(file)), imgs=imgs, gts=gts, boxes=boxes_list, spacing=spacing)
131 |
132 | if __name__ == '__main__':
133 | with mp.Pool(16) as p:
134 | p.map(process, files)
135 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/CVPR25_text_eval.py:
--------------------------------------------------------------------------------
1 | """
2 | The code was adapted from the MICCAI FLARE Challenge
3 | https://github.com/JunMa11/FLARE
4 |
5 | The testing images will be evaluated one by one.
6 |
7 | Folder structure:
8 | CVPR25_text_eval.py
9 | - team_docker
10 | - teamname.tar.gz # submitted docker containers from participants
11 | - test_demo
12 | - imgs
13 | - case1.npz # testing image
14 | - case2.npz
15 | - ...
16 | - demo_seg # segmentation results *******segmentation key: ['segs']*******
17 | - case1.npz # segmentation file name is the same as the testing image name
18 | - case2.npz
19 | - ...
20 | """
21 |
22 | import os
23 | join = os.path.join
24 | import shutil
25 | import time
26 | import torch
27 | import argparse
28 | from collections import OrderedDict
29 | import pandas as pd
30 | import numpy as np
31 | from skimage import segmentation
32 | from scipy.optimize import linear_sum_assignment
33 | import cc3d
34 | import SimpleITK as sitk
35 |
36 | from SurfaceDice import compute_surface_distances, compute_surface_dice_at_tolerance, compute_dice_coefficient
37 |
38 | def compute_multi_class_dsc(gt, seg, label_ids):
39 | present_labels = set(np.unique(gt)[1:]) & set(label_ids)
40 | dsc = [None] * len(present_labels)
41 | for idx, i in enumerate(present_labels):
42 | gt_i = gt == i
43 | seg_i = seg == i
44 | dsc[idx] = compute_dice_coefficient(gt_i, seg_i)
45 | return np.nanmean(dsc)
46 |
47 | def compute_multi_class_nsd(gt, seg, spacing, label_ids, tolerance=2.0):
48 | present_labels = set(np.unique(gt)[1:]) & set(label_ids)
49 | nsd = [None] * len(present_labels)
50 | for idx, i in enumerate(present_labels):
51 | gt_i = gt == i
52 | seg_i = seg == i
53 | surface_distance = compute_surface_distances(gt_i, seg_i, spacing_mm=spacing)
54 | nsd[idx] = compute_surface_dice_at_tolerance(surface_distance, tolerance)
55 | return np.nanmean(nsd)
56 |
57 | def _label_overlap(x, y):
58 | """ fast function to get pixel overlaps between masks in x and y
59 |
60 | Parameters
61 | ------------
62 |
63 | x: ND-array, int
64 | where 0=NO masks; 1,2... are mask labels
65 | y: ND-array, int
66 | where 0=NO masks; 1,2... are mask labels
67 |
68 | Returns
69 | ------------
70 |
71 | overlap: ND-array, int
72 | matrix of pixel overlaps of size [x.max()+1, y.max()+1]
73 |
74 | """
75 | x = x.ravel()
76 | y = y.ravel()
77 |
78 | # preallocate a 'contact map' matrix
79 | overlap = np.zeros((1+x.max(),1+y.max()), dtype=np.uint)
80 |
81 | # loop over the labels in x and add to the corresponding
82 | # overlap entry. If label A in x and label B in y share P
83 | # pixels, then the resulting overlap is P
84 | # len(x)=len(y), the number of pixels in the whole image
85 | for i in range(len(x)):
86 | overlap[x[i],y[i]] += 1
87 | return overlap
88 |
89 | def _intersection_over_union(masks_true, masks_pred):
90 | """ intersection over union of all mask pairs
91 |
92 | Parameters
93 | ------------
94 |
95 | masks_true: ND-array, int
96 | ground truth masks, where 0=NO masks; 1,2... are mask labels
97 | masks_pred: ND-array, int
98 | predicted masks, where 0=NO masks; 1,2... are mask labels
99 |
100 | Returns
101 | ------------
102 | iou: ND-array, float
103 | matrix of IOU pairs of size [masks_true.max()+1, masks_pred.max()+1]
104 | iou[i, j] is the IoU between ground truth instance i+1 and predicted instance j+1.
105 | """
106 | overlap = _label_overlap(masks_true, masks_pred)
107 | n_pixels_pred = np.sum(overlap, axis=0, keepdims=True)
108 | n_pixels_true = np.sum(overlap, axis=1, keepdims=True)
109 | iou = overlap / (n_pixels_pred + n_pixels_true - overlap)
110 | iou[np.isnan(iou)] = 0.0
111 | return iou
112 |
113 | def _true_positive(iou, th):
114 | """ true positive at threshold th
115 |
116 | Parameters
117 | ------------
118 |
119 | iou: float, ND-array
120 | array of IOU pairs
121 | th: float
122 | threshold on IOU for positive label
123 |
124 | Returns
125 | ------------
126 |
127 | tp: float
128 | number of true positives at threshold
129 | """
130 | n_min = min(iou.shape[0], iou.shape[1])
131 | costs = -(iou >= th).astype(float) - iou / (2*n_min)
132 | true_ind, pred_ind = linear_sum_assignment(costs)
133 | match_ok = iou[true_ind, pred_ind] >= th
134 | tp = match_ok.sum()
135 | matched_pairs = [(t, p) for t, p, ok in zip(true_ind, pred_ind, match_ok) if ok]
136 | return tp, matched_pairs
137 |
138 | def eval_tp_fp_fn(masks_true, masks_pred, threshold=0.5):
139 | num_inst_gt = np.max(masks_true)
140 | num_inst_seg = np.max(masks_pred)
141 | if num_inst_seg>0:
142 | iou = _intersection_over_union(masks_true, masks_pred)[1:, 1:]
143 | tp, matched_pairs = _true_positive(iou, threshold)
144 | fp = num_inst_seg - tp
145 | fn = num_inst_gt - tp
146 | else:
147 | # print('No segmentation results!')
148 | tp = 0
149 | fp = 0
150 | fn = 0
151 | matched_pairs = None
152 |
153 | return tp, fp, fn, matched_pairs
154 |
155 | parser = argparse.ArgumentParser('Segmentation eavluation for docker containers', add_help=False)
156 | parser.add_argument('-i', '--test_img_path', default='./3D_val_npz', type=str, help='testing data path')
157 | parser.add_argument('-val_gts','--validation_gts_path', default='./3D_val_gt_text_seg', type=str, help='path to validation set (or final test set) GT files')
158 | parser.add_argument('-o','--save_path', default='./outputs', type=str, help='segmentation output path')
159 | parser.add_argument('-d','--docker_folder_path', default='./team_dockers', type=str, help='team docker path')
160 | args = parser.parse_args()
161 |
162 | test_img_path = args.test_img_path
163 | validation_gts_path = args.validation_gts_path
164 | save_path = args.save_path
165 | docker_path = args.docker_folder_path
166 |
167 | input_temp = './inputs/'
168 | output_temp = './outputs'
169 | os.makedirs(save_path, exist_ok=True)
170 |
171 | dockers = sorted(os.listdir(docker_path))
172 | test_cases = sorted(os.listdir(test_img_path))
173 |
174 | for docker in dockers:
175 | try:
176 | # create temp folers for inference one-by-one
177 | if os.path.exists(input_temp):
178 | shutil.rmtree(input_temp)
179 | if os.path.exists(output_temp):
180 | shutil.rmtree(output_temp)
181 | os.makedirs(input_temp)
182 | os.makedirs(output_temp)
183 |
184 | # load docker and create a new folder to save segmentation results
185 | teamname = docker.split('.')[0].lower()
186 | print('teamname docker: ', docker)
187 | os.system('docker image load -i {}'.format(join(docker_path, docker)))
188 |
189 | # create a new folder to save segmentation results
190 | team_outpath = join(save_path, teamname)
191 | if os.path.exists(team_outpath):
192 | shutil.rmtree(team_outpath)
193 | os.mkdir(team_outpath)
194 | os.system('chmod -R 777 ./* ') # give permission to all files
195 |
196 | # initialize the metric dictionary
197 | metric = OrderedDict()
198 | metric['CaseName'] = []
199 | metric['RunningTime'] = []
200 | metric['DSC'] = []
201 | metric['NSD'] = []
202 | metric['F1'] = []
203 | metric['DSC_TP'] = []
204 |
205 | missing_files = []
206 |
207 | # To obtain the running time for each case, testing cases are inferred one-by-one
208 | for case in test_cases:
209 | shutil.copy(join(test_img_path, case), input_temp)
210 | cmd = 'docker container run --gpus "device=0" -m 32G --name {} --rm -v $PWD/inputs/:/workspace/inputs/ -v $PWD/outputs/:/workspace/outputs/ {}:latest /bin/bash -c "sh predict.sh" '.format(teamname, teamname)
211 | print(teamname, ' docker command:', cmd, '\n', 'testing image name:', case)
212 |
213 | # run the docker container and measure inference time
214 | start_time = time.time()
215 | try:
216 | os.system(cmd)
217 | except Exception as e:
218 | print('inference error!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
219 | print(case, e)
220 | real_running_time = time.time() - start_time
221 | print(f"{case} finished! Inference time: {real_running_time}")
222 |
223 | # save metrics
224 | metric['CaseName'].append(case)
225 | metric['RunningTime'].append(real_running_time)
226 |
227 | # Metric calculation (DSC and NSD)
228 | seg_name = case
229 | gt_path = join(validation_gts_path, seg_name)
230 | seg_path = join(output_temp, seg_name)
231 |
232 | try:
233 | # Load ground truth and segmentation masks
234 | gt_npz = np.load(gt_path, allow_pickle=True)['gts']
235 | seg_npz = np.load(seg_path, allow_pickle=True)['segs']
236 |
237 | gt_npz = gt_npz.astype(np.uint8)
238 | seg_npz = seg_npz.astype(np.uint8)
239 |
240 | # Calculate DSC and NSD
241 | img_npz = np.load(join(input_temp, case), allow_pickle=True)
242 | spacing = img_npz['spacing']
243 | instance_label = img_npz['text_prompts'].item()['instance_label']
244 |
245 | class_ids = sorted([int(k) for k in img_npz['text_prompts'].item() if k != "instance_label"])
246 | class_ids_array = np.array(class_ids, dtype=np.int32)
247 |
248 | if instance_label == 0: # semantic masks
249 | # note: the semantic labels may not be sequential
250 | dsc = compute_multi_class_dsc(gt_npz, seg_npz, class_ids_array)
251 | nsd = compute_multi_class_nsd(gt_npz, seg_npz, spacing, class_ids_array)
252 | f1_score = np.NaN
253 | dsc_tp = np.NaN
254 | elif instance_label == 1: # instance masks
255 | # Calculate F1 instead
256 | if len(np.unique(seg_npz)) == 2:
257 | print("converting segmentation to instance masks")
258 | # convert prediction masks from binary to instance
259 | tumor_inst, tumor_n = cc3d.connected_components(seg_npz, connectivity=6, return_N=True)
260 |
261 | # put the tumor instances back to gt_data_ori
262 | seg_npz[tumor_inst > 0] = (tumor_inst[tumor_inst > 0] + np.max(seg_npz))
263 |
264 | gt_npz = segmentation.relabel_sequential(gt_npz)[0]
265 | seg_npz = segmentation.relabel_sequential(seg_npz)[0]
266 |
267 | tp, fp, fn, matched_pairs = eval_tp_fp_fn(gt_npz, seg_npz) # default f1 overlap threshold is 0.5
268 | precision = tp / (tp + fp) if (tp + fp) > 0 else 0
269 | recall = tp / (tp + fn) if (tp + fn) > 0 else 0
270 | f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
271 |
272 | # compute DSC for TP cases
273 | if matched_pairs:
274 | dsc_list = []
275 | for gt_idx, pred_idx in matched_pairs:
276 | gt_mask = gt_npz == (gt_idx + 1)
277 | pred_mask = seg_npz == (pred_idx + 1)
278 | dsc_value = compute_dice_coefficient(gt_mask, pred_mask)
279 | dsc_list.append(dsc_value)
280 | dsc_tp = np.mean(dsc_list)
281 | else:
282 | dsc_tp = 0
283 |
284 | # Set DSC and NSD to None for instance masks
285 | dsc = None
286 | nsd = None
287 |
288 | metric['DSC'].append(round(dsc, 4) if dsc is not None else np.NAN)
289 | metric['NSD'].append(round(nsd, 4) if nsd is not None else np.NAN)
290 | metric['F1'].append(round(f1_score, 4) if f1_score is not None else np.NAN)
291 | metric['DSC_TP'].append(round(dsc_tp, 4) if dsc_tp is not None else np.NAN)
292 |
293 | print(f"{case}: DSC={dsc if dsc is not None else np.NAN}, NSD={nsd if nsd is not None else np.NAN}, F1={f1_score}, DSC_TP={dsc_tp if dsc_tp is not None else np.NAN}")
294 |
295 | except Exception as e:
296 | print(f"ERROR processing {case}: {e}")
297 | metric['DSC'].append(np.NAN)
298 | metric['NSD'].append(np.NAN)
299 | metric['F1'].append(np.NAN)
300 | missing_files.append(f"{case}: {e}")
301 |
302 | # the segmentation file name should be the same as the testing image name
303 | try:
304 | os.rename(join(output_temp, seg_name), join(team_outpath, seg_name))
305 | except:
306 | print(f"{join(output_temp, seg_name)}, {join(team_outpath, seg_name)}")
307 | print("Wrong segmentation name!!! It should be the same as image_name")
308 |
309 | os.remove(join(input_temp, case)) # Moves the segmentation output file from output_temp to the appropriate team folder in demo_seg.
310 |
311 | # save the metrics to a CSV file
312 | metric_df = pd.DataFrame(metric)
313 | metric_df.to_csv(join(team_outpath, teamname + '_metrics.csv'), index=False)
314 | print(f"Metrics saved to {join(team_outpath, teamname + '_metrics.csv')}")
315 |
316 | # Save missing files log
317 | if missing_files:
318 | missing_file_path = os.path.join(team_outpath, f"{teamname}_error_files.txt")
319 | with open(missing_file_path, 'w') as f:
320 | f.write("\n".join(missing_files))
321 | print(f"Error files logged to {missing_file_path}")
322 |
323 | # clean up
324 | torch.cuda.empty_cache()
325 | os.system("docker rmi {}:latest".format(teamname))
326 | shutil.rmtree(input_temp)
327 | shutil.rmtree(output_temp)
328 |
329 | except Exception as e:
330 | print(e)
331 |
--------------------------------------------------------------------------------
/SurfaceDice.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Fri Apr 15 13:01:08 2022
4 |
5 | @author: 12593
6 | """
7 |
8 | import numpy as np
9 | import scipy.ndimage
10 |
11 | # neighbour_code_to_normals is a lookup table.
12 | # For every binary neighbour code
13 | # (2x2x2 neighbourhood = 8 neighbours = 8 bits = 256 codes)
14 | # it contains the surface normals of the triangles (called "surfel" for
15 | # "surface element" in the following). The length of the normal
16 | # vector encodes the surfel area.
17 | #
18 | # created by compute_surface_area_lookup_table.ipynb using the
19 | # marching_cube algorithm, see e.g. https://en.wikipedia.org/wiki/Marching_cubes
20 | # credit to: http://medicaldecathlon.com/files/Surface_distance_based_measures.ipynb
21 | neighbour_code_to_normals = [
22 | [[0,0,0]],
23 | [[0.125,0.125,0.125]],
24 | [[-0.125,-0.125,0.125]],
25 | [[-0.25,-0.25,0.0],[0.25,0.25,-0.0]],
26 | [[0.125,-0.125,0.125]],
27 | [[-0.25,-0.0,-0.25],[0.25,0.0,0.25]],
28 | [[0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
29 | [[0.5,0.0,-0.0],[0.25,0.25,0.25],[0.125,0.125,0.125]],
30 | [[-0.125,0.125,0.125]],
31 | [[0.125,0.125,0.125],[-0.125,0.125,0.125]],
32 | [[-0.25,0.0,0.25],[-0.25,0.0,0.25]],
33 | [[0.5,0.0,0.0],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125]],
34 | [[0.25,-0.25,0.0],[0.25,-0.25,0.0]],
35 | [[0.5,0.0,0.0],[0.25,-0.25,0.25],[-0.125,0.125,-0.125]],
36 | [[-0.5,0.0,0.0],[-0.25,0.25,0.25],[-0.125,0.125,0.125]],
37 | [[0.5,0.0,0.0],[0.5,0.0,0.0]],
38 | [[0.125,-0.125,-0.125]],
39 | [[0.0,-0.25,-0.25],[0.0,0.25,0.25]],
40 | [[-0.125,-0.125,0.125],[0.125,-0.125,-0.125]],
41 | [[0.0,-0.5,0.0],[0.25,0.25,0.25],[0.125,0.125,0.125]],
42 | [[0.125,-0.125,0.125],[0.125,-0.125,-0.125]],
43 | [[0.0,0.0,-0.5],[0.25,0.25,0.25],[-0.125,-0.125,-0.125]],
44 | [[-0.125,-0.125,0.125],[0.125,-0.125,0.125],[0.125,-0.125,-0.125]],
45 | [[-0.125,-0.125,-0.125],[-0.25,-0.25,-0.25],[0.25,0.25,0.25],[0.125,0.125,0.125]],
46 | [[-0.125,0.125,0.125],[0.125,-0.125,-0.125]],
47 | [[0.0,-0.25,-0.25],[0.0,0.25,0.25],[-0.125,0.125,0.125]],
48 | [[-0.25,0.0,0.25],[-0.25,0.0,0.25],[0.125,-0.125,-0.125]],
49 | [[0.125,0.125,0.125],[0.375,0.375,0.375],[0.0,-0.25,0.25],[-0.25,0.0,0.25]],
50 | [[0.125,-0.125,-0.125],[0.25,-0.25,0.0],[0.25,-0.25,0.0]],
51 | [[0.375,0.375,0.375],[0.0,0.25,-0.25],[-0.125,-0.125,-0.125],[-0.25,0.25,0.0]],
52 | [[-0.5,0.0,0.0],[-0.125,-0.125,-0.125],[-0.25,-0.25,-0.25],[0.125,0.125,0.125]],
53 | [[-0.5,0.0,0.0],[-0.125,-0.125,-0.125],[-0.25,-0.25,-0.25]],
54 | [[0.125,-0.125,0.125]],
55 | [[0.125,0.125,0.125],[0.125,-0.125,0.125]],
56 | [[0.0,-0.25,0.25],[0.0,0.25,-0.25]],
57 | [[0.0,-0.5,0.0],[0.125,0.125,-0.125],[0.25,0.25,-0.25]],
58 | [[0.125,-0.125,0.125],[0.125,-0.125,0.125]],
59 | [[0.125,-0.125,0.125],[-0.25,-0.0,-0.25],[0.25,0.0,0.25]],
60 | [[0.0,-0.25,0.25],[0.0,0.25,-0.25],[0.125,-0.125,0.125]],
61 | [[-0.375,-0.375,0.375],[-0.0,0.25,0.25],[0.125,0.125,-0.125],[-0.25,-0.0,-0.25]],
62 | [[-0.125,0.125,0.125],[0.125,-0.125,0.125]],
63 | [[0.125,0.125,0.125],[0.125,-0.125,0.125],[-0.125,0.125,0.125]],
64 | [[-0.0,0.0,0.5],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125]],
65 | [[0.25,0.25,-0.25],[0.25,0.25,-0.25],[0.125,0.125,-0.125],[-0.125,-0.125,0.125]],
66 | [[0.125,-0.125,0.125],[0.25,-0.25,0.0],[0.25,-0.25,0.0]],
67 | [[0.5,0.0,0.0],[0.25,-0.25,0.25],[-0.125,0.125,-0.125],[0.125,-0.125,0.125]],
68 | [[0.0,0.25,-0.25],[0.375,-0.375,-0.375],[-0.125,0.125,0.125],[0.25,0.25,0.0]],
69 | [[-0.5,0.0,0.0],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125]],
70 | [[0.25,-0.25,0.0],[-0.25,0.25,0.0]],
71 | [[0.0,0.5,0.0],[-0.25,0.25,0.25],[0.125,-0.125,-0.125]],
72 | [[0.0,0.5,0.0],[0.125,-0.125,0.125],[-0.25,0.25,-0.25]],
73 | [[0.0,0.5,0.0],[0.0,-0.5,0.0]],
74 | [[0.25,-0.25,0.0],[-0.25,0.25,0.0],[0.125,-0.125,0.125]],
75 | [[-0.375,-0.375,-0.375],[-0.25,0.0,0.25],[-0.125,-0.125,-0.125],[-0.25,0.25,0.0]],
76 | [[0.125,0.125,0.125],[0.0,-0.5,0.0],[-0.25,-0.25,-0.25],[-0.125,-0.125,-0.125]],
77 | [[0.0,-0.5,0.0],[-0.25,-0.25,-0.25],[-0.125,-0.125,-0.125]],
78 | [[-0.125,0.125,0.125],[0.25,-0.25,0.0],[-0.25,0.25,0.0]],
79 | [[0.0,0.5,0.0],[0.25,0.25,-0.25],[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
80 | [[-0.375,0.375,-0.375],[-0.25,-0.25,0.0],[-0.125,0.125,-0.125],[-0.25,0.0,0.25]],
81 | [[0.0,0.5,0.0],[0.25,0.25,-0.25],[-0.125,-0.125,0.125]],
82 | [[0.25,-0.25,0.0],[-0.25,0.25,0.0],[0.25,-0.25,0.0],[0.25,-0.25,0.0]],
83 | [[-0.25,-0.25,0.0],[-0.25,-0.25,0.0],[-0.125,-0.125,0.125]],
84 | [[0.125,0.125,0.125],[-0.25,-0.25,0.0],[-0.25,-0.25,0.0]],
85 | [[-0.25,-0.25,0.0],[-0.25,-0.25,0.0]],
86 | [[-0.125,-0.125,0.125]],
87 | [[0.125,0.125,0.125],[-0.125,-0.125,0.125]],
88 | [[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
89 | [[-0.125,-0.125,0.125],[-0.25,-0.25,0.0],[0.25,0.25,-0.0]],
90 | [[0.0,-0.25,0.25],[0.0,-0.25,0.25]],
91 | [[0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125]],
92 | [[0.0,-0.25,0.25],[0.0,-0.25,0.25],[-0.125,-0.125,0.125]],
93 | [[0.375,-0.375,0.375],[0.0,-0.25,-0.25],[-0.125,0.125,-0.125],[0.25,0.25,0.0]],
94 | [[-0.125,-0.125,0.125],[-0.125,0.125,0.125]],
95 | [[0.125,0.125,0.125],[-0.125,-0.125,0.125],[-0.125,0.125,0.125]],
96 | [[-0.125,-0.125,0.125],[-0.25,0.0,0.25],[-0.25,0.0,0.25]],
97 | [[0.5,0.0,0.0],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
98 | [[-0.0,0.5,0.0],[-0.25,0.25,-0.25],[0.125,-0.125,0.125]],
99 | [[-0.25,0.25,-0.25],[-0.25,0.25,-0.25],[-0.125,0.125,-0.125],[-0.125,0.125,-0.125]],
100 | [[-0.25,0.0,-0.25],[0.375,-0.375,-0.375],[0.0,0.25,-0.25],[-0.125,0.125,0.125]],
101 | [[0.5,0.0,0.0],[-0.25,0.25,-0.25],[0.125,-0.125,0.125]],
102 | [[-0.25,0.0,0.25],[0.25,0.0,-0.25]],
103 | [[-0.0,0.0,0.5],[-0.25,0.25,0.25],[-0.125,0.125,0.125]],
104 | [[-0.125,-0.125,0.125],[-0.25,0.0,0.25],[0.25,0.0,-0.25]],
105 | [[-0.25,-0.0,-0.25],[-0.375,0.375,0.375],[-0.25,-0.25,0.0],[-0.125,0.125,0.125]],
106 | [[0.0,0.0,-0.5],[0.25,0.25,-0.25],[-0.125,-0.125,0.125]],
107 | [[-0.0,0.0,0.5],[0.0,0.0,0.5]],
108 | [[0.125,0.125,0.125],[0.125,0.125,0.125],[0.25,0.25,0.25],[0.0,0.0,0.5]],
109 | [[0.125,0.125,0.125],[0.25,0.25,0.25],[0.0,0.0,0.5]],
110 | [[-0.25,0.0,0.25],[0.25,0.0,-0.25],[-0.125,0.125,0.125]],
111 | [[-0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125],[0.125,-0.125,0.125]],
112 | [[-0.25,0.0,0.25],[-0.25,0.0,0.25],[-0.25,0.0,0.25],[0.25,0.0,-0.25]],
113 | [[0.125,-0.125,0.125],[0.25,0.0,0.25],[0.25,0.0,0.25]],
114 | [[0.25,0.0,0.25],[-0.375,-0.375,0.375],[-0.25,0.25,0.0],[-0.125,-0.125,0.125]],
115 | [[-0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125]],
116 | [[0.125,0.125,0.125],[0.25,0.0,0.25],[0.25,0.0,0.25]],
117 | [[0.25,0.0,0.25],[0.25,0.0,0.25]],
118 | [[-0.125,-0.125,0.125],[0.125,-0.125,0.125]],
119 | [[0.125,0.125,0.125],[-0.125,-0.125,0.125],[0.125,-0.125,0.125]],
120 | [[-0.125,-0.125,0.125],[0.0,-0.25,0.25],[0.0,0.25,-0.25]],
121 | [[0.0,-0.5,0.0],[0.125,0.125,-0.125],[0.25,0.25,-0.25],[-0.125,-0.125,0.125]],
122 | [[0.0,-0.25,0.25],[0.0,-0.25,0.25],[0.125,-0.125,0.125]],
123 | [[0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125],[0.125,-0.125,0.125]],
124 | [[0.0,-0.25,0.25],[0.0,-0.25,0.25],[0.0,-0.25,0.25],[0.0,0.25,-0.25]],
125 | [[0.0,0.25,0.25],[0.0,0.25,0.25],[0.125,-0.125,-0.125]],
126 | [[-0.125,0.125,0.125],[0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
127 | [[-0.125,0.125,0.125],[0.125,-0.125,0.125],[-0.125,-0.125,0.125],[0.125,0.125,0.125]],
128 | [[-0.0,0.0,0.5],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
129 | [[0.125,0.125,0.125],[0.125,-0.125,0.125],[0.125,-0.125,-0.125]],
130 | [[-0.0,0.5,0.0],[-0.25,0.25,-0.25],[0.125,-0.125,0.125],[0.125,-0.125,0.125]],
131 | [[0.125,0.125,0.125],[-0.125,-0.125,0.125],[0.125,-0.125,-0.125]],
132 | [[0.0,-0.25,-0.25],[0.0,0.25,0.25],[0.125,0.125,0.125]],
133 | [[0.125,0.125,0.125],[0.125,-0.125,-0.125]],
134 | [[0.5,0.0,-0.0],[0.25,-0.25,-0.25],[0.125,-0.125,-0.125]],
135 | [[-0.25,0.25,0.25],[-0.125,0.125,0.125],[-0.25,0.25,0.25],[0.125,-0.125,-0.125]],
136 | [[0.375,-0.375,0.375],[0.0,0.25,0.25],[-0.125,0.125,-0.125],[-0.25,0.0,0.25]],
137 | [[0.0,-0.5,0.0],[-0.25,0.25,0.25],[-0.125,0.125,0.125]],
138 | [[-0.375,-0.375,0.375],[0.25,-0.25,0.0],[0.0,0.25,0.25],[-0.125,-0.125,0.125]],
139 | [[-0.125,0.125,0.125],[-0.25,0.25,0.25],[0.0,0.0,0.5]],
140 | [[0.125,0.125,0.125],[0.0,0.25,0.25],[0.0,0.25,0.25]],
141 | [[0.0,0.25,0.25],[0.0,0.25,0.25]],
142 | [[0.5,0.0,-0.0],[0.25,0.25,0.25],[0.125,0.125,0.125],[0.125,0.125,0.125]],
143 | [[0.125,-0.125,0.125],[-0.125,-0.125,0.125],[0.125,0.125,0.125]],
144 | [[-0.25,-0.0,-0.25],[0.25,0.0,0.25],[0.125,0.125,0.125]],
145 | [[0.125,0.125,0.125],[0.125,-0.125,0.125]],
146 | [[-0.25,-0.25,0.0],[0.25,0.25,-0.0],[0.125,0.125,0.125]],
147 | [[0.125,0.125,0.125],[-0.125,-0.125,0.125]],
148 | [[0.125,0.125,0.125],[0.125,0.125,0.125]],
149 | [[0.125,0.125,0.125]],
150 | [[0.125,0.125,0.125]],
151 | [[0.125,0.125,0.125],[0.125,0.125,0.125]],
152 | [[0.125,0.125,0.125],[-0.125,-0.125,0.125]],
153 | [[-0.25,-0.25,0.0],[0.25,0.25,-0.0],[0.125,0.125,0.125]],
154 | [[0.125,0.125,0.125],[0.125,-0.125,0.125]],
155 | [[-0.25,-0.0,-0.25],[0.25,0.0,0.25],[0.125,0.125,0.125]],
156 | [[0.125,-0.125,0.125],[-0.125,-0.125,0.125],[0.125,0.125,0.125]],
157 | [[0.5,0.0,-0.0],[0.25,0.25,0.25],[0.125,0.125,0.125],[0.125,0.125,0.125]],
158 | [[0.0,0.25,0.25],[0.0,0.25,0.25]],
159 | [[0.125,0.125,0.125],[0.0,0.25,0.25],[0.0,0.25,0.25]],
160 | [[-0.125,0.125,0.125],[-0.25,0.25,0.25],[0.0,0.0,0.5]],
161 | [[-0.375,-0.375,0.375],[0.25,-0.25,0.0],[0.0,0.25,0.25],[-0.125,-0.125,0.125]],
162 | [[0.0,-0.5,0.0],[-0.25,0.25,0.25],[-0.125,0.125,0.125]],
163 | [[0.375,-0.375,0.375],[0.0,0.25,0.25],[-0.125,0.125,-0.125],[-0.25,0.0,0.25]],
164 | [[-0.25,0.25,0.25],[-0.125,0.125,0.125],[-0.25,0.25,0.25],[0.125,-0.125,-0.125]],
165 | [[0.5,0.0,-0.0],[0.25,-0.25,-0.25],[0.125,-0.125,-0.125]],
166 | [[0.125,0.125,0.125],[0.125,-0.125,-0.125]],
167 | [[0.0,-0.25,-0.25],[0.0,0.25,0.25],[0.125,0.125,0.125]],
168 | [[0.125,0.125,0.125],[-0.125,-0.125,0.125],[0.125,-0.125,-0.125]],
169 | [[-0.0,0.5,0.0],[-0.25,0.25,-0.25],[0.125,-0.125,0.125],[0.125,-0.125,0.125]],
170 | [[0.125,0.125,0.125],[0.125,-0.125,0.125],[0.125,-0.125,-0.125]],
171 | [[-0.0,0.0,0.5],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
172 | [[-0.125,0.125,0.125],[0.125,-0.125,0.125],[-0.125,-0.125,0.125],[0.125,0.125,0.125]],
173 | [[-0.125,0.125,0.125],[0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
174 | [[0.0,0.25,0.25],[0.0,0.25,0.25],[0.125,-0.125,-0.125]],
175 | [[0.0,-0.25,-0.25],[0.0,0.25,0.25],[0.0,0.25,0.25],[0.0,0.25,0.25]],
176 | [[0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125],[0.125,-0.125,0.125]],
177 | [[0.0,-0.25,0.25],[0.0,-0.25,0.25],[0.125,-0.125,0.125]],
178 | [[0.0,-0.5,0.0],[0.125,0.125,-0.125],[0.25,0.25,-0.25],[-0.125,-0.125,0.125]],
179 | [[-0.125,-0.125,0.125],[0.0,-0.25,0.25],[0.0,0.25,-0.25]],
180 | [[0.125,0.125,0.125],[-0.125,-0.125,0.125],[0.125,-0.125,0.125]],
181 | [[-0.125,-0.125,0.125],[0.125,-0.125,0.125]],
182 | [[0.25,0.0,0.25],[0.25,0.0,0.25]],
183 | [[0.125,0.125,0.125],[0.25,0.0,0.25],[0.25,0.0,0.25]],
184 | [[-0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125]],
185 | [[0.25,0.0,0.25],[-0.375,-0.375,0.375],[-0.25,0.25,0.0],[-0.125,-0.125,0.125]],
186 | [[0.125,-0.125,0.125],[0.25,0.0,0.25],[0.25,0.0,0.25]],
187 | [[-0.25,-0.0,-0.25],[0.25,0.0,0.25],[0.25,0.0,0.25],[0.25,0.0,0.25]],
188 | [[-0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125],[0.125,-0.125,0.125]],
189 | [[-0.25,0.0,0.25],[0.25,0.0,-0.25],[-0.125,0.125,0.125]],
190 | [[0.125,0.125,0.125],[0.25,0.25,0.25],[0.0,0.0,0.5]],
191 | [[0.125,0.125,0.125],[0.125,0.125,0.125],[0.25,0.25,0.25],[0.0,0.0,0.5]],
192 | [[-0.0,0.0,0.5],[0.0,0.0,0.5]],
193 | [[0.0,0.0,-0.5],[0.25,0.25,-0.25],[-0.125,-0.125,0.125]],
194 | [[-0.25,-0.0,-0.25],[-0.375,0.375,0.375],[-0.25,-0.25,0.0],[-0.125,0.125,0.125]],
195 | [[-0.125,-0.125,0.125],[-0.25,0.0,0.25],[0.25,0.0,-0.25]],
196 | [[-0.0,0.0,0.5],[-0.25,0.25,0.25],[-0.125,0.125,0.125]],
197 | [[-0.25,0.0,0.25],[0.25,0.0,-0.25]],
198 | [[0.5,0.0,0.0],[-0.25,0.25,-0.25],[0.125,-0.125,0.125]],
199 | [[-0.25,0.0,-0.25],[0.375,-0.375,-0.375],[0.0,0.25,-0.25],[-0.125,0.125,0.125]],
200 | [[-0.25,0.25,-0.25],[-0.25,0.25,-0.25],[-0.125,0.125,-0.125],[-0.125,0.125,-0.125]],
201 | [[-0.0,0.5,0.0],[-0.25,0.25,-0.25],[0.125,-0.125,0.125]],
202 | [[0.5,0.0,0.0],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
203 | [[-0.125,-0.125,0.125],[-0.25,0.0,0.25],[-0.25,0.0,0.25]],
204 | [[0.125,0.125,0.125],[-0.125,-0.125,0.125],[-0.125,0.125,0.125]],
205 | [[-0.125,-0.125,0.125],[-0.125,0.125,0.125]],
206 | [[0.375,-0.375,0.375],[0.0,-0.25,-0.25],[-0.125,0.125,-0.125],[0.25,0.25,0.0]],
207 | [[0.0,-0.25,0.25],[0.0,-0.25,0.25],[-0.125,-0.125,0.125]],
208 | [[0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125]],
209 | [[0.0,-0.25,0.25],[0.0,-0.25,0.25]],
210 | [[-0.125,-0.125,0.125],[-0.25,-0.25,0.0],[0.25,0.25,-0.0]],
211 | [[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
212 | [[0.125,0.125,0.125],[-0.125,-0.125,0.125]],
213 | [[-0.125,-0.125,0.125]],
214 | [[-0.25,-0.25,0.0],[-0.25,-0.25,0.0]],
215 | [[0.125,0.125,0.125],[-0.25,-0.25,0.0],[-0.25,-0.25,0.0]],
216 | [[-0.25,-0.25,0.0],[-0.25,-0.25,0.0],[-0.125,-0.125,0.125]],
217 | [[-0.25,-0.25,0.0],[-0.25,-0.25,0.0],[-0.25,-0.25,0.0],[0.25,0.25,-0.0]],
218 | [[0.0,0.5,0.0],[0.25,0.25,-0.25],[-0.125,-0.125,0.125]],
219 | [[-0.375,0.375,-0.375],[-0.25,-0.25,0.0],[-0.125,0.125,-0.125],[-0.25,0.0,0.25]],
220 | [[0.0,0.5,0.0],[0.25,0.25,-0.25],[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
221 | [[-0.125,0.125,0.125],[0.25,-0.25,0.0],[-0.25,0.25,0.0]],
222 | [[0.0,-0.5,0.0],[-0.25,-0.25,-0.25],[-0.125,-0.125,-0.125]],
223 | [[0.125,0.125,0.125],[0.0,-0.5,0.0],[-0.25,-0.25,-0.25],[-0.125,-0.125,-0.125]],
224 | [[-0.375,-0.375,-0.375],[-0.25,0.0,0.25],[-0.125,-0.125,-0.125],[-0.25,0.25,0.0]],
225 | [[0.25,-0.25,0.0],[-0.25,0.25,0.0],[0.125,-0.125,0.125]],
226 | [[0.0,0.5,0.0],[0.0,-0.5,0.0]],
227 | [[0.0,0.5,0.0],[0.125,-0.125,0.125],[-0.25,0.25,-0.25]],
228 | [[0.0,0.5,0.0],[-0.25,0.25,0.25],[0.125,-0.125,-0.125]],
229 | [[0.25,-0.25,0.0],[-0.25,0.25,0.0]],
230 | [[-0.5,0.0,0.0],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125]],
231 | [[0.0,0.25,-0.25],[0.375,-0.375,-0.375],[-0.125,0.125,0.125],[0.25,0.25,0.0]],
232 | [[0.5,0.0,0.0],[0.25,-0.25,0.25],[-0.125,0.125,-0.125],[0.125,-0.125,0.125]],
233 | [[0.125,-0.125,0.125],[0.25,-0.25,0.0],[0.25,-0.25,0.0]],
234 | [[0.25,0.25,-0.25],[0.25,0.25,-0.25],[0.125,0.125,-0.125],[-0.125,-0.125,0.125]],
235 | [[-0.0,0.0,0.5],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125]],
236 | [[0.125,0.125,0.125],[0.125,-0.125,0.125],[-0.125,0.125,0.125]],
237 | [[-0.125,0.125,0.125],[0.125,-0.125,0.125]],
238 | [[-0.375,-0.375,0.375],[-0.0,0.25,0.25],[0.125,0.125,-0.125],[-0.25,-0.0,-0.25]],
239 | [[0.0,-0.25,0.25],[0.0,0.25,-0.25],[0.125,-0.125,0.125]],
240 | [[0.125,-0.125,0.125],[-0.25,-0.0,-0.25],[0.25,0.0,0.25]],
241 | [[0.125,-0.125,0.125],[0.125,-0.125,0.125]],
242 | [[0.0,-0.5,0.0],[0.125,0.125,-0.125],[0.25,0.25,-0.25]],
243 | [[0.0,-0.25,0.25],[0.0,0.25,-0.25]],
244 | [[0.125,0.125,0.125],[0.125,-0.125,0.125]],
245 | [[0.125,-0.125,0.125]],
246 | [[-0.5,0.0,0.0],[-0.125,-0.125,-0.125],[-0.25,-0.25,-0.25]],
247 | [[-0.5,0.0,0.0],[-0.125,-0.125,-0.125],[-0.25,-0.25,-0.25],[0.125,0.125,0.125]],
248 | [[0.375,0.375,0.375],[0.0,0.25,-0.25],[-0.125,-0.125,-0.125],[-0.25,0.25,0.0]],
249 | [[0.125,-0.125,-0.125],[0.25,-0.25,0.0],[0.25,-0.25,0.0]],
250 | [[0.125,0.125,0.125],[0.375,0.375,0.375],[0.0,-0.25,0.25],[-0.25,0.0,0.25]],
251 | [[-0.25,0.0,0.25],[-0.25,0.0,0.25],[0.125,-0.125,-0.125]],
252 | [[0.0,-0.25,-0.25],[0.0,0.25,0.25],[-0.125,0.125,0.125]],
253 | [[-0.125,0.125,0.125],[0.125,-0.125,-0.125]],
254 | [[-0.125,-0.125,-0.125],[-0.25,-0.25,-0.25],[0.25,0.25,0.25],[0.125,0.125,0.125]],
255 | [[-0.125,-0.125,0.125],[0.125,-0.125,0.125],[0.125,-0.125,-0.125]],
256 | [[0.0,0.0,-0.5],[0.25,0.25,0.25],[-0.125,-0.125,-0.125]],
257 | [[0.125,-0.125,0.125],[0.125,-0.125,-0.125]],
258 | [[0.0,-0.5,0.0],[0.25,0.25,0.25],[0.125,0.125,0.125]],
259 | [[-0.125,-0.125,0.125],[0.125,-0.125,-0.125]],
260 | [[0.0,-0.25,-0.25],[0.0,0.25,0.25]],
261 | [[0.125,-0.125,-0.125]],
262 | [[0.5,0.0,0.0],[0.5,0.0,0.0]],
263 | [[-0.5,0.0,0.0],[-0.25,0.25,0.25],[-0.125,0.125,0.125]],
264 | [[0.5,0.0,0.0],[0.25,-0.25,0.25],[-0.125,0.125,-0.125]],
265 | [[0.25,-0.25,0.0],[0.25,-0.25,0.0]],
266 | [[0.5,0.0,0.0],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125]],
267 | [[-0.25,0.0,0.25],[-0.25,0.0,0.25]],
268 | [[0.125,0.125,0.125],[-0.125,0.125,0.125]],
269 | [[-0.125,0.125,0.125]],
270 | [[0.5,0.0,-0.0],[0.25,0.25,0.25],[0.125,0.125,0.125]],
271 | [[0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
272 | [[-0.25,-0.0,-0.25],[0.25,0.0,0.25]],
273 | [[0.125,-0.125,0.125]],
274 | [[-0.25,-0.25,0.0],[0.25,0.25,-0.0]],
275 | [[-0.125,-0.125,0.125]],
276 | [[0.125,0.125,0.125]],
277 | [[0,0,0]]]
278 |
279 |
280 | def compute_surface_distances(mask_gt, mask_pred, spacing_mm):
281 | """Compute closest distances from all surface points to the other surface.
282 |
283 | Finds all surface elements "surfels" in the ground truth mask `mask_gt` and
284 | the predicted mask `mask_pred`, computes their area in mm^2 and the distance
285 | to the closest point on the other surface. It returns two sorted lists of
286 | distances together with the corresponding surfel areas. If one of the masks
287 | is empty, the corresponding lists are empty and all distances in the other
288 | list are `inf`
289 |
290 | Args:
291 | mask_gt: 3-dim Numpy array of type bool. The ground truth mask.
292 | mask_pred: 3-dim Numpy array of type bool. The predicted mask.
293 | spacing_mm: 3-element list-like structure. Voxel spacing in x0, x1 and x2
294 | direction
295 |
296 | Returns:
297 | A dict with
298 | "distances_gt_to_pred": 1-dim numpy array of type float. The distances in mm
299 | from all ground truth surface elements to the predicted surface,
300 | sorted from smallest to largest
301 | "distances_pred_to_gt": 1-dim numpy array of type float. The distances in mm
302 | from all predicted surface elements to the ground truth surface,
303 | sorted from smallest to largest
304 | "surfel_areas_gt": 1-dim numpy array of type float. The area in mm^2 of
305 | the ground truth surface elements in the same order as
306 | distances_gt_to_pred
307 | "surfel_areas_pred": 1-dim numpy array of type float. The area in mm^2 of
308 | the predicted surface elements in the same order as
309 | distances_pred_to_gt
310 |
311 | """
312 |
313 | # compute the area for all 256 possible surface elements
314 | # (given a 2x2x2 neighbourhood) according to the spacing_mm
315 | neighbour_code_to_surface_area = np.zeros([256])
316 | for code in range(256):
317 | normals = np.array(neighbour_code_to_normals[code])
318 | sum_area = 0
319 | for normal_idx in range(normals.shape[0]):
320 | # normal vector
321 | n = np.zeros([3])
322 | n[0] = normals[normal_idx,0] * spacing_mm[1] * spacing_mm[2]
323 | n[1] = normals[normal_idx,1] * spacing_mm[0] * spacing_mm[2]
324 | n[2] = normals[normal_idx,2] * spacing_mm[0] * spacing_mm[1]
325 | area = np.linalg.norm(n)
326 | sum_area += area
327 | neighbour_code_to_surface_area[code] = sum_area
328 |
329 | # compute the bounding box of the masks to trim
330 | # the volume to the smallest possible processing subvolume
331 | mask_all = mask_gt | mask_pred
332 | bbox_min = np.zeros(3, np.int64)
333 | bbox_max = np.zeros(3, np.int64)
334 |
335 | # max projection to the x0-axis
336 | proj_0 = np.max(np.max(mask_all, axis=2), axis=1)
337 | idx_nonzero_0 = np.nonzero(proj_0)[0]
338 | if len(idx_nonzero_0) == 0:
339 | return {"distances_gt_to_pred": np.array([]),
340 | "distances_pred_to_gt": np.array([]),
341 | "surfel_areas_gt": np.array([]),
342 | "surfel_areas_pred": np.array([])}
343 |
344 | bbox_min[0] = np.min(idx_nonzero_0)
345 | bbox_max[0] = np.max(idx_nonzero_0)
346 |
347 | # max projection to the x1-axis
348 | proj_1 = np.max(np.max(mask_all, axis=2), axis=0)
349 | idx_nonzero_1 = np.nonzero(proj_1)[0]
350 | bbox_min[1] = np.min(idx_nonzero_1)
351 | bbox_max[1] = np.max(idx_nonzero_1)
352 |
353 | # max projection to the x2-axis
354 | proj_2 = np.max(np.max(mask_all, axis=1), axis=0)
355 | idx_nonzero_2 = np.nonzero(proj_2)[0]
356 | bbox_min[2] = np.min(idx_nonzero_2)
357 | bbox_max[2] = np.max(idx_nonzero_2)
358 |
359 | # print("bounding box min = {}".format(bbox_min))
360 | # print("bounding box max = {}".format(bbox_max))
361 |
362 | # crop the processing subvolume.
363 | # we need to zeropad the cropped region with 1 voxel at the lower,
364 | # the right and the back side. This is required to obtain the "full"
365 | # convolution result with the 2x2x2 kernel
366 | cropmask_gt = np.zeros((bbox_max - bbox_min)+2, np.uint8)
367 | cropmask_pred = np.zeros((bbox_max - bbox_min)+2, np.uint8)
368 |
369 | cropmask_gt[0:-1, 0:-1, 0:-1] = mask_gt[bbox_min[0]:bbox_max[0]+1,
370 | bbox_min[1]:bbox_max[1]+1,
371 | bbox_min[2]:bbox_max[2]+1]
372 |
373 | cropmask_pred[0:-1, 0:-1, 0:-1] = mask_pred[bbox_min[0]:bbox_max[0]+1,
374 | bbox_min[1]:bbox_max[1]+1,
375 | bbox_min[2]:bbox_max[2]+1]
376 |
377 | # compute the neighbour code (local binary pattern) for each voxel
378 | # the resultsing arrays are spacially shifted by minus half a voxel in each axis.
379 | # i.e. the points are located at the corners of the original voxels
380 | kernel = np.array([[[128,64],
381 | [32,16]],
382 | [[8,4],
383 | [2,1]]])
384 | neighbour_code_map_gt = scipy.ndimage.filters.correlate(cropmask_gt.astype(np.uint8), kernel, mode="constant", cval=0)
385 | neighbour_code_map_pred = scipy.ndimage.filters.correlate(cropmask_pred.astype(np.uint8), kernel, mode="constant", cval=0)
386 |
387 | # create masks with the surface voxels
388 | borders_gt = ((neighbour_code_map_gt != 0) & (neighbour_code_map_gt != 255))
389 | borders_pred = ((neighbour_code_map_pred != 0) & (neighbour_code_map_pred != 255))
390 |
391 | # compute the distance transform (closest distance of each voxel to the surface voxels)
392 | if borders_gt.any():
393 | distmap_gt = scipy.ndimage.morphology.distance_transform_edt(~borders_gt, sampling=spacing_mm)
394 | else:
395 | distmap_gt = np.Inf * np.ones(borders_gt.shape)
396 |
397 | if borders_pred.any():
398 | distmap_pred = scipy.ndimage.morphology.distance_transform_edt(~borders_pred, sampling=spacing_mm)
399 | else:
400 | distmap_pred = np.Inf * np.ones(borders_pred.shape)
401 |
402 | # compute the area of each surface element
403 | surface_area_map_gt = neighbour_code_to_surface_area[neighbour_code_map_gt]
404 | surface_area_map_pred = neighbour_code_to_surface_area[neighbour_code_map_pred]
405 |
406 | # create a list of all surface elements with distance and area
407 | distances_gt_to_pred = distmap_pred[borders_gt]
408 | distances_pred_to_gt = distmap_gt[borders_pred]
409 | surfel_areas_gt = surface_area_map_gt[borders_gt]
410 | surfel_areas_pred = surface_area_map_pred[borders_pred]
411 |
412 | # sort them by distance
413 | if distances_gt_to_pred.shape != (0,):
414 | sorted_surfels_gt = np.array(sorted(zip(distances_gt_to_pred, surfel_areas_gt)))
415 | distances_gt_to_pred = sorted_surfels_gt[:,0]
416 | surfel_areas_gt = sorted_surfels_gt[:,1]
417 |
418 | if distances_pred_to_gt.shape != (0,):
419 | sorted_surfels_pred = np.array(sorted(zip(distances_pred_to_gt, surfel_areas_pred)))
420 | distances_pred_to_gt = sorted_surfels_pred[:,0]
421 | surfel_areas_pred = sorted_surfels_pred[:,1]
422 |
423 |
424 | return {"distances_gt_to_pred": distances_gt_to_pred,
425 | "distances_pred_to_gt": distances_pred_to_gt,
426 | "surfel_areas_gt": surfel_areas_gt,
427 | "surfel_areas_pred": surfel_areas_pred}
428 |
429 |
430 | def compute_average_surface_distance(surface_distances):
431 | distances_gt_to_pred = surface_distances["distances_gt_to_pred"]
432 | distances_pred_to_gt = surface_distances["distances_pred_to_gt"]
433 | surfel_areas_gt = surface_distances["surfel_areas_gt"]
434 | surfel_areas_pred = surface_distances["surfel_areas_pred"]
435 | average_distance_gt_to_pred = np.sum( distances_gt_to_pred * surfel_areas_gt) / np.sum(surfel_areas_gt)
436 | average_distance_pred_to_gt = np.sum( distances_pred_to_gt * surfel_areas_pred) / np.sum(surfel_areas_pred)
437 | return (average_distance_gt_to_pred, average_distance_pred_to_gt)
438 |
439 | def compute_robust_hausdorff(surface_distances, percent):
440 | distances_gt_to_pred = surface_distances["distances_gt_to_pred"]
441 | distances_pred_to_gt = surface_distances["distances_pred_to_gt"]
442 | surfel_areas_gt = surface_distances["surfel_areas_gt"]
443 | surfel_areas_pred = surface_distances["surfel_areas_pred"]
444 | if len(distances_gt_to_pred) > 0:
445 | surfel_areas_cum_gt = np.cumsum(surfel_areas_gt) / np.sum(surfel_areas_gt)
446 | idx = np.searchsorted(surfel_areas_cum_gt, percent/100.0)
447 | perc_distance_gt_to_pred = distances_gt_to_pred[min(idx, len(distances_gt_to_pred)-1)]
448 | else:
449 | perc_distance_gt_to_pred = np.Inf
450 |
451 | if len(distances_pred_to_gt) > 0:
452 | surfel_areas_cum_pred = np.cumsum(surfel_areas_pred) / np.sum(surfel_areas_pred)
453 | idx = np.searchsorted(surfel_areas_cum_pred, percent/100.0)
454 | perc_distance_pred_to_gt = distances_pred_to_gt[min(idx, len(distances_pred_to_gt)-1)]
455 | else:
456 | perc_distance_pred_to_gt = np.Inf
457 |
458 | return max( perc_distance_gt_to_pred, perc_distance_pred_to_gt)
459 |
460 | def compute_surface_overlap_at_tolerance(surface_distances, tolerance_mm):
461 | distances_gt_to_pred = surface_distances["distances_gt_to_pred"]
462 | distances_pred_to_gt = surface_distances["distances_pred_to_gt"]
463 | surfel_areas_gt = surface_distances["surfel_areas_gt"]
464 | surfel_areas_pred = surface_distances["surfel_areas_pred"]
465 | rel_overlap_gt = np.sum(surfel_areas_gt[distances_gt_to_pred <= tolerance_mm]) / np.sum(surfel_areas_gt)
466 | rel_overlap_pred = np.sum(surfel_areas_pred[distances_pred_to_gt <= tolerance_mm]) / np.sum(surfel_areas_pred)
467 | return (rel_overlap_gt, rel_overlap_pred)
468 |
469 | def compute_surface_dice_at_tolerance(surface_distances, tolerance_mm):
470 | distances_gt_to_pred = surface_distances["distances_gt_to_pred"]
471 | distances_pred_to_gt = surface_distances["distances_pred_to_gt"]
472 | surfel_areas_gt = surface_distances["surfel_areas_gt"]
473 | surfel_areas_pred = surface_distances["surfel_areas_pred"]
474 | overlap_gt = np.sum(surfel_areas_gt[distances_gt_to_pred <= tolerance_mm])
475 | overlap_pred = np.sum(surfel_areas_pred[distances_pred_to_gt <= tolerance_mm])
476 | surface_dice = (overlap_gt + overlap_pred) / (
477 | np.sum(surfel_areas_gt) + np.sum(surfel_areas_pred))
478 | return surface_dice
479 |
480 |
481 | def compute_dice_coefficient(mask_gt, mask_pred):
482 | """Compute soerensen-dice coefficient.
483 |
484 | compute the soerensen-dice coefficient between the ground truth mask `mask_gt`
485 | and the predicted mask `mask_pred`.
486 |
487 | Args:
488 | mask_gt: 3-dim Numpy array of type bool. The ground truth mask.
489 | mask_pred: 3-dim Numpy array of type bool. The predicted mask.
490 |
491 | Returns:
492 | the dice coeffcient as float. If both masks are empty, the result is NaN
493 | """
494 | volume_sum = mask_gt.sum() + mask_pred.sum()
495 | if volume_sum == 0:
496 | return np.NaN
497 | volume_intersect = (mask_gt & mask_pred).sum()
498 | return 2*volume_intersect / volume_sum
499 |
500 |
--------------------------------------------------------------------------------
/CVPR25_iter_eval.py:
--------------------------------------------------------------------------------
1 | """
2 | The code was adapted from the CVPR24 Segment Anything in Medical Images on a Laptop Challenge
3 | https://www.codabench.org/competitions/1847/
4 |
5 | pip install connected-components-3d
6 | pip install cupy-cuda12x
7 | pip install cucim-cu12
8 |
9 |
10 | The testing images will be evaluated one by one.
11 |
12 | Folder structure:
13 | CVPR25_iter_eval.py
14 | --docker_folder path # submitted docker containers from participants
15 | - docker_dir
16 | - teamname_1.tar.gz
17 | - teamname_2.tar.gz
18 | - ...
19 | --test_img_path # test images
20 | - imgs
21 | - case1.npz # test image
22 | - case2.npz
23 | - ...
24 | --save_path # segmentation results
25 | - output
26 | - case1.npz # segmentation file name is the same as the testing image name
27 | - case2.npz
28 | - ...
29 | --validation_gts_path # path to validation / test set GT files
30 | - Contains the npz files with the same name as the images but only 'gts' key is available in each file instead of storing it in the image itself. This is done to prevent label leakage during the challenge.
31 | - validation_gts
32 | - case1.npz # file containing only the 'gts' key
33 | - case2.npz
34 | - ...
35 | --verbose
36 | - Whether to have a more detailed output, e.g. coordinates of generated clicks
37 |
38 |
39 | This script is designed for evaluating docker submissions for the CVPR25: Foundation Models for Interactive 3D Biomedical Image Segmentation Challenge Challenge
40 |
41 | ##########################################################
42 | ######### Docker Submission Evaluation Process ###########
43 | ##########################################################
44 | Submissions for the CVPR 2025: Foundation Models for Interactive 3D Biomedical Image Segmentation Challenge will be evaluated using an iterative refinement approach.
45 | Each participant's Docker container will be tested on a set of medical images provided as .npz files.
46 | The evaluation process follows these key steps:
47 | - Initial Prediction: Image + Bounding Box Prompt (1 prediction)
48 | - Each test case begins with a bounding box prompt, specified in the 'bbox' key of the test image. This serves as the starting point for the segmentation.
49 | - Iterative Click Refinements: Image + Bounding Box + 1-5 Clicks (5 predictions)
50 | - After the initial segmentation, we iteratively simulate 5 refinement clicks to address segmentation errors. These clicks are automatically generated based on the center of the largest error region in the current prediction:
51 | - If the center of the largest error is an undersegmentation, we simulate and place a foreground click.
52 | - If the center of the largest error is an oversegmentation, we simulate and place a background click.
53 | - The clicks are stored in the clicks key of the 'npz' file and progressively updated during the second step of the evaluation.
54 |
55 | ###############################################################
56 | ######### How are interactions (bbox, clicks) stored? #########
57 | ###############################################################
58 | The interactions are stored in the 'bbox' and 'clicks' keys of each input .npz image.
59 | - The bounding box is stored in the 'bbox' key as a list of dictionaries [{'z_min': 27, 'z_max': 396, 'z_mid': 311, 'z_mid_x_min': 175, 'z_mid_y_min': 94, 'z_mid_x_max': 278, 'z_mid_y_max': 233}, ...] containing bbox coordinates for each class.
60 | - The clicks are provided in the 'clicks' key as a list of dictionaries [{'fg': [click_fg_1, clicks_fg_2,...], 'bg': [click_bg_1, click_bg_2,...]}, ...]
61 | where click_fg_i and click_bg_i are 3-element arrays with the 3D click coordinates [x, y, z].
62 |
63 | #######################################
64 | ######### Performance Metrics #########
65 | #######################################
66 | For each image, multi-class segmentation quality is evaluated using:
67 | - Dice Similarity Coefficient (DSC) and Normalized Surface Dice (NSD), calculated iteratively over the 6 steps (bounding box + 5 clicks).
68 | - AUC (Area Under the Curve) for DSC and NSD to measure cumulative improvement with more interactions.
69 | - Final DSC and NSD after all interactions.
70 | - Inference Time averaged over all 6 steps.
71 |
72 | ##########################
73 | ######### Output #########
74 | ##########################
75 | Results are saved in .npz format with metrics compiled into a CSV file for each submission. 5 metrics are stored: DSC_AUC, NSD_AUC, Final_DSC, Final_NSD, Inference Time.
76 |
77 |
78 | ################################
79 | ######### Script Steps #########
80 | ################################
81 | This script executes the following steps:
82 | 1. Docker Submission Handling:
83 | - Loads docker containers submitted by participants.
84 | - Executes inference for each test image using the participant's docker container. Images are infered one by one.
85 |
86 | 2. Iterative Refinement:
87 | - The initial bounding box prediction is refined iteratively by simulating user clicks at the centers of segmentation errors for each class in the image.
88 | - The Euclidean Distance Transform (EDT) is computed for error regions to identify the distance to the boundary of each error component,
89 | ensuring clicks are placed at locations at the center of the largest error for refinement.
90 | - For each image, the docker is run 6 times for inference:
91 | - 1) Bounding Box initial prediction
92 | - 2)-6) Click refinement predictions (each new click is placed in the center of the largest error component)
93 | - If the center of the largest error is part of the background --> a background click is placed
94 | - Otherwise, a foreground click is placed
95 | - Steps 1)-6) are done in parallel for all segmentation classes in 6 interaction steps (6 docker runs)
96 |
97 | 3. GPU vs. CPU Computation:
98 | - If a GPU is available, the script uses `cupy` and `cucim` for accelerated EDT computation.
99 | - For CPU-only environments, `scipy.ndimage.distance_transform_edt` is used as a fallback.
100 |
101 | 4. Metrics Calculation:
102 | - Computes multi-class DSC and NSD for each image.
103 | - For the final metrics, the AUC (Area Under the Curve) for the DSC and NSD are computed for iterative improvement across the 6 interactive iterations.
104 | - The AUC quantifies the cumulative performance improvement over the 6 successive iterations (bbox + 5 clicks) providing a holistic view of the segmentation refinement process.
105 | - The final DSC and NSD after all 6 interactive steps are also computed.
106 | - These metrics reflect the final segmentation quality achieved after all refinements, indicating the model's final performance.
107 | - The last metric is the inference time which is the average inference time over the 6 interactive steps.
108 |
109 | 5. Output:
110 | - Segmentation results are saved in the specified output directory.
111 | - Final prediction in the 'segs' key
112 | - Intermediate prediction in the 'all_segs' key
113 | - Metrics for each test case are compiled into a CSV file.
114 |
115 | #################################
116 | ############## Misc##############
117 | #################################
118 | - The input image also contains the 'prev_pred' key which stores the prediction from the previous iteration. This is used only to help with submission that are using the previous prediction as an additional input and is not
119 | a mandatory input.
120 | """
121 |
122 | import os
123 | join = os.path.join
124 | import shutil
125 | import time
126 | import torch
127 | import argparse
128 | from collections import OrderedDict
129 | import pandas as pd
130 | import numpy as np
131 | import traceback
132 |
133 | from scipy.ndimage import distance_transform_edt
134 | import cc3d
135 | from SurfaceDice import compute_surface_distances, compute_surface_dice_at_tolerance, compute_dice_coefficient
136 | from scipy import integrate
137 | from tqdm import tqdm
138 |
139 | # Taken from CVPR24 challenge code with change to np.unique
140 | def compute_multi_class_dsc(gt, seg):
141 | dsc = []
142 | for i in np.sort(pd.unique(gt.ravel()))[1:]: # skip bg
143 | gt_i = gt == i
144 | seg_i = seg == i
145 | dsc.append(compute_dice_coefficient(gt_i, seg_i))
146 | return np.mean(dsc)
147 |
148 | # Taken from CVPR24 challenge code with change to np.unique
149 | def compute_multi_class_nsd(gt, seg, spacing, tolerance=2.0):
150 | nsd = []
151 | for i in np.sort(pd.unique(gt.ravel()))[1:]: # skip bg
152 | gt_i = gt == i
153 | seg_i = seg == i
154 | surface_distance = compute_surface_distances(
155 | gt_i, seg_i, spacing_mm=spacing
156 | )
157 | nsd.append(compute_surface_dice_at_tolerance(surface_distance, tolerance))
158 | return np.mean(nsd)
159 |
160 | def patched_np_load(*args, **kwargs):
161 | with np.load(*args, **kwargs) as f:
162 | return dict(f)
163 |
164 | def sample_coord(edt):
165 | # Find all coordinates with max EDT value
166 | np.random.seed(42)
167 |
168 | max_val = edt.max()
169 | max_coords = np.argwhere(edt == max_val)
170 |
171 | # Uniformly choose one of them
172 | chosen_index = max_coords[np.random.choice(len(max_coords))]
173 |
174 | center = tuple(chosen_index)
175 | return center
176 |
177 | # Compute the EDT with same shape as the image
178 | def compute_edt(error_component):
179 | # Get bounding box of the largest error component to limit computation
180 | coords = np.argwhere(error_component)
181 | min_coords = coords.min(axis=0)
182 | max_coords = coords.max(axis=0) + 1
183 |
184 | crop_shape = max_coords - min_coords
185 |
186 | # Compute padding (25% of crop size in each dimension)
187 | padding = np.maximum((crop_shape * 0.25).astype(int), 1)
188 |
189 |
190 | # Define new padded shape
191 | padded_shape = crop_shape + 2 * padding
192 |
193 | # Create new empty array with padding
194 | center_crop = np.zeros(padded_shape, dtype=np.uint8)
195 |
196 | # Fill center region with actual cropped data
197 | center_crop[
198 | padding[0]:padding[0] + crop_shape[0],
199 | padding[1]:padding[1] + crop_shape[1],
200 | padding[2]:padding[2] + crop_shape[2]
201 | ] = error_component[
202 | min_coords[0]:max_coords[0],
203 | min_coords[1]:max_coords[1],
204 | min_coords[2]:max_coords[2]
205 | ]
206 |
207 | large_roi = False
208 | if center_crop.shape[0] * center_crop.shape[1] * center_crop.shape[2] > 60000000:
209 | from skimage.measure import block_reduce
210 | print(f'ROI too large {center_crop.shape} --> 2x downsampling for EDT')
211 | center_crop = block_reduce(center_crop, block_size=(2, 2, 2), func=np.max)
212 | large_roi = True
213 |
214 | # Compute EDT on the padded array
215 | if torch.cuda.is_available() and not large_roi: # GPU available
216 | import cupy as cp
217 | from cucim.core.operations import morphology
218 | error_mask_cp = cp.array(center_crop)
219 | edt_cp = morphology.distance_transform_edt(error_mask_cp, return_distances=True)
220 | edt = cp.asnumpy(edt_cp)
221 | else: # CPU available only
222 | edt = distance_transform_edt(center_crop)
223 |
224 | if large_roi: # upsample
225 | edt = edt.repeat(2, axis=0).repeat(2, axis=1).repeat(2, axis=2)
226 |
227 | # Crop out the center (remove padding)
228 | dist_cropped = edt[
229 | padding[0]:padding[0] + crop_shape[0],
230 | padding[1]:padding[1] + crop_shape[1],
231 | padding[2]:padding[2] + crop_shape[2]
232 | ]
233 |
234 | # Create full-sized EDT result array and splat back
235 | dist_full = np.zeros_like(error_component, dtype=dist_cropped.dtype)
236 | dist_full[
237 | min_coords[0]:max_coords[0],
238 | min_coords[1]:max_coords[1],
239 | min_coords[2]:max_coords[2]
240 | ] = dist_cropped
241 |
242 | dist_transformed = dist_full
243 |
244 | return dist_transformed
245 |
246 | parser = argparse.ArgumentParser('Segmentation iterative refinement with clicks eavluation for docker containers', add_help=False)
247 | parser.add_argument('-i', '--test_img_path', default='3D_val_npz', type=str, help='testing data path')
248 | parser.add_argument('-o','--save_path', default='./seg', type=str, help='segmentation output path')
249 | parser.add_argument('-d','--docker_folder_path', default='./team_docker', type=str, help='team docker path')
250 | parser.add_argument('-val_gts','--validation_gts_path', default='3D_val_gt_interactive_seg', type=str, help='path to validation set (or final test set) GT files')
251 | parser.add_argument('-v','--verbose', default=False, action='store_true', help="Verbose output, e.g., print coordinates of generated clicks")
252 |
253 | args = parser.parse_args()
254 |
255 | test_img_path = args.test_img_path
256 | save_path = args.save_path
257 | docker_path = args.docker_folder_path
258 | validation_gts_path = args.validation_gts_path
259 | verbose = args.verbose
260 |
261 | if not os.path.exists(validation_gts_path):
262 | validation_gts_path = None
263 | print('[WARNING] Validation path does not exist for your GT data! Make sure you supplied the correct path or your .npz inputs have a gts key!')
264 |
265 | input_temp = './inputs/'
266 | output_temp = './outputs'
267 | os.makedirs(save_path, exist_ok=True)
268 |
269 | dockers = sorted(os.listdir(docker_path))
270 | test_cases = sorted(os.listdir(test_img_path))
271 |
272 | for docker in dockers:
273 | try:
274 | # create temp folers for inference one-by-one
275 | if os.path.exists(input_temp):
276 | shutil.rmtree(input_temp)
277 | if os.path.exists(output_temp):
278 | shutil.rmtree(output_temp)
279 | os.makedirs(input_temp)
280 | os.makedirs(output_temp)
281 |
282 | # load docker and create a new folder to save segmentation results
283 | teamname = docker.split('.')[0].lower()
284 | print('teamname docker: ', docker)
285 | os.system('docker image load -i {}'.format(join(docker_path, docker)))
286 | team_outpath = join(save_path, teamname)
287 | if os.path.exists(team_outpath):
288 | shutil.rmtree(team_outpath)
289 | os.makedirs(team_outpath)
290 | os.system(f'chmod -R 777 ./* >/dev/null 2>&1') # ignore output warnings/errors of this command with >/dev/null 2>&1
291 |
292 | # Evaluation Metrics
293 | metric = OrderedDict()
294 | metric['CaseName'] = []
295 | # 5 Metrics
296 | metric['TotalRunningTime'] = []
297 | metric['RunningTime_1'] = []
298 | metric['RunningTime_2'] = []
299 | metric['RunningTime_3'] = []
300 | metric['RunningTime_4'] = []
301 | metric['RunningTime_5'] = []
302 | metric['RunningTime_6'] = []
303 | metric['DSC_AUC'] = []
304 | metric['NSD_AUC'] = []
305 | metric['DSC_Final'] = []
306 | metric['NSD_Final'] = []
307 | metric['DSC_1'] = []
308 | metric['DSC_2'] = []
309 | metric['DSC_3'] = []
310 | metric['DSC_4'] = []
311 | metric['DSC_5'] = []
312 | metric['DSC_6'] = []
313 | metric['NSD_1'] = []
314 | metric['NSD_2'] = []
315 | metric['NSD_3'] = []
316 | metric['NSD_4'] = []
317 | metric['NSD_5'] = []
318 | metric['NSD_6'] = []
319 | metric['num_class'] = []
320 | metric['runtime_upperbound'] = []
321 | n_clicks = 5
322 | time_warning = False
323 |
324 | # To obtain the running time for each case, testing cases are inferred one-by-one
325 | for case in tqdm(test_cases):
326 |
327 | metric_temp = {}
328 | real_running_time = 0
329 | dscs = []
330 | nsds = []
331 | all_segs = []
332 | no_bbox = False
333 |
334 | # copy input image to accumulate clicks in its dict
335 | shutil.copy(join(test_img_path, case), input_temp)
336 | if validation_gts_path is None: # for training images
337 | gts = patched_np_load(join(input_temp, case), allow_pickle=True)['gts']
338 | else: # for validation or test images --> gts are in separate files to avoid label leakage during the course of the challenge
339 | gts = patched_np_load(join(validation_gts_path, case), allow_pickle=True)['gts']
340 |
341 | unique_gts = np.sort(pd.unique(gts.ravel()))
342 | num_classes = len(unique_gts) - 1
343 | metric_temp['num_class'] = num_classes
344 | metric_temp['runtime_upperbound'] = num_classes * 90
345 |
346 |
347 | # foreground and background clicks for each class
348 | clicks_cls = [{'fg': [], 'bg': []} for _ in unique_gts[1:]] # skip background class 0
349 | clicks_order = [[] for _ in unique_gts[1:]]
350 | if "boxes" in patched_np_load(join(input_temp, case), allow_pickle=True).keys():
351 | boxes = patched_np_load(join(input_temp, case), allow_pickle=True)['boxes']
352 |
353 |
354 | for it in range(n_clicks + 1): # + 1 due to bbox pred at iteration 0
355 | if it == 0:
356 | if "boxes" not in patched_np_load(join(input_temp, case), allow_pickle=True).keys():
357 | if verbose:
358 | print(f'This sample does not use a Bounding Box for the initial iteration {it}')
359 | no_bbox = True
360 | metric_temp["RunningTime_1"] = 0
361 | metric_temp["DSC_1"] = 0
362 | metric_temp["NSD_1"] = 0
363 | dscs.append(0)
364 | nsds.append(0)
365 | continue
366 | if verbose:
367 | print(f'Using Bounding Box for iteration {it}')
368 | else:
369 | if verbose:
370 | print(f'Using Clicks for iteration {it}')
371 | if os.path.isfile(join(output_temp, case)):
372 | segs = patched_np_load(join(output_temp, case), allow_pickle=True)['segs'].astype(np.uint8) # previous prediction
373 | else:
374 | segs = np.zeros_like(gts).astype(np.uint8) # in case the bbox prediction did not produce a result
375 |
376 | # Refinement clicks
377 | for ind, cls in enumerate(sorted(unique_gts[1:])):
378 | if cls == 0:
379 | continue # skip background
380 |
381 | segs_cls = (segs == cls).astype(np.uint8)
382 | gts_cls = (gts == cls).astype(np.uint8)
383 |
384 | # Compute error mask
385 | error_mask = (segs_cls != gts_cls).astype(np.uint8)
386 | if np.sum(error_mask) > 0:
387 | errors = cc3d.connected_components(error_mask, connectivity=26) # 26 for 3D connectivity
388 |
389 | # Calculate the sizes of connected error components
390 | component_sizes = np.bincount(errors.flat)
391 |
392 | # Ignore non-error regions
393 | component_sizes[0] = 0
394 |
395 | # Find the largest error component
396 | largest_component_error = np.argmax(component_sizes)
397 |
398 | # Find the voxel coordinates of the largest error component
399 | largest_component = (errors == largest_component_error)
400 |
401 | edt = compute_edt(largest_component)
402 | edt *= largest_component # make sure correct voxels have a distance of 0
403 | if np.sum(edt) == 0: # no valid voxels to sample
404 | if verbose:
405 | print("Error is extremely small --> Sampling uniformly instead of using EDT")
406 | edt = largest_component # in case EDT is empty (due to artifacts in resizing, simply sample a random voxel from the component), happens only for extremely small errors
407 |
408 | center = sample_coord(edt)
409 |
410 | if gts_cls[center] == 0: # oversegmentation -> place background click
411 | assert segs_cls[center] == 1
412 | clicks_cls[ind]['bg'].append(list(center))
413 | clicks_order[ind].append('bg')
414 | else: # undersegmentation -> place foreground click
415 | assert segs_cls[center] == 0
416 | clicks_cls[ind]['fg'].append(list(center))
417 | clicks_order[ind].append('fg')
418 |
419 | assert largest_component[center] # click within error
420 |
421 | if verbose:
422 | print(f"Class {cls}: Largest error component center is at {center}")
423 | else:
424 | clicks_order[ind].append(None)
425 | if verbose:
426 | print(f"Class {cls}: No error connected components found. Prediction is perfect! No clicks were added.")
427 |
428 | # update model input with new click
429 | input_img = patched_np_load(join(input_temp, case), allow_pickle=True)
430 |
431 | if validation_gts_path is None:
432 | if no_bbox:
433 | np.savez_compressed(
434 | join(input_temp, case),
435 | imgs=input_img['imgs'],
436 | gts=input_img['gts'], # only for training images
437 | spacing=input_img['spacing'],
438 | clicks=clicks_cls,
439 | clicks_order=clicks_order,
440 | prev_pred=segs,
441 | )
442 | else:
443 | np.savez_compressed(
444 | join(input_temp, case),
445 | imgs=input_img['imgs'],
446 | gts=input_img['gts'], # only for training images
447 | spacing=input_img['spacing'],
448 | clicks=clicks_cls,
449 | clicks_order=clicks_order,
450 | prev_pred=segs,
451 | boxes=boxes,
452 | )
453 | else:
454 | if no_bbox:
455 | np.savez_compressed(
456 | join(input_temp, case),
457 | imgs=input_img['imgs'],
458 | spacing=input_img['spacing'],
459 | clicks=clicks_cls,
460 | clicks_order=clicks_order,
461 | prev_pred=segs,
462 | )
463 | else:
464 | np.savez_compressed(
465 | join(input_temp, case),
466 | imgs=input_img['imgs'],
467 | spacing=input_img['spacing'],
468 | clicks=clicks_cls,
469 | clicks_order=clicks_order,
470 | prev_pred=segs,
471 | boxes=boxes,
472 | )
473 |
474 | # Model inference on the current input
475 | if torch.cuda.is_available(): # GPU available
476 | cmd = 'docker container run --gpus "device=0" -m 32G --name {} --rm -v $PWD/inputs/:/workspace/inputs/ -v $PWD/outputs/:/workspace/outputs/ {}:latest /bin/bash -c "sh predict.sh" '.format(teamname.replace('/', '_'), teamname.split('_')[0])
477 | else:
478 | cmd = 'docker container run -m 32G --name {} --rm -v $PWD/inputs/:/workspace/inputs/ -v $PWD/outputs/:/workspace/outputs/ {}:latest /bin/bash -c "sh predict.sh" '.format(teamname.replace('/', '_'), teamname.split('_')[0])
479 | if verbose:
480 | print(teamname, ' docker command:', cmd, '\n', 'testing image name:', case)
481 | start_time = time.time()
482 | os.system(cmd)
483 | infer_time = time.time() - start_time
484 | real_running_time += infer_time # only add the inference time without the click generation time
485 | print(f"{case} finished! Inference time: {infer_time}")
486 | metric_temp[f"RunningTime_{it + 1}"] = infer_time
487 |
488 | if not os.path.isfile(join(output_temp, case)):
489 | print(f"[WARNING] Failed / Skipped prediction for iteration {it}! Setting prediction to zeros...")
490 | segs = np.zeros_like(gts).astype(np.uint8)
491 | else:
492 | segs = patched_np_load(join(output_temp, case), allow_pickle=True)['segs']
493 | all_segs.append(segs.astype(np.uint8))
494 |
495 | dsc = compute_multi_class_dsc(gts, segs)
496 | # compute nsd
497 | if dsc > 0.2:
498 | # only compute nsd when dice > 0.2 because NSD is also low when dice is too low
499 | nsd = compute_multi_class_nsd(gts, segs, patched_np_load(join(input_temp, case), allow_pickle=True)['spacing'])
500 | else:
501 | nsd = 0.0 # Assume model performs poor on this sample
502 | dscs.append(dsc)
503 | nsds.append(nsd)
504 | metric_temp[f'DSC_{it + 1}'] = dsc
505 | metric_temp[f'NSD_{it + 1}'] = nsd
506 | print('Dice', dsc, 'NSD', nsd)
507 | seg_name = case
508 |
509 |
510 | # Copy temp prediction to the final folder
511 | try:
512 | shutil.copy(join(output_temp, seg_name), join(team_outpath, seg_name))
513 | segs = patched_np_load(join(team_outpath, seg_name), allow_pickle=True)['segs']
514 | np.savez_compressed(
515 | join(team_outpath, seg_name),
516 | segs=segs,
517 | all_segs=all_segs, # store all intermediate predictions
518 | )
519 | except:
520 | print(f"{join(output_temp, seg_name)}, {join(team_outpath, seg_name)}")
521 | if os.path.exists(join(team_outpath, seg_name)):
522 | os.remove(team_outpath, seg_name) # clean up cached files if model has failed
523 | print("Final prediction could not be copied!")
524 |
525 |
526 | if real_running_time > 90 * (len(unique_gts) - 1):
527 | print("[WARNING] Your model seems to take more than 90 seconds per class during inference! The final test set will have a time constraint of 90s per class --> Make sure to optimize your approach!")
528 | time_warning = True
529 | # Compute interactive metrics
530 | dsc_auc = integrate.cumulative_trapezoid(np.array(dscs[-n_clicks:]), np.arange(n_clicks))[-1] # AUC is only over the point prompts since the bbox prompt is optional
531 | nsd_auc = integrate.cumulative_trapezoid(np.array(nsds[-n_clicks:]), np.arange(n_clicks))[-1]
532 | dsc_final = dscs[-1]
533 | nsd_final = nsds[-1]
534 | if os.path.exists(join(team_outpath, seg_name)): # add to csv only if final prediction is successful
535 | for k, v in metric_temp.items():
536 | metric[k].append(v)
537 | metric['CaseName'].append(case)
538 | metric['TotalRunningTime'].append(real_running_time)
539 | metric['DSC_AUC'].append(dsc_auc)
540 | metric['NSD_AUC'].append(nsd_auc)
541 | metric['DSC_Final'].append(dsc_final)
542 | metric['NSD_Final'].append(nsd_final)
543 | os.remove(join(input_temp, case))
544 |
545 | metric_df = pd.DataFrame(metric)
546 | metric_df.to_csv(join(team_outpath, teamname + '_metrics.csv'), index=False)
547 |
548 | # Clean up for next docker
549 | torch.cuda.empty_cache()
550 | os.system("docker rmi {}:latest".format(teamname.split('_')[0]))
551 | shutil.rmtree(input_temp)
552 | shutil.rmtree(output_temp)
553 | if time_warning: # repeat warning at the end as well
554 | print("[WARNING] Your model seems to take more than 90 seconds per class during inference for some images! The final test set will have a time constraint of 90s per class --> Make sure to optimize your approach!")
555 | except Exception as e:
556 | print(e)
557 | traceback.print_exc()
558 | print(f"Error processing {case} with docker {docker}. Skipping this docker.")
559 |
--------------------------------------------------------------------------------