├── voxtell
├── model
│ ├── __init__.py
│ ├── transformer.py
│ └── voxtell_model.py
├── utils
│ ├── __init__.py
│ └── text_embedding.py
├── inference
│ ├── __init__.py
│ ├── predict_from_raw_data.py
│ └── predictor.py
└── __init__.py
├── documentation
└── assets
│ ├── VoxTellLogo.png
│ ├── VoxTellConcepts.png
│ └── VoxTellArchitecture.png
├── pyproject.toml
├── .gitignore
├── readme.md
└── LICENSE
/voxtell/model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/voxtell/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/voxtell/inference/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/voxtell/__init__.py:
--------------------------------------------------------------------------------
1 | from importlib.metadata import version as _version
2 |
3 | __version__ = _version("voxtell")
--------------------------------------------------------------------------------
/documentation/assets/VoxTellLogo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIC-DKFZ/VoxTell/HEAD/documentation/assets/VoxTellLogo.png
--------------------------------------------------------------------------------
/documentation/assets/VoxTellConcepts.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIC-DKFZ/VoxTell/HEAD/documentation/assets/VoxTellConcepts.png
--------------------------------------------------------------------------------
/documentation/assets/VoxTellArchitecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIC-DKFZ/VoxTell/HEAD/documentation/assets/VoxTellArchitecture.png
--------------------------------------------------------------------------------
/voxtell/utils/text_embedding.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def last_token_pool(last_hidden_states: torch.Tensor,
4 | attention_mask: torch.Tensor) -> torch.Tensor:
5 | left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
6 | if left_padding:
7 | return last_hidden_states[:, -1]
8 | else:
9 | sequence_lengths = attention_mask.sum(dim=1) - 1
10 | batch_size = last_hidden_states.shape[0]
11 | return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
12 |
13 |
14 | def wrap_with_instruction(text_prompts):
15 | instruct = 'Given an anatomical term query, retrieve the precise anatomical entity and location it represents'
16 |
17 | instruct_text_prompts = []
18 | for text in text_prompts:
19 | instruct_text_prompts.append(f'Instruct: {instruct}\nQuery: {text}')
20 | return instruct_text_prompts
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "voxtell"
3 | version = "0.1.0"
4 | requires-python = ">=3.10"
5 | description = "Free-Text Promptable Universal 3D Medical Image Segmentation"
6 | readme = "readme.md"
7 | license = { file = "LICENSE" }
8 | authors = [
9 | { name = "Maximilian Rokuss", email = "maximilian.rokuss@dkfz-heidelberg.de"},
10 | { name = "Moritz Langenberg", email = "moritz.langenberg@dkfz-heidelberg.de"},
11 | { name = "Yannick Kirchhoff" },
12 | { name = "Fabian Isensee" },
13 | { name = "Benjamin Hamm" },
14 | { name = "Constantin Ulrich" },
15 | { name = "Sebastian Regnery" },
16 | { name = "Lukas Bauer" },
17 | { name = "Efthimios Katsigiannopulos" },
18 | { name = "Tobias Norajitra" },
19 | { name = "Klaus Maier-Hein" }
20 | ]
21 | classifiers = [
22 | "Development Status :: 3 - Alpha",
23 | "Intended Audience :: Developers",
24 | "Intended Audience :: Science/Research",
25 | "Intended Audience :: Healthcare Industry",
26 | "Programming Language :: Python :: 3",
27 | "License :: OSI Approved :: Apache Software License",
28 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
29 | "Topic :: Scientific/Engineering :: Image Recognition",
30 | "Topic :: Scientific/Engineering :: Medical Science Apps.",
31 | ]
32 | keywords = [
33 | 'deep learning',
34 | 'image segmentation',
35 | 'semantic segmentation',
36 | 'medical image analysis',
37 | 'medical image segmentation',
38 | 'vision-language model',
39 | 'text-prompted segmentation',
40 | '3D segmentation',
41 | 'volumetric segmentation'
42 | ]
43 | dependencies = [
44 | "torch>=2.0,<2.9",
45 | "nnunetv2>=2.6",
46 | "acvl-utils>=0.2.3,<0.3",
47 | "dynamic-network-architectures>=0.4.1,<0.5",
48 | "numpy>=1.24",
49 | "SimpleITK",
50 | "nibabel",
51 | "positional-encodings",
52 | "einops",
53 | "tqdm",
54 | "transformers>=4.54.0,<5",
55 | ]
56 |
57 | [project.urls]
58 | repository = "https://github.com/MIC-DKFZ/VoxTell"
59 | paper = "https://arxiv.org/abs/2511.11450"
60 |
61 | [project.scripts]
62 | voxtell-predict = "voxtell.inference.predict_from_raw_data:main"
63 |
64 | [project.optional-dependencies]
65 | dev = [
66 | "black",
67 | "ruff",
68 | "pre-commit"
69 | ]
70 |
71 | [build-system]
72 | requires = ["setuptools>=67.8.0", "wheel"]
73 | build-backend = "setuptools.build_meta"
74 |
75 | [tool.codespell]
76 | skip = '.git,*.pdf,*.svg,*.png'
77 | #
78 | # ignore-words-list = ''
79 |
80 | [tool.setuptools.packages.find]
81 | where = ["."]
82 | include = ["voxtell*"]
83 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # ==============================================================================
2 | # Python
3 | # ==============================================================================
4 |
5 | # Byte-compiled / optimized / DLL files
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 |
10 | # C extensions
11 | *.so
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | .hypothesis/
50 | .pytest_cache/
51 |
52 | # Translations
53 | *.mo
54 | *.pot
55 |
56 | # Logs
57 | *.log
58 |
59 | # ==============================================================================
60 | # Virtual Environments
61 | # ==============================================================================
62 |
63 | venv/
64 | env/
65 | ENV/
66 | .venv/
67 | .python-version
68 |
69 | # ==============================================================================
70 | # IDEs and Editors
71 | # ==============================================================================
72 |
73 | # PyCharm / IntelliJ
74 | .idea/
75 |
76 | # Spyder
77 | .spyderproject
78 |
79 | # Rope
80 | .ropeproject
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # VS Code
86 | .vscode/
87 |
88 | # ==============================================================================
89 | # Environment Variables
90 | # ==============================================================================
91 |
92 | .env
93 | .env.local
94 |
95 | # ==============================================================================
96 | # Medical Imaging & Data Files
97 | # ==============================================================================
98 |
99 | # NIfTI files
100 | *.nii
101 | *.nii.gz
102 |
103 | # Image formats
104 | *.tif
105 | *.tiff
106 | *.bmp
107 | *.dcm
108 |
109 | # Numpy arrays
110 | *.npy
111 | *.npz
112 | *.memmap
113 |
114 | # Model files
115 | *.model
116 | *.pth
117 | *.pt
118 | *.ckpt
119 | *.h5
120 | *.weights
121 |
122 | # Pickle files
123 | *.pkl
124 | *.pickle
125 |
126 | # ==============================================================================
127 | # Documentation
128 | # ==============================================================================
129 |
130 | # Sphinx
131 | docs/_build/
132 | docs/_static/
133 | docs/_templates/
134 |
135 | # PDF documentation
136 | *.pdf
137 |
138 | # ==============================================================================
139 | # Miscellaneous
140 | # ==============================================================================
141 |
142 | # Compressed files
143 | *.zip
144 | *.tar
145 | *.tar.gz
146 | *.rar
147 |
148 | # XML files
149 | *.xml
150 |
151 | # Temporary files
152 | *.tmp
153 | *.temp
154 | *~
155 |
156 | # OS generated files
157 | .DS_Store
158 | .DS_Store?
159 | ._*
160 | .Spotlight-V100
161 | .Trashes
162 | ehthumbs.db
163 | Thumbs.db
164 |
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | # VoxTell: Free-Text Promptable Universal 3D Medical Image Segmentation
2 |
3 |
4 |
5 | This repository contains the official implementation of our paper:
6 |
7 | ### **VoxTell: Free-Text Promptable Universal 3D Medical Image Segmentation**
8 |
9 | VoxTell is a **3D vision–language segmentation model** that directly maps free-form text prompts, from single words to full clinical sentences, to volumetric masks. By leveraging **multi-stage vision–language fusion**, VoxTell achieves state-of-the-art performance on anatomical and pathological structures across CT, PET, and MRI modalities, excelling on familiar concepts while generalizing to related unseen classes.
10 |
11 | > **Authors**: Maximilian Rokuss*, Moritz Langenberg*, Yannick Kirchhoff, Fabian Isensee, Benjamin Hamm, Constantin Ulrich, Sebastian Regnery, Lukas Bauer, Efthimios Katsigiannopulos, Tobias Norajitra, Klaus Maier-Hein
12 | > **Paper**: [](https://arxiv.org/abs/2511.11450)
13 |
14 | ---
15 |
16 | ## Overview
17 |
18 | VoxTell is trained on a **large-scale, multi-modality 3D medical imaging dataset**, aggregating **158 public sources** with over **62,000 volumetric images**. The data covers:
19 |
20 | - Brain, head & neck, thorax, abdomen, pelvis
21 | - Musculoskeletal system and extremities
22 | - Vascular structures, major organs, substructures, and lesions
23 |
24 |
25 |
26 | This rich semantic diversity enables **language-conditioned 3D reasoning**, allowing VoxTell to generate volumetric masks from flexible textual descriptions, from coarse anatomical labels to fine-grained pathological findings.
27 |
28 | ---
29 |
30 | ## Architecture
31 |
32 | VoxTell combines **3D image encoding** with **text-prompt embeddings** and **multi-stage vision–language fusion**:
33 |
34 | - **Image Encoder**: Processes 3D volumetric input into latent feature representations
35 | - **Prompt Encoder**: We use the fozen [Qwen3-Embedding-4B](https://huggingface.co/Qwen/Qwen3-Embedding-4B) model to embed text prompts
36 | - **Prompt Decoder**: Transforms text queries and image latents into multi-scale text features
37 | - **Image Decoder**: Fuses visual and textual information at multiple resolutions using MaskFormer-style query-image fusion with deep supervision
38 |
39 |
40 |
41 | ---
42 |
43 | ## 🛠 Installation
44 |
45 | ### 1. Create a Virtual Environment
46 |
47 | VoxTell supports Python 3.10+ and works with Conda, pip, or any other virtual environment manager. Here's an example using Conda:
48 |
49 | ```bash
50 | conda create -n voxtell python=3.12
51 | conda activate voxtell
52 | ```
53 |
54 | ### 2. Install PyTorch
55 |
56 | > [!WARNING]
57 | > **Temporary Compatibility Warning**
58 | > There is a known issue with **PyTorch 2.9.0** causing **OOM errors during inference** in `nnInteractive` (related to 3D convolutions — see the PyTorch issue [here](https://github.com/pytorch/pytorch/issues/166122)).
59 | > **Until this is resolved, please use PyTorch 2.8.0 or earlier.**
60 |
61 | Install PyTorch compatible with your CUDA version. For example, for Ubuntu with a modern NVIDIA GPU:
62 |
63 | ```bash
64 | pip install torch==2.8.0 torchvision==0.23.0 --index-url https://download.pytorch.org/whl/cu126
65 | ```
66 |
67 | *For other configurations (macOS, CPU, different CUDA versions), please refer to the [PyTorch Get Started](https://pytorch.org/get-started/previous-versions/) page.*
68 |
69 | Install via pip (you can also use [uv](https://docs.astral.sh/uv/)):
70 |
71 | ```bash
72 | pip install voxtell
73 | ```
74 |
75 | or install directly from the repository:
76 |
77 | ```bash
78 | git clone https://github.com/MIC-DKFZ/VoxTell
79 | cd VoxTell
80 | pip install -e .
81 | ```
82 |
83 | ---
84 |
85 | ## 🚀 Getting Started
86 |
87 | > [!NOTE]
88 | > The model will soon be available on Hugging Face! 🤗 Stay tuned for the official release.
89 |
90 | ### Command-Line Interface (CLI)
91 |
92 | VoxTell provides a convenient command-line interface for running predictions:
93 |
94 | ```bash
95 | voxtell-predict -i input.nii.gz -o output_folder -m /path/to/model -p "liver" "spleen" "kidney"
96 | ```
97 |
98 | **Single prompt:**
99 | ```bash
100 | voxtell-predict -i case001.nii.gz -o output_folder -m /path/to/model -p "liver"
101 | # Output: output_folder/case001_liver.nii.gz
102 | ```
103 |
104 | **Multiple prompts (saves individual files by default):**
105 | ```bash
106 | voxtell-predict -i case001.nii.gz -o output_folder -m /path/to/model -p "liver" "spleen" "right kidney"
107 | # Outputs:
108 | # output_folder/case001_liver.nii.gz
109 | # output_folder/case001_spleen.nii.gz
110 | # output_folder/case001_right_kidney.nii.gz
111 | ```
112 |
113 | **Save combined multi-label file:**
114 | ```bash
115 | voxtell-predict -i case001.nii.gz -o output_folder -m /path/to/model -p "liver" "spleen" --save-combined
116 | # Output: output_folder/case001.nii.gz (multi-label: 1=liver, 2=spleen)
117 | # ⚠️ WARNING: Overlapping structures will be overwritten by later prompts
118 | ```
119 |
120 | #### CLI Options
121 |
122 | | Argument | Short | Required | Description |
123 | |----------|-------|----------|-------------|
124 | | `--input` | `-i` | Yes | Path to input NIfTI file |
125 | | `--output` | `-o` | Yes | Path to output folder |
126 | | `--model` | `-m` | Yes | Path to VoxTell model directory |
127 | | `--prompts` | `-p` | Yes | Text prompt(s) for segmentation |
128 | | `--device` | | No | Device to use: `cuda` (default) or `cpu` |
129 | | `--gpu` | | No | GPU device ID (default: 0) |
130 | | `--save-combined` | | No | Save multi-label file instead of individual files |
131 | | `--verbose` | | No | Enable verbose output |
132 |
133 | ---
134 |
135 | ### Python API
136 |
137 | For more control or integration into Python workflows, use the Python API:
138 |
139 | ```python
140 | import torch
141 | from voxtell.inference.predictor import VoxTellPredictor
142 | from nnunetv2.imageio.nibabel_reader_writer import NibabelIOWithReorient
143 |
144 | # Select device
145 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
146 |
147 | # Load image
148 | image_path = "/path/to/your/image.nii.gz"
149 | img, _ = NibabelIOWithReorient().read_images([image_path])
150 |
151 | # Define text prompts
152 | text_prompts = ["liver", "right kidney", "left kidney", "spleen"]
153 |
154 | # Initialize predictor
155 | predictor = VoxTellPredictor(
156 | model_dir="/path/to/voxtell_model_directory",
157 | device=device,
158 | )
159 |
160 | # Run prediction
161 | # Output shape: (num_prompts, x, y, z)
162 | voxtell_seg = predictor.predict_single_image(img, text_prompts)
163 | ```
164 |
165 | #### Optional: Visualize Results
166 |
167 | You can visualize the segmentation results using [napari](https://napari.org/):
168 |
169 | ```bash
170 | pip install napari[all]
171 | ```
172 |
173 | ```python
174 | import napari
175 | import numpy as np
176 |
177 | # Create a napari viewer and add the original image
178 | viewer = napari.Viewer()
179 | viewer.add_image(img, name='Image')
180 |
181 | # Add segmentation results as label layers for each prompt
182 | for i, prompt in enumerate(text_prompts):
183 | viewer.add_labels(voxtell_seg[i].astype(np.uint8), name=prompt)
184 |
185 | # Run napari
186 | napari.run()
187 | ```
188 |
189 | ## Important: Image Orientation and Spacing
190 |
191 | - ⚠️ **Image Orientation (Critical)**: For correct anatomical localization (e.g., distinguishing left from right), images **must be in RAS orientation**. VoxTell was trained on data reoriented using [this specific reader](https://github.com/MIC-DKFZ/nnUNet/blob/86606c53ef9f556d6f024a304b52a48378453641/nnunetv2/imageio/nibabel_reader_writer.py#L101). Orientation mismatches can be a source of error. An easy way to test for this is if a simple prompt like "liver" fails and segments parts of the spleen instead. Make sure your image metadata is correct.
192 |
193 | - **Image Spacing**: The model does not resample images to a standardized spacing for faster inference. Performance may degrade on images with very uncommon voxel spacings (e.g., super high-resolution brain MRI). In such cases, consider resampling the image to a more typical clinical spacing (e.g., 1.5×1.5×1.5 mm³) before segmentation.
194 |
195 | ---
196 |
197 | ## 🗺️ Roadmap
198 |
199 | - [x] **Paper Published**: [arXiv:2511.11450](https://arxiv.org/abs/2511.11450)
200 | - [x] **Code Release**: Official implementation published
201 | - [x] **PyPI Package**: Package downloadable via pip
202 | - [ ] **Napari Plugin**: Integration into the napari viewer
203 | - [ ] **Model Release**: Public availability of pretrained weights
204 | - [ ] **Fine-Tuning**: Support and scripts for custom fine-tuning
205 |
206 | ---
207 |
208 | ## Citation
209 |
210 | ```bibtex
211 | @misc{rokuss2025voxtell,
212 | title={VoxTell: Free-Text Promptable Universal 3D Medical Image Segmentation},
213 | author={Maximilian Rokuss and Moritz Langenberg and Yannick Kirchhoff and Fabian Isensee and Benjamin Hamm and Constantin Ulrich and Sebastian Regnery and Lukas Bauer and Efthimios Katsigiannopulos and Tobias Norajitra and Klaus Maier-Hein},
214 | year={2025},
215 | eprint={2511.11450},
216 | archivePrefix={arXiv},
217 | primaryClass={cs.CV},
218 | url={https://arxiv.org/abs/2511.11450},
219 | }
220 | ```
221 |
222 | ---
223 |
224 | ## 📬 Contact
225 |
226 | For questions, issues, or collaborations, please contact:
227 |
228 | 📧 maximilian.rokuss@dkfz-heidelberg.de / moritz.langenberg@dkfz-heidelberg.de
--------------------------------------------------------------------------------
/voxtell/inference/predict_from_raw_data.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Command-line entrypoint for VoxTell segmentation prediction.
4 |
5 | This script provides a CLI interface to run VoxTell predictions on medical images
6 | with free-text prompts.
7 | """
8 |
9 | import argparse
10 | import sys
11 | from pathlib import Path
12 | from typing import List, Optional
13 |
14 | import numpy as np
15 | import torch
16 |
17 | from nnunetv2.imageio.nibabel_reader_writer import NibabelIOWithReorient
18 | from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO
19 |
20 | from voxtell.inference.predictor import VoxTellPredictor
21 |
22 |
23 | def get_reader_writer(file_path: str):
24 | """
25 | Determine the appropriate reader/writer based on file extension.
26 |
27 | Args:
28 | file_path: Path to the input file.
29 |
30 | Returns:
31 | Appropriate reader/writer instance.
32 | """
33 | suffix = Path(file_path).suffix.lower()
34 | if suffix in ['.nii', '.gz']:
35 | return NibabelIOWithReorient()
36 | else:
37 | raise ValueError(
38 | f"Unsupported file format: {suffix}. "
39 | "Only NIfTI format (.nii, .nii.gz) is currently supported. "
40 | "Images must be reorientable to standard orientation with correct metadata."
41 | )
42 |
43 |
44 | def save_segmentation(
45 | segmentation: np.ndarray,
46 | output_folder: Path,
47 | input_filename: str,
48 | properties: dict,
49 | prompt_name: str = None,
50 | suffix: str = '.nii.gz'
51 | ) -> None:
52 | """
53 | Save segmentation mask to file.
54 |
55 | Args:
56 | segmentation: Segmentation array to save.
57 | output_folder: Output folder path.
58 | input_filename: Original input filename (without extension).
59 | properties: Image properties from the reader.
60 | prompt_name: Optional prompt name to include in filename.
61 | suffix: File extension to use.
62 | """
63 | if prompt_name:
64 | # Clean prompt name for filename
65 | safe_name = "".join(c if c.isalnum() or c in (' ', '_') else '_' for c in prompt_name)
66 | safe_name = safe_name.replace(' ', '_')
67 | output_file = output_folder / f"{input_filename}_{safe_name}{suffix}"
68 | else:
69 | output_file = output_folder / f"{input_filename}{suffix}"
70 |
71 | # Use NIfTI writer
72 | reader_writer = NibabelIOWithReorient()
73 | reader_writer.write_seg(segmentation, str(output_file), properties)
74 | print(f"Saved segmentation to: {output_file}")
75 |
76 |
77 | def parse_args() -> argparse.Namespace:
78 | """Parse command-line arguments."""
79 | parser = argparse.ArgumentParser(
80 | description="VoxTell: Free-Text Promptable Universal 3D Medical Image Segmentation",
81 | formatter_class=argparse.RawDescriptionHelpFormatter,
82 | epilog="""
83 | Examples:
84 | # Single prompt (saves to output_folder/case001_liver.nii.gz)
85 | voxtell-predict -i case001.nii.gz -o output_folder -m /path/to/model -p "liver"
86 |
87 | # Multiple prompts (saves individual files by default)
88 | voxtell-predict -i case001.nii.gz -o output_folder -m /path/to/model -p "liver" "spleen" "kidney"
89 |
90 | # Save combined multi-label file (with overlap warning)
91 | voxtell-predict -i case001.nii.gz -o output_folder -m /path/to/model -p "liver" "spleen" --save-combined
92 |
93 | # Use CPU
94 | voxtell-predict -i case001.nii.gz -o output_folder -m /path/to/model -p "liver" --device cpu
95 | """
96 | )
97 |
98 | parser.add_argument(
99 | '-i', '--input',
100 | type=str,
101 | required=True,
102 | help='Path to input image file (NIfTI format recommended)'
103 | )
104 |
105 | parser.add_argument(
106 | '-o', '--output',
107 | type=str,
108 | required=True,
109 | help='Path to output folder where segmentation files will be saved'
110 | )
111 |
112 | parser.add_argument(
113 | '-m', '--model',
114 | type=str,
115 | required=True,
116 | help='Path to VoxTell model directory containing plans.json and fold_0/'
117 | )
118 |
119 | parser.add_argument(
120 | '-p', '--prompts',
121 | type=str,
122 | nargs='+',
123 | required=True,
124 | help='Text prompt(s) for segmentation (e.g., "liver" "spleen" "tumor")'
125 | )
126 |
127 | parser.add_argument(
128 | '--device',
129 | type=str,
130 | default='cuda',
131 | choices=['cuda', 'cpu'],
132 | help='Device to use for inference (default: cuda)'
133 | )
134 |
135 | parser.add_argument(
136 | '--gpu',
137 | type=int,
138 | default=0,
139 | help='GPU device ID to use (default: 0)'
140 | )
141 |
142 | parser.add_argument(
143 | '--save-combined',
144 | action='store_true',
145 | help='Save all prompts in a single multi-label file (WARNING: overlapping structures will be overwritten by later prompts)'
146 | )
147 |
148 | parser.add_argument(
149 | '--verbose',
150 | action='store_true',
151 | help='Enable verbose output'
152 | )
153 |
154 | return parser.parse_args()
155 |
156 |
157 | def main() -> int:
158 | """Main entrypoint function."""
159 | args = parse_args()
160 |
161 | # Validate inputs
162 | input_path = Path(args.input)
163 | if not input_path.exists():
164 | raise FileNotFoundError(f"Input file does not exist: {input_path}")
165 |
166 | model_path = Path(args.model)
167 | if not model_path.exists():
168 | raise FileNotFoundError(f"Model directory does not exist: {model_path}")
169 |
170 | if not (model_path / 'plans.json').exists():
171 | raise FileNotFoundError(f"plans.json not found in model directory: {model_path}")
172 |
173 | if not (model_path / 'fold_0' / 'checkpoint_final.pth').exists():
174 | raise FileNotFoundError(f"checkpoint_final.pth not found in {model_path / 'fold_0'}")
175 |
176 | # Setup device
177 | if args.device == 'cuda':
178 | if not torch.cuda.is_available():
179 | print("Warning: CUDA not available, falling back to CPU", file=sys.stderr)
180 | device = torch.device('cpu')
181 | else:
182 | device = torch.device(f'cuda:{args.gpu}')
183 | if args.verbose:
184 | print(f"Using GPU: {args.gpu} ({torch.cuda.get_device_name(args.gpu)})")
185 | else:
186 | device = torch.device('cpu')
187 | if args.verbose:
188 | print("Using CPU")
189 |
190 | # Load image
191 | if args.verbose:
192 | print(f"Loading image: {input_path}")
193 |
194 | try:
195 | reader_writer = get_reader_writer(str(input_path))
196 | img, props = reader_writer.read_images([str(input_path)])
197 | except Exception as e:
198 | print(f"Error loading image: {e}", file=sys.stderr)
199 | return 1
200 |
201 | if args.verbose:
202 | print(f"Image shape: {img.shape}")
203 | print(f"Text prompts: {args.prompts}")
204 | print(f"Loading VoxTell model from: {model_path}")
205 |
206 | predictor = VoxTellPredictor(
207 | model_dir=str(model_path),
208 | device=device
209 | )
210 |
211 | # Run prediction
212 | if args.verbose:
213 | print("Running prediction...")
214 |
215 | segmentations = predictor.predict_single_image(img, args.prompts)
216 |
217 | # Save results
218 | output_folder = Path(args.output)
219 | output_folder.mkdir(parents=True, exist_ok=True)
220 |
221 | # Get input filename without extension
222 | input_filename = input_path.stem
223 | if input_filename.endswith('.nii'):
224 | input_filename = input_filename[:-4]
225 |
226 | # Determine file suffix from input
227 | if input_path.suffix == '.gz' and input_path.stem.endswith('.nii'):
228 | suffix = '.nii.gz'
229 | else:
230 | suffix = input_path.suffix
231 |
232 | if args.save_combined:
233 | # Show warning about overlapping structures
234 | if len(args.prompts) > 1:
235 | print("\n" + "="*80)
236 | print("WARNING: Saving combined multi-label segmentation.")
237 | print("If prompts generate overlapping structures, later prompts will overwrite")
238 | print("earlier ones. This may result in loss of segmentation information.")
239 | print("Consider using individual file output (default) for overlapping structures.")
240 | print("="*80 + "\n")
241 |
242 | # Save all prompts in a single multi-label file
243 | if len(args.prompts) == 1:
244 | # Single prompt - save as-is
245 | save_segmentation(segmentations[0], output_folder, input_filename, props, suffix=suffix)
246 | else:
247 | # Multiple prompts - create multi-label segmentation
248 | # Each prompt gets a different label value (1, 2, 3, ...)
249 | # Later prompts overwrite earlier ones in case of overlap
250 | combined_seg = np.zeros_like(segmentations[0], dtype=np.uint8)
251 | for i, seg in enumerate(segmentations):
252 | combined_seg[seg > 0] = i + 1
253 | save_segmentation(combined_seg, output_folder, input_filename, props, suffix=suffix)
254 |
255 | print("\nLabel mapping:")
256 | for i, prompt in enumerate(args.prompts):
257 | print(f" {i + 1}: {prompt}")
258 | else:
259 | # Default: Save each prompt as a separate file
260 | for i, prompt in enumerate(args.prompts):
261 | save_segmentation(
262 | segmentations[i],
263 | output_folder,
264 | input_filename,
265 | props,
266 | prompt_name=prompt,
267 | suffix=suffix
268 | )
269 |
270 | if args.verbose:
271 | print("\nPrediction completed successfully!")
272 |
273 | return 0
274 |
275 |
276 | if __name__ == '__main__':
277 | sys.exit(main())
278 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [2019] [Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/voxtell/model/transformer.py:
--------------------------------------------------------------------------------
1 | """
2 | Text prompt decoder implementation for VoxTell.
3 |
4 | Code modified from DETR transformer:
5 | https://github.com/facebookresearch/detr
6 | Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
7 | """
8 |
9 | import copy
10 | from typing import Callable, List, Optional, Tuple, Union
11 |
12 | import torch
13 | import torch.nn.functional as F
14 | from torch import nn, Tensor
15 |
16 |
17 | class TransformerDecoder(nn.Module):
18 | """
19 | Transformer decoder consisting of multiple decoder layers.
20 |
21 | This decoder processes target sequences with attention to memory (encoder output),
22 | optionally returning intermediate layer outputs. It receives the text prompt embeddings
23 | as queries and attends to image features from the encoder. It outputs refined text-image
24 | fused features for segmentation mask prediction.
25 |
26 | Args:
27 | decoder_layer: A single transformer decoder layer to be cloned.
28 | num_layers: Number of decoder layers to stack.
29 | norm: Optional normalization layer applied to the final output.
30 | return_intermediate: If True, returns outputs from all layers stacked together.
31 | """
32 |
33 | def __init__(
34 | self,
35 | decoder_layer: nn.Module,
36 | num_layers: int,
37 | norm: Optional[nn.Module] = None,
38 | return_intermediate: bool = False
39 | ) -> None:
40 | super().__init__()
41 | self.layers = _get_clones(decoder_layer, num_layers)
42 | self.num_layers = num_layers
43 | self.norm = norm
44 | self.return_intermediate = return_intermediate
45 |
46 | def forward(
47 | self,
48 | tgt: Tensor,
49 | memory: Tensor,
50 | tgt_mask: Optional[Tensor] = None,
51 | memory_mask: Optional[Tensor] = None,
52 | tgt_key_padding_mask: Optional[Tensor] = None,
53 | memory_key_padding_mask: Optional[Tensor] = None,
54 | pos: Optional[Tensor] = None,
55 | query_pos: Optional[Tensor] = None
56 | ) -> Union[Tensor, Tuple[Tensor, List[Tensor]]]:
57 | """
58 | Forward pass through all decoder layers.
59 |
60 | Args:
61 | tgt: Target sequence tensor of shape [T, B, C].
62 | memory: Memory (encoder output) tensor of shape [T, B, C].
63 | tgt_mask: Attention mask for target self-attention.
64 | memory_mask: Attention mask for cross-attention to memory.
65 | tgt_key_padding_mask: Padding mask for target keys.
66 | memory_key_padding_mask: Padding mask for memory keys.
67 | pos: Positional embeddings for memory.
68 | query_pos: Positional embeddings for queries.
69 |
70 | Returns:
71 | If return_intermediate is True, returns stacked intermediate outputs.
72 | Otherwise, returns tuple of (final_output, attention_weights_list).
73 | """
74 | output = tgt
75 | T, B, C = memory.shape
76 | intermediate = []
77 | atten_layers = []
78 |
79 | for n, layer in enumerate(self.layers):
80 | residual = True
81 | output, ws = layer(
82 | output, memory,
83 | tgt_mask=tgt_mask,
84 | memory_mask=memory_mask,
85 | tgt_key_padding_mask=tgt_key_padding_mask,
86 | memory_key_padding_mask=memory_key_padding_mask,
87 | pos=pos,
88 | query_pos=query_pos,
89 | residual=residual
90 | )
91 | atten_layers.append(ws)
92 | if self.return_intermediate:
93 | intermediate.append(self.norm(output))
94 |
95 | if self.norm is not None:
96 | output = self.norm(output)
97 | if self.return_intermediate:
98 | intermediate.pop()
99 | intermediate.append(output)
100 |
101 | if self.return_intermediate:
102 | return torch.stack(intermediate)
103 | return output, atten_layers
104 |
105 |
106 |
107 | class TransformerDecoderLayer(nn.Module):
108 | """
109 | Single transformer decoder layer with self-attention, cross-attention, and FFN.
110 |
111 | This layer implements:
112 | 1. Self-attention on the target sequence
113 | 2. Cross-attention between target and memory (image encoder output)
114 | 3. Position-wise feed-forward network
115 |
116 | Args:
117 | d_model: Dimension of the model (embedding dimension).
118 | nhead: Number of attention heads.
119 | dim_feedforward: Dimension of the feedforward network.
120 | dropout: Dropout probability.
121 | activation: Activation function name ('relu', 'gelu', or 'glu').
122 | normalize_before: If True, applies layer norm before attention/FFN (pre-norm).
123 | If False, applies after (post-norm).
124 | """
125 |
126 | def __init__(
127 | self,
128 | d_model: int,
129 | nhead: int,
130 | dim_feedforward: int = 2048,
131 | dropout: float = 0.1,
132 | activation: str = "relu",
133 | normalize_before: bool = False
134 | ) -> None:
135 | super().__init__()
136 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
137 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
138 |
139 | # Feed-forward network
140 | self.linear1 = nn.Linear(d_model, dim_feedforward)
141 | self.dropout = nn.Dropout(dropout)
142 | self.linear2 = nn.Linear(dim_feedforward, d_model)
143 |
144 | # Normalization layers
145 | self.norm1 = nn.LayerNorm(d_model)
146 | self.norm2 = nn.LayerNorm(d_model)
147 | self.norm3 = nn.LayerNorm(d_model)
148 |
149 | # Dropout layers
150 | self.dropout1 = nn.Dropout(dropout)
151 | self.dropout2 = nn.Dropout(dropout)
152 | self.dropout3 = nn.Dropout(dropout)
153 |
154 | self.activation = _get_activation_fn(activation)
155 | self.normalize_before = normalize_before
156 |
157 | def with_pos_embed(self, tensor: Tensor, pos: Optional[Tensor]) -> Tensor:
158 | """
159 | Add positional embeddings to tensor if provided.
160 |
161 | Args:
162 | tensor: Input tensor.
163 | pos: Optional positional embeddings to add.
164 |
165 | Returns:
166 | Tensor with positional embeddings added, or original tensor if pos is None.
167 | """
168 | return tensor if pos is None else tensor + pos
169 |
170 | def forward_post(
171 | self,
172 | tgt: Tensor,
173 | memory: Tensor,
174 | tgt_mask: Optional[Tensor] = None,
175 | memory_mask: Optional[Tensor] = None,
176 | tgt_key_padding_mask: Optional[Tensor] = None,
177 | memory_key_padding_mask: Optional[Tensor] = None,
178 | pos: Optional[Tensor] = None,
179 | query_pos: Optional[Tensor] = None,
180 | residual: bool = True
181 | ) -> Tuple[Tensor, Tensor]:
182 | """
183 | Post-norm forward pass: attention/FFN first, then normalization.
184 |
185 | Args:
186 | tgt: Target sequence.
187 | memory: Memory (encoder output).
188 | tgt_mask: Self-attention mask for target.
189 | memory_mask: Cross-attention mask.
190 | tgt_key_padding_mask: Padding mask for target keys.
191 | memory_key_padding_mask: Padding mask for memory keys.
192 | pos: Positional embeddings for memory.
193 | query_pos: Positional embeddings for queries.
194 | residual: Whether to use residual connections.
195 |
196 | Returns:
197 | Tuple of (output_tensor, attention_weights).
198 | """
199 | q = k = self.with_pos_embed(tgt, query_pos)
200 | tgt2, ws = self.self_attn(
201 | q, k, value=tgt,
202 | attn_mask=tgt_mask,
203 | key_padding_mask=tgt_key_padding_mask
204 | )
205 | tgt = self.norm1(tgt)
206 | tgt2, ws = self.multihead_attn(
207 | query=self.with_pos_embed(tgt, query_pos),
208 | key=self.with_pos_embed(memory, pos),
209 | value=memory,
210 | attn_mask=memory_mask,
211 | key_padding_mask=memory_key_padding_mask
212 | )
213 |
214 | # Cross-attention with residual connection
215 | tgt = tgt + self.dropout2(tgt2)
216 | tgt = self.norm2(tgt)
217 |
218 | # Feed-forward network
219 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
220 | tgt = tgt + self.dropout3(tgt2)
221 | tgt = self.norm3(tgt)
222 | return tgt, ws
223 |
224 | def forward_pre(
225 | self,
226 | tgt: Tensor,
227 | memory: Tensor,
228 | tgt_mask: Optional[Tensor] = None,
229 | memory_mask: Optional[Tensor] = None,
230 | tgt_key_padding_mask: Optional[Tensor] = None,
231 | memory_key_padding_mask: Optional[Tensor] = None,
232 | pos: Optional[Tensor] = None,
233 | query_pos: Optional[Tensor] = None
234 | ) -> Tuple[Tensor, Tensor]:
235 | """
236 | Pre-norm forward pass: normalization first, then attention/FFN.
237 |
238 | Args:
239 | tgt: Target sequence.
240 | memory: Memory (encoder output).
241 | tgt_mask: Self-attention mask for target.
242 | memory_mask: Cross-attention mask.
243 | tgt_key_padding_mask: Padding mask for target keys.
244 | memory_key_padding_mask: Padding mask for memory keys.
245 | pos: Positional embeddings for memory.
246 | query_pos: Positional embeddings for queries.
247 |
248 | Returns:
249 | Tuple of (output_tensor, attention_weights).
250 | """
251 | tgt2 = self.norm2(tgt)
252 | tgt2, attn_weights = self.multihead_attn(
253 | query=self.with_pos_embed(tgt2, query_pos),
254 | key=self.with_pos_embed(memory, pos),
255 | value=memory,
256 | attn_mask=memory_mask,
257 | key_padding_mask=memory_key_padding_mask
258 | )
259 | tgt = tgt + self.dropout2(tgt2)
260 |
261 | # Feed-forward network with pre-norm
262 | tgt2 = self.norm3(tgt)
263 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
264 | tgt = tgt + self.dropout3(tgt2)
265 | return tgt, attn_weights
266 |
267 |
268 | def forward_pre_selfattention(
269 | self,
270 | tgt: Tensor,
271 | memory: Tensor,
272 | tgt_mask: Optional[Tensor] = None,
273 | memory_mask: Optional[Tensor] = None,
274 | tgt_key_padding_mask: Optional[Tensor] = None,
275 | memory_key_padding_mask: Optional[Tensor] = None,
276 | pos: Optional[Tensor] = None,
277 | query_pos: Optional[Tensor] = None
278 | ) -> Tuple[Tensor, Tensor]:
279 | """
280 | Alternative forward pass with cross-attention before self-attention.
281 |
282 | This variant applies operations in the order:
283 | 1. Cross-attention (without normalization)
284 | 2. Self-attention (with pre-norm)
285 | 3. Feed-forward network (with pre-norm)
286 |
287 | Args:
288 | tgt: Target sequence.
289 | memory: Memory (encoder output).
290 | tgt_mask: Self-attention mask for target.
291 | memory_mask: Cross-attention mask.
292 | tgt_key_padding_mask: Padding mask for target keys.
293 | memory_key_padding_mask: Padding mask for memory keys.
294 | pos: Positional embeddings for memory.
295 | query_pos: Positional embeddings for queries.
296 |
297 | Returns:
298 | Tuple of (output_tensor, attention_weights).
299 | """
300 | # Cross-attention without pre-normalization
301 | tgt2, attn_weights = self.multihead_attn(
302 | query=self.with_pos_embed(tgt, query_pos),
303 | key=self.with_pos_embed(memory, pos),
304 | value=memory,
305 | attn_mask=memory_mask,
306 | key_padding_mask=memory_key_padding_mask
307 | )
308 | tgt = tgt + tgt2
309 |
310 | # Self-attention with pre-norm
311 | tgt2 = self.norm1(tgt)
312 | q = k = self.with_pos_embed(tgt2, query_pos)
313 | tgt2, ws = self.self_attn(q, k, value=tgt2)
314 | tgt = tgt + tgt2
315 |
316 | # Feed-forward network with pre-norm
317 | tgt2 = self.norm2(tgt)
318 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
319 | tgt = tgt + tgt2
320 | tgt = self.norm3(tgt)
321 | return tgt, attn_weights
322 |
323 |
324 | def forward(
325 | self,
326 | tgt: Tensor,
327 | memory: Tensor,
328 | tgt_mask: Optional[Tensor] = None,
329 | memory_mask: Optional[Tensor] = None,
330 | tgt_key_padding_mask: Optional[Tensor] = None,
331 | memory_key_padding_mask: Optional[Tensor] = None,
332 | pos: Optional[Tensor] = None,
333 | query_pos: Optional[Tensor] = None,
334 | residual: bool = True
335 | ) -> Tuple[Tensor, Tensor]:
336 | """
337 | Forward pass through the decoder layer.
338 |
339 | Dispatches to either pre-norm or post-norm variant based on configuration.
340 |
341 | Args:
342 | tgt: Target sequence.
343 | memory: Memory (encoder output).
344 | tgt_mask: Self-attention mask for target.
345 | memory_mask: Cross-attention mask.
346 | tgt_key_padding_mask: Padding mask for target keys.
347 | memory_key_padding_mask: Padding mask for memory keys.
348 | pos: Positional embeddings for memory.
349 | query_pos: Positional embeddings for queries.
350 | residual: Whether to use residual connections (used in post-norm variant).
351 |
352 | Returns:
353 | Tuple of (output_tensor, attention_weights).
354 | """
355 | if self.normalize_before:
356 | return self.forward_pre(
357 | tgt, memory, tgt_mask, memory_mask,
358 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos
359 | )
360 | return self.forward_post(
361 | tgt, memory, tgt_mask, memory_mask,
362 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos, residual
363 | )
364 |
365 |
366 | def _get_clones(module: nn.Module, N: int) -> nn.ModuleList:
367 | """
368 | Create N identical copies of a module.
369 |
370 | Args:
371 | module: The module to clone.
372 | N: Number of clones to create.
373 |
374 | Returns:
375 | ModuleList containing N deep copies of the input module.
376 | """
377 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
378 |
379 |
380 | def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
381 | """
382 | Return an activation function given a string identifier.
383 |
384 | Args:
385 | activation: Name of the activation function ('relu', 'gelu', or 'glu').
386 |
387 | Returns:
388 | The corresponding activation function.
389 |
390 | Raises:
391 | RuntimeError: If the activation name is not recognized.
392 | """
393 | if activation == "relu":
394 | return F.relu
395 | if activation == "gelu":
396 | return F.gelu
397 | if activation == "glu":
398 | return F.glu
399 | raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.")
--------------------------------------------------------------------------------
/voxtell/inference/predictor.py:
--------------------------------------------------------------------------------
1 | import pydoc
2 | from queue import Queue
3 | from threading import Thread
4 | from typing import List, Tuple, Union
5 |
6 | import numpy as np
7 | import torch
8 | from torch._dynamo import OptimizedModule
9 | from tqdm import tqdm
10 | from transformers import AutoModel, AutoTokenizer
11 |
12 | from acvl_utils.cropping_and_padding.bounding_boxes import insert_crop_into_image
13 | from acvl_utils.cropping_and_padding.padding import pad_nd_image
14 | from batchgenerators.utilities.file_and_folder_operations import join, load_json
15 |
16 | from nnunetv2.inference.sliding_window_prediction import compute_gaussian, compute_steps_for_sliding_window
17 | from nnunetv2.preprocessing.cropping.cropping import crop_to_nonzero
18 | from nnunetv2.preprocessing.normalization.default_normalization_schemes import ZScoreNormalization
19 | from nnunetv2.utilities.helpers import dummy_context, empty_cache
20 |
21 | from voxtell.model.voxtell_model import VoxTellModel
22 | from voxtell.utils.text_embedding import last_token_pool, wrap_with_instruction
23 |
24 |
25 | class VoxTellPredictor:
26 | """
27 | Predictor for VoxTell segmentation model.
28 |
29 | This class handles loading the VoxTell model, preprocessing images,
30 | embedding text prompts, and performing sliding window inference to generate
31 | segmentation masks based on free-text anatomical descriptions.
32 |
33 | Attributes:
34 | device: PyTorch device for inference.
35 | network: The VoxTell model.
36 | tokenizer: Text tokenizer for prompt encoding.
37 | text_backbone: Text embedding model.
38 | patch_size: Patch size for sliding window inference.
39 | tile_step_size: Step size for sliding window (default: 0.5 = 50% overlap).
40 | perform_everything_on_device: Keep all tensors on device during inference.
41 | max_text_length: Maximum text prompt length in tokens.
42 | """
43 | def __init__(self, model_dir: str, device: torch.device = torch.device('cuda'),
44 | text_encoding_model: str = 'Qwen/Qwen3-Embedding-4B') -> None:
45 | """
46 | Initialize the VoxTell predictor.
47 |
48 | Args:
49 | model_dir: Path to model directory containing plans.json and checkpoint.
50 | device: PyTorch device to use for inference (default: cuda).
51 | text_encoding_model: Pretrained text encoding model (Qwen/Qwen3-Embedding-4B).
52 |
53 | Raises:
54 | FileNotFoundError: If model files are not found.
55 | RuntimeError: If model loading fails.
56 | """
57 | # Device setup
58 | self.device = device
59 | if device.type == 'cuda':
60 | torch.backends.cudnn.benchmark = True
61 | self.normalization = ZScoreNormalization(intensityproperties={})
62 |
63 | # Predictor settings
64 | self.tile_step_size = 0.5
65 | self.perform_everything_on_device = True
66 |
67 | # Embedding model
68 | self.tokenizer = AutoTokenizer.from_pretrained(text_encoding_model, padding_side='left')
69 | self.text_backbone = AutoModel.from_pretrained(text_encoding_model).eval()
70 | self.max_text_length = 8192
71 |
72 | # Load network settings
73 | plans = load_json(join(model_dir, 'plans.json'))
74 | arch_kwargs = plans['configurations']['3d_fullres']['architecture']['arch_kwargs']
75 | self.patch_size = plans['configurations']['3d_fullres']['patch_size']
76 |
77 | arch_kwargs = dict(**arch_kwargs)
78 | for required_import_key in plans['configurations']['3d_fullres']['architecture']['_kw_requires_import']:
79 | if arch_kwargs[required_import_key] is not None:
80 | arch_kwargs[required_import_key] = pydoc.locate(arch_kwargs[required_import_key])
81 |
82 | # Instantiate network
83 | network = VoxTellModel(
84 | input_channels=1,
85 | **arch_kwargs,
86 | decoder_layer=4,
87 | text_embedding_dim=2560,
88 | num_maskformer_stages=5,
89 | num_heads=32,
90 | query_dim=2048,
91 | project_to_decoder_hidden_dim=2048,
92 | deep_supervision=False
93 | )
94 |
95 | # Load weights
96 | checkpoint = torch.load(
97 | join(model_dir, 'fold_0', 'checkpoint_final.pth'),
98 | map_location=torch.device('cpu'),
99 | weights_only=False
100 | )
101 |
102 | if not isinstance(network, OptimizedModule):
103 | network.load_state_dict(checkpoint['network_weights'])
104 | else:
105 | network._orig_mod.load_state_dict(checkpoint['network_weights'])
106 |
107 | network.eval()
108 | self.network = network
109 |
110 | def preprocess(self, data: np.ndarray) -> Tuple[torch.Tensor, Tuple, Tuple[int, ...]]:
111 | """
112 | Preprocess a single image for inference.
113 |
114 | This function preprocesses an image already in RAS orientation by performing
115 | cropping to non-zero regions and z-score normalization.
116 |
117 | Args:
118 | data: Image data in RAS orientation (3D or 4D with channel dimension).
119 |
120 | Returns:
121 | Tuple containing:
122 | - Preprocessed image tensor
123 | - Bounding box of cropped region
124 | - Original image shape
125 | """
126 |
127 | if data.ndim == 3:
128 | data = data[None] # add channel axis
129 | data = data.astype(np.float32) # this creates a copy
130 | original_shape = data.shape[1:]
131 | data, _, bbox = crop_to_nonzero(data, None)
132 | data = self.normalization.run(data, None)
133 | data = torch.from_numpy(data)
134 | return data, bbox, original_shape
135 |
136 | def _internal_get_sliding_window_slicers(self, image_size: Tuple[int, ...]) -> List[Tuple]:
137 | """
138 | Generate sliding window slicers for patch-based inference.
139 |
140 | Args:
141 | image_size: Shape of the input image.
142 |
143 | Returns:
144 | List of slice tuples for extracting patches.
145 | """
146 | slicers = []
147 | if len(self.patch_size) < len(image_size):
148 | assert len(self.patch_size) == len(image_size) - 1, (
149 | 'if tile_size has less entries than image_size, '
150 | 'len(tile_size) must be one shorter than len(image_size) '
151 | '(only dimension discrepancy of 1 allowed).'
152 | )
153 | steps = compute_steps_for_sliding_window(image_size[1:], self.patch_size,
154 | self.tile_step_size)
155 | for d in range(image_size[0]):
156 | for sx in steps[0]:
157 | for sy in steps[1]:
158 | slicers.append(
159 | tuple([slice(None), d, *[slice(si, si + ti) for si, ti in
160 | zip((sx, sy), self.patch_size)]]))
161 | else:
162 | steps = compute_steps_for_sliding_window(image_size, self.patch_size,
163 | self.tile_step_size)
164 | for sx in steps[0]:
165 | for sy in steps[1]:
166 | for sz in steps[2]:
167 | slicers.append(
168 | tuple([slice(None), *[slice(si, si + ti) for si, ti in
169 | zip((sx, sy, sz), self.patch_size)]]))
170 | return slicers
171 |
172 | @torch.inference_mode()
173 | def embed_text_prompts(self, text_prompts: Union[List[str], str]) -> torch.Tensor:
174 | """
175 | Embed text prompts into vector representations.
176 |
177 | This function converts free-text anatomical descriptions into embeddings
178 | using the text backbone model.
179 |
180 | Args:
181 | text_prompts: Single text prompt or list of text prompts.
182 |
183 | Returns:
184 | Text embeddings tensor of shape (1, num_prompts, embedding_dim).
185 | """
186 | if isinstance(text_prompts, str):
187 | text_prompts = [text_prompts]
188 | n_prompts = len(text_prompts)
189 | self.text_backbone = self.text_backbone.to(self.device)
190 |
191 | text_prompts = wrap_with_instruction(text_prompts)
192 | text_tokens = self.tokenizer(
193 | text_prompts,
194 | padding=True,
195 | truncation=True,
196 | max_length=self.max_text_length,
197 | return_tensors="pt",
198 | )
199 | text_tokens = {k: v.to(self.device) for k, v in text_tokens.items()}
200 | text_embed = self.text_backbone(**text_tokens)
201 | embeddings = last_token_pool(text_embed.last_hidden_state, text_tokens['attention_mask'])
202 | embeddings = embeddings.view(1, n_prompts, -1)
203 | self.text_backbone = self.text_backbone.to('cpu')
204 | empty_cache(self.device)
205 | return embeddings
206 |
207 | @torch.inference_mode()
208 | def predict_sliding_window_return_logits(
209 | self,
210 | input_image: torch.Tensor,
211 | text_embeddings: torch.Tensor
212 | ) -> torch.Tensor:
213 | """
214 | Perform sliding window inference to generate segmentation logits.
215 |
216 | Args:
217 | input_image: Input image tensor of shape (C, X, Y, Z).
218 | text_embeddings: Text embeddings from embed_text_prompts.
219 |
220 | Returns:
221 | Predicted logits tensor.
222 |
223 | Raises:
224 | ValueError: If input_image is not 4D or not a torch.Tensor.
225 | """
226 | if not isinstance(input_image, torch.Tensor):
227 | raise ValueError(f"input_image must be a torch.Tensor, got {type(input_image)}")
228 | if input_image.ndim != 4:
229 | raise ValueError(
230 | f"input_image must be 4D (C, X, Y, Z), got shape {input_image.shape}"
231 | )
232 |
233 | self.network = self.network.to(self.device)
234 |
235 | empty_cache(self.device)
236 | with torch.autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
237 |
238 | # if input_image is smaller than tile_size we need to pad it to tile_size.
239 | data, slicer_revert_padding = pad_nd_image(input_image, self.patch_size,
240 | 'constant', {'value': 0}, True, None)
241 |
242 | slicers = self._internal_get_sliding_window_slicers(data.shape[1:])
243 |
244 | predicted_logits = self._internal_predict_sliding_window_return_logits(
245 | data, text_embeddings, slicers, self.perform_everything_on_device
246 | )
247 |
248 | empty_cache(self.device)
249 | # Revert padding
250 | predicted_logits = predicted_logits[(slice(None), *slicer_revert_padding[1:])]
251 | return predicted_logits
252 |
253 | @torch.inference_mode()
254 | def _internal_predict_sliding_window_return_logits(
255 | self,
256 | data: torch.Tensor,
257 | text_embeddings: torch.Tensor,
258 | slicers: List[Tuple],
259 | do_on_device: bool = True,
260 | ) -> torch.Tensor:
261 | """
262 | Internal method for sliding window prediction with Gaussian weighting.
263 |
264 | Uses a producer-consumer pattern with threading to overlap data loading
265 | and model inference.
266 |
267 | Args:
268 | data: Preprocessed image data.
269 | text_embeddings: Text embeddings for prompts.
270 | slicers: List of slice tuples for patch extraction.
271 | do_on_device: If True, keep all tensors on GPU during computation.
272 |
273 | Returns:
274 | Aggregated prediction logits.
275 |
276 | Raises:
277 | RuntimeError: If inf values are encountered in predictions.
278 | """
279 | results_device = self.device if do_on_device else torch.device('cpu')
280 |
281 | def producer(data_tensor, slicer_list, queue):
282 | """Producer thread that loads patches into queue."""
283 | for slicer in slicer_list:
284 | patch = torch.clone(
285 | data_tensor[slicer][None],
286 | memory_format=torch.contiguous_format
287 | ).to(self.device)
288 | queue.put((patch, slicer))
289 | queue.put('end')
290 |
291 | empty_cache(self.device)
292 |
293 | # move data to device
294 | data = data.to(results_device)
295 | queue = Queue(maxsize=2)
296 | t = Thread(target=producer, args=(data, slicers, queue))
297 | t.start()
298 |
299 | # preallocate arrays
300 | predicted_logits = torch.zeros((text_embeddings.shape[1], *data.shape[1:]),
301 | dtype=torch.half,
302 | device=results_device)
303 | n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device)
304 |
305 | gaussian = compute_gaussian(
306 | tuple(self.patch_size),
307 | sigma_scale=1. / 8,
308 | value_scaling_factor=10,
309 | device=results_device
310 | )
311 |
312 | with tqdm(desc=None, total=len(slicers)) as pbar:
313 | while True:
314 | item = queue.get()
315 | if item == 'end':
316 | queue.task_done()
317 | break
318 | patch, tile_slice = item
319 | prediction = self.network(patch, text_embeddings)[0].to(results_device)
320 | prediction *= gaussian
321 | predicted_logits[tile_slice] += prediction
322 | n_predictions[tile_slice[1:]] += gaussian
323 | queue.task_done()
324 | pbar.update()
325 | queue.join()
326 |
327 | # Normalize by number of predictions per voxel
328 | torch.div(predicted_logits, n_predictions, out=predicted_logits)
329 |
330 | # Check for inf values
331 | if torch.any(torch.isinf(predicted_logits)):
332 | raise RuntimeError(
333 | 'Encountered inf in predicted array. Aborting... '
334 | 'If this problem persists, reduce value_scaling_factor in '
335 | 'compute_gaussian or increase the dtype of predicted_logits to fp32.'
336 | )
337 | return predicted_logits
338 |
339 | def predict_single_image(
340 | self,
341 | data: np.ndarray,
342 | text_prompts: Union[str, List[str]]
343 | ) -> np.ndarray:
344 | """
345 | Predict segmentation masks for a single image with text prompts.
346 |
347 | This is the main prediction method that orchestrates preprocessing,
348 | text embedding, sliding window inference, and postprocessing.
349 |
350 | Args:
351 | data: Image data in RAS orientation (3D or 4D with channel dimension).
352 | text_prompts: Single text prompt or list of text prompts describing
353 | anatomical structures to segment.
354 |
355 | Returns:
356 | Segmentation masks as numpy array of shape (num_prompts, X, Y, Z)
357 | with binary values (0 or 1) indicating the segmented regions.
358 | """
359 |
360 | # Preprocess image
361 | data, bbox, orig_shape = self.preprocess(data)
362 |
363 | # Embed text prompts
364 | embeddings = self.embed_text_prompts(text_prompts)
365 |
366 | # Predict segmentation logits
367 | prediction = self.predict_sliding_window_return_logits(data, embeddings).to('cpu')
368 |
369 | # Postprocess logits to get binary segmentation masks
370 | with torch.no_grad():
371 | prediction = torch.sigmoid(prediction.float()) > 0.5
372 |
373 | segmentation_reverted_cropping = np.zeros(
374 | [prediction.shape[0], *orig_shape],
375 | dtype=np.uint8
376 | )
377 | segmentation_reverted_cropping = insert_crop_into_image(
378 | segmentation_reverted_cropping, prediction, bbox
379 | )
380 |
381 | return segmentation_reverted_cropping
382 |
383 |
384 | if __name__ == '__main__':
385 | from pathlib import Path
386 | from nnunetv2.imageio.nibabel_reader_writer import NibabelIOWithReorient
387 |
388 | # Default paths - modify these as needed
389 | DEFAULT_IMAGE_PATH = "/path/to/your/image.nii.gz"
390 | DEFAULT_MODEL_DIR = "/path/to/your/model/directory"
391 |
392 | # Configuration
393 | image_path = DEFAULT_IMAGE_PATH
394 | model_dir = DEFAULT_MODEL_DIR
395 | text_prompts = ["liver", "right kidney", "left kidney", "spleen"]
396 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
397 |
398 | # Load image
399 | img, props = NibabelIOWithReorient().read_images([image_path])
400 |
401 | # Initialize predictor and run inference
402 | predictor = VoxTellPredictor(model_dir=model_dir, device=device)
403 | voxtell_seg = predictor.predict_single_image(img, text_prompts)
404 |
405 | # Visualize results, we reccommend using napari for 3D visualization
406 | import napari
407 | viewer = napari.Viewer()
408 | viewer.add_image(img, name='image')
409 | for i, prompt in enumerate(text_prompts):
410 | viewer.add_labels(voxtell_seg[i], name=f'voxtell_{prompt}')
411 | napari.run()
--------------------------------------------------------------------------------
/voxtell/model/voxtell_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn.modules.conv import _ConvNd
4 | from torch.nn.modules.dropout import _DropoutNd
5 |
6 | from typing import List, Type, Union, Tuple
7 | from dynamic_network_architectures.building_blocks.helper import get_matching_convtransp
8 | from dynamic_network_architectures.building_blocks.plain_conv_encoder import PlainConvEncoder
9 | from dynamic_network_architectures.building_blocks.residual import BasicBlockD, BottleneckD
10 | from dynamic_network_architectures.building_blocks.residual_encoders import ResidualEncoder
11 | from dynamic_network_architectures.building_blocks.simple_conv_blocks import StackedConvBlocks
12 | from dynamic_network_architectures.initialization.weight_init import InitWeights_He, init_last_bn_before_add_to_0
13 | from einops import rearrange, repeat
14 | from positional_encodings.torch_encodings import PositionalEncoding3D
15 |
16 | from voxtell.model.transformer import TransformerDecoder, TransformerDecoderLayer
17 |
18 |
19 | class VoxTellModel(nn.Module):
20 | """
21 | VoxTell segmentation model with text-prompted decoder.
22 |
23 | This model combines a ResidualEncoder backbone with a transformer-based decoder
24 | that uses text embeddings to generate segmentation masks. It supports multi-stage
25 | decoding with optional deep supervision.
26 |
27 | Attributes:
28 | encoder: ResidualEncoder backbone for feature extraction.
29 | decoder: VoxTellDecoder for multi-scale feature decoding.
30 | transformer_decoder: Transformer for fusing text and image features.
31 | deep_supervision: Whether to return multi-scale predictions.
32 | """
33 |
34 | # Class constants for transformer architecture (text prompt decoder)
35 | TRANSFORMER_NUM_HEADS = 8
36 | TRANSFORMER_NUM_LAYERS = 6
37 |
38 | # Decoder configuration for different stages
39 | DECODER_CONFIGS = {
40 | 0: {"channels": 32, "shape": (192, 192, 192)},
41 | 1: {"channels": 64, "shape": (96, 96, 96)},
42 | 2: {"channels": 128, "shape": (48, 48, 48)},
43 | 3: {"channels": 256, "shape": (24, 24, 24)},
44 | 4: {"channels": 320, "shape": (12, 12, 12)},
45 | 5: {"channels": 320, "shape": (6, 6, 6)},
46 | }
47 |
48 | def __init__(
49 | self,
50 | input_channels: int,
51 | n_stages: int,
52 | features_per_stage: Union[int, List[int], Tuple[int, ...]],
53 | conv_op: Type[_ConvNd],
54 | kernel_sizes: Union[int, List[int], Tuple[int, ...]],
55 | strides: Union[int, List[int], Tuple[int, ...]],
56 | n_blocks_per_stage: Union[int, List[int], Tuple[int, ...]],
57 | n_conv_per_stage_decoder: Union[int, Tuple[int, ...], List[int]],
58 | conv_bias: bool = False,
59 | norm_op: Union[None, Type[nn.Module]] = None,
60 | norm_op_kwargs: dict = None,
61 | dropout_op: Union[None, Type[_DropoutNd]] = None,
62 | dropout_op_kwargs: dict = None,
63 | nonlin: Union[None, Type[torch.nn.Module]] = None,
64 | nonlin_kwargs: dict = None,
65 | deep_supervision: bool = False,
66 | block: Union[Type[BasicBlockD], Type[BottleneckD]] = BasicBlockD,
67 | bottleneck_channels: Union[int, List[int], Tuple[int, ...]] = None,
68 | stem_channels: int = None,
69 | # Text-prompted segmentation parameters
70 | num_maskformer_stages: int = 5,
71 | query_dim: int = 864,
72 | decoder_layer: int = 4,
73 | text_embedding_dim: int = 1024,
74 | num_heads: int = 1,
75 | project_to_decoder_hidden_dim: int = 432
76 | ) -> None:
77 | """
78 | Initialize the VoxTell model.
79 |
80 | Args:
81 | input_channels: Number of input channels.
82 | n_stages: Number of encoder stages.
83 | features_per_stage: Number of features at each stage.
84 | conv_op: Convolution operation type.
85 | kernel_sizes: Kernel sizes for convolutions.
86 | strides: Strides for downsampling.
87 | n_blocks_per_stage: Number of residual blocks per stage.
88 | n_conv_per_stage_decoder: Number of convolutions per stage.
89 | conv_bias: Whether to use bias in convolutions.
90 | norm_op: Normalization operation.
91 | norm_op_kwargs: Normalization operation keyword arguments.
92 | dropout_op: Dropout operation.
93 | dropout_op_kwargs: Dropout operation keyword arguments.
94 | nonlin: Non-linearity operation.
95 | nonlin_kwargs: Non-linearity keyword arguments.
96 | deep_supervision: Whether to use deep supervision.
97 | block: Residual block type (BasicBlockD or BottleneckD).
98 | bottleneck_channels: Channels in bottleneck layers.
99 | stem_channels: Channels in stem layer.
100 | num_maskformer_stages: Number of stages to fuse text-image embeddings in decoder.
101 | query_dim: Dimension of query embeddings.
102 | decoder_layer: Which decoder layer to use as memory for text prompt decoder (0-5).
103 | text_embedding_dim: Dimension of text embeddings.
104 | num_heads: Number of channels added per U-Net stage for mask embedding fusion.
105 | project_to_decoder_hidden_dim: Hidden dimension for projection to decoder.
106 | """
107 | super().__init__()
108 | if isinstance(n_blocks_per_stage, int):
109 | n_blocks_per_stage = [n_blocks_per_stage] * n_stages
110 | if isinstance(n_conv_per_stage_decoder, int):
111 | n_conv_per_stage_decoder = [n_conv_per_stage_decoder] * (n_stages - 1)
112 |
113 | assert len(n_blocks_per_stage) == n_stages, (
114 | f"n_blocks_per_stage must have as many entries as we have resolution stages. "
115 | f"Expected: {n_stages}, got: {len(n_blocks_per_stage)} ({n_blocks_per_stage})"
116 | )
117 | assert len(n_conv_per_stage_decoder) == (n_stages - 1), (
118 | f"n_conv_per_stage_decoder must have one less entry than resolution stages. "
119 | f"Expected: {n_stages - 1}, got: {len(n_conv_per_stage_decoder)} ({n_conv_per_stage_decoder})"
120 | )
121 | if num_maskformer_stages != 5:
122 | assert not deep_supervision, (
123 | "Deep supervision is not supported for num_maskformer_stages != 5."
124 | )
125 | self.deep_supervision = deep_supervision
126 | self.num_heads = num_heads
127 | self.query_dim = query_dim
128 | self.project_to_decoder_hidden_dim = project_to_decoder_hidden_dim
129 | self.text_embedding_dim = text_embedding_dim
130 |
131 | # Initialize encoder backbone
132 | self.encoder = ResidualEncoder(
133 | input_channels, n_stages, features_per_stage, conv_op, kernel_sizes, strides,
134 | n_blocks_per_stage, conv_bias, norm_op, norm_op_kwargs, dropout_op,
135 | dropout_op_kwargs, nonlin, nonlin_kwargs, block, bottleneck_channels,
136 | return_skips=True, disable_default_stem=False, stem_channels=stem_channels
137 | )
138 |
139 | # Initialize decoder
140 | self.decoder = VoxTellDecoder(
141 | self.encoder, 1, n_conv_per_stage_decoder, deep_supervision,
142 | num_maskformer_stages, num_heads=self.num_heads
143 | )
144 |
145 | # Select decoder layer configuration
146 | self.selected_decoder_layer = decoder_layer
147 | if decoder_layer not in self.DECODER_CONFIGS:
148 | raise ValueError(
149 | f"decoder_layer must be in {list(self.DECODER_CONFIGS.keys())}, got {decoder_layer}"
150 | )
151 | selected_config = self.DECODER_CONFIGS[decoder_layer]
152 |
153 | h, w, d = selected_config["shape"]
154 |
155 | # Project bottleneck embeddings to query dimension
156 | self.project_bottleneck_embed = nn.Sequential(
157 | nn.Linear(selected_config["channels"], query_dim),
158 | nn.GELU(),
159 | nn.Linear(query_dim, query_dim),
160 | )
161 |
162 | # Project text embeddings to query dimension
163 | text_hidden_dim = 2048
164 | self.project_text_embed = nn.Sequential(
165 | nn.Linear(self.text_embedding_dim, text_hidden_dim),
166 | nn.GELU(),
167 | nn.Linear(text_hidden_dim, query_dim),
168 | )
169 |
170 | # Project decoder output to image channels for each mask-former stage
171 | self.project_to_decoder_channels = nn.ModuleList([
172 | nn.Sequential(
173 | nn.Linear(query_dim, self.project_to_decoder_hidden_dim),
174 | nn.GELU(),
175 | nn.Linear(
176 | self.project_to_decoder_hidden_dim,
177 | decoder_config["channels"] * self.num_heads if stage_idx != 0 else decoder_config["channels"]
178 | )
179 | )
180 | for stage_idx, decoder_config in enumerate(
181 | list(self.DECODER_CONFIGS.values())[:num_maskformer_stages]
182 | )
183 | ])
184 |
185 | # Initialize 3D positional encoding
186 | # Shape: (H*W*D, batch_size, query_dim)
187 | pos_embed = PositionalEncoding3D(query_dim)(torch.zeros(1, h, w, d, query_dim))
188 | pos_embed = rearrange(pos_embed, 'b h w d c -> (h w d) b c')
189 | self.register_buffer('pos_embed', pos_embed)
190 |
191 | # Initialize transformer decoder for fusing text and image features (prompt decoder)
192 | transformer_layer = TransformerDecoderLayer(
193 | d_model=query_dim,
194 | nhead=self.TRANSFORMER_NUM_HEADS,
195 | normalize_before=True
196 | )
197 | decoder_norm = nn.LayerNorm(query_dim)
198 | self.transformer_decoder = TransformerDecoder(
199 | decoder_layer=transformer_layer,
200 | num_layers=self.TRANSFORMER_NUM_LAYERS,
201 | norm=decoder_norm
202 | )
203 |
204 | def forward(
205 | self,
206 | img: torch.Tensor,
207 | text_embedding: torch.Tensor = None
208 | ) -> Union[torch.Tensor, List[torch.Tensor]]:
209 | """
210 | Forward pass through VoxTell model.
211 |
212 | Args:
213 | img: Input image tensor of shape (B, C, D, H, W).
214 | text_embedding: Pre-computed text embeddings of shape (B, N, D) where
215 | N is number of prompts and D is embedding dimension.
216 |
217 | Returns:
218 | If deep_supervision is False, returns single prediction tensor of shape (B, N, D, H, W).
219 | If deep_supervision is True, returns list of prediction tensors at different scales.
220 | """
221 | # Extract multi-scale features from encoder
222 | skips = self.encoder(img)
223 |
224 | # Select encoder features
225 | selected_feature = skips[self.selected_decoder_layer]
226 |
227 | # Reshape and project features to query dimension
228 | # Shape: (B, C, D, H, W) -> (B, H, W, D, C) -> (B, H, W, D, query_dim)
229 | bottleneck_embed = rearrange(selected_feature, 'b c d h w -> b h w d c')
230 | bottleneck_embed = self.project_bottleneck_embed(bottleneck_embed)
231 | # Shape: (B, H, W, D, query_dim) -> (H*W*D, B, query_dim) for transformer
232 | bottleneck_embed = rearrange(bottleneck_embed, 'b h w d c -> (h w d) b c')
233 |
234 | # Remove singleton dimension from text embeddings and project
235 | # Shape: (B, N, 1, D) -> (B, N, D)
236 | text_embedding = text_embedding.squeeze(2)
237 | # Shape: (B, N, D) -> (N, B, D) as required by transformer decoder
238 | text_embed = repeat(text_embedding, 'b n dim -> n b dim')
239 | text_embed = self.project_text_embed(text_embed)
240 |
241 | # Fuse text and image features through transformer decoder
242 | # Output shape: (N, B, query_dim)
243 | mask_embedding, _ = self.transformer_decoder(
244 | tgt=text_embed,
245 | memory=bottleneck_embed,
246 | pos=self.pos_embed,
247 | memory_key_padding_mask=None
248 | )
249 | # Shape: (N, B, query_dim) -> (B, N, query_dim)
250 | mask_embedding = repeat(mask_embedding, 'n b dim -> b n dim')
251 |
252 | # Project mask embeddings to decoder channel dimensions for each stage
253 | mask_embeddings = [
254 | projection(mask_embedding)
255 | for projection in self.project_to_decoder_channels
256 | ]
257 |
258 | # Generate segmentation outputs for each text prompt
259 | outs = []
260 | num_prompts = text_embedding.shape[1]
261 | for prompt_idx in range(num_prompts):
262 | # Extract embeddings for this prompt across all stages
263 | prompt_embeds = [m[:, prompt_idx:prompt_idx + 1] for m in mask_embeddings]
264 | outs.append(self.decoder(skips, prompt_embeds))
265 |
266 | # Concatenate outputs across prompts for each scale
267 | outs = [torch.cat(scale_outs, dim=1) for scale_outs in zip(*outs)]
268 |
269 | if not self.deep_supervision:
270 | outs = outs[0]
271 |
272 | return outs
273 |
274 | @staticmethod
275 | def initialize(module):
276 | InitWeights_He(1e-2)(module)
277 | init_last_bn_before_add_to_0(module)
278 |
279 |
280 | class VoxTellDecoder(nn.Module):
281 | """
282 | Decoder for VoxTell with mask-embedding fusion.
283 |
284 | This decoder upsamples features from the encoder and fuses them with
285 | mask embeddings from text prompts at multiple scales. It supports
286 | deep supervision for multi-scale training.
287 |
288 | The decoder processes features from bottleneck to highest resolution,
289 | incorporating mask embeddings through einsum operations at each stage.
290 | """
291 |
292 | def __init__(
293 | self,
294 | encoder: Union[PlainConvEncoder, ResidualEncoder],
295 | num_classes: int,
296 | n_conv_per_stage: Union[int, Tuple[int, ...], List[int]],
297 | deep_supervision: bool,
298 | num_maskformer_stages: int = 5,
299 | nonlin_first: bool = False,
300 | norm_op: Union[None, Type[nn.Module]] = None,
301 | norm_op_kwargs: dict = None,
302 | dropout_op: Union[None, Type[_DropoutNd]] = None,
303 | dropout_op_kwargs: dict = None,
304 | nonlin: Union[None, Type[torch.nn.Module]] = None,
305 | nonlin_kwargs: dict = None,
306 | conv_bias: bool = None,
307 | num_heads: int = 1
308 | ) -> None:
309 | """
310 | Initialize VoxTell decoder.
311 |
312 | The decoder upsamples features from encoder stages and fuses them with
313 | mask embeddings. Each stage consists of:
314 | 1) Transpose convolution to upsample lower resolution features
315 | 2) Concatenation with skip connections from encoder
316 | 3) Convolutional blocks to merge features
317 | 4) Fusion with mask embeddings via einsum
318 | 5) Optional segmentation output for deep supervision
319 |
320 | Args:
321 | encoder: Encoder module (PlainConvEncoder or ResidualEncoder).
322 | num_classes: Number of output classes (typically 1 for binary segmentation).
323 | n_conv_per_stage: Number of convolution blocks per decoder stage.
324 | deep_supervision: Whether to output predictions at multiple scales.
325 | num_maskformer_stages: Number of stages to fuse mask embeddings.
326 | nonlin_first: Whether to apply non-linearity before convolution.
327 | norm_op: Normalization operation (inherited from encoder if None).
328 | norm_op_kwargs: Normalization keyword arguments.
329 | dropout_op: Dropout operation (inherited from encoder if None).
330 | dropout_op_kwargs: Dropout keyword arguments.
331 | nonlin: Non-linearity operation (inherited from encoder if None).
332 | nonlin_kwargs: Non-linearity keyword arguments.
333 | conv_bias: Whether to use bias in convolutions (inherited from encoder if None).
334 | num_heads: Number of attention heads for mask embedding fusion.
335 | """
336 | super().__init__()
337 | self.deep_supervision = deep_supervision
338 | self.encoder = encoder
339 | self.num_classes = num_classes
340 | self.num_heads = num_heads
341 |
342 | n_stages_encoder = len(encoder.output_channels)
343 | if isinstance(n_conv_per_stage, int):
344 | n_conv_per_stage = [n_conv_per_stage] * (n_stages_encoder - 1)
345 |
346 | assert len(n_conv_per_stage) == n_stages_encoder - 1, (
347 | f"n_conv_per_stage must have one less entry than encoder stages. "
348 | f"Expected: {n_stages_encoder - 1}, got: {len(n_conv_per_stage)}"
349 | )
350 |
351 | # Inherit hyperparameters from encoder if not specified
352 | transpconv_op = get_matching_convtransp(conv_op=encoder.conv_op)
353 | conv_bias = encoder.conv_bias if conv_bias is None else conv_bias
354 | norm_op = encoder.norm_op if norm_op is None else norm_op
355 | norm_op_kwargs = encoder.norm_op_kwargs if norm_op_kwargs is None else norm_op_kwargs
356 | dropout_op = encoder.dropout_op if dropout_op is None else dropout_op
357 | dropout_op_kwargs = encoder.dropout_op_kwargs if dropout_op_kwargs is None else dropout_op_kwargs
358 | nonlin = encoder.nonlin if nonlin is None else nonlin
359 | nonlin_kwargs = encoder.nonlin_kwargs if nonlin_kwargs is None else nonlin_kwargs
360 |
361 | # Build decoder stages from bottleneck to highest resolution
362 | stages = []
363 | transpconvs = []
364 | seg_layers = []
365 | for stage_idx in range(1, n_stages_encoder):
366 | # Determine input channels: add num_heads for stages with mask embedding fusion
367 | if stage_idx <= n_stages_encoder - num_maskformer_stages:
368 | input_features_below = encoder.output_channels[-stage_idx]
369 | else:
370 | input_features_below = encoder.output_channels[-stage_idx] + num_heads
371 |
372 | input_features_skip = encoder.output_channels[-(stage_idx + 1)]
373 | stride_for_transpconv = encoder.strides[-stage_idx]
374 |
375 | # Transpose convolution for upsampling
376 | transpconvs.append(transpconv_op(
377 | input_features_below, input_features_skip,
378 | stride_for_transpconv, stride_for_transpconv,
379 | bias=conv_bias
380 | ))
381 |
382 | # Convolutional blocks for feature merging
383 | # Input features: 2x input_features_skip (concatenated skip + upsampled features)
384 | stages.append(StackedConvBlocks(
385 | n_conv_per_stage[stage_idx - 1], encoder.conv_op,
386 | 2 * input_features_skip, input_features_skip,
387 | encoder.kernel_sizes[-(stage_idx + 1)], 1,
388 | conv_bias, norm_op, norm_op_kwargs,
389 | dropout_op, dropout_op_kwargs,
390 | nonlin, nonlin_kwargs, nonlin_first
391 | ))
392 |
393 | # Segmentation output layer (always built for parameter loading compatibility)
394 | # This allows models trained with deep_supervision=True to be loaded
395 | # for inference with deep_supervision=False
396 | seg_layers.append(encoder.conv_op(
397 | input_features_skip + num_heads, num_classes,
398 | 1, 1, 0, bias=True
399 | ))
400 |
401 | self.stages = nn.ModuleList(stages)
402 | self.transpconvs = nn.ModuleList(transpconvs)
403 | self.seg_layers = nn.ModuleList(seg_layers)
404 |
405 | def forward(
406 | self,
407 | skips: List[torch.Tensor],
408 | mask_embeddings: List[torch.Tensor]
409 | ) -> List[torch.Tensor]:
410 | """
411 | Forward pass through decoder with mask embedding fusion.
412 |
413 | Processes features from bottleneck to highest resolution, fusing
414 | mask embeddings at multiple stages via einsum operations.
415 |
416 | Args:
417 | skips: List of encoder skip connections in computation order.
418 | Last entry should be bottleneck features.
419 | mask_embeddings: List of mask embeddings for each decoder stage,
420 | in order from lowest to highest resolution.
421 |
422 | Returns:
423 | List of segmentation predictions. If deep_supervision=False,
424 | returns single-element list with highest resolution prediction.
425 | If deep_supervision=True, returns predictions at all scales
426 | from highest to lowest resolution.
427 | """
428 | lres_input = skips[-1]
429 | seg_outputs = []
430 |
431 | # Reverse mask embeddings to match decoder stage order (bottleneck first)
432 | mask_embeddings = mask_embeddings[::-1]
433 |
434 | for stage_idx in range(len(self.stages)):
435 | # Upsample and concatenate with skip connection
436 | x = self.transpconvs[stage_idx](lres_input)
437 | x = torch.cat((x, skips[-(stage_idx + 2)]), dim=1)
438 | x = self.stages[stage_idx](x)
439 |
440 | # Apply mask embedding fusion for relevant stages
441 | if stage_idx == (len(self.stages) - 1):
442 | # Final stage: generate segmentation via einsum
443 | # x: (B, C, H, W, D), mask_embeddings[-1]: (B, N, C)
444 | # Output: (B, N, H, W, D)
445 | seg_pred = torch.einsum('b c h w d, b n c -> b n h w d', x, mask_embeddings[-1])
446 | seg_outputs.append(seg_pred)
447 | elif stage_idx >= len(self.stages) - len(mask_embeddings):
448 | # Intermediate stages with mask embedding fusion
449 | mask_embedding = mask_embeddings.pop(0)
450 | batch_size, _, channels = mask_embedding.shape
451 |
452 | # Reshape for multi-head fusion and compute attention-weighted features
453 | # Shape: (B, num_heads, C // num_heads)
454 | mask_embedding_reshaped = mask_embedding.view(batch_size, self.num_heads, -1)
455 | fusion_features = torch.einsum(
456 | 'b c h w d, b n c -> b n h w d',
457 | x, mask_embedding_reshaped
458 | )
459 |
460 | # Concatenate fused features with spatial features
461 | x = torch.cat((x, fusion_features), dim=1)
462 | seg_outputs.append(self.seg_layers[stage_idx](x))
463 |
464 | lres_input = x
465 |
466 | # Reverse outputs to have highest resolution first
467 | seg_outputs = seg_outputs[::-1]
468 |
469 | if not self.deep_supervision:
470 | return seg_outputs[:1]
471 | else:
472 | return seg_outputs
--------------------------------------------------------------------------------