├── voxtell ├── model │ ├── __init__.py │ ├── transformer.py │ └── voxtell_model.py ├── utils │ ├── __init__.py │ └── text_embedding.py ├── inference │ ├── __init__.py │ ├── predict_from_raw_data.py │ └── predictor.py └── __init__.py ├── documentation └── assets │ ├── VoxTellLogo.png │ ├── VoxTellConcepts.png │ └── VoxTellArchitecture.png ├── pyproject.toml ├── .gitignore ├── readme.md └── LICENSE /voxtell/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /voxtell/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /voxtell/inference/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /voxtell/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib.metadata import version as _version 2 | 3 | __version__ = _version("voxtell") -------------------------------------------------------------------------------- /documentation/assets/VoxTellLogo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/VoxTell/HEAD/documentation/assets/VoxTellLogo.png -------------------------------------------------------------------------------- /documentation/assets/VoxTellConcepts.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/VoxTell/HEAD/documentation/assets/VoxTellConcepts.png -------------------------------------------------------------------------------- /documentation/assets/VoxTellArchitecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/VoxTell/HEAD/documentation/assets/VoxTellArchitecture.png -------------------------------------------------------------------------------- /voxtell/utils/text_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def last_token_pool(last_hidden_states: torch.Tensor, 4 | attention_mask: torch.Tensor) -> torch.Tensor: 5 | left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) 6 | if left_padding: 7 | return last_hidden_states[:, -1] 8 | else: 9 | sequence_lengths = attention_mask.sum(dim=1) - 1 10 | batch_size = last_hidden_states.shape[0] 11 | return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] 12 | 13 | 14 | def wrap_with_instruction(text_prompts): 15 | instruct = 'Given an anatomical term query, retrieve the precise anatomical entity and location it represents' 16 | 17 | instruct_text_prompts = [] 18 | for text in text_prompts: 19 | instruct_text_prompts.append(f'Instruct: {instruct}\nQuery: {text}') 20 | return instruct_text_prompts -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "voxtell" 3 | version = "0.1.0" 4 | requires-python = ">=3.10" 5 | description = "Free-Text Promptable Universal 3D Medical Image Segmentation" 6 | readme = "readme.md" 7 | license = { file = "LICENSE" } 8 | authors = [ 9 | { name = "Maximilian Rokuss", email = "maximilian.rokuss@dkfz-heidelberg.de"}, 10 | { name = "Moritz Langenberg", email = "moritz.langenberg@dkfz-heidelberg.de"}, 11 | { name = "Yannick Kirchhoff" }, 12 | { name = "Fabian Isensee" }, 13 | { name = "Benjamin Hamm" }, 14 | { name = "Constantin Ulrich" }, 15 | { name = "Sebastian Regnery" }, 16 | { name = "Lukas Bauer" }, 17 | { name = "Efthimios Katsigiannopulos" }, 18 | { name = "Tobias Norajitra" }, 19 | { name = "Klaus Maier-Hein" } 20 | ] 21 | classifiers = [ 22 | "Development Status :: 3 - Alpha", 23 | "Intended Audience :: Developers", 24 | "Intended Audience :: Science/Research", 25 | "Intended Audience :: Healthcare Industry", 26 | "Programming Language :: Python :: 3", 27 | "License :: OSI Approved :: Apache Software License", 28 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 29 | "Topic :: Scientific/Engineering :: Image Recognition", 30 | "Topic :: Scientific/Engineering :: Medical Science Apps.", 31 | ] 32 | keywords = [ 33 | 'deep learning', 34 | 'image segmentation', 35 | 'semantic segmentation', 36 | 'medical image analysis', 37 | 'medical image segmentation', 38 | 'vision-language model', 39 | 'text-prompted segmentation', 40 | '3D segmentation', 41 | 'volumetric segmentation' 42 | ] 43 | dependencies = [ 44 | "torch>=2.0,<2.9", 45 | "nnunetv2>=2.6", 46 | "acvl-utils>=0.2.3,<0.3", 47 | "dynamic-network-architectures>=0.4.1,<0.5", 48 | "numpy>=1.24", 49 | "SimpleITK", 50 | "nibabel", 51 | "positional-encodings", 52 | "einops", 53 | "tqdm", 54 | "transformers>=4.54.0,<5", 55 | ] 56 | 57 | [project.urls] 58 | repository = "https://github.com/MIC-DKFZ/VoxTell" 59 | paper = "https://arxiv.org/abs/2511.11450" 60 | 61 | [project.scripts] 62 | voxtell-predict = "voxtell.inference.predict_from_raw_data:main" 63 | 64 | [project.optional-dependencies] 65 | dev = [ 66 | "black", 67 | "ruff", 68 | "pre-commit" 69 | ] 70 | 71 | [build-system] 72 | requires = ["setuptools>=67.8.0", "wheel"] 73 | build-backend = "setuptools.build_meta" 74 | 75 | [tool.codespell] 76 | skip = '.git,*.pdf,*.svg,*.png' 77 | # 78 | # ignore-words-list = '' 79 | 80 | [tool.setuptools.packages.find] 81 | where = ["."] 82 | include = ["voxtell*"] 83 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # ============================================================================== 2 | # Python 3 | # ============================================================================== 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Logs 57 | *.log 58 | 59 | # ============================================================================== 60 | # Virtual Environments 61 | # ============================================================================== 62 | 63 | venv/ 64 | env/ 65 | ENV/ 66 | .venv/ 67 | .python-version 68 | 69 | # ============================================================================== 70 | # IDEs and Editors 71 | # ============================================================================== 72 | 73 | # PyCharm / IntelliJ 74 | .idea/ 75 | 76 | # Spyder 77 | .spyderproject 78 | 79 | # Rope 80 | .ropeproject 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # VS Code 86 | .vscode/ 87 | 88 | # ============================================================================== 89 | # Environment Variables 90 | # ============================================================================== 91 | 92 | .env 93 | .env.local 94 | 95 | # ============================================================================== 96 | # Medical Imaging & Data Files 97 | # ============================================================================== 98 | 99 | # NIfTI files 100 | *.nii 101 | *.nii.gz 102 | 103 | # Image formats 104 | *.tif 105 | *.tiff 106 | *.bmp 107 | *.dcm 108 | 109 | # Numpy arrays 110 | *.npy 111 | *.npz 112 | *.memmap 113 | 114 | # Model files 115 | *.model 116 | *.pth 117 | *.pt 118 | *.ckpt 119 | *.h5 120 | *.weights 121 | 122 | # Pickle files 123 | *.pkl 124 | *.pickle 125 | 126 | # ============================================================================== 127 | # Documentation 128 | # ============================================================================== 129 | 130 | # Sphinx 131 | docs/_build/ 132 | docs/_static/ 133 | docs/_templates/ 134 | 135 | # PDF documentation 136 | *.pdf 137 | 138 | # ============================================================================== 139 | # Miscellaneous 140 | # ============================================================================== 141 | 142 | # Compressed files 143 | *.zip 144 | *.tar 145 | *.tar.gz 146 | *.rar 147 | 148 | # XML files 149 | *.xml 150 | 151 | # Temporary files 152 | *.tmp 153 | *.temp 154 | *~ 155 | 156 | # OS generated files 157 | .DS_Store 158 | .DS_Store? 159 | ._* 160 | .Spotlight-V100 161 | .Trashes 162 | ehthumbs.db 163 | Thumbs.db 164 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # VoxTell: Free-Text Promptable Universal 3D Medical Image Segmentation 2 | 3 | VoxTell Logo 4 | 5 | This repository contains the official implementation of our paper: 6 | 7 | ### **VoxTell: Free-Text Promptable Universal 3D Medical Image Segmentation** 8 | 9 | VoxTell is a **3D vision–language segmentation model** that directly maps free-form text prompts, from single words to full clinical sentences, to volumetric masks. By leveraging **multi-stage vision–language fusion**, VoxTell achieves state-of-the-art performance on anatomical and pathological structures across CT, PET, and MRI modalities, excelling on familiar concepts while generalizing to related unseen classes. 10 | 11 | > **Authors**: Maximilian Rokuss*, Moritz Langenberg*, Yannick Kirchhoff, Fabian Isensee, Benjamin Hamm, Constantin Ulrich, Sebastian Regnery, Lukas Bauer, Efthimios Katsigiannopulos, Tobias Norajitra, Klaus Maier-Hein 12 | > **Paper**: [![arXiv](https://img.shields.io/badge/arXiv-2511.11450-B31B1B.svg)](https://arxiv.org/abs/2511.11450) 13 | 14 | --- 15 | 16 | ## Overview 17 | 18 | VoxTell is trained on a **large-scale, multi-modality 3D medical imaging dataset**, aggregating **158 public sources** with over **62,000 volumetric images**. The data covers: 19 | 20 | - Brain, head & neck, thorax, abdomen, pelvis 21 | - Musculoskeletal system and extremities 22 | - Vascular structures, major organs, substructures, and lesions 23 | 24 | Concept Coverage 25 | 26 | This rich semantic diversity enables **language-conditioned 3D reasoning**, allowing VoxTell to generate volumetric masks from flexible textual descriptions, from coarse anatomical labels to fine-grained pathological findings. 27 | 28 | --- 29 | 30 | ## Architecture 31 | 32 | VoxTell combines **3D image encoding** with **text-prompt embeddings** and **multi-stage vision–language fusion**: 33 | 34 | - **Image Encoder**: Processes 3D volumetric input into latent feature representations 35 | - **Prompt Encoder**: We use the fozen [Qwen3-Embedding-4B](https://huggingface.co/Qwen/Qwen3-Embedding-4B) model to embed text prompts 36 | - **Prompt Decoder**: Transforms text queries and image latents into multi-scale text features 37 | - **Image Decoder**: Fuses visual and textual information at multiple resolutions using MaskFormer-style query-image fusion with deep supervision 38 | 39 | Architecture Diagram 40 | 41 | --- 42 | 43 | ## 🛠 Installation 44 | 45 | ### 1. Create a Virtual Environment 46 | 47 | VoxTell supports Python 3.10+ and works with Conda, pip, or any other virtual environment manager. Here's an example using Conda: 48 | 49 | ```bash 50 | conda create -n voxtell python=3.12 51 | conda activate voxtell 52 | ``` 53 | 54 | ### 2. Install PyTorch 55 | 56 | > [!WARNING] 57 | > **Temporary Compatibility Warning** 58 | > There is a known issue with **PyTorch 2.9.0** causing **OOM errors during inference** in `nnInteractive` (related to 3D convolutions — see the PyTorch issue [here](https://github.com/pytorch/pytorch/issues/166122)). 59 | > **Until this is resolved, please use PyTorch 2.8.0 or earlier.** 60 | 61 | Install PyTorch compatible with your CUDA version. For example, for Ubuntu with a modern NVIDIA GPU: 62 | 63 | ```bash 64 | pip install torch==2.8.0 torchvision==0.23.0 --index-url https://download.pytorch.org/whl/cu126 65 | ``` 66 | 67 | *For other configurations (macOS, CPU, different CUDA versions), please refer to the [PyTorch Get Started](https://pytorch.org/get-started/previous-versions/) page.* 68 | 69 | Install via pip (you can also use [uv](https://docs.astral.sh/uv/)): 70 | 71 | ```bash 72 | pip install voxtell 73 | ``` 74 | 75 | or install directly from the repository: 76 | 77 | ```bash 78 | git clone https://github.com/MIC-DKFZ/VoxTell 79 | cd VoxTell 80 | pip install -e . 81 | ``` 82 | 83 | --- 84 | 85 | ## 🚀 Getting Started 86 | 87 | > [!NOTE] 88 | > The model will soon be available on Hugging Face! 🤗 Stay tuned for the official release. 89 | 90 | ### Command-Line Interface (CLI) 91 | 92 | VoxTell provides a convenient command-line interface for running predictions: 93 | 94 | ```bash 95 | voxtell-predict -i input.nii.gz -o output_folder -m /path/to/model -p "liver" "spleen" "kidney" 96 | ``` 97 | 98 | **Single prompt:** 99 | ```bash 100 | voxtell-predict -i case001.nii.gz -o output_folder -m /path/to/model -p "liver" 101 | # Output: output_folder/case001_liver.nii.gz 102 | ``` 103 | 104 | **Multiple prompts (saves individual files by default):** 105 | ```bash 106 | voxtell-predict -i case001.nii.gz -o output_folder -m /path/to/model -p "liver" "spleen" "right kidney" 107 | # Outputs: 108 | # output_folder/case001_liver.nii.gz 109 | # output_folder/case001_spleen.nii.gz 110 | # output_folder/case001_right_kidney.nii.gz 111 | ``` 112 | 113 | **Save combined multi-label file:** 114 | ```bash 115 | voxtell-predict -i case001.nii.gz -o output_folder -m /path/to/model -p "liver" "spleen" --save-combined 116 | # Output: output_folder/case001.nii.gz (multi-label: 1=liver, 2=spleen) 117 | # ⚠️ WARNING: Overlapping structures will be overwritten by later prompts 118 | ``` 119 | 120 | #### CLI Options 121 | 122 | | Argument | Short | Required | Description | 123 | |----------|-------|----------|-------------| 124 | | `--input` | `-i` | Yes | Path to input NIfTI file | 125 | | `--output` | `-o` | Yes | Path to output folder | 126 | | `--model` | `-m` | Yes | Path to VoxTell model directory | 127 | | `--prompts` | `-p` | Yes | Text prompt(s) for segmentation | 128 | | `--device` | | No | Device to use: `cuda` (default) or `cpu` | 129 | | `--gpu` | | No | GPU device ID (default: 0) | 130 | | `--save-combined` | | No | Save multi-label file instead of individual files | 131 | | `--verbose` | | No | Enable verbose output | 132 | 133 | --- 134 | 135 | ### Python API 136 | 137 | For more control or integration into Python workflows, use the Python API: 138 | 139 | ```python 140 | import torch 141 | from voxtell.inference.predictor import VoxTellPredictor 142 | from nnunetv2.imageio.nibabel_reader_writer import NibabelIOWithReorient 143 | 144 | # Select device 145 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 146 | 147 | # Load image 148 | image_path = "/path/to/your/image.nii.gz" 149 | img, _ = NibabelIOWithReorient().read_images([image_path]) 150 | 151 | # Define text prompts 152 | text_prompts = ["liver", "right kidney", "left kidney", "spleen"] 153 | 154 | # Initialize predictor 155 | predictor = VoxTellPredictor( 156 | model_dir="/path/to/voxtell_model_directory", 157 | device=device, 158 | ) 159 | 160 | # Run prediction 161 | # Output shape: (num_prompts, x, y, z) 162 | voxtell_seg = predictor.predict_single_image(img, text_prompts) 163 | ``` 164 | 165 | #### Optional: Visualize Results 166 | 167 | You can visualize the segmentation results using [napari](https://napari.org/): 168 | 169 | ```bash 170 | pip install napari[all] 171 | ``` 172 | 173 | ```python 174 | import napari 175 | import numpy as np 176 | 177 | # Create a napari viewer and add the original image 178 | viewer = napari.Viewer() 179 | viewer.add_image(img, name='Image') 180 | 181 | # Add segmentation results as label layers for each prompt 182 | for i, prompt in enumerate(text_prompts): 183 | viewer.add_labels(voxtell_seg[i].astype(np.uint8), name=prompt) 184 | 185 | # Run napari 186 | napari.run() 187 | ``` 188 | 189 | ## Important: Image Orientation and Spacing 190 | 191 | - ⚠️ **Image Orientation (Critical)**: For correct anatomical localization (e.g., distinguishing left from right), images **must be in RAS orientation**. VoxTell was trained on data reoriented using [this specific reader](https://github.com/MIC-DKFZ/nnUNet/blob/86606c53ef9f556d6f024a304b52a48378453641/nnunetv2/imageio/nibabel_reader_writer.py#L101). Orientation mismatches can be a source of error. An easy way to test for this is if a simple prompt like "liver" fails and segments parts of the spleen instead. Make sure your image metadata is correct. 192 | 193 | - **Image Spacing**: The model does not resample images to a standardized spacing for faster inference. Performance may degrade on images with very uncommon voxel spacings (e.g., super high-resolution brain MRI). In such cases, consider resampling the image to a more typical clinical spacing (e.g., 1.5×1.5×1.5 mm³) before segmentation. 194 | 195 | --- 196 | 197 | ## 🗺️ Roadmap 198 | 199 | - [x] **Paper Published**: [arXiv:2511.11450](https://arxiv.org/abs/2511.11450) 200 | - [x] **Code Release**: Official implementation published 201 | - [x] **PyPI Package**: Package downloadable via pip 202 | - [ ] **Napari Plugin**: Integration into the napari viewer 203 | - [ ] **Model Release**: Public availability of pretrained weights 204 | - [ ] **Fine-Tuning**: Support and scripts for custom fine-tuning 205 | 206 | --- 207 | 208 | ## Citation 209 | 210 | ```bibtex 211 | @misc{rokuss2025voxtell, 212 | title={VoxTell: Free-Text Promptable Universal 3D Medical Image Segmentation}, 213 | author={Maximilian Rokuss and Moritz Langenberg and Yannick Kirchhoff and Fabian Isensee and Benjamin Hamm and Constantin Ulrich and Sebastian Regnery and Lukas Bauer and Efthimios Katsigiannopulos and Tobias Norajitra and Klaus Maier-Hein}, 214 | year={2025}, 215 | eprint={2511.11450}, 216 | archivePrefix={arXiv}, 217 | primaryClass={cs.CV}, 218 | url={https://arxiv.org/abs/2511.11450}, 219 | } 220 | ``` 221 | 222 | --- 223 | 224 | ## 📬 Contact 225 | 226 | For questions, issues, or collaborations, please contact: 227 | 228 | 📧 maximilian.rokuss@dkfz-heidelberg.de / moritz.langenberg@dkfz-heidelberg.de -------------------------------------------------------------------------------- /voxtell/inference/predict_from_raw_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Command-line entrypoint for VoxTell segmentation prediction. 4 | 5 | This script provides a CLI interface to run VoxTell predictions on medical images 6 | with free-text prompts. 7 | """ 8 | 9 | import argparse 10 | import sys 11 | from pathlib import Path 12 | from typing import List, Optional 13 | 14 | import numpy as np 15 | import torch 16 | 17 | from nnunetv2.imageio.nibabel_reader_writer import NibabelIOWithReorient 18 | from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO 19 | 20 | from voxtell.inference.predictor import VoxTellPredictor 21 | 22 | 23 | def get_reader_writer(file_path: str): 24 | """ 25 | Determine the appropriate reader/writer based on file extension. 26 | 27 | Args: 28 | file_path: Path to the input file. 29 | 30 | Returns: 31 | Appropriate reader/writer instance. 32 | """ 33 | suffix = Path(file_path).suffix.lower() 34 | if suffix in ['.nii', '.gz']: 35 | return NibabelIOWithReorient() 36 | else: 37 | raise ValueError( 38 | f"Unsupported file format: {suffix}. " 39 | "Only NIfTI format (.nii, .nii.gz) is currently supported. " 40 | "Images must be reorientable to standard orientation with correct metadata." 41 | ) 42 | 43 | 44 | def save_segmentation( 45 | segmentation: np.ndarray, 46 | output_folder: Path, 47 | input_filename: str, 48 | properties: dict, 49 | prompt_name: str = None, 50 | suffix: str = '.nii.gz' 51 | ) -> None: 52 | """ 53 | Save segmentation mask to file. 54 | 55 | Args: 56 | segmentation: Segmentation array to save. 57 | output_folder: Output folder path. 58 | input_filename: Original input filename (without extension). 59 | properties: Image properties from the reader. 60 | prompt_name: Optional prompt name to include in filename. 61 | suffix: File extension to use. 62 | """ 63 | if prompt_name: 64 | # Clean prompt name for filename 65 | safe_name = "".join(c if c.isalnum() or c in (' ', '_') else '_' for c in prompt_name) 66 | safe_name = safe_name.replace(' ', '_') 67 | output_file = output_folder / f"{input_filename}_{safe_name}{suffix}" 68 | else: 69 | output_file = output_folder / f"{input_filename}{suffix}" 70 | 71 | # Use NIfTI writer 72 | reader_writer = NibabelIOWithReorient() 73 | reader_writer.write_seg(segmentation, str(output_file), properties) 74 | print(f"Saved segmentation to: {output_file}") 75 | 76 | 77 | def parse_args() -> argparse.Namespace: 78 | """Parse command-line arguments.""" 79 | parser = argparse.ArgumentParser( 80 | description="VoxTell: Free-Text Promptable Universal 3D Medical Image Segmentation", 81 | formatter_class=argparse.RawDescriptionHelpFormatter, 82 | epilog=""" 83 | Examples: 84 | # Single prompt (saves to output_folder/case001_liver.nii.gz) 85 | voxtell-predict -i case001.nii.gz -o output_folder -m /path/to/model -p "liver" 86 | 87 | # Multiple prompts (saves individual files by default) 88 | voxtell-predict -i case001.nii.gz -o output_folder -m /path/to/model -p "liver" "spleen" "kidney" 89 | 90 | # Save combined multi-label file (with overlap warning) 91 | voxtell-predict -i case001.nii.gz -o output_folder -m /path/to/model -p "liver" "spleen" --save-combined 92 | 93 | # Use CPU 94 | voxtell-predict -i case001.nii.gz -o output_folder -m /path/to/model -p "liver" --device cpu 95 | """ 96 | ) 97 | 98 | parser.add_argument( 99 | '-i', '--input', 100 | type=str, 101 | required=True, 102 | help='Path to input image file (NIfTI format recommended)' 103 | ) 104 | 105 | parser.add_argument( 106 | '-o', '--output', 107 | type=str, 108 | required=True, 109 | help='Path to output folder where segmentation files will be saved' 110 | ) 111 | 112 | parser.add_argument( 113 | '-m', '--model', 114 | type=str, 115 | required=True, 116 | help='Path to VoxTell model directory containing plans.json and fold_0/' 117 | ) 118 | 119 | parser.add_argument( 120 | '-p', '--prompts', 121 | type=str, 122 | nargs='+', 123 | required=True, 124 | help='Text prompt(s) for segmentation (e.g., "liver" "spleen" "tumor")' 125 | ) 126 | 127 | parser.add_argument( 128 | '--device', 129 | type=str, 130 | default='cuda', 131 | choices=['cuda', 'cpu'], 132 | help='Device to use for inference (default: cuda)' 133 | ) 134 | 135 | parser.add_argument( 136 | '--gpu', 137 | type=int, 138 | default=0, 139 | help='GPU device ID to use (default: 0)' 140 | ) 141 | 142 | parser.add_argument( 143 | '--save-combined', 144 | action='store_true', 145 | help='Save all prompts in a single multi-label file (WARNING: overlapping structures will be overwritten by later prompts)' 146 | ) 147 | 148 | parser.add_argument( 149 | '--verbose', 150 | action='store_true', 151 | help='Enable verbose output' 152 | ) 153 | 154 | return parser.parse_args() 155 | 156 | 157 | def main() -> int: 158 | """Main entrypoint function.""" 159 | args = parse_args() 160 | 161 | # Validate inputs 162 | input_path = Path(args.input) 163 | if not input_path.exists(): 164 | raise FileNotFoundError(f"Input file does not exist: {input_path}") 165 | 166 | model_path = Path(args.model) 167 | if not model_path.exists(): 168 | raise FileNotFoundError(f"Model directory does not exist: {model_path}") 169 | 170 | if not (model_path / 'plans.json').exists(): 171 | raise FileNotFoundError(f"plans.json not found in model directory: {model_path}") 172 | 173 | if not (model_path / 'fold_0' / 'checkpoint_final.pth').exists(): 174 | raise FileNotFoundError(f"checkpoint_final.pth not found in {model_path / 'fold_0'}") 175 | 176 | # Setup device 177 | if args.device == 'cuda': 178 | if not torch.cuda.is_available(): 179 | print("Warning: CUDA not available, falling back to CPU", file=sys.stderr) 180 | device = torch.device('cpu') 181 | else: 182 | device = torch.device(f'cuda:{args.gpu}') 183 | if args.verbose: 184 | print(f"Using GPU: {args.gpu} ({torch.cuda.get_device_name(args.gpu)})") 185 | else: 186 | device = torch.device('cpu') 187 | if args.verbose: 188 | print("Using CPU") 189 | 190 | # Load image 191 | if args.verbose: 192 | print(f"Loading image: {input_path}") 193 | 194 | try: 195 | reader_writer = get_reader_writer(str(input_path)) 196 | img, props = reader_writer.read_images([str(input_path)]) 197 | except Exception as e: 198 | print(f"Error loading image: {e}", file=sys.stderr) 199 | return 1 200 | 201 | if args.verbose: 202 | print(f"Image shape: {img.shape}") 203 | print(f"Text prompts: {args.prompts}") 204 | print(f"Loading VoxTell model from: {model_path}") 205 | 206 | predictor = VoxTellPredictor( 207 | model_dir=str(model_path), 208 | device=device 209 | ) 210 | 211 | # Run prediction 212 | if args.verbose: 213 | print("Running prediction...") 214 | 215 | segmentations = predictor.predict_single_image(img, args.prompts) 216 | 217 | # Save results 218 | output_folder = Path(args.output) 219 | output_folder.mkdir(parents=True, exist_ok=True) 220 | 221 | # Get input filename without extension 222 | input_filename = input_path.stem 223 | if input_filename.endswith('.nii'): 224 | input_filename = input_filename[:-4] 225 | 226 | # Determine file suffix from input 227 | if input_path.suffix == '.gz' and input_path.stem.endswith('.nii'): 228 | suffix = '.nii.gz' 229 | else: 230 | suffix = input_path.suffix 231 | 232 | if args.save_combined: 233 | # Show warning about overlapping structures 234 | if len(args.prompts) > 1: 235 | print("\n" + "="*80) 236 | print("WARNING: Saving combined multi-label segmentation.") 237 | print("If prompts generate overlapping structures, later prompts will overwrite") 238 | print("earlier ones. This may result in loss of segmentation information.") 239 | print("Consider using individual file output (default) for overlapping structures.") 240 | print("="*80 + "\n") 241 | 242 | # Save all prompts in a single multi-label file 243 | if len(args.prompts) == 1: 244 | # Single prompt - save as-is 245 | save_segmentation(segmentations[0], output_folder, input_filename, props, suffix=suffix) 246 | else: 247 | # Multiple prompts - create multi-label segmentation 248 | # Each prompt gets a different label value (1, 2, 3, ...) 249 | # Later prompts overwrite earlier ones in case of overlap 250 | combined_seg = np.zeros_like(segmentations[0], dtype=np.uint8) 251 | for i, seg in enumerate(segmentations): 252 | combined_seg[seg > 0] = i + 1 253 | save_segmentation(combined_seg, output_folder, input_filename, props, suffix=suffix) 254 | 255 | print("\nLabel mapping:") 256 | for i, prompt in enumerate(args.prompts): 257 | print(f" {i + 1}: {prompt}") 258 | else: 259 | # Default: Save each prompt as a separate file 260 | for i, prompt in enumerate(args.prompts): 261 | save_segmentation( 262 | segmentations[i], 263 | output_folder, 264 | input_filename, 265 | props, 266 | prompt_name=prompt, 267 | suffix=suffix 268 | ) 269 | 270 | if args.verbose: 271 | print("\nPrediction completed successfully!") 272 | 273 | return 0 274 | 275 | 276 | if __name__ == '__main__': 277 | sys.exit(main()) 278 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [2019] [Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /voxtell/model/transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Text prompt decoder implementation for VoxTell. 3 | 4 | Code modified from DETR transformer: 5 | https://github.com/facebookresearch/detr 6 | Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 7 | """ 8 | 9 | import copy 10 | from typing import Callable, List, Optional, Tuple, Union 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch import nn, Tensor 15 | 16 | 17 | class TransformerDecoder(nn.Module): 18 | """ 19 | Transformer decoder consisting of multiple decoder layers. 20 | 21 | This decoder processes target sequences with attention to memory (encoder output), 22 | optionally returning intermediate layer outputs. It receives the text prompt embeddings 23 | as queries and attends to image features from the encoder. It outputs refined text-image 24 | fused features for segmentation mask prediction. 25 | 26 | Args: 27 | decoder_layer: A single transformer decoder layer to be cloned. 28 | num_layers: Number of decoder layers to stack. 29 | norm: Optional normalization layer applied to the final output. 30 | return_intermediate: If True, returns outputs from all layers stacked together. 31 | """ 32 | 33 | def __init__( 34 | self, 35 | decoder_layer: nn.Module, 36 | num_layers: int, 37 | norm: Optional[nn.Module] = None, 38 | return_intermediate: bool = False 39 | ) -> None: 40 | super().__init__() 41 | self.layers = _get_clones(decoder_layer, num_layers) 42 | self.num_layers = num_layers 43 | self.norm = norm 44 | self.return_intermediate = return_intermediate 45 | 46 | def forward( 47 | self, 48 | tgt: Tensor, 49 | memory: Tensor, 50 | tgt_mask: Optional[Tensor] = None, 51 | memory_mask: Optional[Tensor] = None, 52 | tgt_key_padding_mask: Optional[Tensor] = None, 53 | memory_key_padding_mask: Optional[Tensor] = None, 54 | pos: Optional[Tensor] = None, 55 | query_pos: Optional[Tensor] = None 56 | ) -> Union[Tensor, Tuple[Tensor, List[Tensor]]]: 57 | """ 58 | Forward pass through all decoder layers. 59 | 60 | Args: 61 | tgt: Target sequence tensor of shape [T, B, C]. 62 | memory: Memory (encoder output) tensor of shape [T, B, C]. 63 | tgt_mask: Attention mask for target self-attention. 64 | memory_mask: Attention mask for cross-attention to memory. 65 | tgt_key_padding_mask: Padding mask for target keys. 66 | memory_key_padding_mask: Padding mask for memory keys. 67 | pos: Positional embeddings for memory. 68 | query_pos: Positional embeddings for queries. 69 | 70 | Returns: 71 | If return_intermediate is True, returns stacked intermediate outputs. 72 | Otherwise, returns tuple of (final_output, attention_weights_list). 73 | """ 74 | output = tgt 75 | T, B, C = memory.shape 76 | intermediate = [] 77 | atten_layers = [] 78 | 79 | for n, layer in enumerate(self.layers): 80 | residual = True 81 | output, ws = layer( 82 | output, memory, 83 | tgt_mask=tgt_mask, 84 | memory_mask=memory_mask, 85 | tgt_key_padding_mask=tgt_key_padding_mask, 86 | memory_key_padding_mask=memory_key_padding_mask, 87 | pos=pos, 88 | query_pos=query_pos, 89 | residual=residual 90 | ) 91 | atten_layers.append(ws) 92 | if self.return_intermediate: 93 | intermediate.append(self.norm(output)) 94 | 95 | if self.norm is not None: 96 | output = self.norm(output) 97 | if self.return_intermediate: 98 | intermediate.pop() 99 | intermediate.append(output) 100 | 101 | if self.return_intermediate: 102 | return torch.stack(intermediate) 103 | return output, atten_layers 104 | 105 | 106 | 107 | class TransformerDecoderLayer(nn.Module): 108 | """ 109 | Single transformer decoder layer with self-attention, cross-attention, and FFN. 110 | 111 | This layer implements: 112 | 1. Self-attention on the target sequence 113 | 2. Cross-attention between target and memory (image encoder output) 114 | 3. Position-wise feed-forward network 115 | 116 | Args: 117 | d_model: Dimension of the model (embedding dimension). 118 | nhead: Number of attention heads. 119 | dim_feedforward: Dimension of the feedforward network. 120 | dropout: Dropout probability. 121 | activation: Activation function name ('relu', 'gelu', or 'glu'). 122 | normalize_before: If True, applies layer norm before attention/FFN (pre-norm). 123 | If False, applies after (post-norm). 124 | """ 125 | 126 | def __init__( 127 | self, 128 | d_model: int, 129 | nhead: int, 130 | dim_feedforward: int = 2048, 131 | dropout: float = 0.1, 132 | activation: str = "relu", 133 | normalize_before: bool = False 134 | ) -> None: 135 | super().__init__() 136 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 137 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 138 | 139 | # Feed-forward network 140 | self.linear1 = nn.Linear(d_model, dim_feedforward) 141 | self.dropout = nn.Dropout(dropout) 142 | self.linear2 = nn.Linear(dim_feedforward, d_model) 143 | 144 | # Normalization layers 145 | self.norm1 = nn.LayerNorm(d_model) 146 | self.norm2 = nn.LayerNorm(d_model) 147 | self.norm3 = nn.LayerNorm(d_model) 148 | 149 | # Dropout layers 150 | self.dropout1 = nn.Dropout(dropout) 151 | self.dropout2 = nn.Dropout(dropout) 152 | self.dropout3 = nn.Dropout(dropout) 153 | 154 | self.activation = _get_activation_fn(activation) 155 | self.normalize_before = normalize_before 156 | 157 | def with_pos_embed(self, tensor: Tensor, pos: Optional[Tensor]) -> Tensor: 158 | """ 159 | Add positional embeddings to tensor if provided. 160 | 161 | Args: 162 | tensor: Input tensor. 163 | pos: Optional positional embeddings to add. 164 | 165 | Returns: 166 | Tensor with positional embeddings added, or original tensor if pos is None. 167 | """ 168 | return tensor if pos is None else tensor + pos 169 | 170 | def forward_post( 171 | self, 172 | tgt: Tensor, 173 | memory: Tensor, 174 | tgt_mask: Optional[Tensor] = None, 175 | memory_mask: Optional[Tensor] = None, 176 | tgt_key_padding_mask: Optional[Tensor] = None, 177 | memory_key_padding_mask: Optional[Tensor] = None, 178 | pos: Optional[Tensor] = None, 179 | query_pos: Optional[Tensor] = None, 180 | residual: bool = True 181 | ) -> Tuple[Tensor, Tensor]: 182 | """ 183 | Post-norm forward pass: attention/FFN first, then normalization. 184 | 185 | Args: 186 | tgt: Target sequence. 187 | memory: Memory (encoder output). 188 | tgt_mask: Self-attention mask for target. 189 | memory_mask: Cross-attention mask. 190 | tgt_key_padding_mask: Padding mask for target keys. 191 | memory_key_padding_mask: Padding mask for memory keys. 192 | pos: Positional embeddings for memory. 193 | query_pos: Positional embeddings for queries. 194 | residual: Whether to use residual connections. 195 | 196 | Returns: 197 | Tuple of (output_tensor, attention_weights). 198 | """ 199 | q = k = self.with_pos_embed(tgt, query_pos) 200 | tgt2, ws = self.self_attn( 201 | q, k, value=tgt, 202 | attn_mask=tgt_mask, 203 | key_padding_mask=tgt_key_padding_mask 204 | ) 205 | tgt = self.norm1(tgt) 206 | tgt2, ws = self.multihead_attn( 207 | query=self.with_pos_embed(tgt, query_pos), 208 | key=self.with_pos_embed(memory, pos), 209 | value=memory, 210 | attn_mask=memory_mask, 211 | key_padding_mask=memory_key_padding_mask 212 | ) 213 | 214 | # Cross-attention with residual connection 215 | tgt = tgt + self.dropout2(tgt2) 216 | tgt = self.norm2(tgt) 217 | 218 | # Feed-forward network 219 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 220 | tgt = tgt + self.dropout3(tgt2) 221 | tgt = self.norm3(tgt) 222 | return tgt, ws 223 | 224 | def forward_pre( 225 | self, 226 | tgt: Tensor, 227 | memory: Tensor, 228 | tgt_mask: Optional[Tensor] = None, 229 | memory_mask: Optional[Tensor] = None, 230 | tgt_key_padding_mask: Optional[Tensor] = None, 231 | memory_key_padding_mask: Optional[Tensor] = None, 232 | pos: Optional[Tensor] = None, 233 | query_pos: Optional[Tensor] = None 234 | ) -> Tuple[Tensor, Tensor]: 235 | """ 236 | Pre-norm forward pass: normalization first, then attention/FFN. 237 | 238 | Args: 239 | tgt: Target sequence. 240 | memory: Memory (encoder output). 241 | tgt_mask: Self-attention mask for target. 242 | memory_mask: Cross-attention mask. 243 | tgt_key_padding_mask: Padding mask for target keys. 244 | memory_key_padding_mask: Padding mask for memory keys. 245 | pos: Positional embeddings for memory. 246 | query_pos: Positional embeddings for queries. 247 | 248 | Returns: 249 | Tuple of (output_tensor, attention_weights). 250 | """ 251 | tgt2 = self.norm2(tgt) 252 | tgt2, attn_weights = self.multihead_attn( 253 | query=self.with_pos_embed(tgt2, query_pos), 254 | key=self.with_pos_embed(memory, pos), 255 | value=memory, 256 | attn_mask=memory_mask, 257 | key_padding_mask=memory_key_padding_mask 258 | ) 259 | tgt = tgt + self.dropout2(tgt2) 260 | 261 | # Feed-forward network with pre-norm 262 | tgt2 = self.norm3(tgt) 263 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 264 | tgt = tgt + self.dropout3(tgt2) 265 | return tgt, attn_weights 266 | 267 | 268 | def forward_pre_selfattention( 269 | self, 270 | tgt: Tensor, 271 | memory: Tensor, 272 | tgt_mask: Optional[Tensor] = None, 273 | memory_mask: Optional[Tensor] = None, 274 | tgt_key_padding_mask: Optional[Tensor] = None, 275 | memory_key_padding_mask: Optional[Tensor] = None, 276 | pos: Optional[Tensor] = None, 277 | query_pos: Optional[Tensor] = None 278 | ) -> Tuple[Tensor, Tensor]: 279 | """ 280 | Alternative forward pass with cross-attention before self-attention. 281 | 282 | This variant applies operations in the order: 283 | 1. Cross-attention (without normalization) 284 | 2. Self-attention (with pre-norm) 285 | 3. Feed-forward network (with pre-norm) 286 | 287 | Args: 288 | tgt: Target sequence. 289 | memory: Memory (encoder output). 290 | tgt_mask: Self-attention mask for target. 291 | memory_mask: Cross-attention mask. 292 | tgt_key_padding_mask: Padding mask for target keys. 293 | memory_key_padding_mask: Padding mask for memory keys. 294 | pos: Positional embeddings for memory. 295 | query_pos: Positional embeddings for queries. 296 | 297 | Returns: 298 | Tuple of (output_tensor, attention_weights). 299 | """ 300 | # Cross-attention without pre-normalization 301 | tgt2, attn_weights = self.multihead_attn( 302 | query=self.with_pos_embed(tgt, query_pos), 303 | key=self.with_pos_embed(memory, pos), 304 | value=memory, 305 | attn_mask=memory_mask, 306 | key_padding_mask=memory_key_padding_mask 307 | ) 308 | tgt = tgt + tgt2 309 | 310 | # Self-attention with pre-norm 311 | tgt2 = self.norm1(tgt) 312 | q = k = self.with_pos_embed(tgt2, query_pos) 313 | tgt2, ws = self.self_attn(q, k, value=tgt2) 314 | tgt = tgt + tgt2 315 | 316 | # Feed-forward network with pre-norm 317 | tgt2 = self.norm2(tgt) 318 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 319 | tgt = tgt + tgt2 320 | tgt = self.norm3(tgt) 321 | return tgt, attn_weights 322 | 323 | 324 | def forward( 325 | self, 326 | tgt: Tensor, 327 | memory: Tensor, 328 | tgt_mask: Optional[Tensor] = None, 329 | memory_mask: Optional[Tensor] = None, 330 | tgt_key_padding_mask: Optional[Tensor] = None, 331 | memory_key_padding_mask: Optional[Tensor] = None, 332 | pos: Optional[Tensor] = None, 333 | query_pos: Optional[Tensor] = None, 334 | residual: bool = True 335 | ) -> Tuple[Tensor, Tensor]: 336 | """ 337 | Forward pass through the decoder layer. 338 | 339 | Dispatches to either pre-norm or post-norm variant based on configuration. 340 | 341 | Args: 342 | tgt: Target sequence. 343 | memory: Memory (encoder output). 344 | tgt_mask: Self-attention mask for target. 345 | memory_mask: Cross-attention mask. 346 | tgt_key_padding_mask: Padding mask for target keys. 347 | memory_key_padding_mask: Padding mask for memory keys. 348 | pos: Positional embeddings for memory. 349 | query_pos: Positional embeddings for queries. 350 | residual: Whether to use residual connections (used in post-norm variant). 351 | 352 | Returns: 353 | Tuple of (output_tensor, attention_weights). 354 | """ 355 | if self.normalize_before: 356 | return self.forward_pre( 357 | tgt, memory, tgt_mask, memory_mask, 358 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos 359 | ) 360 | return self.forward_post( 361 | tgt, memory, tgt_mask, memory_mask, 362 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos, residual 363 | ) 364 | 365 | 366 | def _get_clones(module: nn.Module, N: int) -> nn.ModuleList: 367 | """ 368 | Create N identical copies of a module. 369 | 370 | Args: 371 | module: The module to clone. 372 | N: Number of clones to create. 373 | 374 | Returns: 375 | ModuleList containing N deep copies of the input module. 376 | """ 377 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 378 | 379 | 380 | def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: 381 | """ 382 | Return an activation function given a string identifier. 383 | 384 | Args: 385 | activation: Name of the activation function ('relu', 'gelu', or 'glu'). 386 | 387 | Returns: 388 | The corresponding activation function. 389 | 390 | Raises: 391 | RuntimeError: If the activation name is not recognized. 392 | """ 393 | if activation == "relu": 394 | return F.relu 395 | if activation == "gelu": 396 | return F.gelu 397 | if activation == "glu": 398 | return F.glu 399 | raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.") -------------------------------------------------------------------------------- /voxtell/inference/predictor.py: -------------------------------------------------------------------------------- 1 | import pydoc 2 | from queue import Queue 3 | from threading import Thread 4 | from typing import List, Tuple, Union 5 | 6 | import numpy as np 7 | import torch 8 | from torch._dynamo import OptimizedModule 9 | from tqdm import tqdm 10 | from transformers import AutoModel, AutoTokenizer 11 | 12 | from acvl_utils.cropping_and_padding.bounding_boxes import insert_crop_into_image 13 | from acvl_utils.cropping_and_padding.padding import pad_nd_image 14 | from batchgenerators.utilities.file_and_folder_operations import join, load_json 15 | 16 | from nnunetv2.inference.sliding_window_prediction import compute_gaussian, compute_steps_for_sliding_window 17 | from nnunetv2.preprocessing.cropping.cropping import crop_to_nonzero 18 | from nnunetv2.preprocessing.normalization.default_normalization_schemes import ZScoreNormalization 19 | from nnunetv2.utilities.helpers import dummy_context, empty_cache 20 | 21 | from voxtell.model.voxtell_model import VoxTellModel 22 | from voxtell.utils.text_embedding import last_token_pool, wrap_with_instruction 23 | 24 | 25 | class VoxTellPredictor: 26 | """ 27 | Predictor for VoxTell segmentation model. 28 | 29 | This class handles loading the VoxTell model, preprocessing images, 30 | embedding text prompts, and performing sliding window inference to generate 31 | segmentation masks based on free-text anatomical descriptions. 32 | 33 | Attributes: 34 | device: PyTorch device for inference. 35 | network: The VoxTell model. 36 | tokenizer: Text tokenizer for prompt encoding. 37 | text_backbone: Text embedding model. 38 | patch_size: Patch size for sliding window inference. 39 | tile_step_size: Step size for sliding window (default: 0.5 = 50% overlap). 40 | perform_everything_on_device: Keep all tensors on device during inference. 41 | max_text_length: Maximum text prompt length in tokens. 42 | """ 43 | def __init__(self, model_dir: str, device: torch.device = torch.device('cuda'), 44 | text_encoding_model: str = 'Qwen/Qwen3-Embedding-4B') -> None: 45 | """ 46 | Initialize the VoxTell predictor. 47 | 48 | Args: 49 | model_dir: Path to model directory containing plans.json and checkpoint. 50 | device: PyTorch device to use for inference (default: cuda). 51 | text_encoding_model: Pretrained text encoding model (Qwen/Qwen3-Embedding-4B). 52 | 53 | Raises: 54 | FileNotFoundError: If model files are not found. 55 | RuntimeError: If model loading fails. 56 | """ 57 | # Device setup 58 | self.device = device 59 | if device.type == 'cuda': 60 | torch.backends.cudnn.benchmark = True 61 | self.normalization = ZScoreNormalization(intensityproperties={}) 62 | 63 | # Predictor settings 64 | self.tile_step_size = 0.5 65 | self.perform_everything_on_device = True 66 | 67 | # Embedding model 68 | self.tokenizer = AutoTokenizer.from_pretrained(text_encoding_model, padding_side='left') 69 | self.text_backbone = AutoModel.from_pretrained(text_encoding_model).eval() 70 | self.max_text_length = 8192 71 | 72 | # Load network settings 73 | plans = load_json(join(model_dir, 'plans.json')) 74 | arch_kwargs = plans['configurations']['3d_fullres']['architecture']['arch_kwargs'] 75 | self.patch_size = plans['configurations']['3d_fullres']['patch_size'] 76 | 77 | arch_kwargs = dict(**arch_kwargs) 78 | for required_import_key in plans['configurations']['3d_fullres']['architecture']['_kw_requires_import']: 79 | if arch_kwargs[required_import_key] is not None: 80 | arch_kwargs[required_import_key] = pydoc.locate(arch_kwargs[required_import_key]) 81 | 82 | # Instantiate network 83 | network = VoxTellModel( 84 | input_channels=1, 85 | **arch_kwargs, 86 | decoder_layer=4, 87 | text_embedding_dim=2560, 88 | num_maskformer_stages=5, 89 | num_heads=32, 90 | query_dim=2048, 91 | project_to_decoder_hidden_dim=2048, 92 | deep_supervision=False 93 | ) 94 | 95 | # Load weights 96 | checkpoint = torch.load( 97 | join(model_dir, 'fold_0', 'checkpoint_final.pth'), 98 | map_location=torch.device('cpu'), 99 | weights_only=False 100 | ) 101 | 102 | if not isinstance(network, OptimizedModule): 103 | network.load_state_dict(checkpoint['network_weights']) 104 | else: 105 | network._orig_mod.load_state_dict(checkpoint['network_weights']) 106 | 107 | network.eval() 108 | self.network = network 109 | 110 | def preprocess(self, data: np.ndarray) -> Tuple[torch.Tensor, Tuple, Tuple[int, ...]]: 111 | """ 112 | Preprocess a single image for inference. 113 | 114 | This function preprocesses an image already in RAS orientation by performing 115 | cropping to non-zero regions and z-score normalization. 116 | 117 | Args: 118 | data: Image data in RAS orientation (3D or 4D with channel dimension). 119 | 120 | Returns: 121 | Tuple containing: 122 | - Preprocessed image tensor 123 | - Bounding box of cropped region 124 | - Original image shape 125 | """ 126 | 127 | if data.ndim == 3: 128 | data = data[None] # add channel axis 129 | data = data.astype(np.float32) # this creates a copy 130 | original_shape = data.shape[1:] 131 | data, _, bbox = crop_to_nonzero(data, None) 132 | data = self.normalization.run(data, None) 133 | data = torch.from_numpy(data) 134 | return data, bbox, original_shape 135 | 136 | def _internal_get_sliding_window_slicers(self, image_size: Tuple[int, ...]) -> List[Tuple]: 137 | """ 138 | Generate sliding window slicers for patch-based inference. 139 | 140 | Args: 141 | image_size: Shape of the input image. 142 | 143 | Returns: 144 | List of slice tuples for extracting patches. 145 | """ 146 | slicers = [] 147 | if len(self.patch_size) < len(image_size): 148 | assert len(self.patch_size) == len(image_size) - 1, ( 149 | 'if tile_size has less entries than image_size, ' 150 | 'len(tile_size) must be one shorter than len(image_size) ' 151 | '(only dimension discrepancy of 1 allowed).' 152 | ) 153 | steps = compute_steps_for_sliding_window(image_size[1:], self.patch_size, 154 | self.tile_step_size) 155 | for d in range(image_size[0]): 156 | for sx in steps[0]: 157 | for sy in steps[1]: 158 | slicers.append( 159 | tuple([slice(None), d, *[slice(si, si + ti) for si, ti in 160 | zip((sx, sy), self.patch_size)]])) 161 | else: 162 | steps = compute_steps_for_sliding_window(image_size, self.patch_size, 163 | self.tile_step_size) 164 | for sx in steps[0]: 165 | for sy in steps[1]: 166 | for sz in steps[2]: 167 | slicers.append( 168 | tuple([slice(None), *[slice(si, si + ti) for si, ti in 169 | zip((sx, sy, sz), self.patch_size)]])) 170 | return slicers 171 | 172 | @torch.inference_mode() 173 | def embed_text_prompts(self, text_prompts: Union[List[str], str]) -> torch.Tensor: 174 | """ 175 | Embed text prompts into vector representations. 176 | 177 | This function converts free-text anatomical descriptions into embeddings 178 | using the text backbone model. 179 | 180 | Args: 181 | text_prompts: Single text prompt or list of text prompts. 182 | 183 | Returns: 184 | Text embeddings tensor of shape (1, num_prompts, embedding_dim). 185 | """ 186 | if isinstance(text_prompts, str): 187 | text_prompts = [text_prompts] 188 | n_prompts = len(text_prompts) 189 | self.text_backbone = self.text_backbone.to(self.device) 190 | 191 | text_prompts = wrap_with_instruction(text_prompts) 192 | text_tokens = self.tokenizer( 193 | text_prompts, 194 | padding=True, 195 | truncation=True, 196 | max_length=self.max_text_length, 197 | return_tensors="pt", 198 | ) 199 | text_tokens = {k: v.to(self.device) for k, v in text_tokens.items()} 200 | text_embed = self.text_backbone(**text_tokens) 201 | embeddings = last_token_pool(text_embed.last_hidden_state, text_tokens['attention_mask']) 202 | embeddings = embeddings.view(1, n_prompts, -1) 203 | self.text_backbone = self.text_backbone.to('cpu') 204 | empty_cache(self.device) 205 | return embeddings 206 | 207 | @torch.inference_mode() 208 | def predict_sliding_window_return_logits( 209 | self, 210 | input_image: torch.Tensor, 211 | text_embeddings: torch.Tensor 212 | ) -> torch.Tensor: 213 | """ 214 | Perform sliding window inference to generate segmentation logits. 215 | 216 | Args: 217 | input_image: Input image tensor of shape (C, X, Y, Z). 218 | text_embeddings: Text embeddings from embed_text_prompts. 219 | 220 | Returns: 221 | Predicted logits tensor. 222 | 223 | Raises: 224 | ValueError: If input_image is not 4D or not a torch.Tensor. 225 | """ 226 | if not isinstance(input_image, torch.Tensor): 227 | raise ValueError(f"input_image must be a torch.Tensor, got {type(input_image)}") 228 | if input_image.ndim != 4: 229 | raise ValueError( 230 | f"input_image must be 4D (C, X, Y, Z), got shape {input_image.shape}" 231 | ) 232 | 233 | self.network = self.network.to(self.device) 234 | 235 | empty_cache(self.device) 236 | with torch.autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context(): 237 | 238 | # if input_image is smaller than tile_size we need to pad it to tile_size. 239 | data, slicer_revert_padding = pad_nd_image(input_image, self.patch_size, 240 | 'constant', {'value': 0}, True, None) 241 | 242 | slicers = self._internal_get_sliding_window_slicers(data.shape[1:]) 243 | 244 | predicted_logits = self._internal_predict_sliding_window_return_logits( 245 | data, text_embeddings, slicers, self.perform_everything_on_device 246 | ) 247 | 248 | empty_cache(self.device) 249 | # Revert padding 250 | predicted_logits = predicted_logits[(slice(None), *slicer_revert_padding[1:])] 251 | return predicted_logits 252 | 253 | @torch.inference_mode() 254 | def _internal_predict_sliding_window_return_logits( 255 | self, 256 | data: torch.Tensor, 257 | text_embeddings: torch.Tensor, 258 | slicers: List[Tuple], 259 | do_on_device: bool = True, 260 | ) -> torch.Tensor: 261 | """ 262 | Internal method for sliding window prediction with Gaussian weighting. 263 | 264 | Uses a producer-consumer pattern with threading to overlap data loading 265 | and model inference. 266 | 267 | Args: 268 | data: Preprocessed image data. 269 | text_embeddings: Text embeddings for prompts. 270 | slicers: List of slice tuples for patch extraction. 271 | do_on_device: If True, keep all tensors on GPU during computation. 272 | 273 | Returns: 274 | Aggregated prediction logits. 275 | 276 | Raises: 277 | RuntimeError: If inf values are encountered in predictions. 278 | """ 279 | results_device = self.device if do_on_device else torch.device('cpu') 280 | 281 | def producer(data_tensor, slicer_list, queue): 282 | """Producer thread that loads patches into queue.""" 283 | for slicer in slicer_list: 284 | patch = torch.clone( 285 | data_tensor[slicer][None], 286 | memory_format=torch.contiguous_format 287 | ).to(self.device) 288 | queue.put((patch, slicer)) 289 | queue.put('end') 290 | 291 | empty_cache(self.device) 292 | 293 | # move data to device 294 | data = data.to(results_device) 295 | queue = Queue(maxsize=2) 296 | t = Thread(target=producer, args=(data, slicers, queue)) 297 | t.start() 298 | 299 | # preallocate arrays 300 | predicted_logits = torch.zeros((text_embeddings.shape[1], *data.shape[1:]), 301 | dtype=torch.half, 302 | device=results_device) 303 | n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device) 304 | 305 | gaussian = compute_gaussian( 306 | tuple(self.patch_size), 307 | sigma_scale=1. / 8, 308 | value_scaling_factor=10, 309 | device=results_device 310 | ) 311 | 312 | with tqdm(desc=None, total=len(slicers)) as pbar: 313 | while True: 314 | item = queue.get() 315 | if item == 'end': 316 | queue.task_done() 317 | break 318 | patch, tile_slice = item 319 | prediction = self.network(patch, text_embeddings)[0].to(results_device) 320 | prediction *= gaussian 321 | predicted_logits[tile_slice] += prediction 322 | n_predictions[tile_slice[1:]] += gaussian 323 | queue.task_done() 324 | pbar.update() 325 | queue.join() 326 | 327 | # Normalize by number of predictions per voxel 328 | torch.div(predicted_logits, n_predictions, out=predicted_logits) 329 | 330 | # Check for inf values 331 | if torch.any(torch.isinf(predicted_logits)): 332 | raise RuntimeError( 333 | 'Encountered inf in predicted array. Aborting... ' 334 | 'If this problem persists, reduce value_scaling_factor in ' 335 | 'compute_gaussian or increase the dtype of predicted_logits to fp32.' 336 | ) 337 | return predicted_logits 338 | 339 | def predict_single_image( 340 | self, 341 | data: np.ndarray, 342 | text_prompts: Union[str, List[str]] 343 | ) -> np.ndarray: 344 | """ 345 | Predict segmentation masks for a single image with text prompts. 346 | 347 | This is the main prediction method that orchestrates preprocessing, 348 | text embedding, sliding window inference, and postprocessing. 349 | 350 | Args: 351 | data: Image data in RAS orientation (3D or 4D with channel dimension). 352 | text_prompts: Single text prompt or list of text prompts describing 353 | anatomical structures to segment. 354 | 355 | Returns: 356 | Segmentation masks as numpy array of shape (num_prompts, X, Y, Z) 357 | with binary values (0 or 1) indicating the segmented regions. 358 | """ 359 | 360 | # Preprocess image 361 | data, bbox, orig_shape = self.preprocess(data) 362 | 363 | # Embed text prompts 364 | embeddings = self.embed_text_prompts(text_prompts) 365 | 366 | # Predict segmentation logits 367 | prediction = self.predict_sliding_window_return_logits(data, embeddings).to('cpu') 368 | 369 | # Postprocess logits to get binary segmentation masks 370 | with torch.no_grad(): 371 | prediction = torch.sigmoid(prediction.float()) > 0.5 372 | 373 | segmentation_reverted_cropping = np.zeros( 374 | [prediction.shape[0], *orig_shape], 375 | dtype=np.uint8 376 | ) 377 | segmentation_reverted_cropping = insert_crop_into_image( 378 | segmentation_reverted_cropping, prediction, bbox 379 | ) 380 | 381 | return segmentation_reverted_cropping 382 | 383 | 384 | if __name__ == '__main__': 385 | from pathlib import Path 386 | from nnunetv2.imageio.nibabel_reader_writer import NibabelIOWithReorient 387 | 388 | # Default paths - modify these as needed 389 | DEFAULT_IMAGE_PATH = "/path/to/your/image.nii.gz" 390 | DEFAULT_MODEL_DIR = "/path/to/your/model/directory" 391 | 392 | # Configuration 393 | image_path = DEFAULT_IMAGE_PATH 394 | model_dir = DEFAULT_MODEL_DIR 395 | text_prompts = ["liver", "right kidney", "left kidney", "spleen"] 396 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 397 | 398 | # Load image 399 | img, props = NibabelIOWithReorient().read_images([image_path]) 400 | 401 | # Initialize predictor and run inference 402 | predictor = VoxTellPredictor(model_dir=model_dir, device=device) 403 | voxtell_seg = predictor.predict_single_image(img, text_prompts) 404 | 405 | # Visualize results, we reccommend using napari for 3D visualization 406 | import napari 407 | viewer = napari.Viewer() 408 | viewer.add_image(img, name='image') 409 | for i, prompt in enumerate(text_prompts): 410 | viewer.add_labels(voxtell_seg[i], name=f'voxtell_{prompt}') 411 | napari.run() -------------------------------------------------------------------------------- /voxtell/model/voxtell_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn.modules.conv import _ConvNd 4 | from torch.nn.modules.dropout import _DropoutNd 5 | 6 | from typing import List, Type, Union, Tuple 7 | from dynamic_network_architectures.building_blocks.helper import get_matching_convtransp 8 | from dynamic_network_architectures.building_blocks.plain_conv_encoder import PlainConvEncoder 9 | from dynamic_network_architectures.building_blocks.residual import BasicBlockD, BottleneckD 10 | from dynamic_network_architectures.building_blocks.residual_encoders import ResidualEncoder 11 | from dynamic_network_architectures.building_blocks.simple_conv_blocks import StackedConvBlocks 12 | from dynamic_network_architectures.initialization.weight_init import InitWeights_He, init_last_bn_before_add_to_0 13 | from einops import rearrange, repeat 14 | from positional_encodings.torch_encodings import PositionalEncoding3D 15 | 16 | from voxtell.model.transformer import TransformerDecoder, TransformerDecoderLayer 17 | 18 | 19 | class VoxTellModel(nn.Module): 20 | """ 21 | VoxTell segmentation model with text-prompted decoder. 22 | 23 | This model combines a ResidualEncoder backbone with a transformer-based decoder 24 | that uses text embeddings to generate segmentation masks. It supports multi-stage 25 | decoding with optional deep supervision. 26 | 27 | Attributes: 28 | encoder: ResidualEncoder backbone for feature extraction. 29 | decoder: VoxTellDecoder for multi-scale feature decoding. 30 | transformer_decoder: Transformer for fusing text and image features. 31 | deep_supervision: Whether to return multi-scale predictions. 32 | """ 33 | 34 | # Class constants for transformer architecture (text prompt decoder) 35 | TRANSFORMER_NUM_HEADS = 8 36 | TRANSFORMER_NUM_LAYERS = 6 37 | 38 | # Decoder configuration for different stages 39 | DECODER_CONFIGS = { 40 | 0: {"channels": 32, "shape": (192, 192, 192)}, 41 | 1: {"channels": 64, "shape": (96, 96, 96)}, 42 | 2: {"channels": 128, "shape": (48, 48, 48)}, 43 | 3: {"channels": 256, "shape": (24, 24, 24)}, 44 | 4: {"channels": 320, "shape": (12, 12, 12)}, 45 | 5: {"channels": 320, "shape": (6, 6, 6)}, 46 | } 47 | 48 | def __init__( 49 | self, 50 | input_channels: int, 51 | n_stages: int, 52 | features_per_stage: Union[int, List[int], Tuple[int, ...]], 53 | conv_op: Type[_ConvNd], 54 | kernel_sizes: Union[int, List[int], Tuple[int, ...]], 55 | strides: Union[int, List[int], Tuple[int, ...]], 56 | n_blocks_per_stage: Union[int, List[int], Tuple[int, ...]], 57 | n_conv_per_stage_decoder: Union[int, Tuple[int, ...], List[int]], 58 | conv_bias: bool = False, 59 | norm_op: Union[None, Type[nn.Module]] = None, 60 | norm_op_kwargs: dict = None, 61 | dropout_op: Union[None, Type[_DropoutNd]] = None, 62 | dropout_op_kwargs: dict = None, 63 | nonlin: Union[None, Type[torch.nn.Module]] = None, 64 | nonlin_kwargs: dict = None, 65 | deep_supervision: bool = False, 66 | block: Union[Type[BasicBlockD], Type[BottleneckD]] = BasicBlockD, 67 | bottleneck_channels: Union[int, List[int], Tuple[int, ...]] = None, 68 | stem_channels: int = None, 69 | # Text-prompted segmentation parameters 70 | num_maskformer_stages: int = 5, 71 | query_dim: int = 864, 72 | decoder_layer: int = 4, 73 | text_embedding_dim: int = 1024, 74 | num_heads: int = 1, 75 | project_to_decoder_hidden_dim: int = 432 76 | ) -> None: 77 | """ 78 | Initialize the VoxTell model. 79 | 80 | Args: 81 | input_channels: Number of input channels. 82 | n_stages: Number of encoder stages. 83 | features_per_stage: Number of features at each stage. 84 | conv_op: Convolution operation type. 85 | kernel_sizes: Kernel sizes for convolutions. 86 | strides: Strides for downsampling. 87 | n_blocks_per_stage: Number of residual blocks per stage. 88 | n_conv_per_stage_decoder: Number of convolutions per stage. 89 | conv_bias: Whether to use bias in convolutions. 90 | norm_op: Normalization operation. 91 | norm_op_kwargs: Normalization operation keyword arguments. 92 | dropout_op: Dropout operation. 93 | dropout_op_kwargs: Dropout operation keyword arguments. 94 | nonlin: Non-linearity operation. 95 | nonlin_kwargs: Non-linearity keyword arguments. 96 | deep_supervision: Whether to use deep supervision. 97 | block: Residual block type (BasicBlockD or BottleneckD). 98 | bottleneck_channels: Channels in bottleneck layers. 99 | stem_channels: Channels in stem layer. 100 | num_maskformer_stages: Number of stages to fuse text-image embeddings in decoder. 101 | query_dim: Dimension of query embeddings. 102 | decoder_layer: Which decoder layer to use as memory for text prompt decoder (0-5). 103 | text_embedding_dim: Dimension of text embeddings. 104 | num_heads: Number of channels added per U-Net stage for mask embedding fusion. 105 | project_to_decoder_hidden_dim: Hidden dimension for projection to decoder. 106 | """ 107 | super().__init__() 108 | if isinstance(n_blocks_per_stage, int): 109 | n_blocks_per_stage = [n_blocks_per_stage] * n_stages 110 | if isinstance(n_conv_per_stage_decoder, int): 111 | n_conv_per_stage_decoder = [n_conv_per_stage_decoder] * (n_stages - 1) 112 | 113 | assert len(n_blocks_per_stage) == n_stages, ( 114 | f"n_blocks_per_stage must have as many entries as we have resolution stages. " 115 | f"Expected: {n_stages}, got: {len(n_blocks_per_stage)} ({n_blocks_per_stage})" 116 | ) 117 | assert len(n_conv_per_stage_decoder) == (n_stages - 1), ( 118 | f"n_conv_per_stage_decoder must have one less entry than resolution stages. " 119 | f"Expected: {n_stages - 1}, got: {len(n_conv_per_stage_decoder)} ({n_conv_per_stage_decoder})" 120 | ) 121 | if num_maskformer_stages != 5: 122 | assert not deep_supervision, ( 123 | "Deep supervision is not supported for num_maskformer_stages != 5." 124 | ) 125 | self.deep_supervision = deep_supervision 126 | self.num_heads = num_heads 127 | self.query_dim = query_dim 128 | self.project_to_decoder_hidden_dim = project_to_decoder_hidden_dim 129 | self.text_embedding_dim = text_embedding_dim 130 | 131 | # Initialize encoder backbone 132 | self.encoder = ResidualEncoder( 133 | input_channels, n_stages, features_per_stage, conv_op, kernel_sizes, strides, 134 | n_blocks_per_stage, conv_bias, norm_op, norm_op_kwargs, dropout_op, 135 | dropout_op_kwargs, nonlin, nonlin_kwargs, block, bottleneck_channels, 136 | return_skips=True, disable_default_stem=False, stem_channels=stem_channels 137 | ) 138 | 139 | # Initialize decoder 140 | self.decoder = VoxTellDecoder( 141 | self.encoder, 1, n_conv_per_stage_decoder, deep_supervision, 142 | num_maskformer_stages, num_heads=self.num_heads 143 | ) 144 | 145 | # Select decoder layer configuration 146 | self.selected_decoder_layer = decoder_layer 147 | if decoder_layer not in self.DECODER_CONFIGS: 148 | raise ValueError( 149 | f"decoder_layer must be in {list(self.DECODER_CONFIGS.keys())}, got {decoder_layer}" 150 | ) 151 | selected_config = self.DECODER_CONFIGS[decoder_layer] 152 | 153 | h, w, d = selected_config["shape"] 154 | 155 | # Project bottleneck embeddings to query dimension 156 | self.project_bottleneck_embed = nn.Sequential( 157 | nn.Linear(selected_config["channels"], query_dim), 158 | nn.GELU(), 159 | nn.Linear(query_dim, query_dim), 160 | ) 161 | 162 | # Project text embeddings to query dimension 163 | text_hidden_dim = 2048 164 | self.project_text_embed = nn.Sequential( 165 | nn.Linear(self.text_embedding_dim, text_hidden_dim), 166 | nn.GELU(), 167 | nn.Linear(text_hidden_dim, query_dim), 168 | ) 169 | 170 | # Project decoder output to image channels for each mask-former stage 171 | self.project_to_decoder_channels = nn.ModuleList([ 172 | nn.Sequential( 173 | nn.Linear(query_dim, self.project_to_decoder_hidden_dim), 174 | nn.GELU(), 175 | nn.Linear( 176 | self.project_to_decoder_hidden_dim, 177 | decoder_config["channels"] * self.num_heads if stage_idx != 0 else decoder_config["channels"] 178 | ) 179 | ) 180 | for stage_idx, decoder_config in enumerate( 181 | list(self.DECODER_CONFIGS.values())[:num_maskformer_stages] 182 | ) 183 | ]) 184 | 185 | # Initialize 3D positional encoding 186 | # Shape: (H*W*D, batch_size, query_dim) 187 | pos_embed = PositionalEncoding3D(query_dim)(torch.zeros(1, h, w, d, query_dim)) 188 | pos_embed = rearrange(pos_embed, 'b h w d c -> (h w d) b c') 189 | self.register_buffer('pos_embed', pos_embed) 190 | 191 | # Initialize transformer decoder for fusing text and image features (prompt decoder) 192 | transformer_layer = TransformerDecoderLayer( 193 | d_model=query_dim, 194 | nhead=self.TRANSFORMER_NUM_HEADS, 195 | normalize_before=True 196 | ) 197 | decoder_norm = nn.LayerNorm(query_dim) 198 | self.transformer_decoder = TransformerDecoder( 199 | decoder_layer=transformer_layer, 200 | num_layers=self.TRANSFORMER_NUM_LAYERS, 201 | norm=decoder_norm 202 | ) 203 | 204 | def forward( 205 | self, 206 | img: torch.Tensor, 207 | text_embedding: torch.Tensor = None 208 | ) -> Union[torch.Tensor, List[torch.Tensor]]: 209 | """ 210 | Forward pass through VoxTell model. 211 | 212 | Args: 213 | img: Input image tensor of shape (B, C, D, H, W). 214 | text_embedding: Pre-computed text embeddings of shape (B, N, D) where 215 | N is number of prompts and D is embedding dimension. 216 | 217 | Returns: 218 | If deep_supervision is False, returns single prediction tensor of shape (B, N, D, H, W). 219 | If deep_supervision is True, returns list of prediction tensors at different scales. 220 | """ 221 | # Extract multi-scale features from encoder 222 | skips = self.encoder(img) 223 | 224 | # Select encoder features 225 | selected_feature = skips[self.selected_decoder_layer] 226 | 227 | # Reshape and project features to query dimension 228 | # Shape: (B, C, D, H, W) -> (B, H, W, D, C) -> (B, H, W, D, query_dim) 229 | bottleneck_embed = rearrange(selected_feature, 'b c d h w -> b h w d c') 230 | bottleneck_embed = self.project_bottleneck_embed(bottleneck_embed) 231 | # Shape: (B, H, W, D, query_dim) -> (H*W*D, B, query_dim) for transformer 232 | bottleneck_embed = rearrange(bottleneck_embed, 'b h w d c -> (h w d) b c') 233 | 234 | # Remove singleton dimension from text embeddings and project 235 | # Shape: (B, N, 1, D) -> (B, N, D) 236 | text_embedding = text_embedding.squeeze(2) 237 | # Shape: (B, N, D) -> (N, B, D) as required by transformer decoder 238 | text_embed = repeat(text_embedding, 'b n dim -> n b dim') 239 | text_embed = self.project_text_embed(text_embed) 240 | 241 | # Fuse text and image features through transformer decoder 242 | # Output shape: (N, B, query_dim) 243 | mask_embedding, _ = self.transformer_decoder( 244 | tgt=text_embed, 245 | memory=bottleneck_embed, 246 | pos=self.pos_embed, 247 | memory_key_padding_mask=None 248 | ) 249 | # Shape: (N, B, query_dim) -> (B, N, query_dim) 250 | mask_embedding = repeat(mask_embedding, 'n b dim -> b n dim') 251 | 252 | # Project mask embeddings to decoder channel dimensions for each stage 253 | mask_embeddings = [ 254 | projection(mask_embedding) 255 | for projection in self.project_to_decoder_channels 256 | ] 257 | 258 | # Generate segmentation outputs for each text prompt 259 | outs = [] 260 | num_prompts = text_embedding.shape[1] 261 | for prompt_idx in range(num_prompts): 262 | # Extract embeddings for this prompt across all stages 263 | prompt_embeds = [m[:, prompt_idx:prompt_idx + 1] for m in mask_embeddings] 264 | outs.append(self.decoder(skips, prompt_embeds)) 265 | 266 | # Concatenate outputs across prompts for each scale 267 | outs = [torch.cat(scale_outs, dim=1) for scale_outs in zip(*outs)] 268 | 269 | if not self.deep_supervision: 270 | outs = outs[0] 271 | 272 | return outs 273 | 274 | @staticmethod 275 | def initialize(module): 276 | InitWeights_He(1e-2)(module) 277 | init_last_bn_before_add_to_0(module) 278 | 279 | 280 | class VoxTellDecoder(nn.Module): 281 | """ 282 | Decoder for VoxTell with mask-embedding fusion. 283 | 284 | This decoder upsamples features from the encoder and fuses them with 285 | mask embeddings from text prompts at multiple scales. It supports 286 | deep supervision for multi-scale training. 287 | 288 | The decoder processes features from bottleneck to highest resolution, 289 | incorporating mask embeddings through einsum operations at each stage. 290 | """ 291 | 292 | def __init__( 293 | self, 294 | encoder: Union[PlainConvEncoder, ResidualEncoder], 295 | num_classes: int, 296 | n_conv_per_stage: Union[int, Tuple[int, ...], List[int]], 297 | deep_supervision: bool, 298 | num_maskformer_stages: int = 5, 299 | nonlin_first: bool = False, 300 | norm_op: Union[None, Type[nn.Module]] = None, 301 | norm_op_kwargs: dict = None, 302 | dropout_op: Union[None, Type[_DropoutNd]] = None, 303 | dropout_op_kwargs: dict = None, 304 | nonlin: Union[None, Type[torch.nn.Module]] = None, 305 | nonlin_kwargs: dict = None, 306 | conv_bias: bool = None, 307 | num_heads: int = 1 308 | ) -> None: 309 | """ 310 | Initialize VoxTell decoder. 311 | 312 | The decoder upsamples features from encoder stages and fuses them with 313 | mask embeddings. Each stage consists of: 314 | 1) Transpose convolution to upsample lower resolution features 315 | 2) Concatenation with skip connections from encoder 316 | 3) Convolutional blocks to merge features 317 | 4) Fusion with mask embeddings via einsum 318 | 5) Optional segmentation output for deep supervision 319 | 320 | Args: 321 | encoder: Encoder module (PlainConvEncoder or ResidualEncoder). 322 | num_classes: Number of output classes (typically 1 for binary segmentation). 323 | n_conv_per_stage: Number of convolution blocks per decoder stage. 324 | deep_supervision: Whether to output predictions at multiple scales. 325 | num_maskformer_stages: Number of stages to fuse mask embeddings. 326 | nonlin_first: Whether to apply non-linearity before convolution. 327 | norm_op: Normalization operation (inherited from encoder if None). 328 | norm_op_kwargs: Normalization keyword arguments. 329 | dropout_op: Dropout operation (inherited from encoder if None). 330 | dropout_op_kwargs: Dropout keyword arguments. 331 | nonlin: Non-linearity operation (inherited from encoder if None). 332 | nonlin_kwargs: Non-linearity keyword arguments. 333 | conv_bias: Whether to use bias in convolutions (inherited from encoder if None). 334 | num_heads: Number of attention heads for mask embedding fusion. 335 | """ 336 | super().__init__() 337 | self.deep_supervision = deep_supervision 338 | self.encoder = encoder 339 | self.num_classes = num_classes 340 | self.num_heads = num_heads 341 | 342 | n_stages_encoder = len(encoder.output_channels) 343 | if isinstance(n_conv_per_stage, int): 344 | n_conv_per_stage = [n_conv_per_stage] * (n_stages_encoder - 1) 345 | 346 | assert len(n_conv_per_stage) == n_stages_encoder - 1, ( 347 | f"n_conv_per_stage must have one less entry than encoder stages. " 348 | f"Expected: {n_stages_encoder - 1}, got: {len(n_conv_per_stage)}" 349 | ) 350 | 351 | # Inherit hyperparameters from encoder if not specified 352 | transpconv_op = get_matching_convtransp(conv_op=encoder.conv_op) 353 | conv_bias = encoder.conv_bias if conv_bias is None else conv_bias 354 | norm_op = encoder.norm_op if norm_op is None else norm_op 355 | norm_op_kwargs = encoder.norm_op_kwargs if norm_op_kwargs is None else norm_op_kwargs 356 | dropout_op = encoder.dropout_op if dropout_op is None else dropout_op 357 | dropout_op_kwargs = encoder.dropout_op_kwargs if dropout_op_kwargs is None else dropout_op_kwargs 358 | nonlin = encoder.nonlin if nonlin is None else nonlin 359 | nonlin_kwargs = encoder.nonlin_kwargs if nonlin_kwargs is None else nonlin_kwargs 360 | 361 | # Build decoder stages from bottleneck to highest resolution 362 | stages = [] 363 | transpconvs = [] 364 | seg_layers = [] 365 | for stage_idx in range(1, n_stages_encoder): 366 | # Determine input channels: add num_heads for stages with mask embedding fusion 367 | if stage_idx <= n_stages_encoder - num_maskformer_stages: 368 | input_features_below = encoder.output_channels[-stage_idx] 369 | else: 370 | input_features_below = encoder.output_channels[-stage_idx] + num_heads 371 | 372 | input_features_skip = encoder.output_channels[-(stage_idx + 1)] 373 | stride_for_transpconv = encoder.strides[-stage_idx] 374 | 375 | # Transpose convolution for upsampling 376 | transpconvs.append(transpconv_op( 377 | input_features_below, input_features_skip, 378 | stride_for_transpconv, stride_for_transpconv, 379 | bias=conv_bias 380 | )) 381 | 382 | # Convolutional blocks for feature merging 383 | # Input features: 2x input_features_skip (concatenated skip + upsampled features) 384 | stages.append(StackedConvBlocks( 385 | n_conv_per_stage[stage_idx - 1], encoder.conv_op, 386 | 2 * input_features_skip, input_features_skip, 387 | encoder.kernel_sizes[-(stage_idx + 1)], 1, 388 | conv_bias, norm_op, norm_op_kwargs, 389 | dropout_op, dropout_op_kwargs, 390 | nonlin, nonlin_kwargs, nonlin_first 391 | )) 392 | 393 | # Segmentation output layer (always built for parameter loading compatibility) 394 | # This allows models trained with deep_supervision=True to be loaded 395 | # for inference with deep_supervision=False 396 | seg_layers.append(encoder.conv_op( 397 | input_features_skip + num_heads, num_classes, 398 | 1, 1, 0, bias=True 399 | )) 400 | 401 | self.stages = nn.ModuleList(stages) 402 | self.transpconvs = nn.ModuleList(transpconvs) 403 | self.seg_layers = nn.ModuleList(seg_layers) 404 | 405 | def forward( 406 | self, 407 | skips: List[torch.Tensor], 408 | mask_embeddings: List[torch.Tensor] 409 | ) -> List[torch.Tensor]: 410 | """ 411 | Forward pass through decoder with mask embedding fusion. 412 | 413 | Processes features from bottleneck to highest resolution, fusing 414 | mask embeddings at multiple stages via einsum operations. 415 | 416 | Args: 417 | skips: List of encoder skip connections in computation order. 418 | Last entry should be bottleneck features. 419 | mask_embeddings: List of mask embeddings for each decoder stage, 420 | in order from lowest to highest resolution. 421 | 422 | Returns: 423 | List of segmentation predictions. If deep_supervision=False, 424 | returns single-element list with highest resolution prediction. 425 | If deep_supervision=True, returns predictions at all scales 426 | from highest to lowest resolution. 427 | """ 428 | lres_input = skips[-1] 429 | seg_outputs = [] 430 | 431 | # Reverse mask embeddings to match decoder stage order (bottleneck first) 432 | mask_embeddings = mask_embeddings[::-1] 433 | 434 | for stage_idx in range(len(self.stages)): 435 | # Upsample and concatenate with skip connection 436 | x = self.transpconvs[stage_idx](lres_input) 437 | x = torch.cat((x, skips[-(stage_idx + 2)]), dim=1) 438 | x = self.stages[stage_idx](x) 439 | 440 | # Apply mask embedding fusion for relevant stages 441 | if stage_idx == (len(self.stages) - 1): 442 | # Final stage: generate segmentation via einsum 443 | # x: (B, C, H, W, D), mask_embeddings[-1]: (B, N, C) 444 | # Output: (B, N, H, W, D) 445 | seg_pred = torch.einsum('b c h w d, b n c -> b n h w d', x, mask_embeddings[-1]) 446 | seg_outputs.append(seg_pred) 447 | elif stage_idx >= len(self.stages) - len(mask_embeddings): 448 | # Intermediate stages with mask embedding fusion 449 | mask_embedding = mask_embeddings.pop(0) 450 | batch_size, _, channels = mask_embedding.shape 451 | 452 | # Reshape for multi-head fusion and compute attention-weighted features 453 | # Shape: (B, num_heads, C // num_heads) 454 | mask_embedding_reshaped = mask_embedding.view(batch_size, self.num_heads, -1) 455 | fusion_features = torch.einsum( 456 | 'b c h w d, b n c -> b n h w d', 457 | x, mask_embedding_reshaped 458 | ) 459 | 460 | # Concatenate fused features with spatial features 461 | x = torch.cat((x, fusion_features), dim=1) 462 | seg_outputs.append(self.seg_layers[stage_idx](x)) 463 | 464 | lres_input = x 465 | 466 | # Reverse outputs to have highest resolution first 467 | seg_outputs = seg_outputs[::-1] 468 | 469 | if not self.deep_supervision: 470 | return seg_outputs[:1] 471 | else: 472 | return seg_outputs --------------------------------------------------------------------------------