├── LICENSE ├── README.md ├── configs └── config.yaml ├── data ├── __init__.py ├── example.json ├── seg_dataset.py ├── transforms.py └── utils.py ├── models ├── __init__.py ├── build_conch_v1_5.py ├── conch_v1_5_config.py ├── lora.py ├── losses.py ├── pfm_seg_models.py └── utils.py ├── scripts ├── infer.py └── train.py └── utils ├── __init__.py ├── evaluator.py ├── logs.py ├── metrics.py ├── scheduler.py ├── trainer.py └── visualization.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🩺 Pathology Foundation Models Meet Semantic Segmentation 2 | 3 | A comprehensive semantic segmentation framework based on Pathology Foundation Models (PFMs), designed specifically for pathological image analysis, supporting multiple state-of-the-art pathology foundation models with complete training, inference, and evaluation capabilities. 4 | 5 | ## 🌟 Features 6 | 7 | - 🧬 **Support for SOTA Pathology Foundation Models**: uni_v1, uni_v2, conch_v1_5, gigapath, virchow_v2 8 | - 🔧 **Flexible Fine-tuning Strategies**: LoRA, full parameter fine-tuning, frozen backbone 9 | - 📊 **Complete Training Pipeline**: Mixed precision training, learning rate scheduling, gradient accumulation 10 | - 🎯 **Advanced Data Augmentation**: Integrated 10+ advanced data augmentations including spatial, color, and noise transformations 11 | - 📈 **Comprehensive Evaluation Metrics**: Integrated 10+ evaluation metrics including IoU/Dice and more 12 | - ⚡ **Advanced Inference Pipeline**: Support for arbitrary resolution sliding window inference 13 | 14 | ## 📋 Table of Contents 15 | 16 | - [Dataset Format](#-dataset-format) 17 | - [Configuration File Details](#-configuration-file-details) 18 | - [Training Script Usage](#-training-script-usage) 19 | - [Inference Script Usage](#-inference-script-usage) 20 | - [Pathology Foundation Models Details](#-pathology-foundation-models-details) 21 | 22 | ## 📁 Dataset Format 23 | 24 | ### JSON Configuration File Format 25 | 26 | The dataset uses JSON format for configuration, supporting train, validation, and test set splits: 27 | 28 | ```json 29 | { 30 | "num_classes": 3, 31 | "data": { 32 | "train": [ 33 | { 34 | "image_path": "/path/to/train/image1.jpg", 35 | "mask_path": "/path/to/train/mask1.png" 36 | }, 37 | ], 38 | "val": [ 39 | { 40 | "image_path": "/path/to/val/image1.jpg", 41 | "mask_path": "/path/to/val/mask1.png" 42 | } 43 | ], 44 | "test": [ 45 | { 46 | "image_path": "/path/to/test/image1.jpg", 47 | "mask_path": "/path/to/test/image2.png" 48 | } 49 | ] 50 | } 51 | } 52 | ``` 53 | 54 | During training, only the `train` and `val` fields are used. The `test` field is used when executing inference scripts. The `mask_path` in the test field can be null or missing, in which case the model will not compute metrics. If `mask_path` exists, metrics will be automatically calculated after inference. 55 | 56 | ## ⚙️ Configuration File Details 57 | 58 | The configuration file uses YAML format and includes the following main sections: 59 | 60 | ### Dataset Configuration (dataset) 61 | 62 | ```yaml 63 | dataset: 64 | json_file: "/path/to/dataset.json" # Path to dataset JSON configuration file 65 | num_classes: 3 # Number of classes, must match JSON file 66 | ignore_index: 255 # Pixel value to ignore for uncertain regions 67 | ``` 68 | 69 | ### System Configuration (system) 70 | 71 | ```yaml 72 | system: 73 | num_workers: 4 # Number of processes for data loading 74 | pin_memory: true # Whether to use pin_memory for faster data transfer 75 | seed: 42 # Random seed for reproducible experiments 76 | device: "cuda:0" # Device to use 77 | ``` 78 | 79 | ### Pathology Foundation Model Configuration (model) 🧬 80 | 81 | This is the most important section, controlling the selection and configuration of pathology foundation models: 82 | 83 | ```yaml 84 | model: 85 | # === Base Model Selection === 86 | pfm_name: "uni_v1" # Pathology foundation model name 87 | # Options: 88 | # - "uni_v1" : UNI model version 1 (1024 dim) 89 | # - "uni_v2" : UNI model version 2 (1536 dim) 90 | # - "virchow_v2" : Virchow model version 2 (1280 dim) 91 | # - "conch_v1_5" : Conch model version 1.5 (1024 dim) 92 | # - "gigapath" : Gigapath model (1536 dim) 93 | 94 | # === Model Parameter Configuration === 95 | emb_dim: 1024 # Embedding dimension, must match selected PFM model 96 | # Corresponding embedding dimensions for each model: 97 | # - uni_v1: 1024 - uni_v2: 1536 98 | # - virchow_v2: 1280 - conch_v1_5: 1024 99 | # - gigapath: 1536 100 | 101 | pfm_weights_path: '/path/to/pytorch_model.bin' # Path to pre-trained weights file 102 | 103 | # === Fine-tuning Strategy Configuration === 104 | finetune_mode: 105 | type: "lora" # Fine-tuning mode 106 | # Options: 107 | # - "lora" : LoRA low-rank adaptation, parameter efficient 108 | # - "full" : Full parameter fine-tuning, best performance but requires more memory 109 | # - "frozen" : Frozen backbone, only train segmentation head 110 | 111 | rank: 16 # LoRA rank, only used when type is "lora" 112 | alpha: 1.0 # LoRA scaling factor, only used when type is "lora" 113 | 114 | # === Data Preprocessing Configuration === 115 | mean: [0.485, 0.456, 0.406] # Input normalization mean, must match PFM model training settings 116 | std: [0.229, 0.224, 0.225] # Input normalization std, must match PFM model training settings 117 | 118 | num_classes: 3 # Number of segmentation classes, must match dataset.num_classes 119 | ``` 120 | 121 | ### Training Configuration (training) 122 | 123 | ```yaml 124 | training: 125 | # === Basic Training Parameters === 126 | batch_size: 8 # Batch size 127 | epochs: 100 # Number of training epochs 128 | learning_rate: 0.01 # Initial learning rate 129 | weight_decay: 0.0001 # Weight decay 130 | 131 | # === Training Optimization Settings === 132 | use_amp: true # Whether to use mixed precision training 133 | accumulate_grad_batches: 1 # Number of gradient accumulation steps 134 | clip_grad_norm: 5.0 # Gradient clipping threshold 135 | 136 | # === Data Augmentation Configuration === 137 | augmentation: 138 | RandomResizedCropSize: 512 # Random crop size 139 | # Note: Different PFM models have input size requirements 140 | # - virchow_v2, uni_v2: must be multiple of 14 141 | # - uni_v1, conch_v1_5, gigapath: must be multiple of 16 142 | 143 | # === Optimizer Configuration === 144 | optimizer: 145 | type: "SGD" # Optimizer type: SGD, Adam, AdamW 146 | momentum: 0.9 # SGD momentum (SGD only) 147 | nesterov: true # Whether to use Nesterov momentum 148 | 149 | # === Learning Rate Scheduler === 150 | scheduler: 151 | type: "cosine" # Scheduler type: cosine, step 152 | warmup_epochs: 2 # Number of warmup epochs 153 | 154 | # === Loss Function === 155 | loss: 156 | type: "cross_entropy" # Loss function: cross_entropy, dice, ohem, iou 157 | ``` 158 | 159 | ### Validation Configuration (validation) 160 | 161 | ```yaml 162 | validation: 163 | eval_interval: 1 # Validate every N epochs 164 | batch_size: 16 # Validation batch size 165 | augmentation: 166 | ResizedSize: 512 # Image size during validation 167 | ``` 168 | 169 | ### Logging and Visualization Configuration 170 | 171 | ```yaml 172 | logging: 173 | log_dir: "/path/to/logs" # Log save directory 174 | experiment_name: "pfm_segmentation" # Experiment name 175 | 176 | visualization: 177 | save_interval: 2 # Save visualization results every N epochs 178 | num_vis_samples: 8 # Number of visualization samples to save 179 | ``` 180 | 181 | ## 🚀 Training Script Usage 182 | 183 | ### Basic Training Command 184 | 185 | ```bash 186 | python scripts/train.py --config configs/config.yaml 187 | ``` 188 | 189 | ### Training Script Parameters Details 190 | 191 | ```bash 192 | python scripts/train.py \ 193 | --config configs/config.yaml \ # Configuration file path 194 | --resume checkpoints/model.pth \ # Resume training from checkpoint (optional) 195 | --device cuda:0 # Specify device (optional, overrides config file) 196 | ``` 197 | 198 | ### Parameter Description 199 | 200 | - `--config`: **Required** Configuration file path containing all training settings 201 | - `--resume`: **Optional** Checkpoint file path for resuming interrupted training 202 | - `--device`: **Optional** Training device, overrides device setting in config file 203 | 204 | ### Training Output 205 | 206 | During training, the following files will be generated: 207 | 208 | ``` 209 | logs/experiment_name/ 210 | ├── config.yaml # Saved copy of configuration file 211 | ├── training.log # Training log 212 | ├── checkpoints/ # Model checkpoints 213 | │ ├── best_model.pth # Best model 214 | ├── visualizations/ # Visualization results 215 | │ ├── epoch_010_sample_00.png 216 | │ └── ... 217 | └── training_history.png # Training curve plot 218 | ``` 219 | 220 | ### Training Monitoring 221 | 222 | During training, the following will be displayed: 223 | - Training loss and validation loss 224 | - Validation metrics (mIoU, Pixel Accuracy, etc.) 225 | - Learning rate changes 226 | - Time consumption per epoch 227 | 228 | ## 🔍 Inference Script Usage 229 | 230 | ### Basic Inference Command 231 | 232 | ```bash 233 | python scripts/infer.py \ 234 | --config logs/experiment_name/config.yaml \ 235 | --checkpoint logs/experiment_name/checkpoints/best_model.pth \ 236 | --input_json dataset/test.json \ 237 | --output_dir results/ 238 | ``` 239 | 240 | ### Inference Script Parameters Details 241 | 242 | ```bash 243 | python scripts/infer.py \ 244 | --config CONFIG_PATH \ # Configuration file used during training 245 | --checkpoint CHECKPOINT_PATH \ # Trained model weights 246 | --input_json INPUT_JSON \ # Input data JSON file 247 | --output_dir OUTPUT_DIR \ # Results save directory 248 | --device cuda:0 \ # Inference device 249 | --input_size 512 \ # Input image size 250 | --resize_or_windowslide windowslide \ # Inference mode 251 | --batch_size 4 # Inference batch size 252 | ``` 253 | 254 | ### Detailed Parameter Description 255 | 256 | | Parameter | Type | Required | Description | 257 | |-----------|------|----------|-------------| 258 | | `--config` | str | ✅ | Configuration file path used during training | 259 | | `--checkpoint` | str | ✅ | Trained model checkpoint path | 260 | | `--input_json` | str | ✅ | JSON file containing data to be inferred | 261 | | `--output_dir` | str | ✅ | Inference results save directory | 262 | | `--device` | str | ✅ | Inference device, default cuda:0 | 263 | | `--input_size` | int | ✅ | Input image size for model, not original image size | 264 | | `--resize_or_windowslide` | str | ✅ | Inference mode, default windowslide | 265 | | `--batch_size` | int | ✅ | Inference batch size, default 2 | 266 | 267 | ### Inference Mode Selection 268 | 269 | 1. **Resize Mode** (`--resize_or_windowslide resize`) 270 | - Resize input images to fixed size (input_size) for inference 271 | - Resize prediction results back to original image size after inference 272 | 273 | 2. **Window Slide Mode** (`--resize_or_windowslide windowslide`) 274 | - Use sliding window (input_size) strategy to process large images 275 | - Maintains original resolution with higher accuracy 276 | - Merge back to original image size after inference 277 | 278 | ### Inference Output 279 | 280 | After inference completion, the following will be generated: 281 | 282 | ``` 283 | output_dir/ 284 | ├── predictions_masks/ # Prediction masks (grayscale images) 285 | │ ├── image001.png 286 | │ ├── image002.png 287 | │ └── ... 288 | └── predictions_overlays/ # Prediction result visualizations (colored overlay images) 289 | ├── image001.png 290 | ├── image002.png 291 | └── ... 292 | ``` 293 | 294 | ### Inference Result Format 295 | 296 | - **Prediction Masks**: Grayscale PNG images with pixel values corresponding to class indices 297 | - **Visualization Overlays**: Colored overlays of original images with prediction results for intuitive viewing 298 | 299 | ## 🧬 Pathology Foundation Models Details 300 | 301 | ### Supported Models List 302 | 303 | | Model Name | Parameters | Embedding Dim | Token Size | HuggingFace | 304 | |------------|------------|---------------|------------|-------------| 305 | | **uni_v1** | 307M | 1024 | 16×16 | [MahmoodLab/UNI](https://huggingface.co/MahmoodLab/UNI) | 306 | | **uni_v2** | 1.1B | 1536 | 14×14 | [MahmoodLab/UNI2](https://huggingface.co/MahmoodLab/UNI2-h) | 307 | | **virchow_v2** | 632M | 1280 | 14×14 | [paige-ai/Virchow2](https://huggingface.co/paige-ai/Virchow2) | 308 | | **conch_v1_5** | 307M | 1024 | 16×16 | [MahmoodLab/TITAN](https://huggingface.co/MahmoodLab/TITAN) | 309 | | **gigapath** | 1.1B | 1536 | 16×16 | [prov-gigapath/prov-gigapath](https://huggingface.co/prov-gigapath/prov-gigapath) | 310 | 311 | ## 🤝 Contributing 312 | 313 | Welcome to submit issues and feature requests! Please check the contribution guidelines for more information. 314 | 315 | ## 📞 Contact 316 | 317 | If you have questions or suggestions, please contact us through: 318 | - Submit GitHub Issue 319 | - Send email to: [lingxt23@mails.tsinghua.edu.cn] 320 | 321 | --- 322 | 323 | -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | # Dataset configuration 2 | dataset: 3 | json_file: "/path/to/dataset.json" # Path to the JSON configuration file 4 | num_classes: 2 # Number of classes; must match num_classes in the JSON file 5 | ignore_index: 255 # Index value to ignore during training 6 | 7 | system: 8 | num_workers: 2 # Number of worker threads for data loading 9 | pin_memory: true # Whether to use pin_memory to accelerate data loading 10 | seed: 42 11 | device: "cuda:1" # Device to use, 'cuda' or 'cpu' 12 | 13 | # Model configuration 14 | model: 15 | pfm_name: "uni_v1" # Options: uni_v1, uni_v2, virvhow_v2, conch_v1_5, gigapath 16 | emb_dim: 1024 # uni_v1: 1024, uni_v2: 1536, virvhow_v2: 1280, conch_v1_5: 1024, gigapath: 1536 17 | finetune_mode: 18 | type: lora # Options: lora, full, frozen 19 | rank: 16 # only used when finetune_mode.type is lora 20 | alpha: 1.0 # only used when finetune_mode.type is lora 21 | pfm_weights_path: '/path/to/pytorch_model.bin' # Path to the PFM model weights 22 | mean: [0.485, 0.456, 0.406] # match the PFM model's input normalization 23 | std: [0.229, 0.224, 0.225] # match the PFM model's input normalization 24 | num_classes: 2 # Must match num_classes in the JSON file 25 | 26 | # Training configuration 27 | training: 28 | batch_size: 1 29 | epochs: 10 30 | learning_rate: 0.01 31 | weight_decay: 0.0001 32 | use_amp: true # Whether to use automatic mixed precision (AMP) for training 33 | accumulate_grad_batches: 1 # Number of batches to accumulate gradients over before performing an optimizer step 34 | clip_grad_norm: 5.0 # Gradient clipping value to prevent exploding gradients 35 | 36 | # Data augmentation 37 | augmentation: 38 | RandomResizedCropSize: 512 # virchow_v2,uni_v2: must be a multiple of 14 (token_size) / uni_v1, conch_v1_5, gigapath: must be a multiple of 16 (token_size) 39 | 40 | # Optimizer settings 41 | optimizer: 42 | type: "SGD" # Options: SGD, Adam, AdamW 43 | 44 | # Learning rate scheduler 45 | scheduler: 46 | type: "cosine" # Options: cosine, step 47 | warmup_epochs: 2 48 | 49 | # Loss function 50 | loss: 51 | type: "cross_entropy" # Options: cross_entropy, dice, ohem, iou 52 | 53 | # Validation configuration 54 | validation: 55 | eval_interval: 1 # Validate every N epochs 56 | batch_size: 16 57 | augmentation: 58 | ResizedSize: 512 # virchow_v2,uni_v2: must be a multiple of 14 (token_size) / uni_v1, conch_v1_5, gigapath: must be a multiple of 16 (token_size) 59 | 60 | logging: 61 | log_dir: "/path/to/logs" 62 | experiment_name: "your_xperiment_name" 63 | 64 | visualization: 65 | save_interval: 2 # Save visualization results every N epochs 66 | num_vis_samples: 8 # Number of samples to visualize -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data package for semantic segmentation. 3 | 4 | This package contains dataset classes, data transforms, and utilities 5 | for loading and preprocessing segmentation data. 6 | """ 7 | 8 | from data.seg_dataset import ( 9 | JSONSegmentationDataset, 10 | ) 11 | 12 | from .transforms import ( 13 | SegmentationTransforms, parse_transform_config, get_transforms, 14 | MixUp, CutMix, Mosaic, AdvancedAugmentationPipeline 15 | ) 16 | from .utils import ( 17 | create_dataloader, segmentation_collate_fn, compute_class_distribution, 18 | visualize_class_distribution, visualize_sample, create_color_map, 19 | analyze_dataset_quality, save_dataset_info, create_data_split 20 | ) 21 | 22 | __all__ = [ 23 | # Datasets 24 | 'BaseSegmentationDataset', 25 | 'CityscapesDataset', 26 | 'ADE20KDataset', 27 | 'PascalVOCDataset', 28 | 'CustomDataset', 29 | 'get_dataset', 30 | 'DatasetStatistics', 31 | 32 | # Transforms 33 | 'SegmentationTransforms', 34 | 'parse_transform_config', 35 | 'get_transforms', 36 | 'MixUp', 37 | 'CutMix', 38 | 'Mosaic', 39 | 'AdvancedAugmentationPipeline', 40 | 41 | # Utils 42 | 'create_dataloader', 43 | 'segmentation_collate_fn', 44 | 'compute_class_distribution', 45 | 'visualize_class_distribution', 46 | 'visualize_sample', 47 | 'create_color_map', 48 | 'analyze_dataset_quality', 49 | 'save_dataset_info', 50 | 'create_data_split' 51 | ] 52 | -------------------------------------------------------------------------------- /data/example.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_classes": 3, 3 | "data": { 4 | "train": [ 5 | { 6 | "image_path": "/images/train/patient001_slide01.jpg", 7 | "mask_path": "/masks/train/patient001_slide01.png" 8 | }, 9 | { 10 | "image_path": "/images/train/patient001_slide02.jpg", 11 | "mask_path": "masks/train/patient001_slide02.png" 12 | }, 13 | { 14 | "image_path": "/images/train/patient002_slide01.jpg", 15 | "mask_path": "/masks/train/patient002_slide01.png" 16 | } 17 | ], 18 | "val": [ 19 | { 20 | "image_path": "/images/val/patient003_slide01.jpg", 21 | "mask_path": "/masks/val/patient003_slide01.png" 22 | }, 23 | { 24 | "image_path": "/images/val/patient003_slide02.jpg", 25 | "mask_path": "/masks/val/patient003_slide02.png" 26 | } 27 | ], 28 | "test": [ 29 | { 30 | "image_path": "/images/val/patient004_slide01.jpg", 31 | "mask_path": "/masks/val/patient004_slide01.png" 32 | }, 33 | { 34 | "image_path": "/images/val/patient004_slide02.jpg", 35 | "mask_path": "/masks/val/patient004_slide02.png" 36 | } 37 | ] 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /data/seg_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simplified JSON-based dataset class 3 | 4 | Supports basic img_path and mask_path format, designed for semantic segmentation tasks. 5 | """ 6 | 7 | import os 8 | import json 9 | import torch 10 | from torch.utils.data import Dataset 11 | from PIL import Image 12 | import numpy as np 13 | from typing import Dict, List, Optional, Callable 14 | 15 | 16 | class JSONSegmentationDataset(Dataset): 17 | """ 18 | Semantic segmentation dataset based on a JSON file. 19 | 20 | Expected JSON format: 21 | { 22 | "num_classes": 3, 23 | "data": { 24 | "train": [ 25 | {"image_path": "/path/to/image1.jpg", "mask_path": "/path/to/mask1.png"}, 26 | {"image_path": "/path/to/image2.jpg", "mask_path": "/path/to/mask2.png"} 27 | ], 28 | "val": [...], 29 | "test": [...] 30 | } 31 | } 32 | 33 | Args: 34 | json_file (str): Path to the JSON config file. 35 | split (str): Dataset split ('train', 'val', or 'test'). 36 | transform (Optional[Callable]): Data transformation/augmentation function. 37 | """ 38 | 39 | def __init__(self, json_file: str, split: str = 'train', 40 | transform: Optional[Callable] = None): 41 | self.json_file = json_file 42 | self.split = split 43 | self.transform = transform 44 | 45 | # Load JSON configuration 46 | self.config = self._load_json_config() 47 | 48 | # Extract basic info 49 | self.num_classes = self.config.get('num_classes') 50 | self.ignore_index = 255 # fixed ignore label 51 | # Load data entries 52 | self.data_items = self._load_data_items() 53 | self.fixed_size = self._check_fixed_size() 54 | self.has_mask = self._check_has_mask() 55 | if not self.has_mask: 56 | self._reset_mask() 57 | print(f"Dataset loaded: split = {split}, samples = {len(self.data_items)}, classes = {self.num_classes}") 58 | 59 | def _load_json_config(self) -> Dict: 60 | """Load the JSON config file.""" 61 | try: 62 | with open(self.json_file, 'r', encoding='utf-8') as f: 63 | config = json.load(f) 64 | return config 65 | except FileNotFoundError: 66 | raise FileNotFoundError(f"JSON config file not found: {self.json_file}") 67 | except json.JSONDecodeError as e: 68 | raise ValueError(f"Invalid JSON format: {e}") 69 | 70 | def _check_has_mask(self) -> bool: 71 | """Check if the dataset has mask paths.""" 72 | for item in self.data_items: 73 | mask_path = item.get('mask_path') 74 | if mask_path == None: 75 | return False 76 | if not os.path.exists(mask_path): 77 | return False 78 | return True 79 | 80 | def _reset_mask(self) -> None: 81 | """Reset mask paths to None if they are not present.""" 82 | new_items = [] 83 | for item in self.data_items: 84 | item['mask_path'] = None 85 | new_items.append(item) 86 | self.data_items = new_items 87 | 88 | 89 | def _check_fixed_size(self) -> bool: 90 | """Check if the dataset has a fixed image size.""" 91 | _img_size = None 92 | for item in self.data_items: 93 | img_path = item.get('img_path', '') 94 | with Image.open(img_path) as img: 95 | if _img_size is None: 96 | _img_size = img.size 97 | elif _img_size != img.size: 98 | return False 99 | return True 100 | 101 | def _load_data_items(self) -> List[Dict]: 102 | """Load the data entries for the given split.""" 103 | data_config = self.config.get('data') 104 | split_data = data_config.get(self.split) 105 | 106 | if not split_data: 107 | raise ValueError(f"No data found for split '{self.split}'") 108 | 109 | processed_items = [] 110 | for item in split_data: 111 | processed_item = self._process_data_item(item) 112 | if processed_item: 113 | processed_items.append(processed_item) 114 | 115 | if not processed_items: 116 | raise ValueError(f"No valid items found in split '{self.split}'") 117 | 118 | return processed_items 119 | 120 | def _process_data_item(self, item: Dict) -> Optional[Dict]: 121 | """Process a single data entry.""" 122 | img_path = item.get('image_path', '') 123 | mask_path = item.get('mask_path', None) 124 | 125 | if not img_path or not mask_path: 126 | if self.split == 'train' or self.split == 'val': 127 | print(f"Missing image or mask path: {item}") 128 | return None 129 | 130 | return { 131 | 'img_path': img_path, 132 | 'mask_path': mask_path 133 | } 134 | 135 | def __len__(self) -> int: 136 | """Return the dataset size.""" 137 | return len(self.data_items) 138 | 139 | def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: 140 | """ 141 | Retrieve a single data entry. 142 | 143 | Args: 144 | index (int): Index of the data item 145 | 146 | Returns: 147 | Dict[str, torch.Tensor]: Dictionary containing image and label tensors 148 | """ 149 | item = self.data_items[index] 150 | 151 | image = Image.open(item['img_path']).convert('RGB') 152 | ori_size = image.size 153 | 154 | if self.has_mask: 155 | mask = Image.open(item['mask_path']) 156 | if mask.mode != 'L': 157 | mask = mask.convert('L') 158 | mask = np.array(mask, dtype=np.int64) 159 | else: 160 | mask = np.ones((ori_size[1],ori_size[0]), dtype=np.int64) * (-1) 161 | 162 | # Validate mask values (should be within [0, num_classes-1] or 255 as ignore index) 163 | unique_values = np.unique(mask) 164 | valid_values = set(range(self.num_classes)) | {self.ignore_index} 165 | invalid_values = set(unique_values) - valid_values 166 | 167 | if invalid_values and self.has_mask: 168 | print(f"Invalid label values {invalid_values} found in {item['mask_path']}") 169 | for invalid_val in invalid_values: 170 | mask[mask == invalid_val] = self.ignore_index 171 | 172 | # Apply transformation 173 | if self.transform: 174 | transformed = self.transform(image=np.array(image), mask=mask) 175 | image = transformed['image'] 176 | mask = transformed['mask'] 177 | else: 178 | image = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0 179 | mask = torch.from_numpy(mask).long() 180 | 181 | return { 182 | 'image': image, 183 | 'label': mask, 184 | 'ori_size': ori_size, 185 | 'image_path': item['img_path'], 186 | 'label_path': item['mask_path'] 187 | } 188 | 189 | def get_class_weights(self) -> torch.Tensor: 190 | """ 191 | Compute class weights to handle class imbalance. 192 | 193 | Returns: 194 | torch.Tensor: Computed class weights 195 | """ 196 | print("Computing class weights...") 197 | 198 | class_counts = np.zeros(self.num_classes) 199 | total_pixels = 0 200 | 201 | for item in self.data_items: 202 | mask = Image.open(item['mask_path']) 203 | if mask.mode != 'L': 204 | mask = mask.convert('L') 205 | mask_array = np.array(mask) 206 | 207 | for class_id in range(self.num_classes): 208 | class_counts[class_id] += np.sum(mask_array == class_id) 209 | 210 | valid_pixels = mask_array != self.ignore_index 211 | total_pixels += np.sum(valid_pixels) 212 | 213 | class_counts = np.maximum(class_counts, 1) 214 | weights = total_pixels / (self.num_classes * class_counts) 215 | weights = weights / weights.sum() * self.num_classes 216 | 217 | print(f"Class weights: {weights}") 218 | return torch.from_numpy(weights).float() 219 | 220 | 221 | def get_dataset(data_configs, transforms, split): 222 | json_file = data_configs.get('json_file') 223 | return JSONSegmentationDataset( 224 | json_file=json_file, 225 | split=split, 226 | transform=transforms 227 | ) 228 | 229 | -------------------------------------------------------------------------------- /data/transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data Transforms for Semantic Segmentation 3 | 4 | This module contains data augmentation and preprocessing transforms 5 | using albumentations library for robust training. 6 | """ 7 | 8 | import albumentations as A 9 | from albumentations.pytorch import ToTensorV2 10 | import cv2 11 | import numpy as np 12 | from typing import List, Dict, Any, Optional, Callable 13 | import torch 14 | 15 | class SegmentationTransforms: 16 | """ 17 | Collection of segmentation-specific transforms. 18 | """ 19 | 20 | @staticmethod 21 | def get_training_transforms(img_size: int = 512, 22 | mean: List[float] = [0.485, 0.456, 0.406], 23 | std: List[float] = [0.229, 0.224, 0.225], 24 | seed: int = 42) -> A.Compose: 25 | """ 26 | Get training transforms with strong augmentations. 27 | 28 | Args: 29 | img_size (int): Target image size 30 | mean (List[float]): Normalization mean 31 | std (List[float]): Normalization standard deviation 32 | 33 | Returns: 34 | A.Compose: Composed transforms 35 | """ 36 | return A.Compose([ 37 | # Geometric transforms 38 | A.RandomResizedCrop(size=(img_size,img_size), scale=(0.5, 1.0), ratio=(0.75, 1.33), p=1.0), 39 | A.HorizontalFlip(p=0.5), 40 | A.VerticalFlip(p=0.1), 41 | A.RandomRotate90(p=0.3), 42 | A.Transpose(p=0.3), 43 | 44 | # Spatial transforms 45 | A.OneOf([ 46 | A.ElasticTransform(alpha=1, sigma=50, p=1.0), 47 | A.GridDistortion(num_steps=5, distort_limit=0.3, p=1.0), 48 | A.OpticalDistortion(distort_limit=0.2, p=1.0), 49 | ], p=0.3), 50 | 51 | # Color transforms 52 | A.OneOf([ 53 | A.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=1.0), 54 | A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=1.0), 55 | A.RGBShift(r_shift_limit=25, g_shift_limit=25, b_shift_limit=25, p=1.0), 56 | ], p=0.5), 57 | 58 | # Noise and blur 59 | A.OneOf([ 60 | A.GaussNoise(p=1.0), 61 | A.MultiplicativeNoise(multiplier=(0.9, 1.1), per_channel=True, p=1.0), 62 | ], p=0.3), 63 | 64 | A.OneOf([ 65 | A.GaussianBlur(blur_limit=(3, 7), p=1.0), 66 | A.MotionBlur(blur_limit=7, p=1.0), 67 | A.MedianBlur(blur_limit=7, p=1.0), 68 | ], p=0.2), 69 | 70 | # Weather effects 71 | A.OneOf([ 72 | A.RandomRain(brightness_coefficient=0.7, p=1.0), 73 | A.RandomSnow(brightness_coeff=2.5, p=1.0), 74 | A.RandomFog(alpha_coef=0.08, p=1.0), 75 | ], p=0.2), 76 | 77 | # Lighting 78 | A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.4), 79 | A.RandomGamma(gamma_limit=(70, 130), p=0.3), 80 | 81 | # Cutout and mixing 82 | A.CoarseDropout(p=0.3), 83 | 84 | # Normalization 85 | A.Normalize(mean=mean, std=std), 86 | ToTensorV2(), 87 | ], seed=seed) 88 | 89 | @staticmethod 90 | def get_validation_transforms(img_size: int = 512, 91 | mean: List[float] = [0.485, 0.456, 0.406], 92 | std: List[float] = [0.229, 0.224, 0.225]) -> A.Compose: 93 | """ 94 | Get validation transforms with minimal augmentation. 95 | 96 | Args: 97 | img_size (int): Target image size 98 | mean (List[float]): Normalization mean 99 | std (List[float]): Normalization standard deviation 100 | 101 | Returns: 102 | A.Compose: Composed transforms 103 | """ 104 | if img_size != None: 105 | return A.Compose([ 106 | A.Resize(height=img_size, width=img_size), 107 | A.Normalize(mean=mean, std=std), 108 | ToTensorV2(), 109 | ]) 110 | else: 111 | return A.Compose([ 112 | A.Normalize(mean=mean, std=std), 113 | ToTensorV2(), 114 | ]) 115 | 116 | 117 | def parse_transform_config(config: Dict[str, Any]) -> A.Compose: 118 | """ 119 | Parse transform configuration and create albumentations transforms. 120 | 121 | Args: 122 | config (Dict[str, Any]): Transform configuration 123 | 124 | Returns: 125 | A.Compose: Composed transforms 126 | """ 127 | transforms = [] 128 | 129 | for transform_config in config: 130 | transform_type = transform_config['type'] 131 | transform_params = {k: v for k, v in transform_config.items() if k != 'type'} 132 | 133 | # Get transform class from albumentations 134 | if hasattr(A, transform_type): 135 | transform_class = getattr(A, transform_type) 136 | transforms.append(transform_class(**transform_params)) 137 | elif transform_type == 'ToTensorV2': 138 | transforms.append(ToTensorV2()) 139 | else: 140 | raise ValueError(f"Unknown transform type: {transform_type}") 141 | 142 | return A.Compose(transforms) 143 | 144 | 145 | def get_transforms(transform_config: List[Dict[str, Any]]) -> A.Compose: 146 | """ 147 | Factory function to create transforms from configuration. 148 | 149 | Args: 150 | transform_config (List[Dict[str, Any]]): List of transform configurations 151 | 152 | Returns: 153 | A.Compose: Composed transforms 154 | """ 155 | if isinstance(transform_config, list): 156 | return parse_transform_config(transform_config) 157 | else: 158 | raise ValueError("Transform config must be a list of dictionaries") 159 | 160 | 161 | class MixUp: 162 | """ 163 | MixUp augmentation for semantic segmentation. 164 | """ 165 | 166 | def __init__(self, alpha: float = 1.0, p: float = 0.5): 167 | self.alpha = alpha 168 | self.p = p 169 | 170 | def __call__(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 171 | """ 172 | Apply MixUp to a batch of data. 173 | 174 | Args: 175 | batch (Dict[str, torch.Tensor]): Batch containing images and labels 176 | 177 | Returns: 178 | Dict[str, torch.Tensor]: Mixed batch 179 | """ 180 | if np.random.random() > self.p: 181 | return batch 182 | 183 | images = batch['image'] 184 | labels = batch['label'] 185 | 186 | batch_size = images.size(0) 187 | indices = torch.randperm(batch_size) 188 | 189 | # Sample lambda from Beta distribution 190 | lam = np.random.beta(self.alpha, self.alpha) 191 | 192 | # Mix images 193 | mixed_images = lam * images + (1 - lam) * images[indices] 194 | 195 | # For segmentation, we need to handle labels differently 196 | # We can either use the original labels or create mixed labels 197 | mixed_labels = labels # Keep original labels for simplicity 198 | 199 | return { 200 | 'image': mixed_images, 201 | 'label': mixed_labels, 202 | 'lambda': lam, 203 | 'indices': indices 204 | } 205 | 206 | 207 | class CutMix: 208 | """ 209 | CutMix augmentation for semantic segmentation. 210 | """ 211 | 212 | def __init__(self, alpha: float = 1.0, p: float = 0.5): 213 | self.alpha = alpha 214 | self.p = p 215 | 216 | def __call__(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 217 | """ 218 | Apply CutMix to a batch of data. 219 | 220 | Args: 221 | batch (Dict[str, torch.Tensor]): Batch containing images and labels 222 | 223 | Returns: 224 | Dict[str, torch.Tensor]: Cut-mixed batch 225 | """ 226 | if np.random.random() > self.p: 227 | return batch 228 | 229 | images = batch['image'] 230 | labels = batch['label'] 231 | 232 | batch_size, _, height, width = images.shape 233 | indices = torch.randperm(batch_size) 234 | 235 | # Sample lambda and bounding box 236 | lam = np.random.beta(self.alpha, self.alpha) 237 | 238 | # Generate random bounding box 239 | cut_ratio = np.sqrt(1.0 - lam) 240 | cut_w = int(width * cut_ratio) 241 | cut_h = int(height * cut_ratio) 242 | 243 | cx = np.random.randint(width) 244 | cy = np.random.randint(height) 245 | 246 | bbx1 = np.clip(cx - cut_w // 2, 0, width) 247 | bby1 = np.clip(cy - cut_h // 2, 0, height) 248 | bbx2 = np.clip(cx + cut_w // 2, 0, width) 249 | bby2 = np.clip(cy + cut_h // 2, 0, height) 250 | 251 | # Apply CutMix 252 | mixed_images = images.clone() 253 | mixed_labels = labels.clone() 254 | 255 | mixed_images[:, :, bby1:bby2, bbx1:bbx2] = images[indices, :, bby1:bby2, bbx1:bbx2] 256 | mixed_labels[:, bby1:bby2, bbx1:bbx2] = labels[indices, bby1:bby2, bbx1:bbx2] 257 | 258 | return { 259 | 'image': mixed_images, 260 | 'label': mixed_labels, 261 | 'lambda': lam, 262 | 'indices': indices, 263 | 'bbox': (bbx1, bby1, bbx2, bby2) 264 | } 265 | 266 | 267 | class Mosaic: 268 | """ 269 | Mosaic augmentation for semantic segmentation. 270 | """ 271 | 272 | def __init__(self, p: float = 0.5): 273 | self.p = p 274 | 275 | def __call__(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 276 | """ 277 | Apply Mosaic to a batch of data. 278 | 279 | Args: 280 | batch (Dict[str, torch.Tensor]): Batch containing images and labels 281 | 282 | Returns: 283 | Dict[str, torch.Tensor]: Mosaic batch 284 | """ 285 | if np.random.random() > self.p or batch['image'].size(0) < 4: 286 | return batch 287 | 288 | images = batch['image'] 289 | labels = batch['label'] 290 | 291 | batch_size, channels, height, width = images.shape 292 | 293 | # Create mosaic for first sample 294 | mosaic_image = torch.zeros(channels, height, width) 295 | mosaic_label = torch.zeros(height, width, dtype=labels.dtype) 296 | 297 | # Divide image into 4 quadrants 298 | h_mid = height // 2 299 | w_mid = width // 2 300 | 301 | indices = torch.randperm(batch_size)[:4] 302 | 303 | # Top-left 304 | mosaic_image[:, :h_mid, :w_mid] = images[indices[0], :, :h_mid, :w_mid] 305 | mosaic_label[:h_mid, :w_mid] = labels[indices[0], :h_mid, :w_mid] 306 | 307 | # Top-right 308 | mosaic_image[:, :h_mid, w_mid:] = images[indices[1], :, :h_mid, w_mid:] 309 | mosaic_label[:h_mid, w_mid:] = labels[indices[1], :h_mid, w_mid:] 310 | 311 | # Bottom-left 312 | mosaic_image[:, h_mid:, :w_mid] = images[indices[2], :, h_mid:, :w_mid] 313 | mosaic_label[h_mid:, :w_mid] = labels[indices[2], h_mid:, :w_mid] 314 | 315 | # Bottom-right 316 | mosaic_image[:, h_mid:, w_mid:] = images[indices[3], :, h_mid:, w_mid:] 317 | mosaic_label[h_mid:, w_mid:] = labels[indices[3], h_mid:, w_mid:] 318 | 319 | # Replace first sample with mosaic 320 | new_images = images.clone() 321 | new_labels = labels.clone() 322 | new_images[0] = mosaic_image 323 | new_labels[0] = mosaic_label 324 | 325 | return { 326 | 'image': new_images, 327 | 'label': new_labels 328 | } 329 | 330 | 331 | class AdvancedAugmentationPipeline: 332 | """ 333 | Advanced augmentation pipeline combining multiple techniques. 334 | """ 335 | 336 | def __init__(self, mixup_p: float = 0.3, cutmix_p: float = 0.3, mosaic_p: float = 0.2): 337 | self.mixup = MixUp(p=mixup_p) 338 | self.cutmix = CutMix(p=cutmix_p) 339 | self.mosaic = Mosaic(p=mosaic_p) 340 | 341 | def __call__(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 342 | """ 343 | Apply advanced augmentations to batch. 344 | 345 | Args: 346 | batch (Dict[str, torch.Tensor]): Input batch 347 | 348 | Returns: 349 | Dict[str, torch.Tensor]: Augmented batch 350 | """ 351 | # Apply augmentations in random order 352 | augmentations = [self.mixup, self.cutmix, self.mosaic] 353 | np.random.shuffle(augmentations) 354 | 355 | for aug in augmentations: 356 | batch = aug(batch) 357 | 358 | return batch 359 | 360 | 361 | if __name__ == "__main__": 362 | # Test transforms 363 | from PIL import Image 364 | import numpy as np 365 | 366 | # Create dummy data 367 | image = np.random.randint(0, 255, (512, 512, 3), dtype=np.uint8) 368 | mask = np.random.randint(0, 19, (512, 512), dtype=np.uint8) 369 | 370 | # Test training transforms 371 | train_transforms = SegmentationTransforms.get_training_transforms() 372 | transformed = train_transforms(image=image, mask=mask) 373 | 374 | print(f"Original image shape: {image.shape}") 375 | print(f"Transformed image shape: {transformed['image'].shape}") 376 | print(f"Transformed mask shape: {transformed['mask'].shape}") 377 | 378 | # Test validation transforms 379 | val_transforms = SegmentationTransforms.get_validation_transforms() 380 | val_transformed = val_transforms(image=image, mask=mask) 381 | 382 | print(f"Validation image shape: {val_transformed['image'].shape}") 383 | print(f"Validation mask shape: {val_transformed['mask'].shape}") 384 | 385 | # Test TTA transforms 386 | tta_transforms = SegmentationTransforms.get_test_time_augmentation_transforms() 387 | print(f"Number of TTA transforms: {len(tta_transforms)}") 388 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data Utilities for Semantic Segmentation 3 | 4 | This module contains utility functions for data loading, preprocessing, 5 | and dataset management. 6 | """ 7 | 8 | import torch 9 | from torch.utils.data import DataLoader, DistributedSampler 10 | import numpy as np 11 | from typing import Dict, List, Optional, Tuple, Any, Callable 12 | import cv2 13 | from PIL import Image 14 | import matplotlib.pyplot as plt 15 | import seaborn as sns 16 | from collections import Counter 17 | import os 18 | 19 | 20 | def create_dataloader(dataset, batch_size: int = 8, shuffle: bool = True, 21 | num_workers: int = 4, pin_memory: bool = True, 22 | drop_last: bool = True, distributed: bool = False, generator = None, worker_init_fn = None) -> DataLoader: 23 | """ 24 | Create DataLoader with appropriate settings. 25 | 26 | Args: 27 | dataset: PyTorch dataset 28 | batch_size (int): Batch size 29 | shuffle (bool): Whether to shuffle data 30 | num_workers (int): Number of worker processes 31 | pin_memory (bool): Whether to pin memory 32 | drop_last (bool): Whether to drop last incomplete batch 33 | distributed (bool): Whether to use distributed training 34 | generator: Random number generator for reproducibility 35 | worker_init_fn (Callable): Function to initialize workers 36 | 37 | Returns: 38 | DataLoader: Configured data loader 39 | """ 40 | sampler = None 41 | if distributed: 42 | sampler = DistributedSampler(dataset, shuffle=shuffle) 43 | shuffle = False # Disable shuffle when using sampler 44 | 45 | return DataLoader( 46 | dataset=dataset, 47 | batch_size=batch_size, 48 | shuffle=shuffle, 49 | num_workers=num_workers, 50 | pin_memory=pin_memory, 51 | drop_last=drop_last, 52 | sampler=sampler, 53 | generator=generator, 54 | collate_fn=segmentation_collate_fn, 55 | worker_init_fn=worker_init_fn, 56 | ) 57 | 58 | 59 | def segmentation_collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: 60 | """ 61 | Custom collate function for segmentation data. 62 | 63 | Args: 64 | batch (List[Dict[str, Any]]): List of sample dictionaries 65 | 66 | Returns: 67 | Dict[str, torch.Tensor]: Batched data 68 | """ 69 | images = [] 70 | labels = [] 71 | ori_sizes = [] 72 | image_paths = [] 73 | label_paths = [] 74 | 75 | for sample in batch: 76 | images.append(sample['image']) 77 | labels.append(sample['label']) 78 | image_paths.append(sample['image_path']) 79 | label_paths.append(sample['label_path']) 80 | ori_sizes.append(sample['ori_size']) 81 | 82 | # Stack images and labels 83 | images = torch.stack(images, dim=0) 84 | labels = torch.stack(labels, dim=0) 85 | 86 | return { 87 | 'image': images, 88 | 'label': labels, 89 | 'ori_size': ori_sizes, 90 | 'image_path': image_paths, 91 | 'label_path': label_paths 92 | } 93 | 94 | 95 | def compute_class_distribution(dataset, num_classes: int, ignore_index: int = 255) -> Dict[str, Any]: 96 | """ 97 | Compute class distribution statistics for a dataset. 98 | 99 | Args: 100 | dataset: Segmentation dataset 101 | num_classes (int): Number of classes 102 | ignore_index (int): Index to ignore in calculations 103 | 104 | Returns: 105 | Dict[str, Any]: Class distribution statistics 106 | """ 107 | class_counts = np.zeros(num_classes, dtype=np.int64) 108 | total_pixels = 0 109 | 110 | print("Computing class distribution...") 111 | for i, sample in enumerate(dataset): 112 | if i % 100 == 0: 113 | print(f"Processed {i}/{len(dataset)} samples") 114 | 115 | label = sample['label'].numpy() 116 | mask = (label != ignore_index) 117 | 118 | for c in range(num_classes): 119 | class_counts[c] += np.sum(label == c) 120 | total_pixels += np.sum(mask) 121 | 122 | # Compute statistics 123 | class_frequencies = class_counts / total_pixels 124 | class_weights = 1.0 / (class_frequencies + 1e-8) 125 | class_weights = class_weights / class_weights.sum() * num_classes 126 | 127 | return { 128 | 'class_counts': class_counts, 129 | 'class_frequencies': class_frequencies, 130 | 'class_weights': class_weights, 131 | 'total_pixels': total_pixels 132 | } 133 | 134 | 135 | def visualize_class_distribution(class_stats: Dict[str, Any], class_names: Optional[List[str]] = None, 136 | save_path: Optional[str] = None) -> None: 137 | """ 138 | Visualize class distribution statistics. 139 | 140 | Args: 141 | class_stats (Dict[str, Any]): Class statistics from compute_class_distribution 142 | class_names (Optional[List[str]]): Names of classes 143 | save_path (Optional[str]): Path to save the plot 144 | """ 145 | num_classes = len(class_stats['class_counts']) 146 | 147 | if class_names is None: 148 | class_names = [f"Class {i}" for i in range(num_classes)] 149 | 150 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6)) 151 | 152 | # Plot class counts 153 | bars1 = ax1.bar(range(num_classes), class_stats['class_counts']) 154 | ax1.set_xlabel('Class') 155 | ax1.set_ylabel('Pixel Count') 156 | ax1.set_title('Class Distribution (Pixel Counts)') 157 | ax1.set_xticks(range(num_classes)) 158 | ax1.set_xticklabels(class_names, rotation=45, ha='right') 159 | 160 | # Add value labels on bars 161 | for i, bar in enumerate(bars1): 162 | height = bar.get_height() 163 | ax1.text(bar.get_x() + bar.get_width()/2., height, 164 | f'{int(height):,}', ha='center', va='bottom', fontsize=8) 165 | 166 | # Plot class frequencies 167 | bars2 = ax2.bar(range(num_classes), class_stats['class_frequencies']) 168 | ax2.set_xlabel('Class') 169 | ax2.set_ylabel('Frequency') 170 | ax2.set_title('Class Distribution (Frequencies)') 171 | ax2.set_xticks(range(num_classes)) 172 | ax2.set_xticklabels(class_names, rotation=45, ha='right') 173 | 174 | # Add value labels on bars 175 | for i, bar in enumerate(bars2): 176 | height = bar.get_height() 177 | ax2.text(bar.get_x() + bar.get_width()/2., height, 178 | f'{height:.3f}', ha='center', va='bottom', fontsize=8) 179 | 180 | plt.tight_layout() 181 | 182 | if save_path: 183 | plt.savefig(save_path, dpi=300, bbox_inches='tight') 184 | print(f"Class distribution plot saved to: {save_path}") 185 | 186 | plt.show() 187 | 188 | 189 | def visualize_sample(sample: Dict[str, torch.Tensor], class_colors: Optional[List[List[int]]] = None, 190 | class_names: Optional[List[str]] = None, save_path: Optional[str] = None) -> None: 191 | """ 192 | Visualize a single sample with image and label overlay. 193 | 194 | Args: 195 | sample (Dict[str, torch.Tensor]): Sample containing image and label 196 | class_colors (Optional[List[List[int]]]): RGB colors for each class 197 | class_names (Optional[List[str]]): Names of classes 198 | save_path (Optional[str]): Path to save the visualization 199 | """ 200 | image = sample['image'] 201 | label = sample['label'] 202 | 203 | # Convert tensors to numpy arrays 204 | if isinstance(image, torch.Tensor): 205 | if image.dim() == 3: # C, H, W 206 | image = image.permute(1, 2, 0) 207 | image = image.numpy() 208 | 209 | if isinstance(label, torch.Tensor): 210 | label = label.numpy() 211 | 212 | # Denormalize image if needed 213 | if image.max() <= 1.0: 214 | image = (image * 255).astype(np.uint8) 215 | 216 | # Create colored label map 217 | if class_colors is None: 218 | # Generate random colors 219 | num_classes = int(label.max()) + 1 220 | class_colors = plt.cm.tab20(np.linspace(0, 1, num_classes))[:, :3] * 255 221 | class_colors = class_colors.astype(np.uint8) 222 | 223 | colored_label = np.zeros((*label.shape, 3), dtype=np.uint8) 224 | for class_id, color in enumerate(class_colors): 225 | colored_label[label == class_id] = color 226 | 227 | # Create overlay 228 | alpha = 0.6 229 | overlay = cv2.addWeighted(image, 1 - alpha, colored_label, alpha, 0) 230 | 231 | # Create visualization 232 | fig, axes = plt.subplots(1, 3, figsize=(15, 5)) 233 | 234 | # Original image 235 | axes[0].imshow(image) 236 | axes[0].set_title('Original Image') 237 | axes[0].axis('off') 238 | 239 | # Label map 240 | axes[1].imshow(colored_label) 241 | axes[1].set_title('Label Map') 242 | axes[1].axis('off') 243 | 244 | # Overlay 245 | axes[2].imshow(overlay) 246 | axes[2].set_title('Overlay') 247 | axes[2].axis('off') 248 | 249 | plt.tight_layout() 250 | 251 | if save_path: 252 | plt.savefig(save_path, dpi=300, bbox_inches='tight') 253 | print(f"Sample visualization saved to: {save_path}") 254 | 255 | plt.show() 256 | 257 | 258 | def create_color_map(num_classes: int) -> np.ndarray: 259 | """ 260 | Create a color map for visualization. 261 | 262 | Args: 263 | num_classes (int): Number of classes 264 | 265 | Returns: 266 | np.ndarray: Color map of shape (num_classes, 3) 267 | """ 268 | colors = [] 269 | for i in range(num_classes): 270 | # Generate distinct colors using HSV space 271 | hue = i / num_classes 272 | saturation = 0.7 + 0.3 * (i % 2) # Alternate between high and higher saturation 273 | value = 0.8 + 0.2 * ((i // 2) % 2) # Alternate brightness 274 | 275 | # Convert HSV to RGB 276 | hsv = np.array([hue, saturation, value]).reshape(1, 1, 3) 277 | rgb = cv2.cvtColor((hsv * 255).astype(np.uint8), cv2.COLOR_HSV2RGB)[0, 0] 278 | colors.append(rgb) 279 | 280 | return np.array(colors) 281 | 282 | 283 | def analyze_dataset_quality(dataset, sample_ratio: float = 0.1) -> Dict[str, Any]: 284 | """ 285 | Analyze dataset quality metrics. 286 | 287 | Args: 288 | dataset: Segmentation dataset 289 | sample_ratio (float): Ratio of samples to analyze 290 | 291 | Returns: 292 | Dict[str, Any]: Quality analysis results 293 | """ 294 | num_samples = int(len(dataset) * sample_ratio) 295 | indices = np.random.choice(len(dataset), num_samples, replace=False) 296 | 297 | image_sizes = [] 298 | label_coverage = [] # Percentage of labeled pixels 299 | class_diversity = [] # Number of unique classes per image 300 | 301 | print(f"Analyzing dataset quality on {num_samples} samples...") 302 | 303 | for i, idx in enumerate(indices): 304 | if i % 50 == 0: 305 | print(f"Processed {i}/{num_samples} samples") 306 | 307 | sample = dataset[idx] 308 | image = sample['image'] 309 | label = sample['label'] 310 | 311 | # Image size 312 | if isinstance(image, torch.Tensor): 313 | h, w = image.shape[-2:] 314 | else: 315 | h, w = image.shape[:2] 316 | image_sizes.append((h, w)) 317 | 318 | # Label coverage 319 | if isinstance(label, torch.Tensor): 320 | label_np = label.numpy() 321 | else: 322 | label_np = label 323 | 324 | valid_pixels = np.sum(label_np != 255) # Assuming 255 is ignore_index 325 | total_pixels = label_np.size 326 | coverage = valid_pixels / total_pixels 327 | label_coverage.append(coverage) 328 | 329 | # Class diversity 330 | unique_classes = len(np.unique(label_np[label_np != 255])) 331 | class_diversity.append(unique_classes) 332 | 333 | # Compute statistics 334 | unique_sizes = list(set(image_sizes)) 335 | size_consistency = len(unique_sizes) == 1 336 | 337 | return { 338 | 'num_samples_analyzed': num_samples, 339 | 'unique_image_sizes': unique_sizes, 340 | 'size_consistency': size_consistency, 341 | 'avg_label_coverage': np.mean(label_coverage), 342 | 'std_label_coverage': np.std(label_coverage), 343 | 'avg_class_diversity': np.mean(class_diversity), 344 | 'std_class_diversity': np.std(class_diversity), 345 | 'label_coverage_distribution': label_coverage, 346 | 'class_diversity_distribution': class_diversity 347 | } 348 | 349 | 350 | def save_dataset_info(dataset, output_dir: str, dataset_name: str = "dataset") -> None: 351 | """ 352 | Save comprehensive dataset information to files. 353 | 354 | Args: 355 | dataset: Segmentation dataset 356 | output_dir (str): Output directory 357 | dataset_name (str): Name of the dataset 358 | """ 359 | os.makedirs(output_dir, exist_ok=True) 360 | 361 | # Basic info 362 | info = { 363 | 'dataset_name': dataset_name, 364 | 'num_samples': len(dataset), 365 | 'num_classes': getattr(dataset, 'num_classes', 'unknown'), 366 | 'ignore_index': getattr(dataset, 'ignore_index', 255) 367 | } 368 | 369 | # Save basic info 370 | import json 371 | with open(os.path.join(output_dir, f'{dataset_name}_info.json'), 'w') as f: 372 | json.dump(info, f, indent=2) 373 | 374 | # Compute and save class distribution 375 | if hasattr(dataset, 'num_classes'): 376 | class_stats = compute_class_distribution(dataset, dataset.num_classes) 377 | 378 | # Save class statistics 379 | np.save(os.path.join(output_dir, f'{dataset_name}_class_counts.npy'), 380 | class_stats['class_counts']) 381 | np.save(os.path.join(output_dir, f'{dataset_name}_class_weights.npy'), 382 | class_stats['class_weights']) 383 | 384 | # Save class distribution plot 385 | visualize_class_distribution( 386 | class_stats, 387 | save_path=os.path.join(output_dir, f'{dataset_name}_class_distribution.png') 388 | ) 389 | 390 | # Dataset quality analysis 391 | quality_stats = analyze_dataset_quality(dataset) 392 | with open(os.path.join(output_dir, f'{dataset_name}_quality.json'), 'w') as f: 393 | # Convert numpy arrays to lists for JSON serialization 394 | quality_stats_serializable = {} 395 | for k, v in quality_stats.items(): 396 | if isinstance(v, np.ndarray): 397 | quality_stats_serializable[k] = v.tolist() 398 | elif isinstance(v, np.float64): 399 | quality_stats_serializable[k] = float(v) 400 | else: 401 | quality_stats_serializable[k] = v 402 | json.dump(quality_stats_serializable, f, indent=2) 403 | 404 | print(f"Dataset information saved to: {output_dir}") 405 | 406 | 407 | def create_data_split(image_dir: str, label_dir: str, 408 | train_ratio: float = 0.7, val_ratio: float = 0.2, test_ratio: float = 0.1, 409 | output_dir: str = "splits", seed: int = 42) -> None: 410 | """ 411 | Create train/val/test splits for a custom dataset. 412 | 413 | Args: 414 | image_dir (str): Directory containing images 415 | label_dir (str): Directory containing labels 416 | train_ratio (float): Ratio for training set 417 | val_ratio (float): Ratio for validation set 418 | test_ratio (float): Ratio for test set 419 | output_dir (str): Output directory for split files 420 | seed (int): Random seed for reproducibility 421 | """ 422 | import random 423 | import shutil 424 | 425 | assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6, "Ratios must sum to 1.0" 426 | 427 | # Get all image files 428 | image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))] 429 | 430 | # Filter files that have corresponding labels 431 | valid_files = [] 432 | for img_file in image_files: 433 | label_file = img_file.replace('.jpg', '.png').replace('.jpeg', '.png') 434 | if os.path.exists(os.path.join(label_dir, label_file)): 435 | valid_files.append(img_file) 436 | 437 | print(f"Found {len(valid_files)} valid image-label pairs") 438 | 439 | # Shuffle files 440 | random.seed(seed) 441 | random.shuffle(valid_files) 442 | 443 | # Compute split indices 444 | num_files = len(valid_files) 445 | train_end = int(num_files * train_ratio) 446 | val_end = train_end + int(num_files * val_ratio) 447 | 448 | # Split files 449 | train_files = valid_files[:train_end] 450 | val_files = valid_files[train_end:val_end] 451 | test_files = valid_files[val_end:] 452 | 453 | print(f"Split: Train={len(train_files)}, Val={len(val_files)}, Test={len(test_files)}") 454 | 455 | # Create output directories 456 | for split in ['train', 'val', 'test']: 457 | os.makedirs(os.path.join(output_dir, 'images', split), exist_ok=True) 458 | os.makedirs(os.path.join(output_dir, 'labels', split), exist_ok=True) 459 | 460 | # Copy files to respective directories 461 | for split, files in [('train', train_files), ('val', val_files), ('test', test_files)]: 462 | for img_file in files: 463 | label_file = img_file.replace('.jpg', '.png').replace('.jpeg', '.png') 464 | 465 | # Copy image 466 | shutil.copy2( 467 | os.path.join(image_dir, img_file), 468 | os.path.join(output_dir, 'images', split, img_file) 469 | ) 470 | 471 | # Copy label 472 | shutil.copy2( 473 | os.path.join(label_dir, label_file), 474 | os.path.join(output_dir, 'labels', split, label_file) 475 | ) 476 | 477 | print(f"Data split created in: {output_dir}") 478 | 479 | 480 | if __name__ == "__main__": 481 | # Test utilities 482 | print("Testing data utilities...") 483 | 484 | # Test color map creation 485 | color_map = create_color_map(19) 486 | print(f"Created color map with shape: {color_map.shape}") 487 | 488 | # Test other functions would require actual dataset 489 | print("Data utilities module loaded successfully!") 490 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Models package for semantic segmentation. 3 | 4 | This package contains model definitions, loss functions, and utilities 5 | for semantic segmentation tasks. 6 | """ 7 | 8 | from .pfm_seg_models import create_pfm_segmentation_model 9 | from .lora import equip_model_with_lora 10 | from .losses import ( 11 | CrossEntropyLoss, DiceLoss, IoULoss, OHEMLoss,get_loss_function 12 | ) 13 | from .utils import ( 14 | count_parameters, initialize_weights, 15 | save_checkpoint, load_checkpoint, 16 | get_model_complexity, convert_to_onnx, print_model_summary 17 | ) 18 | 19 | __all__ = [ 20 | # Models 21 | 'create_pfm_segmentation_model', 22 | 'equip_model_with_lora', 23 | 24 | # Loss functions 25 | 'CrossEntropyLoss', 26 | 'DiceLoss', 27 | 'IoULoss', 28 | 'OHEMLoss', 29 | 'get_loss_function', 30 | 31 | # Utilities 32 | 'count_parameters', 33 | 'initialize_weights', 34 | 'save_checkpoint', 35 | 'load_checkpoint', 36 | 'get_model_complexity', 37 | 'convert_to_onnx', 38 | 'print_model_summary' 39 | ] 40 | -------------------------------------------------------------------------------- /models/conch_v1_5_config.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Any 3 | 4 | from transformers import PretrainedConfig 5 | 6 | class ConchConfig(PretrainedConfig): 7 | model_type = "conch" 8 | 9 | def __init__( 10 | self, 11 | patch_size: int = 16, 12 | context_dim: int = 1024, 13 | embed_dim: int = 768, 14 | depth: int = 24, 15 | num_heads: int = 16, 16 | mlp_ratio: float = 4.0, 17 | qkv_bias: bool = True, 18 | init_values: float = 1e-6, 19 | pooler_n_queries_contrast: int = 1, 20 | **kwargs: Any, 21 | ): 22 | self.patch_size = patch_size 23 | self.context_dim = context_dim 24 | self.embed_dim = embed_dim 25 | self.depth = depth 26 | self.num_heads = num_heads 27 | self.mlp_ratio = mlp_ratio 28 | self.qkv_bias = qkv_bias 29 | self.init_values = init_values 30 | self.pooler_n_queries_contrast = pooler_n_queries_contrast 31 | 32 | super().__init__(**kwargs) -------------------------------------------------------------------------------- /models/lora.py: -------------------------------------------------------------------------------- 1 | import math 2 | import timm 3 | from timm.layers import use_fused_attn 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.nn.init as init 9 | from einops import repeat 10 | from .build_conch_v1_5 import Conch_V1_5_Attention 11 | 12 | 13 | 14 | class LoRALinear(nn.Module): 15 | def __init__(self, in_features, out_features, bias=True, r=4, lora_alpha=1.0): 16 | super().__init__() 17 | self.r = r 18 | self.in_features = in_features 19 | self.out_features = out_features 20 | self.bias = bias 21 | self.r = r 22 | self.lora_alpha = lora_alpha 23 | # original linear layer 24 | self.weight = nn.Parameter(torch.empty((out_features, in_features))) 25 | if bias: 26 | self.bias = nn.Parameter(torch.empty(out_features)) 27 | else: 28 | self.register_parameter('bias', None) 29 | 30 | 31 | # LoRA: low-rank adaptor 32 | self.lora_a = nn.Parameter(torch.zeros(in_features, r), requires_grad=True) 33 | self.lora_b = nn.Parameter(torch.zeros(r, out_features), requires_grad=True) 34 | self.scale = lora_alpha 35 | 36 | # initialization 37 | self.reset_parameters() 38 | 39 | def reset_parameters(self) -> None: 40 | init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 41 | if self.bias is not None: 42 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) 43 | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 44 | init.uniform_(self.bias, -bound, bound) 45 | 46 | nn.init.kaiming_uniform_(self.lora_a, a=math.sqrt(5)) 47 | nn.init.zeros_(self.lora_b) 48 | 49 | def forward(self, x): # shape [10000, 197, 1024] 50 | # compute original output 51 | ori_output = F.linear(x, self.weight, self.bias) 52 | lora_output = ((x @ self.lora_a) @ self.lora_b) * self.scale 53 | return ori_output + lora_output 54 | 55 | 56 | class LoRA_Attention(nn.Module): 57 | def __init__( 58 | self, 59 | dim: int, 60 | num_heads: int = 8, 61 | qkv_bias: bool = False, 62 | qk_norm: bool = False, 63 | attn_drop: float = 0., 64 | proj_drop: float = 0., 65 | norm_layer: nn.Module = nn.LayerNorm, 66 | lora_r: int = 16, 67 | lora_alpha: float = 1., 68 | ) -> None: 69 | super().__init__() 70 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 71 | self.num_heads = num_heads 72 | self.head_dim = dim // num_heads 73 | self.scale = self.head_dim ** -0.5 74 | self.fused_attn = use_fused_attn() 75 | 76 | self.qkv = LoRALinear(dim, dim * 3, bias=qkv_bias, r=lora_r, lora_alpha=lora_alpha) 77 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 78 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 79 | self.attn_drop = nn.Dropout(attn_drop) 80 | self.proj = LoRALinear(dim, dim, r=lora_r, lora_alpha=lora_alpha) 81 | self.proj_drop = nn.Dropout(proj_drop) 82 | 83 | def forward(self, x: torch.Tensor) -> torch.Tensor: 84 | B, N, C = x.shape 85 | 86 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 87 | q, k, v = qkv.unbind(0) 88 | q, k = self.q_norm(q), self.k_norm(k) 89 | 90 | if self.fused_attn: 91 | x = F.scaled_dot_product_attention( 92 | q, k, v, 93 | dropout_p=self.attn_drop.p if self.training else 94 | 0., 95 | ) 96 | else: 97 | q = q * self.scale 98 | attn = q @ k.transpose(-2, -1) 99 | attn = attn.softmax(dim=-1) 100 | attn = self.attn_drop(attn) 101 | x = attn @ v 102 | 103 | x = x.transpose(1, 2).reshape(B, N, C) 104 | x = self.proj(x) 105 | x = self.proj_drop(x) 106 | return x 107 | 108 | 109 | 110 | 111 | def equip_model_with_lora(pfm_name, model, rank, alpha): 112 | """ 113 | Equip a PFM model with LoRA by replacing its attention layers with LoRA_Attention layers. 114 | This version also copies original attention weights into the LoRA-Attention module. 115 | 116 | Args: 117 | pfm_name (str): Name of the PFM model. 118 | model (nn.Module): The PFM model to be equipped with LoRA. 119 | rank (int): Rank of the low-rank adaptation. 120 | alpha (float): Scaling factor for the LoRA output. 121 | 122 | Returns: 123 | nn.Module: The PFM model with LoRA applied to its attention layers. 124 | """ 125 | def copy_weights(src_attn, dst_lora_attn): 126 | with torch.no_grad(): 127 | dst_lora_attn.qkv.weight.copy_(src_attn.qkv.weight) 128 | if src_attn.qkv.bias is not None: 129 | dst_lora_attn.qkv.bias.copy_(src_attn.qkv.bias) 130 | 131 | dst_lora_attn.proj.weight.copy_(src_attn.proj.weight) 132 | if src_attn.proj.bias is not None: 133 | dst_lora_attn.proj.bias.copy_(src_attn.proj.bias) 134 | 135 | if hasattr(src_attn, 'q_norm') and hasattr(dst_lora_attn, 'q_norm') and isinstance(dst_lora_attn.q_norm, nn.LayerNorm): 136 | dst_lora_attn.q_norm.load_state_dict(src_attn.q_norm.state_dict()) 137 | if hasattr(src_attn, 'k_norm') and hasattr(dst_lora_attn, 'k_norm') and isinstance(dst_lora_attn.k_norm, nn.LayerNorm): 138 | dst_lora_attn.k_norm.load_state_dict(src_attn.k_norm.state_dict()) 139 | 140 | if pfm_name in ['uni_v1', 'uni_v2', 'virchow_v2', 'gigapath']: 141 | for name, module in model.named_modules(): 142 | if isinstance(module, timm.models.vision_transformer.Attention): 143 | lora_attn = LoRA_Attention( 144 | dim=module.qkv.in_features, 145 | num_heads=module.num_heads, 146 | qkv_bias=module.qkv.bias is not None, 147 | qk_norm=isinstance(module.q_norm, nn.LayerNorm), 148 | attn_drop=module.attn_drop.p, 149 | proj_drop=module.proj_drop.p, 150 | lora_r=rank, 151 | lora_alpha=alpha, 152 | ) 153 | 154 | copy_weights(module, lora_attn) 155 | 156 | parent_module = dict(model.named_modules())[name.rsplit(".", 1)[0]] 157 | setattr(parent_module, name.rsplit('.', 1)[-1], lora_attn) 158 | 159 | elif pfm_name == 'conch_v1_5': 160 | for name, module in model.named_modules(): 161 | if isinstance(module, Conch_V1_5_Attention): 162 | lora_attn = LoRA_Attention( 163 | dim=module.qkv.in_features, 164 | num_heads=module.num_heads, 165 | qkv_bias=module.qkv.bias is not None, 166 | qk_norm=isinstance(module.q_norm, nn.LayerNorm), 167 | attn_drop=module.attn_drop.p, 168 | proj_drop=module.proj_drop.p, 169 | lora_r=rank, 170 | lora_alpha=alpha, 171 | ) 172 | 173 | copy_weights(module, lora_attn) 174 | 175 | parent_module = dict(model.named_modules())[name.rsplit(".", 1)[0]] 176 | setattr(parent_module, name.rsplit('.', 1)[-1], lora_attn) 177 | 178 | for param in model.parameters(): 179 | param.requires_grad = False 180 | 181 | for name, param in model.named_parameters(): 182 | if 'lora_a' in name or 'lora_b' in name: 183 | param.requires_grad = True 184 | 185 | return model 186 | -------------------------------------------------------------------------------- /models/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Loss Functions for Semantic Segmentation 3 | 4 | This module contains various loss functions commonly used in semantic segmentation, 5 | including Cross Entropy, Focal Loss, Dice Loss, IoU Loss, and OHEM. 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from typing import Optional, List, Tuple 12 | import numpy as np 13 | 14 | 15 | class CrossEntropyLoss(nn.Module): 16 | """ 17 | Standard Cross Entropy Loss for semantic segmentation. 18 | 19 | Args: 20 | ignore_index (int): Index to ignore in loss calculation 21 | weight (Optional[torch.Tensor]): Class weights for handling imbalanced datasets 22 | reduction (str): Reduction method ('mean', 'sum', 'none') 23 | """ 24 | 25 | def __init__(self, ignore_index: int = 255, weight: Optional[torch.Tensor] = None, 26 | reduction: str = 'mean'): 27 | super(CrossEntropyLoss, self).__init__() 28 | self.ignore_index = ignore_index 29 | self.weight = weight 30 | self.reduction = reduction 31 | 32 | def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 33 | """ 34 | Forward pass of Cross Entropy Loss. 35 | 36 | Args: 37 | pred (torch.Tensor): Predictions of shape (B, C, H, W) 38 | target (torch.Tensor): Ground truth of shape (B, H, W) 39 | 40 | Returns: 41 | torch.Tensor: Loss value 42 | """ 43 | return F.cross_entropy( 44 | pred, target, 45 | weight=self.weight, 46 | ignore_index=self.ignore_index, 47 | reduction=self.reduction 48 | ) 49 | 50 | 51 | class DiceLoss(nn.Module): 52 | """ 53 | Dice Loss for semantic segmentation, particularly effective for small objects. 54 | 55 | Args: 56 | smooth (float): Smoothing factor to avoid division by zero 57 | ignore_index (int): Index to ignore in loss calculation 58 | reduction (str): Reduction method ('mean', 'sum', 'none') 59 | """ 60 | 61 | def __init__(self, smooth: float = 1e-5, ignore_index: int = 255, reduction: str = 'mean'): 62 | super(DiceLoss, self).__init__() 63 | self.smooth = smooth 64 | self.ignore_index = ignore_index 65 | self.reduction = reduction 66 | 67 | def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 68 | """ 69 | Forward pass of Dice Loss. 70 | 71 | Args: 72 | pred (torch.Tensor): Predictions of shape (B, C, H, W) 73 | target (torch.Tensor): Ground truth of shape (B, H, W) 74 | 75 | Returns: 76 | torch.Tensor: Loss value 77 | """ 78 | # Convert predictions to probabilities 79 | pred = F.softmax(pred, dim=1) 80 | 81 | # One-hot encode target 82 | num_classes = pred.shape[1] 83 | target_one_hot = F.one_hot(target, num_classes).permute(0, 3, 1, 2).float() 84 | 85 | # Create mask for valid pixels 86 | mask = (target != self.ignore_index).unsqueeze(1).float() 87 | pred = pred * mask 88 | target_one_hot = target_one_hot * mask 89 | 90 | # Compute Dice coefficient 91 | intersection = (pred * target_one_hot).sum(dim=(2, 3)) 92 | union = pred.sum(dim=(2, 3)) + target_one_hot.sum(dim=(2, 3)) 93 | 94 | dice_coeff = (2.0 * intersection + self.smooth) / (union + self.smooth) 95 | dice_loss = 1.0 - dice_coeff 96 | 97 | if self.reduction == 'mean': 98 | return dice_loss.mean() 99 | elif self.reduction == 'sum': 100 | return dice_loss.sum() 101 | else: 102 | return dice_loss 103 | 104 | 105 | class IoULoss(nn.Module): 106 | """ 107 | IoU (Intersection over Union) Loss for semantic segmentation. 108 | 109 | Args: 110 | smooth (float): Smoothing factor to avoid division by zero 111 | ignore_index (int): Index to ignore in loss calculation 112 | reduction (str): Reduction method ('mean', 'sum', 'none') 113 | """ 114 | 115 | def __init__(self, smooth: float = 1e-5, ignore_index: int = 255, reduction: str = 'mean'): 116 | super(IoULoss, self).__init__() 117 | self.smooth = smooth 118 | self.ignore_index = ignore_index 119 | self.reduction = reduction 120 | 121 | def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 122 | """ 123 | Forward pass of IoU Loss. 124 | 125 | Args: 126 | pred (torch.Tensor): Predictions of shape (B, C, H, W) 127 | target (torch.Tensor): Ground truth of shape (B, H, W) 128 | 129 | Returns: 130 | torch.Tensor: Loss value 131 | """ 132 | # Convert predictions to probabilities 133 | pred = F.softmax(pred, dim=1) 134 | 135 | # One-hot encode target 136 | num_classes = pred.shape[1] 137 | target_one_hot = F.one_hot(target, num_classes).permute(0, 3, 1, 2).float() 138 | 139 | # Create mask for valid pixels 140 | mask = (target != self.ignore_index).unsqueeze(1).float() 141 | pred = pred * mask 142 | target_one_hot = target_one_hot * mask 143 | 144 | # Compute IoU 145 | intersection = (pred * target_one_hot).sum(dim=(2, 3)) 146 | union = pred.sum(dim=(2, 3)) + target_one_hot.sum(dim=(2, 3)) - intersection 147 | 148 | iou = (intersection + self.smooth) / (union + self.smooth) 149 | iou_loss = 1.0 - iou 150 | 151 | if self.reduction == 'mean': 152 | return iou_loss.mean() 153 | elif self.reduction == 'sum': 154 | return iou_loss.sum() 155 | else: 156 | return iou_loss 157 | 158 | 159 | class OHEMLoss(nn.Module): 160 | """ 161 | Online Hard Example Mining (OHEM) Loss for focusing on hard examples. 162 | 163 | Args: 164 | thresh (float): Threshold for hard example selection 165 | min_kept (int): Minimum number of pixels to keep 166 | ignore_index (int): Index to ignore in loss calculation 167 | base_loss (str): Base loss function ('ce', 'focal') 168 | """ 169 | 170 | def __init__(self, thresh: float = 0.7, min_kept: int = 100000, 171 | ignore_index: int = 255, base_loss: str = 'ce'): 172 | super(OHEMLoss, self).__init__() 173 | self.thresh = thresh 174 | self.min_kept = min_kept 175 | self.ignore_index = ignore_index 176 | 177 | if base_loss == 'ce': 178 | self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='none') 179 | else: 180 | raise ValueError(f"Unsupported base loss: {base_loss}") 181 | 182 | def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 183 | """ 184 | Forward pass of OHEM Loss. 185 | 186 | Args: 187 | pred (torch.Tensor): Predictions of shape (B, C, H, W) 188 | target (torch.Tensor): Ground truth of shape (B, H, W) 189 | 190 | Returns: 191 | torch.Tensor: Loss value 192 | """ 193 | # Compute pixel-wise loss 194 | pixel_losses = self.criterion(pred, target) 195 | 196 | # Create mask for valid pixels 197 | mask = (target != self.ignore_index).float() 198 | pixel_losses = pixel_losses * mask 199 | 200 | # Sort losses in descending order 201 | sorted_losses, _ = torch.sort(pixel_losses.view(-1), descending=True) 202 | 203 | # Determine number of pixels to keep 204 | valid_pixels = mask.sum().int().item() 205 | keep_num = max(self.min_kept, int(valid_pixels * self.thresh)) 206 | keep_num = min(keep_num, valid_pixels) 207 | 208 | # Keep only the hardest examples 209 | if keep_num < valid_pixels: 210 | threshold = sorted_losses[keep_num] 211 | hard_mask = (pixel_losses >= threshold).float() 212 | return (pixel_losses * hard_mask).sum() / hard_mask.sum() 213 | else: 214 | return pixel_losses.sum() / mask.sum() 215 | 216 | 217 | def get_loss_function(loss_config: dict) -> nn.Module: 218 | """ 219 | Factory function to create loss function based on configuration. 220 | 221 | Args: 222 | loss_config (dict): Loss configuration dictionary 223 | 224 | Returns: 225 | nn.Module: Loss function 226 | """ 227 | loss_type = loss_config.get('type', 'cross_entropy').lower() 228 | ignore_index = loss_config.get('ignore_index', 255) 229 | 230 | if loss_type == 'cross_entropy' or loss_type == 'ce': 231 | weight = loss_config.get('class_weights') 232 | if weight is not None: 233 | weight = torch.tensor(weight, dtype=torch.float32) 234 | return CrossEntropyLoss(ignore_index=ignore_index, weight=weight) 235 | 236 | elif loss_type == 'dice': 237 | smooth = loss_config.get('dice_smooth', 1e-5) 238 | return DiceLoss(smooth=smooth, ignore_index=ignore_index) 239 | 240 | elif loss_type == 'iou': 241 | smooth = loss_config.get('iou_smooth', 1e-5) 242 | return IoULoss(smooth=smooth, ignore_index=ignore_index) 243 | 244 | elif loss_type == 'ohem': 245 | thresh = loss_config.get('ohem_thresh', 0.7) 246 | min_kept = loss_config.get('ohem_min_kept', 100000) 247 | base_loss = loss_config.get('ohem_base_loss', 'ce') 248 | return OHEMLoss(thresh=thresh, min_kept=min_kept, 249 | ignore_index=ignore_index, base_loss=base_loss) 250 | 251 | else: 252 | raise ValueError(f"Unsupported loss type: {loss_type}") 253 | 254 | 255 | if __name__ == "__main__": 256 | # Test loss functions 257 | batch_size, num_classes, height, width = 2, 19, 64, 64 258 | 259 | # Create dummy data 260 | pred = torch.randn(batch_size, num_classes, height, width) 261 | target = torch.randint(0, num_classes, (batch_size, height, width)) 262 | 263 | # Test different loss functions 264 | losses = { 265 | 'CrossEntropy': CrossEntropyLoss(), 266 | 'Dice': DiceLoss(), 267 | 'IoU': IoULoss(), 268 | 'OHEM': OHEMLoss(), 269 | } 270 | 271 | for loss_name, loss_fn in losses.items(): 272 | loss_value = loss_fn(pred, target) 273 | print(f"{loss_name} Loss: {loss_value.item():.4f}") 274 | -------------------------------------------------------------------------------- /models/pfm_seg_models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pathology Foundation Models (PFM) for Semantic Segmentation 3 | 4 | This module integrates multiple pathology foundation models including 5 | Gigapath, UNI v1/v2, Virchow v2, and Conch V1.5 for segmentation tasks. 6 | 7 | Author: @Toby 8 | Function: Segmentation models using PFMs (pathology foundation models) 9 | """ 10 | 11 | import copy 12 | import logging 13 | import math 14 | from os.path import join as pjoin 15 | from collections import OrderedDict 16 | import torch 17 | import torch.nn as nn 18 | import numpy as np 19 | import timm 20 | from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm 21 | from torch.nn.modules.utils import _pair 22 | from scipy import ndimage 23 | from typing import Optional, Dict, Any, Tuple 24 | from .lora import equip_model_with_lora 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | # Vision Transformer component names for loading pretrained weights 29 | ATTENTION_Q = "MultiHeadDotProductAttention_1/query" 30 | ATTENTION_K = "MultiHeadDotProductAttention_1/key" 31 | ATTENTION_V = "MultiHeadDotProductAttention_1/value" 32 | ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" 33 | FC_0 = "MlpBlock_3/Dense_0" 34 | FC_1 = "MlpBlock_3/Dense_1" 35 | ATTENTION_NORM = "LayerNorm_0" 36 | MLP_NORM = "LayerNorm_2" 37 | 38 | 39 | def get_PFM_model(PFM_name: str, PFM_weights_path: str) -> nn.Module: 40 | """ 41 | Load and configure a Pathology Foundation Model. 42 | 43 | Args: 44 | PFM_name (str): Name of the PFM model 45 | PFM_weights_path (str): Path to model weights 46 | 47 | Returns: 48 | nn.Module: Configured PFM model 49 | """ 50 | if PFM_name == 'gigapath': 51 | gig_config = { 52 | "architecture": "vit_giant_patch14_dinov2", 53 | "num_classes": 0, 54 | "num_features": 1536, 55 | "global_pool": "token", 56 | "model_args": { 57 | "img_size": 224, 58 | "in_chans": 3, 59 | "patch_size": 16, 60 | "embed_dim": 1536, 61 | "depth": 40, 62 | "num_heads": 24, 63 | "init_values": 1e-05, 64 | "mlp_ratio": 5.33334, 65 | "num_classes": 0, 66 | "dynamic_img_size": True 67 | } 68 | } 69 | model = timm.create_model("vit_giant_patch14_dinov2", pretrained=False, **gig_config['model_args']) 70 | state_dict = torch.load(PFM_weights_path, map_location="cpu", weights_only=True) 71 | model.load_state_dict(state_dict, strict=True) 72 | 73 | elif PFM_name == 'uni_v1': 74 | model = timm.create_model( 75 | "vit_large_patch16_224", 76 | img_size=224, 77 | patch_size=16, 78 | init_values=1e-5, 79 | num_classes=0, 80 | dynamic_img_size=True 81 | ) 82 | model.load_state_dict(torch.load(PFM_weights_path, map_location='cpu', weights_only=True), strict=True) 83 | 84 | elif PFM_name == 'virchow_v2': 85 | from timm.layers import SwiGLUPacked 86 | virchow_v2_config = { 87 | "img_size": 224, 88 | "init_values": 1e-5, 89 | "num_classes": 0, 90 | "mlp_ratio": 5.3375, 91 | "reg_tokens": 4, 92 | "global_pool": "", 93 | "dynamic_img_size": True 94 | } 95 | model = timm.create_model( 96 | "vit_huge_patch14_224", 97 | pretrained=False, 98 | mlp_layer=SwiGLUPacked, 99 | act_layer=torch.nn.SiLU, 100 | **virchow_v2_config 101 | ) 102 | state_dict = torch.load(PFM_weights_path, map_location="cpu", weights_only=True) 103 | model.load_state_dict(state_dict, strict=True) 104 | 105 | elif PFM_name == 'conch_v1_5': 106 | try: 107 | from .conch_v1_5_config import ConchConfig 108 | from .build_conch_v1_5 import build_conch_v1_5 109 | conch_v1_5_config = ConchConfig() 110 | model = build_conch_v1_5(conch_v1_5_config, PFM_weights_path) 111 | except ImportError: 112 | raise ImportError("Conch V1.5 dependencies not found.") 113 | 114 | elif PFM_name == 'uni_v2': 115 | timm_kwargs = { 116 | 'img_size': 224, 117 | 'patch_size': 14, 118 | 'depth': 24, 119 | 'num_heads': 24, 120 | 'init_values': 1e-5, 121 | 'embed_dim': 1536, 122 | 'mlp_ratio': 2.66667 * 2, 123 | 'num_classes': 0, 124 | 'no_embed_class': True, 125 | 'mlp_layer': timm.layers.SwiGLUPacked, 126 | 'act_layer': torch.nn.SiLU, 127 | 'reg_tokens': 8, 128 | 'dynamic_img_size': True 129 | } 130 | model = timm.create_model('vit_giant_patch14_224', pretrained=False, **timm_kwargs) 131 | state_dict = torch.load(PFM_weights_path, map_location="cpu", weights_only=True) 132 | model.load_state_dict(state_dict, strict=True) 133 | 134 | else: 135 | raise ValueError(f"Unsupported PFM model: {PFM_name}") 136 | 137 | return model 138 | 139 | 140 | class Conv2dReLU(nn.Sequential): 141 | """Convolution layer with batch normalization and ReLU activation.""" 142 | 143 | def __init__(self, in_channels: int, out_channels: int, kernel_size: int, 144 | padding: int = 0, stride: int = 1, use_batchnorm: bool = True): 145 | conv = nn.Conv2d( 146 | in_channels, 147 | out_channels, 148 | kernel_size, 149 | stride=stride, 150 | padding=padding, 151 | bias=not use_batchnorm, 152 | ) 153 | relu = nn.ReLU(inplace=True) 154 | bn = nn.BatchNorm2d(out_channels) 155 | 156 | super(Conv2dReLU, self).__init__(conv, bn, relu) 157 | 158 | 159 | class DecoderBlock(nn.Module): 160 | """Decoder block for upsampling and feature fusion.""" 161 | 162 | def __init__(self, in_channels: int, out_channels: int, skip_channels: int = 0, 163 | use_batchnorm: bool = True, scale: float = 2): 164 | super().__init__() 165 | self.conv1 = Conv2dReLU( 166 | in_channels + skip_channels, 167 | out_channels, 168 | kernel_size=3, 169 | padding=1, 170 | use_batchnorm=use_batchnorm, 171 | ) 172 | self.conv2 = Conv2dReLU( 173 | out_channels, 174 | out_channels, 175 | kernel_size=3, 176 | padding=1, 177 | use_batchnorm=use_batchnorm, 178 | ) 179 | self.up = nn.UpsamplingBilinear2d(scale_factor=scale) 180 | # self.up = nn.Upsample(scale_factor=scale, mode='nearest') 181 | 182 | 183 | 184 | def forward(self, x: torch.Tensor, skip: Optional[torch.Tensor] = None) -> torch.Tensor: 185 | x = self.up(x) 186 | if skip is not None: 187 | x = torch.cat([x, skip], dim=1) 188 | x = self.conv1(x) 189 | x = self.conv2(x) 190 | return x 191 | 192 | 193 | class SegmentationHead(nn.Sequential): 194 | """Segmentation head for final prediction.""" 195 | 196 | def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, upsampling: int = 1): 197 | conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) 198 | upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity() 199 | # upsampling = nn.Upsample(scale_factor=upsampling, mode='nearest') if upsampling > 1 else nn.Identity() 200 | super().__init__(conv2d, upsampling) 201 | 202 | 203 | class DecoderCup(nn.Module): 204 | """Decoder network for feature reconstruction and upsampling.""" 205 | 206 | def __init__(self, emb_dim: int, decoder_channels: Tuple[int, ...], is_virchow_v2_or_uni_v2: bool = False): 207 | super().__init__() 208 | head_channels = 512 209 | self.decoder_channels = decoder_channels 210 | 211 | self.conv_more = Conv2dReLU( 212 | emb_dim, 213 | head_channels, 214 | kernel_size=3, 215 | padding=1, 216 | use_batchnorm=True, 217 | ) 218 | 219 | in_channels = [head_channels] + list(decoder_channels[:-1]) 220 | out_channels = decoder_channels 221 | skip_channels = [0, 0, 0, 0] # No skip connections in current implementation 222 | 223 | blocks = [ 224 | DecoderBlock(in_ch, out_ch, sk_ch) 225 | for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels) 226 | ] 227 | 228 | # Special handling for Virchow v2 model 229 | if is_virchow_v2_or_uni_v2: 230 | blocks[-1] = DecoderBlock(in_channels[-1], out_channels[-1], skip_channels[-1], scale=1.75) 231 | 232 | self.blocks = nn.ModuleList(blocks) 233 | 234 | def forward(self, hidden_states: torch.Tensor, features: Optional[torch.Tensor] = None) -> torch.Tensor: 235 | """ 236 | Forward pass through decoder. 237 | 238 | Args: 239 | hidden_states (torch.Tensor): Encoded features from transformer (B, n_patch, hidden) 240 | features (Optional[torch.Tensor]): Skip connection features (not used currently) 241 | 242 | Returns: 243 | torch.Tensor: Decoded feature maps 244 | """ 245 | B, n_patch, hidden = hidden_states.size() 246 | h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch)) 247 | 248 | # Reshape from (B, n_patch, hidden) to (B, hidden, h, w) 249 | x = hidden_states.permute(0, 2, 1) 250 | x = x.contiguous().view(B, hidden, h, w) 251 | x = self.conv_more(x) 252 | 253 | for decoder_block in self.blocks: 254 | x = decoder_block(x, skip=None) 255 | 256 | return x 257 | 258 | 259 | class PFMSegmentationModel(nn.Module): 260 | """ 261 | Pathology Foundation Model for Semantic Segmentation. 262 | 263 | This model integrates various pathology foundation models with a decoder 264 | network for pixel-level segmentation tasks. 265 | 266 | Args: 267 | PFM_name (str): Name of the pathology foundation model 268 | PFM_weights_path (str): Path to pretrained weights 269 | emb_dim (int): Embedding dimension of the PFM 270 | num_classes (int): Number of segmentation classes 271 | """ 272 | 273 | def __init__(self, PFM_name: str, PFM_weights_path: str, emb_dim: int, num_classes: int = 2): 274 | super(PFMSegmentationModel, self).__init__() 275 | 276 | self.num_classes = num_classes 277 | self.PFM_name = PFM_name 278 | self.decoder_channels = (256, 128, 64, 16) 279 | 280 | # Create decoder 281 | if PFM_name == 'virchow_v2' or PFM_name == 'uni_v2': 282 | self.decoder = DecoderCup(emb_dim, self.decoder_channels, is_virchow_v2_or_uni_v2 = True) 283 | else: 284 | self.decoder = DecoderCup(emb_dim, self.decoder_channels) 285 | 286 | # Create segmentation head 287 | self.segmentation_head = SegmentationHead( 288 | in_channels=self.decoder_channels[-1], 289 | out_channels=num_classes, 290 | kernel_size=3, 291 | ) 292 | 293 | # Load pathology foundation model 294 | self.pfm = get_PFM_model(PFM_name, PFM_weights_path) 295 | 296 | def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: 297 | """ 298 | Forward pass through the model. 299 | 300 | Args: 301 | x (torch.Tensor): Input images of shape (B, C, H, W) 302 | 303 | Returns: 304 | Dict[str, torch.Tensor]: Dictionary containing output predictions 305 | """ 306 | # Handle single channel images by repeating to 3 channels 307 | if x.size(1) == 1: 308 | x = x.repeat(1, 3, 1, 1) 309 | 310 | # Extract features from pathology foundation model 311 | if self.PFM_name == 'virchow_v2': 312 | # Skip first 5 tokens (CLS and register tokens) 313 | features = self.pfm(x)[:, 5:, :] 314 | elif self.PFM_name == 'conch_v1_5': 315 | # Skip CLS token 316 | features = self.pfm.trunk.forward_features(x)[:, 1:, :] 317 | elif self.PFM_name == 'uni_v2': 318 | # Skip first 9 tokens (CLS and register tokens) 319 | features = self.pfm.forward_features(x)[:, 9:, :] 320 | else: 321 | # Standard ViT - skip CLS token 322 | features = self.pfm.forward_features(x)[:, 1:, :] 323 | 324 | # Decode features 325 | decoded_features = self.decoder(features) 326 | 327 | # Generate final predictions 328 | logits = self.segmentation_head(decoded_features) 329 | 330 | return {'out': logits} 331 | 332 | def get_feature_maps(self, x: torch.Tensor) -> torch.Tensor: 333 | """ 334 | Extract intermediate feature maps for visualization. 335 | 336 | Args: 337 | x (torch.Tensor): Input images 338 | 339 | Returns: 340 | torch.Tensor: Intermediate feature maps 341 | """ 342 | with torch.no_grad(): 343 | if x.size(1) == 1: 344 | x = x.repeat(1, 3, 1, 1) 345 | 346 | if self.PFM_name == 'virchow_v2': 347 | features = self.pfm(x)[:, 5:, :] 348 | elif self.PFM_name == 'conch_v1_5': 349 | features = self.pfm.trunk.forward_features(x)[:, 1:, :] 350 | elif self.PFM_name == 'uni_v2': 351 | features = self.pfm.forward_features(x)[:, 9:, :] 352 | else: 353 | features = self.pfm.forward_features(x)[:, 1:, :] 354 | 355 | # Reshape to spatial format 356 | B, n_patch, hidden = features.size() 357 | h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch)) 358 | features = features.permute(0, 2, 1).contiguous().view(B, hidden, h, w) 359 | 360 | return features 361 | 362 | 363 | def create_pfm_segmentation_model(model_config: Dict[str, Any]) -> PFMSegmentationModel: 364 | """ 365 | Factory function to create PFM segmentation model. 366 | 367 | Args: 368 | model_config (Dict[str, Any]): Model configuration dictionary 369 | 370 | Returns: 371 | PFMSegmentationModel: Configured PFM segmentation model 372 | """ 373 | required_keys = ['pfm_name', 'pfm_weights_path', 'emb_dim','num_classes','finetune_mode'] 374 | for key in required_keys: 375 | if key not in model_config: 376 | raise ValueError(f"Missing required configuration key: {key}") 377 | 378 | pfm_seg_model = PFMSegmentationModel( 379 | PFM_name=model_config['pfm_name'], 380 | PFM_weights_path=model_config['pfm_weights_path'], 381 | emb_dim=model_config['emb_dim'], 382 | num_classes=model_config.get('num_classes', 2)) 383 | finetune_mode = model_config['finetune_mode'].get('type') 384 | if finetune_mode == 'frozen': 385 | for param in pfm_seg_model.pfm.parameters(): 386 | param.requires_grad = False 387 | elif finetune_mode == 'lora': 388 | lora_rank = model_config['finetune_mode'].get('rank') 389 | lora_alpha = model_config['finetune_mode'].get('alpha') 390 | for param in pfm_seg_model.pfm.parameters(): 391 | param.requires_grad = False 392 | pfm_seg_model.pfm = equip_model_with_lora(model_config['pfm_name'], pfm_seg_model.pfm, rank=lora_rank, alpha=lora_alpha) 393 | elif finetune_mode == 'full': 394 | pass 395 | return pfm_seg_model 396 | 397 | 398 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model Utilities for Semantic Segmentation 3 | 4 | This module contains utility functions for model management, including 5 | model creation, weight initialization, and checkpoint handling. 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | from typing import Dict, Any, Optional, Union 11 | import os 12 | 13 | 14 | 15 | def count_parameters(model: nn.Module) -> Dict[str, float]: 16 | """ 17 | Count the number of parameters in a model and return in millions (M). 18 | 19 | Args: 20 | model (nn.Module): PyTorch model 21 | 22 | Returns: 23 | Dict[str, float]: Dictionary containing parameter counts in millions (M) 24 | with 2 decimal places precision 25 | """ 26 | total_params = sum(p.numel() for p in model.parameters()) / 1e6 # Convert to millions 27 | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6 28 | 29 | return { 30 | 'total_parameters(M)': round(total_params, 2), 31 | 'trainable_parameters(M)': round(trainable_params, 2), 32 | 'non_trainable_parameters(M)': round(total_params - trainable_params, 2) 33 | } 34 | 35 | 36 | def initialize_weights(model: nn.Module, init_type: str = 'kaiming') -> None: 37 | """ 38 | Initialize model weights with specified initialization method. 39 | 40 | Args: 41 | model (nn.Module): PyTorch model 42 | init_type (str): Initialization type ('kaiming', 'xavier', 'normal', 'zero') 43 | """ 44 | for m in model.modules(): 45 | if isinstance(m, nn.Conv2d): 46 | if init_type == 'kaiming': 47 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 48 | elif init_type == 'xavier': 49 | nn.init.xavier_normal_(m.weight) 50 | elif init_type == 'normal': 51 | nn.init.normal_(m.weight, 0, 0.01) 52 | elif init_type == 'zero': 53 | nn.init.zeros_(m.weight) 54 | 55 | if m.bias is not None: 56 | nn.init.constant_(m.bias, 0) 57 | 58 | elif isinstance(m, nn.BatchNorm2d): 59 | nn.init.constant_(m.weight, 1) 60 | nn.init.constant_(m.bias, 0) 61 | 62 | elif isinstance(m, nn.Linear): 63 | if init_type == 'kaiming': 64 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 65 | elif init_type == 'xavier': 66 | nn.init.xavier_normal_(m.weight) 67 | elif init_type == 'normal': 68 | nn.init.normal_(m.weight, 0, 0.01) 69 | elif init_type == 'zero': 70 | nn.init.zeros_(m.weight) 71 | 72 | if m.bias is not None: 73 | nn.init.constant_(m.bias, 0) 74 | 75 | 76 | def save_checkpoint(model: nn.Module, optimizer: torch.optim.Optimizer, 77 | scheduler: Optional[torch.optim.lr_scheduler._LRScheduler], 78 | epoch: int, loss: float, metrics: Dict[str, float], 79 | checkpoint_path: str, is_best: bool = False) -> None: 80 | """ 81 | Save model checkpoint. 82 | 83 | Args: 84 | model (nn.Module): PyTorch model 85 | optimizer (torch.optim.Optimizer): Optimizer 86 | scheduler (Optional[torch.optim.lr_scheduler._LRScheduler]): Learning rate scheduler 87 | epoch (int): Current epoch 88 | loss (float): Current loss value 89 | metrics (Dict[str, float]): Evaluation metrics 90 | checkpoint_path (str): Path to save checkpoint 91 | is_best (bool): Whether this is the best checkpoint 92 | """ 93 | checkpoint = { 94 | 'epoch': epoch, 95 | 'model_state_dict': model.state_dict(), 96 | 'optimizer_state_dict': optimizer.state_dict(), 97 | 'scheduler_state_dict': scheduler.state_dict() if scheduler else None, 98 | 'loss': loss, 99 | 'metrics': metrics 100 | } 101 | 102 | torch.save(checkpoint, checkpoint_path) 103 | 104 | if is_best: 105 | best_path = os.path.join(os.path.dirname(checkpoint_path), 'best_model.pth') 106 | torch.save(checkpoint, best_path) 107 | 108 | 109 | def load_checkpoint(model: nn.Module, checkpoint_path: str, 110 | optimizer: Optional[torch.optim.Optimizer] = None, 111 | scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, 112 | device: str = 'cpu') -> Dict[str, Any]: 113 | """ 114 | Load model checkpoint. 115 | 116 | Args: 117 | model (nn.Module): PyTorch model 118 | checkpoint_path (str): Path to checkpoint file 119 | optimizer (Optional[torch.optim.Optimizer]): Optimizer to load state 120 | scheduler (Optional[torch.optim.lr_scheduler._LRScheduler]): Scheduler to load state 121 | device (str): Device to load checkpoint on 122 | 123 | Returns: 124 | Dict[str, Any]: Checkpoint information 125 | """ 126 | if not os.path.exists(checkpoint_path): 127 | raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") 128 | 129 | checkpoint = torch.load(checkpoint_path, map_location=device) 130 | 131 | # Load model state 132 | model.load_state_dict(checkpoint['model_state_dict']) 133 | 134 | # Load optimizer state if provided 135 | if optimizer and 'optimizer_state_dict' in checkpoint: 136 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 137 | 138 | # Load scheduler state if provided 139 | if scheduler and 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict']: 140 | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 141 | 142 | return { 143 | 'epoch': checkpoint.get('epoch', 0), 144 | 'loss': checkpoint.get('loss', float('inf')), 145 | 'metrics': checkpoint.get('metrics', {}) 146 | } 147 | 148 | 149 | def get_model_complexity(model: nn.Module, input_size: tuple = (1, 3, 512, 512)) -> Dict[str, Any]: 150 | """ 151 | Analyze model complexity including parameters, FLOPs, and memory usage. 152 | 153 | Args: 154 | model (nn.Module): PyTorch model 155 | input_size (tuple): Input tensor size 156 | 157 | Returns: 158 | Dict[str, Any]: Model complexity metrics 159 | """ 160 | # Count parameters 161 | param_stats = count_parameters(model) 162 | 163 | # Estimate model size in MB 164 | param_size = sum(p.numel() * p.element_size() for p in model.parameters()) 165 | buffer_size = sum(b.numel() * b.element_size() for b in model.buffers()) 166 | model_size_mb = (param_size + buffer_size) / (1024 ** 2) 167 | 168 | # Create dummy input for memory estimation 169 | dummy_input = torch.randn(input_size) 170 | model.eval() 171 | 172 | with torch.no_grad(): 173 | # Estimate memory usage 174 | torch.cuda.empty_cache() if torch.cuda.is_available() else None 175 | 176 | if torch.cuda.is_available(): 177 | model = model.cuda() 178 | dummy_input = dummy_input.cuda() 179 | 180 | # Measure memory before forward pass 181 | torch.cuda.synchronize() 182 | mem_before = torch.cuda.memory_allocated() 183 | 184 | # Forward pass 185 | _ = model(dummy_input) 186 | 187 | # Measure memory after forward pass 188 | torch.cuda.synchronize() 189 | mem_after = torch.cuda.memory_allocated() 190 | 191 | memory_usage_mb = (mem_after - mem_before) / (1024 ** 2) 192 | else: 193 | # CPU memory estimation (approximate) 194 | output = model(dummy_input) 195 | memory_usage_mb = sum( 196 | tensor.numel() * tensor.element_size() 197 | for tensor in [dummy_input, output['out']] 198 | ) / (1024 ** 2) 199 | 200 | return { 201 | 'parameters': param_stats, 202 | 'model_size_mb': model_size_mb, 203 | 'memory_usage_mb': memory_usage_mb, 204 | 'input_size': input_size 205 | } 206 | 207 | 208 | def convert_to_onnx(model: nn.Module, output_path: str, 209 | input_size: tuple = (1, 3, 512, 512), 210 | opset_version: int = 11) -> None: 211 | """ 212 | Convert PyTorch model to ONNX format. 213 | 214 | Args: 215 | model (nn.Module): PyTorch model 216 | output_path (str): Path to save ONNX model 217 | input_size (tuple): Input tensor size 218 | opset_version (int): ONNX opset version 219 | """ 220 | try: 221 | import onnx 222 | import onnxruntime 223 | except ImportError: 224 | raise ImportError("ONNX and ONNXRuntime are required for ONNX conversion") 225 | 226 | model.eval() 227 | dummy_input = torch.randn(input_size) 228 | 229 | # Export to ONNX 230 | torch.onnx.export( 231 | model, 232 | dummy_input, 233 | output_path, 234 | export_params=True, 235 | opset_version=opset_version, 236 | do_constant_folding=True, 237 | input_names=['input'], 238 | output_names=['output'], 239 | dynamic_axes={ 240 | 'input': {0: 'batch_size'}, 241 | 'output': {0: 'batch_size'} 242 | } 243 | ) 244 | 245 | # Verify ONNX model 246 | onnx_model = onnx.load(output_path) 247 | onnx.checker.check_model(onnx_model) 248 | 249 | print(f"Model successfully converted to ONNX: {output_path}") 250 | 251 | 252 | def print_model_summary(model: nn.Module, input_size: tuple = (1, 3, 512, 512)) -> None: 253 | """ 254 | Print a comprehensive model summary. 255 | 256 | Args: 257 | model (nn.Module): PyTorch model 258 | input_size (tuple): Input tensor size 259 | """ 260 | print("=" * 80) 261 | print("MODEL SUMMARY") 262 | print("=" * 80) 263 | 264 | # Model architecture 265 | print(f"Model: {model.__class__.__name__}") 266 | print(f"Input size: {input_size}") 267 | 268 | # Parameter statistics 269 | param_stats = count_parameters(model) 270 | print(f"Total parameters: {param_stats['total_parameters']:,}") 271 | print(f"Trainable parameters: {param_stats['trainable_parameters']:,}") 272 | print(f"Non-trainable parameters: {param_stats['non_trainable_parameters']:,}") 273 | 274 | # Model complexity 275 | complexity = get_model_complexity(model, input_size) 276 | print(f"Model size: {complexity['model_size_mb']:.2f} MB") 277 | print(f"Memory usage: {complexity['memory_usage_mb']:.2f} MB") 278 | 279 | print("=" * 80) 280 | 281 | 282 | 283 | 284 | if __name__ == "__main__": 285 | # Test model utilities 286 | from .PFM import create_pfm_model 287 | 288 | # Create a test model 289 | model = create_pfm_model(num_classes=19, img_size=512) 290 | 291 | # Print model summary 292 | print_model_summary(model) 293 | 294 | # Test other utilities 295 | initialize_weights(model, 'kaiming') 296 | print("Model weights initialized successfully") 297 | -------------------------------------------------------------------------------- /scripts/infer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Inference Script for Semantic Segmentation with two modes: 4 | 1. Resize-based inference (for fixed-size inputs) 5 | 2. Sliding window inference (for large or variable-size inputs) 6 | 7 | Features: 8 | - Supports batch processing with DataLoader 9 | - Handles both resizing and sliding window approaches 10 | - Includes visualization utilities for predictions 11 | 12 | Author: @Toby 13 | Function: Inference for semantic segmentation models 14 | """ 15 | 16 | import argparse 17 | import os 18 | import sys 19 | import yaml 20 | import torch 21 | import torch.nn.functional as F 22 | import numpy as np 23 | import cv2 24 | import json 25 | import logging 26 | from PIL import Image 27 | from typing import Dict, Any, List, Tuple 28 | import tqdm 29 | import warnings 30 | warnings.filterwarnings("ignore", category=UserWarning, module='torchvision') 31 | 32 | # Add project root to path 33 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 34 | from data.seg_dataset import JSONSegmentationDataset 35 | from data.utils import create_dataloader 36 | from data.transforms import SegmentationTransforms 37 | from utils.metrics import SegmentationMetrics 38 | from models.pfm_seg_models import create_pfm_segmentation_model 39 | from utils.visualization import apply_color_map, create_color_palette, put_text_with_bg 40 | from utils.logs import setup_logging 41 | 42 | 43 | def parse_args() -> argparse.Namespace: 44 | """Parse command line arguments for inference configuration.""" 45 | parser = argparse.ArgumentParser(description='Semantic Segmentation Inference Script') 46 | parser.add_argument('--config', type=str, 47 | default='/mnt/sdb/lxt/PFM_Seg/logs/crag_conch_v1_5_v2/config.yaml', 48 | help='Path to config YAML file') 49 | parser.add_argument('--checkpoint', type=str, 50 | default='/mnt/sdb/lxt/PFM_Seg/logs/crag_conch_v1_5_v2/checkpoints/best_model.pth', 51 | help='Path to model checkpoint') 52 | parser.add_argument('--input_json', type=str, 53 | default='/mnt/sdb/lxt/PFM_Seg/CRAG/CRAG_no_mask.json', 54 | help='Path to JSON file containing input data') 55 | parser.add_argument('--output_dir', type=str, 56 | default='/mnt/sdb/lxt/PFM_Seg/logs/crag_conch_v1_5_v2/inference_slidewindow', 57 | help='Directory to save inference results') 58 | parser.add_argument('--device', type=str, default='cuda:0', 59 | help='Device for inference (e.g., "cuda:0" or "cpu")') 60 | parser.add_argument('--input_size', type=int, default=512, 61 | help='Input size for resize or window size for sliding window') 62 | parser.add_argument('--resize_or_windowslide', type=str, 63 | choices=['resize', 'windowslide'], default='windowslide', 64 | help='Inference mode: resize or sliding window') 65 | parser.add_argument('--batch_size', type=int, default=2, 66 | help='Batch size for inference') 67 | return parser.parse_args() 68 | 69 | 70 | def load_config(config_path: str) -> Dict[str, Any]: 71 | """Load YAML configuration file.""" 72 | with open(config_path, 'r') as f: 73 | return yaml.safe_load(f) 74 | 75 | 76 | def get_device(device_str: str) -> torch.device: 77 | """Get PyTorch device from string descriptor.""" 78 | return torch.device(device_str if torch.cuda.is_available() else 'cpu') 79 | 80 | 81 | def load_model(config: Dict[str, Any], checkpoint_path: str, device: torch.device) -> torch.nn.Module: 82 | """ 83 | Load model from checkpoint with configuration. 84 | 85 | Args: 86 | config: Model configuration dictionary 87 | checkpoint_path: Path to model checkpoint 88 | device: Target device for model 89 | 90 | Returns: 91 | Loaded and configured model in evaluation mode 92 | """ 93 | model = create_pfm_segmentation_model(config['model']).to(device) 94 | checkpoint = torch.load(checkpoint_path, map_location=device) 95 | model.load_state_dict(checkpoint.get('model_state_dict', checkpoint)) 96 | model.eval() 97 | return model 98 | 99 | 100 | def postprocess(image_paths: List[str], pred_masks: List[np.ndarray], 101 | label_paths: List[str], preds_dir: str, overlap_dir: str, 102 | palette: np.ndarray) -> None: 103 | """ 104 | Post-process and visualize inference results. 105 | 106 | Args: 107 | image_paths: List of input image paths 108 | pred_masks: List of predicted masks (2D numpy arrays) 109 | label_paths: List of ground truth label paths 110 | preds_dir: Directory to save prediction masks 111 | overlap_dir: Directory to save visualization overlays 112 | palette: Color palette for visualization 113 | """ 114 | for i in range(len(image_paths)): 115 | # Process predicted mask 116 | pred_mask = pred_masks[i] 117 | 118 | # Apply color mapping 119 | pred_colored = apply_color_map(pred_mask, palette) 120 | 121 | 122 | # Save prediction mask 123 | Image.fromarray(pred_mask.astype(np.uint8)).save( 124 | os.path.join(preds_dir, os.path.basename(image_paths[i]))) 125 | 126 | if label_paths[i] is not None: 127 | # Load and process original image 128 | original_image = Image.open(image_paths[i]).convert('RGB') 129 | original_np = np.array(original_image) 130 | 131 | # Create overlays 132 | label_mask = np.array(Image.open(label_paths[i])) 133 | label_colored = apply_color_map(label_mask, palette) 134 | overlay_label = cv2.addWeighted(original_np, 0.5, label_colored, 0.5, 0) 135 | overlay_pred = cv2.addWeighted(original_np, 0.5, pred_colored, 0.5, 0) 136 | 137 | # Add annotations 138 | put_text_with_bg(overlay_label, "Label", position=(10, 40)) 139 | put_text_with_bg(overlay_pred, "Prediction", position=(10, 40)) 140 | 141 | # Combine side-by-side 142 | combined = np.concatenate([overlay_label, overlay_pred], axis=1) 143 | 144 | # Save visualization 145 | Image.fromarray(combined).save( 146 | os.path.join(overlap_dir, os.path.basename(image_paths[i]))) 147 | 148 | 149 | def resizeMode_inference(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, 150 | device: torch.device, output_dir: str, palette: np.ndarray, seg_metrics: SegmentationMetrics) -> None: 151 | """ 152 | Perform inference using resize-based approach. 153 | 154 | Args: 155 | model: Loaded segmentation model 156 | dataloader: DataLoader providing input batches 157 | device: Target device for computation 158 | output_dir: Base directory for saving results 159 | palette: Color palette for visualization 160 | seg_metrics: Segmentation metrics object for evaluation 161 | """ 162 | preds_dir = os.path.join(output_dir, 'predictions_masks') 163 | overlap_dir = os.path.join(output_dir, 'predictions_overlays') 164 | os.makedirs(preds_dir, exist_ok=True) 165 | os.makedirs(overlap_dir, exist_ok=True) 166 | seg_metrics.reset() 167 | with torch.no_grad(): 168 | for batch in tqdm.tqdm(dataloader, desc="Inference Progress"): 169 | images = batch['image'].to(device) 170 | image_paths = batch['image_path'] 171 | label_paths = batch['label_path'] 172 | ori_sizes = batch['ori_size'] 173 | # Forward pass 174 | preds = model(images)['out'] 175 | 176 | # Process predictions 177 | pred_masks = [torch.argmax(pred, dim=0).cpu().numpy() for pred in preds] 178 | pred_masks = [cv2.resize(pred_mask, (ori_sizes[i][0], ori_sizes[i][1]), 179 | interpolation=cv2.INTER_NEAREST) for i, pred_mask in enumerate(pred_masks)] 180 | _pred_masks = [torch.tensor(mask) for mask in pred_masks] 181 | if None not in label_paths: 182 | labels = torch.stack([maskPath2tensor(path, device) for path in label_paths], dim=0) # [B, H, W] 183 | seg_metrics.update(torch.stack(_pred_masks, dim=0).to(device), labels) 184 | 185 | # Save results 186 | postprocess(image_paths, pred_masks, label_paths, preds_dir, overlap_dir, palette) 187 | return seg_metrics.compute() 188 | 189 | 190 | def slideWindow_preprocess(image: torch.Tensor, window_size: int, stride: int) -> Tuple[torch.Tensor, torch.Tensor]: 191 | """ 192 | Split image into sliding window patches. 193 | 194 | Args: 195 | image: Input tensor of shape [B, 3, H, W] 196 | window_size: Size of sliding window (square) 197 | stride: Step size between windows 198 | 199 | Returns: 200 | patches: Tensor of patches [A, 3, window_size, window_size] 201 | coords: Tensor of patch coordinates [A, 2] (x, y) 202 | """ 203 | B, C, H, W = image.shape 204 | all_patches = [] 205 | all_coords = [] 206 | 207 | # Calculate unique window positions 208 | y_positions = [] 209 | for y in range(0, H, stride): 210 | if y + window_size > H: 211 | y = H - window_size 212 | if y not in y_positions: 213 | y_positions.append(y) 214 | 215 | x_positions = [] 216 | for x in range(0, W, stride): 217 | if x + window_size > W: 218 | x = W - window_size 219 | if x not in x_positions: 220 | x_positions.append(x) 221 | 222 | # Extract patches 223 | for b in range(B): 224 | for y in y_positions: 225 | for x in x_positions: 226 | patch = image[b, :, y:y+window_size, x:x+window_size] 227 | all_patches.append(patch.unsqueeze(0)) 228 | all_coords.append([x, y]) 229 | 230 | patches = torch.cat(all_patches, dim=0) 231 | coords = torch.tensor(all_coords, dtype=torch.int) 232 | 233 | return patches, coords 234 | 235 | 236 | def slideWindow_merge(patches_pred: torch.Tensor, window_size: int, stride: int, 237 | coords: torch.Tensor, batch_size: int) -> torch.Tensor: 238 | """ 239 | Merge sliding window predictions into full-size output. 240 | 241 | Args: 242 | patches_pred: Patch predictions [A, num_classes, window_size, window_size] 243 | window_size: Size of sliding window 244 | stride: Step size used between windows 245 | coords: Patch coordinates [A, 2] (x, y) 246 | batch_size: Original number of images in batch 247 | 248 | Returns: 249 | merged: Reconstructed predictions [B, num_classes, H, W] 250 | """ 251 | A, num_classes, _, _ = patches_pred.shape 252 | device = patches_pred.device 253 | patches_per_image = A // batch_size 254 | coords = coords.to(device) 255 | 256 | # Calculate output dimensions 257 | max_x = coords[:, 0].max().item() + window_size 258 | max_y = coords[:, 1].max().item() + window_size 259 | H, W = max_y, max_x 260 | 261 | # Initialize output buffers 262 | merged = torch.zeros((batch_size, num_classes, H, W), 263 | dtype=patches_pred.dtype, device=device) 264 | count = torch.zeros((batch_size, 1, H, W), 265 | dtype=patches_pred.dtype, device=device) 266 | 267 | # Accumulate predictions 268 | for idx in range(A): 269 | b = idx // patches_per_image 270 | x, y = coords[idx] 271 | merged[b, :, y:y+window_size, x:x+window_size] += patches_pred[idx] 272 | count[b, :, y:y+window_size, x:x+window_size] += 1 273 | 274 | # Normalize overlapping regions 275 | count = torch.clamp(count, min=1.0) 276 | merged = merged / count 277 | 278 | return merged 279 | 280 | def maskPath2tensor(mask_path: str, device: torch.device) -> torch.Tensor: 281 | """ 282 | Load a mask image from path and convert to tensor. 283 | 284 | Args: 285 | mask_path: Path to the mask image 286 | device: Target device for tensor 287 | 288 | Returns: 289 | Tensor of shape [1, H, W] with mask values 290 | """ 291 | mask = Image.open(mask_path).convert('L') 292 | mask_tensor = torch.tensor(np.array(mask), dtype=torch.long, device=device) 293 | return mask_tensor.unsqueeze(0) # Add batch dimension 294 | 295 | def slideWindowMode_inference(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, 296 | device: torch.device, output_dir: str, palette: np.ndarray, 297 | seg_metrics: SegmentationMetrics, 298 | window_size: int, overlap: float = 0.2) -> SegmentationMetrics: 299 | """ 300 | Perform inference using sliding window approach. 301 | 302 | Args: 303 | model: Loaded segmentation model 304 | dataloader: DataLoader providing input batches 305 | device: Target device for computation 306 | output_dir: Base directory for saving results 307 | palette: Color palette for visualization 308 | seg_metrics: Segmentation metrics object for evaluation 309 | window_size: Size of sliding window 310 | overlap: Overlap ratio between windows (0-1) 311 | """ 312 | preds_dir = os.path.join(output_dir, 'predictions_masks') 313 | overlap_dir = os.path.join(output_dir, 'predictions_overlays') 314 | os.makedirs(preds_dir, exist_ok=True) 315 | os.makedirs(overlap_dir, exist_ok=True) 316 | seg_metrics.reset() 317 | with torch.no_grad(): 318 | for batch in tqdm.tqdm(dataloader,desc="Inference Progress"): 319 | images = batch['image'].to(device) 320 | batch_size = images.shape[0] 321 | stride = int(window_size * (1 - overlap)) 322 | 323 | # Process with sliding window 324 | patches, coords = slideWindow_preprocess(images, window_size, stride) 325 | image_paths = batch['image_path'] 326 | label_paths = batch['label_path'] 327 | # Predict and merge 328 | patches_preds = model(patches)['out'] 329 | preds = slideWindow_merge(patches_preds, window_size, stride, coords, batch_size) 330 | # Process results 331 | pred_masks = [torch.argmax(pred, dim=0) for pred in preds] 332 | _pred_masks = torch.stack(pred_masks, dim=0) 333 | if None not in label_paths: 334 | labels = torch.stack([maskPath2tensor(path, device) for path in label_paths], dim=0) # [B, H, W] 335 | seg_metrics.update(_pred_masks, labels) 336 | pred_masks = [pred_mask.cpu().numpy() for pred_mask in pred_masks] 337 | postprocess(image_paths, pred_masks, label_paths, preds_dir, overlap_dir, palette) 338 | return seg_metrics.compute() 339 | 340 | 341 | def run_inference(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, 342 | output_dir: str, num_classes: int, device: torch.device, 343 | resize_or_windowslide: str, input_size: int, ignore_index: int = 255) -> Dict[str, float]: 344 | """ 345 | Main inference runner that dispatches to appropriate mode. 346 | 347 | Args: 348 | model: Loaded segmentation model 349 | dataloader: DataLoader providing input batches 350 | output_dir: Directory to save results 351 | num_classes: Number of segmentation classes 352 | device: Target device for computation 353 | resize_or_windowslide: Inference mode ('resize' or 'windowslide') 354 | input_size: Size parameter (resize dim or window size) 355 | """ 356 | palette = create_color_palette(num_classes) 357 | os.makedirs(output_dir, exist_ok=True) 358 | seg_metrics = SegmentationMetrics(num_classes, device=device, ignore_index = ignore_index) 359 | 360 | if resize_or_windowslide == 'resize': 361 | metrics = resizeMode_inference(model, dataloader, device, output_dir, palette, seg_metrics) 362 | elif resize_or_windowslide == 'windowslide': 363 | metrics = slideWindowMode_inference(model, dataloader, device, output_dir, palette, seg_metrics, input_size) 364 | return metrics 365 | 366 | def main() -> None: 367 | """Main execution function for inference script.""" 368 | args = parse_args() 369 | log_dir = args.output_dir 370 | setup_logging(log_dir) 371 | logger = logging.getLogger(__name__) 372 | config = load_config(args.config) 373 | device = get_device(args.device) 374 | 375 | logger.info("Loading model...") 376 | model = load_model(config, args.checkpoint, device) 377 | 378 | logger.info("Loading transforms...") 379 | if args.resize_or_windowslide == 'resize': 380 | test_transforms = SegmentationTransforms.get_validation_transforms( 381 | img_size=args.input_size, 382 | mean=config['model']['mean'], 383 | std=config['model']['std'] 384 | ) 385 | elif args.resize_or_windowslide == 'windowslide': 386 | test_transforms = SegmentationTransforms.get_validation_transforms( 387 | img_size=None, 388 | mean=config['model']['mean'], 389 | std=config['model']['std'] 390 | ) 391 | 392 | logger.info("Preparing dataset...") 393 | test_dataset = JSONSegmentationDataset( 394 | json_file=args.input_json, split='test', transform=test_transforms) 395 | 396 | # Adjust batch size for sliding window if needed 397 | infer_batch_size = args.batch_size 398 | if args.resize_or_windowslide == 'windowslide' and not test_dataset.fixed_size: 399 | infer_batch_size = 1 # Force batch size 1 for variable size inputs 400 | 401 | test_dataloader = create_dataloader( 402 | test_dataset, 403 | batch_size=infer_batch_size, 404 | shuffle=False, 405 | num_workers=config['system'].get('num_workers', 4), 406 | pin_memory=config['system'].get('pin_memory', True), 407 | drop_last=False 408 | ) 409 | 410 | logger.info("Running inference...") 411 | metrics = run_inference( 412 | model, test_dataloader, args.output_dir, 413 | config['model']['num_classes'], device, 414 | args.resize_or_windowslide, args.input_size, 415 | config['dataset'].get('ignore_index') 416 | ) 417 | logger.info("Inference completed successfully.") 418 | logger.info(f'Metrics:{metrics}') 419 | with open(os.path.join(args.output_dir, 'metrics.json'), 'w') as f: 420 | json.dump(metrics, f, indent=4) 421 | 422 | if __name__ == '__main__': 423 | main() -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Training Script for Semantic Segmentation 4 | 5 | This script provides a complete training pipeline for semantic segmentation models 6 | with support for various datasets, augmentations, loss functions, and optimization techniques. 7 | 8 | Author: @Toby 9 | Function: Train a semantic segmentation model using a configuration file. 10 | """ 11 | import warnings 12 | warnings.filterwarnings("ignore") 13 | import argparse 14 | import os 15 | import sys 16 | import yaml 17 | import torch 18 | import torch.nn as nn 19 | import torch.optim as optim 20 | from torch.utils.data import DataLoader 21 | import random 22 | import albumentations as A 23 | import numpy as np 24 | import logging 25 | from typing import Dict, Any 26 | 27 | # Add project root to path 28 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 29 | 30 | from models import create_pfm_segmentation_model,count_parameters 31 | from models.losses import get_loss_function 32 | # from data.datasets import get_dataset 33 | # from data.transforms import get_transforms 34 | from data.utils import create_dataloader 35 | from data.transforms import get_transforms,SegmentationTransforms 36 | from data.seg_dataset import get_dataset 37 | from utils.trainer import SegmentationTrainer 38 | from utils.visualization import plot_training_history 39 | from utils.logs import setup_logging 40 | 41 | 42 | 43 | def parse_args(): 44 | """Parse command line arguments.""" 45 | parser = argparse.ArgumentParser(description='Training script for semantic segmentation') 46 | 47 | parser.add_argument('--config', type=str, default='/mnt/sdb/lxt/PFM_Seg/base.yaml', 48 | help='Path to configuration file') 49 | parser.add_argument('--resume', type=str, default=None, 50 | help='Path to checkpoint to resume from') 51 | parser.add_argument('--device', type=str, default='cuda:3', 52 | help='Device to use (cuda/cpu/auto)') 53 | 54 | return parser.parse_args() 55 | 56 | 57 | def load_config(config_path: str) -> Dict[str, Any]: 58 | """ 59 | Load configuration from YAML file. 60 | 61 | Args: 62 | config_path (str): Path to configuration file 63 | 64 | Returns: 65 | Dict[str, Any]: Configuration dictionary 66 | """ 67 | with open(config_path, 'r') as f: 68 | config = yaml.safe_load(f) 69 | return config 70 | 71 | 72 | 73 | 74 | def set_random_seed(seed: int): 75 | """Set random seed for reproducibility.""" 76 | random.seed(seed) 77 | np.random.seed(seed) 78 | torch.manual_seed(seed) 79 | torch.cuda.manual_seed(seed) 80 | torch.cuda.manual_seed_all(seed) 81 | torch.backends.cudnn.deterministic = True 82 | torch.backends.cudnn.benchmark = False 83 | 84 | 85 | 86 | def get_device(device_arg: str) -> str: 87 | """ 88 | Get device for training. 89 | 90 | Args: 91 | device_arg (str): Device argument from command line 92 | 93 | Returns: 94 | str: Device string 95 | """ 96 | return torch.device(device_arg) 97 | 98 | def save_config(config: Dict[str, Any], save_path: str): 99 | """ 100 | Save configuration to a YAML file. 101 | 102 | Args: 103 | config (Dict[str, Any]): Configuration dictionary 104 | save_path (str): Path to save the configuration file 105 | """ 106 | with open(save_path, 'w') as f: 107 | yaml.dump(config, f, default_flow_style=False) 108 | print(f"Configuration saved to {save_path}") 109 | 110 | 111 | 112 | 113 | def worker_init_fn(worker_id): 114 | """Initialize worker with a random seed based on worker ID.""" 115 | seed = 42 # Base seed 116 | random.seed(seed + worker_id) 117 | np.random.seed(seed+ worker_id) 118 | torch.manual_seed(seed+ worker_id) 119 | torch.cuda.manual_seed(seed+ worker_id) 120 | torch.cuda.manual_seed_all(seed+ worker_id) 121 | 122 | 123 | def create_optimizer(model: nn.Module, config: Dict[str, Any]) -> optim.Optimizer: 124 | """ 125 | Create optimizer based on configuration. 126 | 127 | Args: 128 | model (nn.Module): Model to optimize 129 | config (Dict[str, Any]): Training configuration 130 | 131 | Returns: 132 | optim.Optimizer: Optimizer 133 | """ 134 | optimizer_config = config['training'].get('optimizer') 135 | optimizer_type = optimizer_config.get('type', 'SGD').lower() 136 | 137 | lr = config['training']['learning_rate'] 138 | weight_decay = config['training']['optimizer'].get('weight_decay', 1e-4) 139 | 140 | if optimizer_type == 'sgd': 141 | momentum = config['training'].get('momentum', 0.9) 142 | nesterov = optimizer_config.get('nesterov', True) 143 | 144 | return optim.SGD( 145 | model.parameters(), 146 | lr=lr, 147 | momentum=momentum, 148 | weight_decay=weight_decay, 149 | nesterov=nesterov 150 | ) 151 | 152 | elif optimizer_type == 'adam': 153 | betas = optimizer_config.get('betas', (0.9, 0.999)) 154 | eps = optimizer_config.get('eps', 1e-8) 155 | 156 | return optim.Adam( 157 | model.parameters(), 158 | lr=lr, 159 | betas=betas, 160 | eps=eps, 161 | weight_decay=weight_decay 162 | ) 163 | 164 | elif optimizer_type == 'adamw': 165 | betas = optimizer_config.get('betas', (0.9, 0.999)) 166 | eps = optimizer_config.get('eps', 1e-8) 167 | 168 | return optim.AdamW( 169 | model.parameters(), 170 | lr=lr, 171 | betas=betas, 172 | eps=eps, 173 | weight_decay=weight_decay 174 | ) 175 | 176 | else: 177 | raise ValueError(f"Unsupported optimizer type: {optimizer_type}") 178 | 179 | 180 | def main(): 181 | """Main training function.""" 182 | args = parse_args() 183 | 184 | # Load configuration 185 | config = load_config(args.config) 186 | 187 | # Set random seed 188 | seed = config['system'].get('seed', 42) 189 | set_random_seed(seed) 190 | generator = torch.Generator() 191 | generator.manual_seed(seed) 192 | 193 | # Get device 194 | device = get_device(args.device) 195 | 196 | # Setup logging 197 | log_dir = config['logging'].get('log_dir') 198 | experiment_name = config['logging'].get('experiment_name') 199 | log_dir = os.path.join(log_dir, experiment_name) 200 | os.makedirs(log_dir, exist_ok=True) 201 | save_config(config, os.path.join(log_dir, 'config.yaml')) 202 | setup_logging(log_dir) 203 | logger = logging.getLogger(__name__) 204 | 205 | logger.info("Starting training...") 206 | logger.info(f"Configuration file: {args.config}") 207 | logger.info(f"Device: {device}") 208 | logger.info(f"Random seed: {seed}") 209 | 210 | # Create model 211 | logger.info(f"Creating model: {config['model']['pfm_name']}...") 212 | logging.info(f"Model fintune-model: {config['model']['finetune_mode']}") 213 | model = create_pfm_segmentation_model(config['model']) 214 | model = model.to(device) 215 | 216 | # Log model information 217 | model_params_info_dict = count_parameters(model) 218 | logger.info(f"Model parameters info: {model_params_info_dict}") 219 | # Create datasets and data loaders 220 | logger.info("Creating datasets...") 221 | dataset_config = config['dataset'] 222 | 223 | # Training dataset 224 | # train_transforms = get_transforms(config['training']['augmentation']) 225 | train_transforms = SegmentationTransforms.get_training_transforms(img_size=config['training']['augmentation']['RandomResizedCropSize'],seed=seed,mean=config['model']['mean'],std=config['model']['std']) 226 | train_dataset = get_dataset(dataset_config, train_transforms, split='train') 227 | 228 | train_loader = create_dataloader( 229 | train_dataset, 230 | batch_size=config['training']['batch_size'], 231 | shuffle=True, 232 | generator=generator, 233 | num_workers=config['system'].get('num_workers', 4), 234 | pin_memory=config['system'].get('pin_memory', True), 235 | worker_init_fn=worker_init_fn, 236 | drop_last=False, 237 | ) 238 | 239 | # Validation dataset 240 | 241 | val_transforms = SegmentationTransforms.get_validation_transforms(img_size=config['validation']['augmentation']['ResizedSize'], mean=config['model']['mean'], std=config['model']['std']) 242 | val_dataset = get_dataset(dataset_config, val_transforms, split='val') 243 | 244 | val_loader = create_dataloader( 245 | val_dataset, 246 | batch_size=config['validation']['batch_size'], 247 | shuffle=False, 248 | num_workers=config['system'].get('num_workers', 4), 249 | pin_memory=config['system'].get('pin_memory', True), 250 | drop_last=False 251 | ) 252 | 253 | logger.info(f"Training samples: {len(train_dataset)}") 254 | logger.info(f"Validation samples: {len(val_dataset)}") 255 | logger.info(f"Training batches: {len(train_loader)}") 256 | logger.info(f"Validation batches: {len(val_loader)}") 257 | 258 | # Create loss function 259 | logger.info("Creating loss function...") 260 | criterion = get_loss_function(config['training']['loss']) 261 | criterion = criterion.to(device) 262 | 263 | # Create optimizer 264 | logger.info("Creating optimizer...") 265 | optimizer = create_optimizer(model, config) 266 | 267 | # Create trainer 268 | logger.info("Creating trainer...") 269 | trainer = SegmentationTrainer( 270 | model=model, 271 | train_loader=train_loader, 272 | val_loader=val_loader, 273 | criterion=criterion, 274 | optimizer=optimizer, 275 | config=config, 276 | device=device 277 | ) 278 | 279 | # Resume from checkpoint if specified 280 | if args.resume: 281 | logger.info(f"Resuming from checkpoint: {args.resume}") 282 | trainer.load_checkpoint(args.resume) 283 | 284 | # Start training 285 | 286 | trainer.train() 287 | 288 | # Plot training history 289 | logger.info("Generating training history plots...") 290 | training_stats = trainer.get_training_stats() 291 | 292 | history_plot_path = os.path.join(log_dir, 'training_history.png') 293 | 294 | plot_training_history( 295 | train_losses=training_stats['train_losses'], 296 | val_losses=training_stats['val_losses'], 297 | val_metrics=training_stats['val_mious'], 298 | metric_name='mIoU', 299 | save_path=history_plot_path 300 | ) 301 | 302 | logger.info("Training completed successfully!") 303 | 304 | 305 | if __name__ == "__main__": 306 | main() 307 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utils package for semantic segmentation. 3 | 4 | This package contains training utilities, evaluation metrics, visualization tools, 5 | and other helper functions for semantic segmentation. 6 | """ 7 | 8 | from .trainer import SegmentationTrainer 9 | from .logs import setup_logging 10 | from .evaluator import SegmentationEvaluator 11 | from .metrics import SegmentationMetrics, StreamingMetrics 12 | from .scheduler import ( 13 | CosineAnnealingWithWarmup, PolynomialLR, WarmupMultiStepLR, 14 | OneCycleLR, CyclicLR, get_scheduler, WarmupScheduler 15 | ) 16 | from .visualization import ( 17 | create_color_palette, tensor_to_image, 18 | apply_color_map, visualize_prediction, save_predictions, 19 | plot_training_history, plot_confusion_matrix, plot_class_metrics, 20 | create_interactive_training_dashboard, visualize_feature_maps 21 | ) 22 | 23 | __all__ = [ 24 | # Training 25 | 'SegmentationTrainer', 26 | 27 | # Evaluation 28 | 'SegmentationEvaluator', 29 | 30 | # Metrics 31 | 'SegmentationMetrics', 32 | 'StreamingMetrics', 33 | 34 | # Schedulers 35 | 'CosineAnnealingWithWarmup', 36 | 'PolynomialLR', 37 | 'WarmupMultiStepLR', 38 | 'OneCycleLR', 39 | 'CyclicLR', 40 | 'get_scheduler', 41 | 'WarmupScheduler', 42 | 43 | # Visualization 44 | 'setup_matplotlib_for_plotting', 45 | 'create_color_palette', 46 | 'tensor_to_image', 47 | 'apply_color_map', 48 | 'visualize_prediction', 49 | 'save_predictions', 50 | 'plot_training_history', 51 | 'plot_confusion_matrix', 52 | 'plot_class_metrics', 53 | 'create_interactive_training_dashboard', 54 | 'visualize_feature_maps' 55 | 56 | # Utility functions 57 | 'setup_logging' 58 | ] 59 | -------------------------------------------------------------------------------- /utils/evaluator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model Evaluator for Semantic Segmentation 3 | 4 | This module contains comprehensive evaluation utilities including 5 | model testing, inference with TTA, and detailed analysis. 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.utils.data import DataLoader 12 | import numpy as np 13 | import time 14 | import os 15 | from typing import Dict, List, Optional, Tuple, Any, Union 16 | from tqdm import tqdm 17 | import json 18 | import cv2 19 | from PIL import Image 20 | 21 | from .metrics import SegmentationMetrics, StreamingMetrics 22 | from .visualization import visualize_prediction, apply_color_map, create_color_palette 23 | 24 | 25 | class SegmentationEvaluator: 26 | """ 27 | Comprehensive evaluator for semantic segmentation models. 28 | 29 | Args: 30 | model (nn.Module): Segmentation model 31 | device (str): Device for evaluation 32 | num_classes (int): Number of classes 33 | ignore_index (int): Index to ignore in evaluation 34 | class_names (Optional[List[str]]): Names of classes 35 | class_colors (Optional[List[List[int]]]): Colors for each class 36 | """ 37 | 38 | def __init__(self, model: nn.Module, device: str = 'cuda', 39 | num_classes: int = 19, ignore_index: int = 255, 40 | class_names: Optional[List[str]] = None, 41 | class_colors: Optional[List[List[int]]] = None): 42 | self.model = model 43 | self.device = device 44 | self.num_classes = num_classes 45 | self.ignore_index = ignore_index 46 | self.class_names = class_names or [f"Class {i}" for i in range(num_classes)] 47 | 48 | # Create color palette 49 | if class_colors is not None: 50 | self.color_palette = np.array(class_colors, dtype=np.uint8) 51 | else: 52 | self.color_palette = create_color_palette(num_classes) 53 | 54 | # Initialize metrics 55 | self.metrics = SegmentationMetrics(num_classes, ignore_index, device) 56 | 57 | def evaluate_dataset(self, data_loader: DataLoader, 58 | use_tta: bool = False, 59 | tta_scales: List[float] = [0.75, 1.0, 1.25], 60 | tta_flip: bool = True, 61 | save_predictions: bool = False, 62 | output_dir: str = "eval_results") -> Dict[str, Any]: 63 | """ 64 | Evaluate model on a dataset. 65 | 66 | Args: 67 | data_loader (DataLoader): Data loader for evaluation 68 | use_tta (bool): Whether to use Test Time Augmentation 69 | tta_scales (List[float]): Scales for TTA 70 | tta_flip (bool): Whether to use horizontal flip in TTA 71 | save_predictions (bool): Whether to save predictions 72 | output_dir (str): Output directory for results 73 | 74 | Returns: 75 | Dict[str, Any]: Evaluation results 76 | """ 77 | self.model.eval() 78 | self.metrics.reset() 79 | 80 | if save_predictions: 81 | os.makedirs(output_dir, exist_ok=True) 82 | pred_dir = os.path.join(output_dir, 'predictions') 83 | vis_dir = os.path.join(output_dir, 'visualizations') 84 | os.makedirs(pred_dir, exist_ok=True) 85 | os.makedirs(vis_dir, exist_ok=True) 86 | 87 | total_time = 0 88 | num_samples = 0 89 | 90 | with torch.no_grad(): 91 | pbar = tqdm(data_loader, desc='Evaluating') 92 | 93 | for batch_idx, batch in enumerate(pbar): 94 | images = batch['image'].to(self.device, non_blocking=True) 95 | labels = batch['label'].to(self.device, non_blocking=True) 96 | 97 | start_time = time.time() 98 | 99 | if use_tta: 100 | predictions = self._predict_with_tta(images, tta_scales, tta_flip) 101 | else: 102 | outputs = self.model(images) 103 | if isinstance(outputs, dict): 104 | predictions = outputs['out'] 105 | else: 106 | predictions = outputs 107 | 108 | inference_time = time.time() - start_time 109 | total_time += inference_time 110 | num_samples += len(images) 111 | 112 | # Update metrics 113 | self.metrics.update(predictions, labels) 114 | 115 | # Save predictions and visualizations 116 | if save_predictions: 117 | self._save_batch_predictions( 118 | batch, predictions, batch_idx, pred_dir, vis_dir 119 | ) 120 | 121 | # Update progress bar 122 | current_metrics = self.metrics.compute() 123 | pbar.set_postfix({ 124 | 'mIoU': f'{current_metrics["mIoU"]:.4f}', 125 | 'Pixel Acc': f'{current_metrics["Pixel_Accuracy"]:.4f}' 126 | }) 127 | 128 | # Compute final metrics 129 | final_metrics = self.metrics.compute() 130 | 131 | # Add timing information 132 | final_metrics['inference_time_per_sample'] = total_time / num_samples 133 | final_metrics['fps'] = num_samples / total_time 134 | 135 | # Save metrics 136 | if save_predictions: 137 | self._save_evaluation_results(final_metrics, output_dir) 138 | 139 | return final_metrics 140 | 141 | def _predict_with_tta(self, images: torch.Tensor, 142 | scales: List[float], use_flip: bool) -> torch.Tensor: 143 | """ 144 | Perform prediction with Test Time Augmentation. 145 | 146 | Args: 147 | images (torch.Tensor): Input images 148 | scales (List[float]): Scale factors 149 | use_flip (bool): Whether to use horizontal flip 150 | 151 | Returns: 152 | torch.Tensor: Averaged predictions 153 | """ 154 | b, c, h, w = images.shape 155 | 156 | # Initialize aggregated predictions 157 | aggregated_preds = torch.zeros(b, self.num_classes, h, w, device=self.device) 158 | num_predictions = 0 159 | 160 | for scale in scales: 161 | # Resize images 162 | scaled_h, scaled_w = int(h * scale), int(w * scale) 163 | scaled_images = F.interpolate( 164 | images, size=(scaled_h, scaled_w), 165 | mode='bilinear', align_corners=False 166 | ) 167 | 168 | # Normal prediction 169 | outputs = self.model(scaled_images) 170 | if isinstance(outputs, dict): 171 | preds = outputs['out'] 172 | else: 173 | preds = outputs 174 | 175 | # Resize back to original size 176 | preds = F.interpolate( 177 | preds, size=(h, w), 178 | mode='bilinear', align_corners=False 179 | ) 180 | aggregated_preds += preds 181 | num_predictions += 1 182 | 183 | # Flipped prediction 184 | if use_flip: 185 | flipped_images = torch.flip(scaled_images, dims=[3]) 186 | outputs = self.model(flipped_images) 187 | if isinstance(outputs, dict): 188 | preds = outputs['out'] 189 | else: 190 | preds = outputs 191 | 192 | # Flip back and resize 193 | preds = torch.flip(preds, dims=[3]) 194 | preds = F.interpolate( 195 | preds, size=(h, w), 196 | mode='bilinear', align_corners=False 197 | ) 198 | aggregated_preds += preds 199 | num_predictions += 1 200 | 201 | # Average predictions 202 | averaged_preds = aggregated_preds / num_predictions 203 | 204 | return averaged_preds 205 | 206 | def _save_batch_predictions(self, batch: Dict[str, torch.Tensor], 207 | predictions: torch.Tensor, batch_idx: int, 208 | pred_dir: str, vis_dir: str) -> None: 209 | """Save batch predictions and visualizations.""" 210 | batch_size = len(batch['image']) 211 | 212 | for i in range(batch_size): 213 | sample_idx = batch_idx * batch_size + i 214 | 215 | # Get data 216 | image = batch['image'][i].cpu() 217 | label = batch['label'][i].cpu() 218 | pred = torch.argmax(predictions[i], dim=0).cpu() 219 | confidence = torch.max(torch.softmax(predictions[i], dim=0), dim=0)[0].cpu() 220 | 221 | # Save prediction mask 222 | pred_path = os.path.join(pred_dir, f'prediction_{sample_idx:06d}.png') 223 | pred_image = Image.fromarray(pred.numpy().astype(np.uint8)) 224 | pred_image.save(pred_path) 225 | 226 | # Save visualization 227 | vis_path = os.path.join(vis_dir, f'visualization_{sample_idx:06d}.png') 228 | visualize_prediction( 229 | image=image, 230 | label=label, 231 | prediction=pred, 232 | confidence=confidence, 233 | color_palette=self.color_palette, 234 | save_path=vis_path 235 | ) 236 | 237 | def _save_evaluation_results(self, metrics: Dict[str, float], output_dir: str) -> None: 238 | """Save evaluation results to files.""" 239 | # Save metrics as JSON 240 | metrics_path = os.path.join(output_dir, 'metrics.json') 241 | with open(metrics_path, 'w') as f: 242 | json.dump(metrics, f, indent=2) 243 | 244 | # Save detailed per-class metrics 245 | detailed_metrics = {} 246 | for i in range(self.num_classes): 247 | class_name = self.class_names[i] if i < len(self.class_names) else f"Class {i}" 248 | detailed_metrics[class_name] = { 249 | 'IoU': metrics.get(f'IoU_Class_{i}', 0.0), 250 | 'Dice': metrics.get(f'Dice_Class_{i}', 0.0), 251 | 'Precision': metrics.get(f'Precision_Class_{i}', 0.0), 252 | 'Recall': metrics.get(f'Recall_Class_{i}', 0.0), 253 | 'F1': metrics.get(f'F1_Class_{i}', 0.0) 254 | } 255 | 256 | detailed_path = os.path.join(output_dir, 'detailed_metrics.json') 257 | with open(detailed_path, 'w') as f: 258 | json.dump(detailed_metrics, f, indent=2) 259 | 260 | # Save confusion matrix 261 | confusion_matrix = self.metrics.get_confusion_matrix() 262 | np.save(os.path.join(output_dir, 'confusion_matrix.npy'), confusion_matrix) 263 | 264 | print(f"Evaluation results saved to: {output_dir}") 265 | 266 | def evaluate_single_image(self, image_path: str, 267 | use_tta: bool = False, 268 | save_result: bool = True, 269 | output_path: Optional[str] = None) -> Dict[str, Any]: 270 | """ 271 | Evaluate model on a single image. 272 | 273 | Args: 274 | image_path (str): Path to input image 275 | use_tta (bool): Whether to use TTA 276 | save_result (bool): Whether to save result 277 | output_path (Optional[str]): Output path for result 278 | 279 | Returns: 280 | Dict[str, Any]: Prediction results 281 | """ 282 | self.model.eval() 283 | 284 | # Load and preprocess image 285 | image = Image.open(image_path).convert('RGB') 286 | original_size = image.size 287 | 288 | # Convert to tensor (assuming normalization is handled in transforms) 289 | image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0 290 | image_tensor = image_tensor.unsqueeze(0).to(self.device) 291 | 292 | with torch.no_grad(): 293 | start_time = time.time() 294 | 295 | if use_tta: 296 | predictions = self._predict_with_tta( 297 | image_tensor, scales=[0.75, 1.0, 1.25], use_flip=True 298 | ) 299 | else: 300 | outputs = self.model(image_tensor) 301 | if isinstance(outputs, dict): 302 | predictions = outputs['out'] 303 | else: 304 | predictions = outputs 305 | 306 | inference_time = time.time() - start_time 307 | 308 | # Process predictions 309 | pred_mask = torch.argmax(predictions[0], dim=0).cpu().numpy() 310 | confidence_map = torch.max(torch.softmax(predictions[0], dim=0), dim=0)[0].cpu().numpy() 311 | 312 | # Create colored prediction 313 | colored_pred = apply_color_map(pred_mask, self.color_palette, self.ignore_index) 314 | 315 | results = { 316 | 'prediction_mask': pred_mask, 317 | 'confidence_map': confidence_map, 318 | 'colored_prediction': colored_pred, 319 | 'inference_time': inference_time, 320 | 'original_size': original_size 321 | } 322 | 323 | # Save results 324 | if save_result: 325 | if output_path is None: 326 | base_name = os.path.splitext(os.path.basename(image_path))[0] 327 | output_path = f"{base_name}_prediction.png" 328 | 329 | # Save colored prediction 330 | colored_pred_image = Image.fromarray(colored_pred) 331 | colored_pred_image.save(output_path) 332 | 333 | # Save raw prediction mask 334 | mask_path = output_path.replace('.png', '_mask.png') 335 | mask_image = Image.fromarray(pred_mask.astype(np.uint8)) 336 | mask_image.save(mask_path) 337 | 338 | print(f"Results saved to: {output_path}") 339 | 340 | return results 341 | 342 | def benchmark_model(self, data_loader: DataLoader, 343 | num_warmup: int = 10, 344 | num_iterations: int = 100) -> Dict[str, float]: 345 | """ 346 | Benchmark model performance. 347 | 348 | Args: 349 | data_loader (DataLoader): Data loader for benchmarking 350 | num_warmup (int): Number of warmup iterations 351 | num_iterations (int): Number of benchmark iterations 352 | 353 | Returns: 354 | Dict[str, float]: Benchmark results 355 | """ 356 | self.model.eval() 357 | 358 | # Warmup 359 | print("Warming up...") 360 | with torch.no_grad(): 361 | for i, batch in enumerate(data_loader): 362 | if i >= num_warmup: 363 | break 364 | 365 | images = batch['image'].to(self.device, non_blocking=True) 366 | outputs = self.model(images) 367 | 368 | if self.device == 'cuda': 369 | torch.cuda.synchronize() 370 | 371 | # Benchmark 372 | print("Benchmarking...") 373 | times = [] 374 | 375 | with torch.no_grad(): 376 | for i, batch in enumerate(data_loader): 377 | if i >= num_iterations: 378 | break 379 | 380 | images = batch['image'].to(self.device, non_blocking=True) 381 | 382 | if self.device == 'cuda': 383 | torch.cuda.synchronize() 384 | 385 | start_time = time.time() 386 | outputs = self.model(images) 387 | 388 | if self.device == 'cuda': 389 | torch.cuda.synchronize() 390 | 391 | end_time = time.time() 392 | times.append(end_time - start_time) 393 | 394 | # Compute statistics 395 | times = np.array(times) 396 | batch_size = len(batch['image']) 397 | 398 | results = { 399 | 'avg_batch_time': float(np.mean(times)), 400 | 'std_batch_time': float(np.std(times)), 401 | 'min_batch_time': float(np.min(times)), 402 | 'max_batch_time': float(np.max(times)), 403 | 'avg_sample_time': float(np.mean(times) / batch_size), 404 | 'fps': float(batch_size / np.mean(times)), 405 | 'throughput_samples_per_sec': float(num_iterations * batch_size / np.sum(times)) 406 | } 407 | 408 | return results 409 | 410 | def analyze_failure_cases(self, data_loader: DataLoader, 411 | iou_threshold: float = 0.3, 412 | max_cases: int = 50, 413 | output_dir: str = "failure_analysis") -> List[Dict[str, Any]]: 414 | """ 415 | Analyze failure cases where model performs poorly. 416 | 417 | Args: 418 | data_loader (DataLoader): Data loader 419 | iou_threshold (float): IoU threshold below which samples are considered failures 420 | max_cases (int): Maximum number of failure cases to analyze 421 | output_dir (str): Output directory for analysis 422 | 423 | Returns: 424 | List[Dict[str, Any]]: List of failure case information 425 | """ 426 | self.model.eval() 427 | failure_cases = [] 428 | 429 | os.makedirs(output_dir, exist_ok=True) 430 | 431 | with torch.no_grad(): 432 | pbar = tqdm(data_loader, desc='Analyzing failure cases') 433 | 434 | for batch_idx, batch in enumerate(pbar): 435 | if len(failure_cases) >= max_cases: 436 | break 437 | 438 | images = batch['image'].to(self.device, non_blocking=True) 439 | labels = batch['label'].to(self.device, non_blocking=True) 440 | 441 | outputs = self.model(images) 442 | if isinstance(outputs, dict): 443 | predictions = outputs['out'] 444 | else: 445 | predictions = outputs 446 | 447 | # Compute per-sample IoU 448 | batch_size = len(images) 449 | for i in range(batch_size): 450 | sample_pred = predictions[i:i+1] 451 | sample_label = labels[i:i+1] 452 | 453 | # Compute sample metrics 454 | sample_metrics = SegmentationMetrics(self.num_classes, self.ignore_index, self.device) 455 | sample_metrics.update(sample_pred, sample_label) 456 | metrics = sample_metrics.compute() 457 | 458 | if metrics['mIoU'] < iou_threshold: 459 | # Save failure case 460 | sample_idx = len(failure_cases) 461 | 462 | case_info = { 463 | 'sample_index': sample_idx, 464 | 'batch_index': batch_idx, 465 | 'sample_in_batch': i, 466 | 'miou': metrics['mIoU'], 467 | 'pixel_accuracy': metrics['Pixel_Accuracy'], 468 | 'image_path': batch.get('image_path', [''])[i], 469 | 'label_path': batch.get('label_path', [''])[i] 470 | } 471 | 472 | # Save visualization 473 | vis_path = os.path.join(output_dir, f'failure_case_{sample_idx:03d}.png') 474 | visualize_prediction( 475 | image=images[i].cpu(), 476 | label=labels[i].cpu(), 477 | prediction=torch.argmax(predictions[i], dim=0).cpu(), 478 | confidence=torch.max(torch.softmax(predictions[i], dim=0), dim=0)[0].cpu(), 479 | color_palette=self.color_palette, 480 | save_path=vis_path 481 | ) 482 | 483 | failure_cases.append(case_info) 484 | 485 | # Save failure case summary 486 | summary_path = os.path.join(output_dir, 'failure_cases_summary.json') 487 | with open(summary_path, 'w') as f: 488 | json.dump(failure_cases, f, indent=2) 489 | 490 | print(f"Found {len(failure_cases)} failure cases. Analysis saved to: {output_dir}") 491 | 492 | return failure_cases 493 | 494 | 495 | if __name__ == "__main__": 496 | # Test evaluator functionality 497 | print("Testing evaluator module...") 498 | 499 | # This would require actual model, data loaders, etc. 500 | # For now, just test imports 501 | print("Evaluator module loaded successfully!") 502 | -------------------------------------------------------------------------------- /utils/logs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import sys 4 | 5 | def setup_logging(log_dir: str): 6 | """Setup logging configuration.""" 7 | os.makedirs(log_dir, exist_ok=True) 8 | 9 | level = logging.INFO 10 | 11 | logging.basicConfig( 12 | level=level, 13 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', 14 | handlers=[ 15 | logging.FileHandler(os.path.join(log_dir, 'training.log')), 16 | logging.StreamHandler(sys.stdout) 17 | ] 18 | ) -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluation Metrics for Semantic Segmentation 3 | 4 | This module contains comprehensive evaluation metrics including 5 | IoU, Pixel Accuracy, Dice Score, and class-wise statistics. 6 | """ 7 | 8 | import torch 9 | import numpy as np 10 | from typing import Dict, List, Optional, Tuple 11 | import torch.nn.functional as F 12 | 13 | 14 | class SegmentationMetrics: 15 | """ 16 | Comprehensive metrics for semantic segmentation evaluation. 17 | 18 | Computes: 19 | - Mean IoU (mIoU) 20 | - Pixel Accuracy 21 | - Mean Accuracy 22 | - Frequency Weighted IoU 23 | - Per-class IoU and Accuracy 24 | - Dice Score 25 | - Precision and Recall 26 | 27 | Args: 28 | num_classes (int): Number of classes 29 | ignore_index (int): Index to ignore in calculations 30 | device (str): Device for computations 31 | """ 32 | 33 | def __init__(self, num_classes: int, ignore_index: int = 255, device: str = 'cpu'): 34 | self.num_classes = num_classes 35 | self.ignore_index = ignore_index 36 | self.device = device 37 | 38 | self.reset() 39 | 40 | def reset(self): 41 | """Reset all metrics.""" 42 | self.confusion_matrix = torch.zeros( 43 | (self.num_classes, self.num_classes), 44 | dtype=torch.int64, 45 | device=self.device 46 | ) 47 | self.total_samples = 0 48 | 49 | def update(self, predictions: torch.Tensor, targets: torch.Tensor): 50 | """ 51 | Update metrics with new predictions and targets. 52 | 53 | Args: 54 | predictions (torch.Tensor): Model predictions of shape (B, C, H, W) 55 | targets (torch.Tensor): Ground truth labels of shape (B, H, W) 56 | """ 57 | # Convert predictions to class indices 58 | if predictions.dim() == 4: # (B, C, H, W) 59 | predictions = torch.argmax(predictions, dim=1) 60 | 61 | # Flatten tensors 62 | predictions = predictions.flatten() 63 | targets = targets.flatten() 64 | 65 | # Create mask for valid pixels 66 | mask = (targets != self.ignore_index) 67 | predictions = predictions[mask] 68 | targets = targets[mask] 69 | 70 | # Update confusion matrix 71 | indices = self.num_classes * targets + predictions 72 | cm_update = torch.bincount(indices, minlength=self.num_classes**2) 73 | cm_update = cm_update.reshape(self.num_classes, self.num_classes) 74 | 75 | self.confusion_matrix += cm_update.to(self.device) 76 | self.total_samples += mask.sum().item() 77 | 78 | def compute_iou(self) -> Tuple[torch.Tensor, torch.Tensor]: 79 | """ 80 | Compute IoU metrics. 81 | 82 | Returns: 83 | Tuple[torch.Tensor, torch.Tensor]: (per_class_iou, mean_iou) 84 | """ 85 | # IoU = TP / (TP + FP + FN) 86 | # TP: diagonal elements 87 | # FP: column sum - diagonal 88 | # FN: row sum - diagonal 89 | 90 | tp = torch.diag(self.confusion_matrix).float() 91 | fp = self.confusion_matrix.sum(dim=0) - tp 92 | fn = self.confusion_matrix.sum(dim=1) - tp 93 | 94 | # Avoid division by zero 95 | denominator = tp + fp + fn 96 | iou = tp / (denominator + 1e-8) 97 | 98 | # Set IoU to 0 for classes that don't appear in ground truth 99 | valid_classes = (denominator > 0) 100 | iou = iou * valid_classes.float() 101 | 102 | mean_iou = iou[valid_classes].mean() if valid_classes.any() else torch.tensor(0.0) 103 | 104 | return iou, mean_iou 105 | 106 | def compute_pixel_accuracy(self) -> torch.Tensor: 107 | """ 108 | Compute pixel accuracy. 109 | 110 | Returns: 111 | torch.Tensor: Pixel accuracy 112 | """ 113 | correct_pixels = torch.diag(self.confusion_matrix).sum() 114 | total_pixels = self.confusion_matrix.sum() 115 | 116 | return correct_pixels / (total_pixels + 1e-8) 117 | 118 | def compute_mean_accuracy(self) -> torch.Tensor: 119 | """ 120 | Compute mean class accuracy. 121 | 122 | Returns: 123 | torch.Tensor: Mean accuracy 124 | """ 125 | # Class accuracy = TP / (TP + FN) 126 | tp = torch.diag(self.confusion_matrix).float() 127 | total_per_class = self.confusion_matrix.sum(dim=1).float() 128 | 129 | class_accuracy = tp / (total_per_class + 1e-8) 130 | 131 | # Only consider classes that appear in ground truth 132 | valid_classes = (total_per_class > 0) 133 | mean_accuracy = class_accuracy[valid_classes].mean() if valid_classes.any() else torch.tensor(0.0) 134 | 135 | return mean_accuracy 136 | 137 | def compute_frequency_weighted_iou(self) -> torch.Tensor: 138 | """ 139 | Compute frequency weighted IoU. 140 | 141 | Returns: 142 | torch.Tensor: Frequency weighted IoU 143 | """ 144 | iou, _ = self.compute_iou() 145 | 146 | # Class frequencies 147 | class_frequencies = self.confusion_matrix.sum(dim=1).float() 148 | total_pixels = class_frequencies.sum() 149 | weights = class_frequencies / (total_pixels + 1e-8) 150 | 151 | # Weighted IoU 152 | fwiou = (weights * iou).sum() 153 | 154 | return fwiou 155 | 156 | def compute_dice_score(self) -> Tuple[torch.Tensor, torch.Tensor]: 157 | """ 158 | Compute Dice score metrics. 159 | 160 | Returns: 161 | Tuple[torch.Tensor, torch.Tensor]: (per_class_dice, mean_dice) 162 | """ 163 | # Dice = 2 * TP / (2 * TP + FP + FN) 164 | tp = torch.diag(self.confusion_matrix).float() 165 | fp = self.confusion_matrix.sum(dim=0) - tp 166 | fn = self.confusion_matrix.sum(dim=1) - tp 167 | 168 | dice = (2 * tp) / (2 * tp + fp + fn + 1e-8) 169 | 170 | # Set Dice to 0 for classes that don't appear 171 | valid_classes = ((tp + fp + fn) > 0) 172 | dice = dice * valid_classes.float() 173 | 174 | mean_dice = dice[valid_classes].mean() if valid_classes.any() else torch.tensor(0.0) 175 | 176 | return dice, mean_dice 177 | 178 | def compute_precision_recall(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 179 | """ 180 | Compute precision and recall metrics. 181 | 182 | Returns: 183 | Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 184 | (per_class_precision, mean_precision, per_class_recall, mean_recall) 185 | """ 186 | tp = torch.diag(self.confusion_matrix).float() 187 | fp = self.confusion_matrix.sum(dim=0) - tp 188 | fn = self.confusion_matrix.sum(dim=1) - tp 189 | 190 | # Precision = TP / (TP + FP) 191 | precision = tp / (tp + fp + 1e-8) 192 | valid_precision = ((tp + fp) > 0) 193 | precision = precision * valid_precision.float() 194 | mean_precision = precision[valid_precision].mean() if valid_precision.any() else torch.tensor(0.0) 195 | 196 | # Recall = TP / (TP + FN) 197 | recall = tp / (tp + fn + 1e-8) 198 | valid_recall = ((tp + fn) > 0) 199 | recall = recall * valid_recall.float() 200 | mean_recall = recall[valid_recall].mean() if valid_recall.any() else torch.tensor(0.0) 201 | 202 | return precision, mean_precision, recall, mean_recall 203 | 204 | def compute_f1_score(self) -> Tuple[torch.Tensor, torch.Tensor]: 205 | """ 206 | Compute F1 score metrics. 207 | 208 | Returns: 209 | Tuple[torch.Tensor, torch.Tensor]: (per_class_f1, mean_f1) 210 | """ 211 | precision, _, recall, _ = self.compute_precision_recall() 212 | 213 | f1 = 2 * (precision * recall) / (precision + recall + 1e-8) 214 | 215 | # Only consider valid classes 216 | valid_classes = ((precision + recall) > 0) 217 | f1 = f1 * valid_classes.float() 218 | mean_f1 = f1[valid_classes].mean() if valid_classes.any() else torch.tensor(0.0) 219 | 220 | return f1, mean_f1 221 | 222 | def compute(self) -> Dict[str, float]: 223 | """ 224 | Compute all metrics and return as dictionary. 225 | 226 | Returns: 227 | Dict[str, float]: Dictionary containing all metrics 228 | """ 229 | # IoU metrics 230 | per_class_iou, mean_iou = self.compute_iou() 231 | 232 | # Accuracy metrics 233 | pixel_accuracy = self.compute_pixel_accuracy() 234 | mean_accuracy = self.compute_mean_accuracy() 235 | fwiou = self.compute_frequency_weighted_iou() 236 | 237 | # Dice score 238 | per_class_dice, mean_dice = self.compute_dice_score() 239 | 240 | # Precision and Recall 241 | per_class_precision, mean_precision, per_class_recall, mean_recall = self.compute_precision_recall() 242 | 243 | # F1 Score 244 | per_class_f1, mean_f1 = self.compute_f1_score() 245 | 246 | # Convert to float for logging 247 | metrics = { 248 | 'mIoU': mean_iou.item(), 249 | 'Pixel_Accuracy': pixel_accuracy.item(), 250 | 'Mean_Accuracy': mean_accuracy.item(), 251 | 'Frequency_Weighted_IoU': fwiou.item(), 252 | 'Mean_Dice': mean_dice.item(), 253 | 'Mean_Precision': mean_precision.item(), 254 | 'Mean_Recall': mean_recall.item(), 255 | 'Mean_F1': mean_f1.item() 256 | } 257 | 258 | # Add per-class metrics 259 | for i in range(self.num_classes): 260 | metrics[f'IoU_Class_{i}'] = per_class_iou[i].item() 261 | metrics[f'Dice_Class_{i}'] = per_class_dice[i].item() 262 | metrics[f'Precision_Class_{i}'] = per_class_precision[i].item() 263 | metrics[f'Recall_Class_{i}'] = per_class_recall[i].item() 264 | metrics[f'F1_Class_{i}'] = per_class_f1[i].item() 265 | 266 | return metrics 267 | 268 | def get_confusion_matrix(self) -> np.ndarray: 269 | """ 270 | Get confusion matrix as numpy array. 271 | 272 | Returns: 273 | np.ndarray: Confusion matrix 274 | """ 275 | return self.confusion_matrix.cpu().numpy() 276 | 277 | def print_class_metrics(self, class_names: Optional[List[str]] = None): 278 | """ 279 | Print detailed per-class metrics. 280 | 281 | Args: 282 | class_names (Optional[List[str]]): Names of classes 283 | """ 284 | if class_names is None: 285 | class_names = [f"Class {i}" for i in range(self.num_classes)] 286 | 287 | # Compute metrics 288 | per_class_iou, mean_iou = self.compute_iou() 289 | per_class_dice, mean_dice = self.compute_dice_score() 290 | per_class_precision, mean_precision, per_class_recall, mean_recall = self.compute_precision_recall() 291 | per_class_f1, mean_f1 = self.compute_f1_score() 292 | 293 | print("\nPer-Class Metrics:") 294 | print("-" * 80) 295 | print(f"{'Class':<20} {'IoU':<8} {'Dice':<8} {'Precision':<10} {'Recall':<8} {'F1':<8}") 296 | print("-" * 80) 297 | 298 | for i in range(self.num_classes): 299 | class_name = class_names[i] if i < len(class_names) else f"Class {i}" 300 | print( 301 | f"{class_name:<20} " 302 | f"{per_class_iou[i]:.4f} " 303 | f"{per_class_dice[i]:.4f} " 304 | f"{per_class_precision[i]:.4f} " 305 | f"{per_class_recall[i]:.4f} " 306 | f"{per_class_f1[i]:.4f}" 307 | ) 308 | 309 | print("-" * 80) 310 | print( 311 | f"{'Mean':<20} " 312 | f"{mean_iou:.4f} " 313 | f"{mean_dice:.4f} " 314 | f"{mean_precision:.4f} " 315 | f"{mean_recall:.4f} " 316 | f"{mean_f1:.4f}" 317 | ) 318 | print("-" * 80) 319 | 320 | 321 | class StreamingMetrics: 322 | """ 323 | Streaming version of metrics for large datasets that don't fit in memory. 324 | """ 325 | 326 | def __init__(self, num_classes: int, ignore_index: int = 255): 327 | self.num_classes = num_classes 328 | self.ignore_index = ignore_index 329 | self.reset() 330 | 331 | def reset(self): 332 | """Reset metrics.""" 333 | self.tp = np.zeros(self.num_classes, dtype=np.int64) 334 | self.fp = np.zeros(self.num_classes, dtype=np.int64) 335 | self.fn = np.zeros(self.num_classes, dtype=np.int64) 336 | self.total_pixels = 0 337 | self.correct_pixels = 0 338 | 339 | def update(self, predictions: np.ndarray, targets: np.ndarray): 340 | """ 341 | Update metrics with new predictions and targets. 342 | 343 | Args: 344 | predictions (np.ndarray): Model predictions 345 | targets (np.ndarray): Ground truth labels 346 | """ 347 | # Flatten arrays 348 | predictions = predictions.flatten() 349 | targets = targets.flatten() 350 | 351 | # Create mask for valid pixels 352 | mask = (targets != self.ignore_index) 353 | predictions = predictions[mask] 354 | targets = targets[mask] 355 | 356 | # Update pixel counts 357 | self.total_pixels += len(targets) 358 | self.correct_pixels += np.sum(predictions == targets) 359 | 360 | # Update per-class counts 361 | for c in range(self.num_classes): 362 | pred_mask = (predictions == c) 363 | target_mask = (targets == c) 364 | 365 | self.tp[c] += np.sum(pred_mask & target_mask) 366 | self.fp[c] += np.sum(pred_mask & ~target_mask) 367 | self.fn[c] += np.sum(~pred_mask & target_mask) 368 | 369 | def compute_metrics(self) -> Dict[str, float]: 370 | """ 371 | Compute metrics from accumulated counts. 372 | 373 | Returns: 374 | Dict[str, float]: Computed metrics 375 | """ 376 | # IoU 377 | iou = self.tp / (self.tp + self.fp + self.fn + 1e-8) 378 | valid_classes = (self.tp + self.fp + self.fn) > 0 379 | mean_iou = np.mean(iou[valid_classes]) if np.any(valid_classes) else 0.0 380 | 381 | # Pixel accuracy 382 | pixel_accuracy = self.correct_pixels / (self.total_pixels + 1e-8) 383 | 384 | # Mean accuracy 385 | class_accuracy = self.tp / (self.tp + self.fn + 1e-8) 386 | mean_accuracy = np.mean(class_accuracy[valid_classes]) if np.any(valid_classes) else 0.0 387 | 388 | # Precision and Recall 389 | precision = self.tp / (self.tp + self.fp + 1e-8) 390 | recall = self.tp / (self.tp + self.fn + 1e-8) 391 | 392 | valid_precision = (self.tp + self.fp) > 0 393 | valid_recall = (self.tp + self.fn) > 0 394 | 395 | mean_precision = np.mean(precision[valid_precision]) if np.any(valid_precision) else 0.0 396 | mean_recall = np.mean(recall[valid_recall]) if np.any(valid_recall) else 0.0 397 | 398 | # F1 Score 399 | f1 = 2 * (precision * recall) / (precision + recall + 1e-8) 400 | valid_f1 = (precision + recall) > 0 401 | mean_f1 = np.mean(f1[valid_f1]) if np.any(valid_f1) else 0.0 402 | 403 | return { 404 | 'mIoU': float(mean_iou), 405 | 'Pixel_Accuracy': float(pixel_accuracy), 406 | 'Mean_Accuracy': float(mean_accuracy), 407 | 'Mean_Precision': float(mean_precision), 408 | 'Mean_Recall': float(mean_recall), 409 | 'Mean_F1': float(mean_f1) 410 | } 411 | 412 | 413 | if __name__ == "__main__": 414 | # Test metrics 415 | num_classes = 19 416 | batch_size = 2 417 | height, width = 64, 64 418 | 419 | # Create dummy data 420 | predictions = torch.randn(batch_size, num_classes, height, width) 421 | targets = torch.randint(0, num_classes, (batch_size, height, width)) 422 | 423 | # Test SegmentationMetrics 424 | metrics = SegmentationMetrics(num_classes) 425 | metrics.update(predictions, targets) 426 | 427 | computed_metrics = metrics.compute() 428 | print("Computed metrics:") 429 | for key, value in computed_metrics.items(): 430 | if not key.startswith(('IoU_Class', 'Dice_Class', 'Precision_Class', 'Recall_Class', 'F1_Class')): 431 | print(f"{key}: {value:.4f}") 432 | 433 | # Test class-wise metrics 434 | metrics.print_class_metrics() 435 | 436 | print("\nMetrics module test completed successfully!") 437 | -------------------------------------------------------------------------------- /utils/scheduler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Learning Rate Schedulers for Semantic Segmentation Training 3 | 4 | This module contains various learning rate scheduling strategies 5 | including cosine annealing, polynomial decay, and warmup schedules. 6 | """ 7 | 8 | import torch 9 | import torch.optim as optim 10 | from torch.optim.lr_scheduler import _LRScheduler 11 | import math 12 | from typing import Dict, Any, Optional, List 13 | 14 | 15 | class CosineAnnealingWithWarmup(_LRScheduler): 16 | """ 17 | Cosine Annealing learning rate scheduler with linear warmup. 18 | 19 | Args: 20 | optimizer (torch.optim.Optimizer): Optimizer 21 | T_max (int): Maximum number of iterations/epochs 22 | eta_min (float): Minimum learning rate 23 | warmup_epochs (int): Number of warmup epochs 24 | last_epoch (int): The index of last epoch 25 | """ 26 | 27 | def __init__(self, optimizer: torch.optim.Optimizer, T_max: int, 28 | eta_min: float = 0, warmup_epochs: int = 0, last_epoch: int = -1): 29 | self.T_max = T_max 30 | self.eta_min = eta_min 31 | self.warmup_epochs = warmup_epochs 32 | super(CosineAnnealingWithWarmup, self).__init__(optimizer, last_epoch) 33 | 34 | def get_lr(self) -> List[float]: 35 | """Compute learning rate for current epoch.""" 36 | if self.last_epoch < self.warmup_epochs: 37 | # Linear warmup 38 | warmup_factor = self.last_epoch / self.warmup_epochs 39 | return [base_lr * warmup_factor for base_lr in self.base_lrs] 40 | else: 41 | # Cosine annealing 42 | adjusted_epoch = self.last_epoch - self.warmup_epochs 43 | adjusted_T_max = self.T_max - self.warmup_epochs 44 | 45 | return [ 46 | self.eta_min + (base_lr - self.eta_min) * 47 | (1 + math.cos(math.pi * adjusted_epoch / adjusted_T_max)) / 2 48 | for base_lr in self.base_lrs 49 | ] 50 | 51 | 52 | class PolynomialLR(_LRScheduler): 53 | """ 54 | Polynomial learning rate decay scheduler. 55 | 56 | Args: 57 | optimizer (torch.optim.Optimizer): Optimizer 58 | total_epochs (int): Total number of training epochs 59 | power (float): Power for polynomial decay 60 | warmup_epochs (int): Number of warmup epochs 61 | last_epoch (int): The index of last epoch 62 | """ 63 | 64 | def __init__(self, optimizer: torch.optim.Optimizer, total_epochs: int, 65 | power: float = 0.9, warmup_epochs: int = 0, last_epoch: int = -1): 66 | self.total_epochs = total_epochs 67 | self.power = power 68 | self.warmup_epochs = warmup_epochs 69 | super(PolynomialLR, self).__init__(optimizer, last_epoch) 70 | 71 | def get_lr(self) -> List[float]: 72 | """Compute learning rate for current epoch.""" 73 | if self.last_epoch < self.warmup_epochs: 74 | # Linear warmup 75 | warmup_factor = self.last_epoch / self.warmup_epochs 76 | return [base_lr * warmup_factor for base_lr in self.base_lrs] 77 | else: 78 | # Polynomial decay 79 | factor = (1 - (self.last_epoch - self.warmup_epochs) / 80 | (self.total_epochs - self.warmup_epochs)) ** self.power 81 | return [base_lr * factor for base_lr in self.base_lrs] 82 | 83 | 84 | class WarmupMultiStepLR(_LRScheduler): 85 | """ 86 | Multi-step learning rate scheduler with warmup. 87 | 88 | Args: 89 | optimizer (torch.optim.Optimizer): Optimizer 90 | milestones (List[int]): List of epoch indices for learning rate decay 91 | gamma (float): Multiplicative factor of learning rate decay 92 | warmup_epochs (int): Number of warmup epochs 93 | last_epoch (int): The index of last epoch 94 | """ 95 | 96 | def __init__(self, optimizer: torch.optim.Optimizer, milestones: List[int], 97 | gamma: float = 0.1, warmup_epochs: int = 0, last_epoch: int = -1): 98 | self.milestones = sorted(milestones) 99 | self.gamma = gamma 100 | self.warmup_epochs = warmup_epochs 101 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 102 | 103 | def get_lr(self) -> List[float]: 104 | """Compute learning rate for current epoch.""" 105 | if self.last_epoch < self.warmup_epochs: 106 | # Linear warmup 107 | warmup_factor = self.last_epoch / self.warmup_epochs 108 | return [base_lr * warmup_factor for base_lr in self.base_lrs] 109 | else: 110 | # Multi-step decay 111 | adjusted_epoch = self.last_epoch - self.warmup_epochs 112 | adjusted_milestones = [m - self.warmup_epochs for m in self.milestones if m > self.warmup_epochs] 113 | 114 | decay_factor = self.gamma ** len([m for m in adjusted_milestones if m <= adjusted_epoch]) 115 | return [base_lr * decay_factor for base_lr in self.base_lrs] 116 | 117 | 118 | class OneCycleLR(_LRScheduler): 119 | """ 120 | One Cycle learning rate policy as described in "Super-Convergence". 121 | 122 | Args: 123 | optimizer (torch.optim.Optimizer): Optimizer 124 | max_lr (float): Maximum learning rate 125 | total_steps (int): Total number of training steps 126 | pct_start (float): Percentage of cycle spent increasing learning rate 127 | anneal_strategy (str): Annealing strategy ('cos' or 'linear') 128 | div_factor (float): Determines initial learning rate (max_lr / div_factor) 129 | final_div_factor (float): Determines minimum learning rate (max_lr / final_div_factor) 130 | last_epoch (int): The index of last epoch 131 | """ 132 | 133 | def __init__(self, optimizer: torch.optim.Optimizer, max_lr: float, total_steps: int, 134 | pct_start: float = 0.3, anneal_strategy: str = 'cos', 135 | div_factor: float = 25.0, final_div_factor: float = 1e4, last_epoch: int = -1): 136 | self.max_lr = max_lr 137 | self.total_steps = total_steps 138 | self.pct_start = pct_start 139 | self.anneal_strategy = anneal_strategy 140 | self.div_factor = div_factor 141 | self.final_div_factor = final_div_factor 142 | 143 | self.initial_lr = max_lr / div_factor 144 | self.min_lr = max_lr / final_div_factor 145 | 146 | super(OneCycleLR, self).__init__(optimizer, last_epoch) 147 | 148 | def get_lr(self) -> List[float]: 149 | """Compute learning rate for current step.""" 150 | step_num = self.last_epoch 151 | 152 | if step_num <= self.pct_start * self.total_steps: 153 | # Increasing phase 154 | pct = step_num / (self.pct_start * self.total_steps) 155 | lr = self.initial_lr + pct * (self.max_lr - self.initial_lr) 156 | else: 157 | # Decreasing phase 158 | pct = (step_num - self.pct_start * self.total_steps) / ((1 - self.pct_start) * self.total_steps) 159 | 160 | if self.anneal_strategy == 'cos': 161 | lr = self.min_lr + (self.max_lr - self.min_lr) * (1 + math.cos(math.pi * pct)) / 2 162 | else: # linear 163 | lr = self.max_lr - pct * (self.max_lr - self.min_lr) 164 | 165 | return [lr for _ in self.base_lrs] 166 | 167 | 168 | class CyclicLR(_LRScheduler): 169 | """ 170 | Cyclic learning rate scheduler. 171 | 172 | Args: 173 | optimizer (torch.optim.Optimizer): Optimizer 174 | base_lr (float): Lower boundary of learning rate 175 | max_lr (float): Upper boundary of learning rate 176 | step_size_up (int): Number of training iterations in increasing half of cycle 177 | step_size_down (Optional[int]): Number of training iterations in decreasing half of cycle 178 | mode (str): One of 'triangular', 'triangular2', 'exp_range' 179 | gamma (float): Constant in 'exp_range' scaling function 180 | scale_fn (Optional[callable]): Custom scaling function 181 | scale_mode (str): 'cycle' or 'iterations' 182 | cycle_momentum (bool): Whether to cycle momentum inversely to learning rate 183 | base_momentum (float): Lower boundary of momentum 184 | max_momentum (float): Upper boundary of momentum 185 | last_epoch (int): The index of last epoch 186 | """ 187 | 188 | def __init__(self, optimizer: torch.optim.Optimizer, base_lr: float, max_lr: float, 189 | step_size_up: int = 2000, step_size_down: Optional[int] = None, 190 | mode: str = 'triangular', gamma: float = 1.0, scale_fn: Optional[callable] = None, 191 | scale_mode: str = 'cycle', cycle_momentum: bool = True, 192 | base_momentum: float = 0.8, max_momentum: float = 0.9, last_epoch: int = -1): 193 | 194 | self.base_lr = base_lr 195 | self.max_lr = max_lr 196 | self.step_size_up = step_size_up 197 | self.step_size_down = step_size_down or step_size_up 198 | self.total_size = self.step_size_up + self.step_size_down 199 | self.mode = mode 200 | self.gamma = gamma 201 | self.scale_fn = scale_fn 202 | self.scale_mode = scale_mode 203 | self.cycle_momentum = cycle_momentum 204 | self.base_momentum = base_momentum 205 | self.max_momentum = max_momentum 206 | 207 | super(CyclicLR, self).__init__(optimizer, last_epoch) 208 | 209 | def get_lr(self) -> List[float]: 210 | """Compute learning rate for current step.""" 211 | cycle = math.floor(1 + self.last_epoch / self.total_size) 212 | x = 1 + self.last_epoch / self.total_size - cycle 213 | 214 | if x <= self.step_size_up / self.total_size: 215 | scale_factor = x / (self.step_size_up / self.total_size) 216 | else: 217 | scale_factor = (x - 1) / (self.step_size_down / self.total_size) + 1 218 | 219 | # Apply scaling based on mode 220 | if self.scale_fn is None: 221 | if self.mode == 'triangular': 222 | lrs = [self.base_lr + (self.max_lr - self.base_lr) * max(0, (1 - abs(scale_factor))) 223 | for _ in self.base_lrs] 224 | elif self.mode == 'triangular2': 225 | lrs = [self.base_lr + (self.max_lr - self.base_lr) * max(0, (1 - abs(scale_factor))) / (2 ** (cycle - 1)) 226 | for _ in self.base_lrs] 227 | elif self.mode == 'exp_range': 228 | lrs = [self.base_lr + (self.max_lr - self.base_lr) * max(0, (1 - abs(scale_factor))) * (self.gamma ** self.last_epoch) 229 | for _ in self.base_lrs] 230 | else: 231 | lrs = [self.base_lr + (self.max_lr - self.base_lr) * max(0, (1 - abs(scale_factor))) * 232 | self.scale_fn(self.last_epoch if self.scale_mode == 'iterations' else cycle) 233 | for _ in self.base_lrs] 234 | 235 | return lrs 236 | 237 | 238 | def get_scheduler(optimizer: torch.optim.Optimizer, scheduler_config: Dict[str, Any]) -> Optional[_LRScheduler]: 239 | """ 240 | Factory function to create learning rate scheduler based on configuration. 241 | 242 | Args: 243 | optimizer (torch.optim.Optimizer): Optimizer 244 | scheduler_config (Dict[str, Any]): Scheduler configuration 245 | 246 | Returns: 247 | Optional[_LRScheduler]: Learning rate scheduler or None 248 | """ 249 | if not scheduler_config or scheduler_config.get('type') is None: 250 | return None 251 | 252 | scheduler_type = scheduler_config['type'].lower() 253 | 254 | if scheduler_type == 'cosine': 255 | T_max = scheduler_config.get('T_max', 100) 256 | eta_min = scheduler_config.get('min_lr', 0) 257 | warmup_epochs = scheduler_config.get('warmup_epochs', 0) 258 | 259 | return CosineAnnealingWithWarmup( 260 | optimizer, T_max=T_max, eta_min=eta_min, warmup_epochs=warmup_epochs 261 | ) 262 | 263 | elif scheduler_type == 'polynomial': 264 | total_epochs = scheduler_config.get('total_epochs', 100) 265 | power = scheduler_config.get('power', 0.9) 266 | warmup_epochs = scheduler_config.get('warmup_epochs', 0) 267 | 268 | return PolynomialLR( 269 | optimizer, total_epochs=total_epochs, power=power, warmup_epochs=warmup_epochs 270 | ) 271 | 272 | elif scheduler_type == 'step': 273 | step_size = scheduler_config.get('step_size', 30) 274 | gamma = scheduler_config.get('gamma', 0.1) 275 | 276 | return optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma) 277 | 278 | elif scheduler_type == 'multistep': 279 | milestones = scheduler_config.get('milestones', [60, 80]) 280 | gamma = scheduler_config.get('gamma', 0.1) 281 | warmup_epochs = scheduler_config.get('warmup_epochs', 0) 282 | 283 | if warmup_epochs > 0: 284 | return WarmupMultiStepLR( 285 | optimizer, milestones=milestones, gamma=gamma, warmup_epochs=warmup_epochs 286 | ) 287 | else: 288 | return optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=gamma) 289 | 290 | elif scheduler_type == 'exponential': 291 | gamma = scheduler_config.get('gamma', 0.95) 292 | return optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma) 293 | 294 | elif scheduler_type == 'reduce_on_plateau': 295 | mode = scheduler_config.get('mode', 'min') 296 | factor = scheduler_config.get('factor', 0.5) 297 | patience = scheduler_config.get('patience', 10) 298 | threshold = scheduler_config.get('threshold', 1e-4) 299 | 300 | return optim.lr_scheduler.ReduceLROnPlateau( 301 | optimizer, mode=mode, factor=factor, patience=patience, threshold=threshold 302 | ) 303 | 304 | elif scheduler_type == 'one_cycle': 305 | max_lr = scheduler_config.get('max_lr', 0.1) 306 | total_steps = scheduler_config.get('total_steps', 1000) 307 | pct_start = scheduler_config.get('pct_start', 0.3) 308 | anneal_strategy = scheduler_config.get('anneal_strategy', 'cos') 309 | div_factor = scheduler_config.get('div_factor', 25.0) 310 | final_div_factor = scheduler_config.get('final_div_factor', 1e4) 311 | 312 | return OneCycleLR( 313 | optimizer, max_lr=max_lr, total_steps=total_steps, pct_start=pct_start, 314 | anneal_strategy=anneal_strategy, div_factor=div_factor, final_div_factor=final_div_factor 315 | ) 316 | 317 | elif scheduler_type == 'cyclic': 318 | base_lr = scheduler_config.get('base_lr', 0.001) 319 | max_lr = scheduler_config.get('max_lr', 0.006) 320 | step_size_up = scheduler_config.get('step_size_up', 2000) 321 | mode = scheduler_config.get('mode', 'triangular') 322 | gamma = scheduler_config.get('gamma', 1.0) 323 | 324 | return CyclicLR( 325 | optimizer, base_lr=base_lr, max_lr=max_lr, step_size_up=step_size_up, 326 | mode=mode, gamma=gamma 327 | ) 328 | 329 | else: 330 | raise ValueError(f"Unsupported scheduler type: {scheduler_type}") 331 | 332 | 333 | class WarmupScheduler: 334 | """ 335 | Wrapper for adding warmup to any scheduler. 336 | 337 | Args: 338 | optimizer (torch.optim.Optimizer): Optimizer 339 | scheduler (_LRScheduler): Base scheduler 340 | warmup_epochs (int): Number of warmup epochs 341 | warmup_method (str): Warmup method ('linear' or 'constant') 342 | warmup_factor (float): Warmup factor for 'constant' method 343 | """ 344 | 345 | def __init__(self, optimizer: torch.optim.Optimizer, scheduler: _LRScheduler, 346 | warmup_epochs: int, warmup_method: str = 'linear', warmup_factor: float = 0.1): 347 | self.optimizer = optimizer 348 | self.scheduler = scheduler 349 | self.warmup_epochs = warmup_epochs 350 | self.warmup_method = warmup_method 351 | self.warmup_factor = warmup_factor 352 | self.base_lrs = [group['lr'] for group in optimizer.param_groups] 353 | self.last_epoch = 0 354 | 355 | def step(self, epoch: Optional[int] = None): 356 | """Step the scheduler.""" 357 | if epoch is None: 358 | epoch = self.last_epoch + 1 359 | self.last_epoch = epoch 360 | 361 | if epoch < self.warmup_epochs: 362 | # Warmup phase 363 | if self.warmup_method == 'linear': 364 | warmup_factor = epoch / self.warmup_epochs 365 | else: # constant 366 | warmup_factor = self.warmup_factor 367 | 368 | for i, param_group in enumerate(self.optimizer.param_groups): 369 | param_group['lr'] = self.base_lrs[i] * warmup_factor 370 | else: 371 | # Normal scheduling 372 | self.scheduler.step(epoch - self.warmup_epochs) 373 | 374 | def state_dict(self): 375 | """Return state dict.""" 376 | return { 377 | 'scheduler': self.scheduler.state_dict(), 378 | 'last_epoch': self.last_epoch, 379 | 'warmup_epochs': self.warmup_epochs, 380 | 'warmup_method': self.warmup_method, 381 | 'warmup_factor': self.warmup_factor, 382 | 'base_lrs': self.base_lrs 383 | } 384 | 385 | def load_state_dict(self, state_dict): 386 | """Load state dict.""" 387 | self.scheduler.load_state_dict(state_dict['scheduler']) 388 | self.last_epoch = state_dict['last_epoch'] 389 | self.warmup_epochs = state_dict['warmup_epochs'] 390 | self.warmup_method = state_dict['warmup_method'] 391 | self.warmup_factor = state_dict['warmup_factor'] 392 | self.base_lrs = state_dict['base_lrs'] 393 | 394 | 395 | if __name__ == "__main__": 396 | # Test schedulers 397 | import matplotlib.pyplot as plt 398 | 399 | # Create dummy optimizer 400 | model = torch.nn.Linear(10, 1) 401 | optimizer = torch.optim.SGD(model.parameters(), lr=0.1) 402 | 403 | # Test different schedulers 404 | schedulers = { 405 | 'Cosine with Warmup': CosineAnnealingWithWarmup(optimizer, T_max=100, warmup_epochs=10), 406 | 'Polynomial': PolynomialLR(optimizer, total_epochs=100, power=0.9, warmup_epochs=10), 407 | 'One Cycle': OneCycleLR(optimizer, max_lr=0.1, total_steps=100), 408 | } 409 | 410 | # Plot learning rate schedules 411 | fig, axes = plt.subplots(1, len(schedulers), figsize=(15, 5)) 412 | if len(schedulers) == 1: 413 | axes = [axes] 414 | 415 | for i, (name, scheduler) in enumerate(schedulers.items()): 416 | lrs = [] 417 | # Reset optimizer 418 | for param_group in optimizer.param_groups: 419 | param_group['lr'] = 0.1 420 | scheduler.last_epoch = -1 421 | 422 | for epoch in range(100): 423 | scheduler.step() 424 | lrs.append(optimizer.param_groups[0]['lr']) 425 | 426 | axes[i].plot(lrs) 427 | axes[i].set_title(name) 428 | axes[i].set_xlabel('Epoch') 429 | axes[i].set_ylabel('Learning Rate') 430 | axes[i].grid(True) 431 | 432 | plt.tight_layout() 433 | plt.savefig('/workspace/semantic_segmentation_project/scheduler_comparison.png') 434 | print("Scheduler comparison plot saved!") 435 | 436 | print("Scheduler module test completed successfully!") 437 | -------------------------------------------------------------------------------- /utils/trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training Engine for Semantic Segmentation 3 | 4 | This module contains the main training engine with support for 5 | mixed precision training, gradient accumulation, and comprehensive logging. 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.cuda.amp import GradScaler, autocast 11 | from torch.utils.data import DataLoader 12 | import numpy as np 13 | from PIL import Image 14 | import albumentations as A 15 | import time 16 | import os 17 | from typing import Dict, List, Optional, Tuple, Any 18 | from tqdm import tqdm 19 | import logging 20 | from .metrics import SegmentationMetrics 21 | from .visualization import save_predictions 22 | from .scheduler import get_scheduler 23 | 24 | 25 | class SegmentationTrainer: 26 | """ 27 | Comprehensive training engine for semantic segmentation models. 28 | 29 | Args: 30 | model (nn.Module): Segmentation model 31 | train_loader (DataLoader): Training data loader 32 | val_loader (DataLoader): Validation data loader 33 | criterion (nn.Module): Loss function 34 | optimizer (torch.optim.Optimizer): Optimizer 35 | config (Dict): Training configuration 36 | device (str): Device to run training on 37 | """ 38 | 39 | def __init__(self, model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, 40 | criterion: nn.Module, optimizer: torch.optim.Optimizer, 41 | config: Dict[str, Any], device: str = 'cuda'): 42 | self.model = model 43 | self.train_loader = train_loader 44 | self.val_loader = val_loader 45 | self.val_transform_mean_std = self._extract_validation_transform_mean_std() 46 | self.criterion = criterion 47 | self.optimizer = optimizer 48 | self.config = config 49 | self.device = device 50 | 51 | # Training settings 52 | self.epochs = config['training']['epochs'] 53 | self.use_amp = config['training'].get('use_amp', False) 54 | self.accumulate_grad_batches = config['training'].get('accumulate_grad_batches', 1) 55 | self.clip_grad_norm = config['training'].get('clip_grad_norm', None) 56 | 57 | # Eval settings 58 | self.eval_interval = config['validation'].get('eval_interval', 1) 59 | log_dir = config['logging'].get('log_dir') 60 | experiment_name = config['logging'].get('experiment_name') 61 | self.log_dir = os.path.join(log_dir, experiment_name) 62 | self.checkpoint_dir = os.path.join(self.log_dir, 'checkpoints') 63 | 64 | # Visualization settings 65 | self.save_predictions_flag = config['visualization'].get('save_predictions', True) 66 | self.vis_save_interval = config['visualization'].get('save_interval', 10) 67 | self.num_vis_samples = config['visualization'].get('num_vis_samples', 8) 68 | 69 | # Initialize components 70 | self.scaler = GradScaler() if self.use_amp else None 71 | self.scheduler = get_scheduler(optimizer, config['training'].get('scheduler', None)) 72 | 73 | # Initialize metrics based on task type 74 | num_classes = config['dataset']['num_classes'] 75 | ignore_index = config['dataset'].get('ignore_index', 255) 76 | class_names = config['dataset'].get('class_names', None) 77 | self.metrics = SegmentationMetrics( 78 | num_classes=num_classes, 79 | ignore_index=ignore_index 80 | ) 81 | 82 | # Tracking variables 83 | self.current_epoch = 0 84 | self.global_step = 0 85 | self.best_miou = 0.0 86 | self.train_losses = [] 87 | self.val_losses = [] 88 | self.val_mious = [] 89 | 90 | # Setup logging 91 | self._setup_logging() 92 | 93 | # Create checkpoint directory 94 | os.makedirs(self.checkpoint_dir, exist_ok=True) 95 | 96 | def _extract_validation_transform_mean_std(self): 97 | """ 98 | Extract mean and standard deviation from the validation transforms if A.Normalize is present. 99 | 100 | Returns: 101 | Optional[Tuple[Tuple[float, ...], Tuple[float, ...]]]: 102 | A tuple (mean_tuple, std_tuple), e.g. ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 103 | or None if no Normalize transform is found. 104 | """ 105 | validation_transform = self.val_loader.dataset.transform 106 | for t in validation_transform.transforms: 107 | if isinstance(t, A.Normalize): 108 | # directly return the tuple-of-floats structure 109 | return t.mean, t.std 110 | return None 111 | 112 | def _denormalize_tensor(self, tensor: torch.Tensor) -> torch.Tensor: 113 | """ 114 | Denormalize a tensor using the mean and std from the validation transforms. 115 | 116 | Args: 117 | tensor (torch.Tensor): Input tensor [C, H, W] or [B, C, H, W] 118 | 119 | Returns: 120 | torch.Tensor: Denormalized tensor 121 | """ 122 | if self.val_transform_mean_std is None: 123 | return tensor 124 | 125 | mean, std = list(self.val_transform_mean_std[0]), list(self.val_transform_mean_std[1]) # assume each is a list like [0.485, 0.456, 0.406] 126 | 127 | mean = torch.tensor(mean, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1) 128 | std = torch.tensor(std, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1) 129 | 130 | return (tensor * std + mean) * 255.0 131 | 132 | 133 | def _setup_logging(self): 134 | """Setup logging configuration.""" 135 | 136 | logging.basicConfig( 137 | level=logging.INFO, 138 | format='%(asctime)s - %(levelname)s - %(message)s', 139 | handlers=[ 140 | logging.FileHandler(os.path.join(self.log_dir, 'training.log')), 141 | logging.StreamHandler() 142 | ] 143 | ) 144 | self.logger = logging.getLogger(__name__) 145 | 146 | def train_epoch(self) -> float: 147 | """ 148 | Train for one epoch. 149 | 150 | Returns: 151 | float: Average training loss for the epoch 152 | """ 153 | self.model.train() 154 | total_loss = 0.0 155 | num_batches = len(self.train_loader) 156 | 157 | # Progress bar 158 | pbar = tqdm(self.train_loader, desc=f'Epoch {self.current_epoch + 1}/{self.epochs}') 159 | 160 | self.optimizer.zero_grad() 161 | 162 | for batch_idx, batch in enumerate(pbar): 163 | # Move data to device 164 | images = batch['image'].to(self.device, non_blocking=True) 165 | labels = batch['label'].to(self.device, non_blocking=True) 166 | 167 | # Forward pass with optional mixed precision 168 | if self.use_amp: 169 | with autocast(): 170 | outputs = self.model(images) 171 | main_loss = self.criterion(outputs['out'], labels.long()) 172 | loss = main_loss 173 | 174 | # Normalize loss for gradient accumulation 175 | loss = loss / self.accumulate_grad_batches 176 | else: 177 | outputs = self.model(images) 178 | main_loss = self.criterion(outputs['out'], labels.long()) 179 | loss = main_loss 180 | 181 | loss = loss / self.accumulate_grad_batches 182 | 183 | # Backward pass 184 | if self.use_amp: 185 | self.scaler.scale(loss).backward() 186 | else: 187 | loss.backward() 188 | 189 | # Gradient accumulation and optimization step 190 | if (batch_idx + 1) % self.accumulate_grad_batches == 0: 191 | if self.use_amp: 192 | # Gradient clipping 193 | if self.clip_grad_norm: 194 | self.scaler.unscale_(self.optimizer) 195 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_grad_norm) 196 | 197 | self.scaler.step(self.optimizer) 198 | self.scaler.update() 199 | else: 200 | # Gradient clipping 201 | if self.clip_grad_norm: 202 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_grad_norm) 203 | 204 | self.optimizer.step() 205 | 206 | self.optimizer.zero_grad() 207 | 208 | # Update learning rate 209 | if self.scheduler and hasattr(self.scheduler, 'step_batch'): 210 | self.scheduler.step_batch(self.global_step) 211 | 212 | self.global_step += 1 213 | 214 | # Update metrics 215 | total_loss += loss.item() * self.accumulate_grad_batches 216 | 217 | # Update progress bar 218 | pbar.set_postfix({ 219 | 'Loss': f'{loss.item() * self.accumulate_grad_batches:.4f}', 220 | 'LR': f'{self.optimizer.param_groups[0]["lr"]:.6f}' 221 | }) 222 | 223 | 224 | # Update learning rate scheduler (epoch-based) 225 | if self.scheduler and hasattr(self.scheduler, 'step'): 226 | self.scheduler.step() 227 | 228 | avg_loss = total_loss / num_batches 229 | self.train_losses.append(avg_loss) 230 | 231 | return avg_loss 232 | 233 | def validate(self) -> Tuple[float, Dict[str, float]]: 234 | """ 235 | Validate the model. 236 | 237 | Returns: 238 | Tuple[float, Dict[str, float]]: Average validation loss and metrics 239 | """ 240 | self.model.eval() 241 | total_loss = 0.0 242 | num_batches = len(self.val_loader) 243 | 244 | # Reset metrics 245 | self.metrics.reset() 246 | 247 | # Collect predictions for visualization 248 | vis_data = [] 249 | 250 | with torch.no_grad(): 251 | pbar = tqdm(self.val_loader, desc='Validation') 252 | 253 | for batch_idx, batch in enumerate(pbar): 254 | # Move data to device 255 | images = batch['image'].to(self.device, non_blocking=True) 256 | labels = batch['label'].to(self.device, non_blocking=True) 257 | 258 | # Forward pass 259 | if self.use_amp: 260 | with autocast(): 261 | outputs = self.model(images) 262 | if isinstance(outputs, dict): 263 | predictions = outputs['out'] 264 | else: 265 | predictions = outputs 266 | loss = self.criterion(predictions, labels.long()) 267 | else: 268 | outputs = self.model(images) 269 | if isinstance(outputs, dict): 270 | predictions = outputs['out'] 271 | else: 272 | predictions = outputs 273 | loss = self.criterion(predictions, labels.long()) 274 | 275 | # Update metrics 276 | total_loss += loss.item() 277 | self.metrics.update(predictions, labels) 278 | 279 | # Collect data for visualization 280 | if (batch_idx < self.num_vis_samples // len(images) + 1 and 281 | len(vis_data) < self.num_vis_samples): 282 | batch_size = min(self.num_vis_samples - len(vis_data), len(images)) 283 | for i in range(batch_size): 284 | vis_data.append({ 285 | 'image': self._denormalize_tensor(images[i].cpu()), 286 | 'label': labels[i].cpu(), 287 | 'prediction': torch.argmax(predictions[i], dim=0).cpu() 288 | }) 289 | 290 | # Update progress bar 291 | pbar.set_postfix({'Loss': f'{loss.item():.4f}'}) 292 | 293 | # Compute final metrics 294 | avg_loss = total_loss / num_batches 295 | metrics_dict = self.metrics.compute() 296 | 297 | self.val_losses.append(avg_loss) 298 | # Use mDice for medical tasks, mIoU for standard tasks 299 | primary_metric = metrics_dict.get('mDice', metrics_dict.get('mIoU', 0.0)) 300 | self.val_mious.append(primary_metric) 301 | 302 | # Save predictions visualization 303 | if (self.save_predictions_flag and 304 | (self.current_epoch + 1) % self.vis_save_interval == 0): 305 | vis_dir = os.path.join(self.log_dir, 'visualizations') 306 | os.makedirs(vis_dir, exist_ok=True) 307 | 308 | save_predictions( 309 | vis_data[:self.num_vis_samples], 310 | save_dir=vis_dir, 311 | epoch=self.current_epoch + 1, 312 | class_colors=getattr(self.config['dataset'], 'class_colors', None) 313 | ) 314 | 315 | return avg_loss, metrics_dict 316 | 317 | def save_checkpoint(self, metrics: Dict[str, float], is_best: bool = False , only_best: bool = True): 318 | """ 319 | Save model checkpoint. 320 | 321 | Args: 322 | metrics (Dict[str, float]): Current metrics 323 | is_best (bool): Whether this is the best checkpoint 324 | """ 325 | checkpoint = { 326 | 'epoch': self.current_epoch + 1, 327 | 'model_state_dict': self.model.state_dict(), 328 | 'optimizer_state_dict': self.optimizer.state_dict(), 329 | 'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None, 330 | 'scaler_state_dict': self.scaler.state_dict() if self.scaler else None, 331 | 'metrics': metrics, 332 | 'config': self.config, 333 | 'train_losses': self.train_losses, 334 | 'val_losses': self.val_losses, 335 | 'val_mious': self.val_mious 336 | } 337 | # Save best checkpoint 338 | if is_best: 339 | best_path = os.path.join(self.checkpoint_dir, 'best_model.pth') 340 | torch.save(checkpoint, best_path) 341 | self.logger.info(f'New best model saved with mIoU: {metrics["mIoU"]:.4f}') 342 | else: 343 | if only_best: 344 | self.logger.info("Skipping regular checkpoint save as only_best is True.") 345 | return 346 | # Save regular checkpoint 347 | checkpoint_path = os.path.join( 348 | self.checkpoint_dir, 349 | f'checkpoint_epoch_{self.current_epoch + 1:03d}.pth' 350 | ) 351 | torch.save(checkpoint, checkpoint_path) 352 | self.logger.info(f'Checkpoint saved: {checkpoint_path}') 353 | 354 | def load_checkpoint(self, checkpoint_path: str): 355 | """ 356 | Load model checkpoint. 357 | 358 | Args: 359 | checkpoint_path (str): Path to checkpoint file 360 | """ 361 | if not os.path.exists(checkpoint_path): 362 | self.logger.warning(f'Checkpoint not found: {checkpoint_path}') 363 | return 364 | 365 | checkpoint = torch.load(checkpoint_path, map_location=self.device) 366 | 367 | # Load model state 368 | self.model.load_state_dict(checkpoint['model_state_dict']) 369 | 370 | # Load optimizer state 371 | self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 372 | 373 | # Load scheduler state 374 | if self.scheduler and checkpoint.get('scheduler_state_dict'): 375 | self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 376 | 377 | # Load scaler state 378 | if self.scaler and checkpoint.get('scaler_state_dict'): 379 | self.scaler.load_state_dict(checkpoint['scaler_state_dict']) 380 | 381 | # Load training state 382 | self.current_epoch = checkpoint['epoch'] 383 | self.train_losses = checkpoint.get('train_losses', []) 384 | self.val_losses = checkpoint.get('val_losses', []) 385 | self.val_mious = checkpoint.get('val_mious', []) 386 | 387 | if self.val_mious: 388 | self.best_miou = max(self.val_mious) 389 | 390 | self.logger.info(f'Checkpoint loaded: {checkpoint_path}') 391 | self.logger.info(f'Resuming from epoch {self.current_epoch}') 392 | 393 | def train(self): 394 | """Main training loop.""" 395 | self.logger.info('Starting training...') 396 | self.logger.info(f'Total epochs: {self.epochs}') 397 | self.logger.info(f'Device: {self.device}') 398 | self.logger.info(f'Mixed precision: {self.use_amp}') 399 | self.logger.info(f'Gradient accumulation steps: {self.accumulate_grad_batches}') 400 | 401 | start_time = time.time() 402 | 403 | for epoch in range(self.current_epoch, self.epochs): 404 | self.current_epoch = epoch 405 | 406 | # Training 407 | train_loss = self.train_epoch() 408 | 409 | # Validation 410 | if (epoch + 1) % self.eval_interval == 0: 411 | val_loss, val_metrics = self.validate() 412 | 413 | # Check if this is the best model 414 | primary_metric_name = 'mIoU' 415 | current_metric = val_metrics.get(primary_metric_name, val_metrics.get('mIoU', 0.0)) 416 | is_best = current_metric > self.best_miou 417 | if is_best: 418 | self.best_miou = current_metric 419 | 420 | # Logging 421 | self.logger.info( 422 | f'Epoch {epoch + 1}/{self.epochs} - ' 423 | f'Train Loss: {train_loss:.4f}, ' 424 | f'Val Loss: {val_loss:.4f}, ' 425 | f'Val {primary_metric_name}: {current_metric:.4f}, ' 426 | f'Best {primary_metric_name}: {self.best_miou:.4f}' 427 | ) 428 | 429 | # Print detailed metrics 430 | for metric_name, metric_value in val_metrics.items(): 431 | if metric_name != 'mIoU': 432 | self.logger.info(f'{metric_name}: {metric_value:.4f}') 433 | 434 | else: 435 | self.logger.info( 436 | f'Epoch {epoch + 1}/{self.epochs} - Train Loss: {train_loss:.4f}' 437 | ) 438 | self.save_checkpoint(val_metrics, is_best=is_best, only_best = True) 439 | 440 | total_time = time.time() - start_time 441 | self.logger.info(f'Training completed in {total_time / 3600:.2f} hours') 442 | self.logger.info(f'Best mIoU: {self.best_miou:.4f}') 443 | 444 | def get_training_stats(self) -> Dict[str, List[float]]: 445 | """ 446 | Get training statistics. 447 | 448 | Returns: 449 | Dict[str, List[float]]: Training statistics 450 | """ 451 | return { 452 | 'train_losses': self.train_losses, 453 | 'val_losses': self.val_losses, 454 | 'val_mious': self.val_mious} 455 | 456 | 457 | if __name__ == "__main__": 458 | # Test trainer functionality 459 | print("Testing trainer module...") 460 | 461 | # This would require actual model, data loaders, etc. 462 | # For now, just test imports 463 | print("Trainer module loaded successfully!") 464 | --------------------------------------------------------------------------------